Crystalcareai
commited on
Update generate.py
Browse files- generate.py +48 -58
generate.py
CHANGED
@@ -45,56 +45,55 @@ def custom_generate(
|
|
45 |
synced_gpus=None,
|
46 |
**kwargs,
|
47 |
):
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
while not finished_generating.all() and input_ids.shape[1] < max_length:
|
52 |
-
# Sample the next token
|
53 |
-
new_ids = self(
|
54 |
-
input_ids[~finished_generating],
|
55 |
-
attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
|
56 |
-
**kwargs
|
57 |
-
)['logits']
|
58 |
-
|
59 |
-
# Mask out the start and end thought tokens so we don't accidentally sample them
|
60 |
-
new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
|
61 |
-
|
62 |
-
for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
|
63 |
-
# Find the index of the last token that is not padding
|
64 |
-
base_answer_ids = input_ids[answer_idx]
|
65 |
-
new_answer_ids = new_ids[list_idx]
|
66 |
-
last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
|
67 |
-
|
68 |
-
new_ids_sampled = torch.multinomial(
|
69 |
-
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / kwargs.get("temperature", 1.0), dim=-1), 1)
|
70 |
-
|
71 |
-
# Assign the new id to the last token
|
72 |
-
if last_token_idx + 1 >= len(base_answer_ids):
|
73 |
-
# Add padding everywhere
|
74 |
-
new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
|
75 |
-
device=input_ids.device)
|
76 |
-
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
77 |
-
if attention_mask is not None:
|
78 |
-
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
|
79 |
-
|
80 |
-
if attention_mask is not None:
|
81 |
-
attention_mask[answer_idx, last_token_idx + 1] = 1
|
82 |
-
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
|
83 |
-
|
84 |
-
if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
|
85 |
-
finished_generating[answer_idx] = 1
|
86 |
-
|
87 |
-
# Check if the end token is generated
|
88 |
-
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
|
89 |
-
finished_generating[answer_idx] = 1
|
90 |
-
|
91 |
-
if streamer is not None:
|
92 |
-
streamer.put(new_ids_sampled)
|
93 |
-
|
94 |
-
generated_token_ids = input_ids.tolist()
|
95 |
-
|
96 |
-
return generated_token_ids
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
|
100 |
def generate(
|
@@ -158,15 +157,6 @@ def generate(
|
|
158 |
self.use_complex_talk_head = use_complex_talk_head
|
159 |
self.use_weighted_talk_head = use_weighted_talk_head
|
160 |
|
161 |
-
# Set model properties
|
162 |
-
self.use_end_thought_token = True
|
163 |
-
self.use_start_thought_token = True
|
164 |
-
self.n_ahead = n_ahead
|
165 |
-
self.n_passes = 1
|
166 |
-
self.eval_mode = True
|
167 |
-
self.first_run = False
|
168 |
-
self.rm_initialized = True
|
169 |
-
self.original_mode = False
|
170 |
|
171 |
# Generate using the custom generate function
|
172 |
generated_token_ids = custom_generate(
|
|
|
45 |
synced_gpus=None,
|
46 |
**kwargs,
|
47 |
):
|
48 |
+
with torch.no_grad():
|
49 |
+
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
while not finished_generating.all() and input_ids.shape[1] < max_length:
|
52 |
+
# Sample the next token
|
53 |
+
new_ids = self(
|
54 |
+
input_ids[~finished_generating],
|
55 |
+
attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
|
56 |
+
**kwargs
|
57 |
+
)['logits']
|
58 |
+
|
59 |
+
# Mask out the start and end thought tokens so we don't accidentally sample them
|
60 |
+
new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
|
61 |
+
|
62 |
+
for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
|
63 |
+
# Find the index of the last token that is not padding
|
64 |
+
base_answer_ids = input_ids[answer_idx]
|
65 |
+
new_answer_ids = new_ids[list_idx]
|
66 |
+
last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
|
67 |
+
|
68 |
+
new_ids_sampled = torch.multinomial(
|
69 |
+
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
|
70 |
+
|
71 |
+
# Assign the new id to the last token
|
72 |
+
if last_token_idx + 1 >= len(base_answer_ids):
|
73 |
+
# Add padding everywhere
|
74 |
+
new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
|
75 |
+
device=input_ids.device)
|
76 |
+
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
77 |
+
if attention_mask is not None:
|
78 |
+
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
|
79 |
+
|
80 |
+
if attention_mask is not None:
|
81 |
+
attention_mask[answer_idx, last_token_idx + 1] = 1
|
82 |
+
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
|
83 |
+
|
84 |
+
if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
|
85 |
+
finished_generating[answer_idx] = 1
|
86 |
+
|
87 |
+
# Check if the end token is generated
|
88 |
+
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
|
89 |
+
finished_generating[answer_idx] = 1
|
90 |
+
|
91 |
+
if streamer is not None:
|
92 |
+
streamer.put(new_ids_sampled)
|
93 |
+
|
94 |
+
generated_token_ids = input_ids.tolist()
|
95 |
+
|
96 |
+
return generated_token_ids
|
97 |
|
98 |
|
99 |
def generate(
|
|
|
157 |
self.use_complex_talk_head = use_complex_talk_head
|
158 |
self.use_weighted_talk_head = use_weighted_talk_head
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
# Generate using the custom generate function
|
162 |
generated_token_ids = custom_generate(
|