Crystalcareai commited on
Commit
ffe6ef0
·
verified ·
1 Parent(s): c48f22f

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +7 -4
generate.py CHANGED
@@ -11,7 +11,7 @@ def custom_generate(
11
  self,
12
  input_ids,
13
  attention_mask=None,
14
- max_length=None,
15
  min_length=None,
16
  do_sample=None,
17
  early_stopping=None,
@@ -47,7 +47,7 @@ def custom_generate(
47
  with torch.no_grad():
48
  finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
49
 
50
- while not finished_generating.all() and input_ids.shape[1] < max_length:
51
  # Sample the next token
52
  new_ids = self(
53
  input_ids[~finished_generating],
@@ -86,6 +86,9 @@ def custom_generate(
86
  # Check if the end token is generated
87
  if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
88
  finished_generating[answer_idx] = 1
 
 
 
89
 
90
  if streamer is not None:
91
  streamer.put(new_ids_sampled)
@@ -98,7 +101,7 @@ def generate(
98
  self,
99
  input_ids,
100
  attention_mask=None,
101
- max_length=None,
102
  min_length=None,
103
  do_sample=None,
104
  early_stopping=None,
@@ -169,7 +172,7 @@ def generate(
169
  self,
170
  input_ids=input_ids,
171
  attention_mask=attention_mask,
172
- max_length=max_length,
173
  min_length=min_length,
174
  do_sample=do_sample,
175
  early_stopping=early_stopping,
 
11
  self,
12
  input_ids,
13
  attention_mask=None,
14
+ max_new_tokens=None,
15
  min_length=None,
16
  do_sample=None,
17
  early_stopping=None,
 
47
  with torch.no_grad():
48
  finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
49
 
50
+ for cur_token_idx in range(max_new_tokens):
51
  # Sample the next token
52
  new_ids = self(
53
  input_ids[~finished_generating],
 
86
  # Check if the end token is generated
87
  if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
88
  finished_generating[answer_idx] = 1
89
+
90
+ if finished_generating.all():
91
+ break
92
 
93
  if streamer is not None:
94
  streamer.put(new_ids_sampled)
 
101
  self,
102
  input_ids,
103
  attention_mask=None,
104
+ max_new_tokens=None,
105
  min_length=None,
106
  do_sample=None,
107
  early_stopping=None,
 
172
  self,
173
  input_ids=input_ids,
174
  attention_mask=attention_mask,
175
+ max_new_tokens=max_new_tokens,
176
  min_length=min_length,
177
  do_sample=do_sample,
178
  early_stopping=early_stopping,