Spaces:
Sleeping
Sleeping
import streamlit as st | |
from streamlit.logger import get_logger | |
import requests | |
import os | |
from langchain_core.messages import HumanMessage | |
from .config import model_name_or_path | |
from transformers import AutoTokenizer | |
from utils.mongo_utils import new_cpc_comparison | |
from app_config import ENDPOINT_NAMES | |
logger = get_logger(__name__) | |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left") | |
CPC_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["CPC"]['name']) | |
HEADERS = { | |
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}", | |
"Content-Type": "application/json", | |
} | |
def modify_last_human_message(memory, new_phase): | |
# Travel list backwards | |
for msg in memory.chat_memory.messages[::-1]: | |
if type(msg) == HumanMessage: | |
msg.response_metadata = {"phase":new_phase} | |
break | |
def cpc_predict_message(context, input): | |
# context = memory.load_memory_variables({})[memory.memory_key] | |
encoding = tokenizer( | |
context, | |
input, | |
truncation="only_first", | |
max_length = tokenizer.model_max_length - 2, | |
)['input_ids'] | |
body_request = { | |
"inputs": [tokenizer.decode(encoding[1:-1])] | |
} | |
try: | |
# Send request to Serving | |
response = requests.post(url=CPC_URL, headers=HEADERS, json=body_request) | |
if response.status_code == 200: | |
return response.json()['predictions'][0]["0"]["label"] | |
else: | |
raise Exception(f"Error in response: {response.json()}") | |
except Exception as e: | |
logger.debug(f"Error in response: {e}") | |
st.switch_page("pages/model_loader.py") | |
def cpc_push2db(is_same): | |
text_is_same = "SAME" if is_same else "WRONG" | |
logger.debug(f"pushing new {text_is_same} CPC") | |
new_cpc_comparison(**{ | |
"client": st.session_state['db_client'], | |
"convo_id": st.session_state['convo_id'], | |
"model": st.session_state['source'], | |
"context": st.session_state["context"], | |
"last_message": st.session_state["last_message"], | |
"ytrue": st.session_state["last_phase"] if is_same else st.session_state["sel_phase"], | |
"ypred": st.session_state["last_phase"], | |
}) | |