Recover
Browse files- modeling_cocom.py +5 -42
modeling_cocom.py
CHANGED
@@ -3,45 +3,11 @@ import torch
|
|
3 |
import math
|
4 |
from peft import get_peft_model, LoraConfig, TaskType
|
5 |
import os
|
6 |
-
from flash_attn.flash_attn_interface import flash_attn_func
|
7 |
-
import torch.nn as nn
|
8 |
-
import torch
|
9 |
|
10 |
def freeze_model(model):
|
11 |
for param in model.parameters():
|
12 |
param.requires_grad = False
|
13 |
|
14 |
-
class CustomFlashAttention(nn.Module):
|
15 |
-
def __init__(self, embed_dim, num_heads, dropout=0.0):
|
16 |
-
super().__init__()
|
17 |
-
self.embed_dim = embed_dim
|
18 |
-
self.num_heads = num_heads
|
19 |
-
self.dropout = dropout
|
20 |
-
self.head_dim = embed_dim // num_heads
|
21 |
-
assert self.head_dim * num_heads == embed_dim, "Embedding size must be divisible by the number of heads."
|
22 |
-
|
23 |
-
# Define projection layers
|
24 |
-
self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
|
25 |
-
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
26 |
-
|
27 |
-
def forward(self, hidden_states):
|
28 |
-
batch_size, seq_len, embed_dim = hidden_states.size()
|
29 |
-
qkv = self.qkv_proj(hidden_states) # Project to Q, K, V
|
30 |
-
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
31 |
-
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch_size, num_heads, seq_len, head_dim)
|
32 |
-
query, key, value = qkv[0], qkv[1], qkv[2]
|
33 |
-
|
34 |
-
# FlashAttention expects contiguous inputs
|
35 |
-
query = query.contiguous()
|
36 |
-
key = key.contiguous()
|
37 |
-
value = value.contiguous()
|
38 |
-
|
39 |
-
# Apply FlashAttention
|
40 |
-
attn_output, _ = flash_attn_func(query, key, value, dropout_p=self.dropout, causal=False)
|
41 |
-
|
42 |
-
# Reshape and project back to the original dimension
|
43 |
-
attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
|
44 |
-
return self.out_proj(attn_output)
|
45 |
|
46 |
class BERT_Compressor(torch.nn.Module):
|
47 |
def __init__(self, compr_model_name, compr_rate, compr_linear_type, decoder_hidden_size):
|
@@ -109,7 +75,7 @@ class COCOMConfig(PretrainedConfig):
|
|
109 |
device_map = "cuda",
|
110 |
**kwargs):
|
111 |
super().__init__(**kwargs)
|
112 |
-
|
113 |
self.decoder_model_name = decoder_model_name # model name of decoder
|
114 |
self.quantization = quantization # quantization, could be no, int4, int8
|
115 |
self.generation_top_k = generation_top_k # top k for each query, for pretraining, set to 1
|
@@ -226,12 +192,6 @@ class COCOM(PreTrainedModel):
|
|
226 |
self.sep = cfg.sep
|
227 |
self.compr_rate = cfg.compr_rate
|
228 |
self.local_rank = os.getenv('LOCAL_RANK', '0')
|
229 |
-
for layer in self.decoder.encoder.layer:
|
230 |
-
layer.attention.self = CustomFlashAttention(
|
231 |
-
embed_dim=cfg.hidden_size,
|
232 |
-
num_heads=cfg.num_attention_heads,
|
233 |
-
dropout=cfg.attention_probs_dropout_prob,
|
234 |
-
)
|
235 |
|
236 |
def compress_and_replace_emb(self, enc_input_ids, enc_attention_mask, dec_input_ids):
|
237 |
indices = range(0, enc_input_ids.size(0) + 1, self.generation_top_k)
|
@@ -348,4 +308,7 @@ class COCOM(PreTrainedModel):
|
|
348 |
'dec_attention_mask': inp_dec['attention_mask'].to(self.decoder.device)
|
349 |
}
|
350 |
|
351 |
-
return self.generate(model_input, max_new_tokens)
|
|
|
|
|
|
|
|
3 |
import math
|
4 |
from peft import get_peft_model, LoraConfig, TaskType
|
5 |
import os
|
|
|
|
|
|
|
6 |
|
7 |
def freeze_model(model):
|
8 |
for param in model.parameters():
|
9 |
param.requires_grad = False
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
class BERT_Compressor(torch.nn.Module):
|
13 |
def __init__(self, compr_model_name, compr_rate, compr_linear_type, decoder_hidden_size):
|
|
|
75 |
device_map = "cuda",
|
76 |
**kwargs):
|
77 |
super().__init__(**kwargs)
|
78 |
+
|
79 |
self.decoder_model_name = decoder_model_name # model name of decoder
|
80 |
self.quantization = quantization # quantization, could be no, int4, int8
|
81 |
self.generation_top_k = generation_top_k # top k for each query, for pretraining, set to 1
|
|
|
192 |
self.sep = cfg.sep
|
193 |
self.compr_rate = cfg.compr_rate
|
194 |
self.local_rank = os.getenv('LOCAL_RANK', '0')
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
def compress_and_replace_emb(self, enc_input_ids, enc_attention_mask, dec_input_ids):
|
197 |
indices = range(0, enc_input_ids.size(0) + 1, self.generation_top_k)
|
|
|
308 |
'dec_attention_mask': inp_dec['attention_mask'].to(self.decoder.device)
|
309 |
}
|
310 |
|
311 |
+
return self.generate(model_input, max_new_tokens)
|
312 |
+
|
313 |
+
|
314 |
+
|