from whisperx.alignment import ( DEFAULT_ALIGN_MODELS_TORCH as DAMT, DEFAULT_ALIGN_MODELS_HF as DAMHF, ) from whisperx.utils import TO_LANGUAGE_CODE import whisperx import torch import gc import os import soundfile as sf from IPython.utils import capture # noqa from .language_configuration import EXTRA_ALIGN, INVERTED_LANGUAGES from .logging_setup import logger from .postprocessor import sanitize_file_name from .utils import remove_directory_contents, run_command # ZERO GPU CONFIG import spaces import copy import random import time def random_sleep(): if os.environ.get("ZERO_GPU") == "TRUE": print("Random sleep") sleep_time = round(random.uniform(7.2, 9.9), 1) time.sleep(sleep_time) @spaces.GPU def load_and_transcribe_audio(asr_model, audio, compute_type, language, asr_options, batch_size, segment_duration_limit): # Load model model = whisperx.load_model( asr_model, os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda", compute_type=compute_type, language=language, asr_options=asr_options, ) # Transcribe audio result = model.transcribe( audio, batch_size=batch_size, chunk_size=segment_duration_limit, print_progress=True, ) del model gc.collect() torch.cuda.empty_cache() # noqa return result def load_align_and_align_segments(result, audio, DAMHF): # Load alignment model model_a, metadata = whisperx.load_align_model( language_code=result["language"], device=os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda", model_name=None if result["language"] in DAMHF.keys() else EXTRA_ALIGN[result["language"]], ) # Align segments alignment_result = whisperx.align( result["segments"], model_a, metadata, audio, os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda", return_char_alignments=True, print_progress=False, ) # Clean up del model_a gc.collect() torch.cuda.empty_cache() # noqa return alignment_result @spaces.GPU def diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers): if os.environ.get("ZERO_GPU") == "TRUE": diarize_model.model.to(torch.device("cuda")) diarize_segments = diarize_model( audio_wav, min_speakers=min_speakers, max_speakers=max_speakers ) return diarize_segments # ZERO GPU CONFIG ASR_MODEL_OPTIONS = [ "tiny", "base", "small", "medium", "large", "large-v1", "large-v2", "large-v3", "distil-large-v2", "Systran/faster-distil-whisper-large-v3", "tiny.en", "base.en", "small.en", "medium.en", "distil-small.en", "distil-medium.en", "OpenAI_API_Whisper", ] COMPUTE_TYPE_GPU = [ "default", "auto", "int8", "int8_float32", "int8_float16", "int8_bfloat16", "float16", "bfloat16", "float32" ] COMPUTE_TYPE_CPU = [ "default", "auto", "int8", "int8_float32", "int16", "float32", ] WHISPER_MODELS_PATH = './WHISPER_MODELS' def openai_api_whisper( input_audio_file, source_lang=None, chunk_duration=1800 ): info = sf.info(input_audio_file) duration = info.duration output_directory = "./whisper_api_audio_parts" os.makedirs(output_directory, exist_ok=True) remove_directory_contents(output_directory) if duration > chunk_duration: # Split the audio file into smaller chunks with 30-minute duration cm = f'ffmpeg -i "{input_audio_file}" -f segment -segment_time {chunk_duration} -c:a libvorbis "{output_directory}/output%03d.ogg"' run_command(cm) # Get list of generated chunk files chunk_files = sorted( [f"{output_directory}/{f}" for f in os.listdir(output_directory) if f.endswith('.ogg')] ) else: one_file = f"{output_directory}/output000.ogg" cm = f'ffmpeg -i "{input_audio_file}" -c:a libvorbis {one_file}' run_command(cm) chunk_files = [one_file] # Transcript segments = [] language = source_lang if source_lang else None for i, chunk in enumerate(chunk_files): from openai import OpenAI client = OpenAI() audio_file = open(chunk, "rb") transcription = client.audio.transcriptions.create( model="whisper-1", file=audio_file, language=language, response_format="verbose_json", timestamp_granularities=["segment"], ) try: transcript_dict = transcription.model_dump() except: # noqa transcript_dict = transcription.to_dict() if language is None: logger.info(f'Language detected: {transcript_dict["language"]}') language = TO_LANGUAGE_CODE[transcript_dict["language"]] chunk_time = chunk_duration * (i) for seg in transcript_dict["segments"]: if "start" in seg.keys(): segments.append( { "text": seg["text"], "start": seg["start"] + chunk_time, "end": seg["end"] + chunk_time, } ) audio = whisperx.load_audio(input_audio_file) result = {"segments": segments, "language": language} return audio, result def find_whisper_models(): path = WHISPER_MODELS_PATH folders = [] if os.path.exists(path): for folder in os.listdir(path): folder_path = os.path.join(path, folder) if ( os.path.isdir(folder_path) and 'model.bin' in os.listdir(folder_path) ): folders.append(folder) return folders def transcribe_speech( audio_wav, asr_model, compute_type, batch_size, SOURCE_LANGUAGE, literalize_numbers=True, segment_duration_limit=15, ): """ Transcribe speech using a whisper model. Parameters: - audio_wav (str): Path to the audio file in WAV format. - asr_model (str): The whisper model to be loaded. - compute_type (str): Type of compute to be used (e.g., 'int8', 'float16'). - batch_size (int): Batch size for transcription. - SOURCE_LANGUAGE (str): Source language for transcription. Returns: - Tuple containing: - audio: Loaded audio file. - result: Transcription result as a dictionary. """ if asr_model == "OpenAI_API_Whisper": if literalize_numbers: logger.info( "OpenAI's API Whisper does not support " "the literalization of numbers." ) return openai_api_whisper(audio_wav, SOURCE_LANGUAGE) # https://github.com/openai/whisper/discussions/277 prompt = "以下是普通话的句子。" if SOURCE_LANGUAGE == "zh" else None SOURCE_LANGUAGE = ( SOURCE_LANGUAGE if SOURCE_LANGUAGE != "zh-TW" else "zh" ) asr_options = { "initial_prompt": prompt, "suppress_numerals": literalize_numbers } if asr_model not in ASR_MODEL_OPTIONS: base_dir = WHISPER_MODELS_PATH if not os.path.exists(base_dir): os.makedirs(base_dir) model_dir = os.path.join(base_dir, sanitize_file_name(asr_model)) if not os.path.exists(model_dir): from ctranslate2.converters import TransformersConverter quantization = "float32" # Download new model try: converter = TransformersConverter( asr_model, low_cpu_mem_usage=True, copy_files=[ "tokenizer_config.json", "preprocessor_config.json" ] ) converter.convert( model_dir, quantization=quantization, force=False ) except Exception as error: if "File tokenizer_config.json does not exist" in str(error): converter._copy_files = [ "tokenizer.json", "preprocessor_config.json" ] converter.convert( model_dir, quantization=quantization, force=True ) else: raise error asr_model = model_dir logger.info(f"ASR Model: {str(model_dir)}") audio = whisperx.load_audio(audio_wav) result = load_and_transcribe_audio( asr_model, audio, compute_type, SOURCE_LANGUAGE, asr_options, batch_size, segment_duration_limit ) if result["language"] == "zh" and not prompt: result["language"] = "zh-TW" logger.info("Chinese - Traditional (zh-TW)") return audio, result def align_speech(audio, result): """ Aligns speech segments based on the provided audio and result metadata. Parameters: - audio (array): The audio data in a suitable format for alignment. - result (dict): Metadata containing information about the segments and language. Returns: - result (dict): Updated metadata after aligning the segments with the audio. This includes character-level alignments if 'return_char_alignments' is set to True. Notes: - This function uses language-specific models to align speech segments. - It performs language compatibility checks and selects the appropriate alignment model. - Cleans up memory by releasing resources after alignment. """ DAMHF.update(DAMT) # lang align if ( not result["language"] in DAMHF.keys() and not result["language"] in EXTRA_ALIGN.keys() ): logger.warning( "Automatic detection: Source language not compatible with align" ) raise ValueError( f"Detected language {result['language']} incompatible, " "you can select the source language to avoid this error." ) if ( result["language"] in EXTRA_ALIGN.keys() and EXTRA_ALIGN[result["language"]] == "" ): lang_name = ( INVERTED_LANGUAGES[result["language"]] if result["language"] in INVERTED_LANGUAGES.keys() else result["language"] ) logger.warning( "No compatible wav2vec2 model found " f"for the language '{lang_name}', skipping alignment." ) return result # random_sleep() result = load_align_and_align_segments(result, audio, DAMHF) return result diarization_models = { "pyannote_3.1": "pyannote/speaker-diarization-3.1", "pyannote_2.1": "pyannote/speaker-diarization@2.1", "disable": "", } def reencode_speakers(result): if result["segments"][0]["speaker"] == "SPEAKER_00": return result speaker_mapping = {} counter = 0 logger.debug("Reencode speakers") for segment in result["segments"]: old_speaker = segment["speaker"] if old_speaker not in speaker_mapping: speaker_mapping[old_speaker] = f"SPEAKER_{counter:02d}" counter += 1 segment["speaker"] = speaker_mapping[old_speaker] return result def diarize_speech( audio_wav, result, min_speakers, max_speakers, YOUR_HF_TOKEN, model_name="pyannote/speaker-diarization@2.1", ): """ Performs speaker diarization on speech segments. Parameters: - audio_wav (array): Audio data in WAV format to perform speaker diarization. - result (dict): Metadata containing information about speech segments and alignments. - min_speakers (int): Minimum number of speakers expected in the audio. - max_speakers (int): Maximum number of speakers expected in the audio. - YOUR_HF_TOKEN (str): Your Hugging Face API token for model authentication. - model_name (str): Name of the speaker diarization model to be used (default: "pyannote/speaker-diarization@2.1"). Returns: - result_diarize (dict): Updated metadata after assigning speaker labels to segments. Notes: - This function utilizes a speaker diarization model to label speaker segments in the audio. - It assigns speakers to word-level segments based on diarization results. - Cleans up memory by releasing resources after diarization. - If only one speaker is specified, each segment is automatically assigned as the first speaker, eliminating the need for diarization inference. """ if max(min_speakers, max_speakers) > 1 and model_name: try: diarize_model = whisperx.DiarizationPipeline( model_name=model_name, use_auth_token=YOUR_HF_TOKEN, device=os.environ.get("SONITR_DEVICE"), ) except Exception as error: error_str = str(error) gc.collect() torch.cuda.empty_cache() # noqa if "'NoneType' object has no attribute 'to'" in error_str: if model_name == diarization_models["pyannote_2.1"]: raise ValueError( "Accept the license agreement for using Pyannote 2.1." " You need to have an account on Hugging Face and " "accept the license to use the models: " "https://huggingface.co/pyannote/speaker-diarization " "and https://huggingface.co/pyannote/segmentation " "Get your KEY TOKEN here: " "https://hf.co/settings/tokens " ) elif model_name == diarization_models["pyannote_3.1"]: raise ValueError( "New Licence Pyannote 3.1: You need to have an account" " on Hugging Face and accept the license to use the " "models: https://huggingface.co/pyannote/speaker-diarization-3.1 " # noqa "and https://huggingface.co/pyannote/segmentation-3.0 " ) else: raise error random_sleep() diarize_segments = diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers) result_diarize = whisperx.assign_word_speakers( diarize_segments, result ) for segment in result_diarize["segments"]: if "speaker" not in segment: segment["speaker"] = "SPEAKER_00" logger.warning( f"No speaker detected in {segment['start']}. First TTS " f"will be used for the segment text: {segment['text']} " ) del diarize_model gc.collect() torch.cuda.empty_cache() # noqa else: result_diarize = result result_diarize["segments"] = [ {**item, "speaker": "SPEAKER_00"} for item in result_diarize["segments"] ] return reencode_speakers(result_diarize)