ivnban27-ctl commited on
Commit
f3e0ba5
·
verified ·
1 Parent(s): 348d7de

training-adherence-features (#1)

Browse files

- cpc and bad practices features (70a482f666f55772a84d65dc90e69fb21b3d96de)
- fix on tokenizer special tokens (cf6ebf90b37bb55854638f960ba1fb42e56b08f3)
- change in model aliveness calculations (2e79a3c9fdc409a99f016bad980f00c71e2860fc)
- aliveness calculation fixes (b74c038fe4c7136eb1858138c63c7ee531a574a8)
- training adherence scoring features (cfe8e1a5d05755a0346c269de80f255dfb39c501)
- making explanation editable (5f8859a38074d0644c3d0e5fc87c175ecb79071b)
- fix on roberta input len (5e4965f0930a2498341a903704d9ba347fd575e0)
- ta utils fix for explanation (7e14368f66f1989c0d91ee1fca1a5d16e638a523)
- changes on BL postprocessing (92dff98dd8d42a645c05178c183798b9fb1837e7)
- progress bar instead of spinner (a139603ab1caa2d4c2b71cca5e3d7b55e8788073)
- changed to prod (02544790fb5037419491dfe520951728f1e5cee6)

.streamlit/config.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [client]
2
+ showSidebarNavigation = false
README.md CHANGED
@@ -1,11 +1,11 @@
1
  ---
2
- title: Conversation Simulator
3
  emoji: 💬
4
  colorFrom: red
5
  colorTo: red
6
  sdk: streamlit
7
- sdk_version: 1.26.0
8
- app_file: convosim.py
9
  pinned: false
10
  ---
11
 
 
1
  ---
2
+ title: Conversation Simulator DEV
3
  emoji: 💬
4
  colorFrom: red
5
  colorTo: red
6
  sdk: streamlit
7
+ sdk_version: 1.38.0
8
+ app_file: main.py
9
  pinned: false
10
  ---
11
 
app_config.py CHANGED
@@ -18,9 +18,28 @@ SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT4o',
18
 
19
  ENDPOINT_NAMES = {
20
  # "CTL_llama2": "texter_simulator",
21
- "CTL_llama3": "texter_simulator_llm",
 
 
 
 
 
 
 
22
  # 'CTL_llama2': "llama2_convo_sim",
23
- "CTL_mistral": "convo_sim_mistral"
 
 
 
 
 
 
 
 
 
 
 
 
24
  }
25
 
26
  def source2label(source):
@@ -36,6 +55,9 @@ DB_CONVOS = 'conversations'
36
  DB_COMPLETIONS = 'comparison_completions'
37
  DB_BATTLES = 'battles'
38
  DB_ERRORS = 'completion_errors'
 
 
 
39
 
40
  MAX_MSG_COUNT = 60
41
  WARN_MSG_COUT = int(MAX_MSG_COUNT*0.8)
 
18
 
19
  ENDPOINT_NAMES = {
20
  # "CTL_llama2": "texter_simulator",
21
+ "CTL_llama3": {
22
+ "name": "texter_simulator_llm",
23
+ "model_type": "text-generation"
24
+ },
25
+ # "CTL_llama3": {
26
+ # "name": "databricks-meta-llama-3-1-70b-instruct",
27
+ # "model_type": "text-generation"
28
+ # },
29
  # 'CTL_llama2': "llama2_convo_sim",
30
+ # "CTL_mistral": "convo_sim_mistral",
31
+ "CPC": {
32
+ "name": "phase_classifier",
33
+ "model_type": "classificator"
34
+ },
35
+ "BadPractices": {
36
+ "name": "training_adherence_bp",
37
+ "model_type": "classificator"
38
+ },
39
+ "training_adherence": {
40
+ "name": "training_adherence",
41
+ "model_type": "text-completion"
42
+ },
43
  }
44
 
45
  def source2label(source):
 
55
  DB_COMPLETIONS = 'comparison_completions'
56
  DB_BATTLES = 'battles'
57
  DB_ERRORS = 'completion_errors'
58
+ DB_CPC = "cpc_comparison"
59
+ DB_BP = "bad_practices_comparison"
60
+ DB_TA = "convo_scoring_comparison"
61
 
62
  MAX_MSG_COUNT = 60
63
  WARN_MSG_COUT = int(MAX_MSG_COUNT*0.8)
main.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit.logger import get_logger
3
+
4
+ from utils.app_utils import are_models_alive
5
+
6
+ logger = get_logger(__name__)
7
+
8
+ st.set_page_config(page_title="Conversation Simulator")
9
+
10
+ with st.status("Loading Models Please Wait...(this may take up to 5 min)", expanded=True) as status:
11
+ if not are_models_alive():
12
+ st.switch_page("pages/model_loader.py")
13
+ else:
14
+ st.switch_page("pages/convosim.py")
models/business_logic_utils/config.py CHANGED
@@ -272,7 +272,8 @@ DIFFICULTIES = {
272
  "difficulty_distrustful": {
273
  "difficulty_label": "distrustful",
274
  "description": [
275
- "You don't trust the counselor, you will eventually cooperate.",
 
276
  ],
277
  },
278
  # "difficulty_stop_convo": {
 
272
  "difficulty_distrustful": {
273
  "difficulty_label": "distrustful",
274
  "description": [
275
+ #"You don't trust the counselor, you will eventually cooperate.",
276
+ "You have a distrustful attitude towards the counselor.",
277
  ],
278
  },
279
  # "difficulty_stop_convo": {
models/business_logic_utils/response_processing.py CHANGED
@@ -48,6 +48,7 @@ def postprocess_text(
48
 
49
  # Remove unnecessary role prefixes
50
  text = text.replace(human_prefix, "").replace(assistant_prefix, "")
 
51
 
52
  # Remove whispers or other marked reactions
53
  whispers = re.compile(r"(\([\w\s]+\))") # remove things like "(whispers)"
@@ -55,8 +56,11 @@ def postprocess_text(
55
  text = whispers.sub("", text)
56
  text = reactions.sub("", text)
57
 
58
- # Remove all quotation marks (both single and double)
59
- text = text.replace('"', '').replace("'", "")
 
 
 
60
 
61
  # Normalize spaces
62
  text = re.sub(r"\s+", " ", text).strip()
 
48
 
49
  # Remove unnecessary role prefixes
50
  text = text.replace(human_prefix, "").replace(assistant_prefix, "")
51
+
52
 
53
  # Remove whispers or other marked reactions
54
  whispers = re.compile(r"(\([\w\s]+\))") # remove things like "(whispers)"
 
56
  text = whispers.sub("", text)
57
  text = reactions.sub("", text)
58
 
59
+ # Remove double quotation marks
60
+ text = text.replace('"', '')
61
+
62
+ # Remove stutters of any length (e.g., "M-m-my" or "M-m-m-m-my" or "M-My" to "My")
63
+ text = re.sub(r'\b(\w)(-\1)+-\1(\w*)', r'\1\3', text, flags=re.IGNORECASE)
64
 
65
  # Normalize spaces
66
  text = re.sub(r"\s+", " ", text).strip()
models/databricks/texter_sim_llm.py CHANGED
@@ -16,15 +16,13 @@ texter:"""
16
 
17
  def get_databricks_chain(source, issue, language, memory, temperature=0.8, texter_name="Kit"):
18
 
19
- endpoint_name = ENDPOINT_NAMES.get(source, "texter_simulator")
20
-
21
  PROMPT = PromptTemplate(
22
  input_variables=['history', 'input'],
23
  template=_DATABRICKS_TEMPLATE_
24
  )
25
 
26
  llm = CustomDatabricksLLM(
27
- # endpoint_url="https://dbc-6dca8e8f-4084.cloud.databricks.com/serving-endpoints/databricks-meta-llama-3-1-70b-instruct/invocations",
28
  endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name),
29
  bearer_token=os.environ["DATABRICKS_TOKEN"],
30
  texter_name=texter_name,
@@ -43,4 +41,17 @@ def get_databricks_chain(source, issue, language, memory, temperature=0.8, texte
43
  )
44
 
45
  logging.debug(f"loaded Databricks model")
46
- return llm_chain, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def get_databricks_chain(source, issue, language, memory, temperature=0.8, texter_name="Kit"):
18
 
19
+ endpoint_name = ENDPOINT_NAMES.get(source, "texter_simulator")['name']
 
20
  PROMPT = PromptTemplate(
21
  input_variables=['history', 'input'],
22
  template=_DATABRICKS_TEMPLATE_
23
  )
24
 
25
  llm = CustomDatabricksLLM(
 
26
  endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name),
27
  bearer_token=os.environ["DATABRICKS_TOKEN"],
28
  texter_name=texter_name,
 
41
  )
42
 
43
  logging.debug(f"loaded Databricks model")
44
+ return llm_chain, None
45
+
46
+ def cpc_is_alive():
47
+ body_request = {
48
+ "inputs": [""]
49
+ }
50
+ try:
51
+ # Send request to Serving
52
+ response = requests.post(url=CPC_URL, headers=HEADERS, json=body_request, timeout=2)
53
+ if response.status_code == 200:
54
+ return True
55
+ else: return False
56
+ except:
57
+ return False
models/ta_models/bp_utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit.logger import get_logger
3
+ import requests
4
+ import os
5
+ from .config import model_name_or_path, BP_THRESHOLD
6
+ from transformers import AutoTokenizer
7
+ from utils.mongo_utils import new_bp_comparison
8
+ from app_config import ENDPOINT_NAMES
9
+
10
+ logger = get_logger(__name__)
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left")
13
+ BP_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["BadPractices"]['name'])
14
+ HEADERS = {
15
+ "Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
16
+ "Content-Type": "application/json",
17
+ }
18
+
19
+ def bp_predict_message(context, input):
20
+ # context = memory.load_memory_variables({})[memory.memory_key]
21
+ encoding = tokenizer(
22
+ context,
23
+ input,
24
+ truncation="only_first",
25
+ max_length = tokenizer.model_max_length - 2,
26
+ )['input_ids']
27
+ body_request = {
28
+ "inputs": [tokenizer.decode(encoding[1:-1])],
29
+ "params": {
30
+ "top_k": None
31
+ }
32
+ }
33
+
34
+ try:
35
+ # Send request to Serving
36
+ logger.debug(f"raw BP body is {body_request}")
37
+ response = requests.post(url=BP_URL, headers=HEADERS, json=body_request)
38
+ if response.status_code == 200:
39
+ response = response.json()['predictions'][0]
40
+ logger.debug(f"Raw BP prediction is {response}")
41
+ return [{k:v > BP_THRESHOLD if k=="score" else v for k,v in dict_.items()} for _, dict_ in response.items() ]
42
+ else:
43
+ raise Exception(f"Error in response: {response.json()}")
44
+ except Exception as e:
45
+ logger.debug(f"Error in response: {e}")
46
+ st.switch_page("pages/model_loader.py")
47
+
48
+ def bp_push2db(manual_confirmation=None):
49
+ if manual_confirmation is None:
50
+ if st.session_state.sel_bp == "Advice":
51
+ manual_confirmation = {"is_advice":True, "is_personal_info":False}
52
+ elif st.session_state.sel_bp == "Personal Info":
53
+ manual_confirmation = {"is_advice":False, "is_personal_info":True}
54
+ elif st.session_state.sel_bp == "Advice & Personal Info":
55
+ manual_confirmation = {"is_advice":True, "is_personal_info":True}
56
+ else:
57
+ manual_confirmation = {"is_advice":False, "is_personal_info":False}
58
+ new_bp_comparison(**{
59
+ "client": st.session_state['db_client'],
60
+ "convo_id": st.session_state['convo_id'],
61
+ "model": st.session_state['source'],
62
+ "context": st.session_state["context"],
63
+ "last_message": st.session_state["last_message"],
64
+ "ytrue": manual_confirmation,
65
+ "ypred": {x['label']:x['score'] for x in st.session_state['bp_prediction']},
66
+ })
models/ta_models/config.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name_or_path = "FacebookAI/xlm-roberta-large"
2
+
3
+ CPC_LABEL2STR = {
4
+ "0_ActiveEngagement": "Active Engagement",
5
+ "1_Explore": "Explore",
6
+ "2_IRA": "Immidiate Risk Assessment",
7
+ "3_SafetyAssessment": "Safety Assessment",
8
+ "4_SP&NS": "Safety Planning & Next Steps",
9
+ "5_EmergencyIntervention": "Emergency Intervention",
10
+ "6_WrappingUp": "Wrapping Up",
11
+ "7_Other": "Other",
12
+ }
13
+
14
+ CPC_LBL_OPTS = list(CPC_LABEL2STR.keys())
15
+
16
+ def cpc_label2str(phase):
17
+ return CPC_LABEL2STR[phase]
18
+
19
+ def phase2int(phase):
20
+ return int(phase.split("_")[0])
21
+
22
+ BP_THRESHOLD = 0.7
23
+ BP_LAB2STR = {
24
+ "is_advice": "Advice",
25
+ "is_personal_info": "Personal Info Sharing",
26
+ }
27
+
28
+ QUESTION2PHASE = {
29
+ "question_1": ["0_ActiveEngagement","1_Explore"],
30
+ "question_4": ["1_Explore"],
31
+ "question_5": ["0_ActiveEngagement", "1_Explore"],
32
+ # "question_7": ["1_Explore"],
33
+ # "question_9": ["4_SP&NS"],
34
+ "question_10": ["4_SP&NS"],
35
+ # "question_11": ["4_SP&NS"],
36
+ "question_14": ["6_WrappingUp"],
37
+ # "question_15": ["ALL"],
38
+ "question_19": ["ALL"],
39
+ # "question_21": ["ALL"],
40
+ # "question_22": ["ALL"],
41
+ "question_23": ["2_IRA", "3_SafetyAssessment"],
42
+ }
43
+
44
+ QUESTION2FILTERARGS = {
45
+ "question_1": {
46
+ "phases": QUESTION2PHASE["question_1"],
47
+ "pre_n": 2,
48
+ "post_n": 8,
49
+ "ignore": ["7_Other"],
50
+ },
51
+ "question_4": {
52
+ "phases": QUESTION2PHASE["question_4"],
53
+ "pre_n": 5,
54
+ "post_n": 5,
55
+ "ignore": ["7_Other"],
56
+ },
57
+ "question_5": {
58
+ "phases": QUESTION2PHASE["question_5"],
59
+ "pre_n": 5,
60
+ "post_n": 5,
61
+ "ignore": ["7_Other"],
62
+ },
63
+ # "question_7": {
64
+ # "phases": QUESTION2PHASE["question_7"],
65
+ # "pre_n": 5,
66
+ # "post_n": 15,
67
+ # "ignore": ["7_Other"],
68
+ # },
69
+ # "question_9": {
70
+ # "phases": QUESTION2PHASE["question_9"],
71
+ # "pre_n": 5,
72
+ # "post_n": 5,
73
+ # "ignore": ["7_Other"],
74
+ # },
75
+ "question_10": {
76
+ "phases": QUESTION2PHASE["question_10"],
77
+ "pre_n": 5,
78
+ "post_n": 5,
79
+ "ignore": ["7_Other"],
80
+ },
81
+ # "question_11": {
82
+ # "phases": QUESTION2PHASE["question_11"],
83
+ # "pre_n": 5,
84
+ # "post_n": 5,
85
+ # "ignore": ["7_Other"],
86
+ # },
87
+ "question_14": {
88
+ "phases": QUESTION2PHASE["question_14"],
89
+ "pre_n": 10,
90
+ "post_n": 0,
91
+ "ignore": ["7_Other"],
92
+ },
93
+ # "question_15": {
94
+ # "phases": QUESTION2PHASE["question_15"],
95
+ # "pre_n": 5,
96
+ # "post_n": 5,
97
+ # "ignore": ["7_Other"],
98
+ # },
99
+ "question_19": {
100
+ "phases": QUESTION2PHASE["question_19"],
101
+ "pre_n": 5,
102
+ "post_n": 5,
103
+ "ignore": ["7_Other"],
104
+ },
105
+ # "question_21": {
106
+ # "phases": QUESTION2PHASE["question_21"],
107
+ # "pre_n": 5,
108
+ # "post_n": 5,
109
+ # "ignore": ["7_Other"],
110
+ # },
111
+ # "question_22": {
112
+ # "phases": QUESTION2PHASE["question_22"],
113
+ # "pre_n": 5,
114
+ # "post_n": 5,
115
+ # "ignore": ["7_Other"],
116
+ # },
117
+ "question_23": {
118
+ "phases": QUESTION2PHASE["question_23"],
119
+ "pre_n": 5,
120
+ "post_n": 5,
121
+ "ignore": ["7_Other"],
122
+ },
123
+ }
124
+
125
+
126
+ START_INST = "<|user|>"
127
+ END_INST = "<|end|>\n<|assistant|>"
128
+
129
+ NAME2QUESTION = {
130
+ "question_1": "Did the helper introduce themself in the opening message? Answer only Yes or No.",
131
+ "question_4": "Did the helper actively listened to the texter's crisis? Answer only Yes or No.",
132
+ "question_5": "Did the helper reflect on the main issue that led the texter reach out? Answer only Yes or No.",
133
+ # "question_7": "Did the helper collaborated with the texter to identify the goal of the conversation? Answer only Yes or No.",
134
+ # "question_9": "Did the helper collaborated with the texter to create next steps? Answer only Yes or No.",
135
+ "question_10": "Did the helper explored texter's existing coping skills? Answer only Yes or No.",
136
+ # "question_11": "Did the helper explored texter’s social support? Answer only Yes or No.",
137
+ "question_14": "Did helper reflected the texter’s plan, reiterate coping skills, and end in a supportive way? Answer only Yes or No.",
138
+ # "question_15": "Did the helper consistently used Good Contact Techniques? Answer only Yes or No.",
139
+ "question_19": "Did the helper consistently reflected empathy through the conversation? Answer only Yes or No.",
140
+ # "question_21": "Did the helper shared personal information? Answer only Yes or No.",
141
+ # "question_22": "Did the helper gave advice? Answer only Yes or No.",
142
+ "question_23": "Did the helper explicitely initiated imminent risk assessment? Answer only Yes or No.",
143
+ }
144
+
145
+ NAME2PROMPT = {
146
+ k: "--------Conversation:\n{convo}\n{start_inst}" + v + "\n{end_inst}"
147
+ for k, v in NAME2QUESTION.items()
148
+ }
149
+
150
+ NAME2PROMPT_EXPL = {
151
+ k: v.split("Answer only Yes or No.")[0] + "Answer Yes or No, and give an explanation in a new line.\n{end_inst}"
152
+ for k, v in NAME2PROMPT.items()
153
+ }
154
+
155
+ QUESTIONDEFAULTS = {
156
+ "question_1": {True: "No, There was no evidence of Active Engagement", False: "No"},
157
+ "question_4": {True: "No, There was no evidence of Exploration Phase", False: "No"},
158
+ "question_5": {True: "No, There was no evidence of Exploration Phase", False: "No"},
159
+ # "question_7": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
160
+ # "question_9": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
161
+ "question_10": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
162
+ # "question_11": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
163
+ "question_14": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
164
+ # "question_15": "Did the helper consistently used Good Contact Techniques? Answer only Yes or No.",
165
+ "question_19": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
166
+ # "question_21": "Did the helper shared personal information? Answer only Yes or No.",
167
+ # "question_22": "Did the helper gave advice? Answer only Yes or No.",
168
+ "question_23": {True: "No, There was no evidence of Imminent Risk Assessment", False: "No"},
169
+ }
170
+
171
+ TEXTER_PREFIX = "texter"
172
+ HELPER_PREFIX = "helper"
173
+
174
+ TA_OPTIONS = ["N/A", "No", "Yes"]
models/ta_models/cpc_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit.logger import get_logger
3
+ import requests
4
+ import os
5
+ from .config import model_name_or_path
6
+ from transformers import AutoTokenizer
7
+ from utils.mongo_utils import new_cpc_comparison
8
+ from app_config import ENDPOINT_NAMES
9
+
10
+ logger = get_logger(__name__)
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left")
13
+ CPC_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["CPC"]['name'])
14
+ HEADERS = {
15
+ "Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
16
+ "Content-Type": "application/json",
17
+ }
18
+
19
+ def cpc_predict_message(context, input):
20
+ # context = memory.load_memory_variables({})[memory.memory_key]
21
+ encoding = tokenizer(
22
+ context,
23
+ input,
24
+ truncation="only_first",
25
+ max_length = tokenizer.model_max_length - 2,
26
+ )['input_ids']
27
+ body_request = {
28
+ "inputs": [tokenizer.decode(encoding[1:-1])]
29
+ }
30
+
31
+ try:
32
+ # Send request to Serving
33
+ response = requests.post(url=CPC_URL, headers=HEADERS, json=body_request)
34
+ if response.status_code == 200:
35
+ return response.json()['predictions'][0]["0"]["label"]
36
+ else:
37
+ raise Exception(f"Error in response: {response.json()}")
38
+ except Exception as e:
39
+ logger.debug(f"Error in response: {e}")
40
+ st.switch_page("pages/model_loader.py")
41
+
42
+ def cpc_push2db(is_same):
43
+ text_is_same = "SAME" if is_same else "WRONG"
44
+ logger.debug(f"pushing new {text_is_same} CPC")
45
+ new_cpc_comparison(**{
46
+ "client": st.session_state['db_client'],
47
+ "convo_id": st.session_state['convo_id'],
48
+ "model": st.session_state['source'],
49
+ "context": st.session_state["context"],
50
+ "last_message": st.session_state["last_message"],
51
+ "ytrue": st.session_state["last_phase"] if is_same else st.session_state["sel_phase"],
52
+ "ypred": st.session_state["last_phase"],
53
+ })
models/ta_models/ta_filter_utils.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import chain
2
+ from typing import List, Tuple
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ possible_movements = [-1, 1]
8
+
9
+
10
+ def dfs(indexes: List[int], x0: int, i: int, cur_island: List[int], d=2):
11
+ """Deep First Search Implementation for 2D movement.
12
+ To consider an Island only move one step left or right
13
+ See possible movements
14
+
15
+ Args:
16
+ indexes (List[int]): Indexes of positive examples. i.e [20,21,23,50,51]
17
+ x0 (int): Initial island anchor
18
+ i (int): Current index to test against anchor
19
+ cur_island (List[int]): Current Island from anchor
20
+ d (int, optional): Bounding distance to consider an island. Defaults to 2. For example
21
+ the list [20,21,23,50,51] has two islands with d=2: (20,21,23), and (50,51) but it has
22
+ three islands with d=: (20,21), (23), and (50,51)
23
+ """
24
+ rows = len(indexes)
25
+ if i < 0 or i >= rows:
26
+ return
27
+ if indexes[i] in cur_island:
28
+ return
29
+ if abs(indexes[x0] - indexes[i]) > d:
30
+ return
31
+ # computing coordinates with x0 as base
32
+ cur_island.append(indexes[i])
33
+
34
+ # repeat dfs for neighbors
35
+ for movement in possible_movements:
36
+ dfs(indexes, i, i + movement, cur_island, d)
37
+
38
+
39
+ def get_list_islands(indexes: List[int], **kwargs) -> List[List[int]]:
40
+ """Wrapper over DFS method to obtain islands from list of indexes of positive examples
41
+
42
+ Args:
43
+ indexes (List[int]): Indexes of positive examples. i.e [20,21,23,50,51]
44
+
45
+ Returns:
46
+ List[List[int]]: List of islands (each being a list)
47
+ """
48
+ islands = []
49
+ rows = len(indexes)
50
+ if rows == 0:
51
+ return islands
52
+
53
+ for i, valuei in enumerate(indexes):
54
+ # If already visited index in another dfs continue
55
+ if valuei in list(chain.from_iterable(islands)):
56
+ continue
57
+ # to hold coordinates of new island
58
+ cur_island = []
59
+ dfs(indexes, i, i, cur_island, **kwargs)
60
+
61
+ islands.append(cur_island)
62
+
63
+ return islands
64
+
65
+
66
+ def get_phases_islands_minmax(
67
+ convo: pd.DataFrame,
68
+ phases: List[str],
69
+ column: str = "convo_part",
70
+ ignore: List[str] = [],
71
+ **kwargs,
72
+ ) -> List[Tuple[int]]:
73
+ """Given a conversation with predicted Phases (or Parts), get minimum and maximum index of calculated islands.
74
+
75
+ Args:
76
+ convo (pd.DataFrame): Conversation with predicted phases stored in `column`
77
+ phases (List[str]): Phases to filter in
78
+ column (str, optional): Column where predicted phases information is stored. Defaults to "convo_part".
79
+ ignore (List[str], optional): Ignore phases list. Defaults to [].
80
+
81
+ Returns:
82
+ List[Tuple[int]]: Minimum and Maximum values of calulated islands. i.e [(20,30), (40,60)]
83
+ """
84
+
85
+ reset = convo.query(f"{column}=={column} and {column} not in @ignore").reset_index()
86
+ sub_ = reset.query(f"{column} in @phases").copy()
87
+ indexes = sub_.index.tolist()
88
+ islands = get_list_islands(indexes, **kwargs)
89
+ if len(islands) > 1:
90
+ # If there is more than one island we want to make sure to root out comparable small islands
91
+ # I.e. if there is an island with 10 messages, and island of 1 messages is not useful in that context.
92
+ max_len = np.max([len(x) for x in islands])
93
+ len_cut = 3 if max_len > 9 else 2 if max_len > 3 else 1
94
+ islands = [x for x in islands if len(x) > len_cut]
95
+
96
+ islands = [reset.iloc[x] for x in islands]
97
+ minmax_islands = [(x["index"].min(), x["index"].max()) for x in islands]
98
+
99
+ return minmax_islands
100
+
101
+
102
+ def filter_convo(
103
+ convo: pd.DataFrame,
104
+ phases: List[str],
105
+ column: str = "convo_part",
106
+ strategy: str = "islands",
107
+ pre_n: int = 5,
108
+ post_n: int = 5,
109
+ return_all_on_empty: bool = False,
110
+ **kwargs,
111
+ ) -> pd.DataFrame:
112
+ """Filter convo to include only specified phases. Take into account that sometimes predicted phases
113
+ can be messy. i.e. a prediciton of explore, explore, explore, safety_planning, explore; should return all
114
+ these messages as explore (probably safety_planning message has a low probability here.)
115
+
116
+ Args:
117
+ convo (pd.DataFrame): Conversation with predicted phases stored in `column`
118
+ phases (List[str]): Phases to filter in
119
+ column (str, optional): Column where predicted phases information is stored. Defaults to "convo_part".
120
+ strategy (str, optional): Strategy to use, can be minmax or islands. Defaults to "islands".
121
+ pre_n (int, optional): How many messages pre-phase to include. Defaults to 5.
122
+ post_n (int, optional): How many messages post-phase to include. Defaults to 5.
123
+ return_all_on_empty (bool, optional): Whether to return all messages when specified phases is not found. Defaults to False.
124
+
125
+ Returns:
126
+ pd.DataFrame: Filtered messages from the convo
127
+ """
128
+ if phases == ["ALL"]:
129
+ minidx = convo.index.min()
130
+ maxidx = convo.index.max()
131
+ minmax = [(minidx, maxidx)]
132
+ elif strategy == "minmax":
133
+ minidx = convo.query(f"{column} in @phases").index.min()
134
+ maxidx = convo.query(f"{column} in @phases").index.max() + 1
135
+ minmax = [(minidx, maxidx)]
136
+ elif strategy == "islands":
137
+ minmax = get_phases_islands_minmax(convo, phases, column, **kwargs)
138
+ parts = []
139
+ for minidx, maxidx in minmax:
140
+ minidx = max(convo.index.min(), minidx - pre_n)
141
+ maxidx = min(convo.index.max(), maxidx + post_n)
142
+ parts.append(convo.loc[minidx:maxidx])
143
+ if len(parts) == 0:
144
+ if return_all_on_empty:
145
+ return convo
146
+ else:
147
+ return pd.DataFrame(columns=convo.columns)
148
+ filtered = pd.concat(parts)
149
+ filtered = filtered[~filtered.index.duplicated(keep="first")]
150
+ return filtered
models/ta_models/ta_prompt_utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+
3
+ import pandas as pd
4
+
5
+ from .config import QUESTION2FILTERARGS, TEXTER_PREFIX, HELPER_PREFIX
6
+
7
+ # Utils to filter convo according to a phase
8
+ from .ta_filter_utils import filter_convo
9
+
10
+
11
+ def join_messages(
12
+ grp: pd.DataFrame, texter_prefix: str = "texter", helper_prefix: str = "helper"
13
+ ) -> str:
14
+ """join messages from dataframe using texter an helper prefixes
15
+
16
+ Args:
17
+ grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**.
18
+ Must have the following columns:
19
+ - actor_role
20
+ - message
21
+
22
+ texter_prefix (str, optional): prefix to use as the texter. Defaults to "texter".
23
+ helper_prefix (str, optional): prefix to use as the counselor (helper). Defaults to "helper".
24
+
25
+ Returns:
26
+ str: joined messages string separated by prefixes
27
+ """
28
+
29
+ if "actor_role" not in grp:
30
+ raise Exception("Column 'actor_role' not in DataFrame")
31
+ if "message" not in grp:
32
+ raise Exception("Column 'message' not in DataFrame")
33
+
34
+ roles = grp.actor_role.replace(
35
+ {"texter": texter_prefix, "counselor": helper_prefix, "helper": helper_prefix}
36
+ )
37
+ messages = roles.str.strip() + ": " + grp.message.str.strip()
38
+ return "\n".join(messages)
39
+
40
+
41
+ def _get_context(grp: pd.DataFrame, **kwargs) -> str:
42
+ """Get context as a str taking into account message to delete, context marker
43
+ and the type of question to use. This allows for better truncation later
44
+
45
+ Args:
46
+ grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**.
47
+ Must have the following columns:
48
+ - actor_role
49
+ - message
50
+ - `column`
51
+ column (str): column name in which the marker of the problem is
52
+
53
+ Returns:
54
+ pd.DataFrame: joined messages string separated by prefixes
55
+ """
56
+
57
+ if "actor_role" not in grp:
58
+ raise Exception("Column 'actor_role' not in DataFrame")
59
+ if "message" not in grp:
60
+ raise Exception("Column 'message' not in DataFrame")
61
+
62
+ join_args = list(inspect.signature(join_messages).parameters)
63
+ join_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in join_args}
64
+
65
+ ## DEPRECATED
66
+ # context_args = list(inspect.signature(get_context_on_marker).parameters)
67
+ # context_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in context_args}
68
+
69
+ return join_messages(grp, **join_kwargs)
70
+
71
+
72
+ def load_context(
73
+ messages: pd.DataFrame,
74
+ question: str,
75
+ message_col: str,
76
+ col_type: str,
77
+ inference: bool = False,
78
+ **kwargs,
79
+ ) -> pd.DataFrame:
80
+ """Load and filter conversation from messages given a question (with configured parameters of what phase that question is answered)
81
+
82
+ Args:
83
+ messages (pd.DataFrame): Messages dataframe with conversation_id, actor_role, `message_col` and phase prediction
84
+ question (str): Question to get context to
85
+ message_col (str): Column where messages are
86
+ col_type (str): type of message_col, can be "individual" or "joined"
87
+ base_dir (str, optional): Base directory to find model base args. Defaults to "../../".
88
+
89
+ Raises:
90
+ Exception: If question is not supported
91
+
92
+ Returns:
93
+ pd.DataFrame: filtered messages according to question configuration
94
+ """
95
+
96
+ if question not in QUESTION2FILTERARGS:
97
+ raise Exception(f"Question {question} not supported")
98
+
99
+ texter_prefix = TEXTER_PREFIX
100
+ helper_prefix = HELPER_PREFIX
101
+ context_data = messages.copy()
102
+
103
+ def convo_cpc_get_context(grp, **kwargs):
104
+ """Filter convo according to Convo Phase Classifier (CPC) predictions"""
105
+ context_ = filter_convo(grp, **QUESTION2FILTERARGS[question])
106
+ return _get_context(context_, **kwargs)
107
+
108
+ if col_type == "individual":
109
+ if "actor_role" in context_data:
110
+ context_data.dropna(subset=["actor_role"], inplace=True)
111
+ if "delete_message" in context_data:
112
+ context_data.delete_message.replace({1: True}, inplace=True)
113
+ context_data.delete_message.fillna(False, inplace=True)
114
+
115
+ context_data = (
116
+ context_data.groupby("conversation_id")
117
+ .apply(
118
+ convo_cpc_get_context,
119
+ helper_prefix=helper_prefix,
120
+ texter_prefix=texter_prefix,
121
+ )
122
+ .rename("q_context")
123
+ )
124
+ elif col_type == "joined":
125
+ context_data = context_data.groupby("conversation_id")[[message_col]].max()
126
+ context_data.rename(columns={message_col: "q_context"}, inplace=True)
127
+
128
+ return context_data
models/ta_models/ta_utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import requests
4
+ import string
5
+ import streamlit as st
6
+ from streamlit.logger import get_logger
7
+ from app_config import ENDPOINT_NAMES
8
+ from models.ta_models.config import NAME2PROMPT, NAME2PROMPT_EXPL, START_INST, END_INST, QUESTIONDEFAULTS, TA_OPTIONS, NAME2QUESTION
9
+ import pandas as pd
10
+ from langchain_core.messages import AIMessage, HumanMessage
11
+ from models.ta_models.ta_prompt_utils import load_context
12
+ from utils.mongo_utils import new_convo_scoring_comparison
13
+
14
+ logger = get_logger(__name__)
15
+ TA_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["training_adherence"]['name'])
16
+ HEADERS = {
17
+ "Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
18
+ "Content-Type": "application/json",
19
+ }
20
+
21
+ def memory2df(memory, conversation_id="convo1234"):
22
+ df = []
23
+ for i, msg in enumerate(memory.buffer_as_messages):
24
+ actor_role = "texter" if type(msg) == AIMessage else "helper" if type(msg) == HumanMessage else None
25
+ if actor_role:
26
+ convo_part = msg.response_metadata.get("phase",None)
27
+ row = {"conversation_id":conversation_id, "message_number":i+1, "actor_role":actor_role, "message":msg.content, "convo_part":convo_part}
28
+ df.append(row)
29
+
30
+ return pd.DataFrame(df)
31
+
32
+ def get_default(question, make_explanation=False):
33
+ return QUESTIONDEFAULTS[question][make_explanation]
34
+
35
+ def get_context(memory, question, make_explanation=False, **kwargs):
36
+ df = memory2df(memory, **kwargs)
37
+ contexti = load_context(df, question, "messages", "individual").iloc[0]
38
+ if contexti == "":
39
+ return ""
40
+
41
+ if make_explanation:
42
+ return NAME2PROMPT_EXPL[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST)
43
+ else:
44
+ return NAME2PROMPT[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST)
45
+
46
+ def post_process_response(full_response, delimiter="\n\n", n=2):
47
+ parts = full_response.split(delimiter)[:n]
48
+ response = extract_response(parts[0])
49
+ logger.debug(f"Response extracted is {response}")
50
+ if len(parts) > 1:
51
+ if len(parts[0]) < len(parts[1]):
52
+ full_response = parts[1]
53
+ else: full_response = parts[0]
54
+ else:
55
+ full_response = parts[0]
56
+ explanation = full_response.lstrip(response).lstrip(string.punctuation)
57
+ explanation = explanation.strip()
58
+ logger.debug(f"Explanation extracted is {explanation}")
59
+ return response, explanation
60
+
61
+ def TA_predict_convo(memory, question, make_explanation=False, **kwargs):
62
+ full_convo = memory.load_memory_variables({})[memory.memory_key]
63
+ PROMPT = get_context(memory, question, make_explanation=False, **kwargs)
64
+ logger.debug(f"Raw TA prompt is {PROMPT}")
65
+ if PROMPT == "":
66
+ full_response = get_default(question, make_explanation)
67
+ return full_convo, PROMPT, full_response
68
+
69
+ body_request = {
70
+ "prompt": PROMPT,
71
+ "temperature": 0,
72
+ "max_tokens": 3,
73
+ }
74
+
75
+ try:
76
+ # Send request to Serving
77
+ response = requests.post(url=TA_URL, headers=HEADERS, json=body_request)
78
+ if response.status_code == 200:
79
+ response = response.json()
80
+ else:
81
+ raise Exception(f"Error in response: {response.json()}")
82
+ full_response = response[0]['choices'][0]['text']
83
+ if not make_explanation:
84
+ return full_convo, PROMPT, full_response
85
+ else:
86
+ extract_response, _ = post_process_response(full_response)
87
+ PROMPT = get_context(memory, question, make_explanation=True, **kwargs)
88
+ PROMPT = PROMPT + f" {extract_response}"
89
+ logger.debug(f"Raw TA prompt for Explanation is {PROMPT}")
90
+ body_request["prompt"] = PROMPT
91
+ body_request["max_tokens"] = 128
92
+ response_expl = requests.post(url=TA_URL, headers=HEADERS, json=body_request)
93
+ if response_expl.status_code == 200:
94
+ response_expl = response_expl.json()
95
+ else:
96
+ raise Exception(f"Error in response: {response_expl.json()}")
97
+ full_response_expl = f"{extract_response} {response_expl[0]['choices'][0]['text']}"
98
+ return full_convo, PROMPT, full_response_expl
99
+ except Exception as e:
100
+ logger.debug(f"Error in response: {e}")
101
+ st.switch_page("pages/model_loader.py")
102
+
103
+ def extract_response(x: str, default: str = TA_OPTIONS[0]) -> str:
104
+ """Extract Response from generated answer
105
+ Extract only search strings
106
+
107
+ Args:
108
+ x (str): prediction
109
+ default (str, optional): default in case no response founds. Defaults to "N/A".
110
+
111
+ Returns:
112
+ str: _description_
113
+ """
114
+
115
+ try:
116
+ return re.findall("|".join(TA_OPTIONS), x)[0]
117
+ except Exception:
118
+ return default
119
+
120
+ def ta_push_convo_comparison(ytrue, ypred):
121
+ new_convo_scoring_comparison(**{
122
+ "client": st.session_state['db_client'],
123
+ "convo_id": st.session_state['convo_id'],
124
+ "context": st.session_state["context"] + "\nhelper:" + st.session_state["last_message"],
125
+ "ytrue": ytrue,
126
+ "ypred": ypred,
127
+ })
pages/convosim.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from streamlit.logger import get_logger
4
+ from langchain.schema.messages import HumanMessage
5
+ from utils.mongo_utils import get_db_client
6
+ from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF, are_models_alive
7
+ from utils.memory_utils import clear_memory, push_convo2db
8
+ from utils.chain_utils import get_chain, custom_chain_predict
9
+ from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
10
+ from models.ta_models.config import CPC_LBL_OPTS, cpc_label2str, BP_LAB2STR
11
+ from models.ta_models.cpc_utils import cpc_push2db
12
+ from models.ta_models.bp_utils import bp_predict_message, bp_push2db
13
+
14
+ logger = get_logger(__name__)
15
+ temperature = 0.8
16
+ # username = "barb-chase" #"ivnban-ctl"
17
+ st.set_page_config(page_title="Conversation Simulator")
18
+
19
+ if "sent_messages" not in st.session_state:
20
+ st.session_state['sent_messages'] = 0
21
+ if not are_models_alive():
22
+ st.switch_page("pages/model_loader.py")
23
+
24
+ if "total_messages" not in st.session_state:
25
+ st.session_state['total_messages'] = 0
26
+ if "issue" not in st.session_state:
27
+ st.session_state['issue'] = ISSUES[0]
28
+ if 'previous_source' not in st.session_state:
29
+ st.session_state['previous_source'] = SOURCES[0]
30
+ if 'db_client' not in st.session_state:
31
+ st.session_state["db_client"] = get_db_client()
32
+ if 'texter_name' not in st.session_state:
33
+ st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
34
+ logger.debug(f"texter name is {st.session_state['texter_name']}")
35
+ if "last_phase" not in st.session_state:
36
+ st.session_state["last_phase"] = CPC_LBL_OPTS[0]
37
+ # st.session_state["sel_phase"] = CPC_LBL_OPTS[0]
38
+ if "changed_cpc" not in st.session_state:
39
+ st.session_state["changed_cpc"] = False
40
+ if "changed_bp" not in st.session_state:
41
+ st.session_state["changed_bp"] = False
42
+
43
+ # st.session_state["sel_phase"] = st.session_state["last_phase"]
44
+
45
+ memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}}
46
+
47
+ with st.sidebar:
48
+ username = st.text_input("Username", value='Dani', max_chars=30)
49
+ if 'counselor_name' not in st.session_state:
50
+ st.session_state["counselor_name"] = username #get_random_name(names_df=DEFAULT_NAMES_DF)
51
+ # temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
52
+ issue = st.selectbox("Select a Scenario", ISSUES, index=ISSUES.index(st.session_state['issue']), format_func=issue2label,
53
+ on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
54
+ )
55
+ supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
56
+ language = st.selectbox("Select a Language", supported_languages, index=0,
57
+ format_func=lambda x: "English" if x=="en" else "Spanish",
58
+ on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
59
+ )
60
+
61
+ source = st.selectbox("Select a source Model A", SOURCES, index=0,
62
+ format_func=source2label, key="source"
63
+ )
64
+
65
+ changed_source = any([
66
+ st.session_state['previous_source'] != source,
67
+ st.session_state['issue'] != issue,
68
+ st.session_state['counselor_name'] != username,
69
+ ])
70
+ if changed_source:
71
+ st.session_state["counselor_name"] = username
72
+ st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
73
+ logger.debug(f"texter name is {st.session_state['texter_name']}")
74
+ st.session_state['previous_source'] = source
75
+ st.session_state['issue'] = issue
76
+ st.session_state['sent_messages'] = 0
77
+ st.session_state['total_messages'] = 0
78
+ create_memory_add_initial_message(memories,
79
+ issue,
80
+ language,
81
+ changed_source=changed_source,
82
+ counselor_name=st.session_state["counselor_name"],
83
+ texter_name=st.session_state["texter_name"])
84
+ st.session_state['previous_source'] = source
85
+ memoryA = st.session_state[list(memories.keys())[0]]
86
+ # issue only without "." marker for model compatibility
87
+ llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
88
+
89
+ st.title("💬 Simulator")
90
+ st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
91
+ for msg in memoryA.buffer_as_messages:
92
+ role = "user" if type(msg) == HumanMessage else "assistant"
93
+ st.chat_message(role).write(msg.content)
94
+
95
+ def sent_request_llm(llm_chain, prompt):
96
+ st.session_state['sent_messages'] += 1
97
+ st.chat_message("user").write(prompt)
98
+ responses = custom_chain_predict(llm_chain, prompt, stopper)
99
+ for response in responses:
100
+ st.chat_message("assistant").write(response)
101
+
102
+ # @st.dialog("Bad Practice Detected")
103
+ # def confirm_bp(bp_prediction, prompt):
104
+ # bps = [BP_LAB2STR[x['label']] for x in bp_prediction if x['score']]
105
+ # st.markdown(f"The last message was considered :red[{' and '.join(bps)}]")
106
+ # "Are you sure you want to send this message?"
107
+ # newprompt = st.text_input("Change message to:")
108
+ # "If you do not want to change leave textbox empty"
109
+ # for bp in BP_LAB2STR.keys():
110
+ # _ = st.checkbox(f"Original Message was {BP_LAB2STR[bp]}", key=f"chkbx_{bp}", value=BP_LAB2STR[bp] in bps)
111
+
112
+ # if st.button("Confirm"):
113
+ # if newprompt is not None and newprompt != "":
114
+ # prompt = newprompt
115
+ # bp_push2db(
116
+ # {bp:st.session_state[f"chkbx_{bp}"] for bp in BP_LAB2STR.keys()}
117
+ # )
118
+ # sent_request_llm(llm_chain, prompt)
119
+ # st.rerun()
120
+
121
+ if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG_COUNT - 4): #account for next interaction
122
+ if 'convo_id' not in st.session_state:
123
+ push_convo2db(memories, username, language)
124
+ st.session_state['context'] = llm_chain.memory.load_memory_variables({})[llm_chain.memory.memory_key]
125
+ st.session_state['last_message'] = prompt
126
+ if (not st.session_state.changed_cpc) and st.session_state["sent_messages"] > 0:
127
+ cpc_push2db(True)
128
+ else: st.session_state.changed_cpc = False
129
+ if (not st.session_state.changed_bp) and st.session_state["sent_messages"] > 0:
130
+ bp_push2db({x['label']:x['score'] for x in st.session_state['bp_prediction']})
131
+ else: st.session_state.changed_bp = False
132
+
133
+ context = llm_chain.memory.load_memory_variables({})[llm_chain.memory.memory_key]
134
+ st.session_state['bp_prediction'] = bp_predict_message(context, prompt)
135
+ if any([x['score'] for x in st.session_state['bp_prediction']]):
136
+ for bp in st.session_state['bp_prediction']:
137
+ if bp["score"]:
138
+ st.toast(f"Detected {BP_LAB2STR[bp['label']]} in the last message!", icon=":material/warning:")
139
+
140
+ sent_request_llm(llm_chain, prompt)
141
+ # else:
142
+ # sent_request_llm(llm_chain, prompt)
143
+
144
+ with st.sidebar:
145
+ st.divider()
146
+ st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
147
+ st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]")
148
+ # st.markdown()
149
+ def on_change_cpc():
150
+ cpc_push2db(False)
151
+ st.session_state.changed_cpc = True
152
+ def on_change_bp():
153
+ bp_push2db()
154
+ st.session_state.changed_bp = True
155
+
156
+ if st.session_state["sent_messages"] > 0:
157
+ _ = st.selectbox(f"""Last Human Message was considered :blue[**{
158
+ cpc_label2str(st.session_state['last_phase'])
159
+ }**]. If not please select from the following options""",
160
+
161
+ CPC_LBL_OPTS, index=None,format_func=cpc_label2str, on_change=on_change_cpc,
162
+ key="sel_phase",
163
+ )
164
+
165
+ BPs = [BP_LAB2STR[x['label']] for x in st.session_state['bp_prediction'] if x['score']]
166
+ selecttitle = f"""Last Human Message was considered :blue[**{
167
+ " and ".join(BPs)
168
+ }**].""" if len(BPs) > 0 else "Last Human Message was NOT considered Bad Practice."
169
+ _ = st.selectbox(selecttitle + " If not please select from the following options""",
170
+
171
+ ["None", "Advice", "Personal Info", "Advice & Personal Info"], index=None, on_change=on_change_bp,
172
+ key="sel_bp"
173
+ )
174
+
175
+ if st.button("Score Conversation"):
176
+ st.switch_page("pages/training_adherence.py")
177
+
178
+ st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
179
+ if st.session_state['total_messages'] >= MAX_MSG_COUNT:
180
+ st.toast(f"Total of {MAX_MSG_COUNT} Messages reached. Conversation Ended", icon=":material/verified:")
181
+ elif st.session_state['total_messages'] >= WARN_MSG_COUT:
182
+ st.toast(f"The conversation will end at {MAX_MSG_COUNT} Total Messages ", icon=":material/warning:")
183
+
184
+ if not are_models_alive():
185
+ st.switch_page("pages/model_loader.py")
pages/model_loader.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import streamlit as st
3
+ from streamlit.logger import get_logger
4
+ from utils.app_utils import is_model_alive
5
+ from app_config import ENDPOINT_NAMES
6
+
7
+ logger = get_logger(__name__)
8
+
9
+ st.set_page_config(page_title="Conversation Simulator")
10
+
11
+ models_alive = False
12
+ start = time.time()
13
+
14
+ MODELS2LOAD = {
15
+ "CPC": {"model_name": "Phase Classifier", "loaded":None,},
16
+ "CTL_llama3": {"model_name": "Texter Simulator", "loaded":None,},
17
+ "BadPractices": {"model_name": "Advice Identificator", "loaded":None},
18
+ "training_adherence": {"model_name": "Training Adherence", "loaded":None},
19
+ }
20
+
21
+ def write_model_status(writer, model_name, loaded, fail=False):
22
+ if loaded == "200":
23
+ writer.write(f"✅ - {model_name} Loaded")
24
+ if fail:
25
+ if loaded in ["400", "500"]:
26
+ writer.write(f"❌ - {model_name} Failed to Load, Contact [email protected]")
27
+ elif loaded == "404":
28
+ writer.write(f"❌ - {model_name} Still loading, please try in a couple of minutes")
29
+ else:
30
+ writer.write(f"🔄 - {model_name} Loading")
31
+
32
+ with st.status("Loading Models Please Wait...(this may take up to 5 min)", expanded=True) as status:
33
+
34
+ for k in MODELS2LOAD.keys():
35
+ MODELS2LOAD[k]["writer"] = st.empty()
36
+
37
+ while not models_alive:
38
+ time.sleep(2)
39
+ for name, config in MODELS2LOAD.items():
40
+ config["loaded"] = is_model_alive(**ENDPOINT_NAMES[name])
41
+
42
+ models_alive = all([x['loaded']=="200" for x in MODELS2LOAD.values()])
43
+
44
+ for _, config in MODELS2LOAD.items():
45
+ write_model_status(**config)
46
+
47
+ if int(time.time()-start) > 30:
48
+ status.update(
49
+ label="Models took too long to load. Please Refresh Page in a couple of minutes", state="error", expanded=True
50
+ )
51
+ for _, config in MODELS2LOAD.items():
52
+ write_model_status(**config, fail=True)
53
+ break
54
+
55
+ if models_alive:
56
+ st.switch_page("pages/convosim.py")
pages/training_adherence.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ from collections import defaultdict
4
+ from langchain_core.messages import HumanMessage
5
+ from utils.app_utils import are_models_alive
6
+ from models.ta_models.ta_utils import TA_predict_convo, ta_push_convo_comparison, post_process_response
7
+ from models.ta_models.config import QUESTION2PHASE, NAME2QUESTION, TA_OPTIONS
8
+
9
+ st.set_page_config(page_title="Conversation Simulator - Scoring")
10
+
11
+ if not are_models_alive():
12
+ st.switch_page("pages/model_loader.py")
13
+
14
+ if "memory" not in st.session_state:
15
+ st.switch_page("pages/convosim.py")
16
+
17
+ memory = st.session_state['memory']
18
+ progress_text = "Scoring Conversation using AI models ..."
19
+
20
+ @st.cache_data(show_spinner=False)
21
+ def get_ta_responses():
22
+ my_bar = st.progress(0, text=progress_text)
23
+ data = defaultdict(defaultdict)
24
+ for i, question in enumerate(QUESTION2PHASE.keys()):
25
+ # responses = ["Yes, The helper showed some respect.",
26
+ # "Yes. The helper is good! No doubt",
27
+ # "N/A, Texter disengaged.",
28
+ # "No. While texter is trying is lacking.",
29
+ # "No \n\n This is an explanation."]
30
+ # full_response = np.random.choice(responses)
31
+ full_convo, prompt, full_response = TA_predict_convo(memory, question, make_explanation=True, conversation_id=st.session_state['convo_id'])
32
+ response, explanation = post_process_response(full_response)
33
+ data[question]["response"] = response
34
+ data[question]["explanation"] = explanation
35
+ my_bar.progress((i+1) / len(QUESTION2PHASE.keys()), text = progress_text)
36
+ import time
37
+ time.sleep(2)
38
+ my_bar.empty()
39
+ return data
40
+
41
+ with st.container():
42
+ col1, col2 = st.columns(2)
43
+ if col1.button("Go Back"):
44
+ get_ta_responses.clear()
45
+ st.switch_page("pages/convosim.py")
46
+ expl = col2.checkbox("Show Scoring Explanations")
47
+
48
+ tab1, tab2 = st.tabs(["Scoring", "Conversation"])
49
+ data = get_ta_responses()
50
+
51
+ with tab2:
52
+ for msg in memory.buffer_as_messages:
53
+ role = "user" if type(msg) == HumanMessage else "assistant"
54
+ st.chat_message(role).write(msg.content)
55
+
56
+ with tab1:
57
+ for question in QUESTION2PHASE.keys():
58
+ with st.container(border=True):
59
+ question_str = NAME2QUESTION[question].split(' Answer')[0]
60
+ st.radio(
61
+ f"**{question_str}**", options=TA_OPTIONS,
62
+ index=TA_OPTIONS.index(data[question]['response']), horizontal=True,
63
+ key=f"{question}_manual"
64
+ )
65
+ if expl:
66
+ st.text_area(
67
+ label="", value=data[question]["explanation"], key=f"{question}_explanation_manual"
68
+ )
69
+ # st.write(data[question]["explanation"])
70
+
71
+ with st.container():
72
+ col1, col2 = st.columns(2)
73
+ if col1.button("Go Back", key="goback2"):
74
+ get_ta_responses.clear()
75
+ st.switch_page("pages/convosim.py")
76
+ if col2.button("Submit Scoring", type="primary"):
77
+ ytrue = {
78
+ question: {
79
+ "response": st.session_state[f"{question}_manual"],
80
+ "explanation": st.session_state[f"{question}_explanation_manual"] if expl else "",
81
+ }
82
+ for question in QUESTION2PHASE.keys()
83
+ }
84
+ ta_push_convo_comparison(ytrue, data)
85
+ get_ta_responses.clear()
86
+ st.switch_page("pages/convosim.py")
requirements.txt CHANGED
@@ -4,4 +4,5 @@ mlflow==2.9.0
4
  langchain==0.3.0
5
  langchain-openai==0.2.0
6
  langchain-community==0.3.0
7
- streamlit==1.38.0
 
 
4
  langchain==0.3.0
5
  langchain-openai==0.2.0
6
  langchain-community==0.3.0
7
+ streamlit==1.38.0
8
+ transformers==4.43.0
utils/app_utils.py CHANGED
@@ -1,19 +1,22 @@
1
  import pandas as pd
2
  import streamlit as st
3
  from streamlit.logger import get_logger
4
- import langchain
 
5
 
6
-
7
- from app_config import ENVIRON
8
  from utils.memory_utils import change_memories
9
  from models.model_seeds import seeds
10
 
11
- langchain.verbose = ENVIRON =="dev"
12
  logger = get_logger(__name__)
13
 
14
  # TODO: Include more variable and representative names
15
  DEFAULT_NAMES = ["Olivia", "Kit", "Abby", "Tom", "Carolyne", "Jessiny"]
16
  DEFAULT_NAMES_DF = pd.read_csv("./utils/names.csv")
 
 
 
 
17
 
18
  def get_random_name(gender="Neutral", ethnical_group="Neutral", names_df=None):
19
  if names_df is None:
@@ -61,4 +64,47 @@ def create_memory_add_initial_message(memories, issue, language, changed_source=
61
  if len(st.session_state[memory].buffer_as_messages) < 1:
62
  add_initial_message(issue, language, st.session_state[memory], texter_name=texter_name, counselor_name=counselor_name)
63
 
64
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
  import streamlit as st
3
  from streamlit.logger import get_logger
4
+ import os
5
+ import requests
6
 
7
+ from app_config import ENDPOINT_NAMES
 
8
  from utils.memory_utils import change_memories
9
  from models.model_seeds import seeds
10
 
 
11
  logger = get_logger(__name__)
12
 
13
  # TODO: Include more variable and representative names
14
  DEFAULT_NAMES = ["Olivia", "Kit", "Abby", "Tom", "Carolyne", "Jessiny"]
15
  DEFAULT_NAMES_DF = pd.read_csv("./utils/names.csv")
16
+ HEADERS = {
17
+ "Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
18
+ "Content-Type": "application/json",
19
+ }
20
 
21
  def get_random_name(gender="Neutral", ethnical_group="Neutral", names_df=None):
22
  if names_df is None:
 
64
  if len(st.session_state[memory].buffer_as_messages) < 1:
65
  add_initial_message(issue, language, st.session_state[memory], texter_name=texter_name, counselor_name=counselor_name)
66
 
67
+ def is_model_alive(name, timeout=2, model_type="classificator"):
68
+ if model_type!="openai":
69
+ endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=name)
70
+ headers = HEADERS
71
+ if model_type == "classificator":
72
+ body_request = {
73
+ "inputs": [""]
74
+ }
75
+ elif model_type == "text-completion":
76
+ body_request = {
77
+ "prompt": "",
78
+ "temperature": 0,
79
+ "max_tokens": 1,
80
+ }
81
+ elif model_type == "text-generation":
82
+ body_request = {
83
+ "messages": [{"role":"user","content":""}],
84
+ "max_tokens": 1,
85
+ "temperature": 0
86
+ }
87
+
88
+ else:
89
+ raise Exception(f"Model Type {model_type} not supported")
90
+ try:
91
+ response = requests.post(url=endpoint_url, headers=HEADERS, json=body_request, timeout=timeout)
92
+ return str(response.status_code)
93
+ except:
94
+ return "404"
95
+ else:
96
+ endpoint_url="https://api.openai.com/v1/models"
97
+ headers = {"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",}
98
+ try:
99
+ _ = requests.get(url=endpoint_url, headers=headers, timeout=1)
100
+ return "200"
101
+ except:
102
+ return "404"
103
+
104
+ @st.cache_data(ttl=300, show_spinner=False)
105
+ def are_models_alive():
106
+ models_alive = []
107
+ for config in ENDPOINT_NAMES.values():
108
+ models_alive.append(is_model_alive(**config))
109
+ openai = is_model_alive("openai", model_type="openai")
110
+ return all([x=="200" for x in models_alive + [openai]])
utils/chain_utils.py CHANGED
@@ -1,9 +1,11 @@
 
1
  from streamlit.logger import get_logger
2
- from models.model_seeds import seeds
3
  from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
4
  from models.openai.role_models import get_role_chain, get_template_role_models
5
  from models.databricks.scenario_sim_biz import get_databricks_biz_chain
6
  from models.databricks.texter_sim_llm import get_databricks_chain
 
7
 
8
  logger = get_logger(__name__)
9
 
@@ -32,7 +34,12 @@ def custom_chain_predict(llm_chain, input, stop):
32
  llm_chain._validate_inputs(inputs)
33
  outputs = llm_chain._call(inputs)
34
  llm_chain._validate_outputs(outputs)
35
- llm_chain.memory.chat_memory.add_user_message(inputs['input'])
 
 
 
 
 
36
  for out in outputs[llm_chain.output_key]:
37
  llm_chain.memory.chat_memory.add_ai_message(out)
38
  return outputs[llm_chain.output_key]
 
1
+ import streamlit as st
2
  from streamlit.logger import get_logger
3
+ from langchain_core.messages import HumanMessage
4
  from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
5
  from models.openai.role_models import get_role_chain, get_template_role_models
6
  from models.databricks.scenario_sim_biz import get_databricks_biz_chain
7
  from models.databricks.texter_sim_llm import get_databricks_chain
8
+ from models.ta_models.cpc_utils import cpc_predict_message
9
 
10
  logger = get_logger(__name__)
11
 
 
34
  llm_chain._validate_inputs(inputs)
35
  outputs = llm_chain._call(inputs)
36
  llm_chain._validate_outputs(outputs)
37
+ phase = cpc_predict_message(st.session_state['context'], st.session_state['last_message'])
38
+ st.session_state['last_phase'] = phase
39
+ logger.debug(phase)
40
+ llm_chain.memory.chat_memory.add_user_message(
41
+ HumanMessage(inputs['input'], response_metadata={"phase":phase})
42
+ )
43
  for out in outputs[llm_chain.output_key]:
44
  llm_chain.memory.chat_memory.add_ai_message(out)
45
  return outputs[llm_chain.output_key]
utils/memory_utils.py CHANGED
@@ -30,7 +30,6 @@ def change_memories(memories, language, changed_source=False):
30
 
31
  if ("convo_id" in st.session_state) and changed_source:
32
  del st.session_state['convo_id']
33
-
34
 
35
  def clear_memory(memories, username, language):
36
  for memory, _ in memories.items():
 
30
 
31
  if ("convo_id" in st.session_state) and changed_source:
32
  del st.session_state['convo_id']
 
33
 
34
  def clear_memory(memories, username, language):
35
  for memory, _ in memories.items():
utils/mongo_utils.py CHANGED
@@ -4,7 +4,7 @@ import streamlit as st
4
  from streamlit.logger import get_logger
5
  from pymongo.mongo_client import MongoClient
6
  from pymongo.server_api import ServerApi
7
- from app_config import DB_SCHEMA, DB_COMPLETIONS, DB_CONVOS, DB_BATTLES, DB_ERRORS
8
 
9
  DB_URL = os.environ['MONGO_URL']
10
  DB_USR = os.environ['MONGO_USR']
@@ -19,7 +19,7 @@ def get_db_client():
19
  # Send a ping to confirm a successful connection
20
  try:
21
  client.admin.command('ping')
22
- logger.info(f"DBUTILS: Pinged your deployment. You successfully connected to MongoDB!")
23
  return client
24
  except Exception as e:
25
  logger.error(e)
@@ -38,7 +38,7 @@ def new_convo(client, issue, language, username, is_comparison, model_one, model
38
  db = client[DB_SCHEMA]
39
  convos = db[DB_CONVOS]
40
  convo_id = convos.insert_one(convo).inserted_id
41
- logger.info(f"DBUTILS: new convo id is {convo_id}")
42
  st.session_state['convo_id'] = convo_id
43
 
44
  def new_comparison(client, prompt_timestamp, completion_timestamp,
@@ -66,7 +66,7 @@ def new_comparison(client, prompt_timestamp, completion_timestamp,
66
  db = client[DB_SCHEMA]
67
  comparisons = db[DB_COMPLETIONS]
68
  comparison_id = comparisons.insert_one(comparison).inserted_id
69
- logger.info(f"DBUTILS: new comparison id is {comparison_id}")
70
  st.session_state['comparison_id'] = comparison_id
71
 
72
  def new_battle_result(client, comparison_id, convo_id, username, model_one, model_two, winner):
@@ -84,7 +84,7 @@ def new_battle_result(client, comparison_id, convo_id, username, model_one, mode
84
  db = client[DB_SCHEMA]
85
  battles = db[DB_BATTLES]
86
  battle_id = battles.insert_one(battle).inserted_id
87
- logger.info(f"DBUTILS: new battle id is {battle_id}")
88
 
89
  def new_completion_error(client, comparison_id, username, model):
90
  error = {
@@ -97,7 +97,58 @@ def new_completion_error(client, comparison_id, username, model):
97
  db = client[DB_SCHEMA]
98
  errors = db[DB_ERRORS]
99
  error_id = errors.insert_one(error).inserted_id
100
- logger.info(f"DBUTILS: new error id is {error_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  def get_non_assesed_comparison(client, username):
103
  from bson.son import SON
 
4
  from streamlit.logger import get_logger
5
  from pymongo.mongo_client import MongoClient
6
  from pymongo.server_api import ServerApi
7
+ from app_config import DB_SCHEMA, DB_COMPLETIONS, DB_CONVOS, DB_BATTLES, DB_ERRORS, DB_CPC, DB_BP, DB_TA
8
 
9
  DB_URL = os.environ['MONGO_URL']
10
  DB_USR = os.environ['MONGO_USR']
 
19
  # Send a ping to confirm a successful connection
20
  try:
21
  client.admin.command('ping')
22
+ logger.debug(f"DBUTILS: Pinged your deployment. You successfully connected to MongoDB!")
23
  return client
24
  except Exception as e:
25
  logger.error(e)
 
38
  db = client[DB_SCHEMA]
39
  convos = db[DB_CONVOS]
40
  convo_id = convos.insert_one(convo).inserted_id
41
+ logger.debug(f"DBUTILS: new convo id is {convo_id}")
42
  st.session_state['convo_id'] = convo_id
43
 
44
  def new_comparison(client, prompt_timestamp, completion_timestamp,
 
66
  db = client[DB_SCHEMA]
67
  comparisons = db[DB_COMPLETIONS]
68
  comparison_id = comparisons.insert_one(comparison).inserted_id
69
+ logger.debug(f"DBUTILS: new comparison id is {comparison_id}")
70
  st.session_state['comparison_id'] = comparison_id
71
 
72
  def new_battle_result(client, comparison_id, convo_id, username, model_one, model_two, winner):
 
84
  db = client[DB_SCHEMA]
85
  battles = db[DB_BATTLES]
86
  battle_id = battles.insert_one(battle).inserted_id
87
+ logger.debug(f"DBUTILS: new battle id is {battle_id}")
88
 
89
  def new_completion_error(client, comparison_id, username, model):
90
  error = {
 
97
  db = client[DB_SCHEMA]
98
  errors = db[DB_ERRORS]
99
  error_id = errors.insert_one(error).inserted_id
100
+ logger.debug(f"DBUTILS: new error id is {error_id}")
101
+
102
+ def new_cpc_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
103
+ # context = memory.load_memory_variables({})[memory.memory_key]
104
+ comp = {
105
+ "CPC_timestamp": dt.datetime.now(tz=dt.timezone.utc),
106
+ "conversation_id": convo_id,
107
+ "model": model,
108
+ "context": context,
109
+ "last_message": last_message,
110
+ "predicted_phase": ypred,
111
+ "manual_phase": ytrue,
112
+ }
113
+
114
+ db = client[DB_SCHEMA]
115
+ cpc_comps = db[DB_CPC]
116
+ comarison_id = cpc_comps.insert_one(comp).inserted_id
117
+ logger.debug(f"DBUTILS: new error id is {comarison_id}")
118
+
119
+ def new_bp_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
120
+ # context = memory.load_memory_variables({})[memory.memory_key]
121
+ comp = {
122
+ "BP_timestamp": dt.datetime.now(tz=dt.timezone.utc),
123
+ "conversation_id": convo_id,
124
+ "model": model,
125
+ "context": context,
126
+ "last_message": last_message,
127
+ "is_advice": ypred["is_advice"],
128
+ "manual_is_advice": ytrue["is_advice"],
129
+ "is_pi": ypred["is_personal_info"],
130
+ "manual_is_pi": ytrue["is_personal_info"],
131
+ }
132
+
133
+ db = client[DB_SCHEMA]
134
+ bp_comps = db[DB_BP]
135
+ comarison_id = bp_comps.insert_one(comp).inserted_id
136
+ logger.debug(f"DBUTILS: new BP id is {comarison_id}")
137
+
138
+ def new_convo_scoring_comparison(client, convo_id, context, ytrue, ypred):
139
+ # context = memory.load_memory_variables({})[memory.memory_key]
140
+ comp = {
141
+ "scoring_timestamp": dt.datetime.now(tz=dt.timezone.utc),
142
+ "conversation_id": convo_id,
143
+ "context": context,
144
+ "manual_scoring": ytrue,
145
+ "model_scoring": ypred,
146
+ }
147
+
148
+ db = client[DB_SCHEMA]
149
+ ta_comps = db[DB_TA]
150
+ comarison_id = ta_comps.insert_one(comp).inserted_id
151
+ logger.debug(f"DBUTILS: new TA convo comparison id is {comarison_id}")
152
 
153
  def get_non_assesed_comparison(client, username):
154
  from bson.son import SON