import io import os import torch from transformers import ( AutomaticSpeechRecognitionPipeline, WhisperForConditionalGeneration, WhisperTokenizer, WhisperProcessor, ) from peft import PeftModel, PeftConfig import speech_recognition as sr from datetime import datetime, timedelta from queue import Queue from tempfile import NamedTemporaryFile from time import sleep from sys import platform def main(): # Set your default configuration values here peft_model_id = "DuyTa/Vietnamese_ASR" language = "Vietnamese" task = "transcribe" default_energy_threshold = 900 default_record_timeout = 0.6 default_phrase_timeout = 3 # The last time a recording was retrieved from the queue. phrase_time = None # Current raw audio bytes. last_sample = bytes() # Thread safe Queue for passing data from the threaded recording callback. data_queue = Queue() # We use SpeechRecognizer to record our audio because it has a nice feature where it can detect when speech ends. recorder = sr.Recognizer() recorder.energy_threshold = default_energy_threshold # Definitely do this, dynamic energy compensation lowers the energy threshold dramatically to a point where the SpeechRecognizer never stops recording. recorder.dynamic_energy_threshold = False source = sr.Microphone(sample_rate=16000) # Use default microphone source for non-Linux platforms # Load / Download model peft_config = PeftConfig.from_pretrained(peft_model_id) model = WhisperForConditionalGeneration.from_pretrained( peft_config.base_model_name_or_path ) model = PeftModel.from_pretrained(model, peft_model_id) model.to("cuda:0") processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task) pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, batch_size=8, torch_dtype=torch.float32, device="cuda:0") record_timeout = default_record_timeout phrase_timeout = default_phrase_timeout temp_file = NamedTemporaryFile().name transcription = [''] with source: recorder.adjust_for_ambient_noise(source) def record_callback(_, audio:sr.AudioData) -> None: """ Threaded callback function to receive audio data when recordings finish. audio: An AudioData containing the recorded bytes. """ # Grab the raw bytes and push it into the thread safe queue. data = audio.get_raw_data() data_queue.put(data) # Create a background thread that will pass us raw audio bytes. # We could do this manually but SpeechRecognizer provides a nice helper. recorder.listen_in_background(source, record_callback, phrase_time_limit=record_timeout) print("Model loaded.\n") while True: try: now = datetime.utcnow() # Pull raw recorded audio from the queue. if not data_queue.empty(): phrase_complete = False # If enough time has passed between recordings, consider the phrase complete. # Clear the current working audio buffer to start over with the new data. if phrase_time and now - phrase_time > timedelta(seconds=phrase_timeout): last_sample = bytes() phrase_complete = True # This is the last time we received new audio data from the queue. phrase_time = now # Concatenate our current audio data with the latest audio data. while not data_queue.empty(): data = data_queue.get() last_sample += data # Use AudioData to convert the raw data to wav data. audio_data = sr.AudioData(last_sample, source.SAMPLE_RATE, source.SAMPLE_WIDTH) wav_data = io.BytesIO(audio_data.get_wav_data()) # Write wav data to the temporary file as bytes. with open(temp_file, 'w+b') as f: f.write(wav_data.read()) # Read the transcription. text = pipe(temp_file, chunk_length_s=30, return_timestamps=False, generate_kwargs={"language": language, "task": task})["text"] # If we detected a pause between recordings, add a new item to our transcription. # Otherwise edit the existing one. if phrase_complete: transcription.append(text) else: transcription[-1] = text # Clear the console to reprint the updated transcription. os.system('cls' if os.name == 'nt' else 'clear') for line in transcription: print(line) # Flush stdout. print('', end='', flush=True) # Infinite loops are bad for processors, must sleep. sleep(0.25) except KeyboardInterrupt: break print("\n\nTranscription:") for line in transcription: print(line) if __name__ == "__main__": main()