Spaces:
Runtime error
Runtime error
ChenyuRabbitLove
commited on
Commit
·
e4c798e
1
Parent(s):
abab449
feat: add summerizer map-reduce
Browse files- utils/chatbot_diff.py +249 -0
- utils/gpt_processor.py +62 -21
utils/chatbot_diff.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import secrets
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import numpy as np
|
9 |
+
import openai
|
10 |
+
import pandas as pd
|
11 |
+
from google.oauth2.service_account import Credentials
|
12 |
+
from googleapiclient.discovery import build
|
13 |
+
from googleapiclient.http import MediaIoBaseDownload, MediaFileUpload
|
14 |
+
from openai.embeddings_utils import distances_from_embeddings
|
15 |
+
|
16 |
+
from .gpt_processor import QuestionAnswerer
|
17 |
+
from .work_flow_controller import WorkFlowController
|
18 |
+
|
19 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
20 |
+
openai.api_key = OPENAI_API_KEY
|
21 |
+
|
22 |
+
|
23 |
+
class Chatbot:
|
24 |
+
def __init__(self):
|
25 |
+
self.history = []
|
26 |
+
self.upload_state = "waiting"
|
27 |
+
self.uid = self.__generate_uid()
|
28 |
+
|
29 |
+
self.g_drive_service = self.__init_drive_service()
|
30 |
+
self.knowledge_base = None
|
31 |
+
self.context = None
|
32 |
+
self.context_page_num = None
|
33 |
+
self.context_file_name = None
|
34 |
+
|
35 |
+
def build_knowledge_base(self, files, upload_mode="once"):
|
36 |
+
work_flow_controller = WorkFlowController(files, self.uid)
|
37 |
+
self.csv_result_path = work_flow_controller.csv_result_path
|
38 |
+
self.json_result_path = work_flow_controller.json_result_path
|
39 |
+
|
40 |
+
if upload_mode == "Upload to Database":
|
41 |
+
self.__get_db_knowledge_base()
|
42 |
+
else:
|
43 |
+
self.__get_local_knowledge_base()
|
44 |
+
|
45 |
+
def __get_db_knowledge_base(self):
|
46 |
+
filename = "knowledge_base.csv"
|
47 |
+
db = self.__read_db(self.g_drive_service)
|
48 |
+
cur_content = pd.read_csv(self.csv_result_path)
|
49 |
+
for _ in range(10):
|
50 |
+
try:
|
51 |
+
self.__write_into_db(self.g_drive_service, db, cur_content)
|
52 |
+
break
|
53 |
+
except Exception as e:
|
54 |
+
logging.error(e)
|
55 |
+
logging.error("Failed to upload to database, retrying...")
|
56 |
+
continue
|
57 |
+
self.knowledge_base = db
|
58 |
+
self.upload_state = "done"
|
59 |
+
|
60 |
+
def __get_local_knowledge_base(self):
|
61 |
+
with open(self.csv_result_path, "r", encoding="UTF-8") as fp:
|
62 |
+
knowledge_base = pd.read_csv(fp)
|
63 |
+
knowledge_base["page_embedding"] = (
|
64 |
+
knowledge_base["page_embedding"].apply(eval).apply(np.array)
|
65 |
+
)
|
66 |
+
|
67 |
+
self.knowledge_base = knowledge_base
|
68 |
+
self.upload_state = "done"
|
69 |
+
|
70 |
+
def __write_into_db(self, service, db: pd.DataFrame, cur_content: pd.DataFrame):
|
71 |
+
db = pd.concat([db, cur_content], ignore_index=True)
|
72 |
+
db.to_csv(f"{self.uid}_knowledge_base.csv", index=False)
|
73 |
+
media = MediaFileUpload(f"{self.uid}_knowledge_base.csv", resumable=True)
|
74 |
+
request = (
|
75 |
+
service.files()
|
76 |
+
.update(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW", media_body=media)
|
77 |
+
.execute()
|
78 |
+
)
|
79 |
+
|
80 |
+
def __init_drive_service(self):
|
81 |
+
SCOPES = ["https://www.googleapis.com/auth/drive"]
|
82 |
+
SERVICE_ACCOUNT_INFO = os.getenv("CREDENTIALS")
|
83 |
+
service_account_info_dict = json.loads(SERVICE_ACCOUNT_INFO)
|
84 |
+
|
85 |
+
creds = Credentials.from_service_account_info(
|
86 |
+
service_account_info_dict, scopes=SCOPES
|
87 |
+
)
|
88 |
+
|
89 |
+
return build("drive", "v3", credentials=creds)
|
90 |
+
|
91 |
+
def __read_db(self, service):
|
92 |
+
request = service.files().get_media(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW")
|
93 |
+
fh = io.BytesIO()
|
94 |
+
downloader = MediaIoBaseDownload(fh, request)
|
95 |
+
|
96 |
+
done = False
|
97 |
+
while done is False:
|
98 |
+
status, done = downloader.next_chunk()
|
99 |
+
print(f"Download {int(status.progress() * 100)}%.")
|
100 |
+
|
101 |
+
fh.seek(0)
|
102 |
+
|
103 |
+
return pd.read_csv(fh)
|
104 |
+
|
105 |
+
def __read_file(self, service, filename) -> pd.DataFrame:
|
106 |
+
query = f"name='{filename}'"
|
107 |
+
results = service.files().list(q=query).execute()
|
108 |
+
files = results.get("files", [])
|
109 |
+
|
110 |
+
file_id = files[0]["id"]
|
111 |
+
|
112 |
+
request = service.files().get_media(fileId=file_id)
|
113 |
+
fh = io.BytesIO()
|
114 |
+
downloader = MediaIoBaseDownload(fh, request)
|
115 |
+
|
116 |
+
done = False
|
117 |
+
while done is False:
|
118 |
+
status, done = downloader.next_chunk()
|
119 |
+
print(f"Download {int(status.progress() * 100)}%.")
|
120 |
+
|
121 |
+
fh.seek(0)
|
122 |
+
|
123 |
+
return pd.read_csv(fh)
|
124 |
+
|
125 |
+
def __upload_file(self, service):
|
126 |
+
results = service.files().list(pageSize=10).execute()
|
127 |
+
items = results.get("files", [])
|
128 |
+
if not items:
|
129 |
+
print("No files found.")
|
130 |
+
else:
|
131 |
+
print("Files:")
|
132 |
+
for item in items:
|
133 |
+
print(f"{item['name']} ({item['id']})")
|
134 |
+
|
135 |
+
media = MediaFileUpload(self.csv_result_path, resumable=True)
|
136 |
+
filename_prefix = "ex_bot_database_"
|
137 |
+
filename = filename_prefix + self.uid + ".csv"
|
138 |
+
request = (
|
139 |
+
service.files()
|
140 |
+
.create(
|
141 |
+
media_body=media,
|
142 |
+
body={
|
143 |
+
"name": filename,
|
144 |
+
"parents": [
|
145 |
+
"1Lp21EZlVlqL-c27VQBC6wTbUC1YpKMsG"
|
146 |
+
],
|
147 |
+
},
|
148 |
+
)
|
149 |
+
.execute()
|
150 |
+
)
|
151 |
+
|
152 |
+
def clear_state(self):
|
153 |
+
self.context = None
|
154 |
+
self.context_page_num = None
|
155 |
+
self.context_file_name = None
|
156 |
+
self.knowledge_base = None
|
157 |
+
self.upload_state = "waiting"
|
158 |
+
self.history = []
|
159 |
+
|
160 |
+
def send_system_notification(self):
|
161 |
+
if self.upload_state == "waiting":
|
162 |
+
conversation = [["已上傳文件", "文件處理中(摘要、翻譯等),結束後將自動回覆"]]
|
163 |
+
return conversation
|
164 |
+
elif self.upload_state == "done":
|
165 |
+
conversation = [["已上傳文件", "文件處理完成,請開始提問"]]
|
166 |
+
return conversation
|
167 |
+
|
168 |
+
def change_md(self):
|
169 |
+
content = self.__construct_summary()
|
170 |
+
return gr.Markdown.update(content, visible=True)
|
171 |
+
|
172 |
+
def __construct_summary(self):
|
173 |
+
with open(self.json_result_path, "r", encoding="UTF-8") as fp:
|
174 |
+
knowledge_base = json.load(fp)
|
175 |
+
|
176 |
+
context = ""
|
177 |
+
for key in knowledge_base.keys():
|
178 |
+
file_name = knowledge_base[key]["file_name"]
|
179 |
+
total_page = knowledge_base[key]["total_pages"]
|
180 |
+
summary = knowledge_base[key]["summarized_content"]
|
181 |
+
file_context = f"""
|
182 |
+
### 文件摘要
|
183 |
+
{file_name} (共 {total_page} 頁)<br><br>
|
184 |
+
{summary}<br><br>
|
185 |
+
"""
|
186 |
+
context += file_context
|
187 |
+
return context
|
188 |
+
|
189 |
+
def user(self, message):
|
190 |
+
self.history += [[message, None]]
|
191 |
+
return "", self.history
|
192 |
+
|
193 |
+
def bot(self):
|
194 |
+
user_message = self.history[-1][0]
|
195 |
+
print(f"user_message: {user_message}")
|
196 |
+
|
197 |
+
if self.knowledge_base is None:
|
198 |
+
response = [
|
199 |
+
[user_message, "請先上傳文件"],
|
200 |
+
]
|
201 |
+
self.history = response
|
202 |
+
return self.history
|
203 |
+
|
204 |
+
else:
|
205 |
+
self.__get_index_file(user_message)
|
206 |
+
if self.context is None:
|
207 |
+
response = [
|
208 |
+
[user_message, "無法找到相關文件,請重新提問"],
|
209 |
+
]
|
210 |
+
self.history = response
|
211 |
+
return self.history
|
212 |
+
else:
|
213 |
+
qa_processor = QuestionAnswerer()
|
214 |
+
bot_message = qa_processor.answer_question(
|
215 |
+
self.context,
|
216 |
+
self.context_page_num,
|
217 |
+
self.context_file_name,
|
218 |
+
self.history,
|
219 |
+
)
|
220 |
+
print(f"bot_message: {bot_message}")
|
221 |
+
response = [
|
222 |
+
[user_message, bot_message],
|
223 |
+
]
|
224 |
+
self.history[-1] = response[0]
|
225 |
+
return self.history
|
226 |
+
|
227 |
+
def __get_index_file(self, user_message):
|
228 |
+
user_message_embedding = openai.Embedding.create(
|
229 |
+
input=user_message, engine="text-embedding-ada-002"
|
230 |
+
)["data"][0]["embedding"]
|
231 |
+
|
232 |
+
self.knowledge_base["distance"] = distances_from_embeddings(
|
233 |
+
user_message_embedding,
|
234 |
+
self.knowledge_base["page_embedding"].values,
|
235 |
+
distance_metric="cosine",
|
236 |
+
)
|
237 |
+
self.knowledge_base = self.knowledge_base.sort_values(
|
238 |
+
by="distance", ascending=True
|
239 |
+
)
|
240 |
+
|
241 |
+
if self.knowledge_base["distance"].values[0] > 0.2:
|
242 |
+
self.context = None
|
243 |
+
else:
|
244 |
+
self.context = self.knowledge_base["page_content"].values[0]
|
245 |
+
self.context_page_num = self.knowledge_base["page_num"].values[0]
|
246 |
+
self.context_file_name = self.knowledge_base["file_name"].values[0]
|
247 |
+
|
248 |
+
def __generate_uid(self):
|
249 |
+
return secrets.token_hex(8)
|
utils/gpt_processor.py
CHANGED
@@ -24,38 +24,30 @@ class GPTAgent:
|
|
24 |
response = self.agent.complete(messages=messages)
|
25 |
return response.choices[0].message["content"]
|
26 |
|
27 |
-
def split_into_many(
|
28 |
tokenizer = tiktoken.get_encoding("cl100k_base")
|
29 |
-
# Split the text into sentences
|
30 |
-
sentences = text.split("。")
|
31 |
|
32 |
-
|
33 |
n_tokens = [len(tokenizer.encode(" " + sentence)) for sentence in sentences]
|
34 |
|
35 |
chunks = []
|
36 |
tokens_so_far = 0
|
37 |
chunk = []
|
38 |
|
39 |
-
# Loop through the sentences and tokens joined together in a tuple
|
40 |
for sentence, token in zip(sentences, n_tokens):
|
41 |
-
|
42 |
-
|
43 |
-
# the chunk and tokens so far
|
44 |
-
if tokens_so_far + token > self.split_max_tokens:
|
45 |
chunks.append("。".join(chunk) + "。")
|
46 |
chunk = []
|
47 |
tokens_so_far = 0
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
if token > self.split_max_tokens:
|
52 |
-
continue
|
53 |
-
|
54 |
-
# Otherwise, add the sentence to the chunk and add the number of tokens to the total
|
55 |
chunk.append(sentence)
|
56 |
tokens_so_far += token + 1
|
57 |
|
58 |
-
|
|
|
59 |
return [text] if len(chunks) == 0 else chunks
|
60 |
|
61 |
def preprocess(self, text):
|
@@ -202,10 +194,59 @@ class Summarizer(GPTAgent):
|
|
202 |
system_prompt = """
|
203 |
請幫我總結以下的文章。
|
204 |
"""
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
try:
|
210 |
response = openai.ChatCompletion.create(
|
211 |
model=self.model,
|
@@ -224,7 +265,7 @@ class Summarizer(GPTAgent):
|
|
224 |
response["choices"][0]["message"]["content"]
|
225 |
)
|
226 |
|
227 |
-
return
|
228 |
|
229 |
|
230 |
class QuestionAnswerer(GPTAgent):
|
|
|
24 |
response = self.agent.complete(messages=messages)
|
25 |
return response.choices[0].message["content"]
|
26 |
|
27 |
+
def split_into_many(text):
|
28 |
tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
|
|
|
29 |
|
30 |
+
sentences = text.split("。")
|
31 |
n_tokens = [len(tokenizer.encode(" " + sentence)) for sentence in sentences]
|
32 |
|
33 |
chunks = []
|
34 |
tokens_so_far = 0
|
35 |
chunk = []
|
36 |
|
|
|
37 |
for sentence, token in zip(sentences, n_tokens):
|
38 |
+
|
39 |
+
if tokens_so_far + token > 500:
|
|
|
|
|
40 |
chunks.append("。".join(chunk) + "。")
|
41 |
chunk = []
|
42 |
tokens_so_far = 0
|
43 |
|
44 |
+
if token > 500:
|
45 |
+
continue
|
|
|
|
|
|
|
|
|
46 |
chunk.append(sentence)
|
47 |
tokens_so_far += token + 1
|
48 |
|
49 |
+
chunks.append("。".join(chunk) + "。")
|
50 |
+
|
51 |
return [text] if len(chunks) == 0 else chunks
|
52 |
|
53 |
def preprocess(self, text):
|
|
|
194 |
system_prompt = """
|
195 |
請幫我總結以下的文章。
|
196 |
"""
|
197 |
+
|
198 |
+
text_chunks = self.split_into_many(text)
|
199 |
+
if len(text_chunks) > 1:
|
200 |
+
concated_summary = ""
|
201 |
+
for i in range(len(text_chunks)):
|
202 |
+
text_chunk = text[i].replace("\n", " ").replace("\r", "")
|
203 |
+
messages = [
|
204 |
+
{"role": "system", "content": f"{system_prompt}"},
|
205 |
+
{"role": "user", "content": text_chunk},
|
206 |
+
]
|
207 |
+
try:
|
208 |
+
response = openai.ChatCompletion.create(
|
209 |
+
model=self.model,
|
210 |
+
messages=messages,
|
211 |
+
temperature=self.temperature,
|
212 |
+
max_tokens=self.max_tokens,
|
213 |
+
frequency_penalty=self.frequency_penalty,
|
214 |
+
presence_penalty=self.presence_penalty,
|
215 |
+
)
|
216 |
+
except Exception as e:
|
217 |
+
logging.error(e)
|
218 |
+
logging.error("Failed to summarize text_chunk")
|
219 |
+
chinese_converter = OpenCC("s2tw")
|
220 |
+
concated_summary += chinese_converter.convert(
|
221 |
+
response["choices"][0]["message"]["content"].strip()
|
222 |
+
)
|
223 |
+
|
224 |
+
# summarize concated_summary
|
225 |
+
messages = [
|
226 |
+
{"role": "system", "content": f"{system_prompt}"},
|
227 |
+
{"role": "user", "content": concated_summary},
|
228 |
+
]
|
229 |
+
try:
|
230 |
+
response = openai.ChatCompletion.create(
|
231 |
+
model=self.model,
|
232 |
+
messages=messages,
|
233 |
+
temperature=self.temperature,
|
234 |
+
max_tokens=self.max_tokens,
|
235 |
+
frequency_penalty=self.frequency_penalty,
|
236 |
+
presence_penalty=self.presence_penalty,
|
237 |
+
)
|
238 |
+
except Exception as e:
|
239 |
+
logging.error(e)
|
240 |
+
logging.error("Failed to summarize concated_summary")
|
241 |
+
chinese_converter = OpenCC("s2tw")
|
242 |
+
return chinese_converter.convert(
|
243 |
+
response["choices"][0]["message"]["content"].strip()
|
244 |
+
)
|
245 |
+
else:
|
246 |
+
messages = [
|
247 |
+
{"role": "system", "content": f"{system_prompt}"},
|
248 |
+
{"role": "user", "content": text},
|
249 |
+
]
|
250 |
try:
|
251 |
response = openai.ChatCompletion.create(
|
252 |
model=self.model,
|
|
|
265 |
response["choices"][0]["message"]["content"]
|
266 |
)
|
267 |
|
268 |
+
return response
|
269 |
|
270 |
|
271 |
class QuestionAnswerer(GPTAgent):
|