Files changed (1) hide show
  1. modeling_cocom.py +2 -2
modeling_cocom.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PretrainedConfig, AutoModel
2
  import torch
3
  import math
4
  from peft import get_peft_model, LoraConfig, TaskType
@@ -263,7 +263,7 @@ class COCOM(PreTrainedModel):
263
  attention_mask=dec_attention_mask.to(device),
264
  do_sample=False,
265
  top_p=None,
266
- max_new_tokens=max_new_tokens
267
  )
268
  decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
269
  return decoded
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PretrainedConfig, AutoModel,LongformerForCausalLM, LongformerTokenizer
2
  import torch
3
  import math
4
  from peft import get_peft_model, LoraConfig, TaskType
 
263
  attention_mask=dec_attention_mask.to(device),
264
  do_sample=False,
265
  top_p=None,
266
+ max_new_tokens=min(max_new_tokens, 4096)
267
  )
268
  decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
269
  return decoded