koheibaba
upload files
e7a412f
raw
history blame
3.74 kB
import gradio as gr
from utils.find_ng_word import get_ng_wordlist, get_ng_wordlist_from_saved, search_ng_word
from utils.llm import load_llm_from_pretrained, inference
wordlist_1_path_s = "/content/drive/MyDrive/llm_qlora_ngword/ng_wordlists/ng_wordlist_1_sexual.txt"
wordlist_2_path_s = "/content/drive/MyDrive/llm_qlora_ngword/ng_wordlists/ng_wordlist_2_sexual.txt"
wordlist_1_path_o = "/content/drive/MyDrive/llm_qlora_ngword/ng_wordlists/ng_wordlist_1_offensive.txt"
wordlist_2_path_o = "/content/drive/MyDrive/llm_qlora_ngword/ng_wordlists/ng_wordlist_2_offensive.txt"
pretrained_model_path = "input/llm_weights"
print("モデルをロード")
ng_wordlist_1_s, ng_wordlist_2_s = get_ng_wordlist_from_saved(wordlist_1_path_s, wordlist_2_path_s)
ng_wordlist_1_o, ng_wordlist_2_o = get_ng_wordlist_from_saved(wordlist_1_path_o, wordlist_2_path_o)
model, tokenizer = load_llm_from_pretrained(pretrained_model_path)
# 検出結果を生成
def detect_ng_word(input_text):
response = []
rtn_s = search_ng_word(data_point["input"], ng_wordlist_1_s, ng_wordlist_2_s)
rtn_o = search_ng_word(data_point["input"], ng_wordlist_1_o, ng_wordlist_2_o)
rtn = rtn_s + rtn_o
if len(rtn) == 0:
response.append("NGワードは検知されませんでした \n")
else:
response.append('以下のNGワードを検知しました \n')
for rtn_i in rtn:
ng_word = str(rtn_i) + " \n"
response.append(ng_word)
rtn_s = [ri + "(sexual)" for ri in rtn_s]
rtn_o = [ri + "(offensive)" for ri in rtn_o]
ngword_with_label = rtn_s + rtn_o
output = inference(model, tokenizer, input_text, ngword_with_label)
if output == "はい。攻撃的だから。</s>":
response.append('不適切な内容を検知しました(攻撃的)')
elif output == "はい。暴力的だから。</s>":
response.append('不適切な内容を検知しました(暴力的)')
elif output == "はい。差別的だから。</s>":
response.append('不適切な内容を検知しました(差別的)')
elif output == "はい。性的だから。</s>":
response.append('不適切な内容を検知しました(性的)')
elif output == "はい。政治的だから。</s>":
response.append('不適切な内容を検知しました(政治的)')
else:
response.append("不適切な内容は検知されませんでした")
return response
# 会話履歴用リスト型変数
message_history = []
def chat(user_msg):
"""
AIとの会話を実行後、全会話履歴を返す
user_msg: 入力されたユーザのメッセージ
"""
global message_history
# ユーザの会話を履歴に追加
message_history.append({
"role": "user",
"content": user_msg
})
# AIの回答を履歴に追加
response = detect_ng_word(user_msg)
assistant_msg = " ".join(response)
message_history.append({
"role": "assistant",
"content": assistant_msg
})
# 全会話履歴をChatbot用タプル・リストに変換して返す
return [(message_history[i]["content"], message_history[i+1]["content"]) for i in range(0, len(message_history)-1, 2)]
with gr.Blocks() as demo:
# チャットボットUI処理
chatbot = gr.Chatbot()
input = gr.Textbox(show_label=False, placeholder="チェックしたい文章を入力してください")
input.submit(fn=chat, inputs=input, outputs=chatbot) # メッセージ送信されたら、AIと会話してチャット欄に全会話内容を表示
input.submit(fn=lambda: "", inputs=None, outputs=input) # (上記に加えて)入力欄をクリア
demo.launch()