File size: 6,820 Bytes
edad211 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
'''
Contributed by SagsMug. Modified by binary-husky
https://github.com/oobabooga/text-generation-webui/pull/175
'''
import asyncio
import json
import random
import string
import websockets
import logging
import time
import threading
import importlib
from toolbox import get_conf, update_ui
def random_hash():
letters = string.ascii_lowercase + string.digits
return ''.join(random.choice(letters) for i in range(9))
async def run(context, max_token, temperature, top_p, addr, port):
params = {
'max_new_tokens': max_token,
'do_sample': True,
'temperature': temperature,
'top_p': top_p,
'typical_p': 1,
'repetition_penalty': 1.05,
'encoder_repetition_penalty': 1.0,
'top_k': 0,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': True,
'seed': -1,
}
session = random_hash()
async with websockets.connect(f"ws://{addr}:{port}/queue/join") as websocket:
while content := json.loads(await websocket.recv()):
#Python3.10 syntax, replace with if elif on older
if content["msg"] == "send_hash":
await websocket.send(json.dumps({
"session_hash": session,
"fn_index": 12
}))
elif content["msg"] == "estimation":
pass
elif content["msg"] == "send_data":
await websocket.send(json.dumps({
"session_hash": session,
"fn_index": 12,
"data": [
context,
params['max_new_tokens'],
params['do_sample'],
params['temperature'],
params['top_p'],
params['typical_p'],
params['repetition_penalty'],
params['encoder_repetition_penalty'],
params['top_k'],
params['min_length'],
params['no_repeat_ngram_size'],
params['num_beams'],
params['penalty_alpha'],
params['length_penalty'],
params['early_stopping'],
params['seed'],
]
}))
elif content["msg"] == "process_starts":
pass
elif content["msg"] in ["process_generating", "process_completed"]:
yield content["output"]["data"][0]
# You can search for your desired end indicator and
# stop generation by closing the websocket here
if (content["msg"] == "process_completed"):
break
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
"""
发送至chatGPT,流式获取输出。
用于基础的对话功能。
inputs 是本次问询的输入
top_p, temperature是chatGPT的内部调优参数
history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
additional_fn代表点击的哪个按钮,按钮见functional.py
"""
if additional_fn is not None:
import core_functional
importlib.reload(core_functional) # 热更新prompt
core_functional = core_functional.get_core_functions()
if "PreProcess" in core_functional[additional_fn]: inputs = core_functional[additional_fn]["PreProcess"](inputs) # 获取预处理函数(如果有的话)
inputs = core_functional[additional_fn]["Prefix"] + inputs + core_functional[additional_fn]["Suffix"]
raw_input = "What I would like to say is the following: " + inputs
history.extend([inputs, ""])
chatbot.append([inputs, ""])
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
prompt = raw_input
tgui_say = ""
model_name, addr_port = llm_kwargs['llm_model'].split('@')
assert ':' in addr_port, "LLM_MODEL 格式不正确!" + llm_kwargs['llm_model']
addr, port = addr_port.split(':')
mutable = ["", time.time()]
def run_coorotine(mutable):
async def get_result(mutable):
# "tgui:galactica-1.3b@localhost:7860"
async for response in run(context=prompt, max_token=llm_kwargs['max_length'],
temperature=llm_kwargs['temperature'],
top_p=llm_kwargs['top_p'], addr=addr, port=port):
print(response[len(mutable[0]):])
mutable[0] = response
if (time.time() - mutable[1]) > 3:
print('exit when no listener')
break
asyncio.run(get_result(mutable))
thread_listen = threading.Thread(target=run_coorotine, args=(mutable,), daemon=True)
thread_listen.start()
while thread_listen.is_alive():
time.sleep(1)
mutable[1] = time.time()
# Print intermediate steps
if tgui_say != mutable[0]:
tgui_say = mutable[0]
history[-1] = tgui_say
chatbot[-1] = (history[-2], history[-1])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
def predict_no_ui_long_connection(inputs, llm_kwargs, history, sys_prompt, observe_window, console_slience=False):
raw_input = "What I would like to say is the following: " + inputs
prompt = raw_input
tgui_say = ""
model_name, addr_port = llm_kwargs['llm_model'].split('@')
assert ':' in addr_port, "LLM_MODEL 格式不正确!" + llm_kwargs['llm_model']
addr, port = addr_port.split(':')
def run_coorotine(observe_window):
async def get_result(observe_window):
async for response in run(context=prompt, max_token=llm_kwargs['max_length'],
temperature=llm_kwargs['temperature'],
top_p=llm_kwargs['top_p'], addr=addr, port=port):
print(response[len(observe_window[0]):])
observe_window[0] = response
if (time.time() - observe_window[1]) > 5:
print('exit when no listener')
break
asyncio.run(get_result(observe_window))
thread_listen = threading.Thread(target=run_coorotine, args=(observe_window,))
thread_listen.start()
return observe_window[0]
|