File size: 20,301 Bytes
3e1d9f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 |
from typing import List, Optional, Tuple, Union
from PIL import Image
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
import torchvision.transforms.functional as TF
from transformers import LlamaConfig, LlamaModel, LlamaForCausalLM, CLIPVisionModel, CLIPImageProcessor,AutoImageProcessor, DeformableDetrModel
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
Rxn_st = "<Rxn/st>"
Rxn_ed = "<Rxn/ed>" # Reaction
Rct_st = "<Rct/st>"
Rct_ed = "<Rct/ed>" # Reactant
Prd_st = "<Prd/st> "
Prd_ed = "<Prd/ed>" # Product
Cnd_st = "<Cnd/st>"
Cnd_ed = "<Cnd/ed>"
Mol = "[Str]" # Molecule
Txt = "[Txt]" # Text
Sol = "[Sol]"
Age = "[Age]"
Tem = "[Tem]"
Yld = "[Yld]"
Obj = "[Obj]"
rxn_tokens = [Rxn_st, Rxn_ed,Rct_st, Rct_ed, Prd_st, Prd_ed, Cnd_st, Cnd_ed, Mol, Txt,Sol,Age,Tem, Yld, Obj]
number_tokens = [f"{i:03}" for i in range(1, 1000)]
ID_tokens = [f"<ID_{i}>" for i in range(1, 51)]
def resize_batch(images, size):
"""
Resize a batch of images to the given size.
Args:
- images (torch.Tensor): Input tensor of shape (B, C, H, W)
- size (tuple): Desired output size (new_h, new_w)
Returns:
- torch.Tensor: Resized images of shape (B, C, new_h, new_w)
"""
resized_images = []
for image in images:
# Resize image and add it to the list
resized = TF.resize(image, size, interpolation=Image.BICUBIC)
resized_images.append(resized)
# Stack all resized images along the batch dimension
return torch.stack(resized_images)
class VisionLanguageAdapter(nn.Module):
def __init__(self, feature_dim=1280, num_queries=256, num_heads=16):
super(VisionLanguageAdapter, self).__init__()
self.num_queries = num_queries
self.query_embeds = nn.Parameter(torch.randn(num_queries, feature_dim))
self.cross_attention = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=num_heads, batch_first=True)
self.positional_encoding = nn.Parameter(torch.randn(num_queries, feature_dim))
self.layer_norm = nn.LayerNorm(feature_dim)
self.linear = nn.Linear(feature_dim, 5120)
def forward(self, image_features):
# Add positional encoding to query embeddings
query_embeds = self.query_embeds + self.positional_encoding
# Flag to check if input was unbatched
was_unbatched = image_features.dim() == 2
# Adjust dimensions based on whether input is batched or unbatched
if was_unbatched:
# For unbatched input, add a batch dimension for compatibility
image_features = image_features.unsqueeze(0)
query_embeds = query_embeds.unsqueeze(0)
else:
# For batched input, adjust the query embeddings to match the batch size
batch_size = image_features.size(0)
query_embeds = query_embeds.unsqueeze(0).expand(batch_size, -1, -1)
# Apply cross attention
attn_output, _ = self.cross_attention(query=query_embeds, key=image_features, value=image_features)
attn_output = self.layer_norm(attn_output)
attn_output = self.linear(attn_output)
# If the input was unbatched, remove the batch dimension from the output
if was_unbatched:
attn_output = attn_output.squeeze(0)
return attn_output
class ShikraConfig(LlamaConfig):
model_type = "shikra"
class ShikraLlamaModel(LlamaModel):
config_class = ShikraConfig
def __init__(self, config: LlamaConfig, mm_vision_tower=None, mm_hidden_size=None):
super(ShikraLlamaModel, self).__init__(config)
if hasattr(config, "mm_vision_tower"):
# HACK: for FSDP
self.vision_tower = nn.ModuleList([DeformableDetrModel.from_pretrained(config.mm_vision_tower)])
#self.vision_tower = nn.ModuleList([CLIPVisionModel.from_pretrained(config.mm_vision_tower)])
if hasattr(config, "use_mm_proj"):
self.mm_projector = nn.Linear(256, config.hidden_size)
def initialize_vision_modules(self, vision_tower, mm_vision_select_layer,
pretrain_mm_mlp_adapter=None, tune_mm_mlp_adapter=False):
self.config.mm_vision_tower = vision_tower
image_processor = AutoImageProcessor.from_pretrained(vision_tower)
#image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
if not hasattr(self, 'vision_tower'):
vision_tower = DeformableDetrModel.from_pretrained(vision_tower)
#vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
self.vision_tower = nn.ModuleList([vision_tower]) # 使用 ModuleList 包装模型
else:
self.vision_tower[0] = DeformableDetrModel.from_pretrained(vision_tower)
#self.vision_tower[0] = CLIPVisionModel.from_pretrained(vision_tower)# 直接赋值到 ModuleList 中的相应位置
# 设置模型为训练模式
self.vision_tower[0].requires_grad_(True)
self.vision_tower[0] = self.vision_tower[0].to(torch.float16)
vision_config = self.vision_tower[0].config
num_patches = 300
self.config.use_mm_proj = True
self.config.mm_hidden_size = 256
self.config.mm_vision_select_layer = mm_vision_select_layer
if not hasattr(self, 'mm_projector'):
self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.hidden_size)
if pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
return dict(
image_processor=image_processor,
image_token_len=num_patches,
vision_config=vision_config
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
# if orig_embeds_params is not None:
# orig_embeds_params = orig_embeds_params[0]
# with torch.no_grad():
# self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
vision_tower = getattr(self, 'vision_tower', None)
if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
# TODO: this is a modified multimodal LLM -- Haotian Liu
vision_tower = vision_tower[0] # HACK: for FSDP
new_size = (1333, 1333)
images = resize_batch(images, new_size)
with torch.no_grad():
if type(images) is list:
# variable length images
image_features = []
for image in images:
image_forward_out = vision_tower(image.unsqueeze(0))
select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
image_feature = image_forward_out.last_hidden_state
# image_feature = select_hidden_state[:, 1:]
image_features.append(image_feature)
#print(image_features.shape)
else:
#print(images.shape)
image_forward_outs = vision_tower(images)
select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
image_features = image_forward_outs.last_hidden_state
# print(image_features.shape)
# image_forward_outs = vision_tower(images, output_hidden_states=True)
# select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
# select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
# image_features = select_hidden_state[:, 1:]
if type(images) is list:
image_features = [self.mm_projector(image_feature)[0] for image_feature in image_features]
else:
image_features = self.mm_projector(image_features)
dummy_image_features = torch.zeros(300, 256, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
dummy_image_features = self.mm_projector(dummy_image_features)
new_input_embeds = []
cur_image_idx = 0
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
new_input_embeds.append(cur_input_embeds)
continue
if vision_tower.config.use_im_start_end:
cur_image_features = image_features[cur_image_idx]
num_patches = cur_image_features.shape[0]
if (cur_input_ids == vision_tower.config.im_start_token).sum() != (
cur_input_ids == vision_tower.config.im_end_token).sum():
raise ValueError("The number of image start tokens and image end tokens should be the same.")
image_start_tokens = torch.where(cur_input_ids == vision_tower.config.im_start_token)[0]
for image_start_token_pos in image_start_tokens:
cur_image_features = image_features[cur_image_idx].to(device=cur_input_embeds.device)
num_patches = cur_image_features.shape[0]
if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token:
raise ValueError("The image end token should follow the image start token.")
if orig_embeds_params is not None:
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(),
cur_input_embeds[image_start_token_pos:image_start_token_pos + 1],
cur_image_features, cur_input_embeds[
image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2],
cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
else:
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos + 1], cur_image_features,
cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
cur_image_idx += 1
new_input_embeds.append(cur_new_input_embeds)
else:
cur_image_features = image_features[cur_image_idx]
num_patches = cur_image_features.shape[0]
if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches:
raise ValueError("The number of image patch tokens should be the same as the number of image patches.")
masked_indices = torch.where(cur_input_ids == vision_tower.config.im_patch_token)[0]
mask_index_start = masked_indices[0]
if (masked_indices != torch.arange(mask_index_start, mask_index_start + num_patches, device=masked_indices.device,
dtype=masked_indices.dtype)).any():
raise ValueError("The image patch tokens should be consecutive.")
if orig_embeds_params is not None:
cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_image_features,
cur_input_embeds[mask_index_start + num_patches:].detach()), dim=0)
else:
cur_new_input_embeds = torch.cat(
(cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start + num_patches:]),
dim=0)
new_input_embeds.append(cur_new_input_embeds)
inputs_embeds = torch.stack(new_input_embeds, dim=0)
return super(ShikraLlamaModel, self).forward(
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, use_cache=use_cache,
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
return_dict=return_dict
)
class ShikraLlamaForCausalLM(LlamaForCausalLM):
config_class = ShikraConfig
def __init__(self, config: ShikraConfig):
super(LlamaForCausalLM, self).__init__(config)
self.model = ShikraLlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
images=images
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model/pipeline parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"images": kwargs.get("images", None),
}
)
return model_inputs
def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None):
vision_config = self.model.vision_tower[0].config
vision_config.use_im_start_end = mm_use_im_start_end
tokenizer.add_tokens(rxn_tokens)
tokenizer.add_tokens(ID_tokens)
#tokenizer.add_tokens(number_tokens)
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if mm_use_im_start_end:
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if tune_mm_mlp_adapter:
self.model.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)]
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
if pretrain_mm_mlp_adapter:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
assert num_new_tokens == 2
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] |