from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from IndicTransToolkit import IndicProcessor # Initialize FastAPI app app = FastAPI( title="Indic Translation API", description="API for translating text between English and Indic languages", version="1.0.0" ) # Define request body model class InputData(BaseModel): sentences: List[str] target_lang: str class Config: schema_extra = { "example": { "sentences": ["Hello, how are you?", "What is your name?"], "target_lang": "hin_Deva" } } # Initialize models and processors try: model = AutoModelForSeq2SeqLM.from_pretrained( "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True ) ip = IndicProcessor(inference=True) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(DEVICE) except Exception as e: raise RuntimeError(f"Failed to load models: {str(e)}") @app.get("/") async def root(): """Root endpoint returning API information""" return { "message": "Welcome to the Indic Translation API", "status": "active", "supported_languages": [ "hin_Deva", # Hindi "ben_Beng", # Bengali "tam_Taml", # Tamil # Add other supported languages here ] } @app.post("/translate/") async def translate(input_data: InputData): """ Translate text from English to specified Indic language Args: input_data: InputData object containing sentences and target language Returns: Dictionary containing translated text """ try: # Source language is always English src_lang = "eng_Latn" tgt_lang = input_data.target_lang # Preprocess the input sentences batch = ip.preprocess_batch( input_data.sentences, src_lang=src_lang, tgt_lang=tgt_lang ) # Tokenize the sentences inputs = tokenizer( batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True ).to(DEVICE) # Generate translations with torch.no_grad(): generated_tokens = model.generate( **inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1 ) # Decode the generated tokens with tokenizer.as_target_tokenizer(): generated_tokens = tokenizer.batch_decode( generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True ) # Postprocess the translations translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang) return { "translations": translations, "source_language": src_lang, "target_language": tgt_lang } except Exception as e: raise HTTPException( status_code=500, detail=f"Translation error: {str(e)}" ) # Add health check endpoint @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy"}