Crystalcareai commited on
Commit
54bfe84
·
verified ·
1 Parent(s): 33e1ad7

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +4 -0
generate.py CHANGED
@@ -25,6 +25,7 @@ def custom_generate(
25
  bos_token_id=None,
26
  pad_token_id=None,
27
  eos_token_id=None,
 
28
  length_penalty=None,
29
  no_repeat_ngram_size=None,
30
  num_return_sequences=None,
@@ -51,6 +52,7 @@ def custom_generate(
51
  new_ids = self(
52
  input_ids[~finished_generating],
53
  attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
 
54
  **kwargs
55
  )['logits']
56
 
@@ -100,6 +102,7 @@ def generate(
100
  early_stopping=None,
101
  num_beams=None,
102
  temperature=1.0,
 
103
  top_k=None,
104
  top_p=None,
105
  repetition_penalty=None,
@@ -202,6 +205,7 @@ def generate(
202
  forced_eos_token_id=forced_eos_token_id,
203
  remove_invalid_values=remove_invalid_values,
204
  synced_gpus=synced_gpus,
 
205
  **model_kwargs,
206
  )
207
 
 
25
  bos_token_id=None,
26
  pad_token_id=None,
27
  eos_token_id=None,
28
+ streamer=None,
29
  length_penalty=None,
30
  no_repeat_ngram_size=None,
31
  num_return_sequences=None,
 
52
  new_ids = self(
53
  input_ids[~finished_generating],
54
  attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
55
+ streamer=streamer,
56
  **kwargs
57
  )['logits']
58
 
 
102
  early_stopping=None,
103
  num_beams=None,
104
  temperature=1.0,
105
+ streamer=None,
106
  top_k=None,
107
  top_p=None,
108
  repetition_penalty=None,
 
205
  forced_eos_token_id=forced_eos_token_id,
206
  remove_invalid_values=remove_invalid_values,
207
  synced_gpus=synced_gpus,
208
+ streamer=streamer,
209
  **model_kwargs,
210
  )
211