meghanaraok commited on
Commit
be58120
·
verified ·
1 Parent(s): efcc711

Upload 3 files

Browse files
LongLAT/models/__init__.py ADDED
File without changes
LongLAT/models/modeling.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+
4
+ import torch
5
+ from torch.nn import BCEWithLogitsLoss, Dropout, Linear
6
+ from transformers import AutoModel, XLNetModel, LongformerModel, LongformerConfig
7
+ from transformers.models.longformer.modeling_longformer import LongformerEncoder, LongformerClassificationHead, LongformerLayer
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+
10
+ from LongLAT.hilat.models.utils import initial_code_title_vectors
11
+
12
+ logger = logging.getLogger("lwat")
13
+
14
+
15
+ class CodingModelConfig:
16
+ def __init__(self,
17
+ transformer_model_name_or_path,
18
+ transformer_tokenizer_name,
19
+ transformer_layer_update_strategy,
20
+ num_chunks,
21
+ max_seq_length,
22
+ dropout,
23
+ dropout_att,
24
+ d_model,
25
+ label_dictionary,
26
+ num_labels,
27
+ use_code_representation,
28
+ code_max_seq_length,
29
+ code_batch_size,
30
+ multi_head_att,
31
+ chunk_att,
32
+ linear_init_mean,
33
+ linear_init_std,
34
+ document_pooling_strategy,
35
+ multi_head_chunk_attention,
36
+ num_hidden_layers):
37
+ super(CodingModelConfig, self).__init__()
38
+ self.transformer_model_name_or_path = transformer_model_name_or_path
39
+ self.transformer_tokenizer_name = transformer_tokenizer_name
40
+ self.transformer_layer_update_strategy = transformer_layer_update_strategy
41
+ self.num_chunks = num_chunks
42
+ self.max_seq_length = max_seq_length
43
+ self.dropout = dropout
44
+ self.dropout_att = dropout_att
45
+ self.d_model = d_model
46
+ # labels_dictionary is a dataframe with columns: icd9_code, long_title
47
+ self.label_dictionary = label_dictionary
48
+ self.num_labels = num_labels
49
+ self.use_code_representation = use_code_representation
50
+ self.code_max_seq_length = code_max_seq_length
51
+ self.code_batch_size = code_batch_size
52
+ self.multi_head_att = multi_head_att
53
+ self.chunk_att = chunk_att
54
+ self.linear_init_mean = linear_init_mean
55
+ self.linear_init_std = linear_init_std
56
+ self.document_pooling_strategy = document_pooling_strategy
57
+ self.multi_head_chunk_attention = multi_head_chunk_attention
58
+ self.num_hidden_layers = num_hidden_layers
59
+
60
+
61
+ class LableWiseAttentionLayer(torch.nn.Module):
62
+ def __init__(self, coding_model_config, args):
63
+ super(LableWiseAttentionLayer, self).__init__()
64
+
65
+ self.config = coding_model_config
66
+ self.args = args
67
+
68
+ # layers
69
+ self.l1_linear = torch.nn.Linear(self.config.d_model,
70
+ self.config.d_model, bias=False)
71
+ self.tanh = torch.nn.Tanh()
72
+ self.l2_linear = torch.nn.Linear(self.config.d_model, self.config.num_labels, bias=False)
73
+ self.softmax = torch.nn.Softmax(dim=1)
74
+
75
+ # Mean pooling last hidden state of code title from transformer model as the initial code vectors
76
+ self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)
77
+
78
+ def _init_linear_weights(self, mean, std):
79
+ # normalize the l1 weights
80
+ torch.nn.init.normal_(self.l1_linear.weight, mean, std)
81
+ if self.l1_linear.bias is not None:
82
+ self.l1_linear.bias.data.fill_(0)
83
+ # initialize the l2
84
+ if self.config.use_code_representation:
85
+ code_vectors = initial_code_title_vectors(self.config.label_dictionary,
86
+ self.config.transformer_model_name_or_path,
87
+ self.config.transformer_tokenizer_name
88
+ if self.config.transformer_tokenizer_name
89
+ else self.config.transformer_model_name_or_path,
90
+ self.config.code_max_seq_length,
91
+ self.config.code_batch_size,
92
+ self.config.d_model,
93
+ self.args.device)
94
+
95
+ self.l2_linear.weight = torch.nn.Parameter(code_vectors, requires_grad=True)
96
+ torch.nn.init.normal_(self.l2_linear.weight, mean, std)
97
+ if self.l2_linear.bias is not None:
98
+ self.l2_linear.bias.data.fill_(0)
99
+
100
+ def forward(self, x):
101
+ # input: (batch_size, max_seq_length, transformer_hidden_size)
102
+ # output: (batch_size, max_seq_length, transformer_hidden_size)
103
+ # Z = Tan(WH)
104
+ l1_output = self.tanh(self.l1_linear(x))
105
+ # softmax(UZ)
106
+ # l2_linear output shape: (batch_size, max_seq_length, num_labels)
107
+ # attention_weight shape: (batch_size, num_labels, max_seq_length)
108
+ attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
109
+ # attention_output shpae: (batch_size, num_labels, transformer_hidden_size)
110
+ attention_output = torch.matmul(attention_weight, x)
111
+
112
+ return attention_output, attention_weight
113
+
114
+ class ChunkAttentionLayer(torch.nn.Module):
115
+ def __init__(self, coding_model_config, args):
116
+ super(ChunkAttentionLayer, self).__init__()
117
+
118
+ self.config = coding_model_config
119
+ self.args = args
120
+
121
+ # layers
122
+ self.l1_linear = torch.nn.Linear(self.config.d_model,
123
+ self.config.d_model, bias=False)
124
+ self.tanh = torch.nn.Tanh()
125
+ self.l2_linear = torch.nn.Linear(self.config.d_model, 1, bias=False)
126
+ self.softmax = torch.nn.Softmax(dim=1)
127
+
128
+ self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)
129
+
130
+ def _init_linear_weights(self, mean, std):
131
+ # initialize the l1
132
+ torch.nn.init.normal_(self.l1_linear.weight, mean, std)
133
+ if self.l1_linear.bias is not None:
134
+ self.l1_linear.bias.data.fill_(0)
135
+ # initialize the l2
136
+ torch.nn.init.normal_(self.l2_linear.weight, mean, std)
137
+ if self.l2_linear.bias is not None:
138
+ self.l2_linear.bias.data.fill_(0)
139
+
140
+ def forward(self, x):
141
+ # input: (batch_size, num_chunks, transformer_hidden_size)
142
+ # output: (batch_size, num_chunks, transformer_hidden_size)
143
+ # Z = Tan(WH)
144
+ l1_output = self.tanh(self.l1_linear(x))
145
+ # softmax(UZ)
146
+ # l2_linear output shape: (batch_size, num_chunks, 1)
147
+ # attention_weight shape: (batch_size, 1, num_chunks)
148
+ attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
149
+ # attention_output shpae: (batch_size, 1, transformer_hidden_size)
150
+ attention_output = torch.matmul(attention_weight, x)
151
+ return attention_output, attention_weight
152
+
153
+
154
+ class CodingModel(torch.nn.Module, PyTorchModelHubMixin):
155
+ def __init__(self, coding_model_config, args, **kwargs):
156
+ super(CodingModel, self).__init__()
157
+ self.coding_model_config = coding_model_config
158
+ # layers
159
+ self.transformer_layer = AutoModel.from_pretrained('yikuan8/Clinical-Longformer')
160
+ if isinstance(self.transformer_layer, XLNetModel):
161
+ self.transformer_layer.config.use_mems_eval = False
162
+ # if torch.cuda.is_available():
163
+ # self.transformer_layer = self.transformer_layer.to(torch.device("cuda:0"))
164
+ # self.transformer_layer.to(torch.device("cuda:0"))
165
+ self.dropout = Dropout(p=self.coding_model_config.dropout)
166
+ # self.longformer_transformer = AutoModel.from_pretrained("yikuan8/Clinical-Longformer")
167
+
168
+ if self.coding_model_config.multi_head_att:
169
+ # initial multi head attention according to the num_chunks
170
+ self.label_wise_attention_layer = torch.nn.ModuleList(
171
+ [LableWiseAttentionLayer(coding_model_config, args)
172
+ for _ in range(self.coding_model_config.num_chunks)])
173
+ else:
174
+ self.label_wise_attention_layer = LableWiseAttentionLayer(coding_model_config, args)
175
+ self.dropout_att = Dropout(p=self.coding_model_config.dropout_att)
176
+
177
+ # initial chunk attention
178
+ if self.coding_model_config.chunk_att:
179
+ if self.coding_model_config.multi_head_chunk_attention:
180
+ self.chunk_attention_layer = torch.nn.ModuleList([ChunkAttentionLayer(coding_model_config, args)
181
+ for _ in range(self.coding_model_config.num_labels)])
182
+ else:
183
+ self.chunk_attention_layer = ChunkAttentionLayer(coding_model_config, args)
184
+
185
+ self.classifier_layer = Linear(self.coding_model_config.d_model,
186
+ self.coding_model_config.num_labels)
187
+ else:
188
+ if self.coding_model_config.document_pooling_strategy == "flat":
189
+ self.classifier_layer = Linear(self.coding_model_config.num_chunks * self.coding_model_config.d_model,
190
+ self.coding_model_config.num_labels)
191
+ else: # max or mean pooling
192
+ self.classifier_layer = Linear(self.coding_model_config.d_model,
193
+ self.coding_model_config.num_labels)
194
+ self.sigmoid = torch.nn.Sigmoid()
195
+
196
+ if self.coding_model_config.transformer_layer_update_strategy == "no":
197
+ self.freeze_all_transformer_layers()
198
+ elif self.coding_model_config.transformer_layer_update_strategy == "last":
199
+ self.freeze_all_transformer_layers()
200
+ self.unfreeze_transformer_last_layers()
201
+
202
+ # initialize the weights of classifier
203
+ self._init_linear_weights(mean=self.coding_model_config.linear_init_mean, std=self.coding_model_config.linear_init_std)
204
+
205
+ def _init_linear_weights(self, mean, std):
206
+ torch.nn.init.normal_(self.classifier_layer.weight, mean, std)
207
+
208
+ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, targets=None):
209
+ # input ids/mask/type_ids shape: (batch_size, num_chunks, max_seq_length)
210
+ # labels shape: (batch_size, num_labels)
211
+ transformer_output = []
212
+
213
+ # pass chunk by chunk into transformer layer in the batches.
214
+ # input (batch_size, sequence_length)
215
+ # for i in range(self.coding_model_config.num_chunks):
216
+ # l1_output = self.transformer_layer(input_ids=input_ids[:, i, :],
217
+ # attention_mask=attention_mask[:, i, :],
218
+ # token_type_ids=token_type_ids[:, i, :])
219
+ # # output hidden state shape: (batch_size, sequence_length, hidden_size)
220
+ # transformer_output.append(l1_output[0])
221
+
222
+ input_ids = input_ids.reshape(input_ids.shape[0], input_ids.shape[1]*input_ids.shape[2])
223
+ global_attention_mask = torch.zeros_like(input_ids)
224
+ global_attention_positions = [1, 510, 1022, 1534, 2046, 2558, 3070, 3582, 4094]
225
+ global_attention_mask[:, global_attention_positions] = 1
226
+ attention_mask = attention_mask.reshape(attention_mask.shape[0], attention_mask.shape[1]*attention_mask.shape[2])
227
+ token_type_ids = token_type_ids.reshape(token_type_ids.shape[0], token_type_ids.shape[1]*token_type_ids.shape[2])
228
+ l1_output = self.transformer_layer(input_ids=input_ids, attention_mask=attention_mask, global_attention_mask= global_attention_mask, token_type_ids = token_type_ids)
229
+
230
+ transformer_output.append(l1_output[0])
231
+ # transpose back chunk and batch size dimensions
232
+ transformer_output = torch.stack(transformer_output)
233
+ transformer_output = transformer_output.transpose(0, 1)
234
+ # dropout transformer output
235
+ l2_dropout = self.dropout(transformer_output)
236
+
237
+ # config = LongformerConfig.from_pretrained("allenai/longformer-base-4096")
238
+ # config.num_labels =5
239
+ # config.num_hidden_layers = 1
240
+ # longformer_layer = LongformerLayer(config)
241
+ # # longformer_layer = longformer_layer(config)
242
+ # # longformer_layer = longformer_layer.to(torch.device("cuda:0"))
243
+ # l2_dropout= l2_dropout.reshape(l2_dropout.shape[0], l2_dropout.shape[1]*l2_dropout.shape[2], l2_dropout.shape[3])
244
+ # attention_mask = attention_mask.reshape(attention_mask.shape[0], attention_mask.shape[1]*attention_mask.shape[2])
245
+ # is_index_masked = attention_mask < 0
246
+ # is_index_global_attn = attention_mask > 0
247
+ # is_global_attn = is_index_global_attn.flatten().any().item()
248
+ # output = longformer_layer(l2_dropout, attention_mask=attention_mask,output_attentions=True, is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn)
249
+ # l2_dropout = self.dropout_att(output[0]) #l2_dropout - torch.Size([4, 4096, 768])
250
+ # l2_dropout = l2_dropout.reshape(l2_dropout.shape[0], self.coding_model_config.num_chunks, self.coding_model_config.max_seq_length, self.coding_model_config.d_model)
251
+ # #l2_dropout - torch.Size([4, 8, 512, 768])
252
+
253
+
254
+ # Label-wise attention layers
255
+ # output: (batch_size, num_chunks, num_labels, hidden_size)
256
+ attention_output = []
257
+ attention_weights = []
258
+
259
+ for i in range(self.coding_model_config.num_chunks):
260
+ # input: (batch_size, max_seq_length, transformer_hidden_size)
261
+ if self.coding_model_config.multi_head_att:
262
+ attention_layer = self.label_wise_attention_layer[i]
263
+ else:
264
+ attention_layer = self.label_wise_attention_layer
265
+ l3_attention, attention_weight = attention_layer(l2_dropout[:, i, :])
266
+ # l3_attention shape: (batch_size, num_labels, hidden_size) torch.Size([4, 5, 768])
267
+ # attention_weight: (batch_size, num_labels, max_seq_length) torch.Size([4, 5, 512])
268
+ attention_output.append(l3_attention)
269
+ attention_weights.append(attention_weight)
270
+
271
+ attention_output = torch.stack(attention_output)
272
+ attention_output = attention_output.transpose(0, 1) #torch.Size([4, 8, 5, 768])
273
+ attention_weights = torch.stack(attention_weights)
274
+ attention_weights = attention_weights.transpose(0, 1) #torch.Size([4, 8, 5, 512])
275
+
276
+ l3_dropout = self.dropout_att(attention_output) #torch.Size([4, 8, 5, 768])
277
+
278
+ if self.coding_model_config.chunk_att: #set to false
279
+ # Chunk attention layers
280
+ # output: (batch_size, num_labels, hidden_size)
281
+ chunk_attention_output = []
282
+ chunk_attention_weights = []
283
+
284
+ for i in range(self.coding_model_config.num_labels):
285
+ if self.coding_model_config.multi_head_chunk_attention:
286
+ chunk_attention = self.chunk_attention_layer[i]
287
+ else:
288
+ chunk_attention = self.chunk_attention_layer
289
+ l4_chunk_attention, l4_chunk_attention_weights = chunk_attention(l3_dropout[:, :, i])
290
+ chunk_attention_output.append(l4_chunk_attention.squeeze())
291
+ chunk_attention_weights.append(l4_chunk_attention_weights.squeeze())
292
+
293
+ chunk_attention_output = torch.stack(chunk_attention_output) #torch.Size([5, 4, 768])
294
+ chunk_attention_output = chunk_attention_output.transpose(0, 1) #torch.Size([4, 5, 768])
295
+ chunk_attention_weights = torch.stack(chunk_attention_weights)
296
+ chunk_attention_weights = chunk_attention_weights.transpose(0, 1)
297
+ # output shape: (batch_size, num_labels, hidden_size)
298
+ l4_dropout = self.dropout_att(chunk_attention_output) #torch.Size([4, 5, 768])
299
+ else:
300
+ # output shape: (batch_size, num_labels, hidden_size*num_chunks)
301
+ l4_dropout = l3_dropout.transpose(1, 2)
302
+ if self.coding_model_config.document_pooling_strategy == "flat":
303
+ # Flatten layer. concatenate representation by labels
304
+ l4_dropout = torch.flatten(l4_dropout, start_dim=2)
305
+ elif self.coding_model_config.document_pooling_strategy == "max":
306
+ l4_dropout = torch.amax(l4_dropout, 2)
307
+ elif self.coding_model_config.document_pooling_strategy == "mean":
308
+ l4_dropout = torch.mean(l4_dropout, 2)
309
+ else:
310
+ raise ValueError("Not supported pooling strategy")
311
+
312
+ # classifier layer
313
+ # each code has a binary linear formula
314
+ logits = self.classifier_layer.weight.mul(l4_dropout).sum(dim=2).add(self.classifier_layer.bias)
315
+ #torch.Size([4, 5])
316
+ loss_fct = BCEWithLogitsLoss()
317
+ loss = loss_fct(logits, targets)
318
+
319
+ return {
320
+ "loss": loss,
321
+ "logits": logits,
322
+ "label_attention_weights": attention_weights,
323
+ "chunk_attention_weights": chunk_attention_weights if self.coding_model_config.chunk_att else []
324
+ }
325
+
326
+ def freeze_all_transformer_layers(self):
327
+ """
328
+ Freeze all layer weight parameters. They will not be updated during training.
329
+ """
330
+ for param in self.transformer_layer.parameters():
331
+ param.requires_grad = False
332
+
333
+ def unfreeze_all_transformer_layers(self):
334
+ """
335
+ Unfreeze all layers weight parameters. They will be updated during training.
336
+ """
337
+ for param in self.transformer_layer.parameters():
338
+ param.requires_grad = True
339
+
340
+ def unfreeze_transformer_last_layers(self):
341
+ for name, param in self.transformer_layer.named_parameters():
342
+ if "layer.11" in name or "pooler" in name:
343
+ param.requires_grad = True
LongLAT/models/utils.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import linecache
3
+ import pickle
4
+ import random
5
+ import subprocess
6
+
7
+ import numpy as np
8
+ import redis
9
+ import torch
10
+ import logging
11
+ import ast
12
+
13
+ from datasets import Dataset
14
+ from tqdm import tqdm
15
+
16
+ from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, roc_auc_score, roc_curve, auc
17
+ from torch.utils.data import DataLoader
18
+ from transformers import AutoModel, DataCollatorWithPadding, XLNetTokenizer, XLNetTokenizerFast, AutoTokenizer, \
19
+ XLNetModel, is_torch_tpu_available
20
+
21
+ logger = logging.getLogger("lwat")
22
+
23
+
24
+ class MimicIIIDataset(Dataset):
25
+ def __init__(self, data):
26
+ self.input_ids = data["input_ids"]
27
+ self.attention_mask = data["attention_mask"]
28
+ self.token_type_ids = data["token_type_ids"]
29
+ self.labels = data["targets"]
30
+
31
+ def __len__(self):
32
+ return len(self.input_ids)
33
+
34
+ def __getitem__(self, item):
35
+ return {
36
+ "input_ids": torch.tensor(self.input_ids[item], dtype=torch.long),
37
+ "attention_mask": torch.tensor(self.attention_mask[item], dtype=torch.float),
38
+ "token_type_ids": torch.tensor(self.token_type_ids[item], dtype=torch.long),
39
+ "targets": torch.tensor(self.labels[item], dtype=torch.float)
40
+ }
41
+
42
+ class LazyMimicIIIDataset(Dataset):
43
+ def __init__(self, filename, task, dataset_type):
44
+ print("lazy load from {}".format(filename))
45
+ self.filename = filename
46
+ self.redis = redis.Redis(unix_socket_path="/tmp/redis.sock")
47
+ self.pipe = self.redis.pipeline()
48
+ self.num_examples = 0
49
+ self.task = task
50
+ self.dataset_type = dataset_type
51
+ with open(filename, 'r') as f:
52
+ for line_num, line in enumerate(f.readlines()):
53
+ self.num_examples += 1
54
+ example = eval(line)
55
+ key = task + '_' + dataset_type + '_' + str(line_num)
56
+ input_ids = eval(example[0])
57
+ attention_mask = eval(example[1])
58
+ token_type_ids = eval(example[2])
59
+ labels = eval(example[3])
60
+ example_tuple = (input_ids, attention_mask, token_type_ids, labels)
61
+
62
+ self.pipe.set(key, pickle.dumps(example_tuple))
63
+ if line_num % 100 == 0:
64
+ self.pipe.execute()
65
+ self.pipe.execute()
66
+ if is_torch_tpu_available():
67
+ import torch_xla.core.xla_model as xm
68
+ xm.rendezvous(tag="featuresGenerated")
69
+
70
+ def __len__(self):
71
+ return self.num_examples
72
+
73
+ def __getitem__(self, item):
74
+ key = self.task + '_' + self.dataset_type + '_' + str(item)
75
+ example = pickle.loads(self.redis.get(key))
76
+
77
+ return {
78
+ "input_ids": torch.tensor(example[0], dtype=torch.long),
79
+ "attention_mask": torch.tensor(example[1], dtype=torch.float),
80
+ "token_type_ids": torch.tensor(example[2], dtype=torch.long),
81
+ "targets": torch.tensor(example[3], dtype=torch.float)
82
+ }
83
+
84
+
85
+ class ICDCodeDataset(Dataset):
86
+ def __init__(self, data):
87
+ self.input_ids = data["input_ids"]
88
+ self.attention_mask = data["attention_mask"]
89
+ self.token_type_ids = data["token_type_ids"]
90
+
91
+ def __len__(self):
92
+ return len(self.input_ids)
93
+
94
+ def __getitem__(self, item):
95
+ return {
96
+ "input_ids": torch.tensor(self.input_ids[item], dtype=torch.long),
97
+ "attention_mask": torch.tensor(self.attention_mask[item], dtype=torch.float),
98
+ "token_type_ids": torch.tensor(self.token_type_ids[item], dtype=torch.long)
99
+ }
100
+
101
+
102
+ def set_random_seed(random_seed):
103
+ random.seed(random_seed)
104
+ np.random.seed(random_seed)
105
+ torch.manual_seed(random_seed)
106
+ torch.cuda.manual_seed_all(random_seed)
107
+ torch.backends.cudnn.deterministic = True
108
+ torch.backends.cudnn.benchmark = False
109
+
110
+ def tokenize_inputs(text_list, tokenizer, max_seq_len=512):
111
+ """
112
+ Tokenizes the input text input into ids. Appends the appropriate special
113
+ characters to the end of the text to denote end of sentence. Truncate or pad
114
+ the appropriate sequence length.
115
+ """
116
+ # tokenize the text, then truncate sequence to the desired length minus 2 for
117
+ # the 2 special characters
118
+ tokenized_texts = list(map(lambda t: tokenizer.tokenize(t)[:max_seq_len - 2], text_list))
119
+ # convert tokenized text into numeric ids for the appropriate LM
120
+ input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]
121
+ # get token type for token_ids_0
122
+ token_type_ids = [tokenizer.create_token_type_ids_from_sequences(x) for x in input_ids]
123
+ # append special token to end of sentence: <sep> <cls>
124
+ input_ids = [tokenizer.build_inputs_with_special_tokens(x) for x in input_ids]
125
+ # attention mask
126
+ attention_mask = [[1] * len(x) for x in input_ids]
127
+
128
+ # padding to max_length
129
+ def padding_to_max(sequence, value):
130
+ padding_len = max_seq_len - len(sequence)
131
+ padding = [value] * padding_len
132
+ return sequence + padding
133
+
134
+ input_ids = [padding_to_max(x, tokenizer.pad_token_id) for x in input_ids]
135
+ attention_mask = [padding_to_max(x, 0) for x in attention_mask]
136
+ token_type_ids = [padding_to_max(x, tokenizer.pad_token_type_id) for x in token_type_ids]
137
+
138
+ return input_ids, attention_mask, token_type_ids
139
+
140
+
141
+ def tokenize_dataset(tokenizer, text, labels, max_seq_len):
142
+ if (isinstance(tokenizer, XLNetTokenizer) or isinstance(tokenizer, XLNetTokenizerFast)):
143
+ data = list(map(lambda t: tokenize_inputs(t, tokenizer, max_seq_len=max_seq_len), text))
144
+ input_ids, attention_mask, token_type_ids = zip(*data)
145
+ else:
146
+ tokenizer.model_max_length = max_seq_len
147
+ input_dict = tokenizer(text, padding=True, truncation=True)
148
+ input_ids = input_dict["input_ids"]
149
+ attention_mask = input_dict["attention_mask"]
150
+ token_type_ids = input_dict["token_type_ids"]
151
+
152
+ return {
153
+ "input_ids": input_ids,
154
+ "attention_mask": attention_mask,
155
+ "token_type_ids": token_type_ids,
156
+ "targets": labels
157
+ }
158
+
159
+
160
+ def initial_code_title_vectors(label_dict, transformer_model_name, tokenizer_name, code_max_seq_length, code_batch_size,
161
+ d_model, device):
162
+ logger.info("Generate code title representations from base transformer model")
163
+ model = AutoModel.from_pretrained(transformer_model_name)
164
+ if isinstance(model, XLNetModel):
165
+ model.config.use_mems_eval = False
166
+ #
167
+ # model.config.use_mems_eval = False
168
+ # model.config.reuse_len = 0
169
+ code_titles = label_dict["long_title"].fillna("").tolist()
170
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, padding_side="right")
171
+ data = tokenizer(code_titles, padding=True, truncation=True)
172
+ code_dataset = ICDCodeDataset(data)
173
+
174
+ model.to(device)
175
+
176
+ data_collator = DataCollatorWithPadding(tokenizer, padding="max_length",
177
+ max_length=code_max_seq_length)
178
+ code_param = {"batch_size": code_batch_size, "collate_fn": data_collator}
179
+ code_dataloader = DataLoader(code_dataset, **code_param)
180
+
181
+ code_dataloader_progress_bar = tqdm(code_dataloader, unit="batches",
182
+ desc="Code title representations")
183
+ code_dataloader_progress_bar.clear()
184
+
185
+ # output shape: (num_labels, hidden_size)
186
+ initial_code_vectors = torch.zeros(len(code_dataset), d_model)
187
+
188
+ for i, data in enumerate(code_dataloader_progress_bar):
189
+ input_ids = data["input_ids"].to(device, dtype=torch.long)
190
+ attention_mask = data["attention_mask"].to(device, dtype=torch.float)
191
+ token_type_ids = data["token_type_ids"].to(device, dtype=torch.long)
192
+
193
+ # output shape: (batch_size, sequence_length, hidden_size)
194
+ output = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
195
+ # Mean pooling. output shape: (batch_size, hidden_size)
196
+ mean_last_hidden_state = torch.mean(output[0], 1)
197
+ # Max pooling. output shape: (batch_size, hidden_size)
198
+ # max_last_hidden_state = torch.max((output[0] * attention_mask.unsqueeze(-1)), 1)[0]
199
+
200
+ initial_code_vectors[i * input_ids.shape[0]:(i + 1) * input_ids.shape[0], :] = mean_last_hidden_state
201
+
202
+ code_dataloader_progress_bar.refresh(True)
203
+ code_dataloader_progress_bar.clear(True)
204
+ code_dataloader_progress_bar.close()
205
+ logger.info("Code representations ready for use. Shape {}".format(initial_code_vectors.shape))
206
+ return initial_code_vectors
207
+
208
+
209
+ def normalise_labels(labels, n_label):
210
+ norm_labels = []
211
+ for label in labels:
212
+ one_hot_vector_label = [0] * n_label
213
+ one_hot_vector_label[label] = 1
214
+ norm_labels.append(one_hot_vector_label)
215
+ return np.asarray(norm_labels)
216
+
217
+
218
+ def segment_tokenize_inputs(text, tokenizer, max_seq_len, num_chunks):
219
+ # input is full text of one document
220
+ tokenized_texts = []
221
+ tokens = tokenizer.tokenize(text)
222
+ start_idx = 0
223
+ seq_len = max_seq_len - 2
224
+ for i in range(num_chunks):
225
+ if start_idx > len(tokens):
226
+ tokenized_texts.append([])
227
+ continue
228
+ tokenized_texts.append(tokens[start_idx:(start_idx + seq_len)])
229
+ start_idx += seq_len
230
+
231
+ # convert tokenized text into numeric ids for the appropriate LM
232
+ input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]
233
+ # get token type for token_ids_0
234
+ token_type_ids = [tokenizer.create_token_type_ids_from_sequences(x) for x in input_ids]
235
+ # append special token to end of sentence: <sep> <cls>
236
+ input_ids = [tokenizer.build_inputs_with_special_tokens(x) for x in input_ids]
237
+ # attention mask
238
+ attention_mask = [[1] * len(x) for x in input_ids]
239
+
240
+ # padding to max_length
241
+ def padding_to_max(sequence, value):
242
+ padding_len = max_seq_len - len(sequence)
243
+ padding = [value] * padding_len
244
+ return sequence + padding
245
+
246
+ input_ids = [padding_to_max(x, tokenizer.pad_token_id) for x in input_ids]
247
+ attention_mask = [padding_to_max(x, 0) for x in attention_mask]
248
+ token_type_ids = [padding_to_max(x, tokenizer.pad_token_type_id) for x in token_type_ids]
249
+
250
+ return input_ids, attention_mask, token_type_ids
251
+
252
+
253
+ def segment_tokenize_dataset(tokenizer, text, labels, max_seq_len, num_chunks):
254
+ data = list(
255
+ map(lambda t: segment_tokenize_inputs(t, tokenizer, max_seq_len, num_chunks), text))
256
+ input_ids, attention_mask, token_type_ids = zip(*data)
257
+
258
+ return {
259
+ "input_ids": input_ids,
260
+ "attention_mask": attention_mask,
261
+ "token_type_ids": token_type_ids,
262
+ "targets": labels
263
+ }
264
+
265
+
266
+ # The following functions are modified from the relevant codes of https://github.com/aehrc/LAAT
267
+ def roc_auc(true_labels, pred_probs, average="macro"):
268
+ if pred_probs.shape[0] <= 1:
269
+ return
270
+
271
+ fpr = {}
272
+ tpr = {}
273
+ if average == "macro":
274
+ # get AUC for each label individually
275
+ relevant_labels = []
276
+ auc_labels = {}
277
+ for i in range(true_labels.shape[1]):
278
+ # only if there are true positives for this label
279
+ if true_labels[:, i].sum() > 0:
280
+ fpr[i], tpr[i], _ = roc_curve(true_labels[:, i], pred_probs[:, i])
281
+ if len(fpr[i]) > 1 and len(tpr[i]) > 1:
282
+ auc_score = auc(fpr[i], tpr[i])
283
+ if not np.isnan(auc_score):
284
+ auc_labels["auc_%d" % i] = auc_score
285
+ relevant_labels.append(i)
286
+
287
+ # macro-AUC: just average the auc scores
288
+ aucs = []
289
+ for i in relevant_labels:
290
+ aucs.append(auc_labels['auc_%d' % i])
291
+ score = np.mean(aucs)
292
+ else:
293
+ # micro-AUC: just look at each individual prediction
294
+ flat_pred = pred_probs.ravel()
295
+ fpr["micro"], tpr["micro"], _ = roc_curve(true_labels.ravel(), flat_pred)
296
+ score = auc(fpr["micro"], tpr["micro"])
297
+
298
+ return score
299
+
300
+
301
+ def union_size(x, y, axis):
302
+ return np.logical_or(x, y).sum(axis=axis).astype(float)
303
+
304
+
305
+ def intersect_size(x, y, axis):
306
+ return np.logical_and(x, y).sum(axis=axis).astype(float)
307
+
308
+
309
+ def macro_accuracy(true_labels, pred_labels):
310
+ num = intersect_size(true_labels, pred_labels, 0) / (union_size(true_labels, pred_labels, 0) + 1e-10)
311
+ return np.mean(num)
312
+
313
+
314
+ def macro_precision(true_labels, pred_labels):
315
+ num = intersect_size(true_labels, pred_labels, 0) / (pred_labels.sum(axis=0) + 1e-10)
316
+ return np.mean(num)
317
+
318
+
319
+ def macro_recall(true_labels, pred_labels):
320
+ num = intersect_size(true_labels, pred_labels, 0) / (true_labels.sum(axis=0) + 1e-10)
321
+ return np.mean(num)
322
+
323
+
324
+ def macro_f1(true_labels, pred_labels):
325
+ prec = macro_precision(true_labels, pred_labels)
326
+ rec = macro_recall(true_labels, pred_labels)
327
+ if prec + rec == 0:
328
+ f1 = 0.
329
+ else:
330
+ f1 = 2 * (prec * rec) / (prec + rec)
331
+ return prec, rec, f1
332
+
333
+
334
+ def precision_at_k(true_labels, pred_probs, ks=[1, 5, 8, 10, 15]):
335
+ # num true labels in top k predictions / k
336
+ sorted_pred = np.argsort(pred_probs)[:, ::-1]
337
+ output = []
338
+ for k in ks:
339
+ topk = sorted_pred[:, :k]
340
+
341
+ # get precision at k for each example
342
+ vals = []
343
+ for i, tk in enumerate(topk):
344
+ if len(tk) > 0:
345
+ num_true_in_top_k = true_labels[i, tk].sum()
346
+ denom = len(tk)
347
+ vals.append(num_true_in_top_k / float(denom))
348
+
349
+ output.append(np.mean(vals))
350
+ return output
351
+
352
+
353
+ def micro_recall(true_labels, pred_labels):
354
+ flat_true = true_labels.ravel()
355
+ flat_pred = pred_labels.ravel()
356
+ return intersect_size(flat_true, flat_pred, 0) / flat_true.sum(axis=0)
357
+
358
+
359
+ def micro_precision(true_labels, pred_labels):
360
+ flat_true = true_labels.ravel()
361
+ flat_pred = pred_labels.ravel()
362
+ if flat_pred.sum(axis=0) == 0:
363
+ return 0.0
364
+ return intersect_size(flat_true, flat_pred, 0) / flat_pred.sum(axis=0)
365
+
366
+
367
+ def micro_f1(true_labels, pred_labels):
368
+ prec = micro_precision(true_labels, pred_labels)
369
+ rec = micro_recall(true_labels, pred_labels)
370
+ if prec + rec == 0:
371
+ f1 = 0.
372
+ else:
373
+ f1 = 2 * (prec * rec) / (prec + rec)
374
+ return prec, rec, f1
375
+
376
+
377
+ def micro_accuracy(true_labels, pred_labels):
378
+ flat_true = true_labels.ravel()
379
+ flat_pred = pred_labels.ravel()
380
+ return intersect_size(flat_true, flat_pred, 0) / union_size(flat_true, flat_pred, 0)
381
+
382
+
383
+ def calculate_scores(true_labels, logits, average="macro", is_multilabel=True, threshold=0.5):
384
+ def sigmoid(x):
385
+ return 1 / (1 + np.exp(-x))
386
+
387
+ pred_probs = sigmoid(logits)
388
+ pred_labels = np.rint(pred_probs - threshold + 0.5)
389
+
390
+ max_size = min(len(true_labels), len(pred_labels))
391
+ true_labels = true_labels[: max_size]
392
+ pred_labels = pred_labels[: max_size]
393
+ pred_probs = pred_probs[: max_size]
394
+ p_1 = 0
395
+ p_5 = 0
396
+ p_8 = 0
397
+ p_10 = 0
398
+ p_15 = 0
399
+ if pred_probs is not None:
400
+ if not is_multilabel:
401
+ normalised_labels = normalise_labels(true_labels, len(pred_probs[0]))
402
+ auc_score = roc_auc(normalised_labels, pred_probs, average=average)
403
+ accuracy = accuracy_score(true_labels, pred_labels)
404
+ precision = precision_score(true_labels, pred_labels, average=average)
405
+ recall = recall_score(true_labels, pred_labels, average=average)
406
+ f1 = f1_score(true_labels, pred_labels, average=average)
407
+ else:
408
+ if average == "macro":
409
+ accuracy = macro_accuracy(true_labels, pred_labels) # categorical accuracy
410
+ precision, recall, f1 = macro_f1(true_labels, pred_labels)
411
+ p_ks = precision_at_k(true_labels, pred_probs, [1, 5, 8, 10, 15])
412
+ p_1 = p_ks[0]
413
+ p_5 = p_ks[1]
414
+ p_8 = p_ks[2]
415
+ p_10 = p_ks[3]
416
+ p_15 = p_ks[4]
417
+
418
+ else:
419
+ accuracy = micro_accuracy(true_labels, pred_labels)
420
+ precision, recall, f1 = micro_f1(true_labels, pred_labels)
421
+ auc_score = roc_auc(true_labels, pred_probs, average)
422
+
423
+ # Calculate label-wise F1 scores
424
+ labelwise_f1 = f1_score(true_labels, pred_labels, average=None)
425
+ labelwise_f1 = np.array2string(labelwise_f1, separator=',')
426
+
427
+
428
+ else:
429
+ auc_score = -1
430
+
431
+ output = {"{}_precision".format(average): precision, "{}_recall".format(average): recall,
432
+ "{}_f1".format(average): f1, "{}_accuracy".format(average): accuracy,
433
+ "{}_auc".format(average): auc_score, "{}_P@1".format(average): p_1, "{}_P@5".format(average): p_5,
434
+ "{}_P@8".format(average): p_8, "{}_P@10".format(average): p_10, "{}_P@15".format(average): p_15,
435
+ "labelwise_f1": labelwise_f1
436
+ }
437
+
438
+ return output
439
+
440
+