File size: 3,903 Bytes
d19fddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import torch
import pickle
import numpy as np
import pandas as pd
from transformers import CLIPProcessor, CLIPModel
from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor
from sklearn.metrics.pairwise import cosine_similarity
import csv
from PIL import Image

model_path = "kaveh/rclip"
embeddings_file = './image_embeddings_8_clip14_cxrbert.pkl' 
csv_path = "./captions.csv"

def load_image_ids(csv_file):
    ids = []
    captions = []
    with open(csv_file, 'r') as f:
        reader = csv.reader(f, delimiter='\t')
        for row in reader:
            ids.append(row[0])
            captions.append(row[1])
    return ids, captions

def load_embeddings(embeddings_file):
    with open(embeddings_file, 'rb') as f:
        image_embeddings = pickle.load(f)
    return image_embeddings


def find_similar_images(query_embedding, image_embeddings, k=2):
    similarities = cosine_similarity(query_embedding.reshape(1, -1), image_embeddings)
    closest_indices = np.argsort(similarities[0])[::-1][:k]
    scores = sorted(similarities[0])[::-1][:k]
    return closest_indices, scores


def main(query, k=2):
    # Load RCLIP model
    model = VisionTextDualEncoderModel.from_pretrained(model_path)
    processor = VisionTextDualEncoderProcessor.from_pretrained(model_path)

    # Load image embeddings 
    image_embeddings = load_embeddings(embeddings_file)

    # Embed the query
    inputs = processor(text=query, images=None, return_tensors="pt", padding=True)
    with torch.no_grad():
        query_embedding = model.get_text_features(**inputs)[0].numpy()
    
    # Get image names
    ids, captions = load_image_ids(csv_path)
    
    # Find similar images
    similar_image_indices, scores = find_similar_images(query_embedding, image_embeddings, k=int(k))

    # Return the results
    similar_image_names = [f"./images/{ids[index]}.jpg" for index in similar_image_indices]
    similar_image_captions = [captions[index] for index in similar_image_indices]
    similar_images = [Image.open(i) for i in similar_image_names]

    return similar_images, pd.DataFrame([[t+1 for t in range(k)], similar_image_names, similar_image_captions, scores], index=["#", "path", "caption", "score"]).T


# Define the Gradio interface
examples = [
            ["Chest X-ray photos",5],
            ["Orthopantogram (OPG)",5],
            ["Brain Scan",5], 
            ["tomography",5]
]

title="RCLIP Image Retrieval"
description = "CLIP model fine-tuned on the ROCO dataset"

with gr.Blocks(title=title) as demo:
    with gr.Row():
        with gr.Column(scale=5):
            gr.Markdown("# "+title)
            gr.Markdown(description)
        gr.HTML(value="<img src=\"https://newresults.co.uk/wp-content/uploads/2022/02/teesside-university-logo.png\" alt=\"teesside logo\" width=\"120\" height=\"70\">", show_label=False,scale=1)
        #Image.open("./data/teesside university logo.png"), height=70, show_label=False, container=False)
    with gr.Column(variant="compact"):
        with gr.Row(variant="compact"):
            query = gr.Textbox(label="Enter your query", show_label=False, placeholder= "Enter your query" , scale=5)
            btn = gr.Button("Search query", variant="primary", scale=1)
        
        n_s = gr.Slider(2, 10, label='Number of Top Results', value=5, step=1.0, show_label=True)
        
    
    with gr.Column(variant="compact"):
        gr.Markdown("## Results")
        gallery = gr.Gallery(label="found images", show_label=True, elem_id="gallery", columns=[2], rows=[4], object_fit="contain", height="auto", preview=True)
        gr.Markdown("Information of the found images")
        df = gr.DataFrame()
    btn.click(main, [query, n_s], [gallery, df])
    
    with gr.Column(variant="compact"):
        gr.Markdown("## Examples")
        gr.Examples(examples, [query, n_s])
    
    
demo.launch(debug='True')