import torch import spaces import gradio as gr import os from pyannote.audio import Pipeline # instantiate the pipeline try: pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=os.environ["api"] ) # Move the pipeline to the GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pipeline.to(device) except Exception as e: print(f"Error initializing pipeline: {e}") pipeline = None def save_audio(audio): if pipeline is None: return "Error: Pipeline not initialized" # Read the uploaded audio file as bytes with open(audio, "rb") as f: audio_data = f.read() # Save the uploaded audio file to a temporary location with open("temp.wav", "wb") as f: f.write(audio_data) return "temp.wav" @spaces.GPU(duration=90) def diarize_audio(temp_file, num_speakers, min_speakers, max_speakers): if pipeline is None: return "Error: Pipeline not initialized" try: params = {} if num_speakers > 0: params["num_speakers"] = num_speakers if min_speakers > 0: params["min_speakers"] = min_speakers if max_speakers > 0: params["max_speakers"] = max_speakers diarization = pipeline(temp_file, **params) except Exception as e: return f"Error processing audio: {e}" # Remove the temporary file os.remove(temp_file) # Return the diarization output return str(diarization) def timestamp_to_seconds(timestamp): try: # Extracts hour, minute, and second from timestamp and converts to total seconds h, m, s = map(float, timestamp.split(':')) return 3600 * h + 60 * m + s except ValueError as e: print(f"Error converting timestamp to seconds: '{timestamp}'. Error: {e}") return None def generate_labels_from_diarization(diarization_output): successful_lines = 0 # Counter for successfully processed lines labels_path = 'labels.txt' try: with open(labels_path, 'w') as outfile: lines = diarization_output.strip().split('\n') for line in lines: try: parts = line.strip()[1:-1].split(' --> ') start_time = parts[0].strip() end_time = parts[1].split(']')[0].strip() label = line.split()[-1].strip() # Extracting the last word as label start_seconds = timestamp_to_seconds(start_time) end_seconds = timestamp_to_seconds(end_time) outfile.write(f"{start_seconds}\t{end_seconds}\t{label}\n") successful_lines += 1 except Exception as e: print(f"Error processing line: '{line.strip()}'. Error: {e}") print(f"Processed {successful_lines} lines successfully.") return labels_path if successful_lines > 0 else None except Exception as e: print(f"Cannot write to file '{labels_path}'. Error: {e}") return None def process_audio(audio, num_speakers, min_speakers, max_speakers): diarization_result = diarize_audio(save_audio(audio), num_speakers, min_speakers, max_speakers) if diarization_result.startswith("Error"): return diarization_result, None # Return None for label file link if there's an error else: label_file = generate_labels_from_diarization(diarization_result) return diarization_result, label_file with gr.Blocks() as demo: gr.Markdown(""" # 🗣️Pyannote Speaker Diarization 3.1🗣️ This model takes an audio file as input and outputs the diarization of the speakers in the audio. Please upload an audio file and adjust the parameters as needed. The maximum length of the audio file it can process is around **35-40 minutes**. If you find this space helpful, please ❤ it. """) audio_input = gr.Audio(type="filepath", label="Upload Audio") num_speakers_input = gr.Number(label="Number of Speakers", value=0) min_speakers_input = gr.Number(label="Minimum Number of Speakers", value=0) max_speakers_input = gr.Number(label="Maximum Number of Speakers", value=0) process_button = gr.Button("Process") diarization_output = gr.Textbox(label="Diarization Output") label_file_link = gr.File(label="Download DAW Labels") process_button.click( fn=process_audio, inputs=[audio_input, num_speakers_input, min_speakers_input, max_speakers_input], outputs=[diarization_output, label_file_link] ) demo.launch()