visheratin commited on
Commit
fa0d319
·
verified ·
1 Parent(s): 4830517

Update modeling file

Browse files
Files changed (1) hide show
  1. modeling_llava.py +144 -11
modeling_llava.py CHANGED
@@ -6,8 +6,6 @@ from typing import List, Optional, Tuple, Union
6
  import torch
7
  import torch.nn.functional as F
8
  import torch.utils.checkpoint
9
- from configuration_llava import LlavaConfig
10
- from configuration_phi import PhiConfig
11
  from torch import nn
12
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
13
  from transformers import PreTrainedModel, SiglipVisionModel
@@ -34,9 +32,138 @@ except Exception as exp:
34
  print(exp)
35
 
36
 
 
 
 
 
37
  logger = logging.get_logger(__name__)
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
41
  def _get_unpad_data(attention_mask):
42
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@@ -1324,7 +1451,7 @@ class SiglipVisionEncoder(nn.Module):
1324
 
1325
  self.num_tokens = 728
1326
 
1327
- def feature_select(self, image_forward_outs, coord_feature, num_tokens = None):
1328
  image_features = image_forward_outs
1329
  image_features = image_features[:, 1:]
1330
  if num_tokens is None:
@@ -1344,24 +1471,30 @@ class SiglipVisionEncoder(nn.Module):
1344
  image_features = torch.cat(output_list)
1345
  return image_features
1346
 
1347
- def process_image_chunks(self, image_tensor, coord_tensor, num_tokens = None):
1348
  if image_tensor.shape[0] > 50:
1349
  image_forward_out = []
1350
- for i in range(0,image_tensor.shape[0],50):
1351
- part_forward_out = self.vision_tower(image_tensor[i:i+50], output_hidden_states=True).hidden_states[-1]
 
 
1352
  image_forward_out.append(part_forward_out)
1353
  image_forward_out = torch.cat(image_forward_out, dim=0)
1354
  else:
1355
- image_forward_out = self.vision_tower(image_tensor, output_hidden_states=True).hidden_states[-1]
 
 
1356
  coord_feature = self.coord_embed(coord_tensor)
1357
  if len(coord_feature.shape) == 1:
1358
  coord_feature = coord_feature.unsqueeze(0)
1359
- image_feature = self.feature_select(image_forward_out, coord_feature, num_tokens).to(
1360
- image_tensor.dtype
1361
- )
1362
  return image_feature
1363
 
1364
- def forward(self, images: List[torch.Tensor], coords: List[torch.Tensor], num_tokens = None):
 
 
1365
  image_features = []
1366
  for i, image in enumerate(images):
1367
  image_feature = self.process_image_chunks(image, coords[i], num_tokens)
 
6
  import torch
7
  import torch.nn.functional as F
8
  import torch.utils.checkpoint
 
 
9
  from torch import nn
10
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
11
  from transformers import PreTrainedModel, SiglipVisionModel
 
32
  print(exp)
33
 
34
 
35
+ from transformers.configuration_utils import PretrainedConfig
36
+ from transformers import SiglipVisionConfig
37
+
38
+
39
  logger = logging.get_logger(__name__)
40
 
41
 
42
+ class PhiConfig(PretrainedConfig):
43
+ model_type = "phi"
44
+ keys_to_ignore_at_inference = ["past_key_values"]
45
+
46
+ def __init__(
47
+ self,
48
+ vocab_size=51200,
49
+ hidden_size=2048,
50
+ intermediate_size=8192,
51
+ num_hidden_layers=24,
52
+ num_attention_heads=32,
53
+ num_key_value_heads=None,
54
+ resid_pdrop=0.0,
55
+ embd_pdrop=0.0,
56
+ attention_dropout=0.0,
57
+ hidden_act="gelu_new",
58
+ max_position_embeddings=2048,
59
+ initializer_range=0.02,
60
+ layer_norm_eps=1e-5,
61
+ use_cache=True,
62
+ tie_word_embeddings=False,
63
+ rope_theta=10000.0,
64
+ rope_scaling=None,
65
+ partial_rotary_factor=0.5,
66
+ qk_layernorm=False,
67
+ bos_token_id=1,
68
+ eos_token_id=2,
69
+ **kwargs,
70
+ ):
71
+ self.vocab_size = vocab_size
72
+ self.hidden_size = hidden_size
73
+ self.intermediate_size = intermediate_size
74
+ self.num_hidden_layers = num_hidden_layers
75
+ self.num_attention_heads = num_attention_heads
76
+
77
+ if num_key_value_heads is None:
78
+ num_key_value_heads = num_attention_heads
79
+
80
+ self.num_key_value_heads = num_key_value_heads
81
+ self.resid_pdrop = resid_pdrop
82
+ self.embd_pdrop = embd_pdrop
83
+ self.attention_dropout = attention_dropout
84
+ self.hidden_act = hidden_act
85
+ self.max_position_embeddings = max_position_embeddings
86
+ self.initializer_range = initializer_range
87
+ self.layer_norm_eps = layer_norm_eps
88
+ self.use_cache = use_cache
89
+ self.rope_theta = rope_theta
90
+ self.rope_scaling = rope_scaling
91
+ self.partial_rotary_factor = partial_rotary_factor
92
+ self.qk_layernorm = qk_layernorm
93
+ self._rope_scaling_validation()
94
+
95
+ super().__init__(
96
+ bos_token_id=bos_token_id,
97
+ eos_token_id=eos_token_id,
98
+ tie_word_embeddings=tie_word_embeddings,
99
+ **kwargs,
100
+ )
101
+
102
+ def _rope_scaling_validation(self):
103
+ """
104
+ Validate the `rope_scaling` configuration.
105
+ """
106
+ if self.rope_scaling is None:
107
+ return
108
+
109
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
110
+ raise ValueError(
111
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
112
+ f"got {self.rope_scaling}"
113
+ )
114
+ rope_scaling_type = self.rope_scaling.get("type", None)
115
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
116
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
117
+ raise ValueError(
118
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
119
+ )
120
+ if (
121
+ rope_scaling_factor is None
122
+ or not isinstance(rope_scaling_factor, float)
123
+ or rope_scaling_factor <= 1.0
124
+ ):
125
+ raise ValueError(
126
+ f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}"
127
+ )
128
+
129
+
130
+ class LlavaConfig(PretrainedConfig):
131
+ model_type = "mc-llava"
132
+ is_composition = False
133
+
134
+ def __init__(
135
+ self,
136
+ text_config=None,
137
+ vision_config=None,
138
+ ignore_index=-100,
139
+ image_token_index=50297,
140
+ projector_hidden_act="gelu",
141
+ projector_tokens_num=1,
142
+ vocab_size=51200,
143
+ **kwargs,
144
+ ):
145
+ self.ignore_index = ignore_index
146
+ self.image_token_index = image_token_index
147
+ self.projector_hidden_act = projector_hidden_act
148
+ self.projector_tokens_num = projector_tokens_num
149
+ self.vocab_size = vocab_size
150
+
151
+ self.text_config = text_config
152
+ if isinstance(self.text_config, dict):
153
+ text_config["model_type"] = (
154
+ text_config["model_type"] if "model_type" in text_config else "phi"
155
+ )
156
+ self.text_config = PhiConfig(**text_config)
157
+ self.vocab_size = self.text_config.vocab_size
158
+
159
+ self.vision_config = vision_config
160
+ if isinstance(self.vision_config, dict):
161
+ self.vision_config = SiglipVisionConfig(**vision_config)
162
+ self.vision_embed_dim = self.vision_config.hidden_size
163
+
164
+ super().__init__(**kwargs)
165
+
166
+
167
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
168
  def _get_unpad_data(attention_mask):
169
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
1451
 
1452
  self.num_tokens = 728
1453
 
1454
+ def feature_select(self, image_forward_outs, coord_feature, num_tokens=None):
1455
  image_features = image_forward_outs
1456
  image_features = image_features[:, 1:]
1457
  if num_tokens is None:
 
1471
  image_features = torch.cat(output_list)
1472
  return image_features
1473
 
1474
+ def process_image_chunks(self, image_tensor, coord_tensor, num_tokens=None):
1475
  if image_tensor.shape[0] > 50:
1476
  image_forward_out = []
1477
+ for i in range(0, image_tensor.shape[0], 50):
1478
+ part_forward_out = self.vision_tower(
1479
+ image_tensor[i : i + 50], output_hidden_states=True
1480
+ ).hidden_states[-1]
1481
  image_forward_out.append(part_forward_out)
1482
  image_forward_out = torch.cat(image_forward_out, dim=0)
1483
  else:
1484
+ image_forward_out = self.vision_tower(
1485
+ image_tensor, output_hidden_states=True
1486
+ ).hidden_states[-1]
1487
  coord_feature = self.coord_embed(coord_tensor)
1488
  if len(coord_feature.shape) == 1:
1489
  coord_feature = coord_feature.unsqueeze(0)
1490
+ image_feature = self.feature_select(
1491
+ image_forward_out, coord_feature, num_tokens
1492
+ ).to(image_tensor.dtype)
1493
  return image_feature
1494
 
1495
+ def forward(
1496
+ self, images: List[torch.Tensor], coords: List[torch.Tensor], num_tokens=None
1497
+ ):
1498
  image_features = []
1499
  for i, image in enumerate(images):
1500
  image_feature = self.process_image_chunks(image, coords[i], num_tokens)