import pandas as pd import numpy as np import clip import gradio as gr from utils import * import os # Load the open CLIP model model, preprocess = clip.load("ViT-B/32", device=device) from pathlib import Path # Download from Github Releases if not Path('unsplash-dataset/photo_ids.csv').exists(): os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/photo_ids.csv -O unsplash-dataset/photo_ids.csv''') if not Path('unsplash-dataset/features.npy').exists(): os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/features.npy - O unsplash-dataset/features.npy''') # Load the photo IDs photo_ids = pd.read_csv("unsplash-dataset/photo_ids.csv") photo_ids = list(photo_ids['photo_id']) # Load the features vectors photo_features = np.load("unsplash-dataset/features.npy") # Convert features to Tensors: Float32 on CPU and Float16 on GPU if device == "cpu": photo_features = torch.from_numpy(photo_features).float().to(device) else: photo_features = torch.from_numpy(photo_features).to(device) # Print some statistics print(f"Photos loaded: {len(photo_ids)}") from PIL import Image def encode_search_query(net, search_query): with torch.no_grad(): tokenized_query = clip.tokenize(search_query) # print("tokenized_query: ", tokenized_query.shape) # Encode and normalize the search query using CLIP text_encoded = net.encode_text(tokenized_query.to(device)) text_encoded /= text_encoded.norm(dim=-1, keepdim=True) # Retrieve the feature vector # print("text_encoded: ", text_encoded.shape) return text_encoded def find_best_matches(text_features, photo_features, photo_ids, results_count=5): # Compute the similarity between the search query and each photo using the Cosine similarity # print("text_features: ", text_features.shape) # print("photo_features: ", photo_features.shape) similarities = (photo_features @ text_features.T).squeeze(1) # Sort the photos by their similarity score best_photo_idx = (-similarities).argsort() # print("best_photo_idx: ", best_photo_idx.shape) # print("best_photo_idx: ", best_photo_idx[:results_count]) result_list = [photo_ids[i] for i in best_photo_idx[:results_count]] # print("result_list: ", len(result_list)) # Return the photo IDs of the best matches return result_list def search_unslash(net, search_query, photo_features, photo_ids, results_count=10): # Encode the search query text_features = encode_search_query(net, search_query) # Find the best matches best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, results_count) return best_photo_ids def search_by_text_and_photo(query_text, query_photo=None, query_photo_id=None, photo_weight=0.5): # Encode the search query if not query_text and query_photo is None and not query_photo_id: return [] text_features = encode_search_query(model, query_text) if query_photo_id: # Find the feature vector for the specified photo ID query_photo_index = photo_ids.index(query_photo_id) query_photo_features = photo_features[query_photo_index] # Combine the test and photo queries and normalize again search_features = text_features + query_photo_features * photo_weight search_features /= search_features.norm(dim=-1, keepdim=True) # Find the best match best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10) elif query_photo is not None: query_photo = preprocess(query_photo) query_photo = torch.tensor(query_photo).permute(2, 0, 1) print(query_photo.shape) query_photo_features = model.encode_image(query_photo) query_photo_features = query_photo_features / query_photo_features.norm(dim=1, keepdim=True) # Combine the test and photo queries and normalize again search_features = text_features + query_photo_features * photo_weight search_features /= search_features.norm(dim=-1, keepdim=True) # Find the best match best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10) else: # Display the results print("Result...") best_photo_ids = search_unslash(model, query_text, photo_features, photo_ids, 10) return best_photo_ids def fn_query_on_load(): return "Dogs playing during sunset" with gr.Blocks() as app: with gr.Row(): gr.Markdown( """ # CLIP Image Search Engine! ### Enter search query or/and select image to find the similar images """) with gr.Row(visible=True): with gr.Column(): with gr.Row(): search_text = gr.Textbox(value=fn_query_on_load, placeholder='Search..', label=None) with gr.Row(): submit_btn = gr.Button("Submit", variant='primary') clear_btn = gr.ClearButton() with gr.Column(visible=True) as input_image_col: search_image = gr.Image(label='Select from results', interactive=False) search_image_id = gr.State(None) with gr.Row(visible=True): output_images = gr.Gallery(allow_preview=False, label='Results.. ', value=[], columns=5, rows=2) output_image_ids = gr.State([]) def clear_data(): return { search_image: None, output_images: None, search_text: None, search_image_id: None, input_image_col: gr.update(visible=True) } clear_btn.click(clear_data, None, [search_image, output_images, search_text, search_image_id, input_image_col]) def on_select(evt: gr.SelectData, output_image_ids): return { search_image: f"https://unsplash.com/photos/{output_image_ids[evt.index]}/download?w=320", search_image_id: output_image_ids[evt.index], input_image_col: gr.update(visible=True) } output_images.select(on_select, output_image_ids, [search_image, search_image_id, input_image_col]) def func_search(query, img, img_id): best_photo_ids = [] if img_id: best_photo_ids = search_by_text_and_photo(query, query_photo_id=img_id) elif img is not None: img = Image.open(img) best_photo_ids = search_by_text_and_photo(query, query_photo=img) elif query: best_photo_ids = search_by_text_and_photo(query) if len(best_photo_ids) == 0: print("Invalid Search Request") return { output_image_ids: [], output_images: [] } else: img_urls = [] for p_id in best_photo_ids: url = f"https://unsplash.com/photos/{p_id}/download?w=20" img_urls.append(url) valid_images = filter_invalid_urls(img_urls, best_photo_ids) return { output_image_ids: valid_images['image_ids'], output_images: valid_images['image_urls'] } submit_btn.click( func_search, [search_text, search_image, search_image_id], [output_images, output_image_ids] ) def on_upload(evt: gr.SelectData): return { search_image_id: None } search_image.upload(on_upload, None, search_image_id) ''' Launch the app ''' app.launch()