Add: basic similar image classification
This commit is contained in:
parent
74809e7a29
commit
2bbf440b4a
1 changed files with 53 additions and 0 deletions
53
Python/similar-images-classification/main.py
Normal file
53
Python/similar-images-classification/main.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from tensorflow.keras.applications import vgg19
|
||||||
|
from tensorflow.keras.preprocessing import image
|
||||||
|
from tensorflow.keras.applications.vgg19 import preprocess_input
|
||||||
|
import faiss
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
model = vgg19.VGG19(weights="imagenet", include_top=False, pooling="avg")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_features(img_path, model):
|
||||||
|
img = image.load_img(img_path, target_size=(224, 224))
|
||||||
|
img = image.img_to_array(img)
|
||||||
|
img = np.expand_dims(img, axis=0)
|
||||||
|
img = preprocess_input(img)
|
||||||
|
feature = model.predict(img)
|
||||||
|
return feature.flatten()
|
||||||
|
|
||||||
|
|
||||||
|
image_dir = "images/"
|
||||||
|
image_paths = [
|
||||||
|
os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(".jpg")
|
||||||
|
]
|
||||||
|
features = []
|
||||||
|
|
||||||
|
for image_path in image_paths:
|
||||||
|
img_feature = extract_features(image_path, model)
|
||||||
|
features.append(img_feature)
|
||||||
|
|
||||||
|
features = np.array(features)
|
||||||
|
|
||||||
|
d = features.shape[1]
|
||||||
|
index = faiss.IndexFlatL2(d)
|
||||||
|
index.add(features)
|
||||||
|
|
||||||
|
|
||||||
|
def find_similar_images(query_image_path, index, k=6):
|
||||||
|
query_feature = extract_features(query_image_path, model).reshape(1, -1)
|
||||||
|
_distance, indices = index.search(query_feature, k)
|
||||||
|
return indices.flatten()
|
||||||
|
|
||||||
|
|
||||||
|
query_image_path = "query_image.jpg"
|
||||||
|
similar_image_indices = find_similar_images(query_image_path, index)
|
||||||
|
|
||||||
|
# display the results
|
||||||
|
for idx in similar_image_indices:
|
||||||
|
similar_image_path = image_paths[idx]
|
||||||
|
img = cv2.imread(similar_image_path)
|
||||||
|
cv2.imshow("Similar Image", img)
|
||||||
|
cv2.waitKey(0)
|
||||||
|
cv2.destroyAllWindows()
|
Loading…
Reference in a new issue