Spaces:
Running
Running
#importing all the necessary packages | |
import torch | |
import transformers | |
import gradio as gr | |
from torchaudio.sox_effects import apply_effects_file | |
from termcolor import colored | |
from transformers import Wav2Vec2FeatureExtractor, UniSpeechSatForAudioFrameClassification | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Defines the effects to apply to the audio file | |
EFFECTS = [ | |
['remix', '-'], # merge all the channels | |
["channels", "1"], #channel-->mono | |
["rate", "16000"], # resample to 16000 Hz | |
["gain", "-1.0"], #Attenuation -1 dB | |
["silence", "1", "0.1", "0.1%", "-1", "0.1", "0.1%"], | |
#['pad', '0', '1.5'], # add 1.5 seconds silence at the end | |
['trim', '0', '10'], # get the first 10 seconds | |
] | |
THRESHOLD = 0.85 #depends on dataset | |
model_name = "microsoft/unispeech-sat-base-sd" | |
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) | |
model = UniSpeechSatForAudioFrameClassification.from_pretrained(model_name).to(device) | |
def fn(path): | |
#Applying the effects to the audio input file | |
wav, _ = apply_effects_file(path, EFFECTS) | |
#Extracting features | |
input = feature_extractor(wav.squeeze(0), return_tensors="pt", sampling_rate=16000).input_values.to(device) | |
with torch.no_grad(): | |
logits = model(input).logits | |
logits = logits.to(device) | |
probabilities = torch.sigmoid(logits[0]) | |
# labels is a one-hot array of shape (num_frames, num_speakers) | |
labels = (probabilities > 0.5).long() | |
return labels | |
inputs = [ | |
gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Speaker #1"), | |
] | |
output = gr.outputs.Textbox(label="Output Text") | |
gr.Interface( | |
fn=fn, | |
inputs=inputs, | |
outputs=output, | |
theme = "grass", | |
title="Speaker diarization using UniSpeech-SAT and X-Vectors").launch(enable_queue=True) | |