Update modeling_quiet.py
Browse files- modeling_quiet.py +22 -59
modeling_quiet.py
CHANGED
@@ -1891,12 +1891,16 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1891 |
torch.cuda.empty_cache()
|
1892 |
|
1893 |
|
1894 |
-
return
|
1895 |
-
|
1896 |
-
|
1897 |
-
|
1898 |
-
|
1899 |
-
|
|
|
|
|
|
|
|
|
1900 |
)
|
1901 |
|
1902 |
|
@@ -1904,59 +1908,18 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1904 |
def prepare_inputs_for_generation(
|
1905 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
1906 |
):
|
1907 |
-
|
1908 |
-
|
1909 |
-
|
1910 |
-
|
1911 |
-
|
1912 |
-
|
1913 |
-
|
1914 |
-
|
1915 |
-
|
1916 |
-
|
1917 |
-
|
1918 |
-
|
1919 |
-
# some of the inputs are exclusively passed as part of the cache (e.g. when passing inputs_embeds as
|
1920 |
-
# input)
|
1921 |
-
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
1922 |
-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
1923 |
-
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
1924 |
-
# input_ids based on the past_length.
|
1925 |
-
elif past_length < input_ids.shape[1]:
|
1926 |
-
input_ids = input_ids[:, past_length:]
|
1927 |
-
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
1928 |
-
|
1929 |
-
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
1930 |
-
if (
|
1931 |
-
max_cache_length is not None
|
1932 |
-
and attention_mask is not None
|
1933 |
-
and cache_length + input_ids.shape[1] > max_cache_length
|
1934 |
-
):
|
1935 |
-
attention_mask = attention_mask[:, -max_cache_length:]
|
1936 |
-
|
1937 |
-
position_ids = kwargs.get("position_ids", None)
|
1938 |
-
if attention_mask is not None and position_ids is None:
|
1939 |
-
# create position_ids on the fly for batch generation
|
1940 |
-
position_ids = attention_mask.long().cumsum(-1) - 1
|
1941 |
-
position_ids.masked_fill_(attention_mask == 0, 1)
|
1942 |
-
if past_key_values:
|
1943 |
-
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1944 |
-
|
1945 |
-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1946 |
-
if inputs_embeds is not None and past_key_values is None:
|
1947 |
-
model_inputs = {"inputs_embeds": inputs_embeds}
|
1948 |
-
else:
|
1949 |
-
model_inputs = {"input_ids": input_ids}
|
1950 |
-
|
1951 |
-
model_inputs.update(
|
1952 |
-
{
|
1953 |
-
"position_ids": position_ids,
|
1954 |
-
"past_key_values": past_key_values,
|
1955 |
-
"use_cache": kwargs.get("use_cache"),
|
1956 |
-
"attention_mask": attention_mask,
|
1957 |
-
}
|
1958 |
-
)
|
1959 |
-
return model_inputs
|
1960 |
|
1961 |
@staticmethod
|
1962 |
def _reorder_cache(past_key_values, beam_idx):
|
|
|
1891 |
torch.cuda.empty_cache()
|
1892 |
|
1893 |
|
1894 |
+
return self.infer(
|
1895 |
+
input_ids=input_ids,
|
1896 |
+
attention_mask=attention_mask,
|
1897 |
+
position_ids=position_ids,
|
1898 |
+
past_key_values=past_key_values,
|
1899 |
+
inputs_embeds=inputs_embeds,
|
1900 |
+
use_cache=use_cache,
|
1901 |
+
output_attentions=output_attentions,
|
1902 |
+
output_hidden_states=output_hidden_states,
|
1903 |
+
return_dict=return_dict,
|
1904 |
)
|
1905 |
|
1906 |
|
|
|
1908 |
def prepare_inputs_for_generation(
|
1909 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
1910 |
):
|
1911 |
+
if attention_mask is None:
|
1912 |
+
attention_mask = input_ids.new_ones(input_ids.shape)
|
1913 |
+
|
1914 |
+
if past_key_values:
|
1915 |
+
input_ids = input_ids[:, -1:]
|
1916 |
+
|
1917 |
+
return {
|
1918 |
+
"input_ids": input_ids,
|
1919 |
+
"attention_mask": attention_mask,
|
1920 |
+
"past_key_values": past_key_values,
|
1921 |
+
"inputs_embeds": inputs_embeds,
|
1922 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1923 |
|
1924 |
@staticmethod
|
1925 |
def _reorder_cache(past_key_values, beam_idx):
|