File size: 15,856 Bytes
bbcbb48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PretrainedConfig, AutoModel
import torch
import math 
from peft import get_peft_model, LoraConfig, TaskType
import os

def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False


class BERT_Compressor(torch.nn.Module):
    def __init__(self, compr_model_name, compr_rate, compr_linear_type, decoder_hidden_size):
        super().__init__()
        # init model
        self.model_name = compr_model_name # base model name of BERT; example: bert-base-ucased
        self.model = AutoModel.from_pretrained(compr_model_name, torch_dtype=torch.bfloat16)
        self.tokenizer = AutoTokenizer.from_pretrained(compr_model_name, use_fast=True) 
        self.compr_rate = compr_rate # compression rate
        self.compressing_mode = compr_linear_type # linear layer type, could be either concat or mean.

        if self.compressing_mode == 'concat': # default setting in paper
            self.linear = torch.nn.Linear(self.model.config.hidden_size*self.compr_rate, decoder_hidden_size) 
        elif self.compressing_mode == 'mean':
            self.linear = torch.nn.Linear(self.model.config.hidden_size, decoder_hidden_size)
        self.linear = self.linear.bfloat16()

    def forward(self, input_ids, attention_mask):
        # compressing context using BERT
        segment_compress_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) 
        num_embs = math.ceil(input_ids.size(1) / self.compr_rate)
        all_hidden_states_emb = list()
        if self.compressing_mode == 'concat':
            for segment_idx in range(num_embs):
                start_idx = segment_idx * self.compr_rate
                end_idx = (segment_idx + 1) * self.compr_rate
                hidden_state = segment_compress_outputs.hidden_states[-1][:, start_idx:end_idx, :]
                hidden_state_concat = torch.flatten(hidden_state, start_dim=1) #batch_size, hidden_state_dim * compression_rate
                all_hidden_states_emb.append(hidden_state_concat)
        elif self.compressing_mode == "mean":
            for segment_idx in range(num_embs):
                start_idx = segment_idx * self.compr_rate
                end_idx = (segment_idx + 1) * self.compr_rate
                hidden_state = segment_compress_outputs.hidden_states[-1][:, start_idx:end_idx, :]
                # Apply mean pooling to get the final embedding for the segment
                all_hidden_states_emb.append(hidden_state)
        else: 
            raise NotImplementedError()
        
        all_hidden_states_emb_cat = torch.stack(all_hidden_states_emb, dim=1)
        transformed_embeds = self.linear(all_hidden_states_emb_cat)
        

        if self.compressing_mode == "mean":
            transformed_embeds = torch.mean(transformed_embeds, dim=2)

        # dimention of transformed_embeds: (batch_size*generation_top_k, num_embs, decoder_hidden_size)
        return  transformed_embeds

class COCOMConfig(PretrainedConfig):

    model_type = "COCOM"
    def __init__(self,

                decoder_model_name="meta-llama/Llama-2-7b-chat-hf",

                quantization = 'no', 

                generation_top_k = 1, 

                sep = False,

                compr_model_name = "bert-base-uncased", 

                compr_rate = 64,

                compr_linear_type = 'concat',

                lora = False,

                training_form="both",

                lora_r=16,

                 **kwargs):
        super().__init__(**kwargs)

        self.decoder_model_name = decoder_model_name # model name of decoder
        self.quantization = quantization # quantization, could be no, int4, int8
        self.generation_top_k = generation_top_k # top k for each query, for pretraining, set to 1
        self.sep = sep # boolean type, whether to use sep token
        self.compr_model_name = compr_model_name # model name of compressor
        self.compr_rate = compr_rate # compression rate
        self.compr_linear_type = compr_linear_type # linear layer type, could be either concat or mean
        self.lora = lora # boolean type, whether to use lora trsining
        self.training_form = training_form # training form, could be compressor: training only comprssor; both: 
        self.lora_r = lora_r # lora_r for lora training, we use 16 throughout the experiment.

class COCOM(PreTrainedModel):
    config_class = COCOMConfig
    def __init__(self, cfg):
        super().__init__(cfg)
        # define models
        # model could be loaded in three quantization modes: no, int4, int8
        if cfg.quantization == "no":
            self.decoder = AutoModelForCausalLM.from_pretrained(
                cfg.decoder_model_name, 
                torch_dtype=torch.bfloat16,
                attn_implementation="flash_attention_2", 
                low_cpu_mem_usage = True,
                )
        elif cfg.quantization == "int4":
            quant_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type='nf4',
                bnb_4bit_compute_dtype='bfloat16',
                low_cpu_mem_usage = True,
            )
            self.decoder = AutoModelForCausalLM.from_pretrained(
                cfg.decoder_model_name, 
                quantization_config=quant_config,
                attn_implementation="flash_attention_2", 
                torch_dtype=torch.bfloat16,
                resume_download=True,
                low_cpu_mem_usage = True,
                trust_remote_code=True,
            )
        elif cfg.quantization == "int8":
            quant_config = BitsAndBytesConfig(
                load_in_8bit=True,
                llm_int8_enable_fp32_cpu_offload=True,
                bnb_4bit_compute_dtype='bfloat16',
                low_cpu_mem_usage = True,
            )
            self.decoder = AutoModelForCausalLM.from_pretrained(
                cfg.decoder_model_name,
                quantization_config=quant_config,
                attn_implementation="flash_attention_2",
                torch_dtype=torch.bfloat16,
                resume_download=True,
                low_cpu_mem_usage = True,
                trust_remote_code=True,
            )
        else:
            raise NotImplementedError()
        
        # when compr_model_name is not set, then means using a decoder-based compressor, otherwise a bert based compressor
        if cfg.compr_model_name is not None:
            # case bert based compressor
            self.compr = BERT_Compressor(cfg.compr_model_name, cfg.compr_rate, cfg.compr_linear_type, self.decoder.config.hidden_size)
        else:
            # case decoder based compressor
            self.compr = None

        # set lora adaptors
        if cfg.lora:
            peft_config = LoraConfig(
                        task_type="CAUSAL_LM",
                        r=cfg.lora_r,
                        lora_alpha=2* cfg.lora_r,
                        target_modules='all-linear',
                        lora_dropout=0.1,
                    )
            self.decoder = get_peft_model(self.decoder, peft_config)
            self.decoder.print_trainable_parameters()  

        # for training_form=compressor, then freeze the decoder for BERT-based
        self.training_form = cfg.training_form
        if self.training_form == "compressor" and self.compr is not None:
            freeze_model(self.decoder)

        self.decoder_tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name, use_fast=True, padding_side='left')

        # define special tokens
        self.decoder_tokenizer.add_special_tokens({'additional_special_tokens': ['<MEM>', '<AE>', '<ENC>', '<SEP>']})
        self.decoder_tokenizer.mem_token = '<MEM>' # Memory token
        self.decoder_tokenizer.ae_token = '<AE>' # token for autoencoding on decoder side
        self.decoder_tokenizer.enc_token = '<ENC>' # token for autoencoding on compressor side
        self.decoder_tokenizer.sep_token = '<SEP>' # sep token between document

        self.decoder_tokenizer.mem_token_id = self.decoder_tokenizer.convert_tokens_to_ids('<MEM>')
        self.decoder_tokenizer.ae_token_id = self.decoder_tokenizer.convert_tokens_to_ids('<AE>')
        self.decoder_tokenizer.sep_token_id = self.decoder_tokenizer.convert_tokens_to_ids('<SEP>')
        # if pad token ecist then use pad token, othrwise bos token
        if self.decoder_tokenizer.pad_token_id is None:
            self.decoder_tokenizer.pad_token_id = self.decoder_tokenizer.bos_token_id

        # resize the tokenizer embedding
        self.decoder.resize_token_embeddings(len(self.decoder_tokenizer))
        self.decoder.generation_config.top_p=None
        self.decoder.generation_config.temperature=None
        self.compr_model_name = cfg.compr_model_name
        # other settings
        self.generation_top_k = cfg.generation_top_k
        self.sep = cfg.sep
        self.compr_rate = cfg.compr_rate
        self.local_rank = os.getenv('LOCAL_RANK', '0')

    def compress_and_replace_emb(self, enc_input_ids, enc_attention_mask, dec_input_ids):
        indices = range(0, enc_input_ids.size(0) + 1, self.generation_top_k)
        if self.compr:
            compressed_embs = self.compr(enc_input_ids, enc_attention_mask)
            input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, indices)
        else:
            compressed_embs = self.compr_decoder(enc_input_ids, enc_attention_mask)
            input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, indices)
        return input_embeds
    
    def compr_decoder(self, input_ids, attention_mask):
        emb = self.decoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1]
        mask = input_ids == self.decoder_tokenizer.mem_token_id
        return emb[mask].reshape(emb.size(0), -1, emb.size(-1))
    

    def replace_embeddings(self, compressed_embs, dec_input_ids, indices):
        # Embed the decoder input
        inputs_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
        num_embs = compressed_embs.size(1)
        if self.sep:
            slot_len = num_embs + 1
        else:
            slot_len = num_embs
        # get first mem_token inidices
        first_mem_token_indices = torch.argmax((dec_input_ids == self.decoder_tokenizer.mem_token_id).int(), dim=1)
        batch_size = inputs_embeds.size(0)
        # for each example in batch, replace them with compressed embeddings 
        for i in range(batch_size):
            for j in range(indices[i], indices[i + 1]):
                start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len
                inputs_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
        return inputs_embeds


    def forward(self, 

            enc_input_ids: torch.LongTensor = None,

            enc_attention_mask: torch.LongTensor = None,

            dec_input_ids: torch.LongTensor = None, 

            dec_attention_mask: torch.LongTensor = None,

            labels: torch.LongTensor = None):
        
        # enc_input_ids: stores the contexts, should be flattened from all queries before input, dimention (batch_size*generation_top_k, token_length)
        # enc_attention_mask: attention mask of enc_input_ids
        # dec_input_ids: stores the prompts (including mem tokens), dimention (batch_size, token_length)
        # dec_attention_mask: attention mask of dec_input_ids

        # Perform compression with gradient tracking
        inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)

        # if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
        if (self.training_form == "compressor") and (self.compr is None):
            inputs_embeds  = inputs_embeds.detach()

        # decoding
        decoder_outputs = self.decoder(inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, labels=labels)

        return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}


        
    def generate(self, model_input, max_new_tokens=128):
        device = self.decoder.device
        enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask']
        inputs_embeds = self.compress_and_replace_emb(enc_input_ids.to(device), enc_attention_mask.to(device), dec_input_ids.to(device))
        output_ids = self.decoder.generate(
            inputs_embeds=inputs_embeds.to(device), 
            attention_mask=dec_attention_mask.to(device),
            do_sample=False,
            top_p=None,
            max_new_tokens=max_new_tokens
            )
        decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        return decoded
        
    def generate_from_text(self, contexts, questions, max_new_tokens=128):
        # for each question in list give input a list of contexts of equal length 
        # first make sure that every list in contexts are having the same length
        assert len(contexts) == len(questions)
        assert all([len(context) == len(contexts[0]) for context in contexts])

        # prepare inp_enc for compression
        # first flatten the contexts
        self.generation_top_k = len(contexts[0])
        flat_contexts = sum(contexts, [])
        #tokenize the contexts, depending if compr exist or not
        if self.compr is not None:
            enc_input = self.compr.tokenizer(flat_contexts, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=self.compr_rate)
            num_mem_tokens = math.ceil(enc_input['input_ids'].size(1) / self.compr_rate)
        else:
            # first need to add special token in flat_contexts
            flat_contexts = [self.decoder_tokenizer.enc_token + self.decoder_tokenizer.bos_token +  context  + self.decoder_tokenizer.bos_token  for context in flat_contexts]
            enc_input  = self.decoder_tokenizer(flat_contexts, truncation=True, return_tensors='pt', padding="longest")
            num_mem_tokens = math.ceil((enc_input['input_ids'].size(1)-3) / self.compr_rate)
            mem_tokens = torch.full((enc_input['input_ids'].size(0), num_mem_tokens), self.decoder_tokenizer.mem_token_id, dtype=torch.long)
            enc_input['input_ids'] = torch.cat([mem_tokens, enc_input['input_ids']], dim=1)
            enc_input['attention_mask'] = torch.cat([torch.ones_like(mem_tokens), enc_input['attention_mask']], dim=1)
        
        
        # prepare inp_dec
        mem_tokens = self.decoder_tokenizer.mem_token * num_mem_tokens
        if self.sep:
            mem_tokens += self.decoder_tokenizer.sep_token
        
        instr = [self.decoder_tokenizer.bos_token + mem_tokens* self.generation_top_k + '[INST]' + question + '\n[/INST]\n' for question in questions]
        inp_dec = self.decoder_tokenizer(instr, truncation=True, return_tensors='pt', padding="longest")

        # generate
        model_input = {
            'enc_input_ids': enc_input['input_ids'],
            'enc_attention_mask': enc_input['attention_mask'],
            'dec_input_ids': inp_dec['input_ids'],
            'dec_attention_mask': inp_dec['attention_mask']
        }

        return self.generate(model_input, max_new_tokens)