Crystalcareai commited on
Commit
7874fb0
·
verified ·
1 Parent(s): f321869

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +12 -49
generate.py CHANGED
@@ -1,12 +1,10 @@
1
  import torch
2
- from transformers.utils import logging
3
  from transformers.generation.utils import (
4
  GenerationMixin,
5
  validate_stopping_criteria,
6
  StoppingCriteriaList,
7
  )
8
-
9
- logger = logging.get_logger(__name__)
10
 
11
 
12
  def custom_generate(
@@ -45,8 +43,9 @@ def custom_generate(
45
  synced_gpus=None,
46
  **kwargs,
47
  ):
 
48
  with torch.no_grad():
49
- finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
50
 
51
  while not finished_generating.all() and input_ids.shape[1] < max_length:
52
  # Sample the next token
@@ -72,7 +71,7 @@ def custom_generate(
72
  if last_token_idx + 1 >= len(base_answer_ids):
73
  # Add padding everywhere
74
  new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
75
- device=input_ids.device)
76
  input_ids = torch.cat([input_ids, new_padding], dim=-1)
77
  if attention_mask is not None:
78
  attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
@@ -92,8 +91,7 @@ def custom_generate(
92
  streamer.put(new_ids_sampled)
93
 
94
  generated_token_ids = input_ids.tolist()
95
-
96
- return generated_token_ids
97
 
98
 
99
  def generate(
@@ -105,8 +103,7 @@ def generate(
105
  do_sample=None,
106
  early_stopping=None,
107
  num_beams=None,
108
- temperature= 0.9,
109
- streamer=None,
110
  top_k=None,
111
  top_p=None,
112
  repetition_penalty=None,
@@ -126,49 +123,15 @@ def generate(
126
  output_hidden_states=None,
127
  output_scores=None,
128
  return_dict_in_generate=None,
129
- forced_bos_token_id= True,
130
- forced_eos_token_id= True,
131
  remove_invalid_values=None,
132
  synced_gpus=None,
133
- n_ahead=12,
134
- n_ahead_talk=4,
135
- merged_talk_heads=True,
136
- merged_lm_and_talk_heads=False,
137
- merged_lm_and_think_heads=True,
138
- use_concat_talk_head=True,
139
- use_shallow_think=True,
140
- use_shallow_talk=False,
141
- use_complex_think_head=False,
142
- use_complex_talk_head=True,
143
- use_weighted_talk_head=True,
144
- trust_remote_code=True,
145
- torch_dtype=None,
146
  **model_kwargs,
147
  ):
148
- # Set model attributes
149
- self.max_thoughts = n_ahead + n_ahead_talk + 1
150
- self.merged_talk_heads = merged_talk_heads
151
- self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
152
- self.merged_lm_and_think_heads = merged_lm_and_think_heads
153
- self.use_concat_talk_head = use_concat_talk_head
154
- self.use_shallow_think = use_shallow_think
155
- self.use_shallow_talk = use_shallow_talk
156
- self.use_complex_think_head = use_complex_think_head
157
- self.use_complex_talk_head = use_complex_talk_head
158
- self.use_weighted_talk_head = use_weighted_talk_head
159
-
160
- # Set model properties
161
- self.use_end_thought_token = True
162
- self.use_start_thought_token = True
163
- self.n_ahead = n_ahead
164
- self.n_passes = 1
165
- self.eval_mode = True
166
- self.first_run = False
167
- self.rm_initialized = True
168
- self.original_mode = False
169
-
170
- # Generate using the custom generate function
171
- generated_token_ids = custom_generate(
172
  self,
173
  input_ids=input_ids,
174
  attention_mask=attention_mask,
@@ -205,4 +168,4 @@ def generate(
205
  **model_kwargs,
206
  )
207
 
208
- return generated_token_ids
 
1
  import torch
 
2
  from transformers.generation.utils import (
3
  GenerationMixin,
4
  validate_stopping_criteria,
5
  StoppingCriteriaList,
6
  )
7
+ from transformers import TextStreamer
 
8
 
9
 
10
  def custom_generate(
 
43
  synced_gpus=None,
44
  **kwargs,
45
  ):
46
+ device = input_ids.device
47
  with torch.no_grad():
48
+ finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
49
 
50
  while not finished_generating.all() and input_ids.shape[1] < max_length:
51
  # Sample the next token
 
71
  if last_token_idx + 1 >= len(base_answer_ids):
72
  # Add padding everywhere
73
  new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
74
+ device=device)
75
  input_ids = torch.cat([input_ids, new_padding], dim=-1)
76
  if attention_mask is not None:
77
  attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
 
91
  streamer.put(new_ids_sampled)
92
 
93
  generated_token_ids = input_ids.tolist()
94
+ return generated_token_ids, attention_mask
 
95
 
96
 
97
  def generate(
 
103
  do_sample=None,
104
  early_stopping=None,
105
  num_beams=None,
106
+ temperature=1.1,
 
107
  top_k=None,
108
  top_p=None,
109
  repetition_penalty=None,
 
123
  output_hidden_states=None,
124
  output_scores=None,
125
  return_dict_in_generate=None,
126
+ forced_bos_token_id=None,
127
+ forced_eos_token_id=None,
128
  remove_invalid_values=None,
129
  synced_gpus=None,
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  **model_kwargs,
131
  ):
132
+ streamer = TextStreamer(self.tokenizer, skip_prompt=False, skip_special_tokens=True)
133
+
134
+ generated_token_ids, attention_mask = custom_generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  self,
136
  input_ids=input_ids,
137
  attention_mask=attention_mask,
 
168
  **model_kwargs,
169
  )
170
 
171
+ return generated_token_ids, attention_mask