Files changed (1) hide show
  1. modeling_cocom.py +219 -61
modeling_cocom.py CHANGED
@@ -1,34 +1,32 @@
1
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PretrainedConfig, AutoModel
2
- from linformer.attention import LinformerSelfAttention
3
  import torch
4
  import math
5
  from peft import get_peft_model, LoraConfig, TaskType
6
  import os
7
 
8
-
9
- # Freeze model function (unchanged)
10
  def freeze_model(model):
11
  for param in model.parameters():
12
  param.requires_grad = False
13
 
14
 
15
- # BERT_Compressor remains the same as you are not modifying it for Linformer
16
  class BERT_Compressor(torch.nn.Module):
17
  def __init__(self, compr_model_name, compr_rate, compr_linear_type, decoder_hidden_size):
18
  super().__init__()
19
- self.model_name = compr_model_name
 
20
  self.model = AutoModel.from_pretrained(compr_model_name, torch_dtype=torch.float16)
21
  self.tokenizer = AutoTokenizer.from_pretrained(compr_model_name, use_fast=True)
22
- self.compr_rate = compr_rate
23
- self.compressing_mode = compr_linear_type
24
 
25
- if self.compressing_mode == 'concat':
26
  self.linear = torch.nn.Linear(self.model.config.hidden_size*self.compr_rate, decoder_hidden_size)
27
  elif self.compressing_mode == 'mean':
28
  self.linear = torch.nn.Linear(self.model.config.hidden_size, decoder_hidden_size)
29
  self.linear = self.linear.float16()
30
 
31
  def forward(self, input_ids, attention_mask):
 
32
  segment_compress_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
33
  num_embs = math.ceil(input_ids.size(1) / self.compr_rate)
34
  all_hidden_states_emb = list()
@@ -44,18 +42,23 @@ class BERT_Compressor(torch.nn.Module):
44
  start_idx = segment_idx * self.compr_rate
45
  end_idx = (segment_idx + 1) * self.compr_rate
46
  hidden_state = segment_compress_outputs.hidden_states[-1][:, start_idx:end_idx, :]
 
47
  all_hidden_states_emb.append(hidden_state)
 
 
 
48
  all_hidden_states_emb_cat = torch.stack(all_hidden_states_emb, dim=1)
49
  transformed_embeds = self.linear(all_hidden_states_emb_cat)
50
 
 
51
  if self.compressing_mode == "mean":
52
  transformed_embeds = torch.mean(transformed_embeds, dim=2)
53
 
 
54
  return transformed_embeds
55
 
56
-
57
- # Modify COCOMConfig to support Linformer
58
  class COCOMConfig(PretrainedConfig):
 
59
  model_type = "COCOM"
60
  def __init__(self,
61
  decoder_model_name="meta-llama/Llama-2-7b-chat-hf",
@@ -68,78 +71,189 @@ class COCOMConfig(PretrainedConfig):
68
  lora = False,
69
  training_form="both",
70
  lora_r=16,
71
- attn_implementation="linformer", # Change default to Linformer
72
  device_map = "cuda",
73
  **kwargs):
74
  super().__init__(**kwargs)
75
- self.decoder_model_name = decoder_model_name
76
- self.quantization = quantization
77
- self.generation_top_k = generation_top_k
78
- self.sep = sep
79
- self.compr_model_name = compr_model_name
80
- self.compr_rate = compr_rate
81
- self.compr_linear_type = compr_linear_type
82
- self.lora = lora
83
- self.training_form = training_form
84
- self.lora_r = lora_r
 
85
  self.attn_implementation = attn_implementation
86
  self.device_map = device_map
87
 
88
-
89
- # Modify COCOM model to use Linformer in the attention layer
90
  class COCOM(PreTrainedModel):
91
  config_class = COCOMConfig
92
  def __init__(self, cfg):
93
  super().__init__(cfg)
 
94
  attn_impl = cfg.attn_implementation
95
-
96
- # Load the model (decoder) in standard quantization or Linformer
97
- self.decoder = AutoModelForCausalLM.from_pretrained(
98
- cfg.decoder_model_name,
99
- torch_dtype=torch.float16,
100
- low_cpu_mem_usage=True,
101
- device_map=cfg.device_map
102
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- # Replace decoder's attention mechanism with LinformerSelfAttention if configured
105
- if attn_impl == 'linformer':
106
- self._replace_attention_with_linformer()
 
 
 
 
107
 
108
- # Initialize other parts of the model (compression, LoRA, etc.)
109
- self.compr = BERT_Compressor(cfg.compr_model_name, cfg.compr_rate, cfg.compr_linear_type, self.decoder.config.hidden_size)
110
  if cfg.lora:
111
- self._apply_lora(cfg.lora_r)
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  self.decoder_tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name, use_fast=True, padding_side='left')
 
 
114
  self.decoder_tokenizer.add_special_tokens({'additional_special_tokens': ['<MEM>', '<AE>', '<ENC>', '<SEP>']})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- def _replace_attention_with_linformer(self):
117
- # Replace all attention layers with LinformerSelfAttention in the model
118
- for layer in self.decoder.transformer.h:
119
- layer.attn = LinformerSelfAttention(
120
- dim=layer.attn.attn.in_proj_weight.shape[0],
121
- num_heads=layer.attn.num_attention_heads,
122
- dropout=0.1,
123
- n_heads=layer.attn.num_attention_heads,
124
- d_head=layer.attn.attn.in_proj_weight.shape[0] // layer.attn.num_attention_heads
125
- )
126
 
127
- def _apply_lora(self, lora_r):
128
- # Apply LoRA as per your configuration
129
- peft_config = LoraConfig(
130
- task_type="CAUSAL_LM",
131
- r=lora_r,
132
- lora_alpha=2 * lora_r,
133
- target_modules='all-linear',
134
- lora_dropout=0.1,
135
- )
136
- self.decoder = get_peft_model(self.decoder, peft_config)
137
-
138
- def forward(self, enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask, labels):
139
  inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
 
 
 
 
 
 
140
  decoder_outputs = self.decoder(inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, labels=labels)
 
141
  return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}
142
 
 
 
143
  def generate(self, model_input, max_new_tokens=128):
144
  device = self.decoder.device
145
  enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask']
@@ -149,8 +263,52 @@ class COCOM(PreTrainedModel):
149
  attention_mask=dec_attention_mask.to(device),
150
  do_sample=False,
151
  top_p=None,
152
- max_new_tokens=min(max_new_tokens, 4096)
153
- )
154
  decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
155
  return decoded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
 
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PretrainedConfig, AutoModel
 
2
  import torch
3
  import math
4
  from peft import get_peft_model, LoraConfig, TaskType
5
  import os
6
 
 
 
7
  def freeze_model(model):
8
  for param in model.parameters():
9
  param.requires_grad = False
10
 
11
 
 
12
  class BERT_Compressor(torch.nn.Module):
13
  def __init__(self, compr_model_name, compr_rate, compr_linear_type, decoder_hidden_size):
14
  super().__init__()
15
+ # init model
16
+ self.model_name = compr_model_name # base model name of BERT; example: bert-base-ucased
17
  self.model = AutoModel.from_pretrained(compr_model_name, torch_dtype=torch.float16)
18
  self.tokenizer = AutoTokenizer.from_pretrained(compr_model_name, use_fast=True)
19
+ self.compr_rate = compr_rate # compression rate
20
+ self.compressing_mode = compr_linear_type # linear layer type, could be either concat or mean.
21
 
22
+ if self.compressing_mode == 'concat': # default setting in paper
23
  self.linear = torch.nn.Linear(self.model.config.hidden_size*self.compr_rate, decoder_hidden_size)
24
  elif self.compressing_mode == 'mean':
25
  self.linear = torch.nn.Linear(self.model.config.hidden_size, decoder_hidden_size)
26
  self.linear = self.linear.float16()
27
 
28
  def forward(self, input_ids, attention_mask):
29
+ # compressing context using BERT
30
  segment_compress_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
31
  num_embs = math.ceil(input_ids.size(1) / self.compr_rate)
32
  all_hidden_states_emb = list()
 
42
  start_idx = segment_idx * self.compr_rate
43
  end_idx = (segment_idx + 1) * self.compr_rate
44
  hidden_state = segment_compress_outputs.hidden_states[-1][:, start_idx:end_idx, :]
45
+ # Apply mean pooling to get the final embedding for the segment
46
  all_hidden_states_emb.append(hidden_state)
47
+ else:
48
+ raise NotImplementedError()
49
+
50
  all_hidden_states_emb_cat = torch.stack(all_hidden_states_emb, dim=1)
51
  transformed_embeds = self.linear(all_hidden_states_emb_cat)
52
 
53
+
54
  if self.compressing_mode == "mean":
55
  transformed_embeds = torch.mean(transformed_embeds, dim=2)
56
 
57
+ # dimention of transformed_embeds: (batch_size*generation_top_k, num_embs, decoder_hidden_size)
58
  return transformed_embeds
59
 
 
 
60
  class COCOMConfig(PretrainedConfig):
61
+
62
  model_type = "COCOM"
63
  def __init__(self,
64
  decoder_model_name="meta-llama/Llama-2-7b-chat-hf",
 
71
  lora = False,
72
  training_form="both",
73
  lora_r=16,
74
+ attn_implementation="eager",
75
  device_map = "cuda",
76
  **kwargs):
77
  super().__init__(**kwargs)
78
+
79
+ self.decoder_model_name = decoder_model_name # model name of decoder
80
+ self.quantization = quantization # quantization, could be no, int4, int8
81
+ self.generation_top_k = generation_top_k # top k for each query, for pretraining, set to 1
82
+ self.sep = sep # boolean type, whether to use sep token
83
+ self.compr_model_name = compr_model_name # model name of compressor
84
+ self.compr_rate = compr_rate # compression rate
85
+ self.compr_linear_type = compr_linear_type # linear layer type, could be either concat or mean
86
+ self.lora = lora # boolean type, whether to use lora trsining
87
+ self.training_form = training_form # training form, could be compressor: training only comprssor; both:
88
+ self.lora_r = lora_r # lora_r for lora training, we use 16 throughout the experiment.
89
  self.attn_implementation = attn_implementation
90
  self.device_map = device_map
91
 
 
 
92
  class COCOM(PreTrainedModel):
93
  config_class = COCOMConfig
94
  def __init__(self, cfg):
95
  super().__init__(cfg)
96
+ # define models
97
  attn_impl = cfg.attn_implementation
98
+ # model could be loaded in three quantization modes: no, int4, int8
99
+ if cfg.quantization == "no":
100
+ self.decoder = AutoModelForCausalLM.from_pretrained(
101
+ cfg.decoder_model_name,
102
+ torch_dtype=torch.float16,
103
+ attn_implementation=attn_impl,
104
+ low_cpu_mem_usage = True,
105
+ device_map =cfg.device_map
106
+ )
107
+ elif cfg.quantization == "int4":
108
+ quant_config = BitsAndBytesConfig(
109
+ load_in_4bit=True,
110
+ bnb_4bit_quant_type='nf4',
111
+ bnb_4bit_compute_dtype='float16',
112
+ low_cpu_mem_usage = True,
113
+ )
114
+ self.decoder = AutoModelForCausalLM.from_pretrained(
115
+ cfg.decoder_model_name,
116
+ quantization_config=quant_config,
117
+ attn_implementation=attn_impl,
118
+ torch_dtype=torch.float16,
119
+ resume_download=True,
120
+ low_cpu_mem_usage = True,
121
+ trust_remote_code=True,
122
+ device_map =cfg.device_map
123
+ )
124
+ elif cfg.quantization == "int8":
125
+ quant_config = BitsAndBytesConfig(
126
+ load_in_8bit=True,
127
+ llm_int8_enable_fp32_cpu_offload=True,
128
+ bnb_4bit_compute_dtype='float16',
129
+ low_cpu_mem_usage = True,
130
+ )
131
+ self.decoder = AutoModelForCausalLM.from_pretrained(
132
+ cfg.decoder_model_name,
133
+ quantization_config=quant_config,
134
+ attn_implementation=attn_impl,
135
+ torch_dtype=torch.float16,
136
+ resume_download=True,
137
+ low_cpu_mem_usage = True,
138
+ trust_remote_code=True,
139
+ device_map =cfg.device_map
140
+ )
141
+ else:
142
+ raise NotImplementedError()
143
 
144
+ # when compr_model_name is not set, then means using a decoder-based compressor, otherwise a bert based compressor
145
+ if cfg.compr_model_name is not None:
146
+ # case bert based compressor
147
+ self.compr = BERT_Compressor(cfg.compr_model_name, cfg.compr_rate, cfg.compr_linear_type, self.decoder.config.hidden_size)
148
+ else:
149
+ # case decoder based compressor
150
+ self.compr = None
151
 
152
+ # set lora adaptors
 
153
  if cfg.lora:
154
+ peft_config = LoraConfig(
155
+ task_type="CAUSAL_LM",
156
+ r=cfg.lora_r,
157
+ lora_alpha=2* cfg.lora_r,
158
+ target_modules='all-linear',
159
+ lora_dropout=0.1,
160
+ )
161
+ self.decoder = get_peft_model(self.decoder, peft_config)
162
+ self.decoder.print_trainable_parameters()
163
+
164
+ # for training_form=compressor, then freeze the decoder for BERT-based
165
+ self.training_form = cfg.training_form
166
+ if self.training_form == "compressor" and self.compr is not None:
167
+ freeze_model(self.decoder)
168
 
169
  self.decoder_tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name, use_fast=True, padding_side='left')
170
+
171
+ # define special tokens
172
  self.decoder_tokenizer.add_special_tokens({'additional_special_tokens': ['<MEM>', '<AE>', '<ENC>', '<SEP>']})
173
+ self.decoder_tokenizer.mem_token = '<MEM>' # Memory token
174
+ self.decoder_tokenizer.ae_token = '<AE>' # token for autoencoding on decoder side
175
+ self.decoder_tokenizer.enc_token = '<ENC>' # token for autoencoding on compressor side
176
+ self.decoder_tokenizer.sep_token = '<SEP>' # sep token between document
177
+
178
+ self.decoder_tokenizer.mem_token_id = self.decoder_tokenizer.convert_tokens_to_ids('<MEM>')
179
+ self.decoder_tokenizer.ae_token_id = self.decoder_tokenizer.convert_tokens_to_ids('<AE>')
180
+ self.decoder_tokenizer.sep_token_id = self.decoder_tokenizer.convert_tokens_to_ids('<SEP>')
181
+ # if pad token ecist then use pad token, othrwise bos token
182
+ if self.decoder_tokenizer.pad_token_id is None:
183
+ self.decoder_tokenizer.pad_token_id = self.decoder_tokenizer.bos_token_id
184
+
185
+ # resize the tokenizer embedding
186
+ self.decoder.resize_token_embeddings(len(self.decoder_tokenizer))
187
+ self.decoder.generation_config.top_p=None
188
+ self.decoder.generation_config.temperature=None
189
+ self.compr_model_name = cfg.compr_model_name
190
+ # other settings
191
+ self.generation_top_k = cfg.generation_top_k
192
+ self.sep = cfg.sep
193
+ self.compr_rate = cfg.compr_rate
194
+ self.local_rank = os.getenv('LOCAL_RANK', '0')
195
+
196
+ def compress_and_replace_emb(self, enc_input_ids, enc_attention_mask, dec_input_ids):
197
+ indices = range(0, enc_input_ids.size(0) + 1, self.generation_top_k)
198
+ if self.compr:
199
+ compressed_embs = self.compr(enc_input_ids, enc_attention_mask)
200
+ input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, indices)
201
+ else:
202
+ compressed_embs = self.compr_decoder(enc_input_ids, enc_attention_mask)
203
+ input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, indices)
204
+ return input_embeds
205
+
206
+ def compr_decoder(self, input_ids, attention_mask):
207
+ emb = self.decoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1]
208
+ mask = input_ids == self.decoder_tokenizer.mem_token_id
209
+ return emb[mask].reshape(emb.size(0), -1, emb.size(-1))
210
+
211
+
212
+ def replace_embeddings(self, compressed_embs, dec_input_ids, indices):
213
+ # Embed the decoder input
214
+ inputs_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
215
+ num_embs = compressed_embs.size(1)
216
+ if self.sep:
217
+ slot_len = num_embs + 1
218
+ else:
219
+ slot_len = num_embs
220
+ # get first mem_token inidices
221
+ first_mem_token_indices = torch.argmax((dec_input_ids == self.decoder_tokenizer.mem_token_id).int(), dim=1)
222
+ batch_size = inputs_embeds.size(0)
223
+ # for each example in batch, replace them with compressed embeddings
224
+ for i in range(batch_size):
225
+ for j in range(indices[i], indices[i + 1]):
226
+ start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len
227
+ inputs_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
228
+ return inputs_embeds
229
+
230
+
231
+ def forward(self,
232
+ enc_input_ids: torch.LongTensor = None,
233
+ enc_attention_mask: torch.LongTensor = None,
234
+ dec_input_ids: torch.LongTensor = None,
235
+ dec_attention_mask: torch.LongTensor = None,
236
+ labels: torch.LongTensor = None):
237
 
238
+ # enc_input_ids: stores the contexts, should be flattened from all queries before input, dimention (batch_size*generation_top_k, token_length)
239
+ # enc_attention_mask: attention mask of enc_input_ids
240
+ # dec_input_ids: stores the prompts (including mem tokens), dimention (batch_size, token_length)
241
+ # dec_attention_mask: attention mask of dec_input_ids
 
 
 
 
 
 
242
 
243
+ # Perform compression with gradient tracking
 
 
 
 
 
 
 
 
 
 
 
244
  inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
245
+
246
+ # if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
247
+ if (self.training_form == "compressor") and (self.compr is None):
248
+ inputs_embeds = inputs_embeds.detach()
249
+
250
+ # decoding
251
  decoder_outputs = self.decoder(inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, labels=labels)
252
+
253
  return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}
254
 
255
+
256
+
257
  def generate(self, model_input, max_new_tokens=128):
258
  device = self.decoder.device
259
  enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask']
 
263
  attention_mask=dec_attention_mask.to(device),
264
  do_sample=False,
265
  top_p=None,
266
+ max_new_tokens=max_new_tokens
267
+ )
268
  decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
269
  return decoded
270
+
271
+ def generate_from_text(self, contexts, questions, max_new_tokens=128):
272
+ # for each question in list give input a list of contexts of equal length
273
+ # first make sure that every list in contexts are having the same length
274
+ assert len(contexts) == len(questions)
275
+ assert all([len(context) == len(contexts[0]) for context in contexts])
276
+
277
+ # prepare inp_enc for compression
278
+ # first flatten the contexts
279
+ self.generation_top_k = len(contexts[0])
280
+ flat_contexts = sum(contexts, [])
281
+ #tokenize the contexts, depending if compr exist or not
282
+ if self.compr is not None:
283
+ enc_input = self.compr.tokenizer(flat_contexts, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=self.compr_rate)
284
+ num_mem_tokens = math.ceil(enc_input['input_ids'].size(1) / self.compr_rate)
285
+ else:
286
+ # first need to add special token in flat_contexts
287
+ flat_contexts = [self.decoder_tokenizer.enc_token + self.decoder_tokenizer.bos_token + context + self.decoder_tokenizer.bos_token for context in flat_contexts]
288
+ enc_input = self.decoder_tokenizer(flat_contexts, truncation=True, return_tensors='pt', padding="longest")
289
+ num_mem_tokens = math.ceil((enc_input['input_ids'].size(1)-3) / self.compr_rate)
290
+ mem_tokens = torch.full((enc_input['input_ids'].size(0), num_mem_tokens), self.decoder_tokenizer.mem_token_id, dtype=torch.long)
291
+ enc_input['input_ids'] = torch.cat([mem_tokens, enc_input['input_ids']], dim=1)
292
+ enc_input['attention_mask'] = torch.cat([torch.ones_like(mem_tokens), enc_input['attention_mask']], dim=1)
293
+
294
+
295
+ # prepare inp_dec
296
+ mem_tokens = self.decoder_tokenizer.mem_token * num_mem_tokens
297
+ if self.sep:
298
+ mem_tokens += self.decoder_tokenizer.sep_token
299
+
300
+ instr = [self.decoder_tokenizer.bos_token + mem_tokens* self.generation_top_k + '[INST]' + question + '\n[/INST]\n' for question in questions]
301
+ inp_dec = self.decoder_tokenizer(instr, truncation=True, return_tensors='pt', padding="longest")
302
+
303
+ # generate
304
+ model_input = {
305
+ 'enc_input_ids': enc_input['input_ids'].to(self.decoder.device),
306
+ 'enc_attention_mask': enc_input['attention_mask'].to(self.decoder.device),
307
+ 'dec_input_ids': inp_dec['input_ids'].to(self.decoder.device),
308
+ 'dec_attention_mask': inp_dec['attention_mask'].to(self.decoder.device)
309
+ }
310
+
311
+ return self.generate(model_input, max_new_tokens)
312
+
313
 
314
+