Spaces:
Sleeping
Sleeping
import streamlit as st | |
from streamlit.logger import get_logger | |
import requests | |
import os | |
from .config import model_name_or_path, BP_THRESHOLD | |
from transformers import AutoTokenizer | |
from utils.mongo_utils import new_bp_comparison | |
from app_config import ENDPOINT_NAMES | |
logger = get_logger(__name__) | |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left") | |
BP_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["BadPractices"]['name']) | |
HEADERS = { | |
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}", | |
"Content-Type": "application/json", | |
} | |
def bp_predict_message(context, input): | |
# context = memory.load_memory_variables({})[memory.memory_key] | |
encoding = tokenizer( | |
context, | |
input, | |
truncation="only_first", | |
max_length = tokenizer.model_max_length - 2, | |
)['input_ids'] | |
body_request = { | |
"inputs": [tokenizer.decode(encoding[1:-1])], | |
"params": { | |
"top_k": None | |
} | |
} | |
try: | |
# Send request to Serving | |
logger.debug(f"raw BP body is {body_request}") | |
response = requests.post(url=BP_URL, headers=HEADERS, json=body_request) | |
if response.status_code == 200: | |
response = response.json()['predictions'][0] | |
logger.debug(f"Raw BP prediction is {response}") | |
return [{k:v > BP_THRESHOLD if k=="score" else v for k,v in dict_.items()} for _, dict_ in response.items() ] | |
else: | |
raise Exception(f"Error in response: {response.json()}") | |
except Exception as e: | |
logger.debug(f"Error in response: {e}") | |
st.switch_page("pages/model_loader.py") | |
def bp_push2db(manual_confirmation=None): | |
if manual_confirmation is None: | |
if st.session_state.sel_bp == "Advice": | |
manual_confirmation = {"is_advice":True, "is_personal_info":False} | |
elif st.session_state.sel_bp == "Personal Info": | |
manual_confirmation = {"is_advice":False, "is_personal_info":True} | |
elif st.session_state.sel_bp == "Advice & Personal Info": | |
manual_confirmation = {"is_advice":True, "is_personal_info":True} | |
else: | |
manual_confirmation = {"is_advice":False, "is_personal_info":False} | |
new_bp_comparison(**{ | |
"client": st.session_state['db_client'], | |
"convo_id": st.session_state['convo_id'], | |
"model": st.session_state['source'], | |
"context": st.session_state["context"], | |
"last_message": st.session_state["last_message"], | |
"ytrue": manual_confirmation, | |
"ypred": {x['label']:x['score'] for x in st.session_state['bp_prediction']}, | |
}) | |