|
import base64 |
|
import json |
|
import pickle |
|
|
|
import pandas as pd |
|
import streamlit as st |
|
import torch |
|
import torch.nn as nn |
|
import transformers |
|
|
|
from model.funcs import (create_model_and_tokenizer, execution_time, |
|
load_model, predict_sentiment) |
|
from model.model import LSTMConcatAttentionEmbed |
|
from preprocessing.preprocessing import data_preprocessing |
|
from preprocessing.rnn_preprocessing import preprocess_single_string |
|
|
|
|
|
def get_base64(file_path): |
|
with open(file_path, "rb") as file: |
|
base64_bytes = base64.b64encode(file.read()) |
|
base64_string = base64_bytes.decode("utf-8") |
|
return base64_string |
|
|
|
|
|
def set_background(png_file): |
|
bin_str = get_base64(png_file) |
|
page_bg_img = ( |
|
""" |
|
<style> |
|
.stApp { |
|
background-image: url("data:image/png;base64,%s"); |
|
background-size: auto; |
|
} |
|
</style> |
|
""" |
|
% bin_str |
|
) |
|
st.markdown(page_bg_img, unsafe_allow_html=True) |
|
|
|
|
|
set_background("main_background.png") |
|
|
|
|
|
@st.cache_resource |
|
def load_logreg(): |
|
with open("vectorizer.pkl", "rb") as f: |
|
logreg_vectorizer = pickle.load(f) |
|
|
|
with open("logreg_model.pkl", "rb") as f: |
|
logreg_predictor = pickle.load(f) |
|
return logreg_vectorizer, logreg_predictor |
|
|
|
|
|
logreg_vectorizer, logreg_predictor = load_logreg() |
|
|
|
|
|
@st.cache_resource |
|
def load_lstm(): |
|
with open("model/vocab.json", "r") as f: |
|
vocab_to_int = json.load(f) |
|
|
|
with open("model/int_vocab.json", "r") as f: |
|
int_to_vocab = json.load(f) |
|
model_concat_embed = LSTMConcatAttentionEmbed() |
|
model_concat_embed.load_state_dict(torch.load("model/model_weights.pt")) |
|
|
|
return vocab_to_int, int_to_vocab, model_concat_embed |
|
|
|
|
|
vocab_to_int, int_to_vocab, model_concat_embed = load_lstm() |
|
|
|
|
|
@st.cache_resource |
|
def load_bert(): |
|
model_class = transformers.AutoModel |
|
tokenizer_class = transformers.AutoTokenizer |
|
pretrained_weights = "cointegrated/rubert-tiny2" |
|
weights_path = "model/best_bert_weights.pth" |
|
model = load_model(model_class, pretrained_weights, weights_path) |
|
tokenizer = tokenizer_class.from_pretrained(pretrained_weights) |
|
|
|
return model, tokenizer |
|
|
|
|
|
model, tokenizer = load_bert() |
|
|
|
|
|
@execution_time |
|
def plot_and_predict(review: str, SEQ_LEN: int, model: nn.Module): |
|
inp = preprocess_single_string(review, SEQ_LEN, vocab_to_int) |
|
model.eval() |
|
with torch.inference_mode(): |
|
pred, _ = model(inp.long().unsqueeze(0)) |
|
pred = pred.sigmoid().item() |
|
return 1 if pred > 0.75 else 0 |
|
|
|
|
|
def preprocess_text_logreg(text): |
|
|
|
clean_text = data_preprocessing( |
|
text |
|
) |
|
vectorized_text = logreg_vectorizer.transform([" ".join(clean_text)]) |
|
return vectorized_text |
|
|
|
|
|
|
|
@execution_time |
|
def predict_sentiment_logreg(text): |
|
|
|
processed_text = preprocess_text_logreg(text) |
|
|
|
prediction = logreg_predictor.predict(processed_text) |
|
return prediction |
|
|
|
|
|
metrics = { |
|
"Models": ["Logistic Regression", "LSTM + attention", "ruBERTtiny2"], |
|
"f1-macro score": [0.94376, 0.93317, 0.94070], |
|
} |
|
|
|
|
|
df = pd.DataFrame(metrics) |
|
df.set_index("Models", inplace=True) |
|
df.index.name = "Model" |
|
|
|
|
|
st.sidebar.title("Model Selection") |
|
model_type = st.sidebar.radio("Select Model Type", ["Classic ML", "LSTM", "BERT"]) |
|
|
|
|
|
styled_text = """ |
|
<style> |
|
.styled-title { |
|
color: #FF00FF; |
|
font-size: 40px; |
|
text-shadow: -2px -2px 4px #000000; |
|
-webkit-text-stroke-width: 1px; |
|
-webkit-text-stroke-color: #000000; |
|
} |
|
.positive { |
|
color: #00FF00; |
|
font-size: 30px; |
|
text-shadow: -2px -2px 4px #000000; |
|
-webkit-text-stroke-width: 1px; |
|
-webkit-text-stroke-color: #000000; |
|
|
|
} |
|
.negative { |
|
color: #FF0000; |
|
font-size: 30px; |
|
text-shadow: -2px -2px 4px #000000; |
|
-webkit-text-stroke-width: 1px; |
|
-webkit-text-stroke-color: #000000; |
|
|
|
} |
|
</style> |
|
""" |
|
|
|
st.markdown(styled_text, unsafe_allow_html=True) |
|
|
|
|
|
st.markdown('<div class="styled-title">Review Prediction</div>', unsafe_allow_html=True) |
|
text_input = st.text_input("Enter your review:") |
|
if st.button("Predict"): |
|
if model_type == "Classic ML": |
|
prediction = predict_sentiment_logreg(text_input) |
|
elif model_type == "LSTM": |
|
prediction = plot_and_predict( |
|
review=text_input, SEQ_LEN=25, model=model_concat_embed |
|
) |
|
elif model_type == "BERT": |
|
prediction = predict_sentiment(text_input, model, tokenizer, "cpu") |
|
|
|
|
|
if prediction == 1: |
|
st.markdown( |
|
f'<div class="positive">Отзыв положительный</div>', unsafe_allow_html=True |
|
) |
|
elif prediction == 0: |
|
st.markdown( |
|
f'<div class="negative">Отзыв отрицательный</div>', unsafe_allow_html=True |
|
) |
|
|
|
st.write(df) |
|
|