import gradio as gr from fastapi import FastAPI, HTTPException, BackgroundTasks from pydantic import BaseModel, HttpUrl from typing import List, Optional, Dict import torch import torchaudio from transformers import AutoProcessor, AutoModelForCTC import evaluate import zipfile from datetime import datetime import json import uuid import os from pathlib import Path import os hf_token = os.environ.get("HF_TOKEN") app = FastAPI(title="TIMIT Phoneme Transcription Leaderboard") # Create Gradio interface demo = gr.Interface( fn=lambda x: x, inputs=gr.Textbox(visible=False), outputs=gr.Textbox(visible=False), title="TIMIT Phoneme Transcription Queue", description="API endpoints are available at /api/leaderboard, /api/evaluate, and /api/tasks/{task_id}" ) # Get absolute path - Updated for HF Spaces CURRENT_DIR = Path(__file__).parent.absolute() # Constants - Updated for HF Spaces environment TIMIT_PATH = CURRENT_DIR / ".data" / "TIMIT.zip" # Move TIMIT.zip to root of space QUEUE_DIR = CURRENT_DIR / "queue" PATHS = { 'tasks': QUEUE_DIR / "tasks.json", 'results': QUEUE_DIR / "results.json", 'leaderboard': QUEUE_DIR / "leaderboard.json" } # Initialize evaluation metric # phone_errors = evaluate.load("ginic/phone_errors") phone_errors = evaluate.load("ginic/phone_errors", download_config=evaluate.DownloadConfig(token=hf_token)) class TimitDataManager: """Handles all TIMIT dataset operations""" # TIMIT to IPA mapping with direct simplifications TIMIT_TO_IPA = { # Vowels (simplified) 'aa': 'ɑ', 'ae': 'æ', 'ah': 'ʌ', 'ao': 'ɔ', 'aw': 'aʊ', 'ay': 'aɪ', 'eh': 'ɛ', 'er': 'ɹ', # Simplified from 'ɝ' 'ey': 'eɪ', 'ih': 'ɪ', 'ix': 'i', # Simplified from 'ɨ' 'iy': 'i', 'ow': 'oʊ', 'oy': 'ɔɪ', 'uh': 'ʊ', 'uw': 'u', 'ux': 'u', # Simplified from 'ʉ' 'ax': 'ə', 'ax-h': 'ə', # Simplified from 'ə̥' 'axr': 'ɹ', # Simplified from 'ɚ' # Consonants 'b': '', 'bcl': 'b', 'd': '', 'dcl': 'd', 'g': '', 'gcl': 'g', 'p': '', 'pcl': 'p', 't': '', 'tcl': 't', 'k': '', 'kcl': 'k', 'dx': 'ɾ', 'q': 'ʔ', # Fricatives 'jh': 'dʒ', 'ch': 'tʃ', 's': 's', 'sh': 'ʃ', 'z': 'z', 'zh': 'ʒ', 'f': 'f', 'th': 'θ', 'v': 'v', 'dh': 'ð', 'hh': 'h', 'hv': 'h', # Simplified from 'ɦ' # Nasals (simplified) 'm': 'm', 'n': 'n', 'ng': 'ŋ', 'em': 'm', # Simplified from 'm̩' 'en': 'n', # Simplified from 'n̩' 'eng': 'ŋ', # Simplified from 'ŋ̍' 'nx': 'ɾ', # Simplified from 'ɾ̃' # Semivowels and Glides 'l': 'l', 'r': 'ɹ', 'w': 'w', 'wh': 'ʍ', 'y': 'j', 'el': 'l', # Simplified from 'l̩' # Special 'epi': '', # Remove epenthetic silence 'h#': '', # Remove start/end silence 'pau': '', # Remove pause } def __init__(self, timit_path: Path): self.timit_path = timit_path self._zip = None print(f"TimitDataManager initialized with path: {self.timit_path.absolute()}") if not self.timit_path.exists(): raise FileNotFoundError(f"TIMIT dataset not found at {self.timit_path.absolute()}") print("TIMIT dataset file exists!") @property def zip(self): if not self._zip: try: self._zip = zipfile.ZipFile(self.timit_path, 'r') print("Successfully opened TIMIT zip file") except FileNotFoundError: raise FileNotFoundError(f"TIMIT dataset not found at {self.timit_path}") return self._zip def get_file_list(self, subset: str) -> List[str]: """Get list of WAV files for given subset""" files = [f for f in self.zip.namelist() if f.endswith('.WAV') and subset.lower() in f.lower()] print(f"Found {len(files)} WAV files in {subset} subset") if files: print("First 3 files:", files[:3]) return files def load_audio(self, filename: str) -> torch.Tensor: """Load and preprocess audio file""" with self.zip.open(filename) as wav_file: waveform, sample_rate = torchaudio.load(wav_file) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) if sample_rate != 16000: waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform) waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-7) if waveform.dim() == 1: waveform = waveform.unsqueeze(0) return waveform def get_phonemes(self, filename: str) -> str: """Get cleaned phoneme sequence from PHN file and convert to IPA""" phn_file = filename.replace('.WAV', '.PHN') with self.zip.open(phn_file) as f: phonemes = [] for line in f.read().decode('utf-8').splitlines(): if line.strip(): _, _, phone = line.split() phone = self.remove_stress_mark(phone) # Convert to IPA instead of using simplify_timit ipa = self.TIMIT_TO_IPA.get(phone.lower(), '') if ipa: phonemes.append(ipa) return ''.join(phonemes) # Join without spaces for IPA def simplify_timit(self, phoneme: str) -> str: """Apply substitutions to simplify TIMIT phonemes""" return self.PHONE_SUBSTITUTIONS.get(phoneme, phoneme) def remove_stress_mark(self, text: str) -> str: """Removes the combining double inverted breve (͡) from text""" if not isinstance(text, str): raise TypeError("Input must be string") return text.replace('͡', '') class ModelManager: """Handles model loading and inference""" def __init__(self): self.models = {} self.processors = {} self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.batch_size = 32 # Added batch size parameter def get_model_and_processor(self, model_name: str): """Get or load model and processor""" if model_name not in self.models: print("Loading processor with phoneme tokenizer...") processor = AutoProcessor.from_pretrained(model_name) print("Loading model...", {model_name}) model = AutoModelForCTC.from_pretrained(model_name).to(self.device) self.models[model_name] = model self.processors[model_name] = processor return self.models[model_name], self.processors[model_name] def transcribe(self, audio_list: List[torch.Tensor], model_name: str) -> List[str]: """Transcribe a batch of audio using specified model""" model, processor = self.get_model_and_processor(model_name) if not model or not processor: raise Exception("Model and processor not loaded") # Process audio in batches all_predictions = [] for i in range(0, len(audio_list), self.batch_size): batch_audio = audio_list[i:i + self.batch_size] # Pad sequence within batch max_length = max(audio.shape[-1] for audio in batch_audio) padded_audio = torch.zeros((len(batch_audio), 1, max_length)) attention_mask = torch.zeros((len(batch_audio), max_length)) for j, audio in enumerate(batch_audio): padded_audio[j, :, :audio.shape[-1]] = audio attention_mask[j, :audio.shape[-1]] = 1 # Process batch inputs = processor( padded_audio.squeeze(1).numpy(), sampling_rate=16000, return_tensors="pt", padding=True ) input_values = inputs.input_values.to(self.device) attention_mask = inputs.get("attention_mask", attention_mask).to(self.device) with torch.no_grad(): outputs = model( input_values=input_values, attention_mask=attention_mask ) logits = outputs.logits predicted_ids = torch.argmax(logits, dim=-1) predictions = processor.batch_decode(predicted_ids, skip_special_tokens=True) predictions = [pred.replace(' ', '') for pred in predictions] all_predictions.extend(predictions) return all_predictions class StorageManager: """Handles all JSON storage operations""" def __init__(self, paths: Dict[str, Path]): self.paths = paths self._ensure_directories() def _ensure_directories(self): """Ensure all necessary directories and files exist""" for path in self.paths.values(): path.parent.mkdir(parents=True, exist_ok=True) if not path.exists(): path.write_text('[]') def load(self, key: str) -> List: """Load JSON file""" return json.loads(self.paths[key].read_text()) def save(self, key: str, data: List): """Save data to JSON file""" self.paths[key].write_text(json.dumps(data, indent=4, default=str, ensure_ascii=False)) def update_task(self, task_id: str, updates: Dict): """Update specific task with new data""" tasks = self.load('tasks') for task in tasks: if task['id'] == task_id: task.update(updates) break self.save('tasks', tasks) class EvaluationRequest(BaseModel): """Request model for TIMIT evaluation""" transcription_model: str subset: str = "test" max_samples: Optional[int] = None submission_name: str github_url: Optional[str] = None # Initialize managers timit_manager = TimitDataManager(TIMIT_PATH) model_manager = ModelManager() storage_manager = StorageManager(PATHS) async def evaluate_model(task_id: str, request: EvaluationRequest): """Background task to evaluate model on TIMIT""" try: storage_manager.update_task(task_id, {"status": "processing"}) files = timit_manager.get_file_list(request.subset) if request.max_samples: files = files[:request.max_samples] results = [] total_per = total_pwed = 0 # Process files in batches batch_size = model_manager.batch_size for i in range(0, len(files), batch_size): batch_files = files[i:i + batch_size] # Load batch audio and ground truth batch_audio = [] batch_ground_truth = [] for wav_file in batch_files: audio = timit_manager.load_audio(wav_file) ground_truth = timit_manager.get_phonemes(wav_file) batch_audio.append(audio) batch_ground_truth.append(ground_truth) # Get predictions for batch predictions = model_manager.transcribe(batch_audio, request.transcription_model) # Calculate metrics for each file in batch for j, (wav_file, prediction, ground_truth) in enumerate(zip(batch_files, predictions, batch_ground_truth)): # Convert Unicode to readable format #prediction_str = repr(prediction)[1:-1] # Remove quotes but keep escaped unicode metrics = phone_errors.compute( predictions=[prediction], references=[ground_truth], is_normalize_pfer=True ) per = metrics['phone_error_rates'][0] pwed = metrics['phone_feature_error_rates'][0] results.append({ "file": wav_file, "ground_truth": ground_truth, "prediction": prediction, "per": per, "pwed": pwed }) total_per += per total_pwed += pwed if not results: raise Exception("No files were successfully processed") avg_per = total_per / len(results) avg_pwed = total_pwed / len(results) result = { "task_id": task_id, "model": request.transcription_model, "subset": request.subset, "num_files": len(results), "average_per": avg_per, "average_pwed": avg_pwed, "detailed_results": results[:5], "timestamp": datetime.now().isoformat() } # Save results print("Saving results...") current_results = storage_manager.load('results') current_results.append(result) storage_manager.save('results', current_results) # Update leaderboard print("Updating leaderboard...") leaderboard = storage_manager.load('leaderboard') entry = next((e for e in leaderboard if e["submission_name"] == request.submission_name), None) if entry: # Simply update with new scores entry.update({ "average_per": avg_per, "average_pwed": avg_pwed, "model": request.transcription_model, "subset": request.subset, "github_url": request.github_url, "submission_date": datetime.now().isoformat() }) else: leaderboard.append({ "submission_id": str(uuid.uuid4()), "submission_name": request.submission_name, "model": request.transcription_model, "average_per": avg_per, "average_pwed": avg_pwed, "subset": request.subset, "github_url": request.github_url, "submission_date": datetime.now().isoformat() }) storage_manager.save('leaderboard', leaderboard) storage_manager.update_task(task_id, {"status": "completed"}) print("Evaluation completed successfully") except Exception as e: error_msg = f"Evaluation failed: {str(e)}" print(error_msg) storage_manager.update_task(task_id, { "status": "failed", "error": error_msg }) # Initialize managers def init_directories(): """Ensure all necessary directories exist""" (CURRENT_DIR / ".data").mkdir(parents=True, exist_ok=True) QUEUE_DIR.mkdir(parents=True, exist_ok=True) for path in PATHS.values(): if not path.exists(): path.write_text('[]') # Initialize your managers init_directories() # Your existing initialization function timit_manager = TimitDataManager(TIMIT_PATH) model_manager = ModelManager() storage_manager = StorageManager(PATHS) @app.get("/api/health") async def health_check(): """Simple health check endpoint""" return {"status": "healthy"} @app.post("/api/evaluate") async def submit_evaluation( request: EvaluationRequest, background_tasks: BackgroundTasks ): """Submit new evaluation task""" task_id = str(uuid.uuid4()) task = { "id": task_id, "model": request.transcription_model, "subset": request.subset, "submission_name": request.submission_name, "github_url": request.github_url, "status": "queued", "submitted_at": datetime.now().isoformat() } tasks = storage_manager.load('tasks') tasks.append(task) storage_manager.save('tasks', tasks) background_tasks.add_task(evaluate_model, task_id, request) return { "message": "Evaluation task submitted successfully", "task_id": task_id } @app.get("/api/tasks/{task_id}") async def get_task(task_id: str): """Get specific task status""" tasks = storage_manager.load('tasks') task = next((t for t in tasks if t["id"] == task_id), None) if not task: raise HTTPException(status_code=404, detail="Task not found") return task @app.get("/api/leaderboard") async def get_leaderboard(): """Get current leaderboard""" try: leaderboard = storage_manager.load('leaderboard') sorted_leaderboard = sorted(leaderboard, key=lambda x: (x["average_per"], x["average_pwed"])) return sorted_leaderboard except Exception as e: print(f"Error loading leaderboard: {e}") return [] # Note: We need to mount the FastAPI app after defining all routes app = gr.mount_gradio_app(app, demo, path="/") # For local development if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)