Accuracy with SIFT + bag of words
Hi,
I'm working on content-based image retrieval (CBIR) using SIFT + bag of words. My goal is, given a query image, find which image from a large database is most similar to the query image. I've made some progress — it works perfectly for an identical image, and well for a slightly similar image, but once the image becomes a bit more dissimilar, it no longer detects the right image. Even when it does not detect the right image, when I plot the matches with drawMatchesKnn
, the query image has far more matches with the right image than the wrong image (the one that was stated to be the most accurate match).
I have messed with a bunch of the parameters, including number of SIFT features, and number of clusters, with no noticeable impact. I think the problem is with a bug in my code. Does anyone notice anything that might be causing the searches to be inaccurate for dissimilar images? Or have any recommendations on what can be done to improve the accuracy?
This is how I'm doing it:
- Go through every image, get its descriptors with SIFT, and build a list of these descriptors.
- Using k-means, find the centroids of the list of descriptors. This is the "dictionary".
- Go through every image again, and get the k-nearest neighbors
knnMatch
with k=1 for each image's descriptors and the centroids. Use each match to create a histogram for each image, usingmatch.trainIdx
. - Normalize each image's histogram by dividing the count of each "word" by the sum of the "words".
- Use
knnMatch
with k=1 with the query image's descriptors and the centroids. Go through the matches and create a normalized histogram. - Use
knnMatch
with k=1 on the query image's histogram, and the histograms of all of the images in the database. This creates a list of matches, ordered by similarity to the query image.
Here is my code. I know it's not well organized and inefficient -- I want to focus on getting this to work before cleaning it up.
import numpy as np
import cv2
import os
from matplotlib import pyplot as plt
sift = cv2.xfeatures2d.SIFT_create()
FLANN_INDEX_KDTREE = 0
index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 100)
search_params = dict(checks = 100)
flann = cv2.FlannBasedMatcher(index_params, search_params)
bf = cv2.BFMatcher()
img1 = cv2.imread('path',0)
db = # load database
kp1, des1 = sift.detectAndCompute(img1,None)
load = False
clusters = 800
if load:
db.query('DELETE FROM centroids')
db.query('DELETE FROM histogram')
descriptors = []
for file in os.listdir('path'):
if file.endswith('.png'):
img = cv2.imread('path/{}'.format(file), 0)
kp, des = sift.detectAndCompute(img,None)
if des is None:
continue
descriptors.extend(des)
descriptors = np.float32(descriptors)
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 5, .01)
centroids = cv2.kmeans(descriptors, clusters, None, criteria, 1, cv2.KMEANS_PP_CENTERS)[2]
db.insert('centroids', d = np.ndarray.dumps(centroids))
for file in os.listdir('path'):
counter = np.zeros((clusters,), dtype=np.uint32)
if file.endswith('.png'):
img = cv2.imread('path/{}'.format ...