from re import A |
from transformers.file_utils import cached_path, hf_bucket_url |
import os, zipfile |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC |
import soundfile as sf |
import torch |
import kenlm |
from pyctcdecode import Alphabet, BeamSearchDecoderCTC, LanguageModel |
import os |
from multiprocessing import Pool |
import argparse, subprocess, tempfile |
def extract_audio(filename, channels=1, rate=16000): |
""" |
Extract audio from an input file to a temporary WAV file. |
""" |
temp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False) |
if not os.path.isfile(filename): |
print("The given file does not exist: {}".format(filename)) |
raise Exception("Invalid filepath: {}".format(filename)) |
command = ["ffmpeg", "-y", "-i", filename, |
"-ac", str(channels), "-ar", str(rate), |
"-loglevel", "error", temp.name] |
use_shell = True if os.name == "nt" else False |
subprocess.check_output(command, stdin=open(os.devnull), shell=use_shell) |
return temp.name, rate |
class Wav2Vec: |
def __init__(self): |
self.device = "cuda" |
cache_dir = './cache/' |
self.processor = Wav2Vec2Processor.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h", cache_dir=cache_dir) |
lm_file = hf_bucket_url("nguyenvulebinh/wav2vec2-base-vietnamese-250h", filename='vi_lm_4grams.bin.zip') |
lm_file = cached_path(lm_file,cache_dir=cache_dir) |
with zipfile.ZipFile(lm_file, 'r') as zip_ref: |
zip_ref.extractall(cache_dir) |
lm_file = cache_dir + 'vi_lm_4grams.bin' |
self.model = Wav2Vec2ForCTC.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h", cache_dir=cache_dir) |
self.model.to(self.device) |
self.ngram_lm_model = self.get_decoder_ngram_model(self.processor.tokenizer, lm_file) |
def get_decoder_ngram_model(self, tokenizer, ngram_lm_path): |
vocab_dict = tokenizer.get_vocab() |
sort_vocab = sorted((value, key) for (key, value) in vocab_dict.items()) |
vocab = [x[1] for x in sort_vocab][:-2] |
vocab_list = vocab |
vocab_list[tokenizer.pad_token_id] = "" |
vocab_list[tokenizer.unk_token_id] = "" |
vocab_list[tokenizer.word_delimiter_token_id] = " " |
alphabet = Alphabet.build_alphabet(vocab_list, ctc_token_idx=tokenizer.pad_token_id) |
lm_model = kenlm.Model(ngram_lm_path) |
decoder = BeamSearchDecoderCTC(alphabet, language_model=LanguageModel(lm_model)) |
return decoder |
def map_to_array(self, batch): |
speech, sampling_rate = sf.read(batch["file"]) |
batch["speech"] = speech |
batch["sampling_rate"] = sampling_rate |
return batch |
def inference(self, filename): |
ds = self.map_to_array({"file": filename}) |
input_values = self.processor(ds["speech"], sampling_rate=ds["sampling_rate"], return_tensors="pt").input_values |
input_values = input_values.to(self.device) |
logits = self.model(input_values).logits[0] |
pred_ids = torch.argmax(logits, dim=-1) |
greedy_search_output = self.processor.decode(pred_ids) |
beam_search_output = self.ngram_lm_model.decode(logits.cpu().detach().numpy(), beam_width=500) |
return beam_search_output |
if __name__ == "__main__": |
w2v = Wav2Vec() |
import glob, tqdm |
parser = argparse.ArgumentParser() |
parser.add_argument('--wavs', default="DATA/wavs", help="", type=str) |
parser.add_argument('--train_file', default="DATA/train.txt", help="", type=str) |
parser.add_argument('--val_file', default="DATA/train.txt", help="", type=str) |
args = parser.parse_args() |
os.makedirs(os.path.dirname(args.train_file), exist_ok = True) |
count_val = 0 |
fw = open(args.train_file, "w+", encoding="utf-8") |
fw_val = open(args.val_file, "w+", encoding="utf-8") |
for i in tqdm.tqdm(glob.glob(args.wavs + "/*.wav")): |
audio_filename, audio_rate = extract_audio(i) |
output = w2v.inference(audio_filename) |
fw.write(i.split("/")[-1] + " " + output + "\n") |
if count_val < 64: |
count_val = count_val + 1 |
fw_val.write(i.split("/")[-1] + " " + output + "\n") |
fw.close() |
fw_val.close() |