Spaces:
Runtime error
Runtime error
import os | |
import json | |
import logging | |
import hashlib | |
import pandas as pd | |
from .gpt_processor import (EmbeddingGenerator, KeywordsGenerator, Summarizer, | |
TopicsGenerator, Translator) | |
from .pdf_processor import PDFProcessor | |
processors = { | |
'pdf': PDFProcessor, | |
} | |
class WorkFlowController(): | |
def __init__(self, file_src) -> None: | |
# check if the file_path is list | |
# self.file_paths = self.__get_file_name(file_src) | |
self.file_paths = [x.name for x in file_src] | |
print(self.file_paths) | |
self.files_info = {} | |
for file_path in self.file_paths: | |
file_name = file_path.split('/')[-1] | |
file_format = file_path.split('.')[-1] | |
self.file_processor = processors[file_format] | |
file = self.file_processor(file_path).file_info | |
file = self.__process_file(file) | |
self.files_info[file_name] = file | |
self.__dump_to_json() | |
self.__dump_to_csv() | |
def __get_summary(self, file: dict): | |
# get summary from file content | |
summarizer = Summarizer() | |
file['summarized_content'] = summarizer.summarize(file['file_full_content']) | |
return file | |
def __get_keywords(self, file: dict): | |
# get keywords from file content | |
keywords_generator = KeywordsGenerator() | |
file['keywords'] = keywords_generator.extract_keywords(file['file_full_content']) | |
return file | |
def __get_topics(self, file: dict): | |
# get topics from file content | |
topics_generator = TopicsGenerator() | |
file['topics'] = topics_generator.extract_topics(file['file_full_content']) | |
return file | |
def __get_embedding(self, file): | |
# get embedding from file content | |
# return embedding | |
embedding_generator = EmbeddingGenerator() | |
for i, _ in enumerate(file['file_content']): | |
# use i+1 to meet the index of file_content | |
file['file_content'][i+1]['page_embedding'] = embedding_generator.get_embedding(file['file_content'][i+1]['page_content']) | |
return file | |
def __translate_to_chinese(self, file: dict): | |
# translate file content to chinese | |
translator = Translator() | |
# reset the file full content | |
file['file_full_content'] = '' | |
for i, _ in enumerate(file['file_content']): | |
# use i+1 to meet the index of file_content | |
file['file_content'][i+1]['page_content'] = translator.translate_to_chinese(file['file_content'][i+1]['page_content']) | |
file['file_full_content'] = file['file_full_content'] + file['file_content'][i+1]['page_content'] | |
return file | |
def __process_file(self, file: dict): | |
# process file content | |
# return processed data | |
if not file['is_chinese']: | |
file = self.__translate_to_chinese(file) | |
file = self.__get_embedding(file) | |
file = self.__get_summary(file) | |
return file | |
def __dump_to_json(self): | |
with open(os.path.join(os.getcwd(), 'knowledge_base.json'), 'w', encoding='utf-8') as f: | |
print("Dumping to json, the path is: " + os.path.join(os.getcwd(), 'knowledge_base.json')) | |
self.json_result_path = os.path.join(os.getcwd(), 'knowledge_base.json') | |
json.dump(self.files_info, f, indent=4, ensure_ascii=False) | |
def __construct_knowledge_base_dataframe(self): | |
rows = [] | |
for file_path, content in self.files_info.items(): | |
file_full_content = content["file_full_content"] | |
for page_num, page_details in content["file_content"].items(): | |
row = { | |
"file_name": content["file_name"], | |
"page_num": page_details["page_num"], | |
"page_content": page_details["page_content"], | |
"page_embedding": page_details["page_embedding"], | |
"file_full_content": file_full_content, | |
} | |
rows.append(row) | |
columns = ["file_name", "page_num", "page_content", "page_embedding", "file_full_content"] | |
df = pd.DataFrame(rows, columns=columns) | |
return df | |
def __dump_to_csv(self): | |
df = self.__construct_knowledge_base_dataframe() | |
df.to_csv(os.path.join(os.getcwd(), 'knowledge_base.csv'), index=False) | |
print("Dumping to csv, the path is: " + os.path.join(os.getcwd(), 'knowledge_base.csv')) | |
self.csv_result_path = os.path.join(os.getcwd(), 'knowledge_base.csv') | |
def __get_file_name(self, file_src): | |
file_paths = [x.name for x in file_src] | |
file_paths.sort(key=lambda x: os.path.basename(x)) | |
md5_hash = hashlib.md5() | |
for file_path in file_paths: | |
with open(file_path, "rb") as f: | |
while chunk := f.read(8192): | |
md5_hash.update(chunk) | |
return md5_hash.hexdigest() |