ChenyuRabbitLove commited on
Commit
e4c798e
·
1 Parent(s): abab449

feat: add summerizer map-reduce

Browse files
Files changed (2) hide show
  1. utils/chatbot_diff.py +249 -0
  2. utils/gpt_processor.py +62 -21
utils/chatbot_diff.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import json
4
+ import logging
5
+ import secrets
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import openai
10
+ import pandas as pd
11
+ from google.oauth2.service_account import Credentials
12
+ from googleapiclient.discovery import build
13
+ from googleapiclient.http import MediaIoBaseDownload, MediaFileUpload
14
+ from openai.embeddings_utils import distances_from_embeddings
15
+
16
+ from .gpt_processor import QuestionAnswerer
17
+ from .work_flow_controller import WorkFlowController
18
+
19
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
20
+ openai.api_key = OPENAI_API_KEY
21
+
22
+
23
+ class Chatbot:
24
+ def __init__(self):
25
+ self.history = []
26
+ self.upload_state = "waiting"
27
+ self.uid = self.__generate_uid()
28
+
29
+ self.g_drive_service = self.__init_drive_service()
30
+ self.knowledge_base = None
31
+ self.context = None
32
+ self.context_page_num = None
33
+ self.context_file_name = None
34
+
35
+ def build_knowledge_base(self, files, upload_mode="once"):
36
+ work_flow_controller = WorkFlowController(files, self.uid)
37
+ self.csv_result_path = work_flow_controller.csv_result_path
38
+ self.json_result_path = work_flow_controller.json_result_path
39
+
40
+ if upload_mode == "Upload to Database":
41
+ self.__get_db_knowledge_base()
42
+ else:
43
+ self.__get_local_knowledge_base()
44
+
45
+ def __get_db_knowledge_base(self):
46
+ filename = "knowledge_base.csv"
47
+ db = self.__read_db(self.g_drive_service)
48
+ cur_content = pd.read_csv(self.csv_result_path)
49
+ for _ in range(10):
50
+ try:
51
+ self.__write_into_db(self.g_drive_service, db, cur_content)
52
+ break
53
+ except Exception as e:
54
+ logging.error(e)
55
+ logging.error("Failed to upload to database, retrying...")
56
+ continue
57
+ self.knowledge_base = db
58
+ self.upload_state = "done"
59
+
60
+ def __get_local_knowledge_base(self):
61
+ with open(self.csv_result_path, "r", encoding="UTF-8") as fp:
62
+ knowledge_base = pd.read_csv(fp)
63
+ knowledge_base["page_embedding"] = (
64
+ knowledge_base["page_embedding"].apply(eval).apply(np.array)
65
+ )
66
+
67
+ self.knowledge_base = knowledge_base
68
+ self.upload_state = "done"
69
+
70
+ def __write_into_db(self, service, db: pd.DataFrame, cur_content: pd.DataFrame):
71
+ db = pd.concat([db, cur_content], ignore_index=True)
72
+ db.to_csv(f"{self.uid}_knowledge_base.csv", index=False)
73
+ media = MediaFileUpload(f"{self.uid}_knowledge_base.csv", resumable=True)
74
+ request = (
75
+ service.files()
76
+ .update(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW", media_body=media)
77
+ .execute()
78
+ )
79
+
80
+ def __init_drive_service(self):
81
+ SCOPES = ["https://www.googleapis.com/auth/drive"]
82
+ SERVICE_ACCOUNT_INFO = os.getenv("CREDENTIALS")
83
+ service_account_info_dict = json.loads(SERVICE_ACCOUNT_INFO)
84
+
85
+ creds = Credentials.from_service_account_info(
86
+ service_account_info_dict, scopes=SCOPES
87
+ )
88
+
89
+ return build("drive", "v3", credentials=creds)
90
+
91
+ def __read_db(self, service):
92
+ request = service.files().get_media(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW")
93
+ fh = io.BytesIO()
94
+ downloader = MediaIoBaseDownload(fh, request)
95
+
96
+ done = False
97
+ while done is False:
98
+ status, done = downloader.next_chunk()
99
+ print(f"Download {int(status.progress() * 100)}%.")
100
+
101
+ fh.seek(0)
102
+
103
+ return pd.read_csv(fh)
104
+
105
+ def __read_file(self, service, filename) -> pd.DataFrame:
106
+ query = f"name='{filename}'"
107
+ results = service.files().list(q=query).execute()
108
+ files = results.get("files", [])
109
+
110
+ file_id = files[0]["id"]
111
+
112
+ request = service.files().get_media(fileId=file_id)
113
+ fh = io.BytesIO()
114
+ downloader = MediaIoBaseDownload(fh, request)
115
+
116
+ done = False
117
+ while done is False:
118
+ status, done = downloader.next_chunk()
119
+ print(f"Download {int(status.progress() * 100)}%.")
120
+
121
+ fh.seek(0)
122
+
123
+ return pd.read_csv(fh)
124
+
125
+ def __upload_file(self, service):
126
+ results = service.files().list(pageSize=10).execute()
127
+ items = results.get("files", [])
128
+ if not items:
129
+ print("No files found.")
130
+ else:
131
+ print("Files:")
132
+ for item in items:
133
+ print(f"{item['name']} ({item['id']})")
134
+
135
+ media = MediaFileUpload(self.csv_result_path, resumable=True)
136
+ filename_prefix = "ex_bot_database_"
137
+ filename = filename_prefix + self.uid + ".csv"
138
+ request = (
139
+ service.files()
140
+ .create(
141
+ media_body=media,
142
+ body={
143
+ "name": filename,
144
+ "parents": [
145
+ "1Lp21EZlVlqL-c27VQBC6wTbUC1YpKMsG"
146
+ ],
147
+ },
148
+ )
149
+ .execute()
150
+ )
151
+
152
+ def clear_state(self):
153
+ self.context = None
154
+ self.context_page_num = None
155
+ self.context_file_name = None
156
+ self.knowledge_base = None
157
+ self.upload_state = "waiting"
158
+ self.history = []
159
+
160
+ def send_system_notification(self):
161
+ if self.upload_state == "waiting":
162
+ conversation = [["已上傳文件", "文件處理中(摘要、翻譯等),結束後將自動回覆"]]
163
+ return conversation
164
+ elif self.upload_state == "done":
165
+ conversation = [["已上傳文件", "文件處理完成,請開始提問"]]
166
+ return conversation
167
+
168
+ def change_md(self):
169
+ content = self.__construct_summary()
170
+ return gr.Markdown.update(content, visible=True)
171
+
172
+ def __construct_summary(self):
173
+ with open(self.json_result_path, "r", encoding="UTF-8") as fp:
174
+ knowledge_base = json.load(fp)
175
+
176
+ context = ""
177
+ for key in knowledge_base.keys():
178
+ file_name = knowledge_base[key]["file_name"]
179
+ total_page = knowledge_base[key]["total_pages"]
180
+ summary = knowledge_base[key]["summarized_content"]
181
+ file_context = f"""
182
+ ### 文件摘要
183
+ {file_name} (共 {total_page} 頁)<br><br>
184
+ {summary}<br><br>
185
+ """
186
+ context += file_context
187
+ return context
188
+
189
+ def user(self, message):
190
+ self.history += [[message, None]]
191
+ return "", self.history
192
+
193
+ def bot(self):
194
+ user_message = self.history[-1][0]
195
+ print(f"user_message: {user_message}")
196
+
197
+ if self.knowledge_base is None:
198
+ response = [
199
+ [user_message, "請先上傳文件"],
200
+ ]
201
+ self.history = response
202
+ return self.history
203
+
204
+ else:
205
+ self.__get_index_file(user_message)
206
+ if self.context is None:
207
+ response = [
208
+ [user_message, "無法找到相關文件,請重新提問"],
209
+ ]
210
+ self.history = response
211
+ return self.history
212
+ else:
213
+ qa_processor = QuestionAnswerer()
214
+ bot_message = qa_processor.answer_question(
215
+ self.context,
216
+ self.context_page_num,
217
+ self.context_file_name,
218
+ self.history,
219
+ )
220
+ print(f"bot_message: {bot_message}")
221
+ response = [
222
+ [user_message, bot_message],
223
+ ]
224
+ self.history[-1] = response[0]
225
+ return self.history
226
+
227
+ def __get_index_file(self, user_message):
228
+ user_message_embedding = openai.Embedding.create(
229
+ input=user_message, engine="text-embedding-ada-002"
230
+ )["data"][0]["embedding"]
231
+
232
+ self.knowledge_base["distance"] = distances_from_embeddings(
233
+ user_message_embedding,
234
+ self.knowledge_base["page_embedding"].values,
235
+ distance_metric="cosine",
236
+ )
237
+ self.knowledge_base = self.knowledge_base.sort_values(
238
+ by="distance", ascending=True
239
+ )
240
+
241
+ if self.knowledge_base["distance"].values[0] > 0.2:
242
+ self.context = None
243
+ else:
244
+ self.context = self.knowledge_base["page_content"].values[0]
245
+ self.context_page_num = self.knowledge_base["page_num"].values[0]
246
+ self.context_file_name = self.knowledge_base["file_name"].values[0]
247
+
248
+ def __generate_uid(self):
249
+ return secrets.token_hex(8)
utils/gpt_processor.py CHANGED
@@ -24,38 +24,30 @@ class GPTAgent:
24
  response = self.agent.complete(messages=messages)
25
  return response.choices[0].message["content"]
26
 
27
- def split_into_many(self, text) -> List[str]:
28
  tokenizer = tiktoken.get_encoding("cl100k_base")
29
- # Split the text into sentences
30
- sentences = text.split("。")
31
 
32
- # Get the number of tokens for each sentence
33
  n_tokens = [len(tokenizer.encode(" " + sentence)) for sentence in sentences]
34
 
35
  chunks = []
36
  tokens_so_far = 0
37
  chunk = []
38
 
39
- # Loop through the sentences and tokens joined together in a tuple
40
  for sentence, token in zip(sentences, n_tokens):
41
- # If the number of tokens so far plus the number of tokens in the current sentence is greater
42
- # than the max number of tokens, then add the chunk to the list of chunks and reset
43
- # the chunk and tokens so far
44
- if tokens_so_far + token > self.split_max_tokens:
45
  chunks.append("。".join(chunk) + "。")
46
  chunk = []
47
  tokens_so_far = 0
48
 
49
- # If the number of tokens in the current sentence is greater than the max number of
50
- # tokens, go to the next sentence
51
- if token > self.split_max_tokens:
52
- continue
53
-
54
- # Otherwise, add the sentence to the chunk and add the number of tokens to the total
55
  chunk.append(sentence)
56
  tokens_so_far += token + 1
57
 
58
- # if the length of the text is less than the max number of tokens, then return the text
 
59
  return [text] if len(chunks) == 0 else chunks
60
 
61
  def preprocess(self, text):
@@ -202,10 +194,59 @@ class Summarizer(GPTAgent):
202
  system_prompt = """
203
  請幫我總結以下的文章。
204
  """
205
- messages = [
206
- {"role": "system", "content": f"{system_prompt}"},
207
- {"role": "user", "content": text},
208
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  try:
210
  response = openai.ChatCompletion.create(
211
  model=self.model,
@@ -224,7 +265,7 @@ class Summarizer(GPTAgent):
224
  response["choices"][0]["message"]["content"]
225
  )
226
 
227
- return re.sub(r"\n+", "<br>", response)
228
 
229
 
230
  class QuestionAnswerer(GPTAgent):
 
24
  response = self.agent.complete(messages=messages)
25
  return response.choices[0].message["content"]
26
 
27
+ def split_into_many(text):
28
  tokenizer = tiktoken.get_encoding("cl100k_base")
 
 
29
 
30
+ sentences = text.split("。")
31
  n_tokens = [len(tokenizer.encode(" " + sentence)) for sentence in sentences]
32
 
33
  chunks = []
34
  tokens_so_far = 0
35
  chunk = []
36
 
 
37
  for sentence, token in zip(sentences, n_tokens):
38
+
39
+ if tokens_so_far + token > 500:
 
 
40
  chunks.append("。".join(chunk) + "。")
41
  chunk = []
42
  tokens_so_far = 0
43
 
44
+ if token > 500:
45
+ continue
 
 
 
 
46
  chunk.append(sentence)
47
  tokens_so_far += token + 1
48
 
49
+ chunks.append("。".join(chunk) + "。")
50
+
51
  return [text] if len(chunks) == 0 else chunks
52
 
53
  def preprocess(self, text):
 
194
  system_prompt = """
195
  請幫我總結以下的文章。
196
  """
197
+
198
+ text_chunks = self.split_into_many(text)
199
+ if len(text_chunks) > 1:
200
+ concated_summary = ""
201
+ for i in range(len(text_chunks)):
202
+ text_chunk = text[i].replace("\n", " ").replace("\r", "")
203
+ messages = [
204
+ {"role": "system", "content": f"{system_prompt}"},
205
+ {"role": "user", "content": text_chunk},
206
+ ]
207
+ try:
208
+ response = openai.ChatCompletion.create(
209
+ model=self.model,
210
+ messages=messages,
211
+ temperature=self.temperature,
212
+ max_tokens=self.max_tokens,
213
+ frequency_penalty=self.frequency_penalty,
214
+ presence_penalty=self.presence_penalty,
215
+ )
216
+ except Exception as e:
217
+ logging.error(e)
218
+ logging.error("Failed to summarize text_chunk")
219
+ chinese_converter = OpenCC("s2tw")
220
+ concated_summary += chinese_converter.convert(
221
+ response["choices"][0]["message"]["content"].strip()
222
+ )
223
+
224
+ # summarize concated_summary
225
+ messages = [
226
+ {"role": "system", "content": f"{system_prompt}"},
227
+ {"role": "user", "content": concated_summary},
228
+ ]
229
+ try:
230
+ response = openai.ChatCompletion.create(
231
+ model=self.model,
232
+ messages=messages,
233
+ temperature=self.temperature,
234
+ max_tokens=self.max_tokens,
235
+ frequency_penalty=self.frequency_penalty,
236
+ presence_penalty=self.presence_penalty,
237
+ )
238
+ except Exception as e:
239
+ logging.error(e)
240
+ logging.error("Failed to summarize concated_summary")
241
+ chinese_converter = OpenCC("s2tw")
242
+ return chinese_converter.convert(
243
+ response["choices"][0]["message"]["content"].strip()
244
+ )
245
+ else:
246
+ messages = [
247
+ {"role": "system", "content": f"{system_prompt}"},
248
+ {"role": "user", "content": text},
249
+ ]
250
  try:
251
  response = openai.ChatCompletion.create(
252
  model=self.model,
 
265
  response["choices"][0]["message"]["content"]
266
  )
267
 
268
+ return response
269
 
270
 
271
  class QuestionAnswerer(GPTAgent):