import streamlit as st from streamlit.logger import get_logger import requests import os from .config import model_name_or_path, BP_THRESHOLD from transformers import AutoTokenizer from utils.mongo_utils import new_bp_comparison from app_config import ENDPOINT_NAMES logger = get_logger(__name__) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left") BP_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["BadPractices"]['name']) HEADERS = { "Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}", "Content-Type": "application/json", } def bp_predict_message(context, input): # context = memory.load_memory_variables({})[memory.memory_key] encoding = tokenizer( context, input, truncation="only_first", max_length = tokenizer.model_max_length - 2, )['input_ids'] body_request = { "inputs": [tokenizer.decode(encoding[1:-1])], "params": { "top_k": None } } try: # Send request to Serving logger.debug(f"raw BP body is {body_request}") response = requests.post(url=BP_URL, headers=HEADERS, json=body_request) if response.status_code == 200: response = response.json()['predictions'][0] logger.debug(f"Raw BP prediction is {response}") return [{k:v > BP_THRESHOLD if k=="score" else v for k,v in dict_.items()} for _, dict_ in response.items() ] else: raise Exception(f"Error in response: {response.json()}") except Exception as e: logger.debug(f"Error in response: {e}") st.switch_page("pages/model_loader.py") def bp_push2db(manual_confirmation=None): if manual_confirmation is None: if st.session_state.sel_bp == "Advice": manual_confirmation = {"is_advice":True, "is_personal_info":False} elif st.session_state.sel_bp == "Personal Info": manual_confirmation = {"is_advice":False, "is_personal_info":True} elif st.session_state.sel_bp == "Advice & Personal Info": manual_confirmation = {"is_advice":True, "is_personal_info":True} else: manual_confirmation = {"is_advice":False, "is_personal_info":False} new_bp_comparison(**{ "client": st.session_state['db_client'], "convo_id": st.session_state['convo_id'], "model": st.session_state['source'], "context": st.session_state["context"], "last_message": st.session_state["last_message"], "ytrue": manual_confirmation, "ypred": {x['label']:x['score'] for x in st.session_state['bp_prediction']}, })