File size: 6,242 Bytes
3e1d9f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
from typing import Dict, Any, Tuple
import torch
import transformers
from torch import nn
from ..shikra import ShikraLlamaForCausalLM
PREPROCESSOR = Dict[str, Any]
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
def load_pretrained_shikra(model_args, training_args, **kwargs) -> Tuple[nn.Module, PREPROCESSOR]:
model = ShikraLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
**kwargs
)
model.config.use_cache = False
if model_args.freeze_backbone:
model.model.requires_grad_(False)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
model_max_length=model_args.model_max_length,
padding_side="right",
use_fast=False,
)
assert model_args.version == 'v1'
if model_args.version == "v0":
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
tokenizer=tokenizer,
model=model,
)
if "llama" in model_args.model_name_or_path:
tokenizer.add_special_tokens({
"eos_token": DEFAULT_EOS_TOKEN,
"bos_token": DEFAULT_BOS_TOKEN,
"unk_token": DEFAULT_UNK_TOKEN,
})
else:
tokenizer.pad_token = tokenizer.unk_token
model_vision_dict = model.model.initialize_vision_modules(
vision_tower=model_args.vision_tower,
mm_vision_select_layer=model_args.mm_vision_select_layer,
pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter
)
dtype = torch.float32
if training_args.fp16:
dtype = torch.float16
if training_args.bf16:
dtype = torch.bfloat16
# HACK for quantization
if model.model.vision_tower[0].device != torch.device('meta'):
model.model.vision_tower[0].to(dtype=dtype, device=training_args.device)
else:
from transformers import CLIPVisionModel
model.model.vision_tower[0] = CLIPVisionModel.from_pretrained(model_args.vision_tower) # not quantize clip
# model.model.vision_tower[0] = CLIPVisionModel.from_pretrained(model_args.vision_tower, **kwargs) # quantize clip、
vision_config = model_vision_dict['vision_config']
model.config.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
if model_args.tune_mm_mlp_adapter:
model.requires_grad_(False)
for p in model.model.mm_projector.parameters():
p.requires_grad = True
model.config.freeze_mm_mlp_adapter = model_args.freeze_mm_mlp_adapter
if model_args.freeze_mm_mlp_adapter:
for p in model.model.mm_projector.parameters():
p.requires_grad = False
model.config.mm_use_im_start_end = model_args.mm_use_im_start_end
vision_config.use_im_start_end = model_args.mm_use_im_start_end
model.initialize_vision_tokenizer(mm_use_im_start_end=model_args.mm_use_im_start_end,
tokenizer=tokenizer,
device=training_args.device,
tune_mm_mlp_adapter=model_args.tune_mm_mlp_adapter,
pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter)
params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
if len(params_no_grad) > 0:
if training_args.fsdp is not None and len(training_args.fsdp) > 0:
if len(params_no_grad) < 10:
print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'.format(len(params_no_grad),
params_no_grad))
else:
print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'.format(
len(params_no_grad), ', '.join(params_no_grad[:10])))
print("[WARNING] Attempting to use FSDP with partially frozen parameters, this is experimental.")
print(
"[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
def patch_FSDP_use_orig_params(func):
def wrap_func(*args, **kwargs):
use_orig_params = kwargs.pop('use_orig_params', True)
return func(*args, **kwargs, use_orig_params=use_orig_params)
return wrap_func
FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
preprocessor = dict(
image=model_vision_dict['image_processor'],
text=tokenizer,
conv=dict(
image_token_len=model_args.image_token_len,
sep_image_conv_front=model_args.sep_image_conv_front,
use_im_start_end=model_args.mm_use_im_start_end,
)
)
return model, preprocessor
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|