Crystalcareai commited on
Commit
cc66ab7
·
verified ·
1 Parent(s): 6a35495

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +12 -10
generate.py CHANGED
@@ -45,15 +45,12 @@ def custom_generate(
45
  ):
46
  if input_ids is None or input_ids.nelement() == 0:
47
  # If input_ids is None or an empty tensor, create a default input tensor
48
- input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]])
49
- attention_mask = torch.ones_like(input_ids)
50
 
51
  device = input_ids.device
52
  with torch.no_grad():
53
  batch_size = input_ids.shape[0]
54
- if max_new_tokens is None:
55
- raise ValueError("max_new_tokens must be provided.")
56
-
57
  finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
58
  generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
59
 
@@ -156,10 +153,10 @@ def generate(
156
  torch_dtype=torch.bfloat16,
157
  **model_kwargs,
158
  ):
159
- # Set default value for max_new_tokens if not provided
160
- if max_new_tokens is None:
161
- max_new_tokens = 128 # Set a reasonable default value
162
 
 
 
 
163
  # Set model attributes
164
  self.max_thoughts = n_ahead + n_ahead_talk + 1
165
  self.merged_talk_heads = merged_talk_heads
@@ -186,11 +183,16 @@ def generate(
186
  if isinstance(input_ids, str):
187
  input_ids = self.tokenizer.encode(input_ids, return_tensors='pt')
188
 
 
 
 
 
 
189
  generated_token_ids = custom_generate(
190
  self,
191
- input_ids=input_ids, # Pass input_ids explicitly
192
  attention_mask=attention_mask,
193
- max_new_tokens=max_new_tokens, # Pass max_new_tokens explicitly
194
  min_length=min_length,
195
  do_sample=do_sample,
196
  early_stopping=early_stopping,
 
45
  ):
46
  if input_ids is None or input_ids.nelement() == 0:
47
  # If input_ids is None or an empty tensor, create a default input tensor
48
+ input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]]).to(self.device)
49
+ attention_mask = torch.ones_like(input_ids).to(self.device)
50
 
51
  device = input_ids.device
52
  with torch.no_grad():
53
  batch_size = input_ids.shape[0]
 
 
 
54
  finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
55
  generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
56
 
 
153
  torch_dtype=torch.bfloat16,
154
  **model_kwargs,
155
  ):
 
 
 
156
 
157
+ if max_new_tokens is None:
158
+ max_new_tokens = 128
159
+
160
  # Set model attributes
161
  self.max_thoughts = n_ahead + n_ahead_talk + 1
162
  self.merged_talk_heads = merged_talk_heads
 
183
  if isinstance(input_ids, str):
184
  input_ids = self.tokenizer.encode(input_ids, return_tensors='pt')
185
 
186
+ # Move input_ids and attention_mask to the same device as the model
187
+ input_ids = input_ids.to(self.device)
188
+ if attention_mask is not None:
189
+ attention_mask = attention_mask.to(self.device)
190
+
191
  generated_token_ids = custom_generate(
192
  self,
193
+ input_ids=input_ids,
194
  attention_mask=attention_mask,
195
+ max_new_tokens=max_new_tokens,
196
  min_length=min_length,
197
  do_sample=do_sample,
198
  early_stopping=early_stopping,