|
import copy |
|
|
|
import numpy as np |
|
import torch |
|
import torch.utils.checkpoint |
|
from peft.mapping import get_peft_model |
|
from peft.peft_model import PeftModel |
|
from peft.tuners.lora import LoraConfig |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss |
|
from transformers import AutoConfig, AutoTokenizer |
|
from transformers.generation.streamers import BaseStreamer |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPast, |
|
CausalLMOutputWithPast, |
|
) |
|
from transformers.utils import ( |
|
replace_return_docstrings, |
|
) |
|
|
|
from designvlm.crello_dataset import int_wrap |
|
from designvlm.loading import LoraArguments |
|
from yosematvlm.configuration_internlm_xcomposer2 import InternLMXcomposer2Config |
|
from yosematvlm.modeling_internlm_xcomposer2 import ( |
|
FROM_TOKEN_2, |
|
InternLMXComposer2ForCausalLM, |
|
) |
|
|
|
IM_END_TOKEN = 92542 |
|
EOS = 2 |
|
SWITCH_TOKEN_LENGTH = 6 |
|
MASK_ID = -100 |
|
|
|
|
|
class IXCLayoutConfig(InternLMXcomposer2Config): |
|
def __init__( |
|
self, |
|
vocab_size=103168, |
|
hidden_size=4096, |
|
intermediate_size=11008, |
|
num_hidden_layers=32, |
|
num_attention_heads=32, |
|
num_key_value_heads=None, |
|
hidden_act="silu", |
|
max_position_embeddings=2048, |
|
initializer_range=0.02, |
|
rms_norm_eps=1e-6, |
|
use_cache=True, |
|
pad_token_id=0, |
|
bos_token_id=1, |
|
eos_token_id=2, |
|
tie_word_embeddings=False, |
|
bias=True, |
|
rope_theta=10000, |
|
rope_scaling=None, |
|
attn_implementation="flash_attention_2", |
|
max_length: int = 16384, |
|
discrete_coordinate_tokens: int | None = None, |
|
**kwargs, |
|
): |
|
self.discrete_coordinate_tokens = discrete_coordinate_tokens |
|
super().__init__( |
|
vocab_size=vocab_size, |
|
hidden_size=hidden_size, |
|
intermediate_size=intermediate_size, |
|
num_hidden_layers=num_hidden_layers, |
|
num_attention_heads=num_attention_heads, |
|
num_key_value_heads=num_key_value_heads, |
|
hidden_act=hidden_act, |
|
max_position_embeddings=max_position_embeddings, |
|
initializer_range=initializer_range, |
|
rms_norm_eps=rms_norm_eps, |
|
use_cache=use_cache, |
|
pad_token_id=pad_token_id, |
|
bos_token_id=bos_token_id, |
|
eos_token_id=eos_token_id, |
|
tie_word_embeddings=tie_word_embeddings, |
|
bias=bias, |
|
rope_theta=rope_theta, |
|
rope_scaling=rope_scaling, |
|
attn_implementation=attn_implementation, |
|
**kwargs, |
|
) |
|
self.max_length = max_length |
|
self.use_cache = False |
|
|
|
@classmethod |
|
def with_internlm_config( |
|
cls, |
|
config: InternLMXcomposer2Config, |
|
discrete_custom_tokens: int | None = None, |
|
max_length: int = 16384, |
|
): |
|
return cls( |
|
vocab_size=config.vocab_size, |
|
hidden_size=config.hidden_size, |
|
intermediate_size=config.intermediate_size, |
|
num_hidden_layers=config.num_hidden_layers, |
|
num_attention_heads=config.num_attention_heads, |
|
num_key_value_heads=config.num_key_value_heads, |
|
hidden_act=config.hidden_act, |
|
max_position_embeddings=config.max_position_embeddings, |
|
initializer_range=config.initializer_range, |
|
rms_norm_eps=config.rms_norm_eps, |
|
use_cache=config.use_cache, |
|
pad_token_id=config.pad_token_id, |
|
bos_token_id=config.bos_token_id, |
|
eos_token_id=config.eos_token_id, |
|
tie_word_embeddings=config.tie_word_embeddings, |
|
bias=config.bias, |
|
rope_theta=config.rope_theta, |
|
rope_scaling=config.rope_scaling, |
|
attn_implementation=config.attn_implementation, |
|
discrete_coordinate_tokens=discrete_custom_tokens, |
|
max_length=max_length, |
|
) |
|
|
|
|
|
class IXCLayout(InternLMXComposer2ForCausalLM): |
|
config_class = IXCLayoutConfig |
|
|
|
def __init__(self, config: IXCLayoutConfig): |
|
super().__init__(config) |
|
|
|
self.vit.vision_tower.vision_model.post_layernorm = torch.nn.Identity() |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
"yosematvlm", |
|
padding_side="right", |
|
use_fast=False, |
|
trust_remote_code=True, |
|
) |
|
|
|
|
|
self.coordinate_token_ids: set[int] = set() |
|
if config.discrete_coordinate_tokens is not None: |
|
self.add_coordinate_tokens(config.discrete_coordinate_tokens) |
|
self.config = config |
|
|
|
@replace_return_docstrings( |
|
output_type=CausalLMOutputWithPast, config_class=InternLMXcomposer2Config |
|
) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor | None = None, |
|
attention_mask: torch.Tensor | None = None, |
|
position_ids: torch.LongTensor | None = None, |
|
past_key_values: list[torch.FloatTensor] | None = None, |
|
inputs_embeds: torch.FloatTensor | None = None, |
|
labels: torch.LongTensor | None = None, |
|
use_cache: bool | None = None, |
|
output_attentions: bool | None = None, |
|
output_hidden_states: bool | None = None, |
|
return_dict: bool | None = None, |
|
**kwargs, |
|
) -> CausalLMOutputWithPast: |
|
r""" |
|
Args: |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
Returns: |
|
""" |
|
infer_mode = "base" |
|
return_dict, output_attentions, output_hidden_states = self.or_config( |
|
return_dict=return_dict, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
if "samples" in kwargs: |
|
|
|
samples = kwargs["samples"] |
|
|
|
text = samples["text_input"] |
|
|
|
image = samples["image"][0] |
|
bs = len(samples["text_input"][0]) |
|
image_nums = [] |
|
temp_image = [] |
|
for im in image: |
|
if type(im) is list: |
|
image_nums.append(len(im)) |
|
temp_image.extend(im) |
|
else: |
|
image_nums.append(1) |
|
temp_image.append(im) |
|
image = temp_image |
|
assert type(image) is list and len(image_nums) == bs |
|
|
|
to_regress_embeds, attention_mask, targets, im_mask = self.interleav_wrap( |
|
image, text, image_nums |
|
) |
|
|
|
inputs_embeds = to_regress_embeds[:, : self.max_length] |
|
attention_mask = attention_mask[:, : self.max_length] |
|
targets = targets[:, : self.max_length] |
|
im_mask = im_mask[:, : self.max_length].bool() |
|
labels = targets |
|
|
|
elif inputs_embeds is not None or input_ids is not None: |
|
im_mask = kwargs["im_mask"] |
|
if im_mask is None and inputs_embeds is not None: |
|
im_mask = torch.zeros(inputs_embeds.shape[:2]).to(inputs_embeds.device) |
|
im_mask = im_mask.bool() |
|
else: |
|
raise ValueError( |
|
"Either samples, inputs_embeds or input_ids should be provided." |
|
) |
|
|
|
|
|
outputs: BaseModelOutputWithPast = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=True, |
|
im_mask=im_mask, |
|
infer_mode=infer_mode, |
|
) |
|
logits: torch.Tensor = self.output( |
|
outputs.last_hidden_state |
|
).float() |
|
|
|
if labels is not None: |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss_fct = CrossEntropyLoss() |
|
shift_logits = shift_logits.view(-1, self.vocab_size) |
|
shift_labels = shift_labels.view(-1) |
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
ce_loss: torch.Tensor = loss_fct(shift_logits, shift_labels) |
|
assert not ce_loss.isnan().any() |
|
kl_loss = self.coordinate_kl_loss(logits, labels) |
|
assert not kl_loss.isnan().any() |
|
loss = ce_loss + kl_loss |
|
|
|
else: |
|
loss = None |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (ce_loss,) + output if ce_loss is not None else output |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def coordinate_kl_loss( |
|
self, logits: torch.Tensor, labels: torch.Tensor, eps: float = 1e-9 |
|
) -> torch.Tensor: |
|
""" |
|
For coordinate token, calculate the KL loss between the predicted logits and the target labels. |
|
Instead of one-hot vector, we assume that the target labels are the probability distribution of the target token. |
|
The distribution is a discrete gaussian distribution with mean at the target token and variance of 1. |
|
|
|
Args: |
|
logits: B x T x V; The predicted logits of the model. |
|
labels: B x T; The target labels of the model. |
|
""" |
|
label_std_dev = 2.0 |
|
coordinate_token_ids = torch.tensor( |
|
list(self.coordinate_token_ids), device=labels.device, dtype=labels.dtype |
|
) |
|
assert len(self.coordinate_token_ids) > 0 |
|
|
|
is_label_coordinate = torch.isin( |
|
labels, |
|
coordinate_token_ids, |
|
).type_as(labels) |
|
|
|
|
|
indices = torch.arange( |
|
0, self.vocab_size, dtype=labels.dtype, device=logits.device |
|
).repeat(labels.shape[0], labels.shape[1], 1) |
|
|
|
|
|
|
|
is_indice_coordinate_token = torch.stack( |
|
[ |
|
torch.isin(indices[:, idx], coordinate_token_ids) |
|
for idx in range(indices.size(1)) |
|
], |
|
dim=1, |
|
) |
|
|
|
total_mask = is_label_coordinate.unsqueeze(-1) * is_indice_coordinate_token |
|
|
|
|
|
gauss_label = torch.exp( |
|
-0.5 * ((indices - labels.unsqueeze(-1)) / label_std_dev) ** 2 |
|
) |
|
|
|
|
|
gauss_label = gauss_label * total_mask + eps |
|
|
|
|
|
gauss_label /= gauss_label.sum(dim=-1, keepdim=True) |
|
|
|
pointwise_kl = ( |
|
nn.functional.kl_div( |
|
logits.log_softmax(dim=-1), gauss_label, reduction="none" |
|
).nan_to_num() |
|
* total_mask |
|
) |
|
kl_loss = pointwise_kl.sum(-1).mean() |
|
assert kl_loss > 0 |
|
return kl_loss |
|
|
|
def interleav_wrap( |
|
self, |
|
img_list: list[torch.Tensor], |
|
text_list_list: list[list[str]], |
|
image_nums: list[int], |
|
): |
|
temp_embeds = [] |
|
temp_im_mask = [] |
|
temp_tars = [] |
|
|
|
|
|
if len(img_list) > 0: |
|
img_embeds, img_split = self.vit( |
|
img_list, self.plora_glb_GN, self.plora_sub_GN |
|
) |
|
img_embeds = self.vision_proj(img_embeds) |
|
else: |
|
img_embeds = None |
|
img_split = [] |
|
|
|
text_list = text_list_list[0] |
|
for idx, text in enumerate(text_list): |
|
image_num = image_nums[idx] |
|
im_id = int(np.sum(image_nums[:idx])) |
|
images = [] |
|
for i in range(image_nums[idx]): |
|
st = int(np.sum(img_split[: im_id + i])) |
|
sp = img_split[im_id + i] |
|
temp_img = img_embeds[:, st : st + sp] |
|
images.append(temp_img) |
|
|
|
if image_num == 1 and text.find("<ImageHere>") == -1: |
|
text = "<ImageHere>" + text |
|
parts = text.split("<ImageHere>") |
|
|
|
wrap_tokens, wrap_embeds, wrap_im_mask = [], [], [] |
|
temp_len = 0 |
|
need_bos = True |
|
for idx, part in enumerate(parts): |
|
if len(part) > 0: |
|
part_tokens = self.tokenizer( |
|
part, |
|
return_tensors="pt", |
|
padding="longest", |
|
add_special_tokens=need_bos, |
|
).to(self.device) |
|
if need_bos: |
|
need_bos = False |
|
wrap_tokens.append(part_tokens.input_ids) |
|
part_embeds = self.model.tok_embeddings(part_tokens.input_ids) |
|
wrap_embeds.append(part_embeds) |
|
wrap_im_mask.append( |
|
torch.zeros(part_embeds.shape[:2]).to(self.device) |
|
) |
|
temp_len += part_embeds.shape[1] |
|
if idx < image_num: |
|
wrap_embeds.append(images[idx]) |
|
wrap_token = ( |
|
torch.ones(images[idx].shape[:2], dtype=torch.long).to( |
|
self.device |
|
) |
|
* -100 |
|
) |
|
wrap_tokens.append(wrap_token) |
|
wrap_im_mask.append( |
|
torch.ones(images[idx].shape[:2]).to(self.device) |
|
) |
|
temp_len += images[idx].shape[1] |
|
if temp_len > self.max_length: |
|
break |
|
wrap_tokens = torch.cat(wrap_tokens, dim=1) |
|
wrap_embeds = torch.cat(wrap_embeds, dim=1) |
|
wrap_im_mask = torch.cat(wrap_im_mask, dim=1) |
|
|
|
wrap_target = self.mask_human_targets(wrap_tokens).to(self.device) |
|
|
|
temp_embeds.append(wrap_embeds) |
|
temp_im_mask.append(wrap_im_mask) |
|
temp_tars.append(wrap_target) |
|
|
|
temp_max_len = np.max([i.shape[1] for i in temp_embeds]) |
|
temp_max_len = min(temp_max_len, self.max_length) |
|
|
|
final_input, final_atts, final_tars, final_mask = [], [], [], [] |
|
pad = torch.ones([1, 1]) * self.tokenizer.pad_token_id |
|
pad = pad.long().to(self.device) |
|
pad_emb = self.model.tok_embeddings(pad) |
|
|
|
for idx in range(len(temp_embeds)): |
|
temp_len = temp_embeds[idx].shape[1] |
|
if temp_len >= temp_max_len: |
|
final_input.append(temp_embeds[idx][:, :temp_max_len]) |
|
final_atts.append( |
|
torch.ones(1, temp_max_len).to(wrap_target.dtype).to(self.device) |
|
) |
|
final_tars.append(temp_tars[idx][:, :temp_max_len]) |
|
final_mask.append(temp_im_mask[idx][:, :temp_max_len]) |
|
else: |
|
final_input.append( |
|
torch.cat( |
|
[ |
|
temp_embeds[idx], |
|
pad_emb.repeat(1, temp_max_len - temp_len, 1), |
|
], |
|
dim=1, |
|
) |
|
) |
|
final_atts.append( |
|
torch.cat( |
|
[ |
|
torch.ones(1, temp_len), |
|
torch.zeros(1, temp_max_len - temp_len), |
|
], |
|
dim=1, |
|
) |
|
.to(wrap_target.dtype) |
|
.to(self.device) |
|
) |
|
final_tars.append( |
|
torch.cat( |
|
[ |
|
temp_tars[idx], |
|
(torch.ones(1, temp_max_len - temp_len) * MASK_ID) |
|
.to(wrap_target.dtype) |
|
.to(self.device), |
|
], |
|
dim=1, |
|
) |
|
) |
|
final_mask.append( |
|
torch.cat( |
|
[ |
|
temp_im_mask[idx], |
|
(torch.zeros(1, temp_max_len - temp_len)) |
|
.to(wrap_target.dtype) |
|
.to(self.device), |
|
], |
|
dim=1, |
|
) |
|
) |
|
|
|
inputs_embeds = torch.cat(final_input, dim=0) |
|
attention_mask = torch.cat(final_atts, dim=0) |
|
targets = torch.cat(final_tars, dim=0) |
|
im_mask = torch.cat(final_mask, dim=0) |
|
|
|
return inputs_embeds, attention_mask, targets, im_mask |
|
|
|
def mask_human_targets(self, input_ids: torch.Tensor) -> torch.Tensor: |
|
target_batch = [] |
|
for bs in range(input_ids.shape[0]): |
|
ids = input_ids[bs] |
|
targets = copy.deepcopy(ids) |
|
end_count = 0 |
|
last_eoa = 0 |
|
for i, temp_id in enumerate(ids): |
|
|
|
|
|
if temp_id == IM_END_TOKEN: |
|
if end_count % 2 == 0: |
|
targets[last_eoa : i + SWITCH_TOKEN_LENGTH] = MASK_ID |
|
else: |
|
last_eoa = i + 1 |
|
end_count += 1 |
|
|
|
elif temp_id == EOS: |
|
|
|
targets[i + 1 :] = MASK_ID |
|
break |
|
target_batch.append(targets.unsqueeze(0)) |
|
target_batch = torch.cat(target_batch, dim=0) |
|
return target_batch |
|
|
|
@classmethod |
|
def from_ixc_pretrained( |
|
cls, |
|
ixc_pretrained_model_name_or_path: str, |
|
max_length: int, |
|
img_size: int, |
|
discrete_custom_tokens: int | None = 128, |
|
) -> torch.nn.Module: |
|
r""" |
|
Instantiate a pretrained InternLMXComposer2 model from a pre-trained model configuration. |
|
""" |
|
config: IXCLayoutConfig = AutoConfig.from_pretrained( |
|
ixc_pretrained_model_name_or_path, |
|
trust_remote_code=True, |
|
) |
|
config = IXCLayoutConfig.with_internlm_config( |
|
config, |
|
max_length=max_length, |
|
discrete_custom_tokens=discrete_custom_tokens, |
|
) |
|
|
|
model: IXCLayout = super().from_pretrained( |
|
ixc_pretrained_model_name_or_path, |
|
config=config, |
|
torch_dtype=torch.bfloat16, |
|
) |
|
if img_size != 336: |
|
model.vit.resize_pos() |
|
model.vit.requires_grad_(False) |
|
model.vision_proj.requires_grad_(True) |
|
|
|
return model |
|
|
|
def or_config( |
|
self, |
|
return_dict: bool | None, |
|
output_attentions: bool | None, |
|
output_hidden_states: bool | None, |
|
): |
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
return return_dict, output_attentions, output_hidden_states |
|
|
|
@torch.no_grad() |
|
def chat( |
|
self, |
|
query: str, |
|
image: list[tuple[str, str]] | torch.Tensor = [], |
|
hd_num: int = 24, |
|
history: list[tuple[str, str]] = [], |
|
streamer: BaseStreamer | None = None, |
|
max_new_tokens: int = 1024, |
|
do_sample: bool = True, |
|
num_beams: int = 1, |
|
temperature: float = 1.0, |
|
top_p: float = 0.8, |
|
repetition_penalty: float = 1.005, |
|
infer_mode: str = "base", |
|
use_meta: bool = False, |
|
meta_instruction: str = "You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n" |
|
"- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" |
|
"- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen by the user such as English and 中文.\n" |
|
"- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating responses effectively based on the provided image.", |
|
**kwargs, |
|
): |
|
if not use_meta: |
|
meta_instruction = "" |
|
|
|
inputs, im_mask, _ = self.interleav_wrap_chat( |
|
query, |
|
image, |
|
history=history, |
|
meta_instruction=meta_instruction, |
|
hd_num=hd_num, |
|
) |
|
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)} |
|
|
|
eos_token_id = [ |
|
self.tokenizer.eos_token_id, |
|
self.tokenizer.convert_tokens_to_ids([FROM_TOKEN_2])[0], |
|
] |
|
outputs = self.generate( |
|
**inputs, |
|
streamer=streamer, |
|
max_new_tokens=max_new_tokens, |
|
num_beams=num_beams, |
|
do_sample=do_sample, |
|
temperature=temperature, |
|
top_p=top_p, |
|
eos_token_id=eos_token_id, |
|
repetition_penalty=repetition_penalty, |
|
im_mask=im_mask, |
|
infer_mode=infer_mode, |
|
**kwargs, |
|
) |
|
outputs = outputs[0].cpu().tolist() |
|
response = self.tokenizer.decode(outputs, skip_special_tokens=True) |
|
response = response.split(FROM_TOKEN_2)[0] |
|
history = history + [(query, response)] |
|
return response, history |
|
|
|
def add_coordinate_tokens(self, bins: int): |
|
start = -bins |
|
end = bins * 2 |
|
new_tokens = [int_wrap(idx) for idx in range(start, end + 1)] |
|
new_token_ids = self._add_tokens(new_tokens) |
|
self.coordinate_token_ids.update(new_token_ids) |
|
self.config.discrete_coordinate_tokens = end // 2 |
|
|
|
def _add_tokens(self, new_tokens: list[str]) -> list[int]: |
|
prev_vocab_size = len(self.tokenizer) |
|
vocab = self.tokenizer.get_vocab() |
|
new_tokens = [token for token in new_tokens if token not in vocab] |
|
self.tokenizer.add_tokens(new_tokens) |
|
self.model.resize_token_embeddings(len(self.tokenizer)) |
|
self.vocab_size = len(self.tokenizer) |
|
|
|
|
|
new_output = nn.Linear( |
|
self.model.config.hidden_size, |
|
self.vocab_size, |
|
bias=False, |
|
dtype=self.output.weight.dtype, |
|
device=self.output.weight.device, |
|
).to(self.device) |
|
new_output.weight.data[: self.output.weight.shape[0]] = self.output.weight.data |
|
self.output = new_output |
|
return list(range(prev_vocab_size, self.vocab_size)) |
|
|
|
|
|
def setup_ixclayout( |
|
model_name_or_path: str, |
|
max_length: int, |
|
img_size: int, |
|
use_lora: bool, |
|
gradient_checkpointing: bool, |
|
discrete_custom_tokens: int | None = 128, |
|
) -> IXCLayout | PeftModel: |
|
model = IXCLayout.from_ixc_pretrained( |
|
ixc_pretrained_model_name_or_path=model_name_or_path, |
|
max_length=max_length, |
|
img_size=img_size, |
|
discrete_custom_tokens=discrete_custom_tokens, |
|
) |
|
if use_lora: |
|
lora_args = LoraArguments() |
|
lora_config = LoraConfig( |
|
r=lora_args.lora_r, |
|
lora_alpha=lora_args.lora_alpha, |
|
target_modules=lora_args.lora_target_modules, |
|
lora_dropout=lora_args.lora_dropout, |
|
bias=lora_args.lora_bias, |
|
task_type="CAUSAL_LM", |
|
) |
|
|
|
model = get_peft_model(model, lora_config) |
|
model.print_trainable_parameters() |
|
|
|
if gradient_checkpointing: |
|
model.enable_input_require_grads() |
|
model.vit.vision_tower.gradient_checkpointing_enable( |
|
{"use_reentrant": True} |
|
) |
|
|
|
return model |
|
|