Spaces:
Runtime error
Runtime error
File size: 2,063 Bytes
5f011f5 8287e38 5f011f5 c3831fc 5f011f5 fbb89ca 5f011f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import torch
import requests
import numpy as np
import pandas as pd
import gradio as gr
from io import BytesIO
from PIL import Image as PILIMAGE
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
#Selecting device based on availability of GPUs
device = "cuda" if torch.cuda.is_available() else "cpu"
#Defining model, processor and tokenizer
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
#Loading the data
photos = pd.read_csv("./items_data.csv")
photo_features = np.load("./features.npy")
photo_ids = pd.read_csv("./photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])
def find_best_matches(text):
#Inference
with torch.no_grad():
# Encode and normalize the description using CLIP
inputs = tokenizer([text], padding=True, return_tensors="pt")
inputs = processor(text=[text], images=None, return_tensors="pt", padding=True)
text_encoded = model.get_text_features(**inputs).detach().numpy()
# Finding Cosine similarity
similarities = list((text_encoded @ photo_features.T).squeeze(0))
#Block of code for displaying top 3 best matches (images)
matched_images = []
for i in range(3):
idx = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)[i][1]
photo_id = photo_ids[idx]
photo_data = photos[photos["Uniq Id"] == photo_id].iloc[0]
response = requests.get(photo_data["Image"] + "?w=640")
img = PILIMAGE.open(BytesIO(response.content))
matched_images.append(img)
return matched_images
#Gradio app
iface = gr.Interface(fn=find_best_matches, inputs=[gr.inputs.Textbox(lines=1, label="Text query", placeholder="Introduce the search text...",)],
theme = "dark",
outputs=gr.outputs.Carousel([gr.outputs.Image(type="pil")]),
enable_queue=True).launch()
|