Spaces:
Running
Running
geekyrakshit
commited on
Commit
Β·
0cca41e
1
Parent(s):
abfa06d
add: wandb api key based authentication
Browse files- app.py +1 -1
- application_pages/chat_app.py +63 -69
- application_pages/evaluation_app.py +100 -104
- application_pages/intro_page.py +47 -0
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
|
3 |
intro_page = st.Page(
|
4 |
-
"application_pages/intro_page.py", title="
|
5 |
)
|
6 |
chat_page = st.Page(
|
7 |
"application_pages/chat_app.py",
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
intro_page = st.Page(
|
4 |
+
"application_pages/intro_page.py", title="Authenticate", icon=":material/guardian:"
|
5 |
)
|
6 |
chat_page = st.Page(
|
7 |
"application_pages/chat_app.py",
|
application_pages/chat_app.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
import importlib
|
2 |
-
import os
|
3 |
|
4 |
import streamlit as st
|
5 |
-
import weave
|
6 |
from dotenv import load_dotenv
|
7 |
|
8 |
from guardrails_genie.guardrails import GuardrailManager
|
@@ -11,9 +9,6 @@ from guardrails_genie.llm import OpenAIModel
|
|
11 |
|
12 |
def initialize_session_state():
|
13 |
load_dotenv()
|
14 |
-
weave.init(project_name=os.getenv("WEAVE_PROJECT"))
|
15 |
-
if "weave_project_name" not in st.session_state:
|
16 |
-
st.session_state.weave_project_name = "guardrails-genie"
|
17 |
if "guardrails" not in st.session_state:
|
18 |
st.session_state.guardrails = []
|
19 |
if "guardrail_names" not in st.session_state:
|
@@ -121,70 +116,69 @@ def initialize_guardrails():
|
|
121 |
)
|
122 |
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
weave_project_name = st.sidebar.text_input(
|
128 |
-
"Weave project name", value=st.session_state.weave_project_name
|
129 |
-
)
|
130 |
-
st.session_state.weave_project_name = weave_project_name
|
131 |
-
if st.session_state.weave_project_name != "":
|
132 |
-
weave.init(project_name=st.session_state.weave_project_name)
|
133 |
-
|
134 |
-
openai_model = st.sidebar.selectbox(
|
135 |
-
"OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
|
136 |
-
)
|
137 |
-
chat_condition = openai_model != ""
|
138 |
-
|
139 |
-
guardrails = []
|
140 |
-
|
141 |
-
guardrail_names = st.sidebar.multiselect(
|
142 |
-
label="Select Guardrails",
|
143 |
-
options=[
|
144 |
-
cls_name
|
145 |
-
for cls_name, cls_obj in vars(
|
146 |
-
importlib.import_module("guardrails_genie.guardrails")
|
147 |
-
).items()
|
148 |
-
if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
|
149 |
-
],
|
150 |
-
)
|
151 |
-
st.session_state.guardrail_names = guardrail_names
|
152 |
-
|
153 |
-
if st.sidebar.button("Initialize Guardrails") and chat_condition:
|
154 |
-
st.session_state.initialize_guardrails = True
|
155 |
-
|
156 |
-
if st.session_state.initialize_guardrails:
|
157 |
-
with st.sidebar.status("Initializing Guardrails..."):
|
158 |
-
initialize_guardrails()
|
159 |
-
st.session_state.llm_model = OpenAIModel(model_name=openai_model)
|
160 |
-
|
161 |
-
user_prompt = st.text_area("User Prompt", value="")
|
162 |
-
st.session_state.user_prompt = user_prompt
|
163 |
-
|
164 |
-
test_guardrails_button = st.button("Test Guardrails")
|
165 |
-
st.session_state.test_guardrails = test_guardrails_button
|
166 |
-
|
167 |
-
if st.session_state.test_guardrails:
|
168 |
-
with st.sidebar.status("Running Guardrails..."):
|
169 |
-
guardrails_response, call = st.session_state.guardrails_manager.guard.call(
|
170 |
-
st.session_state.guardrails_manager, prompt=st.session_state.user_prompt
|
171 |
-
)
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
182 |
)
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import importlib
|
|
|
2 |
|
3 |
import streamlit as st
|
|
|
4 |
from dotenv import load_dotenv
|
5 |
|
6 |
from guardrails_genie.guardrails import GuardrailManager
|
|
|
9 |
|
10 |
def initialize_session_state():
|
11 |
load_dotenv()
|
|
|
|
|
|
|
12 |
if "guardrails" not in st.session_state:
|
13 |
st.session_state.guardrails = []
|
14 |
if "guardrail_names" not in st.session_state:
|
|
|
116 |
)
|
117 |
|
118 |
|
119 |
+
if st.session_state.is_authenticated:
|
120 |
+
initialize_session_state()
|
121 |
+
st.title(":material/robot: Guardrails Genie Playground")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
+
openai_model = st.sidebar.selectbox(
|
124 |
+
"OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
|
125 |
+
)
|
126 |
+
chat_condition = openai_model != ""
|
127 |
+
|
128 |
+
guardrails = []
|
129 |
+
|
130 |
+
guardrail_names = st.sidebar.multiselect(
|
131 |
+
label="Select Guardrails",
|
132 |
+
options=[
|
133 |
+
cls_name
|
134 |
+
for cls_name, cls_obj in vars(
|
135 |
+
importlib.import_module("guardrails_genie.guardrails")
|
136 |
+
).items()
|
137 |
+
if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
|
138 |
+
],
|
139 |
+
)
|
140 |
+
st.session_state.guardrail_names = guardrail_names
|
141 |
+
|
142 |
+
if st.sidebar.button("Initialize Guardrails") and chat_condition:
|
143 |
+
st.session_state.initialize_guardrails = True
|
144 |
+
|
145 |
+
if st.session_state.initialize_guardrails:
|
146 |
+
with st.sidebar.status("Initializing Guardrails..."):
|
147 |
+
initialize_guardrails()
|
148 |
+
st.session_state.llm_model = OpenAIModel(model_name=openai_model)
|
149 |
+
|
150 |
+
user_prompt = st.text_area("User Prompt", value="")
|
151 |
+
st.session_state.user_prompt = user_prompt
|
152 |
+
|
153 |
+
test_guardrails_button = st.button("Test Guardrails")
|
154 |
+
st.session_state.test_guardrails = test_guardrails_button
|
155 |
|
156 |
+
if st.session_state.test_guardrails:
|
157 |
+
with st.sidebar.status("Running Guardrails..."):
|
158 |
+
guardrails_response, call = (
|
159 |
+
st.session_state.guardrails_manager.guard.call(
|
160 |
+
st.session_state.guardrails_manager,
|
161 |
+
prompt=st.session_state.user_prompt,
|
162 |
+
)
|
163 |
)
|
164 |
+
|
165 |
+
if guardrails_response["safe"]:
|
166 |
+
st.markdown(
|
167 |
+
f"\n\n---\nPrompt is safe! Explore guardrail trace on [Weave]({call.ui_url})\n\n---\n"
|
168 |
+
)
|
169 |
+
|
170 |
+
with st.sidebar.status("Generating response from LLM..."):
|
171 |
+
response, call = st.session_state.llm_model.predict.call(
|
172 |
+
st.session_state.llm_model,
|
173 |
+
user_prompts=st.session_state.user_prompt,
|
174 |
+
)
|
175 |
+
st.markdown(
|
176 |
+
response.choices[0].message.content
|
177 |
+
+ f"\n\n---\nExplore LLM generation trace on [Weave]({call.ui_url})"
|
178 |
+
)
|
179 |
+
else:
|
180 |
+
st.warning("Prompt is not safe!")
|
181 |
+
st.markdown(guardrails_response["summary"])
|
182 |
+
st.markdown(f"Explore prompt trace on [Weave]({call.ui_url})")
|
183 |
+
else:
|
184 |
+
st.warning("Please authenticate your WandB account to use this feature.")
|
application_pages/evaluation_app.py
CHANGED
@@ -13,8 +13,6 @@ from guardrails_genie.metrics import AccuracyMetric
|
|
13 |
|
14 |
def initialize_session_state():
|
15 |
load_dotenv()
|
16 |
-
if "weave_project_name" not in st.session_state:
|
17 |
-
st.session_state.weave_project_name = "guardrails-genie"
|
18 |
if "uploaded_file" not in st.session_state:
|
19 |
st.session_state.uploaded_file = None
|
20 |
if "dataset_name" not in st.session_state:
|
@@ -124,108 +122,106 @@ def initialize_guardrails():
|
|
124 |
)
|
125 |
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
weave_project_name = st.sidebar.text_input(
|
131 |
-
"Weave project name", value=st.session_state.weave_project_name
|
132 |
-
)
|
133 |
-
st.session_state.weave_project_name = weave_project_name
|
134 |
-
if st.session_state.weave_project_name != "":
|
135 |
-
weave.init(project_name=st.session_state.weave_project_name)
|
136 |
-
|
137 |
-
uploaded_file = st.sidebar.file_uploader(
|
138 |
-
"Upload the evaluation dataset as a CSV file", type="csv"
|
139 |
-
)
|
140 |
-
st.session_state.uploaded_file = uploaded_file
|
141 |
-
|
142 |
-
if st.session_state.uploaded_file is not None:
|
143 |
-
dataset_name = st.sidebar.text_input("Evaluation dataset name", value=None)
|
144 |
-
st.session_state.dataset_name = dataset_name
|
145 |
-
preview_in_app = st.sidebar.toggle("Preview in app", value=False)
|
146 |
-
st.session_state.preview_in_app = preview_in_app
|
147 |
-
publish_dataset_button = st.sidebar.button("Publish dataset")
|
148 |
-
st.session_state.publish_dataset_button = publish_dataset_button
|
149 |
-
|
150 |
-
if st.session_state.publish_dataset_button and (
|
151 |
-
st.session_state.dataset_name is not None
|
152 |
-
and st.session_state.dataset_name != ""
|
153 |
-
):
|
154 |
-
|
155 |
-
with st.expander("Evaluation Dataset Preview", expanded=True):
|
156 |
-
dataframe = pd.read_csv(st.session_state.uploaded_file)
|
157 |
-
data_list = dataframe.to_dict(orient="records")
|
158 |
-
|
159 |
-
dataset = weave.Dataset(name=st.session_state.dataset_name, rows=data_list)
|
160 |
-
st.session_state.dataset_ref = weave.publish(dataset)
|
161 |
-
|
162 |
-
entity = st.session_state.dataset_ref.entity
|
163 |
-
project = st.session_state.dataset_ref.project
|
164 |
-
dataset_name = st.session_state.dataset_name
|
165 |
-
digest = st.session_state.dataset_ref._digest
|
166 |
-
dataset_url = f"https://wandb.ai/{entity}/{project}/weave/objects/{dataset_name}/versions/{digest}"
|
167 |
-
st.markdown(f"Dataset published to [**Weave**]({dataset_url})")
|
168 |
-
|
169 |
-
if preview_in_app:
|
170 |
-
st.dataframe(dataframe.head(20))
|
171 |
-
if len(dataframe) > 20:
|
172 |
-
st.markdown(
|
173 |
-
f"β οΈ Dataset is too large to preview in app, please explore in the [**Weave UI**]({dataset_url})"
|
174 |
-
)
|
175 |
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
st.session_state.
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
dataset=st.session_state.dataset_ref,
|
201 |
-
scorers=[AccuracyMetric()],
|
202 |
-
streamlit_mode=True,
|
203 |
-
)
|
204 |
-
with st.expander("Evaluation Results", expanded=True):
|
205 |
-
evaluation_summary, call = asyncio.run(
|
206 |
-
evaluation.evaluate.call(
|
207 |
-
evaluation,
|
208 |
-
GuardrailManager(guardrails=st.session_state.guardrails),
|
209 |
-
__weave={
|
210 |
-
"display_name": (
|
211 |
-
"Evaluation.evaluate"
|
212 |
-
if st.session_state.evaluation_name == ""
|
213 |
-
else "Evaluation.evaluate:"
|
214 |
-
+ st.session_state.evaluation_name
|
215 |
-
)
|
216 |
-
},
|
217 |
-
)
|
218 |
-
)
|
219 |
-
x_axis = list(evaluation_summary["AccuracyMetric"].keys())
|
220 |
-
y_axis = [
|
221 |
-
evaluation_summary["AccuracyMetric"][x_axis_item]
|
222 |
-
for x_axis_item in x_axis
|
223 |
-
]
|
224 |
-
st.bar_chart(
|
225 |
-
pd.DataFrame({"Metric": x_axis, "Score": y_axis}),
|
226 |
-
x="Metric",
|
227 |
-
y="Score",
|
228 |
)
|
229 |
-
st.
|
230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def initialize_session_state():
|
15 |
load_dotenv()
|
|
|
|
|
16 |
if "uploaded_file" not in st.session_state:
|
17 |
st.session_state.uploaded_file = None
|
18 |
if "dataset_name" not in st.session_state:
|
|
|
122 |
)
|
123 |
|
124 |
|
125 |
+
if st.session_state.is_authenticated:
|
126 |
+
initialize_session_state()
|
127 |
+
st.title(":material/monitoring: Evaluation")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
+
uploaded_file = st.sidebar.file_uploader(
|
130 |
+
"Upload the evaluation dataset as a CSV file", type="csv"
|
131 |
+
)
|
132 |
+
st.session_state.uploaded_file = uploaded_file
|
133 |
+
|
134 |
+
if st.session_state.uploaded_file is not None:
|
135 |
+
dataset_name = st.sidebar.text_input("Evaluation dataset name", value=None)
|
136 |
+
st.session_state.dataset_name = dataset_name
|
137 |
+
preview_in_app = st.sidebar.toggle("Preview in app", value=False)
|
138 |
+
st.session_state.preview_in_app = preview_in_app
|
139 |
+
publish_dataset_button = st.sidebar.button("Publish dataset")
|
140 |
+
st.session_state.publish_dataset_button = publish_dataset_button
|
141 |
+
|
142 |
+
if st.session_state.publish_dataset_button and (
|
143 |
+
st.session_state.dataset_name is not None
|
144 |
+
and st.session_state.dataset_name != ""
|
145 |
+
):
|
146 |
+
|
147 |
+
with st.expander("Evaluation Dataset Preview", expanded=True):
|
148 |
+
dataframe = pd.read_csv(st.session_state.uploaded_file)
|
149 |
+
data_list = dataframe.to_dict(orient="records")
|
150 |
+
|
151 |
+
dataset = weave.Dataset(
|
152 |
+
name=st.session_state.dataset_name, rows=data_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
)
|
154 |
+
st.session_state.dataset_ref = weave.publish(dataset)
|
155 |
+
|
156 |
+
entity = st.session_state.dataset_ref.entity
|
157 |
+
project = st.session_state.dataset_ref.project
|
158 |
+
dataset_name = st.session_state.dataset_name
|
159 |
+
digest = st.session_state.dataset_ref._digest
|
160 |
+
dataset_url = f"https://wandb.ai/{entity}/{project}/weave/objects/{dataset_name}/versions/{digest}"
|
161 |
+
st.markdown(f"Dataset published to [**Weave**]({dataset_url})")
|
162 |
+
|
163 |
+
if preview_in_app:
|
164 |
+
st.dataframe(dataframe.head(20))
|
165 |
+
if len(dataframe) > 20:
|
166 |
+
st.markdown(
|
167 |
+
f"β οΈ Dataset is too large to preview in app, please explore in the [**Weave UI**]({dataset_url})"
|
168 |
+
)
|
169 |
+
|
170 |
+
st.session_state.is_dataset_published = True
|
171 |
+
|
172 |
+
if st.session_state.is_dataset_published:
|
173 |
+
guardrail_names = st.sidebar.multiselect(
|
174 |
+
"Select Guardrails",
|
175 |
+
options=[
|
176 |
+
cls_name
|
177 |
+
for cls_name, cls_obj in vars(
|
178 |
+
import_module("guardrails_genie.guardrails")
|
179 |
+
).items()
|
180 |
+
if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
|
181 |
+
],
|
182 |
+
)
|
183 |
+
st.session_state.guardrail_names = guardrail_names
|
184 |
+
|
185 |
+
initialize_guardrails()
|
186 |
+
evaluation_name = st.sidebar.text_input("Evaluation Name", value="")
|
187 |
+
st.session_state.evaluation_name = evaluation_name
|
188 |
+
|
189 |
+
start_evaluations_button = st.sidebar.button("Start Evaluations")
|
190 |
+
st.session_state.start_evaluations_button = start_evaluations_button
|
191 |
+
if st.session_state.start_evaluations_button:
|
192 |
+
# st.write(len(st.session_state.guardrails))
|
193 |
+
evaluation = weave.Evaluation(
|
194 |
+
dataset=st.session_state.dataset_ref,
|
195 |
+
scorers=[AccuracyMetric()],
|
196 |
+
streamlit_mode=True,
|
197 |
)
|
198 |
+
with st.expander("Evaluation Results", expanded=True):
|
199 |
+
evaluation_summary, call = asyncio.run(
|
200 |
+
evaluation.evaluate.call(
|
201 |
+
evaluation,
|
202 |
+
GuardrailManager(guardrails=st.session_state.guardrails),
|
203 |
+
__weave={
|
204 |
+
"display_name": (
|
205 |
+
"Evaluation.evaluate"
|
206 |
+
if st.session_state.evaluation_name == ""
|
207 |
+
else "Evaluation.evaluate:"
|
208 |
+
+ st.session_state.evaluation_name
|
209 |
+
)
|
210 |
+
},
|
211 |
+
)
|
212 |
+
)
|
213 |
+
x_axis = list(evaluation_summary["AccuracyMetric"].keys())
|
214 |
+
y_axis = [
|
215 |
+
evaluation_summary["AccuracyMetric"][x_axis_item]
|
216 |
+
for x_axis_item in x_axis
|
217 |
+
]
|
218 |
+
st.bar_chart(
|
219 |
+
pd.DataFrame({"Metric": x_axis, "Score": y_axis}),
|
220 |
+
x="Metric",
|
221 |
+
y="Score",
|
222 |
+
)
|
223 |
+
st.markdown(
|
224 |
+
f"Explore the entire evaluation trace table in [Weave]({call.ui_url})"
|
225 |
+
)
|
226 |
+
else:
|
227 |
+
st.warning("Please authenticate your WandB account to use this feature.")
|
application_pages/intro_page.py
CHANGED
@@ -1,7 +1,54 @@
|
|
1 |
import streamlit as st
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
st.title("π§ββοΈ Guardrails Genie")
|
4 |
|
5 |
st.write(
|
6 |
"Guardrails-Genie is a tool that helps you implement guardrails in your LLM applications."
|
7 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
+
import wandb
|
4 |
+
|
5 |
+
|
6 |
+
def initialize_session_state():
|
7 |
+
if "weave_project_name" not in st.session_state:
|
8 |
+
st.session_state.weave_project_name = "guardrails-genie"
|
9 |
+
if "weave_entity_name" not in st.session_state:
|
10 |
+
st.session_state.weave_entity_name = ""
|
11 |
+
if "wandb_api_key" not in st.session_state:
|
12 |
+
st.session_state.wandb_api_key = ""
|
13 |
+
if "authenticate_button" not in st.session_state:
|
14 |
+
st.session_state.authenticate_button = False
|
15 |
+
if "is_authenticated" not in st.session_state:
|
16 |
+
st.session_state.is_authenticated = False
|
17 |
+
|
18 |
+
|
19 |
+
initialize_session_state()
|
20 |
st.title("π§ββοΈ Guardrails Genie")
|
21 |
|
22 |
st.write(
|
23 |
"Guardrails-Genie is a tool that helps you implement guardrails in your LLM applications."
|
24 |
)
|
25 |
+
|
26 |
+
with st.expander("Login to Your WandB Account", expanded=True):
|
27 |
+
st.markdown(
|
28 |
+
"Get your Wandb API key from [https://wandb.ai/authorize](https://wandb.ai/authorize)"
|
29 |
+
)
|
30 |
+
weave_entity_name = st.text_input(
|
31 |
+
"Weave Entity Name", value=st.session_state.weave_entity_name
|
32 |
+
)
|
33 |
+
st.session_state.weave_entity_name = weave_entity_name
|
34 |
+
weave_project_name = st.text_input(
|
35 |
+
"Weave Project Name", value=st.session_state.weave_project_name
|
36 |
+
)
|
37 |
+
st.session_state.weave_project_name = weave_project_name
|
38 |
+
wandb_api_key = st.text_input("Wandb API Key", value="", type="password")
|
39 |
+
st.session_state.wandb_api_key = wandb_api_key
|
40 |
+
authenticate_button = st.button("Authenticate")
|
41 |
+
st.session_state.authenticate_button = authenticate_button
|
42 |
+
|
43 |
+
if authenticate_button and (
|
44 |
+
st.session_state.wandb_api_key != ""
|
45 |
+
and st.session_state.weave_project_name != ""
|
46 |
+
):
|
47 |
+
is_wandb_logged_in = wandb.login(
|
48 |
+
key=st.session_state.wandb_api_key, relogin=True
|
49 |
+
)
|
50 |
+
if is_wandb_logged_in:
|
51 |
+
st.session_state.is_authenticated = True
|
52 |
+
st.success("Logged in to Wandb")
|
53 |
+
else:
|
54 |
+
st.error("Failed to log in to Wandb")
|