Spaces:
Sleeping
Sleeping
import asyncio | |
import json | |
import logging | |
from fastapi import HTTPException, WebSocket, status | |
from typing import Dict | |
class InRequest: | |
def __init__(self): | |
self.responses: Dict[str, asyncio.Future] = {} | |
class ConnectionManager: | |
def __init__(self): | |
self.available = None | |
self.active_connections: Dict[str, WebSocket] = {} # Maps socket ID to WebSocket connection | |
self.in_request: Dict[str, InRequest] = {} # Store pending response futures | |
async def connect(self, socket_id: str, websocket: WebSocket): | |
await websocket.accept() | |
self.active_connections[socket_id] = websocket | |
if self.available is None: | |
self.available = socket_id | |
return socket_id | |
def disconnect(self, socket_id: str): | |
if socket_id in self.active_connections: | |
del self.active_connections[socket_id] | |
if self.available == socket_id: | |
self.available = None | |
async def broadcast(self, message: str): | |
for connection in self.active_connections.values(): | |
await connection.send_text(message) | |
async def receive_text(self, socket_id: str): | |
websocket = self.active_connections.get(socket_id) | |
if websocket: | |
return await websocket.receive_text() | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_502_BAD_GATEWAY, | |
detail=f"Socket ID {socket_id} not connected") | |
async def send_text(self, socket_id: str, message: str): | |
websocket = self.active_connections.get(socket_id) | |
if websocket: | |
await websocket.send_text(message) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_502_BAD_GATEWAY, | |
detail="WebSocket connection not found.") | |
async def send_bytes(self, socket_id: str, binary_data: bytes): | |
websocket = self.active_connections.get(socket_id) | |
if websocket: | |
await websocket.send_bytes(binary_data) # Send binary data | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_502_BAD_GATEWAY, | |
detail=f"Socket ID {socket_id} not connected") | |
async def listen(self, socket_id:str, request_id:str) -> str: | |
req = InRequest() | |
# Create a Future for waiting for the response | |
future = asyncio.get_event_loop().create_future() | |
req.responses[request_id] = future | |
self.in_request[socket_id] = req | |
try: | |
return await future # Await the future until it's set with a response | |
except asyncio.CancelledError: | |
raise HTTPException( | |
status_code=status.HTTP_502_BAD_GATEWAY, | |
detail=f"Socket ID {socket_id} not connected or canceled") | |
async def notify(self, socket_id: str, message: str): | |
logging.debug(message) | |
# If there is a pending future for this socket, set the result | |
if socket_id in self.in_request: | |
request_id, payload = self.extract_message(message) | |
if request_id is not None: | |
self.in_request[socket_id].responses[request_id].set_result(payload) | |
self.in_request.pop(socket_id, None) | |
def extract_message(self, message:str): | |
request_id = None | |
payload = None | |
logging.debug(message) | |
try: | |
o = json.loads(message) | |
if o is not None: | |
request_id, payload = o.get('request_id'), o.get('payload') | |
except Exception as e: | |
logging.warning(f"extract_message error: {str(e)}") | |
return request_id, payload |