From 2bbf440b4afabbb688051427c72da07f8010f36c Mon Sep 17 00:00:00 2001 From: Aroy-Art Date: Sun, 1 Sep 2024 13:28:24 +0200 Subject: [PATCH] Add: basic similar image classification --- Python/similar-images-classification/main.py | 53 ++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 Python/similar-images-classification/main.py diff --git a/Python/similar-images-classification/main.py b/Python/similar-images-classification/main.py new file mode 100644 index 0000000..2cd3eda --- /dev/null +++ b/Python/similar-images-classification/main.py @@ -0,0 +1,53 @@ +import os +import numpy as np +from tensorflow.keras.applications import vgg19 +from tensorflow.keras.preprocessing import image +from tensorflow.keras.applications.vgg19 import preprocess_input +import faiss +import cv2 + +model = vgg19.VGG19(weights="imagenet", include_top=False, pooling="avg") + + +def extract_features(img_path, model): + img = image.load_img(img_path, target_size=(224, 224)) + img = image.img_to_array(img) + img = np.expand_dims(img, axis=0) + img = preprocess_input(img) + feature = model.predict(img) + return feature.flatten() + + +image_dir = "images/" +image_paths = [ + os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(".jpg") +] +features = [] + +for image_path in image_paths: + img_feature = extract_features(image_path, model) + features.append(img_feature) + +features = np.array(features) + +d = features.shape[1] +index = faiss.IndexFlatL2(d) +index.add(features) + + +def find_similar_images(query_image_path, index, k=6): + query_feature = extract_features(query_image_path, model).reshape(1, -1) + _distance, indices = index.search(query_feature, k) + return indices.flatten() + + +query_image_path = "query_image.jpg" +similar_image_indices = find_similar_images(query_image_path, index) + +# display the results +for idx in similar_image_indices: + similar_image_path = image_paths[idx] + img = cv2.imread(similar_image_path) + cv2.imshow("Similar Image", img) + cv2.waitKey(0) + cv2.destroyAllWindows()