Crystalcareai commited on
Commit
53a9463
·
verified ·
1 Parent(s): 27d3137

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +19 -26
generate.py CHANGED
@@ -6,6 +6,7 @@ from transformers.generation.utils import (
6
  )
7
  from transformers import TextStreamer
8
 
 
9
  def custom_generate(
10
  self,
11
  input_ids,
@@ -44,19 +45,19 @@ def custom_generate(
44
  ):
45
  if input_ids is None or input_ids.nelement() == 0:
46
  # If input_ids is None or an empty tensor, create a default input tensor
47
- input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]]).to(self.device)
48
- attention_mask = torch.ones_like(input_ids).to(self.device)
49
 
50
  device = input_ids.device
51
  with torch.no_grad():
52
  batch_size = input_ids.shape[0]
 
 
 
53
  finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
54
-
55
- max_length = input_ids.shape[1] + max_new_tokens
56
  generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
57
 
58
- cur_token_idx = 0
59
- while cur_token_idx < max_new_tokens:
60
  # Sample the next token
61
  new_ids = self(
62
  input_ids[~finished_generating],
@@ -80,7 +81,7 @@ def custom_generate(
80
  if last_token_idx + 1 >= len(base_answer_ids):
81
  # Add padding everywhere
82
  new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
83
- device=device)
84
  input_ids = torch.cat([input_ids, new_padding], dim=-1)
85
  if attention_mask is not None:
86
  attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
@@ -88,15 +89,15 @@ def custom_generate(
88
  if attention_mask is not None:
89
  attention_mask[answer_idx, last_token_idx + 1] = 1
90
  input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
 
91
 
92
- if cur_token_idx < max_new_tokens:
93
- generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
94
-
95
- if new_ids_sampled == self.tokenizer.eos_token_id or cur_token_idx + 1 == max_new_tokens:
96
  finished_generating[answer_idx] = 1
97
 
98
- cur_token_idx += 1
99
-
 
 
100
  if finished_generating.all():
101
  break
102
 
@@ -105,6 +106,7 @@ def custom_generate(
105
 
106
  return generated_token_ids
107
 
 
108
  def generate(
109
  self,
110
  input_ids,
@@ -152,12 +154,11 @@ def generate(
152
  use_weighted_talk_head=True,
153
  trust_remote_code=True,
154
  torch_dtype=torch.bfloat16,
155
- dynamic_temperature=None,
156
  **model_kwargs,
157
  ):
158
-
159
  if max_new_tokens is None:
160
- max_new_tokens = 128
161
 
162
  # Set model attributes
163
  self.max_thoughts = n_ahead + n_ahead_talk + 1
@@ -185,16 +186,11 @@ def generate(
185
  if isinstance(input_ids, str):
186
  input_ids = self.tokenizer.encode(input_ids, return_tensors='pt')
187
 
188
- # Move input_ids and attention_mask to the same device as the model
189
- input_ids = input_ids.to(self.device)
190
- if attention_mask is not None:
191
- attention_mask = attention_mask.to(self.device)
192
-
193
  generated_token_ids = custom_generate(
194
  self,
195
- input_ids=input_ids,
196
  attention_mask=attention_mask,
197
- max_new_tokens=max_new_tokens,
198
  min_length=min_length,
199
  do_sample=do_sample,
200
  early_stopping=early_stopping,
@@ -227,7 +223,4 @@ def generate(
227
  **model_kwargs,
228
  )
229
 
230
- # Convert the generated token IDs tensor to text
231
- generated_text = self.tokenizer.decode(generated_token_ids[0], skip_special_tokens=False)
232
-
233
  return generated_token_ids
 
6
  )
7
  from transformers import TextStreamer
8
 
9
+
10
  def custom_generate(
11
  self,
12
  input_ids,
 
45
  ):
46
  if input_ids is None or input_ids.nelement() == 0:
47
  # If input_ids is None or an empty tensor, create a default input tensor
48
+ input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]])
49
+ attention_mask = torch.ones_like(input_ids)
50
 
51
  device = input_ids.device
52
  with torch.no_grad():
53
  batch_size = input_ids.shape[0]
54
+ if max_new_tokens is None:
55
+ raise ValueError("max_new_tokens must be provided.")
56
+
57
  finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
 
 
58
  generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
59
 
60
+ for cur_token_idx in range(max_new_tokens):
 
61
  # Sample the next token
62
  new_ids = self(
63
  input_ids[~finished_generating],
 
81
  if last_token_idx + 1 >= len(base_answer_ids):
82
  # Add padding everywhere
83
  new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
84
+ device=device)
85
  input_ids = torch.cat([input_ids, new_padding], dim=-1)
86
  if attention_mask is not None:
87
  attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
 
89
  if attention_mask is not None:
90
  attention_mask[answer_idx, last_token_idx + 1] = 1
91
  input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
92
+ generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
93
 
94
+ if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
 
 
 
95
  finished_generating[answer_idx] = 1
96
 
97
+ # Check if the end token is generated
98
+ if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
99
+ finished_generating[answer_idx] = 1
100
+
101
  if finished_generating.all():
102
  break
103
 
 
106
 
107
  return generated_token_ids
108
 
109
+
110
  def generate(
111
  self,
112
  input_ids,
 
154
  use_weighted_talk_head=True,
155
  trust_remote_code=True,
156
  torch_dtype=torch.bfloat16,
 
157
  **model_kwargs,
158
  ):
159
+ # Set default value for max_new_tokens if not provided
160
  if max_new_tokens is None:
161
+ max_new_tokens = 20 # Set a reasonable default value
162
 
163
  # Set model attributes
164
  self.max_thoughts = n_ahead + n_ahead_talk + 1
 
186
  if isinstance(input_ids, str):
187
  input_ids = self.tokenizer.encode(input_ids, return_tensors='pt')
188
 
 
 
 
 
 
189
  generated_token_ids = custom_generate(
190
  self,
191
+ input_ids=input_ids, # Pass input_ids explicitly
192
  attention_mask=attention_mask,
193
+ max_new_tokens=max_new_tokens, # Pass max_new_tokens explicitly
194
  min_length=min_length,
195
  do_sample=do_sample,
196
  early_stopping=early_stopping,
 
223
  **model_kwargs,
224
  )
225
 
 
 
 
226
  return generated_token_ids