Spaces:
Sleeping
Sleeping
ivnban27-ctl
commited on
Commit
·
975a927
1
Parent(s):
39d66f7
Added MongoDB functionality (#1)
Browse files- Added MongoDB functionality (0f381ca015dc186af8f2478cf9029f86247ab469)
- Delete utils.py (655ffb4bff2e04c9a9e22032907957fed331d467)
- .gitignore +1 -1
- app_config.py +15 -0
- utils.py → app_utils.py +43 -11
- convosim.py +28 -13
- models/openai/finetuned_models.py +6 -5
- models/openai/role_models.py +1 -2
- mongo_utils.py +143 -0
- pages/comparisor.py +97 -33
- pages/manual_comparisor.py +108 -0
- requirements.txt +2 -1
.gitignore
CHANGED
@@ -177,7 +177,7 @@ cython_debug/
|
|
177 |
|
178 |
# Jupyter NB Checkpoints
|
179 |
.ipynb_checkpoints/
|
180 |
-
|
181 |
# exclude data from source control by default
|
182 |
/data/
|
183 |
|
|
|
177 |
|
178 |
# Jupyter NB Checkpoints
|
179 |
.ipynb_checkpoints/
|
180 |
+
*.ipynb
|
181 |
# exclude data from source control by default
|
182 |
/data/
|
183 |
|
app_config.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ISSUES = ['Anxiety','Suicide']
|
2 |
+
SOURCES = ['OA_rolemodel', 'OA_finetuned']
|
3 |
+
SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT3.5',
|
4 |
+
"OA_finetuned":'Finetuned OpenAI'}
|
5 |
+
|
6 |
+
def source2label(source):
|
7 |
+
return SOURCES_LAB[source]
|
8 |
+
|
9 |
+
ENVIRON = "dev"
|
10 |
+
|
11 |
+
DB_SCHEMA = 'prod_db' if ENVIRON == 'prod' else 'test_db'
|
12 |
+
DB_CONVOS = 'conversations'
|
13 |
+
DB_COMPLETIONS = 'comparison_completions'
|
14 |
+
DB_BATTLES = 'battles'
|
15 |
+
DB_ERRORS = 'completion_errors'
|
utils.py → app_utils.py
RENAMED
@@ -1,8 +1,16 @@
|
|
|
|
1 |
import streamlit as st
|
|
|
|
|
2 |
from langchain.memory import ConversationBufferMemory
|
3 |
|
|
|
4 |
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
|
5 |
from models.openai.role_models import get_role_chain, role_templates
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def add_initial_message(model_name, memory):
|
8 |
if "Spanish" in model_name:
|
@@ -10,25 +18,49 @@ def add_initial_message(model_name, memory):
|
|
10 |
else:
|
11 |
memory.chat_memory.add_ai_message("Hi I need help")
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
st.session_state[memory].clear()
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
24 |
if len(st.session_state[memory].buffer_as_messages) < 1:
|
25 |
add_initial_message(language, st.session_state[memory])
|
26 |
|
27 |
|
28 |
def get_chain(issue, language, source, memory, temperature):
|
29 |
-
if source in ("
|
30 |
OA_engine = finetuned_models[f"{issue}-{language}"]
|
31 |
return get_finetuned_chain(OA_engine, memory, temperature)
|
32 |
-
|
33 |
template = role_templates[f"{issue}-{language}"]
|
34 |
return get_role_chain(template, memory, temperature)
|
|
|
1 |
+
import datetime as dt
|
2 |
import streamlit as st
|
3 |
+
from streamlit.logger import get_logger
|
4 |
+
import langchain
|
5 |
from langchain.memory import ConversationBufferMemory
|
6 |
|
7 |
+
from app_config import ENVIRON
|
8 |
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
|
9 |
from models.openai.role_models import get_role_chain, role_templates
|
10 |
+
from mongo_utils import new_convo
|
11 |
+
|
12 |
+
langchain.verbose = ENVIRON=="dev"
|
13 |
+
logger = get_logger(__name__)
|
14 |
|
15 |
def add_initial_message(model_name, memory):
|
16 |
if "Spanish" in model_name:
|
|
|
18 |
else:
|
19 |
memory.chat_memory.add_ai_message("Hi I need help")
|
20 |
|
21 |
+
|
22 |
+
def push_convo2db(memories, username, language):
|
23 |
+
if len(memories) == 1:
|
24 |
+
issue = memories['memory']['issue']
|
25 |
+
model_one = memories['memory']['source']
|
26 |
+
new_convo(st.session_state['db_client'], issue, language, username, False, model_one)
|
27 |
+
else:
|
28 |
+
issue = memories['commonMemory']['issue']
|
29 |
+
model_one = memories['memoryA']['source']
|
30 |
+
model_two = memories['memoryB']['source']
|
31 |
+
new_convo(st.session_state['db_client'], issue, language, username, True, model_one, model_two)
|
32 |
+
|
33 |
+
def change_memories(memories, username, language, changed_source=False):
|
34 |
+
for memory, params in memories.items():
|
35 |
+
if (memory not in st.session_state) or changed_source:
|
36 |
+
source = params['source']
|
37 |
+
logger.info(f"Source for memory {memory} is {source}")
|
38 |
+
if source in ('OA_rolemodel','OA_finetuned'):
|
39 |
+
st.session_state[memory] = ConversationBufferMemory(ai_prefix='texter', human_prefix='helper')
|
40 |
+
|
41 |
+
if ("convo_id" in st.session_state) and changed_source:
|
42 |
+
del st.session_state['convo_id']
|
43 |
+
|
44 |
+
|
45 |
+
def clear_memory(memories, username, language):
|
46 |
+
for memory, _ in memories.items():
|
47 |
st.session_state[memory].clear()
|
48 |
|
49 |
+
if "convo_id" in st.session_state:
|
50 |
+
del st.session_state['convo_id']
|
51 |
+
|
52 |
+
|
53 |
+
def create_memory_add_initial_message(memories, username, language, changed_source=False):
|
54 |
+
change_memories(memories, username, language, changed_source=changed_source)
|
55 |
+
for memory, _ in memories.items():
|
56 |
if len(st.session_state[memory].buffer_as_messages) < 1:
|
57 |
add_initial_message(language, st.session_state[memory])
|
58 |
|
59 |
|
60 |
def get_chain(issue, language, source, memory, temperature):
|
61 |
+
if source in ("OA_finetuned"):
|
62 |
OA_engine = finetuned_models[f"{issue}-{language}"]
|
63 |
return get_finetuned_chain(OA_engine, memory, temperature)
|
64 |
+
elif source in ('OA_rolemodel'):
|
65 |
template = role_templates[f"{issue}-{language}"]
|
66 |
return get_role_chain(template, memory, temperature)
|
convosim.py
CHANGED
@@ -1,38 +1,53 @@
|
|
1 |
-
import openai
|
2 |
import os
|
3 |
import streamlit as st
|
|
|
4 |
from langchain.schema.messages import HumanMessage
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
openai_api_key = os.environ['OPENAI_API_KEY']
|
9 |
-
memories =
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
with st.sidebar:
|
|
|
12 |
temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
|
13 |
-
issue = st.selectbox("Select an Issue",
|
14 |
-
on_change=clear_memory,
|
15 |
)
|
16 |
supported_languages = ['English', "Spanish"] if issue == "Anxiety" else ['English']
|
17 |
language = st.selectbox("Select a Language", supported_languages, index=0,
|
18 |
-
on_change=clear_memory,
|
19 |
)
|
20 |
|
21 |
-
source = st.selectbox("Select a source Model A",
|
22 |
-
|
23 |
)
|
24 |
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
27 |
|
28 |
st.title("💬 Simulator")
|
29 |
|
30 |
-
for msg in
|
31 |
role = "user" if type(msg) == HumanMessage else "assistant"
|
32 |
st.chat_message(role).write(msg.content)
|
33 |
|
34 |
if prompt := st.chat_input():
|
|
|
|
|
|
|
35 |
st.chat_message("user").write(prompt)
|
36 |
-
response = llm_chain.predict(input=prompt, stop=
|
37 |
# response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
|
38 |
st.chat_message("assistant").write(response)
|
|
|
|
|
1 |
import os
|
2 |
import streamlit as st
|
3 |
+
from streamlit.logger import get_logger
|
4 |
from langchain.schema.messages import HumanMessage
|
5 |
+
from mongo_utils import get_db_client
|
6 |
+
from app_utils import create_memory_add_initial_message, clear_memory, get_chain, push_convo2db
|
7 |
+
from app_config import ISSUES, SOURCES, source2label
|
8 |
|
9 |
+
logger = get_logger(__name__)
|
|
|
10 |
openai_api_key = os.environ['OPENAI_API_KEY']
|
11 |
+
memories = {'memory':{"issue": ISSUES[0], "source": SOURCES[0]}}
|
12 |
+
|
13 |
+
if 'previous_source' not in st.session_state:
|
14 |
+
st.session_state['previous_source'] = SOURCES[0]
|
15 |
+
if 'db_client' not in st.session_state:
|
16 |
+
st.session_state["db_client"] = get_db_client()
|
17 |
|
18 |
with st.sidebar:
|
19 |
+
username = st.text_input("Username", value='ivnban-ctl', max_chars=30)
|
20 |
temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
|
21 |
+
issue = st.selectbox("Select an Issue", ISSUES, index=0,
|
22 |
+
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
23 |
)
|
24 |
supported_languages = ['English', "Spanish"] if issue == "Anxiety" else ['English']
|
25 |
language = st.selectbox("Select a Language", supported_languages, index=0,
|
26 |
+
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
27 |
)
|
28 |
|
29 |
+
source = st.selectbox("Select a source Model A", SOURCES, index=1,
|
30 |
+
format_func=source2label,
|
31 |
)
|
32 |
|
33 |
+
memories = {'memory':{"issue":issue, "source":source}}
|
34 |
+
changed_source = st.session_state['previous_source'] != source
|
35 |
+
create_memory_add_initial_message(memories, username, language, changed_source=changed_source)
|
36 |
+
st.session_state['previous_source'] = source
|
37 |
+
memoryA = st.session_state[list(memories.keys())[0]]
|
38 |
+
llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature)
|
39 |
|
40 |
st.title("💬 Simulator")
|
41 |
|
42 |
+
for msg in memoryA.buffer_as_messages:
|
43 |
role = "user" if type(msg) == HumanMessage else "assistant"
|
44 |
st.chat_message(role).write(msg.content)
|
45 |
|
46 |
if prompt := st.chat_input():
|
47 |
+
if 'convo_id' not in st.session_state:
|
48 |
+
push_convo2db(memories, username, language)
|
49 |
+
|
50 |
st.chat_message("user").write(prompt)
|
51 |
+
response = llm_chain.predict(input=prompt, stop=stopper)
|
52 |
# response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
|
53 |
st.chat_message("assistant").write(response)
|
models/openai/finetuned_models.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
-
import
|
2 |
from models.custom_parsers import CustomStringOutputParser
|
3 |
from langchain.chains import LLMChain
|
4 |
from langchain.llms import OpenAI
|
5 |
from langchain.prompts import PromptTemplate
|
6 |
-
|
|
|
|
|
7 |
|
8 |
finetuned_models = {
|
9 |
# "olivia_babbage_engine": "babbage:ft-crisis-text-line:exp-olivia-babbage-2023-02-23-19-57-19",
|
@@ -67,9 +69,8 @@ def get_finetuned_chain(model_name, memory, temperature=0.8):
|
|
67 |
llm_chain = LLMChain(
|
68 |
llm=llm,
|
69 |
prompt=PROMPT,
|
70 |
-
verbose=True,
|
71 |
memory=memory,
|
72 |
output_parser = CustomStringOutputParser()
|
73 |
)
|
74 |
-
|
75 |
-
return llm_chain
|
|
|
1 |
+
# from streamlit.logger import get_logger
|
2 |
from models.custom_parsers import CustomStringOutputParser
|
3 |
from langchain.chains import LLMChain
|
4 |
from langchain.llms import OpenAI
|
5 |
from langchain.prompts import PromptTemplate
|
6 |
+
|
7 |
+
# logger = get_logger(__name__)
|
8 |
+
# logger.debug("START APP")
|
9 |
|
10 |
finetuned_models = {
|
11 |
# "olivia_babbage_engine": "babbage:ft-crisis-text-line:exp-olivia-babbage-2023-02-23-19-57-19",
|
|
|
69 |
llm_chain = LLMChain(
|
70 |
llm=llm,
|
71 |
prompt=PROMPT,
|
|
|
72 |
memory=memory,
|
73 |
output_parser = CustomStringOutputParser()
|
74 |
)
|
75 |
+
# logger.debug(f"{__name__}: loaded fine tuned model {model_name}")
|
76 |
+
return llm_chain, "helper:"
|
models/openai/role_models.py
CHANGED
@@ -49,9 +49,8 @@ def get_role_chain(template, memory, temperature=0.8):
|
|
49 |
llm_chain = ConversationChain(
|
50 |
llm=llm,
|
51 |
prompt=PROMPT,
|
52 |
-
verbose=True,
|
53 |
memory=memory,
|
54 |
output_parser=CustomStringOutputParser()
|
55 |
)
|
56 |
logging.debug(f"loaded GPT3.5 model")
|
57 |
-
return llm_chain
|
|
|
49 |
llm_chain = ConversationChain(
|
50 |
llm=llm,
|
51 |
prompt=PROMPT,
|
|
|
52 |
memory=memory,
|
53 |
output_parser=CustomStringOutputParser()
|
54 |
)
|
55 |
logging.debug(f"loaded GPT3.5 model")
|
56 |
+
return llm_chain, "helper:"
|
mongo_utils.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import datetime as dt
|
3 |
+
import streamlit as st
|
4 |
+
from streamlit.logger import get_logger
|
5 |
+
from pymongo.mongo_client import MongoClient
|
6 |
+
from pymongo.server_api import ServerApi
|
7 |
+
from app_config import DB_SCHEMA, DB_COMPLETIONS, DB_CONVOS, DB_BATTLES, DB_ERRORS
|
8 |
+
|
9 |
+
DB_URL = os.environ['MONGO_URL']
|
10 |
+
DB_USR = os.environ['MONGO_USR']
|
11 |
+
DB_PWD = os.environ['MONGO_PWD']
|
12 |
+
|
13 |
+
logger = get_logger(__name__)
|
14 |
+
|
15 |
+
def get_db_client():
|
16 |
+
uri = f"mongodb+srv://{DB_USR}:{DB_PWD}@{DB_URL}/?retryWrites=true&w=majority"
|
17 |
+
# Create a new client and connect to the server
|
18 |
+
client = MongoClient(uri, server_api=ServerApi('1'))
|
19 |
+
# Send a ping to confirm a successful connection
|
20 |
+
try:
|
21 |
+
client.admin.command('ping')
|
22 |
+
logger.info(f"DBUTILS: Pinged your deployment. You successfully connected to MongoDB!")
|
23 |
+
return client
|
24 |
+
except Exception as e:
|
25 |
+
logger.error(e)
|
26 |
+
|
27 |
+
def new_convo(client, issue, language, username, is_comparison, model_one, model_two=None):
|
28 |
+
convo = {
|
29 |
+
"start_timestamp": dt.datetime.now(tz=dt.timezone.utc),
|
30 |
+
"issue": issue,
|
31 |
+
"language": language,
|
32 |
+
"username": username,
|
33 |
+
"is_comparison": is_comparison,
|
34 |
+
"model_one": model_one,
|
35 |
+
"model_two": model_two,
|
36 |
+
}
|
37 |
+
|
38 |
+
db = client[DB_SCHEMA]
|
39 |
+
convos = db[DB_CONVOS]
|
40 |
+
convo_id = convos.insert_one(convo).inserted_id
|
41 |
+
logger.info(f"DBUTILS: new convo id is {convo_id}")
|
42 |
+
st.session_state['convo_id'] = convo_id
|
43 |
+
|
44 |
+
def new_comparison(client, prompt_timestamp, completion_timestamp,
|
45 |
+
chat_history, prompt, completionA, completionB,
|
46 |
+
source="webapp", subset=None
|
47 |
+
):
|
48 |
+
comparison = {
|
49 |
+
"prompt_timestamp": prompt_timestamp,
|
50 |
+
"completion_timestamp": completion_timestamp,
|
51 |
+
"source": source,
|
52 |
+
"subset": subset,
|
53 |
+
"model_one_args": {
|
54 |
+
'temperature':0.8
|
55 |
+
},
|
56 |
+
"model_two_args": {
|
57 |
+
'temperature':0.8
|
58 |
+
},
|
59 |
+
"convo_id": st.session_state['convo_id'],
|
60 |
+
"chat_history": chat_history,
|
61 |
+
"prompt": prompt,
|
62 |
+
"compeltion_model_one": completionA,
|
63 |
+
"compeltion_model_two": completionB,
|
64 |
+
}
|
65 |
+
|
66 |
+
db = client[DB_SCHEMA]
|
67 |
+
comparisons = db[DB_COMPLETIONS]
|
68 |
+
comparison_id = comparisons.insert_one(comparison).inserted_id
|
69 |
+
logger.info(f"DBUTILS: new comparison id is {comparison_id}")
|
70 |
+
st.session_state['comparison_id'] = comparison_id
|
71 |
+
|
72 |
+
def new_battle_result(client, comparison_id, convo_id, username, model_one, model_two, winner):
|
73 |
+
battle = {
|
74 |
+
"battle_timestamp": dt.datetime.now(tz=dt.timezone.utc),
|
75 |
+
"comparison_id": comparison_id,
|
76 |
+
"convo_id": convo_id,
|
77 |
+
"username": username,
|
78 |
+
"model_one": model_one,
|
79 |
+
"model_two": model_two,
|
80 |
+
"winner": winner,
|
81 |
+
|
82 |
+
}
|
83 |
+
|
84 |
+
db = client[DB_SCHEMA]
|
85 |
+
battles = db[DB_BATTLES]
|
86 |
+
battle_id = battles.insert_one(battle).inserted_id
|
87 |
+
logger.info(f"DBUTILS: new battle id is {battle_id}")
|
88 |
+
|
89 |
+
def new_completion_error(client, comparison_id, username, model):
|
90 |
+
error = {
|
91 |
+
"error_timestamp": dt.datetime.now(tz=dt.timezone.utc),
|
92 |
+
"comparison_id": comparison_id,
|
93 |
+
"username": username,
|
94 |
+
"model": model,
|
95 |
+
}
|
96 |
+
|
97 |
+
db = client[DB_SCHEMA]
|
98 |
+
errors = db[DB_ERRORS]
|
99 |
+
error_id = errors.insert_one(error).inserted_id
|
100 |
+
logger.info(f"DBUTILS: new error id is {error_id}")
|
101 |
+
|
102 |
+
def get_non_assesed_comparison(client, username):
|
103 |
+
from bson.son import SON
|
104 |
+
pipeline = [
|
105 |
+
{'$lookup': {
|
106 |
+
'from': DB_BATTLES,
|
107 |
+
'localField': '_id',
|
108 |
+
'foreignField': 'comparison_id',
|
109 |
+
"pipeline": [
|
110 |
+
{"$match": {"username":username}},
|
111 |
+
],
|
112 |
+
'as': 'battles'
|
113 |
+
}},
|
114 |
+
{'$lookup': {
|
115 |
+
'from': DB_CONVOS,
|
116 |
+
'localField': 'convo_id',
|
117 |
+
'foreignField': '_id',
|
118 |
+
'as': 'convo_info'
|
119 |
+
}},
|
120 |
+
{"$match":{
|
121 |
+
"battles": {"$size":0},
|
122 |
+
}},
|
123 |
+
{"$addFields": {
|
124 |
+
"priority": {
|
125 |
+
"$cond":[
|
126 |
+
{"$eq": ["$source","manual"]},
|
127 |
+
1,
|
128 |
+
0
|
129 |
+
]
|
130 |
+
},
|
131 |
+
}},
|
132 |
+
{"$sort": SON([
|
133 |
+
("priority", -1),
|
134 |
+
("prompt_timestamp", 1),
|
135 |
+
("convo_id", 1),
|
136 |
+
])
|
137 |
+
},
|
138 |
+
{"$limit": 1}
|
139 |
+
]
|
140 |
+
|
141 |
+
db = client[DB_SCHEMA]
|
142 |
+
return list(db[DB_COMPLETIONS].aggregate(pipeline))
|
143 |
+
|
pages/comparisor.py
CHANGED
@@ -1,14 +1,27 @@
|
|
1 |
-
|
2 |
import os
|
|
|
|
|
3 |
import streamlit as st
|
|
|
4 |
from langchain.schema.messages import HumanMessage
|
5 |
-
import
|
6 |
-
|
7 |
-
from
|
8 |
|
|
|
9 |
openai_api_key = os.environ['OPENAI_API_KEY']
|
10 |
-
memories =
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
def delete_last_message(memory):
|
14 |
last_prompt = memory.chat_memory.messages[-2].content
|
@@ -20,59 +33,91 @@ def replace_last_message(memory, new_message):
|
|
20 |
memory.chat_memory.add_ai_message(new_message)
|
21 |
|
22 |
def regenerateA():
|
23 |
-
last_prompt = delete_last_message(
|
24 |
-
new_response = llm_chainA.predict(input=last_prompt, stop=
|
25 |
col1.chat_message("user").write(last_prompt)
|
26 |
col1.chat_message("assistant").write(new_response)
|
|
|
27 |
|
28 |
def regenerateB():
|
29 |
-
last_prompt = delete_last_message(
|
30 |
-
new_response = llm_chainB.predict(input=last_prompt, stop=
|
31 |
col2.chat_message("user").write(last_prompt)
|
32 |
col2.chat_message("assistant").write(new_response)
|
|
|
33 |
|
34 |
def replaceA():
|
35 |
-
last_prompt =
|
36 |
-
new_message =
|
37 |
-
replace_last_message(
|
38 |
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
def replaceB():
|
41 |
-
last_prompt =
|
42 |
-
new_message =
|
43 |
-
replace_last_message(
|
44 |
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
def regenerateBoth():
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
def bothGood():
|
51 |
-
if len(
|
52 |
pass
|
53 |
else:
|
54 |
-
|
55 |
-
|
|
|
56 |
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":last_reponse})
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
with st.sidebar:
|
59 |
-
|
60 |
-
|
|
|
61 |
)
|
62 |
supported_languages = ['English', "Spanish"] if issue == "Anxiety" else ['English']
|
63 |
language = st.selectbox("Select a Language", supported_languages, index=0,
|
64 |
-
on_change=clear_memory,
|
65 |
)
|
66 |
|
67 |
with st.expander("Model A"):
|
68 |
temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1)
|
69 |
-
sourceA = st.selectbox("Select a source Model A",
|
70 |
-
|
71 |
)
|
72 |
with st.expander("Model B"):
|
73 |
temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1)
|
74 |
-
sourceB = st.selectbox("Select a source Model B",
|
75 |
-
|
76 |
)
|
77 |
|
78 |
sbcol1, sbcol2 = st.columns(2)
|
@@ -86,9 +131,20 @@ with st.sidebar:
|
|
86 |
# regenB = sbcol2.button("Regenerate B", on_click=regenerateB)
|
87 |
clear = st.button("Clear History", on_click=clear_memory, args=(memories,))
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
st.title(f"💬 History")
|
94 |
for msg in st.session_state['commonMemory'].buffer_as_messages:
|
@@ -115,12 +171,20 @@ def disable_chat():
|
|
115 |
return True
|
116 |
|
117 |
if prompt := st.chat_input(disabled=disable_chat()):
|
|
|
|
|
|
|
|
|
118 |
col1.chat_message("user").write(prompt)
|
119 |
col2.chat_message("user").write(prompt)
|
120 |
|
121 |
-
responseA = llm_chainA.predict(input=prompt, stop=
|
122 |
-
responseB = llm_chainB.predict(input=prompt, stop=
|
|
|
123 |
|
|
|
|
|
|
|
124 |
col1.chat_message("assistant").write(responseA)
|
125 |
col2.chat_message("assistant").write(responseB)
|
126 |
|
|
|
1 |
+
|
2 |
import os
|
3 |
+
import random
|
4 |
+
import datetime as dt
|
5 |
import streamlit as st
|
6 |
+
from streamlit.logger import get_logger
|
7 |
from langchain.schema.messages import HumanMessage
|
8 |
+
from mongo_utils import get_db_client, new_comparison, new_battle_result
|
9 |
+
from app_utils import create_memory_add_initial_message, clear_memory, get_chain, push_convo2db
|
10 |
+
from app_config import ISSUES, SOURCES, source2label
|
11 |
|
12 |
+
logger = get_logger(__name__)
|
13 |
openai_api_key = os.environ['OPENAI_API_KEY']
|
14 |
+
memories = {
|
15 |
+
'memoryA': {"issue": ISSUES[0], "source": SOURCES[0]},
|
16 |
+
'memoryB': {"issue": ISSUES[0], "source": SOURCES[1]},
|
17 |
+
'commonMemory': {"issue": ISSUES[0], "source": SOURCES[0]}
|
18 |
+
}
|
19 |
+
if 'db_client' not in st.session_state:
|
20 |
+
st.session_state["db_client"] = get_db_client()
|
21 |
+
if 'previous_sourceA' not in st.session_state:
|
22 |
+
st.session_state['previous_sourceA'] = SOURCES[0]
|
23 |
+
if 'previous_sourceB' not in st.session_state:
|
24 |
+
st.session_state['previous_sourceB'] = SOURCES[1]
|
25 |
|
26 |
def delete_last_message(memory):
|
27 |
last_prompt = memory.chat_memory.messages[-2].content
|
|
|
33 |
memory.chat_memory.add_ai_message(new_message)
|
34 |
|
35 |
def regenerateA():
|
36 |
+
last_prompt = delete_last_message(memoryA)
|
37 |
+
new_response = llm_chainA.predict(input=last_prompt, stop=stopperA)
|
38 |
col1.chat_message("user").write(last_prompt)
|
39 |
col1.chat_message("assistant").write(new_response)
|
40 |
+
return new_response
|
41 |
|
42 |
def regenerateB():
|
43 |
+
last_prompt = delete_last_message(memoryB)
|
44 |
+
new_response = llm_chainB.predict(input=last_prompt, stop=stopperB)
|
45 |
col2.chat_message("user").write(last_prompt)
|
46 |
col2.chat_message("assistant").write(new_response)
|
47 |
+
return new_response
|
48 |
|
49 |
def replaceA():
|
50 |
+
last_prompt = memoryB.chat_memory.messages[-2].content
|
51 |
+
new_message = memoryB.chat_memory.messages[-1].content
|
52 |
+
replace_last_message(memoryA, new_message)
|
53 |
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
|
54 |
|
55 |
+
new_battle_result(st.session_state['db_client'],
|
56 |
+
st.session_state['comparison_id'],
|
57 |
+
st.session_state['convo_id'],
|
58 |
+
username, sourceA, sourceB, winner='model_two'
|
59 |
+
)
|
60 |
+
|
61 |
def replaceB():
|
62 |
+
last_prompt = memoryA.chat_memory.messages[-2].content
|
63 |
+
new_message = memoryA.chat_memory.messages[-1].content
|
64 |
+
replace_last_message(memoryB, new_message)
|
65 |
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
|
66 |
|
67 |
+
new_battle_result(st.session_state['db_client'],
|
68 |
+
st.session_state['comparison_id'],
|
69 |
+
st.session_state['convo_id'],
|
70 |
+
username, sourceA, sourceB, winner='model_one'
|
71 |
+
)
|
72 |
+
|
73 |
def regenerateBoth():
|
74 |
+
promt_ts = dt.datetime.now(tz=dt.timezone.utc)
|
75 |
+
new_battle_result(st.session_state['db_client'],
|
76 |
+
st.session_state['comparison_id'],
|
77 |
+
st.session_state['convo_id'],
|
78 |
+
username, sourceA, sourceB, winner='both_bad'
|
79 |
+
)
|
80 |
+
|
81 |
+
responseA = regenerateA()
|
82 |
+
responseB = regenerateB()
|
83 |
+
completion_ts = dt.datetime.now(tz=dt.timezone.utc)
|
84 |
+
new_comparison(st.session_state['db_client'], promt_ts, completion_ts,
|
85 |
+
st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB)
|
86 |
|
87 |
def bothGood():
|
88 |
+
if len(memoryA.buffer_as_messages) == 1:
|
89 |
pass
|
90 |
else:
|
91 |
+
i = random.choice([memoryA, memoryB])
|
92 |
+
last_prompt = i.chat_memory.messages[-2].content
|
93 |
+
last_reponse = i.chat_memory.messages[-1].content
|
94 |
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":last_reponse})
|
95 |
+
|
96 |
+
new_battle_result(st.session_state['db_client'],
|
97 |
+
st.session_state['comparison_id'],
|
98 |
+
st.session_state['convo_id'],
|
99 |
+
username, sourceA, sourceB, winner='tie'
|
100 |
+
)
|
101 |
|
102 |
with st.sidebar:
|
103 |
+
username = st.text_input("Username", value='ivnban-ctl', max_chars=30)
|
104 |
+
issue = st.selectbox("Select an Issue", ISSUES, index=0,
|
105 |
+
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
106 |
)
|
107 |
supported_languages = ['English', "Spanish"] if issue == "Anxiety" else ['English']
|
108 |
language = st.selectbox("Select a Language", supported_languages, index=0,
|
109 |
+
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
110 |
)
|
111 |
|
112 |
with st.expander("Model A"):
|
113 |
temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1)
|
114 |
+
sourceA = st.selectbox("Select a source Model A", SOURCES, index=0,
|
115 |
+
format_func=source2label
|
116 |
)
|
117 |
with st.expander("Model B"):
|
118 |
temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1)
|
119 |
+
sourceB = st.selectbox("Select a source Model B", SOURCES, index=1,
|
120 |
+
format_func=source2label
|
121 |
)
|
122 |
|
123 |
sbcol1, sbcol2 = st.columns(2)
|
|
|
131 |
# regenB = sbcol2.button("Regenerate B", on_click=regenerateB)
|
132 |
clear = st.button("Clear History", on_click=clear_memory, args=(memories,))
|
133 |
|
134 |
+
memories = {
|
135 |
+
'memoryA': {"issue": issue, "source": sourceA},
|
136 |
+
'memoryB': {"issue": issue, "source": sourceB},
|
137 |
+
'commonMemory': {"issue": issue, "source": SOURCES[0]}
|
138 |
+
}
|
139 |
+
changed_source = any([
|
140 |
+
st.session_state['previous_sourceA'] != sourceA,
|
141 |
+
st.session_state['previous_sourceB'] != sourceB
|
142 |
+
])
|
143 |
+
create_memory_add_initial_message(memories, username, language, changed_source=changed_source)
|
144 |
+
memoryA = st.session_state[list(memories.keys())[0]]
|
145 |
+
memoryB = st.session_state[list(memories.keys())[1]]
|
146 |
+
llm_chainA, stopperA = get_chain(issue, language, sourceA, memoryA, temperatureA)
|
147 |
+
llm_chainB, stopperB = get_chain(issue, language, sourceB, memoryB, temperatureB)
|
148 |
|
149 |
st.title(f"💬 History")
|
150 |
for msg in st.session_state['commonMemory'].buffer_as_messages:
|
|
|
171 |
return True
|
172 |
|
173 |
if prompt := st.chat_input(disabled=disable_chat()):
|
174 |
+
if 'convo_id' not in st.session_state:
|
175 |
+
push_convo2db(memories, username, language)
|
176 |
+
|
177 |
+
promt_ts = dt.datetime.now(tz=dt.timezone.utc)
|
178 |
col1.chat_message("user").write(prompt)
|
179 |
col2.chat_message("user").write(prompt)
|
180 |
|
181 |
+
responseA = llm_chainA.predict(input=prompt, stop=stopperA)
|
182 |
+
responseB = llm_chainB.predict(input=prompt, stop=stopperB)
|
183 |
+
completion_ts = dt.datetime.now(tz=dt.timezone.utc)
|
184 |
|
185 |
+
new_comparison(st.session_state['db_client'], promt_ts, completion_ts,
|
186 |
+
st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB)
|
187 |
+
|
188 |
col1.chat_message("assistant").write(responseA)
|
189 |
col2.chat_message("assistant").write(responseB)
|
190 |
|
pages/manual_comparisor.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import datetime as dt
|
5 |
+
import streamlit as st
|
6 |
+
from streamlit.logger import get_logger
|
7 |
+
from langchain.schema.messages import HumanMessage
|
8 |
+
from mongo_utils import get_db_client, new_battle_result, get_non_assesed_comparison, new_completion_error
|
9 |
+
from app_config import ISSUES, SOURCES
|
10 |
+
|
11 |
+
logger = get_logger(__name__)
|
12 |
+
openai_api_key = os.environ['OPENAI_API_KEY']
|
13 |
+
if 'db_client' not in st.session_state:
|
14 |
+
st.session_state["db_client"] = get_db_client()
|
15 |
+
|
16 |
+
def disable_buttons():
|
17 |
+
return len(comparison) == 0
|
18 |
+
|
19 |
+
def replaceA():
|
20 |
+
new_battle_result(st.session_state['db_client'],
|
21 |
+
st.session_state['comparison_id'],
|
22 |
+
st.session_state['convo_id'],
|
23 |
+
username, sourceA, sourceB, winner='model_two'
|
24 |
+
)
|
25 |
+
|
26 |
+
def replaceB():
|
27 |
+
new_battle_result(st.session_state['db_client'],
|
28 |
+
st.session_state['comparison_id'],
|
29 |
+
st.session_state['convo_id'],
|
30 |
+
username, sourceA, sourceB, winner='model_one'
|
31 |
+
)
|
32 |
+
|
33 |
+
def regenerateBoth():
|
34 |
+
new_battle_result(st.session_state['db_client'],
|
35 |
+
st.session_state['comparison_id'],
|
36 |
+
st.session_state['convo_id'],
|
37 |
+
username, sourceA, sourceB, winner='both_bad'
|
38 |
+
)
|
39 |
+
|
40 |
+
def bothGood():
|
41 |
+
new_battle_result(st.session_state['db_client'],
|
42 |
+
st.session_state['comparison_id'],
|
43 |
+
st.session_state['convo_id'],
|
44 |
+
username, sourceA, sourceB, winner='tie'
|
45 |
+
)
|
46 |
+
|
47 |
+
def error2db(model):
|
48 |
+
new_completion_error(st.session_state['db_client'],
|
49 |
+
st.session_state['comparison_id'],
|
50 |
+
username, model
|
51 |
+
)
|
52 |
+
|
53 |
+
def error2dbA():
|
54 |
+
error2db(sourceA)
|
55 |
+
|
56 |
+
def error2dbA():
|
57 |
+
error2db(sourceB)
|
58 |
+
|
59 |
+
with st.sidebar:
|
60 |
+
username = st.text_input("Username", value='ivnban-ctl', max_chars=30)
|
61 |
+
|
62 |
+
comparison = get_non_assesed_comparison(st.session_state["db_client"], username)
|
63 |
+
|
64 |
+
with st.sidebar:
|
65 |
+
|
66 |
+
sbcol1, sbcol2 = st.columns(2)
|
67 |
+
beta = sbcol1.button("A is better", on_click=replaceB, disabled=disable_buttons())
|
68 |
+
betb = sbcol2.button("B is better", on_click=replaceA, disabled=disable_buttons())
|
69 |
+
|
70 |
+
same = sbcol1.button("Tie", on_click=bothGood, disabled=disable_buttons())
|
71 |
+
bbad = sbcol2.button("Both are bad", on_click=regenerateBoth, disabled=disable_buttons())
|
72 |
+
|
73 |
+
errorA = sbcol1.button("Error in A", on_click=error2dbA, disabled=disable_buttons())
|
74 |
+
errorB = sbcol2.button("Error in B", on_click=error2dbA, disabled=disable_buttons())
|
75 |
+
|
76 |
+
if len(comparison) > 0:
|
77 |
+
|
78 |
+
st.session_state['comparison_id'] = comparison[0]["_id"]
|
79 |
+
st.session_state['convo_id'] = comparison[0]["convo_id"]
|
80 |
+
st.session_state["disabled_buttons"] = False
|
81 |
+
|
82 |
+
st.sidebar.text_input("Issue", value=comparison[0]['convo_info'][0]['issue'], disabled=True)
|
83 |
+
|
84 |
+
st.title(f"💬 History")
|
85 |
+
|
86 |
+
for msg in comparison[0]['chat_history'].split("\n"):
|
87 |
+
parts = msg.split(":")
|
88 |
+
if len(parts) > 1:
|
89 |
+
role = "user" if parts[0] == 'helper' else "assistant"
|
90 |
+
st.chat_message(role).write(parts[1])
|
91 |
+
|
92 |
+
col1, col2 = st.columns(2)
|
93 |
+
col1.title(f"💬 Simulator A")
|
94 |
+
col2.title(f"💬 Simulator B")
|
95 |
+
|
96 |
+
selectedA = random.choice(['model_one', 'model_two'])
|
97 |
+
selectedB = "model_two" if selectedA == "model_one" else "model_one"
|
98 |
+
logger.info(f"selected A is {selectedA} and B is {selectedB}")
|
99 |
+
sourceA = comparison[0]['convo_info'][0][selectedA]
|
100 |
+
sourceB = comparison[0]['convo_info'][0][selectedB]
|
101 |
+
col1.chat_message("user").write(comparison[0]["prompt"])
|
102 |
+
col2.chat_message("user").write(comparison[0]["prompt"])
|
103 |
+
|
104 |
+
col1.chat_message("assistant").write(comparison[0][f"compeltion_{selectedA}"])
|
105 |
+
col2.chat_message("assistant").write(comparison[0][f"compeltion_{selectedB}"])
|
106 |
+
|
107 |
+
else:
|
108 |
+
st.write("No Comparisons left to Check")
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
scipy==1.11.1
|
2 |
openai==0.28.0
|
3 |
-
langchain==0.0.281
|
|
|
|
1 |
scipy==1.11.1
|
2 |
openai==0.28.0
|
3 |
+
langchain==0.0.281
|
4 |
+
pymongo==4.5.0
|