Spaces:
Sleeping
Sleeping
File size: 2,472 Bytes
04573a7 accc6c7 04573a7 a819f61 04573a7 a819f61 04573a7 accc6c7 04573a7 a819f61 04573a7 a819f61 04573a7 006e560 a819f61 04573a7 accc6c7 04573a7 accc6c7 04573a7 f615499 04573a7 accc6c7 a819f61 04573a7 1026aeb 04573a7 a819f61 accc6c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import numpy as np
import faiss
import torch
from torchvision.transforms import (
Compose,
Resize,
ToTensor,
Normalize,
InterpolationMode,
CenterCrop,
)
from PIL import Image
import gradio as gr
print("starting...")
(ys,) = np.load("embs.npz").values()
print("loaded embs")
model = torch.load(
"style-extractor-v0.3.0.ckpt",
map_location="cpu",
)
print("loaded extractor")
with open("urls.txt") as f:
urls = f.read().splitlines()
print("loaded urls")
assert len(urls) == len(ys)
d = ys.shape[1]
index = faiss.IndexHNSWFlat(d, 32)
print("building index")
index.add(ys)
print("index built")
def MyResize(area, d):
def f(im: Image):
w, h = im.size
s = (area / w / h) ** 0.5
wd, hd = int(s * w / d), int(s * h / d)
e = lambda a, b: 1 - min(a, b) / max(a, b)
wd, hd = min(
(
(ww * d, hh * d)
for ww, hh in [(wd + i, hd + j) for i in (0, 1) for j in (0, 1)]
if ww * d * hh * d <= area
),
key=lambda wh: e(wh[0] / wh[1], w / h),
)
return Compose(
[
Resize(
(int(h * wd / w), wd) if wd / w > hd / h else (hd, int(w * hd / h)),
InterpolationMode.BICUBIC,
),
CenterCrop((hd, wd)),
]
)(im)
return f
tf = Compose(
[
MyResize((518 * 1.3) ** 2, 14),
ToTensor(),
Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
]
)
def get_emb(im: Image):
model.eval()
with torch.no_grad():
return model(tf(im).unsqueeze(0))
n_outputs = 50
row_size = 5
def f(im):
D, I = index.search(get_emb(im), n_outputs)
return [f"Distance: {d:.1f}\n![]({urls[i]})" for d, i in zip(D[0], I[0])]
print("preparing gradio")
with gr.Blocks() as demo:
gr.Markdown(
"# Style Similarity Search\n\nFind artworks with a similar style from a medium-sized database (10k artists * 30 img/artist)"
)
img = gr.Image(type="pil", label="Query", height=500)
btn = gr.Button(variant="primary", value="search")
outputs = []
for i in range(-(n_outputs // (-row_size))):
with gr.Row():
for _ in range(min(row_size, n_outputs - i * row_size)):
outputs.append(gr.Markdown(label=f"#{len(outputs) + 1}"))
btn.click(f, img, outputs)
print("starting gradio")
demo.launch()
|