File size: 2,741 Bytes
e96a816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfb930a
e96a816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from typing import Dict

import gradio as gr
import whisper
from whisper.tokenizer import get_tokenizer

import classify

model_cache = {}


def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Dict[str, float]:
    class_names = class_names.split(",")
    tokenizer = get_tokenizer(multilingual=".en" not in model_name)

    if model_name not in model_cache:
        model = whisper.load_model(model_name)
        model_cache[model_name] = model
    else:
        model = model_cache[model_name]

    internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
        model=model,
        class_names=class_names,
        tokenizer=tokenizer,
    )
    audio_features = classify.calculate_audio_features(audio_path, model)
    average_logprobs = classify.calculate_average_logprobs(
        model=model,
        audio_features=audio_features,
        class_names=class_names,
        tokenizer=tokenizer,
    )
    average_logprobs -= internal_lm_average_logprobs
    scores = average_logprobs.softmax(-1).tolist()
    return {class_name: score for class_name, score in zip(class_names, scores)}


def main():
    CLASS_NAMES = "[dog barking],[helicopter whirring],[laughing],[birds chirping],[clock ticking]"
    AUDIO_PATHS = [
        "./data/(dog)1-100032-A-0.wav",
        "./data/(helicopter)1-181071-A-40.wav",
        "./data/(laughing)1-1791-A-26.wav",
        "./data/(chirping_birds)1-34495-A-14.wav",
        "./data/(clock_tick)1-21934-A-38.wav",
    ]
    EXAMPLES = []
    for audio_path in AUDIO_PATHS:
        EXAMPLES.append([audio_path, CLASS_NAMES, "small"])

    DESCRIPTION = (
        '<div style="text-align: center;">'
        "<p>This demo allows you to try out zero-shot audio classification using "
        "<a href=https://github.com/openai/whisper>Whisper</a>.</p>"
        "<p>Github: <a href=https://github.com/jumon/zac>https://github.com/jumon/zac</a></p>"
        "<p>Example audio files are from the <a href=https://github.com/karolpiczak/ESC-50>ESC-50"
        "</a> dataset (CC BY-NC 3.0).</p></div>"
    )

    demo = gr.Interface(
        fn=zero_shot_classify,
        inputs=[
            gr.Audio(label="Input Audio",show_label=False,source="microphone",type="filepath"),
            gr.Textbox(lines=1, label="Candidate class names (comma-separated)"),
            gr.Radio(
                choices=["tiny", "base", "small", "medium", "large"],
                value="small",
                label="Model Name",
            ),
        ],
        outputs="label",
        examples=EXAMPLES,
        title="Zero-shot Audio Classification using Whisper",
        description=DESCRIPTION,
    )

    demo.launch()


if __name__ == "__main__":
    main()