diff --git "a/modeling_mixtral.py" "b/modeling_mixtral.py" new file mode 100644--- /dev/null +++ "b/modeling_mixtral.py" @@ -0,0 +1,3570 @@ +""" PyTorch Mixtral model.""" +import importlib +import inspect +import math +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import scattermoe +import stk +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.utils.checkpoint +from megablocks import grouped_gemm_util as gg +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import Arguments as MegablocksArguments +from megablocks.layers.dmlp_registry import _REGISTRY +from megablocks.layers.dmoe import ParallelDroplessMLP +from megablocks.layers.glu import memory_optimized_grouped_glu +from packaging import version +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import ( + BeamSearchScorer, + ConstrainedBeamSearchScorer, + DisjunctiveConstraint, + LogitsProcessorList, + PhrasalConstraint, + QuantizedCacheConfig, + StoppingCriteriaList, +) +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation.configuration_utils import GenerationConfig, GenerationMode +from transformers.generation.utils import ( + NEED_SETUP_CACHE_CLASSES_MAPPING, + QUANT_BACKEND_CLASSES_MAPPING, + GenerateOutput, +) +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.modeling_outputs import ModelOutput, SequenceClassifierOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import is_torch_available, logging +from transformers.utils.import_utils import ( + is_hqq_available, + is_quanto_available, + is_torch_fx_available, + is_torchdynamo_compiling, +) + +from smoe.utils.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + +from .configuration_mixtral import MixtralConfig + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MixtralConfig" + +parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) +is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13") + + +def _is_package_available( + pkg_name: str, return_version: bool = False +) -> Union[Tuple[bool, str], bool]: + # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + if package_exists: + try: + package_version = importlib.metadata.version(pkg_name) + package_exists = True + except importlib.metadata.PackageNotFoundError: + package_exists = False + logger.debug(f"Detected {pkg_name} version {package_version}") + if return_version: + return package_exists, package_version + else: + return package_exists + + +def is_flash_attn_2_available(): + if not is_torch_available(): + return False + + if not _is_package_available("flash_attn"): + return False + + # Let's add an extra check to see if cuda is available + import torch + + if not torch.cuda.is_available(): + return False + + if torch.version.cuda: + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse( + "2.1.0" + ) + elif torch.version.hip: + # TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse( + "2.0.4" + ) + else: + return False + + +def is_flash_attn_greater_or_equal_2_10(): + if not _is_package_available("flash_attn"): + return False + + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse( + "2.1.0" + ) + + +def is_flash_attn_available(): + logger.warning( + "Using `is_flash_attn_available` is deprecated and will be removed in v4.38. " + "Please use `is_flash_attn_2_available` instead." + ) + return is_flash_attn_2_available() + + +@dataclass +class MoeCausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) with mixture of experts outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + + aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + aux_loss for the sparse modules. + + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + aux_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + + @property + def balance_loss(self): + return self.aux_loss + + @property + def num_dropped_tokens(self): + return [torch.tensor(-1)] * 32 + + @property + def gate_load(self): + return [torch.tensor(-1)] * 32 + + @property + def gate_importance(self): + return [torch.tensor(-1)] * 32 + + +@dataclass +class MoeModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + attn_router_logits: Optional[Tuple[torch.FloatTensor]] = None # 🔍 + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list( + inspect.signature(flash_attn_func).parameters + ) + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple], + num_experts: torch.Tensor = None, + top_k=2, + use_layer_wise_balance=False, +) -> torch.FloatTensor: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_experts]. + num_experts (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or (isinstance(gate_logits, Iterable) and len(gate_logits) == 0): + return 0 + + # ✨ Here is the fix for balance loss in Mixtral. + # We should calculate the balance loss in a layer-wise manner otherwise it may lead to degenerated solutions. + if use_layer_wise_balance: + if not isinstance(gate_logits, Iterable): + gate_logits = (gate_logits,) + else: + if isinstance(gate_logits, Iterable): + gate_logits = (torch.cat(gate_logits, dim=0),) + else: + gate_logits = (gate_logits,) + + all_balance_losses = [] + + for logits in gate_logits: + routing_weights, selected_experts = torch.topk(logits, top_k, dim=-1) + routing_weights = routing_weights.softmax(dim=-1) + + # cast the expert indices to int64, otherwise one-hot encoding will fail + if selected_experts.dtype != torch.int64: + selected_experts = selected_experts.to(torch.int64) + + if len(selected_experts.shape) == 2: + selected_experts = selected_experts.unsqueeze(2) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + # For a given token, determine if it was routed to a given expert. + expert_mask = torch.max(expert_mask, axis=-2).values + + # cast to float32 otherwise mean will fail + expert_mask = expert_mask.to(torch.float32) + tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) + + router_prob_per_group_and_expert = torch.mean(routing_weights, axis=-1) + + # ✨ balance loss for this layer + balance_loss = torch.mean( + tokens_per_group_and_expert * router_prob_per_group_and_expert.unsqueeze(-1) + ) * (num_experts**2) + all_balance_losses.append(balance_loss.reshape(1)) + + all_balance_losses = torch.cat(all_balance_losses).mean() # ✨ + + return all_balance_losses + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral +class MixtralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MixtralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mixtral +class MixtralRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral +class MixtralAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + + self.rotary_emb = MixtralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# fmt: off +# 🔍 Modified from DynamicCache +class MoECache(Cache): + """ + Modified from the `DynamicCache`!!! + A cache that grows dynamically as more tokens are generated. + This cache adds extra support for Attention MoE. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + """ + + def __init__(self, num_experts: int) -> None: + # 🔍 multi-experts support + self.num_experts = num_experts + self.key_cache: List[Dict[int, torch.Tensor]] = [{} for _ in range(num_experts)] + self.value_cache: List[Dict[int, torch.Tensor]] = [{} for _ in range(num_experts)] + self._seen_tokens: List[Dict[int, int]] = [{} for _ in range(num_experts)] # Used in `generate` to keep tally of how many tokens the cache has seen + self._seen_tokens_total = 0 # 🔍 the total number of individual tokens that at least one expert has seen, this is for `get_seq_length` globally + + self.attention_mask_cache: List[Dict[int, torch.BoolTensor]] = [{} for _ in range(num_experts)] # 🔍 this is a new cache for attention mask that records the state of previous tokens + + def __getitem__(self, layer_idx: int, expert_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + if expert_idx < self.num_experts: # 🔍 + return (self.key_cache[expert_idx][layer_idx], self.value_cache[expert_idx][layer_idx]) + else: # 🔍 + raise KeyError(f"Cache only has {self.num_experts} experts, attempted to access expert with index {expert_idx}") + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + for expert_idx in range(self.num_experts): # 🔍 + if layer_idx in self.key_cache[expert_idx]: + yield (self.key_cache[expert_idx][layer_idx], self.value_cache[expert_idx][layer_idx]) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + all_index_list = [key for i in range(self.num_experts) for key in self.key_cache[i].keys()] + + if len(all_index_list) == 0: + return 0 + else: + return max(all_index_list) + 1 # 🔍 the maximum layer index among all experts + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + expert_idx: int, # 🔍 + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + expert_idx (`int`): + 🔍 The index of the expert to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `MoECache`. + + Return: + A tuple containing the updated key and value states. + """ + if layer_idx not in self._seen_tokens[expert_idx]: # 🔍 + # Update the number of seen tokens + self._seen_tokens[expert_idx][layer_idx] = key_states.shape[-2] + # Update the cache + self.key_cache[expert_idx][layer_idx] = key_states + self.value_cache[expert_idx][layer_idx] = value_states + + else: # 🔍 + # Update the number of seen tokens + self._seen_tokens[expert_idx][layer_idx] += key_states.shape[-2] + # Update the cache + self.key_cache[expert_idx][layer_idx] = torch.cat([self.key_cache[expert_idx][layer_idx], key_states], dim=-2) + self.value_cache[expert_idx][layer_idx] = torch.cat([self.value_cache[expert_idx][layer_idx], value_states], dim=-2) + + return self.key_cache[expert_idx][layer_idx], self.value_cache[expert_idx][layer_idx] + + def add_seen_tokens_total(self, new_token_num: int = 0) -> None: + """🔍 Add the number of new tokens to the total number of seen tokens.""" + # THIS FUNCTION IS EXCLUSIVE FOR `MoECache`! + self._seen_tokens_total += new_token_num + + def get_seq_length(self, layer_idx: Optional[int] = None, expert_idx: Optional[int] = None) -> Union[List[List[int]], int]: # 🔍 + """Returns the sequence length of the cached states. A layer & expert index can be optionally passed.""" + if layer_idx is not None and expert_idx is not None: # 🔍 return the length for specific layer & expert + if self.num_experts <= expert_idx or layer_idx not in self.key_cache[expert_idx]: # 🔍 + return 0 + else: + return self.key_cache[expert_idx][layer_idx].shape[-2] + + else: # 🔍 return the total number of individual tokens the cache has seen + return self._seen_tokens_total + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states. MoECache does not have a maximum length.""" + return None + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = None, expert_idx: Optional[int] = None) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_length() + previous_seq_length = self.get_seq_length(layer_idx, expert_idx) # 🔍 + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + # TODO: support for beam search + print("MoECache, reorder_cache", beam_idx) + raise NotImplementedError + + # for layer_idx in range(len(self.key_cache)): + # device = self.key_cache[layer_idx].device + # self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + # device = self.value_cache[layer_idx].device + # self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + raise NotImplementedError + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "MoECache": + raise NotImplementedError + + def update_attention_mask( + self, + new_attention_mask: torch.BoolTensor, + layer_idx: int, + expert_idx: int, + ) -> torch.BoolTensor: + """ + 🔍 Updates the attention mask cache with the new `attention_mask`. + + Parameters: + new_attention_mask (`torch.Tensor`): + The new key states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + expert_idx (`int`): + The index of the expert to cache the states for. + + Return: + A tensor containing the updated attention_mask. + """ + # Update the cache + if layer_idx not in self.attention_mask_cache[expert_idx]: # 🔍 no attention mask cached, this is the first stroke + self.attention_mask_cache[expert_idx][layer_idx] = new_attention_mask + else: # 🔍 concatenate along the seq_len dim + self.attention_mask_cache[expert_idx][layer_idx] = torch.cat([self.attention_mask_cache[expert_idx][layer_idx], new_attention_mask], dim=-1) + + return self.attention_mask_cache[expert_idx][layer_idx] + + +# 🔍 Modified from MixtralAttention +class MixtralAttentionMoE(MixtralAttention): + def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): + super(MixtralAttention, self).__init__() # 🔍 init using nn.Module + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + # 🔍 + self.softmax = nn.Softmax(dim=-1) + self.top_k_attn = config.top_k_attn + self.attn_experts = config.attn_experts + self.scale_factor_attn = config.scale_factor_attn + + self.split_ratio = self.attn_experts // self.num_key_value_heads + + self.gate = nn.Linear(self.hidden_size, self.attn_experts, bias=False) + + # 🔍 + self.q_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_key_value_groups * self.head_dim // self.split_ratio, bias=False) for _ in range(self.attn_experts)]) + self.k_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.head_dim, bias=False) for _ in range(self.attn_experts)]) + self.v_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.head_dim, bias=False) for _ in range(self.attn_experts)]) + self.o_proj = nn.ModuleList([nn.Linear(self.num_key_value_groups * self.head_dim // self.split_ratio, self.hidden_size, bias=config.add_rescale_bias) for _ in range(self.attn_experts)]) # 🔍 (may add bias for rescaling) + + self.rotary_emb = MixtralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, # 🔍 This should be the Tensor with shape(bsz, seqlen) that represents the padding mask + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[MoECache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + if past_key_value is not None and not isinstance(past_key_value, MoECache): # 🔍 type check + raise TypeError( + "`past_key_value` must be a `MoECache` instance for attention MoE!" + ) + # print("attention_mask", attention_mask, attention_mask.shape) + device = hidden_states.device + dtype = hidden_states.dtype + bsz, q_len, hidden_dim = hidden_states.size() + hidden_states = hidden_states.reshape(-1, hidden_dim) # 🔍 flatten the dim + + # 🔍 topk gating + router_logits = self.gate(hidden_states) # (bsz * q_len, num_key_value_heads) + scores = F.softmax(router_logits, dim=1, dtype=torch.float) + + routing_weights, selected_experts = torch.topk(scores, self.top_k_attn, dim=-1) # (bsz * q_len, top_k_attn) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(dtype) # we cast back to the input dtype + + # 🔍 moe selection + final_attn_output = torch.zeros_like(hidden_states).reshape(-1, hidden_dim) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.attn_experts) # (bsz * q_len, top_k_attn, num_key_value_heads) + expert_mask = expert_mask.permute(2, 1, 0) # (num_key_value_heads, top_k_attn, bsz * q_len) + + # Loop over all available experts in the model and perform the computation on each expert + all_attn_weights = [] if output_attentions else None + for expert_idx in range(self.attn_experts): + # expert_mask[expert_idx]: (top_k_attn, bsz * q_len) + # idx: the topk position. (selected_num) + # top_x: token index. (selected_num) + idx, top_x = torch.nonzero(expert_mask[expert_idx], as_tuple=True) + + if top_x.shape[0] == 0 and not self.training: # skip during training will lead to asynchrony among different GPUs and blocks the training! + if output_attentions: + all_attn_weights.append(None) + continue + + # 🔍 Comment (DDZ): This is useless and even lags the speed, so I get it removed. + # in torch it is faster to index using lists than torch tensors + # top_x_list = top_x.tolist() + # idx_list = idx.tolist() + + # 🔍 get routing info for this expert + current_batch_ids = (top_x // q_len) # batch ids for current_state, (selected_num) + each_batch_selected_token_num = torch.bincount(current_batch_ids, minlength=bsz) # (bsz) + this_q_len = each_batch_selected_token_num.max().item() + + # 🔍 get the indices of each token in the hidden_state of this expert + selection_mask = torch.zeros((bsz * q_len,), device=device, dtype=torch.bool) # the selection mask for this expert (this helps specify the position to put for each token) + selection_mask[top_x] = True + selection_mask = selection_mask.reshape(bsz, q_len) + + token_position_indices = torch.cumsum(selection_mask, dim=1) - 1 # the sequence ids of all tokens in the current state, (bsz, q_len) + token_position_indices = token_position_indices.flatten() + + current_seq_ids = token_position_indices[top_x] # sequence ids for current_state, (selected_num) + + # 🔍 initialize hidden_states for this expert + current_state = torch.zeros((bsz, this_q_len, hidden_dim), dtype=dtype, device=device) + current_state[current_batch_ids, current_seq_ids] = hidden_states[top_x] # assign tokens sparsely + + # Normal Attention Forward + # ---------------------------------------------- # + query_states = self.q_proj[expert_idx](current_state) # 🔍 specify expert + key_states = self.k_proj[expert_idx](current_state) # 🔍 specify expert + value_states = self.v_proj[expert_idx](current_state) # 🔍 specify expert + + query_states = query_states.view(bsz, this_q_len, self.num_key_value_groups // self.split_ratio, self.head_dim).transpose(1, 2) # 🔍 q_len -> this_q_len, num_heads -> num_key_value_groups + key_states = key_states.view(bsz, this_q_len, 1, self.head_dim).transpose(1, 2) # 🔍 q_len -> this_q_len, num_key_value_heads -> 1 + value_states = value_states.view(bsz, this_q_len, 1, self.head_dim).transpose(1, 2) # 🔍 q_len -> this_q_len, num_key_value_heads -> 1 + + past_key_values_length = 0 + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + past_key_values_length = past_key_value.get_usable_length(kv_seq_len, self.layer_idx, expert_idx) # 🔍 specify expert index + kv_seq_len += past_key_values_length + + # 🔍 create position_ids for selected tokens + current_position_ids = torch.zeros((bsz, this_q_len), device=device, dtype=torch.long) + current_position_ids[current_batch_ids, current_seq_ids] = position_ids.expand(bsz, q_len).flatten()[top_x] + + if top_x.shape[0] > 0: # apply only when there are tokens + cos, sin = self.rotary_emb(value_states, seq_len=current_position_ids.max().item() + 1) # 🔍 adjust the seq_len to the maximum possible value + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, current_position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, expert_idx, cache_kwargs) # 🔍 specify expert index + + # repeat k/v heads if n_kv_heads < n_heads + # Note (DDZ): here the dim is expanded internally, rather than concat-repeat. (Disable for Attention MoE) + # key_states = repeat_kv(key_states, self.num_key_value_groups) + # value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) # softmax temperature + + if attn_weights.size() != (bsz, self.num_key_value_groups // self.split_ratio, this_q_len, kv_seq_len): # 🔍 q_len -> this_q_len, num_heads -> num_key_value_groups + raise ValueError(f"Attention weights should be of size {(bsz, self.num_key_value_groups // self.split_ratio, this_q_len, kv_seq_len)}, but is {attn_weights.size()}") + + # 🔍 create `current_attention_mask` with reduced `seq_len` + # Notice that the `attention_mask` is passed intact during both training & generation, so we need to adjust the `top_x` by `past_key_values_length`. + # However, we don't have the routing information of previous tokens, which makes it impossible to create `current_attention_mask` for previous tokens. + # So here we need an extra "attention mask cache" to record the `attention_mask` for previous tokens, and update for new tokens accordingly during generation. + current_attention_mask = torch.zeros((bsz, this_q_len), dtype=torch.bool, device=device) + + if attention_mask is not None: + if past_key_values_length > 0: # 🔍 we need to exclude previous tokens + previous_seen_tokens_total = past_key_value._seen_tokens_total - q_len + temp_attention_mask = attention_mask[:, previous_seen_tokens_total:].flatten() # select along dimension 1 so that we get tokens in this iteration + else: + temp_attention_mask = attention_mask.flatten() # flatten the dim + current_attention_mask[current_batch_ids, current_seq_ids] = temp_attention_mask[top_x] # assign masks sparsely + + else: + current_attention_mask[current_batch_ids, current_seq_ids] = True # assign masks sparsely + + # print("current_attention_mask", current_attention_mask, current_attention_mask.shape) + if past_key_value is not None: # 🔍 we need to update with cached attention mask + current_attention_mask = past_key_value.update_attention_mask(current_attention_mask, self.layer_idx, expert_idx) + + # if self.layer_idx == 0 and expert_idx == 0: + # print("current_attention_mask", current_attention_mask.sum(-1), current_attention_mask.shape, current_attention_mask[0]) + current_attention_mask = _prepare_4d_causal_attention_mask( + current_attention_mask, + (bsz, this_q_len), + current_state, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + if current_attention_mask.size() != (bsz, 1, this_q_len, kv_seq_len): # 🔍 q_len -> this_q_len + raise ValueError(f"Attention mask should be of size {(bsz, 1, this_q_len, kv_seq_len)}, but is {current_attention_mask.size()}") + + attn_weights = attn_weights + current_attention_mask # 🔍 + # print("current_attention_mask", current_attention_mask.shape, current_attention_mask[0]) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + # if attn_output.size() != (bsz, self.num_key_value_groups // self.split_ratio, this_q_len, self.head_dim): # 🔍 q_len -> this_q_len, num_heads -> num_key_value_groups + # raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, this_q_len, self.num_key_value_groups * self.head_dim // self.split_ratio) # 🔍 q_len -> this_q_len, hidden_size -> num_key_value_groups * head_dim + attn_output = self.o_proj[expert_idx](attn_output) + # ---------------------------------------------- # + + # 🔍 select & rescale the outputs by softmax scores + attn_output = attn_output[current_batch_ids, current_seq_ids] * (routing_weights[top_x, idx, None] * self.scale_factor_attn) + # attn_output = attn_output[current_batch_ids, current_seq_ids] # this line for debug only + + # 🔍 add to the final outputs + final_attn_output.index_add_(0, top_x, attn_output) + + if output_attentions: + all_attn_weights.append(attn_weights) + + # 🔍 reshape + final_attn_output = final_attn_output.reshape(bsz, q_len, hidden_dim) + + if output_attentions: + all_attn_weights = tuple(all_attn_weights) + + return final_attn_output, all_attn_weights, past_key_value, router_logits # 🔍 return an extra `router_logits` + + @torch.no_grad() + def from_vanilla_attention(attention: MixtralAttention, top_k_attn, scale_factor_attn): + # config + layer_idx = attention.layer_idx + config = attention.config + config.top_k_attn = top_k_attn + config.scale_factor_attn = scale_factor_attn + + # init + attention_moe = MixtralAttentionMoE(config, layer_idx) + + split = 1 # split the hidden_size, support split=1 --> 8/2, split=2 --> 16/4, split=4 --> 32/8 + # copy weights + num_key_value_groups = attention_moe.num_key_value_groups // split + head_dim = attention_moe.head_dim + + for i in range(config.num_key_value_heads * split): + indices_q_o = [j for j in range(head_dim * num_key_value_groups * i, head_dim * num_key_value_groups * (i + 1))] + indices_k_v = [j for j in range(head_dim * (i // split), head_dim * ((i // split) + 1))] + + print(i, "indices_q_o", indices_q_o) + # print(i, "indices_k_v", indices_k_v) + + attention_moe.q_proj[i].weight.data = attention.q_proj.weight.data[indices_q_o].clone() + attention_moe.k_proj[i].weight.data = attention.k_proj.weight.data[indices_k_v].clone() + attention_moe.v_proj[i].weight.data = attention.v_proj.weight.data[indices_k_v].clone() + attention_moe.o_proj[i].weight.data = attention.o_proj.weight.data[:, indices_q_o].clone() + + return attention_moe + + +# fmt: on + + +# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral +class MixtralFlashAttention2(MixtralAttention): + """ + Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat( + [attention_mask, torch.ones_like(attention_mask[:, -1:])], + dim=-1, + ) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # print("attention_mask", attention_mask, attention_mask.shape) + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=( + self.config.sliding_window, + self.config.sliding_window, + ), + ) + + attn_output = pad_input( + attn_output_unpad, indices_q, batch_size, query_length + ) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=( + self.config.sliding_window, + self.config.sliding_window, + ), + ) + + return attn_output + + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + + +class MixtralFlashAttention2MoE(MixtralFlashAttention2): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.top_k_attn = self.config.top_k_attn + self.scale_factor_attn = self.config.scale_factor_attn + + self.gate = nn.Linear(self.hidden_size, self.num_key_value_heads, bias=False) + + self.q_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_key_value_groups * self.head_dim, bias=False) for _ in range(self.num_key_value_heads)]) + self.k_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.head_dim, bias=False) for _ in range(self.num_key_value_heads)]) + self.v_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.head_dim, bias=False) for _ in range(self.num_key_value_heads)]) + self.o_proj = nn.ModuleList([nn.Linear(self.num_key_value_groups * self.head_dim, self.hidden_size, bias=self.config.add_rescale_bias) for _ in range(self.num_key_value_heads)]) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ): + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + # attention_mask = kwargs.pop("padding_mask") + + if past_key_value is not None and not isinstance(past_key_value, MoECache): # 🔍 type check + raise TypeError( + "`past_key_value` must be a `MoECache` instance for attention MoE!" + ) + + bsz, q_len, hidden_dim = hidden_states.size() + device = hidden_states.device + dtype = hidden_states.dtype + + hidden_states = hidden_states.reshape(-1, hidden_dim) + # gate compute + router_logits = self.gate(hidden_states) + router_scores = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(router_scores, self.top_k_attn, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(dtype) + + final_attn_output = torch.zeros_like(hidden_states).reshape(-1, hidden_dim) + + expert_mask = F.one_hot(selected_experts, num_classes=self.num_heads).permute(2, 1, 0) + + # all_attn_weights = [] if output_attentions else None + + for expert_idx in range(self.num_key_value_heads): + idx, top_x = torch.nonzero(expert_mask[expert_idx], as_tuple=True) + # top_x_list = top_x.tolist() + # idx_list = idx.tolist() + + if top_x.shape[0] == 0 and not self.training: # skip during training will lead to asynchrony among different GPUs and blocks the training! + continue + + # create position_ids for selected tokens + current_batch_ids = (top_x // q_len) + each_batch_selected_token_num = torch.bincount(current_batch_ids, minlength=bsz) # (bsz) + this_q_len = each_batch_selected_token_num.max().item() + + selection_mask = torch.zeros((bsz * q_len,), device=device, dtype=torch.bool) + selection_mask[top_x] = True + selection_mask = selection_mask.reshape(bsz, q_len) + token_position_indices = torch.cumsum(selection_mask, dim=1) - 1 + token_position_indices = token_position_indices.flatten() + current_seq_ids = token_position_indices[top_x] + + + # 🔍 initialize hidden_states for this expert + current_state = torch.zeros((bsz, this_q_len, hidden_dim), dtype=dtype, device=device) + current_state[current_batch_ids, current_seq_ids] = hidden_states[top_x] # assign tokens sparsely + + # for attention forward + # expert_inputs = viewed_hidden_states[None, top_x_list].reshape(-1, self.hidden_size) + + query_states = self.q_proj[expert_idx](current_state) + key_states = self.k_proj[expert_idx](current_state) + value_states = self.v_proj[expert_idx](current_state) + + # seq_len = query_states.numel() // (bsz * self.num_key_value_groups * self.head_dim) + query_states = query_states.view(bsz, -1, self.num_key_value_groups, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, -1, 1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, 1, self.head_dim).transpose(1, 2) + + # for moe kv cache + past_key_values_length = 0 + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + past_key_values_length = past_key_value.get_usable_length(kv_seq_len, self.layer_idx, expert_idx) # 🔍 specify expert index + kv_seq_len += past_key_values_length + + current_position_ids = torch.zeros((bsz, this_q_len), device=hidden_states.device, dtype=torch.long) + current_position_ids[current_batch_ids, current_seq_ids] = position_ids.expand(bsz, q_len).flatten()[top_x] + + if top_x.shape[0] > 0: # apply only when there are tokens + cos, sin = self.rotary_emb(value_states, seq_len=current_position_ids.max().item() + 1) # 🔍 adjust the seq_len to the maximum possible value + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, current_position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, expert_idx, cache_kwargs) # 🔍 specify expert index + + # print("attention_mask", attention_mask.shape, attention_mask) + # for current attention mask + + ''' + current_attention_mask = torch.zeros((bsz, this_q_len), dtype=torch.bool, device=device) + + if attention_mask is not None: + if past_key_values_length > 0: # 🔍 we need to exclude previous tokens + previous_seen_tokens_total = past_key_value._seen_tokens_total - q_len + temp_attention_mask = attention_mask[:, previous_seen_tokens_total:].flatten() # select along dimension 1 so that we get tokens in this iteration + else: + temp_attention_mask = attention_mask.flatten() # flatten the dim + current_attention_mask[current_batch_ids, current_seq_ids] = temp_attention_mask[top_x] # bug here !!! + + else: + current_attention_mask[current_batch_ids, current_seq_ids] = True # assign masks sparsely + + if past_key_value is not None: # 🔍 we need to update with cached attention mask + current_attention_mask = past_key_value.update_attention_mask(current_attention_mask, self.layer_idx, expert_idx) + + + current_attention_mask = _prepare_4d_causal_attention_mask( + current_attention_mask, + (bsz, this_q_len), + current_state, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + if current_attention_mask.size() != (bsz, 1, this_q_len, kv_seq_len): # 🔍 q_len -> this_q_len + raise ValueError(f"Attention mask should be of size {(bsz, 1, this_q_len, kv_seq_len)}, but is {current_attention_mask.size()}") + + ''' + + # for sliding window + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + + # wait for change! sliding_window=4096 + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat( + [attention_mask, torch.ones_like(attention_mask[:, -1:])], + dim=-1, + ) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # for input dtype + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + dropout_rate = 0.0 if not self.training else self.attention_dropout + + repeat_num = query_states.shape[1] + key_states = repeat_kv(key_states, repeat_num) + value_states = repeat_kv(value_states, repeat_num) + + # print("repeat_num", repeat_num) + # print("query_states shape", query_states.shape, key_states.shape, value_states.shape) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + this_q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) + + attn_output = attn_output.reshape(bsz, this_q_len, self.num_key_value_groups * self.head_dim).contiguous() + attn_output = self.o_proj[expert_idx](attn_output) + attn_output = attn_output[current_batch_ids, current_seq_ids] * (routing_weights[top_x, idx, None] * self.scale_factor_attn) + + final_attn_output.index_add_(0, top_x, attn_output) + + final_attn_output = final_attn_output.reshape(bsz, q_len, hidden_dim) + + if not output_attentions: + attn_weights = None + + return final_attn_output, attn_weights, past_key_value, router_logits # 🔍 return an extra `router_logits` + + + +class MixtralFlashAttention2MoE_zt(MixtralFlashAttention2): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.top_k_attn = self.config.top_k_attn + self.scale_factor_attn = self.config.scale_factor_attn + # self.num_heads + # self.head_dim + # self.num_key_value_heads + # self.num_key_value_groups # total number of experts + assert self.top_k_attn <= self.num_key_value_groups + # assert self.top_k_attn % self.num_key_value_heads == 0 + self.attn_hsz = self.hidden_size // self.num_key_value_groups * self.top_k_attn + self.kv_repeat_num = self.attn_hsz // (self.num_key_value_heads * self.head_dim) + self.simulated_attn_head_num = self.attn_hsz // self.head_dim + assert self.attn_hsz % (self.num_key_value_heads * self.head_dim) == 0 + assert self.simulated_attn_head_num == self.num_heads * (self.top_k_attn / self.num_key_value_groups) + assert self.kv_repeat_num * self.num_key_value_heads == self.simulated_attn_head_num + + self.gate = nn.Linear(self.hidden_size, self.num_key_value_groups, bias=False) + # tzhu: there are self.num_key_value_groups experts + # each expert has a size of self.attn_hsz + self.q_proj = nn.ModuleList( + [nn.Linear(self.hidden_size, self.attn_hsz) for _ in range(self.num_key_value_groups)] + ) + self.o_proj = nn.ModuleList( + [nn.Linear(self.attn_hsz, self.hidden_size) for _ in range(self.num_key_value_groups)] + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + bsz, q_len, _ = hidden_states.size() + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # tzhu: attn-moe on q_proj + viewed_hidden_states = hidden_states.view(bsz * q_len, self.hidden_size) + # router + router_logits = self.gate(viewed_hidden_states) + router_scores = F.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(router_scores, self.top_k_attn, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + query_states = torch.zeros( + (bsz * q_len, self.attn_hsz), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + # expert_mask: (num_experts, top_k_attn, bsz * q_len) + expert_mask = F.one_hot(selected_experts, num_classes=self.num_heads).permute(2, 1, 0) + for expert_idx in range(self.num_key_value_groups): + expert_layer = self.q_proj[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + top_x_list = top_x.tolist() + idx_list = idx.tolist() + expert_inputs = viewed_hidden_states[None, top_x_list].reshape(-1, self.hidden_size) + # inputs (-1, hidden_size) -> outputs (-1, attn_hsz) + expert_outs = expert_layer(expert_inputs) * routing_weights[top_x_list, idx_list, None] * self.scale_factor_attn + query_states.index_add_(0, top_x, expert_outs.to(query_states.dtype)) + query_states = query_states.view(bsz, q_len, self.attn_hsz) + # query_states = query_states.view( + # bsz, q_len, self.num_heads, self.simulated_attn_head_num + # ).transpose(1, 2) + query_states = query_states.view( + bsz, q_len, self.simulated_attn_head_num, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat( + [attention_mask, torch.ones_like(attention_mask[:, -1:])], + dim=-1, + ) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.kv_repeat_num) + value_states = repeat_kv(value_states, self.kv_repeat_num) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) + + attn_output = attn_output.reshape(bsz * q_len, self.attn_hsz).contiguous() + final_attn_output = torch.zeros( + (bsz * q_len, self.hidden_size), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + for expert_idx in range(self.num_key_value_groups): + expert_layer = self.o_proj[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + top_x_list = top_x.tolist() + idx_list = idx.tolist() + expert_inputs = attn_output[None, top_x_list].reshape(-1, self.attn_hsz) + expert_outs = expert_layer(expert_inputs) * routing_weights[top_x_list, idx_list, None] * self.scale_factor_attn + final_attn_output.index_add_(0, top_x, expert_outs.to(final_attn_output.dtype)) + final_attn_output = final_attn_output.view(bsz, q_len, self.hidden_size) + + if not output_attentions: + attn_weights = None + + return final_attn_output, attn_weights, past_key_value, router_logits + + + @torch.no_grad() + def from_vanilla_attention(attention: MixtralAttention, top_k_attn, scale_factor_attn): + # config + layer_idx = attention.layer_idx + config = attention.config + config.top_k_attn = top_k_attn + config.scale_factor_attn = scale_factor_attn + + # init + attention_moe = MixtralFlashAttention2MoE(config, layer_idx) + + # copy weights + num_key_value_groups = attention_moe.num_key_value_groups + head_dim = attention_moe.head_dim + + for i in range(num_key_value_groups): + indices_q_o = [] + for j in range(attention_moe.num_key_value_heads): + k = i + j * num_key_value_groups + indices_q_o.extend( + list(range(k * head_dim, (k + 1) * head_dim)) + ) + + print(i, "indices_q_o", indices_q_o) + + attention_moe.q_proj[i].weight.data = attention.q_proj.weight.data[indices_q_o].clone() + attention_moe.o_proj[i].weight.data = attention.o_proj.weight.data[:, indices_q_o].clone() + + return attention_moe + + + + +class MixtralBLockSparseTop2MLP(nn.Module): + def __init__(self, config: MixtralConfig, ffn_dim, add_rescale_bias=False): # 🔍 + super().__init__() + self.ffn_dim = ffn_dim # 🔍 + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # gate + self.w2 = nn.Linear( + self.ffn_dim, self.hidden_dim, bias=add_rescale_bias + ) # 🔍 down (may add bias for rescaling) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # up + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3( + hidden_states + ) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +MISTRAL_ATTENTION_CLASSES = { + "eager": MixtralAttention, + "flash_attention_2": MixtralFlashAttention2, +} + +# 🔍 +MISTRAL_ATTENTION_MOE_CLASSES = { + "eager": MixtralAttentionMoE, + "flash_attention_2": MixtralFlashAttention2MoE, +} + + +class SimplifiedSparseGLU(nn.Module): + def __init__(self, args: MegablocksArguments): + super().__init__() + self.args = args + + if args.bf16: + torch_dtype = torch.bfloat16 + elif args.fp16: + torch_dtype = torch.float16 + else: + torch_dtype = None + + # gate + self.w1 = nn.Parameter( + torch.empty( + args.ffn_hidden_size * args.moe_num_experts, + args.hidden_size, + dtype=torch_dtype, + ) + ) + # down + self.w2 = nn.Parameter( + torch.empty( + args.ffn_hidden_size * args.moe_num_experts, + args.hidden_size, + dtype=torch_dtype, + ) + ) + # up + self.v1 = nn.Parameter( + torch.empty( + args.ffn_hidden_size * args.moe_num_experts, + args.hidden_size, + dtype=torch_dtype, + ) + ) + + self.act_fn = args.activation_fn + + def forward(self, x, topo): + if self.args.memory_optimized_mlp: + raise NotImplementedError( + "Memory optimized implementation not yet supported with GLU with sparse kernels." + ) + + # TODO (tzhu): test if OOM comes from dtensor conversion + # TODO (tzhu): return x directly to see if it still encounters OOM + # w1, v1, w2 = ( + # resolve_dtensor(self.w1), + # resolve_dtensor(self.v1), + # resolve_dtensor(self.w2), + # ) + + # Compute the GLU. + x1 = stk.ops.sdd(x, self.w1.t(), topo) + x2 = stk.ops.sdd(x, self.v1.t(), topo) + + activation_fn_out = act_fn(x1, self.act_fn) + x1 = stk.ops.mul(activation_fn_out, x2) + + return stk.ops.dsd(x1, self.w2) + + +class SimplifiedGroupedSparseGLU(SimplifiedSparseGLU): + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + # w1, v1, w2 = ( + # resolve_dtensor(self.w1), + # resolve_dtensor(self.v1), + # resolve_dtensor(self.w2), + # ) + + # Re-shape the weights for the grouped GEMMs. + # ne = mpu.experts_per_rank(self.args) + # w1 = self.w1.view(ne, -1, self.args.hidden_size) + # v1 = self.v1.view(ne, -1, self.args.hidden_size) + # w2 = self.w2.view(ne, -1, self.args.hidden_size) + + ne = self.args.moe_num_experts + w1 = self.w1.view(ne, -1, self.args.hidden_size) + v1 = self.v1.view(ne, -1, self.args.hidden_size) + w2 = self.w2.view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_glu( + x, + w1, + v1, + w2, + batch_sizes, + self.args.quantize_inputs_num_bits, + self.args.quantize_rematerialize_num_bits, + self.args.activation_fn, + ) + + # Compute the MLP. + x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) + x1 = self.act_fn(x1) * x2 + return gg.ops.gmm(x1, w2, batch_sizes) + + +_REGISTRY["simplified_glu"] = { + "grouped": SimplifiedGroupedSparseGLU, + "sparse": SimplifiedSparseGLU, +} + + +class SimplifiedParallelDroplessMLP(ParallelDroplessMLP): + def forward(self, x, expert_weights, top_experts): + in_shape = x.size() + + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) + x = x.view(in_shape) + if self.bias is not None: + if self.args.return_bias: + return x, self.bias + return x + self.bias + return x + + +class MixtralSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accomodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # specialized for llama-moe-v2 + self.scale_factor = config.scale_factor + self.moe_type = config.moe_type + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + if self.moe_type == "modulelist": + self.experts = nn.ModuleList( + [ + MixtralBLockSparseTop2MLP( + config, + config.intermediate_size, + add_rescale_bias=config.add_rescale_bias, + ) + for _ in range(self.num_experts) + ] # 🔍 + ) + elif self.moe_type == "megablocks": + if config.add_rescale_bias: + raise NotImplementedError( + "RescaleBias not yet supported with megablocks." + ) + is_fp16 = self.gate.weight.dtype == torch.float16 + is_bf16 = self.gate.weight.dtype == torch.bfloat16 + args = MegablocksArguments( + hidden_size=self.hidden_dim, + ffn_hidden_size=self.ffn_dim, + moe_num_experts=self.num_experts, + moe_top_k=self.top_k, + activation_fn={"silu": F.silu}[config.hidden_act], + mlp_type="simplified_glu", + mlp_impl="sparse", + memory_optimized_mlp=False, + bias=False, + fp16=is_fp16, + bf16=is_bf16, + ) + self.experts = SimplifiedParallelDroplessMLP(args) + elif self.moe_type == "scattermoe": + if config.add_rescale_bias: + raise NotImplementedError( + "RescaleBias not yet supported with scattermoe." + ) + self.experts = scattermoe.mlp.GLUMLP( + input_size=self.hidden_dim, + hidden_size=self.ffn_dim, + num_experts=self.num_experts, + top_k=self.top_k, + activation={"silu": F.silu}[config.hidden_act], + ) + else: + raise NotImplementedError(f"Unsupported moe_type: {self.moe_type}") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + scores = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(scores, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + if self.moe_type == "megablocks": + final_hidden_states = self.experts( + hidden_states, routing_weights, selected_experts + ) + elif self.moe_type == "scattermoe": + final_hidden_states = self.experts( + hidden_states, routing_weights, selected_experts + ) + else: + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if ( + top_x.shape[0] == 0 and not self.training + ): # skip during training will lead to asynchrony among different GPUs and blocks the training! + continue + + # in torch it is faster to index using lists than torch tensors + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * ( + routing_weights[top_x_list, idx_list, None] * self.scale_factor + ) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ) + + return final_hidden_states, router_logits + + +class MixtralDecoderLayer(nn.Module): + def __init__(self, config: MixtralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + # 🔍 + self.is_moe = (layer_idx >= config.num_moe_contract_layers) and ( + layer_idx < config.num_hidden_layers - config.num_moe_contract_layers + ) + self.use_attn_moe = config.use_attn_moe + + if self.use_attn_moe: + attn_class = MISTRAL_ATTENTION_MOE_CLASSES[config._attn_implementation] + else: + attn_class = MISTRAL_ATTENTION_CLASSES[config._attn_implementation] + self.self_attn = attn_class(config, layer_idx) + + + if self.is_moe: + self.block_sparse_moe = MixtralSparseMoeBlock(config) + self.mlp_residual = ( + MixtralBLockSparseTop2MLP(config, config.intermediate_size_residual) + if config.intermediate_size_residual is not None + else None + ) + + else: + self.block_sparse_moe = MixtralBLockSparseTop2MLP( + config, config.intermediate_size * config.num_local_experts + ) + self.mlp_residual = None + + self.input_layernorm = MixtralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = MixtralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # 🔍 Self Attention + if self.use_attn_moe: + ( + hidden_states, + self_attn_weights, + present_key_value, + attn_router_logits, + ) = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + else: + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + attn_router_logits = None + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states_input = self.post_attention_layernorm(hidden_states) + + # 🔍 + if self.is_moe: + hidden_states, router_logits = self.block_sparse_moe(hidden_states_input) + else: + hidden_states = self.block_sparse_moe(hidden_states_input) + router_logits = None + + if self.mlp_residual is not None: + hidden_states += self.mlp_residual(hidden_states_input) # + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits, attn_router_logits) # 🔍 + + return outputs + + +# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral +class MixtralPreTrainedModel(PreTrainedModel): + config_class = MixtralConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MixtralDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral +class MixtralModel(MixtralPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] + + Args: + config: MixtralConfig + """ + + def __init__(self, config: MixtralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + MixtralDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + past_key_values_length = 0 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + if self.config.use_attn_moe: # 🔍 + past_key_values = MoECache.from_legacy_cache(past_key_values) + else: # 🔍 + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + # 🔍 add total seen tokens, this is VERY important for getting correct `past_key_values_length`! + if self.config.use_attn_moe: + past_key_values.add_seen_tokens_total(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._use_flash_attention_2 and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if ( + self._use_flash_attention_2 or self.config.use_attn_moe + ): # 🔍 added special case for attention MoE + # 2d mask is passed through the layers + attention_mask = ( + attention_mask + if (attention_mask is not None and 0 in attention_mask) + else None + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + # print("attention_mask" , attention_mask) + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + all_attn_router_logits = () if output_router_logits else None # 🔍 + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs) + + return custom_forward + + layer_outputs: tuple = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + ) + + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-2],) + all_attn_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + all_router_logits, + ] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + attn_router_logits=all_attn_router_logits, # 🔍 + ) + + +class MixtralForCausalLM(MixtralPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MixtralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, MixtralModel): + module.gradient_checkpointing = value + + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + # print("MixtralForCausalLM, cross entropy loss", loss) + + aux_loss = None + if output_router_logits: + valid_router_logits = tuple( + logits + for logits in (outputs.router_logits if return_dict else outputs[-2]) + if logits is not None + ) + + aux_loss = load_balancing_loss_func( + valid_router_logits, + self.num_experts, + self.num_experts_per_tok, + use_layer_wise_balance=self.config.use_layer_wise_balance, # ✨ + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + # loss_mlp = self.router_aux_loss_coef * aux_loss + # loss = loss + loss_mlp + # print("MixtralForCausalLM, mlp aux_loss", loss_mlp) + + # 🔍 for Attention MoE + ################################# + valid_attn_router_logits = tuple( + logits + for logits in ( + outputs.attn_router_logits if return_dict else outputs[-1] + ) + if logits is not None + ) + + if len(valid_attn_router_logits) > 0: # exist logits that is not None + attn_aux_loss = load_balancing_loss_func( + valid_attn_router_logits, + self.config.attn_experts, + self.config.top_k_attn, + use_layer_wise_balance=self.config.use_layer_wise_balance, # ✨ + ) + if labels is not None: + loss += self.router_aux_loss_coef * attn_aux_loss + # loss_attn = self.router_aux_loss_coef * attn_aux_loss + # loss = loss + loss_attn + # print("MixtralForCausalLM, attn aux_loss", loss_attn) + ################################# + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + # Omit tokens covered by past_key_values + if past_key_values is not None: + if isinstance(past_key_values, MoECache): # 🔍 for MoECache only + cache_length = past_key_values.get_seq_length() + past_length = past_key_values._seen_tokens_total # 🔍 + max_cache_length = past_key_values.get_max_length() + elif isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + # TODO: support for beam search + print("MixtralForCausalLM, _reorder_cache", beam_idx) + raise NotImplementedError + + # reordered_past = () + # for layer_past in past_key_values: + # reordered_past += ( + # tuple( + # past_state.index_select(0, beam_idx.to(past_state.device)) + # for past_state in layer_past + # ), + # ) + # return reordered_past + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + r""" + + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](../generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config ([`~generation.GenerationConfig`], *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which has the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complements the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. If your stopping criteria depends on the `scores` input, make + sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is + intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*): + Whether to continue running the while loop until max_length. Unless overridden this flag will be set to + `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished + generating before other GPUs. Otherwise it'll be set to `False`. + assistant_model (`PreTrainedModel`, *optional*): + An assistant model that can be used to accelerate generation. The assistant model must have the exact + same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model + is much faster than running generation with the model you're calling generate from. As such, the + assistant model should be much smaller. + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + The negative prompt needed for some processors such as CFG. The batch size must match the input batch + size. This is an experimental feature, subject to breaking API changes in future versions. + negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Attention_mask for `negative_prompt_ids`. + kwargs (`Dict[str, Any]`, *optional*): + Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GenerateDecoderOnlyOutput`], + - [`~generation.GenerateBeamDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + self._validate_model_class() + tokenizer = kwargs.pop( + "tokenizer", None + ) # Pull this out first, we only use it for stopping criteria + generation_config, model_kwargs = self._prepare_generation_config( + generation_config, **kwargs + ) + self._validate_model_kwargs(model_kwargs.copy()) + self._validate_assistant(assistant_model) + + # 2. Set generation parameters if not already defined + if synced_gpus is None: + if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1: + synced_gpus = True + else: + synced_gpus = False + + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + + accepts_attention_mask = "attention_mask" in set( + inspect.signature(self.forward).parameters.keys() + ) + requires_attention_mask = "encoder_outputs" not in model_kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + + # 3. Define model inputs + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + device = inputs_tensor.device + self._prepare_special_tokens( + generation_config, kwargs_has_attention_mask, device=device + ) + + # decoder-only models must use left-padding for batched generation. + if not self.config.is_encoder_decoder and not is_torchdynamo_compiling(): + # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` + # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. + if ( + generation_config.pad_token_id is not None + and batch_size > 1 + and len(inputs_tensor.shape) == 2 + and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) + > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! For correct " + "generation results, please set `padding_side='left'` when initializing the tokenizer." + ) + + # 4. Define other model kwargs + # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are + # generating the first new token or not, and we only want to use the embeddings for the first new token) + if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": + model_kwargs["use_cache"] = True + else: + model_kwargs["use_cache"] = generation_config.use_cache + + if ( + not kwargs_has_attention_mask + and requires_attention_mask + and accepts_attention_mask + ): + model_kwargs[ + "attention_mask" + ] = self._prepare_attention_mask_for_generation( + inputs_tensor, + generation_config.pad_token_id, + generation_config.eos_token_id, + ) + + if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name, generation_config + ) + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( + batch_size=batch_size, + model_input_name=model_input_name, + model_kwargs=model_kwargs, + decoder_start_token_id=generation_config.decoder_start_token_id, + device=inputs_tensor.device, + ) + else: + input_ids = ( + inputs_tensor + if model_input_name == "input_ids" + else model_kwargs.pop("input_ids") + ) + + if generation_config.token_healing: + input_ids = self.heal_tokens(input_ids, tokenizer) + + if streamer is not None: + streamer.put(input_ids.cpu()) + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_length = input_ids.shape[-1] + has_default_max_length = ( + kwargs.get("max_length") is None + and generation_config.max_length is not None + ) + has_default_min_length = ( + kwargs.get("min_length") is None + and generation_config.min_length is not None + ) + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=inputs_tensor, + input_ids_length=input_ids_length, + ) + + use_dynamic_cache_by_default = False + if ( + generation_config.cache_implementation is not None + and model_kwargs.get("past_key_values") is not None + ): + raise ValueError( + "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " + "Cache object) is unsupported. Please use only one of the two." + ) + elif generation_config.cache_implementation is not None: + if self.config.use_attn_moe: # 🔍 + raise ValueError( + "Attention MoE doesn't support specifying the cache type! You can only use `MoECache`" + ) + if ( + generation_config.cache_implementation + in NEED_SETUP_CACHE_CLASSES_MAPPING + ): + if ( + generation_config.cache_implementation == "static" + and not self._supports_static_cache + ): + raise ValueError( + "This model does not support `cache_implementation='static'`. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981" + ) + model_kwargs["past_key_values"] = self._get_cache( + generation_config.cache_implementation, + getattr(generation_config, "num_beams", 1) * batch_size, + generation_config.max_length, + ) + elif generation_config.cache_implementation == "quantized": + if not self._supports_quantized_cache: + raise ValueError( + "This model does not support the quantized cache. If you want your model to support quantized " + "cache, please open an issue." + ) + + cache_config = ( + generation_config.cache_config + if generation_config.cache_config is not None + else QuantizedCacheConfig() + ) + cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] + + if cache_config.backend == "quanto" and not is_quanto_available(): + raise ImportError( + "You need to install `quanto` in order to use KV cache quantization with quanto backend. " + "Please install it via with `pip install quanto`" + ) + elif cache_config.backend == "HQQ" and not is_hqq_available(): + raise ImportError( + "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " + "Please install it via with `pip install hqq`" + ) + + model_kwargs["past_key_values"] = cache_class(cache_config) + # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that + # keeps copying the cache thus using much more memory + elif ( + generation_config.cache_implementation is None + and self._supports_default_dynamic_cache() + ): + past = model_kwargs.get("past_key_values", None) + if past is None: + if self.config.use_attn_moe: # 🔍 + model_kwargs["past_key_values"] = MoECache( + self.config.num_key_value_heads + ) + else: # 🔍 + model_kwargs["past_key_values"] = DynamicCache() + use_dynamic_cache_by_default = True + elif isinstance(past, tuple): + if self.config.use_attn_moe: # 🔍 + model_kwargs["past_key_values"] = MoECache.from_legacy_cache(past) + else: # 🔍 + model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache( + past + ) + use_dynamic_cache_by_default = True + + self._validate_generated_length( + generation_config, input_ids_length, has_default_max_length + ) + + # 7. determine generation mode + generation_mode = generation_config.get_generation_mode(assistant_model) + + if streamer is not None and (generation_config.num_beams > 1): + raise ValueError( + "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." + ) + + if self.device.type != input_ids.device.type: + warnings.warn( + "You are calling .generate() with the `input_ids` being on a device type different" + f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" + f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." + " Please make sure that you have put `input_ids` to the" + f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" + " running `.generate()`.", + UserWarning, + ) + + # 8. prepare distribution pre_processing samplers + prepared_logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + device=inputs_tensor.device, + model_kwargs=model_kwargs, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + # 9. prepare stopping criteria + prepared_stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, + stopping_criteria=stopping_criteria, + tokenizer=tokenizer, + **kwargs, + ) + + # 10. go into different generation modes + if generation_mode == GenerationMode.ASSISTED_GENERATION: + if generation_config.num_return_sequences > 1: + raise ValueError( + "num_return_sequences has to be 1 when doing assisted generate, " + f"but is {generation_config.num_return_sequences}." + ) + if batch_size > 1: + raise ValueError( + "assisted generate is only supported for batch_size = 1" + ) + if not model_kwargs["use_cache"]: + raise ValueError("assisted generate requires `use_cache=True`") + if generation_config.cache_implementation == "static": + raise ValueError( + "assisted generate is not supported with `static_cache`" + ) + if self._is_stateful: + # In assisted generation we need the ability to confirm whether the model would pick certain tokens, + # which is not possible with stateful models (they can't reset to a previous subset of generated text) + raise ValueError( + f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}" + ) + + # 11. Get the candidate generator, given the parameterization + candidate_generator = self._get_candidate_generator( + generation_config=generation_config, + input_ids=input_ids, + inputs_tensor=inputs_tensor, + assistant_model=assistant_model, + logits_processor=logits_processor, + model_kwargs=model_kwargs, + ) + + # 12. prepare logits warper (if `do_sample` is `True`) + prepared_logits_warper = ( + self._get_logits_warper( + generation_config, + device=input_ids.device, + ) + if generation_config.do_sample + else None + ) + + # 13. run assisted generate + result = self._assisted_decoding( + input_ids, + candidate_generator=candidate_generator, + logits_processor=prepared_logits_processor, + logits_warper=prepared_logits_warper, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) + + elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH: + if not model_kwargs["use_cache"]: + raise ValueError("Contrastive search requires `use_cache=True`") + if self._is_stateful: + # Just like assisted generation, we need to be able to rollback to a previous state (see comment above) + raise ValueError( + f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}" + ) + + result = self._contrastive_search( + input_ids, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) + + elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): + # 11. prepare logits warper + prepared_logits_warper = ( + self._get_logits_warper(generation_config, device=input_ids.device) + if generation_config.do_sample + else None + ) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) + result = self._sample( + input_ids, + logits_processor=prepared_logits_processor, + logits_warper=prepared_logits_warper, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) + + elif generation_mode in ( + GenerationMode.BEAM_SAMPLE, + GenerationMode.BEAM_SEARCH, + ): + # 11. prepare logits warper + prepared_logits_warper = ( + self._get_logits_warper(generation_config, device=input_ids.device) + if generation_config.do_sample + else None + ) + + # 12. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + max_length=generation_config.max_length, + ) + + # 13. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 14. run beam sample + result = self._beam_search( + input_ids, + beam_scorer, + logits_processor=prepared_logits_processor, + logits_warper=prepared_logits_warper, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH: + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + num_beam_groups=generation_config.num_beam_groups, + max_length=generation_config.max_length, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + result = self._group_beam_search( + input_ids, + beam_scorer, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH: + final_constraints = [] + if generation_config.constraints is not None: + final_constraints = generation_config.constraints + + if generation_config.force_words_ids is not None: + + def typeerror(): + raise ValueError( + "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]` " + f"of positive integers, but is {generation_config.force_words_ids}." + ) + + if ( + not isinstance(generation_config.force_words_ids, list) + or len(generation_config.force_words_ids) == 0 + ): + typeerror() + + for word_ids in generation_config.force_words_ids: + if isinstance(word_ids[0], list): + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any( + not isinstance(token_ids, list) for token_ids in word_ids + ): + typeerror() + if any( + any( + (not isinstance(token_id, int) or token_id < 0) + for token_id in token_ids + ) + for token_ids in word_ids + ): + typeerror() + + constraint = DisjunctiveConstraint(word_ids) + else: + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any( + (not isinstance(token_id, int) or token_id < 0) + for token_id in word_ids + ): + typeerror() + + constraint = PhrasalConstraint(word_ids) + final_constraints.append(constraint) + + # 11. prepare beam search scorer + constrained_beam_scorer = ConstrainedBeamSearchScorer( + constraints=final_constraints, + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + max_length=generation_config.max_length, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + result = self._constrained_beam_search( + input_ids, + constrained_beam_scorer=constrained_beam_scorer, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + # Convert to legacy cache if needed + if use_dynamic_cache_by_default and generation_config.return_legacy_cache: + if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"): + if isinstance(result.past_key_values, DynamicCache): + result.past_key_values = result.past_key_values.to_legacy_cache() + return result + + +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL +class MixtralForSequenceClassification(MixtralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MixtralModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) \ No newline at end of file