Upload folder using huggingface_hub
Browse files- ChatWorld/ChatWorld.py +6 -8
- ChatWorld/NaiveDB.py +5 -2
- ChatWorld/models.py +8 -5
- app.py +31 -5
ChatWorld/ChatWorld.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from jinja2 import Template
|
2 |
import torch
|
3 |
|
4 |
-
from .models import GLM
|
5 |
|
6 |
from .NaiveDB import NaiveDB
|
7 |
from .utils import *
|
@@ -19,7 +19,7 @@ class ChatWorld:
|
|
19 |
|
20 |
self.history = []
|
21 |
|
22 |
-
self.client =
|
23 |
self.model = GLM()
|
24 |
self.db = NaiveDB()
|
25 |
self.prompt = Template(('Please be aware that your codename in this conversation is "{{model_role_name}}"'
|
@@ -81,17 +81,15 @@ class ChatWorld:
|
|
81 |
return {"role": "system", "content": self.prompt.render(model_role_name=self.model_role_name, model_role_nickname=self.model_role_nickname, role_name=role_name, role_nickname=role_nick_name, RAG=rag)}
|
82 |
|
83 |
def chat(self, text: str, user_role_name: str, user_role_nick_name: str = None, use_local_model=False):
|
|
|
|
|
84 |
message = [self.getSystemPrompt(text,
|
85 |
-
user_role_name, user_role_nick_name)
|
86 |
-
print(message)
|
87 |
if use_local_model:
|
88 |
response = self.model.get_response(message)
|
89 |
else:
|
90 |
-
response = self.client.chat(
|
91 |
-
user_role_name, text, user_role_nick_name)
|
92 |
|
93 |
-
self.history.append(
|
94 |
-
{"role": "user", "content": f"{user_role_name}:「{text}」"})
|
95 |
self.history.append(
|
96 |
{"role": "assistant", "content": f"{self.model_role_name}:「{response}」"})
|
97 |
return response
|
|
|
1 |
from jinja2 import Template
|
2 |
import torch
|
3 |
|
4 |
+
from .models import GLM, GLM_api
|
5 |
|
6 |
from .NaiveDB import NaiveDB
|
7 |
from .utils import *
|
|
|
19 |
|
20 |
self.history = []
|
21 |
|
22 |
+
self.client = GLM_api()
|
23 |
self.model = GLM()
|
24 |
self.db = NaiveDB()
|
25 |
self.prompt = Template(('Please be aware that your codename in this conversation is "{{model_role_name}}"'
|
|
|
81 |
return {"role": "system", "content": self.prompt.render(model_role_name=self.model_role_name, model_role_nickname=self.model_role_nickname, role_name=role_name, role_nickname=role_nick_name, RAG=rag)}
|
82 |
|
83 |
def chat(self, text: str, user_role_name: str, user_role_nick_name: str = None, use_local_model=False):
|
84 |
+
self.history.append(
|
85 |
+
{"role": "user", "content": f"{user_role_name}:「{text}」"})
|
86 |
message = [self.getSystemPrompt(text,
|
87 |
+
user_role_name, user_role_nick_name), {"role": "user", "content": f"{user_role_name}:「{text}」"}]
|
|
|
88 |
if use_local_model:
|
89 |
response = self.model.get_response(message)
|
90 |
else:
|
91 |
+
response = self.client.chat(message)
|
|
|
92 |
|
|
|
|
|
93 |
self.history.append(
|
94 |
{"role": "assistant", "content": f"{self.model_role_name}:「{response}」"})
|
95 |
return response
|
ChatWorld/NaiveDB.py
CHANGED
@@ -81,7 +81,10 @@ class NaiveDB:
|
|
81 |
similarities.sort(key=lambda x: x[0], reverse=True)
|
82 |
self.last_search_ids = [x[1] for x in similarities[:n_results]]
|
83 |
|
84 |
-
|
|
|
|
|
85 |
|
86 |
-
top_stories = [self.stories[
|
|
|
87 |
return top_stories
|
|
|
81 |
similarities.sort(key=lambda x: x[0], reverse=True)
|
82 |
self.last_search_ids = [x[1] for x in similarities[:n_results]]
|
83 |
|
84 |
+
stories_length = len(self.stories)
|
85 |
+
search_id_range = [(max(0, i-3), min(i+4, stories_length))
|
86 |
+
for i in self.last_search_ids]
|
87 |
|
88 |
+
top_stories = ["\n".join(self.stories[start:end+1])
|
89 |
+
for start, end in search_id_range]
|
90 |
return top_stories
|
ChatWorld/models.py
CHANGED
@@ -40,7 +40,7 @@ class GLM():
|
|
40 |
|
41 |
self.client = client.eval()
|
42 |
|
43 |
-
def message2query(messages) -> str:
|
44 |
# [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
|
45 |
# <|system|>
|
46 |
# You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
|
@@ -53,7 +53,9 @@ class GLM():
|
|
53 |
return "".join([template.substitute(message) for message in messages])
|
54 |
|
55 |
def get_response(self, message):
|
56 |
-
response, history = self.client.chat(
|
|
|
|
|
57 |
return response
|
58 |
|
59 |
|
@@ -62,7 +64,8 @@ class GLM_api:
|
|
62 |
self.client = ZhipuAI(api_key=os.environ["ZHIPU_API_KEY"])
|
63 |
self.model = model_name
|
64 |
|
65 |
-
def
|
|
|
66 |
response = self.client.chat.completions.create(
|
67 |
-
model=self.model,
|
68 |
-
return response.choices[0].message
|
|
|
40 |
|
41 |
self.client = client.eval()
|
42 |
|
43 |
+
def message2query(self, messages) -> str:
|
44 |
# [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
|
45 |
# <|system|>
|
46 |
# You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
|
|
|
53 |
return "".join([template.substitute(message) for message in messages])
|
54 |
|
55 |
def get_response(self, message):
|
56 |
+
response, history = self.client.chat(
|
57 |
+
self.tokenizer, self.message2query(message))
|
58 |
+
print(self.message2query(message))
|
59 |
return response
|
60 |
|
61 |
|
|
|
64 |
self.client = ZhipuAI(api_key=os.environ["ZHIPU_API_KEY"])
|
65 |
self.model = model_name
|
66 |
|
67 |
+
def chat(self, message):
|
68 |
+
print(message)
|
69 |
response = self.client.chat.completions.create(
|
70 |
+
model=self.model, messages=message)
|
71 |
+
return response.choices[0].message.content
|
app.py
CHANGED
@@ -11,6 +11,8 @@ logging.basicConfig(level=logging.INFO, filename="demo.log", filemode="w",
|
|
11 |
|
12 |
chatWorld = ChatWorld()
|
13 |
|
|
|
|
|
14 |
|
15 |
def getContent(input_file):
|
16 |
# 读取文件内容
|
@@ -31,33 +33,57 @@ def getContent(input_file):
|
|
31 |
role_name_list = [i for i in role_name_set if i != ""]
|
32 |
logging.info(f"role_name_list: {role_name_list}")
|
33 |
|
|
|
|
|
|
|
34 |
return gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[0]), gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[-1])
|
35 |
|
36 |
|
37 |
def submit_message(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
|
|
|
38 |
chatWorld.setRoleName(model_role_name, model_role_nickname)
|
39 |
response = chatWorld.chat(message,
|
40 |
role_name, role_nickname, use_local_model=True)
|
41 |
return response
|
42 |
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
with gr.Blocks() as demo:
|
45 |
|
46 |
upload_c = gr.File(label="上传文档文件")
|
47 |
|
48 |
with gr.Row():
|
49 |
-
model_role_name = gr.Radio(
|
50 |
model_role_nickname = gr.Textbox(label="模型角色昵称")
|
51 |
|
52 |
with gr.Row():
|
53 |
-
role_name = gr.Radio(
|
54 |
role_nickname = gr.Textbox(label="角色昵称")
|
55 |
|
56 |
upload_c.upload(fn=getContent, inputs=upload_c,
|
57 |
outputs=[model_role_name, role_name])
|
58 |
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
61 |
|
62 |
|
63 |
-
demo.launch(
|
|
|
11 |
|
12 |
chatWorld = ChatWorld()
|
13 |
|
14 |
+
role_name_list_global = None
|
15 |
+
|
16 |
|
17 |
def getContent(input_file):
|
18 |
# 读取文件内容
|
|
|
33 |
role_name_list = [i for i in role_name_set if i != ""]
|
34 |
logging.info(f"role_name_list: {role_name_list}")
|
35 |
|
36 |
+
global role_name_list_global
|
37 |
+
role_name_list_global = role_name_list
|
38 |
+
|
39 |
return gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[0]), gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[-1])
|
40 |
|
41 |
|
42 |
def submit_message(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
|
43 |
+
print(f"history: {history}")
|
44 |
chatWorld.setRoleName(model_role_name, model_role_nickname)
|
45 |
response = chatWorld.chat(message,
|
46 |
role_name, role_nickname, use_local_model=True)
|
47 |
return response
|
48 |
|
49 |
|
50 |
+
def submit_message_api(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
|
51 |
+
print(f"history: {history}")
|
52 |
+
chatWorld.setRoleName(model_role_name, model_role_nickname)
|
53 |
+
response = chatWorld.chat(message,
|
54 |
+
role_name, role_nickname, use_local_model=False)
|
55 |
+
return response
|
56 |
+
|
57 |
+
|
58 |
+
def get_role_list():
|
59 |
+
global role_name_list_global
|
60 |
+
if role_name_list_global:
|
61 |
+
return role_name_list_global
|
62 |
+
else:
|
63 |
+
return []
|
64 |
+
|
65 |
+
|
66 |
with gr.Blocks() as demo:
|
67 |
|
68 |
upload_c = gr.File(label="上传文档文件")
|
69 |
|
70 |
with gr.Row():
|
71 |
+
model_role_name = gr.Radio(get_role_list(), label="模型角色名")
|
72 |
model_role_nickname = gr.Textbox(label="模型角色昵称")
|
73 |
|
74 |
with gr.Row():
|
75 |
+
role_name = gr.Radio(get_role_list(), label="角色名")
|
76 |
role_nickname = gr.Textbox(label="角色昵称")
|
77 |
|
78 |
upload_c.upload(fn=getContent, inputs=upload_c,
|
79 |
outputs=[model_role_name, role_name])
|
80 |
|
81 |
+
with gr.Row():
|
82 |
+
chatBox_local = gr.ChatInterface(
|
83 |
+
submit_message, chatbot=gr.Chatbot(height=400, label="本地模型", render=False), additional_inputs=[model_role_name, role_name, model_role_nickname, role_nickname])
|
84 |
+
|
85 |
+
chatBox_api = gr.ChatInterface(
|
86 |
+
submit_message_api, chatbot=gr.Chatbot(height=400, label="API模型", render=False), additional_inputs=[model_role_name, role_name, model_role_nickname, role_nickname])
|
87 |
|
88 |
|
89 |
+
demo.launch(server_name="0.0.0.0")
|