traces-tool / main.py
TRACES's picture
Update main.py
157fc26 verified
raw
history blame
7.66 kB
import json
import os
import streamlit as st
import pickle
from transformers import AutoTokenizer, BertForSequenceClassification, pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
def load_models():
st.session_state.loaded = True
# with open('models/tfidf_vectorizer_untrue_inform_detection_tfidf_bg_0.96_F1_score_3Y_N_Q1_082023.pkl', 'rb') as f:
# st.session_state.tfidf_vectorizer_untrue_inf = pickle.load(f)
# with open('models/SVM_model_untrue_inform_detection_tfidf_bg_0.96_F1_score_3Y_N_Q1_082023.pkl', 'rb') as f:
# st.session_state.untrue_detector = pickle.load(f)
st.session_state.bert_disinfo = pipeline(task="text-classification",
model=BertForSequenceClassification.from_pretrained("usmiva/bert-desinform-bg", num_labels=2),
tokenizer=AutoTokenizer.from_pretrained("usmiva/bert-desinform-bg"))
st.session_state.bert_gpt = pipeline(task="text-classification",
model=BertForSequenceClassification.from_pretrained("usmiva/bert-deepfake-bg", num_labels=2),
tokenizer=AutoTokenizer.from_pretrained("usmiva/bert-deepfake-bg"))
st.session_state.emotions = pipeline(task="text-classification",
model=BertForSequenceClassification.from_pretrained("TRACES/emotions", use_auth_token=os.environ['ACCESS_TOKEN2'], num_labels=11),
tokenizer=AutoTokenizer.from_pretrained("usmiva/bert-web-bg"))
def load_content():
with open('resource/page_content.json', encoding='utf8') as json_file:
return json.load(json_file)
def switch_lang(lang):
if 'lang' in st.session_state:
if lang == 'bg':
st.session_state.lang = 'bg'
else:
st.session_state.lang = 'en'
if 'lang' not in st.session_state:
st.session_state.lang = 'bg'
if all([
'bert_gpt_result' not in st.session_state,
# 'untrue_detector_result' not in st.session_state,
'bert_disinfo_result' not in st.session_state,
'emotions_result' not in st.session_state
]):
st.session_state.bert_gpt_result = [{'label': '', 'score': 1}]
# st.session_state.untrue_detector_result = ''
# st.session_state.untrue_detector_probability = 1
st.session_state.bert_disinfo_result = [{'label': '', 'score': 1}]
st.session_state.emotions_result = [{'label': '', 'score': 1}]
content = load_content()
if 'loaded' not in st.session_state:
load_models()
#######################################################################################################################
st.title(content['title'][st.session_state.lang])
col1, col2, col3 = st.columns([1, 1, 10])
with col1:
st.button(
label='EN',
key='en',
on_click=switch_lang,
args=['en']
)
with col2:
st.button(
label='BG',
key='bg',
on_click=switch_lang,
args=['bg']
)
if 'agree' not in st.session_state:
st.session_state.agree = False
if st.session_state.agree:
tab_tool, tab_terms = st.tabs([content['tab_tool'][st.session_state.lang], content['tab_terms'][st.session_state.lang]])
with tab_tool:
user_input = st.text_area(content['textbox_title'][st.session_state.lang],
content['text_placeholder'][st.session_state.lang]).strip('\n')
if st.button(content['analyze_button'][st.session_state.lang]):
st.session_state.bert_gpt_result = st.session_state.bert_gpt(user_input)
# user_tfidf_untrue_inf = st.session_state.tfidf_vectorizer_untrue_inf.transform([user_input])
# st.session_state.untrue_detector_result = st.session_state.untrue_detector.predict(user_tfidf_untrue_inf)[0]
# st.session_state.untrue_detector_probability = st.session_state.untrue_detector.predict_proba(user_tfidf_untrue_inf)[0]
# st.session_state.untrue_detector_probability = max(st.session_state.untrue_detector_probability[0], st.session_state.untrue_detector_probability[1])
st.session_state.bert_disinfo_result = st.session_state.bert_disinfo(user_input)
st.session_state.emotions_result = st.session_state.emotions(user_input)
if st.session_state.bert_gpt_result[0]['label'] == 'LABEL_1':
st.warning(content['bert_gpt'][st.session_state.lang] +
str(round(st.session_state.bert_gpt_result[0]['score'] * 100, 2)) +
content['bert_gpt_prob'][st.session_state.lang], icon = "⚠️")
else:
st.success(content['bert_human'][st.session_state.lang] +
str(round(st.session_state.bert_gpt_result[0]['score'] * 100, 2)) +
content['bert_human_prob'][st.session_state.lang], icon="✅")
# if st.session_state.untrue_detector_result == 0:
# st.warning(content['untrue_getect_yes'][st.session_state.lang] +
# str(round(st.session_state.untrue_detector_probability * 100, 2)) +
# content['untrue_yes_proba'][st.session_state.lang], icon="⚠️")
# else:
# st.success(content['untrue_getect_no'][st.session_state.lang] +
# str(round(st.session_state.untrue_detector_probability * 100, 2)) +
# content['untrue_no_proba'][st.session_state.lang], icon="✅")
if st.session_state.bert_disinfo_result[0]['label'] == 'LABEL_1':
st.warning(content['bert_yes_1'][st.session_state.lang] +
str(round(st.session_state.bert_disinfo_result[0]['score'] * 100, 2)) +
content['bert_yes_2'][st.session_state.lang], icon = "⚠️")
else:
st.success(content['bert_no_1'][st.session_state.lang] +
str(round(st.session_state.bert_disinfo_result[0]['score'] * 100, 2)) +
content['bert_no_2'][st.session_state.lang], icon="✅")
if st.session_state.emotions_result[0]['score'] < 0.97:
st.warning(content['emotions_label_1'][st.session_state.lang] +
str(st.session_state.emotions_result[0]['label']) +
content['emotions_label_2'][st.session_state.lang] +
str(round(st.session_state.emotions_result[0]['score'] * 100, 2)) +
content['emotions_label_3'][st.session_state.lang] +
content['emotions_label_4'][st.session_state.lang], icon = "⚠️")
else:
st.info(content['emotions_label_1'][st.session_state.lang] +
str(st.session_state.emotions_result[0]['label']) +
content['emotions_label_2'][st.session_state.lang] +
str(round(st.session_state.emotions_result[0]['score'] * 100, 2)) +
content['emotions_label_3'][st.session_state.lang]+
content['emotions_label_5'][st.session_state.lang])
st.info(content['disinformation_definition'][st.session_state.lang], icon="ℹ️")
with tab_terms:
st.write(content['disclaimer'][st.session_state.lang])
else:
st.write(content['disclaimer_title'][st.session_state.lang])
st.write(content['disclaimer'][st.session_state.lang])
if st.button(content['disclaimer_agree_text'][st.session_state.lang]):
st.session_state.agree = True
st.experimental_rerun()