music_mind_app / pipeline.py
bpietrzak
Finishing
e0f9e95
raw
history blame
1.23 kB
import torch
import torchaudio
class AudioPipeline:
def __init__(self, feature_extractor, model, top_k=5):
self.fe = feature_extractor
self.model = model
self.top_k = top_k
def __call__(self, audio_file):
if isinstance(audio_file, str):
waveform, sample_rate = torchaudio.load(audio_file)
else:
waveform, sample_rate = torchaudio.load(audio_file.name)
waveform = waveform.mean(dim=0)
if sample_rate != self.fe.sampling_rate:
transform = torchaudio.transforms.Resample(
orig_freq=sample_rate,
new_freq=self.fe.sampling_rate)
waveform = transform(waveform)
inputs = self.fe(waveform,
sampling_rate=self.fe.sampling_rate,
return_tensors="pt",
padding=True)
with torch.no_grad():
logits = self.model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
top_probs, top_ids = torch.topk(probs, self.top_k)
top_labels = [self.model.config.id2label[idx.item()] for idx in top_ids]
return {label: prob.item() for label, prob in zip(top_labels, top_probs)}