import os import re import requests import string import streamlit as st from streamlit.logger import get_logger from app_config import ENDPOINT_NAMES from models.ta_models.config import NAME2PROMPT, NAME2PROMPT_EXPL, START_INST, END_INST, QUESTIONDEFAULTS, TA_OPTIONS, NAME2QUESTION import pandas as pd from langchain_core.messages import AIMessage, HumanMessage from models.ta_models.ta_prompt_utils import load_context from utils.mongo_utils import new_convo_scoring_comparison logger = get_logger(__name__) TA_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["training_adherence"]['name']) HEADERS = { "Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}", "Content-Type": "application/json", } def memory2df(memory, conversation_id="convo1234"): df = [] for i, msg in enumerate(memory.buffer_as_messages): actor_role = "texter" if type(msg) == AIMessage else "helper" if type(msg) == HumanMessage else None if actor_role: convo_part = msg.response_metadata.get("phase",None) row = {"conversation_id":conversation_id, "message_number":i+1, "actor_role":actor_role, "message":msg.content, "convo_part":convo_part} df.append(row) return pd.DataFrame(df) def get_default(question, make_explanation=False): return QUESTIONDEFAULTS[question][make_explanation] def get_context(memory, question, make_explanation=False, **kwargs): df = memory2df(memory, **kwargs) contexti = load_context(df, question, "messages", "individual").iloc[0] if contexti == "": return "" if make_explanation: return NAME2PROMPT_EXPL[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST) else: return NAME2PROMPT[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST) def post_process_response(full_response, delimiter="\n\n", n=2): parts = full_response.split(delimiter)[:n] response = extract_response(parts[0]) logger.debug(f"Response extracted is {response}") if len(parts) > 1: if len(parts[0]) < len(parts[1]): full_response = parts[1] else: full_response = parts[0] else: full_response = parts[0] explanation = full_response.lstrip(response).lstrip(string.punctuation) explanation = explanation.strip() logger.debug(f"Explanation extracted is {explanation}") return response, explanation def TA_predict_convo(memory, question, make_explanation=False, **kwargs): full_convo = memory.load_memory_variables({})[memory.memory_key] PROMPT = get_context(memory, question, make_explanation=False, **kwargs) logger.debug(f"Raw TA prompt is {PROMPT}") if PROMPT == "": full_response = get_default(question, make_explanation) return full_convo, PROMPT, full_response body_request = { "prompt": PROMPT, "temperature": 0, "max_tokens": 3, } try: # Send request to Serving response = requests.post(url=TA_URL, headers=HEADERS, json=body_request) if response.status_code == 200: response = response.json() else: raise Exception(f"Error in response: {response.json()}") full_response = response[0]['choices'][0]['text'] if not make_explanation: return full_convo, PROMPT, full_response else: extract_response, _ = post_process_response(full_response) PROMPT = get_context(memory, question, make_explanation=True, **kwargs) PROMPT = PROMPT + f" {extract_response}" logger.debug(f"Raw TA prompt for Explanation is {PROMPT}") body_request["prompt"] = PROMPT body_request["max_tokens"] = 128 response_expl = requests.post(url=TA_URL, headers=HEADERS, json=body_request) if response_expl.status_code == 200: response_expl = response_expl.json() else: raise Exception(f"Error in response: {response_expl.json()}") full_response_expl = f"{extract_response} {response_expl[0]['choices'][0]['text']}" return full_convo, PROMPT, full_response_expl except Exception as e: logger.debug(f"Error in response: {e}") st.switch_page("pages/model_loader.py") def extract_response(x: str, default: str = TA_OPTIONS[0]) -> str: """Extract Response from generated answer Extract only search strings Args: x (str): prediction default (str, optional): default in case no response founds. Defaults to "N/A". Returns: str: _description_ """ try: return re.findall("|".join(TA_OPTIONS), x)[0] except Exception: return default def ta_push_convo_comparison(ytrue, ypred): new_convo_scoring_comparison(**{ "client": st.session_state['db_client'], "convo_id": st.session_state['convo_id'], "context": st.session_state["context"] + "\nhelper:" + st.session_state["last_message"], "ytrue": ytrue, "ypred": ypred, })