Linformer intergration

#6
by Plasmarine - opened
Files changed (1) hide show
  1. modeling_cocom.py +62 -220
modeling_cocom.py CHANGED
@@ -1,32 +1,34 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PretrainedConfig, AutoModel
 
2
  import torch
3
  import math
4
  from peft import get_peft_model, LoraConfig, TaskType
5
  import os
6
 
 
 
7
  def freeze_model(model):
8
  for param in model.parameters():
9
  param.requires_grad = False
10
 
11
 
 
12
  class BERT_Compressor(torch.nn.Module):
13
  def __init__(self, compr_model_name, compr_rate, compr_linear_type, decoder_hidden_size):
14
  super().__init__()
15
- # init model
16
- self.model_name = compr_model_name # base model name of BERT; example: bert-base-ucased
17
  self.model = AutoModel.from_pretrained(compr_model_name, torch_dtype=torch.float16)
18
  self.tokenizer = AutoTokenizer.from_pretrained(compr_model_name, use_fast=True)
19
- self.compr_rate = compr_rate # compression rate
20
- self.compressing_mode = compr_linear_type # linear layer type, could be either concat or mean.
21
 
22
- if self.compressing_mode == 'concat': # default setting in paper
23
  self.linear = torch.nn.Linear(self.model.config.hidden_size*self.compr_rate, decoder_hidden_size)
24
  elif self.compressing_mode == 'mean':
25
  self.linear = torch.nn.Linear(self.model.config.hidden_size, decoder_hidden_size)
26
  self.linear = self.linear.float16()
27
 
28
  def forward(self, input_ids, attention_mask):
29
- # compressing context using BERT
30
  segment_compress_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
31
  num_embs = math.ceil(input_ids.size(1) / self.compr_rate)
32
  all_hidden_states_emb = list()
@@ -42,23 +44,18 @@ class BERT_Compressor(torch.nn.Module):
42
  start_idx = segment_idx * self.compr_rate
43
  end_idx = (segment_idx + 1) * self.compr_rate
44
  hidden_state = segment_compress_outputs.hidden_states[-1][:, start_idx:end_idx, :]
45
- # Apply mean pooling to get the final embedding for the segment
46
  all_hidden_states_emb.append(hidden_state)
47
- else:
48
- raise NotImplementedError()
49
-
50
  all_hidden_states_emb_cat = torch.stack(all_hidden_states_emb, dim=1)
51
  transformed_embeds = self.linear(all_hidden_states_emb_cat)
52
 
53
-
54
  if self.compressing_mode == "mean":
55
  transformed_embeds = torch.mean(transformed_embeds, dim=2)
56
 
57
- # dimention of transformed_embeds: (batch_size*generation_top_k, num_embs, decoder_hidden_size)
58
  return transformed_embeds
59
 
60
- class COCOMConfig(PretrainedConfig):
61
 
 
 
62
  model_type = "COCOM"
63
  def __init__(self,
64
  decoder_model_name="meta-llama/Llama-2-7b-chat-hf",
@@ -71,189 +68,78 @@ class COCOMConfig(PretrainedConfig):
71
  lora = False,
72
  training_form="both",
73
  lora_r=16,
74
- attn_implementation="eager",
75
  device_map = "cuda",
76
  **kwargs):
77
  super().__init__(**kwargs)
78
-
79
- self.decoder_model_name = decoder_model_name # model name of decoder
80
- self.quantization = quantization # quantization, could be no, int4, int8
81
- self.generation_top_k = generation_top_k # top k for each query, for pretraining, set to 1
82
- self.sep = sep # boolean type, whether to use sep token
83
- self.compr_model_name = compr_model_name # model name of compressor
84
- self.compr_rate = compr_rate # compression rate
85
- self.compr_linear_type = compr_linear_type # linear layer type, could be either concat or mean
86
- self.lora = lora # boolean type, whether to use lora trsining
87
- self.training_form = training_form # training form, could be compressor: training only comprssor; both:
88
- self.lora_r = lora_r # lora_r for lora training, we use 16 throughout the experiment.
89
  self.attn_implementation = attn_implementation
90
  self.device_map = device_map
91
 
 
 
92
  class COCOM(PreTrainedModel):
93
  config_class = COCOMConfig
94
  def __init__(self, cfg):
95
  super().__init__(cfg)
96
- # define models
97
  attn_impl = cfg.attn_implementation
98
- # model could be loaded in three quantization modes: no, int4, int8
99
- if cfg.quantization == "no":
100
- self.decoder = AutoModelForCausalLM.from_pretrained(
101
- cfg.decoder_model_name,
102
- torch_dtype=torch.float16,
103
- attn_implementation=attn_impl,
104
- low_cpu_mem_usage = True,
105
- device_map =cfg.device_map
106
- )
107
- elif cfg.quantization == "int4":
108
- quant_config = BitsAndBytesConfig(
109
- load_in_4bit=True,
110
- bnb_4bit_quant_type='nf4',
111
- bnb_4bit_compute_dtype='float16',
112
- low_cpu_mem_usage = True,
113
- )
114
- self.decoder = AutoModelForCausalLM.from_pretrained(
115
- cfg.decoder_model_name,
116
- quantization_config=quant_config,
117
- attn_implementation=attn_impl,
118
- torch_dtype=torch.float16,
119
- resume_download=True,
120
- low_cpu_mem_usage = True,
121
- trust_remote_code=True,
122
- device_map =cfg.device_map
123
- )
124
- elif cfg.quantization == "int8":
125
- quant_config = BitsAndBytesConfig(
126
- load_in_8bit=True,
127
- llm_int8_enable_fp32_cpu_offload=True,
128
- bnb_4bit_compute_dtype='float16',
129
- low_cpu_mem_usage = True,
130
- )
131
- self.decoder = AutoModelForCausalLM.from_pretrained(
132
- cfg.decoder_model_name,
133
- quantization_config=quant_config,
134
- attn_implementation=attn_impl,
135
- torch_dtype=torch.float16,
136
- resume_download=True,
137
- low_cpu_mem_usage = True,
138
- trust_remote_code=True,
139
- device_map =cfg.device_map
140
- )
141
- else:
142
- raise NotImplementedError()
143
 
144
- # when compr_model_name is not set, then means using a decoder-based compressor, otherwise a bert based compressor
145
- if cfg.compr_model_name is not None:
146
- # case bert based compressor
147
- self.compr = BERT_Compressor(cfg.compr_model_name, cfg.compr_rate, cfg.compr_linear_type, self.decoder.config.hidden_size)
148
- else:
149
- # case decoder based compressor
150
- self.compr = None
151
 
152
- # set lora adaptors
 
153
  if cfg.lora:
154
- peft_config = LoraConfig(
155
- task_type="CAUSAL_LM",
156
- r=cfg.lora_r,
157
- lora_alpha=2* cfg.lora_r,
158
- target_modules='all-linear',
159
- lora_dropout=0.1,
160
- )
161
- self.decoder = get_peft_model(self.decoder, peft_config)
162
- self.decoder.print_trainable_parameters()
163
-
164
- # for training_form=compressor, then freeze the decoder for BERT-based
165
- self.training_form = cfg.training_form
166
- if self.training_form == "compressor" and self.compr is not None:
167
- freeze_model(self.decoder)
168
 
169
  self.decoder_tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name, use_fast=True, padding_side='left')
170
-
171
- # define special tokens
172
  self.decoder_tokenizer.add_special_tokens({'additional_special_tokens': ['<MEM>', '<AE>', '<ENC>', '<SEP>']})
173
- self.decoder_tokenizer.mem_token = '<MEM>' # Memory token
174
- self.decoder_tokenizer.ae_token = '<AE>' # token for autoencoding on decoder side
175
- self.decoder_tokenizer.enc_token = '<ENC>' # token for autoencoding on compressor side
176
- self.decoder_tokenizer.sep_token = '<SEP>' # sep token between document
177
-
178
- self.decoder_tokenizer.mem_token_id = self.decoder_tokenizer.convert_tokens_to_ids('<MEM>')
179
- self.decoder_tokenizer.ae_token_id = self.decoder_tokenizer.convert_tokens_to_ids('<AE>')
180
- self.decoder_tokenizer.sep_token_id = self.decoder_tokenizer.convert_tokens_to_ids('<SEP>')
181
- # if pad token ecist then use pad token, othrwise bos token
182
- if self.decoder_tokenizer.pad_token_id is None:
183
- self.decoder_tokenizer.pad_token_id = self.decoder_tokenizer.bos_token_id
184
-
185
- # resize the tokenizer embedding
186
- self.decoder.resize_token_embeddings(len(self.decoder_tokenizer))
187
- self.decoder.generation_config.top_p=None
188
- self.decoder.generation_config.temperature=None
189
- self.compr_model_name = cfg.compr_model_name
190
- # other settings
191
- self.generation_top_k = cfg.generation_top_k
192
- self.sep = cfg.sep
193
- self.compr_rate = cfg.compr_rate
194
- self.local_rank = os.getenv('LOCAL_RANK', '0')
195
-
196
- def compress_and_replace_emb(self, enc_input_ids, enc_attention_mask, dec_input_ids):
197
- indices = range(0, enc_input_ids.size(0) + 1, self.generation_top_k)
198
- if self.compr:
199
- compressed_embs = self.compr(enc_input_ids, enc_attention_mask)
200
- input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, indices)
201
- else:
202
- compressed_embs = self.compr_decoder(enc_input_ids, enc_attention_mask)
203
- input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, indices)
204
- return input_embeds
205
-
206
- def compr_decoder(self, input_ids, attention_mask):
207
- emb = self.decoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1]
208
- mask = input_ids == self.decoder_tokenizer.mem_token_id
209
- return emb[mask].reshape(emb.size(0), -1, emb.size(-1))
210
-
211
-
212
- def replace_embeddings(self, compressed_embs, dec_input_ids, indices):
213
- # Embed the decoder input
214
- inputs_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
215
- num_embs = compressed_embs.size(1)
216
- if self.sep:
217
- slot_len = num_embs + 1
218
- else:
219
- slot_len = num_embs
220
- # get first mem_token inidices
221
- first_mem_token_indices = torch.argmax((dec_input_ids == self.decoder_tokenizer.mem_token_id).int(), dim=1)
222
- batch_size = inputs_embeds.size(0)
223
- # for each example in batch, replace them with compressed embeddings
224
- for i in range(batch_size):
225
- for j in range(indices[i], indices[i + 1]):
226
- start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len
227
- inputs_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
228
- return inputs_embeds
229
-
230
-
231
- def forward(self,
232
- enc_input_ids: torch.LongTensor = None,
233
- enc_attention_mask: torch.LongTensor = None,
234
- dec_input_ids: torch.LongTensor = None,
235
- dec_attention_mask: torch.LongTensor = None,
236
- labels: torch.LongTensor = None):
237
 
238
- # enc_input_ids: stores the contexts, should be flattened from all queries before input, dimention (batch_size*generation_top_k, token_length)
239
- # enc_attention_mask: attention mask of enc_input_ids
240
- # dec_input_ids: stores the prompts (including mem tokens), dimention (batch_size, token_length)
241
- # dec_attention_mask: attention mask of dec_input_ids
 
 
 
 
 
 
242
 
243
- # Perform compression with gradient tracking
 
 
 
 
 
 
 
 
 
 
 
244
  inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
245
-
246
- # if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
247
- if (self.training_form == "compressor") and (self.compr is None):
248
- inputs_embeds = inputs_embeds.detach()
249
-
250
- # decoding
251
  decoder_outputs = self.decoder(inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, labels=labels)
252
-
253
  return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}
254
 
255
-
256
-
257
  def generate(self, model_input, max_new_tokens=128):
258
  device = self.decoder.device
259
  enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask']
@@ -263,52 +149,8 @@ class COCOM(PreTrainedModel):
263
  attention_mask=dec_attention_mask.to(device),
264
  do_sample=False,
265
  top_p=None,
266
- max_new_tokens=max_new_tokens
267
- )
268
  decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
269
  return decoded
270
-
271
- def generate_from_text(self, contexts, questions, max_new_tokens=128):
272
- # for each question in list give input a list of contexts of equal length
273
- # first make sure that every list in contexts are having the same length
274
- assert len(contexts) == len(questions)
275
- assert all([len(context) == len(contexts[0]) for context in contexts])
276
-
277
- # prepare inp_enc for compression
278
- # first flatten the contexts
279
- self.generation_top_k = len(contexts[0])
280
- flat_contexts = sum(contexts, [])
281
- #tokenize the contexts, depending if compr exist or not
282
- if self.compr is not None:
283
- enc_input = self.compr.tokenizer(flat_contexts, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=self.compr_rate)
284
- num_mem_tokens = math.ceil(enc_input['input_ids'].size(1) / self.compr_rate)
285
- else:
286
- # first need to add special token in flat_contexts
287
- flat_contexts = [self.decoder_tokenizer.enc_token + self.decoder_tokenizer.bos_token + context + self.decoder_tokenizer.bos_token for context in flat_contexts]
288
- enc_input = self.decoder_tokenizer(flat_contexts, truncation=True, return_tensors='pt', padding="longest")
289
- num_mem_tokens = math.ceil((enc_input['input_ids'].size(1)-3) / self.compr_rate)
290
- mem_tokens = torch.full((enc_input['input_ids'].size(0), num_mem_tokens), self.decoder_tokenizer.mem_token_id, dtype=torch.long)
291
- enc_input['input_ids'] = torch.cat([mem_tokens, enc_input['input_ids']], dim=1)
292
- enc_input['attention_mask'] = torch.cat([torch.ones_like(mem_tokens), enc_input['attention_mask']], dim=1)
293
-
294
-
295
- # prepare inp_dec
296
- mem_tokens = self.decoder_tokenizer.mem_token * num_mem_tokens
297
- if self.sep:
298
- mem_tokens += self.decoder_tokenizer.sep_token
299
-
300
- instr = [self.decoder_tokenizer.bos_token + mem_tokens* self.generation_top_k + '[INST]' + question + '\n[/INST]\n' for question in questions]
301
- inp_dec = self.decoder_tokenizer(instr, truncation=True, return_tensors='pt', padding="longest")
302
-
303
- # generate
304
- model_input = {
305
- 'enc_input_ids': enc_input['input_ids'].to(self.decoder.device),
306
- 'enc_attention_mask': enc_input['attention_mask'].to(self.decoder.device),
307
- 'dec_input_ids': inp_dec['input_ids'].to(self.decoder.device),
308
- 'dec_attention_mask': inp_dec['attention_mask'].to(self.decoder.device)
309
- }
310
-
311
- return self.generate(model_input, max_new_tokens)
312
-
313
 
314
-
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PretrainedConfig, AutoModel, LongformerForCausalLM, LongformerTokenizer
2
+ from linformer.attention import LinformerSelfAttention
3
  import torch
4
  import math
5
  from peft import get_peft_model, LoraConfig, TaskType
6
  import os
7
 
8
+
9
+ # Freeze model function (unchanged)
10
  def freeze_model(model):
11
  for param in model.parameters():
12
  param.requires_grad = False
13
 
14
 
15
+ # BERT_Compressor remains the same as you are not modifying it for Linformer
16
  class BERT_Compressor(torch.nn.Module):
17
  def __init__(self, compr_model_name, compr_rate, compr_linear_type, decoder_hidden_size):
18
  super().__init__()
19
+ self.model_name = compr_model_name
 
20
  self.model = AutoModel.from_pretrained(compr_model_name, torch_dtype=torch.float16)
21
  self.tokenizer = AutoTokenizer.from_pretrained(compr_model_name, use_fast=True)
22
+ self.compr_rate = compr_rate
23
+ self.compressing_mode = compr_linear_type
24
 
25
+ if self.compressing_mode == 'concat':
26
  self.linear = torch.nn.Linear(self.model.config.hidden_size*self.compr_rate, decoder_hidden_size)
27
  elif self.compressing_mode == 'mean':
28
  self.linear = torch.nn.Linear(self.model.config.hidden_size, decoder_hidden_size)
29
  self.linear = self.linear.float16()
30
 
31
  def forward(self, input_ids, attention_mask):
 
32
  segment_compress_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
33
  num_embs = math.ceil(input_ids.size(1) / self.compr_rate)
34
  all_hidden_states_emb = list()
 
44
  start_idx = segment_idx * self.compr_rate
45
  end_idx = (segment_idx + 1) * self.compr_rate
46
  hidden_state = segment_compress_outputs.hidden_states[-1][:, start_idx:end_idx, :]
 
47
  all_hidden_states_emb.append(hidden_state)
 
 
 
48
  all_hidden_states_emb_cat = torch.stack(all_hidden_states_emb, dim=1)
49
  transformed_embeds = self.linear(all_hidden_states_emb_cat)
50
 
 
51
  if self.compressing_mode == "mean":
52
  transformed_embeds = torch.mean(transformed_embeds, dim=2)
53
 
 
54
  return transformed_embeds
55
 
 
56
 
57
+ # Modify COCOMConfig to support Linformer
58
+ class COCOMConfig(PretrainedConfig):
59
  model_type = "COCOM"
60
  def __init__(self,
61
  decoder_model_name="meta-llama/Llama-2-7b-chat-hf",
 
68
  lora = False,
69
  training_form="both",
70
  lora_r=16,
71
+ attn_implementation="linformer", # Change default to Linformer
72
  device_map = "cuda",
73
  **kwargs):
74
  super().__init__(**kwargs)
75
+ self.decoder_model_name = decoder_model_name
76
+ self.quantization = quantization
77
+ self.generation_top_k = generation_top_k
78
+ self.sep = sep
79
+ self.compr_model_name = compr_model_name
80
+ self.compr_rate = compr_rate
81
+ self.compr_linear_type = compr_linear_type
82
+ self.lora = lora
83
+ self.training_form = training_form
84
+ self.lora_r = lora_r
 
85
  self.attn_implementation = attn_implementation
86
  self.device_map = device_map
87
 
88
+
89
+ # Modify COCOM model to use Linformer in the attention layer
90
  class COCOM(PreTrainedModel):
91
  config_class = COCOMConfig
92
  def __init__(self, cfg):
93
  super().__init__(cfg)
 
94
  attn_impl = cfg.attn_implementation
95
+
96
+ # Load the model (decoder) in standard quantization or Linformer
97
+ self.decoder = AutoModelForCausalLM.from_pretrained(
98
+ cfg.decoder_model_name,
99
+ torch_dtype=torch.float16,
100
+ low_cpu_mem_usage=True,
101
+ device_map=cfg.device_map
102
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ # Replace decoder's attention mechanism with LinformerSelfAttention if configured
105
+ if attn_impl == 'linformer':
106
+ self._replace_attention_with_linformer()
 
 
 
 
107
 
108
+ # Initialize other parts of the model (compression, LoRA, etc.)
109
+ self.compr = BERT_Compressor(cfg.compr_model_name, cfg.compr_rate, cfg.compr_linear_type, self.decoder.config.hidden_size)
110
  if cfg.lora:
111
+ self._apply_lora(cfg.lora_r)
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  self.decoder_tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name, use_fast=True, padding_side='left')
 
 
114
  self.decoder_tokenizer.add_special_tokens({'additional_special_tokens': ['<MEM>', '<AE>', '<ENC>', '<SEP>']})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ def _replace_attention_with_linformer(self):
117
+ # Replace all attention layers with LinformerSelfAttention in the model
118
+ for layer in self.decoder.transformer.h:
119
+ layer.attn = LinformerSelfAttention(
120
+ dim=layer.attn.attn.in_proj_weight.shape[0],
121
+ num_heads=layer.attn.num_attention_heads,
122
+ dropout=0.1,
123
+ n_heads=layer.attn.num_attention_heads,
124
+ d_head=layer.attn.attn.in_proj_weight.shape[0] // layer.attn.num_attention_heads
125
+ )
126
 
127
+ def _apply_lora(self, lora_r):
128
+ # Apply LoRA as per your configuration
129
+ peft_config = LoraConfig(
130
+ task_type="CAUSAL_LM",
131
+ r=lora_r,
132
+ lora_alpha=2 * lora_r,
133
+ target_modules='all-linear',
134
+ lora_dropout=0.1,
135
+ )
136
+ self.decoder = get_peft_model(self.decoder, peft_config)
137
+
138
+ def forward(self, enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask, labels):
139
  inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
 
 
 
 
 
 
140
  decoder_outputs = self.decoder(inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, labels=labels)
 
141
  return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}
142
 
 
 
143
  def generate(self, model_input, max_new_tokens=128):
144
  device = self.decoder.device
145
  enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask']
 
149
  attention_mask=dec_attention_mask.to(device),
150
  do_sample=False,
151
  top_p=None,
152
+ max_new_tokens=min(max_new_tokens, 4096)
153
+ )
154
  decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
155
  return decoded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156