|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import multiprocessing |
|
from pathlib import Path |
|
import typing as tp |
|
|
|
import flashy |
|
import omegaconf |
|
import torch |
|
from torch import nn |
|
|
|
|
|
from .. import models, quantization |
|
from ..utils import checkpoint |
|
from ..utils.samples.manager import SampleManager |
|
from ..utils.utils import get_pool_executor |
|
|
|
|
|
|
|
|
|
|
|
class CompressionSolver(): |
|
"""Solver for compression task. |
|
|
|
The compression task combines a set of perceptual and objective losses |
|
to train an EncodecModel (composed of an encoder-decoder and a quantizer) |
|
to perform high fidelity audio reconstruction. |
|
""" |
|
def __init__(self, cfg: omegaconf.DictConfig): |
|
|
|
self.cfg = cfg |
|
self.rng: torch.Generator |
|
self.adv_losses = builders.get_adversarial_losses(self.cfg) |
|
self.aux_losses = nn.ModuleDict() |
|
self.info_losses = nn.ModuleDict() |
|
assert not cfg.fsdp.use, "FSDP not supported by CompressionSolver." |
|
loss_weights = dict() |
|
for loss_name, weight in self.cfg.losses.items(): |
|
if loss_name in ['adv', 'feat']: |
|
for adv_name, _ in self.adv_losses.items(): |
|
loss_weights[f'{loss_name}_{adv_name}'] = weight |
|
elif weight > 0: |
|
self.aux_losses[loss_name] = builders.get_loss(loss_name, self.cfg) |
|
loss_weights[loss_name] = weight |
|
else: |
|
self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg) |
|
self.balancer = builders.get_balancer(loss_weights, self.cfg.balancer) |
|
self.register_stateful('adv_losses') |
|
|
|
@property |
|
def best_metric_name(self) -> tp.Optional[str]: |
|
|
|
return None |
|
|
|
def build_model(self): |
|
"""Instantiate model and optimizer.""" |
|
|
|
self.model = models.builders.get_compression_model(self.cfg).to(self.device) |
|
self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) |
|
self.register_stateful('model', 'optimizer') |
|
self.register_best_state('model') |
|
self.register_ema('model') |
|
|
|
|
|
|
|
def evaluate(self): |
|
"""Evaluate stage. Runs audio reconstruction evaluation.""" |
|
self.model.eval() |
|
evaluate_stage_name = str(self.current_stage) |
|
|
|
loader = self.dataloaders['evaluate'] |
|
updates = len(loader) |
|
lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates) |
|
average = flashy.averager() |
|
|
|
pendings = [] |
|
ctx = multiprocessing.get_context('spawn') |
|
with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool: |
|
for idx, batch in enumerate(lp): |
|
x = batch.to(self.device) |
|
with torch.no_grad(): |
|
qres = self.model(x) |
|
|
|
y_pred = qres.x.cpu() |
|
y = batch.cpu() |
|
pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg)) |
|
|
|
metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates) |
|
for pending in metrics_lp: |
|
metrics = pending.result() |
|
metrics = average(metrics) |
|
|
|
metrics = flashy.distrib.average_metrics(metrics, len(loader)) |
|
return metrics |
|
|
|
def generate(self): |
|
"""Generate stage.""" |
|
self.model.eval() |
|
sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True) |
|
generate_stage_name = str(self.current_stage) |
|
|
|
loader = self.dataloaders['generate'] |
|
updates = len(loader) |
|
lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) |
|
|
|
for batch in lp: |
|
reference, _ = batch |
|
reference = reference.to(self.device) |
|
with torch.no_grad(): |
|
qres = self.model(reference) |
|
assert isinstance(qres, quantization.QuantizedResult) |
|
|
|
reference = reference.cpu() |
|
estimate = qres.x.cpu() |
|
sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference) |
|
|
|
flashy.distrib.barrier() |
|
|
|
def load_from_pretrained(self, name: str) -> dict: |
|
model = models.CompressionModel.get_pretrained(name) |
|
if isinstance(model, models.DAC): |
|
raise RuntimeError("Cannot fine tune a DAC model.") |
|
elif isinstance(model, models.HFEncodecCompressionModel): |
|
self.logger.warning('Trying to automatically convert a HuggingFace model ' |
|
'to AudioCraft, this might fail!') |
|
state = model.model.state_dict() |
|
new_state = {} |
|
for k, v in state.items(): |
|
if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k: |
|
|
|
layer = int(k.split('.')[2]) |
|
if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d): |
|
|
|
k = k.replace('.conv.', '.convtr.') |
|
k = k.replace('encoder.layers.', 'encoder.model.') |
|
k = k.replace('decoder.layers.', 'decoder.model.') |
|
k = k.replace('conv.', 'conv.conv.') |
|
k = k.replace('convtr.', 'convtr.convtr.') |
|
k = k.replace('quantizer.layers.', 'quantizer.vq.layers.') |
|
k = k.replace('.codebook.', '._codebook.') |
|
new_state[k] = v |
|
state = new_state |
|
elif isinstance(model, models.EncodecModel): |
|
state = model.state_dict() |
|
else: |
|
raise RuntimeError(f"Cannot fine tune model type {type(model)}.") |
|
return { |
|
'best_state': {'model': state} |
|
} |
|
|
|
@staticmethod |
|
def model_from_checkpoint(checkpoint_path: tp.Union[Path, str], |
|
device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: |
|
"""Instantiate a CompressionModel from a given checkpoint path or dora sig. |
|
This method is a convenient endpoint to load a CompressionModel to use in other solvers. |
|
|
|
Args: |
|
checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. |
|
This also supports pre-trained models by using a path of the form //pretrained/NAME. |
|
See `model_from_pretrained` for a list of supported pretrained models. |
|
use_ema (bool): Use EMA variant of the model instead of the actual model. |
|
device (torch.device or str): Device on which the model is loaded. |
|
""" |
|
checkpoint_path = str(checkpoint_path) |
|
if checkpoint_path.startswith('//pretrained/'): |
|
name = checkpoint_path.split('/', 3)[-1] |
|
return models.CompressionModel.get_pretrained(name, device) |
|
logger = logging.getLogger(__name__) |
|
logger.info(f"Loading compression model from checkpoint: {checkpoint_path}") |
|
_checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False) |
|
assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}" |
|
state = checkpoint.load_checkpoint(_checkpoint_path) |
|
assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}" |
|
cfg = state['xp.cfg'] |
|
cfg.device = device |
|
compression_model = models.builders.get_compression_model(cfg).to(device) |
|
assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match" |
|
|
|
assert 'best_state' in state and state['best_state'] != {} |
|
assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix." |
|
compression_model.load_state_dict(state['best_state']['model']) |
|
compression_model.eval() |
|
logger.info("Compression model loaded!") |
|
return compression_model |
|
|
|
@staticmethod |
|
def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig, |
|
checkpoint_path: tp.Union[Path, str], |
|
device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: |
|
"""Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig. |
|
|
|
Args: |
|
cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode. |
|
checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. |
|
use_ema (bool): Use EMA variant of the model instead of the actual model. |
|
device (torch.device or str): Device on which the model is loaded. |
|
""" |
|
compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device) |
|
compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg) |
|
return compression_model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import typing as tp |
|
|
|
import torch |
|
import julius |
|
|
|
from .unet import DiffusionUnet |
|
from ..modules.diffusion_schedule import NoiseSchedule |
|
from .encodec import CompressionModel |
|
from .loaders import load_compression_model, load_diffusion_models |
|
|
|
|
|
class DiffusionProcess: |
|
"""Sampling for a diffusion Model. |
|
|
|
Args: |
|
model (DiffusionUnet): Diffusion U-Net model. |
|
noise_schedule (NoiseSchedule): Noise schedule for diffusion process. |
|
""" |
|
def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None: |
|
self.model = model |
|
self.schedule = noise_schedule |
|
|
|
def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor, |
|
step_list: tp.Optional[tp.List[int]] = None): |
|
"""Perform one diffusion process to generate one of the bands. |
|
|
|
Args: |
|
condition (torch.Tensor): The embeddings from the compression model. |
|
initial_noise (torch.Tensor): The initial noise to start the process. |
|
""" |
|
return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list, |
|
condition=condition) |
|
|
|
|
|
class MultiBandDiffusion: |
|
"""Sample from multiple diffusion models. |
|
|
|
Args: |
|
DPs (list of DiffusionProcess): Diffusion processes. |
|
codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens. |
|
""" |
|
def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None: |
|
self.DPs = DPs |
|
self.codec_model = codec_model |
|
self.device = next(self.codec_model.parameters()).device |
|
|
|
@property |
|
def sample_rate(self) -> int: |
|
return self.codec_model.sample_rate |
|
|
|
@staticmethod |
|
def get_mbd_musicgen(device=None): |
|
"""Load our diffusion models trained for MusicGen.""" |
|
if device is None: |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
path = 'facebook/multiband-diffusion' |
|
filename = 'mbd_musicgen_32khz.th' |
|
name = 'facebook/musicgen-small' |
|
codec_model = load_compression_model(name, device=device) |
|
models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device) |
|
DPs = [] |
|
for i in range(len(models)): |
|
schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device) |
|
DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule)) |
|
return MultiBandDiffusion(DPs=DPs, codec_model=codec_model) |
|
|
|
@staticmethod |
|
def get_mbd_24khz(bw: float = 3.0, |
|
device: tp.Optional[tp.Union[torch.device, str]] = None, |
|
n_q: tp.Optional[int] = None): |
|
"""Get the pretrained Models for MultibandDiffusion. |
|
|
|
Args: |
|
bw (float): Bandwidth of the compression model. |
|
device (torch.device or str, optional): Device on which the models are loaded. |
|
n_q (int, optional): Number of quantizers to use within the compression model. |
|
""" |
|
if device is None: |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available" |
|
if n_q is not None: |
|
assert n_q in [2, 4, 8] |
|
assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \ |
|
f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}" |
|
n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw] |
|
codec_model = CompressionSolver.model_from_checkpoint( |
|
'//pretrained/facebook/encodec_24khz', device=device) |
|
codec_model.set_num_codebooks(n_q) |
|
codec_model = codec_model.to(device) |
|
path = 'facebook/multiband-diffusion' |
|
filename = f'mbd_comp_{n_q}.pt' |
|
models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device) |
|
DPs = [] |
|
for i in range(len(models)): |
|
schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device) |
|
DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule)) |
|
return MultiBandDiffusion(DPs=DPs, codec_model=codec_model) |
|
|
|
@torch.no_grad() |
|
def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: |
|
"""Get the conditioning (i.e. latent representations of the compression model) from a waveform. |
|
Args: |
|
wav (torch.Tensor): The audio that we want to extract the conditioning from. |
|
sample_rate (int): Sample rate of the audio.""" |
|
if sample_rate != self.sample_rate: |
|
wav = julius.resample_frac(wav, sample_rate, self.sample_rate) |
|
codes, scale = self.codec_model.encode(wav) |
|
assert scale is None, "Scaled compression models not supported." |
|
emb = self.get_emb(codes) |
|
return emb |
|
|
|
@torch.no_grad() |
|
def get_emb(self, codes: torch.Tensor): |
|
"""Get latent representation from the discrete codes. |
|
Args: |
|
codes (torch.Tensor): Discrete tokens.""" |
|
emb = self.codec_model.decode_latent(codes) |
|
return emb |
|
|
|
def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None, |
|
step_list: tp.Optional[tp.List[int]] = None): |
|
"""Generate waveform audio from the latent embeddings of the compression model. |
|
Args: |
|
emb (torch.Tensor): Conditioning embeddings |
|
size (None, torch.Size): Size of the output |
|
if None this is computed from the typical upsampling of the model. |
|
step_list (list[int], optional): list of Markov chain steps, defaults to 50 linearly spaced step. |
|
""" |
|
if size is None: |
|
upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate) |
|
size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling]) |
|
assert size[0] == emb.size(0) |
|
out = torch.zeros(size).to(self.device) |
|
for DP in self.DPs: |
|
out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out)) |
|
return out |
|
|
|
def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1): |
|
"""Match the eq to the encodec output by matching the standard deviation of some frequency bands. |
|
Args: |
|
wav (torch.Tensor): Audio to equalize. |
|
ref (torch.Tensor): Reference audio from which we match the spectrogram. |
|
n_bands (int): Number of bands of the eq. |
|
strictness (float): How strict the matching. 0 is no matching, 1 is exact matching. |
|
""" |
|
split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device) |
|
bands = split(wav) |
|
bands_ref = split(ref) |
|
out = torch.zeros_like(ref) |
|
for i in range(n_bands): |
|
out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness |
|
return out |
|
|
|
def regenerate(self, wav: torch.Tensor, sample_rate: int): |
|
"""Regenerate a waveform through compression and diffusion regeneration. |
|
Args: |
|
wav (torch.Tensor): Original 'ground truth' audio. |
|
sample_rate (int): Sample rate of the input (and output) wav. |
|
""" |
|
if sample_rate != self.codec_model.sample_rate: |
|
wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate) |
|
emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate) |
|
size = wav.size() |
|
out = self.generate(emb, size=size) |
|
if sample_rate != self.codec_model.sample_rate: |
|
out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate) |
|
return out |
|
|
|
def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32): |
|
"""Generate Waveform audio with diffusion from the discrete codes. |
|
Args: |
|
tokens (torch.Tensor): Discrete codes. |
|
n_bands (int): Bands for the eq matching. |
|
""" |
|
wav_encodec = self.codec_model.decode(tokens) |
|
condition = self.get_emb(tokens) |
|
wav_diffusion = self.generate(emb=condition, size=wav_encodec.size()) |
|
return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands) |
|
|