darshankr commited on
Commit
b6cd2a4
·
verified ·
1 Parent(s): 0b7c166

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -34
app.py CHANGED
@@ -1,46 +1,128 @@
1
- # app.py
2
- import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
  from typing import List
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import torch
7
- import asyncio
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
  from IndicTransToolkit import IndicProcessor
10
- import requests
11
  import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Initialize models and processors
14
- model = AutoModelForSeq2SeqLM.from_pretrained("ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True)
15
- tokenizer = AutoTokenizer.from_pretrained("ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True)
 
 
 
 
 
 
16
  ip = IndicProcessor(inference=True)
17
-
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
  model = model.to(DEVICE)
20
 
21
  def translate_text(sentences: List[str], target_lang: str):
22
  try:
23
  src_lang = "eng_Latn"
24
- batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=target_lang)
25
- inputs = tokenizer(batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True).to(DEVICE)
26
-
 
 
 
 
 
 
 
 
 
 
27
  with torch.no_grad():
28
  generated_tokens = model.generate(
29
- inputs,
30
  use_cache=True,
31
  min_length=0,
32
  max_length=256,
33
  num_beams=5,
34
  num_return_sequences=1
35
  )
36
-
37
  with tokenizer.as_target_tokenizer():
38
  generated_tokens = tokenizer.batch_decode(
39
  generated_tokens.detach().cpu().tolist(),
40
  skip_special_tokens=True,
41
  clean_up_tokenization_spaces=True
42
  )
43
-
44
  translations = ip.postprocess_batch(generated_tokens, lang=target_lang)
45
  return {
46
  "translations": translations,
@@ -50,13 +132,26 @@ def translate_text(sentences: List[str], target_lang: str):
50
  except Exception as e:
51
  raise Exception(f"Translation failed: {str(e)}")
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  # Streamlit interface
54
  def main():
55
  st.title("Indic Language Translator")
56
-
57
  # Input text
58
  text_input = st.text_area("Enter text to translate:", "Hello, how are you?")
59
-
60
  # Language selection
61
  target_languages = {
62
  "Hindi": "hin_Deva",
@@ -71,13 +166,17 @@ def main():
71
  "Odia": "ori_Orya"
72
  }
73
 
74
- target_lang = st.selectbox("Select target language:", options=list(target_languages.keys()))
75
-
 
 
 
76
  if st.button("Translate"):
77
  try:
78
- result = translate_text(sentences=[text_input], target_lang=target_languages[target_lang])
79
-
80
- # Display result
 
81
  st.success("Translation:")
82
  st.write(result["translations"][0])
83
  except Exception as e:
@@ -88,8 +187,9 @@ def main():
88
  st.header("API Documentation")
89
  st.markdown("""
90
  To use the translation API, send POST requests to:
 
91
  https://USERNAME-SPACE_NAME.hf.space/translate
92
-
93
  Request body format:
94
  ```json
95
  {
@@ -97,19 +197,21 @@ def main():
97
  "target_lang": "hin_Deva"
98
  }
99
  ```
100
-
101
- Available target languages:
102
- - Hindi: hin_Deva
103
- - Bengali: ben_Beng
104
- - Tamil: tam_Taml
105
- - Telugu: tel_Telu
106
- - Marathi: mar_Deva
107
- - Gujarati: guj_Gujr
108
- - Kannada: kan_Knda
109
- - Malayalam: mal_Mlym
110
- - Punjabi: pan_Guru
111
- - Odia: ori_Orya
112
  """)
 
 
 
113
 
114
  if __name__ == "__main__":
115
- main()
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # run.py
2
+ import subprocess
3
+ import sys
4
+ import os
5
+
6
+ def main():
7
+ # Start Streamlit server only
8
+ port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces uses port 7860
9
+ streamlit_process = subprocess.Popen([
10
+ sys.executable,
11
+ "-m",
12
+ "streamlit",
13
+ "run",
14
+ "app.py",
15
+ "--server.port",
16
+ str(port),
17
+ "--server.address",
18
+ "0.0.0.0"
19
+ ])
20
+
21
+ try:
22
+ streamlit_process.wait()
23
+ except KeyboardInterrupt:
24
+ streamlit_process.terminate()
25
+
26
+ if __name__ == "__main__":
27
+ main()
28
+
29
+ # api.py
30
  from fastapi import FastAPI, HTTPException
31
  from pydantic import BaseModel
32
  from typing import List
33
+ from app import translate_text
34
+
35
+ app = FastAPI()
36
+
37
+ class InputData(BaseModel):
38
+ sentences: List[str]
39
+ target_lang: str
40
+
41
+ @app.get("/health")
42
+ async def health_check():
43
+ return {"status": "healthy"}
44
+
45
+ @app.post("/translate")
46
+ async def translate(input_data: InputData):
47
+ try:
48
+ result = translate_text(
49
+ sentences=input_data.sentences,
50
+ target_lang=input_data.target_lang
51
+ )
52
+ return result
53
+ except Exception as e:
54
+ raise HTTPException(status_code=500, detail=str(e))
55
+
56
+ # app.py
57
+ import streamlit as st
58
+ from fastapi import FastAPI
59
+ from typing import List
60
  import torch
 
61
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
62
  from IndicTransToolkit import IndicProcessor
 
63
  import json
64
+ from fastapi.middleware.cors import CORSMiddleware
65
+ from fastapi.staticfiles import StaticFiles
66
+ import uvicorn
67
+
68
+ # Initialize FastAPI
69
+ api = FastAPI()
70
+
71
+ # Add CORS middleware
72
+ api.add_middleware(
73
+ CORSMiddleware,
74
+ allow_origins=["*"],
75
+ allow_credentials=True,
76
+ allow_methods=["*"],
77
+ allow_headers=["*"],
78
+ )
79
 
80
  # Initialize models and processors
81
+ model = AutoModelForSeq2SeqLM.from_pretrained(
82
+ "ai4bharat/indictrans2-en-indic-1B",
83
+ trust_remote_code=True
84
+ )
85
+ tokenizer = AutoTokenizer.from_pretrained(
86
+ "ai4bharat/indictrans2-en-indic-1B",
87
+ trust_remote_code=True
88
+ )
89
  ip = IndicProcessor(inference=True)
 
90
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
91
  model = model.to(DEVICE)
92
 
93
  def translate_text(sentences: List[str], target_lang: str):
94
  try:
95
  src_lang = "eng_Latn"
96
+ batch = ip.preprocess_batch(
97
+ sentences,
98
+ src_lang=src_lang,
99
+ tgt_lang=target_lang
100
+ )
101
+ inputs = tokenizer(
102
+ batch,
103
+ truncation=True,
104
+ padding="longest",
105
+ return_tensors="pt",
106
+ return_attention_mask=True
107
+ ).to(DEVICE)
108
+
109
  with torch.no_grad():
110
  generated_tokens = model.generate(
111
+ **inputs,
112
  use_cache=True,
113
  min_length=0,
114
  max_length=256,
115
  num_beams=5,
116
  num_return_sequences=1
117
  )
118
+
119
  with tokenizer.as_target_tokenizer():
120
  generated_tokens = tokenizer.batch_decode(
121
  generated_tokens.detach().cpu().tolist(),
122
  skip_special_tokens=True,
123
  clean_up_tokenization_spaces=True
124
  )
125
+
126
  translations = ip.postprocess_batch(generated_tokens, lang=target_lang)
127
  return {
128
  "translations": translations,
 
132
  except Exception as e:
133
  raise Exception(f"Translation failed: {str(e)}")
134
 
135
+ # FastAPI routes
136
+ @api.get("/health")
137
+ async def health_check():
138
+ return {"status": "healthy"}
139
+
140
+ @api.post("/translate")
141
+ async def translate_endpoint(sentences: List[str], target_lang: str):
142
+ try:
143
+ result = translate_text(sentences=sentences, target_lang=target_lang)
144
+ return result
145
+ except Exception as e:
146
+ raise HTTPException(status_code=500, detail=str(e))
147
+
148
  # Streamlit interface
149
  def main():
150
  st.title("Indic Language Translator")
151
+
152
  # Input text
153
  text_input = st.text_area("Enter text to translate:", "Hello, how are you?")
154
+
155
  # Language selection
156
  target_languages = {
157
  "Hindi": "hin_Deva",
 
166
  "Odia": "ori_Orya"
167
  }
168
 
169
+ target_lang = st.selectbox(
170
+ "Select target language:",
171
+ options=list(target_languages.keys())
172
+ )
173
+
174
  if st.button("Translate"):
175
  try:
176
+ result = translate_text(
177
+ sentences=[text_input],
178
+ target_lang=target_languages[target_lang]
179
+ )
180
  st.success("Translation:")
181
  st.write(result["translations"][0])
182
  except Exception as e:
 
187
  st.header("API Documentation")
188
  st.markdown("""
189
  To use the translation API, send POST requests to:
190
+ ```
191
  https://USERNAME-SPACE_NAME.hf.space/translate
192
+ ```
193
  Request body format:
194
  ```json
195
  {
 
197
  "target_lang": "hin_Deva"
198
  }
199
  ```
 
 
 
 
 
 
 
 
 
 
 
 
200
  """)
201
+ st.markdown("Available target languages:")
202
+ for lang, code in target_languages.items():
203
+ st.markdown(f"- {lang}: `{code}`")
204
 
205
  if __name__ == "__main__":
206
+ # Run both Streamlit and FastAPI
207
+ import threading
208
+
209
+ def run_fastapi():
210
+ uvicorn.run(api, host="0.0.0.0", port=8000)
211
+
212
+ # Start FastAPI in a separate thread
213
+ api_thread = threading.Thread(target=run_fastapi)
214
+ api_thread.start()
215
+
216
+ # Run Streamlit
217
+ main()