Spaces:
Sleeping
Sleeping
ivnban27-ctl
commited on
TA-fixesDB (#2)
Browse files- ID exposure + DB fixes (4eaaa739f5fec0ec0ebac6a2b5ecf2714319c482)
- models/ta_models/config.py +2 -0
- pages/convosim.py +22 -12
- utils/mongo_utils.py +1 -1
models/ta_models/config.py
CHANGED
@@ -19,6 +19,8 @@ def cpc_label2str(phase):
|
|
19 |
def phase2int(phase):
|
20 |
return int(phase.split("_")[0])
|
21 |
|
|
|
|
|
22 |
BP_THRESHOLD = 0.7
|
23 |
BP_LAB2STR = {
|
24 |
"is_advice": "Advice",
|
|
|
19 |
def phase2int(phase):
|
20 |
return int(phase.split("_")[0])
|
21 |
|
22 |
+
BP_LBL_OPTS = ["None", "Advice", "Personal Info", "Advice & Personal Info"]
|
23 |
+
|
24 |
BP_THRESHOLD = 0.7
|
25 |
BP_LAB2STR = {
|
26 |
"is_advice": "Advice",
|
pages/convosim.py
CHANGED
@@ -7,7 +7,7 @@ from utils.app_utils import create_memory_add_initial_message, get_random_name,
|
|
7 |
from utils.memory_utils import clear_memory, push_convo2db
|
8 |
from utils.chain_utils import get_chain, custom_chain_predict
|
9 |
from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
|
10 |
-
from models.ta_models.config import CPC_LBL_OPTS, cpc_label2str, BP_LAB2STR
|
11 |
from models.ta_models.cpc_utils import cpc_push2db
|
12 |
from models.ta_models.bp_utils import bp_predict_message, bp_push2db
|
13 |
|
@@ -75,6 +75,7 @@ if changed_source:
|
|
75 |
st.session_state['issue'] = issue
|
76 |
st.session_state['sent_messages'] = 0
|
77 |
st.session_state['total_messages'] = 0
|
|
|
78 |
create_memory_add_initial_message(memories,
|
79 |
issue,
|
80 |
language,
|
@@ -121,17 +122,24 @@ def sent_request_llm(llm_chain, prompt):
|
|
121 |
if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG_COUNT - 4): #account for next interaction
|
122 |
if 'convo_id' not in st.session_state:
|
123 |
push_convo2db(memories, username, language)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
st.session_state['context'] = llm_chain.memory.load_memory_variables({})[llm_chain.memory.memory_key]
|
125 |
st.session_state['last_message'] = prompt
|
126 |
-
|
127 |
-
cpc_push2db(True)
|
128 |
-
else: st.session_state.changed_cpc = False
|
129 |
-
if (not st.session_state.changed_bp) and st.session_state["sent_messages"] > 0:
|
130 |
-
bp_push2db({x['label']:x['score'] for x in st.session_state['bp_prediction']})
|
131 |
-
else: st.session_state.changed_bp = False
|
132 |
-
|
133 |
-
context = llm_chain.memory.load_memory_variables({})[llm_chain.memory.memory_key]
|
134 |
-
st.session_state['bp_prediction'] = bp_predict_message(context, prompt)
|
135 |
if any([x['score'] for x in st.session_state['bp_prediction']]):
|
136 |
for bp in st.session_state['bp_prediction']:
|
137 |
if bp["score"]:
|
@@ -142,6 +150,8 @@ if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG
|
|
142 |
# sent_request_llm(llm_chain, prompt)
|
143 |
|
144 |
with st.sidebar:
|
|
|
|
|
145 |
st.divider()
|
146 |
st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
|
147 |
st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]")
|
@@ -158,7 +168,7 @@ with st.sidebar:
|
|
158 |
cpc_label2str(st.session_state['last_phase'])
|
159 |
}**]. If not please select from the following options""",
|
160 |
|
161 |
-
CPC_LBL_OPTS, index=None,format_func=cpc_label2str, on_change=on_change_cpc,
|
162 |
key="sel_phase",
|
163 |
)
|
164 |
|
@@ -168,7 +178,7 @@ with st.sidebar:
|
|
168 |
}**].""" if len(BPs) > 0 else "Last Human Message was NOT considered Bad Practice."
|
169 |
_ = st.selectbox(selecttitle + " If not please select from the following options""",
|
170 |
|
171 |
-
|
172 |
key="sel_bp"
|
173 |
)
|
174 |
|
|
|
7 |
from utils.memory_utils import clear_memory, push_convo2db
|
8 |
from utils.chain_utils import get_chain, custom_chain_predict
|
9 |
from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
|
10 |
+
from models.ta_models.config import CPC_LBL_OPTS, cpc_label2str, BP_LAB2STR, BP_LBL_OPTS
|
11 |
from models.ta_models.cpc_utils import cpc_push2db
|
12 |
from models.ta_models.bp_utils import bp_predict_message, bp_push2db
|
13 |
|
|
|
75 |
st.session_state['issue'] = issue
|
76 |
st.session_state['sent_messages'] = 0
|
77 |
st.session_state['total_messages'] = 0
|
78 |
+
|
79 |
create_memory_add_initial_message(memories,
|
80 |
issue,
|
81 |
language,
|
|
|
122 |
if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG_COUNT - 4): #account for next interaction
|
123 |
if 'convo_id' not in st.session_state:
|
124 |
push_convo2db(memories, username, language)
|
125 |
+
|
126 |
+
if st.session_state["sent_messages"] > 0:
|
127 |
+
if st.session_state.changed_cpc:
|
128 |
+
st.session_state["sel_phase"] = None
|
129 |
+
st.session_state.changed_cpc = False
|
130 |
+
else:
|
131 |
+
cpc_push2db(True)
|
132 |
+
|
133 |
+
if st.session_state.changed_bp:
|
134 |
+
st.session_state["sel_bp"] = None
|
135 |
+
st.session_state.changed_bp = False
|
136 |
+
else:
|
137 |
+
bp_push2db({x['label']:x['score'] for x in st.session_state['bp_prediction']})
|
138 |
+
|
139 |
+
|
140 |
st.session_state['context'] = llm_chain.memory.load_memory_variables({})[llm_chain.memory.memory_key]
|
141 |
st.session_state['last_message'] = prompt
|
142 |
+
st.session_state['bp_prediction'] = bp_predict_message(st.session_state['context'], prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
if any([x['score'] for x in st.session_state['bp_prediction']]):
|
144 |
for bp in st.session_state['bp_prediction']:
|
145 |
if bp["score"]:
|
|
|
150 |
# sent_request_llm(llm_chain, prompt)
|
151 |
|
152 |
with st.sidebar:
|
153 |
+
if "convo_id" in st.session_state:
|
154 |
+
st.write(f"Conversation ID is `{st.session_state['convo_id']}`")
|
155 |
st.divider()
|
156 |
st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
|
157 |
st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]")
|
|
|
168 |
cpc_label2str(st.session_state['last_phase'])
|
169 |
}**]. If not please select from the following options""",
|
170 |
|
171 |
+
CPC_LBL_OPTS, index=None, format_func=cpc_label2str, on_change=on_change_cpc,
|
172 |
key="sel_phase",
|
173 |
)
|
174 |
|
|
|
178 |
}**].""" if len(BPs) > 0 else "Last Human Message was NOT considered Bad Practice."
|
179 |
_ = st.selectbox(selecttitle + " If not please select from the following options""",
|
180 |
|
181 |
+
BP_LBL_OPTS, index=None, format_func=lambda x: x, on_change=on_change_bp,
|
182 |
key="sel_bp"
|
183 |
)
|
184 |
|
utils/mongo_utils.py
CHANGED
@@ -114,7 +114,7 @@ def new_cpc_comparison(client, convo_id, model, context, last_message, ytrue, yp
|
|
114 |
db = client[DB_SCHEMA]
|
115 |
cpc_comps = db[DB_CPC]
|
116 |
comarison_id = cpc_comps.insert_one(comp).inserted_id
|
117 |
-
logger.debug(f"DBUTILS: new
|
118 |
|
119 |
def new_bp_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
|
120 |
# context = memory.load_memory_variables({})[memory.memory_key]
|
|
|
114 |
db = client[DB_SCHEMA]
|
115 |
cpc_comps = db[DB_CPC]
|
116 |
comarison_id = cpc_comps.insert_one(comp).inserted_id
|
117 |
+
logger.debug(f"DBUTILS: new CPC id is {comarison_id}")
|
118 |
|
119 |
def new_bp_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
|
120 |
# context = memory.load_memory_variables({})[memory.memory_key]
|