ivnban27-ctl commited on
Commit
134b64a
·
1 Parent(s): adae8ef

saving new CPC after selection and updating aliveness logic

Browse files
models/ta_models/cpc_utils.py CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
2
  from streamlit.logger import get_logger
3
  import requests
4
  import os
 
5
  from .config import model_name_or_path
6
  from transformers import AutoTokenizer
7
  from utils.mongo_utils import new_cpc_comparison
@@ -16,6 +17,13 @@ HEADERS = {
16
  "Content-Type": "application/json",
17
  }
18
 
 
 
 
 
 
 
 
19
  def cpc_predict_message(context, input):
20
  # context = memory.load_memory_variables({})[memory.memory_key]
21
  encoding = tokenizer(
 
2
  from streamlit.logger import get_logger
3
  import requests
4
  import os
5
+ from langchain_core.messages import HumanMessage
6
  from .config import model_name_or_path
7
  from transformers import AutoTokenizer
8
  from utils.mongo_utils import new_cpc_comparison
 
17
  "Content-Type": "application/json",
18
  }
19
 
20
+ def modify_last_human_message(memory, new_phase):
21
+ # Travel list backwards
22
+ for msg in memory.chat_memory.messages[::-1]:
23
+ if type(msg) == HumanMessage:
24
+ msg.response_metadata = {"phase":new_phase}
25
+ break
26
+
27
  def cpc_predict_message(context, input):
28
  # context = memory.load_memory_variables({})[memory.memory_key]
29
  encoding = tokenizer(
pages/convosim.py CHANGED
@@ -8,7 +8,7 @@ from utils.memory_utils import clear_memory, push_convo2db
8
  from utils.chain_utils import get_chain, custom_chain_predict
9
  from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
10
  from models.ta_models.config import CPC_LBL_OPTS, cpc_label2str, BP_LAB2STR, BP_LBL_OPTS
11
- from models.ta_models.cpc_utils import cpc_push2db
12
  from models.ta_models.bp_utils import bp_predict_message, bp_push2db
13
 
14
  logger = get_logger(__name__)
@@ -99,6 +99,7 @@ def sent_request_llm(llm_chain, prompt):
99
  responses = custom_chain_predict(llm_chain, prompt, stopper)
100
  for response in responses:
101
  st.chat_message("assistant").write(response)
 
102
  transcript = memoryA.load_memory_variables({})[memoryA.memory_key]
103
  update_convo(st.session_state["db_client"], st.session_state["convo_id"], transcript)
104
 
@@ -160,6 +161,7 @@ with st.sidebar:
160
  # st.markdown()
161
  def on_change_cpc():
162
  cpc_push2db(False)
 
163
  st.session_state.changed_cpc = True
164
  def on_change_bp():
165
  bp_push2db()
 
8
  from utils.chain_utils import get_chain, custom_chain_predict
9
  from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
10
  from models.ta_models.config import CPC_LBL_OPTS, cpc_label2str, BP_LAB2STR, BP_LBL_OPTS
11
+ from models.ta_models.cpc_utils import cpc_push2db, modify_last_human_message
12
  from models.ta_models.bp_utils import bp_predict_message, bp_push2db
13
 
14
  logger = get_logger(__name__)
 
99
  responses = custom_chain_predict(llm_chain, prompt, stopper)
100
  for response in responses:
101
  st.chat_message("assistant").write(response)
102
+
103
  transcript = memoryA.load_memory_variables({})[memoryA.memory_key]
104
  update_convo(st.session_state["db_client"], st.session_state["convo_id"], transcript)
105
 
 
161
  # st.markdown()
162
  def on_change_cpc():
163
  cpc_push2db(False)
164
+ modify_last_human_message(memoryA, st.session_state['sel_phase'])
165
  st.session_state.changed_cpc = True
166
  def on_change_bp():
167
  bp_push2db()
utils/app_utils.py CHANGED
@@ -96,8 +96,8 @@ def is_model_alive(name, timeout=2, model_type="classificator"):
96
  endpoint_url="https://api.openai.com/v1/models"
97
  headers = {"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",}
98
  try:
99
- _ = requests.get(url=endpoint_url, headers=headers, timeout=1)
100
- return "200"
101
  except:
102
  return "404"
103
 
@@ -107,4 +107,5 @@ def are_models_alive():
107
  for config in ENDPOINT_NAMES.values():
108
  models_alive.append(is_model_alive(**config))
109
  openai = is_model_alive("openai", model_type="openai")
110
- return all([x=="200" for x in models_alive + [openai]])
 
 
96
  endpoint_url="https://api.openai.com/v1/models"
97
  headers = {"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",}
98
  try:
99
+ response = requests.get(url=endpoint_url, headers=headers, timeout=1)
100
+ return str(response.status_code)
101
  except:
102
  return "404"
103
 
 
107
  for config in ENDPOINT_NAMES.values():
108
  models_alive.append(is_model_alive(**config))
109
  openai = is_model_alive("openai", model_type="openai")
110
+ models_alive.append(openai)
111
+ return all([x=="200" for x in models_alive])