Spaces:
Runtime error
Runtime error
Version 0.0
Browse files- app.py +112 -0
- photos.tsv000 +0 -0
- unsplash-25k-photos-embeddings.pkl +3 -0
app.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from sentence_transformers import SentenceTransformer, util
|
3 |
+
from pathlib import Path
|
4 |
+
import pickle
|
5 |
+
import requests
|
6 |
+
from PIL import Image
|
7 |
+
from io import BytesIO
|
8 |
+
import pandas as pd
|
9 |
+
from loguru import logger
|
10 |
+
import torch
|
11 |
+
|
12 |
+
T2I = "Text 2 Image"
|
13 |
+
I2I = "Image 2 Image"
|
14 |
+
def get_match(model, query, img_embs):
|
15 |
+
query_emb = model.encode([query], convert_to_tensor=True)
|
16 |
+
cosine_sim = util.pytorch_cos_sim(query_emb, img_embs)
|
17 |
+
return cosine_sim
|
18 |
+
def text_2_image(model, img_emb, img_names, img_urls, n_top_k_images):
|
19 |
+
st.title("Text to Image")
|
20 |
+
st.write("This is the text to image mode. Enter a text to be converted to an image")
|
21 |
+
text = st.text_input("Enter the text to be converted to an image")
|
22 |
+
if text:
|
23 |
+
if st.button("Convert"):
|
24 |
+
st.write("The image with the most similar embedding is:")
|
25 |
+
cosine_sim = get_match(model, text, img_emb)
|
26 |
+
logger.info(cosine_sim.shape)
|
27 |
+
top_k_images_indices = torch.topk(cosine_sim, n_top_k_images, 1).indices
|
28 |
+
logger.info(top_k_images_indices.squeeze().tolist())
|
29 |
+
images_found = [img_names[top_k_best_image] for top_k_best_image in top_k_images_indices.squeeze().tolist()]
|
30 |
+
cols = st.columns(n_top_k_images)
|
31 |
+
for i, image_found in enumerate(images_found):
|
32 |
+
logger.success(f"Image match found: {image_found}")
|
33 |
+
img_url_best_match = img_urls.loc[img_urls["photo_id"] == image_found]
|
34 |
+
logger.info(img_url_best_match.photo_url)
|
35 |
+
if len(img_url_best_match) >= 1:
|
36 |
+
response = requests.get(img_url_best_match.iloc[0]["photo_image_url"] + "?w=320")
|
37 |
+
image = Image.open(BytesIO(response.content))
|
38 |
+
with cols[i]:
|
39 |
+
st.image(image, caption=f"{i+1}/{n_top_k_images} most similar")
|
40 |
+
else:
|
41 |
+
st.error("No image found")
|
42 |
+
|
43 |
+
|
44 |
+
def image_2_image(model, img_emb, img_names, img_urls,n_top_k_images):
|
45 |
+
st.title("Image to Image")
|
46 |
+
st.write("This is the image to image mode. Enter an image to be converted to an image")
|
47 |
+
image = st.file_uploader("Upload an image to be converted to an image", type=["jpg", "png", "jpeg"])
|
48 |
+
if image is not None:
|
49 |
+
image = Image.open(BytesIO(image.getvalue()))
|
50 |
+
st.image(image, caption="Uploaded image")
|
51 |
+
if st.button("Convert"):
|
52 |
+
st.write("The image with the most similar embedding is:")
|
53 |
+
cosine_sim = get_match(model, image.convert("RGB"), img_emb)
|
54 |
+
logger.info(cosine_sim.shape)
|
55 |
+
top_k_images_indices = torch.topk(cosine_sim, n_top_k_images, 1).indices
|
56 |
+
logger.info(top_k_images_indices.squeeze().tolist())
|
57 |
+
images_found = [img_names[top_k_best_image] for top_k_best_image in top_k_images_indices.squeeze().tolist()]
|
58 |
+
cols = st.columns(n_top_k_images)
|
59 |
+
for i, image_found in enumerate(images_found):
|
60 |
+
logger.success(f"Image match found: {image_found}")
|
61 |
+
img_url_best_match = img_urls.loc[img_urls["photo_id"] == image_found]
|
62 |
+
logger.info(img_url_best_match.photo_url)
|
63 |
+
if len(img_url_best_match) >= 1:
|
64 |
+
response = requests.get(img_url_best_match.iloc[0]["photo_image_url"] + "?w=320")
|
65 |
+
image = Image.open(BytesIO(response.content))
|
66 |
+
with cols[i]:
|
67 |
+
st.image(image, caption=f"{i+1}/{n_top_k_images} most similar")
|
68 |
+
else:
|
69 |
+
st.error("No image found")
|
70 |
+
|
71 |
+
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
|
72 |
+
def load_model(name):
|
73 |
+
# st.sidebar.info("Loading model")
|
74 |
+
model = SentenceTransformer(name)
|
75 |
+
# st.sidebar.success(f"Model {name} loaded")
|
76 |
+
return model
|
77 |
+
|
78 |
+
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
|
79 |
+
def load_embeddings(filename):
|
80 |
+
st.sidebar.info("Loading Unsplash-Lite image embeddings")
|
81 |
+
with open(filename, "rb") as fIn:
|
82 |
+
img_names, img_emb = pickle.load(fIn)
|
83 |
+
st.sidebar.success("Images embeddings loaded")
|
84 |
+
return img_names, img_emb
|
85 |
+
|
86 |
+
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
|
87 |
+
def load_image_url_list(filename):
|
88 |
+
url_list = pd.read_csv(filename, sep='\t', header=0)
|
89 |
+
return url_list
|
90 |
+
|
91 |
+
def main():
|
92 |
+
st.title("CLIP Image Search")
|
93 |
+
model = load_model("clip-ViT-B-32")
|
94 |
+
st.write("Select the mode to search for a match in Unsplash (thumbnail size) dataset. text2image mode needs a text as input and outputs the image with the most similar embedding (following cosine similarity). The Image to image mode is similar, but an input image is used instead of a text query")
|
95 |
+
emb_filename = Path("./unsplash-25k-photos-embeddings.pkl")
|
96 |
+
urls_file = "./photos.tsv000"
|
97 |
+
img_urls = load_image_url_list(urls_file)
|
98 |
+
img_names, img_emb = load_embeddings(emb_filename)
|
99 |
+
# Convert list of image names to a dict matching image IDs and their embedding index
|
100 |
+
img_names = {img_number: img_name.split('.')[0] for img_number, img_name in enumerate(img_names)}
|
101 |
+
st.sidebar.title("Settings")
|
102 |
+
app_mode = st.sidebar.selectbox("Choose the app mode",
|
103 |
+
[T2I, I2I])
|
104 |
+
n_images_to_search = st.sidebar.number_input("Select the number of images to search", min_value=1, max_value=6)
|
105 |
+
if app_mode == T2I:
|
106 |
+
st.sidebar.info("Text to image mode")
|
107 |
+
text_2_image(model, img_emb, img_names, img_urls,n_images_to_search)
|
108 |
+
elif app_mode == I2I:
|
109 |
+
st.sidebar.info("Image to image mode")
|
110 |
+
image_2_image(model, img_emb, img_names, img_urls, n_images_to_search)
|
111 |
+
if __name__ == "__main__":
|
112 |
+
main()
|
photos.tsv000
ADDED
The diff for this file is too large to render.
See raw diff
|
|
unsplash-25k-photos-embeddings.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e60d49bd6334c29bc15fc2bc18c30b6d047a5584ad67c793eba376e95eaef8e
|
3 |
+
size 51816207
|