geekyrakshit commited on
Commit
0cca41e
Β·
1 Parent(s): abfa06d

add: wandb api key based authentication

Browse files
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
2
 
3
  intro_page = st.Page(
4
- "application_pages/intro_page.py", title="Introduction", icon=":material/guardian:"
5
  )
6
  chat_page = st.Page(
7
  "application_pages/chat_app.py",
 
1
  import streamlit as st
2
 
3
  intro_page = st.Page(
4
+ "application_pages/intro_page.py", title="Authenticate", icon=":material/guardian:"
5
  )
6
  chat_page = st.Page(
7
  "application_pages/chat_app.py",
application_pages/chat_app.py CHANGED
@@ -1,8 +1,6 @@
1
  import importlib
2
- import os
3
 
4
  import streamlit as st
5
- import weave
6
  from dotenv import load_dotenv
7
 
8
  from guardrails_genie.guardrails import GuardrailManager
@@ -11,9 +9,6 @@ from guardrails_genie.llm import OpenAIModel
11
 
12
  def initialize_session_state():
13
  load_dotenv()
14
- weave.init(project_name=os.getenv("WEAVE_PROJECT"))
15
- if "weave_project_name" not in st.session_state:
16
- st.session_state.weave_project_name = "guardrails-genie"
17
  if "guardrails" not in st.session_state:
18
  st.session_state.guardrails = []
19
  if "guardrail_names" not in st.session_state:
@@ -121,70 +116,69 @@ def initialize_guardrails():
121
  )
122
 
123
 
124
- initialize_session_state()
125
- st.title(":material/robot: Guardrails Genie Playground")
126
-
127
- weave_project_name = st.sidebar.text_input(
128
- "Weave project name", value=st.session_state.weave_project_name
129
- )
130
- st.session_state.weave_project_name = weave_project_name
131
- if st.session_state.weave_project_name != "":
132
- weave.init(project_name=st.session_state.weave_project_name)
133
-
134
- openai_model = st.sidebar.selectbox(
135
- "OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
136
- )
137
- chat_condition = openai_model != ""
138
-
139
- guardrails = []
140
-
141
- guardrail_names = st.sidebar.multiselect(
142
- label="Select Guardrails",
143
- options=[
144
- cls_name
145
- for cls_name, cls_obj in vars(
146
- importlib.import_module("guardrails_genie.guardrails")
147
- ).items()
148
- if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
149
- ],
150
- )
151
- st.session_state.guardrail_names = guardrail_names
152
-
153
- if st.sidebar.button("Initialize Guardrails") and chat_condition:
154
- st.session_state.initialize_guardrails = True
155
-
156
- if st.session_state.initialize_guardrails:
157
- with st.sidebar.status("Initializing Guardrails..."):
158
- initialize_guardrails()
159
- st.session_state.llm_model = OpenAIModel(model_name=openai_model)
160
-
161
- user_prompt = st.text_area("User Prompt", value="")
162
- st.session_state.user_prompt = user_prompt
163
-
164
- test_guardrails_button = st.button("Test Guardrails")
165
- st.session_state.test_guardrails = test_guardrails_button
166
-
167
- if st.session_state.test_guardrails:
168
- with st.sidebar.status("Running Guardrails..."):
169
- guardrails_response, call = st.session_state.guardrails_manager.guard.call(
170
- st.session_state.guardrails_manager, prompt=st.session_state.user_prompt
171
- )
172
 
173
- if guardrails_response["safe"]:
174
- st.markdown(
175
- f"\n\n---\nPrompt is safe! Explore guardrail trace on [Weave]({call.ui_url})\n\n---\n"
176
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- with st.sidebar.status("Generating response from LLM..."):
179
- response, call = st.session_state.llm_model.predict.call(
180
- st.session_state.llm_model,
181
- user_prompts=st.session_state.user_prompt,
 
 
 
182
  )
183
- st.markdown(
184
- response.choices[0].message.content
185
- + f"\n\n---\nExplore LLM generation trace on [Weave]({call.ui_url})"
186
- )
187
- else:
188
- st.warning("Prompt is not safe!")
189
- st.markdown(guardrails_response["summary"])
190
- st.markdown(f"Explore prompt trace on [Weave]({call.ui_url})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import importlib
 
2
 
3
  import streamlit as st
 
4
  from dotenv import load_dotenv
5
 
6
  from guardrails_genie.guardrails import GuardrailManager
 
9
 
10
  def initialize_session_state():
11
  load_dotenv()
 
 
 
12
  if "guardrails" not in st.session_state:
13
  st.session_state.guardrails = []
14
  if "guardrail_names" not in st.session_state:
 
116
  )
117
 
118
 
119
+ if st.session_state.is_authenticated:
120
+ initialize_session_state()
121
+ st.title(":material/robot: Guardrails Genie Playground")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ openai_model = st.sidebar.selectbox(
124
+ "OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
125
+ )
126
+ chat_condition = openai_model != ""
127
+
128
+ guardrails = []
129
+
130
+ guardrail_names = st.sidebar.multiselect(
131
+ label="Select Guardrails",
132
+ options=[
133
+ cls_name
134
+ for cls_name, cls_obj in vars(
135
+ importlib.import_module("guardrails_genie.guardrails")
136
+ ).items()
137
+ if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
138
+ ],
139
+ )
140
+ st.session_state.guardrail_names = guardrail_names
141
+
142
+ if st.sidebar.button("Initialize Guardrails") and chat_condition:
143
+ st.session_state.initialize_guardrails = True
144
+
145
+ if st.session_state.initialize_guardrails:
146
+ with st.sidebar.status("Initializing Guardrails..."):
147
+ initialize_guardrails()
148
+ st.session_state.llm_model = OpenAIModel(model_name=openai_model)
149
+
150
+ user_prompt = st.text_area("User Prompt", value="")
151
+ st.session_state.user_prompt = user_prompt
152
+
153
+ test_guardrails_button = st.button("Test Guardrails")
154
+ st.session_state.test_guardrails = test_guardrails_button
155
 
156
+ if st.session_state.test_guardrails:
157
+ with st.sidebar.status("Running Guardrails..."):
158
+ guardrails_response, call = (
159
+ st.session_state.guardrails_manager.guard.call(
160
+ st.session_state.guardrails_manager,
161
+ prompt=st.session_state.user_prompt,
162
+ )
163
  )
164
+
165
+ if guardrails_response["safe"]:
166
+ st.markdown(
167
+ f"\n\n---\nPrompt is safe! Explore guardrail trace on [Weave]({call.ui_url})\n\n---\n"
168
+ )
169
+
170
+ with st.sidebar.status("Generating response from LLM..."):
171
+ response, call = st.session_state.llm_model.predict.call(
172
+ st.session_state.llm_model,
173
+ user_prompts=st.session_state.user_prompt,
174
+ )
175
+ st.markdown(
176
+ response.choices[0].message.content
177
+ + f"\n\n---\nExplore LLM generation trace on [Weave]({call.ui_url})"
178
+ )
179
+ else:
180
+ st.warning("Prompt is not safe!")
181
+ st.markdown(guardrails_response["summary"])
182
+ st.markdown(f"Explore prompt trace on [Weave]({call.ui_url})")
183
+ else:
184
+ st.warning("Please authenticate your WandB account to use this feature.")
application_pages/evaluation_app.py CHANGED
@@ -13,8 +13,6 @@ from guardrails_genie.metrics import AccuracyMetric
13
 
14
  def initialize_session_state():
15
  load_dotenv()
16
- if "weave_project_name" not in st.session_state:
17
- st.session_state.weave_project_name = "guardrails-genie"
18
  if "uploaded_file" not in st.session_state:
19
  st.session_state.uploaded_file = None
20
  if "dataset_name" not in st.session_state:
@@ -124,108 +122,106 @@ def initialize_guardrails():
124
  )
125
 
126
 
127
- initialize_session_state()
128
- st.title(":material/monitoring: Evaluation")
129
-
130
- weave_project_name = st.sidebar.text_input(
131
- "Weave project name", value=st.session_state.weave_project_name
132
- )
133
- st.session_state.weave_project_name = weave_project_name
134
- if st.session_state.weave_project_name != "":
135
- weave.init(project_name=st.session_state.weave_project_name)
136
-
137
- uploaded_file = st.sidebar.file_uploader(
138
- "Upload the evaluation dataset as a CSV file", type="csv"
139
- )
140
- st.session_state.uploaded_file = uploaded_file
141
-
142
- if st.session_state.uploaded_file is not None:
143
- dataset_name = st.sidebar.text_input("Evaluation dataset name", value=None)
144
- st.session_state.dataset_name = dataset_name
145
- preview_in_app = st.sidebar.toggle("Preview in app", value=False)
146
- st.session_state.preview_in_app = preview_in_app
147
- publish_dataset_button = st.sidebar.button("Publish dataset")
148
- st.session_state.publish_dataset_button = publish_dataset_button
149
-
150
- if st.session_state.publish_dataset_button and (
151
- st.session_state.dataset_name is not None
152
- and st.session_state.dataset_name != ""
153
- ):
154
-
155
- with st.expander("Evaluation Dataset Preview", expanded=True):
156
- dataframe = pd.read_csv(st.session_state.uploaded_file)
157
- data_list = dataframe.to_dict(orient="records")
158
-
159
- dataset = weave.Dataset(name=st.session_state.dataset_name, rows=data_list)
160
- st.session_state.dataset_ref = weave.publish(dataset)
161
-
162
- entity = st.session_state.dataset_ref.entity
163
- project = st.session_state.dataset_ref.project
164
- dataset_name = st.session_state.dataset_name
165
- digest = st.session_state.dataset_ref._digest
166
- dataset_url = f"https://wandb.ai/{entity}/{project}/weave/objects/{dataset_name}/versions/{digest}"
167
- st.markdown(f"Dataset published to [**Weave**]({dataset_url})")
168
-
169
- if preview_in_app:
170
- st.dataframe(dataframe.head(20))
171
- if len(dataframe) > 20:
172
- st.markdown(
173
- f"⚠️ Dataset is too large to preview in app, please explore in the [**Weave UI**]({dataset_url})"
174
- )
175
 
176
- st.session_state.is_dataset_published = True
177
-
178
- if st.session_state.is_dataset_published:
179
- guardrail_names = st.sidebar.multiselect(
180
- "Select Guardrails",
181
- options=[
182
- cls_name
183
- for cls_name, cls_obj in vars(
184
- import_module("guardrails_genie.guardrails")
185
- ).items()
186
- if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
187
- ],
188
- )
189
- st.session_state.guardrail_names = guardrail_names
190
-
191
- initialize_guardrails()
192
- evaluation_name = st.sidebar.text_input("Evaluation Name", value="")
193
- st.session_state.evaluation_name = evaluation_name
194
-
195
- start_evaluations_button = st.sidebar.button("Start Evaluations")
196
- st.session_state.start_evaluations_button = start_evaluations_button
197
- if st.session_state.start_evaluations_button:
198
- # st.write(len(st.session_state.guardrails))
199
- evaluation = weave.Evaluation(
200
- dataset=st.session_state.dataset_ref,
201
- scorers=[AccuracyMetric()],
202
- streamlit_mode=True,
203
- )
204
- with st.expander("Evaluation Results", expanded=True):
205
- evaluation_summary, call = asyncio.run(
206
- evaluation.evaluate.call(
207
- evaluation,
208
- GuardrailManager(guardrails=st.session_state.guardrails),
209
- __weave={
210
- "display_name": (
211
- "Evaluation.evaluate"
212
- if st.session_state.evaluation_name == ""
213
- else "Evaluation.evaluate:"
214
- + st.session_state.evaluation_name
215
- )
216
- },
217
- )
218
- )
219
- x_axis = list(evaluation_summary["AccuracyMetric"].keys())
220
- y_axis = [
221
- evaluation_summary["AccuracyMetric"][x_axis_item]
222
- for x_axis_item in x_axis
223
- ]
224
- st.bar_chart(
225
- pd.DataFrame({"Metric": x_axis, "Score": y_axis}),
226
- x="Metric",
227
- y="Score",
228
  )
229
- st.markdown(
230
- f"Explore the entire evaluation trace table in [Weave]({call.ui_url})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def initialize_session_state():
15
  load_dotenv()
 
 
16
  if "uploaded_file" not in st.session_state:
17
  st.session_state.uploaded_file = None
18
  if "dataset_name" not in st.session_state:
 
122
  )
123
 
124
 
125
+ if st.session_state.is_authenticated:
126
+ initialize_session_state()
127
+ st.title(":material/monitoring: Evaluation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ uploaded_file = st.sidebar.file_uploader(
130
+ "Upload the evaluation dataset as a CSV file", type="csv"
131
+ )
132
+ st.session_state.uploaded_file = uploaded_file
133
+
134
+ if st.session_state.uploaded_file is not None:
135
+ dataset_name = st.sidebar.text_input("Evaluation dataset name", value=None)
136
+ st.session_state.dataset_name = dataset_name
137
+ preview_in_app = st.sidebar.toggle("Preview in app", value=False)
138
+ st.session_state.preview_in_app = preview_in_app
139
+ publish_dataset_button = st.sidebar.button("Publish dataset")
140
+ st.session_state.publish_dataset_button = publish_dataset_button
141
+
142
+ if st.session_state.publish_dataset_button and (
143
+ st.session_state.dataset_name is not None
144
+ and st.session_state.dataset_name != ""
145
+ ):
146
+
147
+ with st.expander("Evaluation Dataset Preview", expanded=True):
148
+ dataframe = pd.read_csv(st.session_state.uploaded_file)
149
+ data_list = dataframe.to_dict(orient="records")
150
+
151
+ dataset = weave.Dataset(
152
+ name=st.session_state.dataset_name, rows=data_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  )
154
+ st.session_state.dataset_ref = weave.publish(dataset)
155
+
156
+ entity = st.session_state.dataset_ref.entity
157
+ project = st.session_state.dataset_ref.project
158
+ dataset_name = st.session_state.dataset_name
159
+ digest = st.session_state.dataset_ref._digest
160
+ dataset_url = f"https://wandb.ai/{entity}/{project}/weave/objects/{dataset_name}/versions/{digest}"
161
+ st.markdown(f"Dataset published to [**Weave**]({dataset_url})")
162
+
163
+ if preview_in_app:
164
+ st.dataframe(dataframe.head(20))
165
+ if len(dataframe) > 20:
166
+ st.markdown(
167
+ f"⚠️ Dataset is too large to preview in app, please explore in the [**Weave UI**]({dataset_url})"
168
+ )
169
+
170
+ st.session_state.is_dataset_published = True
171
+
172
+ if st.session_state.is_dataset_published:
173
+ guardrail_names = st.sidebar.multiselect(
174
+ "Select Guardrails",
175
+ options=[
176
+ cls_name
177
+ for cls_name, cls_obj in vars(
178
+ import_module("guardrails_genie.guardrails")
179
+ ).items()
180
+ if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
181
+ ],
182
+ )
183
+ st.session_state.guardrail_names = guardrail_names
184
+
185
+ initialize_guardrails()
186
+ evaluation_name = st.sidebar.text_input("Evaluation Name", value="")
187
+ st.session_state.evaluation_name = evaluation_name
188
+
189
+ start_evaluations_button = st.sidebar.button("Start Evaluations")
190
+ st.session_state.start_evaluations_button = start_evaluations_button
191
+ if st.session_state.start_evaluations_button:
192
+ # st.write(len(st.session_state.guardrails))
193
+ evaluation = weave.Evaluation(
194
+ dataset=st.session_state.dataset_ref,
195
+ scorers=[AccuracyMetric()],
196
+ streamlit_mode=True,
197
  )
198
+ with st.expander("Evaluation Results", expanded=True):
199
+ evaluation_summary, call = asyncio.run(
200
+ evaluation.evaluate.call(
201
+ evaluation,
202
+ GuardrailManager(guardrails=st.session_state.guardrails),
203
+ __weave={
204
+ "display_name": (
205
+ "Evaluation.evaluate"
206
+ if st.session_state.evaluation_name == ""
207
+ else "Evaluation.evaluate:"
208
+ + st.session_state.evaluation_name
209
+ )
210
+ },
211
+ )
212
+ )
213
+ x_axis = list(evaluation_summary["AccuracyMetric"].keys())
214
+ y_axis = [
215
+ evaluation_summary["AccuracyMetric"][x_axis_item]
216
+ for x_axis_item in x_axis
217
+ ]
218
+ st.bar_chart(
219
+ pd.DataFrame({"Metric": x_axis, "Score": y_axis}),
220
+ x="Metric",
221
+ y="Score",
222
+ )
223
+ st.markdown(
224
+ f"Explore the entire evaluation trace table in [Weave]({call.ui_url})"
225
+ )
226
+ else:
227
+ st.warning("Please authenticate your WandB account to use this feature.")
application_pages/intro_page.py CHANGED
@@ -1,7 +1,54 @@
1
  import streamlit as st
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  st.title("πŸ§žβ€β™‚οΈ Guardrails Genie")
4
 
5
  st.write(
6
  "Guardrails-Genie is a tool that helps you implement guardrails in your LLM applications."
7
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
 
3
+ import wandb
4
+
5
+
6
+ def initialize_session_state():
7
+ if "weave_project_name" not in st.session_state:
8
+ st.session_state.weave_project_name = "guardrails-genie"
9
+ if "weave_entity_name" not in st.session_state:
10
+ st.session_state.weave_entity_name = ""
11
+ if "wandb_api_key" not in st.session_state:
12
+ st.session_state.wandb_api_key = ""
13
+ if "authenticate_button" not in st.session_state:
14
+ st.session_state.authenticate_button = False
15
+ if "is_authenticated" not in st.session_state:
16
+ st.session_state.is_authenticated = False
17
+
18
+
19
+ initialize_session_state()
20
  st.title("πŸ§žβ€β™‚οΈ Guardrails Genie")
21
 
22
  st.write(
23
  "Guardrails-Genie is a tool that helps you implement guardrails in your LLM applications."
24
  )
25
+
26
+ with st.expander("Login to Your WandB Account", expanded=True):
27
+ st.markdown(
28
+ "Get your Wandb API key from [https://wandb.ai/authorize](https://wandb.ai/authorize)"
29
+ )
30
+ weave_entity_name = st.text_input(
31
+ "Weave Entity Name", value=st.session_state.weave_entity_name
32
+ )
33
+ st.session_state.weave_entity_name = weave_entity_name
34
+ weave_project_name = st.text_input(
35
+ "Weave Project Name", value=st.session_state.weave_project_name
36
+ )
37
+ st.session_state.weave_project_name = weave_project_name
38
+ wandb_api_key = st.text_input("Wandb API Key", value="", type="password")
39
+ st.session_state.wandb_api_key = wandb_api_key
40
+ authenticate_button = st.button("Authenticate")
41
+ st.session_state.authenticate_button = authenticate_button
42
+
43
+ if authenticate_button and (
44
+ st.session_state.wandb_api_key != ""
45
+ and st.session_state.weave_project_name != ""
46
+ ):
47
+ is_wandb_logged_in = wandb.login(
48
+ key=st.session_state.wandb_api_key, relogin=True
49
+ )
50
+ if is_wandb_logged_in:
51
+ st.session_state.is_authenticated = True
52
+ st.success("Logged in to Wandb")
53
+ else:
54
+ st.error("Failed to log in to Wandb")