update
Browse files- README.md +31 -1
- configuration_moonshine.py +0 -32
- modeling_moonshine.py +0 -512
- special_tokens_map.json +0 -1
- tokenizer_config.json +0 -0
README.md
CHANGED
@@ -6,7 +6,7 @@ library_name: transformers
|
|
6 |
pipeline_tag: automatic-speech-recognition
|
7 |
arxiv: https://arxiv.org/abs/2410.15608
|
8 |
---
|
9 |
-
#
|
10 |
|
11 |
[[Blog]](https://petewarden.com/2024/10/21/introducing-moonshine-the-new-state-of-the-art-for-speech-to-text/) [[Paper]](https://arxiv.org/abs/2410.15608) [[Installation]](https://github.com/usefulsensors/moonshine/blob/main/README.md) [[Podcast]](https://notebooklm.google.com/notebook/d787d6c2-7d7b-478c-b7d5-a0be4c74ae19/audio)
|
12 |
|
@@ -14,6 +14,36 @@ This is the model card for running the automatic speech recognition (ASR) models
|
|
14 |
|
15 |
Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about the automatic speech recognition model. More information on how these models were trained and evaluated can be found [in the paper](https://arxiv.org/abs/2410.15608). Note, a lot of the text has been copied verbatim from the [model card](https://github.com/openai/whisper/blob/main/model-card.md) for the Whisper model developed by OpenAI, because both models serve identical purposes, and carry identical risks.
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
## Model Details
|
18 |
|
19 |
The Moonshine models are trained for the speech recognition task, capable of transcribing English speech audio into English text. Useful Sensors developed the models to support their business direction of developing real time speech transcription products based on low cost hardware. There are 2 models of different sizes and capabilities, summarized in the following table.
|
|
|
6 |
pipeline_tag: automatic-speech-recognition
|
7 |
arxiv: https://arxiv.org/abs/2410.15608
|
8 |
---
|
9 |
+
# Moonshine
|
10 |
|
11 |
[[Blog]](https://petewarden.com/2024/10/21/introducing-moonshine-the-new-state-of-the-art-for-speech-to-text/) [[Paper]](https://arxiv.org/abs/2410.15608) [[Installation]](https://github.com/usefulsensors/moonshine/blob/main/README.md) [[Podcast]](https://notebooklm.google.com/notebook/d787d6c2-7d7b-478c-b7d5-a0be4c74ae19/audio)
|
12 |
|
|
|
14 |
|
15 |
Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about the automatic speech recognition model. More information on how these models were trained and evaluated can be found [in the paper](https://arxiv.org/abs/2410.15608). Note, a lot of the text has been copied verbatim from the [model card](https://github.com/openai/whisper/blob/main/model-card.md) for the Whisper model developed by OpenAI, because both models serve identical purposes, and carry identical risks.
|
16 |
|
17 |
+
## Usage
|
18 |
+
|
19 |
+
Moonshine is supported in Hugging Face 🤗 Transformers. To run the model, first install the Transformers library. For this example, we'll also install 🤗 Datasets to load toy audio dataset from the Hugging Face Hub, and 🤗 Accelerate to reduce the model loading time:
|
20 |
+
|
21 |
+
```bash
|
22 |
+
pip install --upgrade pip
|
23 |
+
pip install --upgrade transformers datasets[audio]
|
24 |
+
```
|
25 |
+
|
26 |
+
```python
|
27 |
+
from transformers import MoonshineForConditionalGeneration, AutoProcessor
|
28 |
+
from datasets import load_dataset, Audio
|
29 |
+
import torch
|
30 |
+
|
31 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
32 |
+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
33 |
+
|
34 |
+
model = MoonshineForConditionalGeneration.from_pretrained('UsefulSensors/moonshine-tiny').to(device).to(torch_dtype)
|
35 |
+
processor = AutoProcessor.from_pretrained('UsefulSensors/moonshine-tiny')
|
36 |
+
|
37 |
+
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
38 |
+
dataset = dataset.cast_column("audio", Audio(processor.feature_extractor.sampling_rate))
|
39 |
+
sample = dataset[0]["audio"]
|
40 |
+
|
41 |
+
inputs = processor(sample["array"], return_tensors="pt").to(device).to(torch_dtype)
|
42 |
+
|
43 |
+
generated_ids = model.generate(**inputs)
|
44 |
+
print(processor.decode(generated_ids[0], skip_special_tokens=True))
|
45 |
+
```
|
46 |
+
|
47 |
## Model Details
|
48 |
|
49 |
The Moonshine models are trained for the speech recognition task, capable of transcribing English speech audio into English text. Useful Sensors developed the models to support their business direction of developing real time speech transcription products based on low cost hardware. There are 2 models of different sizes and capabilities, summarized in the following table.
|
configuration_moonshine.py
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
from transformers import PretrainedConfig
|
2 |
-
from typing import List
|
3 |
-
|
4 |
-
|
5 |
-
class MoonshineConfig(PretrainedConfig):
|
6 |
-
model_type = "moonshine"
|
7 |
-
|
8 |
-
def __init__(
|
9 |
-
self,
|
10 |
-
dim: int = 288,
|
11 |
-
inner_dim: int = None,
|
12 |
-
enc_depth: int = 8,
|
13 |
-
dec_depth: int = 8,
|
14 |
-
n_head: int = 8,
|
15 |
-
dec_voc_size: int = 32768,
|
16 |
-
enc_ff_swiglu: bool = False,
|
17 |
-
dec_ff_swiglu: bool = True,
|
18 |
-
**kwargs
|
19 |
-
):
|
20 |
-
if inner_dim is None:
|
21 |
-
inner_dim = dim
|
22 |
-
if inner_dim % n_head != 0:
|
23 |
-
raise ValueError("`inner dim` must be divisible by `n_head`")
|
24 |
-
self.dim = dim
|
25 |
-
self.inner_dim = inner_dim
|
26 |
-
self.enc_depth = enc_depth
|
27 |
-
self.dec_depth = dec_depth
|
28 |
-
self.n_head = n_head
|
29 |
-
self.dec_voc_size = dec_voc_size
|
30 |
-
self.enc_ff_swiglu = enc_ff_swiglu
|
31 |
-
self.dec_ff_swiglu = dec_ff_swiglu
|
32 |
-
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modeling_moonshine.py
DELETED
@@ -1,512 +0,0 @@
|
|
1 |
-
from einops import rearrange
|
2 |
-
from einops.layers.torch import Rearrange
|
3 |
-
from torch import nn
|
4 |
-
from transformers import PreTrainedModel
|
5 |
-
|
6 |
-
import math
|
7 |
-
import torch
|
8 |
-
|
9 |
-
from .configuration_moonshine import MoonshineConfig
|
10 |
-
|
11 |
-
|
12 |
-
class RotaryEmbedding(nn.Module):
|
13 |
-
def __init__(self, dim, base=10000):
|
14 |
-
super().__init__()
|
15 |
-
|
16 |
-
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
17 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
18 |
-
|
19 |
-
def forward(self, t):
|
20 |
-
freqs = torch.einsum("i , j -> i j", t.type_as(self.inv_freq), self.inv_freq)
|
21 |
-
freqs = torch.stack((freqs, freqs), dim=-1)
|
22 |
-
return rearrange(freqs, "... d r -> ... (d r)")
|
23 |
-
|
24 |
-
|
25 |
-
def rotate_half(x):
|
26 |
-
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
27 |
-
x1, x2 = x.unbind(dim=-1)
|
28 |
-
x = torch.stack((-x2, x1), dim=-1)
|
29 |
-
return rearrange(x, "... d r -> ... (d r)")
|
30 |
-
|
31 |
-
|
32 |
-
def apply_rotary_pos_emb(t, freqs):
|
33 |
-
rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
|
34 |
-
|
35 |
-
freqs = freqs[-seq_len:, :]
|
36 |
-
|
37 |
-
# partial rotary embeddings, Wang et al. GPT-J
|
38 |
-
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
39 |
-
t = t * freqs.cos() + rotate_half(t) * freqs.sin()
|
40 |
-
out = torch.cat((t, t_unrotated), dim=-1)
|
41 |
-
|
42 |
-
return out.type(orig_dtype)
|
43 |
-
|
44 |
-
|
45 |
-
class MultiHeadAttention(nn.Module):
|
46 |
-
def __init__(self, dim, inner_dim, n_head):
|
47 |
-
super().__init__()
|
48 |
-
self.n_head = n_head
|
49 |
-
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
50 |
-
self.to_k = nn.Linear(dim, inner_dim, bias=False)
|
51 |
-
self.to_v = nn.Linear(dim, inner_dim, bias=False)
|
52 |
-
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
53 |
-
self.softmax = nn.Softmax(dim=-1)
|
54 |
-
|
55 |
-
# Scaled dot product attention
|
56 |
-
def sdp_attention(self, q, k_t, v, mask=None):
|
57 |
-
d_tensor = v.shape[3]
|
58 |
-
|
59 |
-
op = (q @ k_t) / math.sqrt(d_tensor)
|
60 |
-
if mask is not None:
|
61 |
-
op = op.masked_fill(mask, -torch.finfo(op.dtype).max)
|
62 |
-
score = self.softmax(op)
|
63 |
-
out = score @ v
|
64 |
-
|
65 |
-
# concat and pass to linear layer
|
66 |
-
out = rearrange(out, "b h n d -> b n (h d)")
|
67 |
-
return self.to_out(out)
|
68 |
-
|
69 |
-
def forward(self, q, k, v, rot_pos_emb=None, mask=None):
|
70 |
-
# dot product with weight matrices
|
71 |
-
q, k, v = self.to_q(q), self.to_k(k), self.to_v(v)
|
72 |
-
|
73 |
-
q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
|
74 |
-
k = rearrange(k, "b n (h d) -> b h n d", h=self.n_head)
|
75 |
-
v = rearrange(v, "b n (h d) -> b h n d", h=self.n_head)
|
76 |
-
|
77 |
-
# apply RoPE
|
78 |
-
if rot_pos_emb is not None:
|
79 |
-
q = apply_rotary_pos_emb(q, rot_pos_emb)
|
80 |
-
k = apply_rotary_pos_emb(k, rot_pos_emb)
|
81 |
-
|
82 |
-
k_t = k.transpose(2, 3)
|
83 |
-
|
84 |
-
return self.sdp_attention(q, k_t, v, mask), k_t, v
|
85 |
-
|
86 |
-
|
87 |
-
class MultiHeadCausalSelfAttentionWithKVCache(MultiHeadAttention):
|
88 |
-
def __init__(self, dim, inner_dim, n_head):
|
89 |
-
super().__init__(dim, inner_dim, n_head)
|
90 |
-
|
91 |
-
def forward(self, q, k, v, k_cache, v_cache, rot_pos_emb, mask):
|
92 |
-
# dot product with weight matrices
|
93 |
-
q, k, v = self.to_q(q), self.to_k(k), self.to_v(v)
|
94 |
-
|
95 |
-
q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
|
96 |
-
k = rearrange(k, "b n (h d) -> b h n d", h=self.n_head)
|
97 |
-
v = rearrange(v, "b n (h d) -> b h n d", h=self.n_head)
|
98 |
-
|
99 |
-
# apply RoPE
|
100 |
-
q = apply_rotary_pos_emb(q, rot_pos_emb)
|
101 |
-
k = apply_rotary_pos_emb(k, rot_pos_emb)
|
102 |
-
|
103 |
-
k_t = k.transpose(2, 3)
|
104 |
-
|
105 |
-
# Append new rows to K and V caches.
|
106 |
-
k_t = torch.concat((k_cache, k_t), dim=3)
|
107 |
-
v = torch.concat((v_cache, v), dim=2)
|
108 |
-
|
109 |
-
return super().sdp_attention(q, k_t, v, mask=mask), k_t, v
|
110 |
-
|
111 |
-
|
112 |
-
class MultiHeadCrossAttentionWithKVCache(MultiHeadAttention):
|
113 |
-
def __init__(self, dim, inner_dim, n_head):
|
114 |
-
super().__init__(dim, inner_dim, n_head)
|
115 |
-
|
116 |
-
def forward(self, q, k_cache, v_cache, mask):
|
117 |
-
q = self.to_q(q)
|
118 |
-
q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
|
119 |
-
|
120 |
-
return super().sdp_attention(q, k_cache, v_cache, mask=mask)
|
121 |
-
|
122 |
-
|
123 |
-
class FFLinearGelu(nn.Module):
|
124 |
-
def __init__(self, dim, ff_mult=4):
|
125 |
-
super().__init__()
|
126 |
-
|
127 |
-
self.ff = nn.Sequential(
|
128 |
-
nn.Linear(dim, dim * ff_mult, bias=True),
|
129 |
-
nn.GELU(),
|
130 |
-
nn.Linear(dim * ff_mult, dim, bias=True),
|
131 |
-
)
|
132 |
-
|
133 |
-
def forward(self, x):
|
134 |
-
return self.ff(x)
|
135 |
-
|
136 |
-
|
137 |
-
class FFSwiGLU(nn.Module):
|
138 |
-
def __init__(self, dim, ff_mult=4):
|
139 |
-
super().__init__()
|
140 |
-
|
141 |
-
self.ff_proj = nn.Linear(dim, dim * ff_mult, bias=True)
|
142 |
-
self.ff_noact = nn.Linear(dim, dim * ff_mult, bias=True)
|
143 |
-
self.ff_act = nn.SiLU()
|
144 |
-
self.ff_out = nn.Linear(dim * ff_mult, dim, bias=True)
|
145 |
-
|
146 |
-
def forward(self, x):
|
147 |
-
gate = self.ff_act(self.ff_proj(x))
|
148 |
-
x_noact = self.ff_noact(x)
|
149 |
-
x = x_noact * gate
|
150 |
-
return self.ff_out(x)
|
151 |
-
|
152 |
-
|
153 |
-
class EncoderLayer(nn.Module):
|
154 |
-
def __init__(self, dim, inner_dim, n_head, ff_swiglu, ff_mult=4):
|
155 |
-
super().__init__()
|
156 |
-
|
157 |
-
self.norm1 = nn.LayerNorm(dim, bias=False)
|
158 |
-
|
159 |
-
self.attention = MultiHeadAttention(dim, inner_dim=inner_dim, n_head=n_head)
|
160 |
-
|
161 |
-
self.norm2 = nn.LayerNorm(dim, bias=False)
|
162 |
-
|
163 |
-
self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
|
164 |
-
|
165 |
-
def forward(self, x, rot_pos_emb, mask):
|
166 |
-
_x = x
|
167 |
-
x = self.norm1(x)
|
168 |
-
x, _, _ = self.attention(q=x, k=x, v=x, rot_pos_emb=rot_pos_emb, mask=mask)
|
169 |
-
x = x + _x
|
170 |
-
|
171 |
-
_x = x
|
172 |
-
x = self.norm2(x)
|
173 |
-
x = self.ff(x)
|
174 |
-
|
175 |
-
x = x + _x
|
176 |
-
return x
|
177 |
-
|
178 |
-
|
179 |
-
class Encoder(nn.Module):
|
180 |
-
def __init__(self, dim, inner_dim, n_head, n_layers, ff_swiglu):
|
181 |
-
super().__init__()
|
182 |
-
rot_embed_dim = max(inner_dim / n_head / 2, 32)
|
183 |
-
self.rot_pos_emb = RotaryEmbedding(rot_embed_dim)
|
184 |
-
|
185 |
-
self.layers = nn.ModuleList(
|
186 |
-
[EncoderLayer(dim, inner_dim, n_head, ff_swiglu) for _ in range(n_layers)]
|
187 |
-
)
|
188 |
-
self.post_norm = nn.LayerNorm(dim, bias=False)
|
189 |
-
|
190 |
-
def forward(self, x, mask):
|
191 |
-
pos = torch.arange(x.shape[-2], device=x.device)
|
192 |
-
rot_pos_emb = self.rot_pos_emb(pos)
|
193 |
-
|
194 |
-
for idx, layer in enumerate(self.layers):
|
195 |
-
x = layer(x, rot_pos_emb=rot_pos_emb, mask=mask)
|
196 |
-
return self.post_norm(x)
|
197 |
-
|
198 |
-
|
199 |
-
class DecoderLayer(nn.Module):
|
200 |
-
def __init__(self, dim, inner_dim, n_head, ff_swiglu, ff_mult=4):
|
201 |
-
super().__init__()
|
202 |
-
|
203 |
-
self.norm1 = nn.LayerNorm(dim, bias=False)
|
204 |
-
|
205 |
-
self.self_attention = MultiHeadCausalSelfAttentionWithKVCache(
|
206 |
-
dim, inner_dim=inner_dim, n_head=n_head
|
207 |
-
)
|
208 |
-
|
209 |
-
self.norm2 = nn.LayerNorm(dim, bias=False)
|
210 |
-
self.cross_attention = MultiHeadCrossAttentionWithKVCache(
|
211 |
-
dim, inner_dim=inner_dim, n_head=n_head
|
212 |
-
)
|
213 |
-
|
214 |
-
self.norm3 = nn.LayerNorm(dim, bias=False)
|
215 |
-
self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
|
216 |
-
|
217 |
-
def forward(self, x, k_cache, v_cache, x_attn_k_cache, x_attn_v_cache, rot_pos_emb, input_mask):
|
218 |
-
dim = x.size()[1]
|
219 |
-
causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
|
220 |
-
_x = x
|
221 |
-
x = self.norm1(x)
|
222 |
-
x, new_k_cache, new_v_cache = self.self_attention(
|
223 |
-
q=x,
|
224 |
-
k=x,
|
225 |
-
v=x,
|
226 |
-
k_cache=k_cache,
|
227 |
-
v_cache=v_cache,
|
228 |
-
rot_pos_emb=rot_pos_emb,
|
229 |
-
mask=causal_mask,
|
230 |
-
)
|
231 |
-
x = x + _x
|
232 |
-
|
233 |
-
_x = x
|
234 |
-
x = self.norm2(x)
|
235 |
-
x = self.cross_attention(q=x, k_cache=x_attn_k_cache, v_cache=x_attn_v_cache, mask=input_mask)
|
236 |
-
x = x + _x
|
237 |
-
|
238 |
-
_x = x
|
239 |
-
x = self.norm3(x)
|
240 |
-
x = self.ff(x)
|
241 |
-
x = x + _x
|
242 |
-
|
243 |
-
return x, new_k_cache, new_v_cache
|
244 |
-
|
245 |
-
|
246 |
-
class Decoder(nn.Module):
|
247 |
-
def __init__(self, dim, inner_dim, n_head, n_layers, dec_voc_size, ff_swiglu):
|
248 |
-
super().__init__()
|
249 |
-
|
250 |
-
self.n_head = n_head
|
251 |
-
self.d_head = inner_dim // n_head
|
252 |
-
|
253 |
-
rot_embed_dim = max(inner_dim / n_head / 2, 32)
|
254 |
-
self.rot_pos_emb = RotaryEmbedding(rot_embed_dim)
|
255 |
-
|
256 |
-
self.layers = nn.ModuleList(
|
257 |
-
[DecoderLayer(dim, inner_dim, n_head, ff_swiglu) for _ in range(n_layers)]
|
258 |
-
)
|
259 |
-
self.final_norm = nn.LayerNorm(dim, bias=False)
|
260 |
-
self.token_embedding = nn.Embedding(dec_voc_size, dim)
|
261 |
-
|
262 |
-
def forward(self, x, input_mask, *args):
|
263 |
-
pos = torch.arange(x.shape[1], device=x.device)
|
264 |
-
rot_pos_emb = self.rot_pos_emb(pos)
|
265 |
-
x = self.token_embedding(x)
|
266 |
-
|
267 |
-
k_cache_new = []
|
268 |
-
v_cache_new = []
|
269 |
-
|
270 |
-
n_layer = len(self.layers)
|
271 |
-
k_cache, v_cache, x_attn_k_cache, x_attn_v_cache = [
|
272 |
-
args[i : i + n_layer] for i in range(0, 4 * n_layer, n_layer)
|
273 |
-
]
|
274 |
-
for idx, layer in enumerate(self.layers):
|
275 |
-
x, new_k_line, new_v_line = layer(
|
276 |
-
x[:, -1:],
|
277 |
-
k_cache=k_cache[idx],
|
278 |
-
v_cache=v_cache[idx],
|
279 |
-
x_attn_k_cache=x_attn_k_cache[idx],
|
280 |
-
x_attn_v_cache=x_attn_v_cache[idx],
|
281 |
-
rot_pos_emb=rot_pos_emb,
|
282 |
-
input_mask=input_mask,
|
283 |
-
)
|
284 |
-
k_cache_new.append(new_k_line)
|
285 |
-
v_cache_new.append(new_v_line)
|
286 |
-
|
287 |
-
x = self.final_norm(x)
|
288 |
-
|
289 |
-
return x @ self.token_embedding.weight.t(), *k_cache_new, *v_cache_new
|
290 |
-
|
291 |
-
|
292 |
-
class InitialDecoderLayer(nn.Module):
|
293 |
-
def __init__(self, dim, inner_dim, n_head, ff_swiglu, ff_mult=4):
|
294 |
-
super().__init__()
|
295 |
-
|
296 |
-
self.norm1 = nn.LayerNorm(dim, bias=False)
|
297 |
-
|
298 |
-
self.self_attention = MultiHeadAttention(
|
299 |
-
dim, inner_dim=inner_dim, n_head=n_head
|
300 |
-
)
|
301 |
-
|
302 |
-
self.norm2 = nn.LayerNorm(dim, bias=False)
|
303 |
-
self.cross_attention = MultiHeadAttention(
|
304 |
-
dim, inner_dim=inner_dim, n_head=n_head
|
305 |
-
)
|
306 |
-
|
307 |
-
self.norm3 = nn.LayerNorm(dim, bias=False)
|
308 |
-
self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
|
309 |
-
|
310 |
-
def forward(self, x, context, rot_pos_emb, input_mask):
|
311 |
-
dim = x.size()[1]
|
312 |
-
causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
|
313 |
-
_x = x
|
314 |
-
x = self.norm1(x)
|
315 |
-
x, new_k_cache, new_v_cache = self.self_attention(
|
316 |
-
q=x,
|
317 |
-
k=x,
|
318 |
-
v=x,
|
319 |
-
rot_pos_emb=rot_pos_emb,
|
320 |
-
mask=causal_mask,
|
321 |
-
)
|
322 |
-
x = x + _x
|
323 |
-
|
324 |
-
_x = x
|
325 |
-
x = self.norm2(x)
|
326 |
-
x, x_attn_k_cache, x_attn_v_cache = self.cross_attention(
|
327 |
-
q=x, k=context, v=context, mask=input_mask,
|
328 |
-
)
|
329 |
-
x = x + _x
|
330 |
-
|
331 |
-
_x = x
|
332 |
-
x = self.norm3(x)
|
333 |
-
x = self.ff(x)
|
334 |
-
x = x + _x
|
335 |
-
|
336 |
-
return x, new_k_cache, new_v_cache, x_attn_k_cache, x_attn_v_cache
|
337 |
-
|
338 |
-
|
339 |
-
class DecoderInitial(Decoder):
|
340 |
-
def __init__(self, dim, inner_dim, n_head, n_layers, dec_voc_size, ff_swiglu):
|
341 |
-
super().__init__(dim, inner_dim, n_head, n_layers, dec_voc_size, ff_swiglu)
|
342 |
-
self.layers = nn.ModuleList(
|
343 |
-
[
|
344 |
-
InitialDecoderLayer(dim, inner_dim, n_head, ff_swiglu)
|
345 |
-
for _ in range(n_layers)
|
346 |
-
]
|
347 |
-
)
|
348 |
-
|
349 |
-
def forward(self, x, enc_src, input_mask):
|
350 |
-
pos = torch.arange(x.shape[1], device=x.device)
|
351 |
-
rot_pos_emb = self.rot_pos_emb(pos)
|
352 |
-
x = self.token_embedding(x)
|
353 |
-
|
354 |
-
# Shape [n_layers, batch_size, n_head, seq_len, inner_dim]. Cache K transposed.
|
355 |
-
n_layer = len(self.layers)
|
356 |
-
k_cache = []
|
357 |
-
v_cache = []
|
358 |
-
x_attn_k_cache = []
|
359 |
-
x_attn_v_cache = []
|
360 |
-
|
361 |
-
for idx, layer in enumerate(self.layers):
|
362 |
-
x, new_k_line, new_v_line, new_x_attn_k_line, new_x_attn_v_line = layer(
|
363 |
-
x,
|
364 |
-
enc_src,
|
365 |
-
rot_pos_emb,
|
366 |
-
input_mask,
|
367 |
-
)
|
368 |
-
|
369 |
-
k_cache.append(new_k_line)
|
370 |
-
v_cache.append(new_v_line)
|
371 |
-
x_attn_k_cache.append(new_x_attn_k_line)
|
372 |
-
x_attn_v_cache.append(new_x_attn_v_line)
|
373 |
-
|
374 |
-
x = self.final_norm(x)
|
375 |
-
|
376 |
-
return (
|
377 |
-
x @ self.token_embedding.weight.t(),
|
378 |
-
*k_cache,
|
379 |
-
*v_cache,
|
380 |
-
*x_attn_k_cache,
|
381 |
-
*x_attn_v_cache,
|
382 |
-
)
|
383 |
-
|
384 |
-
|
385 |
-
class AudioPreprocessor(nn.Module):
|
386 |
-
def __init__(self, dim):
|
387 |
-
super().__init__()
|
388 |
-
self.audio_preprocess = nn.Sequential(
|
389 |
-
nn.Conv1d(1, dim, 127, 64, bias=False),
|
390 |
-
nn.Tanh(),
|
391 |
-
nn.GroupNorm(1, dim),
|
392 |
-
nn.Conv1d(dim, 2 * dim, 7, 3),
|
393 |
-
nn.GELU(),
|
394 |
-
nn.Conv1d(2 * dim, dim, 3, 2),
|
395 |
-
nn.GELU(),
|
396 |
-
Rearrange("... c s -> ... s c"),
|
397 |
-
)
|
398 |
-
|
399 |
-
def forward(self, src):
|
400 |
-
assert (
|
401 |
-
src.shape[-1] >= 1023
|
402 |
-
), f"src shape[-1] {src.shape[-1]} should be at least 1023"
|
403 |
-
src = src.reshape((-1, 1, src.shape[-1]))
|
404 |
-
return self.audio_preprocess(src)
|
405 |
-
|
406 |
-
|
407 |
-
class MoonshineModelTorch(nn.Module):
|
408 |
-
def __init__(
|
409 |
-
self,
|
410 |
-
dim,
|
411 |
-
inner_dim,
|
412 |
-
enc_depth,
|
413 |
-
dec_depth,
|
414 |
-
n_head=8,
|
415 |
-
dec_voc_size=32768,
|
416 |
-
enc_ff_swiglu=False,
|
417 |
-
dec_ff_swiglu=False,
|
418 |
-
):
|
419 |
-
super().__init__()
|
420 |
-
self.preprocessor = AudioPreprocessor(dim)
|
421 |
-
self.encoder = Encoder(
|
422 |
-
dim, inner_dim, n_head, enc_depth, ff_swiglu=enc_ff_swiglu
|
423 |
-
)
|
424 |
-
self.decoder_initial = DecoderInitial(
|
425 |
-
dim, inner_dim, n_head, dec_depth, dec_voc_size, ff_swiglu=dec_ff_swiglu
|
426 |
-
)
|
427 |
-
self.decoder = Decoder(
|
428 |
-
dim, inner_dim, n_head, dec_depth, dec_voc_size, ff_swiglu=dec_ff_swiglu
|
429 |
-
)
|
430 |
-
self.dec_depth = dec_depth
|
431 |
-
self.n_head = n_head
|
432 |
-
self.d_head = inner_dim // n_head
|
433 |
-
|
434 |
-
def generate(self, src, mask):
|
435 |
-
preprocessed = self.preprocessor(src)
|
436 |
-
batch_size = preprocessed.shape[0]
|
437 |
-
|
438 |
-
# Get max sequence length based on number of unmasked inputs for each sample in batch.
|
439 |
-
token_limit_factor = 6.5 / 16000.0 # Maximum of 6.5 tokens per second.
|
440 |
-
if mask is not None:
|
441 |
-
seq_lens = torch.sum(mask, dim=-1, keepdim=True) * token_limit_factor
|
442 |
-
else:
|
443 |
-
token_limit = torch.tensor([src.shape[-1] * token_limit_factor])
|
444 |
-
seq_lens = torch.stack([token_limit for _ in range(batch_size)])
|
445 |
-
seq_lens = seq_lens.to(torch.int32).to(src.device).squeeze()
|
446 |
-
|
447 |
-
# Preprocess mask so that it matches preprocessed audio.
|
448 |
-
if mask is not None:
|
449 |
-
mask = mask[..., :-127:64][..., :-7:3][..., :-3:2].to(torch.bool)
|
450 |
-
mask = ~mask.reshape((batch_size, 1, 1, -1))
|
451 |
-
mask = torch.nn.functional.pad(mask, (0, preprocessed.shape[-2] - mask.shape[-1]))
|
452 |
-
|
453 |
-
enc = self.encoder(preprocessed, mask)
|
454 |
-
|
455 |
-
sot_token = 1
|
456 |
-
eot_token = 2
|
457 |
-
|
458 |
-
sot_array = [[sot_token] for _ in range(batch_size)]
|
459 |
-
seq = torch.as_tensor(sot_array).to(src.device)
|
460 |
-
|
461 |
-
vals = self.decoder_initial(x=seq, enc_src=enc, input_mask=mask)
|
462 |
-
logits = vals[0]
|
463 |
-
k_cache, v_cache, x_attn_k_cache, x_attn_v_cache = [
|
464 |
-
vals[i : i + self.dec_depth]
|
465 |
-
for i in range(1, 1 + self.dec_depth * 4, self.dec_depth)
|
466 |
-
]
|
467 |
-
|
468 |
-
sample = logits[:, -1].argmax(dim=-1, keepdim=True)
|
469 |
-
seq = torch.cat((seq, sample), dim=-1)
|
470 |
-
|
471 |
-
eot_mask = torch.zeros((batch_size), dtype=torch.bool).to(src.device)
|
472 |
-
while not torch.all(eot_mask):
|
473 |
-
vals = self.decoder(
|
474 |
-
seq,
|
475 |
-
mask,
|
476 |
-
*k_cache,
|
477 |
-
*v_cache,
|
478 |
-
*x_attn_k_cache,
|
479 |
-
*x_attn_v_cache,
|
480 |
-
)
|
481 |
-
logits = vals[0]
|
482 |
-
k_cache = vals[1 : self.dec_depth + 1]
|
483 |
-
v_cache = vals[self.dec_depth + 1 :]
|
484 |
-
logits = logits[:, -1] # get last token
|
485 |
-
sample = logits.argmax(dim=-1, keepdim=True)
|
486 |
-
# For each sample in batch detect EOT or token limit reached.
|
487 |
-
eot_mask = eot_mask | (sample.squeeze() == eot_token)
|
488 |
-
eot_mask = eot_mask | (seq.shape[-1] >= seq_lens)
|
489 |
-
sample = sample.masked_fill(eot_mask.reshape((-1, 1)), eot_token)
|
490 |
-
seq = torch.cat((seq, sample), dim=-1)
|
491 |
-
|
492 |
-
return seq
|
493 |
-
|
494 |
-
|
495 |
-
class MoonshineModel(PreTrainedModel):
|
496 |
-
config_class = MoonshineConfig
|
497 |
-
|
498 |
-
def __init__(self, config):
|
499 |
-
super().__init__(config)
|
500 |
-
self.model = MoonshineModelTorch(
|
501 |
-
dim = config.dim,
|
502 |
-
inner_dim = config.inner_dim,
|
503 |
-
enc_depth = config.enc_depth,
|
504 |
-
dec_depth = config.dec_depth,
|
505 |
-
n_head = config.n_head,
|
506 |
-
dec_voc_size = config.dec_voc_size,
|
507 |
-
enc_ff_swiglu = config.enc_ff_swiglu,
|
508 |
-
dec_ff_swiglu = config.dec_ff_swiglu,
|
509 |
-
)
|
510 |
-
|
511 |
-
def forward(self, tensor, input_mask=None):
|
512 |
-
return self.model.generate(tensor, input_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
special_tokens_map.json
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
{}
|
|
|
|
tokenizer_config.json
DELETED
The diff for this file is too large to render.
See raw diff
|
|