Crystalcareai commited on
Commit
b983c45
·
verified ·
1 Parent(s): 5a3b899

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +200 -57
generate.py CHANGED
@@ -1,71 +1,214 @@
1
  import torch
2
- from transformers import LogitsProcessorList, StoppingCriteriaList
3
 
4
- def generate(
5
  self,
6
  input_ids,
7
  attention_mask=None,
8
  max_new_tokens=None,
9
- temperature=1.0,
10
- do_sample=True,
 
 
 
 
 
 
 
 
11
  pad_token_id=None,
12
  eos_token_id=None,
13
- **kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  ):
15
  device = input_ids.device
16
- logits_processor = LogitsProcessorList()
17
- stopping_criteria = StoppingCriteriaList()
18
-
19
- if attention_mask is None:
20
- attention_mask = torch.ones_like(input_ids)
21
-
22
- # Initialize unfinished sentences to manage the loop for early stopping
23
- unfinished_sents = input_ids.new(input_ids.shape[0]).fill_(1)
24
-
25
- cur_len = input_ids.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- while cur_len < max_new_tokens:
28
- model_outputs = self(
29
- input_ids=input_ids,
30
- attention_mask=attention_mask,
31
- use_cache=True,
32
- return_dict=True
33
- )
34
-
35
- next_token_logits = model_outputs.logits[:, -1, :]
36
-
37
- # Processing logits to avoid generating undesired tokens
38
- next_token_logits[:, pad_token_id] = -float('inf') # Never select pad
39
- next_token_logits[:, eos_token_id] = -float('inf') # Avoid generating end token prematurely
40
-
41
- # Apply temperature scaling and softmax to generate probabilities
42
- if do_sample:
43
- probabilities = torch.nn.functional.softmax(next_token_logits / temperature, dim=-1)
44
- next_token = torch.multinomial(probabilities, num_samples=1).squeeze(1)
45
- else:
46
- next_token = next_token_logits.argmax(dim=-1)
47
-
48
- # Update input_ids and attention_mask
49
- input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=1)
50
- new_attention = torch.ones_like(input_ids[:, 0:1])
51
- attention_mask = torch.cat([attention_mask, new_attention], dim=1)
52
-
53
- # Check unfinished sentences
54
- unfinished_sents.mul_(next_token.ne(eos_token_id).long())
55
- if unfinished_sents.max() == 0:
56
- break
57
-
58
- cur_len += 1
59
 
60
- # Optionally return additional information
61
- if kwargs.get('return_dict_in_generate', False):
62
- output = {
63
- "sequences": input_ids,
64
- "scores": None, # Placeholder for when score calculations are implemented
65
- "attentions": model_outputs.attentions,
66
- "hidden_states": model_outputs.hidden_states
67
- }
68
- else:
69
- output = input_ids
70
 
71
  return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
 
2
 
3
+ def custom_generate(
4
  self,
5
  input_ids,
6
  attention_mask=None,
7
  max_new_tokens=None,
8
+ min_length=None,
9
+ do_sample=None,
10
+ early_stopping=None,
11
+ num_beams=None,
12
+ temperature=None,
13
+ top_k=None,
14
+ top_p=None,
15
+ repetition_penalty=None,
16
+ bad_words_ids=None,
17
+ bos_token_id=None,
18
  pad_token_id=None,
19
  eos_token_id=None,
20
+ streamer=None,
21
+ length_penalty=None,
22
+ no_repeat_ngram_size=None,
23
+ num_return_sequences=None,
24
+ decoder_start_token_id=None,
25
+ use_cache=None,
26
+ num_beam_groups=None,
27
+ diversity_penalty=None,
28
+ prefix_allowed_tokens_fn=None,
29
+ output_attentions=None,
30
+ output_hidden_states=None,
31
+ output_scores=None,
32
+ return_dict_in_generate=None,
33
+ forced_bos_token_id=None,
34
+ forced_eos_token_id=None,
35
+ remove_invalid_values=None,
36
+ synced_gpus=None,
37
+ **kwargs,
38
  ):
39
  device = input_ids.device
40
+ with torch.no_grad():
41
+ finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
42
+
43
+ if max_new_tokens is None:
44
+ max_new_tokens = 50 # Default value if not specified
45
+ for cur_token_idx in range(max_new_tokens):
46
+ new_ids = self(
47
+ input_ids[~finished_generating],
48
+ attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
49
+ **kwargs
50
+ )['logits']
51
+
52
+ # Mask out the start and end thought tokens so we don't accidentally sample them
53
+ new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
54
+
55
+ for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
56
+ # Find the index of the last token that is not padding
57
+ base_answer_ids = input_ids[answer_idx]
58
+ new_answer_ids = new_ids[list_idx]
59
+ last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
60
+
61
+ new_ids_sampled = torch.multinomial(
62
+ torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
63
+
64
+ # Assign the new id to the last token
65
+ if last_token_idx + 1 >= len(base_answer_ids):
66
+ # Add padding everywhere
67
+ new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
68
+ device=device)
69
+ input_ids = torch.cat([input_ids, new_padding], dim=-1)
70
+ if attention_mask is not None:
71
+ attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
72
+
73
+ if attention_mask is not None:
74
+ attention_mask[answer_idx, last_token_idx + 1] = 1
75
+ input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
76
+
77
+ if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
78
+ finished_generating[answer_idx] = 1
79
+
80
+ # Check if the end token is generated
81
+ if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
82
+ finished_generating[answer_idx] = 1
83
+
84
+ if finished_generating.all():
85
+ break
86
+
87
+ if streamer is not None:
88
+ streamer.put(new_ids_sampled)
89
+
90
+ from collections import namedtuple
91
+ GenerateOutput = namedtuple("GenerateOutput", ["sequences", "scores", "attentions", "hidden_states"])
92
 
93
+ # Convert the generated token IDs to a tensor
94
+ generated_token_ids_tensor = input_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ output = GenerateOutput(
97
+ sequences=generated_token_ids_tensor,
98
+ scores=None,
99
+ attentions=None,
100
+ hidden_states=None
101
+ )
 
 
 
 
102
 
103
  return output
104
+
105
+
106
+ def generate(
107
+ self,
108
+ input_ids,
109
+ attention_mask=None,
110
+ max_new_tokens=None,
111
+ min_length=None,
112
+ do_sample=None,
113
+ early_stopping=None,
114
+ num_beams=None,
115
+ temperature=1.1,
116
+ streamer=None,
117
+ top_k=None,
118
+ top_p=None,
119
+ repetition_penalty=None,
120
+ bad_words_ids=None,
121
+ bos_token_id=None,
122
+ pad_token_id=None,
123
+ eos_token_id=None,
124
+ length_penalty=None,
125
+ no_repeat_ngram_size=None,
126
+ num_return_sequences=None,
127
+ decoder_start_token_id=None,
128
+ use_cache=None,
129
+ num_beam_groups=None,
130
+ diversity_penalty=None,
131
+ prefix_allowed_tokens_fn=None,
132
+ output_attentions=None,
133
+ output_hidden_states=None,
134
+ output_scores=None,
135
+ return_dict_in_generate=None,
136
+ forced_bos_token_id=None,
137
+ forced_eos_token_id=None,
138
+ remove_invalid_values=None,
139
+ synced_gpus=None,
140
+ n_ahead=4,
141
+ n_ahead_talk=4,
142
+ merged_talk_heads=True,
143
+ merged_lm_and_talk_heads=False,
144
+ merged_lm_and_think_heads=True,
145
+ use_concat_talk_head=True,
146
+ use_shallow_think=True,
147
+ use_shallow_talk=False,
148
+ use_complex_think_head=False,
149
+ use_complex_talk_head=True,
150
+ use_weighted_talk_head=True,
151
+ trust_remote_code=True,
152
+ torch_dtype=torch.bfloat16,
153
+ **model_kwargs,
154
+ ):
155
+ # Set model attributes
156
+ self.max_thoughts = n_ahead + n_ahead_talk + 1
157
+ self.merged_talk_heads = merged_talk_heads
158
+ self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
159
+ self.merged_lm_and_think_heads = merged_lm_and_think_heads
160
+ self.use_concat_talk_head = use_concat_talk_head
161
+ self.use_shallow_think = use_shallow_think
162
+ self.use_shallow_talk = use_shallow_talk
163
+ self.use_complex_think_head = use_complex_think_head
164
+ self.use_complex_talk_head = use_complex_talk_head
165
+ self.use_weighted_talk_head = use_weighted_talk_head
166
+
167
+ # Set model properties
168
+ self.use_end_thought_token = True
169
+ self.use_start_thought_token = True
170
+ self.n_ahead = n_ahead
171
+ self.n_passes = 1
172
+ self.eval_mode = True
173
+ self.first_run = False
174
+ self.rm_initialized = True
175
+ self.original_mode = False
176
+
177
+ output = custom_generate(
178
+ self,
179
+ input_ids=input_ids,
180
+ attention_mask=attention_mask,
181
+ max_new_tokens=max_new_tokens,
182
+ min_length=min_length,
183
+ do_sample=do_sample,
184
+ early_stopping=early_stopping,
185
+ num_beams=num_beams,
186
+ temperature=temperature,
187
+ top_k=top_k,
188
+ top_p=top_p,
189
+ repetition_penalty=repetition_penalty,
190
+ bad_words_ids=bad_words_ids,
191
+ bos_token_id=bos_token_id,
192
+ pad_token_id=pad_token_id,
193
+ eos_token_id=eos_token_id,
194
+ length_penalty=length_penalty,
195
+ no_repeat_ngram_size=no_repeat_ngram_size,
196
+ num_return_sequences=num_return_sequences,
197
+ decoder_start_token_id=decoder_start_token_id,
198
+ use_cache=use_cache,
199
+ num_beam_groups=num_beam_groups,
200
+ diversity_penalty=diversity_penalty,
201
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
202
+ output_attentions=output_attentions,
203
+ output_hidden_states=output_hidden_states,
204
+ output_scores=output_scores,
205
+ return_dict_in_generate=return_dict_in_generate,
206
+ forced_bos_token_id=forced_bos_token_id,
207
+ forced_eos_token_id=forced_eos_token_id,
208
+ remove_invalid_values=remove_invalid_values,
209
+ synced_gpus=synced_gpus,
210
+ streamer=streamer,
211
+ **model_kwargs,
212
+ )
213
+
214
+ return output