YongganFu commited on
Commit
5a4f554
·
verified ·
1 Parent(s): a7335b0

Upload HymbaForCausalLM

Browse files
Files changed (2) hide show
  1. config.json +9 -9
  2. modeling_hymba.py +160 -125
config.json CHANGED
@@ -15,14 +15,6 @@
15
  "conv_dim": {
16
  "0": 3200,
17
  "1": 3200,
18
- "2": 3200,
19
- "3": 3200,
20
- "4": 3200,
21
- "5": 3200,
22
- "6": 3200,
23
- "7": 3200,
24
- "8": 3200,
25
- "9": 3200,
26
  "10": 3200,
27
  "11": 3200,
28
  "12": 3200,
@@ -33,6 +25,7 @@
33
  "17": 3200,
34
  "18": 3200,
35
  "19": 3200,
 
36
  "20": 3200,
37
  "21": 3200,
38
  "22": 3200,
@@ -43,8 +36,15 @@
43
  "27": 3200,
44
  "28": 3200,
45
  "29": 3200,
 
46
  "30": 3200,
47
- "31": 3200
 
 
 
 
 
 
48
  },
49
  "eos_token_id": 2,
50
  "global_attn_idx": [
 
15
  "conv_dim": {
16
  "0": 3200,
17
  "1": 3200,
 
 
 
 
 
 
 
 
18
  "10": 3200,
19
  "11": 3200,
20
  "12": 3200,
 
25
  "17": 3200,
26
  "18": 3200,
27
  "19": 3200,
28
+ "2": 3200,
29
  "20": 3200,
30
  "21": 3200,
31
  "22": 3200,
 
36
  "27": 3200,
37
  "28": 3200,
38
  "29": 3200,
39
+ "3": 3200,
40
  "30": 3200,
41
+ "31": 3200,
42
+ "4": 3200,
43
+ "5": 3200,
44
+ "6": 3200,
45
+ "7": 3200,
46
+ "8": 3200,
47
+ "9": 3200
48
  },
49
  "eos_token_id": 2,
50
  "global_attn_idx": [
modeling_hymba.py CHANGED
@@ -1579,145 +1579,133 @@ class HymbaBlock(nn.Module):
1579
  def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None, position_ids=None, kv_last_layer=None, use_cache=False, use_swa=False):
1580
  projected_states = self.in_proj(hidden_states).transpose(1, 2) ## (bs, latent_dim, seq_len)
1581
 
1582
- if (
1583
- self.training and cache_params is None and not self.apply_inner_layernorms
1584
- ): # Doesn't support outputting the states -> used for training
1585
- contextualized_states = mamba_inner_fn(
1586
- projected_states,
1587
- self.conv1d.weight,
1588
- self.conv1d.bias if self.use_conv_bias else None,
1589
- self.x_proj.weight,
1590
- self.dt_proj.weight,
1591
- self.out_proj.weight,
1592
- self.out_proj.bias.float() if self.use_bias else None,
1593
- -torch.exp(self.A_log.float()),
1594
- None, # input-dependent B
1595
- None, # input-dependent C
1596
- self.D.float(),
1597
- delta_bias=self.dt_proj.bias.float(),
1598
- delta_softplus=True,
1599
- )
1600
 
1601
- else:
1602
- batch_size, seq_len, _ = hidden_states.shape
1603
- use_precomputed_states = (
1604
- cache_params is not None
1605
- and cache_params.has_previous_state
1606
- and seq_len == 1
1607
- and cache_params.conv_states[self.layer_idx].shape[0]
1608
- == cache_params.ssm_states[self.layer_idx].shape[0]
1609
- == batch_size
1610
- and use_cache
1611
- )
1612
 
1613
- hidden_states, gate = projected_states.tensor_split((self.latent_dim,), dim=1)
1614
 
1615
- conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1616
 
1617
- if self.reuse_kv:
1618
- query_states, hidden_states = hidden_states.tensor_split((self.attn_hidden_size,), dim=1)
1619
- query_states = query_states.transpose(1,2)
1620
- else:
1621
- query_states, key_states, value_states, hidden_states = hidden_states.tensor_split((self.attn_hidden_size, self.attn_hidden_size + self.k_hidden_size, self.attn_hidden_size + self.k_hidden_size + self.v_hidden_size), dim=1)
1622
-
1623
- query_states = query_states.transpose(1,2)
1624
- key_states = key_states.transpose(1,2)
1625
- value_states = value_states.transpose(1,2)
1626
-
1627
- if use_precomputed_states:
1628
- hidden_states = causal_conv1d_update(
1629
- hidden_states.squeeze(-1),
1630
- cache_params.conv_states[self.layer_idx],
1631
- conv_weights,
1632
- self.conv1d.bias,
1633
- self.activation,
1634
  )
1635
- hidden_states = hidden_states.unsqueeze(-1)
1636
 
1637
- cache_params.mamba_past_length[self.layer_idx] += seq_len
1638
- else:
1639
- if cache_params is not None:
1640
- conv_states = nn.functional.pad(
1641
- hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
1642
- )
1643
 
1644
- cache_params.conv_states[self.layer_idx].copy_(conv_states)
 
 
 
 
1645
 
1646
- cache_params.mamba_past_length[self.layer_idx] += seq_len
1647
-
1648
- hidden_states = causal_conv1d_fn(
1649
- hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
1650
- )
1651
 
1652
- if self.reuse_kv:
1653
- assert kv_last_layer is not None
1654
- attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, kv_last_layer=kv_last_layer, use_swa=use_swa, use_cache=use_cache, past_key_value=cache_params)
1655
- else:
1656
- attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, key_states=key_states, value_states=value_states, use_swa=use_swa, use_cache=use_cache, past_key_value=cache_params)
1657
 
1658
- ## Mamba head
1659
- index = 0
1660
- ssm_parameters = self.x_proj[index](hidden_states.transpose(1, 2))
1661
- time_step, B, C = torch.split(
1662
- ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
1663
- )
1664
- time_step, B, C = self._apply_layernorms(time_step, B, C)
1665
 
1666
- if hasattr(self.dt_proj[index], "base_layer"):
1667
- time_proj_bias = self.dt_proj[index].base_layer.bias
1668
- self.dt_proj[index].base_layer.bias = None
1669
- else:
1670
- time_proj_bias = self.dt_proj[index].bias
1671
- self.dt_proj[index].bias = None
1672
- discrete_time_step = self.dt_proj[index](time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
1673
 
1674
- if hasattr(self.dt_proj[index], "base_layer"):
1675
- self.dt_proj[index].base_layer.bias = time_proj_bias
1676
- else:
1677
- self.dt_proj[index].bias = time_proj_bias
1678
-
1679
- A = -torch.exp(self.A_log[index].float())
1680
-
1681
- time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
1682
- if use_precomputed_states:
1683
- scan_outputs = selective_state_update(
1684
- cache_params.ssm_states[self.layer_idx],
1685
- hidden_states[..., 0],
1686
- discrete_time_step[..., 0],
1687
- A,
1688
- B[:, 0],
1689
- C[:, 0],
1690
- self.D[index],
1691
- gate[..., 0],
1692
- time_proj_bias,
1693
- dt_softplus=True,
1694
- ).unsqueeze(-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1695
  else:
1696
- outputs = selective_scan_fn(
1697
- hidden_states,
1698
- discrete_time_step,
1699
- A,
1700
- B.transpose(1, 2),
1701
- C.transpose(1, 2),
1702
- self.D[index].float(),
1703
- z=gate,
1704
- delta_bias=time_proj_bias,
1705
- delta_softplus=True,
1706
- return_last_state=True,
1707
- )
1708
-
1709
- if len(outputs) == 3:
1710
- scan_outputs, ssm_state, _ = outputs
1711
- else:
1712
- scan_outputs, ssm_state = outputs
1713
 
1714
- if ssm_state is not None and cache_params is not None:
1715
- cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
1716
-
1717
- scan_outputs = scan_outputs.transpose(1, 2)
1718
 
1719
- hidden_states = (self.pre_avg_layernorm1(attn_outputs) + self.pre_avg_layernorm2(scan_outputs)) / 2
1720
- contextualized_states = self.out_proj(hidden_states)
1721
 
1722
  return contextualized_states, attn_key_value
1723
 
@@ -2037,6 +2025,49 @@ class HymbaPreTrainedModel(PreTrainedModel):
2037
 
2038
 
2039
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2040
  HYMBA_INPUTS_DOCSTRING = r"""
2041
  Args: To be added later. Please refer to the forward function.
2042
  """
@@ -2205,7 +2236,11 @@ class HymbaModel(HymbaPreTrainedModel):
2205
 
2206
  if position_ids is not None and position_ids.shape[1] != inputs_embeds.shape[1]:
2207
  position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
2208
-
 
 
 
 
2209
  attention_mask_raw = attention_mask
2210
 
2211
  if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
 
1579
  def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None, position_ids=None, kv_last_layer=None, use_cache=False, use_swa=False):
1580
  projected_states = self.in_proj(hidden_states).transpose(1, 2) ## (bs, latent_dim, seq_len)
1581
 
1582
+ ## Handle padding for Mamba: Set padding tokens to 0
1583
+ if projected_states.shape[-1] > 1 and attention_mask is not None and (attention_mask == 0).any():
1584
+ projected_states = projected_states * attention_mask.unsqueeze(1).to(projected_states)
1585
+
1586
+ batch_size, seq_len, _ = hidden_states.shape
1587
+ use_precomputed_states = (
1588
+ cache_params is not None
1589
+ and cache_params.has_previous_state
1590
+ and seq_len == 1
1591
+ and cache_params.conv_states[self.layer_idx].shape[0]
1592
+ == cache_params.ssm_states[self.layer_idx].shape[0]
1593
+ == batch_size
1594
+ and use_cache
1595
+ )
 
 
 
 
1596
 
1597
+ hidden_states, gate = projected_states.tensor_split((self.latent_dim,), dim=1)
 
 
 
 
 
 
 
 
 
 
1598
 
1599
+ conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
1600
 
1601
+ if self.reuse_kv:
1602
+ query_states, hidden_states = hidden_states.tensor_split((self.attn_hidden_size,), dim=1)
1603
+ query_states = query_states.transpose(1,2)
1604
+ else:
1605
+ query_states, key_states, value_states, hidden_states = hidden_states.tensor_split((self.attn_hidden_size, self.attn_hidden_size + self.k_hidden_size, self.attn_hidden_size + self.k_hidden_size + self.v_hidden_size), dim=1)
1606
+
1607
+ query_states = query_states.transpose(1,2)
1608
+ key_states = key_states.transpose(1,2)
1609
+ value_states = value_states.transpose(1,2)
1610
+
1611
+ if use_precomputed_states:
1612
+ hidden_states = causal_conv1d_update(
1613
+ hidden_states.squeeze(-1),
1614
+ cache_params.conv_states[self.layer_idx],
1615
+ conv_weights,
1616
+ self.conv1d.bias,
1617
+ self.activation,
1618
+ )
1619
+ hidden_states = hidden_states.unsqueeze(-1)
1620
 
1621
+ cache_params.mamba_past_length[self.layer_idx] += seq_len
1622
+ else:
1623
+ if cache_params is not None:
1624
+ conv_states = nn.functional.pad(
1625
+ hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
 
 
 
 
 
 
 
 
 
 
 
 
1626
  )
 
1627
 
1628
+ cache_params.conv_states[self.layer_idx].copy_(conv_states)
 
 
 
 
 
1629
 
1630
+ cache_params.mamba_past_length[self.layer_idx] += seq_len
1631
+
1632
+ hidden_states = causal_conv1d_fn(
1633
+ hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
1634
+ )
1635
 
1636
+ ## Handle padding for Mamba: Set padding tokens to 0
1637
+ if seq_len > 1 and attention_mask is not None and (attention_mask == 0).any():
1638
+ hidden_states = hidden_states * attention_mask.unsqueeze(1).to(hidden_states)
 
 
1639
 
1640
+ if self.reuse_kv:
1641
+ assert kv_last_layer is not None
1642
+ attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, kv_last_layer=kv_last_layer, use_swa=use_swa, use_cache=use_cache, past_key_value=cache_params)
1643
+ else:
1644
+ attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, key_states=key_states, value_states=value_states, use_swa=use_swa, use_cache=use_cache, past_key_value=cache_params)
1645
 
1646
+ ## Mamba head
1647
+ index = 0
1648
+ ssm_parameters = self.x_proj[index](hidden_states.transpose(1, 2))
1649
+ time_step, B, C = torch.split(
1650
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
1651
+ )
1652
+ time_step, B, C = self._apply_layernorms(time_step, B, C)
1653
 
1654
+ if hasattr(self.dt_proj[index], "base_layer"):
1655
+ time_proj_bias = self.dt_proj[index].base_layer.bias
1656
+ self.dt_proj[index].base_layer.bias = None
1657
+ else:
1658
+ time_proj_bias = self.dt_proj[index].bias
1659
+ self.dt_proj[index].bias = None
1660
+ discrete_time_step = self.dt_proj[index](time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
1661
 
1662
+ if hasattr(self.dt_proj[index], "base_layer"):
1663
+ self.dt_proj[index].base_layer.bias = time_proj_bias
1664
+ else:
1665
+ self.dt_proj[index].bias = time_proj_bias
1666
+
1667
+ A = -torch.exp(self.A_log[index].float())
1668
+
1669
+ time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
1670
+ if use_precomputed_states:
1671
+ scan_outputs = selective_state_update(
1672
+ cache_params.ssm_states[self.layer_idx],
1673
+ hidden_states[..., 0],
1674
+ discrete_time_step[..., 0],
1675
+ A,
1676
+ B[:, 0],
1677
+ C[:, 0],
1678
+ self.D[index],
1679
+ gate[..., 0],
1680
+ time_proj_bias,
1681
+ dt_softplus=True,
1682
+ ).unsqueeze(-1)
1683
+ else:
1684
+ outputs = selective_scan_fn(
1685
+ hidden_states,
1686
+ discrete_time_step,
1687
+ A,
1688
+ B.transpose(1, 2),
1689
+ C.transpose(1, 2),
1690
+ self.D[index].float(),
1691
+ z=gate,
1692
+ delta_bias=time_proj_bias,
1693
+ delta_softplus=True,
1694
+ return_last_state=True,
1695
+ )
1696
+
1697
+ if len(outputs) == 3:
1698
+ scan_outputs, ssm_state, _ = outputs
1699
  else:
1700
+ scan_outputs, ssm_state = outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1701
 
1702
+ if ssm_state is not None and cache_params is not None:
1703
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
1704
+
1705
+ scan_outputs = scan_outputs.transpose(1, 2)
1706
 
1707
+ hidden_states = (self.pre_avg_layernorm1(attn_outputs) + self.pre_avg_layernorm2(scan_outputs)) / 2
1708
+ contextualized_states = self.out_proj(hidden_states)
1709
 
1710
  return contextualized_states, attn_key_value
1711
 
 
2025
 
2026
 
2027
 
2028
+ def shift_zeros_to_front(attention_mask, hidden_states, position_ids):
2029
+ """
2030
+ Move all zero entries in 'attention_mask' to the front of the sequence
2031
+ and reorder 'hidden_states' accordingly, preserving the order of zeros
2032
+ and the order of ones.
2033
+
2034
+ Args:
2035
+ attention_mask: (batch_size, seq_len), values in {0, 1}.
2036
+ hidden_states: (batch_size, seq_len, dim).
2037
+
2038
+ Returns:
2039
+ shifted_mask: (batch_size, seq_len) with zeros at the front.
2040
+ shifted_states: (batch_size, seq_len, dim) reordered accordingly.
2041
+ """
2042
+ B, L = attention_mask.shape
2043
+ D = hidden_states.shape[-1]
2044
+
2045
+ shifted_mask = torch.empty_like(attention_mask)
2046
+ shifted_states = torch.empty_like(hidden_states)
2047
+ shifted_position_ids = torch.empty_like(position_ids)
2048
+
2049
+ # Process each batch row independently
2050
+ for b in range(B):
2051
+ row_mask = attention_mask[b] # (seq_len,)
2052
+ row_states = hidden_states[b] # (seq_len, dim)
2053
+ row_pos = position_ids[b] # (seq_len,)
2054
+
2055
+ # Find positions of zeros and ones
2056
+ zero_indices = torch.where(row_mask == 0)[0]
2057
+ one_indices = torch.where(row_mask == 1)[0]
2058
+
2059
+ # Concatenate zero indices (in order) then one indices
2060
+ new_order = torch.cat([zero_indices, one_indices], dim=0)
2061
+
2062
+ # Reorder mask and states
2063
+ shifted_mask[b] = row_mask[new_order]
2064
+ shifted_states[b] = row_states[new_order]
2065
+ shifted_position_ids[b] = row_pos[new_order]
2066
+
2067
+ return shifted_mask, shifted_states, shifted_position_ids
2068
+
2069
+
2070
+
2071
  HYMBA_INPUTS_DOCSTRING = r"""
2072
  Args: To be added later. Please refer to the forward function.
2073
  """
 
2236
 
2237
  if position_ids is not None and position_ids.shape[1] != inputs_embeds.shape[1]:
2238
  position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
2239
+
2240
+ ## Handle paddings: Shift all padding tokens to the beginning of the sequence
2241
+ if inputs_embeds.shape[1] > 1 and attention_mask is not None and (attention_mask == 0).any():
2242
+ attention_mask, inputs_embeds, position_ids = shift_zeros_to_front(attention_mask, inputs_embeds, position_ids)
2243
+
2244
  attention_mask_raw = attention_mask
2245
 
2246
  if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: