Spaces:
Sleeping
Sleeping
meghanaraok
commited on
Delete HiLATmain/models/modeling - Copy1.py
Browse files
HiLATmain/models/modeling - Copy1.py
DELETED
@@ -1,337 +0,0 @@
|
|
1 |
-
import collections
|
2 |
-
import logging
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from torch.nn import BCEWithLogitsLoss, Dropout, Linear
|
6 |
-
from transformers import AutoModel, XLNetModel, LongformerModel, LongformerConfig
|
7 |
-
from transformers.models.longformer.modeling_longformer import LongformerEncoder, LongformerClassificationHead, LongformerLayer
|
8 |
-
|
9 |
-
from hilat.models.utils import initial_code_title_vectors
|
10 |
-
|
11 |
-
logger = logging.getLogger("lwat")
|
12 |
-
|
13 |
-
|
14 |
-
class CodingModelConfig:
|
15 |
-
def __init__(self,
|
16 |
-
transformer_model_name_or_path,
|
17 |
-
transformer_tokenizer_name,
|
18 |
-
transformer_layer_update_strategy,
|
19 |
-
num_chunks,
|
20 |
-
max_seq_length,
|
21 |
-
dropout,
|
22 |
-
dropout_att,
|
23 |
-
d_model,
|
24 |
-
label_dictionary,
|
25 |
-
num_labels,
|
26 |
-
use_code_representation,
|
27 |
-
code_max_seq_length,
|
28 |
-
code_batch_size,
|
29 |
-
multi_head_att,
|
30 |
-
chunk_att,
|
31 |
-
linear_init_mean,
|
32 |
-
linear_init_std,
|
33 |
-
document_pooling_strategy,
|
34 |
-
multi_head_chunk_attention):
|
35 |
-
super(CodingModelConfig, self).__init__()
|
36 |
-
self.transformer_model_name_or_path = transformer_model_name_or_path
|
37 |
-
self.transformer_tokenizer_name = transformer_tokenizer_name
|
38 |
-
self.transformer_layer_update_strategy = transformer_layer_update_strategy
|
39 |
-
self.num_chunks = num_chunks
|
40 |
-
self.max_seq_length = max_seq_length
|
41 |
-
self.dropout = dropout
|
42 |
-
self.dropout_att = dropout_att
|
43 |
-
self.d_model = d_model
|
44 |
-
# labels_dictionary is a dataframe with columns: icd9_code, long_title
|
45 |
-
self.label_dictionary = label_dictionary
|
46 |
-
self.num_labels = num_labels
|
47 |
-
self.use_code_representation = use_code_representation
|
48 |
-
self.code_max_seq_length = code_max_seq_length
|
49 |
-
self.code_batch_size = code_batch_size
|
50 |
-
self.multi_head_att = multi_head_att
|
51 |
-
self.chunk_att = chunk_att
|
52 |
-
self.linear_init_mean = linear_init_mean
|
53 |
-
self.linear_init_std = linear_init_std
|
54 |
-
self.document_pooling_strategy = document_pooling_strategy
|
55 |
-
self.multi_head_chunk_attention = multi_head_chunk_attention
|
56 |
-
|
57 |
-
|
58 |
-
class LableWiseAttentionLayer(torch.nn.Module):
|
59 |
-
def __init__(self, coding_model_config, args):
|
60 |
-
super(LableWiseAttentionLayer, self).__init__()
|
61 |
-
|
62 |
-
self.config = coding_model_config
|
63 |
-
self.args = args
|
64 |
-
|
65 |
-
# layers
|
66 |
-
self.l1_linear = torch.nn.Linear(self.config.d_model,
|
67 |
-
self.config.d_model, bias=False)
|
68 |
-
self.tanh = torch.nn.Tanh()
|
69 |
-
self.l2_linear = torch.nn.Linear(self.config.d_model, self.config.num_labels, bias=False)
|
70 |
-
self.softmax = torch.nn.Softmax(dim=1)
|
71 |
-
|
72 |
-
# Mean pooling last hidden state of code title from transformer model as the initial code vectors
|
73 |
-
self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)
|
74 |
-
|
75 |
-
def _init_linear_weights(self, mean, std):
|
76 |
-
# normalize the l1 weights
|
77 |
-
torch.nn.init.normal_(self.l1_linear.weight, mean, std)
|
78 |
-
if self.l1_linear.bias is not None:
|
79 |
-
self.l1_linear.bias.data.fill_(0)
|
80 |
-
# initialize the l2
|
81 |
-
if self.config.use_code_representation:
|
82 |
-
code_vectors = initial_code_title_vectors(self.config.label_dictionary,
|
83 |
-
self.config.transformer_model_name_or_path,
|
84 |
-
self.config.transformer_tokenizer_name
|
85 |
-
if self.config.transformer_tokenizer_name
|
86 |
-
else self.config.transformer_model_name_or_path,
|
87 |
-
self.config.code_max_seq_length,
|
88 |
-
self.config.code_batch_size,
|
89 |
-
self.config.d_model,
|
90 |
-
self.args.device)
|
91 |
-
|
92 |
-
self.l2_linear.weight = torch.nn.Parameter(code_vectors, requires_grad=True)
|
93 |
-
torch.nn.init.normal_(self.l2_linear.weight, mean, std)
|
94 |
-
if self.l2_linear.bias is not None:
|
95 |
-
self.l2_linear.bias.data.fill_(0)
|
96 |
-
|
97 |
-
def forward(self, x):
|
98 |
-
# input: (batch_size, max_seq_length, transformer_hidden_size)
|
99 |
-
# output: (batch_size, max_seq_length, transformer_hidden_size)
|
100 |
-
# Z = Tan(WH)
|
101 |
-
l1_output = self.tanh(self.l1_linear(x))
|
102 |
-
# softmax(UZ)
|
103 |
-
# l2_linear output shape: (batch_size, max_seq_length, num_labels)
|
104 |
-
# attention_weight shape: (batch_size, num_labels, max_seq_length)
|
105 |
-
attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
|
106 |
-
# attention_output shpae: (batch_size, num_labels, transformer_hidden_size)
|
107 |
-
attention_output = torch.matmul(attention_weight, x)
|
108 |
-
|
109 |
-
return attention_output, attention_weight
|
110 |
-
|
111 |
-
class ChunkAttentionLayer(torch.nn.Module):
|
112 |
-
def __init__(self, coding_model_config, args):
|
113 |
-
super(ChunkAttentionLayer, self).__init__()
|
114 |
-
|
115 |
-
self.config = coding_model_config
|
116 |
-
self.args = args
|
117 |
-
|
118 |
-
# layers
|
119 |
-
self.l1_linear = torch.nn.Linear(self.config.d_model,
|
120 |
-
self.config.d_model, bias=False)
|
121 |
-
self.tanh = torch.nn.Tanh()
|
122 |
-
self.l2_linear = torch.nn.Linear(self.config.d_model, 1, bias=False)
|
123 |
-
self.softmax = torch.nn.Softmax(dim=1)
|
124 |
-
|
125 |
-
self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)
|
126 |
-
|
127 |
-
def _init_linear_weights(self, mean, std):
|
128 |
-
# initialize the l1
|
129 |
-
torch.nn.init.normal_(self.l1_linear.weight, mean, std)
|
130 |
-
if self.l1_linear.bias is not None:
|
131 |
-
self.l1_linear.bias.data.fill_(0)
|
132 |
-
# initialize the l2
|
133 |
-
torch.nn.init.normal_(self.l2_linear.weight, mean, std)
|
134 |
-
if self.l2_linear.bias is not None:
|
135 |
-
self.l2_linear.bias.data.fill_(0)
|
136 |
-
|
137 |
-
def forward(self, x):
|
138 |
-
# input: (batch_size, num_chunks, transformer_hidden_size)
|
139 |
-
# output: (batch_size, num_chunks, transformer_hidden_size)
|
140 |
-
# Z = Tan(WH)
|
141 |
-
l1_output = self.tanh(self.l1_linear(x))
|
142 |
-
# softmax(UZ)
|
143 |
-
# l2_linear output shape: (batch_size, num_chunks, 1)
|
144 |
-
# attention_weight shape: (batch_size, 1, num_chunks)
|
145 |
-
attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
|
146 |
-
# attention_output shpae: (batch_size, 1, transformer_hidden_size)
|
147 |
-
attention_output = torch.matmul(attention_weight, x)
|
148 |
-
return attention_output, attention_weight
|
149 |
-
|
150 |
-
|
151 |
-
class CodingModel(torch.nn.Module):
|
152 |
-
def __init__(self, coding_model_config, args):
|
153 |
-
super(CodingModel, self).__init__()
|
154 |
-
self.coding_model_config = coding_model_config
|
155 |
-
self.args = args
|
156 |
-
# layers
|
157 |
-
self.transformer_layer = AutoModel.from_pretrained(self.coding_model_config.transformer_model_name_or_path)
|
158 |
-
if isinstance(self.transformer_layer, XLNetModel):
|
159 |
-
self.transformer_layer.config.use_mems_eval = False
|
160 |
-
self.dropout = Dropout(p=self.coding_model_config.dropout)
|
161 |
-
|
162 |
-
if self.coding_model_config.multi_head_att:
|
163 |
-
# initial multi head attention according to the num_chunks
|
164 |
-
self.label_wise_attention_layer = torch.nn.ModuleList(
|
165 |
-
[LableWiseAttentionLayer(coding_model_config, args)
|
166 |
-
for _ in range(self.coding_model_config.num_chunks)])
|
167 |
-
else:
|
168 |
-
self.label_wise_attention_layer = LableWiseAttentionLayer(coding_model_config, args)
|
169 |
-
self.dropout_att = Dropout(p=self.coding_model_config.dropout_att)
|
170 |
-
|
171 |
-
# initial chunk attention
|
172 |
-
if self.coding_model_config.chunk_att:
|
173 |
-
if self.coding_model_config.multi_head_chunk_attention:
|
174 |
-
self.chunk_attention_layer = torch.nn.ModuleList([ChunkAttentionLayer(coding_model_config, args)
|
175 |
-
for _ in range(self.coding_model_config.num_labels)])
|
176 |
-
else:
|
177 |
-
self.chunk_attention_layer = ChunkAttentionLayer(coding_model_config, args)
|
178 |
-
|
179 |
-
self.classifier_layer = Linear(self.coding_model_config.d_model,
|
180 |
-
self.coding_model_config.num_labels)
|
181 |
-
else:
|
182 |
-
if self.coding_model_config.document_pooling_strategy == "flat":
|
183 |
-
self.classifier_layer = Linear(self.coding_model_config.num_chunks * self.coding_model_config.d_model,
|
184 |
-
self.coding_model_config.num_labels)
|
185 |
-
else: # max or mean pooling
|
186 |
-
self.classifier_layer = Linear(self.coding_model_config.d_model,
|
187 |
-
self.coding_model_config.num_labels)
|
188 |
-
self.sigmoid = torch.nn.Sigmoid()
|
189 |
-
|
190 |
-
if self.coding_model_config.transformer_layer_update_strategy == "no":
|
191 |
-
self.freeze_all_transformer_layers()
|
192 |
-
elif self.coding_model_config.transformer_layer_update_strategy == "last":
|
193 |
-
self.freeze_all_transformer_layers()
|
194 |
-
self.unfreeze_transformer_last_layers()
|
195 |
-
|
196 |
-
# initialize the weights of classifier
|
197 |
-
self._init_linear_weights(mean=self.coding_model_config.linear_init_mean, std=self.coding_model_config.linear_init_std)
|
198 |
-
|
199 |
-
def _init_linear_weights(self, mean, std):
|
200 |
-
torch.nn.init.normal_(self.classifier_layer.weight, mean, std)
|
201 |
-
|
202 |
-
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, targets=None):
|
203 |
-
# input ids/mask/type_ids shape: (batch_size, num_chunks, max_seq_length)
|
204 |
-
# labels shape: (batch_size, num_labels)
|
205 |
-
transformer_output = []
|
206 |
-
|
207 |
-
# pass chunk by chunk into transformer layer in the batches.
|
208 |
-
# input (batch_size, sequence_length)
|
209 |
-
for i in range(self.coding_model_config.num_chunks):
|
210 |
-
l1_output = self.transformer_layer(input_ids=input_ids[:, i, :],
|
211 |
-
attention_mask=attention_mask[:, i, :],
|
212 |
-
token_type_ids=token_type_ids[:, i, :])
|
213 |
-
# output hidden state shape: (batch_size, sequence_length, hidden_size)
|
214 |
-
transformer_output.append(l1_output[0])
|
215 |
-
|
216 |
-
# transpose back chunk and batch size dimensions
|
217 |
-
transformer_output = torch.stack(transformer_output)
|
218 |
-
transformer_output = transformer_output.transpose(0, 1)
|
219 |
-
# dropout transformer output
|
220 |
-
l2_dropout = self.dropout(transformer_output)
|
221 |
-
|
222 |
-
# Label-wise attention layers
|
223 |
-
# output: (batch_size, num_chunks, num_labels, hidden_size)
|
224 |
-
attention_output = []
|
225 |
-
attention_weights = []
|
226 |
-
|
227 |
-
for i in range(self.coding_model_config.num_chunks):
|
228 |
-
# input: (batch_size, max_seq_length, transformer_hidden_size)
|
229 |
-
if self.coding_model_config.multi_head_att:
|
230 |
-
attention_layer = self.label_wise_attention_layer[i]
|
231 |
-
else:
|
232 |
-
attention_layer = self.label_wise_attention_layer
|
233 |
-
l3_attention, attention_weight = attention_layer(l2_dropout[:, i, :])
|
234 |
-
# l3_attention shape: (batch_size, num_labels, hidden_size)
|
235 |
-
# attention_weight: (batch_size, num_labels, max_seq_length)
|
236 |
-
attention_output.append(l3_attention)
|
237 |
-
attention_weights.append(attention_weight)
|
238 |
-
|
239 |
-
attention_output = torch.stack(attention_output)
|
240 |
-
attention_output = attention_output.transpose(0, 1)
|
241 |
-
attention_weights = torch.stack(attention_weights)
|
242 |
-
attention_weights = attention_weights.transpose(0, 1)
|
243 |
-
|
244 |
-
config = LongformerConfig.from_pretrained("allenai/longformer-base-4096")
|
245 |
-
config.num_labels =5
|
246 |
-
config.num_hidden_layers = 1
|
247 |
-
longformer_layer = LongformerLayer(config)
|
248 |
-
l2_dropout= l2_dropout.reshape(l2_dropout.shape[0], l2_dropout.shape[1]*l2_dropout.shape[2], l2_dropout.shape[3])
|
249 |
-
attention_mask = attention_mask.reshape(attention_mask.shape[0], attention_mask.shape[1]*attention_mask.shape[2])
|
250 |
-
is_index_masked = attention_mask < 0
|
251 |
-
output = longformer_layer(l2_dropout, attention_mask=attention_mask,output_attentions=True, is_index_masked=is_index_masked)
|
252 |
-
l3_dropout = self.dropout_att(output[0])
|
253 |
-
l3_dropout = l3_dropout.reshape(l3_dropout.shape[0], self.coding_model_config.num_chunks, self.coding_model_config.max_seq_length, self.coding_model_config.d_model)
|
254 |
-
self.softmax = torch.nn.Softmax(dim=1)
|
255 |
-
self.l2_linear = torch.nn.Linear(self.coding_model_config.d_model, self.coding_model_config.num_labels, bias=False)
|
256 |
-
attention_weight = self.softmax(self.l2_linear(l3_dropout)).transpose(1, 2)
|
257 |
-
attention_weight = attention_weight.reshape(attention_weight.shape[0], self.coding_model_config.num_labels, self.coding_model_config.num_chunks, self.coding_model_config.max_seq_length)
|
258 |
-
# attention_weight = attention_weight.permute(0,2,1)
|
259 |
-
l2_dropout = l2_dropout.reshape(l2_dropout.shape[0], self.coding_model_config.num_chunks, self.coding_model_config.max_seq_length, self.coding_model_config.d_model)
|
260 |
-
|
261 |
-
attention_output = []
|
262 |
-
|
263 |
-
for i in range(self.coding_model_config.num_chunks):
|
264 |
-
l3_attention = torch.matmul(attention_weight[:,:,i], l2_dropout[:,i,:])
|
265 |
-
attention_output.append(l3_attention)
|
266 |
-
|
267 |
-
attention_output = torch.stack(attention_output)
|
268 |
-
l3_dropout = self.dropout_att(attention_output)
|
269 |
-
l3_dropout = l3_dropout.transpose(0,1)
|
270 |
-
|
271 |
-
|
272 |
-
if self.coding_model_config.chunk_att:
|
273 |
-
# Chunk attention layers
|
274 |
-
# output: (batch_size, num_labels, hidden_size)
|
275 |
-
chunk_attention_output = []
|
276 |
-
chunk_attention_weights = []
|
277 |
-
|
278 |
-
for i in range(self.coding_model_config.num_labels):
|
279 |
-
if self.coding_model_config.multi_head_chunk_attention:
|
280 |
-
chunk_attention = self.chunk_attention_layer[i]
|
281 |
-
else:
|
282 |
-
chunk_attention = self.chunk_attention_layer
|
283 |
-
l4_chunk_attention, l4_chunk_attention_weights = chunk_attention(l3_dropout[:, :, i])
|
284 |
-
chunk_attention_output.append(l4_chunk_attention.squeeze())
|
285 |
-
chunk_attention_weights.append(l4_chunk_attention_weights.squeeze())
|
286 |
-
|
287 |
-
chunk_attention_output = torch.stack(chunk_attention_output)
|
288 |
-
chunk_attention_output = chunk_attention_output.transpose(0, 1)
|
289 |
-
chunk_attention_weights = torch.stack(chunk_attention_weights)
|
290 |
-
chunk_attention_weights = chunk_attention_weights.transpose(0, 1)
|
291 |
-
# output shape: (batch_size, num_labels, hidden_size)
|
292 |
-
l4_dropout = self.dropout_att(chunk_attention_output)
|
293 |
-
else:
|
294 |
-
# output shape: (batch_size, num_labels, hidden_size*num_chunks)
|
295 |
-
l4_dropout = l3_dropout.transpose(1, 2)
|
296 |
-
if self.coding_model_config.document_pooling_strategy == "flat":
|
297 |
-
# Flatten layer. concatenate representation by labels
|
298 |
-
l4_dropout = torch.flatten(l4_dropout, start_dim=2)
|
299 |
-
elif self.coding_model_config.document_pooling_strategy == "max":
|
300 |
-
l4_dropout = torch.amax(l4_dropout, 2)
|
301 |
-
elif self.coding_model_config.document_pooling_strategy == "mean":
|
302 |
-
l4_dropout = torch.mean(l4_dropout, 2)
|
303 |
-
else:
|
304 |
-
raise ValueError("Not supported pooling strategy")
|
305 |
-
|
306 |
-
# classifier layer
|
307 |
-
# each code has a binary linear formula
|
308 |
-
logits = self.classifier_layer.weight.mul(l4_dropout).sum(dim=2).add(self.classifier_layer.bias)
|
309 |
-
|
310 |
-
loss_fct = BCEWithLogitsLoss()
|
311 |
-
loss = loss_fct(logits, targets)
|
312 |
-
|
313 |
-
return {
|
314 |
-
"loss": loss,
|
315 |
-
"logits": logits,
|
316 |
-
"label_attention_weights": attention_weights,
|
317 |
-
"chunk_attention_weights": chunk_attention_weights if self.coding_model_config.chunk_att else []
|
318 |
-
}
|
319 |
-
|
320 |
-
def freeze_all_transformer_layers(self):
|
321 |
-
"""
|
322 |
-
Freeze all layer weight parameters. They will not be updated during training.
|
323 |
-
"""
|
324 |
-
for param in self.transformer_layer.parameters():
|
325 |
-
param.requires_grad = False
|
326 |
-
|
327 |
-
def unfreeze_all_transformer_layers(self):
|
328 |
-
"""
|
329 |
-
Unfreeze all layers weight parameters. They will be updated during training.
|
330 |
-
"""
|
331 |
-
for param in self.transformer_layer.parameters():
|
332 |
-
param.requires_grad = True
|
333 |
-
|
334 |
-
def unfreeze_transformer_last_layers(self):
|
335 |
-
for name, param in self.transformer_layer.named_parameters():
|
336 |
-
if "layer.11" in name or "pooler" in name:
|
337 |
-
param.requires_grad = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|