Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,46 +1,128 @@
|
|
1 |
-
#
|
2 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from fastapi import FastAPI, HTTPException
|
4 |
from pydantic import BaseModel
|
5 |
from typing import List
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import torch
|
7 |
-
import asyncio
|
8 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
9 |
from IndicTransToolkit import IndicProcessor
|
10 |
-
import requests
|
11 |
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
# Initialize models and processors
|
14 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
ip = IndicProcessor(inference=True)
|
17 |
-
|
18 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
model = model.to(DEVICE)
|
20 |
|
21 |
def translate_text(sentences: List[str], target_lang: str):
|
22 |
try:
|
23 |
src_lang = "eng_Latn"
|
24 |
-
batch = ip.preprocess_batch(
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
with torch.no_grad():
|
28 |
generated_tokens = model.generate(
|
29 |
-
inputs,
|
30 |
use_cache=True,
|
31 |
min_length=0,
|
32 |
max_length=256,
|
33 |
num_beams=5,
|
34 |
num_return_sequences=1
|
35 |
)
|
36 |
-
|
37 |
with tokenizer.as_target_tokenizer():
|
38 |
generated_tokens = tokenizer.batch_decode(
|
39 |
generated_tokens.detach().cpu().tolist(),
|
40 |
skip_special_tokens=True,
|
41 |
clean_up_tokenization_spaces=True
|
42 |
)
|
43 |
-
|
44 |
translations = ip.postprocess_batch(generated_tokens, lang=target_lang)
|
45 |
return {
|
46 |
"translations": translations,
|
@@ -50,13 +132,26 @@ def translate_text(sentences: List[str], target_lang: str):
|
|
50 |
except Exception as e:
|
51 |
raise Exception(f"Translation failed: {str(e)}")
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
# Streamlit interface
|
54 |
def main():
|
55 |
st.title("Indic Language Translator")
|
56 |
-
|
57 |
# Input text
|
58 |
text_input = st.text_area("Enter text to translate:", "Hello, how are you?")
|
59 |
-
|
60 |
# Language selection
|
61 |
target_languages = {
|
62 |
"Hindi": "hin_Deva",
|
@@ -71,13 +166,17 @@ def main():
|
|
71 |
"Odia": "ori_Orya"
|
72 |
}
|
73 |
|
74 |
-
target_lang = st.selectbox(
|
75 |
-
|
|
|
|
|
|
|
76 |
if st.button("Translate"):
|
77 |
try:
|
78 |
-
result = translate_text(
|
79 |
-
|
80 |
-
|
|
|
81 |
st.success("Translation:")
|
82 |
st.write(result["translations"][0])
|
83 |
except Exception as e:
|
@@ -88,8 +187,9 @@ def main():
|
|
88 |
st.header("API Documentation")
|
89 |
st.markdown("""
|
90 |
To use the translation API, send POST requests to:
|
|
|
91 |
https://USERNAME-SPACE_NAME.hf.space/translate
|
92 |
-
|
93 |
Request body format:
|
94 |
```json
|
95 |
{
|
@@ -97,19 +197,21 @@ def main():
|
|
97 |
"target_lang": "hin_Deva"
|
98 |
}
|
99 |
```
|
100 |
-
|
101 |
-
Available target languages:
|
102 |
-
- Hindi: hin_Deva
|
103 |
-
- Bengali: ben_Beng
|
104 |
-
- Tamil: tam_Taml
|
105 |
-
- Telugu: tel_Telu
|
106 |
-
- Marathi: mar_Deva
|
107 |
-
- Gujarati: guj_Gujr
|
108 |
-
- Kannada: kan_Knda
|
109 |
-
- Malayalam: mal_Mlym
|
110 |
-
- Punjabi: pan_Guru
|
111 |
-
- Odia: ori_Orya
|
112 |
""")
|
|
|
|
|
|
|
113 |
|
114 |
if __name__ == "__main__":
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# run.py
|
2 |
+
import subprocess
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
|
6 |
+
def main():
|
7 |
+
# Start Streamlit server only
|
8 |
+
port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces uses port 7860
|
9 |
+
streamlit_process = subprocess.Popen([
|
10 |
+
sys.executable,
|
11 |
+
"-m",
|
12 |
+
"streamlit",
|
13 |
+
"run",
|
14 |
+
"app.py",
|
15 |
+
"--server.port",
|
16 |
+
str(port),
|
17 |
+
"--server.address",
|
18 |
+
"0.0.0.0"
|
19 |
+
])
|
20 |
+
|
21 |
+
try:
|
22 |
+
streamlit_process.wait()
|
23 |
+
except KeyboardInterrupt:
|
24 |
+
streamlit_process.terminate()
|
25 |
+
|
26 |
+
if __name__ == "__main__":
|
27 |
+
main()
|
28 |
+
|
29 |
+
# api.py
|
30 |
from fastapi import FastAPI, HTTPException
|
31 |
from pydantic import BaseModel
|
32 |
from typing import List
|
33 |
+
from app import translate_text
|
34 |
+
|
35 |
+
app = FastAPI()
|
36 |
+
|
37 |
+
class InputData(BaseModel):
|
38 |
+
sentences: List[str]
|
39 |
+
target_lang: str
|
40 |
+
|
41 |
+
@app.get("/health")
|
42 |
+
async def health_check():
|
43 |
+
return {"status": "healthy"}
|
44 |
+
|
45 |
+
@app.post("/translate")
|
46 |
+
async def translate(input_data: InputData):
|
47 |
+
try:
|
48 |
+
result = translate_text(
|
49 |
+
sentences=input_data.sentences,
|
50 |
+
target_lang=input_data.target_lang
|
51 |
+
)
|
52 |
+
return result
|
53 |
+
except Exception as e:
|
54 |
+
raise HTTPException(status_code=500, detail=str(e))
|
55 |
+
|
56 |
+
# app.py
|
57 |
+
import streamlit as st
|
58 |
+
from fastapi import FastAPI
|
59 |
+
from typing import List
|
60 |
import torch
|
|
|
61 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
62 |
from IndicTransToolkit import IndicProcessor
|
|
|
63 |
import json
|
64 |
+
from fastapi.middleware.cors import CORSMiddleware
|
65 |
+
from fastapi.staticfiles import StaticFiles
|
66 |
+
import uvicorn
|
67 |
+
|
68 |
+
# Initialize FastAPI
|
69 |
+
api = FastAPI()
|
70 |
+
|
71 |
+
# Add CORS middleware
|
72 |
+
api.add_middleware(
|
73 |
+
CORSMiddleware,
|
74 |
+
allow_origins=["*"],
|
75 |
+
allow_credentials=True,
|
76 |
+
allow_methods=["*"],
|
77 |
+
allow_headers=["*"],
|
78 |
+
)
|
79 |
|
80 |
# Initialize models and processors
|
81 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
82 |
+
"ai4bharat/indictrans2-en-indic-1B",
|
83 |
+
trust_remote_code=True
|
84 |
+
)
|
85 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
86 |
+
"ai4bharat/indictrans2-en-indic-1B",
|
87 |
+
trust_remote_code=True
|
88 |
+
)
|
89 |
ip = IndicProcessor(inference=True)
|
|
|
90 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
91 |
model = model.to(DEVICE)
|
92 |
|
93 |
def translate_text(sentences: List[str], target_lang: str):
|
94 |
try:
|
95 |
src_lang = "eng_Latn"
|
96 |
+
batch = ip.preprocess_batch(
|
97 |
+
sentences,
|
98 |
+
src_lang=src_lang,
|
99 |
+
tgt_lang=target_lang
|
100 |
+
)
|
101 |
+
inputs = tokenizer(
|
102 |
+
batch,
|
103 |
+
truncation=True,
|
104 |
+
padding="longest",
|
105 |
+
return_tensors="pt",
|
106 |
+
return_attention_mask=True
|
107 |
+
).to(DEVICE)
|
108 |
+
|
109 |
with torch.no_grad():
|
110 |
generated_tokens = model.generate(
|
111 |
+
**inputs,
|
112 |
use_cache=True,
|
113 |
min_length=0,
|
114 |
max_length=256,
|
115 |
num_beams=5,
|
116 |
num_return_sequences=1
|
117 |
)
|
118 |
+
|
119 |
with tokenizer.as_target_tokenizer():
|
120 |
generated_tokens = tokenizer.batch_decode(
|
121 |
generated_tokens.detach().cpu().tolist(),
|
122 |
skip_special_tokens=True,
|
123 |
clean_up_tokenization_spaces=True
|
124 |
)
|
125 |
+
|
126 |
translations = ip.postprocess_batch(generated_tokens, lang=target_lang)
|
127 |
return {
|
128 |
"translations": translations,
|
|
|
132 |
except Exception as e:
|
133 |
raise Exception(f"Translation failed: {str(e)}")
|
134 |
|
135 |
+
# FastAPI routes
|
136 |
+
@api.get("/health")
|
137 |
+
async def health_check():
|
138 |
+
return {"status": "healthy"}
|
139 |
+
|
140 |
+
@api.post("/translate")
|
141 |
+
async def translate_endpoint(sentences: List[str], target_lang: str):
|
142 |
+
try:
|
143 |
+
result = translate_text(sentences=sentences, target_lang=target_lang)
|
144 |
+
return result
|
145 |
+
except Exception as e:
|
146 |
+
raise HTTPException(status_code=500, detail=str(e))
|
147 |
+
|
148 |
# Streamlit interface
|
149 |
def main():
|
150 |
st.title("Indic Language Translator")
|
151 |
+
|
152 |
# Input text
|
153 |
text_input = st.text_area("Enter text to translate:", "Hello, how are you?")
|
154 |
+
|
155 |
# Language selection
|
156 |
target_languages = {
|
157 |
"Hindi": "hin_Deva",
|
|
|
166 |
"Odia": "ori_Orya"
|
167 |
}
|
168 |
|
169 |
+
target_lang = st.selectbox(
|
170 |
+
"Select target language:",
|
171 |
+
options=list(target_languages.keys())
|
172 |
+
)
|
173 |
+
|
174 |
if st.button("Translate"):
|
175 |
try:
|
176 |
+
result = translate_text(
|
177 |
+
sentences=[text_input],
|
178 |
+
target_lang=target_languages[target_lang]
|
179 |
+
)
|
180 |
st.success("Translation:")
|
181 |
st.write(result["translations"][0])
|
182 |
except Exception as e:
|
|
|
187 |
st.header("API Documentation")
|
188 |
st.markdown("""
|
189 |
To use the translation API, send POST requests to:
|
190 |
+
```
|
191 |
https://USERNAME-SPACE_NAME.hf.space/translate
|
192 |
+
```
|
193 |
Request body format:
|
194 |
```json
|
195 |
{
|
|
|
197 |
"target_lang": "hin_Deva"
|
198 |
}
|
199 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
""")
|
201 |
+
st.markdown("Available target languages:")
|
202 |
+
for lang, code in target_languages.items():
|
203 |
+
st.markdown(f"- {lang}: `{code}`")
|
204 |
|
205 |
if __name__ == "__main__":
|
206 |
+
# Run both Streamlit and FastAPI
|
207 |
+
import threading
|
208 |
+
|
209 |
+
def run_fastapi():
|
210 |
+
uvicorn.run(api, host="0.0.0.0", port=8000)
|
211 |
+
|
212 |
+
# Start FastAPI in a separate thread
|
213 |
+
api_thread = threading.Thread(target=run_fastapi)
|
214 |
+
api_thread.start()
|
215 |
+
|
216 |
+
# Run Streamlit
|
217 |
+
main()
|