Spaces:
Runtime error
Runtime error
File size: 6,881 Bytes
9c48ae2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
from datetime import datetime
import websockets
import asyncio
import os
import uuid
import json
import functools
import traceback
import sys
import logging
from multiprocessing import current_process, Process, Queue, queues
from common import MessageType, format_message, timestamp
import startup
user_dict = {}
KEY_TO_USE_DEFAULT = os.getenv("KEY_TO_USE_DEFAULT")
DEFAULT_LLM_API_KEY = os.getenv("DEFAULT_LLM_API_KEY") if KEY_TO_USE_DEFAULT is not None else None
DEFAULT_SERP_API_KEY = os.getenv("DEFAULT_SERP_API_KEY") if KEY_TO_USE_DEFAULT is not None else None
logging.basicConfig(level=logging.WARNING, format='%(asctime)s | %(levelname)-8s | %(module)s:%(funcName)s:%(lineno)d - %(message)s')
logger = logging.getLogger(__name__)
async def handle_message(task_id=None, message=None, alg_msg_queue=None, proxy=None, llm_api_key=None, serpapi_key=None):
if "llm_api_key" in message["data"] and len(message["data"]["llm_api_key"].strip()) >= 32:
llm_api_key = message["data"]["llm_api_key"].strip()
if KEY_TO_USE_DEFAULT is not None and \
DEFAULT_LLM_API_KEY is not None and \
llm_api_key == KEY_TO_USE_DEFAULT:
# replace with default key
logger.warning("Using default llm api key")
llm_api_key = DEFAULT_LLM_API_KEY
if "serpapi_key" in message["data"] and len(message["data"]["serpapi_key"].strip()) >= 32:
serpapi_key = message["data"]["serpapi_key"].strip()
if KEY_TO_USE_DEFAULT is not None and \
DEFAULT_SERP_API_KEY is not None and \
serpapi_key == KEY_TO_USE_DEFAULT:
# replace with default key
logger.warning("Using default serp api key")
serpapi_key = DEFAULT_SERP_API_KEY
idea = message["data"]["idea"].strip()
if not llm_api_key:
alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, msg="Invalid OpenAI key"))
return
if not serpapi_key:
alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, msg="Invalid SerpAPI key"))
return
if not idea or len(idea) < 2:
alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, msg="Invalid task idea"))
return
try:
await startup.startup(idea=idea, task_id=task_id, llm_api_key=llm_api_key, serpapi_key=serpapi_key, proxy=proxy, alg_msg_queue=alg_msg_queue)
alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, data={'task_id':task_id}, msg="finished"))
except Exception as e:
alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, msg=f"{e}"))
exc_type, exc_value, exc_traceback = sys.exc_info()
error_message = traceback.format_exception(exc_type, exc_value, exc_traceback)
logger.error("".join(error_message))
def handle_message_wrapper(task_id=None, message=None, alg_msg_queue=None, proxy=None, llm_api_key=None, serpapi_key=None):
logger.warning("New task:"+current_process().name)
asyncio.run(handle_message(task_id, message, alg_msg_queue, proxy, llm_api_key, serpapi_key))
def clear_queue(alg_msg_queue:Queue=None):
if not Queue:
return
try:
while True:
alg_msg_queue.get_nowait()
except queues.Empty:
pass
# read websocket messages
async def read_msg_worker(websocket=None, alg_msg_queue=None, proxy=None, llm_api_key=None, serpapi_key=None):
process = None
async for raw_message in websocket:
message = json.loads(raw_message)
if message["action"] == MessageType.Interrupt.value:
# force interrupt a specific task
task_id = message["data"]["task_id"]
if process and process.is_alive() and process.name == task_id:
logger.warning("Interrupt task:" + process.name)
process.terminate()
process = None
clear_queue(alg_msg_queue=alg_msg_queue)
alg_msg_queue.put_nowait(format_message(action=MessageType.Interrupt.value, data={'task_id': task_id}))
alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, data={'task_id': task_id}, msg="finished"))
elif message["action"] == MessageType.RunTask.value:
# auto interrupt previous task
if process and process.is_alive():
logger.warning("Interrupt task:" + process.name)
process.terminate()
process = None
clear_queue(alg_msg_queue=alg_msg_queue)
task_id = str(uuid.uuid4())
process = Process(target=handle_message_wrapper, args=(task_id, message, alg_msg_queue, proxy, llm_api_key, serpapi_key))
process.daemon = True
process.name = task_id
process.start()
# auto terminate process
if process and process.is_alive():
logger.warning("Interrupt task:" + process.name)
process.terminate()
process = None
clear_queue(alg_msg_queue=alg_msg_queue)
raise websockets.exceptions.ConnectionClosed(0, "websocket closed")
# send
async def send_msg_worker(websocket=None, alg_msg_queue=None):
while True:
if alg_msg_queue.empty():
await asyncio.sleep(0.5)
else:
msg = alg_msg_queue.get_nowait()
print("=====Sending msg=====\n", msg)
await websocket.send(msg)
async def echo(websocket, proxy=None, llm_api_key=None, serpapi_key=None):
# audo register
uid = datetime.strftime(datetime.now(), '%Y%m%d%H%M%S.%f')+'_'+str(uuid.uuid4())
logger.warning(f"New user registered, uid: {uid}")
if uid not in user_dict:
user_dict[uid] = websocket
else:
logger.warning(f"Duplicate user, uid: {uid}")
# message handling
try:
alg_msg_queue = Queue()
await asyncio.gather(
read_msg_worker(websocket=websocket, alg_msg_queue=alg_msg_queue, proxy=proxy, llm_api_key=llm_api_key, serpapi_key=serpapi_key),
send_msg_worker(websocket=websocket, alg_msg_queue=alg_msg_queue)
)
except websockets.exceptions.ConnectionClosed:
logger.warning("Websocket closed: remote endpoint going away")
finally:
asyncio.current_task().cancel()
# auto unregister
logger.warning(f"Auto unregister, uid: {uid}")
if uid in user_dict:
user_dict.pop(uid)
async def run_service(host: str = "localhost", port: int=9000, proxy: str=None, llm_api_key:str=None, serpapi_key:str=None):
message_handler = functools.partial(echo, proxy=proxy,llm_api_key=llm_api_key, serpapi_key=serpapi_key)
async with websockets.serve(message_handler, host, port):
logger.warning(f"Websocket server started: {host}:{port} {f'[proxy={proxy}]' if proxy else ''}")
await asyncio.Future()
|