Spaces:
Sleeping
Sleeping
import os | |
import re | |
import requests | |
import string | |
import streamlit as st | |
from streamlit.logger import get_logger | |
from app_config import ENDPOINT_NAMES | |
from models.ta_models.config import NAME2PROMPT, NAME2PROMPT_EXPL, START_INST, END_INST, QUESTIONDEFAULTS, TA_OPTIONS, NAME2QUESTION | |
import pandas as pd | |
from langchain_core.messages import AIMessage, HumanMessage | |
from models.ta_models.ta_prompt_utils import load_context | |
from utils.mongo_utils import new_convo_scoring_comparison | |
logger = get_logger(__name__) | |
TA_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["training_adherence"]['name']) | |
HEADERS = { | |
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}", | |
"Content-Type": "application/json", | |
} | |
def memory2df(memory, conversation_id="convo1234"): | |
df = [] | |
for i, msg in enumerate(memory.buffer_as_messages): | |
actor_role = "texter" if type(msg) == AIMessage else "helper" if type(msg) == HumanMessage else None | |
if actor_role: | |
convo_part = msg.response_metadata.get("phase",None) | |
row = {"conversation_id":conversation_id, "message_number":i+1, "actor_role":actor_role, "message":msg.content, "convo_part":convo_part} | |
df.append(row) | |
return pd.DataFrame(df) | |
def get_default(question, make_explanation=False): | |
return QUESTIONDEFAULTS[question][make_explanation] | |
def get_context(memory, question, make_explanation=False, **kwargs): | |
df = memory2df(memory, **kwargs) | |
contexti = load_context(df, question, "messages", "individual").iloc[0] | |
if contexti == "": | |
return "" | |
if make_explanation: | |
return NAME2PROMPT_EXPL[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST) | |
else: | |
return NAME2PROMPT[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST) | |
def post_process_response(full_response, delimiter="\n\n", n=2): | |
parts = full_response.split(delimiter)[:n] | |
response = extract_response(parts[0]) | |
logger.debug(f"Response extracted is {response}") | |
if len(parts) > 1: | |
if len(parts[0]) < len(parts[1]): | |
full_response = parts[1] | |
else: full_response = parts[0] | |
else: | |
full_response = parts[0] | |
explanation = full_response.lstrip(response).lstrip(string.punctuation) | |
explanation = explanation.strip() | |
logger.debug(f"Explanation extracted is {explanation}") | |
return response, explanation | |
def TA_predict_convo(memory, question, make_explanation=False, **kwargs): | |
full_convo = memory.load_memory_variables({})[memory.memory_key] | |
PROMPT = get_context(memory, question, make_explanation=False, **kwargs) | |
logger.debug(f"Raw TA prompt is {PROMPT}") | |
if PROMPT == "": | |
full_response = get_default(question, make_explanation) | |
return full_convo, PROMPT, full_response | |
body_request = { | |
"prompt": PROMPT, | |
"temperature": 0, | |
"max_tokens": 3, | |
} | |
try: | |
# Send request to Serving | |
response = requests.post(url=TA_URL, headers=HEADERS, json=body_request) | |
if response.status_code == 200: | |
response = response.json() | |
else: | |
raise Exception(f"Error in response: {response.json()}") | |
full_response = response[0]['choices'][0]['text'] | |
if not make_explanation: | |
return full_convo, PROMPT, full_response | |
else: | |
extract_response, _ = post_process_response(full_response) | |
PROMPT = get_context(memory, question, make_explanation=True, **kwargs) | |
PROMPT = PROMPT + f" {extract_response}" | |
logger.debug(f"Raw TA prompt for Explanation is {PROMPT}") | |
body_request["prompt"] = PROMPT | |
body_request["max_tokens"] = 128 | |
response_expl = requests.post(url=TA_URL, headers=HEADERS, json=body_request) | |
if response_expl.status_code == 200: | |
response_expl = response_expl.json() | |
else: | |
raise Exception(f"Error in response: {response_expl.json()}") | |
full_response_expl = f"{extract_response} {response_expl[0]['choices'][0]['text']}" | |
return full_convo, PROMPT, full_response_expl | |
except Exception as e: | |
logger.debug(f"Error in response: {e}") | |
st.switch_page("pages/model_loader.py") | |
def extract_response(x: str, default: str = TA_OPTIONS[0]) -> str: | |
"""Extract Response from generated answer | |
Extract only search strings | |
Args: | |
x (str): prediction | |
default (str, optional): default in case no response founds. Defaults to "N/A". | |
Returns: | |
str: _description_ | |
""" | |
try: | |
return re.findall("|".join(TA_OPTIONS), x)[0] | |
except Exception: | |
return default | |
def ta_push_convo_comparison(ytrue, ypred): | |
new_convo_scoring_comparison(**{ | |
"client": st.session_state['db_client'], | |
"convo_id": st.session_state['convo_id'], | |
"context": st.session_state["context"] + "\nhelper:" + st.session_state["last_message"], | |
"ytrue": ytrue, | |
"ypred": ypred, | |
}) |