Gijs Wijngaard
Test
47bcf45
raw
history blame
3.06 kB
from functools import partial
import gradio as gr
import spaces
import torch
from torchaudio.functional import resample
from transformers import AutoModel, PreTrainedTokenizerFast
def load_model(model_name,
device):
if model_name == "AudioCaps":
model = AutoModel.from_pretrained(
"wsntxxn/effb2-trm-audiocaps-captioning",
trust_remote_code=True
).to(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained(
"wsntxxn/audiocaps-simple-tokenizer"
)
elif model_name == "Clotho":
model = AutoModel.from_pretrained(
"wsntxxn/effb2-trm-clotho-captioning",
trust_remote_code=True
).to(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained(
"wsntxxn/clotho-simple-tokenizer"
)
return model, tokenizer
@spaces.GPU
def infer(file, runner):
sr, wav = file
wav = torch.as_tensor(wav)
if wav.dtype == torch.short:
wav = wav / 2 ** 15
elif wav.dtype == torch.int:
wav = wav / 2 ** 31
if wav.ndim > 1:
wav = wav.mean(1)
wav = resample(wav, sr, runner.target_sr)
wav_len = len(wav)
wav = wav.float().unsqueeze(0)
with torch.no_grad():
word_idx = runner.model(
audio=wav,
audio_length=[wav_len]
)[0]
cap = runner.tokenizer.decode(word_idx, skip_special_tokens=True)
return cap
# def input_toggle(input_type):
# if input_type == "file":
# return gr.update(visible=True), gr.update(visible=False)
# elif input_type == "mic":
# return gr.update(visible=False), gr.update(visible=True)
class InferRunner:
def __init__(self, model_name):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model, self.tokenizer = load_model(model_name, self.device)
self.target_sr = self.model.config.sample_rate
def change_model(self, model_name):
self.model, self.tokenizer = load_model(model_name, self.device)
self.target_sr = self.model.config.sample_rate
def change_model(radio):
global infer_runner
infer_runner.change_model(radio)
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("# Lightweight Audio Captioning")
with gr.Row():
gr.Markdown("""
Audio Captioning Demo
""")
with gr.Row():
with gr.Column():
radio = gr.Radio(
["AudioCaps", "Clotho"],
value="AudioCaps",
label="Select model"
)
infer_runner = InferRunner(radio.value)
file = gr.Audio(label="Input", visible=True)
radio.change(fn=change_model, inputs=[radio,],)
btn = gr.Button("Run")
with gr.Column():
output = gr.Textbox(label="Output")
btn.click(
fn=partial(infer,
runner=infer_runner),
inputs=[file,],
outputs=output
)
demo.launch()