File size: 2,780 Bytes
740a68f 8a51868 88efd7e a297d3b 2d14ce1 88efd7e d0dbc2d 88efd7e 8a4b264 0a67d6d 8a51868 740a68f 2d14ce1 c0fc3e6 740a68f 793652a c0fc3e6 8a4b264 a297d3b 8a4b264 a297d3b 8a4b264 c0fc3e6 1acbf3a 8a4b264 a297d3b 8a4b264 a297d3b 8a4b264 a297d3b 8a4b264 0a67d6d 8a4b264 c0fc3e6 88efd7e 740a68f 793652a 740a68f 88efd7e 38a61d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
import torchaudio
import gradio as gr
from sgmse.model import ScoreModel
from sgmse.util.other import pad_spec
import time # Import the time module
import os
# Define parameters based on the configuration in enhancement.py
args = {
"test_dir": "./test_data", # example directory, adjust as needed
"enhanced_dir": "./enhanced_data", # example directory, adjust as needed
"ckpt": "https://huggingface.co/sp-uhh/speech-enhancement-sgmse/resolve/main/train_vb_29nqe0uh_epoch%3D115.ckpt",
"corrector": "ald",
"corrector_steps": 1,
"snr": 0.5,
"N": 30,
"device": "cuda" if torch.cuda.is_available() else "cpu"
}
# Ensure the model is loaded to the correct device
model = ScoreModel.load_from_checkpoint(args["ckpt"]).to(args["device"])
def enhance_speech(audio_file):
start_time = time.time() # Start the timer
# Load and process the audio file
y, sr = torchaudio.load(audio_file) # Gradio passes the file path
print(f"Loaded audio in {time.time() - start_time:.2f}s")
T_orig = y.size(1)
# Normalize
norm_factor = y.abs().max()
y = y / norm_factor
# Prepare DNN input
Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args["device"]))), 0)
print(f"Transformed input in {time.time() - start_time:.2f}s")
Y = pad_spec(Y, mode="zero_pad") # Use "zero_pad" mode for padding
# Reverse sampling
sampler = model.get_pc_sampler(
'reverse_diffusion', args["corrector"], Y.to(args["device"]),
N=args["N"], corrector_steps=args["corrector_steps"], snr=args["snr"]
)
sample, _ = sampler()
# Backward transform in time domain
x_hat = model.to_audio(sample.squeeze(), T_orig)
# Renormalize
x_hat = x_hat * norm_factor
# Create a temporary path for saving the enhanced audio in Hugging Face Space
output_file = "/tmp/enhanced_output.wav" # Use a temporary directory
torchaudio.save(output_file, x_hat.cpu(), sr)
print(f"Processed audio in {time.time() - start_time:.2f}s")
# Return the path to the enhanced file for Gradio to handle
return output_file
# Gradio interface setup
inputs = gr.Audio(label="Input Audio", type="filepath") # Adjusted to 'filepath'
outputs = gr.Audio(label="Enhanced Audio", type="filepath") # Output as filepath
title = "Speech Enhancement using SGMSE"
description = "This Gradio demo uses the SGMSE model for speech enhancement. Upload your audio file to enhance it."
article = "<p style='text-align: center'><a href='https://huggingface.co/SP-UHH/speech-enhancement-sgmse' target='_blank'>Model Card</a></p>"
# Launch the Gradio interface
gr.Interface(fn=enhance_speech, inputs=inputs, outputs=outputs, title=title, description=description, article=article).launch()
|