diff --git a/Python/similar-images-classification/main.py b/Python/similar-images-classification/main.py index 2cd3eda..a0c1406 100644 --- a/Python/similar-images-classification/main.py +++ b/Python/similar-images-classification/main.py @@ -6,6 +6,40 @@ from tensorflow.keras.applications.vgg19 import preprocess_input import faiss import cv2 + +def query_yes_no(question, default="yes"): + """Ask a yes/no question via raw_input() and return their answer. + + "question" is a string that is presented to the user. + "default" is the presumed answer if the user just hits . + It must be "yes" (the default), "no" or None (meaning + an answer is required of the user). + + The "answer" return value is True for "yes" or False for "no". + """ + valid = {"yes": True, "y": True, "ye": True, + "no": False, "n": False} + if default is None: + prompt = " [y/n] " + elif default == "yes": + prompt = " [Y/n] " + elif default == "no": + prompt = " [y/N] " + else: + raise ValueError("invalid default answer: '%s'" % default) + + while True: + sys.stdout.write(question + prompt) + choice = input().lower() + if default is not None and choice == '': + return valid[default] + elif choice in valid: + return valid[choice] + else: + sys.stdout.write("Please respond with 'yes' or 'no' " + "(or 'y' or 'n').\n") + + model = vgg19.VGG19(weights="imagenet", include_top=False, pooling="avg") @@ -24,15 +58,35 @@ image_paths = [ ] features = [] -for image_path in image_paths: - img_feature = extract_features(image_path, model) - features.append(img_feature) +if os.path.exists("image_index.bin"): + if query_yes_no("Load the index?", default="yes"): + index = faiss.read_index("image_index.bin") + else: + for image_path in image_paths: + img_feature = extract_features(image_path, model) + features.append(img_feature) -features = np.array(features) + features = np.array(features) -d = features.shape[1] -index = faiss.IndexFlatL2(d) -index.add(features) + d = features.shape[1] + index = faiss.IndexFlatL2(d) + index.add(features) + + if query_yes_no("Save the index?", default="yes"): + faiss.write_index(index, "image_index.bin") +else: + for image_path in image_paths: + img_feature = extract_features(image_path, model) + features.append(img_feature) + + features = np.array(features) + + d = features.shape[1] + index = faiss.IndexFlatL2(d) + index.add(features) + + if query_yes_no("Save the index?", default="yes"): + faiss.write_index(index, "image_index.bin") def find_similar_images(query_image_path, index, k=6):