Darshan commited on
Commit
6f55a35
·
1 Parent(s): 2218bb2
Files changed (1) hide show
  1. app.py +89 -91
app.py CHANGED
@@ -7,14 +7,13 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
  from IndicTransToolkit import IndicProcessor
8
  import json
9
  from fastapi.middleware.cors import CORSMiddleware
10
- from fastapi.staticfiles import StaticFiles
11
  import uvicorn
12
 
13
  # Initialize FastAPI
14
- api = FastAPI()
15
 
16
  # Add CORS middleware
17
- api.add_middleware(
18
  CORSMiddleware,
19
  allow_origins=["*"],
20
  allow_credentials=True,
@@ -24,33 +23,28 @@ api.add_middleware(
24
 
25
  # Initialize models and processors
26
  model = AutoModelForSeq2SeqLM.from_pretrained(
27
- "ai4bharat/indictrans2-en-indic-1B",
28
- trust_remote_code=True
29
  )
30
  tokenizer = AutoTokenizer.from_pretrained(
31
- "ai4bharat/indictrans2-en-indic-1B",
32
- trust_remote_code=True
33
  )
34
  ip = IndicProcessor(inference=True)
35
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
36
  model = model.to(DEVICE)
37
 
 
38
  def translate_text(sentences: List[str], target_lang: str):
39
  try:
40
  src_lang = "eng_Latn"
41
- batch = ip.preprocess_batch(
42
- sentences,
43
- src_lang=src_lang,
44
- tgt_lang=target_lang
45
- )
46
  inputs = tokenizer(
47
  batch,
48
  truncation=True,
49
  padding="longest",
50
  return_tensors="pt",
51
- return_attention_mask=True
52
  ).to(DEVICE)
53
-
54
  with torch.no_grad():
55
  generated_tokens = model.generate(
56
  **inputs,
@@ -58,31 +52,33 @@ def translate_text(sentences: List[str], target_lang: str):
58
  min_length=0,
59
  max_length=256,
60
  num_beams=5,
61
- num_return_sequences=1
62
  )
63
-
64
  with tokenizer.as_target_tokenizer():
65
  generated_tokens = tokenizer.batch_decode(
66
  generated_tokens.detach().cpu().tolist(),
67
  skip_special_tokens=True,
68
- clean_up_tokenization_spaces=True
69
  )
70
-
71
  translations = ip.postprocess_batch(generated_tokens, lang=target_lang)
72
  return {
73
  "translations": translations,
74
  "source_language": src_lang,
75
- "target_language": target_lang
76
  }
77
  except Exception as e:
78
  raise Exception(f"Translation failed: {str(e)}")
79
 
 
80
  # FastAPI routes
81
- @api.get("/health")
82
  async def health_check():
83
  return {"status": "healthy"}
84
 
85
- @api.post("/translate")
 
86
  async def translate_endpoint(sentences: List[str], target_lang: str):
87
  try:
88
  result = translate_text(sentences=sentences, target_lang=target_lang)
@@ -90,73 +86,75 @@ async def translate_endpoint(sentences: List[str], target_lang: str):
90
  except Exception as e:
91
  raise HTTPException(status_code=500, detail=str(e))
92
 
93
- # Streamlit interface
94
- def main():
95
- st.title("Indic Language Translator")
96
-
97
- # Input text
98
- text_input = st.text_area("Enter text to translate:", "Hello, how are you?")
99
-
100
- # Language selection
101
- target_languages = {
102
- "Hindi": "hin_Deva",
103
- "Bengali": "ben_Beng",
104
- "Tamil": "tam_Taml",
105
- "Telugu": "tel_Telu",
106
- "Marathi": "mar_Deva",
107
- "Gujarati": "guj_Gujr",
108
- "Kannada": "kan_Knda",
109
- "Malayalam": "mal_Mlym",
110
- "Punjabi": "pan_Guru",
111
- "Odia": "ori_Orya"
112
- }
113
-
114
- target_lang = st.selectbox(
115
- "Select target language:",
116
- options=list(target_languages.keys())
117
- )
118
-
119
- if st.button("Translate"):
120
- try:
121
- result = translate_text(
122
- sentences=[text_input],
123
- target_lang=target_languages[target_lang]
124
- )
125
- st.success("Translation:")
126
- st.write(result["translations"][0])
127
- except Exception as e:
128
- st.error(f"Translation failed: {str(e)}")
129
-
130
- # Add API documentation
131
- st.markdown("---")
132
- st.header("API Documentation")
133
- st.markdown("""
134
- To use the translation API, send POST requests to:
135
- ```
136
- https://USERNAME-SPACE_NAME.hf.space/translate
137
- ```
138
- Request body format:
139
- ```json
140
- {
141
- "sentences": ["Your text here"],
142
- "target_lang": "hin_Deva"
143
- }
144
- ```
145
- """)
146
- st.markdown("Available target languages:")
147
- for lang, code in target_languages.items():
148
- st.markdown(f"- {lang}: `{code}`")
149
-
150
- if __name__ == "__main__":
151
- # Run both Streamlit and FastAPI
152
- import threading
153
-
154
- def run_fastapi():
155
- uvicorn.run(api, host="0.0.0.0", port=8000)
156
-
157
- # Start FastAPI in a separate thread
158
- api_thread = threading.Thread(target=run_fastapi)
159
- api_thread.start()
160
-
161
- # Run Streamlit
162
- main()
 
 
 
7
  from IndicTransToolkit import IndicProcessor
8
  import json
9
  from fastapi.middleware.cors import CORSMiddleware
 
10
  import uvicorn
11
 
12
  # Initialize FastAPI
13
+ app = FastAPI()
14
 
15
  # Add CORS middleware
16
+ app.add_middleware(
17
  CORSMiddleware,
18
  allow_origins=["*"],
19
  allow_credentials=True,
 
23
 
24
  # Initialize models and processors
25
  model = AutoModelForSeq2SeqLM.from_pretrained(
26
+ "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True
 
27
  )
28
  tokenizer = AutoTokenizer.from_pretrained(
29
+ "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True
 
30
  )
31
  ip = IndicProcessor(inference=True)
32
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
33
  model = model.to(DEVICE)
34
 
35
+
36
  def translate_text(sentences: List[str], target_lang: str):
37
  try:
38
  src_lang = "eng_Latn"
39
+ batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=target_lang)
 
 
 
 
40
  inputs = tokenizer(
41
  batch,
42
  truncation=True,
43
  padding="longest",
44
  return_tensors="pt",
45
+ return_attention_mask=True,
46
  ).to(DEVICE)
47
+
48
  with torch.no_grad():
49
  generated_tokens = model.generate(
50
  **inputs,
 
52
  min_length=0,
53
  max_length=256,
54
  num_beams=5,
55
+ num_return_sequences=1,
56
  )
57
+
58
  with tokenizer.as_target_tokenizer():
59
  generated_tokens = tokenizer.batch_decode(
60
  generated_tokens.detach().cpu().tolist(),
61
  skip_special_tokens=True,
62
+ clean_up_tokenization_spaces=True,
63
  )
64
+
65
  translations = ip.postprocess_batch(generated_tokens, lang=target_lang)
66
  return {
67
  "translations": translations,
68
  "source_language": src_lang,
69
+ "target_language": target_lang,
70
  }
71
  except Exception as e:
72
  raise Exception(f"Translation failed: {str(e)}")
73
 
74
+
75
  # FastAPI routes
76
+ @app.get("/health")
77
  async def health_check():
78
  return {"status": "healthy"}
79
 
80
+
81
+ @app.post("/translate")
82
  async def translate_endpoint(sentences: List[str], target_lang: str):
83
  try:
84
  result = translate_text(sentences=sentences, target_lang=target_lang)
 
86
  except Exception as e:
87
  raise HTTPException(status_code=500, detail=str(e))
88
 
89
+
90
+ # # Streamlit interface
91
+ # def main():
92
+ # st.title("Indic Language Translator")
93
+
94
+ # # Input text
95
+ # text_input = st.text_area("Enter text to translate:", "Hello, how are you?")
96
+
97
+ # # Language selection
98
+ # target_languages = {
99
+ # "Hindi": "hin_Deva",
100
+ # "Bengali": "ben_Beng",
101
+ # "Tamil": "tam_Taml",
102
+ # "Telugu": "tel_Telu",
103
+ # "Marathi": "mar_Deva",
104
+ # "Gujarati": "guj_Gujr",
105
+ # "Kannada": "kan_Knda",
106
+ # "Malayalam": "mal_Mlym",
107
+ # "Punjabi": "pan_Guru",
108
+ # "Odia": "ori_Orya",
109
+ # }
110
+
111
+ # target_lang = st.selectbox(
112
+ # "Select target language:", options=list(target_languages.keys())
113
+ # )
114
+
115
+ # if st.button("Translate"):
116
+ # try:
117
+ # result = translate_text(
118
+ # sentences=[text_input], target_lang=target_languages[target_lang]
119
+ # )
120
+ # st.success("Translation:")
121
+ # st.write(result["translations"][0])
122
+ # except Exception as e:
123
+ # st.error(f"Translation failed: {str(e)}")
124
+
125
+ # # Add API documentation
126
+ # st.markdown("---")
127
+ # st.header("API Documentation")
128
+ # st.markdown(
129
+ # """
130
+ # To use the translation API, send POST requests to:
131
+ # ```
132
+ # https://darshankr-trans-en-indic.hf.space/translate
133
+ # ```
134
+ # Request body format:
135
+ # ```json
136
+ # {
137
+ # "sentences": ["Your text here"],
138
+ # "target_lang": "hin_Deva"
139
+ # }
140
+ # ```
141
+ # """
142
+ # )
143
+ # st.markdown("Available target languages:")
144
+ # for lang, code in target_languages.items():
145
+ # st.markdown(f"- {lang}: `{code}`")
146
+
147
+
148
+ # if __name__ == "__main__":
149
+ # # Run both Streamlit and FastAPI
150
+ # import threading
151
+
152
+ # def run_fastapi():
153
+ # uvicorn.run(api, host="0.0.0.0", port=8000)
154
+
155
+ # # Start FastAPI in a separate thread
156
+ # api_thread = threading.Thread(target=run_fastapi)
157
+ # api_thread.start()
158
+
159
+ # # Run Streamlit
160
+ # main()