darshankr commited on
Commit
45a86ac
·
verified ·
1 Parent(s): 83b59a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -23
app.py CHANGED
@@ -1,20 +1,29 @@
1
  # app.py
2
  import streamlit as st
3
- from fastapi import FastAPI
 
 
4
  from typing import List
5
  import torch
6
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
  from IndicTransToolkit import IndicProcessor
8
  import json
9
  from fastapi.middleware.cors import CORSMiddleware
10
- from fastapi.staticfiles import StaticFiles
11
  import uvicorn
 
 
 
 
 
 
 
 
12
 
13
  # Initialize FastAPI
14
- api = FastAPI()
15
 
16
  # Add CORS middleware
17
- api.add_middleware(
18
  CORSMiddleware,
19
  allow_origins=["*"],
20
  allow_credentials=True,
@@ -35,6 +44,10 @@ ip = IndicProcessor(inference=True)
35
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
36
  model = model.to(DEVICE)
37
 
 
 
 
 
38
  def translate_text(sentences: List[str], target_lang: str):
39
  try:
40
  src_lang = "eng_Latn"
@@ -78,20 +91,23 @@ def translate_text(sentences: List[str], target_lang: str):
78
  raise Exception(f"Translation failed: {str(e)}")
79
 
80
  # FastAPI routes
81
- @api.get("/health")
82
  async def health_check():
83
  return {"status": "healthy"}
84
 
85
- @api.post("/translate")
86
- async def translate_endpoint(sentences: List[str], target_lang: str):
87
  try:
88
- result = translate_text(sentences=sentences, target_lang=target_lang)
89
- return result
 
 
 
90
  except Exception as e:
91
  raise HTTPException(status_code=500, detail=str(e))
92
 
93
  # Streamlit interface
94
- def main():
95
  st.title("Indic Language Translator")
96
 
97
  # Input text
@@ -133,7 +149,7 @@ def main():
133
  st.markdown("""
134
  To use the translation API, send POST requests to:
135
  ```
136
- https://USERNAME-SPACE_NAME.hf.space/translate
137
  ```
138
  Request body format:
139
  ```json
@@ -147,16 +163,16 @@ def main():
147
  for lang, code in target_languages.items():
148
  st.markdown(f"- {lang}: `{code}`")
149
 
 
 
 
 
 
 
 
 
150
  if __name__ == "__main__":
151
- # Run both Streamlit and FastAPI
152
- import threading
153
-
154
- def run_fastapi():
155
- uvicorn.run(api, host="0.0.0.0", port=8000)
156
-
157
- # Start FastAPI in a separate thread
158
- api_thread = threading.Thread(target=run_fastapi)
159
- api_thread.start()
160
-
161
- # Run Streamlit
162
- main()
 
1
  # app.py
2
  import streamlit as st
3
+ from fastapi import FastAPI, HTTPException, Request
4
+ from fastapi.responses import JSONResponse
5
+ from pydantic import BaseModel
6
  from typing import List
7
  import torch
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
  from IndicTransToolkit import IndicProcessor
10
  import json
11
  from fastapi.middleware.cors import CORSMiddleware
 
12
  import uvicorn
13
+ from starlette.applications import Starlette
14
+ from starlette.routing import Mount, Route
15
+ from starlette.staticfiles import StaticFiles
16
+ import asyncio
17
+ import nest_asyncio
18
+
19
+ # Enable nested event loops
20
+ nest_asyncio.apply()
21
 
22
  # Initialize FastAPI
23
+ app = FastAPI()
24
 
25
  # Add CORS middleware
26
+ app.add_middleware(
27
  CORSMiddleware,
28
  allow_origins=["*"],
29
  allow_credentials=True,
 
44
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
45
  model = model.to(DEVICE)
46
 
47
+ class TranslationRequest(BaseModel):
48
+ sentences: List[str]
49
+ target_lang: str
50
+
51
  def translate_text(sentences: List[str], target_lang: str):
52
  try:
53
  src_lang = "eng_Latn"
 
91
  raise Exception(f"Translation failed: {str(e)}")
92
 
93
  # FastAPI routes
94
+ @app.get("/api/health")
95
  async def health_check():
96
  return {"status": "healthy"}
97
 
98
+ @app.post("/api/translate")
99
+ async def translate_endpoint(request: TranslationRequest):
100
  try:
101
+ result = translate_text(
102
+ sentences=request.sentences,
103
+ target_lang=request.target_lang
104
+ )
105
+ return JSONResponse(content=result)
106
  except Exception as e:
107
  raise HTTPException(status_code=500, detail=str(e))
108
 
109
  # Streamlit interface
110
+ def streamlit_app():
111
  st.title("Indic Language Translator")
112
 
113
  # Input text
 
149
  st.markdown("""
150
  To use the translation API, send POST requests to:
151
  ```
152
+ https://darshankr-trans-en-indic.hf.space/api/translate
153
  ```
154
  Request body format:
155
  ```json
 
163
  for lang, code in target_languages.items():
164
  st.markdown(f"- {lang}: `{code}`")
165
 
166
+ # Create a unified application
167
+ def create_app():
168
+ routes = [
169
+ Mount("/api", app),
170
+ Mount("/", StaticFiles(directory="static", html=True), name="static"),
171
+ ]
172
+ return Starlette(routes=routes)
173
+
174
  if __name__ == "__main__":
175
+ if "streamlit" in sys.argv[0]:
176
+ streamlit_app()
177
+ else:
178
+ uvicorn.run(create_app(), host="0.0.0.0", port=7860)