Crystalcareai commited on
Commit
ea97f7f
·
verified ·
1 Parent(s): eacd98b

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +20 -9
generate.py CHANGED
@@ -56,9 +56,12 @@ def custom_generate(
56
  with torch.no_grad():
57
  batch_size = input_ids.shape[0]
58
  finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
 
 
59
  generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
60
 
61
- for cur_token_idx in range(max_new_tokens):
 
62
  # Sample the next token
63
  new_ids = self(
64
  input_ids[~finished_generating],
@@ -96,13 +99,11 @@ def custom_generate(
96
  input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
97
  generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
98
 
99
- if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
100
  finished_generating[answer_idx] = 1
101
 
102
- # Check if the end token is generated
103
- if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
104
- finished_generating[answer_idx] = 1
105
-
106
  if finished_generating.all():
107
  break
108
 
@@ -112,7 +113,13 @@ def custom_generate(
112
  print("Generated Token IDs shape:", generated_token_ids.shape)
113
  print("Generated Token IDs:", generated_token_ids)
114
 
115
- return generated_token_ids
 
 
 
 
 
 
116
 
117
  def generate(
118
  self,
@@ -161,6 +168,7 @@ def generate(
161
  use_weighted_talk_head=True,
162
  trust_remote_code=True,
163
  torch_dtype=torch.bfloat16,
 
164
  **model_kwargs,
165
  ):
166
 
@@ -198,7 +206,7 @@ def generate(
198
  if attention_mask is not None:
199
  attention_mask = attention_mask.to(self.device)
200
 
201
- generated_token_ids = custom_generate(
202
  self,
203
  input_ids=input_ids,
204
  attention_mask=attention_mask,
@@ -235,4 +243,7 @@ def generate(
235
  **model_kwargs,
236
  )
237
 
238
- return generated_token_ids
 
 
 
 
56
  with torch.no_grad():
57
  batch_size = input_ids.shape[0]
58
  finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
59
+
60
+ max_length = input_ids.shape[1] + max_new_tokens
61
  generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
62
 
63
+ cur_token_idx = 0
64
+ while cur_token_idx < max_length:
65
  # Sample the next token
66
  new_ids = self(
67
  input_ids[~finished_generating],
 
99
  input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
100
  generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
101
 
102
+ if new_ids_sampled == self.tokenizer.eos_token_id or cur_token_idx + 1 == max_length:
103
  finished_generating[answer_idx] = 1
104
 
105
+ cur_token_idx += 1
106
+
 
 
107
  if finished_generating.all():
108
  break
109
 
 
113
  print("Generated Token IDs shape:", generated_token_ids.shape)
114
  print("Generated Token IDs:", generated_token_ids)
115
 
116
+ # Decode the generated token IDs into text
117
+ generated_text = self.tokenizer.decode(generated_token_ids[0], skip_special_tokens=True)
118
+ print("Generated Text:")
119
+ print(generated_text)
120
+
121
+ return generated_token_ids, generated_text
122
+
123
 
124
  def generate(
125
  self,
 
168
  use_weighted_talk_head=True,
169
  trust_remote_code=True,
170
  torch_dtype=torch.bfloat16,
171
+ dynamic_temperature=None,
172
  **model_kwargs,
173
  ):
174
 
 
206
  if attention_mask is not None:
207
  attention_mask = attention_mask.to(self.device)
208
 
209
+ generated_token_ids, generated_text = custom_generate(
210
  self,
211
  input_ids=input_ids,
212
  attention_mask=attention_mask,
 
243
  **model_kwargs,
244
  )
245
 
246
+ if dynamic_temperature is not None:
247
+ return generated_text
248
+ else:
249
+ return generated_token_ids