Spaces:
Sleeping
Sleeping
import torch | |
import torchaudio | |
class AudioPipeline: | |
def __init__(self, feature_extractor, model, top_k=5): | |
self.fe = feature_extractor | |
self.model = model | |
self.top_k = top_k | |
def __call__(self, audio_file): | |
if isinstance(audio_file, str): | |
waveform, sample_rate = torchaudio.load(audio_file) | |
else: | |
waveform, sample_rate = torchaudio.load(audio_file.name) | |
waveform = waveform.mean(dim=0) | |
if sample_rate != self.fe.sampling_rate: | |
transform = torchaudio.transforms.Resample( | |
orig_freq=sample_rate, | |
new_freq=self.fe.sampling_rate) | |
waveform = transform(waveform) | |
inputs = self.fe(waveform, | |
sampling_rate=self.fe.sampling_rate, | |
return_tensors="pt", | |
padding=True) | |
with torch.no_grad(): | |
logits = self.model(**inputs).logits | |
probs = torch.nn.functional.softmax(logits, dim=-1)[0] | |
top_probs, top_ids = torch.topk(probs, self.top_k) | |
top_labels = [self.model.config.id2label[idx.item()] for idx in top_ids] | |
return {label: prob.item() for label, prob in zip(top_labels, top_probs)} |