File size: 3,539 Bytes
c6a12ae
 
 
 
39e3e15
c6a12ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torch.nn as nn
import os
from peft import PeftModel
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer
from .projection_layer import ProjectionBlock

class MultimodalPhiModel(PreTrainedModel):
    def gradient_checkpointing_enable(self, **kwargs):
        self.phi_model.gradient_checkpointing_enable(**kwargs)

    def gradient_checkpointing_disable(self):
        self.phi_model.gradient_checkpointing_disable()

    def __init__(self, phi_model, tokenizer, projection):
        super().__init__(phi_model.config)
        self.phi_model = phi_model
        self.image_projection = projection
        self.tokenizer = tokenizer
        self.base_phi_model = None

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, debug=False, **kwargs):
        model_name = "microsoft/Phi-3.5-mini-instruct"
        base_phi_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
        )
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

        phi_path = pretrained_model_name_or_path

        model = PeftModel.from_pretrained(base_phi_model, phi_path)
        phi_model = model.merge_and_unload()

        input_dim = 512
        output_dim = 3072

        projector_path = os.path.join(pretrained_model_name_or_path, "image_projector.pth")
        if os.path.exists(projector_path):
            projector_state_dict = torch.load(projector_path, map_location=phi_model.device)
            projector = ProjectionBlock(input_dim, output_dim)
            projector.load_state_dict(projector_state_dict, strict=False)
            print(f"Loaded projector with input_dim={input_dim}, output_dim={output_dim}")
        else:
            print(f"Projector weights not found at {projector_path}. Initializing with default dimensions.")
            input_dim = 512
            output_dim = phi_model.config.hidden_size
            projector = ProjectionBlock(input_dim, output_dim)

        model = cls(phi_model, tokenizer, projector)
        model.base_phi_model = base_phi_model
        return model

    def save_pretrained(self, save_directory):
        self.phi_model.save_pretrained(save_directory)
        projector_path = os.path.join(save_directory, "image_projector.pth")
        torch.save(self.image_projection.state_dict(), projector_path)
        self.config.save_pretrained(save_directory)

    def encode(self, image_features):
        image_projections = self.image_projection(image_features)
        return image_projections

    def forward(self, start_input_ids, end_input_ids, image_features, attention_mask, labels):
        device = next(self.parameters()).device

        start_embeds = self.phi_model.get_input_embeddings()(start_input_ids.to(device))
        end_embeds = self.phi_model.get_input_embeddings()(end_input_ids.to(device))

        if image_features is not None:
            image_embeddings = self.encode(image_features.to(device)).bfloat16()
            input_embeds = torch.cat([start_embeds, image_embeddings, end_embeds], dim=1)
        else:
            input_embeds = torch.cat([start_embeds, end_embeds], dim=1)

        outputs = self.phi_model(inputs_embeds=input_embeds.to(device), 
                                 attention_mask=attention_mask.to(device), 
                                 labels=labels, 
                                 return_dict=True)
        
        return outputs