import streamlit as st from streamlit.logger import get_logger import requests import os from langchain_core.messages import HumanMessage from .config import model_name_or_path from transformers import AutoTokenizer from utils.mongo_utils import new_cpc_comparison from app_config import ENDPOINT_NAMES logger = get_logger(__name__) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left") CPC_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["CPC"]['name']) HEADERS = { "Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}", "Content-Type": "application/json", } def modify_last_human_message(memory, new_phase): # Travel list backwards for msg in memory.chat_memory.messages[::-1]: if type(msg) == HumanMessage: msg.response_metadata = {"phase":new_phase} break def cpc_predict_message(context, input): # context = memory.load_memory_variables({})[memory.memory_key] encoding = tokenizer( context, input, truncation="only_first", max_length = tokenizer.model_max_length - 2, )['input_ids'] body_request = { "inputs": [tokenizer.decode(encoding[1:-1])] } try: # Send request to Serving response = requests.post(url=CPC_URL, headers=HEADERS, json=body_request) if response.status_code == 200: return response.json()['predictions'][0]["0"]["label"] else: raise Exception(f"Error in response: {response.json()}") except Exception as e: logger.debug(f"Error in response: {e}") st.switch_page("pages/model_loader.py") def cpc_push2db(is_same): text_is_same = "SAME" if is_same else "WRONG" logger.debug(f"pushing new {text_is_same} CPC") new_cpc_comparison(**{ "client": st.session_state['db_client'], "convo_id": st.session_state['convo_id'], "model": st.session_state['source'], "context": st.session_state["context"], "last_message": st.session_state["last_message"], "ytrue": st.session_state["last_phase"] if is_same else st.session_state["sel_phase"], "ypred": st.session_state["last_phase"], })