Crystalcareai commited on
Commit
d00c49d
·
verified ·
1 Parent(s): 890bc4c

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +48 -58
generate.py CHANGED
@@ -45,56 +45,55 @@ def custom_generate(
45
  synced_gpus=None,
46
  **kwargs,
47
  ):
48
- with torch.no_grad():
49
- finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
50
-
51
- while not finished_generating.all() and input_ids.shape[1] < max_length:
52
- # Sample the next token
53
- new_ids = self(
54
- input_ids[~finished_generating],
55
- attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
56
- **kwargs
57
- )['logits']
58
-
59
- # Mask out the start and end thought tokens so we don't accidentally sample them
60
- new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
61
-
62
- for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
63
- # Find the index of the last token that is not padding
64
- base_answer_ids = input_ids[answer_idx]
65
- new_answer_ids = new_ids[list_idx]
66
- last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
67
-
68
- new_ids_sampled = torch.multinomial(
69
- torch.nn.functional.softmax(new_answer_ids[last_token_idx] / kwargs.get("temperature", 1.0), dim=-1), 1)
70
-
71
- # Assign the new id to the last token
72
- if last_token_idx + 1 >= len(base_answer_ids):
73
- # Add padding everywhere
74
- new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
75
- device=input_ids.device)
76
- input_ids = torch.cat([input_ids, new_padding], dim=-1)
77
- if attention_mask is not None:
78
- attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
79
-
80
- if attention_mask is not None:
81
- attention_mask[answer_idx, last_token_idx + 1] = 1
82
- input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
83
-
84
- if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
85
- finished_generating[answer_idx] = 1
86
-
87
- # Check if the end token is generated
88
- if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
89
- finished_generating[answer_idx] = 1
90
-
91
- if streamer is not None:
92
- streamer.put(new_ids_sampled)
93
-
94
- generated_token_ids = input_ids.tolist()
95
-
96
- return generated_token_ids
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
 
100
  def generate(
@@ -158,15 +157,6 @@ def generate(
158
  self.use_complex_talk_head = use_complex_talk_head
159
  self.use_weighted_talk_head = use_weighted_talk_head
160
 
161
- # Set model properties
162
- self.use_end_thought_token = True
163
- self.use_start_thought_token = True
164
- self.n_ahead = n_ahead
165
- self.n_passes = 1
166
- self.eval_mode = True
167
- self.first_run = False
168
- self.rm_initialized = True
169
- self.original_mode = False
170
 
171
  # Generate using the custom generate function
172
  generated_token_ids = custom_generate(
 
45
  synced_gpus=None,
46
  **kwargs,
47
  ):
48
+ with torch.no_grad():
49
+ finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ while not finished_generating.all() and input_ids.shape[1] < max_length:
52
+ # Sample the next token
53
+ new_ids = self(
54
+ input_ids[~finished_generating],
55
+ attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
56
+ **kwargs
57
+ )['logits']
58
+
59
+ # Mask out the start and end thought tokens so we don't accidentally sample them
60
+ new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
61
+
62
+ for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
63
+ # Find the index of the last token that is not padding
64
+ base_answer_ids = input_ids[answer_idx]
65
+ new_answer_ids = new_ids[list_idx]
66
+ last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
67
+
68
+ new_ids_sampled = torch.multinomial(
69
+ torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
70
+
71
+ # Assign the new id to the last token
72
+ if last_token_idx + 1 >= len(base_answer_ids):
73
+ # Add padding everywhere
74
+ new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
75
+ device=input_ids.device)
76
+ input_ids = torch.cat([input_ids, new_padding], dim=-1)
77
+ if attention_mask is not None:
78
+ attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
79
+
80
+ if attention_mask is not None:
81
+ attention_mask[answer_idx, last_token_idx + 1] = 1
82
+ input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
83
+
84
+ if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
85
+ finished_generating[answer_idx] = 1
86
+
87
+ # Check if the end token is generated
88
+ if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
89
+ finished_generating[answer_idx] = 1
90
+
91
+ if streamer is not None:
92
+ streamer.put(new_ids_sampled)
93
+
94
+ generated_token_ids = input_ids.tolist()
95
+
96
+ return generated_token_ids
97
 
98
 
99
  def generate(
 
157
  self.use_complex_talk_head = use_complex_talk_head
158
  self.use_weighted_talk_head = use_weighted_talk_head
159
 
 
 
 
 
 
 
 
 
 
160
 
161
  # Generate using the custom generate function
162
  generated_token_ids = custom_generate(