trustgate / app.py
patharanor's picture
improve websocket connection
4096277
import asyncio
import logging
from fastapi.concurrency import asynccontextmanager
import uvicorn
import os
from dotenv import load_dotenv
from fastapi import FastAPI, Response, WebSocket, WebSocketDisconnect, status, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from models.connection_manager import ConnectionManager
from models.request_payload import RequestPayload
from utils.package_manager import PackageManager
# Load environment variables from .env file
load_dotenv()
IS_DEV = os.environ.get('ENV', 'DEV') != 'PROD'
WEBSOCKET_SECURE_TOKEN = os.getenv("SECURE_TOKEN")
WHITELIST_CHANNEL_IDS = os.getenv('WHITELIST_CHANNEL_IDS')
X_REQUEST_USER = os.environ.get('X_REQUEST_USER')
X_API_KEY = os.environ.get('X_API_KEY')
WHITELIST_CHANNEL_IDS = WHITELIST_CHANNEL_IDS.split(',') if WHITELIST_CHANNEL_IDS is not None else []
app = FastAPI()
# Initialize the connection manager
manager = ConnectionManager()
package = PackageManager()
logging.basicConfig(
level=logging.WARNING,
format='%(asctime)s %(name)s %(levelname)-8s %(message)s',
datefmt='(%H:%M:%S)'
)
# CORS Middleware: restrict access to only trusted origins
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
#allow_origins=["https://your-frontend-domain.com"],
#allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
def root():
return Response(status_code=status.HTTP_200_OK, data='ok')
@app.get("/health")
def healthcheck():
return Response(status_code=status.HTTP_200_OK, data='ok')
@app.post("/hi_mlhub")
async def hi_mlhub(payload: RequestPayload):
if manager.available is not None:
request_id, compressed_data = package.gzip(payload)
# Send binary data to all connected WebSocket clients
await manager.send_bytes(manager.available, compressed_data)
try:
# Wait for the response with a timeout (e.g., 10 seconds)
data = await asyncio.wait_for(manager.listen(manager.available, request_id), timeout=10.0)
return JSONResponse(status_code=status.HTTP_200_OK, content=data)
except Exception:
return JSONResponse(status_code=status.HTTP_504_GATEWAY_TIMEOUT, content={ "error": "Timeout" })
else:
return JSONResponse(status_code=status.HTTP_502_BAD_GATEWAY, content={ "error": "MLaaS is not available." })
# Simple token-based authentication dependency
def is_valid_token(token: str):
return token == WEBSOCKET_SECURE_TOKEN
def is_valid_apikey(channel_id: str):
return channel_id is not None and channel_id in WHITELIST_CHANNEL_IDS
# WebSocket endpoint
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
headers = websocket.headers
token = headers.get("x-token")
channel_id = headers.get("x-channel-id")
if not is_valid_token(token):
return HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
if not is_valid_apikey(channel_id):
return HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No permission")
await manager.connect(channel_id, websocket)
try:
while True:
# Common receiver
data = await manager.receive_text(channel_id)
print(f"Message from MLaaS: {data}")
# Notify the manager that a message was received
await manager.notify(channel_id, data)
# Broadcast the message to all clients
#await manager.broadcast(f"Client {channel_id} says: {data}")
except WebSocketDisconnect:
manager.disconnect(channel_id)
await manager.broadcast(f"A client has disconnected with ID: {channel_id}")
return None
def is_valid(u, p):
return u == X_REQUEST_USER and p == X_API_KEY
if __name__ == "__main__":
uvicorn.run('app:app', host='0.0.0.0', port=7860, reload=True)