Spaces:
Runtime error
Runtime error
import re | |
import os | |
import logging | |
from typing import List | |
from opencc import OpenCC | |
import openai | |
import tiktoken | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
class GPTAgent: | |
def __init__(self, model): | |
openai.api_key = OPENAI_API_KEY | |
self.model = model | |
self.temperature = 0.8 | |
self.frequency_penalty = 0 | |
self.presence_penalty = 0.6 | |
self.max_tokens = 2048 | |
self.split_max_tokens = 13000 | |
def request(self, messages): | |
response = self.agent.complete(messages=messages) | |
return response.choices[0].message["content"] | |
def split_into_many(text): | |
tokenizer = tiktoken.get_encoding("cl100k_base") | |
sentences = text.split("。") | |
n_tokens = [len(tokenizer.encode(" " + sentence)) for sentence in sentences] | |
chunks = [] | |
tokens_so_far = 0 | |
chunk = [] | |
for sentence, token in zip(sentences, n_tokens): | |
if tokens_so_far + token > 500: | |
chunks.append("。".join(chunk) + "。") | |
chunk = [] | |
tokens_so_far = 0 | |
if token > 500: | |
continue | |
chunk.append(sentence) | |
tokens_so_far += token + 1 | |
chunks.append("。".join(chunk) + "。") | |
return [text] if len(chunks) == 0 else chunks | |
def preprocess(self, text): | |
text = text.replace("\n", " ").replace("\r", "") | |
return text | |
def parse_result(self, result): | |
parsed_result = [] | |
chinese_converter = OpenCC("s2tw") | |
for i in range(len(result)): | |
result[i] = result[i].split(",") | |
if len(result[i]) == 1: | |
result[i] = result[i][0].split("、") | |
if len(result[i]) == 1: | |
result[i] = result[i][0].split(",") | |
for word in result[i]: | |
try: | |
parsed_result.append( | |
chinese_converter.convert(word).strip().replace("。", "") | |
) | |
except Exception as e: | |
logging.error(e) | |
logging.error("Failed to parse result") | |
return parsed_result | |
class Translator(GPTAgent): | |
def __init__(self): | |
super().__init__("gpt-3.5-turbo") | |
def translate_to_chinese(self, text): | |
system_prompt = """ | |
I want you to act as an Chinese translator, spelling corrector and improver. | |
I will speak to you in English, translate it and answer in the corrected and improved version of my text, in Traditional Chinese. | |
Keep the meaning same, but make them more literary. I want you to only reply the correction, the improvements and nothing else, do not write explanations and DO NOT use any Simplified Chinese. | |
""" | |
system_prompt_zh_tw = """ | |
我希望你擔任中文翻譯、拼寫糾正及改進的角色。 | |
我將用英文與你交流,請將其翻譯並用繁體中文回答,同時對我的文本進行糾正和改進。 | |
保持原意不變,但使其更具文學性。我希望你僅回覆更正、改進的部分,不要寫解釋,也不要使用任何简体中文。 | |
""" | |
messages = [ | |
{"role": "system", "content": f"{system_prompt_zh_tw}"}, | |
{"role": "user", "content": text}, | |
] | |
try: | |
response = openai.ChatCompletion.create( | |
model=self.model, | |
messages=messages, | |
temperature=self.temperature, | |
frequency_penalty=self.frequency_penalty, | |
presence_penalty=self.presence_penalty, | |
) | |
except Exception as e: | |
logging.error(e) | |
logging.error("Failed to translate to Chinese") | |
# translate from simplified chinese to traditional chinese | |
chinese_converter = OpenCC("s2tw") | |
return chinese_converter.convert( | |
response["choices"][0]["message"]["content"].strip() | |
) | |
class EmbeddingGenerator(GPTAgent): | |
def __init__(self): | |
super().__init__("text-davinci-002") | |
def get_embedding(self, text): | |
return openai.Embedding.create(input=text, engine="text-embedding-ada-002")[ | |
"data" | |
][0]["embedding"] | |
class KeywordsGenerator(GPTAgent): | |
def __init__(self): | |
super().__init__("gpt-3.5-turbo") | |
def extract_keywords(self, text): | |
system_prompt = """ | |
請你為以下內容抓出 5 個關鍵字用以搜尋這篇文章,並用「,」來分隔 | |
""" | |
text_chunks = self.split_into_many(text) | |
keywords = [] | |
for i in range(len(text_chunks)): | |
text = text_chunks[i] | |
messages = [ | |
{"role": "system", "content": f"{system_prompt}"}, | |
{"role": "user", "content": f"{self.preprocess(text)}"}, | |
] | |
try: | |
response = openai.ChatCompletion.create( | |
model=self.model, | |
messages=messages, | |
temperature=0, | |
max_tokens=self.max_tokens, | |
frequency_penalty=self.frequency_penalty, | |
presence_penalty=self.presence_penalty, | |
) | |
keywords.append(response["choices"][0]["message"]["content"].strip()) | |
except Exception as e: | |
logging.error(e) | |
logging.error("Failed to extract keywords") | |
return self.parse_result(keywords) | |
class TopicsGenerator(GPTAgent): | |
def __init__(self): | |
super().__init__("gpt-3.5-turbo") | |
def extract_topics(self, text): | |
system_prompt = """ | |
請你為以下內容給予 3 個高度抽象的主題分類這篇文章,並用「,」來分隔 | |
""" | |
text_chunks = self.split_into_many(text) | |
topics = [] | |
for i in range(len(text_chunks)): | |
text = text_chunks[i] | |
messages = [ | |
{"role": "system", "content": f"{system_prompt}"}, | |
{"role": "user", "content": f"{self.preprocess(text)}"}, | |
] | |
try: | |
response = openai.ChatCompletion.create( | |
model=self.model, | |
messages=messages, | |
temperature=0, | |
max_tokens=self.max_tokens, | |
frequency_penalty=self.frequency_penalty, | |
presence_penalty=self.presence_penalty, | |
) | |
topics.append(response["choices"][0]["message"]["content"].strip()) | |
except Exception as e: | |
logging.error(e) | |
logging.error("Failed to extract topics") | |
return self.parse_result(topics) | |
class Summarizer(GPTAgent): | |
def __init__(self): | |
super().__init__("gpt-3.5-turbo-16k") | |
def summarize(self, text): | |
system_prompt = """ | |
請幫我總結以下的文章。 | |
""" | |
text_chunks = self.split_into_many(text) | |
if len(text_chunks) > 1: | |
concated_summary = "" | |
for i in range(len(text_chunks)): | |
text_chunk = text[i].replace("\n", " ").replace("\r", "") | |
messages = [ | |
{"role": "system", "content": f"{system_prompt}"}, | |
{"role": "user", "content": text_chunk}, | |
] | |
try: | |
response = openai.ChatCompletion.create( | |
model=self.model, | |
messages=messages, | |
temperature=self.temperature, | |
max_tokens=self.max_tokens, | |
frequency_penalty=self.frequency_penalty, | |
presence_penalty=self.presence_penalty, | |
) | |
except Exception as e: | |
logging.error(e) | |
logging.error("Failed to summarize text_chunk") | |
chinese_converter = OpenCC("s2tw") | |
concated_summary += chinese_converter.convert( | |
response["choices"][0]["message"]["content"].strip() | |
) | |
# summarize concated_summary | |
messages = [ | |
{"role": "system", "content": f"{system_prompt}"}, | |
{"role": "user", "content": concated_summary}, | |
] | |
try: | |
response = openai.ChatCompletion.create( | |
model=self.model, | |
messages=messages, | |
temperature=self.temperature, | |
max_tokens=self.max_tokens, | |
frequency_penalty=self.frequency_penalty, | |
presence_penalty=self.presence_penalty, | |
) | |
except Exception as e: | |
logging.error(e) | |
logging.error("Failed to summarize concated_summary") | |
chinese_converter = OpenCC("s2tw") | |
return chinese_converter.convert( | |
response["choices"][0]["message"]["content"].strip() | |
) | |
else: | |
messages = [ | |
{"role": "system", "content": f"{system_prompt}"}, | |
{"role": "user", "content": text}, | |
] | |
try: | |
response = openai.ChatCompletion.create( | |
model=self.model, | |
messages=messages, | |
temperature=self.temperature, | |
max_tokens=self.max_tokens, | |
frequency_penalty=self.frequency_penalty, | |
presence_penalty=self.presence_penalty, | |
) | |
except Exception as e: | |
logging.error(e) | |
logging.error("Failed to summarize") | |
chinese_converter = OpenCC("s2tw") | |
print(f'the summary is {response["choices"][0]["message"]["content"].strip()}') | |
response = chinese_converter.convert( | |
response["choices"][0]["message"]["content"] | |
) | |
return response | |
class QuestionAnswerer(GPTAgent): | |
def __init__(self): | |
super().__init__("gpt-3.5-turbo-16k") | |
def answer_chunk_question(self, text, question): | |
system_prompt = """ | |
你是一個知識檢索系統,我會給你一份文件,請幫我依照文件內容回答問題,並用繁體中文回答。以下是文件內容 | |
""" | |
text_chunks = self.split_into_many(text) | |
answer_chunks = [] | |
for i in range(len(text_chunks)): | |
text = text_chunks[i] | |
messages = [ | |
{"role": "system", "content": f"{system_prompt} + '\n' '{text}'"}, | |
{"role": "user", "content": f"{question}"}, | |
] | |
try: | |
response = openai.ChatCompletion.create( | |
model=self.model, | |
messages=messages, | |
temperature=self.temperature, | |
max_tokens=1024, | |
frequency_penalty=self.frequency_penalty, | |
presence_penalty=self.presence_penalty, | |
) | |
except Exception as e: | |
logging.error(e) | |
logging.error("Failed to answer question") | |
chinese_converter = OpenCC("s2tw") | |
answer_chunks.append( | |
chinese_converter.convert( | |
response["choices"][0]["message"]["content"].strip() | |
) | |
) | |
return "。".join(answer_chunks) | |
def answer_question(self, context, context_page_num, context_file_name, history): | |
system_prompt = """ | |
你是一個知識檢索系統,我會給你一份文件,請幫我依照文件內容回答問題,並用繁體中文回答。以下是文件內容 | |
""" | |
history = self.__construct_message_history(history) | |
messages = [ | |
{"role": "system", "content": f"{system_prompt} + '\n' '''{context}'''"}, | |
] + history | |
try: | |
response = openai.ChatCompletion.create( | |
model=self.model, | |
messages=messages, | |
temperature=self.temperature, | |
max_tokens=2048, | |
frequency_penalty=self.frequency_penalty, | |
presence_penalty=self.presence_penalty, | |
) | |
chinese_converter = OpenCC("s2tw") | |
page_num_message = f"以下內容來自 {context_file_name},第 {context_page_num} 頁\n\n" | |
bot_answer = response["choices"][0]["message"]["content"] | |
whole_answer = page_num_message + bot_answer | |
return chinese_converter.convert(whole_answer) | |
except Exception as e: | |
logging.error(e) | |
logging.error("Failed to answer question") | |
def __construct_message_history(self, history): | |
print(f"history is {history}") | |
max_history_length = 10 | |
if len(history) > max_history_length: | |
history = history[-max_history_length:] | |
messages = [] | |
for i in range(len(history)): | |
messages.append({"role": "user", "content": history[i][0]}) | |
if history[i][1] is not None: | |
messages.append({"role": "assistant", "content": history[i][1]}) | |
return messages | |