Spaces:
Sleeping
Sleeping
meghanaraok
commited on
Upload 3 files
Browse files- LongLAT/models/__init__.py +0 -0
- LongLAT/models/modeling.py +343 -0
- LongLAT/models/utils.py +440 -0
LongLAT/models/__init__.py
ADDED
File without changes
|
LongLAT/models/modeling.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.nn import BCEWithLogitsLoss, Dropout, Linear
|
6 |
+
from transformers import AutoModel, XLNetModel, LongformerModel, LongformerConfig
|
7 |
+
from transformers.models.longformer.modeling_longformer import LongformerEncoder, LongformerClassificationHead, LongformerLayer
|
8 |
+
from huggingface_hub import PyTorchModelHubMixin
|
9 |
+
|
10 |
+
from LongLAT.hilat.models.utils import initial_code_title_vectors
|
11 |
+
|
12 |
+
logger = logging.getLogger("lwat")
|
13 |
+
|
14 |
+
|
15 |
+
class CodingModelConfig:
|
16 |
+
def __init__(self,
|
17 |
+
transformer_model_name_or_path,
|
18 |
+
transformer_tokenizer_name,
|
19 |
+
transformer_layer_update_strategy,
|
20 |
+
num_chunks,
|
21 |
+
max_seq_length,
|
22 |
+
dropout,
|
23 |
+
dropout_att,
|
24 |
+
d_model,
|
25 |
+
label_dictionary,
|
26 |
+
num_labels,
|
27 |
+
use_code_representation,
|
28 |
+
code_max_seq_length,
|
29 |
+
code_batch_size,
|
30 |
+
multi_head_att,
|
31 |
+
chunk_att,
|
32 |
+
linear_init_mean,
|
33 |
+
linear_init_std,
|
34 |
+
document_pooling_strategy,
|
35 |
+
multi_head_chunk_attention,
|
36 |
+
num_hidden_layers):
|
37 |
+
super(CodingModelConfig, self).__init__()
|
38 |
+
self.transformer_model_name_or_path = transformer_model_name_or_path
|
39 |
+
self.transformer_tokenizer_name = transformer_tokenizer_name
|
40 |
+
self.transformer_layer_update_strategy = transformer_layer_update_strategy
|
41 |
+
self.num_chunks = num_chunks
|
42 |
+
self.max_seq_length = max_seq_length
|
43 |
+
self.dropout = dropout
|
44 |
+
self.dropout_att = dropout_att
|
45 |
+
self.d_model = d_model
|
46 |
+
# labels_dictionary is a dataframe with columns: icd9_code, long_title
|
47 |
+
self.label_dictionary = label_dictionary
|
48 |
+
self.num_labels = num_labels
|
49 |
+
self.use_code_representation = use_code_representation
|
50 |
+
self.code_max_seq_length = code_max_seq_length
|
51 |
+
self.code_batch_size = code_batch_size
|
52 |
+
self.multi_head_att = multi_head_att
|
53 |
+
self.chunk_att = chunk_att
|
54 |
+
self.linear_init_mean = linear_init_mean
|
55 |
+
self.linear_init_std = linear_init_std
|
56 |
+
self.document_pooling_strategy = document_pooling_strategy
|
57 |
+
self.multi_head_chunk_attention = multi_head_chunk_attention
|
58 |
+
self.num_hidden_layers = num_hidden_layers
|
59 |
+
|
60 |
+
|
61 |
+
class LableWiseAttentionLayer(torch.nn.Module):
|
62 |
+
def __init__(self, coding_model_config, args):
|
63 |
+
super(LableWiseAttentionLayer, self).__init__()
|
64 |
+
|
65 |
+
self.config = coding_model_config
|
66 |
+
self.args = args
|
67 |
+
|
68 |
+
# layers
|
69 |
+
self.l1_linear = torch.nn.Linear(self.config.d_model,
|
70 |
+
self.config.d_model, bias=False)
|
71 |
+
self.tanh = torch.nn.Tanh()
|
72 |
+
self.l2_linear = torch.nn.Linear(self.config.d_model, self.config.num_labels, bias=False)
|
73 |
+
self.softmax = torch.nn.Softmax(dim=1)
|
74 |
+
|
75 |
+
# Mean pooling last hidden state of code title from transformer model as the initial code vectors
|
76 |
+
self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)
|
77 |
+
|
78 |
+
def _init_linear_weights(self, mean, std):
|
79 |
+
# normalize the l1 weights
|
80 |
+
torch.nn.init.normal_(self.l1_linear.weight, mean, std)
|
81 |
+
if self.l1_linear.bias is not None:
|
82 |
+
self.l1_linear.bias.data.fill_(0)
|
83 |
+
# initialize the l2
|
84 |
+
if self.config.use_code_representation:
|
85 |
+
code_vectors = initial_code_title_vectors(self.config.label_dictionary,
|
86 |
+
self.config.transformer_model_name_or_path,
|
87 |
+
self.config.transformer_tokenizer_name
|
88 |
+
if self.config.transformer_tokenizer_name
|
89 |
+
else self.config.transformer_model_name_or_path,
|
90 |
+
self.config.code_max_seq_length,
|
91 |
+
self.config.code_batch_size,
|
92 |
+
self.config.d_model,
|
93 |
+
self.args.device)
|
94 |
+
|
95 |
+
self.l2_linear.weight = torch.nn.Parameter(code_vectors, requires_grad=True)
|
96 |
+
torch.nn.init.normal_(self.l2_linear.weight, mean, std)
|
97 |
+
if self.l2_linear.bias is not None:
|
98 |
+
self.l2_linear.bias.data.fill_(0)
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
# input: (batch_size, max_seq_length, transformer_hidden_size)
|
102 |
+
# output: (batch_size, max_seq_length, transformer_hidden_size)
|
103 |
+
# Z = Tan(WH)
|
104 |
+
l1_output = self.tanh(self.l1_linear(x))
|
105 |
+
# softmax(UZ)
|
106 |
+
# l2_linear output shape: (batch_size, max_seq_length, num_labels)
|
107 |
+
# attention_weight shape: (batch_size, num_labels, max_seq_length)
|
108 |
+
attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
|
109 |
+
# attention_output shpae: (batch_size, num_labels, transformer_hidden_size)
|
110 |
+
attention_output = torch.matmul(attention_weight, x)
|
111 |
+
|
112 |
+
return attention_output, attention_weight
|
113 |
+
|
114 |
+
class ChunkAttentionLayer(torch.nn.Module):
|
115 |
+
def __init__(self, coding_model_config, args):
|
116 |
+
super(ChunkAttentionLayer, self).__init__()
|
117 |
+
|
118 |
+
self.config = coding_model_config
|
119 |
+
self.args = args
|
120 |
+
|
121 |
+
# layers
|
122 |
+
self.l1_linear = torch.nn.Linear(self.config.d_model,
|
123 |
+
self.config.d_model, bias=False)
|
124 |
+
self.tanh = torch.nn.Tanh()
|
125 |
+
self.l2_linear = torch.nn.Linear(self.config.d_model, 1, bias=False)
|
126 |
+
self.softmax = torch.nn.Softmax(dim=1)
|
127 |
+
|
128 |
+
self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)
|
129 |
+
|
130 |
+
def _init_linear_weights(self, mean, std):
|
131 |
+
# initialize the l1
|
132 |
+
torch.nn.init.normal_(self.l1_linear.weight, mean, std)
|
133 |
+
if self.l1_linear.bias is not None:
|
134 |
+
self.l1_linear.bias.data.fill_(0)
|
135 |
+
# initialize the l2
|
136 |
+
torch.nn.init.normal_(self.l2_linear.weight, mean, std)
|
137 |
+
if self.l2_linear.bias is not None:
|
138 |
+
self.l2_linear.bias.data.fill_(0)
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
# input: (batch_size, num_chunks, transformer_hidden_size)
|
142 |
+
# output: (batch_size, num_chunks, transformer_hidden_size)
|
143 |
+
# Z = Tan(WH)
|
144 |
+
l1_output = self.tanh(self.l1_linear(x))
|
145 |
+
# softmax(UZ)
|
146 |
+
# l2_linear output shape: (batch_size, num_chunks, 1)
|
147 |
+
# attention_weight shape: (batch_size, 1, num_chunks)
|
148 |
+
attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
|
149 |
+
# attention_output shpae: (batch_size, 1, transformer_hidden_size)
|
150 |
+
attention_output = torch.matmul(attention_weight, x)
|
151 |
+
return attention_output, attention_weight
|
152 |
+
|
153 |
+
|
154 |
+
class CodingModel(torch.nn.Module, PyTorchModelHubMixin):
|
155 |
+
def __init__(self, coding_model_config, args, **kwargs):
|
156 |
+
super(CodingModel, self).__init__()
|
157 |
+
self.coding_model_config = coding_model_config
|
158 |
+
# layers
|
159 |
+
self.transformer_layer = AutoModel.from_pretrained('yikuan8/Clinical-Longformer')
|
160 |
+
if isinstance(self.transformer_layer, XLNetModel):
|
161 |
+
self.transformer_layer.config.use_mems_eval = False
|
162 |
+
# if torch.cuda.is_available():
|
163 |
+
# self.transformer_layer = self.transformer_layer.to(torch.device("cuda:0"))
|
164 |
+
# self.transformer_layer.to(torch.device("cuda:0"))
|
165 |
+
self.dropout = Dropout(p=self.coding_model_config.dropout)
|
166 |
+
# self.longformer_transformer = AutoModel.from_pretrained("yikuan8/Clinical-Longformer")
|
167 |
+
|
168 |
+
if self.coding_model_config.multi_head_att:
|
169 |
+
# initial multi head attention according to the num_chunks
|
170 |
+
self.label_wise_attention_layer = torch.nn.ModuleList(
|
171 |
+
[LableWiseAttentionLayer(coding_model_config, args)
|
172 |
+
for _ in range(self.coding_model_config.num_chunks)])
|
173 |
+
else:
|
174 |
+
self.label_wise_attention_layer = LableWiseAttentionLayer(coding_model_config, args)
|
175 |
+
self.dropout_att = Dropout(p=self.coding_model_config.dropout_att)
|
176 |
+
|
177 |
+
# initial chunk attention
|
178 |
+
if self.coding_model_config.chunk_att:
|
179 |
+
if self.coding_model_config.multi_head_chunk_attention:
|
180 |
+
self.chunk_attention_layer = torch.nn.ModuleList([ChunkAttentionLayer(coding_model_config, args)
|
181 |
+
for _ in range(self.coding_model_config.num_labels)])
|
182 |
+
else:
|
183 |
+
self.chunk_attention_layer = ChunkAttentionLayer(coding_model_config, args)
|
184 |
+
|
185 |
+
self.classifier_layer = Linear(self.coding_model_config.d_model,
|
186 |
+
self.coding_model_config.num_labels)
|
187 |
+
else:
|
188 |
+
if self.coding_model_config.document_pooling_strategy == "flat":
|
189 |
+
self.classifier_layer = Linear(self.coding_model_config.num_chunks * self.coding_model_config.d_model,
|
190 |
+
self.coding_model_config.num_labels)
|
191 |
+
else: # max or mean pooling
|
192 |
+
self.classifier_layer = Linear(self.coding_model_config.d_model,
|
193 |
+
self.coding_model_config.num_labels)
|
194 |
+
self.sigmoid = torch.nn.Sigmoid()
|
195 |
+
|
196 |
+
if self.coding_model_config.transformer_layer_update_strategy == "no":
|
197 |
+
self.freeze_all_transformer_layers()
|
198 |
+
elif self.coding_model_config.transformer_layer_update_strategy == "last":
|
199 |
+
self.freeze_all_transformer_layers()
|
200 |
+
self.unfreeze_transformer_last_layers()
|
201 |
+
|
202 |
+
# initialize the weights of classifier
|
203 |
+
self._init_linear_weights(mean=self.coding_model_config.linear_init_mean, std=self.coding_model_config.linear_init_std)
|
204 |
+
|
205 |
+
def _init_linear_weights(self, mean, std):
|
206 |
+
torch.nn.init.normal_(self.classifier_layer.weight, mean, std)
|
207 |
+
|
208 |
+
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, targets=None):
|
209 |
+
# input ids/mask/type_ids shape: (batch_size, num_chunks, max_seq_length)
|
210 |
+
# labels shape: (batch_size, num_labels)
|
211 |
+
transformer_output = []
|
212 |
+
|
213 |
+
# pass chunk by chunk into transformer layer in the batches.
|
214 |
+
# input (batch_size, sequence_length)
|
215 |
+
# for i in range(self.coding_model_config.num_chunks):
|
216 |
+
# l1_output = self.transformer_layer(input_ids=input_ids[:, i, :],
|
217 |
+
# attention_mask=attention_mask[:, i, :],
|
218 |
+
# token_type_ids=token_type_ids[:, i, :])
|
219 |
+
# # output hidden state shape: (batch_size, sequence_length, hidden_size)
|
220 |
+
# transformer_output.append(l1_output[0])
|
221 |
+
|
222 |
+
input_ids = input_ids.reshape(input_ids.shape[0], input_ids.shape[1]*input_ids.shape[2])
|
223 |
+
global_attention_mask = torch.zeros_like(input_ids)
|
224 |
+
global_attention_positions = [1, 510, 1022, 1534, 2046, 2558, 3070, 3582, 4094]
|
225 |
+
global_attention_mask[:, global_attention_positions] = 1
|
226 |
+
attention_mask = attention_mask.reshape(attention_mask.shape[0], attention_mask.shape[1]*attention_mask.shape[2])
|
227 |
+
token_type_ids = token_type_ids.reshape(token_type_ids.shape[0], token_type_ids.shape[1]*token_type_ids.shape[2])
|
228 |
+
l1_output = self.transformer_layer(input_ids=input_ids, attention_mask=attention_mask, global_attention_mask= global_attention_mask, token_type_ids = token_type_ids)
|
229 |
+
|
230 |
+
transformer_output.append(l1_output[0])
|
231 |
+
# transpose back chunk and batch size dimensions
|
232 |
+
transformer_output = torch.stack(transformer_output)
|
233 |
+
transformer_output = transformer_output.transpose(0, 1)
|
234 |
+
# dropout transformer output
|
235 |
+
l2_dropout = self.dropout(transformer_output)
|
236 |
+
|
237 |
+
# config = LongformerConfig.from_pretrained("allenai/longformer-base-4096")
|
238 |
+
# config.num_labels =5
|
239 |
+
# config.num_hidden_layers = 1
|
240 |
+
# longformer_layer = LongformerLayer(config)
|
241 |
+
# # longformer_layer = longformer_layer(config)
|
242 |
+
# # longformer_layer = longformer_layer.to(torch.device("cuda:0"))
|
243 |
+
# l2_dropout= l2_dropout.reshape(l2_dropout.shape[0], l2_dropout.shape[1]*l2_dropout.shape[2], l2_dropout.shape[3])
|
244 |
+
# attention_mask = attention_mask.reshape(attention_mask.shape[0], attention_mask.shape[1]*attention_mask.shape[2])
|
245 |
+
# is_index_masked = attention_mask < 0
|
246 |
+
# is_index_global_attn = attention_mask > 0
|
247 |
+
# is_global_attn = is_index_global_attn.flatten().any().item()
|
248 |
+
# output = longformer_layer(l2_dropout, attention_mask=attention_mask,output_attentions=True, is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn)
|
249 |
+
# l2_dropout = self.dropout_att(output[0]) #l2_dropout - torch.Size([4, 4096, 768])
|
250 |
+
# l2_dropout = l2_dropout.reshape(l2_dropout.shape[0], self.coding_model_config.num_chunks, self.coding_model_config.max_seq_length, self.coding_model_config.d_model)
|
251 |
+
# #l2_dropout - torch.Size([4, 8, 512, 768])
|
252 |
+
|
253 |
+
|
254 |
+
# Label-wise attention layers
|
255 |
+
# output: (batch_size, num_chunks, num_labels, hidden_size)
|
256 |
+
attention_output = []
|
257 |
+
attention_weights = []
|
258 |
+
|
259 |
+
for i in range(self.coding_model_config.num_chunks):
|
260 |
+
# input: (batch_size, max_seq_length, transformer_hidden_size)
|
261 |
+
if self.coding_model_config.multi_head_att:
|
262 |
+
attention_layer = self.label_wise_attention_layer[i]
|
263 |
+
else:
|
264 |
+
attention_layer = self.label_wise_attention_layer
|
265 |
+
l3_attention, attention_weight = attention_layer(l2_dropout[:, i, :])
|
266 |
+
# l3_attention shape: (batch_size, num_labels, hidden_size) torch.Size([4, 5, 768])
|
267 |
+
# attention_weight: (batch_size, num_labels, max_seq_length) torch.Size([4, 5, 512])
|
268 |
+
attention_output.append(l3_attention)
|
269 |
+
attention_weights.append(attention_weight)
|
270 |
+
|
271 |
+
attention_output = torch.stack(attention_output)
|
272 |
+
attention_output = attention_output.transpose(0, 1) #torch.Size([4, 8, 5, 768])
|
273 |
+
attention_weights = torch.stack(attention_weights)
|
274 |
+
attention_weights = attention_weights.transpose(0, 1) #torch.Size([4, 8, 5, 512])
|
275 |
+
|
276 |
+
l3_dropout = self.dropout_att(attention_output) #torch.Size([4, 8, 5, 768])
|
277 |
+
|
278 |
+
if self.coding_model_config.chunk_att: #set to false
|
279 |
+
# Chunk attention layers
|
280 |
+
# output: (batch_size, num_labels, hidden_size)
|
281 |
+
chunk_attention_output = []
|
282 |
+
chunk_attention_weights = []
|
283 |
+
|
284 |
+
for i in range(self.coding_model_config.num_labels):
|
285 |
+
if self.coding_model_config.multi_head_chunk_attention:
|
286 |
+
chunk_attention = self.chunk_attention_layer[i]
|
287 |
+
else:
|
288 |
+
chunk_attention = self.chunk_attention_layer
|
289 |
+
l4_chunk_attention, l4_chunk_attention_weights = chunk_attention(l3_dropout[:, :, i])
|
290 |
+
chunk_attention_output.append(l4_chunk_attention.squeeze())
|
291 |
+
chunk_attention_weights.append(l4_chunk_attention_weights.squeeze())
|
292 |
+
|
293 |
+
chunk_attention_output = torch.stack(chunk_attention_output) #torch.Size([5, 4, 768])
|
294 |
+
chunk_attention_output = chunk_attention_output.transpose(0, 1) #torch.Size([4, 5, 768])
|
295 |
+
chunk_attention_weights = torch.stack(chunk_attention_weights)
|
296 |
+
chunk_attention_weights = chunk_attention_weights.transpose(0, 1)
|
297 |
+
# output shape: (batch_size, num_labels, hidden_size)
|
298 |
+
l4_dropout = self.dropout_att(chunk_attention_output) #torch.Size([4, 5, 768])
|
299 |
+
else:
|
300 |
+
# output shape: (batch_size, num_labels, hidden_size*num_chunks)
|
301 |
+
l4_dropout = l3_dropout.transpose(1, 2)
|
302 |
+
if self.coding_model_config.document_pooling_strategy == "flat":
|
303 |
+
# Flatten layer. concatenate representation by labels
|
304 |
+
l4_dropout = torch.flatten(l4_dropout, start_dim=2)
|
305 |
+
elif self.coding_model_config.document_pooling_strategy == "max":
|
306 |
+
l4_dropout = torch.amax(l4_dropout, 2)
|
307 |
+
elif self.coding_model_config.document_pooling_strategy == "mean":
|
308 |
+
l4_dropout = torch.mean(l4_dropout, 2)
|
309 |
+
else:
|
310 |
+
raise ValueError("Not supported pooling strategy")
|
311 |
+
|
312 |
+
# classifier layer
|
313 |
+
# each code has a binary linear formula
|
314 |
+
logits = self.classifier_layer.weight.mul(l4_dropout).sum(dim=2).add(self.classifier_layer.bias)
|
315 |
+
#torch.Size([4, 5])
|
316 |
+
loss_fct = BCEWithLogitsLoss()
|
317 |
+
loss = loss_fct(logits, targets)
|
318 |
+
|
319 |
+
return {
|
320 |
+
"loss": loss,
|
321 |
+
"logits": logits,
|
322 |
+
"label_attention_weights": attention_weights,
|
323 |
+
"chunk_attention_weights": chunk_attention_weights if self.coding_model_config.chunk_att else []
|
324 |
+
}
|
325 |
+
|
326 |
+
def freeze_all_transformer_layers(self):
|
327 |
+
"""
|
328 |
+
Freeze all layer weight parameters. They will not be updated during training.
|
329 |
+
"""
|
330 |
+
for param in self.transformer_layer.parameters():
|
331 |
+
param.requires_grad = False
|
332 |
+
|
333 |
+
def unfreeze_all_transformer_layers(self):
|
334 |
+
"""
|
335 |
+
Unfreeze all layers weight parameters. They will be updated during training.
|
336 |
+
"""
|
337 |
+
for param in self.transformer_layer.parameters():
|
338 |
+
param.requires_grad = True
|
339 |
+
|
340 |
+
def unfreeze_transformer_last_layers(self):
|
341 |
+
for name, param in self.transformer_layer.named_parameters():
|
342 |
+
if "layer.11" in name or "pooler" in name:
|
343 |
+
param.requires_grad = True
|
LongLAT/models/utils.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import linecache
|
3 |
+
import pickle
|
4 |
+
import random
|
5 |
+
import subprocess
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import redis
|
9 |
+
import torch
|
10 |
+
import logging
|
11 |
+
import ast
|
12 |
+
|
13 |
+
from datasets import Dataset
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, roc_auc_score, roc_curve, auc
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from transformers import AutoModel, DataCollatorWithPadding, XLNetTokenizer, XLNetTokenizerFast, AutoTokenizer, \
|
19 |
+
XLNetModel, is_torch_tpu_available
|
20 |
+
|
21 |
+
logger = logging.getLogger("lwat")
|
22 |
+
|
23 |
+
|
24 |
+
class MimicIIIDataset(Dataset):
|
25 |
+
def __init__(self, data):
|
26 |
+
self.input_ids = data["input_ids"]
|
27 |
+
self.attention_mask = data["attention_mask"]
|
28 |
+
self.token_type_ids = data["token_type_ids"]
|
29 |
+
self.labels = data["targets"]
|
30 |
+
|
31 |
+
def __len__(self):
|
32 |
+
return len(self.input_ids)
|
33 |
+
|
34 |
+
def __getitem__(self, item):
|
35 |
+
return {
|
36 |
+
"input_ids": torch.tensor(self.input_ids[item], dtype=torch.long),
|
37 |
+
"attention_mask": torch.tensor(self.attention_mask[item], dtype=torch.float),
|
38 |
+
"token_type_ids": torch.tensor(self.token_type_ids[item], dtype=torch.long),
|
39 |
+
"targets": torch.tensor(self.labels[item], dtype=torch.float)
|
40 |
+
}
|
41 |
+
|
42 |
+
class LazyMimicIIIDataset(Dataset):
|
43 |
+
def __init__(self, filename, task, dataset_type):
|
44 |
+
print("lazy load from {}".format(filename))
|
45 |
+
self.filename = filename
|
46 |
+
self.redis = redis.Redis(unix_socket_path="/tmp/redis.sock")
|
47 |
+
self.pipe = self.redis.pipeline()
|
48 |
+
self.num_examples = 0
|
49 |
+
self.task = task
|
50 |
+
self.dataset_type = dataset_type
|
51 |
+
with open(filename, 'r') as f:
|
52 |
+
for line_num, line in enumerate(f.readlines()):
|
53 |
+
self.num_examples += 1
|
54 |
+
example = eval(line)
|
55 |
+
key = task + '_' + dataset_type + '_' + str(line_num)
|
56 |
+
input_ids = eval(example[0])
|
57 |
+
attention_mask = eval(example[1])
|
58 |
+
token_type_ids = eval(example[2])
|
59 |
+
labels = eval(example[3])
|
60 |
+
example_tuple = (input_ids, attention_mask, token_type_ids, labels)
|
61 |
+
|
62 |
+
self.pipe.set(key, pickle.dumps(example_tuple))
|
63 |
+
if line_num % 100 == 0:
|
64 |
+
self.pipe.execute()
|
65 |
+
self.pipe.execute()
|
66 |
+
if is_torch_tpu_available():
|
67 |
+
import torch_xla.core.xla_model as xm
|
68 |
+
xm.rendezvous(tag="featuresGenerated")
|
69 |
+
|
70 |
+
def __len__(self):
|
71 |
+
return self.num_examples
|
72 |
+
|
73 |
+
def __getitem__(self, item):
|
74 |
+
key = self.task + '_' + self.dataset_type + '_' + str(item)
|
75 |
+
example = pickle.loads(self.redis.get(key))
|
76 |
+
|
77 |
+
return {
|
78 |
+
"input_ids": torch.tensor(example[0], dtype=torch.long),
|
79 |
+
"attention_mask": torch.tensor(example[1], dtype=torch.float),
|
80 |
+
"token_type_ids": torch.tensor(example[2], dtype=torch.long),
|
81 |
+
"targets": torch.tensor(example[3], dtype=torch.float)
|
82 |
+
}
|
83 |
+
|
84 |
+
|
85 |
+
class ICDCodeDataset(Dataset):
|
86 |
+
def __init__(self, data):
|
87 |
+
self.input_ids = data["input_ids"]
|
88 |
+
self.attention_mask = data["attention_mask"]
|
89 |
+
self.token_type_ids = data["token_type_ids"]
|
90 |
+
|
91 |
+
def __len__(self):
|
92 |
+
return len(self.input_ids)
|
93 |
+
|
94 |
+
def __getitem__(self, item):
|
95 |
+
return {
|
96 |
+
"input_ids": torch.tensor(self.input_ids[item], dtype=torch.long),
|
97 |
+
"attention_mask": torch.tensor(self.attention_mask[item], dtype=torch.float),
|
98 |
+
"token_type_ids": torch.tensor(self.token_type_ids[item], dtype=torch.long)
|
99 |
+
}
|
100 |
+
|
101 |
+
|
102 |
+
def set_random_seed(random_seed):
|
103 |
+
random.seed(random_seed)
|
104 |
+
np.random.seed(random_seed)
|
105 |
+
torch.manual_seed(random_seed)
|
106 |
+
torch.cuda.manual_seed_all(random_seed)
|
107 |
+
torch.backends.cudnn.deterministic = True
|
108 |
+
torch.backends.cudnn.benchmark = False
|
109 |
+
|
110 |
+
def tokenize_inputs(text_list, tokenizer, max_seq_len=512):
|
111 |
+
"""
|
112 |
+
Tokenizes the input text input into ids. Appends the appropriate special
|
113 |
+
characters to the end of the text to denote end of sentence. Truncate or pad
|
114 |
+
the appropriate sequence length.
|
115 |
+
"""
|
116 |
+
# tokenize the text, then truncate sequence to the desired length minus 2 for
|
117 |
+
# the 2 special characters
|
118 |
+
tokenized_texts = list(map(lambda t: tokenizer.tokenize(t)[:max_seq_len - 2], text_list))
|
119 |
+
# convert tokenized text into numeric ids for the appropriate LM
|
120 |
+
input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]
|
121 |
+
# get token type for token_ids_0
|
122 |
+
token_type_ids = [tokenizer.create_token_type_ids_from_sequences(x) for x in input_ids]
|
123 |
+
# append special token to end of sentence: <sep> <cls>
|
124 |
+
input_ids = [tokenizer.build_inputs_with_special_tokens(x) for x in input_ids]
|
125 |
+
# attention mask
|
126 |
+
attention_mask = [[1] * len(x) for x in input_ids]
|
127 |
+
|
128 |
+
# padding to max_length
|
129 |
+
def padding_to_max(sequence, value):
|
130 |
+
padding_len = max_seq_len - len(sequence)
|
131 |
+
padding = [value] * padding_len
|
132 |
+
return sequence + padding
|
133 |
+
|
134 |
+
input_ids = [padding_to_max(x, tokenizer.pad_token_id) for x in input_ids]
|
135 |
+
attention_mask = [padding_to_max(x, 0) for x in attention_mask]
|
136 |
+
token_type_ids = [padding_to_max(x, tokenizer.pad_token_type_id) for x in token_type_ids]
|
137 |
+
|
138 |
+
return input_ids, attention_mask, token_type_ids
|
139 |
+
|
140 |
+
|
141 |
+
def tokenize_dataset(tokenizer, text, labels, max_seq_len):
|
142 |
+
if (isinstance(tokenizer, XLNetTokenizer) or isinstance(tokenizer, XLNetTokenizerFast)):
|
143 |
+
data = list(map(lambda t: tokenize_inputs(t, tokenizer, max_seq_len=max_seq_len), text))
|
144 |
+
input_ids, attention_mask, token_type_ids = zip(*data)
|
145 |
+
else:
|
146 |
+
tokenizer.model_max_length = max_seq_len
|
147 |
+
input_dict = tokenizer(text, padding=True, truncation=True)
|
148 |
+
input_ids = input_dict["input_ids"]
|
149 |
+
attention_mask = input_dict["attention_mask"]
|
150 |
+
token_type_ids = input_dict["token_type_ids"]
|
151 |
+
|
152 |
+
return {
|
153 |
+
"input_ids": input_ids,
|
154 |
+
"attention_mask": attention_mask,
|
155 |
+
"token_type_ids": token_type_ids,
|
156 |
+
"targets": labels
|
157 |
+
}
|
158 |
+
|
159 |
+
|
160 |
+
def initial_code_title_vectors(label_dict, transformer_model_name, tokenizer_name, code_max_seq_length, code_batch_size,
|
161 |
+
d_model, device):
|
162 |
+
logger.info("Generate code title representations from base transformer model")
|
163 |
+
model = AutoModel.from_pretrained(transformer_model_name)
|
164 |
+
if isinstance(model, XLNetModel):
|
165 |
+
model.config.use_mems_eval = False
|
166 |
+
#
|
167 |
+
# model.config.use_mems_eval = False
|
168 |
+
# model.config.reuse_len = 0
|
169 |
+
code_titles = label_dict["long_title"].fillna("").tolist()
|
170 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, padding_side="right")
|
171 |
+
data = tokenizer(code_titles, padding=True, truncation=True)
|
172 |
+
code_dataset = ICDCodeDataset(data)
|
173 |
+
|
174 |
+
model.to(device)
|
175 |
+
|
176 |
+
data_collator = DataCollatorWithPadding(tokenizer, padding="max_length",
|
177 |
+
max_length=code_max_seq_length)
|
178 |
+
code_param = {"batch_size": code_batch_size, "collate_fn": data_collator}
|
179 |
+
code_dataloader = DataLoader(code_dataset, **code_param)
|
180 |
+
|
181 |
+
code_dataloader_progress_bar = tqdm(code_dataloader, unit="batches",
|
182 |
+
desc="Code title representations")
|
183 |
+
code_dataloader_progress_bar.clear()
|
184 |
+
|
185 |
+
# output shape: (num_labels, hidden_size)
|
186 |
+
initial_code_vectors = torch.zeros(len(code_dataset), d_model)
|
187 |
+
|
188 |
+
for i, data in enumerate(code_dataloader_progress_bar):
|
189 |
+
input_ids = data["input_ids"].to(device, dtype=torch.long)
|
190 |
+
attention_mask = data["attention_mask"].to(device, dtype=torch.float)
|
191 |
+
token_type_ids = data["token_type_ids"].to(device, dtype=torch.long)
|
192 |
+
|
193 |
+
# output shape: (batch_size, sequence_length, hidden_size)
|
194 |
+
output = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
195 |
+
# Mean pooling. output shape: (batch_size, hidden_size)
|
196 |
+
mean_last_hidden_state = torch.mean(output[0], 1)
|
197 |
+
# Max pooling. output shape: (batch_size, hidden_size)
|
198 |
+
# max_last_hidden_state = torch.max((output[0] * attention_mask.unsqueeze(-1)), 1)[0]
|
199 |
+
|
200 |
+
initial_code_vectors[i * input_ids.shape[0]:(i + 1) * input_ids.shape[0], :] = mean_last_hidden_state
|
201 |
+
|
202 |
+
code_dataloader_progress_bar.refresh(True)
|
203 |
+
code_dataloader_progress_bar.clear(True)
|
204 |
+
code_dataloader_progress_bar.close()
|
205 |
+
logger.info("Code representations ready for use. Shape {}".format(initial_code_vectors.shape))
|
206 |
+
return initial_code_vectors
|
207 |
+
|
208 |
+
|
209 |
+
def normalise_labels(labels, n_label):
|
210 |
+
norm_labels = []
|
211 |
+
for label in labels:
|
212 |
+
one_hot_vector_label = [0] * n_label
|
213 |
+
one_hot_vector_label[label] = 1
|
214 |
+
norm_labels.append(one_hot_vector_label)
|
215 |
+
return np.asarray(norm_labels)
|
216 |
+
|
217 |
+
|
218 |
+
def segment_tokenize_inputs(text, tokenizer, max_seq_len, num_chunks):
|
219 |
+
# input is full text of one document
|
220 |
+
tokenized_texts = []
|
221 |
+
tokens = tokenizer.tokenize(text)
|
222 |
+
start_idx = 0
|
223 |
+
seq_len = max_seq_len - 2
|
224 |
+
for i in range(num_chunks):
|
225 |
+
if start_idx > len(tokens):
|
226 |
+
tokenized_texts.append([])
|
227 |
+
continue
|
228 |
+
tokenized_texts.append(tokens[start_idx:(start_idx + seq_len)])
|
229 |
+
start_idx += seq_len
|
230 |
+
|
231 |
+
# convert tokenized text into numeric ids for the appropriate LM
|
232 |
+
input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]
|
233 |
+
# get token type for token_ids_0
|
234 |
+
token_type_ids = [tokenizer.create_token_type_ids_from_sequences(x) for x in input_ids]
|
235 |
+
# append special token to end of sentence: <sep> <cls>
|
236 |
+
input_ids = [tokenizer.build_inputs_with_special_tokens(x) for x in input_ids]
|
237 |
+
# attention mask
|
238 |
+
attention_mask = [[1] * len(x) for x in input_ids]
|
239 |
+
|
240 |
+
# padding to max_length
|
241 |
+
def padding_to_max(sequence, value):
|
242 |
+
padding_len = max_seq_len - len(sequence)
|
243 |
+
padding = [value] * padding_len
|
244 |
+
return sequence + padding
|
245 |
+
|
246 |
+
input_ids = [padding_to_max(x, tokenizer.pad_token_id) for x in input_ids]
|
247 |
+
attention_mask = [padding_to_max(x, 0) for x in attention_mask]
|
248 |
+
token_type_ids = [padding_to_max(x, tokenizer.pad_token_type_id) for x in token_type_ids]
|
249 |
+
|
250 |
+
return input_ids, attention_mask, token_type_ids
|
251 |
+
|
252 |
+
|
253 |
+
def segment_tokenize_dataset(tokenizer, text, labels, max_seq_len, num_chunks):
|
254 |
+
data = list(
|
255 |
+
map(lambda t: segment_tokenize_inputs(t, tokenizer, max_seq_len, num_chunks), text))
|
256 |
+
input_ids, attention_mask, token_type_ids = zip(*data)
|
257 |
+
|
258 |
+
return {
|
259 |
+
"input_ids": input_ids,
|
260 |
+
"attention_mask": attention_mask,
|
261 |
+
"token_type_ids": token_type_ids,
|
262 |
+
"targets": labels
|
263 |
+
}
|
264 |
+
|
265 |
+
|
266 |
+
# The following functions are modified from the relevant codes of https://github.com/aehrc/LAAT
|
267 |
+
def roc_auc(true_labels, pred_probs, average="macro"):
|
268 |
+
if pred_probs.shape[0] <= 1:
|
269 |
+
return
|
270 |
+
|
271 |
+
fpr = {}
|
272 |
+
tpr = {}
|
273 |
+
if average == "macro":
|
274 |
+
# get AUC for each label individually
|
275 |
+
relevant_labels = []
|
276 |
+
auc_labels = {}
|
277 |
+
for i in range(true_labels.shape[1]):
|
278 |
+
# only if there are true positives for this label
|
279 |
+
if true_labels[:, i].sum() > 0:
|
280 |
+
fpr[i], tpr[i], _ = roc_curve(true_labels[:, i], pred_probs[:, i])
|
281 |
+
if len(fpr[i]) > 1 and len(tpr[i]) > 1:
|
282 |
+
auc_score = auc(fpr[i], tpr[i])
|
283 |
+
if not np.isnan(auc_score):
|
284 |
+
auc_labels["auc_%d" % i] = auc_score
|
285 |
+
relevant_labels.append(i)
|
286 |
+
|
287 |
+
# macro-AUC: just average the auc scores
|
288 |
+
aucs = []
|
289 |
+
for i in relevant_labels:
|
290 |
+
aucs.append(auc_labels['auc_%d' % i])
|
291 |
+
score = np.mean(aucs)
|
292 |
+
else:
|
293 |
+
# micro-AUC: just look at each individual prediction
|
294 |
+
flat_pred = pred_probs.ravel()
|
295 |
+
fpr["micro"], tpr["micro"], _ = roc_curve(true_labels.ravel(), flat_pred)
|
296 |
+
score = auc(fpr["micro"], tpr["micro"])
|
297 |
+
|
298 |
+
return score
|
299 |
+
|
300 |
+
|
301 |
+
def union_size(x, y, axis):
|
302 |
+
return np.logical_or(x, y).sum(axis=axis).astype(float)
|
303 |
+
|
304 |
+
|
305 |
+
def intersect_size(x, y, axis):
|
306 |
+
return np.logical_and(x, y).sum(axis=axis).astype(float)
|
307 |
+
|
308 |
+
|
309 |
+
def macro_accuracy(true_labels, pred_labels):
|
310 |
+
num = intersect_size(true_labels, pred_labels, 0) / (union_size(true_labels, pred_labels, 0) + 1e-10)
|
311 |
+
return np.mean(num)
|
312 |
+
|
313 |
+
|
314 |
+
def macro_precision(true_labels, pred_labels):
|
315 |
+
num = intersect_size(true_labels, pred_labels, 0) / (pred_labels.sum(axis=0) + 1e-10)
|
316 |
+
return np.mean(num)
|
317 |
+
|
318 |
+
|
319 |
+
def macro_recall(true_labels, pred_labels):
|
320 |
+
num = intersect_size(true_labels, pred_labels, 0) / (true_labels.sum(axis=0) + 1e-10)
|
321 |
+
return np.mean(num)
|
322 |
+
|
323 |
+
|
324 |
+
def macro_f1(true_labels, pred_labels):
|
325 |
+
prec = macro_precision(true_labels, pred_labels)
|
326 |
+
rec = macro_recall(true_labels, pred_labels)
|
327 |
+
if prec + rec == 0:
|
328 |
+
f1 = 0.
|
329 |
+
else:
|
330 |
+
f1 = 2 * (prec * rec) / (prec + rec)
|
331 |
+
return prec, rec, f1
|
332 |
+
|
333 |
+
|
334 |
+
def precision_at_k(true_labels, pred_probs, ks=[1, 5, 8, 10, 15]):
|
335 |
+
# num true labels in top k predictions / k
|
336 |
+
sorted_pred = np.argsort(pred_probs)[:, ::-1]
|
337 |
+
output = []
|
338 |
+
for k in ks:
|
339 |
+
topk = sorted_pred[:, :k]
|
340 |
+
|
341 |
+
# get precision at k for each example
|
342 |
+
vals = []
|
343 |
+
for i, tk in enumerate(topk):
|
344 |
+
if len(tk) > 0:
|
345 |
+
num_true_in_top_k = true_labels[i, tk].sum()
|
346 |
+
denom = len(tk)
|
347 |
+
vals.append(num_true_in_top_k / float(denom))
|
348 |
+
|
349 |
+
output.append(np.mean(vals))
|
350 |
+
return output
|
351 |
+
|
352 |
+
|
353 |
+
def micro_recall(true_labels, pred_labels):
|
354 |
+
flat_true = true_labels.ravel()
|
355 |
+
flat_pred = pred_labels.ravel()
|
356 |
+
return intersect_size(flat_true, flat_pred, 0) / flat_true.sum(axis=0)
|
357 |
+
|
358 |
+
|
359 |
+
def micro_precision(true_labels, pred_labels):
|
360 |
+
flat_true = true_labels.ravel()
|
361 |
+
flat_pred = pred_labels.ravel()
|
362 |
+
if flat_pred.sum(axis=0) == 0:
|
363 |
+
return 0.0
|
364 |
+
return intersect_size(flat_true, flat_pred, 0) / flat_pred.sum(axis=0)
|
365 |
+
|
366 |
+
|
367 |
+
def micro_f1(true_labels, pred_labels):
|
368 |
+
prec = micro_precision(true_labels, pred_labels)
|
369 |
+
rec = micro_recall(true_labels, pred_labels)
|
370 |
+
if prec + rec == 0:
|
371 |
+
f1 = 0.
|
372 |
+
else:
|
373 |
+
f1 = 2 * (prec * rec) / (prec + rec)
|
374 |
+
return prec, rec, f1
|
375 |
+
|
376 |
+
|
377 |
+
def micro_accuracy(true_labels, pred_labels):
|
378 |
+
flat_true = true_labels.ravel()
|
379 |
+
flat_pred = pred_labels.ravel()
|
380 |
+
return intersect_size(flat_true, flat_pred, 0) / union_size(flat_true, flat_pred, 0)
|
381 |
+
|
382 |
+
|
383 |
+
def calculate_scores(true_labels, logits, average="macro", is_multilabel=True, threshold=0.5):
|
384 |
+
def sigmoid(x):
|
385 |
+
return 1 / (1 + np.exp(-x))
|
386 |
+
|
387 |
+
pred_probs = sigmoid(logits)
|
388 |
+
pred_labels = np.rint(pred_probs - threshold + 0.5)
|
389 |
+
|
390 |
+
max_size = min(len(true_labels), len(pred_labels))
|
391 |
+
true_labels = true_labels[: max_size]
|
392 |
+
pred_labels = pred_labels[: max_size]
|
393 |
+
pred_probs = pred_probs[: max_size]
|
394 |
+
p_1 = 0
|
395 |
+
p_5 = 0
|
396 |
+
p_8 = 0
|
397 |
+
p_10 = 0
|
398 |
+
p_15 = 0
|
399 |
+
if pred_probs is not None:
|
400 |
+
if not is_multilabel:
|
401 |
+
normalised_labels = normalise_labels(true_labels, len(pred_probs[0]))
|
402 |
+
auc_score = roc_auc(normalised_labels, pred_probs, average=average)
|
403 |
+
accuracy = accuracy_score(true_labels, pred_labels)
|
404 |
+
precision = precision_score(true_labels, pred_labels, average=average)
|
405 |
+
recall = recall_score(true_labels, pred_labels, average=average)
|
406 |
+
f1 = f1_score(true_labels, pred_labels, average=average)
|
407 |
+
else:
|
408 |
+
if average == "macro":
|
409 |
+
accuracy = macro_accuracy(true_labels, pred_labels) # categorical accuracy
|
410 |
+
precision, recall, f1 = macro_f1(true_labels, pred_labels)
|
411 |
+
p_ks = precision_at_k(true_labels, pred_probs, [1, 5, 8, 10, 15])
|
412 |
+
p_1 = p_ks[0]
|
413 |
+
p_5 = p_ks[1]
|
414 |
+
p_8 = p_ks[2]
|
415 |
+
p_10 = p_ks[3]
|
416 |
+
p_15 = p_ks[4]
|
417 |
+
|
418 |
+
else:
|
419 |
+
accuracy = micro_accuracy(true_labels, pred_labels)
|
420 |
+
precision, recall, f1 = micro_f1(true_labels, pred_labels)
|
421 |
+
auc_score = roc_auc(true_labels, pred_probs, average)
|
422 |
+
|
423 |
+
# Calculate label-wise F1 scores
|
424 |
+
labelwise_f1 = f1_score(true_labels, pred_labels, average=None)
|
425 |
+
labelwise_f1 = np.array2string(labelwise_f1, separator=',')
|
426 |
+
|
427 |
+
|
428 |
+
else:
|
429 |
+
auc_score = -1
|
430 |
+
|
431 |
+
output = {"{}_precision".format(average): precision, "{}_recall".format(average): recall,
|
432 |
+
"{}_f1".format(average): f1, "{}_accuracy".format(average): accuracy,
|
433 |
+
"{}_auc".format(average): auc_score, "{}_P@1".format(average): p_1, "{}_P@5".format(average): p_5,
|
434 |
+
"{}_P@8".format(average): p_8, "{}_P@10".format(average): p_10, "{}_P@15".format(average): p_15,
|
435 |
+
"labelwise_f1": labelwise_f1
|
436 |
+
}
|
437 |
+
|
438 |
+
return output
|
439 |
+
|
440 |
+
|