File size: 1,226 Bytes
e0f9e95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torchaudio


class AudioPipeline:
    def __init__(self, feature_extractor, model, top_k=5):
        self.fe = feature_extractor
        self.model = model
        self.top_k = top_k

    def __call__(self, audio_file):
        if isinstance(audio_file, str):
            waveform, sample_rate = torchaudio.load(audio_file)
        else:
            waveform, sample_rate = torchaudio.load(audio_file.name)
        waveform = waveform.mean(dim=0)
        if sample_rate != self.fe.sampling_rate:
            transform = torchaudio.transforms.Resample(
                orig_freq=sample_rate,
                new_freq=self.fe.sampling_rate)
            waveform = transform(waveform)
        inputs = self.fe(waveform,
            sampling_rate=self.fe.sampling_rate,
            return_tensors="pt",
            padding=True)
        with torch.no_grad():
            logits = self.model(**inputs).logits
        probs = torch.nn.functional.softmax(logits, dim=-1)[0]
        
        top_probs, top_ids = torch.topk(probs, self.top_k)
        top_labels = [self.model.config.id2label[idx.item()] for idx in top_ids]
        
        return {label: prob.item() for label, prob in zip(top_labels, top_probs)}