HachiML commited on
Commit
2b5af3d
·
verified ·
1 Parent(s): 253f120

Upload modeling_mists.py

Browse files
Files changed (1) hide show
  1. modeling_mists.py +405 -0
modeling_mists.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+
8
+ from transformers import PreTrainedModel
9
+ from transformers.activations import ACT2FN
10
+ from transformers import Cache
11
+ from transformers.modeling_outputs import ModelOutput
12
+ from transformers.utils import (
13
+ add_start_docstrings,
14
+ add_start_docstrings_to_model_forward,
15
+ logging,
16
+ replace_return_docstrings,
17
+ )
18
+ from transformers import AutoModel, AutoModelForCausalLM
19
+
20
+ from .modeling_moment import MomentEmbeddingModel
21
+ from .configuration_mists import MistsConfig
22
+
23
+
24
+ @dataclass
25
+ # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Mists
26
+ class MistsCausalLMOutputWithPast(ModelOutput):
27
+ loss: Optional[torch.FloatTensor] = None
28
+ logits: torch.FloatTensor = None
29
+ past_key_values: Optional[List[torch.FloatTensor]] = None
30
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
31
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
32
+ time_series_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
33
+
34
+
35
+ class MistsMultiModalProjector(nn.Module):
36
+ def __init__(self, config: MistsConfig):
37
+ super().__init__()
38
+
39
+ # time series towerからのoutputは定型でない。input_maskに合わせてpadding用の学習可能なベクトルを使用し、time series towerからの入力を定型にする。
40
+ self.mask_embedding = nn.Parameter(torch.randn(1, 1, config.time_series_hidden_size))
41
+
42
+ # mlp
43
+ self.linear_1 = nn.Linear(config.time_series_hidden_size, config.text_config.hidden_size, bias=True)
44
+ self.act = ACT2FN[config.projector_hidden_act]
45
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
46
+
47
+ def forward(self, time_series_features, input_mask):
48
+ masked_features = time_series_features * input_mask.unsqueeze(-1) + self.mask_embedding * (1 - input_mask.unsqueeze(-1))
49
+ hidden_states = self.linear_1(masked_features)
50
+ hidden_states = self.act(hidden_states)
51
+ hidden_states = self.linear_2(hidden_states)
52
+ return hidden_states
53
+
54
+
55
+ class MistsPreTrainedModel(PreTrainedModel):
56
+ config_class = MistsConfig
57
+ base_model_prefix = "model"
58
+ supports_gradient_checkpointing = True
59
+ _no_split_modules = ["T5Block"]
60
+ _skip_keys_device_placement = "past_key_values"
61
+ _supports_flash_attn_2 = True
62
+ _supports_sdpa = True
63
+ _supports_cache_class = True
64
+ _supports_static_cache = True
65
+
66
+ def _init_weights(self, module):
67
+ # important: 現状Mistralの初期化コードをそのまま移植している。
68
+ # refers: https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/mistral/modeling_mistral.py#L762
69
+ # 現状のまま事前学習を行うのは望ましくなく、FineTuningと推論のみが可能。
70
+ std = self.config.text_config.initializer_range
71
+ if isinstance(module, nn.Linear):
72
+ module.weight.data.normal_(mean=0.0, std=std)
73
+ if module.bias is not None:
74
+ module.bias.data.zero_()
75
+ elif isinstance(module, nn.Embedding):
76
+ module.weight.data.normal_(mean=0.0, std=std)
77
+ if module.padding_idx is not None:
78
+ module.weight.data[module.padding_idx].zero_()
79
+
80
+
81
+ class MistsForConditionalGeneration(MistsPreTrainedModel):
82
+ def __init__(self, config: MistsConfig):
83
+ super().__init__(config)
84
+
85
+ self.time_series_tower = MomentEmbeddingModel(config.time_series_config)
86
+ self.multi_modal_projector = MistsMultiModalProjector(config)
87
+ self.vocab_size = config.text_config.vocab_size
88
+ self.language_model = AutoModelForCausalLM.from_config(
89
+ config.text_config, attn_implementation=config._attn_implementation
90
+ )
91
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
92
+ self.post_init()
93
+
94
+ def get_time_series_tower(self):
95
+ time_series_tower = getattr(self, 'time_series_tower', None)
96
+ if type(time_series_tower) is list:
97
+ time_series_tower = time_series_tower[0]
98
+ return time_series_tower
99
+
100
+ def get_input_embeddings(self):
101
+ return self.language_model.get_input_embeddings()
102
+
103
+ def set_input_embeddings(self, value):
104
+ self.language_model.set_input_embeddings(value)
105
+
106
+ def get_output_embeddings(self):
107
+ return self.language_model.get_output_embeddings()
108
+
109
+ def set_output_embeddings(self, new_embeddings):
110
+ self.language_model.set_output_embeddings(new_embeddings)
111
+
112
+ def set_decoder(self, decoder):
113
+ self.language_model.set_decoder(decoder)
114
+
115
+ def get_decoder(self):
116
+ return self.language_model.get_decoder()
117
+
118
+ def tie_weights(self):
119
+ return self.language_model.tie_weights()
120
+
121
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
122
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
123
+ # update vocab size
124
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
125
+ self.vocab_size = model_embeds.num_embeddings
126
+ return model_embeds
127
+
128
+ # copy _merge_input_ids_with_image_features from LlabaForConditionalGeneration
129
+ # refers: https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/llava/modeling_llava.py#L277C9-L277C45
130
+ def _merge_input_ids_with_time_series_features(self, time_series_features, inputs_embeds, input_ids, attention_mask, labels):
131
+ num_time_series, num_time_series_patches, embed_dim = time_series_features.shape # num_time_series_patches = n_channels x n_patches
132
+ batch_size, sequence_length = input_ids.shape
133
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
134
+ # 1. Create a mask to know where special time_series tokens are
135
+ special_time_series_token_mask = input_ids == self.config.time_series_token_index
136
+ num_special_time_series_tokens = torch.sum(special_time_series_token_mask, dim=-1)
137
+ # Compute the maximum embed dimension
138
+ max_embed_dim = (num_special_time_series_tokens.max() * (num_time_series_patches - 1)) + sequence_length
139
+ max_embed_dim = int(max_embed_dim.item()) # テンソルから整数値を取得
140
+ if max_embed_dim is None:
141
+ print(f"num_special_time_series_tokens.max(): {num_special_time_series_tokens.max()}")
142
+ print(f"num_time_series_patches: {num_time_series_patches}")
143
+ print(f"sequence_length: {sequence_length}")
144
+ else:
145
+ print(f"max_embed_dim 0: {max_embed_dim}")
146
+ batch_indices, non_time_series_indices = torch.where(input_ids != self.config.time_series_token_index)
147
+
148
+ # 2. Compute the positions where text should be written
149
+ # Calculate new positions for text tokens in merged time_series-text sequence.
150
+ # `special_time_series_token_mask` identifies time_series tokens. Each time_series token will be replaced by `nb_text_tokens_per_time_series - 1` text tokens.
151
+ # `torch.cumsum` computes how each time_series token shifts subsequent text token positions.
152
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
153
+ new_token_positions = torch.cumsum((special_time_series_token_mask * (num_time_series_patches - 1) + 1), -1) - 1
154
+ nb_time_series_pad = max_embed_dim - 1 - new_token_positions[:, -1]
155
+ if left_padding:
156
+ new_token_positions += nb_time_series_pad[:, None] # offset for left padding
157
+ text_to_overwrite = new_token_positions[batch_indices, non_time_series_indices]
158
+
159
+ # 3. Create the full embedding, already padded to the maximum position
160
+ final_embedding = torch.zeros(
161
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
162
+ )
163
+ final_attention_mask = torch.zeros(
164
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
165
+ )
166
+ if labels is not None:
167
+ final_labels = torch.full(
168
+ (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
169
+ )
170
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
171
+ # set the corresponding tensors into their correct target device.
172
+ target_device = inputs_embeds.device
173
+ batch_indices, non_time_series_indices, text_to_overwrite = (
174
+ batch_indices.to(target_device),
175
+ non_time_series_indices.to(target_device),
176
+ text_to_overwrite.to(target_device),
177
+ )
178
+ attention_mask = attention_mask.to(target_device)
179
+
180
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<time_series>", "how", "are"]
181
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the time_series features
182
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_time_series_indices]
183
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_time_series_indices]
184
+ print("max_embed_dim is None: ", (max_embed_dim is None))
185
+ print("max_embed_dim: ", max_embed_dim)
186
+ if labels is not None:
187
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_time_series_indices]
188
+ print("max_embed_dim is None: ", (max_embed_dim is None))
189
+ print("max_embed_dim: ", max_embed_dim)
190
+
191
+ # 5. Fill the embeddings corresponding to the time_series. Anything that is not `text_positions` needs filling (#29835)
192
+ print("inputs_embeds.device: ", inputs_embeds.device)
193
+ print("max_embed_dim: ", max_embed_dim, " is None: ", (max_embed_dim is None))
194
+ time_series_to_overwrite = torch.full(
195
+ (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
196
+ )
197
+ time_series_to_overwrite[batch_indices, text_to_overwrite] = False
198
+ time_series_to_overwrite &= time_series_to_overwrite.cumsum(-1) - 1 >= nb_time_series_pad[:, None].to(target_device)
199
+
200
+ if time_series_to_overwrite.sum() != time_series_features.shape[:-1].numel():
201
+ raise ValueError(
202
+ f"The input provided to the model are wrong. The number of time series tokens is {torch.sum(special_time_series_token_mask)} while"
203
+ f" the number of time series given to the model is {num_time_series}. This prevents correct indexing and breaks batch generation."
204
+ )
205
+
206
+ final_embedding[time_series_to_overwrite] = time_series_features.contiguous().reshape(-1, embed_dim).to(target_device)
207
+ final_attention_mask |= time_series_to_overwrite
208
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
209
+
210
+ # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
211
+ batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
212
+ indices_to_mask = new_token_positions[batch_indices, pad_indices]
213
+
214
+ final_embedding[batch_indices, indices_to_mask] = 0
215
+
216
+ if labels is None:
217
+ final_labels = None
218
+
219
+ return final_embedding, final_attention_mask, final_labels, position_ids
220
+
221
+ def forward(
222
+ self,
223
+ input_ids: torch.LongTensor = None,
224
+ time_series_values: torch.FloatTensor = None,
225
+ time_series_input_mask: torch.FloatTensor = None,
226
+ attention_mask: Optional[torch.Tensor] = None,
227
+ position_ids: Optional[torch.LongTensor] = None,
228
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
229
+ inputs_embeds: Optional[torch.FloatTensor] = None,
230
+ # time_series_feature_layer: Optional[int] = None,
231
+ # time_series_feature_select_strategy: Optional[str] = None,
232
+ labels: Optional[torch.LongTensor] = None,
233
+ use_cache: Optional[bool] = None,
234
+ output_attentions: Optional[bool] = None,
235
+ output_hidden_states: Optional[bool] = None,
236
+ return_dict: Optional[bool] = None,
237
+ ) -> Union[Tuple, MistsCausalLMOutputWithPast]:
238
+
239
+ # language_modelの引数で変わる
240
+ # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
241
+ # output_hidden_states = (
242
+ # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
243
+ # )
244
+ # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
245
+ # vision_feature_layer = (
246
+ # vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
247
+ # )
248
+ # vision_feature_select_strategy = (
249
+ # vision_feature_select_strategy
250
+ # if vision_feature_select_strategy is not None
251
+ # else self.config.vision_feature_select_strategy
252
+ # )
253
+
254
+ if inputs_embeds is None:
255
+ # 1. Extra the input embeddings
256
+ inputs_embeds = self.get_input_embeddings()(input_ids)
257
+
258
+ # 2. Merge text and time_series
259
+ if time_series_values is not None and input_ids.shape[1] != 1:
260
+ time_series_outputs = self.time_series_tower(time_series_values, time_series_input_mask)
261
+ time_series_features = self.multi_modal_projector(
262
+ time_series_features=time_series_outputs.hidden_states, # [batch_size, n_patches, d_model]
263
+ input_mask=time_series_outputs.input_mask_patch_view, # [batch_size, n_paches]
264
+ )
265
+
266
+ inputs_embeds = inputs_embeds.to(time_series_features.dtype)
267
+ inputs_embeds, attention_mask, labels, position_ids =self._merge_input_ids_with_time_series_features(
268
+ time_series_features, inputs_embeds, input_ids, attention_mask, labels
269
+ )
270
+
271
+ # In case input_ids.shape[1] == 1 & time_series_values==None & past_key_values != None, we are in the case of
272
+ # generation with cache
273
+ elif past_key_values is not None and time_series_values is not None and input_ids.shape[1] == 1:
274
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
275
+ # that are set to 0
276
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
277
+
278
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
279
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
280
+
281
+ # Get the target length
282
+ target_length = input_ids.shape[1]
283
+ past_length = first_layer_past_key_value.shape[-1]
284
+
285
+ extended_attention_mask = torch.ones(
286
+ (attention_mask.shape[0], past_length),
287
+ dtype=attention_mask.dtype,
288
+ device=attention_mask.device,
289
+ )
290
+
291
+ # Filter out only the tokens that can be un-attended, this can happen
292
+ # if one uses Llava + Fused modules where the cache on the
293
+ # first iteration is already big enough, or if one passes custom cache
294
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
295
+ new_batch_index = batch_index[valid_indices]
296
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
297
+
298
+ # Zero-out the places where we don't need to attend
299
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
300
+
301
+ attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
302
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
303
+
304
+ print("inputs_embeds: ", inputs_embeds.shape)
305
+
306
+ outputs = self.language_model(
307
+ attention_mask=attention_mask,
308
+ position_ids=position_ids,
309
+ past_key_values=past_key_values,
310
+ inputs_embeds=inputs_embeds.to(self.language_model.dtype),
311
+ use_cache=use_cache,
312
+ output_attentions=output_attentions,
313
+ output_hidden_states=output_hidden_states,
314
+ return_dict=return_dict,
315
+ )
316
+
317
+ logits = outputs[0]
318
+
319
+ loss = None
320
+ if labels is not None:
321
+ # Shift so that tokens < n predict n
322
+ if attention_mask is not None:
323
+ shift_attention_mask = attention_mask[..., 1:]
324
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
325
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
326
+ else:
327
+ shift_logits = logits[..., :-1, :].contiguous()
328
+ shift_labels = labels[..., 1:].contiguous()
329
+ # Flatten the tokens
330
+ loss_fct = nn.CrossEntropyLoss()
331
+ loss = loss_fct(
332
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
333
+ )
334
+
335
+ if not return_dict:
336
+ output = (logits,) + outputs[1:]
337
+ return (loss,) + output if loss is not None else output
338
+
339
+ return MistsCausalLMOutputWithPast(
340
+ loss=loss,
341
+ logits=logits,
342
+ past_key_values=outputs.past_key_values,
343
+ hidden_states=outputs.hidden_states,
344
+ attentions=outputs.attentions,
345
+ )
346
+
347
+ def prepare_inputs_for_generation(
348
+ self, input_ids, past_key_values=None, inputs_embeds=None, time_series_values=None, attention_mask=None, **kwargs
349
+ ):
350
+ if past_key_values is not None:
351
+ if isinstance(past_key_values, Cache):
352
+ cache_length = past_key_values.get_seq_length()
353
+ past_length = past_key_values.seen_tokens
354
+ else:
355
+ cache_length = past_length = past_key_values[0][0].shape[2]
356
+
357
+ # Keep only the unprocessed tokens:
358
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
359
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
360
+ # input)
361
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
362
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
363
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
364
+ # input_ids based on the past_length.
365
+ elif past_length < input_ids.shape[1]:
366
+ input_ids = input_ids[:, past_length:]
367
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
368
+ elif self.config.time_series_token_index in input_ids:
369
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
370
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
371
+ # older attention values, as their corresponding values are not part of the input.
372
+ if cache_length < past_length and attention_mask is not None:
373
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
374
+
375
+ position_ids = kwargs.get("position_ids", None)
376
+ if attention_mask is not None and position_ids is None:
377
+ # create position_ids on the fly for batch generation
378
+ position_ids = attention_mask.long().cumsum(-1) - 1
379
+ position_ids.masked_fill_(attention_mask == 0, 1)
380
+ if past_key_values:
381
+ position_ids = position_ids[:, -input_ids.shape[1] :]
382
+
383
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
384
+ if inputs_embeds is not None and past_key_values is None:
385
+ model_inputs = {"inputs_embeds": inputs_embeds}
386
+ else:
387
+ model_inputs = {"input_ids": input_ids}
388
+
389
+ model_inputs.update(
390
+ {
391
+ "position_ids": position_ids,
392
+ "past_key_values": past_key_values,
393
+ "use_cache": kwargs.get("use_cache"),
394
+ "attention_mask": attention_mask,
395
+ "time_series_values": time_series_values,
396
+ }
397
+ )
398
+ return model_inputs
399
+
400
+ def _reorder_cache(self, *args, **kwargs):
401
+ return self.language_model._reorder_cache(*args, **kwargs)
402
+
403
+
404
+
405
+