Spaces:
Runtime error
Runtime error
from typing import Dict | |
import gradio as gr | |
import whisper | |
from whisper.tokenizer import get_tokenizer | |
import classify | |
model_cache = {} | |
def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Dict[str, float]: | |
class_names = class_names.split(",") | |
tokenizer = get_tokenizer(multilingual=".en" not in model_name) | |
if model_name not in model_cache: | |
model = whisper.load_model(model_name) | |
model_cache[model_name] = model | |
else: | |
model = model_cache[model_name] | |
internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs( | |
model=model, | |
class_names=class_names, | |
tokenizer=tokenizer, | |
) | |
audio_features = classify.calculate_audio_features(audio_path, model) | |
average_logprobs = classify.calculate_average_logprobs( | |
model=model, | |
audio_features=audio_features, | |
class_names=class_names, | |
tokenizer=tokenizer, | |
) | |
average_logprobs -= internal_lm_average_logprobs | |
scores = average_logprobs.softmax(-1).tolist() | |
return {class_name: score for class_name, score in zip(class_names, scores)} | |
def main(): | |
CLASS_NAMES = "[dog barking],[helicopter whirring],[laughing],[birds chirping],[clock ticking]" | |
AUDIO_PATHS = [ | |
"./data/(dog)1-100032-A-0.wav", | |
"./data/(helicopter)1-181071-A-40.wav", | |
"./data/(laughing)1-1791-A-26.wav", | |
"./data/(chirping_birds)1-34495-A-14.wav", | |
"./data/(clock_tick)1-21934-A-38.wav", | |
] | |
EXAMPLES = [] | |
for audio_path in AUDIO_PATHS: | |
EXAMPLES.append([audio_path, CLASS_NAMES, "small"]) | |
DESCRIPTION = ( | |
'<div style="text-align: center;">' | |
"<p>This demo allows you to try out zero-shot audio classification using " | |
"<a href=https://github.com/openai/whisper>Whisper</a>.</p>" | |
"<p>Github: <a href=https://github.com/jumon/zac>https://github.com/jumon/zac</a></p>" | |
"<p>Example audio files are from the <a href=https://github.com/karolpiczak/ESC-50>ESC-50" | |
"</a> dataset (CC BY-NC 3.0).</p></div>" | |
) | |
demo = gr.Interface( | |
fn=zero_shot_classify, | |
inputs=[ | |
gr.Audio(label="Input Audio",show_label=False,source="microphone",type="filepath"), | |
gr.Textbox(lines=1, label="Candidate class names (comma-separated)"), | |
gr.Radio( | |
choices=["tiny", "base", "small", "medium", "large"], | |
value="small", | |
label="Model Name", | |
), | |
], | |
outputs="label", | |
examples=EXAMPLES, | |
title="Zero-shot Audio Classification using Whisper", | |
description=DESCRIPTION, | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |