ivnban27-ctl commited on
Commit
18d7e6c
·
verified ·
1 Parent(s): f3e0ba5

TA-fixesDB (#2)

Browse files

- ID exposure + DB fixes (4eaaa739f5fec0ec0ebac6a2b5ecf2714319c482)

models/ta_models/config.py CHANGED
@@ -19,6 +19,8 @@ def cpc_label2str(phase):
19
  def phase2int(phase):
20
  return int(phase.split("_")[0])
21
 
 
 
22
  BP_THRESHOLD = 0.7
23
  BP_LAB2STR = {
24
  "is_advice": "Advice",
 
19
  def phase2int(phase):
20
  return int(phase.split("_")[0])
21
 
22
+ BP_LBL_OPTS = ["None", "Advice", "Personal Info", "Advice & Personal Info"]
23
+
24
  BP_THRESHOLD = 0.7
25
  BP_LAB2STR = {
26
  "is_advice": "Advice",
pages/convosim.py CHANGED
@@ -7,7 +7,7 @@ from utils.app_utils import create_memory_add_initial_message, get_random_name,
7
  from utils.memory_utils import clear_memory, push_convo2db
8
  from utils.chain_utils import get_chain, custom_chain_predict
9
  from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
10
- from models.ta_models.config import CPC_LBL_OPTS, cpc_label2str, BP_LAB2STR
11
  from models.ta_models.cpc_utils import cpc_push2db
12
  from models.ta_models.bp_utils import bp_predict_message, bp_push2db
13
 
@@ -75,6 +75,7 @@ if changed_source:
75
  st.session_state['issue'] = issue
76
  st.session_state['sent_messages'] = 0
77
  st.session_state['total_messages'] = 0
 
78
  create_memory_add_initial_message(memories,
79
  issue,
80
  language,
@@ -121,17 +122,24 @@ def sent_request_llm(llm_chain, prompt):
121
  if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG_COUNT - 4): #account for next interaction
122
  if 'convo_id' not in st.session_state:
123
  push_convo2db(memories, username, language)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  st.session_state['context'] = llm_chain.memory.load_memory_variables({})[llm_chain.memory.memory_key]
125
  st.session_state['last_message'] = prompt
126
- if (not st.session_state.changed_cpc) and st.session_state["sent_messages"] > 0:
127
- cpc_push2db(True)
128
- else: st.session_state.changed_cpc = False
129
- if (not st.session_state.changed_bp) and st.session_state["sent_messages"] > 0:
130
- bp_push2db({x['label']:x['score'] for x in st.session_state['bp_prediction']})
131
- else: st.session_state.changed_bp = False
132
-
133
- context = llm_chain.memory.load_memory_variables({})[llm_chain.memory.memory_key]
134
- st.session_state['bp_prediction'] = bp_predict_message(context, prompt)
135
  if any([x['score'] for x in st.session_state['bp_prediction']]):
136
  for bp in st.session_state['bp_prediction']:
137
  if bp["score"]:
@@ -142,6 +150,8 @@ if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG
142
  # sent_request_llm(llm_chain, prompt)
143
 
144
  with st.sidebar:
 
 
145
  st.divider()
146
  st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
147
  st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]")
@@ -158,7 +168,7 @@ with st.sidebar:
158
  cpc_label2str(st.session_state['last_phase'])
159
  }**]. If not please select from the following options""",
160
 
161
- CPC_LBL_OPTS, index=None,format_func=cpc_label2str, on_change=on_change_cpc,
162
  key="sel_phase",
163
  )
164
 
@@ -168,7 +178,7 @@ with st.sidebar:
168
  }**].""" if len(BPs) > 0 else "Last Human Message was NOT considered Bad Practice."
169
  _ = st.selectbox(selecttitle + " If not please select from the following options""",
170
 
171
- ["None", "Advice", "Personal Info", "Advice & Personal Info"], index=None, on_change=on_change_bp,
172
  key="sel_bp"
173
  )
174
 
 
7
  from utils.memory_utils import clear_memory, push_convo2db
8
  from utils.chain_utils import get_chain, custom_chain_predict
9
  from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
10
+ from models.ta_models.config import CPC_LBL_OPTS, cpc_label2str, BP_LAB2STR, BP_LBL_OPTS
11
  from models.ta_models.cpc_utils import cpc_push2db
12
  from models.ta_models.bp_utils import bp_predict_message, bp_push2db
13
 
 
75
  st.session_state['issue'] = issue
76
  st.session_state['sent_messages'] = 0
77
  st.session_state['total_messages'] = 0
78
+
79
  create_memory_add_initial_message(memories,
80
  issue,
81
  language,
 
122
  if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG_COUNT - 4): #account for next interaction
123
  if 'convo_id' not in st.session_state:
124
  push_convo2db(memories, username, language)
125
+
126
+ if st.session_state["sent_messages"] > 0:
127
+ if st.session_state.changed_cpc:
128
+ st.session_state["sel_phase"] = None
129
+ st.session_state.changed_cpc = False
130
+ else:
131
+ cpc_push2db(True)
132
+
133
+ if st.session_state.changed_bp:
134
+ st.session_state["sel_bp"] = None
135
+ st.session_state.changed_bp = False
136
+ else:
137
+ bp_push2db({x['label']:x['score'] for x in st.session_state['bp_prediction']})
138
+
139
+
140
  st.session_state['context'] = llm_chain.memory.load_memory_variables({})[llm_chain.memory.memory_key]
141
  st.session_state['last_message'] = prompt
142
+ st.session_state['bp_prediction'] = bp_predict_message(st.session_state['context'], prompt)
 
 
 
 
 
 
 
 
143
  if any([x['score'] for x in st.session_state['bp_prediction']]):
144
  for bp in st.session_state['bp_prediction']:
145
  if bp["score"]:
 
150
  # sent_request_llm(llm_chain, prompt)
151
 
152
  with st.sidebar:
153
+ if "convo_id" in st.session_state:
154
+ st.write(f"Conversation ID is `{st.session_state['convo_id']}`")
155
  st.divider()
156
  st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
157
  st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]")
 
168
  cpc_label2str(st.session_state['last_phase'])
169
  }**]. If not please select from the following options""",
170
 
171
+ CPC_LBL_OPTS, index=None, format_func=cpc_label2str, on_change=on_change_cpc,
172
  key="sel_phase",
173
  )
174
 
 
178
  }**].""" if len(BPs) > 0 else "Last Human Message was NOT considered Bad Practice."
179
  _ = st.selectbox(selecttitle + " If not please select from the following options""",
180
 
181
+ BP_LBL_OPTS, index=None, format_func=lambda x: x, on_change=on_change_bp,
182
  key="sel_bp"
183
  )
184
 
utils/mongo_utils.py CHANGED
@@ -114,7 +114,7 @@ def new_cpc_comparison(client, convo_id, model, context, last_message, ytrue, yp
114
  db = client[DB_SCHEMA]
115
  cpc_comps = db[DB_CPC]
116
  comarison_id = cpc_comps.insert_one(comp).inserted_id
117
- logger.debug(f"DBUTILS: new error id is {comarison_id}")
118
 
119
  def new_bp_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
120
  # context = memory.load_memory_variables({})[memory.memory_key]
 
114
  db = client[DB_SCHEMA]
115
  cpc_comps = db[DB_CPC]
116
  comarison_id = cpc_comps.insert_one(comp).inserted_id
117
+ logger.debug(f"DBUTILS: new CPC id is {comarison_id}")
118
 
119
  def new_bp_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
120
  # context = memory.load_memory_variables({})[memory.memory_key]