from fastapi import FastAPI, HTTPException from transformers import AutoModelForSeq2SeqLM from IndicTransToolkit import IndicProcessor from typing import List import os # Set the HF_HOME environment variable to a writable directory os.environ["HF_HOME"] = "/app/cache" os.environ["TRANSFORMERS_CACHE"] = "/app/cache" model = AutoModelForSeq2SeqLM.from_pretrained("ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True) ip = IndicProcessor(inference=True) app = FastAPI() # Define request body with Pydantic class InputData(BaseModel): sentences: List[str] target_lang: str # API endpoint to receive input and return predictions @app.post("/translate/") async def predict(input_data: InputData): try: result = model(input_data.text) return {"output": result} src_lang, tgt_lang = "eng_Latn", input_data.target_lang batch = ip.preprocess_batch( input_sentences, src_lang=src_lang, tgt_lang=tgt_lang, ) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Tokenize the sentences and generate input encodings inputs = tokenizer( batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True, ).to(DEVICE) # Generate translations using the model with torch.no_grad(): generated_tokens = model.generate( **inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1, ) # Decode the generated tokens into text with tokenizer.as_target_tokenizer(): generated_tokens = tokenizer.batch_decode( generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True, ) # Postprocess the translations, including entity replacement translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang) return {"output": translations} except Exception as e: raise HTTPException(status_code=500, detail=str(e))