1
0
Fork 0

Add: basic similar image classification

This commit is contained in:
Aroy-Art 2024-09-01 13:28:24 +02:00
parent 74809e7a29
commit 2bbf440b4a
Signed by: Aroy
GPG key ID: 583642324A1D2070

View file

@ -0,0 +1,53 @@
import os
import numpy as np
from tensorflow.keras.applications import vgg19
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg19 import preprocess_input
import faiss
import cv2
model = vgg19.VGG19(weights="imagenet", include_top=False, pooling="avg")
def extract_features(img_path, model):
img = image.load_img(img_path, target_size=(224, 224))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input(img)
feature = model.predict(img)
return feature.flatten()
image_dir = "images/"
image_paths = [
os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(".jpg")
]
features = []
for image_path in image_paths:
img_feature = extract_features(image_path, model)
features.append(img_feature)
features = np.array(features)
d = features.shape[1]
index = faiss.IndexFlatL2(d)
index.add(features)
def find_similar_images(query_image_path, index, k=6):
query_feature = extract_features(query_image_path, model).reshape(1, -1)
_distance, indices = index.search(query_feature, k)
return indices.flatten()
query_image_path = "query_image.jpg"
similar_image_indices = find_similar_images(query_image_path, index)
# display the results
for idx in similar_image_indices:
similar_image_path = image_paths[idx]
img = cv2.imread(similar_image_path)
cv2.imshow("Similar Image", img)
cv2.waitKey(0)
cv2.destroyAllWindows()