Crystalcareai commited on
Commit
ea46013
·
verified ·
1 Parent(s): de8dd38

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +22 -59
modeling_quiet.py CHANGED
@@ -1891,12 +1891,16 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1891
  torch.cuda.empty_cache()
1892
 
1893
 
1894
- return CausalLMOutputWithPast(
1895
- loss=loss if loss is not None else None,
1896
- logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,
1897
- past_key_values=outputs.past_key_values,
1898
- hidden_states=outputs.hidden_states,
1899
- attentions=outputs.attentions,
 
 
 
 
1900
  )
1901
 
1902
 
@@ -1904,59 +1908,18 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1904
  def prepare_inputs_for_generation(
1905
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1906
  ):
1907
- # Omit tokens covered by past_key_values
1908
- if past_key_values is not None:
1909
- if isinstance(past_key_values, Cache):
1910
- cache_length = past_key_values.get_seq_length()
1911
- past_length = past_key_values.seen_tokens
1912
- max_cache_length = past_key_values.get_max_length()
1913
- else:
1914
- cache_length = past_length = past_key_values[0][0].shape[2]
1915
- max_cache_length = None
1916
-
1917
- # Keep only the unprocessed tokens:
1918
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1919
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing inputs_embeds as
1920
- # input)
1921
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1922
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1923
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1924
- # input_ids based on the past_length.
1925
- elif past_length < input_ids.shape[1]:
1926
- input_ids = input_ids[:, past_length:]
1927
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1928
-
1929
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1930
- if (
1931
- max_cache_length is not None
1932
- and attention_mask is not None
1933
- and cache_length + input_ids.shape[1] > max_cache_length
1934
- ):
1935
- attention_mask = attention_mask[:, -max_cache_length:]
1936
-
1937
- position_ids = kwargs.get("position_ids", None)
1938
- if attention_mask is not None and position_ids is None:
1939
- # create position_ids on the fly for batch generation
1940
- position_ids = attention_mask.long().cumsum(-1) - 1
1941
- position_ids.masked_fill_(attention_mask == 0, 1)
1942
- if past_key_values:
1943
- position_ids = position_ids[:, -input_ids.shape[1] :]
1944
-
1945
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1946
- if inputs_embeds is not None and past_key_values is None:
1947
- model_inputs = {"inputs_embeds": inputs_embeds}
1948
- else:
1949
- model_inputs = {"input_ids": input_ids}
1950
-
1951
- model_inputs.update(
1952
- {
1953
- "position_ids": position_ids,
1954
- "past_key_values": past_key_values,
1955
- "use_cache": kwargs.get("use_cache"),
1956
- "attention_mask": attention_mask,
1957
- }
1958
- )
1959
- return model_inputs
1960
 
1961
  @staticmethod
1962
  def _reorder_cache(past_key_values, beam_idx):
 
1891
  torch.cuda.empty_cache()
1892
 
1893
 
1894
+ return self.infer(
1895
+ input_ids=input_ids,
1896
+ attention_mask=attention_mask,
1897
+ position_ids=position_ids,
1898
+ past_key_values=past_key_values,
1899
+ inputs_embeds=inputs_embeds,
1900
+ use_cache=use_cache,
1901
+ output_attentions=output_attentions,
1902
+ output_hidden_states=output_hidden_states,
1903
+ return_dict=return_dict,
1904
  )
1905
 
1906
 
 
1908
  def prepare_inputs_for_generation(
1909
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1910
  ):
1911
+ if attention_mask is None:
1912
+ attention_mask = input_ids.new_ones(input_ids.shape)
1913
+
1914
+ if past_key_values:
1915
+ input_ids = input_ids[:, -1:]
1916
+
1917
+ return {
1918
+ "input_ids": input_ids,
1919
+ "attention_mask": attention_mask,
1920
+ "past_key_values": past_key_values,
1921
+ "inputs_embeds": inputs_embeds,
1922
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1923
 
1924
  @staticmethod
1925
  def _reorder_cache(past_key_values, beam_idx):