Upload HymbaForCausalLM
Browse files- config.json +9 -9
- modeling_hymba.py +160 -125
config.json
CHANGED
@@ -15,14 +15,6 @@
|
|
15 |
"conv_dim": {
|
16 |
"0": 3200,
|
17 |
"1": 3200,
|
18 |
-
"2": 3200,
|
19 |
-
"3": 3200,
|
20 |
-
"4": 3200,
|
21 |
-
"5": 3200,
|
22 |
-
"6": 3200,
|
23 |
-
"7": 3200,
|
24 |
-
"8": 3200,
|
25 |
-
"9": 3200,
|
26 |
"10": 3200,
|
27 |
"11": 3200,
|
28 |
"12": 3200,
|
@@ -33,6 +25,7 @@
|
|
33 |
"17": 3200,
|
34 |
"18": 3200,
|
35 |
"19": 3200,
|
|
|
36 |
"20": 3200,
|
37 |
"21": 3200,
|
38 |
"22": 3200,
|
@@ -43,8 +36,15 @@
|
|
43 |
"27": 3200,
|
44 |
"28": 3200,
|
45 |
"29": 3200,
|
|
|
46 |
"30": 3200,
|
47 |
-
"31": 3200
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
},
|
49 |
"eos_token_id": 2,
|
50 |
"global_attn_idx": [
|
|
|
15 |
"conv_dim": {
|
16 |
"0": 3200,
|
17 |
"1": 3200,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
"10": 3200,
|
19 |
"11": 3200,
|
20 |
"12": 3200,
|
|
|
25 |
"17": 3200,
|
26 |
"18": 3200,
|
27 |
"19": 3200,
|
28 |
+
"2": 3200,
|
29 |
"20": 3200,
|
30 |
"21": 3200,
|
31 |
"22": 3200,
|
|
|
36 |
"27": 3200,
|
37 |
"28": 3200,
|
38 |
"29": 3200,
|
39 |
+
"3": 3200,
|
40 |
"30": 3200,
|
41 |
+
"31": 3200,
|
42 |
+
"4": 3200,
|
43 |
+
"5": 3200,
|
44 |
+
"6": 3200,
|
45 |
+
"7": 3200,
|
46 |
+
"8": 3200,
|
47 |
+
"9": 3200
|
48 |
},
|
49 |
"eos_token_id": 2,
|
50 |
"global_attn_idx": [
|
modeling_hymba.py
CHANGED
@@ -1579,145 +1579,133 @@ class HymbaBlock(nn.Module):
|
|
1579 |
def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None, position_ids=None, kv_last_layer=None, use_cache=False, use_swa=False):
|
1580 |
projected_states = self.in_proj(hidden_states).transpose(1, 2) ## (bs, latent_dim, seq_len)
|
1581 |
|
1582 |
-
|
1583 |
-
|
1584 |
-
|
1585 |
-
|
1586 |
-
|
1587 |
-
|
1588 |
-
|
1589 |
-
|
1590 |
-
|
1591 |
-
|
1592 |
-
|
1593 |
-
|
1594 |
-
|
1595 |
-
|
1596 |
-
self.D.float(),
|
1597 |
-
delta_bias=self.dt_proj.bias.float(),
|
1598 |
-
delta_softplus=True,
|
1599 |
-
)
|
1600 |
|
1601 |
-
|
1602 |
-
batch_size, seq_len, _ = hidden_states.shape
|
1603 |
-
use_precomputed_states = (
|
1604 |
-
cache_params is not None
|
1605 |
-
and cache_params.has_previous_state
|
1606 |
-
and seq_len == 1
|
1607 |
-
and cache_params.conv_states[self.layer_idx].shape[0]
|
1608 |
-
== cache_params.ssm_states[self.layer_idx].shape[0]
|
1609 |
-
== batch_size
|
1610 |
-
and use_cache
|
1611 |
-
)
|
1612 |
|
1613 |
-
|
1614 |
|
1615 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1616 |
|
1617 |
-
|
1618 |
-
|
1619 |
-
|
1620 |
-
|
1621 |
-
|
1622 |
-
|
1623 |
-
query_states = query_states.transpose(1,2)
|
1624 |
-
key_states = key_states.transpose(1,2)
|
1625 |
-
value_states = value_states.transpose(1,2)
|
1626 |
-
|
1627 |
-
if use_precomputed_states:
|
1628 |
-
hidden_states = causal_conv1d_update(
|
1629 |
-
hidden_states.squeeze(-1),
|
1630 |
-
cache_params.conv_states[self.layer_idx],
|
1631 |
-
conv_weights,
|
1632 |
-
self.conv1d.bias,
|
1633 |
-
self.activation,
|
1634 |
)
|
1635 |
-
hidden_states = hidden_states.unsqueeze(-1)
|
1636 |
|
1637 |
-
cache_params.
|
1638 |
-
else:
|
1639 |
-
if cache_params is not None:
|
1640 |
-
conv_states = nn.functional.pad(
|
1641 |
-
hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
|
1642 |
-
)
|
1643 |
|
1644 |
-
|
|
|
|
|
|
|
|
|
1645 |
|
1646 |
-
|
1647 |
-
|
1648 |
-
|
1649 |
-
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
|
1650 |
-
)
|
1651 |
|
1652 |
-
|
1653 |
-
|
1654 |
-
|
1655 |
-
|
1656 |
-
|
1657 |
|
1658 |
-
|
1659 |
-
|
1660 |
-
|
1661 |
-
|
1662 |
-
|
1663 |
-
|
1664 |
-
|
1665 |
|
1666 |
-
|
1667 |
-
|
1668 |
-
|
1669 |
-
|
1670 |
-
|
1671 |
-
|
1672 |
-
|
1673 |
|
1674 |
-
|
1675 |
-
|
1676 |
-
|
1677 |
-
|
1678 |
-
|
1679 |
-
|
1680 |
-
|
1681 |
-
|
1682 |
-
|
1683 |
-
|
1684 |
-
|
1685 |
-
|
1686 |
-
|
1687 |
-
|
1688 |
-
|
1689 |
-
|
1690 |
-
|
1691 |
-
|
1692 |
-
|
1693 |
-
|
1694 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1695 |
else:
|
1696 |
-
|
1697 |
-
hidden_states,
|
1698 |
-
discrete_time_step,
|
1699 |
-
A,
|
1700 |
-
B.transpose(1, 2),
|
1701 |
-
C.transpose(1, 2),
|
1702 |
-
self.D[index].float(),
|
1703 |
-
z=gate,
|
1704 |
-
delta_bias=time_proj_bias,
|
1705 |
-
delta_softplus=True,
|
1706 |
-
return_last_state=True,
|
1707 |
-
)
|
1708 |
-
|
1709 |
-
if len(outputs) == 3:
|
1710 |
-
scan_outputs, ssm_state, _ = outputs
|
1711 |
-
else:
|
1712 |
-
scan_outputs, ssm_state = outputs
|
1713 |
|
1714 |
-
|
1715 |
-
|
1716 |
-
|
1717 |
-
|
1718 |
|
1719 |
-
|
1720 |
-
|
1721 |
|
1722 |
return contextualized_states, attn_key_value
|
1723 |
|
@@ -2037,6 +2025,49 @@ class HymbaPreTrainedModel(PreTrainedModel):
|
|
2037 |
|
2038 |
|
2039 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2040 |
HYMBA_INPUTS_DOCSTRING = r"""
|
2041 |
Args: To be added later. Please refer to the forward function.
|
2042 |
"""
|
@@ -2205,7 +2236,11 @@ class HymbaModel(HymbaPreTrainedModel):
|
|
2205 |
|
2206 |
if position_ids is not None and position_ids.shape[1] != inputs_embeds.shape[1]:
|
2207 |
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
|
2208 |
-
|
|
|
|
|
|
|
|
|
2209 |
attention_mask_raw = attention_mask
|
2210 |
|
2211 |
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
|
|
1579 |
def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None, position_ids=None, kv_last_layer=None, use_cache=False, use_swa=False):
|
1580 |
projected_states = self.in_proj(hidden_states).transpose(1, 2) ## (bs, latent_dim, seq_len)
|
1581 |
|
1582 |
+
## Handle padding for Mamba: Set padding tokens to 0
|
1583 |
+
if projected_states.shape[-1] > 1 and attention_mask is not None and (attention_mask == 0).any():
|
1584 |
+
projected_states = projected_states * attention_mask.unsqueeze(1).to(projected_states)
|
1585 |
+
|
1586 |
+
batch_size, seq_len, _ = hidden_states.shape
|
1587 |
+
use_precomputed_states = (
|
1588 |
+
cache_params is not None
|
1589 |
+
and cache_params.has_previous_state
|
1590 |
+
and seq_len == 1
|
1591 |
+
and cache_params.conv_states[self.layer_idx].shape[0]
|
1592 |
+
== cache_params.ssm_states[self.layer_idx].shape[0]
|
1593 |
+
== batch_size
|
1594 |
+
and use_cache
|
1595 |
+
)
|
|
|
|
|
|
|
|
|
1596 |
|
1597 |
+
hidden_states, gate = projected_states.tensor_split((self.latent_dim,), dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1598 |
|
1599 |
+
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
|
1600 |
|
1601 |
+
if self.reuse_kv:
|
1602 |
+
query_states, hidden_states = hidden_states.tensor_split((self.attn_hidden_size,), dim=1)
|
1603 |
+
query_states = query_states.transpose(1,2)
|
1604 |
+
else:
|
1605 |
+
query_states, key_states, value_states, hidden_states = hidden_states.tensor_split((self.attn_hidden_size, self.attn_hidden_size + self.k_hidden_size, self.attn_hidden_size + self.k_hidden_size + self.v_hidden_size), dim=1)
|
1606 |
+
|
1607 |
+
query_states = query_states.transpose(1,2)
|
1608 |
+
key_states = key_states.transpose(1,2)
|
1609 |
+
value_states = value_states.transpose(1,2)
|
1610 |
+
|
1611 |
+
if use_precomputed_states:
|
1612 |
+
hidden_states = causal_conv1d_update(
|
1613 |
+
hidden_states.squeeze(-1),
|
1614 |
+
cache_params.conv_states[self.layer_idx],
|
1615 |
+
conv_weights,
|
1616 |
+
self.conv1d.bias,
|
1617 |
+
self.activation,
|
1618 |
+
)
|
1619 |
+
hidden_states = hidden_states.unsqueeze(-1)
|
1620 |
|
1621 |
+
cache_params.mamba_past_length[self.layer_idx] += seq_len
|
1622 |
+
else:
|
1623 |
+
if cache_params is not None:
|
1624 |
+
conv_states = nn.functional.pad(
|
1625 |
+
hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1626 |
)
|
|
|
1627 |
|
1628 |
+
cache_params.conv_states[self.layer_idx].copy_(conv_states)
|
|
|
|
|
|
|
|
|
|
|
1629 |
|
1630 |
+
cache_params.mamba_past_length[self.layer_idx] += seq_len
|
1631 |
+
|
1632 |
+
hidden_states = causal_conv1d_fn(
|
1633 |
+
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
|
1634 |
+
)
|
1635 |
|
1636 |
+
## Handle padding for Mamba: Set padding tokens to 0
|
1637 |
+
if seq_len > 1 and attention_mask is not None and (attention_mask == 0).any():
|
1638 |
+
hidden_states = hidden_states * attention_mask.unsqueeze(1).to(hidden_states)
|
|
|
|
|
1639 |
|
1640 |
+
if self.reuse_kv:
|
1641 |
+
assert kv_last_layer is not None
|
1642 |
+
attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, kv_last_layer=kv_last_layer, use_swa=use_swa, use_cache=use_cache, past_key_value=cache_params)
|
1643 |
+
else:
|
1644 |
+
attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, key_states=key_states, value_states=value_states, use_swa=use_swa, use_cache=use_cache, past_key_value=cache_params)
|
1645 |
|
1646 |
+
## Mamba head
|
1647 |
+
index = 0
|
1648 |
+
ssm_parameters = self.x_proj[index](hidden_states.transpose(1, 2))
|
1649 |
+
time_step, B, C = torch.split(
|
1650 |
+
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
|
1651 |
+
)
|
1652 |
+
time_step, B, C = self._apply_layernorms(time_step, B, C)
|
1653 |
|
1654 |
+
if hasattr(self.dt_proj[index], "base_layer"):
|
1655 |
+
time_proj_bias = self.dt_proj[index].base_layer.bias
|
1656 |
+
self.dt_proj[index].base_layer.bias = None
|
1657 |
+
else:
|
1658 |
+
time_proj_bias = self.dt_proj[index].bias
|
1659 |
+
self.dt_proj[index].bias = None
|
1660 |
+
discrete_time_step = self.dt_proj[index](time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
|
1661 |
|
1662 |
+
if hasattr(self.dt_proj[index], "base_layer"):
|
1663 |
+
self.dt_proj[index].base_layer.bias = time_proj_bias
|
1664 |
+
else:
|
1665 |
+
self.dt_proj[index].bias = time_proj_bias
|
1666 |
+
|
1667 |
+
A = -torch.exp(self.A_log[index].float())
|
1668 |
+
|
1669 |
+
time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
|
1670 |
+
if use_precomputed_states:
|
1671 |
+
scan_outputs = selective_state_update(
|
1672 |
+
cache_params.ssm_states[self.layer_idx],
|
1673 |
+
hidden_states[..., 0],
|
1674 |
+
discrete_time_step[..., 0],
|
1675 |
+
A,
|
1676 |
+
B[:, 0],
|
1677 |
+
C[:, 0],
|
1678 |
+
self.D[index],
|
1679 |
+
gate[..., 0],
|
1680 |
+
time_proj_bias,
|
1681 |
+
dt_softplus=True,
|
1682 |
+
).unsqueeze(-1)
|
1683 |
+
else:
|
1684 |
+
outputs = selective_scan_fn(
|
1685 |
+
hidden_states,
|
1686 |
+
discrete_time_step,
|
1687 |
+
A,
|
1688 |
+
B.transpose(1, 2),
|
1689 |
+
C.transpose(1, 2),
|
1690 |
+
self.D[index].float(),
|
1691 |
+
z=gate,
|
1692 |
+
delta_bias=time_proj_bias,
|
1693 |
+
delta_softplus=True,
|
1694 |
+
return_last_state=True,
|
1695 |
+
)
|
1696 |
+
|
1697 |
+
if len(outputs) == 3:
|
1698 |
+
scan_outputs, ssm_state, _ = outputs
|
1699 |
else:
|
1700 |
+
scan_outputs, ssm_state = outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1701 |
|
1702 |
+
if ssm_state is not None and cache_params is not None:
|
1703 |
+
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
|
1704 |
+
|
1705 |
+
scan_outputs = scan_outputs.transpose(1, 2)
|
1706 |
|
1707 |
+
hidden_states = (self.pre_avg_layernorm1(attn_outputs) + self.pre_avg_layernorm2(scan_outputs)) / 2
|
1708 |
+
contextualized_states = self.out_proj(hidden_states)
|
1709 |
|
1710 |
return contextualized_states, attn_key_value
|
1711 |
|
|
|
2025 |
|
2026 |
|
2027 |
|
2028 |
+
def shift_zeros_to_front(attention_mask, hidden_states, position_ids):
|
2029 |
+
"""
|
2030 |
+
Move all zero entries in 'attention_mask' to the front of the sequence
|
2031 |
+
and reorder 'hidden_states' accordingly, preserving the order of zeros
|
2032 |
+
and the order of ones.
|
2033 |
+
|
2034 |
+
Args:
|
2035 |
+
attention_mask: (batch_size, seq_len), values in {0, 1}.
|
2036 |
+
hidden_states: (batch_size, seq_len, dim).
|
2037 |
+
|
2038 |
+
Returns:
|
2039 |
+
shifted_mask: (batch_size, seq_len) with zeros at the front.
|
2040 |
+
shifted_states: (batch_size, seq_len, dim) reordered accordingly.
|
2041 |
+
"""
|
2042 |
+
B, L = attention_mask.shape
|
2043 |
+
D = hidden_states.shape[-1]
|
2044 |
+
|
2045 |
+
shifted_mask = torch.empty_like(attention_mask)
|
2046 |
+
shifted_states = torch.empty_like(hidden_states)
|
2047 |
+
shifted_position_ids = torch.empty_like(position_ids)
|
2048 |
+
|
2049 |
+
# Process each batch row independently
|
2050 |
+
for b in range(B):
|
2051 |
+
row_mask = attention_mask[b] # (seq_len,)
|
2052 |
+
row_states = hidden_states[b] # (seq_len, dim)
|
2053 |
+
row_pos = position_ids[b] # (seq_len,)
|
2054 |
+
|
2055 |
+
# Find positions of zeros and ones
|
2056 |
+
zero_indices = torch.where(row_mask == 0)[0]
|
2057 |
+
one_indices = torch.where(row_mask == 1)[0]
|
2058 |
+
|
2059 |
+
# Concatenate zero indices (in order) then one indices
|
2060 |
+
new_order = torch.cat([zero_indices, one_indices], dim=0)
|
2061 |
+
|
2062 |
+
# Reorder mask and states
|
2063 |
+
shifted_mask[b] = row_mask[new_order]
|
2064 |
+
shifted_states[b] = row_states[new_order]
|
2065 |
+
shifted_position_ids[b] = row_pos[new_order]
|
2066 |
+
|
2067 |
+
return shifted_mask, shifted_states, shifted_position_ids
|
2068 |
+
|
2069 |
+
|
2070 |
+
|
2071 |
HYMBA_INPUTS_DOCSTRING = r"""
|
2072 |
Args: To be added later. Please refer to the forward function.
|
2073 |
"""
|
|
|
2236 |
|
2237 |
if position_ids is not None and position_ids.shape[1] != inputs_embeds.shape[1]:
|
2238 |
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
|
2239 |
+
|
2240 |
+
## Handle paddings: Shift all padding tokens to the beginning of the sequence
|
2241 |
+
if inputs_embeds.shape[1] > 1 and attention_mask is not None and (attention_mask == 0).any():
|
2242 |
+
attention_mask, inputs_embeds, position_ids = shift_zeros_to_front(attention_mask, inputs_embeds, position_ids)
|
2243 |
+
|
2244 |
attention_mask_raw = attention_mask
|
2245 |
|
2246 |
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|