File size: 3,736 Bytes
e7a412f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr

from utils.find_ng_word import get_ng_wordlist, get_ng_wordlist_from_saved, search_ng_word
from utils.llm import load_llm_from_pretrained, inference


wordlist_1_path_s = "/content/drive/MyDrive/llm_qlora_ngword/ng_wordlists/ng_wordlist_1_sexual.txt"
wordlist_2_path_s = "/content/drive/MyDrive/llm_qlora_ngword/ng_wordlists/ng_wordlist_2_sexual.txt"
wordlist_1_path_o = "/content/drive/MyDrive/llm_qlora_ngword/ng_wordlists/ng_wordlist_1_offensive.txt"
wordlist_2_path_o = "/content/drive/MyDrive/llm_qlora_ngword/ng_wordlists/ng_wordlist_2_offensive.txt"

pretrained_model_path = "input/llm_weights"

print("モデルをロード")
ng_wordlist_1_s, ng_wordlist_2_s = get_ng_wordlist_from_saved(wordlist_1_path_s, wordlist_2_path_s)
ng_wordlist_1_o, ng_wordlist_2_o = get_ng_wordlist_from_saved(wordlist_1_path_o, wordlist_2_path_o)

model, tokenizer = load_llm_from_pretrained(pretrained_model_path)

# 検出結果を生成
def detect_ng_word(input_text):
    response = []
    rtn_s = search_ng_word(data_point["input"], ng_wordlist_1_s, ng_wordlist_2_s)
    rtn_o = search_ng_word(data_point["input"], ng_wordlist_1_o, ng_wordlist_2_o)
    rtn = rtn_s + rtn_o

    if len(rtn) == 0:
        response.append("NGワードは検知されませんでした  \n")
    else:
        response.append('以下のNGワードを検知しました  \n')
        for rtn_i in rtn:
            ng_word = str(rtn_i) + "  \n"
            response.append(ng_word)

    rtn_s = [ri + "(sexual)" for ri in rtn_s]
    rtn_o = [ri + "(offensive)" for ri in rtn_o]
    ngword_with_label = rtn_s + rtn_o
    
    output = inference(model, tokenizer, input_text, ngword_with_label)
    
    if output == "はい。攻撃的だから。</s>":
        response.append('不適切な内容を検知しました(攻撃的)')
    elif output == "はい。暴力的だから。</s>":
        response.append('不適切な内容を検知しました(暴力的)')
    elif output == "はい。差別的だから。</s>":
        response.append('不適切な内容を検知しました(差別的)')
    elif output == "はい。性的だから。</s>":
        response.append('不適切な内容を検知しました(性的)')
    elif output == "はい。政治的だから。</s>":
        response.append('不適切な内容を検知しました(政治的)')
    else:
        response.append("不適切な内容は検知されませんでした")
      
    return response


# 会話履歴用リスト型変数
message_history = []

def chat(user_msg):
    """
    AIとの会話を実行後、全会話履歴を返す
    user_msg: 入力されたユーザのメッセージ
    """
    global message_history

    # ユーザの会話を履歴に追加
    message_history.append({
        "role": "user",
        "content": user_msg
    })

    # AIの回答を履歴に追加
    response = detect_ng_word(user_msg)
    assistant_msg = " ".join(response)
    message_history.append({
        "role": "assistant",
        "content": assistant_msg
    })

    # 全会話履歴をChatbot用タプル・リストに変換して返す
    return [(message_history[i]["content"], message_history[i+1]["content"]) for i in range(0, len(message_history)-1, 2)]


with gr.Blocks() as demo:
    # チャットボットUI処理
    chatbot = gr.Chatbot()
    input = gr.Textbox(show_label=False, placeholder="チェックしたい文章を入力してください")
    input.submit(fn=chat, inputs=input, outputs=chatbot) # メッセージ送信されたら、AIと会話してチャット欄に全会話内容を表示
    input.submit(fn=lambda: "", inputs=None, outputs=input) # (上記に加えて)入力欄をクリア

demo.launch()