Spaces:
Sleeping
Sleeping
import os | |
import json | |
import nltk | |
import gradio as gr | |
from datetime import datetime | |
from pathlib import Path | |
import shutil | |
# Download NLTK data | |
nltk.download('punkt') | |
class TTSDatasetCollector: | |
"""Manages TTS dataset collection and organization""" | |
def __init__(self, root_path: str = "dataset_root"): | |
self.root_path = Path(root_path) | |
self.sentences = [] | |
self.current_index = 0 | |
self.setup_directories() | |
def setup_directories(self): | |
"""Create necessary directory structure""" | |
for subdir in ['audio', 'transcriptions', 'metadata']: | |
(self.root_path / subdir).mkdir(parents=True, exist_ok=True) | |
def load_text_file(self, file): | |
"""Process and load text file""" | |
try: | |
with open(file.name, 'r', encoding='utf-8') as f: | |
text = f.read() | |
self.sentences = nltk.sent_tokenize(text) | |
self.current_index = 0 | |
return True, f"Loaded {len(self.sentences)} sentences" | |
except Exception as e: | |
return False, f"Error loading file: {str(e)}" | |
def generate_filenames(self, dataset_name: str, speaker_id: str) -> tuple: | |
"""Generate unique filenames for audio and text""" | |
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | |
sentence_id = f"{self.current_index+1:04d}" | |
base_name = f"{dataset_name}_{speaker_id}_{sentence_id}_{timestamp}" | |
return f"{base_name}.wav", f"{base_name}.txt" | |
def save_recording(self, audio_file, speaker_id: str, dataset_name: str): | |
"""Save recording and transcription""" | |
if not audio_file or not speaker_id or not dataset_name: | |
return False, "Missing required information" | |
try: | |
# Generate filenames | |
audio_name, text_name = self.generate_filenames(dataset_name, speaker_id) | |
# Create speaker directories | |
audio_dir = self.root_path / 'audio' / speaker_id | |
text_dir = self.root_path / 'transcriptions' / speaker_id | |
audio_dir.mkdir(exist_ok=True) | |
text_dir.mkdir(exist_ok=True) | |
# Save audio file | |
shutil.copy2(audio_file, audio_dir / audio_name) | |
# Save transcription | |
self.save_transcription( | |
text_dir / text_name, | |
self.sentences[self.current_index], | |
{ | |
'speaker_id': speaker_id, | |
'dataset_name': dataset_name, | |
'timestamp': datetime.now().isoformat(), | |
'audio_file': audio_name | |
} | |
) | |
return True, "Recording saved successfully" | |
except Exception as e: | |
return False, f"Error saving recording: {str(e)}" | |
def save_transcription(self, file_path: Path, text: str, metadata: dict): | |
"""Save transcription with metadata""" | |
content = f"""[METADATA] | |
Recording_ID: {metadata['audio_file']} | |
Speaker_ID: {metadata['speaker_id']} | |
Dataset_Name: {metadata['dataset_name']} | |
Timestamp: {metadata['timestamp']} | |
[TEXT] | |
{text} | |
""" | |
with open(file_path, 'w', encoding='utf-8') as f: | |
f.write(content) | |
def create_interface(): | |
"""Create Gradio interface for TTS data collection""" | |
collector = TTSDatasetCollector() | |
with gr.Blocks(title="TTS Dataset Collection Tool") as interface: | |
gr.Markdown("# TTS Dataset Collection Tool") | |
with gr.Row(): | |
# Left column - Configuration | |
with gr.Column(): | |
file_input = gr.File( | |
label="Upload Text File (.txt)", | |
file_types=[".txt"] | |
) | |
speaker_id = gr.Textbox( | |
label="Speaker ID", | |
placeholder="Enter unique speaker identifier" | |
) | |
dataset_name = gr.Textbox( | |
label="Dataset Name", | |
placeholder="Enter dataset name" | |
) | |
# Right column - Recording | |
with gr.Column(): | |
current_text = gr.Textbox( | |
label="Current Sentence", | |
interactive=False | |
) | |
audio_recorder = gr.Audio( | |
label="Record Audio", | |
type="filepath" | |
) | |
next_text = gr.Textbox( | |
label="Next Sentence", | |
interactive=False | |
) | |
# Controls | |
with gr.Row(): | |
prev_btn = gr.Button("Previous") | |
next_btn = gr.Button("Next") | |
save_btn = gr.Button("Save Recording", variant="primary") | |
# Status | |
with gr.Row(): | |
progress = gr.Textbox( | |
label="Progress", | |
interactive=False | |
) | |
status = gr.Textbox( | |
label="Status", | |
interactive=False | |
) | |
# Event handlers | |
def load_file(file): | |
if not file: | |
return { | |
current_text: "", | |
next_text: "", | |
progress: "", | |
status: "No file selected" | |
} | |
success, msg = collector.load_text_file(file) | |
if not success: | |
return { | |
current_text: "", | |
next_text: "", | |
progress: "", | |
status: msg | |
} | |
return { | |
current_text: collector.sentences[0], | |
next_text: collector.sentences[1] if len(collector.sentences) > 1 else "", | |
progress: f"Sentence 1 of {len(collector.sentences)}", | |
status: msg | |
} | |
def update_display(): | |
"""Update interface display""" | |
if not collector.sentences: | |
return { | |
current_text: "", | |
next_text: "", | |
progress: "", | |
status: "No text loaded" | |
} | |
next_idx = collector.current_index + 1 | |
return { | |
current_text: collector.sentences[collector.current_index], | |
next_text: collector.sentences[next_idx] if next_idx < len(collector.sentences) else "", | |
progress: f"Sentence {collector.current_index + 1} of {len(collector.sentences)}", | |
status: "Ready for recording" | |
} | |
def next_sentence(): | |
"""Move to next sentence""" | |
if collector.sentences and collector.current_index < len(collector.sentences) - 1: | |
collector.current_index += 1 | |
return update_display() | |
def prev_sentence(): | |
"""Move to previous sentence""" | |
if collector.sentences and collector.current_index > 0: | |
collector.current_index -= 1 | |
return update_display() | |
def save_recording(audio, spk_id, ds_name): | |
"""Handle saving recording""" | |
if not audio: | |
return {status: "No audio recorded"} | |
if not spk_id: | |
return {status: "Speaker ID required"} | |
if not ds_name: | |
return {status: "Dataset name required"} | |
success, msg = collector.save_recording(audio, spk_id, ds_name) | |
return {status: msg} | |
# Connect event handlers | |
file_input.change( | |
load_file, | |
inputs=[file_input], | |
outputs=[current_text, next_text, progress, status] | |
) | |
next_btn.click( | |
next_sentence, | |
outputs=[current_text, next_text, progress, status] | |
) | |
prev_btn.click( | |
prev_sentence, | |
outputs=[current_text, next_text, progress, status] | |
) | |
save_btn.click( | |
save_recording, | |
inputs=[audio_recorder, speaker_id, dataset_name], | |
outputs=[status] | |
) | |
return interface | |
if __name__ == "__main__": | |
interface = create_interface() | |
interface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) |