Spaces:
Sleeping
Sleeping
ivnban27-ctl
commited on
llm_endpoint_update (#11)
Browse files- fixed databricks integration (42a7266256c56efda882f5ea36a1dacf26a18515)
- logging utils (5ff529cbd911654e16cc9722f5920503085cc3f2)
- open ai update to gpt4 (a59350a066686b7d41db18868a4bd227ca6c28bd)
- conversation end (5b7437146dae69a95dd96b623229f50e58c420de)
- max messages updated (62de720c64e356a1d0785ac136fc527308dfd1d7)
- app_config.py +11 -8
- convosim.py +22 -10
- {pages → hidden_pages}/comparisor.py +0 -0
- models/business_logic_utils/business_logic.py +39 -0
- models/business_logic_utils/config.py +292 -0
- models/business_logic_utils/input_processing.py +143 -0
- models/business_logic_utils/prompt_generation.py +83 -0
- models/business_logic_utils/requirements.txt +3 -0
- models/business_logic_utils/response_generation.py +49 -0
- models/business_logic_utils/response_processing.py +194 -0
- models/custom_parsers.py +30 -30
- models/databricks/custom_databricks_llm.py +72 -0
- models/databricks/scenario_sim.py +0 -91
- models/databricks/texter_sim_llm.py +46 -0
- models/model_seeds.py +0 -100
- models/openai/role_models.py +23 -66
- requirements.txt +3 -3
- utils/chain_utils.py +8 -11
- utils/memory_utils.py +1 -1
app_config.py
CHANGED
@@ -3,22 +3,22 @@ from models.model_seeds import seeds, seed2str
|
|
3 |
# ISSUES = ['Anxiety','Suicide']
|
4 |
ISSUES = [k for k,_ in seeds.items()]
|
5 |
SOURCES = [
|
6 |
-
"CTL_llama2",
|
7 |
-
|
8 |
# "CTL_mistral",
|
9 |
'OA_rolemodel',
|
10 |
# 'OA_finetuned',
|
11 |
]
|
12 |
-
SOURCES_LAB = {"OA_rolemodel":'OpenAI
|
13 |
"OA_finetuned":'Finetuned OpenAI',
|
14 |
-
|
15 |
-
|
16 |
"CTL_mistral": "Mistral",
|
17 |
}
|
18 |
|
19 |
ENDPOINT_NAMES = {
|
20 |
-
"CTL_llama2": "texter_simulator",
|
21 |
-
|
22 |
# 'CTL_llama2': "llama2_convo_sim",
|
23 |
"CTL_mistral": "convo_sim_mistral"
|
24 |
}
|
@@ -35,4 +35,7 @@ DB_SCHEMA = 'prod_db' if ENVIRON == 'prod' else 'test_db'
|
|
35 |
DB_CONVOS = 'conversations'
|
36 |
DB_COMPLETIONS = 'comparison_completions'
|
37 |
DB_BATTLES = 'battles'
|
38 |
-
DB_ERRORS = 'completion_errors'
|
|
|
|
|
|
|
|
3 |
# ISSUES = ['Anxiety','Suicide']
|
4 |
ISSUES = [k for k,_ in seeds.items()]
|
5 |
SOURCES = [
|
6 |
+
# "CTL_llama2",
|
7 |
+
"CTL_llama3",
|
8 |
# "CTL_mistral",
|
9 |
'OA_rolemodel',
|
10 |
# 'OA_finetuned',
|
11 |
]
|
12 |
+
SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT4o',
|
13 |
"OA_finetuned":'Finetuned OpenAI',
|
14 |
+
# "CTL_llama2": "Llama 2",
|
15 |
+
"CTL_llama3": "Llama 3",
|
16 |
"CTL_mistral": "Mistral",
|
17 |
}
|
18 |
|
19 |
ENDPOINT_NAMES = {
|
20 |
+
# "CTL_llama2": "texter_simulator",
|
21 |
+
"CTL_llama3": "texter_simulator_llm",
|
22 |
# 'CTL_llama2': "llama2_convo_sim",
|
23 |
"CTL_mistral": "convo_sim_mistral"
|
24 |
}
|
|
|
35 |
DB_CONVOS = 'conversations'
|
36 |
DB_COMPLETIONS = 'comparison_completions'
|
37 |
DB_BATTLES = 'battles'
|
38 |
+
DB_ERRORS = 'completion_errors'
|
39 |
+
|
40 |
+
MAX_MSG_COUNT = 60
|
41 |
+
WARN_MSG_COUT = int(MAX_MSG_COUNT*0.8)
|
convosim.py
CHANGED
@@ -6,7 +6,7 @@ from utils.mongo_utils import get_db_client
|
|
6 |
from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
|
7 |
from utils.memory_utils import clear_memory, push_convo2db
|
8 |
from utils.chain_utils import get_chain, custom_chain_predict
|
9 |
-
from app_config import ISSUES, SOURCES, source2label, issue2label
|
10 |
|
11 |
logger = get_logger(__name__)
|
12 |
openai_api_key = os.environ['OPENAI_API_KEY']
|
@@ -15,6 +15,8 @@ temperature = 0.8
|
|
15 |
|
16 |
if "sent_messages" not in st.session_state:
|
17 |
st.session_state['sent_messages'] = 0
|
|
|
|
|
18 |
if "issue" not in st.session_state:
|
19 |
st.session_state['issue'] = ISSUES[0]
|
20 |
if 'previous_source' not in st.session_state:
|
@@ -23,7 +25,7 @@ if 'db_client' not in st.session_state:
|
|
23 |
st.session_state["db_client"] = get_db_client()
|
24 |
if 'texter_name' not in st.session_state:
|
25 |
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
|
26 |
-
logger.
|
27 |
|
28 |
memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}}
|
29 |
|
@@ -44,7 +46,6 @@ with st.sidebar:
|
|
44 |
source = st.selectbox("Select a source Model A", SOURCES, index=0,
|
45 |
format_func=source2label,
|
46 |
)
|
47 |
-
st.markdown(f"### Previous Prompt Count: :red[**{st.session_state['sent_messages']}**]")
|
48 |
|
49 |
changed_source = any([
|
50 |
st.session_state['previous_source'] != source,
|
@@ -52,11 +53,13 @@ changed_source = any([
|
|
52 |
st.session_state['counselor_name'] != username,
|
53 |
])
|
54 |
if changed_source:
|
55 |
-
st.session_state["counselor_name"] = username
|
56 |
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
|
|
|
57 |
st.session_state['previous_source'] = source
|
58 |
st.session_state['issue'] = issue
|
59 |
st.session_state['sent_messages'] = 0
|
|
|
60 |
create_memory_add_initial_message(memories,
|
61 |
issue,
|
62 |
language,
|
@@ -66,22 +69,31 @@ create_memory_add_initial_message(memories,
|
|
66 |
st.session_state['previous_source'] = source
|
67 |
memoryA = st.session_state[list(memories.keys())[0]]
|
68 |
# issue only without "." marker for model compatibility
|
69 |
-
llm_chain, stopper = get_chain(issue
|
70 |
|
71 |
st.title("💬 Simulator")
|
72 |
-
|
73 |
for msg in memoryA.buffer_as_messages:
|
74 |
role = "user" if type(msg) == HumanMessage else "assistant"
|
75 |
st.chat_message(role).write(msg.content)
|
76 |
|
77 |
-
if prompt := st.chat_input():
|
78 |
st.session_state['sent_messages'] += 1
|
|
|
79 |
if 'convo_id' not in st.session_state:
|
80 |
push_convo2db(memories, username, language)
|
81 |
-
|
82 |
-
st.chat_message("user").write(prompt)
|
83 |
responses = custom_chain_predict(llm_chain, prompt, stopper)
|
84 |
# responses = llm_chain.predict(input=prompt, stop=stopper)
|
85 |
# response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
|
86 |
for response in responses:
|
87 |
-
st.chat_message("assistant").write(response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
|
7 |
from utils.memory_utils import clear_memory, push_convo2db
|
8 |
from utils.chain_utils import get_chain, custom_chain_predict
|
9 |
+
from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
|
10 |
|
11 |
logger = get_logger(__name__)
|
12 |
openai_api_key = os.environ['OPENAI_API_KEY']
|
|
|
15 |
|
16 |
if "sent_messages" not in st.session_state:
|
17 |
st.session_state['sent_messages'] = 0
|
18 |
+
if "total_messages" not in st.session_state:
|
19 |
+
st.session_state['total_messages'] = 0
|
20 |
if "issue" not in st.session_state:
|
21 |
st.session_state['issue'] = ISSUES[0]
|
22 |
if 'previous_source' not in st.session_state:
|
|
|
25 |
st.session_state["db_client"] = get_db_client()
|
26 |
if 'texter_name' not in st.session_state:
|
27 |
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
|
28 |
+
logger.debug(f"texter name is {st.session_state['texter_name']}")
|
29 |
|
30 |
memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}}
|
31 |
|
|
|
46 |
source = st.selectbox("Select a source Model A", SOURCES, index=0,
|
47 |
format_func=source2label,
|
48 |
)
|
|
|
49 |
|
50 |
changed_source = any([
|
51 |
st.session_state['previous_source'] != source,
|
|
|
53 |
st.session_state['counselor_name'] != username,
|
54 |
])
|
55 |
if changed_source:
|
56 |
+
st.session_state["counselor_name"] = username
|
57 |
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
|
58 |
+
logger.debug(f"texter name is {st.session_state['texter_name']}")
|
59 |
st.session_state['previous_source'] = source
|
60 |
st.session_state['issue'] = issue
|
61 |
st.session_state['sent_messages'] = 0
|
62 |
+
st.session_state['total_messages'] = 0
|
63 |
create_memory_add_initial_message(memories,
|
64 |
issue,
|
65 |
language,
|
|
|
69 |
st.session_state['previous_source'] = source
|
70 |
memoryA = st.session_state[list(memories.keys())[0]]
|
71 |
# issue only without "." marker for model compatibility
|
72 |
+
llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
|
73 |
|
74 |
st.title("💬 Simulator")
|
75 |
+
st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
|
76 |
for msg in memoryA.buffer_as_messages:
|
77 |
role = "user" if type(msg) == HumanMessage else "assistant"
|
78 |
st.chat_message(role).write(msg.content)
|
79 |
|
80 |
+
if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG_COUNT - 4): #account for next interaction
|
81 |
st.session_state['sent_messages'] += 1
|
82 |
+
st.chat_message("user").write(prompt)
|
83 |
if 'convo_id' not in st.session_state:
|
84 |
push_convo2db(memories, username, language)
|
|
|
|
|
85 |
responses = custom_chain_predict(llm_chain, prompt, stopper)
|
86 |
# responses = llm_chain.predict(input=prompt, stop=stopper)
|
87 |
# response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
|
88 |
for response in responses:
|
89 |
+
st.chat_message("assistant").write(response)
|
90 |
+
|
91 |
+
st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
|
92 |
+
if st.session_state['total_messages'] >= MAX_MSG_COUNT:
|
93 |
+
st.toast(f"Total of {MAX_MSG_COUNT} Messages reached. Conversation Ended", icon=":material/verified:")
|
94 |
+
elif st.session_state['total_messages'] >= WARN_MSG_COUT:
|
95 |
+
st.toast(f"The conversation will end at {MAX_MSG_COUNT} Total Messages ", icon=":material/warning:")
|
96 |
+
|
97 |
+
with st.sidebar:
|
98 |
+
st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
|
99 |
+
st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]")
|
{pages → hidden_pages}/comparisor.py
RENAMED
File without changes
|
models/business_logic_utils/business_logic.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .input_processing import parse_app_request, initialize_conversation, parse_prompt
|
2 |
+
from .response_generation import generate_sim
|
3 |
+
from .response_processing import process_model_response
|
4 |
+
from streamlit.logger import get_logger
|
5 |
+
|
6 |
+
logger = get_logger(__name__)
|
7 |
+
|
8 |
+
def process_app_request(app_request: dict, endpoint_url: str, bearer_token: str) -> dict:
|
9 |
+
"""Process the app request and return the response in the required format."""
|
10 |
+
|
11 |
+
############################# Input Processing ###################################
|
12 |
+
# Parse the app request into model_input and extract the prompt
|
13 |
+
model_input, prompt, conversation_id = parse_app_request(app_request)
|
14 |
+
|
15 |
+
# Initialize the conversation (adds the system message)
|
16 |
+
model_input = initialize_conversation(model_input, conversation_id)
|
17 |
+
|
18 |
+
# Parse the prompt into messages
|
19 |
+
prompt_messages = parse_prompt(prompt)
|
20 |
+
|
21 |
+
# Append the messages parsed from the app prompt to the conversation history
|
22 |
+
model_input['messages'].extend(prompt_messages)
|
23 |
+
|
24 |
+
####################################################################################
|
25 |
+
|
26 |
+
####################### Output Generation & Processing #############################
|
27 |
+
|
28 |
+
# Generate the assistant's response (texter's reply)
|
29 |
+
completion = generate_sim(model_input, endpoint_url, bearer_token)
|
30 |
+
|
31 |
+
# Process the raw model response (parse, guardrails, split)
|
32 |
+
final_response = process_model_response(completion, model_input, endpoint_url, bearer_token)
|
33 |
+
|
34 |
+
# Format the response for the APP
|
35 |
+
response = {"predictions": [{"generated_text": final_response}]}
|
36 |
+
|
37 |
+
####################################################################################
|
38 |
+
|
39 |
+
return response
|
models/business_logic_utils/config.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
API_TIMEOUT = 240
|
2 |
+
|
3 |
+
AI_PHRASES = [
|
4 |
+
"I am an AI",
|
5 |
+
"I'm an AI",
|
6 |
+
"I am not human",
|
7 |
+
"I'm not human",
|
8 |
+
"I am a machine learning model",
|
9 |
+
"I'm a machine learning model",
|
10 |
+
"as an AI",
|
11 |
+
"as a text-based assistant",
|
12 |
+
"as a text based assistant",
|
13 |
+
]
|
14 |
+
|
15 |
+
SUPPORTED_LANGUAGES = [
|
16 |
+
"en",
|
17 |
+
"es"
|
18 |
+
]
|
19 |
+
|
20 |
+
TEMPLATE = {
|
21 |
+
"EN_template": {
|
22 |
+
"language": "en",
|
23 |
+
"description": """The following is a conversation between you and a crisis counselor.
|
24 |
+
{current_seed}
|
25 |
+
You are able to reply with what the character should say. You are able to reply with your character's dialogue inside and nothing else. Do not write explanations.""",
|
26 |
+
},
|
27 |
+
"ES_template": {
|
28 |
+
"language": "es",
|
29 |
+
"description": """La siguiente es una conversacion entre tu y un consejero de crisis
|
30 |
+
{current_seed}
|
31 |
+
Puedes responder como lo haria tu personaje. Puedes responder como si fueras tu personaje y nada mas. No escribas explicaciones.""",
|
32 |
+
},
|
33 |
+
}
|
34 |
+
|
35 |
+
SEED = "Your character, {texter_name}, {crisis} {risk} {personality} {coping_preference} {difficulty}"
|
36 |
+
|
37 |
+
SCENARIOS = {
|
38 |
+
"full_convo": {
|
39 |
+
"crisis": "default",
|
40 |
+
"risk": "default",
|
41 |
+
"personality": "personality_open",
|
42 |
+
"coping_preference": "default",
|
43 |
+
"difficulty": "default",
|
44 |
+
},
|
45 |
+
"full_convo__seeded1": {
|
46 |
+
"crisis": "bullying",
|
47 |
+
"risk": "low",
|
48 |
+
"personality": "personality_open",
|
49 |
+
"coping_preference": "default",
|
50 |
+
"difficulty": "default",
|
51 |
+
},
|
52 |
+
"full_convo__seeded2": {
|
53 |
+
"crisis": "parent_issues",
|
54 |
+
"risk": "low",
|
55 |
+
"personality": "personality_open",
|
56 |
+
"coping_preference": "default",
|
57 |
+
"difficulty": "default",
|
58 |
+
},
|
59 |
+
"safety_assessment__seeded1": {
|
60 |
+
"crisis": "bullying",
|
61 |
+
"risk": "thoughts__noplan",
|
62 |
+
"personality": "personality_open",
|
63 |
+
"coping_preference": "default",
|
64 |
+
"difficulty": "default",
|
65 |
+
},
|
66 |
+
"safety_assessment__seeded2": {
|
67 |
+
"crisis": "grief",
|
68 |
+
"risk": "thoughts__noplan",
|
69 |
+
"personality": "personality_open",
|
70 |
+
"coping_preference": "default",
|
71 |
+
"difficulty": "default",
|
72 |
+
},
|
73 |
+
"full_convo__seeded3": {
|
74 |
+
"crisis": "lgbt",
|
75 |
+
"risk": "low",
|
76 |
+
"personality": "personality_open",
|
77 |
+
"coping_preference": "default",
|
78 |
+
"difficulty": "default",
|
79 |
+
},
|
80 |
+
"full_convo__seeded4": {
|
81 |
+
"crisis": "relationship_issues",
|
82 |
+
"risk": "low",
|
83 |
+
"personality": "personality_open",
|
84 |
+
"coping_preference": "default",
|
85 |
+
"difficulty": "default",
|
86 |
+
},
|
87 |
+
"full_convo__seeded5": {
|
88 |
+
"crisis": "child_abuse",
|
89 |
+
"risk": "low",
|
90 |
+
"personality": "personality_open",
|
91 |
+
"coping_preference": "default",
|
92 |
+
"difficulty": "default",
|
93 |
+
},
|
94 |
+
"full_convo__seeded6": {
|
95 |
+
"crisis": "overdose",
|
96 |
+
"risk": "low",
|
97 |
+
"personality": "personality_open",
|
98 |
+
"coping_preference": "default",
|
99 |
+
"difficulty": "default",
|
100 |
+
},
|
101 |
+
"full_convo__hard": {
|
102 |
+
"crisis": "default",
|
103 |
+
"risk": "default",
|
104 |
+
"personality": "personality_closed",
|
105 |
+
"coping_preference": "default",
|
106 |
+
"difficulty": "non_default",
|
107 |
+
},
|
108 |
+
"full_convo__hard__seeded1": {
|
109 |
+
"crisis": "bullying",
|
110 |
+
"risk": "low",
|
111 |
+
"personality": "personality_closed",
|
112 |
+
"coping_preference": "default",
|
113 |
+
"difficulty": "non_default",
|
114 |
+
},
|
115 |
+
"full_convo__hard__seeded2": {
|
116 |
+
"crisis": "parent_issues",
|
117 |
+
"risk": "low",
|
118 |
+
"personality": "personality_open",
|
119 |
+
"coping_preference": "default",
|
120 |
+
"difficulty": "non_default",
|
121 |
+
},
|
122 |
+
"full_convo__hard__seeded3": {
|
123 |
+
"crisis": "lgbt",
|
124 |
+
"risk": "low",
|
125 |
+
"personality": "personality_closed",
|
126 |
+
"coping_preference": "default",
|
127 |
+
"difficulty": "non_default",
|
128 |
+
},
|
129 |
+
"full_convo__hard__seeded4": {
|
130 |
+
"crisis": "relationship_issues",
|
131 |
+
"risk": "low",
|
132 |
+
"personality": "personality_open",
|
133 |
+
"coping_preference": "default",
|
134 |
+
"difficulty": "non_default",
|
135 |
+
},
|
136 |
+
}
|
137 |
+
|
138 |
+
DEPREC_SCENARIO_MAPPING = {
|
139 |
+
"GCT": {
|
140 |
+
"crisis": "default",
|
141 |
+
"risk": "default",
|
142 |
+
"personality": "default",
|
143 |
+
"coping_preference": "default",
|
144 |
+
},
|
145 |
+
"safety_planning": {
|
146 |
+
"crisis": "default",
|
147 |
+
"risk": "thoughts__noplan",
|
148 |
+
"personality": "default",
|
149 |
+
"coping_preference": "default",
|
150 |
+
},
|
151 |
+
}
|
152 |
+
|
153 |
+
DEFAULT_NAMES = ["Olivia", "Kit", "Abby", "Tom", "Carolyne", "Jessiny"]
|
154 |
+
|
155 |
+
CRISES = {
|
156 |
+
"default": {
|
157 |
+
"description": [
|
158 |
+
"is experiencing a mental health crisis.",
|
159 |
+
]
|
160 |
+
},
|
161 |
+
"bullying": {
|
162 |
+
"description": [
|
163 |
+
"is suffering from bullying at school.",
|
164 |
+
"is suffering from bullying at college.",
|
165 |
+
]
|
166 |
+
},
|
167 |
+
"parent_issues": {
|
168 |
+
"description": [
|
169 |
+
"just had a huge fight with their parents.",
|
170 |
+
]
|
171 |
+
},
|
172 |
+
"grief": {
|
173 |
+
"description": [
|
174 |
+
"is grieving his wife who died exactly one year ago.",
|
175 |
+
"is grieving her grandmother who died a couple of years ago.",
|
176 |
+
]
|
177 |
+
},
|
178 |
+
"lgbt": {
|
179 |
+
"description": [
|
180 |
+
"is struggling with coming out to their parents about being gay and fears rejection.",
|
181 |
+
"is facing harassment at college for being transgender and feels isolated.",
|
182 |
+
]
|
183 |
+
},
|
184 |
+
"relationship_issues": {
|
185 |
+
"description": [
|
186 |
+
"is feeling hopeless after their significant other broke up with them unexpectedly.",
|
187 |
+
"is feeling trapped in an emotionally abusive relationship.",
|
188 |
+
]
|
189 |
+
},
|
190 |
+
"child_abuse": {
|
191 |
+
"description": [
|
192 |
+
"is being physically abused by a parent and is scared to go home. Your last name is Davis but you won't tell it unless you are asked about it. Your address is 104 Elm St in Austin TX.",
|
193 |
+
"is 15 years old and is feeling helpless after witnessing a sibling, Charlie, being abused by a parent. You live at 200 Independence Avenue, S.W. Washington. When asked about your full name or last name you will reveal that it is Smith.",
|
194 |
+
]
|
195 |
+
},
|
196 |
+
"overdose": {
|
197 |
+
"description": [
|
198 |
+
"is feeling scared after accidentally overdosing on prescription medication and doesn't know who to turn to.",
|
199 |
+
]
|
200 |
+
},
|
201 |
+
}
|
202 |
+
|
203 |
+
PERSONALITIES = {
|
204 |
+
"default": {"personality_label": "default", "description": [""]},
|
205 |
+
"personality_open": {
|
206 |
+
"personality_label": "open",
|
207 |
+
"description": [
|
208 |
+
"You are talkative, and you are open to talk with details about your problems.",
|
209 |
+
"You express openly and detailed about what you're going through.",
|
210 |
+
"You are an extrovert, you do not hesitate to share your feelings.",
|
211 |
+
"You like to communicate a lot.",
|
212 |
+
],
|
213 |
+
},
|
214 |
+
"personality_closed": {
|
215 |
+
"personality_label": "closed",
|
216 |
+
"description": [
|
217 |
+
"You are a shy person, it's hard for you to talk about your problems.",
|
218 |
+
"You are introverted. It is difficult for you to be open about your feelings.",
|
219 |
+
"You are timid. It's hard for you to trust others.",
|
220 |
+
"You prefer to use few words.",
|
221 |
+
],
|
222 |
+
},
|
223 |
+
}
|
224 |
+
|
225 |
+
COPING_PREFERENCES = {
|
226 |
+
"with_coping": {
|
227 |
+
"coping_defined": True,
|
228 |
+
"description": [
|
229 |
+
"When you are feeling down you enjoy listening to classical music. Your favorite composer is Beethoven.",
|
230 |
+
"To calm down you like to draw animals in your sketchbook. You love panda bears.",
|
231 |
+
"In dire situations to prefer to exercise. You like lifting heavy weights.",
|
232 |
+
"When you're feeling stressed you enjoy playing videogames on your laptop. You really enjoy strategy games.",
|
233 |
+
],
|
234 |
+
},
|
235 |
+
"default": {"coping_defined": False, "description": [""]},
|
236 |
+
}
|
237 |
+
|
238 |
+
SUICIDE_PLANS = [
|
239 |
+
" You are planning to shoot yourself with a gun.",
|
240 |
+
" You are planning to hang yourself with a rope.",
|
241 |
+
" You are planning to overdose with sleep pills.",
|
242 |
+
" You are planning to cut your wrists open with a knife.",
|
243 |
+
]
|
244 |
+
|
245 |
+
RISKS = {
|
246 |
+
"default": {
|
247 |
+
"description": ["{plan}"],
|
248 |
+
"plan": [""],
|
249 |
+
},
|
250 |
+
"low": {
|
251 |
+
"description": [
|
252 |
+
"Your character does not have suicidal thoughts.{plan}",
|
253 |
+
],
|
254 |
+
"plan": [""],
|
255 |
+
},
|
256 |
+
"thoughts__noplan": {
|
257 |
+
"description": [
|
258 |
+
"Your character has suicidal thoughts.{plan}",
|
259 |
+
],
|
260 |
+
"plan": [""],
|
261 |
+
},
|
262 |
+
"thoughts__plan": {
|
263 |
+
"description": [
|
264 |
+
"Your character has suicidal thoughts.{plan}",
|
265 |
+
],
|
266 |
+
"plan": SUICIDE_PLANS,
|
267 |
+
},
|
268 |
+
}
|
269 |
+
|
270 |
+
DIFFICULTIES = {
|
271 |
+
"default": {"difficulty_label": "default", "description": [""]},
|
272 |
+
"difficulty_distrustful": {
|
273 |
+
"difficulty_label": "distrustful",
|
274 |
+
"description": [
|
275 |
+
"You don't trust the counselor, you will eventually cooperate.",
|
276 |
+
],
|
277 |
+
},
|
278 |
+
# "difficulty_stop_convo": {
|
279 |
+
# "difficulty_label": "stop_convo",
|
280 |
+
# "description": [
|
281 |
+
# "You are angry. You are likely to type 'STOP' to end the conversation when you are very upset. However you are willing to cooperate with the counselor",
|
282 |
+
# ],
|
283 |
+
# },
|
284 |
+
}
|
285 |
+
|
286 |
+
SUBSEEDS = {
|
287 |
+
"crisis": CRISES,
|
288 |
+
"risk": RISKS,
|
289 |
+
"personality": PERSONALITIES,
|
290 |
+
"coping_preference": COPING_PREFERENCES,
|
291 |
+
"difficulty": DIFFICULTIES,
|
292 |
+
}
|
models/business_logic_utils/input_processing.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .config import SCENARIOS
|
2 |
+
from .prompt_generation import get_template
|
3 |
+
from streamlit.logger import get_logger
|
4 |
+
|
5 |
+
logger = get_logger(__name__)
|
6 |
+
|
7 |
+
def parse_app_request(app_request: dict) -> tuple:
|
8 |
+
"""Parse the APP request and convert it to model_input format, returning model_input, prompt, and conversation_id."""
|
9 |
+
inputs = app_request.get("inputs", {})
|
10 |
+
|
11 |
+
# Extract fields
|
12 |
+
conversation_id = inputs.get("conversation_id", [""])[0]
|
13 |
+
ip_address = inputs.get("ip_address", [""])[0]
|
14 |
+
prompt = inputs.get("prompt", [""])[0]
|
15 |
+
issue = inputs.get("issue", ["full_convo"])[0]
|
16 |
+
language = inputs.get("language", ["en"])[0]
|
17 |
+
temperature = float(inputs.get("temperature", ["0.7"])[0])
|
18 |
+
max_tokens = int(inputs.get("max_tokens", ["128"])[0])
|
19 |
+
texter_name = inputs.get("texter_name", [None])[0]
|
20 |
+
|
21 |
+
# Build the model_input dictionary without messages
|
22 |
+
model_input = {
|
23 |
+
"issue": issue,
|
24 |
+
"language": language,
|
25 |
+
"texter_name": texter_name, # Assuming empty unless provided elsewhere
|
26 |
+
"messages": [],
|
27 |
+
"max_tokens": max_tokens,
|
28 |
+
"temperature": temperature,
|
29 |
+
}
|
30 |
+
|
31 |
+
# Return model_input, prompt, and conversation_id
|
32 |
+
return model_input, prompt, conversation_id
|
33 |
+
|
34 |
+
def initialize_conversation(model_input: dict, conversation_id: str) -> dict:
|
35 |
+
"""Initialize the conversation by adding the system message."""
|
36 |
+
messages = model_input.get("messages", [])
|
37 |
+
|
38 |
+
# Check if the first message is already a system message
|
39 |
+
if not messages or messages[0].get('role') != 'system':
|
40 |
+
texter_name = model_input.get("texter_name", None)
|
41 |
+
language = model_input.get("language", "en")
|
42 |
+
|
43 |
+
# Retrieve the scenario configuration based on 'issue'
|
44 |
+
scenario_key = model_input["issue"]
|
45 |
+
scenario_config = SCENARIOS.get(scenario_key)
|
46 |
+
if not scenario_config:
|
47 |
+
raise ValueError(f"The scenario '{scenario_key}' is not defined in SCENARIOS.")
|
48 |
+
# Generate the system message (prompt)
|
49 |
+
system_message_content = get_template(
|
50 |
+
language=language, texter_name=texter_name, **scenario_config
|
51 |
+
)
|
52 |
+
logger.debug(f"System message is: {system_message_content}")
|
53 |
+
system_message = {"role": "system", "content": system_message_content}
|
54 |
+
|
55 |
+
# Insert the system message at the beginning
|
56 |
+
messages.insert(0, system_message)
|
57 |
+
|
58 |
+
model_input['messages'] = messages
|
59 |
+
|
60 |
+
return model_input
|
61 |
+
|
62 |
+
def parse_prompt(
|
63 |
+
prompt: str,
|
64 |
+
user_prefix: str = "helper:",
|
65 |
+
assistant_prefix: str = "texter:",
|
66 |
+
delimitator: str = "\n"
|
67 |
+
) -> list:
|
68 |
+
"""
|
69 |
+
Parse the prompt string into a list of messages.
|
70 |
+
|
71 |
+
- Includes an initial empty 'user' message if not present.
|
72 |
+
- Combines consecutive messages from the same role into a single message.
|
73 |
+
- Handles punctuation when combining messages.
|
74 |
+
- The prefixes for user and assistant can be customized.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
prompt (str): The conversation history string.
|
78 |
+
user_prefix (str): Prefix for user messages (default: "helper:").
|
79 |
+
assistant_prefix (str): Prefix for assistant messages (default: "texter:").
|
80 |
+
delimitator (str): The delimiter used to split the prompt into lines. Defaults to "\n".
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
list: Parsed messages in the form of dictionaries with 'role' and 'content'.
|
84 |
+
"""
|
85 |
+
|
86 |
+
# Check if the prompt starts with the user prefix; if not, add an initial empty user message
|
87 |
+
if not prompt.strip().startswith(user_prefix):
|
88 |
+
prompt = f"{user_prefix}{delimitator}" + prompt
|
89 |
+
|
90 |
+
# Split the prompt using the specified delimiter
|
91 |
+
lines = [line.strip() for line in prompt.strip().split(delimitator) if line.strip()]
|
92 |
+
|
93 |
+
messages = []
|
94 |
+
last_role = None
|
95 |
+
last_content = ""
|
96 |
+
last_line_empty_texter = False
|
97 |
+
|
98 |
+
for line in lines:
|
99 |
+
if line.startswith(user_prefix):
|
100 |
+
content = line[len(user_prefix):].strip()
|
101 |
+
role = 'user'
|
102 |
+
# Include 'user' messages even if content is empty
|
103 |
+
if last_role == role:
|
104 |
+
# Combine with previous content
|
105 |
+
if last_content and not last_content.endswith(('...', '.', '!', '?')):
|
106 |
+
last_content += '.'
|
107 |
+
last_content += f" {content}"
|
108 |
+
else:
|
109 |
+
# Save previous message if exists
|
110 |
+
if last_role is not None:
|
111 |
+
messages.append({'role': last_role, 'content': last_content})
|
112 |
+
last_role = role
|
113 |
+
last_content = content
|
114 |
+
elif line.startswith(assistant_prefix):
|
115 |
+
content = line[len(assistant_prefix):].strip()
|
116 |
+
role = 'assistant'
|
117 |
+
if content:
|
118 |
+
if last_role == role:
|
119 |
+
# Combine with previous content
|
120 |
+
if last_content and not last_content.endswith(('...', '.', '!', '?')):
|
121 |
+
last_content += '.'
|
122 |
+
last_content += f" {content}"
|
123 |
+
else:
|
124 |
+
# Save previous message if exists
|
125 |
+
if last_role is not None:
|
126 |
+
messages.append({'role': last_role, 'content': last_content})
|
127 |
+
last_role = role
|
128 |
+
last_content = content
|
129 |
+
else:
|
130 |
+
# Empty 'texter:' line, mark for exclusion
|
131 |
+
last_line_empty_texter = True
|
132 |
+
else:
|
133 |
+
# Ignore or handle unexpected lines
|
134 |
+
pass
|
135 |
+
|
136 |
+
# After processing all lines, add the last message if it's not an empty assistant message
|
137 |
+
if last_role == 'assistant' and not last_content:
|
138 |
+
# Do not add empty assistant message
|
139 |
+
pass
|
140 |
+
else:
|
141 |
+
messages.append({'role': last_role, 'content': last_content})
|
142 |
+
|
143 |
+
return messages
|
models/business_logic_utils/prompt_generation.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from string import Formatter
|
3 |
+
from typing import Dict, Any
|
4 |
+
from .config import DEFAULT_NAMES, TEMPLATE, SEED, SUBSEEDS
|
5 |
+
|
6 |
+
def get_random_default_name(gender: str = None) -> str:
|
7 |
+
return random.choice(DEFAULT_NAMES)
|
8 |
+
|
9 |
+
def _get_subseed_description_(
|
10 |
+
scenario_config: Dict[str, str], subseed_name: str, SUBSEED_VALUES: Dict[str, Any]
|
11 |
+
) -> str:
|
12 |
+
"""Format a subseed with no formatting gaps."""
|
13 |
+
if subseed_name not in scenario_config:
|
14 |
+
raise Exception(f"{subseed_name} not in scenario config")
|
15 |
+
|
16 |
+
subseed_value = scenario_config[subseed_name]
|
17 |
+
|
18 |
+
# Handle difficulty specifically for hard scenarios when difficulty is not default
|
19 |
+
if subseed_name == "difficulty" and subseed_value != "default":
|
20 |
+
# Select a random difficulty from the SUBSEED_VALUES dictionary, excluding "default"
|
21 |
+
non_default_difficulties = [key for key in SUBSEED_VALUES if key != "default"]
|
22 |
+
subseed_value = random.choice(non_default_difficulties)
|
23 |
+
|
24 |
+
descriptions = SUBSEED_VALUES.get(subseed_value, {}).get("description", [""])
|
25 |
+
# Get subseed description
|
26 |
+
subseed_descrip = random.choice(descriptions)
|
27 |
+
# Additional formatting options
|
28 |
+
format_opts = [
|
29 |
+
fn for _, fn, _, _ in Formatter().parse(subseed_descrip) if fn is not None
|
30 |
+
]
|
31 |
+
format_values = {}
|
32 |
+
if len(format_opts) > 0:
|
33 |
+
for opt_name in format_opts:
|
34 |
+
opts = SUBSEED_VALUES.get(subseed_value, {}).get(opt_name, [""])
|
35 |
+
format_values[opt_name] = random.choice(opts)
|
36 |
+
# Format the description
|
37 |
+
return subseed_descrip.format(**format_values)
|
38 |
+
|
39 |
+
def get_seed_description(
|
40 |
+
scenario_config: Dict[str, Any],
|
41 |
+
texter_name: str,
|
42 |
+
SUBSEEDS: Dict[str, Any] = SUBSEEDS,
|
43 |
+
SEED: str = SEED,
|
44 |
+
) -> str:
|
45 |
+
"""Format the SEED with appropriate parameters from scenario_config."""
|
46 |
+
subseed_names = [fn for _, fn, _, _ in Formatter().parse(SEED) if fn is not None]
|
47 |
+
subseeds = {}
|
48 |
+
for subname in subseed_names:
|
49 |
+
if subname == "texter_name":
|
50 |
+
subseeds[subname] = texter_name
|
51 |
+
else:
|
52 |
+
subseeds[subname] = _get_subseed_description_(
|
53 |
+
scenario_config, subname, SUBSEEDS.get(subname, {})
|
54 |
+
)
|
55 |
+
return SEED.format(**subseeds)
|
56 |
+
|
57 |
+
def get_template(
|
58 |
+
language: str = "en", texter_name: str = None, SEED: str = SEED, **kwargs
|
59 |
+
) -> str:
|
60 |
+
"""
|
61 |
+
Generate a conversation template for a simulated crisis scenario based on provided parameters.
|
62 |
+
"""
|
63 |
+
# Accessing the template based on the language
|
64 |
+
template = TEMPLATE.get(f"{language.upper()}_template", {}).get("description", "")
|
65 |
+
|
66 |
+
# Default name if not provided
|
67 |
+
if (texter_name is None) or (texter_name==""):
|
68 |
+
texter_name = get_random_default_name()
|
69 |
+
|
70 |
+
# Create a default scenario configuration if not fully provided
|
71 |
+
defaults = {
|
72 |
+
fn: "default" for _, fn, _, _ in Formatter().parse(SEED) if fn is not None
|
73 |
+
}
|
74 |
+
kwargs.update((k, defaults[k]) for k in defaults.keys() if k not in kwargs)
|
75 |
+
|
76 |
+
# Generate the seed description
|
77 |
+
scenario_seed = get_seed_description(kwargs, texter_name)
|
78 |
+
|
79 |
+
# Remove excessive indentation and format the final template
|
80 |
+
formatted_template = template.format(current_seed=scenario_seed)
|
81 |
+
cleaned_output = "\n".join(line.strip() for line in formatted_template.split("\n"))
|
82 |
+
|
83 |
+
return cleaned_output
|
models/business_logic_utils/requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
boto3==1.28.0
|
2 |
+
requests==2.25.1
|
3 |
+
numpy==1.24.3
|
models/business_logic_utils/response_generation.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from .config import API_TIMEOUT, SCENARIOS, SUPPORTED_LANGUAGES
|
3 |
+
from streamlit.logger import get_logger
|
4 |
+
|
5 |
+
logger = get_logger(__name__)
|
6 |
+
|
7 |
+
def check_arguments(model_input: dict) -> None:
|
8 |
+
"""Check if the input arguments are valid."""
|
9 |
+
|
10 |
+
# Validate the issue
|
11 |
+
if model_input["issue"] not in SCENARIOS:
|
12 |
+
raise ValueError(f"Invalid issue: {model_input['issue']}")
|
13 |
+
|
14 |
+
# Validate the language
|
15 |
+
if model_input["language"] not in SUPPORTED_LANGUAGES:
|
16 |
+
raise ValueError(f"Invalid language: {model_input['language']}")
|
17 |
+
|
18 |
+
def generate_sim(model_input: dict, endpoint_url: str, endpoint_bearer_token: str) -> dict:
|
19 |
+
"""Generate a response from the LLM and return the raw completion response."""
|
20 |
+
check_arguments(model_input)
|
21 |
+
|
22 |
+
# Retrieve the messages history
|
23 |
+
messages = model_input['messages']
|
24 |
+
|
25 |
+
# Retrieve the temperature and max_tokens from model_input
|
26 |
+
temperature = model_input.get("temperature", 0.7)
|
27 |
+
max_tokens = model_input.get("max_tokens", 128)
|
28 |
+
|
29 |
+
# Prepare the request body
|
30 |
+
json_request = {
|
31 |
+
"messages": messages,
|
32 |
+
"max_tokens": max_tokens,
|
33 |
+
"temperature": temperature
|
34 |
+
}
|
35 |
+
|
36 |
+
# Define headers for the request
|
37 |
+
headers = {
|
38 |
+
"Authorization": f"Bearer {endpoint_bearer_token}",
|
39 |
+
"Content-Type": "application/json",
|
40 |
+
}
|
41 |
+
|
42 |
+
# Send request to Serving
|
43 |
+
response = requests.post(url=endpoint_url, headers=headers, json=json_request, timeout=API_TIMEOUT)
|
44 |
+
|
45 |
+
if response.status_code != 200:
|
46 |
+
raise ValueError(f"Error in response: {response.status_code} - {response.text}")
|
47 |
+
logger.debug(f"Default response is {response.json()}")
|
48 |
+
# Return the raw response as a dictionary
|
49 |
+
return response.json()
|
models/business_logic_utils/response_processing.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import random
|
3 |
+
from .config import AI_PHRASES
|
4 |
+
from .response_generation import generate_sim
|
5 |
+
|
6 |
+
def parse_model_response(response: dict, name: str = "") -> str:
|
7 |
+
"""
|
8 |
+
Parse the LLM response to extract the assistant's message and apply initial post-processing.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
response (dict): The raw response from the LLM.
|
12 |
+
name (str, optional): Name to strip from the beginning of the text. Defaults to "".
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
str: The cleaned and parsed assistant's message.
|
16 |
+
"""
|
17 |
+
assistant_message = response["choices"][0]["message"]["content"]
|
18 |
+
cleaned_text = postprocess_text(
|
19 |
+
assistant_message,
|
20 |
+
name=name,
|
21 |
+
human_prefix="user:",
|
22 |
+
assistant_prefix="assistant:"
|
23 |
+
)
|
24 |
+
return cleaned_text
|
25 |
+
|
26 |
+
def postprocess_text(
|
27 |
+
text: str,
|
28 |
+
name: str = "",
|
29 |
+
human_prefix: str = "user:",
|
30 |
+
assistant_prefix: str = "assistant:",
|
31 |
+
strip_name: bool = True
|
32 |
+
) -> str:
|
33 |
+
"""Eliminates whispers, reactions, ellipses, and quotation marks from generated text by LLMs.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
text (str): The text to process.
|
37 |
+
name (str, optional): Name to strip from the beginning of the text. Defaults to "".
|
38 |
+
human_prefix (str, optional): The user prefix to remove. Defaults to "user:".
|
39 |
+
assistant_prefix (str, optional): The assistant prefix to remove. Defaults to "assistant:".
|
40 |
+
strip_name (bool, optional): Whether to remove the name at the beginning of the text. Defaults to True.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
str: Cleaned text.
|
44 |
+
"""
|
45 |
+
if text:
|
46 |
+
# Replace ellipses with a single period
|
47 |
+
text = re.sub(r'\.\.\.', '.', text)
|
48 |
+
|
49 |
+
# Remove unnecessary role prefixes
|
50 |
+
text = text.replace(human_prefix, "").replace(assistant_prefix, "")
|
51 |
+
|
52 |
+
# Remove whispers or other marked reactions
|
53 |
+
whispers = re.compile(r"(\([\w\s]+\))") # remove things like "(whispers)"
|
54 |
+
reactions = re.compile(r"(\*[\w\s]+\*)") # remove things like "*stutters*"
|
55 |
+
text = whispers.sub("", text)
|
56 |
+
text = reactions.sub("", text)
|
57 |
+
|
58 |
+
# Remove all quotation marks (both single and double)
|
59 |
+
text = text.replace('"', '').replace("'", "")
|
60 |
+
|
61 |
+
# Normalize spaces
|
62 |
+
text = re.sub(r"\s+", " ", text).strip()
|
63 |
+
|
64 |
+
return text
|
65 |
+
|
66 |
+
def apply_guardrails(model_input: dict, response: str, endpoint_url: str, endpoint_bearer_token: str) -> str:
|
67 |
+
"""Apply the 'I am an AI' guardrail to model responses"""
|
68 |
+
attempt = 0
|
69 |
+
max_attempts = 2
|
70 |
+
|
71 |
+
while attempt < max_attempts and contains_ai_phrase(response):
|
72 |
+
# Regenerate the response without modifying the conversation history
|
73 |
+
completion = generate_sim(model_input, endpoint_url, endpoint_bearer_token)
|
74 |
+
response = parse_model_response(completion)
|
75 |
+
attempt += 1
|
76 |
+
|
77 |
+
if contains_ai_phrase(response):
|
78 |
+
# Use only the last user message for regeneration
|
79 |
+
memory = model_input['messages']
|
80 |
+
last_user_message = next((msg for msg in reversed(memory) if msg['role'] == 'user'), None)
|
81 |
+
if last_user_message:
|
82 |
+
# Create a new conversation with system message and last user message
|
83 |
+
model_input_copy = {
|
84 |
+
**model_input,
|
85 |
+
'messages': [memory[0], last_user_message] # memory[0] is the system message
|
86 |
+
}
|
87 |
+
completion = generate_sim(model_input_copy, endpoint_url, endpoint_bearer_token)
|
88 |
+
response = parse_model_response(completion)
|
89 |
+
|
90 |
+
return response
|
91 |
+
|
92 |
+
|
93 |
+
def contains_ai_phrase(text: str) -> bool:
|
94 |
+
"""Check if the text contains any 'I am an AI' phrases."""
|
95 |
+
text_lower = text.lower()
|
96 |
+
return any(phrase.lower() in text_lower for phrase in AI_PHRASES)
|
97 |
+
|
98 |
+
def truncate_response(text: str, punctuation_marks: tuple = ('.', '!', '?', '…')) -> str:
|
99 |
+
"""
|
100 |
+
Truncate the text at the last occurrence of a specified punctuation mark.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
text (str): The text to truncate.
|
104 |
+
punctuation_marks (tuple, optional): A tuple of punctuation marks to use for truncation. Defaults to ('.', '!', '?', '…').
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
str: The truncated text.
|
108 |
+
"""
|
109 |
+
# Find the last position of any punctuation mark from the provided set
|
110 |
+
last_punct_position = max(text.rfind(p) for p in punctuation_marks)
|
111 |
+
|
112 |
+
# Check if any punctuation mark is found
|
113 |
+
if last_punct_position == -1:
|
114 |
+
# No punctuation found, return the original text
|
115 |
+
return text.strip()
|
116 |
+
|
117 |
+
# Return the truncated text up to and including the last punctuation mark
|
118 |
+
return text[:last_punct_position + 1].strip()
|
119 |
+
|
120 |
+
def split_texter_response(text: str) -> str:
|
121 |
+
"""
|
122 |
+
Splits the texter's response into multiple messages,
|
123 |
+
introducing '\ntexter:' prefixes after punctuation.
|
124 |
+
|
125 |
+
The number of messages is randomly chosen based on specified probabilities:
|
126 |
+
- 1 message: 30% chance
|
127 |
+
- 2 messages: 25% chance
|
128 |
+
- 3 messages: 45% chance
|
129 |
+
|
130 |
+
The first message does not include the '\ntexter:' prefix.
|
131 |
+
"""
|
132 |
+
# Use regex to split text into sentences, keeping the punctuation
|
133 |
+
sentences = re.findall(r'[^.!?]+[.!?]*', text)
|
134 |
+
# Remove empty strings from sentences
|
135 |
+
sentences = [s.strip() for s in sentences if s.strip()]
|
136 |
+
|
137 |
+
# Decide number of messages based on specified probabilities
|
138 |
+
num_messages = random.choices([1, 2, 3], weights=[0.3, 0.25, 0.45], k=1)[0]
|
139 |
+
|
140 |
+
# If not enough sentences to make the splits, adjust num_messages
|
141 |
+
if len(sentences) < num_messages:
|
142 |
+
num_messages = len(sentences)
|
143 |
+
|
144 |
+
# If num_messages is 1, return the original text
|
145 |
+
if num_messages == 1:
|
146 |
+
return text.strip()
|
147 |
+
|
148 |
+
# Calculate split points
|
149 |
+
# We need to divide the sentences into num_messages parts
|
150 |
+
avg = len(sentences) / num_messages
|
151 |
+
split_indices = [int(round(avg * i)) for i in range(1, num_messages)]
|
152 |
+
|
153 |
+
# Build the new text
|
154 |
+
new_text = ''
|
155 |
+
start = 0
|
156 |
+
for i, end in enumerate(split_indices + [len(sentences)]):
|
157 |
+
segment_sentences = sentences[start:end]
|
158 |
+
segment_text = ' '.join(segment_sentences).strip()
|
159 |
+
if i == 0:
|
160 |
+
# First segment, do not add '\ntexter:'
|
161 |
+
new_text += segment_text
|
162 |
+
else:
|
163 |
+
# Subsequent segments, add '\ntexter:'
|
164 |
+
new_text += f"\ntexter: {segment_text}"
|
165 |
+
start = end
|
166 |
+
return new_text.strip()
|
167 |
+
|
168 |
+
def process_model_response(completion: dict, model_input: dict, endpoint_url: str, endpoint_bearer_token: str) -> str:
|
169 |
+
"""
|
170 |
+
Process the raw model response, including parsing, applying guardrails,
|
171 |
+
truncation, and splitting the response into multiple messages if necessary.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
completion (dict): Raw response from the LLM.
|
175 |
+
model_input (dict): The model input containing the conversation history.
|
176 |
+
endpoint_url (str): The URL of the endpoint.
|
177 |
+
endpoint_bearer_token (str): The authentication token for endpoint.
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
str: Final processed response ready for the APP.
|
181 |
+
"""
|
182 |
+
# Step 1: Parse the raw response to extract the assistant's message
|
183 |
+
assistant_message = parse_model_response(completion)
|
184 |
+
|
185 |
+
# Step 2: Apply guardrails (handle possible AI responses)
|
186 |
+
guardrail_message = apply_guardrails(model_input, assistant_message, endpoint_url, endpoint_bearer_token)
|
187 |
+
|
188 |
+
# Step 3: Apply response truncation
|
189 |
+
truncated_message = truncate_response(guardrail_message)
|
190 |
+
|
191 |
+
# Step 4: Split the response into multiple messages if needed
|
192 |
+
final_response = split_texter_response(truncated_message)
|
193 |
+
|
194 |
+
return final_response
|
models/custom_parsers.py
CHANGED
@@ -20,38 +20,38 @@ class CustomStringOutputParser(BaseOutputParser[List[str]]):
|
|
20 |
text_list = [x.strip() for x in list(chain.from_iterable(text_list))]
|
21 |
return text_list
|
22 |
|
23 |
-
class CustomINSTOutputParser(BaseOutputParser[List[str]]):
|
24 |
-
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
|
37 |
-
def parse_whispers(self, text: str) -> str:
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
|
45 |
-
def parse_split(self, text: str) -> str:
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
|
54 |
-
def parse(self, text: str) -> str:
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
20 |
text_list = [x.strip() for x in list(chain.from_iterable(text_list))]
|
21 |
return text_list
|
22 |
|
23 |
+
# class CustomINSTOutputParser(BaseOutputParser[List[str]]):
|
24 |
+
# """Parse the output of an LLM call to a list."""
|
25 |
|
26 |
+
# name = "Kit"
|
27 |
+
# name_rx = re.compile(r""+ name + r":|" + name.lower() + r":")
|
28 |
+
# whispers = re.compile((r"([\(]).*?([\)])"))
|
29 |
+
# reactions = re.compile(r"([\*]).*?([\*])")
|
30 |
+
# double_spaces = re.compile(r" ")
|
31 |
+
# quotation_rx = re.compile('"')
|
32 |
|
33 |
+
# @property
|
34 |
+
# def _type(self) -> str:
|
35 |
+
# return "str"
|
36 |
|
37 |
+
# def parse_whispers(self, text: str) -> str:
|
38 |
+
# text = self.name_rx.sub("", text).strip()
|
39 |
+
# text = self.reactions.sub("", text).strip()
|
40 |
+
# text = self.whispers.sub("", text).strip()
|
41 |
+
# text = self.double_spaces.sub(r" ", text).strip()
|
42 |
+
# text = self.quotation_rx.sub("", text).strip()
|
43 |
+
# return text
|
44 |
|
45 |
+
# def parse_split(self, text: str) -> str:
|
46 |
+
# text = text.split("[INST]")[0]
|
47 |
+
# text_list = text.split("[/INST]")
|
48 |
+
# text_list = [x.split("</s>") for x in text_list]
|
49 |
+
# text_list = [x.strip() for x in list(chain.from_iterable(text_list))]
|
50 |
+
# text_list = [x.split("\n\n") for x in text_list]
|
51 |
+
# text_list = [x.strip().rstrip("\n") for x in list(chain.from_iterable(text_list))]
|
52 |
+
# return text_list
|
53 |
|
54 |
+
# def parse(self, text: str) -> str:
|
55 |
+
# text = self.parse_whispers(text)
|
56 |
+
# text_list = self.parse_split(text)
|
57 |
+
# return text_list
|
models/databricks/custom_databricks_llm.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
2 |
+
from models.business_logic_utils.business_logic import process_app_request
|
3 |
+
|
4 |
+
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
5 |
+
from langchain_core.language_models.llms import LLM
|
6 |
+
from langchain_core.outputs import GenerationChunk
|
7 |
+
|
8 |
+
|
9 |
+
class CustomDatabricksLLM(LLM):
|
10 |
+
|
11 |
+
endpoint_url: str
|
12 |
+
bearer_token: str
|
13 |
+
issue: str
|
14 |
+
language: str
|
15 |
+
temperature: float
|
16 |
+
texter_name: str = ""
|
17 |
+
"""The number of characters from the last message of the prompt to be echoed."""
|
18 |
+
|
19 |
+
def generate_databricks_request(self, prompt):
|
20 |
+
return {
|
21 |
+
"inputs": {
|
22 |
+
"conversation_id": [""],
|
23 |
+
"prompt": [prompt],
|
24 |
+
"issue": [self.issue],
|
25 |
+
"language": [self.language],
|
26 |
+
"temperature": [self.temperature],
|
27 |
+
"max_tokens": [128],
|
28 |
+
"texter_name": [self.texter_name]
|
29 |
+
}
|
30 |
+
}
|
31 |
+
|
32 |
+
def _call(
|
33 |
+
self,
|
34 |
+
prompt: str,
|
35 |
+
stop: Optional[List[str]] = None,
|
36 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
37 |
+
**kwargs: Any,
|
38 |
+
) -> str:
|
39 |
+
request = self.generate_databricks_request(prompt)
|
40 |
+
output = process_app_request(request, self.endpoint_url, self.bearer_token)
|
41 |
+
return output['predictions'][0]['generated_text']
|
42 |
+
|
43 |
+
def _stream(
|
44 |
+
self,
|
45 |
+
prompt: str,
|
46 |
+
stop: Optional[List[str]] = None,
|
47 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
48 |
+
**kwargs: Any,
|
49 |
+
) -> Iterator[GenerationChunk]:
|
50 |
+
output = self._call(prompt, stop, run_manager, **kwargs)
|
51 |
+
for char in output:
|
52 |
+
chunk = GenerationChunk(text=char)
|
53 |
+
if run_manager:
|
54 |
+
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
55 |
+
|
56 |
+
yield chunk
|
57 |
+
|
58 |
+
@property
|
59 |
+
def _identifying_params(self) -> Dict[str, Any]:
|
60 |
+
"""Return a dictionary of identifying parameters."""
|
61 |
+
return {
|
62 |
+
# The model name allows users to specify custom token counting
|
63 |
+
# rules in LLM monitoring applications (e.g., in LangSmith users
|
64 |
+
# can provide per token pricing for their model and monitor
|
65 |
+
# costs for the given LLM.)
|
66 |
+
"model_name": "CustomChatModel",
|
67 |
+
}
|
68 |
+
|
69 |
+
@property
|
70 |
+
def _llm_type(self) -> str:
|
71 |
+
"""Get the type of language model used by this chat model. Used for logging purposes only."""
|
72 |
+
return "custom"
|
models/databricks/scenario_sim.py
DELETED
@@ -1,91 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import logging
|
4 |
-
from models.custom_parsers import CustomINSTOutputParser
|
5 |
-
from utils.app_utils import get_random_name
|
6 |
-
from app_config import ENDPOINT_NAMES
|
7 |
-
from langchain.chains import ConversationChain
|
8 |
-
from langchain_community.llms import Databricks
|
9 |
-
from langchain.prompts import PromptTemplate
|
10 |
-
|
11 |
-
from typing import Any, List, Mapping, Optional, Dict
|
12 |
-
|
13 |
-
ISSUE_MAPPING = {
|
14 |
-
"anxiety": "issue_Anxiety",
|
15 |
-
"suicide": "issue_Suicide",
|
16 |
-
"safety_planning": "issue_Suicide",
|
17 |
-
"GCT": "issue_Gral",
|
18 |
-
}
|
19 |
-
|
20 |
-
_EN_INST_TEMPLATE_ = """<s> [INST] The following is a conversation between you and a crisis counselor.
|
21 |
-
{current_issue}
|
22 |
-
You are able to reply with what the character should say. You are able to reply with your character's dialogue inside and nothing else. Do not write explanations.
|
23 |
-
Do not disclose your name unless asked.
|
24 |
-
|
25 |
-
{history} </s> [INST] {input} [/INST]"""
|
26 |
-
|
27 |
-
CURRENT_ISSUE_MAPPING = {
|
28 |
-
"issue_Suicide-en": "Your character, {texter_name}, has suicidal thoughts. Your character has a plan to end his life and has all the means and requirements to do so. {seed}",
|
29 |
-
"issue_Anxiety-en": "Your character, {texter_name}, is experiencing anxiety. Your character has suicide thoughts but no plan. {seed}",
|
30 |
-
"issue_Suicide-es": "Tu personaje, {texter_name}, tiene pensamientos suicidas. Tu personaje tiene un plan para terminar con su vida y tiene todos los medios y requerimientos para hacerlo. {seed}",
|
31 |
-
"issue_Anxiety-es": "Tu personaje, {texter_name}, experimenta ansiedad. Tu personaje tiene pensamientos suicidas pero ningun plan. {seed}",
|
32 |
-
"issue_Gral-en": "Your character {texter_name} is experiencing a mental health crisis. {seed}",
|
33 |
-
"issue_Gral-es": "Tu personaje {texter_name} esta experimentando una crisis de salud mental. {seed}",
|
34 |
-
}
|
35 |
-
|
36 |
-
def get_template_databricks_models(issue: str, language: str, texter_name: str = "", seed="") -> str:
|
37 |
-
"""_summary_
|
38 |
-
|
39 |
-
Args:
|
40 |
-
issue (str): Issue for template, current options are ['issue_Suicide','issue_Anxiety']
|
41 |
-
language (str): Language for the template, current options are ['en','es']
|
42 |
-
texter_name (str): texter to apply to template, defaults to None
|
43 |
-
|
44 |
-
Returns:
|
45 |
-
str: template
|
46 |
-
"""
|
47 |
-
current_issue = CURRENT_ISSUE_MAPPING.get(
|
48 |
-
f"{issue}-{language}", CURRENT_ISSUE_MAPPING[f"issue_Gral-{language}"]
|
49 |
-
)
|
50 |
-
default_name = get_random_name()
|
51 |
-
texter_name=default_name if not texter_name else texter_name
|
52 |
-
current_issue = current_issue.format(
|
53 |
-
texter_name=texter_name,
|
54 |
-
seed = seed
|
55 |
-
)
|
56 |
-
|
57 |
-
if language == "en":
|
58 |
-
template = _EN_INST_TEMPLATE_.format(current_issue=current_issue, history="{history}", input="{input}")
|
59 |
-
else:
|
60 |
-
raise Exception(f"Language not supported for Databricks: {language}")
|
61 |
-
|
62 |
-
return template, texter_name
|
63 |
-
|
64 |
-
def get_databricks_chain(source, template, memory, temperature=0.8, texter_name="Kit"):
|
65 |
-
|
66 |
-
endpoint_name = ENDPOINT_NAMES.get(source, "conversation_simulator")
|
67 |
-
|
68 |
-
PROMPT = PromptTemplate(
|
69 |
-
input_variables=['history', 'input'],
|
70 |
-
template=template
|
71 |
-
)
|
72 |
-
|
73 |
-
def transform_output(response):
|
74 |
-
return response['candidates'][0]['text']
|
75 |
-
|
76 |
-
llm = Databricks(endpoint_name=endpoint_name,
|
77 |
-
transform_output_fn=transform_output,
|
78 |
-
temperature=temperature,
|
79 |
-
max_tokens=256,
|
80 |
-
)
|
81 |
-
|
82 |
-
llm_chain = ConversationChain(
|
83 |
-
llm=llm,
|
84 |
-
prompt=PROMPT,
|
85 |
-
memory=memory,
|
86 |
-
output_parser=CustomINSTOutputParser(name=texter_name, name_rx=re.compile(r""+ texter_name + r":|" + texter_name.lower() + r":")),
|
87 |
-
verbose=True,
|
88 |
-
)
|
89 |
-
|
90 |
-
logging.debug(f"loaded Databricks model")
|
91 |
-
return llm_chain, ["[INST]"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/databricks/texter_sim_llm.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import logging
|
4 |
+
from models.custom_parsers import CustomStringOutputParser
|
5 |
+
from utils.app_utils import get_random_name
|
6 |
+
from app_config import ENDPOINT_NAMES, SOURCES
|
7 |
+
from models.databricks.custom_databricks_llm import CustomDatabricksLLM
|
8 |
+
from langchain.chains import ConversationChain
|
9 |
+
from langchain.prompts import PromptTemplate
|
10 |
+
|
11 |
+
from typing import Any, List, Mapping, Optional, Dict
|
12 |
+
|
13 |
+
_DATABRICKS_TEMPLATE_ = """{history}
|
14 |
+
helper: {input}
|
15 |
+
texter:"""
|
16 |
+
|
17 |
+
def get_databricks_chain(source, issue, language, memory, temperature=0.8, texter_name="Kit"):
|
18 |
+
|
19 |
+
endpoint_name = ENDPOINT_NAMES.get(source, "texter_simulator")
|
20 |
+
|
21 |
+
PROMPT = PromptTemplate(
|
22 |
+
input_variables=['history', 'input'],
|
23 |
+
template=_DATABRICKS_TEMPLATE_
|
24 |
+
)
|
25 |
+
|
26 |
+
llm = CustomDatabricksLLM(
|
27 |
+
# endpoint_url="https://dbc-6dca8e8f-4084.cloud.databricks.com/serving-endpoints/databricks-meta-llama-3-1-70b-instruct/invocations",
|
28 |
+
endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name),
|
29 |
+
bearer_token=os.environ["DATABRICKS_TOKEN"],
|
30 |
+
texter_name=texter_name,
|
31 |
+
issue=issue,
|
32 |
+
language=language,
|
33 |
+
temperature=temperature,
|
34 |
+
max_tokens=256,
|
35 |
+
)
|
36 |
+
|
37 |
+
llm_chain = ConversationChain(
|
38 |
+
llm=llm,
|
39 |
+
prompt=PROMPT,
|
40 |
+
memory=memory,
|
41 |
+
output_parser=CustomStringOutputParser(),
|
42 |
+
verbose=True,
|
43 |
+
)
|
44 |
+
|
45 |
+
logging.debug(f"loaded Databricks model")
|
46 |
+
return llm_chain, None
|
models/model_seeds.py
CHANGED
@@ -91,106 +91,6 @@ texter: idk what to do"""
|
|
91 |
"memory": """texter: Help
|
92 |
texter: I need help"""
|
93 |
},
|
94 |
-
# "NewSA__YesNo.2": {
|
95 |
-
# "prompt": "Your character is suffering from bullying at school. Your character has suicidal thoughs but has not harmed himself. You are talkative, and open to talk about your problems.",
|
96 |
-
# "memory": """texter: I just can’t take it anymore. Every day is just endless torment.
|
97 |
-
# texter: They find new ways to humiliate me."""
|
98 |
-
# },
|
99 |
-
# "NewSA__NoYes.2": {
|
100 |
-
# "prompt": "Your character is suffering anxiety from finals week. Your character does not have suicidal thoughs but has harmed herself. You are talkative, and open to talk about your problems.",
|
101 |
-
# "memory": """texter: I'm just so stressed out all the time
|
102 |
-
# texter: I can't seem to calm down"""
|
103 |
-
# },
|
104 |
-
# "NewSA__YesYes.2":
|
105 |
-
# "safety_assessment__seeded2": {
|
106 |
-
# "prompt": "Your character is griefing his wife. Your character has suicidal thoughs and has harmed himself taking pills. You are talkative, and open to talk about your problems.",
|
107 |
-
# "memory": """texter: I can’t take this pain anymore.
|
108 |
-
# texter: My wife died and I don’t want to be here without her."""
|
109 |
-
# },
|
110 |
-
# "GCT__relationship": {
|
111 |
-
# "prompt": "Your character is having a hard time becuase a failed relationship.",
|
112 |
-
# "memory": "texter: Hi, I don't know what to do",
|
113 |
-
# },
|
114 |
-
# "GCT__body_image": {
|
115 |
-
# "prompt": "Your character has a low steem and struggles with body image.",
|
116 |
-
# "memory": "texter: I feel so dumb\ntexter: nobody loves me",
|
117 |
-
# },
|
118 |
-
# "GCT__sexuality": {
|
119 |
-
# "prompt": "Your character has a sexuality identity crisis.",
|
120 |
-
# "memory": "texter: Hi\ntexter:I'm not sure who I am anymore",
|
121 |
-
# },
|
122 |
-
# "GCT__anxiety": {
|
123 |
-
# "prompt": "Your character is experiencing an anxiety crisis.",
|
124 |
-
# "memory": "texter: help!\ntexter: I'm feeling overwhelmed",
|
125 |
-
# },
|
126 |
-
# "GCT": {"prompt": "You are talkative, and you are open to talk with details about your problems.", "memory": "texter: Help"},
|
127 |
-
# "GCT__seed2": {"prompt": "Your character is experiencing an anxiety crisis. You express openly and detailed about what you're going through.", "memory": "texter: Help"},
|
128 |
-
# "safety_planning": {
|
129 |
-
# "prompt": "You are talkative, and you are open to talk with details about your problems. When you are feeling down you like to listen to classical music. Your favorite composer is Beethoven.",
|
130 |
-
# "memory": """texter: Hi, this is pointless
|
131 |
-
# helper: Hi, my name is {counselor_name} and I'm here to support you. It sounds like you are having a rough time. Do you want to share what is going on?
|
132 |
-
# texter: sure
|
133 |
-
# texter: nothing makes sense in my life, I see no future.
|
134 |
-
# helper: It takes courage to reach out when you are im. I'm here with you. Sounds like you are feeling defeated by how things are going in your life
|
135 |
-
# texter: Yeah, I guess I'm better off dead
|
136 |
-
# helper: It's really brave of you to talk about this openly. No one deserves to feel like that. I'm wondering how long have you been feeling this way?
|
137 |
-
# texter: About 1 week or so
|
138 |
-
# helper: You are so strong for dealing with this so long. I really appreciate your openess. Did something happened specifically today?
|
139 |
-
# texter: Well, finding a job is impossible, money is tight, nothing goes my way
|
140 |
-
# helper: I hear you are frustrated, and you are currently unemployed correct?
|
141 |
-
# texter: Yeah
|
142 |
-
# helper: Dealing with unemployment is hard and is normal to feel dissapointed
|
143 |
-
# texter: thanks I probably needed to hear that
|
144 |
-
# helper: If you are comfortable, is ther a name I can call you by while we talk
|
145 |
-
# texter: call me {texter_name}
|
146 |
-
# helper: Nice to meet you {texter_name}. You mentioned having thoughts of suicide, are you having those thoughts now?
|
147 |
-
# texter: Yes
|
148 |
-
# helper: I know this is thought to share. I'm wondering is there any plan to end your life?
|
149 |
-
# texter: I guess I'll just take lots of pills, that is a calm way to go out
|
150 |
-
# helper: I really appreciate your strength in talking about this. I want to help you stay safe today. Do you have the pills right now?
|
151 |
-
# texter: not really, I'll have to buy them or something""",
|
152 |
-
# },
|
153 |
-
# "safety_planning__selfharm": {
|
154 |
-
# "prompt": "",
|
155 |
-
# "memory": """texter: I need help
|
156 |
-
# texter: I cut myself, I don't want to live anymore
|
157 |
-
# helper: Hi, my name is {counselor_name}. It seems you are going through a lot. Are you self-harming right now?
|
158 |
-
# texter: Not anymore
|
159 |
-
# helper: Your safety is my priority number one. Thanks for being honest with me. Would you like to share a name I can call you?
|
160 |
-
# texter: {texter_name}
|
161 |
-
# helper: Nice to meet you {texter_name}. I'm glad you reach out this shows stregth in you. Would you like to share more on what is going on in your life?
|
162 |
-
# texter: I just can't do it anymore
|
163 |
-
# texter: Finding a job is impossible, money is tight, nothing goes my way
|
164 |
-
# helper: I hear you are frustrated, and you are currently unemployed correct?
|
165 |
-
# texter: Yeah
|
166 |
-
# helper: Dealing with unemployment is hard and is normal to feel dissapointed. How long have you been feeling this way?
|
167 |
-
# texter: a while now
|
168 |
-
# texter: I've been unemployed 6 months
|
169 |
-
# helper: You are so resilient for dealing with this so much time. You mentioned cutting yourself earlier. I want to check in your safety. Do you have suicide thoughts
|
170 |
-
# texter: Definitely
|
171 |
-
# helper: Do you have a plan?
|
172 |
-
# texter: I'll just keep cutting myself""",
|
173 |
-
# },
|
174 |
-
# "safety_planning__overdose": {
|
175 |
-
# "prompt": "",
|
176 |
-
# "memory": """texter: I want to kms
|
177 |
-
# helper: Hi there I'm {counselor_name}. I'm here to listen. It sounds like you're dealing with a lot right now. Can you tell me a little more what is going on?
|
178 |
-
# texter: I feel like nobody loves me, not even me. I don't want to live anymore
|
179 |
-
# helper: I can tell you are really going through a lot right now. Would you mind sharing a name with me?
|
180 |
-
# texter: yeah, I'm {texter_name}
|
181 |
-
# helper: Nice to meet you {texter_name}. Did something happened recently that intensified these feelings?
|
182 |
-
# texter: I dont know I'm just so done with life
|
183 |
-
# helper: I can hear how much pain you are in {texter_name}. You are smart for reaching out. You mentioned don't wanting to live anymore, I want to check in your safety, does this means you have thoughts of suicide?
|
184 |
-
# texter: Yeah, what else would it be
|
185 |
-
# helper: Thanks for sharing that with me. It is not easy to accept those feelings specially with a stranger over text. Do you have a plan to end your life?
|
186 |
-
# texter: yeah I've been thinking about it for a while
|
187 |
-
# helper: Sounds like you've been contemplating this for a while. Would you mind sharing this plan with me?
|
188 |
-
# texter: I thought about taking a bunch of benadryll and be done with it
|
189 |
-
# helper: You've been so forthcoming with all this and I admire your stregth for holding on this long. Do you have those pills right now?
|
190 |
-
# texter: They are at the cabinet right now
|
191 |
-
# helper: You been so strong so far {texter_name}. I'm here for you tonight. Your safety is really important to me. Do you have a date you are going to end your life?
|
192 |
-
# texter: I was thinking tonight""",
|
193 |
-
# },
|
194 |
}
|
195 |
|
196 |
seed2str = {
|
|
|
91 |
"memory": """texter: Help
|
92 |
texter: I need help"""
|
93 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
}
|
95 |
|
96 |
seed2str = {
|
models/openai/role_models.py
CHANGED
@@ -1,88 +1,45 @@
|
|
1 |
import logging
|
2 |
-
import pandas as pd
|
3 |
from models.custom_parsers import CustomStringOutputParser
|
4 |
-
from utils.app_utils import get_random_name
|
5 |
from langchain.chains import ConversationChain
|
6 |
-
from
|
7 |
from langchain.prompts import PromptTemplate
|
|
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
"suicide": "issue_Suicide",
|
13 |
-
"safety_planning": "issue_Suicide",
|
14 |
-
"GCT": "issue_Gral",
|
15 |
-
}
|
16 |
-
|
17 |
-
EN_TEXTER_TEMPLATE_ = """The following is a conversation between you and a crisis counselor.
|
18 |
-
{current_issue}
|
19 |
-
You are able to reply with what the character should say. You are able to reply with your character's dialogue inside and nothing else. Do not write explanations.
|
20 |
-
Do not disclose your name unless asked.
|
21 |
-
Current conversation:
|
22 |
-
{history}
|
23 |
-
helper: {input}
|
24 |
texter:"""
|
25 |
|
26 |
-
|
27 |
-
{
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
texter:"""
|
34 |
-
|
35 |
-
CURRENT_ISSUE_MAPPING = {
|
36 |
-
"issue_Suicide-en": "Your character, {texter_name}, has suicidal thoughts. Your character has a plan to end his life and has all the means and requirements to do so. {seed}",
|
37 |
-
"issue_Anxiety-en": "Your character, {texter_name}, is experiencing anxiety. Your character has suicide thoughts but no plan. {seed}",
|
38 |
-
"issue_Suicide-es": "Tu personaje, {texter_name}, tiene pensamientos suicidas. Tu personaje tiene un plan para terminar con su vida y tiene todos los medios y requerimientos para hacerlo. {seed}",
|
39 |
-
"issue_Anxiety-es": "Tu personaje, {texter_name}, experimenta ansiedad. Tu personaje tiene pensamientos suicidas pero ningun plan. {seed}",
|
40 |
-
"issue_Gral-en": "Your character {texter_name} is experiencing a mental health crisis. {seed}",
|
41 |
-
"issue_Gral-es": "Tu personaje {texter_name} esta experimentando una crisis de salud mental. {seed}",
|
42 |
-
}
|
43 |
-
|
44 |
-
def get_template_role_models(issue: str, language: str, texter_name: str = "", seed="") -> str:
|
45 |
-
"""_summary_
|
46 |
-
|
47 |
-
Args:
|
48 |
-
issue (str): Issue for template, current options are ['issue_Suicide','issue_Anxiety']
|
49 |
-
language (str): Language for the template, current options are ['en','es']
|
50 |
-
texter_name (str): texter to apply to template, defaults to None
|
51 |
-
|
52 |
-
Returns:
|
53 |
-
str: template
|
54 |
-
"""
|
55 |
-
current_issue = CURRENT_ISSUE_MAPPING.get(
|
56 |
-
f"{issue}-{language}", CURRENT_ISSUE_MAPPING[f"issue_Gral-{language}"]
|
57 |
-
)
|
58 |
-
default_name = get_random_name()
|
59 |
-
current_issue = current_issue.format(
|
60 |
-
texter_name=default_name if not texter_name else texter_name,
|
61 |
-
seed = seed
|
62 |
-
)
|
63 |
-
|
64 |
-
if language == "en":
|
65 |
-
template = EN_TEXTER_TEMPLATE_.format(current_issue=current_issue, history="{history}", input="{input}")
|
66 |
-
elif language == "es":
|
67 |
-
template = SP_TEXTER_TEMPLATE_.format(current_issue=current_issue, history="{history}", input="{input}")
|
68 |
|
69 |
-
|
|
|
|
|
70 |
|
71 |
def get_role_chain(template, memory, temperature=0.8):
|
72 |
|
|
|
73 |
PROMPT = PromptTemplate(
|
74 |
input_variables=['history', 'input'],
|
75 |
template=template
|
76 |
)
|
77 |
-
llm =
|
78 |
-
|
79 |
-
|
|
|
80 |
)
|
81 |
llm_chain = ConversationChain(
|
82 |
llm=llm,
|
83 |
prompt=PROMPT,
|
84 |
memory=memory,
|
85 |
-
output_parser=CustomStringOutputParser()
|
|
|
86 |
)
|
87 |
-
logging.debug(f"loaded
|
88 |
return llm_chain, "helper:"
|
|
|
1 |
import logging
|
|
|
2 |
from models.custom_parsers import CustomStringOutputParser
|
|
|
3 |
from langchain.chains import ConversationChain
|
4 |
+
from langchain_openai import ChatOpenAI
|
5 |
from langchain.prompts import PromptTemplate
|
6 |
+
from models.business_logic_utils.input_processing import initialize_conversation
|
7 |
|
8 |
+
OPENAI_TEMPLATE = """{template}
|
9 |
+
{{history}}
|
10 |
+
helper: {{input}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
texter:"""
|
12 |
|
13 |
+
def get_template_role_models(issue: str, language: str, texter_name: str = "") -> str:
|
14 |
+
model_input = {
|
15 |
+
"issue": issue,
|
16 |
+
"language": language,
|
17 |
+
"texter_name": texter_name,
|
18 |
+
"messages": [],
|
19 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
# Initialize the conversation (adds the system message)
|
22 |
+
model_input = initialize_conversation(model_input, "")
|
23 |
+
return model_input["messages"][0]["content"]
|
24 |
|
25 |
def get_role_chain(template, memory, temperature=0.8):
|
26 |
|
27 |
+
template = OPENAI_TEMPLATE.format(template=template)
|
28 |
PROMPT = PromptTemplate(
|
29 |
input_variables=['history', 'input'],
|
30 |
template=template
|
31 |
)
|
32 |
+
llm = ChatOpenAI(
|
33 |
+
model="gpt-4o",
|
34 |
+
temperature=temperature,
|
35 |
+
max_tokens=256,
|
36 |
)
|
37 |
llm_chain = ConversationChain(
|
38 |
llm=llm,
|
39 |
prompt=PROMPT,
|
40 |
memory=memory,
|
41 |
+
output_parser=CustomStringOutputParser(),
|
42 |
+
verbose=True,
|
43 |
)
|
44 |
+
logging.debug(f"loaded GPT4o model")
|
45 |
return llm_chain, "helper:"
|
requirements.txt
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
scipy==1.11.1
|
2 |
-
|
3 |
-
langchain==0.1.0
|
4 |
pymongo==4.5.0
|
5 |
mlflow==2.9.0
|
6 |
-
langchain-
|
|
|
|
1 |
scipy==1.11.1
|
2 |
+
langchain==0.3.0
|
|
|
3 |
pymongo==4.5.0
|
4 |
mlflow==2.9.0
|
5 |
+
langchain-openai==0.2.0
|
6 |
+
streamlit==1.38.0
|
utils/chain_utils.py
CHANGED
@@ -1,36 +1,33 @@
|
|
|
|
1 |
from models.model_seeds import seeds
|
2 |
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
|
3 |
from models.openai.role_models import get_role_chain, get_template_role_models
|
4 |
from models.databricks.scenario_sim_biz import get_databricks_biz_chain
|
5 |
-
from models.databricks.
|
|
|
|
|
6 |
|
7 |
def get_chain(issue, language, source, memory, temperature, texter_name=""):
|
8 |
if source in ("OA_finetuned"):
|
9 |
OA_engine = finetuned_models[f"{issue}-{language}"]
|
10 |
return get_finetuned_chain(OA_engine, memory, temperature)
|
11 |
elif source in ('OA_rolemodel'):
|
12 |
-
|
13 |
-
template = get_template_role_models(issue, language, texter_name=texter_name, seed=seed)
|
14 |
return get_role_chain(template, memory, temperature)
|
15 |
-
elif source in ('CTL_llama2'
|
16 |
if language == "English":
|
17 |
language = "en"
|
18 |
elif language == "Spanish":
|
19 |
language = "es"
|
20 |
return get_databricks_biz_chain(source, issue, language, memory, temperature)
|
21 |
-
elif source in ('
|
22 |
if language == "English":
|
23 |
language = "en"
|
24 |
elif language == "Spanish":
|
25 |
language = "es"
|
26 |
-
|
27 |
-
template, texter_name = get_template_databricks_models(issue, language, texter_name=texter_name, seed=seed)
|
28 |
-
return get_databricks_chain(source, template, memory, temperature, texter_name)
|
29 |
|
30 |
-
from typing import cast
|
31 |
-
|
32 |
def custom_chain_predict(llm_chain, input, stop):
|
33 |
-
|
34 |
inputs = llm_chain.prep_inputs({"input":input, "stop":stop})
|
35 |
llm_chain._validate_inputs(inputs)
|
36 |
outputs = llm_chain._call(inputs)
|
|
|
1 |
+
from streamlit.logger import get_logger
|
2 |
from models.model_seeds import seeds
|
3 |
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
|
4 |
from models.openai.role_models import get_role_chain, get_template_role_models
|
5 |
from models.databricks.scenario_sim_biz import get_databricks_biz_chain
|
6 |
+
from models.databricks.texter_sim_llm import get_databricks_chain
|
7 |
+
|
8 |
+
logger = get_logger(__name__)
|
9 |
|
10 |
def get_chain(issue, language, source, memory, temperature, texter_name=""):
|
11 |
if source in ("OA_finetuned"):
|
12 |
OA_engine = finetuned_models[f"{issue}-{language}"]
|
13 |
return get_finetuned_chain(OA_engine, memory, temperature)
|
14 |
elif source in ('OA_rolemodel'):
|
15 |
+
template = get_template_role_models(issue, language, texter_name=texter_name)
|
|
|
16 |
return get_role_chain(template, memory, temperature)
|
17 |
+
elif source in ('CTL_llama2'):
|
18 |
if language == "English":
|
19 |
language = "en"
|
20 |
elif language == "Spanish":
|
21 |
language = "es"
|
22 |
return get_databricks_biz_chain(source, issue, language, memory, temperature)
|
23 |
+
elif source in ('CTL_llama3'):
|
24 |
if language == "English":
|
25 |
language = "en"
|
26 |
elif language == "Spanish":
|
27 |
language = "es"
|
28 |
+
return get_databricks_chain(source, issue, language, memory, temperature, texter_name=texter_name)
|
|
|
|
|
29 |
|
|
|
|
|
30 |
def custom_chain_predict(llm_chain, input, stop):
|
|
|
31 |
inputs = llm_chain.prep_inputs({"input":input, "stop":stop})
|
32 |
llm_chain._validate_inputs(inputs)
|
33 |
outputs = llm_chain._call(inputs)
|
utils/memory_utils.py
CHANGED
@@ -23,7 +23,7 @@ def change_memories(memories, language, changed_source=False):
|
|
23 |
if (memory not in st.session_state) or changed_source:
|
24 |
source = params['source']
|
25 |
logger.info(f"Source for memory {memory} is {source}")
|
26 |
-
if source in ('OA_rolemodel','OA_finetuned',"CTL_llama2"):
|
27 |
st.session_state[memory] = ConversationBufferMemory(ai_prefix='texter', human_prefix='helper')
|
28 |
elif source in ('CTL_mistral'):
|
29 |
st.session_state[memory] = CustomBufferInstructionMemory(human_prefix="</s> [INST]", memory_key="history")
|
|
|
23 |
if (memory not in st.session_state) or changed_source:
|
24 |
source = params['source']
|
25 |
logger.info(f"Source for memory {memory} is {source}")
|
26 |
+
if source in ('OA_rolemodel','OA_finetuned',"CTL_llama2","CTL_llama3"):
|
27 |
st.session_state[memory] = ConversationBufferMemory(ai_prefix='texter', human_prefix='helper')
|
28 |
elif source in ('CTL_mistral'):
|
29 |
st.session_state[memory] = CustomBufferInstructionMemory(human_prefix="</s> [INST]", memory_key="history")
|