Spaces:
Sleeping
Sleeping
import asyncio | |
import logging | |
from fastapi.concurrency import asynccontextmanager | |
import uvicorn | |
import os | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, Response, WebSocket, WebSocketDisconnect, status, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from models.connection_manager import ConnectionManager | |
from models.request_payload import RequestPayload | |
from utils.package_manager import PackageManager | |
# Load environment variables from .env file | |
load_dotenv() | |
IS_DEV = os.environ.get('ENV', 'DEV') != 'PROD' | |
WEBSOCKET_SECURE_TOKEN = os.getenv("SECURE_TOKEN") | |
WHITELIST_CHANNEL_IDS = os.getenv('WHITELIST_CHANNEL_IDS') | |
X_REQUEST_USER = os.environ.get('X_REQUEST_USER') | |
X_API_KEY = os.environ.get('X_API_KEY') | |
WHITELIST_CHANNEL_IDS = WHITELIST_CHANNEL_IDS.split(',') if WHITELIST_CHANNEL_IDS is not None else [] | |
app = FastAPI() | |
# Initialize the connection manager | |
manager = ConnectionManager() | |
package = PackageManager() | |
logging.basicConfig( | |
level=logging.WARNING, | |
format='%(asctime)s %(name)s %(levelname)-8s %(message)s', | |
datefmt='(%H:%M:%S)' | |
) | |
# CORS Middleware: restrict access to only trusted origins | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
#allow_origins=["https://your-frontend-domain.com"], | |
#allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def root(): | |
return Response(status_code=status.HTTP_200_OK, data='ok') | |
def healthcheck(): | |
return Response(status_code=status.HTTP_200_OK, data='ok') | |
async def hi_mlhub(payload: RequestPayload): | |
if manager.available is not None: | |
request_id, compressed_data = package.gzip(payload) | |
# Send binary data to all connected WebSocket clients | |
await manager.send_bytes(manager.available, compressed_data) | |
try: | |
# Wait for the response with a timeout (e.g., 10 seconds) | |
data = await asyncio.wait_for(manager.listen(manager.available, request_id), timeout=10.0) | |
return JSONResponse(status_code=status.HTTP_200_OK, content=data) | |
except Exception: | |
return JSONResponse(status_code=status.HTTP_504_GATEWAY_TIMEOUT, content={ "error": "Timeout" }) | |
else: | |
return JSONResponse(status_code=status.HTTP_502_BAD_GATEWAY, content={ "error": "MLaaS is not available." }) | |
# Simple token-based authentication dependency | |
def is_valid_token(token: str): | |
return token == WEBSOCKET_SECURE_TOKEN | |
def is_valid_apikey(channel_id: str): | |
return channel_id is not None and channel_id in WHITELIST_CHANNEL_IDS | |
# WebSocket endpoint | |
async def websocket_endpoint(websocket: WebSocket): | |
headers = websocket.headers | |
token = headers.get("x-token") | |
channel_id = headers.get("x-channel-id") | |
if not is_valid_token(token): | |
return HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") | |
if not is_valid_apikey(channel_id): | |
return HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No permission") | |
await manager.connect(channel_id, websocket) | |
try: | |
while True: | |
# Common receiver | |
data = await manager.receive_text(channel_id) | |
print(f"Message from MLaaS: {data}") | |
# Notify the manager that a message was received | |
await manager.notify(channel_id, data) | |
# Broadcast the message to all clients | |
#await manager.broadcast(f"Client {channel_id} says: {data}") | |
except WebSocketDisconnect: | |
manager.disconnect(channel_id) | |
await manager.broadcast(f"A client has disconnected with ID: {channel_id}") | |
return None | |
def is_valid(u, p): | |
return u == X_REQUEST_USER and p == X_API_KEY | |
if __name__ == "__main__": | |
uvicorn.run('app:app', host='0.0.0.0', port=7860, reload=True) |