Spaces:
Sleeping
Sleeping
from functools import partial | |
import gradio as gr | |
import spaces | |
import torch | |
from torchaudio.functional import resample | |
from transformers import AutoModel, PreTrainedTokenizerFast | |
def load_model(model_name, | |
device): | |
if model_name == "AudioCaps": | |
model = AutoModel.from_pretrained( | |
"wsntxxn/effb2-trm-audiocaps-captioning", | |
trust_remote_code=True | |
).to(device) | |
tokenizer = PreTrainedTokenizerFast.from_pretrained( | |
"wsntxxn/audiocaps-simple-tokenizer" | |
) | |
elif model_name == "Clotho": | |
model = AutoModel.from_pretrained( | |
"wsntxxn/effb2-trm-clotho-captioning", | |
trust_remote_code=True | |
).to(device) | |
tokenizer = PreTrainedTokenizerFast.from_pretrained( | |
"wsntxxn/clotho-simple-tokenizer" | |
) | |
return model, tokenizer | |
def infer(file, runner): | |
sr, wav = file | |
wav = torch.as_tensor(wav) | |
if wav.dtype == torch.short: | |
wav = wav / 2 ** 15 | |
elif wav.dtype == torch.int: | |
wav = wav / 2 ** 31 | |
if wav.ndim > 1: | |
wav = wav.mean(1) | |
wav = resample(wav, sr, runner.target_sr) | |
wav_len = len(wav) | |
wav = wav.float().unsqueeze(0) | |
with torch.no_grad(): | |
word_idx = runner.model( | |
audio=wav, | |
audio_length=[wav_len] | |
)[0] | |
cap = runner.tokenizer.decode(word_idx, skip_special_tokens=True) | |
return cap | |
# def input_toggle(input_type): | |
# if input_type == "file": | |
# return gr.update(visible=True), gr.update(visible=False) | |
# elif input_type == "mic": | |
# return gr.update(visible=False), gr.update(visible=True) | |
class InferRunner: | |
def __init__(self, model_name): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model, self.tokenizer = load_model(model_name, self.device) | |
self.target_sr = self.model.config.sample_rate | |
def change_model(self, model_name): | |
self.model, self.tokenizer = load_model(model_name, self.device) | |
self.target_sr = self.model.config.sample_rate | |
def change_model(radio): | |
global infer_runner | |
infer_runner.change_model(radio) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
gr.Markdown("# Lightweight Audio Captioning") | |
with gr.Row(): | |
gr.Markdown(""" | |
Audio Captioning Demo | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
radio = gr.Radio( | |
["AudioCaps", "Clotho"], | |
value="AudioCaps", | |
label="Select model" | |
) | |
infer_runner = InferRunner(radio.value) | |
file = gr.Audio(label="Input", visible=True) | |
radio.change(fn=change_model, inputs=[radio,],) | |
btn = gr.Button("Run") | |
with gr.Column(): | |
output = gr.Textbox(label="Output") | |
btn.click( | |
fn=partial(infer, | |
runner=infer_runner), | |
inputs=[file,], | |
outputs=output | |
) | |
demo.launch() | |