Crystalcareai commited on
Commit
54fa971
·
verified ·
1 Parent(s): 49663aa

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +12 -44
modeling_quiet.py CHANGED
@@ -1334,6 +1334,17 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1334
  output_attentions: Optional[bool] = None,
1335
  output_hidden_states: Optional[bool] = None,
1336
  return_dict: Optional[bool] = None,
 
 
 
 
 
 
 
 
 
 
 
1337
  ):
1338
  batch_size, seq_len = input_ids.shape
1339
 
@@ -2128,50 +2139,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
2128
  del start_embedding
2129
  del end_embedding
2130
  torch.cuda.empty_cache()
2131
-
2132
- if not self.training:
2133
- # Inference mode
2134
- if max_length is None:
2135
- max_length = self.config.max_length
2136
-
2137
- finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
2138
- for cur_token_idx in range(max_length):
2139
- outputs = self.model(
2140
- input_ids=input_ids,
2141
- attention_mask=attention_mask,
2142
- position_ids=position_ids,
2143
- past_key_values=past_key_values,
2144
- inputs_embeds=inputs_embeds,
2145
- use_cache=use_cache,
2146
- output_attentions=output_attentions,
2147
- output_hidden_states=output_hidden_states,
2148
- return_dict=return_dict,
2149
- )
2150
- hidden_states = outputs[0]
2151
- logits = self.lm_head(hidden_states)
2152
-
2153
- # Mask out the start and end thought tokens
2154
- logits[:, :, self.start_token_id] = -float("inf")
2155
- logits[:, :, self.end_token_id] = -float("inf")
2156
-
2157
- for batch_idx in range(batch_size):
2158
- if not finished_generating[batch_idx]:
2159
- last_token_idx = (input_ids[batch_idx] != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
2160
- new_id_sampled = torch.multinomial(
2161
- torch.nn.functional.softmax(logits[batch_idx, last_token_idx] / temperature, dim=-1), 1
2162
- )
2163
- if last_token_idx + 1 >= input_ids.shape[1]:
2164
- # Add padding
2165
- new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long, device=input_ids.device)
2166
- input_ids = torch.cat([input_ids, new_padding], dim=-1)
2167
- attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
2168
- attention_mask[batch_idx, last_token_idx + 1] = 1
2169
- input_ids[batch_idx, last_token_idx + 1] = new_id_sampled
2170
- if new_id_sampled == self.tokenizer.eos_token_id or new_id_sampled == self.tokenizer.bos_token_id or new_id_sampled == self.tokenizer.pad_token_id:
2171
- finished_generating[batch_idx] = True
2172
-
2173
- if finished_generating.all():
2174
- break
2175
 
2176
  return CausalLMOutputWithPast(
2177
  loss=loss if loss is not None else None,
 
1334
  output_attentions: Optional[bool] = None,
1335
  output_hidden_states: Optional[bool] = None,
1336
  return_dict: Optional[bool] = None,
1337
+ max_length: Optional[int] = None,
1338
+ num_return_sequences: Optional[int] = 1,
1339
+ no_repeat_ngram_size: Optional[int] = 2,
1340
+ early_stopping: Optional[bool] = True,
1341
+ num_beams: Optional[int] = 1,
1342
+ temperature: Optional[float] = 1.0,
1343
+ repetition_penalty: Optional[float] = 1.2,
1344
+ length_penalty: Optional[float] = 1.0,
1345
+ pad_token_id: Optional[int] = None,
1346
+ eos_token_id: Optional[int] = None,
1347
+ streamer: Optional[TextStreamer] = None,
1348
  ):
1349
  batch_size, seq_len = input_ids.shape
1350
 
 
2139
  del start_embedding
2140
  del end_embedding
2141
  torch.cuda.empty_cache()
2142
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2143
 
2144
  return CausalLMOutputWithPast(
2145
  loss=loss if loss is not None else None,