|
from dataclasses import dataclass, field |
|
import inspect |
|
import logging |
|
from typing import Optional, List, Union, Dict, Tuple, Any |
|
from transformers.configuration_utils import PretrainedConfig |
|
import mlx.core as mx |
|
|
|
|
|
|
|
class FloatTensor: |
|
def __init__(self, data): |
|
if data is not None: |
|
self.tensor = mx.array(data, dtype=mx.float32) |
|
else: |
|
self.tensor = None |
|
|
|
def __repr__(self): |
|
return repr(self.tensor) |
|
|
|
|
|
class LongTensor: |
|
def __init__(self, data=None): |
|
if data is not None: |
|
self.tensor = mx.array(data, dtype=mx.int64) |
|
else: |
|
self.tensor = None |
|
|
|
def assign(self, data): |
|
self.tensor = mx.array(data, dtype=mx.int64) |
|
|
|
def __repr__(self): |
|
return repr(self.tensor) |
|
|
|
@dataclass |
|
class BaseModelOutputWithPast: |
|
""" |
|
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). |
|
|
|
Args: |
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
|
Sequence of hidden-states at the output of the last layer of the model. |
|
|
|
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, |
|
hidden_size)` is output. |
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if |
|
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, |
|
encoder_sequence_length, embed_size_per_head)`. |
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if |
|
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` |
|
input) to speed up sequential decoding. |
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
""" |
|
|
|
last_hidden_state: FloatTensor = None |
|
past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None |
|
hidden_states: Optional[Tuple[FloatTensor, ...]] = None |
|
attentions: Optional[Tuple[FloatTensor, ...]] = None |
|
|
|
|
|
@dataclass |
|
class Cache: |
|
""" |
|
Base, abstract class for all caches. The actual data structure is specific to each subclass. |
|
""" |
|
|
|
def update( |
|
self, |
|
key_states: mx.array, |
|
value_states: mx.array, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[mx.array, mx.array]: |
|
""" |
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
|
|
Parameters: |
|
key_states (`mx.array`): |
|
The new key states to cache. |
|
value_states (`mx.array`): |
|
The new value states to cache. |
|
layer_idx (`int`): |
|
The index of the layer to cache the states for. |
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of |
|
cache to be created. |
|
|
|
Return: |
|
A tuple containing the updated key and value states. |
|
""" |
|
raise NotImplementedError("Make sure to implement `update` in a subclass.") |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
|
|
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states, if there is any.""" |
|
raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") |
|
|
|
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: |
|
"""Given the sequence length of the new inputs, returns the usable length of the cache.""" |
|
|
|
|
|
|
|
max_length = self.get_max_length() |
|
previous_seq_length = self.get_seq_length(layer_idx) |
|
if max_length is not None and previous_seq_length + new_seq_length > max_length: |
|
return max_length - new_seq_length |
|
return previous_seq_length |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
def seen_tokens(self): |
|
logging.warning( |
|
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " |
|
"model input instead." |
|
) |
|
if hasattr(self, "_seen_tokens"): |
|
return self._seen_tokens |
|
else: |
|
return None |
|
|
|
|
|
class DynamicCache(Cache): |
|
""" |
|
A cache that grows dynamically as more tokens are generated. This is the default for generative models. |
|
|
|
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is |
|
`[batch_size, num_heads, seq_len, head_dim]`. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
self.key_cache: List[mx.array] = [] |
|
self.value_cache: List[mx.array] = [] |
|
self._seen_tokens = 0 |
|
|
|
def __getitem__(self, layer_idx: int) -> List[Tuple[mx.array]]: |
|
""" |
|
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the |
|
sequence length. |
|
""" |
|
if layer_idx < len(self): |
|
return (self.key_cache[layer_idx], self.value_cache[layer_idx]) |
|
else: |
|
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") |
|
|
|
def __iter__(self): |
|
""" |
|
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over |
|
keys and values |
|
""" |
|
for layer_idx in range(len(self)): |
|
yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) |
|
|
|
def __len__(self): |
|
""" |
|
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds |
|
to the number of layers in the model. |
|
""" |
|
return len(self.key_cache) |
|
|
|
def update( |
|
self, |
|
key_states: mx.array, |
|
value_states: mx.array, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[mx.array, mx.array]: |
|
""" |
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
|
|
Parameters: |
|
key_states (`mx.array`): |
|
The new key states to cache. |
|
value_states (`mx.array`): |
|
The new value states to cache. |
|
layer_idx (`int`): |
|
The index of the layer to cache the states for. |
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. |
|
|
|
Return: |
|
A tuple containing the updated key and value states. |
|
""" |
|
|
|
if layer_idx == 0: |
|
self._seen_tokens += key_states.shape[-2] |
|
|
|
|
|
if len(self.key_cache) <= layer_idx: |
|
self.key_cache.append(key_states) |
|
self.value_cache.append(value_states) |
|
else: |
|
self.key_cache[layer_idx] = mx.concatenate([self.key_cache[layer_idx], key_states], dim=-2) |
|
self.value_cache[layer_idx] = mx.concatenate([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
|
return self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
|
|
if len(self.key_cache) <= layer_idx: |
|
return 0 |
|
return self.key_cache[layer_idx].shape[-2] |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" |
|
return None |
|
|
|
def to_legacy_cache(self) -> Tuple[Tuple[mx.array], Tuple[mx.array]]: |
|
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" |
|
legacy_cache = () |
|
for layer_idx in range(len(self)): |
|
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) |
|
return legacy_cache |
|
|
|
@classmethod |
|
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None) -> "DynamicCache": |
|
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" |
|
cache = cls() |
|
if past_key_values is not None: |
|
for layer_idx in range(len(past_key_values)): |
|
key_states, value_states = past_key_values[layer_idx] |
|
cache.update(key_states, value_states, layer_idx) |
|
return cache |
|
|
|
|
|
@dataclass |
|
class CausalLMOutputWithPast(): |
|
|
|
loss: Optional[FloatTensor] = None |
|
logits: FloatTensor = None |
|
past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None |
|
hidden_states: Optional[Tuple[FloatTensor, ...]] = None |
|
attentions: Optional[Tuple[FloatTensor, ...]] = None |