AhmedSSabir commited on
Commit
c2e8ec7
·
1 Parent(s): 8e8e671

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -0
app.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from doctest import OutputChecker
3
+ import sys
4
+ import argparse
5
+ #import torch
6
+ import re
7
+ import os
8
+ import gradio as gr
9
+ import requests
10
+ from sentence_transformers import SentenceTransformer, util
11
+ import torch
12
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
13
+ from transformers import T5Tokenizer, AutoModelForCausalLM
14
+ import torch
15
+
16
+ from transformers import BertJapaneseTokenizer, BertModel
17
+ import torch
18
+
19
+
20
+ class SentenceBertJapanese:
21
+ def __init__(self, model_name_or_path, device=None):
22
+ self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path)
23
+ self.model = BertModel.from_pretrained(model_name_or_path)
24
+ self.model.eval()
25
+
26
+ if device is None:
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ self.device = torch.device(device)
29
+ self.model.to(device)
30
+
31
+ def _mean_pooling(self, model_output, attention_mask):
32
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
33
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
34
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
35
+
36
+
37
+ def encode(self, sentences, batch_size=8):
38
+ all_embeddings = []
39
+ iterator = range(0, len(sentences), batch_size)
40
+ for batch_idx in iterator:
41
+ batch = sentences[batch_idx:batch_idx + batch_size]
42
+
43
+ encoded_input = self.tokenizer.batch_encode_plus(batch, padding="longest",
44
+ truncation=True, return_tensors="pt").to(self.device)
45
+ model_output = self.model(**encoded_input)
46
+ sentence_embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"]).to('cpu')
47
+
48
+ all_embeddings.extend(sentence_embeddings)
49
+
50
+ # return torch.stack(all_embeddings).numpy()
51
+ return torch.stack(all_embeddings)
52
+
53
+
54
+ #model_sbert = SentenceTransformer('stsb-distilbert-base')
55
+
56
+ model_sbert = SentenceTransformer("colorfulscoop/sbert-base-ja")
57
+ #MODEL_NAME = "sonoisa/sentence-bert-base-ja-mean-tokens-v2"
58
+
59
+
60
+
61
+ #model_sbert = SentenceBertJapanese(MODEL_NAME)
62
+
63
+
64
+ #batch_size = 1
65
+ #scorer = LMScorer.from_pretrained('gpt2' , device=device, batch_size=batch_size)
66
+
67
+ #import torch
68
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
69
+ import numpy as np
70
+ import re
71
+
72
+ def Sort_Tuple(tup):
73
+
74
+ # (Sorts in descending order)
75
+ tup.sort(key = lambda x: x[1])
76
+ return tup[::-1]
77
+
78
+
79
+ def softmax(x):
80
+ exps = np.exp(x)
81
+ return np.divide(exps, np.sum(exps))
82
+
83
+ # Load pre-trained model
84
+
85
+ #model = GPT2LMHeadModel.from_pretrained('distilgpt2', output_hidden_states = True, output_attentions = True)
86
+ #model = GPT2LMHeadModel.from_pretrained('colorfulscoop/gpt2-small-ja',output_hidden_states= True, output_attentions=True)
87
+
88
+ tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt-1b")
89
+ model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-1b")
90
+
91
+ #model = gr.Interface.load('huggingface/distilgpt2', output_hidden_states = True, output_attentions = True)
92
+
93
+ #model.eval()
94
+ #tokenizer = gr.Interface.load('huggingface/distilgpt2')
95
+
96
+ #tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
97
+ #tokenizer = T5Tokenizer.from_pretrained('colorfulscoop/gpt2-small-ja')
98
+ #tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
99
+
100
+
101
+ def cloze_prob(text):
102
+
103
+ whole_text_encoding = tokenizer.encode(text)
104
+ # Parse out the stem of the whole sentence (i.e., the part leading up to but not including the critical word)
105
+ text_list = text.split()
106
+ stem = ' '.join(text_list[:-1])
107
+ stem_encoding = tokenizer.encode(stem)
108
+ # cw_encoding is just the difference between whole_text_encoding and stem_encoding
109
+ # note: this might not correspond exactly to the word itself
110
+ cw_encoding = whole_text_encoding[len(stem_encoding):]
111
+ # Run the entire sentence through the model. Then go "back in time" to look at what the model predicted for each token, starting at the stem.
112
+ # Put the whole text encoding into a tensor, and get the model's comprehensive output
113
+ tokens_tensor = torch.tensor([whole_text_encoding])
114
+
115
+ with torch.no_grad():
116
+ outputs = model(tokens_tensor)
117
+ predictions = outputs[0]
118
+
119
+ logprobs = []
120
+ # start at the stem and get downstream probabilities incrementally from the model(see above)
121
+ start = -1-len(cw_encoding)
122
+ for j in range(start,-1,1):
123
+ raw_output = []
124
+ for i in predictions[-1][j]:
125
+ raw_output.append(i.item())
126
+
127
+ logprobs.append(np.log(softmax(raw_output)))
128
+
129
+ # if the critical word is three tokens long, the raw_probabilities should look something like this:
130
+ # [ [0.412, 0.001, ... ] ,[0.213, 0.004, ...], [0.002,0.001, 0.93 ...]]
131
+ # Then for the i'th token we want to find its associated probability
132
+ # this is just: raw_probabilities[i][token_index]
133
+ conditional_probs = []
134
+ for cw,prob in zip(cw_encoding,logprobs):
135
+ conditional_probs.append(prob[cw])
136
+ # now that you have all the relevant probabilities, return their product.
137
+ # This is the probability of the critical word given the context before it.
138
+
139
+ return np.exp(np.sum(conditional_probs))
140
+
141
+
142
+
143
+
144
+
145
+ def cos_sim(a, b):
146
+ return np.inner(a, b) / (np.linalg.norm(a) * (np.linalg.norm(b)))
147
+
148
+ def get_sim(x):
149
+ x = str(x)[1:-1]
150
+ x = str(x)[1:-1]
151
+ return x
152
+
153
+
154
+
155
+
156
+
157
+ #def Visual_re_ranker(caption, visual_context_label, visual_context_prob):
158
+ def Visual_re_ranker(caption_man, caption_woman, visual_context_label, visual_context_prob):
159
+ caption_man = caption_man
160
+ caption_woman = caption_woman
161
+ visual_context_label= visual_context_label
162
+ visual_context_prob = visual_context_prob
163
+ caption_emb_man = model_sbert.encode(caption_man, convert_to_tensor=True)
164
+ caption_emb_woman = model_sbert.encode(caption_woman, convert_to_tensor=True)
165
+ visual_context_label_emb = model_sbert.encode(visual_context_label, convert_to_tensor=True)
166
+
167
+ sim_m = cosine_scores = util.pytorch_cos_sim(caption_emb_man, visual_context_label_emb)
168
+ sim_m = sim_m.cpu().numpy()
169
+ sim_m = get_sim(sim_m)
170
+
171
+ sim_w = cosine_scores = util.pytorch_cos_sim(caption_emb_woman, visual_context_label_emb)
172
+ sim_w = sim_w.cpu().numpy()
173
+ sim_w = get_sim(sim_w)
174
+
175
+
176
+ LM_man = cloze_prob(caption_man)
177
+ LM_woman = cloze_prob(caption_woman)
178
+ score_man = pow(float(LM_man),pow((1-float(sim_m))/(1+ float(sim_m)),1-float(visual_context_prob)))
179
+ score_woman = pow(float(LM_woman),pow((1-float(sim_w))/(1+ float(sim_w)),1-float(visual_context_prob)))
180
+
181
+
182
+
183
+
184
+
185
+ return {"彼": float(score_man)/1, "彼女": float(score_woman)/1}
186
+
187
+
188
+
189
+ #print(Visual_re_ranker("ハイデルベルク大学は彼の出身大学である。", "大学", "0.7458009"))
190
+
191
+
192
+ demo = gr.Interface(
193
+ fn=Visual_re_ranker,
194
+ description="Demo for Women Wearing Lipstick: Measuring the Bias Between Object and Its Related Gender",
195
+ inputs=[gr.Textbox(value="ハイデルベルク大学は彼の出身大学である。") , gr.Textbox(value="ハイデルベルク大学は彼女の出身大学である"), gr.Textbox(value="大学"), gr.Textbox(value="0.7458009")],
196
+ #inputs=[gr.Textbox(value="a man is blow drying his hair in the bathroom") , gr.Textbox(value="a woman is blow drying her hair in the bathroom"), gr.Textbox(value="hair spray"), gr.Textbox(value="0.7385")],
197
+ #outputs=[gr.Textbox(value="Language Model Score") , gr.Textbox(value="Semantic Similarity Score"), gr.Textbox(value="Belief revision score via visual context")],
198
+ outputs="label",
199
+ )
200
+ demo.launch()