ImageCaptioning / models.py
ayushman72's picture
Upload folder using huggingface_hub
58cfd1b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from timm import create_model
from types import SimpleNamespace
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
class GPT2Attention(nn.Module):
def __init__(self,config:SimpleNamespace):
super(GPT2Attention,self).__init__()
self.embed_dim = config.embed_dim
self.n_heads = config.num_heads
assert self.embed_dim % self.n_heads == 0, "embedding dim must be divisible by num heads"
self.head_size = self.embed_dim // self.n_heads
self.seq_len = config.seq_len
self.c_attn = nn.Linear(self.embed_dim,self.embed_dim*3)
self.scale = self.head_size ** -0.5
self.register_buffer('mask',torch.tril(torch.ones(1,1,self.seq_len,self.seq_len)))
self.c_proj = nn.Linear(self.embed_dim,self.embed_dim)
self.attn_dropout = nn.Dropout(config.attention_dropout)
self.resid_dropout = nn.Dropout(config.residual_dropout)
def forward(self,x:torch.Tensor)-> torch.Tensor:
b,t,c = x.shape
q,k,v = self.c_attn(x).chunk(3,dim=-1)
q = q.view(b,t,self.n_heads,self.head_size).permute(0,2,1,3)
k = k.view(b,t,self.n_heads,self.head_size).permute(0,2,1,3)
v = v.view(b,t,self.n_heads,self.head_size).permute(0,2,1,3)
qk_t = ([email protected](-2,-1))*self.scale
qk_t = qk_t.masked_fill(self.mask[:,:,:t,:t]==0,float('-inf'))
qk_t = F.softmax(qk_t,dim=-1)
weights = self.attn_dropout(qk_t)
attention = weights@v
attention = attention.permute(0,2,1,3).contiguous().view(b,t,c)
out = self.c_proj(attention)
return self.resid_dropout(out)
class GPT2CrossAttention(nn.Module):
def __init__(self,config:SimpleNamespace):
super(GPT2CrossAttention,self).__init__()
self.embed_dim = config.embed_dim
self.n_heads = config.num_heads
assert self.embed_dim %self.n_heads == 0, "embedding dim must be divisible by num heads"
self.head_size = self.embed_dim // self.n_heads
self.seq_len = config.seq_len
self.q = nn.Linear(self.embed_dim,self.embed_dim)
self.k = nn.Linear(self.embed_dim,self.embed_dim)
self.v = nn.Linear(self.embed_dim,self.embed_dim)
self.scale = self.head_size ** -0.5
self.c_proj = nn.Linear(self.embed_dim,self.embed_dim)
self.attn_dropout = nn.Dropout(config.attention_dropout)
self.resid_dropout = nn.Dropout(config.residual_dropout)
self.apply(self._init_weights)
def _init_weights(self,module):
if isinstance(module,nn.Linear):
nn.init.normal_(module.weight,mean=0.0,std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self,q:torch.Tensor,k:torch.Tensor,v:torch.Tensor)->torch.Tensor:
b,t,c = q.shape
q,k,v = self.q(q),self.k(k),self.v(v)
q = q.view(b,q.size(1),self.n_heads,self.head_size).permute(0,2,1,3)
k = k.view(b,k.size(1),self.n_heads,self.head_size).permute(0,2,1,3)
v = v.view(b,v.size(1),self.n_heads,self.head_size).permute(0,2,1,3)
qk_t = ([email protected](-2,-1))*self.scale
qk_t = F.softmax(qk_t,dim=-1)
weights = self.attn_dropout(qk_t)
attention = weights@v
attention = attention.permute(0,2,1,3).contiguous().view(b,t,c)
out = self.c_proj(attention)
return self.resid_dropout(out)
class GPT2MLP(nn.Module):
def __init__(self,config:SimpleNamespace):
super().__init__()
self.embed_dim = config.embed_dim
self.mlp_ratio = config.mlp_ratio
self.mlp_dropout = config.mlp_dropout
self.c_fc = nn.Linear(self.embed_dim,self.embed_dim*self.mlp_ratio)
self.c_proj = nn.Linear(self.embed_dim*self.mlp_ratio,self.embed_dim)
self.act = nn.GELU()
self.dropout = nn.Dropout(self.mlp_dropout)
def forward(self,x:torch.Tensor)->torch.Tensor:
x = self.c_fc(x)
x = self.act(x)
x = self.c_proj(x)
return self.dropout(x)
class GPT2Block(nn.Module):
def __init__(self,config:SimpleNamespace):
super(GPT2Block,self).__init__()
self.embed_dim = config.embed_dim
self.ln_1 = nn.LayerNorm(self.embed_dim)
self.attn = GPT2Attention(config)
self.ln_2 = nn.LayerNorm(self.embed_dim)
self.mlp = GPT2MLP(config)
self.ln_3 = nn.LayerNorm(self.embed_dim)
self.cross_attn = GPT2CrossAttention(config)
def forward(self,x:torch.Tensor,enc_out:torch.Tensor)->torch.Tensor:
x = x+self.attn(self.ln_1(x))
x = x+self.cross_attn(self.ln_2(x),enc_out,enc_out)
x = x+self.mlp(self.ln_3(x))
return x
class VisionGPT2Model(nn.Module):
def __init__(self,config:SimpleNamespace):
super(VisionGPT2Model,self).__init__()
self.config = config
vit = create_model('vit_base_patch16_224',pretrained=True,num_classes=0)
self.patch_embed = vit.patch_embed
num_patches = self.patch_embed.num_patches
self.cls_token = vit.cls_token
embed_len = num_patches + vit.num_prefix_tokens
self.pos_embed = vit.pos_embed
self.blocks = nn.ModuleList([vit.blocks[i] for i in range(config.depth)])
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size,config.embed_dim),
wpe = nn.Embedding(config.seq_len,config.embed_dim),
drop = nn.Dropout(config.emb_dropout),
h = nn.ModuleList([GPT2Block(config) for _ in range(config.depth)]),
ln_f = nn.LayerNorm(config.embed_dim),
))
self.lm_head = nn.Linear(config.embed_dim,config.vocab_size,bias= False)
self.transformer.wte.weight = self.lm_head.weight
def _pos_embed(self,x:torch.Tensor)->torch.Tensor:
pos_embed = self.pos_embed
x = torch.cat((self.cls_token.expand(x.shape[0],-1,-1),x),dim =1)
x = x+pos_embed
return x
def pretrained_layers_trainable(self,t:bool = False)->None:
layers =[
self.cls_token,self.patch_embed,self.pos_embed,self.blocks,
self.transformer.wte,self.transformer.wpe,
self.transformer.ln_f,self.lm_head
]
gpt_layers = [[
self.transformer.h[i].ln_1,self.transformer.h[i].ln_2,
self.transformer.h[i].attn,self.transformer.h[i].mlp
]for i in range(self.config.depth)]
for l in gpt_layers:
layers.extend(l)
for layer in layers:
if not isinstance(layer,nn.Parameter):
for p in layer.parameters():
p.requires_grad = t
else:
layer.requires_grad = t
total_frozen_params = sum([p.numel() for p in self.parameters() if not p.requires_grad])
print(f"{total_frozen_params =}")
def unfreeze_gpt_layers(self)->None:
gpt_layers = [[
self.transformer.h[i].ln_1,self.transformer.h[i].ln_2,
self.transformer.h[i].attn,self.transformer.h[i].mlp
]for i in range(self.config.depth)]
flatten = []
for l in gpt_layers:
flatten.extend(l)
for layer in flatten:
if not isinstance(layer,nn.Parameter):
for p in layer.parameters():
p.requires_grad = True
else:
layer.requires_grad = True
@classmethod
def from_pretrained(self,config:SimpleNamespace):
model = VisionGPT2Model(config)
sd = model.state_dict()
keys = sd.keys()
ignore_matches = ['blocks.','cross_attn.','ln_3','cls_token',
'pos_embed','patch_embed.','.attn.mask']
vit_keys = [key for key in keys if any(match in key for match in ignore_matches)]
gpt_keys = [key for key in keys if key not in vit_keys]
gpt2_small = GPT2LMHeadModel.from_pretrained('gpt2')
sd_hf = gpt2_small.state_dict()
hf_keys = sd_hf.keys()
hf_keys = [k for k in hf_keys if not k.endswith('.attn.masked_bias')]
hf_keys = [k for k in hf_keys if not k.endswith('.attn.bias')]
transposed = ['attn.c_attn.weight','attn.c_proj.weight',
'mlp.c_fc.weight','mlp.c_proj.weight']
for k in hf_keys:
if any(match in k for match in ignore_matches):
continue
if any(k.endswith(w) for w in transposed):
assert sd_hf[k].shape[::-1] == sd[k].shape
with torch.no_grad():
sd[k].copy_(sd_hf[k].t())
else:
assert sd_hf[k].shape == sd[k].shape
with torch.no_grad():
sd[k].copy_(sd_hf[k])
model.load_state_dict(sd)
return model
def forward(self,image:torch.Tensor,input_ids:torch.Tensor,labels:None|torch.Tensor=None)->torch.Tensor:
image = self.patch_embed(image)
image = self._pos_embed(image)
token_embeddings = self.transformer.wte(input_ids)
pos_embs = torch.arange(0,input_ids.size(1)).to(input_ids.device)
positional_embeddings = self.transformer.wpe(pos_embs)
input_ids = self.transformer.drop(token_embeddings+positional_embeddings)
for i in range(self.config.depth):
image = self.blocks[i](image)
input_ids = self.transformer.h[i](input_ids,image)
input_ids = self.transformer.ln_f(input_ids)
if labels is not None:
lm_logits = self.lm_head(input_ids)
loss = F.cross_entropy(lm_logits.view(-1,lm_logits.shape[-1]),labels.view(-1))
return loss
lm_logits = self.lm_head(input_ids[:,[-1],:])
return lm_logits
def generate(self,image:torch.Tensor,
sequence:torch.Tensor,
max_tokens:int =50,
temp:float =1.0,
deter:bool =False) -> torch.Tensor:
for _ in range(max_tokens):
out = self(image,sequence)
out = out[:,-1,:]/temp
probs = F.softmax(out,dim=-1)
if deter:
next_token = torch.argmax(probs,dim=-1,keepdim=True)
else:
next_token = torch.multinomial(probs,num_samples=1)
sequence = torch.cat([sequence,next_token],dim=1)
if next_token.item() == tokenizer.eos_token_id:
break
return sequence.cpu().flatten()