Spaces:
Sleeping
Sleeping
import gradio as gr | |
from utils.find_ng_word import get_ng_wordlist, get_ng_wordlist_from_saved, search_ng_word | |
from utils.llm import load_llm_from_pretrained, inference | |
wordlist_1_path_s = "/content/drive/MyDrive/llm_qlora_ngword/ng_wordlists/ng_wordlist_1_sexual.txt" | |
wordlist_2_path_s = "/content/drive/MyDrive/llm_qlora_ngword/ng_wordlists/ng_wordlist_2_sexual.txt" | |
wordlist_1_path_o = "/content/drive/MyDrive/llm_qlora_ngword/ng_wordlists/ng_wordlist_1_offensive.txt" | |
wordlist_2_path_o = "/content/drive/MyDrive/llm_qlora_ngword/ng_wordlists/ng_wordlist_2_offensive.txt" | |
pretrained_model_path = "input/llm_weights" | |
print("モデルをロード") | |
ng_wordlist_1_s, ng_wordlist_2_s = get_ng_wordlist_from_saved(wordlist_1_path_s, wordlist_2_path_s) | |
ng_wordlist_1_o, ng_wordlist_2_o = get_ng_wordlist_from_saved(wordlist_1_path_o, wordlist_2_path_o) | |
model, tokenizer = load_llm_from_pretrained(pretrained_model_path) | |
# 検出結果を生成 | |
def detect_ng_word(input_text): | |
response = [] | |
rtn_s = search_ng_word(data_point["input"], ng_wordlist_1_s, ng_wordlist_2_s) | |
rtn_o = search_ng_word(data_point["input"], ng_wordlist_1_o, ng_wordlist_2_o) | |
rtn = rtn_s + rtn_o | |
if len(rtn) == 0: | |
response.append("NGワードは検知されませんでした \n") | |
else: | |
response.append('以下のNGワードを検知しました \n') | |
for rtn_i in rtn: | |
ng_word = str(rtn_i) + " \n" | |
response.append(ng_word) | |
rtn_s = [ri + "(sexual)" for ri in rtn_s] | |
rtn_o = [ri + "(offensive)" for ri in rtn_o] | |
ngword_with_label = rtn_s + rtn_o | |
output = inference(model, tokenizer, input_text, ngword_with_label) | |
if output == "はい。攻撃的だから。</s>": | |
response.append('不適切な内容を検知しました(攻撃的)') | |
elif output == "はい。暴力的だから。</s>": | |
response.append('不適切な内容を検知しました(暴力的)') | |
elif output == "はい。差別的だから。</s>": | |
response.append('不適切な内容を検知しました(差別的)') | |
elif output == "はい。性的だから。</s>": | |
response.append('不適切な内容を検知しました(性的)') | |
elif output == "はい。政治的だから。</s>": | |
response.append('不適切な内容を検知しました(政治的)') | |
else: | |
response.append("不適切な内容は検知されませんでした") | |
return response | |
# 会話履歴用リスト型変数 | |
message_history = [] | |
def chat(user_msg): | |
""" | |
AIとの会話を実行後、全会話履歴を返す | |
user_msg: 入力されたユーザのメッセージ | |
""" | |
global message_history | |
# ユーザの会話を履歴に追加 | |
message_history.append({ | |
"role": "user", | |
"content": user_msg | |
}) | |
# AIの回答を履歴に追加 | |
response = detect_ng_word(user_msg) | |
assistant_msg = " ".join(response) | |
message_history.append({ | |
"role": "assistant", | |
"content": assistant_msg | |
}) | |
# 全会話履歴をChatbot用タプル・リストに変換して返す | |
return [(message_history[i]["content"], message_history[i+1]["content"]) for i in range(0, len(message_history)-1, 2)] | |
with gr.Blocks() as demo: | |
# チャットボットUI処理 | |
chatbot = gr.Chatbot() | |
input = gr.Textbox(show_label=False, placeholder="チェックしたい文章を入力してください") | |
input.submit(fn=chat, inputs=input, outputs=chatbot) # メッセージ送信されたら、AIと会話してチャット欄に全会話内容を表示 | |
input.submit(fn=lambda: "", inputs=None, outputs=input) # (上記に加えて)入力欄をクリア | |
demo.launch() |