import streamlit as st | |
import numpy as np | |
from collections import defaultdict | |
from langchain_core.messages import HumanMessage | |
from utils.app_utils import are_models_alive | |
from models.ta_models.ta_utils import TA_predict_convo, ta_push_convo_comparison, post_process_response | |
from models.ta_models.config import QUESTION2PHASE, NAME2QUESTION, TA_OPTIONS | |
st.set_page_config(page_title="Conversation Simulator - Scoring") | |
if not are_models_alive(): | |
st.switch_page("pages/") | |
if "memory" not in st.session_state: | |
st.switch_page("pages/") | |
memory = st.session_state['memory'] | |
progress_text = "Scoring Conversation using AI models ..." | |
def get_ta_responses(): | |
my_bar = st.progress(0, text=progress_text) | |
data = defaultdict(defaultdict) | |
for i, question in enumerate(QUESTION2PHASE.keys()): | |
# responses = ["Yes, The helper showed some respect.", | |
# "Yes. The helper is good! No doubt", | |
# "N/A, Texter disengaged.", | |
# "No. While texter is trying is lacking.", | |
# "No \n\n This is an explanation."] | |
# full_response = np.random.choice(responses) | |
full_convo, prompt, full_response = TA_predict_convo(memory, question, make_explanation=True, conversation_id=st.session_state['convo_id']) | |
response, explanation = post_process_response(full_response) | |
data[question]["response"] = response | |
data[question]["explanation"] = explanation | |
my_bar.progress((i+1) / len(QUESTION2PHASE.keys()), text = progress_text) | |
import time | |
time.sleep(2) | |
my_bar.empty() | |
return data | |
with st.container(): | |
col1, col2 = st.columns(2) | |
if col1.button("Go Back"): | |
get_ta_responses.clear() | |
st.switch_page("pages/") | |
expl = col2.checkbox("Show Scoring Explanations") | |
tab1, tab2 = st.tabs(["Scoring", "Conversation"]) | |
data = get_ta_responses() | |
with tab2: | |
for msg in memory.buffer_as_messages: | |
role = "user" if type(msg) == HumanMessage else "assistant" | |
st.chat_message(role).write(msg.content) | |
with tab1: | |
for question in QUESTION2PHASE.keys(): | |
with st.container(border=True): | |
question_str = NAME2QUESTION[question].split(' Answer')[0] | | | |
f"**{question_str}**", options=TA_OPTIONS, | |
index=TA_OPTIONS.index(data[question]['response']), horizontal=True, | |
key=f"{question}_manual" | |
) | |
if expl: | |
st.text_area( | |
label="", value=data[question]["explanation"], key=f"{question}_explanation_manual" | |
) | |
# st.write(data[question]["explanation"]) | |
with st.container(): | |
col1, col2 = st.columns(2) | |
if col1.button("Go Back", key="goback2"): | |
get_ta_responses.clear() | |
st.switch_page("pages/") | |
if col2.button("Submit Scoring", type="primary"): | |
ytrue = { | |
question: { | |
"response": st.session_state[f"{question}_manual"], | |
"explanation": st.session_state[f"{question}_explanation_manual"] if expl else "", | |
} | |
for question in QUESTION2PHASE.keys() | |
} | |
ta_push_convo_comparison(ytrue, data) | |
get_ta_responses.clear() | |
st.switch_page("pages/") | |