from transformers import ViTModel, ViTImageProcessor from PIL import Image, ImageOps import gradio as gr import torch from datasets import Dataset from torch.nn import CosineSimilarity image_processor = ViTImageProcessor.from_pretrained("vit-base-patch16-224") image_encoder = ViTModel.from_pretrained("model/image_encoder/epoch_29").eval().to("cuda") scribble_encoder = ViTModel.from_pretrained("model/scibble_encoder/epoch_29").eval().to("cuda") candidates: Dataset = None cosinesimilarity = CosineSimilarity() def load_candidates(candidate_dir): def preprocess(examples): images = [image.convert("RGB") for image in examples["image"]] examples["image_embedding"] = image_encoder(image_processor(images, return_tensors="pt")["pixel_values"].to("cuda"))["pooler_output"] return examples dataset = [dict(image=Image.open(tempfile.name).convert("RGB").resize((224, 224))) for tempfile in candidate_dir] dataset = Dataset.from_list(dataset) with torch.no_grad(): dataset = dataset.map(preprocess, batched=True, batch_size=1024) return dataset def load_candidates_in_cache(candidate_files): global candidates candidates = load_candidates(candidate_files) def scribble_matching(input_img: Image): input_img = ImageOps.invert(input_img) scribble = input_img scribble_embedding = scribble_encoder(image_processor(scribble, return_tensors="pt")["pixel_values"].to("cuda"))["pooler_output"].to("cpu") image_embeddings = torch.tensor(candidates["image_embedding"], dtype=torch.float32) sim = cosinesimilarity(scribble_embedding, image_embeddings) predicts = torch.topk(sim, k=15) output_imgs = candidates[predicts.indices.tolist()]["image"] labels = predicts.values.tolist() labels = [f"{label:.3f}" for label in labels] return list(zip([input_img] + output_imgs, ["preview"] + labels)) def main(): with gr.Blocks() as demo: with gr.Row(): input_img = gr.Image(type="pil", label="scribble", height=512, width=512, source="canvas", tool="color-sketch", brush_radius=10) prediction_gallery = gr.Gallery(min_width=512, columns=4, show_label=True, ) with gr.Row(): candidate_dir = gr.File(file_count="directory", min_width=300, height=300) load_candidates_btn = gr.Button("Load", variant="secondary", size="sm") btn = gr.Button("Scribble Matching", variant="primary") load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir]) btn.click(fn=scribble_matching, inputs=[input_img], outputs=[prediction_gallery]) demo.launch(debug=True) if __name__ == "__main__": main()