ChenyuRabbitLove commited on
Commit
c88c1d9
·
1 Parent(s): 1beaddf

feat/add g-drive coontection

Browse files
Files changed (3) hide show
  1. app.py +9 -4
  2. utils/chatbot.py +112 -7
  3. utils/work_flow_controller.py +12 -10
app.py CHANGED
@@ -30,7 +30,11 @@ with gr.Blocks() as demo:
30
 
31
  with gr.Row():
32
  index_file = gr.File(
33
- file_count="multiple", file_types=["pdf"], label="Upload PDF file"
 
 
 
 
34
  )
35
 
36
  with gr.Row():
@@ -42,7 +46,8 @@ with gr.Blocks() as demo:
42
  3. 可以根據下方的摘要內容來提問
43
  4. 每次對話會根據第一個問題的內容來檢索所有文件,並挑選最能回答問題的文件來回覆
44
  5. 要切換檢索的文件,請點選「清除」按鈕後再重新提問
45
- """
 
46
  )
47
 
48
  with gr.Row():
@@ -80,6 +85,7 @@ with gr.Blocks() as demo:
80
  **bot_args
81
  ).then(lambda: gr.update(interactive=True), None, [user_input], queue=False)
82
 
 
83
  # defining workflow of clear state
84
  clear_state_args = dict(
85
  fn=clear_state,
@@ -98,7 +104,7 @@ with gr.Blocks() as demo:
98
 
99
  bulid_knowledge_base_args = dict(
100
  fn=build_knowledge_base,
101
- inputs=[user_chatbot, index_file],
102
  outputs=None,
103
  )
104
 
@@ -118,6 +124,5 @@ with gr.Blocks() as demo:
118
 
119
  video_text_input.submit(video_bot, [test_video_chabot, video_text_input], video_text_output, api_name="video_bot")
120
 
121
-
122
  if __name__ == "__main__":
123
  demo.launch()
 
30
 
31
  with gr.Row():
32
  index_file = gr.File(
33
+ file_count="multiple", file_types=["pdf"], label="Upload PDF file", scale=3
34
+ )
35
+ upload_to_db = gr.CheckboxGroup(
36
+ ["Upload to Database"],
37
+ label="是否上傳至資料庫", info="將資料上傳至資料庫時,資料庫會自動建立索引,下次使用時可以直接檢索,預設為僅作這次使用", scale=1
38
  )
39
 
40
  with gr.Row():
 
46
  3. 可以根據下方的摘要內容來提問
47
  4. 每次對話會根據第一個問題的內容來檢索所有文件,並挑選最能回答問題的文件來回覆
48
  5. 要切換檢索的文件,請點選「清除」按鈕後再重新提問
49
+
50
+ """,
51
  )
52
 
53
  with gr.Row():
 
85
  **bot_args
86
  ).then(lambda: gr.update(interactive=True), None, [user_input], queue=False)
87
 
88
+
89
  # defining workflow of clear state
90
  clear_state_args = dict(
91
  fn=clear_state,
 
104
 
105
  bulid_knowledge_base_args = dict(
106
  fn=build_knowledge_base,
107
+ inputs=[user_chatbot, index_file, upload_to_db],
108
  outputs=None,
109
  )
110
 
 
124
 
125
  video_text_input.submit(video_bot, [test_video_chabot, video_text_input], video_text_output, api_name="video_bot")
126
 
 
127
  if __name__ == "__main__":
128
  demo.launch()
utils/chatbot.py CHANGED
@@ -1,31 +1,62 @@
1
- import json
2
  import os
 
 
 
3
 
 
 
4
  import openai
5
  import pandas as pd
6
- import numpy as np
7
- import gradio as gr
 
8
  from openai.embeddings_utils import distances_from_embeddings
9
 
10
- from .work_flow_controller import WorkFlowController
11
  from .gpt_processor import QuestionAnswerer
 
12
 
 
 
13
 
14
  class Chatbot:
15
  def __init__(self) -> None:
16
  self.history = []
17
  self.upload_state = "waiting"
 
18
 
 
19
  self.knowledge_base = None
20
  self.context = None
21
  self.context_page_num = None
22
  self.context_file_name = None
23
 
24
- def build_knowledge_base(self, files):
25
- work_flow_controller = WorkFlowController(files)
26
  self.csv_result_path = work_flow_controller.csv_result_path
27
  self.json_result_path = work_flow_controller.json_result_path
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  with open(self.csv_result_path, "r", encoding="UTF-8") as fp:
30
  knowledge_base = pd.read_csv(fp)
31
  knowledge_base["page_embedding"] = (
@@ -35,10 +66,81 @@ class Chatbot:
35
  self.knowledge_base = knowledge_base
36
  self.upload_state = "done"
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def clear_state(self):
39
  self.context = None
40
  self.context_page_num = None
41
  self.context_file_name = None
 
42
  self.upload_state = "waiting"
43
  self.history = []
44
 
@@ -130,9 +232,12 @@ class Chatbot:
130
  self.context_page_num = self.knowledge_base["page_num"].values[0]
131
  self.context_file_name = self.knowledge_base["file_name"].values[0]
132
 
 
 
 
 
133
  class VideoChatbot:
134
  def __init__(self) -> None:
135
- openai.api_key = os.getenv("OPENAI_API_KEY")
136
  self.metadata_keys = ["標題", "逐字稿", "摘要", "關鍵字"]
137
  self.metadata = {
138
  "c2fK-hxnPSY":{
 
1
+ import io
2
  import os
3
+ import json
4
+ import logging
5
+ import secrets
6
 
7
+ import gradio as gr
8
+ import numpy as np
9
  import openai
10
  import pandas as pd
11
+ from google.oauth2.service_account import Credentials
12
+ from googleapiclient.discovery import build
13
+ from googleapiclient.http import MediaIoBaseDownload, MediaFileUpload
14
  from openai.embeddings_utils import distances_from_embeddings
15
 
 
16
  from .gpt_processor import QuestionAnswerer
17
+ from .work_flow_controller import WorkFlowController
18
 
19
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
20
+ openai.api_key = OPENAI_API_KEY
21
 
22
  class Chatbot:
23
  def __init__(self) -> None:
24
  self.history = []
25
  self.upload_state = "waiting"
26
+ self.uid = self.__generate_uid()
27
 
28
+ self.g_drive_service = self.__init_drive_service()
29
  self.knowledge_base = None
30
  self.context = None
31
  self.context_page_num = None
32
  self.context_file_name = None
33
 
34
+ def build_knowledge_base(self, files, upload_mode="僅作這次使用"):
35
+ work_flow_controller = WorkFlowController(files, self.uid)
36
  self.csv_result_path = work_flow_controller.csv_result_path
37
  self.json_result_path = work_flow_controller.json_result_path
38
 
39
+ if upload_mode == "上傳至資料庫":
40
+ self.knowledge_base = self.__get_db_knowledge_base()
41
+ else:
42
+ self.knowledge_base = self.__get_local_knowledge_base()
43
+
44
+ def __get_db_knowledge_base(self):
45
+ filename = "knowledge_base.csv"
46
+ db = self.__read_db(self.g_drive_service)
47
+ cur_content = pd.read_csv(self.csv_result_path)
48
+ for _ in range(10):
49
+ try:
50
+ self.__write_into_db(self.g_drive_service, db, cur_content)
51
+ break
52
+ except Exception as e:
53
+ logging.error(e)
54
+ logging.error("Failed to upload to database, retrying...")
55
+ continue
56
+ self.knowledge_base = db
57
+ self.upload_state = "done"
58
+
59
+ def __get_local_knowledge_base(self):
60
  with open(self.csv_result_path, "r", encoding="UTF-8") as fp:
61
  knowledge_base = pd.read_csv(fp)
62
  knowledge_base["page_embedding"] = (
 
66
  self.knowledge_base = knowledge_base
67
  self.upload_state = "done"
68
 
69
+ def __write_into_db(self, service, db: pd.DataFrame, cur_content: pd.DataFrame):
70
+ # db = pd.concat([db, cur_content], ignore_index=True)
71
+ # db.to_csv(f"{self.uid}_knowledge_base.csv", index=False)
72
+ cur_content.to_csv(f"{self.uid}_knowledge_base.csv", index=False)
73
+ media = MediaFileUpload(f"{self.uid}_knowledge_base.csv", resumable=True)
74
+ request = service.files().update(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW", media_body=media).execute()
75
+
76
+ def __init_drive_service(self):
77
+ SCOPES = ['https://www.googleapis.com/auth/drive']
78
+ SERVICE_ACCOUNT_FILE = os.getenv("CREDENTIALS")
79
+
80
+ creds = Credentials.from_service_account_file(SERVICE_ACCOUNT_FILE, scopes=SCOPES)
81
+
82
+ return build('drive', 'v3', credentials=creds)
83
+
84
+ def __read_db(self, service):
85
+ request = service.files().get_media(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW")
86
+ fh = io.BytesIO()
87
+ downloader = MediaIoBaseDownload(fh, request)
88
+
89
+ done = False
90
+ while done is False:
91
+ status, done = downloader.next_chunk()
92
+ print(f"Download {int(status.progress() * 100)}%.")
93
+
94
+ # file_content = fh.getvalue().decode('utf-8')
95
+ fh.seek(0)
96
+
97
+ return pd.read_csv(fh)
98
+
99
+ def __read_file(self, service, filename) -> pd.DataFrame:
100
+ query = f"name='{filename}'"
101
+ results = service.files().list(q=query).execute()
102
+ files = results.get('files', [])
103
+
104
+ file_id = files[0]['id']
105
+
106
+ request = service.files().get_media(fileId=file_id)
107
+ fh = io.BytesIO()
108
+ downloader = MediaIoBaseDownload(fh, request)
109
+
110
+ done = False
111
+ while done is False:
112
+ status, done = downloader.next_chunk()
113
+ print(f"Download {int(status.progress() * 100)}%.")
114
+
115
+ # file_content = fh.getvalue().decode('utf-8')
116
+ fh.seek(0)
117
+
118
+ return pd.read_csv(fh)
119
+
120
+ def __upload_file(self, service):
121
+ results = service.files().list(pageSize=10).execute()
122
+ items = results.get('files', [])
123
+ if not items:
124
+ print('No files found.')
125
+ else:
126
+ print('Files:')
127
+ for item in items:
128
+ print(f"{item['name']} ({item['id']})")
129
+
130
+ media = MediaFileUpload(self.csv_result_path, resumable=True)
131
+ filename_prefix = 'ex_bot_database_'
132
+ filename = filename_prefix + self.uid + '.csv'
133
+ request = service.files().create(media_body=media, body={
134
+ 'name': filename,
135
+ 'parents': ["1Lp21EZlVlqL-c27VQBC6wTbUC1YpKMsG"] # Optional, to place the file in a specific folder
136
+ }).execute()
137
+
138
+
139
  def clear_state(self):
140
  self.context = None
141
  self.context_page_num = None
142
  self.context_file_name = None
143
+ self.knowledge_base = None
144
  self.upload_state = "waiting"
145
  self.history = []
146
 
 
232
  self.context_page_num = self.knowledge_base["page_num"].values[0]
233
  self.context_file_name = self.knowledge_base["file_name"].values[0]
234
 
235
+ def __generate_uid(self):
236
+ return secrets.token_hex(8)
237
+
238
+
239
  class VideoChatbot:
240
  def __init__(self) -> None:
 
241
  self.metadata_keys = ["標題", "逐字稿", "摘要", "關鍵字"]
242
  self.metadata = {
243
  "c2fK-hxnPSY":{
utils/work_flow_controller.py CHANGED
@@ -20,10 +20,11 @@ processors = {
20
 
21
 
22
  class WorkFlowController:
23
- def __init__(self, file_src) -> None:
24
  # check if the file_path is list
25
  # self.file_paths = self.__get_file_name(file_src)
26
  self.file_paths = [x.name for x in file_src]
 
27
 
28
  print(self.file_paths)
29
 
@@ -83,6 +84,7 @@ class WorkFlowController:
83
 
84
  for i, _ in enumerate(file["file_content"]):
85
  # use i+1 to meet the index of file_content
 
86
  file["file_content"][i + 1][
87
  "page_content"
88
  ] = translator.translate_to_chinese(
@@ -97,33 +99,34 @@ class WorkFlowController:
97
  # process file content
98
  # return processed data
99
  if not file["is_chinese"]:
 
100
  file = self.__translate_to_chinese(file)
 
101
  file = self.__get_embedding(file)
 
102
  file = self.__get_summary(file)
103
  return file
104
 
105
  def __dump_to_json(self):
106
  with open(
107
- os.path.join(os.getcwd(), "knowledge_base.json"), "w", encoding="utf-8"
108
  ) as f:
109
  print(
110
  "Dumping to json, the path is: "
111
- + os.path.join(os.getcwd(), "knowledge_base.json")
112
  )
113
- self.json_result_path = os.path.join(os.getcwd(), "knowledge_base.json")
114
  json.dump(self.files_info, f, indent=4, ensure_ascii=False)
115
 
116
  def __construct_knowledge_base_dataframe(self):
117
  rows = []
118
  for file_path, content in self.files_info.items():
119
- file_full_content = content["file_full_content"]
120
  for page_num, page_details in content["file_content"].items():
121
  row = {
122
  "file_name": content["file_name"],
123
  "page_num": page_details["page_num"],
124
  "page_content": page_details["page_content"],
125
  "page_embedding": page_details["page_embedding"],
126
- "file_full_content": file_full_content,
127
  }
128
  rows.append(row)
129
 
@@ -132,19 +135,18 @@ class WorkFlowController:
132
  "page_num",
133
  "page_content",
134
  "page_embedding",
135
- "file_full_content",
136
  ]
137
  df = pd.DataFrame(rows, columns=columns)
138
  return df
139
 
140
  def __dump_to_csv(self):
141
  df = self.__construct_knowledge_base_dataframe()
142
- df.to_csv(os.path.join(os.getcwd(), "knowledge_base.csv"), index=False)
143
  print(
144
  "Dumping to csv, the path is: "
145
- + os.path.join(os.getcwd(), "knowledge_base.csv")
146
  )
147
- self.csv_result_path = os.path.join(os.getcwd(), "knowledge_base.csv")
148
 
149
  def __get_file_name(self, file_src):
150
  file_paths = [x.name for x in file_src]
 
20
 
21
 
22
  class WorkFlowController:
23
+ def __init__(self, file_src, uid) -> None:
24
  # check if the file_path is list
25
  # self.file_paths = self.__get_file_name(file_src)
26
  self.file_paths = [x.name for x in file_src]
27
+ self.uid = uid
28
 
29
  print(self.file_paths)
30
 
 
84
 
85
  for i, _ in enumerate(file["file_content"]):
86
  # use i+1 to meet the index of file_content
87
+ print("Translating page: " + str(i + 1))
88
  file["file_content"][i + 1][
89
  "page_content"
90
  ] = translator.translate_to_chinese(
 
99
  # process file content
100
  # return processed data
101
  if not file["is_chinese"]:
102
+ print("Translating to chinese...")
103
  file = self.__translate_to_chinese(file)
104
+ print("Getting embedding...")
105
  file = self.__get_embedding(file)
106
+ print("Getting summary...")
107
  file = self.__get_summary(file)
108
  return file
109
 
110
  def __dump_to_json(self):
111
  with open(
112
+ os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.json"), "w", encoding="utf-8"
113
  ) as f:
114
  print(
115
  "Dumping to json, the path is: "
116
+ + os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.json")
117
  )
118
+ self.json_result_path = os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.json")
119
  json.dump(self.files_info, f, indent=4, ensure_ascii=False)
120
 
121
  def __construct_knowledge_base_dataframe(self):
122
  rows = []
123
  for file_path, content in self.files_info.items():
 
124
  for page_num, page_details in content["file_content"].items():
125
  row = {
126
  "file_name": content["file_name"],
127
  "page_num": page_details["page_num"],
128
  "page_content": page_details["page_content"],
129
  "page_embedding": page_details["page_embedding"],
 
130
  }
131
  rows.append(row)
132
 
 
135
  "page_num",
136
  "page_content",
137
  "page_embedding",
 
138
  ]
139
  df = pd.DataFrame(rows, columns=columns)
140
  return df
141
 
142
  def __dump_to_csv(self):
143
  df = self.__construct_knowledge_base_dataframe()
144
+ df.to_csv(os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.csv"), index=False)
145
  print(
146
  "Dumping to csv, the path is: "
147
+ + os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.csv")
148
  )
149
+ self.csv_result_path = os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.csv")
150
 
151
  def __get_file_name(self, file_src):
152
  file_paths = [x.name for x in file_src]