ivnban27-ctl's picture
training-adherence-features (#1)
f3e0ba5 verified
raw
history blame
2.66 kB
import streamlit as st
from streamlit.logger import get_logger
import requests
import os
from .config import model_name_or_path, BP_THRESHOLD
from transformers import AutoTokenizer
from utils.mongo_utils import new_bp_comparison
from app_config import ENDPOINT_NAMES
logger = get_logger(__name__)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left")
BP_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["BadPractices"]['name'])
HEADERS = {
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
"Content-Type": "application/json",
}
def bp_predict_message(context, input):
# context = memory.load_memory_variables({})[memory.memory_key]
encoding = tokenizer(
context,
input,
truncation="only_first",
max_length = tokenizer.model_max_length - 2,
)['input_ids']
body_request = {
"inputs": [tokenizer.decode(encoding[1:-1])],
"params": {
"top_k": None
}
}
try:
# Send request to Serving
logger.debug(f"raw BP body is {body_request}")
response = requests.post(url=BP_URL, headers=HEADERS, json=body_request)
if response.status_code == 200:
response = response.json()['predictions'][0]
logger.debug(f"Raw BP prediction is {response}")
return [{k:v > BP_THRESHOLD if k=="score" else v for k,v in dict_.items()} for _, dict_ in response.items() ]
else:
raise Exception(f"Error in response: {response.json()}")
except Exception as e:
logger.debug(f"Error in response: {e}")
st.switch_page("pages/model_loader.py")
def bp_push2db(manual_confirmation=None):
if manual_confirmation is None:
if st.session_state.sel_bp == "Advice":
manual_confirmation = {"is_advice":True, "is_personal_info":False}
elif st.session_state.sel_bp == "Personal Info":
manual_confirmation = {"is_advice":False, "is_personal_info":True}
elif st.session_state.sel_bp == "Advice & Personal Info":
manual_confirmation = {"is_advice":True, "is_personal_info":True}
else:
manual_confirmation = {"is_advice":False, "is_personal_info":False}
new_bp_comparison(**{
"client": st.session_state['db_client'],
"convo_id": st.session_state['convo_id'],
"model": st.session_state['source'],
"context": st.session_state["context"],
"last_message": st.session_state["last_message"],
"ytrue": manual_confirmation,
"ypred": {x['label']:x['score'] for x in st.session_state['bp_prediction']},
})