Spaces:
Running
Running
import gradio as gr | |
from fastapi import FastAPI, HTTPException, BackgroundTasks | |
from pydantic import BaseModel, HttpUrl | |
from typing import List, Optional, Dict | |
import torch | |
import torchaudio | |
from transformers import AutoProcessor, AutoModelForCTC | |
import evaluate | |
import zipfile | |
from datetime import datetime | |
import json | |
import uuid | |
import os | |
from pathlib import Path | |
from huggingface_hub import HfApi | |
import evaluate | |
from phone_metrics import PhoneErrorMetrics | |
# Set up download configuration with your token | |
app = FastAPI(title="TIMIT Phoneme Transcription Leaderboard") | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=lambda x: x, | |
inputs=gr.Textbox(visible=False), | |
outputs=gr.Textbox(visible=False), | |
title="TIMIT Phoneme Transcription Queue", | |
description="API endpoints are available at /api/leaderboard, /api/evaluate, and /api/tasks/{task_id}" | |
) | |
# Get absolute path - Updated for HF Spaces | |
CURRENT_DIR = Path(__file__).parent.absolute() | |
# Constants - Updated for HF Spaces environment | |
TIMIT_PATH = CURRENT_DIR / ".data" / "TIMIT.zip" # Move TIMIT.zip to root of space | |
QUEUE_DIR = CURRENT_DIR / "queue" | |
PATHS = { | |
'tasks': QUEUE_DIR / "tasks.json", | |
'results': QUEUE_DIR / "results.json", | |
'leaderboard': QUEUE_DIR / "leaderboard.json" | |
} | |
# Initialize evaluation metric | |
phone_errors = PhoneErrorMetrics() | |
class TimitDataManager: | |
"""Handles all TIMIT dataset operations""" | |
# TIMIT to IPA mapping with direct simplifications | |
TIMIT_TO_IPA = { | |
# Vowels (simplified) | |
'aa': 'ɑ', | |
'ae': 'æ', | |
'ah': 'ʌ', | |
'ao': 'ɔ', | |
'aw': 'aʊ', | |
'ay': 'aɪ', | |
'eh': 'ɛ', | |
'er': 'ɹ', # Simplified from 'ɝ' | |
'ey': 'eɪ', | |
'ih': 'ɪ', | |
'ix': 'i', # Simplified from 'ɨ' | |
'iy': 'i', | |
'ow': 'oʊ', | |
'oy': 'ɔɪ', | |
'uh': 'ʊ', | |
'uw': 'u', | |
'ux': 'u', # Simplified from 'ʉ' | |
'ax': 'ə', | |
'ax-h': 'ə', # Simplified from 'ə̥' | |
'axr': 'ɹ', # Simplified from 'ɚ' | |
# Consonants | |
'b': '', | |
'bcl': 'b', | |
'd': '', | |
'dcl': 'd', | |
'g': '', | |
'gcl': 'g', | |
'p': '', | |
'pcl': 'p', | |
't': '', | |
'tcl': 't', | |
'k': '', | |
'kcl': 'k', | |
'dx': 'ɾ', | |
'q': 'ʔ', | |
# Fricatives | |
'jh': 'dʒ', | |
'ch': 'tʃ', | |
's': 's', | |
'sh': 'ʃ', | |
'z': 'z', | |
'zh': 'ʒ', | |
'f': 'f', | |
'th': 'θ', | |
'v': 'v', | |
'dh': 'ð', | |
'hh': 'h', | |
'hv': 'h', # Simplified from 'ɦ' | |
# Nasals (simplified) | |
'm': 'm', | |
'n': 'n', | |
'ng': 'ŋ', | |
'em': 'm', # Simplified from 'm̩' | |
'en': 'n', # Simplified from 'n̩' | |
'eng': 'ŋ', # Simplified from 'ŋ̍' | |
'nx': 'ɾ', # Simplified from 'ɾ̃' | |
# Semivowels and Glides | |
'l': 'l', | |
'r': 'ɹ', | |
'w': 'w', | |
'wh': 'ʍ', | |
'y': 'j', | |
'el': 'l', # Simplified from 'l̩' | |
# Special | |
'epi': '', # Remove epenthetic silence | |
'h#': '', # Remove start/end silence | |
'pau': '', # Remove pause | |
} | |
def __init__(self, timit_path: Path): | |
self.timit_path = timit_path | |
self._zip = None | |
print(f"TimitDataManager initialized with path: {self.timit_path.absolute()}") | |
if not self.timit_path.exists(): | |
raise FileNotFoundError(f"TIMIT dataset not found at {self.timit_path.absolute()}") | |
print("TIMIT dataset file exists!") | |
def zip(self): | |
if not self._zip: | |
try: | |
self._zip = zipfile.ZipFile(self.timit_path, 'r') | |
print("Successfully opened TIMIT zip file") | |
except FileNotFoundError: | |
raise FileNotFoundError(f"TIMIT dataset not found at {self.timit_path}") | |
return self._zip | |
def get_file_list(self, subset: str) -> List[str]: | |
"""Get list of WAV files for given subset""" | |
files = [f for f in self.zip.namelist() | |
if f.endswith('.WAV') and subset.lower() in f.lower()] | |
print(f"Found {len(files)} WAV files in {subset} subset") | |
if files: | |
print("First 3 files:", files[:3]) | |
return files | |
def load_audio(self, filename: str) -> torch.Tensor: | |
"""Load and preprocess audio file""" | |
with self.zip.open(filename) as wav_file: | |
waveform, sample_rate = torchaudio.load(wav_file) | |
if waveform.shape[0] > 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
if sample_rate != 16000: | |
waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform) | |
waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-7) | |
if waveform.dim() == 1: | |
waveform = waveform.unsqueeze(0) | |
return waveform | |
def get_phonemes(self, filename: str) -> str: | |
"""Get cleaned phoneme sequence from PHN file and convert to IPA""" | |
phn_file = filename.replace('.WAV', '.PHN') | |
with self.zip.open(phn_file) as f: | |
phonemes = [] | |
for line in f.read().decode('utf-8').splitlines(): | |
if line.strip(): | |
_, _, phone = line.split() | |
phone = self.remove_stress_mark(phone) | |
# Convert to IPA instead of using simplify_timit | |
ipa = self.TIMIT_TO_IPA.get(phone.lower(), '') | |
if ipa: | |
phonemes.append(ipa) | |
return ''.join(phonemes) # Join without spaces for IPA | |
def simplify_timit(self, phoneme: str) -> str: | |
"""Apply substitutions to simplify TIMIT phonemes""" | |
return self.PHONE_SUBSTITUTIONS.get(phoneme, phoneme) | |
def remove_stress_mark(self, text: str) -> str: | |
"""Removes the combining double inverted breve (͡) from text""" | |
if not isinstance(text, str): | |
raise TypeError("Input must be string") | |
return text.replace('͡', '') | |
class ModelManager: | |
"""Handles model loading and inference""" | |
def __init__(self): | |
self.models = {} | |
self.processors = {} | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.batch_size = 32 # Added batch size parameter | |
def get_model_and_processor(self, model_name: str): | |
"""Get or load model and processor""" | |
if model_name not in self.models: | |
print("Loading processor with phoneme tokenizer...") | |
processor = AutoProcessor.from_pretrained(model_name) | |
print("Loading model...", {model_name}) | |
model = AutoModelForCTC.from_pretrained(model_name).to(self.device) | |
self.models[model_name] = model | |
self.processors[model_name] = processor | |
return self.models[model_name], self.processors[model_name] | |
def transcribe(self, audio_list: List[torch.Tensor], model_name: str) -> List[str]: | |
"""Transcribe a batch of audio using specified model""" | |
model, processor = self.get_model_and_processor(model_name) | |
if not model or not processor: | |
raise Exception("Model and processor not loaded") | |
# Process audio in batches | |
all_predictions = [] | |
for i in range(0, len(audio_list), self.batch_size): | |
batch_audio = audio_list[i:i + self.batch_size] | |
# Pad sequence within batch | |
max_length = max(audio.shape[-1] for audio in batch_audio) | |
padded_audio = torch.zeros((len(batch_audio), 1, max_length)) | |
attention_mask = torch.zeros((len(batch_audio), max_length)) | |
for j, audio in enumerate(batch_audio): | |
padded_audio[j, :, :audio.shape[-1]] = audio | |
attention_mask[j, :audio.shape[-1]] = 1 | |
# Process batch | |
inputs = processor( | |
padded_audio.squeeze(1).numpy(), | |
sampling_rate=16000, | |
return_tensors="pt", | |
padding=True | |
) | |
input_values = inputs.input_values.to(self.device) | |
attention_mask = inputs.get("attention_mask", attention_mask).to(self.device) | |
with torch.no_grad(): | |
outputs = model( | |
input_values=input_values, | |
attention_mask=attention_mask | |
) | |
logits = outputs.logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
predictions = processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
predictions = [pred.replace(' ', '') for pred in predictions] | |
all_predictions.extend(predictions) | |
return all_predictions | |
class StorageManager: | |
"""Handles all JSON storage operations""" | |
def __init__(self, paths: Dict[str, Path]): | |
self.paths = paths | |
self._ensure_directories() | |
def _ensure_directories(self): | |
"""Ensure all necessary directories and files exist""" | |
for path in self.paths.values(): | |
path.parent.mkdir(parents=True, exist_ok=True) | |
if not path.exists(): | |
path.write_text('[]') | |
def load(self, key: str) -> List: | |
"""Load JSON file""" | |
return json.loads(self.paths[key].read_text()) | |
def save(self, key: str, data: List): | |
"""Save data to JSON file""" | |
self.paths[key].write_text(json.dumps(data, indent=4, default=str, ensure_ascii=False)) | |
def update_task(self, task_id: str, updates: Dict): | |
"""Update specific task with new data""" | |
tasks = self.load('tasks') | |
for task in tasks: | |
if task['id'] == task_id: | |
task.update(updates) | |
break | |
self.save('tasks', tasks) | |
class EvaluationRequest(BaseModel): | |
"""Request model for TIMIT evaluation""" | |
transcription_model: str | |
subset: str = "test" | |
max_samples: Optional[int] = None | |
submission_name: str | |
github_url: Optional[str] = None | |
# Initialize managers | |
timit_manager = TimitDataManager(TIMIT_PATH) | |
model_manager = ModelManager() | |
storage_manager = StorageManager(PATHS) | |
async def evaluate_model(task_id: str, request: EvaluationRequest): | |
"""Background task to evaluate model on TIMIT""" | |
try: | |
storage_manager.update_task(task_id, {"status": "processing"}) | |
files = timit_manager.get_file_list(request.subset) | |
if request.max_samples: | |
files = files[:request.max_samples] | |
results = [] | |
total_per = total_pwed = 0 | |
# Process files in batches | |
batch_size = model_manager.batch_size | |
for i in range(0, len(files), batch_size): | |
batch_files = files[i:i + batch_size] | |
# Load batch audio and ground truth | |
batch_audio = [] | |
batch_ground_truth = [] | |
for wav_file in batch_files: | |
audio = timit_manager.load_audio(wav_file) | |
ground_truth = timit_manager.get_phonemes(wav_file) | |
batch_audio.append(audio) | |
batch_ground_truth.append(ground_truth) | |
# Get predictions for batch | |
predictions = model_manager.transcribe(batch_audio, request.transcription_model) | |
# Calculate metrics for each file in batch | |
for j, (wav_file, prediction, ground_truth) in enumerate(zip(batch_files, predictions, batch_ground_truth)): | |
# Convert Unicode to readable format | |
#prediction_str = repr(prediction)[1:-1] # Remove quotes but keep escaped unicode | |
metrics = phone_errors.compute( | |
predictions=[prediction], | |
references=[ground_truth], | |
is_normalize_pfer=True | |
) | |
per = metrics['phone_error_rates'][0] | |
pwed = metrics['phone_feature_error_rates'][0] | |
results.append({ | |
"file": wav_file, | |
"ground_truth": ground_truth, | |
"prediction": prediction, | |
"per": per, | |
"pwed": pwed | |
}) | |
total_per += per | |
total_pwed += pwed | |
if not results: | |
raise Exception("No files were successfully processed") | |
avg_per = total_per / len(results) | |
avg_pwed = total_pwed / len(results) | |
result = { | |
"task_id": task_id, | |
"model": request.transcription_model, | |
"subset": request.subset, | |
"num_files": len(results), | |
"average_per": avg_per, | |
"average_pwed": avg_pwed, | |
"detailed_results": results[:5], | |
"timestamp": datetime.now().isoformat() | |
} | |
# Save results | |
print("Saving results...") | |
current_results = storage_manager.load('results') | |
current_results.append(result) | |
storage_manager.save('results', current_results) | |
# Update leaderboard | |
print("Updating leaderboard...") | |
leaderboard = storage_manager.load('leaderboard') | |
entry = next((e for e in leaderboard | |
if e["submission_name"] == request.submission_name), None) | |
if entry: | |
# Simply update with new scores | |
entry.update({ | |
"average_per": avg_per, | |
"average_pwed": avg_pwed, | |
"model": request.transcription_model, | |
"subset": request.subset, | |
"github_url": request.github_url, | |
"submission_date": datetime.now().isoformat() | |
}) | |
else: | |
leaderboard.append({ | |
"submission_id": str(uuid.uuid4()), | |
"submission_name": request.submission_name, | |
"model": request.transcription_model, | |
"average_per": avg_per, | |
"average_pwed": avg_pwed, | |
"subset": request.subset, | |
"github_url": request.github_url, | |
"submission_date": datetime.now().isoformat() | |
}) | |
storage_manager.save('leaderboard', leaderboard) | |
storage_manager.update_task(task_id, {"status": "completed"}) | |
print("Evaluation completed successfully") | |
except Exception as e: | |
error_msg = f"Evaluation failed: {str(e)}" | |
print(error_msg) | |
storage_manager.update_task(task_id, { | |
"status": "failed", | |
"error": error_msg | |
}) | |
# Initialize managers | |
def init_directories(): | |
"""Ensure all necessary directories exist""" | |
(CURRENT_DIR / ".data").mkdir(parents=True, exist_ok=True) | |
QUEUE_DIR.mkdir(parents=True, exist_ok=True) | |
for path in PATHS.values(): | |
if not path.exists(): | |
path.write_text('[]') | |
# Initialize your managers | |
init_directories() # Your existing initialization function | |
timit_manager = TimitDataManager(TIMIT_PATH) | |
model_manager = ModelManager() | |
storage_manager = StorageManager(PATHS) | |
async def health_check(): | |
"""Simple health check endpoint""" | |
return {"status": "healthy"} | |
async def submit_evaluation( | |
request: EvaluationRequest, | |
background_tasks: BackgroundTasks | |
): | |
"""Submit new evaluation task""" | |
task_id = str(uuid.uuid4()) | |
task = { | |
"id": task_id, | |
"model": request.transcription_model, | |
"subset": request.subset, | |
"submission_name": request.submission_name, | |
"github_url": request.github_url, | |
"status": "queued", | |
"submitted_at": datetime.now().isoformat() | |
} | |
tasks = storage_manager.load('tasks') | |
tasks.append(task) | |
storage_manager.save('tasks', tasks) | |
background_tasks.add_task(evaluate_model, task_id, request) | |
return { | |
"message": "Evaluation task submitted successfully", | |
"task_id": task_id | |
} | |
async def get_task(task_id: str): | |
"""Get specific task status""" | |
tasks = storage_manager.load('tasks') | |
task = next((t for t in tasks if t["id"] == task_id), None) | |
if not task: | |
raise HTTPException(status_code=404, detail="Task not found") | |
return task | |
async def get_leaderboard(): | |
"""Get current leaderboard""" | |
try: | |
leaderboard = storage_manager.load('leaderboard') | |
sorted_leaderboard = sorted(leaderboard, key=lambda x: (x["average_pwed"], x["average_per"])) | |
return sorted_leaderboard | |
except Exception as e: | |
print(f"Error loading leaderboard: {e}") | |
return [] | |
# Note: We need to mount the FastAPI app after defining all routes | |
app = gr.mount_gradio_app(app, demo, path="/") | |
# For local development | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |