Spaces:
Sleeping
Sleeping
bpietrzak
commited on
Commit
·
e0f9e95
1
Parent(s):
c49f003
Finishing
Browse files- app.py +8 -33
- pipeline.py +33 -0
app.py
CHANGED
@@ -1,49 +1,24 @@
|
|
1 |
-
import torch
|
2 |
-
import torchaudio
|
3 |
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
4 |
import gradio as gr
|
5 |
-
import json
|
6 |
|
|
|
7 |
|
8 |
-
|
|
|
9 |
|
10 |
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
11 |
-
|
12 |
)
|
13 |
|
14 |
-
model = AutoModelForAudioClassification.from_pretrained(
|
15 |
-
|
16 |
-
)
|
17 |
|
18 |
-
def audio_pipeline(audio_file):
|
19 |
-
if isinstance(audio_file, str):
|
20 |
-
waveform, sample_rate = torchaudio.load(audio_file)
|
21 |
-
else:
|
22 |
-
waveform, sample_rate = torchaudio.load(audio_file.name)
|
23 |
-
waveform = waveform.mean(dim=0)
|
24 |
-
if sample_rate != feature_extractor.sampling_rate:
|
25 |
-
transform = torchaudio.transforms.Resample(
|
26 |
-
orig_freq=sample_rate,
|
27 |
-
new_freq=feature_extractor.sampling_rate)
|
28 |
-
waveform = transform(waveform)
|
29 |
-
inputs = feature_extractor(waveform,
|
30 |
-
sampling_rate=feature_extractor.sampling_rate,
|
31 |
-
return_tensors="pt",
|
32 |
-
padding=True)
|
33 |
-
with torch.no_grad():
|
34 |
-
logits = model(**inputs).logits
|
35 |
-
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
36 |
-
|
37 |
-
top_probs, top_ids = torch.topk(probs, config['top_k'])
|
38 |
-
top_labels = [model.config.id2label[idx.item()] for idx in top_ids]
|
39 |
-
|
40 |
-
results = {label: prob.item() for label, prob in zip(top_labels, top_probs)}
|
41 |
-
return results
|
42 |
|
43 |
demo = gr.Interface(
|
44 |
fn=audio_pipeline,
|
45 |
inputs=[gr.Audio(type="filepath", label="Upload Audio")],
|
46 |
-
outputs=gr.Label(num_top_classes=
|
47 |
title="Music Mind",
|
48 |
)
|
49 |
|
|
|
|
|
|
|
1 |
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
2 |
import gradio as gr
|
|
|
3 |
|
4 |
+
from pipeline import AudioPipeline
|
5 |
|
6 |
+
|
7 |
+
model_id = 'bjpietrzak/music_mind_distillhubert_gtzan_4e-5_WAdam_CosineCheguler'
|
8 |
|
9 |
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
10 |
+
model_id, do_normalize=True, return_attention_mask=True
|
11 |
)
|
12 |
|
13 |
+
model = AutoModelForAudioClassification.from_pretrained(model_id)
|
14 |
+
|
15 |
+
audio_pipeline = AudioPipeline(feature_extractor, model, top_k=7)
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
demo = gr.Interface(
|
19 |
fn=audio_pipeline,
|
20 |
inputs=[gr.Audio(type="filepath", label="Upload Audio")],
|
21 |
+
outputs=gr.Label(num_top_classes=7),
|
22 |
title="Music Mind",
|
23 |
)
|
24 |
|
pipeline.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
|
4 |
+
|
5 |
+
class AudioPipeline:
|
6 |
+
def __init__(self, feature_extractor, model, top_k=5):
|
7 |
+
self.fe = feature_extractor
|
8 |
+
self.model = model
|
9 |
+
self.top_k = top_k
|
10 |
+
|
11 |
+
def __call__(self, audio_file):
|
12 |
+
if isinstance(audio_file, str):
|
13 |
+
waveform, sample_rate = torchaudio.load(audio_file)
|
14 |
+
else:
|
15 |
+
waveform, sample_rate = torchaudio.load(audio_file.name)
|
16 |
+
waveform = waveform.mean(dim=0)
|
17 |
+
if sample_rate != self.fe.sampling_rate:
|
18 |
+
transform = torchaudio.transforms.Resample(
|
19 |
+
orig_freq=sample_rate,
|
20 |
+
new_freq=self.fe.sampling_rate)
|
21 |
+
waveform = transform(waveform)
|
22 |
+
inputs = self.fe(waveform,
|
23 |
+
sampling_rate=self.fe.sampling_rate,
|
24 |
+
return_tensors="pt",
|
25 |
+
padding=True)
|
26 |
+
with torch.no_grad():
|
27 |
+
logits = self.model(**inputs).logits
|
28 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
29 |
+
|
30 |
+
top_probs, top_ids = torch.topk(probs, self.top_k)
|
31 |
+
top_labels = [self.model.config.id2label[idx.item()] for idx in top_ids]
|
32 |
+
|
33 |
+
return {label: prob.item() for label, prob in zip(top_labels, top_probs)}
|