import torch import gradio as gr import soundfile as sf import numpy as np import random, os from consistencytta import ConsistencyTTA def seed_all(seed): """ Seed all random number generators. """ seed = int(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.cuda.random.manual_seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True device = torch.device( "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) sr = 16000 # Build ConsistencyTTA model consistencytta = ConsistencyTTA().to(device) consistencytta.eval() consistencytta.requires_grad_(False) def generate(prompt: str, seed: str = '', cfg_weight: float = 4.): """ Generate audio from a given prompt. Args: prompt (str): Text prompt to generate audio from. seed (str, optional): Random seed. Defaults to '', which means no seed. """ if seed != '': try: seed_all(int(seed)) except: pass with torch.no_grad(): with torch.autocast( device_type="cuda", dtype=torch.bfloat16, enabled=torch.cuda.is_available() ): wav = consistencytta( [prompt], num_steps=1, cfg_scale_input=cfg_weight, cfg_scale_post=1., sr=sr ) sf.write("output.wav", wav.T, samplerate=sr, subtype='PCM_16') return "output.wav" # Generate test audio print("Generating test audio...") generate("A dog barks as a train passes by.", seed=1) print("Test audio generated successfully! Starting Gradio interface...") # Launch Gradio interface iface = gr.Interface( fn=generate, inputs=[ gr.Textbox( label="Text", value="Several people cheer and scream and speak as water flows hard." ), gr.Textbox(label="Random Seed (Optional)", value=''), gr.Slider( minimum=0., maximum=8., value=3.5, label="Classifier-Free Guidance Strength" )], outputs="audio", title="ConsistencyTTA: Accelerating Diffusion-Based Text-to-Audio " \ "Generation with Consistency Distillation", description="This is the official demo page for ConsistencyTTA, a model that accelerates " \ "diffusion-based text-to-audio generation hundreds of times with consistency " \ "models.
Here, the audio is generated within a single non-autoregressive " \ "forward pass from the CLAP-finetuned ConsistencyTTA checkpoint.
Since " \ "the training dataset does not include speech, the model is not expected to " \ "generate coherent speech.
Have fun!" ) iface.launch(share=True)