tinyllava commited on
Commit
0d36298
·
verified ·
1 Parent(s): 58845c3

upload python file

Browse files
Files changed (2) hide show
  1. generate_model.py +4 -3
  2. modeling_tinyllava_phi.py +419 -0
generate_model.py CHANGED
@@ -649,7 +649,8 @@ def generate(
649
  outputs = outputs.strip()
650
 
651
  return outputs, generation_time
652
- def tinyllava_elm_generate_parser():
 
653
  """Argument Parser"""
654
 
655
  class KwargsParser(argparse.Action):
@@ -674,7 +675,7 @@ def tinyllava_elm_generate_parser():
674
  converted_v = kwarg_v
675
  getattr(namespace, self.dest)[kwarg_k] = converted_v
676
 
677
- parser = argparse.ArgumentParser('TinyLLaVA-OpenELM Generate Module')
678
  parser.add_argument(
679
  '--model',
680
  dest='model',
@@ -704,7 +705,7 @@ def tinyllava_elm_generate_parser():
704
 
705
 
706
  if __name__ == '__main__':
707
- args = tinyllava_elm_generate_parser()
708
 
709
  output_text, genertaion_time = generate(
710
  prompt=args.prompt,
 
649
  outputs = outputs.strip()
650
 
651
  return outputs, generation_time
652
+
653
+ def tinyllava_phi_generate_parser():
654
  """Argument Parser"""
655
 
656
  class KwargsParser(argparse.Action):
 
675
  converted_v = kwarg_v
676
  getattr(namespace, self.dest)[kwarg_k] = converted_v
677
 
678
+ parser = argparse.ArgumentParser('TinyLLaVA-Phi Generate Module')
679
  parser.add_argument(
680
  '--model',
681
  dest='model',
 
705
 
706
 
707
  if __name__ == '__main__':
708
+ args = tinyllava_phi_generate_parser()
709
 
710
  output_text, genertaion_time = generate(
711
  prompt=args.prompt,
modeling_tinyllava_phi.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union
3
+ import ast
4
+ import re
5
+
6
+ import torch
7
+ import torch.utils.checkpoint
8
+ from torch import nn, Tensor
9
+ from torch.nn import functional as F
10
+
11
+ from transformers import PreTrainedModel
12
+ from transformers.modeling_outputs import CausalLMOutputWithPast
13
+ from transformers.generation.utils import GenerateOutput
14
+ from transformers import CLIPVisionModel, CLIPImageProcessor, SiglipVisionModel, SiglipImageProcessor
15
+
16
+ from .configuration import TinyLlavaConfig, IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
17
+
18
+ from transformers import AutoConfig, AutoModelForCausalLM, PhiForCausalLM
19
+
20
+ # from tinyllava.utils.data_utils import get_value_from_kwargs
21
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
22
+ WORKER_HEART_BEAT_INTERVAL = 15
23
+
24
+ LOGDIR = "."
25
+ #
26
+ # For licensing see accompanying LICENSE file.
27
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
28
+ #
29
+ from transformers.utils import logging
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ # this import has to be relative, otherwise, when setting trust_remote_code=True
34
+ # huggingface transformers won't be able to load the module correctly
35
+ from numbers import Number
36
+ from typing import List, Optional, Union
37
+
38
+
39
+
40
+
41
+ ACT_TYPE = {
42
+ 'relu': nn.ReLU,
43
+ 'gelu': nn.GELU
44
+ }
45
+
46
+ class Connector(nn.Module):
47
+ def __init__(self, config=None):
48
+ super().__init__()
49
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', config.connector_type)
50
+ act_type = config.connector_type.split('_')[-1]
51
+ mlp_depth = int(mlp_gelu_match.group(1))
52
+ modules = [nn.Linear(config.vision_hidden_size, config.hidden_size)]
53
+ for _ in range(1, mlp_depth):
54
+ modules.append(ACT_TYPE[act_type]())
55
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
56
+
57
+ self._connector = nn.Sequential(*modules)
58
+
59
+ def forward(self, x):
60
+ return self._connector(x)
61
+
62
+ class VisionTower(nn.Module):
63
+ def __init__(self, cfg, model_name_or_path = 'clip'):
64
+ super().__init__()
65
+ if 'clip' in model_name_or_path:
66
+ self._vision_tower = CLIPVisionModel(cfg)
67
+ self._image_processor = CLIPImageProcessor.from_pretrained(cfg.model_name_or_path)
68
+ else:
69
+ self._vision_tower = SiglipVisionModel(cfg)
70
+ self._image_processor = SiglipImageProcessor.from_pretrained(cfg.model_name_or_path)
71
+
72
+ self.config = cfg
73
+
74
+ def forward(self, x, **kwargs):
75
+ image_features = self._vision_tower(x, output_hidden_states=True)
76
+ image_features = image_features.hidden_states[kwargs.get('vision_feature_layer', -2)]
77
+
78
+ if kwargs.get('vision_feature_select_strategy', 'patch') == 'patch':
79
+ image_features = image_features[:, 1:]
80
+ elif kwargs.get('vision_feature_select_strategy', 'patch') == 'cls_patch':
81
+ image_features = image_features
82
+ else:
83
+ raise ValueError(f"Unexpected select feature: {kwargs.get('vision_feature_select_strategy')}")
84
+
85
+ return image_features
86
+
87
+ @property
88
+ def vision_tower(self):
89
+ return self._vision_tower
90
+
91
+ @vision_tower.setter
92
+ def vision_tower(self, vision_tower):
93
+ self._vision_tower = vision_tower
94
+
95
+ def get_value_from_kwargs(kwargs, name):
96
+ if name in kwargs:
97
+ return kwargs.pop(name)
98
+ else:
99
+ return None
100
+
101
+
102
+ class TinyLlavaPreTrainedModel(PreTrainedModel):
103
+ config_class = TinyLlavaConfig
104
+ base_model_prefix = "model"
105
+ supports_gradient_checkpointing = True
106
+ _no_split_modules = ["LlavaVisionAttention"]
107
+ _skip_keys_device_placement = "past_key_values"
108
+ _supports_flash_attn_2 = True
109
+
110
+ def _init_weights(self, module):
111
+ std = (
112
+ self.config.initializer_range
113
+ if hasattr(self.config, "initializer_range")
114
+ else self.config.text_config.initializer_range
115
+ )
116
+
117
+ if hasattr(module, "class_embedding"):
118
+ module.class_embedding.data.normal_(mean=0.0, std=std)
119
+
120
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
121
+ module.weight.data.normal_(mean=0.0, std=std)
122
+ if module.bias is not None:
123
+ module.bias.data.zero_()
124
+ elif isinstance(module, nn.Embedding):
125
+ module.weight.data.normal_(mean=0.0, std=std)
126
+ if module.padding_idx is not None:
127
+ module.weight.data[module.padding_idx].zero_()
128
+
129
+ @property
130
+ def _supports_sdpa(self):
131
+ return self.language_model._supports_sdpa
132
+
133
+
134
+ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
135
+ def __init__(self, config: TinyLlavaConfig):
136
+
137
+ super().__init__(config)
138
+
139
+ self.language_model = PhiForCausalLM(config.text_config)
140
+ self.vision_tower = VisionTower(config.vision_config, config.vision_model_name_or_path)
141
+ self.connector = Connector(config)
142
+ self.post_init()
143
+
144
+
145
+ def get_input_embeddings(self):
146
+ return self.language_model.get_input_embeddings()
147
+
148
+ def set_input_embeddings(self, value):
149
+ self.language_model.set_input_embeddings(value)
150
+
151
+ def get_output_embeddings(self):
152
+ return self.language_model.get_output_embeddings()
153
+
154
+ def set_output_embeddings(self, new_embeddings):
155
+ self.language_model.set_output_embeddings(new_embeddings)
156
+
157
+ def set_decoder(self, decoder):
158
+ self.language_model.set_decoder(decoder)
159
+
160
+ def get_decoder(self):
161
+ return self.language_model.get_decoder()
162
+
163
+ def tie_weights(self):
164
+ return self.language_model.tie_weights()
165
+
166
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
167
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
168
+ # update vocab size
169
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
170
+ self.config.vocab_size = model_embeds.num_embeddings
171
+ self.vocab_size = model_embeds.num_embeddings
172
+ return model_embeds
173
+
174
+
175
+ def forward(
176
+ self,
177
+ input_ids: torch.LongTensor = None,
178
+ attention_mask: Optional[torch.Tensor] = None,
179
+ position_ids: Optional[torch.LongTensor] = None,
180
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
181
+ inputs_embeds: Optional[torch.FloatTensor] = None,
182
+ labels: Optional[torch.LongTensor] = None,
183
+ use_cache: Optional[bool] = None,
184
+ output_attentions: Optional[bool] = None,
185
+ output_hidden_states: Optional[bool] = None,
186
+ images: Optional[torch.FloatTensor] = None,
187
+ image_sizes: Optional[List[List[int]]] = None,
188
+ return_dict: Optional[bool] = None,
189
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
190
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
191
+ if inputs_embeds is None:
192
+ (
193
+ input_ids,
194
+ position_ids,
195
+ attention_mask,
196
+ past_key_values,
197
+ inputs_embeds,
198
+ labels
199
+ ) = self.prepare_inputs_labels_for_multimodal(
200
+ input_ids,
201
+ position_ids,
202
+ attention_mask,
203
+ past_key_values,
204
+ labels,
205
+ images,
206
+ image_sizes
207
+ )
208
+ return self.language_model.forward(
209
+ input_ids=input_ids,
210
+ attention_mask=attention_mask,
211
+ position_ids=position_ids,
212
+ past_key_values=past_key_values,
213
+ inputs_embeds=inputs_embeds,
214
+ labels=labels,
215
+ use_cache=use_cache,
216
+ output_attentions=output_attentions,
217
+ output_hidden_states=output_hidden_states,
218
+ return_dict=return_dict
219
+ )
220
+
221
+ @torch.no_grad()
222
+ def generate(
223
+ self,
224
+ inputs: Optional[torch.Tensor] = None,
225
+ images: Optional[torch.Tensor] = None,
226
+ image_sizes: Optional[torch.Tensor] = None,
227
+ **kwargs,
228
+ ) -> Union[GenerateOutput, torch.LongTensor]:
229
+ position_ids = kwargs.pop("position_ids", None)
230
+ attention_mask = kwargs.pop("attention_mask", None)
231
+ if "inputs_embeds" in kwargs:
232
+ raise NotImplementedError("`inputs_embeds` is not supported")
233
+
234
+ if images is not None:
235
+ (
236
+ inputs,
237
+ position_ids,
238
+ attention_mask,
239
+ _,
240
+ inputs_embeds,
241
+ _
242
+ ) = self.prepare_inputs_labels_for_multimodal(
243
+ inputs,
244
+ position_ids,
245
+ attention_mask,
246
+ None,
247
+ None,
248
+ images,
249
+ image_sizes=image_sizes
250
+ )
251
+ else:
252
+ inputs_embeds = self.language_model.get_input_embeddings()(inputs)
253
+
254
+ return self.language_model.generate(
255
+ position_ids=position_ids,
256
+ attention_mask=attention_mask,
257
+ inputs_embeds=inputs_embeds,
258
+ **kwargs
259
+ )
260
+
261
+ def encode_images(self, images):
262
+ kwargs = {}
263
+ kwargs['vision_feature_layer'] = self.config.vision_feature_layer
264
+ kwargs['vision_feature_select_strategy'] = self.config.vision_feature_select_strategy
265
+ images = images.to(device=self.device, dtype=self.dtype)
266
+ image_features = self.vision_tower(images, **kwargs)
267
+ image_features = self.connector(image_features)
268
+ return image_features
269
+
270
+
271
+
272
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
273
+ inputs_embeds=None, **kwargs):
274
+ images = kwargs.pop("images", None)
275
+ image_sizes = kwargs.pop("image_sizes", None)
276
+ inputs = self.language_model.prepare_inputs_for_generation(
277
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
278
+ )
279
+ if images is not None:
280
+ inputs['images'] = images
281
+ if image_sizes is not None:
282
+ inputs['image_sizes'] = image_sizes
283
+ return inputs
284
+
285
+ def prepare_inputs_labels_for_multimodal(
286
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
287
+ images, image_sizes=None
288
+ ):
289
+ vision_tower = self.vision_tower
290
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
291
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
292
+
293
+
294
+ image_features = self.encode_images(images)
295
+
296
+ # TODO: image start / end is not implemented here to support pretraining.
297
+ if getattr(self.config, 'tune_mm_mlp_adapter', False):
298
+ raise NotImplementedError
299
+
300
+ # Let's just add dummy tensors if they do not exist,
301
+ # it is a headache to deal with None all the time.
302
+ # But it is not ideal, and if you have a better idea,
303
+ # please open an issue / submit a PR, thanks.
304
+ _labels = labels
305
+ _position_ids = position_ids
306
+ _attention_mask = attention_mask
307
+ if attention_mask is None:
308
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
309
+ else:
310
+ attention_mask = attention_mask.bool()
311
+ if position_ids is None:
312
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
313
+ if labels is None:
314
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
315
+
316
+ # remove the padding using attention_mask -- FIXME
317
+ _input_ids = input_ids
318
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
319
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
320
+
321
+ new_input_embeds = []
322
+ new_labels = []
323
+ cur_image_idx = 0
324
+ for batch_idx, cur_input_ids in enumerate(input_ids):
325
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
326
+ if num_images == 0:
327
+ cur_image_features = image_features[cur_image_idx]
328
+ cur_input_embeds_1 = self.language_model.get_input_embeddings()(cur_input_ids)
329
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
330
+ new_input_embeds.append(cur_input_embeds)
331
+ new_labels.append(labels[batch_idx])
332
+ cur_image_idx += 1
333
+ continue
334
+
335
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
336
+ cur_input_ids_noim = []
337
+ cur_labels = labels[batch_idx]
338
+ cur_labels_noim = []
339
+ for i in range(len(image_token_indices) - 1):
340
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
341
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
342
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
343
+ cur_input_embeds = self.language_model.get_input_embeddings()(torch.cat(cur_input_ids_noim))
344
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
345
+ cur_new_input_embeds = []
346
+ cur_new_labels = []
347
+
348
+ for i in range(num_images + 1):
349
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
350
+ cur_new_labels.append(cur_labels_noim[i])
351
+ if i < num_images:
352
+ cur_image_features = image_features[cur_image_idx]
353
+ cur_image_idx += 1
354
+ cur_new_input_embeds.append(cur_image_features)
355
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
356
+
357
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
358
+
359
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
360
+ cur_new_labels = torch.cat(cur_new_labels)
361
+
362
+ new_input_embeds.append(cur_new_input_embeds)
363
+ new_labels.append(cur_new_labels)
364
+
365
+ # Truncate sequences to max length as image embeddings can make the sequence longer
366
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
367
+ if tokenizer_model_max_length is not None:
368
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
369
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
370
+
371
+ # Combine them
372
+ max_len = max(x.shape[0] for x in new_input_embeds)
373
+ batch_size = len(new_input_embeds)
374
+
375
+ new_input_embeds_padded = []
376
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
377
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
378
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
379
+
380
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
381
+ cur_len = cur_new_embed.shape[0]
382
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
383
+ new_input_embeds_padded.append(torch.cat((
384
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
385
+ cur_new_embed
386
+ ), dim=0))
387
+ if cur_len > 0:
388
+ new_labels_padded[i, -cur_len:] = cur_new_labels
389
+ attention_mask[i, -cur_len:] = True
390
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
391
+ else:
392
+ new_input_embeds_padded.append(torch.cat((
393
+ cur_new_embed,
394
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
395
+ ), dim=0))
396
+ if cur_len > 0:
397
+ new_labels_padded[i, :cur_len] = cur_new_labels
398
+ attention_mask[i, :cur_len] = True
399
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
400
+
401
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
402
+
403
+ if _labels is None:
404
+ new_labels = None
405
+ else:
406
+ new_labels = new_labels_padded
407
+
408
+ if _attention_mask is None:
409
+ attention_mask = None
410
+ else:
411
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
412
+
413
+ if _position_ids is None:
414
+ position_ids = None
415
+
416
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
417
+
418
+ AutoConfig.register("tinyllava", TinyLlavaConfig)
419
+ AutoModelForCausalLM.register(TinyLlavaConfig, TinyLlavaForConditionalGeneration)