File size: 38,349 Bytes
f5b4ff2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
# @ [email protected]

import random

import numpy as np
import logging
import argparse, copy
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.classification import MulticlassAccuracy

from .modules.utils import make_pad_mask

from .modules.embedding import SinePositionalEmbedding, TokenEmbedding
from .modules.transformer import (
    LayerNorm,
    TransformerEncoder,
    TransformerEncoderLayer,
)
from huggingface_hub import PyTorchModelHubMixin
from argparse import Namespace
import typing as tp

def top_k_top_p_filtering(
    logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
):
    """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
    Args:
        logits: logits distribution shape (batch size, vocabulary size)
        if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
        if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
            Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        Make sure we keep at least min_tokens_to_keep per batch example in the output
    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    if top_k > 0:
        top_k = min(
            max(top_k, min_tokens_to_keep), logits.size(-1)
        )  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(
            F.softmax(sorted_logits, dim=-1), dim=-1
        )

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
            ..., :-1
        ].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(
            1, sorted_indices, sorted_indices_to_remove
        )
        logits[indices_to_remove] = filter_value
    return logits


def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
    # temperature: (`optional`) float
    #     The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
    # top_k: (`optional`) int
    #     The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
    # top_p: (`optional`) float
    #     The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.

    # Temperature (higher temperature => more likely to sample low probability tokens)
    if temperature != 1.0:
        logits = logits / temperature
    # Top-p/top-k filtering
    logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
    # Sample
    token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
    return token

class SSR_Speech(
        nn.Module,
        PyTorchModelHubMixin,
        library_name="ssr_speech",
        repo_url=None,
        tags=None,
    ):
    def __new__(cls, args: Optional[Namespace] = None, config: Optional[Dict] = None, **kwargs) -> "SSR_Speech":
        # If initialized from Namespace args => convert to dict config for 'PyTorchModelHubMixin' to serialize it as config.json
        # Won't affect instance initialization
        if args is not None:
            if config is not None:
                raise ValueError("Cannot provide both `args` and `config`.")
            config = vars(args)
        return super().__new__(cls, args=args, config=config, **kwargs)

    def __init__(self, args: Optional[Namespace] = None, config: Optional[Dict] = None):
        super().__init__()

        # If loaded from HF Hub => convert config.json to Namespace args before initializing
        if args is None:
            if config is None:
                raise ValueError("Either `args` or `config` must be provided.")
            args = Namespace(**config)

        self.args = copy.copy(args)
        if not getattr(self.args, "n_special", False):
            self.args.n_special = 3
        self.args.eos = getattr(self.args, "eos", -1)

        if isinstance(self.args.audio_vocab_size, str):
            self.args.audio_vocab_size = eval(self.args.audio_vocab_size)

        self.n_text_tokens = self.args.text_vocab_size + 1
        assert self.args.text_pad_token == self.args.text_vocab_size, f"self.args.text_vocab_size: {self.args.text_vocab_size}, self.args.text_pad_token: {self.args.text_pad_token}"

        self.n_audio_tokens = [int(self.args.audio_vocab_size) + self.args.n_special + self.args.max_n_spans] * self.args.n_codebooks # special tokens: empty token, EOG token, audio pad token, mask tokens
        assert self.args.audio_vocab_size == self.args.empty_token, self.args.empty_token
        assert self.args.eog == self.args.audio_vocab_size + 1, self.args.eog
        assert self.args.audio_pad_token == self.args.audio_vocab_size + 2, self.args.audio_pad_token
        assert self.args.eos == self.args.audio_vocab_size + 3, self.args.eos
        assert self.args.sos == self.args.audio_vocab_size + 4, self.args.sos
        assert self.args.mts == self.args.audio_vocab_size + 5, self.args.mts

        self.text_embedding = TokenEmbedding(
            dim_model=self.args.d_model,
            vocab_size=self.n_text_tokens, 
            dropout=self.args.text_embedding_dropout
        )

        self.audio_embedding = nn.ModuleList(
            [
                TokenEmbedding(
                dim_model=self.args.audio_embedding_dim, 
                vocab_size=self.n_audio_tokens[k], 
                dropout=self.args.audio_embedding_dropout
            ) for k in range(self.args.n_codebooks)
            ]
        )
        self.text_positional_embedding = SinePositionalEmbedding(
            self.args.d_model,
            dropout=self.args.text_positional_embedding_dropout,
            scale=False,
            alpha=True, # learnable scaler, scale the volume of positional embedding
        )
        self.audio_positional_embedding = SinePositionalEmbedding(
            self.args.d_model,
            dropout=self.args.audio_positional_embedding_dropout,
            scale=False,
            alpha=True, # learnable scaler, scale the volume of positional embedding
        )

        dec_layer = TransformerEncoderLayer(
            self.args.d_model,
            self.args.nhead,
            dim_feedforward=self.args.d_model * 4,
            dropout=self.args.trm_dropout,
            batch_first=True,
            norm_first=True,
            layer_norm_cls=LayerNorm
        )
        self.decoder = TransformerEncoder(
            dec_layer,
            num_layers=self.args.num_decoder_layers,
            norm=LayerNorm(self.args.d_model),
        )
        
        self.predict_layer = nn.ModuleList(
            [
                nn.Sequential(nn.Linear(self.args.d_model, self.args.audio_vocab_size//2), nn.GELU(), nn.Linear(self.args.audio_vocab_size//2, self.n_audio_tokens[k])) for k in range(self.args.n_codebooks)
            ]
        )
        
        self.accuracy_metrics = nn.ModuleList(
            [MulticlassAccuracy(
                self.n_audio_tokens[k],
                top_k=10,
                average="micro",
                multidim_average="global",
                ignore_index=None,
            ) for k in range(self.args.n_codebooks)]
        )

    def embed_y(self, cated_y):
        # [K,T,B]
        embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, T, B, D]
        assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
        assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
        embedded_y = embedded_y.sum(dim=0) # [K,T,B,D]->[T,B,D]
        embedded_y = embedded_y.transpose(1,0) # [T,B,D]->[B,T,D]
        return embedded_y

        
    def prepare_input_target(self, cated_y, y_lens):
        
        embedded_y = self.embed_y(cated_y) # [B,T,D]
        
        # positional embedding
        y_input = self.audio_positional_embedding(embedded_y)
        
        # make attention mask and padding mask
        y_padding_mask = make_pad_mask(y_lens).to(cated_y.device)
        y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y_padding_mask.device)
        return y_input, y_padding_mask, y_attention_mask


    def dec_forward(
            self, 
            x_input, 
            x_lens,
            x_attention_mask,
            x_padding_mask,
            y_input,
            new_y_lens,
            y_attention_mask,
            y_padding_mask,
            past=None,
            last_3_tokens=False
        ):
            x_attn_mask = F.pad(
                x_attention_mask,
                (0, new_y_lens.max()),
                value=True,
            ) # x attn to all x, doesn't attn to any y, this follow figure 3 of the valle paper
            y_attn_mask = F.pad(
                y_attention_mask,
                (x_lens.max(), 0), # y is padded at the front
                value=False,
            ) # y attn to all x, for y itself use lower triangle mask to ensure autoregressive
            xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)

            # merge key padding and attention masks
            bsz, src_len = x_input.shape[0], x_lens.max() + new_y_lens.max()
            xy_padding_mask = torch.concat([x_padding_mask, y_padding_mask], dim=1)
            _xy_padding_mask = (
                xy_padding_mask.view(bsz, 1, 1, src_len)
                .expand(-1, self.args.nhead, -1, -1)
                .reshape(bsz * self.args.nhead, 1, src_len)
            )
            # Check shapes and resize+broadcast as necessary
            if xy_attn_mask.shape != _xy_padding_mask.shape:
                assert xy_attn_mask.ndim + 1 == _xy_padding_mask.ndim, f"xy_attn_mask.shape: {xy_attn_mask.shape}, _xy_padding_mask: {_xy_padding_mask.shape}"
                xy_attn_mask = xy_attn_mask.unsqueeze(0).repeat(_xy_padding_mask.shape[0], 1, 1)  # Example approach
            xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)

            new_attn_mask = torch.zeros_like(xy_attn_mask)
            new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
            xy_attn_mask = new_attn_mask

            xy_input = torch.cat([x_input, y_input], dim=1)

            if past == None: # do not use kvcache
                out, _ =  self.decoder((xy_input, None), mask=xy_attn_mask)
                return out[:, x_lens.max():], None
            else: # use kvcache
                if past.ndim > 3: # uses kvcache, only need to pass the last tokens, this doesn't work with multi-span speech editing yet
                    if last_3_tokens:
                        xy_input = xy_input[:, -3:]
                        xy_attn_mask = xy_attn_mask[:, -3:]
                    else:
                        xy_input = xy_input[:, -1:]
                        xy_attn_mask = xy_attn_mask[:, -1:]

                out, present = self.decoder((xy_input, None), mask=xy_attn_mask, past=past)
                if isinstance(out, tuple): # get rid of stage_embedding
                    out = out[0]

                if out.shape[1] > x_lens.max(): # the first pass, not kvcache yet
                    return out[:, x_lens.max():], present
                else: # used kvcache
                    return out, present

    def forward(self, batch):
        """
        Args:
          x:
            A 2-D tensor of shape (N, S).
          x_lens:
            A 1-D tensor of shape (N,). It contains the number of tokens in `x`
            before padding.
          y:
            A 3-D tensor of shape (N, K, T).
            where K is the number of codebooks
          y_lens:
            A 1-D tensor of shape (N,). It contains the number of tokens in `x`
            before padding.
        """
        x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
        if len(x) == 0:
            return None
        x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
        y = y[:, :, :y_lens.max()]
        assert x.ndim == 2, x.shape
        assert x_lens.ndim == 1, x_lens.shape
        assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape
        assert y_lens.ndim == 1, y_lens.shape
        
        targets = y.clone()
        y = y.permute(1,2,0) # [B,K,T]->[K,T,B]
        # makes attention mask and padding mask for x
        x_padding_mask = make_pad_mask(x_lens).to(x.device)
        x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x_padding_mask.device)
        x_input = self.text_embedding(x)
        x_input = self.text_positional_embedding(x_input)
        y_input, y_padding_mask, y_attention_mask = self.prepare_input_target(y, y_lens)
        y_out = self.dec_forward(
                    x_input, 
                    x_lens,
                    x_attention_mask,
                    x_padding_mask,
                    y_input,
                    y_lens,
                    y_attention_mask,
                    y_padding_mask
                )
        y_out = y_out[0] # no kv-caching during training
        assert y_out.shape == y_input.shape, f"y_out.shape: {y_out.shape}, y_input.shape: {y_input.shape}" # [B S D]
        
        logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card]
        assert logits.shape[1] == self.args.n_codebooks and logits.shape[3] == self.n_audio_tokens[0], logits.shape

        targets = targets.permute(1,0,2) # [K B T]
        logits = logits.permute(1,0,2,3) # [K B S card]

        logits = logits[:, :, :-1]
        targets = targets[:, :, 1:]
        
        if self.args.predict_mask_token:
            masks = (targets != self.args.audio_pad_token) & (targets != self.args.empty_token)
        else:
            masks = (targets != self.args.audio_pad_token) & (targets != self.args.empty_token) & (targets < self.args.mts)
            
        tmp_masks = masks.clone()
        
        if not self.args.predict_all:
            eos_pos = (targets == self.args.mts).nonzero(as_tuple=False)
            for k, b, t in eos_pos:
                tmp_masks[k, b, :t] = False
            
            
        assert masks.shape[0] == self.args.n_codebooks, masks.shape
        
        loss = []
        ntokens = []
        top10acc = []

        for k, (logit, target, mask, tmp_mask) in enumerate(zip(logits, targets, masks, tmp_masks)):
            logit = logit.reshape(-1, logit.size(-1)) # B*S card
            target = target.reshape(-1) # B*T
            mask = mask.reshape(-1).bool()
            tmp_mask = tmp_mask.reshape(-1).bool()

            loss.append(F.cross_entropy(logit[tmp_mask], target[tmp_mask], reduction='mean'))
            top10acc.append(self.accuracy_metrics[k](logit[tmp_mask].detach(), target[tmp_mask]))
            ntokens.append(len(target[mask]))
        
        all_ntokens = sum(ntokens)
        if self.args.codebook_weight != None:
            codebook_weight = eval(self.args.codebook_weight)
        else:
            codebook_weight = [1.] * self.args.n_codebooks
        loss = sum([l*nt*cw for l, nt, cw in zip(loss, ntokens, codebook_weight)])
        top10acc_by_codebook = [t10a*nt for t10a, nt in zip(top10acc, ntokens)]
        top10acc = sum(top10acc_by_codebook)
        ntokens = torch.tensor(all_ntokens).to(logits.device)

        return {
            "loss": loss,
            "top10acc": top10acc,
            "top10acc_by_codebook": top10acc_by_codebook,
            "effective_ntoken": ntokens,
        }

    def rearrange(self, y, non_mask_intervals, mask_intervals):
    
        assert self.args.eos > 0, f"eos={self.args.eos} should > 0"
        
        rearranged_y = []
        sos_tensor = torch.LongTensor([self.args.sos] * self.args.n_codebooks).unsqueeze(-1).to(y.device)
        eos_tensor = torch.LongTensor([self.args.eos] * self.args.n_codebooks).unsqueeze(-1).to(y.device)
        eog_tensor = torch.LongTensor([self.args.eog] * self.args.n_codebooks).unsqueeze(-1).to(y.device)
        for i, item in enumerate(non_mask_intervals):
            if i == 0:
                if item[0] == item[1]: # case: (0,0)
                    rearranged_y.append(sos_tensor)
                else:
                    rearranged_y.append(torch.cat([sos_tensor, y[:, item[0]: item[1]]], dim=-1))
            elif i == len(non_mask_intervals)-1:
                if item[0] == item[1]: # case: (N,N)
                    rearranged_y.append(eos_tensor)
                else:
                    rearranged_y.append(torch.cat([y[:, item[0]: item[1]], eos_tensor], dim=-1))
            else:
                rearranged_y.append(y[:, item[0]: item[1]])
                
        for i, item in enumerate(mask_intervals):
            rearranged_y.append(torch.cat([y[:, item[0]: item[1]], eog_tensor], dim=-1))

        return rearranged_y

    def get_pattern_sequence(self, tokens: torch.Tensor, n_q: int, special_token: int, delays: tp.Optional[tp.List[int]] = None, 
                         empty_initial: int = 0) -> torch.Tensor:
        """Generate a pattern sequence for delayed codebooks without batch dimension.
    
        Args:
            tokens (torch.Tensor): Input tensor of shape [K, T].
            n_q (int): Number of codebooks.
            delays (Optional[List[int]]): Delay for each codebook. Defaults to increasing delays.
            empty_initial (int): Number of initial empty steps. Defaults to 0.
            special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
    
        Returns:
            torch.Tensor: Modified tokens based on the pattern.
        """
        K, T = tokens.shape
        assert K == n_q, "Number of codebooks (K) must match n_q"
        if delays is None:
            delays = list(range(n_q))
        max_delay = max(delays)
        pattern_length = T + max_delay + empty_initial
        pattern_tokens = torch.full((K, pattern_length), fill_value=special_token, dtype=tokens.dtype).to(tokens.device)
    
        for t in range(T):
            for q in range(n_q):
                delayed_t = t + delays[q] + empty_initial
                if delayed_t < pattern_length:
                    pattern_tokens[q, delayed_t] = tokens[q, t]
    
        return pattern_tokens

    def revert_pattern_sequence(self, pattern_tokens: torch.Tensor, n_q: int,
                            delays: tp.Optional[tp.List[int]] = None, special_token: int = -1) -> torch.Tensor:
        """Revert the pattern sequence back to the original multi-codebook sequence without batch dimension.
    
        Args:
            pattern_tokens (torch.Tensor): Pattern tensor of shape [K, S].
            n_q (int): Number of codebooks.
            delays (Optional[List[int]]): Delay for each codebook. Defaults to increasing delays.
            special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
    
        Returns:
            torch.Tensor: Reverted tokens of shape [K, T].
        """
        K, S = pattern_tokens.shape
        assert K == n_q, "Number of codebooks (K) must match n_q"
        if delays is None:
            delays = list(range(n_q))
        T = S - max(delays)
        reverted_tokens = torch.full((K, T), fill_value=special_token, dtype=pattern_tokens.dtype).to(pattern_tokens.device)
    
        for t in range(T):
            for q in range(n_q):
                delayed_t = t + delays[q]
                if delayed_t < S:
                    reverted_tokens[q, t] = pattern_tokens[q, delayed_t]
    
        return reverted_tokens
        
    def shift(self, rearranged_y):
        
        shifted_y = [self.get_pattern_sequence(tokens=cur_y, n_q=self.args.n_codebooks, special_token=self.args.empty_token) for cur_y in rearranged_y] # the first item is values, later two are indexes and mask
        
        return shifted_y
    
    def insert_mask(self, shifted_y):

        num_masks = (len(shifted_y) - 1) // 2
        assert num_masks == (len(shifted_y) - 1) / 2, len(shifted_y)
        emb_inds = list(range(self.args.mts, self.args.mts+ self.args.max_n_spans))
        if self.args.shuffle_mask_embedding:
            random.shuffle(emb_inds)
        emb_inds_use = emb_inds[:num_masks]
        mask_value = emb_inds_use + emb_inds_use
        assert len(shifted_y) == len(mask_value) + 1, len(mask_value)
        
        inserted_y = []
        mask_position = [-1] * (self.args.max_n_spans*2)
        for j in range(len(shifted_y)-1):
            inserted_y.append(shifted_y[j])
            mask_position[j] = sum([item.shape[1] for item in inserted_y]) # each item is of shape [K S], so take shape[1]
            tmp = torch.LongTensor([mask_value[j]] * self.args.n_codebooks).unsqueeze(-1).to(shifted_y[0].device)
            inserted_y.append(tmp)

        inserted_y.append(shifted_y[-1])
        mask_position = [item for item in mask_position if item != -1]
        
        return inserted_y, mask_position
    
    def cat_y(self, inserted_y):

        cated_y = torch.cat(inserted_y, dim=1)
        assert cated_y.shape[0] == self.args.n_codebooks, cated_y.shape
        new_y_lens = cated_y.shape[1]
        
        return cated_y, new_y_lens
    
    def inference(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        prompt_x: torch.Tensor,
        prompt_x_lens: torch.Tensor,
        y: torch.Tensor,
        prompt: torch.Tensor,
        mask_interval: list[torch.Tensor],
        top_k: int=-100,
        top_p: float=1.0,
        temperature: float=1.0,
        stop_repetition: int=-1,
        kvcache: int=1,
        silence_tokens: list[int]=[1388,1898,131],
        cfg_coef: float=1.5,
        aug_text: bool=False,
        aug_context: bool=False,
        cfg_pretrained: bool=False,
    ) -> torch.Tensor:
        """
        Args:
          x:
            A 2-D tensor of shape (1, L).
          x_lens:
            A 1-D tensor of shape (1,). It contains the number of tokens in `x`
            before padding.
          y:
            A 3-D tensor of shape (1, T, K).
          mask_interval:
            a list of tensors of shape (M, 2). contains M mask_start and mask_end. list length is actually 1, because we only support single sample inference for now
          top_k: (`optional`) int
            The number of highest probability tokens to keep for top-k-filtering. Default to -100.
          top_p: (`optional`) float
            For Neucleus sampling
          temperature: (`optional`) float
            The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
          stop_repetition (`optional`) int
            if not -1, will set the logits of a token that repeated this many times to be -100000, to avoid generating it again. This only apply to tokens from the first codebook
          kvcache (`optional`) int
            if 1, use kvcache to speed up sampling
          cfg_coef: float (>= 1.0)
          aug_text: whether use cfg to improve the text input
          aug_context: whether improve the context by combining original audio and text
          cfg_pretrained: whether use cfg in training
        """
        
        assert cfg_coef >= 1.0, cfg_coef
        assert x.ndim == 2, x.shape
        assert x_lens.ndim == 1, x_lens.shape
        assert y.ndim == 3, y.shape
        y = y.transpose(2,1) # [1,T,K] -> [1,K,T]
        assert prompt.ndim == 3, prompt.shape
        prompt = prompt.transpose(2,1)
        assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding
        assert prompt.shape[0] == 1 and prompt.shape[1] == self.args.n_codebooks, prompt.shape # there is no padding
        assert mask_interval.shape == torch.Size((1, mask_interval.shape[1], 2)), mask_interval

        # whether to use context
        context_len = sum([item[1] - item[0] for item in mask_interval[0]])
        if aug_context and context_len < 2 * 50:
            aug_context = True
        else:
            aug_context = False
        
        # augment
        if aug_text and not aug_context: # [t, ab, m] [t', ab, m]
            y = y.repeat(2, 1, 1)
            if not cfg_pretrained:
                uncond_x = torch.randint(0, self.n_text_tokens, (1, x.shape[1])).to(x.device)
            else:
                uncond_x = torch.tensor([self.args.text_vocab_size-1], dtype=torch.long).unsqueeze(0).repeat(1, x.shape[1]).to(x.device)
            x = torch.cat([x, uncond_x], dim=0)

        if aug_text and aug_context: # [tc, t, c, ab, m] [tc, t', c, ab, m]
            out_len = prompt.shape[2]
            gt_y = torch.cat([prompt, y], dim=-1)
            y = gt_y.repeat(2, 1, 1)
            gt_x = torch.cat([prompt_x, x], dim=1)
            if not cfg_pretrained:
                uncond_x = torch.randint(0, self.n_text_tokens, (1, gt_x.shape[1])).to(gt_x.device)
            else:
                uncond_x = torch.tensor([self.args.text_vocab_size-1], dtype=torch.long).unsqueeze(0).repeat(1, gt_x.shape[1]).to(gt_x.device)
            x = torch.cat([gt_x, uncond_x], dim=0)
            
        if not aug_text and aug_context: # [tc, t, c, ab, m]
            out_len = prompt.shape[2]
            y = torch.cat([prompt, y], dim=-1)
            x = torch.cat([prompt_x, x], dim=1)


        # make x attention mask and x_input
        x_lens = torch.LongTensor([x.shape[-1]]).to(x_lens.device)
        x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device)
        x_input = self.text_embedding(x)
        x_input = self.text_positional_embedding(x_input)

        # make initial y_input
        # make mask_interval and non_mask_interval
        y_len = y.shape[2]
        y_lens = torch.LongTensor([y_len]).to(y.device)
        mask_interval = mask_interval[0]
        if aug_context:
            mask_interval = [[item[0]+out_len, item[1]+out_len] for item in mask_interval]
        starts =  [item[0].item() for item in mask_interval] + [y_len]
        ends = [0] + [item[1].item() for item in mask_interval]
        mask_intervals = [
            (item[0].item(), item[1].item()) for item in mask_interval
        ] # a werid name change, mask_interval is input, now is mask_intervals, with one more dimension
        non_mask_intervals = [
            (ns, ne) for ns, ne in zip(ends, starts)
        ]

        # prepare input sequences
        rearranged_y = self.rearrange(y[0], non_mask_intervals, mask_intervals)
        shifted_y = self.shift(rearranged_y) # each element [K S], patterns is not used, as we directly use the original input y
        inserted_y, mask_position = self.insert_mask(shifted_y)
        cated_y, new_y_lens = self.cat_y(inserted_y) # KT
        num_task = len(mask_position)//2
        cated_y = cated_y[:, :mask_position[num_task]] # of shape [K,T] input of the network
        new_y_lens = torch.LongTensor([mask_position[num_task]]).to(cated_y.device)
        cated_y = cated_y.unsqueeze(0).permute(1,2,0) # B,K,T -> K,T,B
        if aug_text:
            cated_y = cated_y.repeat(1, 1, 2)
        embedded_y = self.embed_y(cated_y) #BTD
        
        if aug_text:
            x_padding_mask = torch.full((2, x_lens[0]), False).to(x.device)
            if cfg_pretrained:
                x_padding_mask[1:, 1:] = True
            past = torch.ones([self.args.num_decoder_layers, 2, 2], device=x.device, dtype=torch.float32) if kvcache else None
        else:
            x_padding_mask = torch.full((1, x_lens[0]), False).to(x.device)
            past = torch.ones([self.args.num_decoder_layers, 2, 1], device=x.device, dtype=torch.float32) if kvcache else None
            
        emb_inds = list(range(self.args.mts, self.args.mts+ self.args.max_n_spans))
        
        
        generated = []
        logging.info(f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
            
        for idx in range(num_task):
            cur_generated = []
            prev_token = None
            consec_silence_count = 0
            num_gen = 0
            num_eog = 0

            # add mask token
            mts = torch.LongTensor([emb_inds[idx]] * self.args.n_codebooks).unsqueeze(-1).to(embedded_y.device) # K, 1
            mts_emb = torch.stack([self.audio_embedding[k](mts[k]) for k in range(self.args.n_codebooks)], dim=0) # [K,1,D]
            mts_emb = mts_emb.sum(dim=0,keepdim=True) # [1,1,D]
            if aug_text:
                mts_emb = mts_emb.repeat(2,1,1)
            embedded_y = torch.cat([embedded_y, mts_emb], dim=1)
            # positional embedding
            y_input = self.audio_positional_embedding(embedded_y) # [B T D]
            # make attention mask and padding mask
            y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
            new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
            if aug_text:
                y_padding_mask = torch.full((2,new_y_lens[0]), False).to(y.device)
            else:
                y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
            
            while True:
                # get model output
                y_out, present = self.dec_forward(
                                        x_input, 
                                        x_lens,
                                        x_attention_mask,
                                        x_padding_mask,
                                        y_input,
                                        new_y_lens,
                                        y_attention_mask,
                                        y_padding_mask,
                                        past=past,
                                        last_3_tokens=False
                                        )
                if past != None:
                    past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype)
                y_out = y_out[:, -1:] # only take the last one
                logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card]
                logits = logits.squeeze() # [K card]
                if aug_text:
                    logits = cfg_coef * logits[0] + (1 - cfg_coef) * logits[1]
                assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}"
                # filter out mts, sos and eos
                for jj in range(self.args.n_codebooks):
                    logits[jj][self.args.eos] = -10000.
                    logits[jj][self.args.sos] = -10000.
                    for mts in range(self.args.mts, self.args.mts+ self.args.max_n_spans):
                        logits[jj][mts] = -10000.
                # add first empty tokens
                if num_gen < self.args.n_codebooks - 1:
                    for jj in range(num_gen + 1, self.args.n_codebooks):
                        logits[jj][self.args.empty_token] = 10000.
                # deal with eog token
                if num_eog > 0: # codebook 1 has produced eog token
                    for jj in range(num_eog+1,self.args.n_codebooks):
                        logits[jj][self.args.eog] = -10000
                        logits[jj][self.args.empty_token] = -10000
                    samples = topk_sampling(
                                logits, top_k=top_k, top_p=top_p, temperature=temperature
                            ) # [K, 1]
                    for jj in range(num_eog):
                        samples[jj, 0] = self.args.empty_token
                    samples[num_eog, 0] = self.args.eog
                    num_eog += 1
                else: # codebook 1 did not produce eog token
                    # filter out eog for codebook 2-4
                    for jj in range(1,self.args.n_codebooks):
                        logits[jj][self.args.eog] = -10000
                        
                    # silence repetition handling
                    if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
                        if logits[0, prev_token] < 0:
                            logits[0, prev_token] = logits[0, prev_token] * (consec_silence_count - (stop_repetition-1))
                        else:
                            logits[0, prev_token] = logits[0, prev_token] / (consec_silence_count - (stop_repetition-1))
                    
                    samples = topk_sampling(
                                logits, top_k=top_k, top_p=top_p, temperature=temperature
                            ) # [K, 1]
                    
                    assert samples.shape == torch.Size((self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}"
                    
                    if (
                        samples[0,0] == self.args.eog or torch.argmax(logits[0], dim=-1) == self.args.eog or y_input.shape[1] > x_lens[0] * 10
                    ): # last one means y is already too long, shouldn't happen, but put it here
                        samples[0,0] = self.args.eog
                        num_eog += 1
                    
                    if samples[0,0] in silence_tokens and samples[0,0] == prev_token:
                        consec_silence_count += 1
                    else:
                        consec_silence_count = 0
                    prev_token = samples[0,0]

                num_gen += 1
                cur_generated.append(samples.squeeze(-1))
                
                if num_eog == self.args.n_codebooks: # current span is done
                    break
                    
                # prepare input for next token prediction
                samples_emb = torch.stack([self.audio_embedding[k](samples[k]) for k in range(self.args.n_codebooks)], dim=0) # [K,1,D]
                samples_emb = samples_emb.sum(dim=0,keepdim=True) # [1,1,D]
                if aug_text:
                    samples_emb = samples_emb.repeat(2, 1, 1)
                embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
                # positional embedding
                y_input = self.audio_positional_embedding(embedded_y) # [B T D]
                # make attention mask and padding mask
                y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
                new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
                if aug_text:
                    y_padding_mask = torch.full((2,new_y_lens[0]), False).to(y.device)
                else:
                    y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
            generated.append(cur_generated)
        assert len(generated) == num_task, f"len(generated): {len(generated)}, num_task: {num_task}"

        # # combine non_masked_span with generated spans
        # first need to shift the generated part back
        flatten_gen = []
        for i, orig_span in enumerate(generated):
            span = torch.stack(orig_span, dim=0) # [T K]
            span = span.transpose(1,0) # [K, T]
            assert span.shape[0] == self.args.n_codebooks, span.shape
            unshifted_span = self.revert_pattern_sequence(pattern_tokens=span, n_q=self.args.n_codebooks, special_token=self.args.empty_token)
            assert unshifted_span.shape[1] == span.shape[1]-self.args.n_codebooks+1, f"unshifted_span:{unshifted_span.shape}, orig_span:{span.shape}"
            unshifted_span = unshifted_span[:,:-1] # remove eog token
            flatten_gen.append(unshifted_span)
                
        res = []
        marks = []
        masks = []
        tmp = 0
        for orig_interval, gen in zip(non_mask_intervals, flatten_gen):
            res.append(y[0, :, orig_interval[0]:orig_interval[1]])
            masks.append((tmp, tmp+orig_interval[1]-orig_interval[0]))
            tmp_mark = [0] * (orig_interval[1] - orig_interval[0])
            marks = [*marks, *tmp_mark]
            res.append(gen)
            tmp += orig_interval[1]-orig_interval[0] + gen.shape[-1]
            tmp_mark = [1] * gen.shape[-1]
            marks = [*marks, *tmp_mark]
        if y.shape[-1] != non_mask_intervals[-1][1] + 1: # edit last tokens or tts
            res.append(y[0, :, non_mask_intervals[-1][0]:non_mask_intervals[-1][1]])
            masks.append((tmp, tmp+non_mask_intervals[-1][1]-non_mask_intervals[-1][0]))
            tmp_mark = [0] * (non_mask_intervals[-1][1] - non_mask_intervals[-1][0])
            marks = [*marks, *tmp_mark]
        res = torch.cat(res, dim=1).unsqueeze(0) # [K,new_T] -> [1, K, new_T]
        marks = torch.LongTensor(marks).unsqueeze(0)
        if aug_context:
            res = res[:, :, out_len:]
            marks = marks[:, out_len:]
            masks = [(item[0]-out_len, item[1]-out_len) for item in masks]
            non_mask_intervals = [(item[0]-out_len, item[1]-out_len) for item in non_mask_intervals]

        return res, marks, masks, non_mask_intervals


if __name__ == "__main__":
    # debug
    pass