Spaces:
Sleeping
Sleeping
import time | |
import streamlit as st | |
from streamlit.logger import get_logger | |
from langchain.schema.messages import HumanMessage | |
from utils.mongo_utils import get_db_client, update_convo | |
from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF, are_models_alive | |
from utils.memory_utils import clear_memory, push_convo2db | |
from utils.chain_utils import get_chain, custom_chain_predict | |
from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT | |
from models.ta_models.config import CPC_LBL_OPTS, cpc_label2str, BP_LAB2STR, BP_LBL_OPTS | |
from models.ta_models.cpc_utils import cpc_push2db, modify_last_human_message | |
from models.ta_models.bp_utils import bp_predict_message, bp_push2db | |
logger = get_logger(__name__) | |
temperature = 0.8 | |
# username = "barb-chase" #"ivnban-ctl" | |
st.set_page_config(page_title="Conversation Simulator") | |
if "sent_messages" not in st.session_state: | |
st.session_state['sent_messages'] = 0 | |
if not are_models_alive(): | |
st.switch_page("pages/model_loader.py") | |
if "total_messages" not in st.session_state: | |
st.session_state['total_messages'] = 0 | |
if "issue" not in st.session_state: | |
st.session_state['issue'] = ISSUES[0] | |
if 'previous_source' not in st.session_state: | |
st.session_state['previous_source'] = SOURCES[0] | |
if 'db_client' not in st.session_state: | |
st.session_state["db_client"] = get_db_client() | |
if 'texter_name' not in st.session_state: | |
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF) | |
logger.debug(f"texter name is {st.session_state['texter_name']}") | |
if "last_phase" not in st.session_state: | |
st.session_state["last_phase"] = CPC_LBL_OPTS[0] | |
# st.session_state["sel_phase"] = CPC_LBL_OPTS[0] | |
if "changed_cpc" not in st.session_state: | |
st.session_state["changed_cpc"] = False | |
if "changed_bp" not in st.session_state: | |
st.session_state["changed_bp"] = False | |
if "last_message_ts" not in st.session_state: | |
st.session_state["last_message_ts"] = time.time() | |
# st.session_state["sel_phase"] = st.session_state["last_phase"] | |
memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}} | |
with st.sidebar: | |
username = st.text_input("Username", value='Dani', max_chars=30) | |
if 'counselor_name' not in st.session_state: | |
st.session_state["counselor_name"] = username #get_random_name(names_df=DEFAULT_NAMES_DF) | |
# temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1) | |
issue = st.selectbox("Select a Scenario", ISSUES, index=ISSUES.index(st.session_state['issue']), format_func=issue2label, | |
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"} | |
) | |
supported_languages = ['en', "es"] if issue == "Anxiety" else ['en'] | |
language = st.selectbox("Select a Language", supported_languages, index=0, | |
format_func=lambda x: "English" if x=="en" else "Spanish", | |
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"} | |
) | |
source = st.selectbox("Select a source Model A", SOURCES, index=0, | |
format_func=source2label, key="source" | |
) | |
changed_source = any([ | |
st.session_state['previous_source'] != source, | |
st.session_state['issue'] != issue, | |
st.session_state['counselor_name'] != username, | |
]) | |
if changed_source: | |
st.session_state["counselor_name"] = username | |
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF) | |
logger.debug(f"texter name is {st.session_state['texter_name']}") | |
st.session_state['previous_source'] = source | |
st.session_state['issue'] = issue | |
st.session_state['sent_messages'] = 0 | |
st.session_state['total_messages'] = 0 | |
create_memory_add_initial_message(memories, | |
issue, | |
language, | |
changed_source=changed_source, | |
counselor_name=st.session_state["counselor_name"], | |
texter_name=st.session_state["texter_name"]) | |
st.session_state['previous_source'] = source | |
memoryA = st.session_state[list(memories.keys())[0]] | |
# issue only without "." marker for model compatibility | |
llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"]) | |
st.title("💬 Simulator") | |
st.session_state['total_messages'] = len(memoryA.chat_memory.messages) | |
for msg in memoryA.buffer_as_messages: | |
role = "user" if type(msg) == HumanMessage else "assistant" | |
st.chat_message(role).write(msg.content) | |
def sent_request_llm(llm_chain, prompt): | |
st.session_state['sent_messages'] += 1 | |
st.chat_message("user").write(prompt) | |
responses = custom_chain_predict(llm_chain, prompt, stopper) | |
for response in responses: | |
st.chat_message("assistant").write(response) | |
transcript = memoryA.load_memory_variables({})[memoryA.memory_key] | |
update_convo(st.session_state["db_client"], st.session_state["convo_id"], transcript) | |
# @st.dialog("Bad Practice Detected") | |
# def confirm_bp(bp_prediction, prompt): | |
# bps = [BP_LAB2STR[x['label']] for x in bp_prediction if x['score']] | |
# st.markdown(f"The last message was considered :red[{' and '.join(bps)}]") | |
# "Are you sure you want to send this message?" | |
# newprompt = st.text_input("Change message to:") | |
# "If you do not want to change leave textbox empty" | |
# for bp in BP_LAB2STR.keys(): | |
# _ = st.checkbox(f"Original Message was {BP_LAB2STR[bp]}", key=f"chkbx_{bp}", value=BP_LAB2STR[bp] in bps) | |
# if st.button("Confirm"): | |
# if newprompt is not None and newprompt != "": | |
# prompt = newprompt | |
# bp_push2db( | |
# {bp:st.session_state[f"chkbx_{bp}"] for bp in BP_LAB2STR.keys()} | |
# ) | |
# sent_request_llm(llm_chain, prompt) | |
# st.rerun() | |
if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG_COUNT - 4): #account for next interaction | |
st.session_state['last_message_ts'] = time.time() | |
if 'convo_id' not in st.session_state: | |
push_convo2db(memories, username, language) | |
if st.session_state["sent_messages"] > 0: | |
if st.session_state.changed_cpc: | |
st.session_state["sel_phase"] = None | |
st.session_state.changed_cpc = False | |
else: | |
cpc_push2db(True) | |
if st.session_state.changed_bp: | |
st.session_state["sel_bp"] = None | |
st.session_state.changed_bp = False | |
else: | |
bp_push2db({x['label']:x['score'] for x in st.session_state['bp_prediction']}) | |
st.session_state['context'] = llm_chain.memory.load_memory_variables({})[llm_chain.memory.memory_key] | |
st.session_state['last_message'] = prompt | |
st.session_state['bp_prediction'] = bp_predict_message(st.session_state['context'], prompt) | |
if any([x['score'] for x in st.session_state['bp_prediction']]): | |
for bp in st.session_state['bp_prediction']: | |
if bp["score"]: | |
st.toast(f"Detected {BP_LAB2STR[bp['label']]} in the last message!", icon=":material/warning:") | |
sent_request_llm(llm_chain, prompt) | |
# else: | |
# sent_request_llm(llm_chain, prompt) | |
with st.sidebar: | |
if "convo_id" in st.session_state: | |
st.write(f"Conversation ID is `{st.session_state['convo_id']}`") | |
st.divider() | |
st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]") | |
st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]") | |
# st.markdown() | |
def on_change_cpc(): | |
cpc_push2db(False) | |
modify_last_human_message(memoryA, st.session_state['sel_phase']) | |
st.session_state.changed_cpc = True | |
def on_change_bp(): | |
bp_push2db() | |
st.session_state.changed_bp = True | |
if st.session_state["sent_messages"] > 0: | |
_ = st.selectbox(f"""Last Human Message was considered :blue[**{ | |
cpc_label2str(st.session_state['last_phase']) | |
}**]. If not please select from the following options""", | |
CPC_LBL_OPTS, index=None, format_func=cpc_label2str, on_change=on_change_cpc, | |
key="sel_phase", | |
) | |
BPs = [BP_LAB2STR[x['label']] for x in st.session_state['bp_prediction'] if x['score']] | |
selecttitle = f"""Last Human Message was considered :blue[**{ | |
" and ".join(BPs) | |
}**].""" if len(BPs) > 0 else "Last Human Message was NOT considered Bad Practice." | |
_ = st.selectbox(selecttitle + " If not please select from the following options""", | |
BP_LBL_OPTS, index=None, format_func=lambda x: x, on_change=on_change_bp, | |
key="sel_bp" | |
) | |
if st.button("Score Conversation"): | |
st.switch_page("pages/training_adherence.py") | |
st.session_state['total_messages'] = len(memoryA.chat_memory.messages) | |
if st.session_state['total_messages'] >= MAX_MSG_COUNT: | |
st.toast(f"Total of {MAX_MSG_COUNT} Messages reached. Conversation Ended", icon=":material/verified:") | |
elif st.session_state['total_messages'] >= WARN_MSG_COUT: | |
st.toast(f"The conversation will end at {MAX_MSG_COUNT} Total Messages ", icon=":material/warning:") | |
if time.time() - st.session_state['last_message_ts'] > 2400: # > 40 min | |
if not are_models_alive(): | |
st.switch_page("pages/model_loader.py") | |