import gradio as gr import torch import pickle import numpy as np import pandas as pd from transformers import CLIPProcessor, CLIPModel from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor from sklearn.metrics.pairwise import cosine_similarity import csv from PIL import Image model_path = "kaveh/rclip" embeddings_file = './image_embeddings_8_clip14_cxrbert.pkl' csv_path = "./captions.csv" def load_image_ids(csv_file): ids = [] captions = [] with open(csv_file, 'r') as f: reader = csv.reader(f, delimiter='\t') for row in reader: ids.append(row[0]) captions.append(row[1]) return ids, captions def load_embeddings(embeddings_file): with open(embeddings_file, 'rb') as f: image_embeddings = pickle.load(f) return image_embeddings def find_similar_images(query_embedding, image_embeddings, k=2): similarities = cosine_similarity(query_embedding.reshape(1, -1), image_embeddings) closest_indices = np.argsort(similarities[0])[::-1][:k] scores = sorted(similarities[0])[::-1][:k] return closest_indices, scores def main(query, k=2): # Load RCLIP model model = VisionTextDualEncoderModel.from_pretrained(model_path) processor = VisionTextDualEncoderProcessor.from_pretrained(model_path) # Load image embeddings image_embeddings = load_embeddings(embeddings_file) # Embed the query inputs = processor(text=query, images=None, return_tensors="pt", padding=True) with torch.no_grad(): query_embedding = model.get_text_features(**inputs)[0].numpy() # Get image names ids, captions = load_image_ids(csv_path) # Find similar images similar_image_indices, scores = find_similar_images(query_embedding, image_embeddings, k=int(k)) # Return the results similar_image_names = [f"./images/{ids[index]}.jpg" for index in similar_image_indices] similar_image_captions = [captions[index] for index in similar_image_indices] similar_images = [Image.open(i) for i in similar_image_names] return similar_images, pd.DataFrame([[t+1 for t in range(k)], similar_image_names, similar_image_captions, scores], index=["#", "path", "caption", "score"]).T # Define the Gradio interface examples = [ ["Chest X-ray photos",5], ["Orthopantogram (OPG)",5], ["Brain Scan",5], ["tomography",5] ] title="RCLIP Image Retrieval" description = "CLIP model fine-tuned on the ROCO dataset" with gr.Blocks(title=title) as demo: with gr.Row(): with gr.Column(scale=5): gr.Markdown("# "+title) gr.Markdown(description) gr.HTML(value="\"teesside", show_label=False,scale=1) #Image.open("./data/teesside university logo.png"), height=70, show_label=False, container=False) with gr.Column(variant="compact"): with gr.Row(variant="compact"): query = gr.Textbox(label="Enter your query", show_label=False, placeholder= "Enter your query" , scale=5) btn = gr.Button("Search query", variant="primary", scale=1) n_s = gr.Slider(2, 10, label='Number of Top Results', value=5, step=1.0, show_label=True) with gr.Column(variant="compact"): gr.Markdown("## Results") gallery = gr.Gallery(label="found images", show_label=True, elem_id="gallery", columns=[2], rows=[4], object_fit="contain", height="auto", preview=True) gr.Markdown("Information of the found images") df = gr.DataFrame() btn.click(main, [query, n_s], [gallery, df]) with gr.Column(variant="compact"): gr.Markdown("## Examples") gr.Examples(examples, [query, n_s]) demo.launch(debug='True')