Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import os | |
from peft import PeftModel | |
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer | |
from .projection_layer import ProjectionBlock | |
class MultimodalPhiModel(PreTrainedModel): | |
def gradient_checkpointing_enable(self, **kwargs): | |
self.phi_model.gradient_checkpointing_enable(**kwargs) | |
def gradient_checkpointing_disable(self): | |
self.phi_model.gradient_checkpointing_disable() | |
def __init__(self, phi_model, tokenizer, projection): | |
super().__init__(phi_model.config) | |
self.phi_model = phi_model | |
self.image_projection = projection | |
self.tokenizer = tokenizer | |
self.base_phi_model = None | |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, debug=False, **kwargs): | |
model_name = "microsoft/Phi-3.5-mini-instruct" | |
base_phi_model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
phi_path = pretrained_model_name_or_path | |
model = PeftModel.from_pretrained(base_phi_model, phi_path) | |
phi_model = model.merge_and_unload() | |
input_dim = 512 | |
output_dim = 3072 | |
projector_path = os.path.join(pretrained_model_name_or_path, "image_projector.pth") | |
if os.path.exists(projector_path): | |
projector_state_dict = torch.load(projector_path, map_location=phi_model.device) | |
projector = ProjectionBlock(input_dim, output_dim) | |
projector.load_state_dict(projector_state_dict, strict=False) | |
print(f"Loaded projector with input_dim={input_dim}, output_dim={output_dim}") | |
else: | |
print(f"Projector weights not found at {projector_path}. Initializing with default dimensions.") | |
input_dim = 512 | |
output_dim = phi_model.config.hidden_size | |
projector = ProjectionBlock(input_dim, output_dim) | |
model = cls(phi_model, tokenizer, projector) | |
model.base_phi_model = base_phi_model | |
return model | |
def save_pretrained(self, save_directory): | |
self.phi_model.save_pretrained(save_directory) | |
projector_path = os.path.join(save_directory, "image_projector.pth") | |
torch.save(self.image_projection.state_dict(), projector_path) | |
self.config.save_pretrained(save_directory) | |
def encode(self, image_features): | |
image_projections = self.image_projection(image_features) | |
return image_projections | |
def forward(self, start_input_ids, end_input_ids, image_features, attention_mask, labels): | |
device = next(self.parameters()).device | |
start_embeds = self.phi_model.get_input_embeddings()(start_input_ids.to(device)) | |
end_embeds = self.phi_model.get_input_embeddings()(end_input_ids.to(device)) | |
if image_features is not None: | |
image_embeddings = self.encode(image_features.to(device)).bfloat16() | |
input_embeds = torch.cat([start_embeds, image_embeddings, end_embeds], dim=1) | |
else: | |
input_embeds = torch.cat([start_embeds, end_embeds], dim=1) | |
outputs = self.phi_model(inputs_embeds=input_embeds.to(device), | |
attention_mask=attention_mask.to(device), | |
labels=labels, | |
return_dict=True) | |
return outputs |