import asyncio import logging from fastapi.concurrency import asynccontextmanager import uvicorn import os from dotenv import load_dotenv from fastapi import FastAPI, Response, WebSocket, WebSocketDisconnect, status, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from models.connection_manager import ConnectionManager from models.request_payload import RequestPayload from utils.package_manager import PackageManager # Load environment variables from .env file load_dotenv() IS_DEV = os.environ.get('ENV', 'DEV') != 'PROD' WEBSOCKET_SECURE_TOKEN = os.getenv("SECURE_TOKEN") WHITELIST_CHANNEL_IDS = os.getenv('WHITELIST_CHANNEL_IDS') X_REQUEST_USER = os.environ.get('X_REQUEST_USER') X_API_KEY = os.environ.get('X_API_KEY') WHITELIST_CHANNEL_IDS = WHITELIST_CHANNEL_IDS.split(',') if WHITELIST_CHANNEL_IDS is not None else [] app = FastAPI() # Initialize the connection manager manager = ConnectionManager() package = PackageManager() logging.basicConfig( level=logging.WARNING, format='%(asctime)s %(name)s %(levelname)-8s %(message)s', datefmt='(%H:%M:%S)' ) # CORS Middleware: restrict access to only trusted origins app.add_middleware( CORSMiddleware, allow_origins=["*"], #allow_origins=["https://your-frontend-domain.com"], #allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") def root(): return Response(status_code=status.HTTP_200_OK, data='ok') @app.get("/health") def healthcheck(): return Response(status_code=status.HTTP_200_OK, data='ok') @app.post("/hi_mlhub") async def hi_mlhub(payload: RequestPayload): if manager.available is not None: request_id, compressed_data = package.gzip(payload) # Send binary data to all connected WebSocket clients await manager.send_bytes(manager.available, compressed_data) try: # Wait for the response with a timeout (e.g., 10 seconds) data = await asyncio.wait_for(manager.listen(manager.available, request_id), timeout=10.0) return JSONResponse(status_code=status.HTTP_200_OK, content=data) except Exception: return JSONResponse(status_code=status.HTTP_504_GATEWAY_TIMEOUT, content={ "error": "Timeout" }) else: return JSONResponse(status_code=status.HTTP_502_BAD_GATEWAY, content={ "error": "MLaaS is not available." }) # Simple token-based authentication dependency def is_valid_token(token: str): return token == WEBSOCKET_SECURE_TOKEN def is_valid_apikey(channel_id: str): return channel_id is not None and channel_id in WHITELIST_CHANNEL_IDS # WebSocket endpoint @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): headers = websocket.headers token = headers.get("x-token") channel_id = headers.get("x-channel-id") if not is_valid_token(token): return HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") if not is_valid_apikey(channel_id): return HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No permission") await manager.connect(channel_id, websocket) try: while True: # Common receiver data = await manager.receive_text(channel_id) print(f"Message from MLaaS: {data}") # Notify the manager that a message was received await manager.notify(channel_id, data) # Broadcast the message to all clients #await manager.broadcast(f"Client {channel_id} says: {data}") except WebSocketDisconnect: manager.disconnect(channel_id) await manager.broadcast(f"A client has disconnected with ID: {channel_id}") return None def is_valid(u, p): return u == X_REQUEST_USER and p == X_API_KEY if __name__ == "__main__": uvicorn.run('app:app', host='0.0.0.0', port=7860, reload=True)