Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +7 -29
modeling_gemmoe.py
CHANGED
@@ -96,10 +96,8 @@ class GemmoeRMSNorm(nn.Module):
|
|
96 |
normed_x = normed_x.type_as(x)
|
97 |
return normed_x * (self.weight + 1)
|
98 |
|
99 |
-
|
100 |
ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
|
101 |
|
102 |
-
|
103 |
class GemmoeRotaryEmbedding(nn.Module):
|
104 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
105 |
super().__init__()
|
@@ -110,10 +108,11 @@ class GemmoeRotaryEmbedding(nn.Module):
|
|
110 |
|
111 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
112 |
self.max_seq_len_cached = seq_len
|
113 |
-
freq_exponents = (torch.arange(self.dim // 2, dtype=torch.float32, device="cpu").float())
|
114 |
-
timescale = self.base **
|
115 |
positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.float32).float()
|
116 |
-
radians_new = positions[..., None] / timescale[None, :]
|
|
|
117 |
emb = torch.cat((radians_new, radians_new), dim=-1)
|
118 |
cos = emb.cos().to(device=device, dtype=dtype, non_blocking=True)
|
119 |
sin = emb.sin().to(device=device, dtype=dtype, non_blocking=True)
|
@@ -127,7 +126,6 @@ class GemmoeRotaryEmbedding(nn.Module):
|
|
127 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
128 |
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
|
129 |
|
130 |
-
|
131 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
132 |
def rotate_half(x):
|
133 |
"""Rotates half the hidden dims of the input."""
|
@@ -136,34 +134,14 @@ def rotate_half(x):
|
|
136 |
return torch.cat((-x2, x1), dim=-1)
|
137 |
|
138 |
|
139 |
-
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
140 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
q (`torch.Tensor`): The query tensor.
|
145 |
-
k (`torch.Tensor`): The key tensor.
|
146 |
-
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
147 |
-
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
148 |
-
position_ids (`torch.Tensor`, *optional*):
|
149 |
-
Deprecated and unused.
|
150 |
-
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
151 |
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
152 |
-
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
153 |
-
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
154 |
-
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
155 |
-
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
156 |
-
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
157 |
-
Returns:
|
158 |
-
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
159 |
-
"""
|
160 |
-
cos = cos.unsqueeze(unsqueeze_dim)
|
161 |
-
sin = sin.unsqueeze(unsqueeze_dim)
|
162 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
163 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
164 |
return q_embed, k_embed
|
165 |
|
166 |
-
|
167 |
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Gemmoe
|
168 |
class GemmoeMLP(nn.Module):
|
169 |
def __init__(self, config):
|
|
|
96 |
normed_x = normed_x.type_as(x)
|
97 |
return normed_x * (self.weight + 1)
|
98 |
|
|
|
99 |
ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
|
100 |
|
|
|
101 |
class GemmoeRotaryEmbedding(nn.Module):
|
102 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
103 |
super().__init__()
|
|
|
108 |
|
109 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
110 |
self.max_seq_len_cached = seq_len
|
111 |
+
freq_exponents = (2.0 / self.dim) * (torch.arange(self.dim // 2, dtype=torch.float32, device="cpu").float())
|
112 |
+
timescale = self.base ** freq_exponents
|
113 |
positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.float32).float()
|
114 |
+
radians_new = positions[..., None] / timescale[None, None, :]
|
115 |
+
radians_new = radians_new.squeeze(0)
|
116 |
emb = torch.cat((radians_new, radians_new), dim=-1)
|
117 |
cos = emb.cos().to(device=device, dtype=dtype, non_blocking=True)
|
118 |
sin = emb.sin().to(device=device, dtype=dtype, non_blocking=True)
|
|
|
126 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
127 |
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
|
128 |
|
|
|
129 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
130 |
def rotate_half(x):
|
131 |
"""Rotates half the hidden dims of the input."""
|
|
|
134 |
return torch.cat((-x2, x1), dim=-1)
|
135 |
|
136 |
|
|
|
137 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
138 |
+
seq_len, dim = q.shape[-2], q.shape[-1]
|
139 |
+
cos = cos[:seq_len].view(1, 1, seq_len, dim)
|
140 |
+
sin = sin[:seq_len].view(1, 1, seq_len, dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
142 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
143 |
return q_embed, k_embed
|
144 |
|
|
|
145 |
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Gemmoe
|
146 |
class GemmoeMLP(nn.Module):
|
147 |
def __init__(self, config):
|