Eden-Multimodal / models /multimodel_phi.py
Himank Jain
fixed tokenizer imports
39e3e15
import torch
import torch.nn as nn
import os
from peft import PeftModel
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer
from .projection_layer import ProjectionBlock
class MultimodalPhiModel(PreTrainedModel):
def gradient_checkpointing_enable(self, **kwargs):
self.phi_model.gradient_checkpointing_enable(**kwargs)
def gradient_checkpointing_disable(self):
self.phi_model.gradient_checkpointing_disable()
def __init__(self, phi_model, tokenizer, projection):
super().__init__(phi_model.config)
self.phi_model = phi_model
self.image_projection = projection
self.tokenizer = tokenizer
self.base_phi_model = None
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, debug=False, **kwargs):
model_name = "microsoft/Phi-3.5-mini-instruct"
base_phi_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
phi_path = pretrained_model_name_or_path
model = PeftModel.from_pretrained(base_phi_model, phi_path)
phi_model = model.merge_and_unload()
input_dim = 512
output_dim = 3072
projector_path = os.path.join(pretrained_model_name_or_path, "image_projector.pth")
if os.path.exists(projector_path):
projector_state_dict = torch.load(projector_path, map_location=phi_model.device)
projector = ProjectionBlock(input_dim, output_dim)
projector.load_state_dict(projector_state_dict, strict=False)
print(f"Loaded projector with input_dim={input_dim}, output_dim={output_dim}")
else:
print(f"Projector weights not found at {projector_path}. Initializing with default dimensions.")
input_dim = 512
output_dim = phi_model.config.hidden_size
projector = ProjectionBlock(input_dim, output_dim)
model = cls(phi_model, tokenizer, projector)
model.base_phi_model = base_phi_model
return model
def save_pretrained(self, save_directory):
self.phi_model.save_pretrained(save_directory)
projector_path = os.path.join(save_directory, "image_projector.pth")
torch.save(self.image_projection.state_dict(), projector_path)
self.config.save_pretrained(save_directory)
def encode(self, image_features):
image_projections = self.image_projection(image_features)
return image_projections
def forward(self, start_input_ids, end_input_ids, image_features, attention_mask, labels):
device = next(self.parameters()).device
start_embeds = self.phi_model.get_input_embeddings()(start_input_ids.to(device))
end_embeds = self.phi_model.get_input_embeddings()(end_input_ids.to(device))
if image_features is not None:
image_embeddings = self.encode(image_features.to(device)).bfloat16()
input_embeds = torch.cat([start_embeds, image_embeddings, end_embeds], dim=1)
else:
input_embeds = torch.cat([start_embeds, end_embeds], dim=1)
outputs = self.phi_model(inputs_embeds=input_embeds.to(device),
attention_mask=attention_mask.to(device),
labels=labels,
return_dict=True)
return outputs