Crystalcareai commited on
Commit
504e404
·
verified ·
1 Parent(s): 1c368e2

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +5 -3
generate.py CHANGED
@@ -56,7 +56,7 @@ def custom_generate(
56
  generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
57
 
58
  cur_token_idx = 0
59
- while cur_token_idx < max_length:
60
  # Sample the next token
61
  new_ids = self(
62
  input_ids[~finished_generating],
@@ -88,9 +88,11 @@ def custom_generate(
88
  if attention_mask is not None:
89
  attention_mask[answer_idx, last_token_idx + 1] = 1
90
  input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
91
- generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
92
 
93
- if new_ids_sampled == self.tokenizer.eos_token_id or cur_token_idx + 1 == max_length:
 
 
 
94
  finished_generating[answer_idx] = 1
95
 
96
  cur_token_idx += 1
 
56
  generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
57
 
58
  cur_token_idx = 0
59
+ while cur_token_idx < max_new_tokens:
60
  # Sample the next token
61
  new_ids = self(
62
  input_ids[~finished_generating],
 
88
  if attention_mask is not None:
89
  attention_mask[answer_idx, last_token_idx + 1] = 1
90
  input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
 
91
 
92
+ if cur_token_idx < max_new_tokens:
93
+ generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
94
+
95
+ if new_ids_sampled == self.tokenizer.eos_token_id or cur_token_idx + 1 == max_new_tokens:
96
  finished_generating[answer_idx] = 1
97
 
98
  cur_token_idx += 1