convosim-ui-dev / pages /training_adherence.py
ivnban27-ctl's picture
training-adherence-features (#1)
f3e0ba5 verified
raw
history blame
3.47 kB
import streamlit as st
import numpy as np
from collections import defaultdict
from langchain_core.messages import HumanMessage
from utils.app_utils import are_models_alive
from models.ta_models.ta_utils import TA_predict_convo, ta_push_convo_comparison, post_process_response
from models.ta_models.config import QUESTION2PHASE, NAME2QUESTION, TA_OPTIONS
st.set_page_config(page_title="Conversation Simulator - Scoring")
if not are_models_alive():
st.switch_page("pages/model_loader.py")
if "memory" not in st.session_state:
st.switch_page("pages/convosim.py")
memory = st.session_state['memory']
progress_text = "Scoring Conversation using AI models ..."
@st.cache_data(show_spinner=False)
def get_ta_responses():
my_bar = st.progress(0, text=progress_text)
data = defaultdict(defaultdict)
for i, question in enumerate(QUESTION2PHASE.keys()):
# responses = ["Yes, The helper showed some respect.",
# "Yes. The helper is good! No doubt",
# "N/A, Texter disengaged.",
# "No. While texter is trying is lacking.",
# "No \n\n This is an explanation."]
# full_response = np.random.choice(responses)
full_convo, prompt, full_response = TA_predict_convo(memory, question, make_explanation=True, conversation_id=st.session_state['convo_id'])
response, explanation = post_process_response(full_response)
data[question]["response"] = response
data[question]["explanation"] = explanation
my_bar.progress((i+1) / len(QUESTION2PHASE.keys()), text = progress_text)
import time
time.sleep(2)
my_bar.empty()
return data
with st.container():
col1, col2 = st.columns(2)
if col1.button("Go Back"):
get_ta_responses.clear()
st.switch_page("pages/convosim.py")
expl = col2.checkbox("Show Scoring Explanations")
tab1, tab2 = st.tabs(["Scoring", "Conversation"])
data = get_ta_responses()
with tab2:
for msg in memory.buffer_as_messages:
role = "user" if type(msg) == HumanMessage else "assistant"
st.chat_message(role).write(msg.content)
with tab1:
for question in QUESTION2PHASE.keys():
with st.container(border=True):
question_str = NAME2QUESTION[question].split(' Answer')[0]
st.radio(
f"**{question_str}**", options=TA_OPTIONS,
index=TA_OPTIONS.index(data[question]['response']), horizontal=True,
key=f"{question}_manual"
)
if expl:
st.text_area(
label="", value=data[question]["explanation"], key=f"{question}_explanation_manual"
)
# st.write(data[question]["explanation"])
with st.container():
col1, col2 = st.columns(2)
if col1.button("Go Back", key="goback2"):
get_ta_responses.clear()
st.switch_page("pages/convosim.py")
if col2.button("Submit Scoring", type="primary"):
ytrue = {
question: {
"response": st.session_state[f"{question}_manual"],
"explanation": st.session_state[f"{question}_explanation_manual"] if expl else "",
}
for question in QUESTION2PHASE.keys()
}
ta_push_convo_comparison(ytrue, data)
get_ta_responses.clear()
st.switch_page("pages/convosim.py")