File size: 2,780 Bytes
740a68f
 
8a51868
88efd7e
a297d3b
2d14ce1
88efd7e
d0dbc2d
88efd7e
8a4b264
 
 
 
 
 
 
 
 
 
 
0a67d6d
 
8a51868
740a68f
2d14ce1
c0fc3e6
740a68f
793652a
c0fc3e6
8a4b264
a297d3b
 
 
 
8a4b264
a297d3b
8a4b264
c0fc3e6
 
1acbf3a
8a4b264
a297d3b
 
8a4b264
 
 
a297d3b
8a4b264
a297d3b
 
 
 
 
8a4b264
0a67d6d
 
8a4b264
c0fc3e6
 
88efd7e
 
740a68f
 
 
793652a
 
740a68f
 
 
 
88efd7e
38a61d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torchaudio
import gradio as gr
from sgmse.model import ScoreModel
from sgmse.util.other import pad_spec
import time  # Import the time module
import os

# Define parameters based on the configuration in enhancement.py
args = {
    "test_dir": "./test_data",  # example directory, adjust as needed
    "enhanced_dir": "./enhanced_data",  # example directory, adjust as needed
    "ckpt": "https://huggingface.co/sp-uhh/speech-enhancement-sgmse/resolve/main/train_vb_29nqe0uh_epoch%3D115.ckpt",
    "corrector": "ald",
    "corrector_steps": 1,
    "snr": 0.5,
    "N": 30,
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

# Ensure the model is loaded to the correct device
model = ScoreModel.load_from_checkpoint(args["ckpt"]).to(args["device"])

def enhance_speech(audio_file):
    start_time = time.time()  # Start the timer

    # Load and process the audio file
    y, sr = torchaudio.load(audio_file)  # Gradio passes the file path
    print(f"Loaded audio in {time.time() - start_time:.2f}s")
    T_orig = y.size(1)

    # Normalize
    norm_factor = y.abs().max()
    y = y / norm_factor

    # Prepare DNN input
    Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args["device"]))), 0)
    print(f"Transformed input in {time.time() - start_time:.2f}s")

    Y = pad_spec(Y, mode="zero_pad")  # Use "zero_pad" mode for padding

    # Reverse sampling
    sampler = model.get_pc_sampler(
        'reverse_diffusion', args["corrector"], Y.to(args["device"]),
        N=args["N"], corrector_steps=args["corrector_steps"], snr=args["snr"]
    )
    sample, _ = sampler()

    # Backward transform in time domain
    x_hat = model.to_audio(sample.squeeze(), T_orig)

    # Renormalize
    x_hat = x_hat * norm_factor

    # Create a temporary path for saving the enhanced audio in Hugging Face Space
    output_file = "/tmp/enhanced_output.wav"  # Use a temporary directory
    torchaudio.save(output_file, x_hat.cpu(), sr)

    print(f"Processed audio in {time.time() - start_time:.2f}s")
    
    # Return the path to the enhanced file for Gradio to handle
    return output_file

# Gradio interface setup
inputs = gr.Audio(label="Input Audio", type="filepath")  # Adjusted to 'filepath'
outputs = gr.Audio(label="Enhanced Audio", type="filepath")  # Output as filepath
title = "Speech Enhancement using SGMSE"
description = "This Gradio demo uses the SGMSE model for speech enhancement. Upload your audio file to enhance it."
article = "<p style='text-align: center'><a href='https://huggingface.co/SP-UHH/speech-enhancement-sgmse' target='_blank'>Model Card</a></p>"

# Launch the Gradio interface
gr.Interface(fn=enhance_speech, inputs=inputs, outputs=outputs, title=title, description=description, article=article).launch()