Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +12 -44
modeling_quiet.py
CHANGED
@@ -1334,6 +1334,17 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1334 |
output_attentions: Optional[bool] = None,
|
1335 |
output_hidden_states: Optional[bool] = None,
|
1336 |
return_dict: Optional[bool] = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1337 |
):
|
1338 |
batch_size, seq_len = input_ids.shape
|
1339 |
|
@@ -2128,50 +2139,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
2128 |
del start_embedding
|
2129 |
del end_embedding
|
2130 |
torch.cuda.empty_cache()
|
2131 |
-
|
2132 |
-
if not self.training:
|
2133 |
-
# Inference mode
|
2134 |
-
if max_length is None:
|
2135 |
-
max_length = self.config.max_length
|
2136 |
-
|
2137 |
-
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
|
2138 |
-
for cur_token_idx in range(max_length):
|
2139 |
-
outputs = self.model(
|
2140 |
-
input_ids=input_ids,
|
2141 |
-
attention_mask=attention_mask,
|
2142 |
-
position_ids=position_ids,
|
2143 |
-
past_key_values=past_key_values,
|
2144 |
-
inputs_embeds=inputs_embeds,
|
2145 |
-
use_cache=use_cache,
|
2146 |
-
output_attentions=output_attentions,
|
2147 |
-
output_hidden_states=output_hidden_states,
|
2148 |
-
return_dict=return_dict,
|
2149 |
-
)
|
2150 |
-
hidden_states = outputs[0]
|
2151 |
-
logits = self.lm_head(hidden_states)
|
2152 |
-
|
2153 |
-
# Mask out the start and end thought tokens
|
2154 |
-
logits[:, :, self.start_token_id] = -float("inf")
|
2155 |
-
logits[:, :, self.end_token_id] = -float("inf")
|
2156 |
-
|
2157 |
-
for batch_idx in range(batch_size):
|
2158 |
-
if not finished_generating[batch_idx]:
|
2159 |
-
last_token_idx = (input_ids[batch_idx] != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
|
2160 |
-
new_id_sampled = torch.multinomial(
|
2161 |
-
torch.nn.functional.softmax(logits[batch_idx, last_token_idx] / temperature, dim=-1), 1
|
2162 |
-
)
|
2163 |
-
if last_token_idx + 1 >= input_ids.shape[1]:
|
2164 |
-
# Add padding
|
2165 |
-
new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long, device=input_ids.device)
|
2166 |
-
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
2167 |
-
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
|
2168 |
-
attention_mask[batch_idx, last_token_idx + 1] = 1
|
2169 |
-
input_ids[batch_idx, last_token_idx + 1] = new_id_sampled
|
2170 |
-
if new_id_sampled == self.tokenizer.eos_token_id or new_id_sampled == self.tokenizer.bos_token_id or new_id_sampled == self.tokenizer.pad_token_id:
|
2171 |
-
finished_generating[batch_idx] = True
|
2172 |
-
|
2173 |
-
if finished_generating.all():
|
2174 |
-
break
|
2175 |
|
2176 |
return CausalLMOutputWithPast(
|
2177 |
loss=loss if loss is not None else None,
|
|
|
1334 |
output_attentions: Optional[bool] = None,
|
1335 |
output_hidden_states: Optional[bool] = None,
|
1336 |
return_dict: Optional[bool] = None,
|
1337 |
+
max_length: Optional[int] = None,
|
1338 |
+
num_return_sequences: Optional[int] = 1,
|
1339 |
+
no_repeat_ngram_size: Optional[int] = 2,
|
1340 |
+
early_stopping: Optional[bool] = True,
|
1341 |
+
num_beams: Optional[int] = 1,
|
1342 |
+
temperature: Optional[float] = 1.0,
|
1343 |
+
repetition_penalty: Optional[float] = 1.2,
|
1344 |
+
length_penalty: Optional[float] = 1.0,
|
1345 |
+
pad_token_id: Optional[int] = None,
|
1346 |
+
eos_token_id: Optional[int] = None,
|
1347 |
+
streamer: Optional[TextStreamer] = None,
|
1348 |
):
|
1349 |
batch_size, seq_len = input_ids.shape
|
1350 |
|
|
|
2139 |
del start_embedding
|
2140 |
del end_embedding
|
2141 |
torch.cuda.empty_cache()
|
2142 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2143 |
|
2144 |
return CausalLMOutputWithPast(
|
2145 |
loss=loss if loss is not None else None,
|