Spaces:
Sleeping
Sleeping
import pandas as pd | |
import numpy as np | |
import clip | |
import gradio as gr | |
from utils import * | |
import os | |
# Load the open CLIP model | |
model, preprocess = clip.load("ViT-B/32", device=device) | |
from pathlib import Path | |
# Download from Github Releases | |
if not Path('unsplash-dataset/photo_ids.csv').exists(): | |
os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/photo_ids.csv -O unsplash-dataset/photo_ids.csv''') | |
if not Path('unsplash-dataset/features.npy').exists(): | |
os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/features.npy - O unsplash-dataset/features.npy''') | |
# Load the photo IDs | |
photo_ids = pd.read_csv("unsplash-dataset/photo_ids.csv") | |
photo_ids = list(photo_ids['photo_id']) | |
# Load the features vectors | |
photo_features = np.load("unsplash-dataset/features.npy") | |
# Convert features to Tensors: Float32 on CPU and Float16 on GPU | |
if device == "cpu": | |
photo_features = torch.from_numpy(photo_features).float().to(device) | |
else: | |
photo_features = torch.from_numpy(photo_features).to(device) | |
# Print some statistics | |
print(f"Photos loaded: {len(photo_ids)}") | |
from PIL import Image | |
def encode_search_query(net, search_query): | |
with torch.no_grad(): | |
tokenized_query = clip.tokenize(search_query) | |
# print("tokenized_query: ", tokenized_query.shape) | |
# Encode and normalize the search query using CLIP | |
text_encoded = net.encode_text(tokenized_query.to(device)) | |
text_encoded /= text_encoded.norm(dim=-1, keepdim=True) | |
# Retrieve the feature vector | |
# print("text_encoded: ", text_encoded.shape) | |
return text_encoded | |
def find_best_matches(text_features, photo_features, photo_ids, results_count=5): | |
# Compute the similarity between the search query and each photo using the Cosine similarity | |
# print("text_features: ", text_features.shape) | |
# print("photo_features: ", photo_features.shape) | |
similarities = (photo_features @ text_features.T).squeeze(1) | |
# Sort the photos by their similarity score | |
best_photo_idx = (-similarities).argsort() | |
# print("best_photo_idx: ", best_photo_idx.shape) | |
# print("best_photo_idx: ", best_photo_idx[:results_count]) | |
result_list = [photo_ids[i] for i in best_photo_idx[:results_count]] | |
# print("result_list: ", len(result_list)) | |
# Return the photo IDs of the best matches | |
return result_list | |
def search_unslash(net, search_query, photo_features, photo_ids, results_count=10): | |
# Encode the search query | |
text_features = encode_search_query(net, search_query) | |
# Find the best matches | |
best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, results_count) | |
return best_photo_ids | |
def search_by_text_and_photo(query_text, query_photo=None, query_photo_id=None, photo_weight=0.5): | |
# Encode the search query | |
if not query_text and query_photo is None and not query_photo_id: | |
return [] | |
text_features = encode_search_query(model, query_text) | |
if query_photo_id: | |
# Find the feature vector for the specified photo ID | |
query_photo_index = photo_ids.index(query_photo_id) | |
query_photo_features = photo_features[query_photo_index] | |
# Combine the test and photo queries and normalize again | |
search_features = text_features + query_photo_features * photo_weight | |
search_features /= search_features.norm(dim=-1, keepdim=True) | |
# Find the best match | |
best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10) | |
elif query_photo is not None: | |
query_photo = preprocess(query_photo) | |
query_photo = torch.tensor(query_photo).permute(2, 0, 1) | |
print(query_photo.shape) | |
query_photo_features = model.encode_image(query_photo) | |
query_photo_features = query_photo_features / query_photo_features.norm(dim=1, keepdim=True) | |
# Combine the test and photo queries and normalize again | |
search_features = text_features + query_photo_features * photo_weight | |
search_features /= search_features.norm(dim=-1, keepdim=True) | |
# Find the best match | |
best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10) | |
else: | |
# Display the results | |
print("Result...") | |
best_photo_ids = search_unslash(model, query_text, photo_features, photo_ids, 10) | |
return best_photo_ids | |
def fn_query_on_load(): | |
return "Dogs playing during sunset" | |
with gr.Blocks() as app: | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
# CLIP Image Search Engine! | |
### Enter search query or/and select image to find the similar images | |
""") | |
with gr.Row(visible=True): | |
with gr.Column(): | |
with gr.Row(): | |
search_text = gr.Textbox(value=fn_query_on_load, placeholder='Search..', label=None) | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", variant='primary') | |
clear_btn = gr.ClearButton() | |
with gr.Column(visible=True) as input_image_col: | |
search_image = gr.Image(label='Select from results', interactive=False) | |
search_image_id = gr.State(None) | |
with gr.Row(visible=True): | |
output_images = gr.Gallery(allow_preview=False, label='Results.. ', | |
value=[], columns=5, rows=2) | |
output_image_ids = gr.State([]) | |
def clear_data(): | |
return { | |
search_image: None, | |
output_images: None, | |
search_text: None, | |
search_image_id: None, | |
input_image_col: gr.update(visible=True) | |
} | |
clear_btn.click(clear_data, None, [search_image, output_images, search_text, search_image_id, input_image_col]) | |
def on_select(evt: gr.SelectData, output_image_ids): | |
return { | |
search_image: f"https://unsplash.com/photos/{output_image_ids[evt.index]}/download?w=320", | |
search_image_id: output_image_ids[evt.index], | |
input_image_col: gr.update(visible=True) | |
} | |
output_images.select(on_select, output_image_ids, [search_image, search_image_id, input_image_col]) | |
def func_search(query, img, img_id): | |
best_photo_ids = [] | |
if img_id: | |
best_photo_ids = search_by_text_and_photo(query, query_photo_id=img_id) | |
elif img is not None: | |
img = Image.open(img) | |
best_photo_ids = search_by_text_and_photo(query, query_photo=img) | |
elif query: | |
best_photo_ids = search_by_text_and_photo(query) | |
if len(best_photo_ids) == 0: | |
print("Invalid Search Request") | |
return { | |
output_image_ids: [], | |
output_images: [] | |
} | |
else: | |
img_urls = [] | |
for p_id in best_photo_ids: | |
url = f"https://unsplash.com/photos/{p_id}/download?w=20" | |
img_urls.append(url) | |
valid_images = filter_invalid_urls(img_urls, best_photo_ids) | |
return { | |
output_image_ids: valid_images['image_ids'], | |
output_images: valid_images['image_urls'] | |
} | |
submit_btn.click( | |
func_search, | |
[search_text, search_image, search_image_id], | |
[output_images, output_image_ids] | |
) | |
def on_upload(evt: gr.SelectData): | |
return { | |
search_image_id: None | |
} | |
search_image.upload(on_upload, None, search_image_id) | |
''' | |
Launch the app | |
''' | |
app.launch() | |