Spaces:
Running
Running
File size: 7,467 Bytes
a1da02d 18bb72f a1da02d 18bb72f a1da02d 18bb72f a1da02d 18bb72f a1da02d 18bb72f a1da02d 18bb72f a1da02d 18bb72f a1da02d 18bb72f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import argparse
import gc
import json
import os
from pathlib import Path
import tempfile
from typing import TYPE_CHECKING, List
import torch
import ffmpeg
class DiarizationEntry:
def __init__(self, start, end, speaker):
self.start = start
self.end = end
self.speaker = speaker
def __repr__(self):
return f"<DiarizationEntry start={self.start} end={self.end} speaker={self.speaker}>"
def toJson(self):
return {
"start": self.start,
"end": self.end,
"speaker": self.speaker
}
class Diarization:
def __init__(self, auth_token=None):
if auth_token is None:
auth_token = os.environ.get("HK_ACCESS_TOKEN")
if auth_token is None:
raise ValueError("No HuggingFace API Token provided - please use the --auth_token argument or set the HK_ACCESS_TOKEN environment variable")
self.auth_token = auth_token
self.initialized = False
self.pipeline = None
@staticmethod
def has_libraries():
try:
import pyannote.audio
import intervaltree
return True
except ImportError:
return False
def initialize(self):
if self.initialized:
return
from pyannote.audio import Pipeline
self.pipeline = Pipeline.from_pretrained("pyannote/[email protected]", use_auth_token=self.auth_token)
self.initialized = True
# Load GPU mode if available
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
print("Diarization - using GPU")
self.pipeline = self.pipeline.to(torch.device(0))
else:
print("Diarization - using CPU")
def run(self, audio_file, **kwargs):
self.initialize()
audio_file_obj = Path(audio_file)
# Supported file types in soundfile is WAV, FLAC, OGG and MAT
if audio_file_obj.suffix in [".wav", ".flac", ".ogg", ".mat"]:
target_file = audio_file
else:
# Create temp WAV file
target_file = tempfile.mktemp(prefix="diarization_", suffix=".wav")
try:
ffmpeg.input(audio_file).output(target_file, ac=1).run()
except ffmpeg.Error as e:
print(f"Error occurred during audio conversion: {e.stderr}")
diarization = self.pipeline(target_file, **kwargs)
if target_file != audio_file:
# Delete temp file
os.remove(target_file)
# Yield result
for turn, _, speaker in diarization.itertracks(yield_label=True):
yield DiarizationEntry(turn.start, turn.end, speaker)
def mark_speakers(self, diarization_result: List[DiarizationEntry], whisper_result: dict):
from intervaltree import IntervalTree
result = whisper_result.copy()
# Create an interval tree from the diarization results
tree = IntervalTree()
for entry in diarization_result:
tree[entry.start:entry.end] = entry
# Iterate through each segment in the Whisper JSON
for segment in result["segments"]:
segment_start = segment["start"]
segment_end = segment["end"]
# Find overlapping speakers using the interval tree
overlapping_speakers = tree[segment_start:segment_end]
# If no speakers overlap with this segment, skip it
if not overlapping_speakers:
continue
# If multiple speakers overlap with this segment, choose the one with the longest duration
longest_speaker = None
longest_duration = 0
for speaker_interval in overlapping_speakers:
overlap_start = max(speaker_interval.begin, segment_start)
overlap_end = min(speaker_interval.end, segment_end)
overlap_duration = overlap_end - overlap_start
if overlap_duration > longest_duration:
longest_speaker = speaker_interval.data.speaker
longest_duration = overlap_duration
# Add speakers
segment["longest_speaker"] = longest_speaker
segment["speakers"] = list([speaker_interval.data.toJson() for speaker_interval in overlapping_speakers])
# The write_srt will use the longest_speaker if it exist, and add it to the text field
return result
def _write_file(input_file: str, output_path: str, output_extension: str, file_writer: lambda f: None):
if input_file is None:
raise ValueError("input_file is required")
if file_writer is None:
raise ValueError("file_writer is required")
# Write file
if output_path is None:
effective_path = os.path.splitext(input_file)[0] + "_output" + output_extension
else:
effective_path = output_path
with open(effective_path, 'w+', encoding="utf-8") as f:
file_writer(f)
print(f"Output saved to {effective_path}")
def main():
from src.utils import write_srt
from src.diarization.transcriptLoader import load_transcript
parser = argparse.ArgumentParser(description='Add speakers to a SRT file or Whisper JSON file using pyannote/speaker-diarization.')
parser.add_argument('audio_file', type=str, help='Input audio file')
parser.add_argument('whisper_file', type=str, help='Input Whisper JSON/SRT file')
parser.add_argument('--output_json_file', type=str, default=None, help='Output JSON file (optional)')
parser.add_argument('--output_srt_file', type=str, default=None, help='Output SRT file (optional)')
parser.add_argument('--auth_token', type=str, default=None, help='HuggingFace API Token (optional)')
parser.add_argument("--max_line_width", type=int, default=40, help="Maximum line width for SRT file (default: 40)")
parser.add_argument("--num_speakers", type=int, default=None, help="Number of speakers")
parser.add_argument("--min_speakers", type=int, default=None, help="Minimum number of speakers")
parser.add_argument("--max_speakers", type=int, default=None, help="Maximum number of speakers")
args = parser.parse_args()
print("\nReading whisper JSON from " + args.whisper_file)
# Read whisper JSON or SRT file
whisper_result = load_transcript(args.whisper_file)
diarization = Diarization(auth_token=args.auth_token)
diarization_result = list(diarization.run(args.audio_file, num_speakers=args.num_speakers, min_speakers=args.min_speakers, max_speakers=args.max_speakers))
# Print result
print("Diarization result:")
for entry in diarization_result:
print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
marked_whisper_result = diarization.mark_speakers(diarization_result, whisper_result)
# Write output JSON to file
_write_file(args.whisper_file, args.output_json_file, ".json",
lambda f: json.dump(marked_whisper_result, f, indent=4, ensure_ascii=False))
# Write SRT
_write_file(args.whisper_file, args.output_srt_file, ".srt",
lambda f: write_srt(marked_whisper_result["segments"], f, maxLineWidth=args.max_line_width))
if __name__ == "__main__":
main()
#test = Diarization()
#print("Initializing")
#test.initialize()
#input("Press Enter to continue...") |