Spaces:
Sleeping
Sleeping
ivnban27-ctl
commited on
Commit
·
7e14368
1
Parent(s):
5e4965f
ta utils fix for explanation
Browse files- models/ta_models/ta_utils.py +23 -9
- pages/convosim.py +4 -4
- pages/training_adherence.py +2 -0
models/ta_models/ta_utils.py
CHANGED
@@ -60,18 +60,16 @@ def post_process_response(full_response, delimiter="\n\n", n=2):
|
|
60 |
|
61 |
def TA_predict_convo(memory, question, make_explanation=False, **kwargs):
|
62 |
full_convo = memory.load_memory_variables({})[memory.memory_key]
|
63 |
-
PROMPT = get_context(memory, question, make_explanation=
|
64 |
logger.debug(f"Raw TA prompt is {PROMPT}")
|
65 |
if PROMPT == "":
|
66 |
full_response = get_default(question, make_explanation)
|
67 |
-
# response, explanation = post_process_response(full_response)
|
68 |
return full_convo, PROMPT, full_response
|
69 |
|
70 |
-
max_tokens = 128 if make_explanation else 3
|
71 |
body_request = {
|
72 |
"prompt": PROMPT,
|
73 |
"temperature": 0,
|
74 |
-
"max_tokens":
|
75 |
}
|
76 |
|
77 |
try:
|
@@ -79,12 +77,28 @@ def TA_predict_convo(memory, question, make_explanation=False, **kwargs):
|
|
79 |
response = requests.post(url=TA_URL, headers=HEADERS, json=body_request)
|
80 |
if response.status_code == 200:
|
81 |
response = response.json()
|
|
|
|
|
82 |
full_response = response[0]['choices'][0]['text']
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
def extract_response(x: str, default: str = TA_OPTIONS[0]) -> str:
|
90 |
"""Extract Response from generated answer
|
|
|
60 |
|
61 |
def TA_predict_convo(memory, question, make_explanation=False, **kwargs):
|
62 |
full_convo = memory.load_memory_variables({})[memory.memory_key]
|
63 |
+
PROMPT = get_context(memory, question, make_explanation=False, **kwargs)
|
64 |
logger.debug(f"Raw TA prompt is {PROMPT}")
|
65 |
if PROMPT == "":
|
66 |
full_response = get_default(question, make_explanation)
|
|
|
67 |
return full_convo, PROMPT, full_response
|
68 |
|
|
|
69 |
body_request = {
|
70 |
"prompt": PROMPT,
|
71 |
"temperature": 0,
|
72 |
+
"max_tokens": 3,
|
73 |
}
|
74 |
|
75 |
try:
|
|
|
77 |
response = requests.post(url=TA_URL, headers=HEADERS, json=body_request)
|
78 |
if response.status_code == 200:
|
79 |
response = response.json()
|
80 |
+
else:
|
81 |
+
raise Exception(f"Error in response: {response.json()}")
|
82 |
full_response = response[0]['choices'][0]['text']
|
83 |
+
if not make_explanation:
|
84 |
+
return full_convo, PROMPT, full_response
|
85 |
+
else:
|
86 |
+
extract_response, _ = post_process_response(full_response)
|
87 |
+
PROMPT = get_context(memory, question, make_explanation=True, **kwargs)
|
88 |
+
PROMPT = PROMPT + f" {extract_response}"
|
89 |
+
logger.debug(f"Raw TA prompt for Explanation is {PROMPT}")
|
90 |
+
body_request["prompt"] = PROMPT
|
91 |
+
body_request["max_tokens"] = 128
|
92 |
+
response_expl = requests.post(url=TA_URL, headers=HEADERS, json=body_request)
|
93 |
+
if response_expl.status_code == 200:
|
94 |
+
response_expl = response_expl.json()
|
95 |
+
else:
|
96 |
+
raise Exception(f"Error in response: {response_expl.json()}")
|
97 |
+
full_response_expl = f"{extract_response} {response_expl[0]['choices'][0]['text']}"
|
98 |
+
return full_convo, PROMPT, full_response_expl
|
99 |
+
except Exception as e:
|
100 |
+
logger.debug(f"Error in response: {e}")
|
101 |
+
st.switch_page("pages/model_loader.py")
|
102 |
|
103 |
def extract_response(x: str, default: str = TA_OPTIONS[0]) -> str:
|
104 |
"""Extract Response from generated answer
|
pages/convosim.py
CHANGED
@@ -136,10 +136,10 @@ if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG
|
|
136 |
for bp in st.session_state['bp_prediction']:
|
137 |
if bp["score"]:
|
138 |
st.toast(f"Detected {BP_LAB2STR[bp['label']]} in the last message!", icon=":material/warning:")
|
139 |
-
|
140 |
-
|
141 |
-
else:
|
142 |
-
|
143 |
|
144 |
with st.sidebar:
|
145 |
st.divider()
|
|
|
136 |
for bp in st.session_state['bp_prediction']:
|
137 |
if bp["score"]:
|
138 |
st.toast(f"Detected {BP_LAB2STR[bp['label']]} in the last message!", icon=":material/warning:")
|
139 |
+
|
140 |
+
sent_request_llm(llm_chain, prompt)
|
141 |
+
# else:
|
142 |
+
# sent_request_llm(llm_chain, prompt)
|
143 |
|
144 |
with st.sidebar:
|
145 |
st.divider()
|
pages/training_adherence.py
CHANGED
@@ -6,6 +6,8 @@ from utils.app_utils import are_models_alive
|
|
6 |
from models.ta_models.ta_utils import TA_predict_convo, ta_push_convo_comparison, post_process_response
|
7 |
from models.ta_models.config import QUESTION2PHASE, NAME2QUESTION, TA_OPTIONS
|
8 |
|
|
|
|
|
9 |
if "memory" not in st.session_state:
|
10 |
st.switch_page("pages/convosim.py")
|
11 |
|
|
|
6 |
from models.ta_models.ta_utils import TA_predict_convo, ta_push_convo_comparison, post_process_response
|
7 |
from models.ta_models.config import QUESTION2PHASE, NAME2QUESTION, TA_OPTIONS
|
8 |
|
9 |
+
st.set_page_config(page_title="Conversation Simulator - Scoring")
|
10 |
+
|
11 |
if "memory" not in st.session_state:
|
12 |
st.switch_page("pages/convosim.py")
|
13 |
|