RxnIM / mllm /models /builder /build_shikra.py
CYF200127's picture
Upload 235 files
3e1d9f3 verified
from typing import Dict, Any, Tuple
import torch
import transformers
from torch import nn
from ..shikra import ShikraLlamaForCausalLM
PREPROCESSOR = Dict[str, Any]
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
def load_pretrained_shikra(model_args, training_args, **kwargs) -> Tuple[nn.Module, PREPROCESSOR]:
model = ShikraLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
**kwargs
)
model.config.use_cache = False
if model_args.freeze_backbone:
model.model.requires_grad_(False)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
model_max_length=model_args.model_max_length,
padding_side="right",
use_fast=False,
)
assert model_args.version == 'v1'
if model_args.version == "v0":
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
tokenizer=tokenizer,
model=model,
)
if "llama" in model_args.model_name_or_path:
tokenizer.add_special_tokens({
"eos_token": DEFAULT_EOS_TOKEN,
"bos_token": DEFAULT_BOS_TOKEN,
"unk_token": DEFAULT_UNK_TOKEN,
})
else:
tokenizer.pad_token = tokenizer.unk_token
model_vision_dict = model.model.initialize_vision_modules(
vision_tower=model_args.vision_tower,
mm_vision_select_layer=model_args.mm_vision_select_layer,
pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter
)
dtype = torch.float32
if training_args.fp16:
dtype = torch.float16
if training_args.bf16:
dtype = torch.bfloat16
# HACK for quantization
if model.model.vision_tower[0].device != torch.device('meta'):
model.model.vision_tower[0].to(dtype=dtype, device=training_args.device)
else:
from transformers import CLIPVisionModel
model.model.vision_tower[0] = CLIPVisionModel.from_pretrained(model_args.vision_tower) # not quantize clip
# model.model.vision_tower[0] = CLIPVisionModel.from_pretrained(model_args.vision_tower, **kwargs) # quantize clip、
vision_config = model_vision_dict['vision_config']
model.config.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
if model_args.tune_mm_mlp_adapter:
model.requires_grad_(False)
for p in model.model.mm_projector.parameters():
p.requires_grad = True
model.config.freeze_mm_mlp_adapter = model_args.freeze_mm_mlp_adapter
if model_args.freeze_mm_mlp_adapter:
for p in model.model.mm_projector.parameters():
p.requires_grad = False
model.config.mm_use_im_start_end = model_args.mm_use_im_start_end
vision_config.use_im_start_end = model_args.mm_use_im_start_end
model.initialize_vision_tokenizer(mm_use_im_start_end=model_args.mm_use_im_start_end,
tokenizer=tokenizer,
device=training_args.device,
tune_mm_mlp_adapter=model_args.tune_mm_mlp_adapter,
pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter)
params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
if len(params_no_grad) > 0:
if training_args.fsdp is not None and len(training_args.fsdp) > 0:
if len(params_no_grad) < 10:
print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'.format(len(params_no_grad),
params_no_grad))
else:
print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'.format(
len(params_no_grad), ', '.join(params_no_grad[:10])))
print("[WARNING] Attempting to use FSDP with partially frozen parameters, this is experimental.")
print(
"[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
def patch_FSDP_use_orig_params(func):
def wrap_func(*args, **kwargs):
use_orig_params = kwargs.pop('use_orig_params', True)
return func(*args, **kwargs, use_orig_params=use_orig_params)
return wrap_func
FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
preprocessor = dict(
image=model_vision_dict['image_processor'],
text=tokenizer,
conv=dict(
image_token_len=model_args.image_token_len,
sep_image_conv_front=model_args.sep_image_conv_front,
use_im_start_end=model_args.mm_use_im_start_end,
)
)
return model, preprocessor
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg