2024-09-01 13:28:24 +02:00
|
|
|
import os
|
2024-09-02 21:38:51 +02:00
|
|
|
import sys
|
2024-09-01 13:28:24 +02:00
|
|
|
import numpy as np
|
|
|
|
from tensorflow.keras.applications import vgg19
|
|
|
|
from tensorflow.keras.preprocessing import image
|
|
|
|
from tensorflow.keras.applications.vgg19 import preprocess_input
|
|
|
|
import faiss
|
|
|
|
import cv2
|
|
|
|
|
2024-09-02 21:34:53 +02:00
|
|
|
|
|
|
|
def query_yes_no(question, default="yes"):
|
|
|
|
"""Ask a yes/no question via raw_input() and return their answer.
|
|
|
|
|
|
|
|
"question" is a string that is presented to the user.
|
|
|
|
"default" is the presumed answer if the user just hits <Enter>.
|
|
|
|
It must be "yes" (the default), "no" or None (meaning
|
|
|
|
an answer is required of the user).
|
|
|
|
|
|
|
|
The "answer" return value is True for "yes" or False for "no".
|
|
|
|
"""
|
|
|
|
valid = {"yes": True, "y": True, "ye": True,
|
|
|
|
"no": False, "n": False}
|
|
|
|
if default is None:
|
|
|
|
prompt = " [y/n] "
|
|
|
|
elif default == "yes":
|
|
|
|
prompt = " [Y/n] "
|
|
|
|
elif default == "no":
|
|
|
|
prompt = " [y/N] "
|
|
|
|
else:
|
|
|
|
raise ValueError("invalid default answer: '%s'" % default)
|
|
|
|
|
|
|
|
while True:
|
|
|
|
sys.stdout.write(question + prompt)
|
|
|
|
choice = input().lower()
|
|
|
|
if default is not None and choice == '':
|
|
|
|
return valid[default]
|
|
|
|
elif choice in valid:
|
|
|
|
return valid[choice]
|
|
|
|
else:
|
|
|
|
sys.stdout.write("Please respond with 'yes' or 'no' "
|
|
|
|
"(or 'y' or 'n').\n")
|
|
|
|
|
|
|
|
|
2024-09-01 13:28:24 +02:00
|
|
|
model = vgg19.VGG19(weights="imagenet", include_top=False, pooling="avg")
|
|
|
|
|
|
|
|
|
|
|
|
def extract_features(img_path, model):
|
|
|
|
img = image.load_img(img_path, target_size=(224, 224))
|
|
|
|
img = image.img_to_array(img)
|
|
|
|
img = np.expand_dims(img, axis=0)
|
|
|
|
img = preprocess_input(img)
|
|
|
|
feature = model.predict(img)
|
|
|
|
return feature.flatten()
|
|
|
|
|
|
|
|
|
|
|
|
image_dir = "images/"
|
|
|
|
image_paths = [
|
|
|
|
os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(".jpg")
|
|
|
|
]
|
|
|
|
features = []
|
|
|
|
|
2024-09-02 21:36:40 +02:00
|
|
|
if os.path.exists("image_index.bin"):
|
|
|
|
if query_yes_no("Load the index?", default="yes"):
|
|
|
|
index = faiss.read_index("image_index.bin")
|
|
|
|
else:
|
|
|
|
for image_path in image_paths:
|
|
|
|
img_feature = extract_features(image_path, model)
|
|
|
|
features.append(img_feature)
|
|
|
|
|
|
|
|
features = np.array(features)
|
|
|
|
|
|
|
|
d = features.shape[1]
|
|
|
|
index = faiss.IndexFlatL2(d)
|
|
|
|
index.add(features)
|
|
|
|
|
|
|
|
if query_yes_no("Save the index?", default="yes"):
|
|
|
|
faiss.write_index(index, "image_index.bin")
|
|
|
|
else:
|
|
|
|
for image_path in image_paths:
|
|
|
|
img_feature = extract_features(image_path, model)
|
|
|
|
features.append(img_feature)
|
|
|
|
|
|
|
|
features = np.array(features)
|
2024-09-01 13:28:24 +02:00
|
|
|
|
2024-09-02 21:36:40 +02:00
|
|
|
d = features.shape[1]
|
|
|
|
index = faiss.IndexFlatL2(d)
|
|
|
|
index.add(features)
|
2024-09-01 13:28:24 +02:00
|
|
|
|
2024-09-02 21:36:40 +02:00
|
|
|
if query_yes_no("Save the index?", default="yes"):
|
|
|
|
faiss.write_index(index, "image_index.bin")
|
2024-09-01 13:28:24 +02:00
|
|
|
|
|
|
|
|
|
|
|
def find_similar_images(query_image_path, index, k=6):
|
|
|
|
query_feature = extract_features(query_image_path, model).reshape(1, -1)
|
2024-09-02 21:39:28 +02:00
|
|
|
distance, indices = index.search(query_feature, k)
|
|
|
|
return distance.flatten(), indices.flatten()
|
2024-09-01 13:28:24 +02:00
|
|
|
|
|
|
|
|
|
|
|
query_image_path = "query_image.jpg"
|
2024-09-02 21:39:28 +02:00
|
|
|
similar_image_distance, similar_image_indices = find_similar_images(query_image_path, index)
|
2024-09-01 13:28:24 +02:00
|
|
|
|
|
|
|
# display the results
|
2024-09-02 21:39:28 +02:00
|
|
|
for i, idx in enumerate(similar_image_indices):
|
2024-09-01 13:28:24 +02:00
|
|
|
similar_image_path = image_paths[idx]
|
|
|
|
img = cv2.imread(similar_image_path)
|
2024-09-02 21:39:28 +02:00
|
|
|
similarity = 1 - similar_image_distance[i] / np.max(similar_image_distance)
|
|
|
|
cv2.imshow(f"Similarity: {similarity * 100:.2f}% ({i+1}/{len(similar_image_indices)})", img)
|
2024-09-01 13:28:24 +02:00
|
|
|
cv2.waitKey(0)
|
|
|
|
cv2.destroyAllWindows()
|