Spaces:
Sleeping
Sleeping
File size: 1,931 Bytes
70a482f 2e79a3c 70a482f 5e4965f 70a482f 5e4965f 70a482f 5e4965f 70a482f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import streamlit as st
from streamlit.logger import get_logger
import requests
import os
from .config import model_name_or_path
from transformers import AutoTokenizer
from utils.mongo_utils import new_cpc_comparison
from app_config import ENDPOINT_NAMES
logger = get_logger(__name__)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left")
CPC_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["CPC"]['name'])
HEADERS = {
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
"Content-Type": "application/json",
}
def cpc_predict_message(context, input):
# context = memory.load_memory_variables({})[memory.memory_key]
encoding = tokenizer(
context,
input,
truncation="only_first",
max_length = tokenizer.model_max_length - 2,
)['input_ids']
body_request = {
"inputs": [tokenizer.decode(encoding[1:-1])]
}
try:
# Send request to Serving
response = requests.post(url=CPC_URL, headers=HEADERS, json=body_request)
if response.status_code == 200:
return response.json()['predictions'][0]["0"]["label"]
else:
raise Exception(f"Error in response: {response.json()}")
except Exception as e:
logger.debug(f"Error in response: {e}")
st.switch_page("pages/model_loader.py")
def cpc_push2db(is_same):
text_is_same = "SAME" if is_same else "WRONG"
logger.debug(f"pushing new {text_is_same} CPC")
new_cpc_comparison(**{
"client": st.session_state['db_client'],
"convo_id": st.session_state['convo_id'],
"model": st.session_state['source'],
"context": st.session_state["context"],
"last_message": st.session_state["last_message"],
"ytrue": st.session_state["last_phase"] if is_same else st.session_state["sel_phase"],
"ypred": st.session_state["last_phase"],
})
|