File size: 6,881 Bytes
9c48ae2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from datetime import datetime
import websockets
import asyncio
import os
import uuid
import json
import functools
import traceback
import sys
import logging
from multiprocessing import current_process, Process, Queue, queues

from common import MessageType, format_message, timestamp
import startup
user_dict = {}

KEY_TO_USE_DEFAULT = os.getenv("KEY_TO_USE_DEFAULT")
DEFAULT_LLM_API_KEY = os.getenv("DEFAULT_LLM_API_KEY") if KEY_TO_USE_DEFAULT is not None else None
DEFAULT_SERP_API_KEY = os.getenv("DEFAULT_SERP_API_KEY") if KEY_TO_USE_DEFAULT is not None else None

logging.basicConfig(level=logging.WARNING, format='%(asctime)s | %(levelname)-8s | %(module)s:%(funcName)s:%(lineno)d - %(message)s')
logger = logging.getLogger(__name__)

async def handle_message(task_id=None, message=None, alg_msg_queue=None, proxy=None, llm_api_key=None, serpapi_key=None): 
    if "llm_api_key" in message["data"] and len(message["data"]["llm_api_key"].strip()) >= 32:
        llm_api_key = message["data"]["llm_api_key"].strip()
    if KEY_TO_USE_DEFAULT is not None and \
            DEFAULT_LLM_API_KEY is not None and \
            llm_api_key == KEY_TO_USE_DEFAULT:
        # replace with default key
        logger.warning("Using default llm api key")
        llm_api_key = DEFAULT_LLM_API_KEY

    if "serpapi_key" in message["data"] and len(message["data"]["serpapi_key"].strip()) >= 32:
        serpapi_key = message["data"]["serpapi_key"].strip()
    if KEY_TO_USE_DEFAULT is not None and \
            DEFAULT_SERP_API_KEY is not None and \
            serpapi_key == KEY_TO_USE_DEFAULT:
        # replace with default key
        logger.warning("Using default serp api key")
        serpapi_key = DEFAULT_SERP_API_KEY

    idea = message["data"]["idea"].strip() 

    if not llm_api_key:
        alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, msg="Invalid OpenAI key"))
        return
    if not serpapi_key:
        alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, msg="Invalid SerpAPI key"))
        return
    if not idea or len(idea) < 2:
        alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, msg="Invalid task idea"))
        return
    try:
        await startup.startup(idea=idea, task_id=task_id, llm_api_key=llm_api_key, serpapi_key=serpapi_key, proxy=proxy, alg_msg_queue=alg_msg_queue)
        alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, data={'task_id':task_id}, msg="finished"))
    except Exception as e:
        alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, msg=f"{e}"))

        exc_type, exc_value, exc_traceback = sys.exc_info()
        error_message = traceback.format_exception(exc_type, exc_value, exc_traceback)
        logger.error("".join(error_message))

def handle_message_wrapper(task_id=None, message=None, alg_msg_queue=None, proxy=None, llm_api_key=None, serpapi_key=None):
    logger.warning("New task:"+current_process().name)
    asyncio.run(handle_message(task_id, message, alg_msg_queue, proxy, llm_api_key, serpapi_key))

def clear_queue(alg_msg_queue:Queue=None):
    if not Queue:
        return
    try:
        while True:
            alg_msg_queue.get_nowait()
    except queues.Empty:
        pass

# read websocket messages
async def read_msg_worker(websocket=None, alg_msg_queue=None, proxy=None, llm_api_key=None, serpapi_key=None):
    process = None
    async for raw_message in websocket:
        message = json.loads(raw_message)
        if message["action"] == MessageType.Interrupt.value:
            # force interrupt a specific task
            task_id = message["data"]["task_id"]
            if process and process.is_alive() and process.name == task_id:
                logger.warning("Interrupt task:" + process.name)
                process.terminate()
                process = None
            clear_queue(alg_msg_queue=alg_msg_queue)
            alg_msg_queue.put_nowait(format_message(action=MessageType.Interrupt.value, data={'task_id': task_id}))
            alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, data={'task_id': task_id}, msg="finished"))
                
        elif message["action"] == MessageType.RunTask.value:
            # auto interrupt previous task
            if process and process.is_alive():
                logger.warning("Interrupt task:" + process.name)
                process.terminate()
                process = None
                clear_queue(alg_msg_queue=alg_msg_queue)

            task_id = str(uuid.uuid4())
            process = Process(target=handle_message_wrapper, args=(task_id, message, alg_msg_queue, proxy, llm_api_key, serpapi_key))
            process.daemon = True
            process.name = task_id
            process.start()
        
    # auto terminate process
    if process and process.is_alive():
        logger.warning("Interrupt task:" + process.name)
        process.terminate()
        process = None
        clear_queue(alg_msg_queue=alg_msg_queue)
    
    raise websockets.exceptions.ConnectionClosed(0, "websocket closed")

# send
async def send_msg_worker(websocket=None, alg_msg_queue=None):
    while True:
        if alg_msg_queue.empty():
            await asyncio.sleep(0.5)
        else:
            msg = alg_msg_queue.get_nowait()
            print("=====Sending msg=====\n", msg)
            await websocket.send(msg)

async def echo(websocket, proxy=None, llm_api_key=None, serpapi_key=None):
    # audo register
    uid = datetime.strftime(datetime.now(), '%Y%m%d%H%M%S.%f')+'_'+str(uuid.uuid4())
    logger.warning(f"New user registered, uid: {uid}")
    if uid not in user_dict:
        user_dict[uid] = websocket
    else:
        logger.warning(f"Duplicate user, uid: {uid}")
        
    # message handling
    try:
        alg_msg_queue = Queue()
        await asyncio.gather(
            read_msg_worker(websocket=websocket, alg_msg_queue=alg_msg_queue, proxy=proxy, llm_api_key=llm_api_key, serpapi_key=serpapi_key), 
            send_msg_worker(websocket=websocket, alg_msg_queue=alg_msg_queue)
        )
    except websockets.exceptions.ConnectionClosed:
        logger.warning("Websocket closed: remote endpoint going away")
    finally:
        asyncio.current_task().cancel()
        # auto unregister
        logger.warning(f"Auto unregister, uid: {uid}")
       
        if uid in user_dict:
            user_dict.pop(uid)


async def run_service(host: str = "localhost", port: int=9000, proxy: str=None, llm_api_key:str=None, serpapi_key:str=None):
    message_handler = functools.partial(echo, proxy=proxy,llm_api_key=llm_api_key, serpapi_key=serpapi_key)
    async with websockets.serve(message_handler, host, port):
        logger.warning(f"Websocket server started: {host}:{port} {f'[proxy={proxy}]' if proxy else ''}")
        await asyncio.Future()