Spaces:
Running
Running
import numpy as np | |
import faiss | |
import torch | |
from torchvision.transforms import ( | |
Compose, | |
Resize, | |
ToTensor, | |
Normalize, | |
InterpolationMode, | |
CenterCrop, | |
) | |
from PIL import Image | |
import gradio as gr | |
print("starting...") | |
(ys,) = np.load("embs.npz").values() | |
print("loaded embs") | |
model = torch.load( | |
"style-extractor-v0.3.0.ckpt", | |
map_location="cpu", | |
) | |
print("loaded extractor") | |
with open("urls.txt") as f: | |
urls = f.read().splitlines() | |
print("loaded urls") | |
assert len(urls) == len(ys) | |
d = ys.shape[1] | |
index = faiss.IndexHNSWFlat(d, 32) | |
print("building index") | |
index.add(ys) | |
print("index built") | |
def MyResize(area, d): | |
def f(im: Image): | |
w, h = im.size | |
s = (area / w / h) ** 0.5 | |
wd, hd = int(s * w / d), int(s * h / d) | |
e = lambda a, b: 1 - min(a, b) / max(a, b) | |
wd, hd = min( | |
( | |
(ww * d, hh * d) | |
for ww, hh in [(wd + i, hd + j) for i in (0, 1) for j in (0, 1)] | |
if ww * d * hh * d <= area | |
), | |
key=lambda wh: e(wh[0] / wh[1], w / h), | |
) | |
return Compose( | |
[ | |
Resize( | |
(int(h * wd / w), wd) if wd / w > hd / h else (hd, int(w * hd / h)), | |
InterpolationMode.BICUBIC, | |
), | |
CenterCrop((hd, wd)), | |
] | |
)(im) | |
return f | |
tf = Compose( | |
[ | |
MyResize((518 * 1.3) ** 2, 14), | |
ToTensor(), | |
Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]), | |
] | |
) | |
def get_emb(im: Image): | |
model.eval() | |
with torch.no_grad(): | |
return model(tf(im).unsqueeze(0)) | |
n_outputs = 50 | |
row_size = 5 | |
def f(im): | |
D, I = index.search(get_emb(im), n_outputs) | |
return [f"Distance: {d:.1f}\n![]({urls[i]})" for d, i in zip(D[0], I[0])] | |
print("preparing gradio") | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
"# Style Similarity Search\n\nFind artworks with a similar style from a medium-sized database (10k artists * 30 img/artist)" | |
) | |
img = gr.Image(type="pil", label="Query", height=500) | |
btn = gr.Button(variant="primary", value="search") | |
outputs = [] | |
for i in range(-(n_outputs // (-row_size))): | |
with gr.Row(): | |
for _ in range(min(row_size, n_outputs - i * row_size)): | |
outputs.append(gr.Markdown(label=f"#{len(outputs) + 1}")) | |
btn.click(f, img, outputs) | |
print("starting gradio") | |
demo.launch() | |