geekyrakshit commited on
Commit
49dbefc
·
unverified ·
2 Parent(s): 9e04c4b 21eed35

Merge pull request #15 from soumik12345/fix/eval-app

Browse files
README.md CHANGED
@@ -18,11 +18,6 @@ source .venv/bin/activate
18
  ## Run the App
19
 
20
  ```bash
21
- export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
22
- export WEAVE_PROJECT="YOUR_WEAVE_PROJECT"
23
- export WANDB_PROJECT_NAME="YOUR_WANDB_PROJECT_NAME"
24
- export WANDB_ENTITY_NAME="YOUR_WANDB_ENTITY_NAME"
25
- export WANDB_LOG_MODEL="checkpoint"
26
  streamlit run app.py
27
  ```
28
 
 
18
  ## Run the App
19
 
20
  ```bash
 
 
 
 
 
21
  streamlit run app.py
22
  ```
23
 
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
2
 
3
  intro_page = st.Page(
4
- "application_pages/intro_page.py", title="Introduction", icon=":material/guardian:"
5
  )
6
  chat_page = st.Page(
7
  "application_pages/chat_app.py",
@@ -13,23 +13,11 @@ evaluation_page = st.Page(
13
  title="Evaluation",
14
  icon=":material/monitoring:",
15
  )
16
- # train_classifier_page = st.Page(
17
- # "application_pages/train_classifier.py",
18
- # title="Train Classifier",
19
- # icon=":material/fitness_center:",
20
- # )
21
- # llama_guard_fine_tuning_page = st.Page(
22
- # "application_pages/llama_guard_fine_tuning.py",
23
- # title="Fine-Tune LLama Guard",
24
- # icon=":material/star:",
25
- # )
26
  page_navigation = st.navigation(
27
  [
28
  intro_page,
29
  chat_page,
30
  evaluation_page,
31
- # train_classifier_page,
32
- # llama_guard_fine_tuning_page,
33
  ]
34
  )
35
  st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
 
1
  import streamlit as st
2
 
3
  intro_page = st.Page(
4
+ "application_pages/intro_page.py", title="Authenticate", icon=":material/guardian:"
5
  )
6
  chat_page = st.Page(
7
  "application_pages/chat_app.py",
 
13
  title="Evaluation",
14
  icon=":material/monitoring:",
15
  )
 
 
 
 
 
 
 
 
 
 
16
  page_navigation = st.navigation(
17
  [
18
  intro_page,
19
  chat_page,
20
  evaluation_page,
 
 
21
  ]
22
  )
23
  st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
application_pages/chat_app.py CHANGED
@@ -1,8 +1,6 @@
1
  import importlib
2
- import os
3
 
4
  import streamlit as st
5
- import weave
6
  from dotenv import load_dotenv
7
 
8
  from guardrails_genie.guardrails import GuardrailManager
@@ -11,8 +9,6 @@ from guardrails_genie.llm import OpenAIModel
11
 
12
  def initialize_session_state():
13
  load_dotenv()
14
- weave.init(project_name=os.getenv("WEAVE_PROJECT"))
15
-
16
  if "guardrails" not in st.session_state:
17
  st.session_state.guardrails = []
18
  if "guardrail_names" not in st.session_state:
@@ -53,7 +49,6 @@ def initialize_guardrails():
53
  [
54
  "",
55
  "ProtectAI/deberta-v3-base-prompt-injection-v2",
56
- "wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
57
  ],
58
  )
59
  if classifier_model_name != "":
@@ -120,63 +115,69 @@ def initialize_guardrails():
120
  )
121
 
122
 
123
- initialize_session_state()
124
- st.title(":material/robot: Guardrails Genie Playground")
125
-
126
- openai_model = st.sidebar.selectbox(
127
- "OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
128
- )
129
- chat_condition = openai_model != ""
130
 
131
- guardrails = []
132
-
133
- guardrail_names = st.sidebar.multiselect(
134
- label="Select Guardrails",
135
- options=[
136
- cls_name
137
- for cls_name, cls_obj in vars(
138
- importlib.import_module("guardrails_genie.guardrails")
139
- ).items()
140
- if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
141
- ],
142
- )
143
- st.session_state.guardrail_names = guardrail_names
 
 
 
 
 
144
 
145
- if st.sidebar.button("Initialize Guardrails") and chat_condition:
146
- st.session_state.initialize_guardrails = True
147
 
148
- if st.session_state.initialize_guardrails:
149
- with st.sidebar.status("Initializing Guardrails..."):
150
- initialize_guardrails()
151
- st.session_state.llm_model = OpenAIModel(model_name=openai_model)
152
 
153
- user_prompt = st.text_area("User Prompt", value="")
154
- st.session_state.user_prompt = user_prompt
155
 
156
- test_guardrails_button = st.button("Test Guardrails")
157
- st.session_state.test_guardrails = test_guardrails_button
158
 
159
- if st.session_state.test_guardrails:
160
- with st.sidebar.status("Running Guardrails..."):
161
- guardrails_response, call = st.session_state.guardrails_manager.guard.call(
162
- st.session_state.guardrails_manager, prompt=st.session_state.user_prompt
163
- )
 
 
 
164
 
165
- if guardrails_response["safe"]:
166
- st.markdown(
167
- f"\n\n---\nPrompt is safe! Explore guardrail trace on [Weave]({call.ui_url})\n\n---\n"
168
- )
169
 
170
- with st.sidebar.status("Generating response from LLM..."):
171
- response, call = st.session_state.llm_model.predict.call(
172
- st.session_state.llm_model,
173
- user_prompts=st.session_state.user_prompt,
 
 
 
 
174
  )
175
- st.markdown(
176
- response.choices[0].message.content
177
- + f"\n\n---\nExplore LLM generation trace on [Weave]({call.ui_url})"
178
- )
179
- else:
180
- st.warning("Prompt is not safe!")
181
- st.markdown(guardrails_response["summary"])
182
- st.markdown(f"Explore prompt trace on [Weave]({call.ui_url})")
 
1
  import importlib
 
2
 
3
  import streamlit as st
 
4
  from dotenv import load_dotenv
5
 
6
  from guardrails_genie.guardrails import GuardrailManager
 
9
 
10
  def initialize_session_state():
11
  load_dotenv()
 
 
12
  if "guardrails" not in st.session_state:
13
  st.session_state.guardrails = []
14
  if "guardrail_names" not in st.session_state:
 
49
  [
50
  "",
51
  "ProtectAI/deberta-v3-base-prompt-injection-v2",
 
52
  ],
53
  )
54
  if classifier_model_name != "":
 
115
  )
116
 
117
 
118
+ if st.session_state.is_authenticated:
119
+ initialize_session_state()
120
+ st.title(":material/robot: Guardrails Genie Playground")
 
 
 
 
121
 
122
+ openai_model = st.sidebar.selectbox(
123
+ "OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
124
+ )
125
+ chat_condition = openai_model != ""
126
+
127
+ guardrails = []
128
+
129
+ guardrail_names = st.sidebar.multiselect(
130
+ label="Select Guardrails",
131
+ options=[
132
+ cls_name
133
+ for cls_name, cls_obj in vars(
134
+ importlib.import_module("guardrails_genie.guardrails")
135
+ ).items()
136
+ if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
137
+ ],
138
+ )
139
+ st.session_state.guardrail_names = guardrail_names
140
 
141
+ if st.sidebar.button("Initialize Guardrails") and chat_condition:
142
+ st.session_state.initialize_guardrails = True
143
 
144
+ if st.session_state.initialize_guardrails:
145
+ with st.sidebar.status("Initializing Guardrails..."):
146
+ initialize_guardrails()
147
+ st.session_state.llm_model = OpenAIModel(model_name=openai_model)
148
 
149
+ user_prompt = st.text_area("User Prompt", value="")
150
+ st.session_state.user_prompt = user_prompt
151
 
152
+ test_guardrails_button = st.button("Test Guardrails")
153
+ st.session_state.test_guardrails = test_guardrails_button
154
 
155
+ if st.session_state.test_guardrails:
156
+ with st.sidebar.status("Running Guardrails..."):
157
+ guardrails_response, call = (
158
+ st.session_state.guardrails_manager.guard.call(
159
+ st.session_state.guardrails_manager,
160
+ prompt=st.session_state.user_prompt,
161
+ )
162
+ )
163
 
164
+ if guardrails_response["safe"]:
165
+ st.markdown(
166
+ f"\n\n---\nPrompt is safe! Explore guardrail trace on [Weave]({call.ui_url})\n\n---\n"
167
+ )
168
 
169
+ with st.sidebar.status("Generating response from LLM..."):
170
+ response, call = st.session_state.llm_model.predict.call(
171
+ st.session_state.llm_model,
172
+ user_prompts=st.session_state.user_prompt,
173
+ )
174
+ st.markdown(
175
+ response.choices[0].message.content
176
+ + f"\n\n---\nExplore LLM generation trace on [Weave]({call.ui_url})"
177
  )
178
+ else:
179
+ st.warning("Prompt is not safe!")
180
+ st.markdown(guardrails_response["summary"])
181
+ st.markdown(f"Explore prompt trace on [Weave]({call.ui_url})")
182
+ else:
183
+ st.warning("Please authenticate your WandB account to use this feature.")
 
 
application_pages/evaluation_app.py CHANGED
@@ -1,10 +1,7 @@
1
  import asyncio
2
- import os
3
- import time
4
  from importlib import import_module
5
 
6
  import pandas as pd
7
- import rich
8
  import streamlit as st
9
  import weave
10
  from dotenv import load_dotenv
@@ -12,7 +9,6 @@ from dotenv import load_dotenv
12
  from guardrails_genie.guardrails import GuardrailManager
13
  from guardrails_genie.llm import OpenAIModel
14
  from guardrails_genie.metrics import AccuracyMetric
15
- from guardrails_genie.utils import EvaluationCallManager
16
 
17
 
18
  def initialize_session_state():
@@ -20,48 +16,34 @@ def initialize_session_state():
20
  if "uploaded_file" not in st.session_state:
21
  st.session_state.uploaded_file = None
22
  if "dataset_name" not in st.session_state:
23
- st.session_state.dataset_name = ""
24
  if "preview_in_app" not in st.session_state:
25
  st.session_state.preview_in_app = False
 
 
 
 
26
  if "dataset_ref" not in st.session_state:
27
  st.session_state.dataset_ref = None
28
- if "dataset_previewed" not in st.session_state:
29
- st.session_state.dataset_previewed = False
30
- if "guardrail_names" not in st.session_state:
31
- st.session_state.guardrail_names = []
32
  if "guardrails" not in st.session_state:
33
  st.session_state.guardrails = []
34
- if "start_evaluation" not in st.session_state:
35
- st.session_state.start_evaluation = False
36
- if "evaluation_summary" not in st.session_state:
37
- st.session_state.evaluation_summary = None
38
- if "guardrail_manager" not in st.session_state:
39
- st.session_state.guardrail_manager = None
40
  if "evaluation_name" not in st.session_state:
41
  st.session_state.evaluation_name = ""
42
- if "show_result_table" not in st.session_state:
43
- st.session_state.show_result_table = False
44
- if "weave_client" not in st.session_state:
45
- st.session_state.weave_client = weave.init(
46
- project_name=os.getenv("WEAVE_PROJECT")
47
- )
48
- if "evaluation_call_manager" not in st.session_state:
49
- st.session_state.evaluation_call_manager = None
50
- if "call_id" not in st.session_state:
51
- st.session_state.call_id = None
52
- if "llama_guardrail_checkpoint" not in st.session_state:
53
- st.session_state.llama_guardrail_checkpoint = None
54
-
55
-
56
- def initialize_guardrail():
57
- guardrails = []
58
  for guardrail_name in st.session_state.guardrail_names:
59
  if guardrail_name == "PromptInjectionSurveyGuardrail":
60
  survey_guardrail_model = st.sidebar.selectbox(
61
  "Survey Guardrail LLM", ["", "gpt-4o-mini", "gpt-4o"]
62
  )
63
  if survey_guardrail_model:
64
- guardrails.append(
65
  getattr(
66
  import_module("guardrails_genie.guardrails"),
67
  guardrail_name,
@@ -73,29 +55,60 @@ def initialize_guardrail():
73
  [
74
  "",
75
  "ProtectAI/deberta-v3-base-prompt-injection-v2",
76
- "wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
77
  ],
78
  )
79
- if classifier_model_name:
80
  st.session_state.guardrails.append(
81
  getattr(
82
  import_module("guardrails_genie.guardrails"),
83
  guardrail_name,
84
  )(model_name=classifier_model_name)
85
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  elif guardrail_name == "PromptInjectionLlamaGuardrail":
87
- llama_guardrail_checkpoint = st.sidebar.text_input(
88
- "Llama Guardrail Checkpoint",
89
- value=None,
90
  )
91
- st.session_state.llama_guardrail_checkpoint = llama_guardrail_checkpoint
92
- if st.session_state.llama_guardrail_checkpoint is not None:
93
- st.session_state.guardrails.append(
94
- getattr(
95
- import_module("guardrails_genie.guardrails"),
96
- guardrail_name,
97
- )(checkpoint=st.session_state.llama_guardrail_checkpoint)
 
 
 
 
98
  )
 
99
  else:
100
  st.session_state.guardrails.append(
101
  getattr(
@@ -103,64 +116,79 @@ def initialize_guardrail():
103
  guardrail_name,
104
  )()
105
  )
106
- st.session_state.guardrails = guardrails
107
- st.session_state.guardrail_manager = GuardrailManager(guardrails=guardrails)
108
-
109
-
110
- initialize_session_state()
111
- st.title(":material/monitoring: Evaluation")
112
-
113
- uploaded_file = st.sidebar.file_uploader(
114
- "Upload the evaluation dataset as a CSV file", type="csv"
115
- )
116
- st.session_state.uploaded_file = uploaded_file
117
- dataset_name = st.sidebar.text_input("Evaluation dataset name", value="")
118
- st.session_state.dataset_name = dataset_name
119
- preview_in_app = st.sidebar.toggle("Preview in app", value=False)
120
- st.session_state.preview_in_app = preview_in_app
121
-
122
- if st.session_state.uploaded_file is not None and st.session_state.dataset_name != "":
123
- with st.expander("Evaluation Dataset Preview", expanded=True):
124
- dataframe = pd.read_csv(st.session_state.uploaded_file)
125
- data_list = dataframe.to_dict(orient="records")
126
-
127
- dataset = weave.Dataset(name=st.session_state.dataset_name, rows=data_list)
128
- st.session_state.dataset_ref = weave.publish(dataset)
129
-
130
- entity = st.session_state.dataset_ref.entity
131
- project = st.session_state.dataset_ref.project
132
- dataset_name = st.session_state.dataset_name
133
- digest = st.session_state.dataset_ref._digest
134
- st.markdown(
135
- f"Dataset published to [**Weave**](https://wandb.ai/{entity}/{project}/weave/objects/{dataset_name}/versions/{digest})"
136
- )
137
-
138
- if preview_in_app:
139
- st.dataframe(dataframe)
140
-
141
- st.session_state.dataset_previewed = True
142
-
143
- if st.session_state.dataset_previewed:
144
- guardrail_names = st.sidebar.multiselect(
145
- "Select Guardrails",
146
- options=[
147
- cls_name
148
- for cls_name, cls_obj in vars(
149
- import_module("guardrails_genie.guardrails")
150
- ).items()
151
- if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
152
- ],
153
  )
154
- st.session_state.guardrail_names = guardrail_names
155
-
156
- if st.session_state.guardrail_names != []:
157
- initialize_guardrail()
158
- evaluation_name = st.sidebar.text_input("Evaluation name", value="")
159
- st.session_state.evaluation_name = evaluation_name
160
- if st.session_state.guardrail_manager is not None:
161
- if st.sidebar.button("Start Evaluation"):
162
- st.session_state.start_evaluation = True
163
- if st.session_state.start_evaluation:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  evaluation = weave.Evaluation(
165
  dataset=st.session_state.dataset_ref,
166
  scorers=[AccuracyMetric()],
@@ -170,10 +198,14 @@ if st.session_state.dataset_previewed:
170
  evaluation_summary, call = asyncio.run(
171
  evaluation.evaluate.call(
172
  evaluation,
173
- st.session_state.guardrail_manager,
174
  __weave={
175
- "display_name": "Evaluation.evaluate:"
176
- + st.session_state.evaluation_name
 
 
 
 
177
  },
178
  )
179
  )
@@ -187,37 +219,8 @@ if st.session_state.dataset_previewed:
187
  x="Metric",
188
  y="Score",
189
  )
190
- st.session_state.evaluation_summary = evaluation_summary
191
- st.session_state.call_id = call.id
192
- st.session_state.start_evaluation = False
193
-
194
- if not st.session_state.start_evaluation:
195
- time.sleep(5)
196
- st.session_state.evaluation_call_manager = (
197
- EvaluationCallManager(
198
- entity="geekyrakshit",
199
- project="guardrails-genie",
200
- call_id=st.session_state.call_id,
201
- )
202
- )
203
- for guardrail_name in st.session_state.guardrail_names:
204
- st.session_state.evaluation_call_manager.call_list.append(
205
- {
206
- "guardrail_name": guardrail_name,
207
- "calls": st.session_state.evaluation_call_manager.collect_guardrail_guard_calls_from_eval(),
208
- }
209
- )
210
- rich.print(
211
- st.session_state.evaluation_call_manager.call_list
212
- )
213
- st.dataframe(
214
- st.session_state.evaluation_call_manager.render_calls_to_streamlit()
215
- )
216
- if st.session_state.evaluation_call_manager.show_warning_in_app:
217
- st.warning(
218
- f"Only {st.session_state.evaluation_call_manager.max_count} calls can be shown in the app."
219
- )
220
- st.markdown(
221
- f"Explore the entire evaluation trace table in [Weave]({call.ui_url})"
222
- )
223
- st.session_state.evaluation_call_manager = None
 
1
  import asyncio
 
 
2
  from importlib import import_module
3
 
4
  import pandas as pd
 
5
  import streamlit as st
6
  import weave
7
  from dotenv import load_dotenv
 
9
  from guardrails_genie.guardrails import GuardrailManager
10
  from guardrails_genie.llm import OpenAIModel
11
  from guardrails_genie.metrics import AccuracyMetric
 
12
 
13
 
14
  def initialize_session_state():
 
16
  if "uploaded_file" not in st.session_state:
17
  st.session_state.uploaded_file = None
18
  if "dataset_name" not in st.session_state:
19
+ st.session_state.dataset_name = None
20
  if "preview_in_app" not in st.session_state:
21
  st.session_state.preview_in_app = False
22
+ if "is_dataset_published" not in st.session_state:
23
+ st.session_state.is_dataset_published = False
24
+ if "publish_dataset_button" not in st.session_state:
25
+ st.session_state.publish_dataset_button = False
26
  if "dataset_ref" not in st.session_state:
27
  st.session_state.dataset_ref = None
 
 
 
 
28
  if "guardrails" not in st.session_state:
29
  st.session_state.guardrails = []
30
+ if "guardrail_names" not in st.session_state:
31
+ st.session_state.guardrail_names = []
32
+ if "start_evaluations_button" not in st.session_state:
33
+ st.session_state.start_evaluations_button = False
 
 
34
  if "evaluation_name" not in st.session_state:
35
  st.session_state.evaluation_name = ""
36
+
37
+
38
+ def initialize_guardrails():
39
+ st.session_state.guardrails = []
 
 
 
 
 
 
 
 
 
 
 
 
40
  for guardrail_name in st.session_state.guardrail_names:
41
  if guardrail_name == "PromptInjectionSurveyGuardrail":
42
  survey_guardrail_model = st.sidebar.selectbox(
43
  "Survey Guardrail LLM", ["", "gpt-4o-mini", "gpt-4o"]
44
  )
45
  if survey_guardrail_model:
46
+ st.session_state.guardrails.append(
47
  getattr(
48
  import_module("guardrails_genie.guardrails"),
49
  guardrail_name,
 
55
  [
56
  "",
57
  "ProtectAI/deberta-v3-base-prompt-injection-v2",
 
58
  ],
59
  )
60
+ if classifier_model_name != "":
61
  st.session_state.guardrails.append(
62
  getattr(
63
  import_module("guardrails_genie.guardrails"),
64
  guardrail_name,
65
  )(model_name=classifier_model_name)
66
  )
67
+ elif guardrail_name == "PresidioEntityRecognitionGuardrail":
68
+ st.session_state.guardrails.append(
69
+ getattr(
70
+ import_module("guardrails_genie.guardrails"),
71
+ guardrail_name,
72
+ )(should_anonymize=True)
73
+ )
74
+ elif guardrail_name == "RegexEntityRecognitionGuardrail":
75
+ st.session_state.guardrails.append(
76
+ getattr(
77
+ import_module("guardrails_genie.guardrails"),
78
+ guardrail_name,
79
+ )(should_anonymize=True)
80
+ )
81
+ elif guardrail_name == "TransformersEntityRecognitionGuardrail":
82
+ st.session_state.guardrails.append(
83
+ getattr(
84
+ import_module("guardrails_genie.guardrails"),
85
+ guardrail_name,
86
+ )(should_anonymize=True)
87
+ )
88
+ elif guardrail_name == "RestrictedTermsJudge":
89
+ st.session_state.guardrails.append(
90
+ getattr(
91
+ import_module("guardrails_genie.guardrails"),
92
+ guardrail_name,
93
+ )(should_anonymize=True)
94
+ )
95
  elif guardrail_name == "PromptInjectionLlamaGuardrail":
96
+ llama_guard_checkpoint_name = st.sidebar.text_input(
97
+ "Checkpoint Name", value=""
 
98
  )
99
+ st.session_state.llama_guard_checkpoint_name = llama_guard_checkpoint_name
100
+ st.session_state.guardrails.append(
101
+ getattr(
102
+ import_module("guardrails_genie.guardrails"),
103
+ guardrail_name,
104
+ )(
105
+ checkpoint=(
106
+ None
107
+ if st.session_state.llama_guard_checkpoint_name == ""
108
+ else st.session_state.llama_guard_checkpoint_name
109
+ )
110
  )
111
+ )
112
  else:
113
  st.session_state.guardrails.append(
114
  getattr(
 
116
  guardrail_name,
117
  )()
118
  )
119
+ st.session_state.guardrails_manager = GuardrailManager(
120
+ guardrails=st.session_state.guardrails
121
+ )
122
+
123
+
124
+ if st.session_state.is_authenticated:
125
+ initialize_session_state()
126
+ st.title(":material/monitoring: Evaluation")
127
+
128
+ uploaded_file = st.sidebar.file_uploader(
129
+ "Upload the evaluation dataset as a CSV file", type="csv"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  )
131
+ st.session_state.uploaded_file = uploaded_file
132
+
133
+ if st.session_state.uploaded_file is not None:
134
+ dataset_name = st.sidebar.text_input("Evaluation dataset name", value=None)
135
+ st.session_state.dataset_name = dataset_name
136
+ preview_in_app = st.sidebar.toggle("Preview in app", value=False)
137
+ st.session_state.preview_in_app = preview_in_app
138
+ publish_dataset_button = st.sidebar.button("Publish dataset")
139
+ st.session_state.publish_dataset_button = publish_dataset_button
140
+
141
+ if st.session_state.publish_dataset_button and (
142
+ st.session_state.dataset_name is not None
143
+ and st.session_state.dataset_name != ""
144
+ ):
145
+
146
+ with st.expander("Evaluation Dataset Preview", expanded=True):
147
+ dataframe = pd.read_csv(st.session_state.uploaded_file)
148
+ data_list = dataframe.to_dict(orient="records")
149
+
150
+ dataset = weave.Dataset(
151
+ name=st.session_state.dataset_name, rows=data_list
152
+ )
153
+ st.session_state.dataset_ref = weave.publish(dataset)
154
+
155
+ entity = st.session_state.dataset_ref.entity
156
+ project = st.session_state.dataset_ref.project
157
+ dataset_name = st.session_state.dataset_name
158
+ digest = st.session_state.dataset_ref._digest
159
+ dataset_url = f"https://wandb.ai/{entity}/{project}/weave/objects/{dataset_name}/versions/{digest}"
160
+ st.markdown(f"Dataset published to [**Weave**]({dataset_url})")
161
+
162
+ if preview_in_app:
163
+ st.dataframe(dataframe.head(20))
164
+ if len(dataframe) > 20:
165
+ st.markdown(
166
+ f"⚠️ Dataset is too large to preview in app, please explore in the [**Weave UI**]({dataset_url})"
167
+ )
168
+
169
+ st.session_state.is_dataset_published = True
170
+
171
+ if st.session_state.is_dataset_published:
172
+ guardrail_names = st.sidebar.multiselect(
173
+ "Select Guardrails",
174
+ options=[
175
+ cls_name
176
+ for cls_name, cls_obj in vars(
177
+ import_module("guardrails_genie.guardrails")
178
+ ).items()
179
+ if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
180
+ ],
181
+ )
182
+ st.session_state.guardrail_names = guardrail_names
183
+
184
+ initialize_guardrails()
185
+ evaluation_name = st.sidebar.text_input("Evaluation Name", value="")
186
+ st.session_state.evaluation_name = evaluation_name
187
+
188
+ start_evaluations_button = st.sidebar.button("Start Evaluations")
189
+ st.session_state.start_evaluations_button = start_evaluations_button
190
+ if st.session_state.start_evaluations_button:
191
+ # st.write(len(st.session_state.guardrails))
192
  evaluation = weave.Evaluation(
193
  dataset=st.session_state.dataset_ref,
194
  scorers=[AccuracyMetric()],
 
198
  evaluation_summary, call = asyncio.run(
199
  evaluation.evaluate.call(
200
  evaluation,
201
+ GuardrailManager(guardrails=st.session_state.guardrails),
202
  __weave={
203
+ "display_name": (
204
+ "Evaluation.evaluate"
205
+ if st.session_state.evaluation_name == ""
206
+ else "Evaluation.evaluate:"
207
+ + st.session_state.evaluation_name
208
+ )
209
  },
210
  )
211
  )
 
219
  x="Metric",
220
  y="Score",
221
  )
222
+ st.markdown(
223
+ f"Explore the entire evaluation trace table in [Weave]({call.ui_url})"
224
+ )
225
+ else:
226
+ st.warning("Please authenticate your WandB account to use this feature.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
application_pages/intro_page.py CHANGED
@@ -1,7 +1,54 @@
 
 
1
  import streamlit as st
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  st.title("🧞‍♂️ Guardrails Genie")
4
 
5
  st.write(
6
  "Guardrails-Genie is a tool that helps you implement guardrails in your LLM applications."
7
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
  import streamlit as st
4
 
5
+ import wandb
6
+
7
+
8
+ def initialize_session_state():
9
+ if "weave_project_name" not in st.session_state:
10
+ st.session_state.weave_project_name = "guardrails-genie"
11
+ if "weave_entity_name" not in st.session_state:
12
+ st.session_state.weave_entity_name = ""
13
+ if "wandb_api_key" not in st.session_state:
14
+ st.session_state.wandb_api_key = ""
15
+ if "authenticate_button" not in st.session_state:
16
+ st.session_state.authenticate_button = False
17
+ if "is_authenticated" not in st.session_state:
18
+ st.session_state.is_authenticated = False
19
+
20
+
21
+ initialize_session_state()
22
  st.title("🧞‍♂️ Guardrails Genie")
23
 
24
  st.write(
25
  "Guardrails-Genie is a tool that helps you implement guardrails in your LLM applications."
26
  )
27
+
28
+ st.sidebar.markdown(
29
+ "Get your Wandb API key from [https://wandb.ai/authorize](https://wandb.ai/authorize)"
30
+ )
31
+ weave_entity_name = st.sidebar.text_input(
32
+ "Weave Entity Name", value=st.session_state.weave_entity_name
33
+ )
34
+ st.session_state.weave_entity_name = weave_entity_name
35
+ weave_project_name = st.sidebar.text_input(
36
+ "Weave Project Name", value=st.session_state.weave_project_name
37
+ )
38
+ st.session_state.weave_project_name = weave_project_name
39
+ wandb_api_key = st.sidebar.text_input("Wandb API Key", value="", type="password")
40
+ st.session_state.wandb_api_key = wandb_api_key
41
+ openai_api_key = st.sidebar.text_input("OpenAI API Key", value="", type="password")
42
+ os.environ["OPENAI_API_KEY"] = openai_api_key
43
+ authenticate_button = st.sidebar.button("Authenticate")
44
+ st.session_state.authenticate_button = authenticate_button
45
+
46
+ if authenticate_button and (
47
+ st.session_state.wandb_api_key != "" and st.session_state.weave_project_name != ""
48
+ ):
49
+ is_wandb_logged_in = wandb.login(key=st.session_state.wandb_api_key, relogin=True)
50
+ if is_wandb_logged_in:
51
+ st.session_state.is_authenticated = True
52
+ st.success("Logged in to Wandb")
53
+ else:
54
+ st.error("Failed to log in to Wandb")
guardrails_genie/train/llama_guard.py CHANGED
@@ -3,12 +3,13 @@ import shutil
3
  from glob import glob
4
  from typing import Optional
5
 
 
 
6
  import plotly.graph_objects as go
7
  import streamlit as st
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
- import torch.optim as optim
12
  from datasets import load_dataset
13
  from pydantic import BaseModel
14
  from rich.progress import track
@@ -335,8 +336,8 @@ class LlamaGuardFineTuner:
335
 
336
  def train(
337
  self,
338
- batch_size: int = 32,
339
- lr: float = 5e-6,
340
  num_classes: int = 2,
341
  log_interval: int = 1,
342
  save_interval: int = 50,
@@ -358,7 +359,7 @@ class LlamaGuardFineTuner:
358
 
359
  Args:
360
  batch_size (int, optional): The number of samples per batch during training.
361
- lr (float, optional): The learning rate for the optimizer.
362
  num_classes (int, optional): The number of output classes for the classifier.
363
  log_interval (int, optional): The interval (in batches) at which to log the loss.
364
  save_interval (int, optional): The interval (in batches) at which to save model checkpoints.
@@ -377,7 +378,7 @@ class LlamaGuardFineTuner:
377
  wandb.config.dataset_args = self.dataset_args.model_dump()
378
  wandb.config.model_name = self.model_name
379
  wandb.config.batch_size = batch_size
380
- wandb.config.lr = lr
381
  wandb.config.num_classes = num_classes
382
  wandb.config.log_interval = log_interval
383
  wandb.config.save_interval = save_interval
@@ -387,7 +388,16 @@ class LlamaGuardFineTuner:
387
  self.model.num_labels = num_classes
388
  self.model = self.model.to(self.device)
389
  self.model.train()
390
- optimizer = optim.AdamW(self.model.parameters(), lr=lr)
 
 
 
 
 
 
 
 
 
391
  data_loader = DataLoader(
392
  self.train_dataset,
393
  batch_size=batch_size,
@@ -405,9 +415,14 @@ class LlamaGuardFineTuner:
405
  loss = outputs.loss
406
  optimizer.zero_grad()
407
  loss.backward()
 
 
 
408
  optimizer.step()
 
409
  if (i + 1) % log_interval == 0:
410
  wandb.log({"loss": loss.item()}, step=i + 1)
 
411
  if progress_bar:
412
  progress_percentage = (i + 1) * 100 // len(data_loader)
413
  progress_bar.progress(
 
3
  from glob import glob
4
  from typing import Optional
5
 
6
+ # import torch.optim as optim
7
+ import bitsandbytes.optim as optim
8
  import plotly.graph_objects as go
9
  import streamlit as st
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
 
13
  from datasets import load_dataset
14
  from pydantic import BaseModel
15
  from rich.progress import track
 
336
 
337
  def train(
338
  self,
339
+ batch_size: int = 16,
340
+ starting_lr: float = 1e-7,
341
  num_classes: int = 2,
342
  log_interval: int = 1,
343
  save_interval: int = 50,
 
359
 
360
  Args:
361
  batch_size (int, optional): The number of samples per batch during training.
362
+ starting_lr (float, optional): The starting learning rate for the optimizer.
363
  num_classes (int, optional): The number of output classes for the classifier.
364
  log_interval (int, optional): The interval (in batches) at which to log the loss.
365
  save_interval (int, optional): The interval (in batches) at which to save model checkpoints.
 
378
  wandb.config.dataset_args = self.dataset_args.model_dump()
379
  wandb.config.model_name = self.model_name
380
  wandb.config.batch_size = batch_size
381
+ wandb.config.starting_lr = starting_lr
382
  wandb.config.num_classes = num_classes
383
  wandb.config.log_interval = log_interval
384
  wandb.config.save_interval = save_interval
 
388
  self.model.num_labels = num_classes
389
  self.model = self.model.to(self.device)
390
  self.model.train()
391
+ # optimizer = optim.AdamW(self.model.parameters(), lr=starting_lr)
392
+ optimizer = optim.Lion(
393
+ self.model.parameters(), lr=starting_lr, weight_decay=0.01
394
+ )
395
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
396
+ optimizer,
397
+ max_lr=starting_lr,
398
+ steps_per_epoch=len(self.train_dataset) // batch_size + 1,
399
+ epochs=1,
400
+ )
401
  data_loader = DataLoader(
402
  self.train_dataset,
403
  batch_size=batch_size,
 
415
  loss = outputs.loss
416
  optimizer.zero_grad()
417
  loss.backward()
418
+
419
+ # torch.nn.utils.clip_grad_norm_(self.model.parameters(), gradient_clipping)
420
+
421
  optimizer.step()
422
+ scheduler.step()
423
  if (i + 1) % log_interval == 0:
424
  wandb.log({"loss": loss.item()}, step=i + 1)
425
+ wandb.log({"learning_rate": scheduler.get_last_lr()[0]}, step=i + 1)
426
  if progress_bar:
427
  progress_percentage = (i + 1) * 100 // len(data_loader)
428
  progress_bar.progress(