Hi,
I'm working on content-based image retrieval (CBIR) using SIFT + bag of words. My goal is, given a query image, find which image from a large database is most similar to the query image. I've made some progress — it works perfectly for an identical image, and well for a slightly similar image, but once the image becomes a bit more dissimilar, it no longer detects the right image. Even when it does not detect the right image, when I plot the matches with drawMatchesKnn
, the query image has far more matches with the right image than the wrong image (the one that was stated to be the most accurate match).
I have messed with a bunch of the parameters, including number of SIFT features, and number of clusters, with no noticeable impact. I think the problem is with a bug in my code.
This is how I'm doing it:
- Go through every image, get its descriptors with SIFT, and build a list of these descriptors.
- Using k-means, find the centroids of the list of descriptors. This is the "dictionary".
- Go through every image again, and get the k-nearest neighbors
knnMatch
with k=1 for each image's descriptors and the centroids. Use each match to create a histogram for each image, usingmatch.trainIdx
. - Normalize each image's histogram by dividing the count of each "word" by the sum of the "words".
- Use
knnMatch
with k=1 with the query image's descriptors and the centroids. Go through the matches and create a normalized histogram. - Use
knnMatch
with k=1 on the query image's histogram, and the histograms of all of the images in the database. This creates a list of matches, ordered by similarity to the query image.
Here is my code. I know it's not well organized and inefficient -- I want to focus on getting this to work before cleaning it up.
import numpy as np
import cv2
import os
from matplotlib import pyplot as plt
sift = cv2.xfeatures2d.SIFT_create()
FLANN_INDEX_KDTREE = 0
index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 100)
search_params = dict(checks = 100)
flann = cv2.FlannBasedMatcher(index_params, search_params)
bf = cv2.BFMatcher()
img1 = cv2.imread('path',0)
db = # load database
kp1, des1 = sift.detectAndCompute(img1,None)
load = False
clusters = 800
if load:
db.query('DELETE FROM centroids')
db.query('DELETE FROM histogram')
descriptors = []
for file in os.listdir('path'):
if file.endswith('.png'):
img = cv2.imread('path/{}'.format(file), 0)
kp, des = sift.detectAndCompute(img,None)
if des is None:
continue
descriptors.extend(des)
descriptors = np.float32(descriptors)
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 5, .01)
centroids = cv2.kmeans(descriptors, clusters, None, criteria, 1, cv2.KMEANS_PP_CENTERS)[2]
db.insert('centroids', d = np.ndarray.dumps(centroids))
for file in os.listdir('path'):
counter = np.zeros((clusters,), dtype=np.uint32)
if file.endswith('.png'):
img = cv2.imread('path/{}'.format(file),0)
kp, d = sift.detectAndCompute(img,None)
if d is None:
continue
matches = bf.knnMatch(d, centroids, k=1)
for match in matches:
counter[match[0].trainIdx] += 1
counter_sum = np.sum(counter)
counter = [float(n)/counter_sum for n in counter]
db.insert('histogram', frame_id = file, count=','.join(np.char.mod('%f', counter)))
histograms_db = list(db.query('SELECT * FROM histogram'))
histograms = []
for histogram in histograms_db:
histogram = histogram['count'].split(',')
histograms.append(histogram)
histograms = np.array(histograms)
counter = np.zeros((clusters,), dtype=np.uint32)
centroids = np.loads(db.query('SELECT * FROM centroids')[0]['d'])
matches = bf.knnMatch(des1, centroids, k=1)
for match in matches:
counter[match[0].trainIdx] += 1
counter_sum = np.sum(counter)
counter = [float(n)/counter_sum for n in counter]
matches = bf.knnMatch(np.float32([counter]), np.float32(histograms), k=1)
for match in matches[0]:
print "{} {}".format(histograms_db[match.trainIdx]['frame_id'], match.distance)
name = histograms_db[match.trainIdx]['frame_id']