Crystalcareai
commited on
Update generate.py
Browse files- generate.py +4 -0
generate.py
CHANGED
@@ -25,6 +25,7 @@ def custom_generate(
|
|
25 |
bos_token_id=None,
|
26 |
pad_token_id=None,
|
27 |
eos_token_id=None,
|
|
|
28 |
length_penalty=None,
|
29 |
no_repeat_ngram_size=None,
|
30 |
num_return_sequences=None,
|
@@ -51,6 +52,7 @@ def custom_generate(
|
|
51 |
new_ids = self(
|
52 |
input_ids[~finished_generating],
|
53 |
attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
|
|
|
54 |
**kwargs
|
55 |
)['logits']
|
56 |
|
@@ -100,6 +102,7 @@ def generate(
|
|
100 |
early_stopping=None,
|
101 |
num_beams=None,
|
102 |
temperature=1.0,
|
|
|
103 |
top_k=None,
|
104 |
top_p=None,
|
105 |
repetition_penalty=None,
|
@@ -202,6 +205,7 @@ def generate(
|
|
202 |
forced_eos_token_id=forced_eos_token_id,
|
203 |
remove_invalid_values=remove_invalid_values,
|
204 |
synced_gpus=synced_gpus,
|
|
|
205 |
**model_kwargs,
|
206 |
)
|
207 |
|
|
|
25 |
bos_token_id=None,
|
26 |
pad_token_id=None,
|
27 |
eos_token_id=None,
|
28 |
+
streamer=None,
|
29 |
length_penalty=None,
|
30 |
no_repeat_ngram_size=None,
|
31 |
num_return_sequences=None,
|
|
|
52 |
new_ids = self(
|
53 |
input_ids[~finished_generating],
|
54 |
attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
|
55 |
+
streamer=streamer,
|
56 |
**kwargs
|
57 |
)['logits']
|
58 |
|
|
|
102 |
early_stopping=None,
|
103 |
num_beams=None,
|
104 |
temperature=1.0,
|
105 |
+
streamer=None,
|
106 |
top_k=None,
|
107 |
top_p=None,
|
108 |
repetition_penalty=None,
|
|
|
205 |
forced_eos_token_id=forced_eos_token_id,
|
206 |
remove_invalid_values=remove_invalid_values,
|
207 |
synced_gpus=synced_gpus,
|
208 |
+
streamer=streamer,
|
209 |
**model_kwargs,
|
210 |
)
|
211 |
|