edugp commited on
Commit
8842af8
·
1 Parent(s): 0338241

Add gradio app for caption scoring

Browse files
Files changed (2) hide show
  1. app.py +63 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import gradio as gr
4
+ import jax
5
+ from huggingface_hub import snapshot_download
6
+ from PIL import Image
7
+ from transformers import AutoTokenizer
8
+
9
+ LOCAL_PATH = snapshot_download("flax-community/clip-spanish")
10
+ sys.path.append(LOCAL_PATH)
11
+
12
+ from modeling_hybrid_clip import FlaxHybridCLIP
13
+ from test_on_image import prepare_image, prepare_text
14
+
15
+
16
+ def save_file_to_disk(uplaoded_file):
17
+ temp_file = "/tmp/image.jpeg"
18
+ im = Image.fromarray(uplaoded_file)
19
+ im.save(temp_file)
20
+ # with open(temp_file, "wb") as f:
21
+ # f.write(uploaded_file.getbuffer())
22
+ return temp_file
23
+
24
+
25
+ def run_inference(image_path, text, model, tokenizer):
26
+ pixel_values = prepare_image(image_path, model)
27
+ input_text = prepare_text(text, tokenizer)
28
+ model_output = model(
29
+ input_text["input_ids"],
30
+ pixel_values,
31
+ attention_mask=input_text["attention_mask"],
32
+ train=False,
33
+ return_dict=True,
34
+ )
35
+ logits = model_output["logits_per_image"]
36
+ score = jax.nn.sigmoid(logits)[0][0]
37
+ return score
38
+
39
+
40
+ def load_tokenizer_and_model():
41
+ # load the saved model
42
+ tokenizer = AutoTokenizer.from_pretrained(
43
+ "bertin-project/bertin-roberta-base-spanish"
44
+ )
45
+ model = FlaxHybridCLIP.from_pretrained(LOCAL_PATH)
46
+ return tokenizer, model
47
+
48
+
49
+ tokenizer, model = load_tokenizer_and_model()
50
+
51
+
52
+ def score_image_caption_pair(uploaded_file, text_input):
53
+ local_image_path = save_file_to_disk(uploaded_file)
54
+ score = run_inference(
55
+ local_image_path, text_input, model, tokenizer).tolist()
56
+ return {"Score": score}
57
+
58
+
59
+ image = gr.inputs.Image(shape=(299, 299))
60
+ iface = gr.Interface(
61
+ fn=score_image_caption_pair, inputs=[image, "text"], outputs="label"
62
+ )
63
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ flax==0.3.4
2
+ gradio==2.2.2
3
+ huggingface-hub==0.0.12
4
+ jax==0.2.17
5
+ streamlit==0.84.1
6
+ torch==1.9.0
7
+ torchvision==0.10.0
8
+ transformers==4.8.2
9
+ watchdog==2.1.3