|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Main model for using AudioGen. This will combine all the required components |
|
and provide easy access to the generation API. |
|
""" |
|
|
|
import typing as tp |
|
import torch |
|
|
|
from audiocraft.encodec import CompressionModel |
|
from audiocraft.genmodel import BaseGenModel |
|
from audiocraft.lm import LMModel |
|
from audiocraft.loaders import load_compression_model, load_lm_model |
|
from .utils.audio_utils import f32_pcm, normalize_audio |
|
|
|
|
|
def audio_write(stem_name, |
|
wav, |
|
sample_rate, |
|
format= 'wav', |
|
mp3_rate=320, |
|
ogg_rate= None, |
|
normalize= True, |
|
strategy= 'peak', |
|
peak_clip_headroom_db=1, |
|
rms_headroom_db= 18, |
|
loudness_headroom_db = 14, |
|
loudness_compressor = False, |
|
log_clipping = True, |
|
make_parent_dir = True, |
|
add_suffix = True): |
|
|
|
assert wav.dtype.is_floating_point, "wav is not floating point" |
|
if wav.dim() == 1: |
|
wav = wav[None] |
|
elif wav.dim() > 2: |
|
raise ValueError("Input wav should be at most 2 dimension.") |
|
assert wav.isfinite().all() |
|
wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db, |
|
rms_headroom_db, loudness_headroom_db, loudness_compressor, |
|
log_clipping=log_clipping, sample_rate=sample_rate, |
|
stem_name=str(stem_name)) |
|
return wav |
|
|
|
|
|
class AudioGen(BaseGenModel): |
|
"""AudioGen main model with convenient generation API. |
|
|
|
Args: |
|
name (str): name of the model. |
|
compression_model (CompressionModel): Compression model |
|
used to map audio to invertible discrete representations. |
|
lm (LMModel): Language model over discrete representations. |
|
max_duration (float, optional): maximum duration the model can produce, |
|
otherwise, inferred from the training params. |
|
""" |
|
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, |
|
max_duration: tp.Optional[float] = None): |
|
|
|
super().__init__(name, compression_model, lm, max_duration) |
|
self.set_generation_params(duration=5) |
|
|
|
@staticmethod |
|
def get_pretrained(name: str = 'facebook/audiogen-medium', device=None): |
|
"""Return pretrained model, we provide a single model for now: |
|
- facebook/audiogen-medium (1.5B), text to sound, |
|
# see: https://huggingface.co/facebook/audiogen-medium |
|
""" |
|
if device is None: |
|
if torch.cuda.device_count(): |
|
device = 'cuda' |
|
else: |
|
device = 'cpu' |
|
|
|
|
|
|
|
compression_model = load_compression_model(name, device=device) |
|
lm = load_lm_model(name, device=device) |
|
assert 'self_wav' not in lm.condition_provider.conditioners, \ |
|
"AudioGen do not support waveform conditioning for now" |
|
return AudioGen(name, compression_model, lm) |
|
|
|
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, |
|
top_p: float = 0.0, temperature: float = 1.0, |
|
duration: float = 10.0, cfg_coef: float = 2.4, |
|
two_step_cfg: bool = False, extend_stride: float = 2): |
|
"""Set the generation parameters for AudioGen. |
|
|
|
Args: |
|
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True. |
|
top_k (int, optional): top_k used for sampling. Defaults to 250. |
|
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0. |
|
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0. |
|
duration (float, optional): Duration of the generated waveform. Defaults to 10.0. |
|
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0. |
|
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance, |
|
instead of batching together the two. This has some impact on how things |
|
are padded but seems to have little impact in practice. |
|
extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much |
|
should we extend the audio each time. Larger values will mean less context is |
|
preserved, and shorter value will require extra computations. |
|
""" |
|
assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration." |
|
self.extend_stride = extend_stride |
|
self.duration = duration |
|
self.generation_params = { |
|
'use_sampling': use_sampling, |
|
'temp': temperature, |
|
'top_k': top_k, |
|
'top_p': top_p, |
|
'cfg_coef': cfg_coef, |
|
'two_step_cfg': two_step_cfg, |
|
} |
|
|