Spaces:
Sleeping
Sleeping
""" | |
TTS Dataset Collection Tool with Custom Fonts and Enhanced Features | |
""" | |
import os | |
import json | |
import nltk | |
import gradio as gr | |
import uuid | |
from datetime import datetime | |
from pathlib import Path | |
import logging | |
from typing import Dict, Tuple, Optional | |
import traceback | |
import soundfile as sf | |
import re | |
# Download required NLTK data during initialization | |
try: | |
nltk.download('punkt') # Download punkt tokenizer data | |
nltk.data.find('tokenizers/punkt') | |
except Exception as e: | |
logger.warning(f"Error downloading NLTK data: {str(e)}") | |
logger.warning("NLTK tokenization might not work properly") | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Font configurations | |
FONT_STYLES = { | |
"english_serif": { | |
"name": "Times New Roman", | |
"family": "Times New Roman", | |
"css": "font-family: 'Times New Roman', serif;" | |
}, | |
"english_sans": { | |
"name": "Arial", | |
"family": "Arial", | |
"css": "font-family: Arial, sans-serif;" | |
}, | |
"nastaliq": { | |
"name": "Nastaliq", | |
"family": "Noto Nastaliq Urdu", | |
"css": "font-family: 'Noto Nastaliq Urdu', serif;" | |
}, | |
"naskh": { | |
"name": "Naskh", | |
"family": "Scheherazade New", | |
"css": "font-family: 'Scheherazade New', serif;" | |
} | |
} | |
class TTSDatasetCollector: | |
"""Manages TTS dataset collection and organization with enhanced features""" | |
def __init__(self): | |
"""Initialize the TTS Dataset Collector""" | |
# Handle both script and notebook environments for root path | |
try: | |
# When running as a script | |
self.root_path = Path(os.path.dirname(os.path.abspath(__file__))) / "dataset" | |
except NameError: | |
# When running in Jupyter/IPython | |
self.root_path = Path.cwd() / "dataset" | |
self.fonts_path = self.root_path / "fonts" | |
self.sentences = [] | |
self.current_index = 0 | |
self.current_font = "english_serif" | |
self.custom_fonts = {} | |
self.recordings = {} # Store recordings by sentence index | |
self.setup_directories() | |
# Ensure NLTK data is downloaded | |
try: | |
nltk.data.find('tokenizers/punkt') | |
except LookupError: | |
nltk.download('punkt', quiet=True) | |
logger.info("TTS Dataset Collector initialized") | |
def setup_directories(self) -> None: | |
"""Create necessary directory structure with logging""" | |
try: | |
# Create main dataset directory | |
self.root_path.mkdir(parents=True, exist_ok=True) | |
# Create subdirectories | |
for subdir in ['audio', 'transcriptions', 'metadata', 'fonts']: | |
(self.root_path / subdir).mkdir(parents=True, exist_ok=True) | |
# Initialize log file | |
log_file = self.root_path / 'dataset_log.txt' | |
if not log_file.exists(): | |
with open(log_file, 'w', encoding='utf-8') as f: | |
f.write(f"Dataset collection initialized on {datetime.now().isoformat()}\n") | |
logger.info("Directory structure created successfully") | |
except Exception as e: | |
logger.error(f"Failed to create directory structure: {str(e)}") | |
logger.error(traceback.format_exc()) | |
raise RuntimeError("Failed to initialize directory structure") | |
def log_operation(self, message: str, level: str = "info") -> None: | |
"""Log operations with timestamp and level""" | |
try: | |
log_file = self.root_path / 'dataset_log.txt' | |
timestamp = datetime.now().isoformat() | |
with open(log_file, 'a', encoding='utf-8') as f: | |
f.write(f"[{timestamp}] [{level.upper()}] {message}\n") | |
if level.lower() == "error": | |
logger.error(message) | |
else: | |
logger.info(message) | |
except Exception as e: | |
logger.error(f"Failed to log operation: {str(e)}") | |
def process_text(self, text: str) -> Tuple[bool, str]: | |
"""Process pasted or loaded text with error handling""" | |
try: | |
if not text.strip(): | |
return False, "Text is empty" | |
# Simple sentence splitting as fallback | |
def simple_split_sentences(text): | |
# Split on common sentence endings | |
sentences = [] | |
current = [] | |
for line in text.split('\n'): | |
line = line.strip() | |
if not line: | |
continue | |
# Split on common sentence endings | |
parts = re.split(r'[.!?]', line) | |
for part in parts: | |
part = part.strip() | |
if part: | |
current.append(part) | |
sentences.append(' '.join(current)) | |
current = [] | |
if current: | |
sentences.append(' '.join(current)) | |
return [s.strip() for s in sentences if s.strip()] | |
try: | |
# Try NLTK first | |
self.sentences = nltk.sent_tokenize(text.strip()) | |
except Exception as e: | |
logger.warning(f"NLTK tokenization failed, falling back to simple splitting: {str(e)}") | |
# Fallback to simple splitting | |
self.sentences = simple_split_sentences(text.strip()) | |
if not self.sentences: | |
return False, "No valid sentences found in text" | |
self.current_index = 0 | |
# Log success | |
self.log_operation(f"Processed text with {len(self.sentences)} sentences") | |
return True, f"Successfully loaded {len(self.sentences)} sentences" | |
except Exception as e: | |
error_msg = f"Error processing text: {str(e)}" | |
self.log_operation(error_msg, "error") | |
logger.error(traceback.format_exc()) | |
return False, error_msg | |
def load_text_file(self, file) -> Tuple[bool, str]: | |
"""Process and load text file with enhanced error handling""" | |
if not file: | |
return False, "No file provided" | |
try: | |
# Validate file extension | |
if not file.name.endswith('.txt'): | |
return False, "Only .txt files are supported" | |
text = file.read().decode('utf-8') | |
return self.process_text(text) | |
except UnicodeDecodeError: | |
error_msg = "File encoding error. Please ensure the file is UTF-8 encoded" | |
self.log_operation(error_msg, "error") | |
return False, error_msg | |
except Exception as e: | |
error_msg = f"Error loading file: {str(e)}" | |
self.log_operation(error_msg, "error") | |
logger.error(traceback.format_exc()) | |
return False, error_msg | |
def get_styled_text(self, text: str) -> str: | |
"""Get text with current font styling""" | |
font_css = FONT_STYLES.get(self.current_font, {}).get('css', '') | |
return f"<div style='{font_css}'>{text}</div>" | |
def set_font(self, font_style: str) -> Tuple[bool, str]: | |
"""Set the current font style""" | |
if font_style not in FONT_STYLES and font_style not in self.custom_fonts: | |
available_fonts = ', '.join(list(FONT_STYLES.keys()) + list(self.custom_fonts.keys())) | |
return False, f"Invalid font style. Available styles: {available_fonts}" | |
self.current_font = font_style | |
return True, f"Font style set to {font_style}" | |
def add_custom_font(self, font_file_path) -> Tuple[bool, str]: | |
"""Add a custom font from the uploaded TTF file""" | |
try: | |
if not font_file_path: | |
return False, "No font file provided" | |
if not font_file_path.endswith('.ttf'): | |
return False, "Only .ttf font files are supported" | |
# Generate a unique font family name | |
font_family = f"font_{uuid.uuid4().hex[:8]}" | |
font_filename = font_family + '.ttf' | |
font_dest = self.fonts_path / font_filename | |
# Read and save the font file | |
with open(font_file_path, 'rb') as f_src, open(font_dest, 'wb') as f_dest: | |
f_dest.write(f_src.read()) | |
# Add to custom fonts | |
self.custom_fonts[font_family] = { | |
'name': os.path.basename(font_file_path), | |
'family': font_family, | |
'css': f"font-family: '{font_family}', serif;" | |
} | |
# Update the FONT_STYLES with the custom font | |
FONT_STYLES[font_family] = self.custom_fonts[font_family] | |
# Log success | |
self.log_operation(f"Added custom font: {font_file_path} as {font_family}") | |
return True, f"Custom font '{os.path.basename(font_file_path)}' added successfully" | |
except Exception as e: | |
error_msg = f"Error adding custom font: {str(e)}" | |
self.log_operation(error_msg, "error") | |
logger.error(traceback.format_exc()) | |
return False, error_msg | |
def generate_filenames(self, dataset_name: str, speaker_id: str, sentence_text: str) -> Tuple[str, str]: | |
"""Generate unique filenames for audio and text files""" | |
line_number = self.current_index + 1 | |
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | |
# Sanitize strings for filenames | |
def sanitize_filename(s): | |
return re.sub(r'[^a-zA-Z0-9_-]', '_', s)[:50] | |
dataset_name_safe = sanitize_filename(dataset_name) | |
speaker_id_safe = sanitize_filename(speaker_id) | |
sentence_excerpt = sanitize_filename(sentence_text[:20]) | |
base_name = f"{dataset_name_safe}_{speaker_id_safe}_line{line_number}_{sentence_excerpt}_{timestamp}" | |
return f"{base_name}.wav", f"{base_name}.txt" | |
def save_recording(self, audio_file, speaker_id: str, dataset_name: str) -> Tuple[bool, str, Dict]: | |
"""Save recording with enhanced error handling and logging""" | |
if not all([audio_file, speaker_id, dataset_name]): | |
missing = [] | |
if not audio_file: | |
missing.append("audio recording") | |
if not speaker_id: | |
missing.append("speaker ID") | |
if not dataset_name: | |
missing.append("dataset name") | |
return False, f"Missing required information: {', '.join(missing)}", {} | |
# Check if sentences have been loaded | |
if not self.sentences: | |
return False, "No sentences have been loaded. Please load text before saving recordings.", {} | |
if self.current_index >= len(self.sentences): | |
return False, "Current sentence index is out of range.", {} | |
try: | |
# Validate inputs | |
if not speaker_id.strip().isalnum(): | |
return False, "Speaker ID must contain only letters and numbers", {} | |
if not dataset_name.strip().isalnum(): | |
return False, "Dataset name must contain only letters and numbers", {} | |
# Get current sentence text | |
sentence_text = self.sentences[self.current_index] | |
# Generate filenames | |
audio_name, text_name = self.generate_filenames(dataset_name, speaker_id, sentence_text) | |
# Create speaker directories | |
audio_dir = self.root_path / 'audio' / speaker_id | |
text_dir = self.root_path / 'transcriptions' / speaker_id | |
audio_dir.mkdir(parents=True, exist_ok=True) | |
text_dir.mkdir(parents=True, exist_ok=True) | |
# Save audio file | |
audio_path = audio_dir / audio_name | |
# Read the audio file using soundfile | |
audio_data, sampling_rate = sf.read(audio_file) | |
# Save audio file | |
sf.write(str(audio_path), audio_data, sampling_rate) | |
# Save transcription | |
text_path = text_dir / text_name | |
self.save_transcription( | |
text_path, | |
sentence_text, | |
{ | |
'speaker_id': speaker_id, | |
'dataset_name': dataset_name, | |
'timestamp': datetime.now().isoformat(), | |
'audio_file': audio_name, | |
'font_style': self.current_font | |
} | |
) | |
# Update metadata | |
self.update_metadata(speaker_id, dataset_name) | |
# Store the recording | |
self.recordings[self.current_index] = { | |
'audio_file': audio_file, | |
'speaker_id': speaker_id, | |
'dataset_name': dataset_name, | |
'sentence': self.sentences[self.current_index] | |
} | |
# Log success | |
self.log_operation( | |
f"Saved recording: Speaker={speaker_id}, Dataset={dataset_name}, " | |
f"Audio={audio_name}, Text={text_name}" | |
) | |
return True, f"Recording saved successfully as {audio_name}", self.recordings | |
except Exception as e: | |
error_msg = f"Error saving recording: {str(e)}" | |
self.log_operation(error_msg, "error") | |
logger.error(traceback.format_exc()) | |
return False, error_msg, self.recordings | |
def save_transcription(self, file_path: Path, text: str, metadata: Dict) -> None: | |
"""Save transcription with metadata""" | |
content = f"""[METADATA] | |
Recording_ID: {metadata['audio_file']} | |
Speaker_ID: {metadata['speaker_id']} | |
Dataset_Name: {metadata['dataset_name']} | |
Timestamp: {metadata['timestamp']} | |
Font_Style: {metadata['font_style']} | |
[TEXT] | |
{text} | |
""" | |
with open(file_path, 'w', encoding='utf-8') as f: | |
f.write(content) | |
def update_metadata(self, speaker_id: str, dataset_name: str) -> None: | |
"""Update dataset metadata with error handling""" | |
metadata_file = self.root_path / 'metadata' / 'dataset_info.json' | |
try: | |
if metadata_file.exists(): | |
with open(metadata_file, 'r') as f: | |
metadata = json.load(f) | |
else: | |
metadata = {'speakers': {}, 'last_updated': None} | |
# Update speaker data | |
if speaker_id not in metadata['speakers']: | |
metadata['speakers'][speaker_id] = { | |
'total_recordings': 0, | |
'datasets': {} | |
} | |
if dataset_name not in metadata['speakers'][speaker_id]['datasets']: | |
metadata['speakers'][speaker_id]['datasets'][dataset_name] = { | |
'recordings': 0, | |
'sentences': len(self.sentences), | |
'recorded_sentences': [], | |
'first_recording': datetime.now().isoformat(), | |
'last_recording': None, | |
'font_styles_used': [] | |
} | |
# Update counts and timestamps | |
metadata['speakers'][speaker_id]['total_recordings'] += 1 | |
metadata['speakers'][speaker_id]['datasets'][dataset_name]['recordings'] += 1 | |
metadata['speakers'][speaker_id]['datasets'][dataset_name]['last_recording'] = \ | |
datetime.now().isoformat() | |
# Add current index to recorded sentences | |
if self.current_index not in metadata['speakers'][speaker_id]['datasets'][dataset_name]['recorded_sentences']: | |
metadata['speakers'][speaker_id]['datasets'][dataset_name]['recorded_sentences'].append(self.current_index) | |
# Update font styles | |
if self.current_font not in metadata['speakers'][speaker_id]['datasets'][dataset_name]['font_styles_used']: | |
metadata['speakers'][speaker_id]['datasets'][dataset_name]['font_styles_used'].append( | |
self.current_font | |
) | |
metadata['last_updated'] = datetime.now().isoformat() | |
# Save updated metadata | |
with open(metadata_file, 'w') as f: | |
json.dump(metadata, f, indent=2) | |
self.log_operation(f"Updated metadata for {speaker_id} in {dataset_name}") | |
except Exception as e: | |
error_msg = f"Error updating metadata: {str(e)}" | |
self.log_operation(error_msg, "error") | |
logger.error(traceback.format_exc()) | |
def get_navigation_info(self) -> Dict[str, Optional[str]]: | |
"""Get current and next sentence information""" | |
if not self.sentences: | |
return { | |
'current': None, | |
'next': None, | |
'progress': "No text loaded" | |
} | |
current = self.get_styled_text(self.sentences[self.current_index]) | |
next_text = None | |
if self.current_index < len(self.sentences) - 1: | |
next_text = self.get_styled_text(self.sentences[self.current_index + 1]) | |
progress = f"Sentence {self.current_index + 1} of {len(self.sentences)}" | |
return { | |
'current': current, | |
'next': next_text, | |
'progress': progress | |
} | |
def navigate(self, direction: str) -> Dict[str, Optional[str]]: | |
"""Navigate through sentences""" | |
if not self.sentences: | |
return { | |
'current': None, | |
'next': None, | |
'progress': "No text loaded", | |
'status': "⚠️ Please load a text file first" | |
} | |
if direction == "next" and self.current_index < len(self.sentences) - 1: | |
self.current_index += 1 | |
elif direction == "prev" and self.current_index > 0: | |
self.current_index -= 1 | |
nav_info = self.get_navigation_info() | |
nav_info['status'] = "✅ Navigation successful" | |
return nav_info | |
def get_dataset_statistics(self) -> Dict: | |
"""Get current dataset statistics""" | |
try: | |
metadata_file = self.root_path / 'metadata' / 'dataset_info.json' | |
if not metadata_file.exists(): | |
return {} | |
with open(metadata_file, 'r') as f: | |
metadata = json.load(f) | |
# Flatten statistics for display | |
total_sentences = len(self.sentences) | |
recorded = sum(len(dataset.get('recorded_sentences', [])) for speaker in metadata['speakers'].values() for dataset in speaker['datasets'].values()) | |
remaining = total_sentences - recorded | |
stats = { | |
"Total Sentences": total_sentences, | |
"Recorded Sentences": recorded, | |
"Remaining Sentences": remaining, | |
"Last Updated": metadata.get('last_updated', 'N/A') | |
} | |
return stats | |
except Exception as e: | |
logger.error(f"Error reading dataset statistics: {str(e)}") | |
return {} | |
def get_last_audio_path(self, speaker_id: str) -> Optional[str]: | |
"""Get the path to the last saved audio file for downloading""" | |
audio_dir = self.root_path / 'audio' / speaker_id | |
audio_files = sorted(audio_dir.glob('*.wav'), key=lambda f: f.stat().st_mtime, reverse=True) | |
if audio_files: | |
return str(audio_files[0]) | |
else: | |
return None | |
def get_last_transcript_path(self, speaker_id: str) -> Optional[str]: | |
"""Get the path to the last saved transcription file for downloading""" | |
text_dir = self.root_path / 'transcriptions' / speaker_id | |
text_files = sorted(text_dir.glob('*.txt'), key=lambda f: f.stat().st_mtime, reverse=True) | |
if text_files: | |
return str(text_files[0]) | |
else: | |
return None | |
def create_zip_archive(self, speaker_id: str) -> Optional[str]: | |
"""Create a ZIP archive of all recordings and transcriptions for a speaker""" | |
try: | |
from zipfile import ZipFile | |
import tempfile | |
# Create temporary zip file | |
temp_dir = Path(tempfile.gettempdir()) | |
zip_path = temp_dir / f"{speaker_id}_recordings.zip" | |
with ZipFile(zip_path, 'w') as zipf: | |
# Add audio files | |
audio_dir = self.root_path / 'audio' / speaker_id | |
if audio_dir.exists(): | |
for audio_file in audio_dir.glob('*.wav'): | |
zipf.write(audio_file, f"audio/{audio_file.name}") | |
# Add transcription files | |
text_dir = self.root_path / 'transcriptions' / speaker_id | |
if text_dir.exists(): | |
for text_file in text_dir.glob('*.txt'): | |
zipf.write(text_file, f"transcriptions/{text_file.name}") | |
return str(zip_path) | |
except Exception as e: | |
logger.error(f"Error creating zip archive: {str(e)}") | |
return None | |
def create_interface(): | |
"""Create Gradio interface with enhanced features""" | |
collector = TTSDatasetCollector() | |
# Create custom CSS for fonts | |
custom_css = """ | |
.gradio-container { | |
max-width: 1200px !important; | |
} | |
.record-button { | |
font-size: 1em !important; | |
padding: 10px !important; | |
} | |
.sentence-display { | |
font-size: 1.4em !important; | |
padding: 15px !important; | |
border: 1px solid #ddd !important; | |
border-radius: 8px !important; | |
margin: 10px 0 !important; | |
min-height: 100px !important; | |
} | |
.small-input { | |
max-width: 300px !important; | |
} | |
""" | |
# Include Google Fonts for Nastaliq and Naskh | |
google_fonts_css = """ | |
@import url('https://fonts.googleapis.com/earlyaccess/notonastaliqurdu.css'); | |
@import url('https://fonts.googleapis.com/css2?family=Scheherazade+New&display=swap'); | |
""" | |
custom_css = google_fonts_css + custom_css | |
with gr.Blocks(title="TTS Dataset Collection Tool", css=custom_css) as interface: | |
gr.Markdown("# TTS Dataset Collection Tool") | |
status = gr.Textbox( | |
label="Status", | |
interactive=False, | |
max_lines=3, | |
elem_classes=["small-input"] | |
) | |
with gr.Row(): | |
# Left column - Configuration and Input | |
with gr.Column(scale=1): | |
text_input = gr.Textbox( | |
label="Paste Text", | |
placeholder="Paste your text here...", | |
lines=5, | |
elem_classes=["small-input"], | |
interactive=True | |
) | |
file_input = gr.File( | |
label="Or Upload Text File (.txt)", | |
file_types=[".txt"], | |
elem_classes=["small-input"] | |
) | |
speaker_id = gr.Textbox( | |
label="Speaker ID", | |
placeholder="Enter unique speaker identifier (letters and numbers only)", | |
elem_classes=["small-input"] | |
) | |
dataset_name = gr.Textbox( | |
label="Dataset Name", | |
placeholder="Enter dataset name (letters and numbers only)", | |
elem_classes=["small-input"] | |
) | |
font_select = gr.Dropdown( | |
choices=list(FONT_STYLES.keys()), | |
value="english_serif", | |
label="Select Font Style", | |
elem_classes=["small-input"] | |
) | |
# Custom font upload | |
with gr.Accordion("Custom Font Upload", open=False): | |
font_file_input = gr.File( | |
label="Upload Custom Font (.ttf)", | |
file_types=[".ttf"], | |
elem_classes=["small-input"], | |
type="filepath" | |
) | |
add_font_btn = gr.Button("Add Custom Font") | |
# Dataset Info | |
with gr.Accordion("Dataset Statistics", open=False): | |
dataset_info = gr.JSON( | |
label="", | |
value={} | |
) | |
# Right column - Recording | |
with gr.Column(scale=2): | |
current_text = gr.HTML( | |
label="Current Sentence", | |
elem_classes=["sentence-display"] | |
) | |
next_text = gr.HTML( | |
label="Next Sentence", | |
elem_classes=["sentence-display"] | |
) | |
progress = gr.HTML("") | |
with gr.Row(): | |
audio_recorder = gr.Audio( | |
label="Record Audio", | |
type="filepath", | |
elem_classes=["record-button"], | |
interactive=True, | |
streaming=False # Disable streaming to prevent freezing | |
) | |
clear_btn = gr.Button("Clear Recording", variant="secondary") | |
# Controls | |
with gr.Row(): | |
prev_btn = gr.Button("Previous", variant="secondary") | |
save_btn = gr.Button("Save Recording", variant="primary") | |
next_btn = gr.Button("Next", variant="primary") | |
# Download Links | |
with gr.Row(): | |
download_audio = gr.File(label="Download Last Audio", interactive=False) | |
download_transcript = gr.File(label="Download Last Transcript", interactive=False) | |
download_all = gr.File(label="Download All Recordings", interactive=False) | |
def download_all_recordings(speaker_id_value): | |
"""Handle downloading all recordings for a speaker""" | |
if not speaker_id_value: | |
return { | |
status: "⚠️ Please enter a Speaker ID first", | |
download_all: None | |
} | |
zip_path = collector.create_zip_archive(speaker_id_value) | |
if zip_path: | |
return { | |
status: "✅ Archive created successfully", | |
download_all: zip_path | |
} | |
return { | |
status: "❌ Failed to create archive", | |
download_all: None | |
} | |
# Add download all button and its event handler | |
download_all_btn = gr.Button("Download All Recordings", variant="secondary") | |
download_all_btn.click( | |
download_all_recordings, | |
inputs=[speaker_id], | |
outputs=[status, download_all] | |
) | |
# Add recordings display | |
with gr.Column(scale=2): | |
recordings_display = gr.HTML( | |
label="Saved Recordings", | |
value="<div id='recordings-list'></div>" | |
) | |
def process_pasted_text(text): | |
"""Handle pasted text input""" | |
if not text: | |
return { | |
current_text: "", | |
next_text: "", | |
progress: "", | |
status: "⚠️ No text provided", | |
dataset_info: collector.get_dataset_statistics() | |
} | |
success, msg = collector.process_text(text) | |
if not success: | |
return { | |
current_text: "", | |
next_text: "", | |
progress: "", | |
status: f"❌ {msg}", | |
dataset_info: collector.get_dataset_statistics() | |
} | |
nav_info = collector.get_navigation_info() | |
progress_bar = f"<progress value='{collector.current_index + 1}' max='{len(collector.sentences)}'></progress> {nav_info['progress']}" | |
return { | |
current_text: nav_info['current'], | |
next_text: nav_info['next'], | |
progress: progress_bar, | |
status: f"✅ {msg}", | |
dataset_info: collector.get_dataset_statistics() | |
} | |
def update_font(font_style): | |
"""Update font and refresh display""" | |
success, msg = collector.set_font(font_style) | |
if not success: | |
return {status: msg} | |
nav_info = collector.get_navigation_info() | |
return { | |
current_text: nav_info['current'], | |
next_text: nav_info['next'], | |
status: f"Font updated to {font_style}" | |
} | |
def load_file(file): | |
"""Handle file loading with enhanced error reporting""" | |
if not file: | |
return { | |
current_text: "", | |
next_text: "", | |
progress: "", | |
status: "⚠️ No file selected", | |
dataset_info: collector.get_dataset_statistics() | |
} | |
success, msg = collector.load_text_file(file) | |
if not success: | |
return { | |
current_text: "", | |
next_text: "", | |
progress: "", | |
status: f"❌ {msg}", | |
dataset_info: collector.get_dataset_statistics() | |
} | |
nav_info = collector.get_navigation_info() | |
progress_bar = f"<progress value='{collector.current_index + 1}' max='{len(collector.sentences)}'></progress> {nav_info['progress']}" | |
return { | |
current_text: nav_info['current'], | |
next_text: nav_info['next'], | |
progress: progress_bar, | |
status: f"✅ {msg}", | |
dataset_info: collector.get_dataset_statistics() | |
} | |
def save_current_recording(audio_file, speaker_id_value, dataset_name_value): | |
"""Handle saving the current recording""" | |
if not audio_file: | |
return { | |
status: "⚠️ Please record audio first", | |
download_audio: None, | |
download_transcript: None, | |
download_all: None, | |
recordings_display: "<div id='recordings-list'>No recordings yet</div>", | |
audio_recorder: None # Clear the recorder | |
} | |
success, msg, recordings = collector.save_recording( | |
audio_file, speaker_id_value, dataset_name_value | |
) | |
if not success: | |
return { | |
status: f"❌ {msg}", | |
dataset_info: collector.get_dataset_statistics(), | |
download_audio: None, | |
download_transcript: None, | |
download_all: None, | |
recordings_display: "<div id='recordings-list'>No recordings yet</div>" | |
} | |
# Get paths to the saved files | |
audio_path = collector.get_last_audio_path(speaker_id_value) | |
transcript_path = collector.get_last_transcript_path(speaker_id_value) | |
zip_path = collector.create_zip_archive(speaker_id_value) | |
# Auto-advance to next sentence after successful save | |
nav_info = collector.navigate("next") | |
progress_bar = f"<progress value='{collector.current_index + 1}' max='{len(collector.sentences)}'></progress> {nav_info['progress']}" | |
# Update recordings display | |
recordings_html = create_recordings_display(recordings) | |
result = { | |
current_text: nav_info['current'], | |
next_text: nav_info['next'], | |
progress: progress_bar, | |
status: f"✅ {msg}", | |
dataset_info: collector.get_dataset_statistics(), | |
download_audio: audio_path, | |
download_transcript: transcript_path, | |
download_all: zip_path, | |
recordings_display: recordings_html, | |
audio_recorder: None # Clear the recorder after successful save | |
} | |
return result | |
def create_recordings_display(recordings): | |
"""Create HTML display for recordings""" | |
recordings_html = "<div id='recordings-list'><h3>Saved Recordings:</h3>" | |
for idx, rec in recordings.items(): | |
recordings_html += f""" | |
<div style='margin: 10px 0; padding: 10px; border: 1px solid #ddd; border-radius: 5px;'> | |
<p><strong>Sentence {idx + 1}:</strong> {rec['sentence']}</p> | |
<audio controls src='{rec['audio_file']}'></audio> | |
</div> | |
""" | |
recordings_html += "</div>" | |
return recordings_html | |
def navigate_sentences(direction): | |
"""Handle navigation between sentences""" | |
nav_info = collector.navigate(direction) | |
progress_bar = f"<progress value='{collector.current_index + 1}' max='{len(collector.sentences)}'></progress> {nav_info['progress']}" | |
return { | |
current_text: nav_info['current'], | |
next_text: nav_info['next'], | |
progress: progress_bar, | |
status: nav_info['status'] | |
} | |
def add_custom_font(font_file_path): | |
"""Handle adding a custom font""" | |
if not font_file_path: | |
return { | |
font_select: gr.update(), | |
status: "⚠️ No font file selected" | |
} | |
success, msg = collector.add_custom_font(font_file_path) | |
if not success: | |
return { | |
font_select: gr.update(), | |
status: f"❌ {msg}" | |
} | |
# Update font dropdown | |
font_choices = list(FONT_STYLES.keys()) + list(collector.custom_fonts.keys()) | |
# Return updates to font_select and status | |
return { | |
font_select: gr.update(choices=font_choices), | |
status: f"✅ {msg}" | |
} | |
def clear_recording(): | |
"""Clear the current recording""" | |
return { | |
audio_recorder: None, | |
status: "Recording cleared" | |
} | |
# Add clear button handler | |
clear_btn.click( | |
clear_recording, | |
outputs=[audio_recorder, status] | |
) | |
# Event handlers | |
text_input.change( | |
process_pasted_text, | |
inputs=[text_input], | |
outputs=[current_text, next_text, progress, status, dataset_info] | |
) | |
file_input.upload( | |
load_file, | |
inputs=[file_input], | |
outputs=[current_text, next_text, progress, status, dataset_info] | |
) | |
font_select.change( | |
update_font, | |
inputs=[font_select], | |
outputs=[current_text, next_text, status] | |
) | |
add_font_btn.click( | |
add_custom_font, | |
inputs=[font_file_input], | |
outputs=[font_select, status] | |
) | |
save_btn.click( | |
save_current_recording, | |
inputs=[audio_recorder, speaker_id, dataset_name], | |
outputs=[current_text, next_text, progress, status, dataset_info, | |
download_audio, download_transcript, download_all, recordings_display, | |
audio_recorder] # Add audio_recorder to outputs | |
) | |
prev_btn.click( | |
lambda: navigate_sentences("prev"), | |
outputs=[current_text, next_text, progress, status] | |
) | |
next_btn.click( | |
lambda: navigate_sentences("next"), | |
outputs=[current_text, next_text, progress, status] | |
) | |
# Initialize dataset info | |
dataset_info.value = collector.get_dataset_statistics() | |
return interface | |
if __name__ == "__main__": | |
try: | |
# Set up any required environment variables | |
os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0" | |
os.environ["GRADIO_SERVER_PORT"] = "7860" | |
# Create and launch the interface | |
interface = create_interface() | |
interface.queue() # Enable queuing for better handling of concurrent users | |
interface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, | |
debug=True, | |
show_error=True | |
) | |
except Exception as e: | |
logger.error(f"Failed to launch interface: {str(e)}") | |
logger.error(traceback.format_exc()) | |
raise | |