# attribution: code for demo is based on https://huggingface.co/spaces/Geonmo/nllb-translation-demo from fastapi import FastAPI, Depends, HTTPException, Request from fastapi.security import APIKeyQuery from pydantic import BaseModel from typing import List, Union, Dict from functools import lru_cache import jwt from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline import torch from flores200_codes import flores_codes import gradio as gr from fastapi import FastAPI, Request from fastapi.responses import JSONResponse import uvicorn from starlette.middleware.base import BaseHTTPMiddleware import logging import json # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) CUSTOM_PATH = "/gradio" app = FastAPI() class LoggingMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # Log request info logger.info(f"--- RAW REQUEST ---") logger.info(f"Method: {request.method}") logger.info(f"URL: {request.url}") logger.info("Headers:") for name, value in request.headers.items(): logger.info(f" {name}: {value}") # Get raw body body = await request.body() logger.info("Body:") logger.info(body.decode()) logger.info("--- END RAW REQUEST ---") # We need to set the body again since we've already read it request._body = body response = await call_next(request) return response app.add_middleware(LoggingMiddleware) # This should be a secure secret key in a real application SECRET_KEY = "your_secret_key_here" # Define the security scheme api_key_query = APIKeyQuery(name="jwtToken", auto_error=False) class TranslationRequest(BaseModel): strings: List[Union[str, Dict[str, str]]] class TranslationResponse(BaseModel): data: Dict[str, List[str]] @lru_cache() def load_model(): model_name_dict = { "nllb-distilled-600M": "facebook/nllb-200-distilled-600M", } call_name = "nllb-distilled-600M" real_name = model_name_dict[call_name] print(f"\tLoading model: {call_name}") device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModelForSeq2SeqLM.from_pretrained(real_name).to(device) tokenizer = AutoTokenizer.from_pretrained(real_name) return model, tokenizer model, tokenizer = load_model() def translate_text(text: List[str], source_lang: str, target_lang: str) -> List[str]: source = flores_codes[source_lang] target = flores_codes[target_lang] translator = pipeline( "translation", model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target, ) output = translator(text, max_length=400) return [item["translation_text"] for item in output] async def verify_token(token: str = Depends(api_key_query)): if not token: return "test123" #raise HTTPException(status_code=401, detail={"message": "Token is missing"}) try: pass # disable temporarily #jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) except: raise HTTPException(status_code=401, detail={"message": "Token is invalid"}) return token @app.get("/translate/", response_model=TranslationResponse) @app.post("/translate/", response_model=TranslationResponse) async def translate( request: Request, source: str, target: str, project_id: str, token: str = Depends(verify_token), ): if not all([source, target, project_id]): raise HTTPException( status_code=400, detail={"message": "Missing required parameters"} ) try: data = await request.json() except: data = await request.body() print("====", data.decode(), "====", sep="\n") data = json.loads(data.decode()) strings = data.get("strings", []) if not strings: raise HTTPException( status_code=400, detail={"message": "No strings provided for translation"} ) try: if isinstance(strings[0], dict): # Extended request translations = translate_text([s["text"] for s in strings], source, target) else: # Simple request translations = translate_text(strings, source, target) return TranslationResponse(data={"translations": translations}) except Exception as e: raise HTTPException(status_code=500, detail={"message": str(e)}) @app.get("/logo.png") async def logo(): # TODO: Implement logic to serve the logo return "Logo placeholder" lang_codes = list(flores_codes.keys()) #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'), inputs = [gr.Dropdown(lang_codes, value='English', label='Source'), gr.Dropdown(lang_codes, value='Crimean Tatar', label='Target'), gr.Textbox(lines=5, label="Input text"), ] outputs = gr.Textbox(label="Output") title = "Crimean Tatar Translator based on NLLB distilled 600M demo" description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb." examples = [ ['English', 'Korean', 'Hi. nice to meet you'] ] def translate_single(source_lang: str, target_lang: str, text: str) -> List[str]: return translate_text([text], source_lang, target_lang)[0] io = gr.Interface(translate_single, inputs, outputs, title=title, description=description, ) app = gr.mount_gradio_app(app, io, path=CUSTOM_PATH) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)