1
0
Fork 0

Compare commits

...

2 commits

View file

@ -6,6 +6,40 @@ from tensorflow.keras.applications.vgg19 import preprocess_input
import faiss
import cv2
def query_yes_no(question, default="yes"):
"""Ask a yes/no question via raw_input() and return their answer.
"question" is a string that is presented to the user.
"default" is the presumed answer if the user just hits <Enter>.
It must be "yes" (the default), "no" or None (meaning
an answer is required of the user).
The "answer" return value is True for "yes" or False for "no".
"""
valid = {"yes": True, "y": True, "ye": True,
"no": False, "n": False}
if default is None:
prompt = " [y/n] "
elif default == "yes":
prompt = " [Y/n] "
elif default == "no":
prompt = " [y/N] "
else:
raise ValueError("invalid default answer: '%s'" % default)
while True:
sys.stdout.write(question + prompt)
choice = input().lower()
if default is not None and choice == '':
return valid[default]
elif choice in valid:
return valid[choice]
else:
sys.stdout.write("Please respond with 'yes' or 'no' "
"(or 'y' or 'n').\n")
model = vgg19.VGG19(weights="imagenet", include_top=False, pooling="avg")
@ -24,15 +58,35 @@ image_paths = [
]
features = []
for image_path in image_paths:
img_feature = extract_features(image_path, model)
features.append(img_feature)
if os.path.exists("image_index.bin"):
if query_yes_no("Load the index?", default="yes"):
index = faiss.read_index("image_index.bin")
else:
for image_path in image_paths:
img_feature = extract_features(image_path, model)
features.append(img_feature)
features = np.array(features)
features = np.array(features)
d = features.shape[1]
index = faiss.IndexFlatL2(d)
index.add(features)
d = features.shape[1]
index = faiss.IndexFlatL2(d)
index.add(features)
if query_yes_no("Save the index?", default="yes"):
faiss.write_index(index, "image_index.bin")
else:
for image_path in image_paths:
img_feature = extract_features(image_path, model)
features.append(img_feature)
features = np.array(features)
d = features.shape[1]
index = faiss.IndexFlatL2(d)
index.add(features)
if query_yes_no("Save the index?", default="yes"):
faiss.write_index(index, "image_index.bin")
def find_similar_images(query_image_path, index, k=6):