eustlb HF staff commited on
Commit
193333b
·
1 Parent(s): 08138cc
README.md CHANGED
@@ -6,7 +6,7 @@ library_name: transformers
6
  pipeline_tag: automatic-speech-recognition
7
  arxiv: https://arxiv.org/abs/2410.15608
8
  ---
9
- # Model Card: Moonshine
10
 
11
  [[Blog]](https://petewarden.com/2024/10/21/introducing-moonshine-the-new-state-of-the-art-for-speech-to-text/) [[Paper]](https://arxiv.org/abs/2410.15608) [[Installation]](https://github.com/usefulsensors/moonshine/blob/main/README.md) [[Podcast]](https://notebooklm.google.com/notebook/d787d6c2-7d7b-478c-b7d5-a0be4c74ae19/audio)
12
 
@@ -14,6 +14,36 @@ This is the model card for running the automatic speech recognition (ASR) models
14
 
15
  Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about the automatic speech recognition model. More information on how these models were trained and evaluated can be found [in the paper](https://arxiv.org/abs/2410.15608). Note, a lot of the text has been copied verbatim from the [model card](https://github.com/openai/whisper/blob/main/model-card.md) for the Whisper model developed by OpenAI, because both models serve identical purposes, and carry identical risks.
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  ## Model Details
18
 
19
  The Moonshine models are trained for the speech recognition task, capable of transcribing English speech audio into English text. Useful Sensors developed the models to support their business direction of developing real time speech transcription products based on low cost hardware. There are 2 models of different sizes and capabilities, summarized in the following table.
 
6
  pipeline_tag: automatic-speech-recognition
7
  arxiv: https://arxiv.org/abs/2410.15608
8
  ---
9
+ # Moonshine
10
 
11
  [[Blog]](https://petewarden.com/2024/10/21/introducing-moonshine-the-new-state-of-the-art-for-speech-to-text/) [[Paper]](https://arxiv.org/abs/2410.15608) [[Installation]](https://github.com/usefulsensors/moonshine/blob/main/README.md) [[Podcast]](https://notebooklm.google.com/notebook/d787d6c2-7d7b-478c-b7d5-a0be4c74ae19/audio)
12
 
 
14
 
15
  Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about the automatic speech recognition model. More information on how these models were trained and evaluated can be found [in the paper](https://arxiv.org/abs/2410.15608). Note, a lot of the text has been copied verbatim from the [model card](https://github.com/openai/whisper/blob/main/model-card.md) for the Whisper model developed by OpenAI, because both models serve identical purposes, and carry identical risks.
16
 
17
+ ## Usage
18
+
19
+ Moonshine is supported in Hugging Face 🤗 Transformers. To run the model, first install the Transformers library. For this example, we'll also install 🤗 Datasets to load toy audio dataset from the Hugging Face Hub, and 🤗 Accelerate to reduce the model loading time:
20
+
21
+ ```bash
22
+ pip install --upgrade pip
23
+ pip install --upgrade transformers datasets[audio]
24
+ ```
25
+
26
+ ```python
27
+ from transformers import MoonshineForConditionalGeneration, AutoProcessor
28
+ from datasets import load_dataset, Audio
29
+ import torch
30
+
31
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
32
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
33
+
34
+ model = MoonshineForConditionalGeneration.from_pretrained('UsefulSensors/moonshine-tiny').to(device).to(torch_dtype)
35
+ processor = AutoProcessor.from_pretrained('UsefulSensors/moonshine-tiny')
36
+
37
+ dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
38
+ dataset = dataset.cast_column("audio", Audio(processor.feature_extractor.sampling_rate))
39
+ sample = dataset[0]["audio"]
40
+
41
+ inputs = processor(sample["array"], return_tensors="pt").to(device).to(torch_dtype)
42
+
43
+ generated_ids = model.generate(**inputs)
44
+ print(processor.decode(generated_ids[0], skip_special_tokens=True))
45
+ ```
46
+
47
  ## Model Details
48
 
49
  The Moonshine models are trained for the speech recognition task, capable of transcribing English speech audio into English text. Useful Sensors developed the models to support their business direction of developing real time speech transcription products based on low cost hardware. There are 2 models of different sizes and capabilities, summarized in the following table.
configuration_moonshine.py DELETED
@@ -1,32 +0,0 @@
1
- from transformers import PretrainedConfig
2
- from typing import List
3
-
4
-
5
- class MoonshineConfig(PretrainedConfig):
6
- model_type = "moonshine"
7
-
8
- def __init__(
9
- self,
10
- dim: int = 288,
11
- inner_dim: int = None,
12
- enc_depth: int = 8,
13
- dec_depth: int = 8,
14
- n_head: int = 8,
15
- dec_voc_size: int = 32768,
16
- enc_ff_swiglu: bool = False,
17
- dec_ff_swiglu: bool = True,
18
- **kwargs
19
- ):
20
- if inner_dim is None:
21
- inner_dim = dim
22
- if inner_dim % n_head != 0:
23
- raise ValueError("`inner dim` must be divisible by `n_head`")
24
- self.dim = dim
25
- self.inner_dim = inner_dim
26
- self.enc_depth = enc_depth
27
- self.dec_depth = dec_depth
28
- self.n_head = n_head
29
- self.dec_voc_size = dec_voc_size
30
- self.enc_ff_swiglu = enc_ff_swiglu
31
- self.dec_ff_swiglu = dec_ff_swiglu
32
- super().__init__(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_moonshine.py DELETED
@@ -1,512 +0,0 @@
1
- from einops import rearrange
2
- from einops.layers.torch import Rearrange
3
- from torch import nn
4
- from transformers import PreTrainedModel
5
-
6
- import math
7
- import torch
8
-
9
- from .configuration_moonshine import MoonshineConfig
10
-
11
-
12
- class RotaryEmbedding(nn.Module):
13
- def __init__(self, dim, base=10000):
14
- super().__init__()
15
-
16
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
17
- self.register_buffer("inv_freq", inv_freq, persistent=False)
18
-
19
- def forward(self, t):
20
- freqs = torch.einsum("i , j -> i j", t.type_as(self.inv_freq), self.inv_freq)
21
- freqs = torch.stack((freqs, freqs), dim=-1)
22
- return rearrange(freqs, "... d r -> ... (d r)")
23
-
24
-
25
- def rotate_half(x):
26
- x = rearrange(x, "... (d r) -> ... d r", r=2)
27
- x1, x2 = x.unbind(dim=-1)
28
- x = torch.stack((-x2, x1), dim=-1)
29
- return rearrange(x, "... d r -> ... (d r)")
30
-
31
-
32
- def apply_rotary_pos_emb(t, freqs):
33
- rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
34
-
35
- freqs = freqs[-seq_len:, :]
36
-
37
- # partial rotary embeddings, Wang et al. GPT-J
38
- t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
39
- t = t * freqs.cos() + rotate_half(t) * freqs.sin()
40
- out = torch.cat((t, t_unrotated), dim=-1)
41
-
42
- return out.type(orig_dtype)
43
-
44
-
45
- class MultiHeadAttention(nn.Module):
46
- def __init__(self, dim, inner_dim, n_head):
47
- super().__init__()
48
- self.n_head = n_head
49
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
50
- self.to_k = nn.Linear(dim, inner_dim, bias=False)
51
- self.to_v = nn.Linear(dim, inner_dim, bias=False)
52
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
53
- self.softmax = nn.Softmax(dim=-1)
54
-
55
- # Scaled dot product attention
56
- def sdp_attention(self, q, k_t, v, mask=None):
57
- d_tensor = v.shape[3]
58
-
59
- op = (q @ k_t) / math.sqrt(d_tensor)
60
- if mask is not None:
61
- op = op.masked_fill(mask, -torch.finfo(op.dtype).max)
62
- score = self.softmax(op)
63
- out = score @ v
64
-
65
- # concat and pass to linear layer
66
- out = rearrange(out, "b h n d -> b n (h d)")
67
- return self.to_out(out)
68
-
69
- def forward(self, q, k, v, rot_pos_emb=None, mask=None):
70
- # dot product with weight matrices
71
- q, k, v = self.to_q(q), self.to_k(k), self.to_v(v)
72
-
73
- q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
74
- k = rearrange(k, "b n (h d) -> b h n d", h=self.n_head)
75
- v = rearrange(v, "b n (h d) -> b h n d", h=self.n_head)
76
-
77
- # apply RoPE
78
- if rot_pos_emb is not None:
79
- q = apply_rotary_pos_emb(q, rot_pos_emb)
80
- k = apply_rotary_pos_emb(k, rot_pos_emb)
81
-
82
- k_t = k.transpose(2, 3)
83
-
84
- return self.sdp_attention(q, k_t, v, mask), k_t, v
85
-
86
-
87
- class MultiHeadCausalSelfAttentionWithKVCache(MultiHeadAttention):
88
- def __init__(self, dim, inner_dim, n_head):
89
- super().__init__(dim, inner_dim, n_head)
90
-
91
- def forward(self, q, k, v, k_cache, v_cache, rot_pos_emb, mask):
92
- # dot product with weight matrices
93
- q, k, v = self.to_q(q), self.to_k(k), self.to_v(v)
94
-
95
- q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
96
- k = rearrange(k, "b n (h d) -> b h n d", h=self.n_head)
97
- v = rearrange(v, "b n (h d) -> b h n d", h=self.n_head)
98
-
99
- # apply RoPE
100
- q = apply_rotary_pos_emb(q, rot_pos_emb)
101
- k = apply_rotary_pos_emb(k, rot_pos_emb)
102
-
103
- k_t = k.transpose(2, 3)
104
-
105
- # Append new rows to K and V caches.
106
- k_t = torch.concat((k_cache, k_t), dim=3)
107
- v = torch.concat((v_cache, v), dim=2)
108
-
109
- return super().sdp_attention(q, k_t, v, mask=mask), k_t, v
110
-
111
-
112
- class MultiHeadCrossAttentionWithKVCache(MultiHeadAttention):
113
- def __init__(self, dim, inner_dim, n_head):
114
- super().__init__(dim, inner_dim, n_head)
115
-
116
- def forward(self, q, k_cache, v_cache, mask):
117
- q = self.to_q(q)
118
- q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
119
-
120
- return super().sdp_attention(q, k_cache, v_cache, mask=mask)
121
-
122
-
123
- class FFLinearGelu(nn.Module):
124
- def __init__(self, dim, ff_mult=4):
125
- super().__init__()
126
-
127
- self.ff = nn.Sequential(
128
- nn.Linear(dim, dim * ff_mult, bias=True),
129
- nn.GELU(),
130
- nn.Linear(dim * ff_mult, dim, bias=True),
131
- )
132
-
133
- def forward(self, x):
134
- return self.ff(x)
135
-
136
-
137
- class FFSwiGLU(nn.Module):
138
- def __init__(self, dim, ff_mult=4):
139
- super().__init__()
140
-
141
- self.ff_proj = nn.Linear(dim, dim * ff_mult, bias=True)
142
- self.ff_noact = nn.Linear(dim, dim * ff_mult, bias=True)
143
- self.ff_act = nn.SiLU()
144
- self.ff_out = nn.Linear(dim * ff_mult, dim, bias=True)
145
-
146
- def forward(self, x):
147
- gate = self.ff_act(self.ff_proj(x))
148
- x_noact = self.ff_noact(x)
149
- x = x_noact * gate
150
- return self.ff_out(x)
151
-
152
-
153
- class EncoderLayer(nn.Module):
154
- def __init__(self, dim, inner_dim, n_head, ff_swiglu, ff_mult=4):
155
- super().__init__()
156
-
157
- self.norm1 = nn.LayerNorm(dim, bias=False)
158
-
159
- self.attention = MultiHeadAttention(dim, inner_dim=inner_dim, n_head=n_head)
160
-
161
- self.norm2 = nn.LayerNorm(dim, bias=False)
162
-
163
- self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
164
-
165
- def forward(self, x, rot_pos_emb, mask):
166
- _x = x
167
- x = self.norm1(x)
168
- x, _, _ = self.attention(q=x, k=x, v=x, rot_pos_emb=rot_pos_emb, mask=mask)
169
- x = x + _x
170
-
171
- _x = x
172
- x = self.norm2(x)
173
- x = self.ff(x)
174
-
175
- x = x + _x
176
- return x
177
-
178
-
179
- class Encoder(nn.Module):
180
- def __init__(self, dim, inner_dim, n_head, n_layers, ff_swiglu):
181
- super().__init__()
182
- rot_embed_dim = max(inner_dim / n_head / 2, 32)
183
- self.rot_pos_emb = RotaryEmbedding(rot_embed_dim)
184
-
185
- self.layers = nn.ModuleList(
186
- [EncoderLayer(dim, inner_dim, n_head, ff_swiglu) for _ in range(n_layers)]
187
- )
188
- self.post_norm = nn.LayerNorm(dim, bias=False)
189
-
190
- def forward(self, x, mask):
191
- pos = torch.arange(x.shape[-2], device=x.device)
192
- rot_pos_emb = self.rot_pos_emb(pos)
193
-
194
- for idx, layer in enumerate(self.layers):
195
- x = layer(x, rot_pos_emb=rot_pos_emb, mask=mask)
196
- return self.post_norm(x)
197
-
198
-
199
- class DecoderLayer(nn.Module):
200
- def __init__(self, dim, inner_dim, n_head, ff_swiglu, ff_mult=4):
201
- super().__init__()
202
-
203
- self.norm1 = nn.LayerNorm(dim, bias=False)
204
-
205
- self.self_attention = MultiHeadCausalSelfAttentionWithKVCache(
206
- dim, inner_dim=inner_dim, n_head=n_head
207
- )
208
-
209
- self.norm2 = nn.LayerNorm(dim, bias=False)
210
- self.cross_attention = MultiHeadCrossAttentionWithKVCache(
211
- dim, inner_dim=inner_dim, n_head=n_head
212
- )
213
-
214
- self.norm3 = nn.LayerNorm(dim, bias=False)
215
- self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
216
-
217
- def forward(self, x, k_cache, v_cache, x_attn_k_cache, x_attn_v_cache, rot_pos_emb, input_mask):
218
- dim = x.size()[1]
219
- causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
220
- _x = x
221
- x = self.norm1(x)
222
- x, new_k_cache, new_v_cache = self.self_attention(
223
- q=x,
224
- k=x,
225
- v=x,
226
- k_cache=k_cache,
227
- v_cache=v_cache,
228
- rot_pos_emb=rot_pos_emb,
229
- mask=causal_mask,
230
- )
231
- x = x + _x
232
-
233
- _x = x
234
- x = self.norm2(x)
235
- x = self.cross_attention(q=x, k_cache=x_attn_k_cache, v_cache=x_attn_v_cache, mask=input_mask)
236
- x = x + _x
237
-
238
- _x = x
239
- x = self.norm3(x)
240
- x = self.ff(x)
241
- x = x + _x
242
-
243
- return x, new_k_cache, new_v_cache
244
-
245
-
246
- class Decoder(nn.Module):
247
- def __init__(self, dim, inner_dim, n_head, n_layers, dec_voc_size, ff_swiglu):
248
- super().__init__()
249
-
250
- self.n_head = n_head
251
- self.d_head = inner_dim // n_head
252
-
253
- rot_embed_dim = max(inner_dim / n_head / 2, 32)
254
- self.rot_pos_emb = RotaryEmbedding(rot_embed_dim)
255
-
256
- self.layers = nn.ModuleList(
257
- [DecoderLayer(dim, inner_dim, n_head, ff_swiglu) for _ in range(n_layers)]
258
- )
259
- self.final_norm = nn.LayerNorm(dim, bias=False)
260
- self.token_embedding = nn.Embedding(dec_voc_size, dim)
261
-
262
- def forward(self, x, input_mask, *args):
263
- pos = torch.arange(x.shape[1], device=x.device)
264
- rot_pos_emb = self.rot_pos_emb(pos)
265
- x = self.token_embedding(x)
266
-
267
- k_cache_new = []
268
- v_cache_new = []
269
-
270
- n_layer = len(self.layers)
271
- k_cache, v_cache, x_attn_k_cache, x_attn_v_cache = [
272
- args[i : i + n_layer] for i in range(0, 4 * n_layer, n_layer)
273
- ]
274
- for idx, layer in enumerate(self.layers):
275
- x, new_k_line, new_v_line = layer(
276
- x[:, -1:],
277
- k_cache=k_cache[idx],
278
- v_cache=v_cache[idx],
279
- x_attn_k_cache=x_attn_k_cache[idx],
280
- x_attn_v_cache=x_attn_v_cache[idx],
281
- rot_pos_emb=rot_pos_emb,
282
- input_mask=input_mask,
283
- )
284
- k_cache_new.append(new_k_line)
285
- v_cache_new.append(new_v_line)
286
-
287
- x = self.final_norm(x)
288
-
289
- return x @ self.token_embedding.weight.t(), *k_cache_new, *v_cache_new
290
-
291
-
292
- class InitialDecoderLayer(nn.Module):
293
- def __init__(self, dim, inner_dim, n_head, ff_swiglu, ff_mult=4):
294
- super().__init__()
295
-
296
- self.norm1 = nn.LayerNorm(dim, bias=False)
297
-
298
- self.self_attention = MultiHeadAttention(
299
- dim, inner_dim=inner_dim, n_head=n_head
300
- )
301
-
302
- self.norm2 = nn.LayerNorm(dim, bias=False)
303
- self.cross_attention = MultiHeadAttention(
304
- dim, inner_dim=inner_dim, n_head=n_head
305
- )
306
-
307
- self.norm3 = nn.LayerNorm(dim, bias=False)
308
- self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
309
-
310
- def forward(self, x, context, rot_pos_emb, input_mask):
311
- dim = x.size()[1]
312
- causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
313
- _x = x
314
- x = self.norm1(x)
315
- x, new_k_cache, new_v_cache = self.self_attention(
316
- q=x,
317
- k=x,
318
- v=x,
319
- rot_pos_emb=rot_pos_emb,
320
- mask=causal_mask,
321
- )
322
- x = x + _x
323
-
324
- _x = x
325
- x = self.norm2(x)
326
- x, x_attn_k_cache, x_attn_v_cache = self.cross_attention(
327
- q=x, k=context, v=context, mask=input_mask,
328
- )
329
- x = x + _x
330
-
331
- _x = x
332
- x = self.norm3(x)
333
- x = self.ff(x)
334
- x = x + _x
335
-
336
- return x, new_k_cache, new_v_cache, x_attn_k_cache, x_attn_v_cache
337
-
338
-
339
- class DecoderInitial(Decoder):
340
- def __init__(self, dim, inner_dim, n_head, n_layers, dec_voc_size, ff_swiglu):
341
- super().__init__(dim, inner_dim, n_head, n_layers, dec_voc_size, ff_swiglu)
342
- self.layers = nn.ModuleList(
343
- [
344
- InitialDecoderLayer(dim, inner_dim, n_head, ff_swiglu)
345
- for _ in range(n_layers)
346
- ]
347
- )
348
-
349
- def forward(self, x, enc_src, input_mask):
350
- pos = torch.arange(x.shape[1], device=x.device)
351
- rot_pos_emb = self.rot_pos_emb(pos)
352
- x = self.token_embedding(x)
353
-
354
- # Shape [n_layers, batch_size, n_head, seq_len, inner_dim]. Cache K transposed.
355
- n_layer = len(self.layers)
356
- k_cache = []
357
- v_cache = []
358
- x_attn_k_cache = []
359
- x_attn_v_cache = []
360
-
361
- for idx, layer in enumerate(self.layers):
362
- x, new_k_line, new_v_line, new_x_attn_k_line, new_x_attn_v_line = layer(
363
- x,
364
- enc_src,
365
- rot_pos_emb,
366
- input_mask,
367
- )
368
-
369
- k_cache.append(new_k_line)
370
- v_cache.append(new_v_line)
371
- x_attn_k_cache.append(new_x_attn_k_line)
372
- x_attn_v_cache.append(new_x_attn_v_line)
373
-
374
- x = self.final_norm(x)
375
-
376
- return (
377
- x @ self.token_embedding.weight.t(),
378
- *k_cache,
379
- *v_cache,
380
- *x_attn_k_cache,
381
- *x_attn_v_cache,
382
- )
383
-
384
-
385
- class AudioPreprocessor(nn.Module):
386
- def __init__(self, dim):
387
- super().__init__()
388
- self.audio_preprocess = nn.Sequential(
389
- nn.Conv1d(1, dim, 127, 64, bias=False),
390
- nn.Tanh(),
391
- nn.GroupNorm(1, dim),
392
- nn.Conv1d(dim, 2 * dim, 7, 3),
393
- nn.GELU(),
394
- nn.Conv1d(2 * dim, dim, 3, 2),
395
- nn.GELU(),
396
- Rearrange("... c s -> ... s c"),
397
- )
398
-
399
- def forward(self, src):
400
- assert (
401
- src.shape[-1] >= 1023
402
- ), f"src shape[-1] {src.shape[-1]} should be at least 1023"
403
- src = src.reshape((-1, 1, src.shape[-1]))
404
- return self.audio_preprocess(src)
405
-
406
-
407
- class MoonshineModelTorch(nn.Module):
408
- def __init__(
409
- self,
410
- dim,
411
- inner_dim,
412
- enc_depth,
413
- dec_depth,
414
- n_head=8,
415
- dec_voc_size=32768,
416
- enc_ff_swiglu=False,
417
- dec_ff_swiglu=False,
418
- ):
419
- super().__init__()
420
- self.preprocessor = AudioPreprocessor(dim)
421
- self.encoder = Encoder(
422
- dim, inner_dim, n_head, enc_depth, ff_swiglu=enc_ff_swiglu
423
- )
424
- self.decoder_initial = DecoderInitial(
425
- dim, inner_dim, n_head, dec_depth, dec_voc_size, ff_swiglu=dec_ff_swiglu
426
- )
427
- self.decoder = Decoder(
428
- dim, inner_dim, n_head, dec_depth, dec_voc_size, ff_swiglu=dec_ff_swiglu
429
- )
430
- self.dec_depth = dec_depth
431
- self.n_head = n_head
432
- self.d_head = inner_dim // n_head
433
-
434
- def generate(self, src, mask):
435
- preprocessed = self.preprocessor(src)
436
- batch_size = preprocessed.shape[0]
437
-
438
- # Get max sequence length based on number of unmasked inputs for each sample in batch.
439
- token_limit_factor = 6.5 / 16000.0 # Maximum of 6.5 tokens per second.
440
- if mask is not None:
441
- seq_lens = torch.sum(mask, dim=-1, keepdim=True) * token_limit_factor
442
- else:
443
- token_limit = torch.tensor([src.shape[-1] * token_limit_factor])
444
- seq_lens = torch.stack([token_limit for _ in range(batch_size)])
445
- seq_lens = seq_lens.to(torch.int32).to(src.device).squeeze()
446
-
447
- # Preprocess mask so that it matches preprocessed audio.
448
- if mask is not None:
449
- mask = mask[..., :-127:64][..., :-7:3][..., :-3:2].to(torch.bool)
450
- mask = ~mask.reshape((batch_size, 1, 1, -1))
451
- mask = torch.nn.functional.pad(mask, (0, preprocessed.shape[-2] - mask.shape[-1]))
452
-
453
- enc = self.encoder(preprocessed, mask)
454
-
455
- sot_token = 1
456
- eot_token = 2
457
-
458
- sot_array = [[sot_token] for _ in range(batch_size)]
459
- seq = torch.as_tensor(sot_array).to(src.device)
460
-
461
- vals = self.decoder_initial(x=seq, enc_src=enc, input_mask=mask)
462
- logits = vals[0]
463
- k_cache, v_cache, x_attn_k_cache, x_attn_v_cache = [
464
- vals[i : i + self.dec_depth]
465
- for i in range(1, 1 + self.dec_depth * 4, self.dec_depth)
466
- ]
467
-
468
- sample = logits[:, -1].argmax(dim=-1, keepdim=True)
469
- seq = torch.cat((seq, sample), dim=-1)
470
-
471
- eot_mask = torch.zeros((batch_size), dtype=torch.bool).to(src.device)
472
- while not torch.all(eot_mask):
473
- vals = self.decoder(
474
- seq,
475
- mask,
476
- *k_cache,
477
- *v_cache,
478
- *x_attn_k_cache,
479
- *x_attn_v_cache,
480
- )
481
- logits = vals[0]
482
- k_cache = vals[1 : self.dec_depth + 1]
483
- v_cache = vals[self.dec_depth + 1 :]
484
- logits = logits[:, -1] # get last token
485
- sample = logits.argmax(dim=-1, keepdim=True)
486
- # For each sample in batch detect EOT or token limit reached.
487
- eot_mask = eot_mask | (sample.squeeze() == eot_token)
488
- eot_mask = eot_mask | (seq.shape[-1] >= seq_lens)
489
- sample = sample.masked_fill(eot_mask.reshape((-1, 1)), eot_token)
490
- seq = torch.cat((seq, sample), dim=-1)
491
-
492
- return seq
493
-
494
-
495
- class MoonshineModel(PreTrainedModel):
496
- config_class = MoonshineConfig
497
-
498
- def __init__(self, config):
499
- super().__init__(config)
500
- self.model = MoonshineModelTorch(
501
- dim = config.dim,
502
- inner_dim = config.inner_dim,
503
- enc_depth = config.enc_depth,
504
- dec_depth = config.dec_depth,
505
- n_head = config.n_head,
506
- dec_voc_size = config.dec_voc_size,
507
- enc_ff_swiglu = config.enc_ff_swiglu,
508
- dec_ff_swiglu = config.dec_ff_swiglu,
509
- )
510
-
511
- def forward(self, tensor, input_mask=None):
512
- return self.model.generate(tensor, input_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
special_tokens_map.json DELETED
@@ -1 +0,0 @@
1
- {}
 
 
tokenizer_config.json DELETED
The diff for this file is too large to render. See raw diff