ivnban27-ctl commited on
Commit
7e14368
·
1 Parent(s): 5e4965f

ta utils fix for explanation

Browse files
models/ta_models/ta_utils.py CHANGED
@@ -60,18 +60,16 @@ def post_process_response(full_response, delimiter="\n\n", n=2):
60
 
61
  def TA_predict_convo(memory, question, make_explanation=False, **kwargs):
62
  full_convo = memory.load_memory_variables({})[memory.memory_key]
63
- PROMPT = get_context(memory, question, make_explanation=make_explanation, **kwargs)
64
  logger.debug(f"Raw TA prompt is {PROMPT}")
65
  if PROMPT == "":
66
  full_response = get_default(question, make_explanation)
67
- # response, explanation = post_process_response(full_response)
68
  return full_convo, PROMPT, full_response
69
 
70
- max_tokens = 128 if make_explanation else 3
71
  body_request = {
72
  "prompt": PROMPT,
73
  "temperature": 0,
74
- "max_tokens": max_tokens,
75
  }
76
 
77
  try:
@@ -79,12 +77,28 @@ def TA_predict_convo(memory, question, make_explanation=False, **kwargs):
79
  response = requests.post(url=TA_URL, headers=HEADERS, json=body_request)
80
  if response.status_code == 200:
81
  response = response.json()
 
 
82
  full_response = response[0]['choices'][0]['text']
83
- logger.debug(f"Raw TA response is {full_response}")
84
- # response, explanation = post_process_response(full_response)
85
- return full_convo, PROMPT, full_response
86
- except:
87
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  def extract_response(x: str, default: str = TA_OPTIONS[0]) -> str:
90
  """Extract Response from generated answer
 
60
 
61
  def TA_predict_convo(memory, question, make_explanation=False, **kwargs):
62
  full_convo = memory.load_memory_variables({})[memory.memory_key]
63
+ PROMPT = get_context(memory, question, make_explanation=False, **kwargs)
64
  logger.debug(f"Raw TA prompt is {PROMPT}")
65
  if PROMPT == "":
66
  full_response = get_default(question, make_explanation)
 
67
  return full_convo, PROMPT, full_response
68
 
 
69
  body_request = {
70
  "prompt": PROMPT,
71
  "temperature": 0,
72
+ "max_tokens": 3,
73
  }
74
 
75
  try:
 
77
  response = requests.post(url=TA_URL, headers=HEADERS, json=body_request)
78
  if response.status_code == 200:
79
  response = response.json()
80
+ else:
81
+ raise Exception(f"Error in response: {response.json()}")
82
  full_response = response[0]['choices'][0]['text']
83
+ if not make_explanation:
84
+ return full_convo, PROMPT, full_response
85
+ else:
86
+ extract_response, _ = post_process_response(full_response)
87
+ PROMPT = get_context(memory, question, make_explanation=True, **kwargs)
88
+ PROMPT = PROMPT + f" {extract_response}"
89
+ logger.debug(f"Raw TA prompt for Explanation is {PROMPT}")
90
+ body_request["prompt"] = PROMPT
91
+ body_request["max_tokens"] = 128
92
+ response_expl = requests.post(url=TA_URL, headers=HEADERS, json=body_request)
93
+ if response_expl.status_code == 200:
94
+ response_expl = response_expl.json()
95
+ else:
96
+ raise Exception(f"Error in response: {response_expl.json()}")
97
+ full_response_expl = f"{extract_response} {response_expl[0]['choices'][0]['text']}"
98
+ return full_convo, PROMPT, full_response_expl
99
+ except Exception as e:
100
+ logger.debug(f"Error in response: {e}")
101
+ st.switch_page("pages/model_loader.py")
102
 
103
  def extract_response(x: str, default: str = TA_OPTIONS[0]) -> str:
104
  """Extract Response from generated answer
pages/convosim.py CHANGED
@@ -136,10 +136,10 @@ if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG
136
  for bp in st.session_state['bp_prediction']:
137
  if bp["score"]:
138
  st.toast(f"Detected {BP_LAB2STR[bp['label']]} in the last message!", icon=":material/warning:")
139
- st.session_state.changed_bp = True
140
- sent_request_llm(llm_chain, prompt)
141
- else:
142
- sent_request_llm(llm_chain, prompt)
143
 
144
  with st.sidebar:
145
  st.divider()
 
136
  for bp in st.session_state['bp_prediction']:
137
  if bp["score"]:
138
  st.toast(f"Detected {BP_LAB2STR[bp['label']]} in the last message!", icon=":material/warning:")
139
+
140
+ sent_request_llm(llm_chain, prompt)
141
+ # else:
142
+ # sent_request_llm(llm_chain, prompt)
143
 
144
  with st.sidebar:
145
  st.divider()
pages/training_adherence.py CHANGED
@@ -6,6 +6,8 @@ from utils.app_utils import are_models_alive
6
  from models.ta_models.ta_utils import TA_predict_convo, ta_push_convo_comparison, post_process_response
7
  from models.ta_models.config import QUESTION2PHASE, NAME2QUESTION, TA_OPTIONS
8
 
 
 
9
  if "memory" not in st.session_state:
10
  st.switch_page("pages/convosim.py")
11
 
 
6
  from models.ta_models.ta_utils import TA_predict_convo, ta_push_convo_comparison, post_process_response
7
  from models.ta_models.config import QUESTION2PHASE, NAME2QUESTION, TA_OPTIONS
8
 
9
+ st.set_page_config(page_title="Conversation Simulator - Scoring")
10
+
11
  if "memory" not in st.session_state:
12
  st.switch_page("pages/convosim.py")
13