|
import torch |
|
import torchaudio |
|
from einops import rearrange |
|
from stable_audio_tools import get_pretrained_model |
|
from stable_audio_tools.inference.generation import generate_diffusion_cond |
|
from pydub import AudioSegment |
|
import re |
|
import os |
|
from datetime import datetime |
|
import gradio as gr |
|
|
|
|
|
def generate_audio(prompt, steps, cfg_scale, sigma_min, sigma_max, generation_time, seed, sampler_type, model_half): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model, model_config = get_pretrained_model("audo/stable-audio-open-1.0") |
|
sample_rate = model_config["sample_rate"] |
|
sample_size = model_config["sample_size"] |
|
|
|
model = model.to(device) |
|
|
|
|
|
print("Model data type before conversion:", next(model.parameters()).dtype) |
|
|
|
|
|
if model_half: |
|
model = model.to(torch.float16) |
|
|
|
|
|
print("Model data type after conversion:", next(model.parameters()).dtype) |
|
|
|
|
|
conditioning = [{ |
|
"prompt": prompt, |
|
"seconds_start": 0, |
|
"seconds_total": generation_time |
|
}] |
|
|
|
|
|
output = generate_diffusion_cond( |
|
model, |
|
steps=steps, |
|
cfg_scale=cfg_scale, |
|
conditioning=conditioning, |
|
sample_size=sample_size, |
|
sigma_min=sigma_min, |
|
sigma_max=sigma_max, |
|
sampler_type=sampler_type, |
|
device=device, |
|
seed=seed |
|
) |
|
|
|
|
|
print("Output data type:", output.dtype) |
|
|
|
|
|
output = rearrange(output, "b d n -> d (b n)") |
|
|
|
|
|
output = output.div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767) |
|
if model_half: |
|
output = output.to(torch.int16).cpu() |
|
else: |
|
output = output.to(torch.float32).to(torch.int16).cpu() |
|
|
|
torchaudio.save("temp_output.wav", output, sample_rate) |
|
|
|
|
|
audio = AudioSegment.from_wav("temp_output.wav") |
|
|
|
|
|
output_folder = "Output" |
|
date_folder = datetime.now().strftime("%Y-%m-%d") |
|
save_path = os.path.join(output_folder, date_folder) |
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
|
|
filename = re.sub(r'\W+', '_', prompt) + ".mp3" |
|
full_path = os.path.join(save_path, filename) |
|
|
|
|
|
base_filename = filename |
|
counter = 1 |
|
while os.path.exists(full_path): |
|
filename = f"{base_filename[:-4]}_{counter}.mp3" |
|
full_path = os.path.join(save_path, filename) |
|
counter += 1 |
|
|
|
|
|
audio.export(full_path, format="mp3") |
|
|
|
return full_path |
|
|
|
def audio_generator(prompt, sampler_type, steps, cfg_scale, sigma_min, sigma_max, generation_time, seed, model_half): |
|
try: |
|
print("Generating audio with parameters:") |
|
print("Prompt:", prompt) |
|
print("Sampler Type:", sampler_type) |
|
print("Steps:", steps) |
|
print("CFG Scale:", cfg_scale) |
|
print("Sigma Min:", sigma_min) |
|
print("Sigma Max:", sigma_max) |
|
print("Generation Time:", generation_time) |
|
print("Seed:", seed) |
|
print("Model Half Precision:", model_half) |
|
|
|
filename = generate_audio(prompt, steps, cfg_scale, sigma_min, sigma_max, generation_time, seed, sampler_type, model_half) |
|
return gr.Audio(filename), f"Generated: {filename}" |
|
except Exception as e: |
|
return str(e) |
|
|
|
|
|
prompt_textbox = gr.Textbox(lines=5, label="Prompt") |
|
sampler_dropdown = gr.Dropdown( |
|
label="Sampler Type", |
|
choices=[ |
|
"dpmpp-3m-sde", |
|
"dpmpp-2m-sde", |
|
"k-heun", |
|
"k-lms", |
|
"k-dpmpp-2s-ancestral", |
|
"k-dpm-2", |
|
"k-dpm-fast" |
|
], |
|
value="dpmpp-3m-sde" |
|
) |
|
steps_slider = gr.Slider(minimum=0, maximum=200, label="Steps", step=1, value=100) |
|
cfg_scale_slider = gr.Slider(minimum=0, maximum=15, label="CFG Scale", step=0.1, value=7) |
|
sigma_min_slider = gr.Slider(minimum=0, maximum=50, label="Sigma Min", step=0.1, value=0.3) |
|
sigma_max_slider = gr.Slider(minimum=0, maximum=1000, label="Sigma Max", step=0.1, value=500) |
|
generation_time_slider = gr.Slider(minimum=0, maximum=47, label="Generation Time (seconds)", step=1, value=47) |
|
seed_slider = gr.Slider(minimum=-1, maximum=999999, label="Seed", step=1, value=123456) |
|
model_half_checkbox = gr.Checkbox(label="Low VRAM (float16)", value=False) |
|
|
|
output_textbox = gr.Textbox(label="Output") |
|
|
|
title = "ππ StableAudioWebUI ππ" |
|
description = "[Github Repository](https://github.com/Saganaki22/StableAudioWebUI)" |
|
|
|
gr.Interface( |
|
audio_generator, |
|
[prompt_textbox, sampler_dropdown, steps_slider, cfg_scale_slider, sigma_min_slider, sigma_max_slider, generation_time_slider, seed_slider, model_half_checkbox], |
|
[gr.Audio(), output_textbox], |
|
title=title, |
|
description=description |
|
).launch() |
|
|