haritzpuerto commited on
Commit
9f26017
1 Parent(s): 70a4de4

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +403 -0
inference.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import numpy as np
4
+
5
+ from transformers import BertPreTrainedModel, AutoTokenizer
6
+ from transformers.modeling_outputs import TokenClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions
7
+ from transformers.models.bert.modeling_bert import BertPooler, BertEncoder
8
+
9
+ class PredictionRequest():
10
+ input_question: str
11
+ input_predictions: list[(str, float)]
12
+
13
+
14
+ class MetaQA():
15
+ def __init__(self, path_to_model):
16
+ self.metaqa_model = MetaQA_Model.from_pretrained(path_to_model)
17
+ self.tokenizer = AutoTokenizer.from_pretrained(path_to_model) # change to path_to_model
18
+
19
+ def run_metaqa(self, request: PredictionRequest):
20
+ '''
21
+ Runs MetaQA on a single instance.
22
+ '''
23
+ # Encode instance
24
+ input_ids, token_ids, attention_masks, ans_sc = self._encode_metaQA_instance(request)
25
+ # Run model
26
+ logits = self.metaqa_model(input_ids, token_ids, attention_masks, ans_sc).logits
27
+ # Get predictions
28
+ (pred, agent_name, metaqa_score, agent_score) = self._get_predictions(logits.detach().numpy(), request.input_predictions)
29
+ return (pred, agent_name, metaqa_score, agent_score)
30
+
31
+ def _encode_metaQA_instance(self, request: PredictionRequest, max_len=512):
32
+ '''
33
+ Creates input ids, token ids, token masks for an instance of MetaQA.
34
+ '''
35
+ # Create input ids, token ids, and masks
36
+ list_input_ids = []
37
+ list_token_ids = []
38
+ list_attention_masks = []
39
+ list_ans_sc = []
40
+
41
+ # Process question
42
+ ## input ids
43
+ list_input_ids.extend(self.tokenizer.encode("[CLS]", add_special_tokens=False)) # [CLS]
44
+ list_input_ids.extend(self.tokenizer.encode(request.input_question, add_special_tokens=False)) # Query token ids
45
+ list_input_ids.extend(self.tokenizer.encode("[SEP]", add_special_tokens=False)) # [SEP]
46
+ ## token ids
47
+ list_token_ids.extend(len(list_input_ids) * [0])
48
+ ## ans_sc_ids
49
+ list_ans_sc.extend(len(list_input_ids) * [0])
50
+
51
+ # Process qa_agents predictions
52
+ for qa_agent_pred in request.input_predictions:
53
+ ## input ids
54
+ list_input_ids.append(1) # [RANK]
55
+ ans_input_ids = self.tokenizer.encode(qa_agent_pred[0], add_special_tokens=False)
56
+ list_input_ids.extend(ans_input_ids)
57
+ ## token ids
58
+ list_token_ids.extend((len(ans_input_ids)+1) * [1]) # +1 to account for [RANK]
59
+ ## ans_sc ids
60
+ ans_score = qa_agent_pred[1]
61
+ list_ans_sc.extend((len(ans_input_ids)+1) * [ans_score]) # +1 to account for [RANK]
62
+ # Last [SEP]
63
+ # input ids
64
+ list_input_ids.extend(self.tokenizer.encode("[SEP]", add_special_tokens=False)) # last [SEP]
65
+ # token ids
66
+ list_token_ids.append(1)
67
+ # ans_sc_ids
68
+ list_ans_sc.append(0)
69
+ # attention masks
70
+ list_attention_masks.extend(len(list_input_ids) * [1])
71
+
72
+ # PADDING
73
+ len_padding = max_len - len(list_input_ids)
74
+ ## inputs ids
75
+ list_input_ids.extend([0]*len_padding) # [PAD]
76
+ ## token ids
77
+ list_token_ids.extend((len(list_input_ids) - len(list_token_ids)) * [1])
78
+ ## ans_sc_ids
79
+ list_ans_sc.extend((len(list_input_ids) - len(list_ans_sc)) * [0])
80
+ ## attention masks
81
+ list_attention_masks.extend((len(list_input_ids) - len(list_attention_masks)) * [0])
82
+
83
+
84
+ list_input_ids = torch.Tensor(list_input_ids).unsqueeze(0).long()
85
+ list_token_ids = torch.Tensor(list_token_ids).unsqueeze(0).long()
86
+ list_attention_masks = torch.Tensor(list_attention_masks).unsqueeze(0).long()
87
+ list_ans_sc = torch.Tensor(list_ans_sc).unsqueeze(0).long()
88
+
89
+ if len(list_input_ids) > max_len:
90
+ return None
91
+ else:
92
+ return (list_input_ids, list_token_ids, list_attention_masks, list_ans_sc)
93
+
94
+ def _get_predictions(self, logits, input_predictions):
95
+ top_k = lambda a, k: np.argsort(-a)[:k]
96
+ for idx in top_k(logits[0][:,1], self.metaqa_model.num_agents):
97
+ pred = input_predictions[idx][0]
98
+ if pred != '':
99
+ agent_name = self.metaqa_model.config.agents[idx]
100
+ metaqa_score = logits[0][idx][1]
101
+ agent_score = input_predictions[idx][1]
102
+ return (pred, agent_name, metaqa_score, agent_score)
103
+ # no valid prediction found, return the best prediction
104
+ idx = top_k(logits[0][:,1], 1)[0]
105
+ pred = input_predictions[idx][0]
106
+ metaqa_score = logits[0][idx][1]
107
+ agent_name = self.metaqa_model.config.agents[idx]
108
+ agent_score = input_predictions[idx][1]
109
+ return (pred, agent_name, metaqa_score, agent_score)
110
+
111
+
112
+ class MetaQA_Model(BertPreTrainedModel):
113
+ def __init__(self, config):
114
+ super().__init__(config)
115
+ self.bert = MetaQABertModel(config)
116
+ self.num_agents = config.num_agents
117
+
118
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
119
+ self.list_MoSeN = nn.ModuleList([nn.Linear(config.hidden_size, 1) for i in range(self.num_agents)])
120
+ self.input_size_ans_sel = 1 + config.hidden_size
121
+ interm_size = int(config.hidden_size/2)
122
+ self.ans_sel = nn.Sequential(nn.Linear(self.input_size_ans_sel, interm_size),
123
+ nn.ReLU(),
124
+ nn.Dropout(config.hidden_dropout_prob),
125
+ nn.Linear(interm_size, 2))
126
+
127
+ self.init_weights()
128
+
129
+ def forward(
130
+ self,
131
+ input_ids=None,
132
+ attention_mask=None,
133
+ token_type_ids=None,
134
+ position_ids=None,
135
+ head_mask=None,
136
+ inputs_embeds=None,
137
+ labels=None,
138
+ output_attentions=None,
139
+ output_hidden_states=None,
140
+ return_dict=None,
141
+ ans_sc=None,
142
+ agent_sc=None,
143
+ ):
144
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
145
+
146
+ outputs = self.bert(
147
+ input_ids,
148
+ attention_mask=attention_mask,
149
+ token_type_ids=token_type_ids,
150
+ position_ids=position_ids,
151
+ head_mask=head_mask,
152
+ inputs_embeds=inputs_embeds,
153
+ output_attentions=output_attentions,
154
+ output_hidden_states=output_hidden_states,
155
+ return_dict=return_dict,
156
+ ans_sc=ans_sc,
157
+ agent_sc=agent_sc,
158
+ )
159
+ # domain classification
160
+ pooled_output = outputs[1]
161
+
162
+ pooled_output = self.dropout(pooled_output)
163
+ list_domains_logits = []
164
+ for MoSeN in self.list_MoSeN:
165
+ domain_logits = MoSeN(pooled_output)
166
+ list_domains_logits.append(domain_logits)
167
+ domain_logits = torch.stack(list_domains_logits)
168
+ # shape = (num_agents, batch_size, 1)
169
+ # we have to transpose the shape to (batch_size, num_agents, 1)
170
+ domain_logits = domain_logits.transpose(0,1)
171
+
172
+ # ans classifier
173
+ sequence_output = outputs[0] # (batch_size, seq_len, hidden_size)
174
+ # select the [RANK] token embeddings
175
+ idx_rank = (input_ids == 1).nonzero() # (batch_size x num_agents, 2)
176
+ idx_rank = idx_rank[:,1].view(-1, self.num_agents)
177
+ list_emb = []
178
+ for i in range(idx_rank.shape[0]):
179
+ rank_emb = sequence_output[i][idx_rank[i], :]
180
+ # rank shape = (1, hidden_size)
181
+ list_emb.append(rank_emb)
182
+
183
+ rank_emb = torch.stack(list_emb)
184
+
185
+ rank_emb = self.dropout(rank_emb)
186
+ rank_emb = torch.cat((rank_emb, domain_logits), dim=2)
187
+ # rank emb shape = (batch_size, num_agents, hidden_size+1)
188
+ logits = self.ans_sel(rank_emb) # (batch_size, num_agents, 2)
189
+
190
+ if not return_dict:
191
+ output = (logits,) + outputs[2:]
192
+ return output
193
+
194
+ return TokenClassifierOutput(
195
+ loss=None,
196
+ logits=logits,
197
+ hidden_states=outputs.hidden_states,
198
+ attentions=outputs.attentions,
199
+ )
200
+
201
+
202
+ class MetaQABertModel(BertPreTrainedModel):
203
+ def __init__(self, config, add_pooling_layer=True):
204
+ super().__init__(config)
205
+ self.config = config
206
+
207
+ self.embeddings = MetaQABertEmbeddings(config)
208
+ self.encoder = BertEncoder(config)
209
+ self.pooler = BertPooler(config) if add_pooling_layer else None
210
+
211
+ self.init_weights()
212
+
213
+ def get_input_embeddings(self):
214
+ return self.embeddings.word_embeddings
215
+
216
+ def set_input_embeddings(self, value):
217
+ self.embeddings.word_embeddings = value
218
+
219
+ def _prune_heads(self, heads_to_prune):
220
+ for layer, heads in heads_to_prune.items():
221
+ self.encoder.layer[layer].attention.prune_heads(heads)
222
+
223
+ def forward(
224
+ self,
225
+ input_ids=None,
226
+ attention_mask=None,
227
+ token_type_ids=None,
228
+ position_ids=None,
229
+ head_mask=None,
230
+ inputs_embeds=None,
231
+ encoder_hidden_states=None,
232
+ encoder_attention_mask=None,
233
+ past_key_values=None,
234
+ use_cache=None,
235
+ output_attentions=None,
236
+ output_hidden_states=None,
237
+ return_dict=None,
238
+ ans_sc=None,
239
+ agent_sc=None,
240
+ ):
241
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
242
+ output_hidden_states = (
243
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
244
+ )
245
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
246
+
247
+ if self.config.is_decoder:
248
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
249
+ else:
250
+ use_cache = False
251
+
252
+ if input_ids is not None and inputs_embeds is not None:
253
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
254
+ elif input_ids is not None:
255
+ input_shape = input_ids.size()
256
+ batch_size, seq_length = input_shape
257
+ elif inputs_embeds is not None:
258
+ input_shape = inputs_embeds.size()[:-1]
259
+ batch_size, seq_length = input_shape
260
+ else:
261
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
262
+
263
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
264
+
265
+ # past_key_values_length
266
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
267
+
268
+ if attention_mask is None:
269
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
270
+
271
+ if token_type_ids is None:
272
+ if hasattr(self.embeddings, "token_type_ids"):
273
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
274
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
275
+ token_type_ids = buffered_token_type_ids_expanded
276
+ else:
277
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
278
+
279
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
280
+ # ourselves in which case we just need to make it broadcastable to all heads.
281
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
282
+
283
+ # If a 2D or 3D attention mask is provided for the cross-attention
284
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
285
+ if self.config.is_decoder and encoder_hidden_states is not None:
286
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
287
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
288
+ if encoder_attention_mask is None:
289
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
290
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
291
+ else:
292
+ encoder_extended_attention_mask = None
293
+
294
+ # Prepare head mask if needed
295
+ # 1.0 in head_mask indicate we keep the head
296
+ # attention_probs has shape bsz x n_heads x N x N
297
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
298
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
299
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
300
+
301
+ embedding_output = self.embeddings(
302
+ input_ids=input_ids,
303
+ position_ids=position_ids,
304
+ token_type_ids=token_type_ids,
305
+ inputs_embeds=inputs_embeds,
306
+ past_key_values_length=past_key_values_length,
307
+ ans_sc=ans_sc,
308
+ agent_sc=agent_sc,
309
+ )
310
+ encoder_outputs = self.encoder(
311
+ embedding_output,
312
+ attention_mask=extended_attention_mask,
313
+ head_mask=head_mask,
314
+ encoder_hidden_states=encoder_hidden_states,
315
+ encoder_attention_mask=encoder_extended_attention_mask,
316
+ past_key_values=past_key_values,
317
+ use_cache=use_cache,
318
+ output_attentions=output_attentions,
319
+ output_hidden_states=output_hidden_states,
320
+ return_dict=return_dict,
321
+ )
322
+ sequence_output = encoder_outputs[0]
323
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
324
+
325
+ if not return_dict:
326
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
327
+
328
+ return BaseModelOutputWithPoolingAndCrossAttentions(
329
+ last_hidden_state=sequence_output,
330
+ pooler_output=pooled_output,
331
+ past_key_values=encoder_outputs.past_key_values,
332
+ hidden_states=encoder_outputs.hidden_states,
333
+ attentions=encoder_outputs.attentions,
334
+ cross_attentions=encoder_outputs.cross_attentions,
335
+ )
336
+
337
+ class MetaQABertEmbeddings(nn.Module):
338
+ """Construct the embeddings from word, position and token_type embeddings."""
339
+
340
+ def __init__(self, config):
341
+ super().__init__()
342
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
343
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
344
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
345
+ self.ans_sc_proj = nn.Linear(1, config.hidden_size)
346
+ self.agent_sc_proj = nn.Linear(1, config.hidden_size)
347
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
348
+ # any TensorFlow checkpoint file
349
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
350
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
351
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
352
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
353
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
354
+ self.register_buffer(
355
+ "token_type_ids",
356
+ torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
357
+ persistent=False,
358
+ )
359
+
360
+
361
+ def forward(
362
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0,
363
+ ans_sc=None, agent_sc=None):
364
+ if input_ids is not None:
365
+ input_shape = input_ids.size()
366
+ else:
367
+ input_shape = inputs_embeds.size()[:-1]
368
+
369
+ seq_length = input_shape[1]
370
+
371
+ if position_ids is None:
372
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
373
+
374
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
375
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
376
+ # issue #5664
377
+ if token_type_ids is None:
378
+ if hasattr(self, "token_type_ids"):
379
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
380
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
381
+ token_type_ids = buffered_token_type_ids_expanded
382
+ else:
383
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
384
+
385
+ if inputs_embeds is None:
386
+ inputs_embeds = self.word_embeddings(input_ids)
387
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
388
+
389
+ embeddings = inputs_embeds + token_type_embeddings
390
+ if self.position_embedding_type == "absolute":
391
+ position_embeddings = self.position_embeddings(position_ids)
392
+ embeddings += position_embeddings
393
+
394
+ if ans_sc is not None:
395
+ ans_sc_emb = self.ans_sc_proj(ans_sc.unsqueeze(2))
396
+ embeddings += ans_sc_emb
397
+ if agent_sc is not None:
398
+ agent_sc_emb = self.agent_sc_proj(agent_sc.unsqueeze(2))
399
+ embeddings += agent_sc_emb
400
+
401
+ embeddings = self.LayerNorm(embeddings)
402
+ embeddings = self.dropout(embeddings)
403
+ return embeddings