ivnban27-ctl's picture
training-adherence-features (#1)
f3e0ba5 verified
raw
history blame
5.1 kB
import os
import re
import requests
import string
import streamlit as st
from streamlit.logger import get_logger
from app_config import ENDPOINT_NAMES
from models.ta_models.config import NAME2PROMPT, NAME2PROMPT_EXPL, START_INST, END_INST, QUESTIONDEFAULTS, TA_OPTIONS, NAME2QUESTION
import pandas as pd
from langchain_core.messages import AIMessage, HumanMessage
from models.ta_models.ta_prompt_utils import load_context
from utils.mongo_utils import new_convo_scoring_comparison
logger = get_logger(__name__)
TA_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["training_adherence"]['name'])
HEADERS = {
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
"Content-Type": "application/json",
}
def memory2df(memory, conversation_id="convo1234"):
df = []
for i, msg in enumerate(memory.buffer_as_messages):
actor_role = "texter" if type(msg) == AIMessage else "helper" if type(msg) == HumanMessage else None
if actor_role:
convo_part = msg.response_metadata.get("phase",None)
row = {"conversation_id":conversation_id, "message_number":i+1, "actor_role":actor_role, "message":msg.content, "convo_part":convo_part}
df.append(row)
return pd.DataFrame(df)
def get_default(question, make_explanation=False):
return QUESTIONDEFAULTS[question][make_explanation]
def get_context(memory, question, make_explanation=False, **kwargs):
df = memory2df(memory, **kwargs)
contexti = load_context(df, question, "messages", "individual").iloc[0]
if contexti == "":
return ""
if make_explanation:
return NAME2PROMPT_EXPL[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST)
else:
return NAME2PROMPT[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST)
def post_process_response(full_response, delimiter="\n\n", n=2):
parts = full_response.split(delimiter)[:n]
response = extract_response(parts[0])
logger.debug(f"Response extracted is {response}")
if len(parts) > 1:
if len(parts[0]) < len(parts[1]):
full_response = parts[1]
else: full_response = parts[0]
else:
full_response = parts[0]
explanation = full_response.lstrip(response).lstrip(string.punctuation)
explanation = explanation.strip()
logger.debug(f"Explanation extracted is {explanation}")
return response, explanation
def TA_predict_convo(memory, question, make_explanation=False, **kwargs):
full_convo = memory.load_memory_variables({})[memory.memory_key]
PROMPT = get_context(memory, question, make_explanation=False, **kwargs)
logger.debug(f"Raw TA prompt is {PROMPT}")
if PROMPT == "":
full_response = get_default(question, make_explanation)
return full_convo, PROMPT, full_response
body_request = {
"prompt": PROMPT,
"temperature": 0,
"max_tokens": 3,
}
try:
# Send request to Serving
response = requests.post(url=TA_URL, headers=HEADERS, json=body_request)
if response.status_code == 200:
response = response.json()
else:
raise Exception(f"Error in response: {response.json()}")
full_response = response[0]['choices'][0]['text']
if not make_explanation:
return full_convo, PROMPT, full_response
else:
extract_response, _ = post_process_response(full_response)
PROMPT = get_context(memory, question, make_explanation=True, **kwargs)
PROMPT = PROMPT + f" {extract_response}"
logger.debug(f"Raw TA prompt for Explanation is {PROMPT}")
body_request["prompt"] = PROMPT
body_request["max_tokens"] = 128
response_expl = requests.post(url=TA_URL, headers=HEADERS, json=body_request)
if response_expl.status_code == 200:
response_expl = response_expl.json()
else:
raise Exception(f"Error in response: {response_expl.json()}")
full_response_expl = f"{extract_response} {response_expl[0]['choices'][0]['text']}"
return full_convo, PROMPT, full_response_expl
except Exception as e:
logger.debug(f"Error in response: {e}")
st.switch_page("pages/model_loader.py")
def extract_response(x: str, default: str = TA_OPTIONS[0]) -> str:
"""Extract Response from generated answer
Extract only search strings
Args:
x (str): prediction
default (str, optional): default in case no response founds. Defaults to "N/A".
Returns:
str: _description_
"""
try:
return re.findall("|".join(TA_OPTIONS), x)[0]
except Exception:
return default
def ta_push_convo_comparison(ytrue, ypred):
new_convo_scoring_comparison(**{
"client": st.session_state['db_client'],
"convo_id": st.session_state['convo_id'],
"context": st.session_state["context"] + "\nhelper:" + st.session_state["last_message"],
"ytrue": ytrue,
"ypred": ypred,
})