darshankr commited on
Commit
3b25c41
·
verified ·
1 Parent(s): 659554e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List
4
+ import torch
5
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
6
+ from IndicTransToolkit import IndicProcessor
7
+
8
+ # Initialize FastAPI app
9
+ app = FastAPI(
10
+ title="Indic Translation API",
11
+ description="API for translating text between English and Indic languages",
12
+ version="1.0.0"
13
+ )
14
+
15
+ # Define request body model
16
+ class InputData(BaseModel):
17
+ sentences: List[str]
18
+ target_lang: str
19
+
20
+ class Config:
21
+ schema_extra = {
22
+ "example": {
23
+ "sentences": ["Hello, how are you?", "What is your name?"],
24
+ "target_lang": "hin_Deva"
25
+ }
26
+ }
27
+
28
+ # Initialize models and processors
29
+ try:
30
+ model = AutoModelForSeq2SeqLM.from_pretrained(
31
+ "ai4bharat/indictrans2-en-indic-1B",
32
+ trust_remote_code=True
33
+ )
34
+ tokenizer = AutoTokenizer.from_pretrained(
35
+ "ai4bharat/indictrans2-en-indic-1B",
36
+ trust_remote_code=True
37
+ )
38
+ ip = IndicProcessor(inference=True)
39
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
40
+ model = model.to(DEVICE)
41
+ except Exception as e:
42
+ raise RuntimeError(f"Failed to load models: {str(e)}")
43
+
44
+ @app.get("/")
45
+ async def root():
46
+ """Root endpoint returning API information"""
47
+ return {
48
+ "message": "Welcome to the Indic Translation API",
49
+ "status": "active",
50
+ "supported_languages": [
51
+ "hin_Deva", # Hindi
52
+ "ben_Beng", # Bengali
53
+ "tam_Taml", # Tamil
54
+ # Add other supported languages here
55
+ ]
56
+ }
57
+
58
+ @app.post("/translate/")
59
+ async def translate(input_data: InputData):
60
+ """
61
+ Translate text from English to specified Indic language
62
+
63
+ Args:
64
+ input_data: InputData object containing sentences and target language
65
+
66
+ Returns:
67
+ Dictionary containing translated text
68
+ """
69
+ try:
70
+ # Source language is always English
71
+ src_lang = "eng_Latn"
72
+ tgt_lang = input_data.target_lang
73
+
74
+ # Preprocess the input sentences
75
+ batch = ip.preprocess_batch(
76
+ input_data.sentences,
77
+ src_lang=src_lang,
78
+ tgt_lang=tgt_lang
79
+ )
80
+
81
+ # Tokenize the sentences
82
+ inputs = tokenizer(
83
+ batch,
84
+ truncation=True,
85
+ padding="longest",
86
+ return_tensors="pt",
87
+ return_attention_mask=True
88
+ ).to(DEVICE)
89
+
90
+ # Generate translations
91
+ with torch.no_grad():
92
+ generated_tokens = model.generate(
93
+ **inputs,
94
+ use_cache=True,
95
+ min_length=0,
96
+ max_length=256,
97
+ num_beams=5,
98
+ num_return_sequences=1
99
+ )
100
+
101
+ # Decode the generated tokens
102
+ with tokenizer.as_target_tokenizer():
103
+ generated_tokens = tokenizer.batch_decode(
104
+ generated_tokens.detach().cpu().tolist(),
105
+ skip_special_tokens=True,
106
+ clean_up_tokenization_spaces=True
107
+ )
108
+
109
+ # Postprocess the translations
110
+ translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
111
+
112
+ return {
113
+ "translations": translations,
114
+ "source_language": src_lang,
115
+ "target_language": tgt_lang
116
+ }
117
+
118
+ except Exception as e:
119
+ raise HTTPException(
120
+ status_code=500,
121
+ detail=f"Translation error: {str(e)}"
122
+ )
123
+
124
+ # Add health check endpoint
125
+ @app.get("/health")
126
+ async def health_check():
127
+ """Health check endpoint"""
128
+ return {"status": "healthy"}