aadnk's picture
Ensure GPU memory in diarization can be cleaned up
18bb72f
import argparse
import gc
import json
import os
from pathlib import Path
import tempfile
from typing import TYPE_CHECKING, List
import torch
import ffmpeg
class DiarizationEntry:
def __init__(self, start, end, speaker):
self.start = start
self.end = end
self.speaker = speaker
def __repr__(self):
return f"<DiarizationEntry start={self.start} end={self.end} speaker={self.speaker}>"
def toJson(self):
return {
"start": self.start,
"end": self.end,
"speaker": self.speaker
}
class Diarization:
def __init__(self, auth_token=None):
if auth_token is None:
auth_token = os.environ.get("HK_ACCESS_TOKEN")
if auth_token is None:
raise ValueError("No HuggingFace API Token provided - please use the --auth_token argument or set the HK_ACCESS_TOKEN environment variable")
self.auth_token = auth_token
self.initialized = False
self.pipeline = None
@staticmethod
def has_libraries():
try:
import pyannote.audio
import intervaltree
return True
except ImportError:
return False
def initialize(self):
if self.initialized:
return
from pyannote.audio import Pipeline
self.pipeline = Pipeline.from_pretrained("pyannote/[email protected]", use_auth_token=self.auth_token)
self.initialized = True
# Load GPU mode if available
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
print("Diarization - using GPU")
self.pipeline = self.pipeline.to(torch.device(0))
else:
print("Diarization - using CPU")
def run(self, audio_file, **kwargs):
self.initialize()
audio_file_obj = Path(audio_file)
# Supported file types in soundfile is WAV, FLAC, OGG and MAT
if audio_file_obj.suffix in [".wav", ".flac", ".ogg", ".mat"]:
target_file = audio_file
else:
# Create temp WAV file
target_file = tempfile.mktemp(prefix="diarization_", suffix=".wav")
try:
ffmpeg.input(audio_file).output(target_file, ac=1).run()
except ffmpeg.Error as e:
print(f"Error occurred during audio conversion: {e.stderr}")
diarization = self.pipeline(target_file, **kwargs)
if target_file != audio_file:
# Delete temp file
os.remove(target_file)
# Yield result
for turn, _, speaker in diarization.itertracks(yield_label=True):
yield DiarizationEntry(turn.start, turn.end, speaker)
def mark_speakers(self, diarization_result: List[DiarizationEntry], whisper_result: dict):
from intervaltree import IntervalTree
result = whisper_result.copy()
# Create an interval tree from the diarization results
tree = IntervalTree()
for entry in diarization_result:
tree[entry.start:entry.end] = entry
# Iterate through each segment in the Whisper JSON
for segment in result["segments"]:
segment_start = segment["start"]
segment_end = segment["end"]
# Find overlapping speakers using the interval tree
overlapping_speakers = tree[segment_start:segment_end]
# If no speakers overlap with this segment, skip it
if not overlapping_speakers:
continue
# If multiple speakers overlap with this segment, choose the one with the longest duration
longest_speaker = None
longest_duration = 0
for speaker_interval in overlapping_speakers:
overlap_start = max(speaker_interval.begin, segment_start)
overlap_end = min(speaker_interval.end, segment_end)
overlap_duration = overlap_end - overlap_start
if overlap_duration > longest_duration:
longest_speaker = speaker_interval.data.speaker
longest_duration = overlap_duration
# Add speakers
segment["longest_speaker"] = longest_speaker
segment["speakers"] = list([speaker_interval.data.toJson() for speaker_interval in overlapping_speakers])
# The write_srt will use the longest_speaker if it exist, and add it to the text field
return result
def _write_file(input_file: str, output_path: str, output_extension: str, file_writer: lambda f: None):
if input_file is None:
raise ValueError("input_file is required")
if file_writer is None:
raise ValueError("file_writer is required")
# Write file
if output_path is None:
effective_path = os.path.splitext(input_file)[0] + "_output" + output_extension
else:
effective_path = output_path
with open(effective_path, 'w+', encoding="utf-8") as f:
file_writer(f)
print(f"Output saved to {effective_path}")
def main():
from src.utils import write_srt
from src.diarization.transcriptLoader import load_transcript
parser = argparse.ArgumentParser(description='Add speakers to a SRT file or Whisper JSON file using pyannote/speaker-diarization.')
parser.add_argument('audio_file', type=str, help='Input audio file')
parser.add_argument('whisper_file', type=str, help='Input Whisper JSON/SRT file')
parser.add_argument('--output_json_file', type=str, default=None, help='Output JSON file (optional)')
parser.add_argument('--output_srt_file', type=str, default=None, help='Output SRT file (optional)')
parser.add_argument('--auth_token', type=str, default=None, help='HuggingFace API Token (optional)')
parser.add_argument("--max_line_width", type=int, default=40, help="Maximum line width for SRT file (default: 40)")
parser.add_argument("--num_speakers", type=int, default=None, help="Number of speakers")
parser.add_argument("--min_speakers", type=int, default=None, help="Minimum number of speakers")
parser.add_argument("--max_speakers", type=int, default=None, help="Maximum number of speakers")
args = parser.parse_args()
print("\nReading whisper JSON from " + args.whisper_file)
# Read whisper JSON or SRT file
whisper_result = load_transcript(args.whisper_file)
diarization = Diarization(auth_token=args.auth_token)
diarization_result = list(diarization.run(args.audio_file, num_speakers=args.num_speakers, min_speakers=args.min_speakers, max_speakers=args.max_speakers))
# Print result
print("Diarization result:")
for entry in diarization_result:
print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
marked_whisper_result = diarization.mark_speakers(diarization_result, whisper_result)
# Write output JSON to file
_write_file(args.whisper_file, args.output_json_file, ".json",
lambda f: json.dump(marked_whisper_result, f, indent=4, ensure_ascii=False))
# Write SRT
_write_file(args.whisper_file, args.output_srt_file, ".srt",
lambda f: write_srt(marked_whisper_result["segments"], f, maxLineWidth=args.max_line_width))
if __name__ == "__main__":
main()
#test = Diarization()
#print("Initializing")
#test.initialize()
#input("Press Enter to continue...")