Bai-YT
Gradio App for ConsistencyTTA V1
66982e9
raw
history blame
4.42 kB
import torch
from torch import nn
from audioldm.variational_autoencoder.modules import Encoder, Decoder
from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution
from audioldm.hifigan.utilities import get_vocoder, vocoder_infer
class AutoencoderKL(nn.Module):
def __init__(
self,
ddconfig=None,
lossconfig=None,
image_key="fbank",
embed_dim=None,
time_shuffle=1,
subband=1,
ckpt_path=None,
reload_from_ckpt=None,
ignore_keys=[],
colorize_nlabels=None,
monitor=None,
base_learning_rate=1e-5,
scale_factor=1
):
super().__init__()
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.ema_decoder = None
self.subband = int(subband)
if self.subband > 1:
print("Use subband decomposition %s" % self.subband)
self.quant_conv = nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.ema_post_quant_conv = None
self.vocoder = get_vocoder(None, "cpu")
self.embed_dim = embed_dim
if monitor is not None:
self.monitor = monitor
self.time_shuffle = time_shuffle
self.reload_from_ckpt = reload_from_ckpt
self.reloaded = False
self.mean, self.std = None, None
self.scale_factor = scale_factor
@property
def device(self):
return next(self.parameters()).device
def freq_split_subband(self, fbank):
if self.subband == 1 or self.image_key != "stft":
return fbank
bs, ch, tstep, fbins = fbank.size()
assert fbank.size(-1) % self.subband == 0
assert ch == 1
return (
fbank.squeeze(1)
.reshape(bs, tstep, self.subband, fbins // self.subband)
.permute(0, 2, 1, 3)
)
def freq_merge_subband(self, subband_fbank):
if self.subband == 1 or self.image_key != "stft":
return subband_fbank
assert subband_fbank.size(1) == self.subband # Channel dimension
bs, sub_ch, tstep, fbins = subband_fbank.size()
return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1)
def encode(self, x):
x = self.freq_split_subband(x)
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
@torch.no_grad()
def encode_first_stage(self, x):
return self.encode(x)
def decode(self, z, use_ema=False):
if use_ema and (not hasattr(self, 'ema_decoder') or self.ema_decoder is None):
print("VAE does not have EMA modules, but specified use_ema. "
"Using the none-EMA modules instead.")
if use_ema and hasattr(self, 'ema_decoder') and self.ema_decoder is not None:
z = self.ema_post_quant_conv(z)
dec = self.ema_decoder(z)
else:
z = self.post_quant_conv(z)
dec = self.decoder(z)
return self.freq_merge_subband(dec)
def decode_first_stage(self, z, allow_grad=False, use_ema=False):
with torch.set_grad_enabled(allow_grad):
z = z / self.scale_factor
return self.decode(z, use_ema)
def decode_to_waveform(self, dec, allow_grad=False):
dec = dec.squeeze(1).permute(0, 2, 1)
wav_reconstruction = vocoder_infer(dec, self.vocoder, allow_grad=allow_grad)
return wav_reconstruction
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
z = posterior.sample() if sample_posterior else posterior.mode()
if self.flag_first_run:
print("Latent size: ", z.size())
self.flag_first_run = False
return self.decode(z), posterior
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
)
return self.scale_factor * z