import typing as tp import omegaconf from torch import nn import torch from huggingface_hub import hf_hub_download import os from omegaconf import OmegaConf, DictConfig from .encodec import EncodecModel from .lm import LMModel from .seanet import SEANetDecoder from .codebooks_patterns import DelayedPatternProvider from .conditioners import ( ConditioningProvider, T5Conditioner, ConditioningAttributes ) from .vq import ResidualVectorQuantizer def _delete_param(cfg: DictConfig, full_name: str): parts = full_name.split('.') for part in parts[:-1]: if part in cfg: cfg = cfg[part] else: return OmegaConf.set_struct(cfg, False) if parts[-1] in cfg: del cfg[parts[-1]] OmegaConf.set_struct(cfg, True) def dict_from_config(cfg): dct = omegaconf.OmegaConf.to_container(cfg, resolve=True) return dct # ============================================== DEFINE AUDIOGEN class AudioGen(nn.Module): # https://huggingface.co/facebook/audiogen-medium def __init__(self, duration=0.024, device='cpu'): super().__init__() self.device = device # needed for loading & select float16 LM self.load_compression_model() self.load_lm_model() self.duration = duration @property def frame_rate(self): return self.compression_model.frame_rate def generate(self, descriptions): with torch.no_grad(): attributes = [ ConditioningAttributes(text={'description': d}) for d in descriptions] gen_tokens = self.lm.generate( conditions=attributes, max_gen_len=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw] x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840] # print('______________\nGENTOk 5', gen_tokens) print('GENAUD 5', x.sum(), x.shape) return x / x.abs().max(2, keepdims=True)[0] + 1e-7 # == BUILD Fn def get_quantizer(self, quantizer, cfg, dimension): klass = { 'no_quant': None, 'rvq': ResidualVectorQuantizer }[quantizer] kwargs = dict_from_config(getattr(cfg, quantizer)) if quantizer != 'no_quant': kwargs['dimension'] = dimension return klass(**kwargs) def get_encodec_autoencoder(self, cfg): kwargs = dict_from_config(getattr(cfg, 'seanet')) _ = kwargs.pop('encoder') decoder_override_kwargs = kwargs.pop('decoder') decoder_kwargs = {**kwargs, **decoder_override_kwargs} decoder = SEANetDecoder(**decoder_kwargs) return decoder def get_compression_model(self, cfg): """Instantiate a compression model.""" if cfg.compression_model == 'encodec': kwargs = dict_from_config(getattr(cfg, 'encodec')) quantizer_name = kwargs.pop('quantizer') decoder = self.get_encodec_autoencoder(cfg) quantizer = self.get_quantizer(quantizer_name, cfg, 128) renormalize = kwargs.pop('renormalize', False) # deprecated params # print(f'{frame_rate=} {encoder.dimension=}') frame_rate=50 encoder.dimension=128 kwargs.pop('renorm', None) # print('\n______!____________\n', kwargs, '\n______!____________\n') # ______!____________ # {'autoencoder': 'seanet', 'sample_rate': 16000, 'channels': 1, 'causal': False} # ______!____________ return EncodecModel(decoder=decoder, quantizer=quantizer, frame_rate=50, renormalize=renormalize, sample_rate=16000, channels=1, causal=False ).to(cfg.device) else: raise KeyError(f"Unexpected compression model {cfg.compression_model}") def get_lm_model(self, cfg): """Instantiate a transformer LM.""" if cfg.lm_model in ['transformer_lm', 'transformer_lm_magnet']: kwargs = dict_from_config(getattr(cfg, 'transformer_lm')) n_q = kwargs['n_q'] q_modeling = kwargs.pop('q_modeling', None) codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern') attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout')) cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance')) cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef'] condition_provider = self.get_conditioner_provider(kwargs["dim"], cfg ).to(self.device) # if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically kwargs['cross_attention'] = True if codebooks_pattern_cfg.modeling is None: print('Q MODELING\n=\n=><') assert q_modeling is not None, \ "LM model should either have a codebook pattern defined or transformer_lm.q_modeling" codebooks_pattern_cfg = omegaconf.OmegaConf.create( {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}} ) pattern_provider = self.get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg) return LMModel( pattern_provider=pattern_provider, condition_provider=condition_provider, cfg_dropout=cfg_prob, cfg_coef=cfg_coef, attribute_dropout=attribute_dropout, dtype=getattr(torch, cfg.dtype), device=self.device, **kwargs ).to(cfg.device) else: raise KeyError(f"Unexpected LM model {cfg.lm_model}") def get_conditioner_provider(self, output_dim, cfg): """Instantiate T5 text""" cfg = getattr(cfg, 'conditioners') dict_cfg = {} if cfg is None else dict_from_config(cfg) conditioners={} condition_provider_args = dict_cfg.pop('args', {}) condition_provider_args.pop('merge_text_conditions_p', None) condition_provider_args.pop('drop_desc_p', None) for cond, cond_cfg in dict_cfg.items(): model_type = cond_cfg['model'] model_args = cond_cfg[model_type] if model_type == 't5': conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=self.device, **model_args) else: raise ValueError(f"Unrecognized conditioning model: {model_type}") # print(f'{condition_provider_args=}') return ConditioningProvider(conditioners) def get_codebooks_pattern_provider(self, n_q, cfg): pattern_providers = { 'delay': DelayedPatternProvider, # THIS } name = cfg.modeling kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {} klass = pattern_providers[name] return klass(n_q, **kwargs) # ====================== def load_compression_model(self): file = hf_hub_download( repo_id='facebook/audiogen-medium', filename="compression_state_dict.bin", cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None), library_name="audiocraft", library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__) pkg = torch.load(file, map_location='cpu') # if 'pretrained' in pkg: # print('NO RPtrained\n=\n=\n=\n=\n=') # return EncodecModel.get_pretrained(pkg['pretrained'], device='cpu') cfg = OmegaConf.create(pkg['xp.cfg']) cfg.device = 'cpu' model = self.get_compression_model(cfg) model.load_state_dict(pkg['best_state'], strict=False) # ckpt has also unused encoder weights # return model self.compression_model = model def load_lm_model(self): file = hf_hub_download( repo_id='facebook/audiogen-medium', filename="state_dict.bin", cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None), library_name="audiocraft", library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__) pkg = torch.load(file, map_location=self.device) #'cpu') cfg = OmegaConf.create(pkg['xp.cfg']) # cfg.device = 'cpu' if self.device == 'cpu': cfg.dtype = 'float32' else: cfg.dtype = 'float16' _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path') _delete_param(cfg, 'conditioners.args.merge_text_conditions_p') _delete_param(cfg, 'conditioners.args.drop_desc_p') model = self.get_lm_model(cfg) model.load_state_dict(pkg['best_state']) model.cfg = cfg # return model self.lm = model.to(torch.float)