Spaces:
Sleeping
training-adherence-features (#1)
Browse files- cpc and bad practices features (70a482f666f55772a84d65dc90e69fb21b3d96de)
- fix on tokenizer special tokens (cf6ebf90b37bb55854638f960ba1fb42e56b08f3)
- change in model aliveness calculations (2e79a3c9fdc409a99f016bad980f00c71e2860fc)
- aliveness calculation fixes (b74c038fe4c7136eb1858138c63c7ee531a574a8)
- training adherence scoring features (cfe8e1a5d05755a0346c269de80f255dfb39c501)
- making explanation editable (5f8859a38074d0644c3d0e5fc87c175ecb79071b)
- fix on roberta input len (5e4965f0930a2498341a903704d9ba347fd575e0)
- ta utils fix for explanation (7e14368f66f1989c0d91ee1fca1a5d16e638a523)
- changes on BL postprocessing (92dff98dd8d42a645c05178c183798b9fb1837e7)
- progress bar instead of spinner (a139603ab1caa2d4c2b71cca5e3d7b55e8788073)
- changed to prod (02544790fb5037419491dfe520951728f1e5cee6)
- .streamlit/config.toml +2 -0
- README.md +3 -3
- app_config.py +24 -2
- main.py +14 -0
- models/business_logic_utils/config.py +2 -1
- models/business_logic_utils/response_processing.py +6 -2
- models/databricks/texter_sim_llm.py +15 -4
- models/ta_models/bp_utils.py +66 -0
- models/ta_models/config.py +174 -0
- models/ta_models/cpc_utils.py +53 -0
- models/ta_models/ta_filter_utils.py +150 -0
- models/ta_models/ta_prompt_utils.py +128 -0
- models/ta_models/ta_utils.py +127 -0
- pages/convosim.py +185 -0
- pages/model_loader.py +56 -0
- pages/training_adherence.py +86 -0
- requirements.txt +2 -1
- utils/app_utils.py +51 -5
- utils/chain_utils.py +9 -2
- utils/memory_utils.py +0 -1
- utils/mongo_utils.py +57 -6
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[client]
|
2 |
+
showSidebarNavigation = false
|
@@ -1,11 +1,11 @@
|
|
1 |
---
|
2 |
-
title: Conversation Simulator
|
3 |
emoji: 💬
|
4 |
colorFrom: red
|
5 |
colorTo: red
|
6 |
sdk: streamlit
|
7 |
-
sdk_version: 1.
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
---
|
11 |
|
|
|
1 |
---
|
2 |
+
title: Conversation Simulator DEV
|
3 |
emoji: 💬
|
4 |
colorFrom: red
|
5 |
colorTo: red
|
6 |
sdk: streamlit
|
7 |
+
sdk_version: 1.38.0
|
8 |
+
app_file: main.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
@@ -18,9 +18,28 @@ SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT4o',
|
|
18 |
|
19 |
ENDPOINT_NAMES = {
|
20 |
# "CTL_llama2": "texter_simulator",
|
21 |
-
"CTL_llama3":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
# 'CTL_llama2': "llama2_convo_sim",
|
23 |
-
"CTL_mistral": "convo_sim_mistral"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
}
|
25 |
|
26 |
def source2label(source):
|
@@ -36,6 +55,9 @@ DB_CONVOS = 'conversations'
|
|
36 |
DB_COMPLETIONS = 'comparison_completions'
|
37 |
DB_BATTLES = 'battles'
|
38 |
DB_ERRORS = 'completion_errors'
|
|
|
|
|
|
|
39 |
|
40 |
MAX_MSG_COUNT = 60
|
41 |
WARN_MSG_COUT = int(MAX_MSG_COUNT*0.8)
|
|
|
18 |
|
19 |
ENDPOINT_NAMES = {
|
20 |
# "CTL_llama2": "texter_simulator",
|
21 |
+
"CTL_llama3": {
|
22 |
+
"name": "texter_simulator_llm",
|
23 |
+
"model_type": "text-generation"
|
24 |
+
},
|
25 |
+
# "CTL_llama3": {
|
26 |
+
# "name": "databricks-meta-llama-3-1-70b-instruct",
|
27 |
+
# "model_type": "text-generation"
|
28 |
+
# },
|
29 |
# 'CTL_llama2': "llama2_convo_sim",
|
30 |
+
# "CTL_mistral": "convo_sim_mistral",
|
31 |
+
"CPC": {
|
32 |
+
"name": "phase_classifier",
|
33 |
+
"model_type": "classificator"
|
34 |
+
},
|
35 |
+
"BadPractices": {
|
36 |
+
"name": "training_adherence_bp",
|
37 |
+
"model_type": "classificator"
|
38 |
+
},
|
39 |
+
"training_adherence": {
|
40 |
+
"name": "training_adherence",
|
41 |
+
"model_type": "text-completion"
|
42 |
+
},
|
43 |
}
|
44 |
|
45 |
def source2label(source):
|
|
|
55 |
DB_COMPLETIONS = 'comparison_completions'
|
56 |
DB_BATTLES = 'battles'
|
57 |
DB_ERRORS = 'completion_errors'
|
58 |
+
DB_CPC = "cpc_comparison"
|
59 |
+
DB_BP = "bad_practices_comparison"
|
60 |
+
DB_TA = "convo_scoring_comparison"
|
61 |
|
62 |
MAX_MSG_COUNT = 60
|
63 |
WARN_MSG_COUT = int(MAX_MSG_COUNT*0.8)
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit.logger import get_logger
|
3 |
+
|
4 |
+
from utils.app_utils import are_models_alive
|
5 |
+
|
6 |
+
logger = get_logger(__name__)
|
7 |
+
|
8 |
+
st.set_page_config(page_title="Conversation Simulator")
|
9 |
+
|
10 |
+
with st.status("Loading Models Please Wait...(this may take up to 5 min)", expanded=True) as status:
|
11 |
+
if not are_models_alive():
|
12 |
+
st.switch_page("pages/model_loader.py")
|
13 |
+
else:
|
14 |
+
st.switch_page("pages/convosim.py")
|
@@ -272,7 +272,8 @@ DIFFICULTIES = {
|
|
272 |
"difficulty_distrustful": {
|
273 |
"difficulty_label": "distrustful",
|
274 |
"description": [
|
275 |
-
"You don't trust the counselor, you will eventually cooperate.",
|
|
|
276 |
],
|
277 |
},
|
278 |
# "difficulty_stop_convo": {
|
|
|
272 |
"difficulty_distrustful": {
|
273 |
"difficulty_label": "distrustful",
|
274 |
"description": [
|
275 |
+
#"You don't trust the counselor, you will eventually cooperate.",
|
276 |
+
"You have a distrustful attitude towards the counselor.",
|
277 |
],
|
278 |
},
|
279 |
# "difficulty_stop_convo": {
|
@@ -48,6 +48,7 @@ def postprocess_text(
|
|
48 |
|
49 |
# Remove unnecessary role prefixes
|
50 |
text = text.replace(human_prefix, "").replace(assistant_prefix, "")
|
|
|
51 |
|
52 |
# Remove whispers or other marked reactions
|
53 |
whispers = re.compile(r"(\([\w\s]+\))") # remove things like "(whispers)"
|
@@ -55,8 +56,11 @@ def postprocess_text(
|
|
55 |
text = whispers.sub("", text)
|
56 |
text = reactions.sub("", text)
|
57 |
|
58 |
-
# Remove
|
59 |
-
text = text.replace('"', '')
|
|
|
|
|
|
|
60 |
|
61 |
# Normalize spaces
|
62 |
text = re.sub(r"\s+", " ", text).strip()
|
|
|
48 |
|
49 |
# Remove unnecessary role prefixes
|
50 |
text = text.replace(human_prefix, "").replace(assistant_prefix, "")
|
51 |
+
|
52 |
|
53 |
# Remove whispers or other marked reactions
|
54 |
whispers = re.compile(r"(\([\w\s]+\))") # remove things like "(whispers)"
|
|
|
56 |
text = whispers.sub("", text)
|
57 |
text = reactions.sub("", text)
|
58 |
|
59 |
+
# Remove double quotation marks
|
60 |
+
text = text.replace('"', '')
|
61 |
+
|
62 |
+
# Remove stutters of any length (e.g., "M-m-my" or "M-m-m-m-my" or "M-My" to "My")
|
63 |
+
text = re.sub(r'\b(\w)(-\1)+-\1(\w*)', r'\1\3', text, flags=re.IGNORECASE)
|
64 |
|
65 |
# Normalize spaces
|
66 |
text = re.sub(r"\s+", " ", text).strip()
|
@@ -16,15 +16,13 @@ texter:"""
|
|
16 |
|
17 |
def get_databricks_chain(source, issue, language, memory, temperature=0.8, texter_name="Kit"):
|
18 |
|
19 |
-
endpoint_name = ENDPOINT_NAMES.get(source, "texter_simulator")
|
20 |
-
|
21 |
PROMPT = PromptTemplate(
|
22 |
input_variables=['history', 'input'],
|
23 |
template=_DATABRICKS_TEMPLATE_
|
24 |
)
|
25 |
|
26 |
llm = CustomDatabricksLLM(
|
27 |
-
# endpoint_url="https://dbc-6dca8e8f-4084.cloud.databricks.com/serving-endpoints/databricks-meta-llama-3-1-70b-instruct/invocations",
|
28 |
endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name),
|
29 |
bearer_token=os.environ["DATABRICKS_TOKEN"],
|
30 |
texter_name=texter_name,
|
@@ -43,4 +41,17 @@ def get_databricks_chain(source, issue, language, memory, temperature=0.8, texte
|
|
43 |
)
|
44 |
|
45 |
logging.debug(f"loaded Databricks model")
|
46 |
-
return llm_chain, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
def get_databricks_chain(source, issue, language, memory, temperature=0.8, texter_name="Kit"):
|
18 |
|
19 |
+
endpoint_name = ENDPOINT_NAMES.get(source, "texter_simulator")['name']
|
|
|
20 |
PROMPT = PromptTemplate(
|
21 |
input_variables=['history', 'input'],
|
22 |
template=_DATABRICKS_TEMPLATE_
|
23 |
)
|
24 |
|
25 |
llm = CustomDatabricksLLM(
|
|
|
26 |
endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name),
|
27 |
bearer_token=os.environ["DATABRICKS_TOKEN"],
|
28 |
texter_name=texter_name,
|
|
|
41 |
)
|
42 |
|
43 |
logging.debug(f"loaded Databricks model")
|
44 |
+
return llm_chain, None
|
45 |
+
|
46 |
+
def cpc_is_alive():
|
47 |
+
body_request = {
|
48 |
+
"inputs": [""]
|
49 |
+
}
|
50 |
+
try:
|
51 |
+
# Send request to Serving
|
52 |
+
response = requests.post(url=CPC_URL, headers=HEADERS, json=body_request, timeout=2)
|
53 |
+
if response.status_code == 200:
|
54 |
+
return True
|
55 |
+
else: return False
|
56 |
+
except:
|
57 |
+
return False
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit.logger import get_logger
|
3 |
+
import requests
|
4 |
+
import os
|
5 |
+
from .config import model_name_or_path, BP_THRESHOLD
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
from utils.mongo_utils import new_bp_comparison
|
8 |
+
from app_config import ENDPOINT_NAMES
|
9 |
+
|
10 |
+
logger = get_logger(__name__)
|
11 |
+
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left")
|
13 |
+
BP_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["BadPractices"]['name'])
|
14 |
+
HEADERS = {
|
15 |
+
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
|
16 |
+
"Content-Type": "application/json",
|
17 |
+
}
|
18 |
+
|
19 |
+
def bp_predict_message(context, input):
|
20 |
+
# context = memory.load_memory_variables({})[memory.memory_key]
|
21 |
+
encoding = tokenizer(
|
22 |
+
context,
|
23 |
+
input,
|
24 |
+
truncation="only_first",
|
25 |
+
max_length = tokenizer.model_max_length - 2,
|
26 |
+
)['input_ids']
|
27 |
+
body_request = {
|
28 |
+
"inputs": [tokenizer.decode(encoding[1:-1])],
|
29 |
+
"params": {
|
30 |
+
"top_k": None
|
31 |
+
}
|
32 |
+
}
|
33 |
+
|
34 |
+
try:
|
35 |
+
# Send request to Serving
|
36 |
+
logger.debug(f"raw BP body is {body_request}")
|
37 |
+
response = requests.post(url=BP_URL, headers=HEADERS, json=body_request)
|
38 |
+
if response.status_code == 200:
|
39 |
+
response = response.json()['predictions'][0]
|
40 |
+
logger.debug(f"Raw BP prediction is {response}")
|
41 |
+
return [{k:v > BP_THRESHOLD if k=="score" else v for k,v in dict_.items()} for _, dict_ in response.items() ]
|
42 |
+
else:
|
43 |
+
raise Exception(f"Error in response: {response.json()}")
|
44 |
+
except Exception as e:
|
45 |
+
logger.debug(f"Error in response: {e}")
|
46 |
+
st.switch_page("pages/model_loader.py")
|
47 |
+
|
48 |
+
def bp_push2db(manual_confirmation=None):
|
49 |
+
if manual_confirmation is None:
|
50 |
+
if st.session_state.sel_bp == "Advice":
|
51 |
+
manual_confirmation = {"is_advice":True, "is_personal_info":False}
|
52 |
+
elif st.session_state.sel_bp == "Personal Info":
|
53 |
+
manual_confirmation = {"is_advice":False, "is_personal_info":True}
|
54 |
+
elif st.session_state.sel_bp == "Advice & Personal Info":
|
55 |
+
manual_confirmation = {"is_advice":True, "is_personal_info":True}
|
56 |
+
else:
|
57 |
+
manual_confirmation = {"is_advice":False, "is_personal_info":False}
|
58 |
+
new_bp_comparison(**{
|
59 |
+
"client": st.session_state['db_client'],
|
60 |
+
"convo_id": st.session_state['convo_id'],
|
61 |
+
"model": st.session_state['source'],
|
62 |
+
"context": st.session_state["context"],
|
63 |
+
"last_message": st.session_state["last_message"],
|
64 |
+
"ytrue": manual_confirmation,
|
65 |
+
"ypred": {x['label']:x['score'] for x in st.session_state['bp_prediction']},
|
66 |
+
})
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name_or_path = "FacebookAI/xlm-roberta-large"
|
2 |
+
|
3 |
+
CPC_LABEL2STR = {
|
4 |
+
"0_ActiveEngagement": "Active Engagement",
|
5 |
+
"1_Explore": "Explore",
|
6 |
+
"2_IRA": "Immidiate Risk Assessment",
|
7 |
+
"3_SafetyAssessment": "Safety Assessment",
|
8 |
+
"4_SP&NS": "Safety Planning & Next Steps",
|
9 |
+
"5_EmergencyIntervention": "Emergency Intervention",
|
10 |
+
"6_WrappingUp": "Wrapping Up",
|
11 |
+
"7_Other": "Other",
|
12 |
+
}
|
13 |
+
|
14 |
+
CPC_LBL_OPTS = list(CPC_LABEL2STR.keys())
|
15 |
+
|
16 |
+
def cpc_label2str(phase):
|
17 |
+
return CPC_LABEL2STR[phase]
|
18 |
+
|
19 |
+
def phase2int(phase):
|
20 |
+
return int(phase.split("_")[0])
|
21 |
+
|
22 |
+
BP_THRESHOLD = 0.7
|
23 |
+
BP_LAB2STR = {
|
24 |
+
"is_advice": "Advice",
|
25 |
+
"is_personal_info": "Personal Info Sharing",
|
26 |
+
}
|
27 |
+
|
28 |
+
QUESTION2PHASE = {
|
29 |
+
"question_1": ["0_ActiveEngagement","1_Explore"],
|
30 |
+
"question_4": ["1_Explore"],
|
31 |
+
"question_5": ["0_ActiveEngagement", "1_Explore"],
|
32 |
+
# "question_7": ["1_Explore"],
|
33 |
+
# "question_9": ["4_SP&NS"],
|
34 |
+
"question_10": ["4_SP&NS"],
|
35 |
+
# "question_11": ["4_SP&NS"],
|
36 |
+
"question_14": ["6_WrappingUp"],
|
37 |
+
# "question_15": ["ALL"],
|
38 |
+
"question_19": ["ALL"],
|
39 |
+
# "question_21": ["ALL"],
|
40 |
+
# "question_22": ["ALL"],
|
41 |
+
"question_23": ["2_IRA", "3_SafetyAssessment"],
|
42 |
+
}
|
43 |
+
|
44 |
+
QUESTION2FILTERARGS = {
|
45 |
+
"question_1": {
|
46 |
+
"phases": QUESTION2PHASE["question_1"],
|
47 |
+
"pre_n": 2,
|
48 |
+
"post_n": 8,
|
49 |
+
"ignore": ["7_Other"],
|
50 |
+
},
|
51 |
+
"question_4": {
|
52 |
+
"phases": QUESTION2PHASE["question_4"],
|
53 |
+
"pre_n": 5,
|
54 |
+
"post_n": 5,
|
55 |
+
"ignore": ["7_Other"],
|
56 |
+
},
|
57 |
+
"question_5": {
|
58 |
+
"phases": QUESTION2PHASE["question_5"],
|
59 |
+
"pre_n": 5,
|
60 |
+
"post_n": 5,
|
61 |
+
"ignore": ["7_Other"],
|
62 |
+
},
|
63 |
+
# "question_7": {
|
64 |
+
# "phases": QUESTION2PHASE["question_7"],
|
65 |
+
# "pre_n": 5,
|
66 |
+
# "post_n": 15,
|
67 |
+
# "ignore": ["7_Other"],
|
68 |
+
# },
|
69 |
+
# "question_9": {
|
70 |
+
# "phases": QUESTION2PHASE["question_9"],
|
71 |
+
# "pre_n": 5,
|
72 |
+
# "post_n": 5,
|
73 |
+
# "ignore": ["7_Other"],
|
74 |
+
# },
|
75 |
+
"question_10": {
|
76 |
+
"phases": QUESTION2PHASE["question_10"],
|
77 |
+
"pre_n": 5,
|
78 |
+
"post_n": 5,
|
79 |
+
"ignore": ["7_Other"],
|
80 |
+
},
|
81 |
+
# "question_11": {
|
82 |
+
# "phases": QUESTION2PHASE["question_11"],
|
83 |
+
# "pre_n": 5,
|
84 |
+
# "post_n": 5,
|
85 |
+
# "ignore": ["7_Other"],
|
86 |
+
# },
|
87 |
+
"question_14": {
|
88 |
+
"phases": QUESTION2PHASE["question_14"],
|
89 |
+
"pre_n": 10,
|
90 |
+
"post_n": 0,
|
91 |
+
"ignore": ["7_Other"],
|
92 |
+
},
|
93 |
+
# "question_15": {
|
94 |
+
# "phases": QUESTION2PHASE["question_15"],
|
95 |
+
# "pre_n": 5,
|
96 |
+
# "post_n": 5,
|
97 |
+
# "ignore": ["7_Other"],
|
98 |
+
# },
|
99 |
+
"question_19": {
|
100 |
+
"phases": QUESTION2PHASE["question_19"],
|
101 |
+
"pre_n": 5,
|
102 |
+
"post_n": 5,
|
103 |
+
"ignore": ["7_Other"],
|
104 |
+
},
|
105 |
+
# "question_21": {
|
106 |
+
# "phases": QUESTION2PHASE["question_21"],
|
107 |
+
# "pre_n": 5,
|
108 |
+
# "post_n": 5,
|
109 |
+
# "ignore": ["7_Other"],
|
110 |
+
# },
|
111 |
+
# "question_22": {
|
112 |
+
# "phases": QUESTION2PHASE["question_22"],
|
113 |
+
# "pre_n": 5,
|
114 |
+
# "post_n": 5,
|
115 |
+
# "ignore": ["7_Other"],
|
116 |
+
# },
|
117 |
+
"question_23": {
|
118 |
+
"phases": QUESTION2PHASE["question_23"],
|
119 |
+
"pre_n": 5,
|
120 |
+
"post_n": 5,
|
121 |
+
"ignore": ["7_Other"],
|
122 |
+
},
|
123 |
+
}
|
124 |
+
|
125 |
+
|
126 |
+
START_INST = "<|user|>"
|
127 |
+
END_INST = "<|end|>\n<|assistant|>"
|
128 |
+
|
129 |
+
NAME2QUESTION = {
|
130 |
+
"question_1": "Did the helper introduce themself in the opening message? Answer only Yes or No.",
|
131 |
+
"question_4": "Did the helper actively listened to the texter's crisis? Answer only Yes or No.",
|
132 |
+
"question_5": "Did the helper reflect on the main issue that led the texter reach out? Answer only Yes or No.",
|
133 |
+
# "question_7": "Did the helper collaborated with the texter to identify the goal of the conversation? Answer only Yes or No.",
|
134 |
+
# "question_9": "Did the helper collaborated with the texter to create next steps? Answer only Yes or No.",
|
135 |
+
"question_10": "Did the helper explored texter's existing coping skills? Answer only Yes or No.",
|
136 |
+
# "question_11": "Did the helper explored texter’s social support? Answer only Yes or No.",
|
137 |
+
"question_14": "Did helper reflected the texter’s plan, reiterate coping skills, and end in a supportive way? Answer only Yes or No.",
|
138 |
+
# "question_15": "Did the helper consistently used Good Contact Techniques? Answer only Yes or No.",
|
139 |
+
"question_19": "Did the helper consistently reflected empathy through the conversation? Answer only Yes or No.",
|
140 |
+
# "question_21": "Did the helper shared personal information? Answer only Yes or No.",
|
141 |
+
# "question_22": "Did the helper gave advice? Answer only Yes or No.",
|
142 |
+
"question_23": "Did the helper explicitely initiated imminent risk assessment? Answer only Yes or No.",
|
143 |
+
}
|
144 |
+
|
145 |
+
NAME2PROMPT = {
|
146 |
+
k: "--------Conversation:\n{convo}\n{start_inst}" + v + "\n{end_inst}"
|
147 |
+
for k, v in NAME2QUESTION.items()
|
148 |
+
}
|
149 |
+
|
150 |
+
NAME2PROMPT_EXPL = {
|
151 |
+
k: v.split("Answer only Yes or No.")[0] + "Answer Yes or No, and give an explanation in a new line.\n{end_inst}"
|
152 |
+
for k, v in NAME2PROMPT.items()
|
153 |
+
}
|
154 |
+
|
155 |
+
QUESTIONDEFAULTS = {
|
156 |
+
"question_1": {True: "No, There was no evidence of Active Engagement", False: "No"},
|
157 |
+
"question_4": {True: "No, There was no evidence of Exploration Phase", False: "No"},
|
158 |
+
"question_5": {True: "No, There was no evidence of Exploration Phase", False: "No"},
|
159 |
+
# "question_7": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
|
160 |
+
# "question_9": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
|
161 |
+
"question_10": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
|
162 |
+
# "question_11": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
|
163 |
+
"question_14": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
|
164 |
+
# "question_15": "Did the helper consistently used Good Contact Techniques? Answer only Yes or No.",
|
165 |
+
"question_19": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
|
166 |
+
# "question_21": "Did the helper shared personal information? Answer only Yes or No.",
|
167 |
+
# "question_22": "Did the helper gave advice? Answer only Yes or No.",
|
168 |
+
"question_23": {True: "No, There was no evidence of Imminent Risk Assessment", False: "No"},
|
169 |
+
}
|
170 |
+
|
171 |
+
TEXTER_PREFIX = "texter"
|
172 |
+
HELPER_PREFIX = "helper"
|
173 |
+
|
174 |
+
TA_OPTIONS = ["N/A", "No", "Yes"]
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit.logger import get_logger
|
3 |
+
import requests
|
4 |
+
import os
|
5 |
+
from .config import model_name_or_path
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
from utils.mongo_utils import new_cpc_comparison
|
8 |
+
from app_config import ENDPOINT_NAMES
|
9 |
+
|
10 |
+
logger = get_logger(__name__)
|
11 |
+
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left")
|
13 |
+
CPC_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["CPC"]['name'])
|
14 |
+
HEADERS = {
|
15 |
+
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
|
16 |
+
"Content-Type": "application/json",
|
17 |
+
}
|
18 |
+
|
19 |
+
def cpc_predict_message(context, input):
|
20 |
+
# context = memory.load_memory_variables({})[memory.memory_key]
|
21 |
+
encoding = tokenizer(
|
22 |
+
context,
|
23 |
+
input,
|
24 |
+
truncation="only_first",
|
25 |
+
max_length = tokenizer.model_max_length - 2,
|
26 |
+
)['input_ids']
|
27 |
+
body_request = {
|
28 |
+
"inputs": [tokenizer.decode(encoding[1:-1])]
|
29 |
+
}
|
30 |
+
|
31 |
+
try:
|
32 |
+
# Send request to Serving
|
33 |
+
response = requests.post(url=CPC_URL, headers=HEADERS, json=body_request)
|
34 |
+
if response.status_code == 200:
|
35 |
+
return response.json()['predictions'][0]["0"]["label"]
|
36 |
+
else:
|
37 |
+
raise Exception(f"Error in response: {response.json()}")
|
38 |
+
except Exception as e:
|
39 |
+
logger.debug(f"Error in response: {e}")
|
40 |
+
st.switch_page("pages/model_loader.py")
|
41 |
+
|
42 |
+
def cpc_push2db(is_same):
|
43 |
+
text_is_same = "SAME" if is_same else "WRONG"
|
44 |
+
logger.debug(f"pushing new {text_is_same} CPC")
|
45 |
+
new_cpc_comparison(**{
|
46 |
+
"client": st.session_state['db_client'],
|
47 |
+
"convo_id": st.session_state['convo_id'],
|
48 |
+
"model": st.session_state['source'],
|
49 |
+
"context": st.session_state["context"],
|
50 |
+
"last_message": st.session_state["last_message"],
|
51 |
+
"ytrue": st.session_state["last_phase"] if is_same else st.session_state["sel_phase"],
|
52 |
+
"ypred": st.session_state["last_phase"],
|
53 |
+
})
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from itertools import chain
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
possible_movements = [-1, 1]
|
8 |
+
|
9 |
+
|
10 |
+
def dfs(indexes: List[int], x0: int, i: int, cur_island: List[int], d=2):
|
11 |
+
"""Deep First Search Implementation for 2D movement.
|
12 |
+
To consider an Island only move one step left or right
|
13 |
+
See possible movements
|
14 |
+
|
15 |
+
Args:
|
16 |
+
indexes (List[int]): Indexes of positive examples. i.e [20,21,23,50,51]
|
17 |
+
x0 (int): Initial island anchor
|
18 |
+
i (int): Current index to test against anchor
|
19 |
+
cur_island (List[int]): Current Island from anchor
|
20 |
+
d (int, optional): Bounding distance to consider an island. Defaults to 2. For example
|
21 |
+
the list [20,21,23,50,51] has two islands with d=2: (20,21,23), and (50,51) but it has
|
22 |
+
three islands with d=: (20,21), (23), and (50,51)
|
23 |
+
"""
|
24 |
+
rows = len(indexes)
|
25 |
+
if i < 0 or i >= rows:
|
26 |
+
return
|
27 |
+
if indexes[i] in cur_island:
|
28 |
+
return
|
29 |
+
if abs(indexes[x0] - indexes[i]) > d:
|
30 |
+
return
|
31 |
+
# computing coordinates with x0 as base
|
32 |
+
cur_island.append(indexes[i])
|
33 |
+
|
34 |
+
# repeat dfs for neighbors
|
35 |
+
for movement in possible_movements:
|
36 |
+
dfs(indexes, i, i + movement, cur_island, d)
|
37 |
+
|
38 |
+
|
39 |
+
def get_list_islands(indexes: List[int], **kwargs) -> List[List[int]]:
|
40 |
+
"""Wrapper over DFS method to obtain islands from list of indexes of positive examples
|
41 |
+
|
42 |
+
Args:
|
43 |
+
indexes (List[int]): Indexes of positive examples. i.e [20,21,23,50,51]
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
List[List[int]]: List of islands (each being a list)
|
47 |
+
"""
|
48 |
+
islands = []
|
49 |
+
rows = len(indexes)
|
50 |
+
if rows == 0:
|
51 |
+
return islands
|
52 |
+
|
53 |
+
for i, valuei in enumerate(indexes):
|
54 |
+
# If already visited index in another dfs continue
|
55 |
+
if valuei in list(chain.from_iterable(islands)):
|
56 |
+
continue
|
57 |
+
# to hold coordinates of new island
|
58 |
+
cur_island = []
|
59 |
+
dfs(indexes, i, i, cur_island, **kwargs)
|
60 |
+
|
61 |
+
islands.append(cur_island)
|
62 |
+
|
63 |
+
return islands
|
64 |
+
|
65 |
+
|
66 |
+
def get_phases_islands_minmax(
|
67 |
+
convo: pd.DataFrame,
|
68 |
+
phases: List[str],
|
69 |
+
column: str = "convo_part",
|
70 |
+
ignore: List[str] = [],
|
71 |
+
**kwargs,
|
72 |
+
) -> List[Tuple[int]]:
|
73 |
+
"""Given a conversation with predicted Phases (or Parts), get minimum and maximum index of calculated islands.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
convo (pd.DataFrame): Conversation with predicted phases stored in `column`
|
77 |
+
phases (List[str]): Phases to filter in
|
78 |
+
column (str, optional): Column where predicted phases information is stored. Defaults to "convo_part".
|
79 |
+
ignore (List[str], optional): Ignore phases list. Defaults to [].
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
List[Tuple[int]]: Minimum and Maximum values of calulated islands. i.e [(20,30), (40,60)]
|
83 |
+
"""
|
84 |
+
|
85 |
+
reset = convo.query(f"{column}=={column} and {column} not in @ignore").reset_index()
|
86 |
+
sub_ = reset.query(f"{column} in @phases").copy()
|
87 |
+
indexes = sub_.index.tolist()
|
88 |
+
islands = get_list_islands(indexes, **kwargs)
|
89 |
+
if len(islands) > 1:
|
90 |
+
# If there is more than one island we want to make sure to root out comparable small islands
|
91 |
+
# I.e. if there is an island with 10 messages, and island of 1 messages is not useful in that context.
|
92 |
+
max_len = np.max([len(x) for x in islands])
|
93 |
+
len_cut = 3 if max_len > 9 else 2 if max_len > 3 else 1
|
94 |
+
islands = [x for x in islands if len(x) > len_cut]
|
95 |
+
|
96 |
+
islands = [reset.iloc[x] for x in islands]
|
97 |
+
minmax_islands = [(x["index"].min(), x["index"].max()) for x in islands]
|
98 |
+
|
99 |
+
return minmax_islands
|
100 |
+
|
101 |
+
|
102 |
+
def filter_convo(
|
103 |
+
convo: pd.DataFrame,
|
104 |
+
phases: List[str],
|
105 |
+
column: str = "convo_part",
|
106 |
+
strategy: str = "islands",
|
107 |
+
pre_n: int = 5,
|
108 |
+
post_n: int = 5,
|
109 |
+
return_all_on_empty: bool = False,
|
110 |
+
**kwargs,
|
111 |
+
) -> pd.DataFrame:
|
112 |
+
"""Filter convo to include only specified phases. Take into account that sometimes predicted phases
|
113 |
+
can be messy. i.e. a prediciton of explore, explore, explore, safety_planning, explore; should return all
|
114 |
+
these messages as explore (probably safety_planning message has a low probability here.)
|
115 |
+
|
116 |
+
Args:
|
117 |
+
convo (pd.DataFrame): Conversation with predicted phases stored in `column`
|
118 |
+
phases (List[str]): Phases to filter in
|
119 |
+
column (str, optional): Column where predicted phases information is stored. Defaults to "convo_part".
|
120 |
+
strategy (str, optional): Strategy to use, can be minmax or islands. Defaults to "islands".
|
121 |
+
pre_n (int, optional): How many messages pre-phase to include. Defaults to 5.
|
122 |
+
post_n (int, optional): How many messages post-phase to include. Defaults to 5.
|
123 |
+
return_all_on_empty (bool, optional): Whether to return all messages when specified phases is not found. Defaults to False.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
pd.DataFrame: Filtered messages from the convo
|
127 |
+
"""
|
128 |
+
if phases == ["ALL"]:
|
129 |
+
minidx = convo.index.min()
|
130 |
+
maxidx = convo.index.max()
|
131 |
+
minmax = [(minidx, maxidx)]
|
132 |
+
elif strategy == "minmax":
|
133 |
+
minidx = convo.query(f"{column} in @phases").index.min()
|
134 |
+
maxidx = convo.query(f"{column} in @phases").index.max() + 1
|
135 |
+
minmax = [(minidx, maxidx)]
|
136 |
+
elif strategy == "islands":
|
137 |
+
minmax = get_phases_islands_minmax(convo, phases, column, **kwargs)
|
138 |
+
parts = []
|
139 |
+
for minidx, maxidx in minmax:
|
140 |
+
minidx = max(convo.index.min(), minidx - pre_n)
|
141 |
+
maxidx = min(convo.index.max(), maxidx + post_n)
|
142 |
+
parts.append(convo.loc[minidx:maxidx])
|
143 |
+
if len(parts) == 0:
|
144 |
+
if return_all_on_empty:
|
145 |
+
return convo
|
146 |
+
else:
|
147 |
+
return pd.DataFrame(columns=convo.columns)
|
148 |
+
filtered = pd.concat(parts)
|
149 |
+
filtered = filtered[~filtered.index.duplicated(keep="first")]
|
150 |
+
return filtered
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
from .config import QUESTION2FILTERARGS, TEXTER_PREFIX, HELPER_PREFIX
|
6 |
+
|
7 |
+
# Utils to filter convo according to a phase
|
8 |
+
from .ta_filter_utils import filter_convo
|
9 |
+
|
10 |
+
|
11 |
+
def join_messages(
|
12 |
+
grp: pd.DataFrame, texter_prefix: str = "texter", helper_prefix: str = "helper"
|
13 |
+
) -> str:
|
14 |
+
"""join messages from dataframe using texter an helper prefixes
|
15 |
+
|
16 |
+
Args:
|
17 |
+
grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**.
|
18 |
+
Must have the following columns:
|
19 |
+
- actor_role
|
20 |
+
- message
|
21 |
+
|
22 |
+
texter_prefix (str, optional): prefix to use as the texter. Defaults to "texter".
|
23 |
+
helper_prefix (str, optional): prefix to use as the counselor (helper). Defaults to "helper".
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
str: joined messages string separated by prefixes
|
27 |
+
"""
|
28 |
+
|
29 |
+
if "actor_role" not in grp:
|
30 |
+
raise Exception("Column 'actor_role' not in DataFrame")
|
31 |
+
if "message" not in grp:
|
32 |
+
raise Exception("Column 'message' not in DataFrame")
|
33 |
+
|
34 |
+
roles = grp.actor_role.replace(
|
35 |
+
{"texter": texter_prefix, "counselor": helper_prefix, "helper": helper_prefix}
|
36 |
+
)
|
37 |
+
messages = roles.str.strip() + ": " + grp.message.str.strip()
|
38 |
+
return "\n".join(messages)
|
39 |
+
|
40 |
+
|
41 |
+
def _get_context(grp: pd.DataFrame, **kwargs) -> str:
|
42 |
+
"""Get context as a str taking into account message to delete, context marker
|
43 |
+
and the type of question to use. This allows for better truncation later
|
44 |
+
|
45 |
+
Args:
|
46 |
+
grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**.
|
47 |
+
Must have the following columns:
|
48 |
+
- actor_role
|
49 |
+
- message
|
50 |
+
- `column`
|
51 |
+
column (str): column name in which the marker of the problem is
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
pd.DataFrame: joined messages string separated by prefixes
|
55 |
+
"""
|
56 |
+
|
57 |
+
if "actor_role" not in grp:
|
58 |
+
raise Exception("Column 'actor_role' not in DataFrame")
|
59 |
+
if "message" not in grp:
|
60 |
+
raise Exception("Column 'message' not in DataFrame")
|
61 |
+
|
62 |
+
join_args = list(inspect.signature(join_messages).parameters)
|
63 |
+
join_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in join_args}
|
64 |
+
|
65 |
+
## DEPRECATED
|
66 |
+
# context_args = list(inspect.signature(get_context_on_marker).parameters)
|
67 |
+
# context_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in context_args}
|
68 |
+
|
69 |
+
return join_messages(grp, **join_kwargs)
|
70 |
+
|
71 |
+
|
72 |
+
def load_context(
|
73 |
+
messages: pd.DataFrame,
|
74 |
+
question: str,
|
75 |
+
message_col: str,
|
76 |
+
col_type: str,
|
77 |
+
inference: bool = False,
|
78 |
+
**kwargs,
|
79 |
+
) -> pd.DataFrame:
|
80 |
+
"""Load and filter conversation from messages given a question (with configured parameters of what phase that question is answered)
|
81 |
+
|
82 |
+
Args:
|
83 |
+
messages (pd.DataFrame): Messages dataframe with conversation_id, actor_role, `message_col` and phase prediction
|
84 |
+
question (str): Question to get context to
|
85 |
+
message_col (str): Column where messages are
|
86 |
+
col_type (str): type of message_col, can be "individual" or "joined"
|
87 |
+
base_dir (str, optional): Base directory to find model base args. Defaults to "../../".
|
88 |
+
|
89 |
+
Raises:
|
90 |
+
Exception: If question is not supported
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
pd.DataFrame: filtered messages according to question configuration
|
94 |
+
"""
|
95 |
+
|
96 |
+
if question not in QUESTION2FILTERARGS:
|
97 |
+
raise Exception(f"Question {question} not supported")
|
98 |
+
|
99 |
+
texter_prefix = TEXTER_PREFIX
|
100 |
+
helper_prefix = HELPER_PREFIX
|
101 |
+
context_data = messages.copy()
|
102 |
+
|
103 |
+
def convo_cpc_get_context(grp, **kwargs):
|
104 |
+
"""Filter convo according to Convo Phase Classifier (CPC) predictions"""
|
105 |
+
context_ = filter_convo(grp, **QUESTION2FILTERARGS[question])
|
106 |
+
return _get_context(context_, **kwargs)
|
107 |
+
|
108 |
+
if col_type == "individual":
|
109 |
+
if "actor_role" in context_data:
|
110 |
+
context_data.dropna(subset=["actor_role"], inplace=True)
|
111 |
+
if "delete_message" in context_data:
|
112 |
+
context_data.delete_message.replace({1: True}, inplace=True)
|
113 |
+
context_data.delete_message.fillna(False, inplace=True)
|
114 |
+
|
115 |
+
context_data = (
|
116 |
+
context_data.groupby("conversation_id")
|
117 |
+
.apply(
|
118 |
+
convo_cpc_get_context,
|
119 |
+
helper_prefix=helper_prefix,
|
120 |
+
texter_prefix=texter_prefix,
|
121 |
+
)
|
122 |
+
.rename("q_context")
|
123 |
+
)
|
124 |
+
elif col_type == "joined":
|
125 |
+
context_data = context_data.groupby("conversation_id")[[message_col]].max()
|
126 |
+
context_data.rename(columns={message_col: "q_context"}, inplace=True)
|
127 |
+
|
128 |
+
return context_data
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import requests
|
4 |
+
import string
|
5 |
+
import streamlit as st
|
6 |
+
from streamlit.logger import get_logger
|
7 |
+
from app_config import ENDPOINT_NAMES
|
8 |
+
from models.ta_models.config import NAME2PROMPT, NAME2PROMPT_EXPL, START_INST, END_INST, QUESTIONDEFAULTS, TA_OPTIONS, NAME2QUESTION
|
9 |
+
import pandas as pd
|
10 |
+
from langchain_core.messages import AIMessage, HumanMessage
|
11 |
+
from models.ta_models.ta_prompt_utils import load_context
|
12 |
+
from utils.mongo_utils import new_convo_scoring_comparison
|
13 |
+
|
14 |
+
logger = get_logger(__name__)
|
15 |
+
TA_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["training_adherence"]['name'])
|
16 |
+
HEADERS = {
|
17 |
+
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
|
18 |
+
"Content-Type": "application/json",
|
19 |
+
}
|
20 |
+
|
21 |
+
def memory2df(memory, conversation_id="convo1234"):
|
22 |
+
df = []
|
23 |
+
for i, msg in enumerate(memory.buffer_as_messages):
|
24 |
+
actor_role = "texter" if type(msg) == AIMessage else "helper" if type(msg) == HumanMessage else None
|
25 |
+
if actor_role:
|
26 |
+
convo_part = msg.response_metadata.get("phase",None)
|
27 |
+
row = {"conversation_id":conversation_id, "message_number":i+1, "actor_role":actor_role, "message":msg.content, "convo_part":convo_part}
|
28 |
+
df.append(row)
|
29 |
+
|
30 |
+
return pd.DataFrame(df)
|
31 |
+
|
32 |
+
def get_default(question, make_explanation=False):
|
33 |
+
return QUESTIONDEFAULTS[question][make_explanation]
|
34 |
+
|
35 |
+
def get_context(memory, question, make_explanation=False, **kwargs):
|
36 |
+
df = memory2df(memory, **kwargs)
|
37 |
+
contexti = load_context(df, question, "messages", "individual").iloc[0]
|
38 |
+
if contexti == "":
|
39 |
+
return ""
|
40 |
+
|
41 |
+
if make_explanation:
|
42 |
+
return NAME2PROMPT_EXPL[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST)
|
43 |
+
else:
|
44 |
+
return NAME2PROMPT[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST)
|
45 |
+
|
46 |
+
def post_process_response(full_response, delimiter="\n\n", n=2):
|
47 |
+
parts = full_response.split(delimiter)[:n]
|
48 |
+
response = extract_response(parts[0])
|
49 |
+
logger.debug(f"Response extracted is {response}")
|
50 |
+
if len(parts) > 1:
|
51 |
+
if len(parts[0]) < len(parts[1]):
|
52 |
+
full_response = parts[1]
|
53 |
+
else: full_response = parts[0]
|
54 |
+
else:
|
55 |
+
full_response = parts[0]
|
56 |
+
explanation = full_response.lstrip(response).lstrip(string.punctuation)
|
57 |
+
explanation = explanation.strip()
|
58 |
+
logger.debug(f"Explanation extracted is {explanation}")
|
59 |
+
return response, explanation
|
60 |
+
|
61 |
+
def TA_predict_convo(memory, question, make_explanation=False, **kwargs):
|
62 |
+
full_convo = memory.load_memory_variables({})[memory.memory_key]
|
63 |
+
PROMPT = get_context(memory, question, make_explanation=False, **kwargs)
|
64 |
+
logger.debug(f"Raw TA prompt is {PROMPT}")
|
65 |
+
if PROMPT == "":
|
66 |
+
full_response = get_default(question, make_explanation)
|
67 |
+
return full_convo, PROMPT, full_response
|
68 |
+
|
69 |
+
body_request = {
|
70 |
+
"prompt": PROMPT,
|
71 |
+
"temperature": 0,
|
72 |
+
"max_tokens": 3,
|
73 |
+
}
|
74 |
+
|
75 |
+
try:
|
76 |
+
# Send request to Serving
|
77 |
+
response = requests.post(url=TA_URL, headers=HEADERS, json=body_request)
|
78 |
+
if response.status_code == 200:
|
79 |
+
response = response.json()
|
80 |
+
else:
|
81 |
+
raise Exception(f"Error in response: {response.json()}")
|
82 |
+
full_response = response[0]['choices'][0]['text']
|
83 |
+
if not make_explanation:
|
84 |
+
return full_convo, PROMPT, full_response
|
85 |
+
else:
|
86 |
+
extract_response, _ = post_process_response(full_response)
|
87 |
+
PROMPT = get_context(memory, question, make_explanation=True, **kwargs)
|
88 |
+
PROMPT = PROMPT + f" {extract_response}"
|
89 |
+
logger.debug(f"Raw TA prompt for Explanation is {PROMPT}")
|
90 |
+
body_request["prompt"] = PROMPT
|
91 |
+
body_request["max_tokens"] = 128
|
92 |
+
response_expl = requests.post(url=TA_URL, headers=HEADERS, json=body_request)
|
93 |
+
if response_expl.status_code == 200:
|
94 |
+
response_expl = response_expl.json()
|
95 |
+
else:
|
96 |
+
raise Exception(f"Error in response: {response_expl.json()}")
|
97 |
+
full_response_expl = f"{extract_response} {response_expl[0]['choices'][0]['text']}"
|
98 |
+
return full_convo, PROMPT, full_response_expl
|
99 |
+
except Exception as e:
|
100 |
+
logger.debug(f"Error in response: {e}")
|
101 |
+
st.switch_page("pages/model_loader.py")
|
102 |
+
|
103 |
+
def extract_response(x: str, default: str = TA_OPTIONS[0]) -> str:
|
104 |
+
"""Extract Response from generated answer
|
105 |
+
Extract only search strings
|
106 |
+
|
107 |
+
Args:
|
108 |
+
x (str): prediction
|
109 |
+
default (str, optional): default in case no response founds. Defaults to "N/A".
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
str: _description_
|
113 |
+
"""
|
114 |
+
|
115 |
+
try:
|
116 |
+
return re.findall("|".join(TA_OPTIONS), x)[0]
|
117 |
+
except Exception:
|
118 |
+
return default
|
119 |
+
|
120 |
+
def ta_push_convo_comparison(ytrue, ypred):
|
121 |
+
new_convo_scoring_comparison(**{
|
122 |
+
"client": st.session_state['db_client'],
|
123 |
+
"convo_id": st.session_state['convo_id'],
|
124 |
+
"context": st.session_state["context"] + "\nhelper:" + st.session_state["last_message"],
|
125 |
+
"ytrue": ytrue,
|
126 |
+
"ypred": ypred,
|
127 |
+
})
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
from streamlit.logger import get_logger
|
4 |
+
from langchain.schema.messages import HumanMessage
|
5 |
+
from utils.mongo_utils import get_db_client
|
6 |
+
from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF, are_models_alive
|
7 |
+
from utils.memory_utils import clear_memory, push_convo2db
|
8 |
+
from utils.chain_utils import get_chain, custom_chain_predict
|
9 |
+
from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
|
10 |
+
from models.ta_models.config import CPC_LBL_OPTS, cpc_label2str, BP_LAB2STR
|
11 |
+
from models.ta_models.cpc_utils import cpc_push2db
|
12 |
+
from models.ta_models.bp_utils import bp_predict_message, bp_push2db
|
13 |
+
|
14 |
+
logger = get_logger(__name__)
|
15 |
+
temperature = 0.8
|
16 |
+
# username = "barb-chase" #"ivnban-ctl"
|
17 |
+
st.set_page_config(page_title="Conversation Simulator")
|
18 |
+
|
19 |
+
if "sent_messages" not in st.session_state:
|
20 |
+
st.session_state['sent_messages'] = 0
|
21 |
+
if not are_models_alive():
|
22 |
+
st.switch_page("pages/model_loader.py")
|
23 |
+
|
24 |
+
if "total_messages" not in st.session_state:
|
25 |
+
st.session_state['total_messages'] = 0
|
26 |
+
if "issue" not in st.session_state:
|
27 |
+
st.session_state['issue'] = ISSUES[0]
|
28 |
+
if 'previous_source' not in st.session_state:
|
29 |
+
st.session_state['previous_source'] = SOURCES[0]
|
30 |
+
if 'db_client' not in st.session_state:
|
31 |
+
st.session_state["db_client"] = get_db_client()
|
32 |
+
if 'texter_name' not in st.session_state:
|
33 |
+
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
|
34 |
+
logger.debug(f"texter name is {st.session_state['texter_name']}")
|
35 |
+
if "last_phase" not in st.session_state:
|
36 |
+
st.session_state["last_phase"] = CPC_LBL_OPTS[0]
|
37 |
+
# st.session_state["sel_phase"] = CPC_LBL_OPTS[0]
|
38 |
+
if "changed_cpc" not in st.session_state:
|
39 |
+
st.session_state["changed_cpc"] = False
|
40 |
+
if "changed_bp" not in st.session_state:
|
41 |
+
st.session_state["changed_bp"] = False
|
42 |
+
|
43 |
+
# st.session_state["sel_phase"] = st.session_state["last_phase"]
|
44 |
+
|
45 |
+
memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}}
|
46 |
+
|
47 |
+
with st.sidebar:
|
48 |
+
username = st.text_input("Username", value='Dani', max_chars=30)
|
49 |
+
if 'counselor_name' not in st.session_state:
|
50 |
+
st.session_state["counselor_name"] = username #get_random_name(names_df=DEFAULT_NAMES_DF)
|
51 |
+
# temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
|
52 |
+
issue = st.selectbox("Select a Scenario", ISSUES, index=ISSUES.index(st.session_state['issue']), format_func=issue2label,
|
53 |
+
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
54 |
+
)
|
55 |
+
supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
|
56 |
+
language = st.selectbox("Select a Language", supported_languages, index=0,
|
57 |
+
format_func=lambda x: "English" if x=="en" else "Spanish",
|
58 |
+
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
59 |
+
)
|
60 |
+
|
61 |
+
source = st.selectbox("Select a source Model A", SOURCES, index=0,
|
62 |
+
format_func=source2label, key="source"
|
63 |
+
)
|
64 |
+
|
65 |
+
changed_source = any([
|
66 |
+
st.session_state['previous_source'] != source,
|
67 |
+
st.session_state['issue'] != issue,
|
68 |
+
st.session_state['counselor_name'] != username,
|
69 |
+
])
|
70 |
+
if changed_source:
|
71 |
+
st.session_state["counselor_name"] = username
|
72 |
+
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
|
73 |
+
logger.debug(f"texter name is {st.session_state['texter_name']}")
|
74 |
+
st.session_state['previous_source'] = source
|
75 |
+
st.session_state['issue'] = issue
|
76 |
+
st.session_state['sent_messages'] = 0
|
77 |
+
st.session_state['total_messages'] = 0
|
78 |
+
create_memory_add_initial_message(memories,
|
79 |
+
issue,
|
80 |
+
language,
|
81 |
+
changed_source=changed_source,
|
82 |
+
counselor_name=st.session_state["counselor_name"],
|
83 |
+
texter_name=st.session_state["texter_name"])
|
84 |
+
st.session_state['previous_source'] = source
|
85 |
+
memoryA = st.session_state[list(memories.keys())[0]]
|
86 |
+
# issue only without "." marker for model compatibility
|
87 |
+
llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
|
88 |
+
|
89 |
+
st.title("💬 Simulator")
|
90 |
+
st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
|
91 |
+
for msg in memoryA.buffer_as_messages:
|
92 |
+
role = "user" if type(msg) == HumanMessage else "assistant"
|
93 |
+
st.chat_message(role).write(msg.content)
|
94 |
+
|
95 |
+
def sent_request_llm(llm_chain, prompt):
|
96 |
+
st.session_state['sent_messages'] += 1
|
97 |
+
st.chat_message("user").write(prompt)
|
98 |
+
responses = custom_chain_predict(llm_chain, prompt, stopper)
|
99 |
+
for response in responses:
|
100 |
+
st.chat_message("assistant").write(response)
|
101 |
+
|
102 |
+
# @st.dialog("Bad Practice Detected")
|
103 |
+
# def confirm_bp(bp_prediction, prompt):
|
104 |
+
# bps = [BP_LAB2STR[x['label']] for x in bp_prediction if x['score']]
|
105 |
+
# st.markdown(f"The last message was considered :red[{' and '.join(bps)}]")
|
106 |
+
# "Are you sure you want to send this message?"
|
107 |
+
# newprompt = st.text_input("Change message to:")
|
108 |
+
# "If you do not want to change leave textbox empty"
|
109 |
+
# for bp in BP_LAB2STR.keys():
|
110 |
+
# _ = st.checkbox(f"Original Message was {BP_LAB2STR[bp]}", key=f"chkbx_{bp}", value=BP_LAB2STR[bp] in bps)
|
111 |
+
|
112 |
+
# if st.button("Confirm"):
|
113 |
+
# if newprompt is not None and newprompt != "":
|
114 |
+
# prompt = newprompt
|
115 |
+
# bp_push2db(
|
116 |
+
# {bp:st.session_state[f"chkbx_{bp}"] for bp in BP_LAB2STR.keys()}
|
117 |
+
# )
|
118 |
+
# sent_request_llm(llm_chain, prompt)
|
119 |
+
# st.rerun()
|
120 |
+
|
121 |
+
if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG_COUNT - 4): #account for next interaction
|
122 |
+
if 'convo_id' not in st.session_state:
|
123 |
+
push_convo2db(memories, username, language)
|
124 |
+
st.session_state['context'] = llm_chain.memory.load_memory_variables({})[llm_chain.memory.memory_key]
|
125 |
+
st.session_state['last_message'] = prompt
|
126 |
+
if (not st.session_state.changed_cpc) and st.session_state["sent_messages"] > 0:
|
127 |
+
cpc_push2db(True)
|
128 |
+
else: st.session_state.changed_cpc = False
|
129 |
+
if (not st.session_state.changed_bp) and st.session_state["sent_messages"] > 0:
|
130 |
+
bp_push2db({x['label']:x['score'] for x in st.session_state['bp_prediction']})
|
131 |
+
else: st.session_state.changed_bp = False
|
132 |
+
|
133 |
+
context = llm_chain.memory.load_memory_variables({})[llm_chain.memory.memory_key]
|
134 |
+
st.session_state['bp_prediction'] = bp_predict_message(context, prompt)
|
135 |
+
if any([x['score'] for x in st.session_state['bp_prediction']]):
|
136 |
+
for bp in st.session_state['bp_prediction']:
|
137 |
+
if bp["score"]:
|
138 |
+
st.toast(f"Detected {BP_LAB2STR[bp['label']]} in the last message!", icon=":material/warning:")
|
139 |
+
|
140 |
+
sent_request_llm(llm_chain, prompt)
|
141 |
+
# else:
|
142 |
+
# sent_request_llm(llm_chain, prompt)
|
143 |
+
|
144 |
+
with st.sidebar:
|
145 |
+
st.divider()
|
146 |
+
st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
|
147 |
+
st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]")
|
148 |
+
# st.markdown()
|
149 |
+
def on_change_cpc():
|
150 |
+
cpc_push2db(False)
|
151 |
+
st.session_state.changed_cpc = True
|
152 |
+
def on_change_bp():
|
153 |
+
bp_push2db()
|
154 |
+
st.session_state.changed_bp = True
|
155 |
+
|
156 |
+
if st.session_state["sent_messages"] > 0:
|
157 |
+
_ = st.selectbox(f"""Last Human Message was considered :blue[**{
|
158 |
+
cpc_label2str(st.session_state['last_phase'])
|
159 |
+
}**]. If not please select from the following options""",
|
160 |
+
|
161 |
+
CPC_LBL_OPTS, index=None,format_func=cpc_label2str, on_change=on_change_cpc,
|
162 |
+
key="sel_phase",
|
163 |
+
)
|
164 |
+
|
165 |
+
BPs = [BP_LAB2STR[x['label']] for x in st.session_state['bp_prediction'] if x['score']]
|
166 |
+
selecttitle = f"""Last Human Message was considered :blue[**{
|
167 |
+
" and ".join(BPs)
|
168 |
+
}**].""" if len(BPs) > 0 else "Last Human Message was NOT considered Bad Practice."
|
169 |
+
_ = st.selectbox(selecttitle + " If not please select from the following options""",
|
170 |
+
|
171 |
+
["None", "Advice", "Personal Info", "Advice & Personal Info"], index=None, on_change=on_change_bp,
|
172 |
+
key="sel_bp"
|
173 |
+
)
|
174 |
+
|
175 |
+
if st.button("Score Conversation"):
|
176 |
+
st.switch_page("pages/training_adherence.py")
|
177 |
+
|
178 |
+
st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
|
179 |
+
if st.session_state['total_messages'] >= MAX_MSG_COUNT:
|
180 |
+
st.toast(f"Total of {MAX_MSG_COUNT} Messages reached. Conversation Ended", icon=":material/verified:")
|
181 |
+
elif st.session_state['total_messages'] >= WARN_MSG_COUT:
|
182 |
+
st.toast(f"The conversation will end at {MAX_MSG_COUNT} Total Messages ", icon=":material/warning:")
|
183 |
+
|
184 |
+
if not are_models_alive():
|
185 |
+
st.switch_page("pages/model_loader.py")
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import streamlit as st
|
3 |
+
from streamlit.logger import get_logger
|
4 |
+
from utils.app_utils import is_model_alive
|
5 |
+
from app_config import ENDPOINT_NAMES
|
6 |
+
|
7 |
+
logger = get_logger(__name__)
|
8 |
+
|
9 |
+
st.set_page_config(page_title="Conversation Simulator")
|
10 |
+
|
11 |
+
models_alive = False
|
12 |
+
start = time.time()
|
13 |
+
|
14 |
+
MODELS2LOAD = {
|
15 |
+
"CPC": {"model_name": "Phase Classifier", "loaded":None,},
|
16 |
+
"CTL_llama3": {"model_name": "Texter Simulator", "loaded":None,},
|
17 |
+
"BadPractices": {"model_name": "Advice Identificator", "loaded":None},
|
18 |
+
"training_adherence": {"model_name": "Training Adherence", "loaded":None},
|
19 |
+
}
|
20 |
+
|
21 |
+
def write_model_status(writer, model_name, loaded, fail=False):
|
22 |
+
if loaded == "200":
|
23 |
+
writer.write(f"✅ - {model_name} Loaded")
|
24 |
+
if fail:
|
25 |
+
if loaded in ["400", "500"]:
|
26 |
+
writer.write(f"❌ - {model_name} Failed to Load, Contact [email protected]")
|
27 |
+
elif loaded == "404":
|
28 |
+
writer.write(f"❌ - {model_name} Still loading, please try in a couple of minutes")
|
29 |
+
else:
|
30 |
+
writer.write(f"🔄 - {model_name} Loading")
|
31 |
+
|
32 |
+
with st.status("Loading Models Please Wait...(this may take up to 5 min)", expanded=True) as status:
|
33 |
+
|
34 |
+
for k in MODELS2LOAD.keys():
|
35 |
+
MODELS2LOAD[k]["writer"] = st.empty()
|
36 |
+
|
37 |
+
while not models_alive:
|
38 |
+
time.sleep(2)
|
39 |
+
for name, config in MODELS2LOAD.items():
|
40 |
+
config["loaded"] = is_model_alive(**ENDPOINT_NAMES[name])
|
41 |
+
|
42 |
+
models_alive = all([x['loaded']=="200" for x in MODELS2LOAD.values()])
|
43 |
+
|
44 |
+
for _, config in MODELS2LOAD.items():
|
45 |
+
write_model_status(**config)
|
46 |
+
|
47 |
+
if int(time.time()-start) > 30:
|
48 |
+
status.update(
|
49 |
+
label="Models took too long to load. Please Refresh Page in a couple of minutes", state="error", expanded=True
|
50 |
+
)
|
51 |
+
for _, config in MODELS2LOAD.items():
|
52 |
+
write_model_status(**config, fail=True)
|
53 |
+
break
|
54 |
+
|
55 |
+
if models_alive:
|
56 |
+
st.switch_page("pages/convosim.py")
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
from collections import defaultdict
|
4 |
+
from langchain_core.messages import HumanMessage
|
5 |
+
from utils.app_utils import are_models_alive
|
6 |
+
from models.ta_models.ta_utils import TA_predict_convo, ta_push_convo_comparison, post_process_response
|
7 |
+
from models.ta_models.config import QUESTION2PHASE, NAME2QUESTION, TA_OPTIONS
|
8 |
+
|
9 |
+
st.set_page_config(page_title="Conversation Simulator - Scoring")
|
10 |
+
|
11 |
+
if not are_models_alive():
|
12 |
+
st.switch_page("pages/model_loader.py")
|
13 |
+
|
14 |
+
if "memory" not in st.session_state:
|
15 |
+
st.switch_page("pages/convosim.py")
|
16 |
+
|
17 |
+
memory = st.session_state['memory']
|
18 |
+
progress_text = "Scoring Conversation using AI models ..."
|
19 |
+
|
20 |
+
@st.cache_data(show_spinner=False)
|
21 |
+
def get_ta_responses():
|
22 |
+
my_bar = st.progress(0, text=progress_text)
|
23 |
+
data = defaultdict(defaultdict)
|
24 |
+
for i, question in enumerate(QUESTION2PHASE.keys()):
|
25 |
+
# responses = ["Yes, The helper showed some respect.",
|
26 |
+
# "Yes. The helper is good! No doubt",
|
27 |
+
# "N/A, Texter disengaged.",
|
28 |
+
# "No. While texter is trying is lacking.",
|
29 |
+
# "No \n\n This is an explanation."]
|
30 |
+
# full_response = np.random.choice(responses)
|
31 |
+
full_convo, prompt, full_response = TA_predict_convo(memory, question, make_explanation=True, conversation_id=st.session_state['convo_id'])
|
32 |
+
response, explanation = post_process_response(full_response)
|
33 |
+
data[question]["response"] = response
|
34 |
+
data[question]["explanation"] = explanation
|
35 |
+
my_bar.progress((i+1) / len(QUESTION2PHASE.keys()), text = progress_text)
|
36 |
+
import time
|
37 |
+
time.sleep(2)
|
38 |
+
my_bar.empty()
|
39 |
+
return data
|
40 |
+
|
41 |
+
with st.container():
|
42 |
+
col1, col2 = st.columns(2)
|
43 |
+
if col1.button("Go Back"):
|
44 |
+
get_ta_responses.clear()
|
45 |
+
st.switch_page("pages/convosim.py")
|
46 |
+
expl = col2.checkbox("Show Scoring Explanations")
|
47 |
+
|
48 |
+
tab1, tab2 = st.tabs(["Scoring", "Conversation"])
|
49 |
+
data = get_ta_responses()
|
50 |
+
|
51 |
+
with tab2:
|
52 |
+
for msg in memory.buffer_as_messages:
|
53 |
+
role = "user" if type(msg) == HumanMessage else "assistant"
|
54 |
+
st.chat_message(role).write(msg.content)
|
55 |
+
|
56 |
+
with tab1:
|
57 |
+
for question in QUESTION2PHASE.keys():
|
58 |
+
with st.container(border=True):
|
59 |
+
question_str = NAME2QUESTION[question].split(' Answer')[0]
|
60 |
+
st.radio(
|
61 |
+
f"**{question_str}**", options=TA_OPTIONS,
|
62 |
+
index=TA_OPTIONS.index(data[question]['response']), horizontal=True,
|
63 |
+
key=f"{question}_manual"
|
64 |
+
)
|
65 |
+
if expl:
|
66 |
+
st.text_area(
|
67 |
+
label="", value=data[question]["explanation"], key=f"{question}_explanation_manual"
|
68 |
+
)
|
69 |
+
# st.write(data[question]["explanation"])
|
70 |
+
|
71 |
+
with st.container():
|
72 |
+
col1, col2 = st.columns(2)
|
73 |
+
if col1.button("Go Back", key="goback2"):
|
74 |
+
get_ta_responses.clear()
|
75 |
+
st.switch_page("pages/convosim.py")
|
76 |
+
if col2.button("Submit Scoring", type="primary"):
|
77 |
+
ytrue = {
|
78 |
+
question: {
|
79 |
+
"response": st.session_state[f"{question}_manual"],
|
80 |
+
"explanation": st.session_state[f"{question}_explanation_manual"] if expl else "",
|
81 |
+
}
|
82 |
+
for question in QUESTION2PHASE.keys()
|
83 |
+
}
|
84 |
+
ta_push_convo_comparison(ytrue, data)
|
85 |
+
get_ta_responses.clear()
|
86 |
+
st.switch_page("pages/convosim.py")
|
@@ -4,4 +4,5 @@ mlflow==2.9.0
|
|
4 |
langchain==0.3.0
|
5 |
langchain-openai==0.2.0
|
6 |
langchain-community==0.3.0
|
7 |
-
streamlit==1.38.0
|
|
|
|
4 |
langchain==0.3.0
|
5 |
langchain-openai==0.2.0
|
6 |
langchain-community==0.3.0
|
7 |
+
streamlit==1.38.0
|
8 |
+
transformers==4.43.0
|
@@ -1,19 +1,22 @@
|
|
1 |
import pandas as pd
|
2 |
import streamlit as st
|
3 |
from streamlit.logger import get_logger
|
4 |
-
import
|
|
|
5 |
|
6 |
-
|
7 |
-
from app_config import ENVIRON
|
8 |
from utils.memory_utils import change_memories
|
9 |
from models.model_seeds import seeds
|
10 |
|
11 |
-
langchain.verbose = ENVIRON =="dev"
|
12 |
logger = get_logger(__name__)
|
13 |
|
14 |
# TODO: Include more variable and representative names
|
15 |
DEFAULT_NAMES = ["Olivia", "Kit", "Abby", "Tom", "Carolyne", "Jessiny"]
|
16 |
DEFAULT_NAMES_DF = pd.read_csv("./utils/names.csv")
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def get_random_name(gender="Neutral", ethnical_group="Neutral", names_df=None):
|
19 |
if names_df is None:
|
@@ -61,4 +64,47 @@ def create_memory_add_initial_message(memories, issue, language, changed_source=
|
|
61 |
if len(st.session_state[memory].buffer_as_messages) < 1:
|
62 |
add_initial_message(issue, language, st.session_state[memory], texter_name=texter_name, counselor_name=counselor_name)
|
63 |
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
import streamlit as st
|
3 |
from streamlit.logger import get_logger
|
4 |
+
import os
|
5 |
+
import requests
|
6 |
|
7 |
+
from app_config import ENDPOINT_NAMES
|
|
|
8 |
from utils.memory_utils import change_memories
|
9 |
from models.model_seeds import seeds
|
10 |
|
|
|
11 |
logger = get_logger(__name__)
|
12 |
|
13 |
# TODO: Include more variable and representative names
|
14 |
DEFAULT_NAMES = ["Olivia", "Kit", "Abby", "Tom", "Carolyne", "Jessiny"]
|
15 |
DEFAULT_NAMES_DF = pd.read_csv("./utils/names.csv")
|
16 |
+
HEADERS = {
|
17 |
+
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
|
18 |
+
"Content-Type": "application/json",
|
19 |
+
}
|
20 |
|
21 |
def get_random_name(gender="Neutral", ethnical_group="Neutral", names_df=None):
|
22 |
if names_df is None:
|
|
|
64 |
if len(st.session_state[memory].buffer_as_messages) < 1:
|
65 |
add_initial_message(issue, language, st.session_state[memory], texter_name=texter_name, counselor_name=counselor_name)
|
66 |
|
67 |
+
def is_model_alive(name, timeout=2, model_type="classificator"):
|
68 |
+
if model_type!="openai":
|
69 |
+
endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=name)
|
70 |
+
headers = HEADERS
|
71 |
+
if model_type == "classificator":
|
72 |
+
body_request = {
|
73 |
+
"inputs": [""]
|
74 |
+
}
|
75 |
+
elif model_type == "text-completion":
|
76 |
+
body_request = {
|
77 |
+
"prompt": "",
|
78 |
+
"temperature": 0,
|
79 |
+
"max_tokens": 1,
|
80 |
+
}
|
81 |
+
elif model_type == "text-generation":
|
82 |
+
body_request = {
|
83 |
+
"messages": [{"role":"user","content":""}],
|
84 |
+
"max_tokens": 1,
|
85 |
+
"temperature": 0
|
86 |
+
}
|
87 |
+
|
88 |
+
else:
|
89 |
+
raise Exception(f"Model Type {model_type} not supported")
|
90 |
+
try:
|
91 |
+
response = requests.post(url=endpoint_url, headers=HEADERS, json=body_request, timeout=timeout)
|
92 |
+
return str(response.status_code)
|
93 |
+
except:
|
94 |
+
return "404"
|
95 |
+
else:
|
96 |
+
endpoint_url="https://api.openai.com/v1/models"
|
97 |
+
headers = {"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",}
|
98 |
+
try:
|
99 |
+
_ = requests.get(url=endpoint_url, headers=headers, timeout=1)
|
100 |
+
return "200"
|
101 |
+
except:
|
102 |
+
return "404"
|
103 |
+
|
104 |
+
@st.cache_data(ttl=300, show_spinner=False)
|
105 |
+
def are_models_alive():
|
106 |
+
models_alive = []
|
107 |
+
for config in ENDPOINT_NAMES.values():
|
108 |
+
models_alive.append(is_model_alive(**config))
|
109 |
+
openai = is_model_alive("openai", model_type="openai")
|
110 |
+
return all([x=="200" for x in models_alive + [openai]])
|
@@ -1,9 +1,11 @@
|
|
|
|
1 |
from streamlit.logger import get_logger
|
2 |
-
from
|
3 |
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
|
4 |
from models.openai.role_models import get_role_chain, get_template_role_models
|
5 |
from models.databricks.scenario_sim_biz import get_databricks_biz_chain
|
6 |
from models.databricks.texter_sim_llm import get_databricks_chain
|
|
|
7 |
|
8 |
logger = get_logger(__name__)
|
9 |
|
@@ -32,7 +34,12 @@ def custom_chain_predict(llm_chain, input, stop):
|
|
32 |
llm_chain._validate_inputs(inputs)
|
33 |
outputs = llm_chain._call(inputs)
|
34 |
llm_chain._validate_outputs(outputs)
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
36 |
for out in outputs[llm_chain.output_key]:
|
37 |
llm_chain.memory.chat_memory.add_ai_message(out)
|
38 |
return outputs[llm_chain.output_key]
|
|
|
1 |
+
import streamlit as st
|
2 |
from streamlit.logger import get_logger
|
3 |
+
from langchain_core.messages import HumanMessage
|
4 |
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
|
5 |
from models.openai.role_models import get_role_chain, get_template_role_models
|
6 |
from models.databricks.scenario_sim_biz import get_databricks_biz_chain
|
7 |
from models.databricks.texter_sim_llm import get_databricks_chain
|
8 |
+
from models.ta_models.cpc_utils import cpc_predict_message
|
9 |
|
10 |
logger = get_logger(__name__)
|
11 |
|
|
|
34 |
llm_chain._validate_inputs(inputs)
|
35 |
outputs = llm_chain._call(inputs)
|
36 |
llm_chain._validate_outputs(outputs)
|
37 |
+
phase = cpc_predict_message(st.session_state['context'], st.session_state['last_message'])
|
38 |
+
st.session_state['last_phase'] = phase
|
39 |
+
logger.debug(phase)
|
40 |
+
llm_chain.memory.chat_memory.add_user_message(
|
41 |
+
HumanMessage(inputs['input'], response_metadata={"phase":phase})
|
42 |
+
)
|
43 |
for out in outputs[llm_chain.output_key]:
|
44 |
llm_chain.memory.chat_memory.add_ai_message(out)
|
45 |
return outputs[llm_chain.output_key]
|
@@ -30,7 +30,6 @@ def change_memories(memories, language, changed_source=False):
|
|
30 |
|
31 |
if ("convo_id" in st.session_state) and changed_source:
|
32 |
del st.session_state['convo_id']
|
33 |
-
|
34 |
|
35 |
def clear_memory(memories, username, language):
|
36 |
for memory, _ in memories.items():
|
|
|
30 |
|
31 |
if ("convo_id" in st.session_state) and changed_source:
|
32 |
del st.session_state['convo_id']
|
|
|
33 |
|
34 |
def clear_memory(memories, username, language):
|
35 |
for memory, _ in memories.items():
|
@@ -4,7 +4,7 @@ import streamlit as st
|
|
4 |
from streamlit.logger import get_logger
|
5 |
from pymongo.mongo_client import MongoClient
|
6 |
from pymongo.server_api import ServerApi
|
7 |
-
from app_config import DB_SCHEMA, DB_COMPLETIONS, DB_CONVOS, DB_BATTLES, DB_ERRORS
|
8 |
|
9 |
DB_URL = os.environ['MONGO_URL']
|
10 |
DB_USR = os.environ['MONGO_USR']
|
@@ -19,7 +19,7 @@ def get_db_client():
|
|
19 |
# Send a ping to confirm a successful connection
|
20 |
try:
|
21 |
client.admin.command('ping')
|
22 |
-
logger.
|
23 |
return client
|
24 |
except Exception as e:
|
25 |
logger.error(e)
|
@@ -38,7 +38,7 @@ def new_convo(client, issue, language, username, is_comparison, model_one, model
|
|
38 |
db = client[DB_SCHEMA]
|
39 |
convos = db[DB_CONVOS]
|
40 |
convo_id = convos.insert_one(convo).inserted_id
|
41 |
-
logger.
|
42 |
st.session_state['convo_id'] = convo_id
|
43 |
|
44 |
def new_comparison(client, prompt_timestamp, completion_timestamp,
|
@@ -66,7 +66,7 @@ def new_comparison(client, prompt_timestamp, completion_timestamp,
|
|
66 |
db = client[DB_SCHEMA]
|
67 |
comparisons = db[DB_COMPLETIONS]
|
68 |
comparison_id = comparisons.insert_one(comparison).inserted_id
|
69 |
-
logger.
|
70 |
st.session_state['comparison_id'] = comparison_id
|
71 |
|
72 |
def new_battle_result(client, comparison_id, convo_id, username, model_one, model_two, winner):
|
@@ -84,7 +84,7 @@ def new_battle_result(client, comparison_id, convo_id, username, model_one, mode
|
|
84 |
db = client[DB_SCHEMA]
|
85 |
battles = db[DB_BATTLES]
|
86 |
battle_id = battles.insert_one(battle).inserted_id
|
87 |
-
logger.
|
88 |
|
89 |
def new_completion_error(client, comparison_id, username, model):
|
90 |
error = {
|
@@ -97,7 +97,58 @@ def new_completion_error(client, comparison_id, username, model):
|
|
97 |
db = client[DB_SCHEMA]
|
98 |
errors = db[DB_ERRORS]
|
99 |
error_id = errors.insert_one(error).inserted_id
|
100 |
-
logger.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
def get_non_assesed_comparison(client, username):
|
103 |
from bson.son import SON
|
|
|
4 |
from streamlit.logger import get_logger
|
5 |
from pymongo.mongo_client import MongoClient
|
6 |
from pymongo.server_api import ServerApi
|
7 |
+
from app_config import DB_SCHEMA, DB_COMPLETIONS, DB_CONVOS, DB_BATTLES, DB_ERRORS, DB_CPC, DB_BP, DB_TA
|
8 |
|
9 |
DB_URL = os.environ['MONGO_URL']
|
10 |
DB_USR = os.environ['MONGO_USR']
|
|
|
19 |
# Send a ping to confirm a successful connection
|
20 |
try:
|
21 |
client.admin.command('ping')
|
22 |
+
logger.debug(f"DBUTILS: Pinged your deployment. You successfully connected to MongoDB!")
|
23 |
return client
|
24 |
except Exception as e:
|
25 |
logger.error(e)
|
|
|
38 |
db = client[DB_SCHEMA]
|
39 |
convos = db[DB_CONVOS]
|
40 |
convo_id = convos.insert_one(convo).inserted_id
|
41 |
+
logger.debug(f"DBUTILS: new convo id is {convo_id}")
|
42 |
st.session_state['convo_id'] = convo_id
|
43 |
|
44 |
def new_comparison(client, prompt_timestamp, completion_timestamp,
|
|
|
66 |
db = client[DB_SCHEMA]
|
67 |
comparisons = db[DB_COMPLETIONS]
|
68 |
comparison_id = comparisons.insert_one(comparison).inserted_id
|
69 |
+
logger.debug(f"DBUTILS: new comparison id is {comparison_id}")
|
70 |
st.session_state['comparison_id'] = comparison_id
|
71 |
|
72 |
def new_battle_result(client, comparison_id, convo_id, username, model_one, model_two, winner):
|
|
|
84 |
db = client[DB_SCHEMA]
|
85 |
battles = db[DB_BATTLES]
|
86 |
battle_id = battles.insert_one(battle).inserted_id
|
87 |
+
logger.debug(f"DBUTILS: new battle id is {battle_id}")
|
88 |
|
89 |
def new_completion_error(client, comparison_id, username, model):
|
90 |
error = {
|
|
|
97 |
db = client[DB_SCHEMA]
|
98 |
errors = db[DB_ERRORS]
|
99 |
error_id = errors.insert_one(error).inserted_id
|
100 |
+
logger.debug(f"DBUTILS: new error id is {error_id}")
|
101 |
+
|
102 |
+
def new_cpc_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
|
103 |
+
# context = memory.load_memory_variables({})[memory.memory_key]
|
104 |
+
comp = {
|
105 |
+
"CPC_timestamp": dt.datetime.now(tz=dt.timezone.utc),
|
106 |
+
"conversation_id": convo_id,
|
107 |
+
"model": model,
|
108 |
+
"context": context,
|
109 |
+
"last_message": last_message,
|
110 |
+
"predicted_phase": ypred,
|
111 |
+
"manual_phase": ytrue,
|
112 |
+
}
|
113 |
+
|
114 |
+
db = client[DB_SCHEMA]
|
115 |
+
cpc_comps = db[DB_CPC]
|
116 |
+
comarison_id = cpc_comps.insert_one(comp).inserted_id
|
117 |
+
logger.debug(f"DBUTILS: new error id is {comarison_id}")
|
118 |
+
|
119 |
+
def new_bp_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
|
120 |
+
# context = memory.load_memory_variables({})[memory.memory_key]
|
121 |
+
comp = {
|
122 |
+
"BP_timestamp": dt.datetime.now(tz=dt.timezone.utc),
|
123 |
+
"conversation_id": convo_id,
|
124 |
+
"model": model,
|
125 |
+
"context": context,
|
126 |
+
"last_message": last_message,
|
127 |
+
"is_advice": ypred["is_advice"],
|
128 |
+
"manual_is_advice": ytrue["is_advice"],
|
129 |
+
"is_pi": ypred["is_personal_info"],
|
130 |
+
"manual_is_pi": ytrue["is_personal_info"],
|
131 |
+
}
|
132 |
+
|
133 |
+
db = client[DB_SCHEMA]
|
134 |
+
bp_comps = db[DB_BP]
|
135 |
+
comarison_id = bp_comps.insert_one(comp).inserted_id
|
136 |
+
logger.debug(f"DBUTILS: new BP id is {comarison_id}")
|
137 |
+
|
138 |
+
def new_convo_scoring_comparison(client, convo_id, context, ytrue, ypred):
|
139 |
+
# context = memory.load_memory_variables({})[memory.memory_key]
|
140 |
+
comp = {
|
141 |
+
"scoring_timestamp": dt.datetime.now(tz=dt.timezone.utc),
|
142 |
+
"conversation_id": convo_id,
|
143 |
+
"context": context,
|
144 |
+
"manual_scoring": ytrue,
|
145 |
+
"model_scoring": ypred,
|
146 |
+
}
|
147 |
+
|
148 |
+
db = client[DB_SCHEMA]
|
149 |
+
ta_comps = db[DB_TA]
|
150 |
+
comarison_id = ta_comps.insert_one(comp).inserted_id
|
151 |
+
logger.debug(f"DBUTILS: new TA convo comparison id is {comarison_id}")
|
152 |
|
153 |
def get_non_assesed_comparison(client, username):
|
154 |
from bson.son import SON
|