Dionyssos's picture
duplicate xN_DRAW - for long gen
d9889a1
raw
history blame
2.73 kB
import typing as tp
from einops import rearrange
import numpy as np
import torch
from torch import nn
class EncodecModel(nn.Module):
def __init__(self,
decoder=None,
quantizer=None,
frame_rate=None,
sample_rate=None,
channels=None,
causal=False,
renormalize=False):
super().__init__()
self.frame_rate=0
self.sample_rate=0
self.channels=0
self.decoder = decoder
self.quantizer = quantizer
self.frame_rate = frame_rate
self.sample_rate = sample_rate
self.channels = channels
self.renormalize = renormalize
self.causal = causal
if self.causal:
# we force disabling here to avoid handling linear overlap of segments
# as supported in original EnCodec codebase.
assert not self.renormalize, 'Causal model does not support renormalize'
@property
def total_codebooks(self):
"""Total number of quantizer codebooks available."""
return self.quantizer.total_codebooks
@property
def num_codebooks(self):
"""Active number of codebooks used by the quantizer."""
return self.quantizer.num_codebooks
def set_num_codebooks(self, n):
"""Set the active number of codebooks used by the quantizer."""
self.quantizer.set_num_codebooks(n)
@property
def cardinality(self):
"""Cardinality of each codebook."""
return self.quantizer.bins
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
scale: tp.Optional[torch.Tensor]
if self.renormalize:
mono = x.mean(dim=1, keepdim=True)
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
scale = 1e-8 + volume
x = x / scale
scale = scale.view(-1, 1)
else:
scale = None
return x, scale
def postprocess(self,
x: torch.Tensor,
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
if scale is not None:
assert self.renormalize
x = x * scale.view(-1, 1, 1)
return x
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
# B,K,T -> B,C,T
emb = self.decode_latent(codes)
out = self.decoder(emb)
out = self.postprocess(out, scale)
return out
def decode_latent(self, codes: torch.Tensor):
"""Decode from the discrete codes to continuous latent space."""
return self.quantizer.decode(codes)