indic-trans-api / app.py
darshankr's picture
Update app.py
0b8919f verified
raw
history blame
2.15 kB
from fastapi import FastAPI, HTTPException
from typing import List
from pydantic import BaseModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
from fastapi.middleware.cors import CORSMiddleware
import torch
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model = AutoModelForSeq2SeqLM.from_pretrained(
"ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
"ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True
)
ip = IndicProcessor(inference=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(DEVICE)
def translate_text(sentences: List[str], target_lang: str):
try:
src_lang = "eng_Latn"
batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=target_lang)
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
with tokenizer.as_target_tokenizer():
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
)
return generated_tokens
except Exception as e:
return str(e)
@app.get("/")
def read_root():
return {"Hello": "World"}
class TranslateRequest(BaseModel):
sentences: List[str]
target_lang: str
@app.post("/translate/")
def translate(request: TranslateRequest):
try:
result = translate_text(request.sentences, request.target_lang)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))