trustgate / models /connection_manager.py
patharanor's picture
improve websocket connection
4096277
import asyncio
import json
import logging
from fastapi import HTTPException, WebSocket, status
from typing import Dict
class InRequest:
def __init__(self):
self.responses: Dict[str, asyncio.Future] = {}
class ConnectionManager:
def __init__(self):
self.available = None
self.active_connections: Dict[str, WebSocket] = {} # Maps socket ID to WebSocket connection
self.in_request: Dict[str, InRequest] = {} # Store pending response futures
async def connect(self, socket_id: str, websocket: WebSocket):
await websocket.accept()
self.active_connections[socket_id] = websocket
if self.available is None:
self.available = socket_id
return socket_id
def disconnect(self, socket_id: str):
if socket_id in self.active_connections:
del self.active_connections[socket_id]
if self.available == socket_id:
self.available = None
async def broadcast(self, message: str):
for connection in self.active_connections.values():
await connection.send_text(message)
async def receive_text(self, socket_id: str):
websocket = self.active_connections.get(socket_id)
if websocket:
return await websocket.receive_text()
else:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Socket ID {socket_id} not connected")
async def send_text(self, socket_id: str, message: str):
websocket = self.active_connections.get(socket_id)
if websocket:
await websocket.send_text(message)
else:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="WebSocket connection not found.")
async def send_bytes(self, socket_id: str, binary_data: bytes):
websocket = self.active_connections.get(socket_id)
if websocket:
await websocket.send_bytes(binary_data) # Send binary data
else:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Socket ID {socket_id} not connected")
async def listen(self, socket_id:str, request_id:str) -> str:
req = InRequest()
# Create a Future for waiting for the response
future = asyncio.get_event_loop().create_future()
req.responses[request_id] = future
self.in_request[socket_id] = req
try:
return await future # Await the future until it's set with a response
except asyncio.CancelledError:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Socket ID {socket_id} not connected or canceled")
async def notify(self, socket_id: str, message: str):
logging.debug(message)
# If there is a pending future for this socket, set the result
if socket_id in self.in_request:
request_id, payload = self.extract_message(message)
if request_id is not None:
self.in_request[socket_id].responses[request_id].set_result(payload)
self.in_request.pop(socket_id, None)
def extract_message(self, message:str):
request_id = None
payload = None
logging.debug(message)
try:
o = json.loads(message)
if o is not None:
request_id, payload = o.get('request_id'), o.get('payload')
except Exception as e:
logging.warning(f"extract_message error: {str(e)}")
return request_id, payload