File size: 1,931 Bytes
70a482f
 
 
 
 
 
 
 
 
 
 
 
2e79a3c
70a482f
 
 
 
 
 
 
 
 
 
 
5e4965f
70a482f
 
5e4965f
70a482f
 
 
 
 
 
 
5e4965f
 
 
 
 
70a482f
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import streamlit as st
from streamlit.logger import get_logger
import requests
import os
from .config import model_name_or_path
from transformers import AutoTokenizer
from utils.mongo_utils import new_cpc_comparison
from app_config import ENDPOINT_NAMES

logger = get_logger(__name__)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left")
CPC_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["CPC"]['name'])
HEADERS = {
    "Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
    "Content-Type": "application/json",
}

def cpc_predict_message(context, input):
    # context = memory.load_memory_variables({})[memory.memory_key]
    encoding = tokenizer(
        context,
        input,
        truncation="only_first",
        max_length = tokenizer.model_max_length - 2,
    )['input_ids']
    body_request = {
        "inputs": [tokenizer.decode(encoding[1:-1])]
    }

    try:
        # Send request to Serving
        response = requests.post(url=CPC_URL, headers=HEADERS, json=body_request)
        if response.status_code == 200:
            return response.json()['predictions'][0]["0"]["label"]
        else:
            raise Exception(f"Error in response: {response.json()}")
    except Exception as e:
        logger.debug(f"Error in response: {e}")
        st.switch_page("pages/model_loader.py")

def cpc_push2db(is_same):
    text_is_same = "SAME" if is_same else "WRONG"
    logger.debug(f"pushing new {text_is_same} CPC")
    new_cpc_comparison(**{
        "client": st.session_state['db_client'],
        "convo_id": st.session_state['convo_id'],
        "model": st.session_state['source'],
        "context": st.session_state["context"],
        "last_message": st.session_state["last_message"],
        "ytrue": st.session_state["last_phase"] if is_same else st.session_state["sel_phase"],
        "ypred": st.session_state["last_phase"],
    })