Crystalcareai commited on
Commit
fe54712
·
verified ·
1 Parent(s): 28b5873

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +7 -29
modeling_gemmoe.py CHANGED
@@ -96,10 +96,8 @@ class GemmoeRMSNorm(nn.Module):
96
  normed_x = normed_x.type_as(x)
97
  return normed_x * (self.weight + 1)
98
 
99
-
100
  ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
101
 
102
-
103
  class GemmoeRotaryEmbedding(nn.Module):
104
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
105
  super().__init__()
@@ -110,10 +108,11 @@ class GemmoeRotaryEmbedding(nn.Module):
110
 
111
  def _set_cos_sin_cache(self, seq_len, device, dtype):
112
  self.max_seq_len_cached = seq_len
113
- freq_exponents = (torch.arange(self.dim // 2, dtype=torch.float32, device="cpu").float())
114
- timescale = self.base ** (freq_exponents / (self.dim / 2))
115
  positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.float32).float()
116
- radians_new = positions[..., None] / timescale[None, :]
 
117
  emb = torch.cat((radians_new, radians_new), dim=-1)
118
  cos = emb.cos().to(device=device, dtype=dtype, non_blocking=True)
119
  sin = emb.sin().to(device=device, dtype=dtype, non_blocking=True)
@@ -127,7 +126,6 @@ class GemmoeRotaryEmbedding(nn.Module):
127
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
128
  return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
129
 
130
-
131
  # Copied from transformers.models.llama.modeling_llama.rotate_half
132
  def rotate_half(x):
133
  """Rotates half the hidden dims of the input."""
@@ -136,34 +134,14 @@ def rotate_half(x):
136
  return torch.cat((-x2, x1), dim=-1)
137
 
138
 
139
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
140
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
141
- """Applies Rotary Position Embedding to the query and key tensors.
142
-
143
- Args:
144
- q (`torch.Tensor`): The query tensor.
145
- k (`torch.Tensor`): The key tensor.
146
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
147
- sin (`torch.Tensor`): The sine part of the rotary embedding.
148
- position_ids (`torch.Tensor`, *optional*):
149
- Deprecated and unused.
150
- unsqueeze_dim (`int`, *optional*, defaults to 1):
151
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
152
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
153
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
154
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
155
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
156
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
157
- Returns:
158
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
159
- """
160
- cos = cos.unsqueeze(unsqueeze_dim)
161
- sin = sin.unsqueeze(unsqueeze_dim)
162
  q_embed = (q * cos) + (rotate_half(q) * sin)
163
  k_embed = (k * cos) + (rotate_half(k) * sin)
164
  return q_embed, k_embed
165
 
166
-
167
  # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Gemmoe
168
  class GemmoeMLP(nn.Module):
169
  def __init__(self, config):
 
96
  normed_x = normed_x.type_as(x)
97
  return normed_x * (self.weight + 1)
98
 
 
99
  ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
100
 
 
101
  class GemmoeRotaryEmbedding(nn.Module):
102
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
103
  super().__init__()
 
108
 
109
  def _set_cos_sin_cache(self, seq_len, device, dtype):
110
  self.max_seq_len_cached = seq_len
111
+ freq_exponents = (2.0 / self.dim) * (torch.arange(self.dim // 2, dtype=torch.float32, device="cpu").float())
112
+ timescale = self.base ** freq_exponents
113
  positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.float32).float()
114
+ radians_new = positions[..., None] / timescale[None, None, :]
115
+ radians_new = radians_new.squeeze(0)
116
  emb = torch.cat((radians_new, radians_new), dim=-1)
117
  cos = emb.cos().to(device=device, dtype=dtype, non_blocking=True)
118
  sin = emb.sin().to(device=device, dtype=dtype, non_blocking=True)
 
126
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
127
  return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
128
 
 
129
  # Copied from transformers.models.llama.modeling_llama.rotate_half
130
  def rotate_half(x):
131
  """Rotates half the hidden dims of the input."""
 
134
  return torch.cat((-x2, x1), dim=-1)
135
 
136
 
 
137
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
138
+ seq_len, dim = q.shape[-2], q.shape[-1]
139
+ cos = cos[:seq_len].view(1, 1, seq_len, dim)
140
+ sin = sin[:seq_len].view(1, 1, seq_len, dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  q_embed = (q * cos) + (rotate_half(q) * sin)
142
  k_embed = (k * cos) + (rotate_half(k) * sin)
143
  return q_embed, k_embed
144
 
 
145
  # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Gemmoe
146
  class GemmoeMLP(nn.Module):
147
  def __init__(self, config):