Spaces:
Sleeping
Sleeping
ivnban27-ctl
commited on
bug/ifb_fix_changed_CPC_saving (#4)
Browse files- fix on change CPC after selection and update aliveness logic (28bc4e32cafddf8ef76c60c668b6d5256f38360f)
- models/ta_models/cpc_utils.py +9 -0
- pages/convosim.py +4 -1
- utils/app_utils.py +4 -3
models/ta_models/cpc_utils.py
CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
|
|
2 |
from streamlit.logger import get_logger
|
3 |
import requests
|
4 |
import os
|
|
|
5 |
from .config import model_name_or_path
|
6 |
from transformers import AutoTokenizer
|
7 |
from utils.mongo_utils import new_cpc_comparison
|
@@ -16,6 +17,14 @@ HEADERS = {
|
|
16 |
"Content-Type": "application/json",
|
17 |
}
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def cpc_predict_message(context, input):
|
20 |
# context = memory.load_memory_variables({})[memory.memory_key]
|
21 |
encoding = tokenizer(
|
|
|
2 |
from streamlit.logger import get_logger
|
3 |
import requests
|
4 |
import os
|
5 |
+
from langchain_core.messages import HumanMessage
|
6 |
from .config import model_name_or_path
|
7 |
from transformers import AutoTokenizer
|
8 |
from utils.mongo_utils import new_cpc_comparison
|
|
|
17 |
"Content-Type": "application/json",
|
18 |
}
|
19 |
|
20 |
+
def modify_last_human_message(memory, new_phase):
|
21 |
+
# Travel list backwards
|
22 |
+
for msg in memory.chat_memory.messages[::-1]:
|
23 |
+
if type(msg) == HumanMessage:
|
24 |
+
msg.response_metadata = {"phase":new_phase}
|
25 |
+
break
|
26 |
+
|
27 |
+
|
28 |
def cpc_predict_message(context, input):
|
29 |
# context = memory.load_memory_variables({})[memory.memory_key]
|
30 |
encoding = tokenizer(
|
pages/convosim.py
CHANGED
@@ -8,7 +8,7 @@ from utils.memory_utils import clear_memory, push_convo2db
|
|
8 |
from utils.chain_utils import get_chain, custom_chain_predict
|
9 |
from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
|
10 |
from models.ta_models.config import CPC_LBL_OPTS, cpc_label2str, BP_LAB2STR, BP_LBL_OPTS
|
11 |
-
from models.ta_models.cpc_utils import cpc_push2db
|
12 |
from models.ta_models.bp_utils import bp_predict_message, bp_push2db
|
13 |
|
14 |
logger = get_logger(__name__)
|
@@ -99,6 +99,7 @@ def sent_request_llm(llm_chain, prompt):
|
|
99 |
responses = custom_chain_predict(llm_chain, prompt, stopper)
|
100 |
for response in responses:
|
101 |
st.chat_message("assistant").write(response)
|
|
|
102 |
transcript = memoryA.load_memory_variables({})[memoryA.memory_key]
|
103 |
update_convo(st.session_state["db_client"], st.session_state["convo_id"], transcript)
|
104 |
|
@@ -160,6 +161,8 @@ with st.sidebar:
|
|
160 |
# st.markdown()
|
161 |
def on_change_cpc():
|
162 |
cpc_push2db(False)
|
|
|
|
|
163 |
st.session_state.changed_cpc = True
|
164 |
def on_change_bp():
|
165 |
bp_push2db()
|
|
|
8 |
from utils.chain_utils import get_chain, custom_chain_predict
|
9 |
from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
|
10 |
from models.ta_models.config import CPC_LBL_OPTS, cpc_label2str, BP_LAB2STR, BP_LBL_OPTS
|
11 |
+
from models.ta_models.cpc_utils import cpc_push2db, modify_last_human_message
|
12 |
from models.ta_models.bp_utils import bp_predict_message, bp_push2db
|
13 |
|
14 |
logger = get_logger(__name__)
|
|
|
99 |
responses = custom_chain_predict(llm_chain, prompt, stopper)
|
100 |
for response in responses:
|
101 |
st.chat_message("assistant").write(response)
|
102 |
+
logger.info(f"After Change: {[x.response_metadata for x in memoryA.chat_memory.messages]}")
|
103 |
transcript = memoryA.load_memory_variables({})[memoryA.memory_key]
|
104 |
update_convo(st.session_state["db_client"], st.session_state["convo_id"], transcript)
|
105 |
|
|
|
161 |
# st.markdown()
|
162 |
def on_change_cpc():
|
163 |
cpc_push2db(False)
|
164 |
+
modify_last_human_message(memoryA, st.session_state['sel_phase'])
|
165 |
+
logger.info(f"After Change: {[x.response_metadata for x in memoryA.chat_memory.messages]}")
|
166 |
st.session_state.changed_cpc = True
|
167 |
def on_change_bp():
|
168 |
bp_push2db()
|
utils/app_utils.py
CHANGED
@@ -96,8 +96,8 @@ def is_model_alive(name, timeout=2, model_type="classificator"):
|
|
96 |
endpoint_url="https://api.openai.com/v1/models"
|
97 |
headers = {"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",}
|
98 |
try:
|
99 |
-
|
100 |
-
return
|
101 |
except:
|
102 |
return "404"
|
103 |
|
@@ -107,4 +107,5 @@ def are_models_alive():
|
|
107 |
for config in ENDPOINT_NAMES.values():
|
108 |
models_alive.append(is_model_alive(**config))
|
109 |
openai = is_model_alive("openai", model_type="openai")
|
110 |
-
|
|
|
|
96 |
endpoint_url="https://api.openai.com/v1/models"
|
97 |
headers = {"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",}
|
98 |
try:
|
99 |
+
response = requests.get(url=endpoint_url, headers=headers, timeout=1)
|
100 |
+
return str(response.status_code)
|
101 |
except:
|
102 |
return "404"
|
103 |
|
|
|
107 |
for config in ENDPOINT_NAMES.values():
|
108 |
models_alive.append(is_model_alive(**config))
|
109 |
openai = is_model_alive("openai", model_type="openai")
|
110 |
+
models_alive.append(openai)
|
111 |
+
return all([x=="200" for x in models_alive])
|