from ...hparams import HParams from .base import Chain, Choice, Permutation from .custom import RandomGaussianNoise, RandomRIR class Distorter(Chain): def __init__(self, hp: HParams, training: bool = False, mode: str = "enhancer"): # Lazy import from .sox import RandomBandpassDistorter, RandomEqualizer, RandomLowpassDistorter, RandomOverdrive, RandomReverb if training: permutation = Permutation( RandomRIR(hp.rir_dir), RandomReverb(), RandomGaussianNoise(), RandomOverdrive(), RandomEqualizer(), Choice( RandomLowpassDistorter(), RandomBandpassDistorter(), ), ) if mode == "denoiser": super().__init__(permutation) else: # 80%: distortion, 20%: clean super().__init__(Choice(permutation, Chain(), p=[0.8, 0.2])) else: super().__init__( RandomRIR(hp.rir_dir, deterministic=True), RandomReverb(deterministic=True), )