Tu2003716 commited on
Commit
3bfe4a8
·
verified ·
1 Parent(s): e203c18
Files changed (1) hide show
  1. modeling_cocom.py +5 -42
modeling_cocom.py CHANGED
@@ -3,45 +3,11 @@ import torch
3
  import math
4
  from peft import get_peft_model, LoraConfig, TaskType
5
  import os
6
- from flash_attn.flash_attn_interface import flash_attn_func
7
- import torch.nn as nn
8
- import torch
9
 
10
  def freeze_model(model):
11
  for param in model.parameters():
12
  param.requires_grad = False
13
 
14
- class CustomFlashAttention(nn.Module):
15
- def __init__(self, embed_dim, num_heads, dropout=0.0):
16
- super().__init__()
17
- self.embed_dim = embed_dim
18
- self.num_heads = num_heads
19
- self.dropout = dropout
20
- self.head_dim = embed_dim // num_heads
21
- assert self.head_dim * num_heads == embed_dim, "Embedding size must be divisible by the number of heads."
22
-
23
- # Define projection layers
24
- self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
25
- self.out_proj = nn.Linear(embed_dim, embed_dim)
26
-
27
- def forward(self, hidden_states):
28
- batch_size, seq_len, embed_dim = hidden_states.size()
29
- qkv = self.qkv_proj(hidden_states) # Project to Q, K, V
30
- qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
31
- qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch_size, num_heads, seq_len, head_dim)
32
- query, key, value = qkv[0], qkv[1], qkv[2]
33
-
34
- # FlashAttention expects contiguous inputs
35
- query = query.contiguous()
36
- key = key.contiguous()
37
- value = value.contiguous()
38
-
39
- # Apply FlashAttention
40
- attn_output, _ = flash_attn_func(query, key, value, dropout_p=self.dropout, causal=False)
41
-
42
- # Reshape and project back to the original dimension
43
- attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
44
- return self.out_proj(attn_output)
45
 
46
  class BERT_Compressor(torch.nn.Module):
47
  def __init__(self, compr_model_name, compr_rate, compr_linear_type, decoder_hidden_size):
@@ -109,7 +75,7 @@ class COCOMConfig(PretrainedConfig):
109
  device_map = "cuda",
110
  **kwargs):
111
  super().__init__(**kwargs)
112
-
113
  self.decoder_model_name = decoder_model_name # model name of decoder
114
  self.quantization = quantization # quantization, could be no, int4, int8
115
  self.generation_top_k = generation_top_k # top k for each query, for pretraining, set to 1
@@ -226,12 +192,6 @@ class COCOM(PreTrainedModel):
226
  self.sep = cfg.sep
227
  self.compr_rate = cfg.compr_rate
228
  self.local_rank = os.getenv('LOCAL_RANK', '0')
229
- for layer in self.decoder.encoder.layer:
230
- layer.attention.self = CustomFlashAttention(
231
- embed_dim=cfg.hidden_size,
232
- num_heads=cfg.num_attention_heads,
233
- dropout=cfg.attention_probs_dropout_prob,
234
- )
235
 
236
  def compress_and_replace_emb(self, enc_input_ids, enc_attention_mask, dec_input_ids):
237
  indices = range(0, enc_input_ids.size(0) + 1, self.generation_top_k)
@@ -348,4 +308,7 @@ class COCOM(PreTrainedModel):
348
  'dec_attention_mask': inp_dec['attention_mask'].to(self.decoder.device)
349
  }
350
 
351
- return self.generate(model_input, max_new_tokens)
 
 
 
 
3
  import math
4
  from peft import get_peft_model, LoraConfig, TaskType
5
  import os
 
 
 
6
 
7
  def freeze_model(model):
8
  for param in model.parameters():
9
  param.requires_grad = False
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class BERT_Compressor(torch.nn.Module):
13
  def __init__(self, compr_model_name, compr_rate, compr_linear_type, decoder_hidden_size):
 
75
  device_map = "cuda",
76
  **kwargs):
77
  super().__init__(**kwargs)
78
+
79
  self.decoder_model_name = decoder_model_name # model name of decoder
80
  self.quantization = quantization # quantization, could be no, int4, int8
81
  self.generation_top_k = generation_top_k # top k for each query, for pretraining, set to 1
 
192
  self.sep = cfg.sep
193
  self.compr_rate = cfg.compr_rate
194
  self.local_rank = os.getenv('LOCAL_RANK', '0')
 
 
 
 
 
 
195
 
196
  def compress_and_replace_emb(self, enc_input_ids, enc_attention_mask, dec_input_ids):
197
  indices = range(0, enc_input_ids.size(0) + 1, self.generation_top_k)
 
308
  'dec_attention_mask': inp_dec['attention_mask'].to(self.decoder.device)
309
  }
310
 
311
+ return self.generate(model_input, max_new_tokens)
312
+
313
+
314
+