Crystalcareai commited on
Commit
dcc7444
·
verified ·
1 Parent(s): 38421a3

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +56 -199
generate.py CHANGED
@@ -1,214 +1,71 @@
1
  import torch
 
2
 
3
  def custom_generate(
4
  self,
5
  input_ids,
6
  attention_mask=None,
7
  max_new_tokens=None,
8
- min_length=None,
9
- do_sample=None,
10
- early_stopping=None,
11
- num_beams=None,
12
- temperature=None,
13
- top_k=None,
14
- top_p=None,
15
- repetition_penalty=None,
16
- bad_words_ids=None,
17
- bos_token_id=None,
18
  pad_token_id=None,
19
  eos_token_id=None,
20
- streamer=None,
21
- length_penalty=None,
22
- no_repeat_ngram_size=None,
23
- num_return_sequences=None,
24
- decoder_start_token_id=None,
25
- use_cache=None,
26
- num_beam_groups=None,
27
- diversity_penalty=None,
28
- prefix_allowed_tokens_fn=None,
29
- output_attentions=None,
30
- output_hidden_states=None,
31
- output_scores=None,
32
- return_dict_in_generate=None,
33
- forced_bos_token_id=None,
34
- forced_eos_token_id=None,
35
- remove_invalid_values=None,
36
- synced_gpus=None,
37
- **kwargs,
38
  ):
39
  device = input_ids.device
40
- with torch.no_grad():
41
- finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
42
-
43
- if max_new_tokens is None:
44
- max_new_tokens = 50 # Default value if not specified
45
- for cur_token_idx in range(max_new_tokens):
46
- new_ids = self(
47
- input_ids[~finished_generating],
48
- attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
49
- **kwargs
50
- )['logits']
51
-
52
- # Mask out the start and end thought tokens so we don't accidentally sample them
53
- new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
54
-
55
- for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
56
- # Find the index of the last token that is not padding
57
- base_answer_ids = input_ids[answer_idx]
58
- new_answer_ids = new_ids[list_idx]
59
- last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
60
-
61
- new_ids_sampled = torch.multinomial(
62
- torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
63
-
64
- # Assign the new id to the last token
65
- if last_token_idx + 1 >= len(base_answer_ids):
66
- # Add padding everywhere
67
- new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
68
- device=device)
69
- input_ids = torch.cat([input_ids, new_padding], dim=-1)
70
- if attention_mask is not None:
71
- attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
72
-
73
- if attention_mask is not None:
74
- attention_mask[answer_idx, last_token_idx + 1] = 1
75
- input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
76
-
77
- if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
78
- finished_generating[answer_idx] = 1
79
-
80
- # Check if the end token is generated
81
- if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
82
- finished_generating[answer_idx] = 1
83
-
84
- if finished_generating.all():
85
- break
86
-
87
- if streamer is not None:
88
- streamer.put(new_ids_sampled)
89
-
90
- from collections import namedtuple
91
- GenerateOutput = namedtuple("GenerateOutput", ["sequences", "scores", "attentions", "hidden_states"])
92
 
93
- # Convert the generated token IDs to a tensor
94
- generated_token_ids_tensor = input_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- output = GenerateOutput(
97
- sequences=generated_token_ids_tensor,
98
- scores=None,
99
- attentions=None,
100
- hidden_states=None
101
- )
 
 
 
 
102
 
103
  return output
104
-
105
-
106
- def generate(
107
- self,
108
- input_ids,
109
- attention_mask=None,
110
- max_new_tokens=None,
111
- min_length=None,
112
- do_sample=None,
113
- early_stopping=None,
114
- num_beams=None,
115
- temperature=1.1,
116
- streamer=None,
117
- top_k=None,
118
- top_p=None,
119
- repetition_penalty=None,
120
- bad_words_ids=None,
121
- bos_token_id=None,
122
- pad_token_id=None,
123
- eos_token_id=None,
124
- length_penalty=None,
125
- no_repeat_ngram_size=None,
126
- num_return_sequences=None,
127
- decoder_start_token_id=None,
128
- use_cache=None,
129
- num_beam_groups=None,
130
- diversity_penalty=None,
131
- prefix_allowed_tokens_fn=None,
132
- output_attentions=None,
133
- output_hidden_states=None,
134
- output_scores=None,
135
- return_dict_in_generate=None,
136
- forced_bos_token_id=None,
137
- forced_eos_token_id=None,
138
- remove_invalid_values=None,
139
- synced_gpus=None,
140
- n_ahead=4,
141
- n_ahead_talk=4,
142
- merged_talk_heads=True,
143
- merged_lm_and_talk_heads=False,
144
- merged_lm_and_think_heads=True,
145
- use_concat_talk_head=True,
146
- use_shallow_think=True,
147
- use_shallow_talk=False,
148
- use_complex_think_head=False,
149
- use_complex_talk_head=True,
150
- use_weighted_talk_head=True,
151
- trust_remote_code=True,
152
- torch_dtype=torch.bfloat16,
153
- **model_kwargs,
154
- ):
155
- # Set model attributes
156
- self.max_thoughts = n_ahead + n_ahead_talk + 1
157
- self.merged_talk_heads = merged_talk_heads
158
- self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
159
- self.merged_lm_and_think_heads = merged_lm_and_think_heads
160
- self.use_concat_talk_head = use_concat_talk_head
161
- self.use_shallow_think = use_shallow_think
162
- self.use_shallow_talk = use_shallow_talk
163
- self.use_complex_think_head = use_complex_think_head
164
- self.use_complex_talk_head = use_complex_talk_head
165
- self.use_weighted_talk_head = use_weighted_talk_head
166
-
167
- # Set model properties
168
- self.use_end_thought_token = True
169
- self.use_start_thought_token = True
170
- self.n_ahead = n_ahead
171
- self.n_passes = 1
172
- self.eval_mode = True
173
- self.first_run = False
174
- self.rm_initialized = True
175
- self.original_mode = False
176
-
177
- output = custom_generate(
178
- self,
179
- input_ids=input_ids,
180
- attention_mask=attention_mask,
181
- max_new_tokens=max_new_tokens,
182
- min_length=min_length,
183
- do_sample=do_sample,
184
- early_stopping=early_stopping,
185
- num_beams=num_beams,
186
- temperature=temperature,
187
- top_k=top_k,
188
- top_p=top_p,
189
- repetition_penalty=repetition_penalty,
190
- bad_words_ids=bad_words_ids,
191
- bos_token_id=bos_token_id,
192
- pad_token_id=pad_token_id,
193
- eos_token_id=eos_token_id,
194
- length_penalty=length_penalty,
195
- no_repeat_ngram_size=no_repeat_ngram_size,
196
- num_return_sequences=num_return_sequences,
197
- decoder_start_token_id=decoder_start_token_id,
198
- use_cache=use_cache,
199
- num_beam_groups=num_beam_groups,
200
- diversity_penalty=diversity_penalty,
201
- prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
202
- output_attentions=output_attentions,
203
- output_hidden_states=output_hidden_states,
204
- output_scores=output_scores,
205
- return_dict_in_generate=return_dict_in_generate,
206
- forced_bos_token_id=forced_bos_token_id,
207
- forced_eos_token_id=forced_eos_token_id,
208
- remove_invalid_values=remove_invalid_values,
209
- synced_gpus=synced_gpus,
210
- streamer=streamer,
211
- **model_kwargs,
212
- )
213
-
214
- return output
 
1
  import torch
2
+ from transformers import LogitsProcessorList, StoppingCriteriaList
3
 
4
  def custom_generate(
5
  self,
6
  input_ids,
7
  attention_mask=None,
8
  max_new_tokens=None,
9
+ temperature=1.0,
10
+ do_sample=True,
 
 
 
 
 
 
 
 
11
  pad_token_id=None,
12
  eos_token_id=None,
13
+ **kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  ):
15
  device = input_ids.device
16
+ logits_processor = LogitsProcessorList()
17
+ stopping_criteria = StoppingCriteriaList()
18
+
19
+ if attention_mask is None:
20
+ attention_mask = torch.ones_like(input_ids)
21
+
22
+ # Initialize unfinished sentences to manage the loop for early stopping
23
+ unfinished_sents = input_ids.new(input_ids.shape[0]).fill_(1)
24
+
25
+ cur_len = input_ids.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ while cur_len < max_new_tokens:
28
+ model_outputs = self(
29
+ input_ids=input_ids,
30
+ attention_mask=attention_mask,
31
+ use_cache=True,
32
+ return_dict=True
33
+ )
34
+
35
+ next_token_logits = model_outputs.logits[:, -1, :]
36
+
37
+ # Processing logits to avoid generating undesired tokens
38
+ next_token_logits[:, pad_token_id] = -float('inf') # Never select pad
39
+ next_token_logits[:, eos_token_id] = -float('inf') # Avoid generating end token prematurely
40
+
41
+ # Apply temperature scaling and softmax to generate probabilities
42
+ if do_sample:
43
+ probabilities = torch.nn.functional.softmax(next_token_logits / temperature, dim=-1)
44
+ next_token = torch.multinomial(probabilities, num_samples=1).squeeze(1)
45
+ else:
46
+ next_token = next_token_logits.argmax(dim=-1)
47
+
48
+ # Update input_ids and attention_mask
49
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=1)
50
+ new_attention = torch.ones_like(input_ids[:, 0:1])
51
+ attention_mask = torch.cat([attention_mask, new_attention], dim=1)
52
+
53
+ # Check unfinished sentences
54
+ unfinished_sents.mul_(next_token.ne(eos_token_id).long())
55
+ if unfinished_sents.max() == 0:
56
+ break
57
+
58
+ cur_len += 1
59
 
60
+ # Optionally return additional information
61
+ if kwargs.get('return_dict_in_generate', False):
62
+ output = {
63
+ "sequences": input_ids,
64
+ "scores": None, # Placeholder for when score calculations are implemented
65
+ "attentions": model_outputs.attentions,
66
+ "hidden_states": model_outputs.hidden_states
67
+ }
68
+ else:
69
+ output = input_ids
70
 
71
  return output