anicolson commited on
Commit
f26c980
·
1 Parent(s): f1d30ab

Delete modelling_variable.py

Browse files
Files changed (1) hide show
  1. modelling_variable.py +0 -425
modelling_variable.py DELETED
@@ -1,425 +0,0 @@
1
- import os
2
- from typing import Any, Optional, Tuple, Union
3
-
4
- import torch
5
- import transformers
6
- from torch.nn import CrossEntropyLoss
7
- from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel
8
- from transformers.configuration_utils import PretrainedConfig
9
- from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
10
- from transformers.modeling_utils import PreTrainedModel
11
- from transformers.models.vision_encoder_decoder.configuration_vision_encoder_decoder import \
12
- VisionEncoderDecoderConfig
13
- from transformers.utils import logging
14
-
15
- logger = logging.get_logger(__name__)
16
-
17
-
18
- class CvtWithProjectionHeadConfig(transformers.CvtConfig):
19
- def __init__(self, projection_size: int = None, **kwargs: Any) -> None:
20
- super().__init__(**kwargs)
21
- self.projection_size = projection_size
22
-
23
-
24
- class ModelOutputWithProjectionEmbedding(transformers.modeling_outputs.ModelOutput):
25
- last_hidden_state: torch.FloatTensor
26
- attention_mask: torch.FloatTensor
27
-
28
-
29
- class CvtProjectionHead(torch.nn.Module):
30
-
31
- def __init__(self, config) -> None:
32
- super().__init__()
33
-
34
- # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/models/cvt/modeling_cvt.py#L657
35
- self.layer_norm = torch.nn.LayerNorm(config.embed_dim[-1], eps=config.layer_norm_eps)
36
-
37
- # No bias as following layer normalisation with bias:
38
- self.projection = torch.nn.Linear(config.embed_dim[-1], config.projection_size, bias=False)
39
-
40
-
41
- def forward(self, x: torch.Tensor) -> torch.Tensor:
42
- x = self.layer_norm(x)
43
- x = self.projection(x)
44
- return x
45
-
46
-
47
- class VariableCvtWithProjectionHead(transformers.CvtPreTrainedModel):
48
- def __init__(self, config):
49
- super().__init__(config)
50
-
51
- self.cvt = transformers.CvtModel(config, add_pooling_layer=False)
52
- self.projection_head = CvtProjectionHead(config)
53
-
54
- # Initialize weights and apply final processing:
55
- self.post_init()
56
-
57
- def forward(
58
- self,
59
- pixel_values: Optional[torch.Tensor] = None,
60
- output_hidden_states: Optional[bool] = None,
61
- return_dict: Optional[bool] = None,
62
- ) -> Union[Tuple, ModelOutputWithProjectionEmbedding]:
63
-
64
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
65
-
66
- # Flatten the batch and study_id dimensions:
67
- outputs = self.cvt(
68
- pixel_values.view(-1, *pixel_values.shape[2:]),
69
- output_hidden_states=output_hidden_states,
70
- return_dict=return_dict,
71
- )
72
-
73
- # Flatten h x w:
74
- last_hidden_state = torch.flatten(outputs.last_hidden_state, 2)
75
-
76
- # Project the features for each spatial position to the decoder's hidden size:
77
- projection = self.projection_head(torch.permute(last_hidden_state, [0, 2, 1]))
78
-
79
- # Concatenate the features for each chest X-ray:
80
- projection = projection.view(pixel_values.shape[0], -1, projection.shape[-1])
81
-
82
- # Derive the attention mask from the pixel values:
83
- attention_mask = (pixel_values[:, :, 0, 0, 0] != 0.0).repeat_interleave(last_hidden_state.shape[-1], dim=1)
84
-
85
- if not return_dict:
86
- return projection
87
-
88
- return ModelOutputWithProjectionEmbedding(
89
- last_hidden_state=projection, attention_mask=attention_mask,
90
- )
91
-
92
-
93
- class VariableCXREncoderDecoderModel(VisionEncoderDecoderModel):
94
-
95
- config_class = VisionEncoderDecoderConfig
96
- base_model_prefix = "vision_encoder_decoder"
97
- main_input_name = "pixel_values"
98
- supports_gradient_checkpointing = True
99
-
100
- def __init__(
101
- self,
102
- config: Optional[PretrainedConfig] = None,
103
- encoder: Optional[PreTrainedModel] = None,
104
- decoder: Optional[PreTrainedModel] = None,
105
- ):
106
-
107
- if decoder:
108
- assert decoder.config.add_cross_attention, '"add_cross_attention" must be True for the given decoder'
109
- assert decoder.config.is_decoder, '"is_decoder" must be True for the given decoder'
110
-
111
- if config is None and (encoder is None or decoder is None):
112
- raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
113
- if config is None:
114
- config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
115
- else:
116
- if not isinstance(config, self.config_class):
117
- raise ValueError(f"Config: {config} has to be of type {self.config_class}")
118
-
119
- config.tie_word_embeddings = False
120
-
121
- # initialize with config
122
- PreTrainedModel.__init__(self, config)
123
-
124
- # Encoder:
125
- if encoder is None:
126
- encoder = VariableCvtWithProjectionHead(config=config.encoder)
127
-
128
- # Decoder:
129
- if decoder is None:
130
- decoder = transformers.BertLMHeadModel(config=config.decoder)
131
-
132
- self.encoder = encoder
133
- self.decoder = decoder
134
-
135
- if self.encoder.config.to_dict() != self.config.encoder.to_dict():
136
- logger.warning(
137
- f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
138
- f" {self.config.encoder}"
139
- )
140
- if self.decoder.config.to_dict() != self.config.decoder.to_dict():
141
- logger.warning(
142
- f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
143
- f" {self.config.decoder}"
144
- )
145
-
146
- self.encoder.config = self.config.encoder
147
- self.decoder.config = self.config.decoder
148
-
149
- # config.add_cross_attention = True
150
- # config.is_decoder = True
151
-
152
- def forward(
153
- self,
154
- pixel_values: Optional[torch.FloatTensor] = None,
155
- decoder_input_ids: Optional[torch.LongTensor] = None,
156
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
157
- encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
158
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
159
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
160
- labels: Optional[torch.LongTensor] = None,
161
- use_cache: Optional[bool] = None,
162
- output_attentions: Optional[bool] = None,
163
- output_hidden_states: Optional[bool] = None,
164
- return_dict: Optional[bool] = None,
165
- **kwargs,
166
- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
167
-
168
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
169
-
170
- kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
171
-
172
- kwargs_decoder = {
173
- argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
174
- }
175
-
176
- if encoder_outputs is None:
177
- if pixel_values is None:
178
- raise ValueError("You have to specify pixel_values")
179
-
180
- encoder_outputs = self.encoder(
181
- pixel_values,
182
- output_hidden_states=output_hidden_states,
183
- return_dict=return_dict,
184
- **kwargs_encoder,
185
- ) # CvT does not support output_attentions.
186
-
187
- elif isinstance(encoder_outputs, tuple):
188
- encoder_outputs = BaseModelOutput(*encoder_outputs)
189
-
190
- encoder_hidden_states = encoder_outputs[0]
191
-
192
- decoder_outputs = self.decoder(
193
- input_ids=decoder_input_ids,
194
- attention_mask=decoder_attention_mask,
195
- encoder_hidden_states=encoder_hidden_states,
196
- encoder_attention_mask=encoder_outputs.attention_mask,
197
- inputs_embeds=decoder_inputs_embeds,
198
- output_attentions=output_attentions,
199
- output_hidden_states=output_hidden_states,
200
- use_cache=use_cache,
201
- past_key_values=past_key_values,
202
- return_dict=return_dict,
203
- **kwargs_decoder,
204
- )
205
-
206
- # Loss:
207
- loss = None
208
- if labels is not None:
209
- logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
210
- loss_fct = CrossEntropyLoss()
211
- loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
212
-
213
- if not return_dict:
214
- if loss is not None:
215
- return (loss,) + decoder_outputs + encoder_outputs
216
- else:
217
- return decoder_outputs + encoder_outputs
218
-
219
- return Seq2SeqLMOutput(
220
- loss=loss,
221
- logits=decoder_outputs.logits,
222
- past_key_values=decoder_outputs.past_key_values,
223
- decoder_hidden_states=decoder_outputs.hidden_states,
224
- decoder_attentions=decoder_outputs.attentions,
225
- cross_attentions=decoder_outputs.cross_attentions,
226
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
227
- # encoder_hidden_states=encoder_outputs.hidden_states,
228
- # encoder_attentions=encoder_outputs.attentions,
229
- )
230
-
231
- def prepare_inputs_for_generation(
232
- self,
233
- input_ids,
234
- special_token_ids,
235
- past_key_values=None,
236
- attention_mask=None,
237
- use_cache=None,
238
- encoder_outputs=None,
239
- **kwargs,
240
- ):
241
- """
242
- Modification of:
243
- https://github.com/huggingface/transformers/blob/main/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L660
244
- """
245
-
246
- decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
247
- decoder_attention_mask = decoder_inputs['attention_mask'] if 'attention_mask' in decoder_inputs else None
248
-
249
- if not past_key_values:
250
- token_type_ids = self.token_ids_to_token_type_ids(input_ids, special_token_ids)
251
- else:
252
- token_type_ids = self.token_ids_to_token_type_ids_past(input_ids, special_token_ids)
253
-
254
- input_dict = {
255
- 'attention_mask': attention_mask,
256
- 'decoder_attention_mask': decoder_attention_mask,
257
- 'decoder_input_ids': decoder_inputs['input_ids'],
258
- 'decoder_token_type_ids': token_type_ids,
259
- 'encoder_outputs': encoder_outputs,
260
- 'past_key_values': decoder_inputs['past_key_values'],
261
- 'use_cache': use_cache,
262
- }
263
- return input_dict
264
-
265
- def token_ids_to_token_type_ids(self, token_ids, special_token_ids, token_type_id_sections=None):
266
- """
267
- Extract token type identifiers from the token identifiers.
268
-
269
- Argument/s:
270
- token_ids - token identifiers.
271
- special_token_ids - special token identifiers that indicate the separation between sections.
272
- token_type_id_section - token type identifier for each section.
273
-
274
- Returns:
275
- token_type_ids - token type identifiers.
276
- """
277
-
278
- token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
279
-
280
- mbatch_size, seq_len = token_ids.shape
281
- token_type_ids = torch.full_like(token_ids, token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
282
-
283
- for i, j in enumerate(special_token_ids):
284
- # Find first occurrence of special tokens that indicate the boundary between sections:
285
- cols = (token_ids == j).int().argmax(dim=1)
286
- rows = torch.arange(mbatch_size, device=token_ids.device)
287
-
288
- # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
289
- cols += 1
290
-
291
- # Ensure that the column index is not out of bounds. If 0, then token_id not present.
292
- # This is safe as index 0 is always a special token (now equal to 1 due to +1):
293
- rows = rows[torch.logical_and(cols != 1, cols < seq_len)]
294
- cols = cols[torch.logical_and(cols != 1, cols < seq_len)]
295
-
296
- # Indices to that correspond to the second sequence:
297
- if rows.nelement() != 0:
298
- ids = torch.stack([
299
- torch.stack([x, z]) for (x, y) in zip(rows, cols) for z in torch.arange(
300
- y, seq_len, device=token_ids.device,
301
- )
302
- ])
303
-
304
- token_type_ids[ids[:, 0], ids[:, 1]] = token_type_id_sections[i + 1]
305
-
306
- return token_type_ids
307
-
308
- def token_ids_to_token_type_ids_past(self, token_ids, special_token_ids, token_type_id_sections=None):
309
- """
310
- Extract token type identifiers from the token identifiers if past != None.
311
-
312
- Argument/s:
313
- token_ids - token identifiers.
314
- special_token_ids - special token identifiers that indicate the separation between sections.
315
-
316
- Returns:
317
- token_type_ids - token type identifiers.
318
- """
319
-
320
- token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
321
- token_type_ids = torch.full([token_ids.shape[0], 1], token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
322
-
323
- # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
324
- token_ids = token_ids[:, :-1]
325
-
326
- for i, j in enumerate(special_token_ids):
327
-
328
- # Find first occurrence of special token, which indicates the boundary between sections:
329
- exists = torch.any(token_ids == j, dim=1, keepdim=True)
330
- token_type_ids[exists] = token_type_id_sections[i + 1]
331
-
332
- return token_type_ids
333
-
334
- def tokenize_report_teacher_forcing(self, findings: str, impression: str, tokenizer: PreTrainedTokenizerFast, max_len: int):
335
- """
336
- Tokenize the reports and creates the inputs and targets for teacher forcing.
337
-
338
- Argument/s:
339
- findings - findings section.
340
- impression - impression section.
341
- return_token_type_ids - return the token type identifiers.
342
- tokenizer - Hugging Face tokenizer.
343
- max_len - maximum number of tokens.
344
-
345
- Returns:
346
- decoder_input_ids - the token identifiers for the input of the decoder.
347
- decoder_attention_mask - the attention mask for the decoder_input_ids.
348
- label_ids - the label token identifiers for the decoder.
349
- """
350
-
351
- # Prepare the sections for the tokenizer by placing special tokens between each section:
352
- report = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in
353
- zip(findings, impression)]
354
-
355
- # Tokenize the report:
356
- tokenized = tokenizer(
357
- report,
358
- padding='longest',
359
- truncation=True,
360
- max_length=max_len + 1, # +1 to account for the bias between input and target.
361
- return_tensors='pt',
362
- return_token_type_ids=False,
363
- add_special_tokens=False,
364
- ).to(self.device)
365
-
366
- # Modify for language modelling:
367
- batch_dict = {
368
-
369
- # Labels for the decoder (shifted right by one for autoregression):
370
- 'label_ids': tokenized['input_ids'][:, 1:].detach().clone(),
371
-
372
- # Remove last token identifier to match the sequence length of the labels:
373
- 'decoder_input_ids': tokenized['input_ids'][:, :-1],
374
-
375
- # Attention mask for the decoder_input_ids (remove first token so that the eos_token_id is not considered):
376
- 'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
377
- }
378
-
379
- return batch_dict
380
-
381
- def split_and_decode_sections(self, token_ids, special_token_ids, tokenizer: PreTrainedTokenizerFast):
382
- """
383
- Split the token identifiers into sections, then convert the token identifiers into strings.
384
-
385
- Argument/s:
386
- token_ids - token identifiers.
387
- special_token_ids - special token identifiers that indicate the end of each section.
388
- tokenizer - Hugging Face tokenizer.
389
-
390
- Returns:
391
- token_type_ids - token type identifiers.
392
- """
393
-
394
- _, seq_len = token_ids.shape
395
-
396
- # The number of sections is the same as the number of special_token_ids:
397
- num_sections = len(special_token_ids)
398
-
399
- sections = {k: [] for k in range(num_sections)}
400
-
401
- for i in token_ids:
402
- prev_col = 0
403
- for j, k in enumerate(special_token_ids):
404
-
405
- # The maximum sequence length was exceeded, thus no more tokens:
406
- if prev_col >= seq_len:
407
- sections[j].append('')
408
- continue
409
-
410
- # Find first occurrence of special tokens that indicate the boundary between sections:
411
- col = (i == k).int().argmax().item()
412
-
413
- # If equal to 0, token was not found, set the column to the sequence length (as the decoder exceeded
414
- # the maximum sequence length):
415
- if col == 0:
416
- col = seq_len
417
-
418
- # Extract section token identifiers:
419
- section_token_ids = i[prev_col:col]
420
- prev_col = col
421
- section_string = tokenizer.decode(section_token_ids, skip_special_tokens=True)
422
-
423
- sections[j].append(section_string)
424
-
425
- return tuple(sections.values())