hhguo's picture
update
37ced70
from fireredtts.modules.flow.codec_embedding import HHGCodecEmbedding
from fireredtts.modules.flow.conformer import ConformerDecoderV2
from fireredtts.modules.flow.mel_encoder import MelReduceEncoder
from fireredtts.modules.flow.decoder import ConditionalCFM, ConditionalDecoder
from fireredtts.modules.flow.flow_model import InterpolateRegulator, CrossAttnFlowMatching
from fireredtts.modules.flow.mel_spectrogram import MelSpectrogramExtractor
def get_flow_frontend(flow_config):
flow = CrossAttnFlowMatching(
output_size=flow_config["output_size"],
input_embedding=HHGCodecEmbedding(**flow_config["input_embedding"]),
encoder=ConformerDecoderV2(**flow_config["encoder"]),
length_regulator=InterpolateRegulator(**flow_config["length_regulator"]),
mel_encoder=MelReduceEncoder(**flow_config["mel_encoder"]),
decoder=ConditionalCFM(
estimator=ConditionalDecoder(**flow_config["decoder"]["estimator"]),
t_scheduler=flow_config["decoder"]["t_scheduler"],
inference_cfg_rate=flow_config["decoder"]["inference_cfg_rate"]
)
)
return flow