import typing as tp from einops import rearrange import numpy as np import torch from torch import nn class EncodecModel(nn.Module): def __init__(self, decoder=None, quantizer=None, frame_rate=None, sample_rate=None, channels=None, causal=False, renormalize=False): super().__init__() self.frame_rate=0 self.sample_rate=0 self.channels=0 self.decoder = decoder self.quantizer = quantizer self.frame_rate = frame_rate self.sample_rate = sample_rate self.channels = channels self.renormalize = renormalize self.causal = causal if self.causal: # we force disabling here to avoid handling linear overlap of segments # as supported in original EnCodec codebase. assert not self.renormalize, 'Causal model does not support renormalize' @property def total_codebooks(self): """Total number of quantizer codebooks available.""" return self.quantizer.total_codebooks @property def num_codebooks(self): """Active number of codebooks used by the quantizer.""" return self.quantizer.num_codebooks def set_num_codebooks(self, n): """Set the active number of codebooks used by the quantizer.""" self.quantizer.set_num_codebooks(n) @property def cardinality(self): """Cardinality of each codebook.""" return self.quantizer.bins def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: scale: tp.Optional[torch.Tensor] if self.renormalize: mono = x.mean(dim=1, keepdim=True) volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() scale = 1e-8 + volume x = x / scale scale = scale.view(-1, 1) else: scale = None return x, scale def postprocess(self, x: torch.Tensor, scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: if scale is not None: assert self.renormalize x = x * scale.view(-1, 1, 1) return x def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): # B,K,T -> B,C,T emb = self.decode_latent(codes) out = self.decoder(emb) out = self.postprocess(out, scale) return out def decode_latent(self, codes: torch.Tensor): """Decode from the discrete codes to continuous latent space.""" return self.quantizer.decode(codes)