Spaces:
Runtime error
Runtime error
File size: 4,415 Bytes
66982e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
from torch import nn
from audioldm.variational_autoencoder.modules import Encoder, Decoder
from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution
from audioldm.hifigan.utilities import get_vocoder, vocoder_infer
class AutoencoderKL(nn.Module):
def __init__(
self,
ddconfig=None,
lossconfig=None,
image_key="fbank",
embed_dim=None,
time_shuffle=1,
subband=1,
ckpt_path=None,
reload_from_ckpt=None,
ignore_keys=[],
colorize_nlabels=None,
monitor=None,
base_learning_rate=1e-5,
scale_factor=1
):
super().__init__()
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.ema_decoder = None
self.subband = int(subband)
if self.subband > 1:
print("Use subband decomposition %s" % self.subband)
self.quant_conv = nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.ema_post_quant_conv = None
self.vocoder = get_vocoder(None, "cpu")
self.embed_dim = embed_dim
if monitor is not None:
self.monitor = monitor
self.time_shuffle = time_shuffle
self.reload_from_ckpt = reload_from_ckpt
self.reloaded = False
self.mean, self.std = None, None
self.scale_factor = scale_factor
@property
def device(self):
return next(self.parameters()).device
def freq_split_subband(self, fbank):
if self.subband == 1 or self.image_key != "stft":
return fbank
bs, ch, tstep, fbins = fbank.size()
assert fbank.size(-1) % self.subband == 0
assert ch == 1
return (
fbank.squeeze(1)
.reshape(bs, tstep, self.subband, fbins // self.subband)
.permute(0, 2, 1, 3)
)
def freq_merge_subband(self, subband_fbank):
if self.subband == 1 or self.image_key != "stft":
return subband_fbank
assert subband_fbank.size(1) == self.subband # Channel dimension
bs, sub_ch, tstep, fbins = subband_fbank.size()
return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1)
def encode(self, x):
x = self.freq_split_subband(x)
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
@torch.no_grad()
def encode_first_stage(self, x):
return self.encode(x)
def decode(self, z, use_ema=False):
if use_ema and (not hasattr(self, 'ema_decoder') or self.ema_decoder is None):
print("VAE does not have EMA modules, but specified use_ema. "
"Using the none-EMA modules instead.")
if use_ema and hasattr(self, 'ema_decoder') and self.ema_decoder is not None:
z = self.ema_post_quant_conv(z)
dec = self.ema_decoder(z)
else:
z = self.post_quant_conv(z)
dec = self.decoder(z)
return self.freq_merge_subband(dec)
def decode_first_stage(self, z, allow_grad=False, use_ema=False):
with torch.set_grad_enabled(allow_grad):
z = z / self.scale_factor
return self.decode(z, use_ema)
def decode_to_waveform(self, dec, allow_grad=False):
dec = dec.squeeze(1).permute(0, 2, 1)
wav_reconstruction = vocoder_infer(dec, self.vocoder, allow_grad=allow_grad)
return wav_reconstruction
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
z = posterior.sample() if sample_posterior else posterior.mode()
if self.flag_first_run:
print("Latent size: ", z.size())
self.flag_first_run = False
return self.decode(z), posterior
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
)
return self.scale_factor * z
|