arunasrivastava's picture
req updatesgradio
ef871e4
raw
history blame
17.5 kB
import gradio as gr
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel, HttpUrl
from typing import List, Optional, Dict
import torch
import torchaudio
from transformers import AutoProcessor, AutoModelForCTC
import evaluate
import zipfile
from datetime import datetime
import json
import uuid
import os
from pathlib import Path
from huggingface_hub import HfApi
import evaluate
# Set up download configuration with your token
app = FastAPI(title="TIMIT Phoneme Transcription Leaderboard")
# Create Gradio interface
demo = gr.Interface(
fn=lambda x: x,
inputs=gr.Textbox(visible=False),
outputs=gr.Textbox(visible=False),
title="TIMIT Phoneme Transcription Queue",
description="API endpoints are available at /api/leaderboard, /api/evaluate, and /api/tasks/{task_id}"
)
# Get absolute path - Updated for HF Spaces
CURRENT_DIR = Path(__file__).parent.absolute()
# Constants - Updated for HF Spaces environment
TIMIT_PATH = CURRENT_DIR / ".data" / "TIMIT.zip" # Move TIMIT.zip to root of space
QUEUE_DIR = CURRENT_DIR / "queue"
PATHS = {
'tasks': QUEUE_DIR / "tasks.json",
'results': QUEUE_DIR / "results.json",
'leaderboard': QUEUE_DIR / "leaderboard.json"
}
# Initialize evaluation metric
phone_errors = evaluate.load("ginic/phone_errors")
class TimitDataManager:
"""Handles all TIMIT dataset operations"""
# TIMIT to IPA mapping with direct simplifications
TIMIT_TO_IPA = {
# Vowels (simplified)
'aa': 'ɑ',
'ae': 'æ',
'ah': 'ʌ',
'ao': 'ɔ',
'aw': 'aʊ',
'ay': 'aɪ',
'eh': 'ɛ',
'er': 'ɹ', # Simplified from 'ɝ'
'ey': 'eɪ',
'ih': 'ɪ',
'ix': 'i', # Simplified from 'ɨ'
'iy': 'i',
'ow': 'oʊ',
'oy': 'ɔɪ',
'uh': 'ʊ',
'uw': 'u',
'ux': 'u', # Simplified from 'ʉ'
'ax': 'ə',
'ax-h': 'ə', # Simplified from 'ə̥'
'axr': 'ɹ', # Simplified from 'ɚ'
# Consonants
'b': '',
'bcl': 'b',
'd': '',
'dcl': 'd',
'g': '',
'gcl': 'g',
'p': '',
'pcl': 'p',
't': '',
'tcl': 't',
'k': '',
'kcl': 'k',
'dx': 'ɾ',
'q': 'ʔ',
# Fricatives
'jh': 'dʒ',
'ch': 'tʃ',
's': 's',
'sh': 'ʃ',
'z': 'z',
'zh': 'ʒ',
'f': 'f',
'th': 'θ',
'v': 'v',
'dh': 'ð',
'hh': 'h',
'hv': 'h', # Simplified from 'ɦ'
# Nasals (simplified)
'm': 'm',
'n': 'n',
'ng': 'ŋ',
'em': 'm', # Simplified from 'm̩'
'en': 'n', # Simplified from 'n̩'
'eng': 'ŋ', # Simplified from 'ŋ̍'
'nx': 'ɾ', # Simplified from 'ɾ̃'
# Semivowels and Glides
'l': 'l',
'r': 'ɹ',
'w': 'w',
'wh': 'ʍ',
'y': 'j',
'el': 'l', # Simplified from 'l̩'
# Special
'epi': '', # Remove epenthetic silence
'h#': '', # Remove start/end silence
'pau': '', # Remove pause
}
def __init__(self, timit_path: Path):
self.timit_path = timit_path
self._zip = None
print(f"TimitDataManager initialized with path: {self.timit_path.absolute()}")
if not self.timit_path.exists():
raise FileNotFoundError(f"TIMIT dataset not found at {self.timit_path.absolute()}")
print("TIMIT dataset file exists!")
@property
def zip(self):
if not self._zip:
try:
self._zip = zipfile.ZipFile(self.timit_path, 'r')
print("Successfully opened TIMIT zip file")
except FileNotFoundError:
raise FileNotFoundError(f"TIMIT dataset not found at {self.timit_path}")
return self._zip
def get_file_list(self, subset: str) -> List[str]:
"""Get list of WAV files for given subset"""
files = [f for f in self.zip.namelist()
if f.endswith('.WAV') and subset.lower() in f.lower()]
print(f"Found {len(files)} WAV files in {subset} subset")
if files:
print("First 3 files:", files[:3])
return files
def load_audio(self, filename: str) -> torch.Tensor:
"""Load and preprocess audio file"""
with self.zip.open(filename) as wav_file:
waveform, sample_rate = torchaudio.load(wav_file)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-7)
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
return waveform
def get_phonemes(self, filename: str) -> str:
"""Get cleaned phoneme sequence from PHN file and convert to IPA"""
phn_file = filename.replace('.WAV', '.PHN')
with self.zip.open(phn_file) as f:
phonemes = []
for line in f.read().decode('utf-8').splitlines():
if line.strip():
_, _, phone = line.split()
phone = self.remove_stress_mark(phone)
# Convert to IPA instead of using simplify_timit
ipa = self.TIMIT_TO_IPA.get(phone.lower(), '')
if ipa:
phonemes.append(ipa)
return ''.join(phonemes) # Join without spaces for IPA
def simplify_timit(self, phoneme: str) -> str:
"""Apply substitutions to simplify TIMIT phonemes"""
return self.PHONE_SUBSTITUTIONS.get(phoneme, phoneme)
def remove_stress_mark(self, text: str) -> str:
"""Removes the combining double inverted breve (͡) from text"""
if not isinstance(text, str):
raise TypeError("Input must be string")
return text.replace('͡', '')
class ModelManager:
"""Handles model loading and inference"""
def __init__(self):
self.models = {}
self.processors = {}
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.batch_size = 32 # Added batch size parameter
def get_model_and_processor(self, model_name: str):
"""Get or load model and processor"""
if model_name not in self.models:
print("Loading processor with phoneme tokenizer...")
processor = AutoProcessor.from_pretrained(model_name)
print("Loading model...", {model_name})
model = AutoModelForCTC.from_pretrained(model_name).to(self.device)
self.models[model_name] = model
self.processors[model_name] = processor
return self.models[model_name], self.processors[model_name]
def transcribe(self, audio_list: List[torch.Tensor], model_name: str) -> List[str]:
"""Transcribe a batch of audio using specified model"""
model, processor = self.get_model_and_processor(model_name)
if not model or not processor:
raise Exception("Model and processor not loaded")
# Process audio in batches
all_predictions = []
for i in range(0, len(audio_list), self.batch_size):
batch_audio = audio_list[i:i + self.batch_size]
# Pad sequence within batch
max_length = max(audio.shape[-1] for audio in batch_audio)
padded_audio = torch.zeros((len(batch_audio), 1, max_length))
attention_mask = torch.zeros((len(batch_audio), max_length))
for j, audio in enumerate(batch_audio):
padded_audio[j, :, :audio.shape[-1]] = audio
attention_mask[j, :audio.shape[-1]] = 1
# Process batch
inputs = processor(
padded_audio.squeeze(1).numpy(),
sampling_rate=16000,
return_tensors="pt",
padding=True
)
input_values = inputs.input_values.to(self.device)
attention_mask = inputs.get("attention_mask", attention_mask).to(self.device)
with torch.no_grad():
outputs = model(
input_values=input_values,
attention_mask=attention_mask
)
logits = outputs.logits
predicted_ids = torch.argmax(logits, dim=-1)
predictions = processor.batch_decode(predicted_ids, skip_special_tokens=True)
predictions = [pred.replace(' ', '') for pred in predictions]
all_predictions.extend(predictions)
return all_predictions
class StorageManager:
"""Handles all JSON storage operations"""
def __init__(self, paths: Dict[str, Path]):
self.paths = paths
self._ensure_directories()
def _ensure_directories(self):
"""Ensure all necessary directories and files exist"""
for path in self.paths.values():
path.parent.mkdir(parents=True, exist_ok=True)
if not path.exists():
path.write_text('[]')
def load(self, key: str) -> List:
"""Load JSON file"""
return json.loads(self.paths[key].read_text())
def save(self, key: str, data: List):
"""Save data to JSON file"""
self.paths[key].write_text(json.dumps(data, indent=4, default=str, ensure_ascii=False))
def update_task(self, task_id: str, updates: Dict):
"""Update specific task with new data"""
tasks = self.load('tasks')
for task in tasks:
if task['id'] == task_id:
task.update(updates)
break
self.save('tasks', tasks)
class EvaluationRequest(BaseModel):
"""Request model for TIMIT evaluation"""
transcription_model: str
subset: str = "test"
max_samples: Optional[int] = None
submission_name: str
github_url: Optional[str] = None
# Initialize managers
timit_manager = TimitDataManager(TIMIT_PATH)
model_manager = ModelManager()
storage_manager = StorageManager(PATHS)
async def evaluate_model(task_id: str, request: EvaluationRequest):
"""Background task to evaluate model on TIMIT"""
try:
storage_manager.update_task(task_id, {"status": "processing"})
files = timit_manager.get_file_list(request.subset)
if request.max_samples:
files = files[:request.max_samples]
results = []
total_per = total_pwed = 0
# Process files in batches
batch_size = model_manager.batch_size
for i in range(0, len(files), batch_size):
batch_files = files[i:i + batch_size]
# Load batch audio and ground truth
batch_audio = []
batch_ground_truth = []
for wav_file in batch_files:
audio = timit_manager.load_audio(wav_file)
ground_truth = timit_manager.get_phonemes(wav_file)
batch_audio.append(audio)
batch_ground_truth.append(ground_truth)
# Get predictions for batch
predictions = model_manager.transcribe(batch_audio, request.transcription_model)
# Calculate metrics for each file in batch
for j, (wav_file, prediction, ground_truth) in enumerate(zip(batch_files, predictions, batch_ground_truth)):
# Convert Unicode to readable format
#prediction_str = repr(prediction)[1:-1] # Remove quotes but keep escaped unicode
metrics = phone_errors.compute(
predictions=[prediction],
references=[ground_truth],
is_normalize_pfer=True
)
per = metrics['phone_error_rates'][0]
pwed = metrics['phone_feature_error_rates'][0]
results.append({
"file": wav_file,
"ground_truth": ground_truth,
"prediction": prediction,
"per": per,
"pwed": pwed
})
total_per += per
total_pwed += pwed
if not results:
raise Exception("No files were successfully processed")
avg_per = total_per / len(results)
avg_pwed = total_pwed / len(results)
result = {
"task_id": task_id,
"model": request.transcription_model,
"subset": request.subset,
"num_files": len(results),
"average_per": avg_per,
"average_pwed": avg_pwed,
"detailed_results": results[:5],
"timestamp": datetime.now().isoformat()
}
# Save results
print("Saving results...")
current_results = storage_manager.load('results')
current_results.append(result)
storage_manager.save('results', current_results)
# Update leaderboard
print("Updating leaderboard...")
leaderboard = storage_manager.load('leaderboard')
entry = next((e for e in leaderboard
if e["submission_name"] == request.submission_name), None)
if entry:
# Simply update with new scores
entry.update({
"average_per": avg_per,
"average_pwed": avg_pwed,
"model": request.transcription_model,
"subset": request.subset,
"github_url": request.github_url,
"submission_date": datetime.now().isoformat()
})
else:
leaderboard.append({
"submission_id": str(uuid.uuid4()),
"submission_name": request.submission_name,
"model": request.transcription_model,
"average_per": avg_per,
"average_pwed": avg_pwed,
"subset": request.subset,
"github_url": request.github_url,
"submission_date": datetime.now().isoformat()
})
storage_manager.save('leaderboard', leaderboard)
storage_manager.update_task(task_id, {"status": "completed"})
print("Evaluation completed successfully")
except Exception as e:
error_msg = f"Evaluation failed: {str(e)}"
print(error_msg)
storage_manager.update_task(task_id, {
"status": "failed",
"error": error_msg
})
# Initialize managers
def init_directories():
"""Ensure all necessary directories exist"""
(CURRENT_DIR / ".data").mkdir(parents=True, exist_ok=True)
QUEUE_DIR.mkdir(parents=True, exist_ok=True)
for path in PATHS.values():
if not path.exists():
path.write_text('[]')
# Initialize your managers
init_directories() # Your existing initialization function
timit_manager = TimitDataManager(TIMIT_PATH)
model_manager = ModelManager()
storage_manager = StorageManager(PATHS)
@app.get("/api/health")
async def health_check():
"""Simple health check endpoint"""
return {"status": "healthy"}
@app.post("/api/evaluate")
async def submit_evaluation(
request: EvaluationRequest,
background_tasks: BackgroundTasks
):
"""Submit new evaluation task"""
task_id = str(uuid.uuid4())
task = {
"id": task_id,
"model": request.transcription_model,
"subset": request.subset,
"submission_name": request.submission_name,
"github_url": request.github_url,
"status": "queued",
"submitted_at": datetime.now().isoformat()
}
tasks = storage_manager.load('tasks')
tasks.append(task)
storage_manager.save('tasks', tasks)
background_tasks.add_task(evaluate_model, task_id, request)
return {
"message": "Evaluation task submitted successfully",
"task_id": task_id
}
@app.get("/api/tasks/{task_id}")
async def get_task(task_id: str):
"""Get specific task status"""
tasks = storage_manager.load('tasks')
task = next((t for t in tasks if t["id"] == task_id), None)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task
@app.get("/api/leaderboard")
async def get_leaderboard():
"""Get current leaderboard"""
try:
leaderboard = storage_manager.load('leaderboard')
sorted_leaderboard = sorted(leaderboard, key=lambda x: (x["average_per"], x["average_pwed"]))
return sorted_leaderboard
except Exception as e:
print(f"Error loading leaderboard: {e}")
return []
# Note: We need to mount the FastAPI app after defining all routes
app = gr.mount_gradio_app(app, demo, path="/")
# For local development
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)