|
import pandas as pd |
|
import streamlit as st |
|
import torch |
|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import BertTokenizer |
|
import appbuilder |
|
from transformers import BertModel |
|
|
|
pretrained = BertModel.from_pretrained('hfl/chinese-macbert-base') |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
pretrained.to(device) |
|
|
|
for param in pretrained.parameters(): |
|
param.requires_grad_(False) |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
def __init__(self, hidden_size, num_heads): |
|
super(MultiHeadAttention, self).__init__() |
|
|
|
assert hidden_size % num_heads == 0 |
|
self.hidden_size = hidden_size |
|
self.num_heads = num_heads |
|
self.head_dim = hidden_size // num_heads |
|
|
|
self.linear_q = nn.Linear(hidden_size, hidden_size) |
|
self.linear_k = nn.Linear(hidden_size, hidden_size) |
|
self.linear_v = nn.Linear(hidden_size, hidden_size) |
|
self.linear_out = nn.Linear(hidden_size, hidden_size) |
|
|
|
def forward(self, x): |
|
batch_size, seq_len, _ = x.size() |
|
|
|
q = self.linear_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
k = self.linear_k(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
v = self.linear_v(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float)) |
|
attn_weights = F.softmax(scores, dim=-1) |
|
|
|
context = torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size) |
|
out = self.linear_out(context) |
|
return out |
|
|
|
class Model(nn.Module): |
|
def __init__(self): |
|
super(Model, self).__init__() |
|
self.fc1 = nn.Linear(768, 512) |
|
self.fc2 = nn.Linear(512, 256) |
|
self.fc3 = nn.Linear(256, 2) |
|
self.dropout = nn.Dropout(p=0.5) |
|
self.bn1 = nn.BatchNorm1d(512) |
|
self.bn2 = nn.BatchNorm1d(256) |
|
self.activation = nn.ReLU() |
|
self.multihead_attention = MultiHeadAttention(hidden_size=768, num_heads=8) |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids): |
|
out = pretrained(input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids).last_hidden_state |
|
|
|
|
|
out = self.multihead_attention(out) |
|
out = out[:, 0] |
|
|
|
out = self.activation(self.bn1(self.fc1(out))) |
|
out = self.dropout(out) |
|
out = self.activation(self.bn2(self.fc2(out))) |
|
out = self.dropout(out) |
|
out = self.fc3(out) |
|
out = out.softmax(dim=1) |
|
return out |
|
|
|
|
|
def load_models_and_predict(text, device): |
|
|
|
MacBERT_base_CDialBias = torch.load('models/MacBERT-base-CDialBias.pth', map_location=torch.device('cpu')) |
|
MacBERT_base_COLD = torch.load('models/MacBERT-base-CDialBias.pth', map_location=torch.device('cpu')) |
|
|
|
|
|
os.environ['APPBUILDER_TOKEN'] = "bce-v3/ALTAK-n2XgeA6FS3Q5E7Jab6UwE/850b44ebec64c4cad705986ab0b5e3df4b05d407" |
|
app_id = "df881861-9fa6-40b6-b3bd-26df5f5d4b9a" |
|
|
|
|
|
your_agent = appbuilder.AppBuilderClient(app_id) |
|
|
|
|
|
conversation_id = your_agent.create_conversation() |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('hfl/chinese-macbert-base') |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) |
|
|
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
MacBERT_base_CDialBias.eval() |
|
MacBERT_base_COLD.eval() |
|
|
|
|
|
msg = your_agent.run(conversation_id, text) |
|
answer = msg.content.answer |
|
|
|
|
|
with torch.no_grad(): |
|
out1 = MacBERT_base_CDialBias(**inputs) |
|
with torch.no_grad(): |
|
out2 = MacBERT_base_COLD(**inputs) |
|
|
|
out1 = torch.argmax(out1, dim=1).item() |
|
out2 = torch.argmax(out2, dim=1).item() |
|
out3 = answer[0] |
|
|
|
|
|
if out3 == "1": |
|
if out1 == out2 == 1: |
|
result = "这句话具有攻击性和社会偏见" |
|
elif out1 == 0 and out2 == 1: |
|
result = "这句话具有攻击性,但无社会偏见" |
|
elif out1 == 1 and out2 == 0: |
|
result = "这句话不具有攻击性,但具有社会偏见" |
|
else: |
|
result = "这句话具有攻击性" |
|
elif out3 == "0": |
|
if out1 == out2 == 0: |
|
result = "这句话不具有攻击性和社会偏见" |
|
elif out1 == 0 and out2 == 1: |
|
result = "这句话具有攻击性,但无社会偏见" |
|
elif out1 == 1 and out2 == 0: |
|
result = "这句话不具有攻击性,但具有社会偏见" |
|
else: |
|
result = "这句话不具有攻击性" |
|
return result |
|
|
|
|
|
st.set_page_config(page_title="文件式文本检测工具") |
|
st.title("批量检测攻击性和偏见") |
|
|
|
with st.sidebar: |
|
|
|
if 'logged_in' not in st.session_state: |
|
st.session_state.logged_in = False |
|
|
|
|
|
username = st.sidebar.text_input('用户名') |
|
password = st.sidebar.text_input('密码', type='password') |
|
|
|
|
|
if st.sidebar.button('登录'): |
|
|
|
if username == 'admin' and password == '12345': |
|
st.session_state.logged_in = True |
|
st.sidebar.success('登录成功!') |
|
|
|
|
|
else: |
|
st.error('用户名或密码错误,请重试。') |
|
st.stop() |
|
st.divider() |
|
|
|
|
|
file = st.file_uploader("上传你的CSV文件", type=["csv"]) |
|
|
|
if file is not None: |
|
|
|
df = pd.read_csv(file) |
|
st.dataframe(df) |
|
|
|
|
|
column = st.text_input("请输入需要判断的内容的列名") |
|
|
|
|
|
save_results = st.checkbox("保存结果为CSV文件") |
|
|
|
if st.button("开始检测") : |
|
if st.session_state.logged_in == False: |
|
st.error("请先登录!") |
|
st.stop() |
|
if column not in df.columns: |
|
st.error(f"列名 '{column}' 不存在于数据集中,请检查并重新输入。") |
|
else: |
|
|
|
results_df = pd.DataFrame(columns=['检测文本', '检测结果']) |
|
|
|
|
|
progress_bar = st.progress(0) |
|
|
|
|
|
stop_flag = False |
|
|
|
|
|
stop_button = st.button("停止检测") |
|
|
|
for i, (index, row) in enumerate(df.iterrows()): |
|
|
|
|
|
if stop_button: |
|
stop_flag = True |
|
break |
|
|
|
text = row[column] |
|
|
|
|
|
with st.spinner("AI正在思考中,请稍等..."): |
|
result = load_models_and_predict(text, device) |
|
|
|
|
|
results_df.loc[i] = [text, result] |
|
r = results_df.loc[i] |
|
|
|
|
|
st.dataframe(r) |
|
|
|
st.divider() |
|
|
|
|
|
progress_bar.progress((i + 1) / len(df)) |
|
|
|
|
|
progress_bar.empty() |
|
|
|
|
|
if stop_flag: |
|
st.warning("检测已停止。") |
|
else: |
|
st.success("所有文本已检测完成!") |
|
|
|
|
|
if save_results and not stop_flag: |
|
|
|
csv_result = results_df.to_csv(index=False) |
|
st.download_button( |
|
label="下载结果", |
|
data=csv_result, |
|
file_name='results.csv', |
|
mime='text/csv' |
|
) |
|
|