Spaces:
Running
Running
geekyrakshit
commited on
Merge pull request #15 from soumik12345/fix/eval-app
Browse files- README.md +0 -5
- app.py +1 -13
- application_pages/chat_app.py +57 -56
- application_pages/evaluation_app.py +141 -138
- application_pages/intro_page.py +47 -0
- guardrails_genie/train/llama_guard.py +21 -6
README.md
CHANGED
@@ -18,11 +18,6 @@ source .venv/bin/activate
|
|
18 |
## Run the App
|
19 |
|
20 |
```bash
|
21 |
-
export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
|
22 |
-
export WEAVE_PROJECT="YOUR_WEAVE_PROJECT"
|
23 |
-
export WANDB_PROJECT_NAME="YOUR_WANDB_PROJECT_NAME"
|
24 |
-
export WANDB_ENTITY_NAME="YOUR_WANDB_ENTITY_NAME"
|
25 |
-
export WANDB_LOG_MODEL="checkpoint"
|
26 |
streamlit run app.py
|
27 |
```
|
28 |
|
|
|
18 |
## Run the App
|
19 |
|
20 |
```bash
|
|
|
|
|
|
|
|
|
|
|
21 |
streamlit run app.py
|
22 |
```
|
23 |
|
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
|
3 |
intro_page = st.Page(
|
4 |
-
"application_pages/intro_page.py", title="
|
5 |
)
|
6 |
chat_page = st.Page(
|
7 |
"application_pages/chat_app.py",
|
@@ -13,23 +13,11 @@ evaluation_page = st.Page(
|
|
13 |
title="Evaluation",
|
14 |
icon=":material/monitoring:",
|
15 |
)
|
16 |
-
# train_classifier_page = st.Page(
|
17 |
-
# "application_pages/train_classifier.py",
|
18 |
-
# title="Train Classifier",
|
19 |
-
# icon=":material/fitness_center:",
|
20 |
-
# )
|
21 |
-
# llama_guard_fine_tuning_page = st.Page(
|
22 |
-
# "application_pages/llama_guard_fine_tuning.py",
|
23 |
-
# title="Fine-Tune LLama Guard",
|
24 |
-
# icon=":material/star:",
|
25 |
-
# )
|
26 |
page_navigation = st.navigation(
|
27 |
[
|
28 |
intro_page,
|
29 |
chat_page,
|
30 |
evaluation_page,
|
31 |
-
# train_classifier_page,
|
32 |
-
# llama_guard_fine_tuning_page,
|
33 |
]
|
34 |
)
|
35 |
st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
intro_page = st.Page(
|
4 |
+
"application_pages/intro_page.py", title="Authenticate", icon=":material/guardian:"
|
5 |
)
|
6 |
chat_page = st.Page(
|
7 |
"application_pages/chat_app.py",
|
|
|
13 |
title="Evaluation",
|
14 |
icon=":material/monitoring:",
|
15 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
page_navigation = st.navigation(
|
17 |
[
|
18 |
intro_page,
|
19 |
chat_page,
|
20 |
evaluation_page,
|
|
|
|
|
21 |
]
|
22 |
)
|
23 |
st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
|
application_pages/chat_app.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
import importlib
|
2 |
-
import os
|
3 |
|
4 |
import streamlit as st
|
5 |
-
import weave
|
6 |
from dotenv import load_dotenv
|
7 |
|
8 |
from guardrails_genie.guardrails import GuardrailManager
|
@@ -11,8 +9,6 @@ from guardrails_genie.llm import OpenAIModel
|
|
11 |
|
12 |
def initialize_session_state():
|
13 |
load_dotenv()
|
14 |
-
weave.init(project_name=os.getenv("WEAVE_PROJECT"))
|
15 |
-
|
16 |
if "guardrails" not in st.session_state:
|
17 |
st.session_state.guardrails = []
|
18 |
if "guardrail_names" not in st.session_state:
|
@@ -53,7 +49,6 @@ def initialize_guardrails():
|
|
53 |
[
|
54 |
"",
|
55 |
"ProtectAI/deberta-v3-base-prompt-injection-v2",
|
56 |
-
"wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
|
57 |
],
|
58 |
)
|
59 |
if classifier_model_name != "":
|
@@ -120,63 +115,69 @@ def initialize_guardrails():
|
|
120 |
)
|
121 |
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
openai_model = st.sidebar.selectbox(
|
127 |
-
"OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
|
128 |
-
)
|
129 |
-
chat_condition = openai_model != ""
|
130 |
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
-
if st.sidebar.button("Initialize Guardrails") and chat_condition:
|
146 |
-
|
147 |
|
148 |
-
if st.session_state.initialize_guardrails:
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
|
153 |
-
|
154 |
-
|
155 |
|
156 |
-
|
157 |
-
|
158 |
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
164 |
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
|
|
|
|
|
|
|
|
174 |
)
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
st.markdown(guardrails_response["summary"])
|
182 |
-
st.markdown(f"Explore prompt trace on [Weave]({call.ui_url})")
|
|
|
1 |
import importlib
|
|
|
2 |
|
3 |
import streamlit as st
|
|
|
4 |
from dotenv import load_dotenv
|
5 |
|
6 |
from guardrails_genie.guardrails import GuardrailManager
|
|
|
9 |
|
10 |
def initialize_session_state():
|
11 |
load_dotenv()
|
|
|
|
|
12 |
if "guardrails" not in st.session_state:
|
13 |
st.session_state.guardrails = []
|
14 |
if "guardrail_names" not in st.session_state:
|
|
|
49 |
[
|
50 |
"",
|
51 |
"ProtectAI/deberta-v3-base-prompt-injection-v2",
|
|
|
52 |
],
|
53 |
)
|
54 |
if classifier_model_name != "":
|
|
|
115 |
)
|
116 |
|
117 |
|
118 |
+
if st.session_state.is_authenticated:
|
119 |
+
initialize_session_state()
|
120 |
+
st.title(":material/robot: Guardrails Genie Playground")
|
|
|
|
|
|
|
|
|
121 |
|
122 |
+
openai_model = st.sidebar.selectbox(
|
123 |
+
"OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
|
124 |
+
)
|
125 |
+
chat_condition = openai_model != ""
|
126 |
+
|
127 |
+
guardrails = []
|
128 |
+
|
129 |
+
guardrail_names = st.sidebar.multiselect(
|
130 |
+
label="Select Guardrails",
|
131 |
+
options=[
|
132 |
+
cls_name
|
133 |
+
for cls_name, cls_obj in vars(
|
134 |
+
importlib.import_module("guardrails_genie.guardrails")
|
135 |
+
).items()
|
136 |
+
if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
|
137 |
+
],
|
138 |
+
)
|
139 |
+
st.session_state.guardrail_names = guardrail_names
|
140 |
|
141 |
+
if st.sidebar.button("Initialize Guardrails") and chat_condition:
|
142 |
+
st.session_state.initialize_guardrails = True
|
143 |
|
144 |
+
if st.session_state.initialize_guardrails:
|
145 |
+
with st.sidebar.status("Initializing Guardrails..."):
|
146 |
+
initialize_guardrails()
|
147 |
+
st.session_state.llm_model = OpenAIModel(model_name=openai_model)
|
148 |
|
149 |
+
user_prompt = st.text_area("User Prompt", value="")
|
150 |
+
st.session_state.user_prompt = user_prompt
|
151 |
|
152 |
+
test_guardrails_button = st.button("Test Guardrails")
|
153 |
+
st.session_state.test_guardrails = test_guardrails_button
|
154 |
|
155 |
+
if st.session_state.test_guardrails:
|
156 |
+
with st.sidebar.status("Running Guardrails..."):
|
157 |
+
guardrails_response, call = (
|
158 |
+
st.session_state.guardrails_manager.guard.call(
|
159 |
+
st.session_state.guardrails_manager,
|
160 |
+
prompt=st.session_state.user_prompt,
|
161 |
+
)
|
162 |
+
)
|
163 |
|
164 |
+
if guardrails_response["safe"]:
|
165 |
+
st.markdown(
|
166 |
+
f"\n\n---\nPrompt is safe! Explore guardrail trace on [Weave]({call.ui_url})\n\n---\n"
|
167 |
+
)
|
168 |
|
169 |
+
with st.sidebar.status("Generating response from LLM..."):
|
170 |
+
response, call = st.session_state.llm_model.predict.call(
|
171 |
+
st.session_state.llm_model,
|
172 |
+
user_prompts=st.session_state.user_prompt,
|
173 |
+
)
|
174 |
+
st.markdown(
|
175 |
+
response.choices[0].message.content
|
176 |
+
+ f"\n\n---\nExplore LLM generation trace on [Weave]({call.ui_url})"
|
177 |
)
|
178 |
+
else:
|
179 |
+
st.warning("Prompt is not safe!")
|
180 |
+
st.markdown(guardrails_response["summary"])
|
181 |
+
st.markdown(f"Explore prompt trace on [Weave]({call.ui_url})")
|
182 |
+
else:
|
183 |
+
st.warning("Please authenticate your WandB account to use this feature.")
|
|
|
|
application_pages/evaluation_app.py
CHANGED
@@ -1,10 +1,7 @@
|
|
1 |
import asyncio
|
2 |
-
import os
|
3 |
-
import time
|
4 |
from importlib import import_module
|
5 |
|
6 |
import pandas as pd
|
7 |
-
import rich
|
8 |
import streamlit as st
|
9 |
import weave
|
10 |
from dotenv import load_dotenv
|
@@ -12,7 +9,6 @@ from dotenv import load_dotenv
|
|
12 |
from guardrails_genie.guardrails import GuardrailManager
|
13 |
from guardrails_genie.llm import OpenAIModel
|
14 |
from guardrails_genie.metrics import AccuracyMetric
|
15 |
-
from guardrails_genie.utils import EvaluationCallManager
|
16 |
|
17 |
|
18 |
def initialize_session_state():
|
@@ -20,48 +16,34 @@ def initialize_session_state():
|
|
20 |
if "uploaded_file" not in st.session_state:
|
21 |
st.session_state.uploaded_file = None
|
22 |
if "dataset_name" not in st.session_state:
|
23 |
-
st.session_state.dataset_name =
|
24 |
if "preview_in_app" not in st.session_state:
|
25 |
st.session_state.preview_in_app = False
|
|
|
|
|
|
|
|
|
26 |
if "dataset_ref" not in st.session_state:
|
27 |
st.session_state.dataset_ref = None
|
28 |
-
if "dataset_previewed" not in st.session_state:
|
29 |
-
st.session_state.dataset_previewed = False
|
30 |
-
if "guardrail_names" not in st.session_state:
|
31 |
-
st.session_state.guardrail_names = []
|
32 |
if "guardrails" not in st.session_state:
|
33 |
st.session_state.guardrails = []
|
34 |
-
if "
|
35 |
-
st.session_state.
|
36 |
-
if "
|
37 |
-
st.session_state.
|
38 |
-
if "guardrail_manager" not in st.session_state:
|
39 |
-
st.session_state.guardrail_manager = None
|
40 |
if "evaluation_name" not in st.session_state:
|
41 |
st.session_state.evaluation_name = ""
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
project_name=os.getenv("WEAVE_PROJECT")
|
47 |
-
)
|
48 |
-
if "evaluation_call_manager" not in st.session_state:
|
49 |
-
st.session_state.evaluation_call_manager = None
|
50 |
-
if "call_id" not in st.session_state:
|
51 |
-
st.session_state.call_id = None
|
52 |
-
if "llama_guardrail_checkpoint" not in st.session_state:
|
53 |
-
st.session_state.llama_guardrail_checkpoint = None
|
54 |
-
|
55 |
-
|
56 |
-
def initialize_guardrail():
|
57 |
-
guardrails = []
|
58 |
for guardrail_name in st.session_state.guardrail_names:
|
59 |
if guardrail_name == "PromptInjectionSurveyGuardrail":
|
60 |
survey_guardrail_model = st.sidebar.selectbox(
|
61 |
"Survey Guardrail LLM", ["", "gpt-4o-mini", "gpt-4o"]
|
62 |
)
|
63 |
if survey_guardrail_model:
|
64 |
-
guardrails.append(
|
65 |
getattr(
|
66 |
import_module("guardrails_genie.guardrails"),
|
67 |
guardrail_name,
|
@@ -73,29 +55,60 @@ def initialize_guardrail():
|
|
73 |
[
|
74 |
"",
|
75 |
"ProtectAI/deberta-v3-base-prompt-injection-v2",
|
76 |
-
"wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
|
77 |
],
|
78 |
)
|
79 |
-
if classifier_model_name:
|
80 |
st.session_state.guardrails.append(
|
81 |
getattr(
|
82 |
import_module("guardrails_genie.guardrails"),
|
83 |
guardrail_name,
|
84 |
)(model_name=classifier_model_name)
|
85 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
elif guardrail_name == "PromptInjectionLlamaGuardrail":
|
87 |
-
|
88 |
-
"
|
89 |
-
value=None,
|
90 |
)
|
91 |
-
st.session_state.
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
98 |
)
|
|
|
99 |
else:
|
100 |
st.session_state.guardrails.append(
|
101 |
getattr(
|
@@ -103,64 +116,79 @@ def initialize_guardrail():
|
|
103 |
guardrail_name,
|
104 |
)()
|
105 |
)
|
106 |
-
st.session_state.
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
st.
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
dataset_name = st.sidebar.text_input("Evaluation dataset name", value="")
|
118 |
-
st.session_state.dataset_name = dataset_name
|
119 |
-
preview_in_app = st.sidebar.toggle("Preview in app", value=False)
|
120 |
-
st.session_state.preview_in_app = preview_in_app
|
121 |
-
|
122 |
-
if st.session_state.uploaded_file is not None and st.session_state.dataset_name != "":
|
123 |
-
with st.expander("Evaluation Dataset Preview", expanded=True):
|
124 |
-
dataframe = pd.read_csv(st.session_state.uploaded_file)
|
125 |
-
data_list = dataframe.to_dict(orient="records")
|
126 |
-
|
127 |
-
dataset = weave.Dataset(name=st.session_state.dataset_name, rows=data_list)
|
128 |
-
st.session_state.dataset_ref = weave.publish(dataset)
|
129 |
-
|
130 |
-
entity = st.session_state.dataset_ref.entity
|
131 |
-
project = st.session_state.dataset_ref.project
|
132 |
-
dataset_name = st.session_state.dataset_name
|
133 |
-
digest = st.session_state.dataset_ref._digest
|
134 |
-
st.markdown(
|
135 |
-
f"Dataset published to [**Weave**](https://wandb.ai/{entity}/{project}/weave/objects/{dataset_name}/versions/{digest})"
|
136 |
-
)
|
137 |
-
|
138 |
-
if preview_in_app:
|
139 |
-
st.dataframe(dataframe)
|
140 |
-
|
141 |
-
st.session_state.dataset_previewed = True
|
142 |
-
|
143 |
-
if st.session_state.dataset_previewed:
|
144 |
-
guardrail_names = st.sidebar.multiselect(
|
145 |
-
"Select Guardrails",
|
146 |
-
options=[
|
147 |
-
cls_name
|
148 |
-
for cls_name, cls_obj in vars(
|
149 |
-
import_module("guardrails_genie.guardrails")
|
150 |
-
).items()
|
151 |
-
if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
|
152 |
-
],
|
153 |
)
|
154 |
-
st.session_state.
|
155 |
-
|
156 |
-
if st.session_state.
|
157 |
-
|
158 |
-
|
159 |
-
st.
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
evaluation = weave.Evaluation(
|
165 |
dataset=st.session_state.dataset_ref,
|
166 |
scorers=[AccuracyMetric()],
|
@@ -170,10 +198,14 @@ if st.session_state.dataset_previewed:
|
|
170 |
evaluation_summary, call = asyncio.run(
|
171 |
evaluation.evaluate.call(
|
172 |
evaluation,
|
173 |
-
st.session_state.
|
174 |
__weave={
|
175 |
-
"display_name":
|
176 |
-
|
|
|
|
|
|
|
|
|
177 |
},
|
178 |
)
|
179 |
)
|
@@ -187,37 +219,8 @@ if st.session_state.dataset_previewed:
|
|
187 |
x="Metric",
|
188 |
y="Score",
|
189 |
)
|
190 |
-
st.
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
time.sleep(5)
|
196 |
-
st.session_state.evaluation_call_manager = (
|
197 |
-
EvaluationCallManager(
|
198 |
-
entity="geekyrakshit",
|
199 |
-
project="guardrails-genie",
|
200 |
-
call_id=st.session_state.call_id,
|
201 |
-
)
|
202 |
-
)
|
203 |
-
for guardrail_name in st.session_state.guardrail_names:
|
204 |
-
st.session_state.evaluation_call_manager.call_list.append(
|
205 |
-
{
|
206 |
-
"guardrail_name": guardrail_name,
|
207 |
-
"calls": st.session_state.evaluation_call_manager.collect_guardrail_guard_calls_from_eval(),
|
208 |
-
}
|
209 |
-
)
|
210 |
-
rich.print(
|
211 |
-
st.session_state.evaluation_call_manager.call_list
|
212 |
-
)
|
213 |
-
st.dataframe(
|
214 |
-
st.session_state.evaluation_call_manager.render_calls_to_streamlit()
|
215 |
-
)
|
216 |
-
if st.session_state.evaluation_call_manager.show_warning_in_app:
|
217 |
-
st.warning(
|
218 |
-
f"Only {st.session_state.evaluation_call_manager.max_count} calls can be shown in the app."
|
219 |
-
)
|
220 |
-
st.markdown(
|
221 |
-
f"Explore the entire evaluation trace table in [Weave]({call.ui_url})"
|
222 |
-
)
|
223 |
-
st.session_state.evaluation_call_manager = None
|
|
|
1 |
import asyncio
|
|
|
|
|
2 |
from importlib import import_module
|
3 |
|
4 |
import pandas as pd
|
|
|
5 |
import streamlit as st
|
6 |
import weave
|
7 |
from dotenv import load_dotenv
|
|
|
9 |
from guardrails_genie.guardrails import GuardrailManager
|
10 |
from guardrails_genie.llm import OpenAIModel
|
11 |
from guardrails_genie.metrics import AccuracyMetric
|
|
|
12 |
|
13 |
|
14 |
def initialize_session_state():
|
|
|
16 |
if "uploaded_file" not in st.session_state:
|
17 |
st.session_state.uploaded_file = None
|
18 |
if "dataset_name" not in st.session_state:
|
19 |
+
st.session_state.dataset_name = None
|
20 |
if "preview_in_app" not in st.session_state:
|
21 |
st.session_state.preview_in_app = False
|
22 |
+
if "is_dataset_published" not in st.session_state:
|
23 |
+
st.session_state.is_dataset_published = False
|
24 |
+
if "publish_dataset_button" not in st.session_state:
|
25 |
+
st.session_state.publish_dataset_button = False
|
26 |
if "dataset_ref" not in st.session_state:
|
27 |
st.session_state.dataset_ref = None
|
|
|
|
|
|
|
|
|
28 |
if "guardrails" not in st.session_state:
|
29 |
st.session_state.guardrails = []
|
30 |
+
if "guardrail_names" not in st.session_state:
|
31 |
+
st.session_state.guardrail_names = []
|
32 |
+
if "start_evaluations_button" not in st.session_state:
|
33 |
+
st.session_state.start_evaluations_button = False
|
|
|
|
|
34 |
if "evaluation_name" not in st.session_state:
|
35 |
st.session_state.evaluation_name = ""
|
36 |
+
|
37 |
+
|
38 |
+
def initialize_guardrails():
|
39 |
+
st.session_state.guardrails = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
for guardrail_name in st.session_state.guardrail_names:
|
41 |
if guardrail_name == "PromptInjectionSurveyGuardrail":
|
42 |
survey_guardrail_model = st.sidebar.selectbox(
|
43 |
"Survey Guardrail LLM", ["", "gpt-4o-mini", "gpt-4o"]
|
44 |
)
|
45 |
if survey_guardrail_model:
|
46 |
+
st.session_state.guardrails.append(
|
47 |
getattr(
|
48 |
import_module("guardrails_genie.guardrails"),
|
49 |
guardrail_name,
|
|
|
55 |
[
|
56 |
"",
|
57 |
"ProtectAI/deberta-v3-base-prompt-injection-v2",
|
|
|
58 |
],
|
59 |
)
|
60 |
+
if classifier_model_name != "":
|
61 |
st.session_state.guardrails.append(
|
62 |
getattr(
|
63 |
import_module("guardrails_genie.guardrails"),
|
64 |
guardrail_name,
|
65 |
)(model_name=classifier_model_name)
|
66 |
)
|
67 |
+
elif guardrail_name == "PresidioEntityRecognitionGuardrail":
|
68 |
+
st.session_state.guardrails.append(
|
69 |
+
getattr(
|
70 |
+
import_module("guardrails_genie.guardrails"),
|
71 |
+
guardrail_name,
|
72 |
+
)(should_anonymize=True)
|
73 |
+
)
|
74 |
+
elif guardrail_name == "RegexEntityRecognitionGuardrail":
|
75 |
+
st.session_state.guardrails.append(
|
76 |
+
getattr(
|
77 |
+
import_module("guardrails_genie.guardrails"),
|
78 |
+
guardrail_name,
|
79 |
+
)(should_anonymize=True)
|
80 |
+
)
|
81 |
+
elif guardrail_name == "TransformersEntityRecognitionGuardrail":
|
82 |
+
st.session_state.guardrails.append(
|
83 |
+
getattr(
|
84 |
+
import_module("guardrails_genie.guardrails"),
|
85 |
+
guardrail_name,
|
86 |
+
)(should_anonymize=True)
|
87 |
+
)
|
88 |
+
elif guardrail_name == "RestrictedTermsJudge":
|
89 |
+
st.session_state.guardrails.append(
|
90 |
+
getattr(
|
91 |
+
import_module("guardrails_genie.guardrails"),
|
92 |
+
guardrail_name,
|
93 |
+
)(should_anonymize=True)
|
94 |
+
)
|
95 |
elif guardrail_name == "PromptInjectionLlamaGuardrail":
|
96 |
+
llama_guard_checkpoint_name = st.sidebar.text_input(
|
97 |
+
"Checkpoint Name", value=""
|
|
|
98 |
)
|
99 |
+
st.session_state.llama_guard_checkpoint_name = llama_guard_checkpoint_name
|
100 |
+
st.session_state.guardrails.append(
|
101 |
+
getattr(
|
102 |
+
import_module("guardrails_genie.guardrails"),
|
103 |
+
guardrail_name,
|
104 |
+
)(
|
105 |
+
checkpoint=(
|
106 |
+
None
|
107 |
+
if st.session_state.llama_guard_checkpoint_name == ""
|
108 |
+
else st.session_state.llama_guard_checkpoint_name
|
109 |
+
)
|
110 |
)
|
111 |
+
)
|
112 |
else:
|
113 |
st.session_state.guardrails.append(
|
114 |
getattr(
|
|
|
116 |
guardrail_name,
|
117 |
)()
|
118 |
)
|
119 |
+
st.session_state.guardrails_manager = GuardrailManager(
|
120 |
+
guardrails=st.session_state.guardrails
|
121 |
+
)
|
122 |
+
|
123 |
+
|
124 |
+
if st.session_state.is_authenticated:
|
125 |
+
initialize_session_state()
|
126 |
+
st.title(":material/monitoring: Evaluation")
|
127 |
+
|
128 |
+
uploaded_file = st.sidebar.file_uploader(
|
129 |
+
"Upload the evaluation dataset as a CSV file", type="csv"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
)
|
131 |
+
st.session_state.uploaded_file = uploaded_file
|
132 |
+
|
133 |
+
if st.session_state.uploaded_file is not None:
|
134 |
+
dataset_name = st.sidebar.text_input("Evaluation dataset name", value=None)
|
135 |
+
st.session_state.dataset_name = dataset_name
|
136 |
+
preview_in_app = st.sidebar.toggle("Preview in app", value=False)
|
137 |
+
st.session_state.preview_in_app = preview_in_app
|
138 |
+
publish_dataset_button = st.sidebar.button("Publish dataset")
|
139 |
+
st.session_state.publish_dataset_button = publish_dataset_button
|
140 |
+
|
141 |
+
if st.session_state.publish_dataset_button and (
|
142 |
+
st.session_state.dataset_name is not None
|
143 |
+
and st.session_state.dataset_name != ""
|
144 |
+
):
|
145 |
+
|
146 |
+
with st.expander("Evaluation Dataset Preview", expanded=True):
|
147 |
+
dataframe = pd.read_csv(st.session_state.uploaded_file)
|
148 |
+
data_list = dataframe.to_dict(orient="records")
|
149 |
+
|
150 |
+
dataset = weave.Dataset(
|
151 |
+
name=st.session_state.dataset_name, rows=data_list
|
152 |
+
)
|
153 |
+
st.session_state.dataset_ref = weave.publish(dataset)
|
154 |
+
|
155 |
+
entity = st.session_state.dataset_ref.entity
|
156 |
+
project = st.session_state.dataset_ref.project
|
157 |
+
dataset_name = st.session_state.dataset_name
|
158 |
+
digest = st.session_state.dataset_ref._digest
|
159 |
+
dataset_url = f"https://wandb.ai/{entity}/{project}/weave/objects/{dataset_name}/versions/{digest}"
|
160 |
+
st.markdown(f"Dataset published to [**Weave**]({dataset_url})")
|
161 |
+
|
162 |
+
if preview_in_app:
|
163 |
+
st.dataframe(dataframe.head(20))
|
164 |
+
if len(dataframe) > 20:
|
165 |
+
st.markdown(
|
166 |
+
f"⚠️ Dataset is too large to preview in app, please explore in the [**Weave UI**]({dataset_url})"
|
167 |
+
)
|
168 |
+
|
169 |
+
st.session_state.is_dataset_published = True
|
170 |
+
|
171 |
+
if st.session_state.is_dataset_published:
|
172 |
+
guardrail_names = st.sidebar.multiselect(
|
173 |
+
"Select Guardrails",
|
174 |
+
options=[
|
175 |
+
cls_name
|
176 |
+
for cls_name, cls_obj in vars(
|
177 |
+
import_module("guardrails_genie.guardrails")
|
178 |
+
).items()
|
179 |
+
if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
|
180 |
+
],
|
181 |
+
)
|
182 |
+
st.session_state.guardrail_names = guardrail_names
|
183 |
+
|
184 |
+
initialize_guardrails()
|
185 |
+
evaluation_name = st.sidebar.text_input("Evaluation Name", value="")
|
186 |
+
st.session_state.evaluation_name = evaluation_name
|
187 |
+
|
188 |
+
start_evaluations_button = st.sidebar.button("Start Evaluations")
|
189 |
+
st.session_state.start_evaluations_button = start_evaluations_button
|
190 |
+
if st.session_state.start_evaluations_button:
|
191 |
+
# st.write(len(st.session_state.guardrails))
|
192 |
evaluation = weave.Evaluation(
|
193 |
dataset=st.session_state.dataset_ref,
|
194 |
scorers=[AccuracyMetric()],
|
|
|
198 |
evaluation_summary, call = asyncio.run(
|
199 |
evaluation.evaluate.call(
|
200 |
evaluation,
|
201 |
+
GuardrailManager(guardrails=st.session_state.guardrails),
|
202 |
__weave={
|
203 |
+
"display_name": (
|
204 |
+
"Evaluation.evaluate"
|
205 |
+
if st.session_state.evaluation_name == ""
|
206 |
+
else "Evaluation.evaluate:"
|
207 |
+
+ st.session_state.evaluation_name
|
208 |
+
)
|
209 |
},
|
210 |
)
|
211 |
)
|
|
|
219 |
x="Metric",
|
220 |
y="Score",
|
221 |
)
|
222 |
+
st.markdown(
|
223 |
+
f"Explore the entire evaluation trace table in [Weave]({call.ui_url})"
|
224 |
+
)
|
225 |
+
else:
|
226 |
+
st.warning("Please authenticate your WandB account to use this feature.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
application_pages/intro_page.py
CHANGED
@@ -1,7 +1,54 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
st.title("🧞♂️ Guardrails Genie")
|
4 |
|
5 |
st.write(
|
6 |
"Guardrails-Genie is a tool that helps you implement guardrails in your LLM applications."
|
7 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
import streamlit as st
|
4 |
|
5 |
+
import wandb
|
6 |
+
|
7 |
+
|
8 |
+
def initialize_session_state():
|
9 |
+
if "weave_project_name" not in st.session_state:
|
10 |
+
st.session_state.weave_project_name = "guardrails-genie"
|
11 |
+
if "weave_entity_name" not in st.session_state:
|
12 |
+
st.session_state.weave_entity_name = ""
|
13 |
+
if "wandb_api_key" not in st.session_state:
|
14 |
+
st.session_state.wandb_api_key = ""
|
15 |
+
if "authenticate_button" not in st.session_state:
|
16 |
+
st.session_state.authenticate_button = False
|
17 |
+
if "is_authenticated" not in st.session_state:
|
18 |
+
st.session_state.is_authenticated = False
|
19 |
+
|
20 |
+
|
21 |
+
initialize_session_state()
|
22 |
st.title("🧞♂️ Guardrails Genie")
|
23 |
|
24 |
st.write(
|
25 |
"Guardrails-Genie is a tool that helps you implement guardrails in your LLM applications."
|
26 |
)
|
27 |
+
|
28 |
+
st.sidebar.markdown(
|
29 |
+
"Get your Wandb API key from [https://wandb.ai/authorize](https://wandb.ai/authorize)"
|
30 |
+
)
|
31 |
+
weave_entity_name = st.sidebar.text_input(
|
32 |
+
"Weave Entity Name", value=st.session_state.weave_entity_name
|
33 |
+
)
|
34 |
+
st.session_state.weave_entity_name = weave_entity_name
|
35 |
+
weave_project_name = st.sidebar.text_input(
|
36 |
+
"Weave Project Name", value=st.session_state.weave_project_name
|
37 |
+
)
|
38 |
+
st.session_state.weave_project_name = weave_project_name
|
39 |
+
wandb_api_key = st.sidebar.text_input("Wandb API Key", value="", type="password")
|
40 |
+
st.session_state.wandb_api_key = wandb_api_key
|
41 |
+
openai_api_key = st.sidebar.text_input("OpenAI API Key", value="", type="password")
|
42 |
+
os.environ["OPENAI_API_KEY"] = openai_api_key
|
43 |
+
authenticate_button = st.sidebar.button("Authenticate")
|
44 |
+
st.session_state.authenticate_button = authenticate_button
|
45 |
+
|
46 |
+
if authenticate_button and (
|
47 |
+
st.session_state.wandb_api_key != "" and st.session_state.weave_project_name != ""
|
48 |
+
):
|
49 |
+
is_wandb_logged_in = wandb.login(key=st.session_state.wandb_api_key, relogin=True)
|
50 |
+
if is_wandb_logged_in:
|
51 |
+
st.session_state.is_authenticated = True
|
52 |
+
st.success("Logged in to Wandb")
|
53 |
+
else:
|
54 |
+
st.error("Failed to log in to Wandb")
|
guardrails_genie/train/llama_guard.py
CHANGED
@@ -3,12 +3,13 @@ import shutil
|
|
3 |
from glob import glob
|
4 |
from typing import Optional
|
5 |
|
|
|
|
|
6 |
import plotly.graph_objects as go
|
7 |
import streamlit as st
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
11 |
-
import torch.optim as optim
|
12 |
from datasets import load_dataset
|
13 |
from pydantic import BaseModel
|
14 |
from rich.progress import track
|
@@ -335,8 +336,8 @@ class LlamaGuardFineTuner:
|
|
335 |
|
336 |
def train(
|
337 |
self,
|
338 |
-
batch_size: int =
|
339 |
-
|
340 |
num_classes: int = 2,
|
341 |
log_interval: int = 1,
|
342 |
save_interval: int = 50,
|
@@ -358,7 +359,7 @@ class LlamaGuardFineTuner:
|
|
358 |
|
359 |
Args:
|
360 |
batch_size (int, optional): The number of samples per batch during training.
|
361 |
-
|
362 |
num_classes (int, optional): The number of output classes for the classifier.
|
363 |
log_interval (int, optional): The interval (in batches) at which to log the loss.
|
364 |
save_interval (int, optional): The interval (in batches) at which to save model checkpoints.
|
@@ -377,7 +378,7 @@ class LlamaGuardFineTuner:
|
|
377 |
wandb.config.dataset_args = self.dataset_args.model_dump()
|
378 |
wandb.config.model_name = self.model_name
|
379 |
wandb.config.batch_size = batch_size
|
380 |
-
wandb.config.
|
381 |
wandb.config.num_classes = num_classes
|
382 |
wandb.config.log_interval = log_interval
|
383 |
wandb.config.save_interval = save_interval
|
@@ -387,7 +388,16 @@ class LlamaGuardFineTuner:
|
|
387 |
self.model.num_labels = num_classes
|
388 |
self.model = self.model.to(self.device)
|
389 |
self.model.train()
|
390 |
-
optimizer = optim.AdamW(self.model.parameters(), lr=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
data_loader = DataLoader(
|
392 |
self.train_dataset,
|
393 |
batch_size=batch_size,
|
@@ -405,9 +415,14 @@ class LlamaGuardFineTuner:
|
|
405 |
loss = outputs.loss
|
406 |
optimizer.zero_grad()
|
407 |
loss.backward()
|
|
|
|
|
|
|
408 |
optimizer.step()
|
|
|
409 |
if (i + 1) % log_interval == 0:
|
410 |
wandb.log({"loss": loss.item()}, step=i + 1)
|
|
|
411 |
if progress_bar:
|
412 |
progress_percentage = (i + 1) * 100 // len(data_loader)
|
413 |
progress_bar.progress(
|
|
|
3 |
from glob import glob
|
4 |
from typing import Optional
|
5 |
|
6 |
+
# import torch.optim as optim
|
7 |
+
import bitsandbytes.optim as optim
|
8 |
import plotly.graph_objects as go
|
9 |
import streamlit as st
|
10 |
import torch
|
11 |
import torch.nn as nn
|
12 |
import torch.nn.functional as F
|
|
|
13 |
from datasets import load_dataset
|
14 |
from pydantic import BaseModel
|
15 |
from rich.progress import track
|
|
|
336 |
|
337 |
def train(
|
338 |
self,
|
339 |
+
batch_size: int = 16,
|
340 |
+
starting_lr: float = 1e-7,
|
341 |
num_classes: int = 2,
|
342 |
log_interval: int = 1,
|
343 |
save_interval: int = 50,
|
|
|
359 |
|
360 |
Args:
|
361 |
batch_size (int, optional): The number of samples per batch during training.
|
362 |
+
starting_lr (float, optional): The starting learning rate for the optimizer.
|
363 |
num_classes (int, optional): The number of output classes for the classifier.
|
364 |
log_interval (int, optional): The interval (in batches) at which to log the loss.
|
365 |
save_interval (int, optional): The interval (in batches) at which to save model checkpoints.
|
|
|
378 |
wandb.config.dataset_args = self.dataset_args.model_dump()
|
379 |
wandb.config.model_name = self.model_name
|
380 |
wandb.config.batch_size = batch_size
|
381 |
+
wandb.config.starting_lr = starting_lr
|
382 |
wandb.config.num_classes = num_classes
|
383 |
wandb.config.log_interval = log_interval
|
384 |
wandb.config.save_interval = save_interval
|
|
|
388 |
self.model.num_labels = num_classes
|
389 |
self.model = self.model.to(self.device)
|
390 |
self.model.train()
|
391 |
+
# optimizer = optim.AdamW(self.model.parameters(), lr=starting_lr)
|
392 |
+
optimizer = optim.Lion(
|
393 |
+
self.model.parameters(), lr=starting_lr, weight_decay=0.01
|
394 |
+
)
|
395 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
396 |
+
optimizer,
|
397 |
+
max_lr=starting_lr,
|
398 |
+
steps_per_epoch=len(self.train_dataset) // batch_size + 1,
|
399 |
+
epochs=1,
|
400 |
+
)
|
401 |
data_loader = DataLoader(
|
402 |
self.train_dataset,
|
403 |
batch_size=batch_size,
|
|
|
415 |
loss = outputs.loss
|
416 |
optimizer.zero_grad()
|
417 |
loss.backward()
|
418 |
+
|
419 |
+
# torch.nn.utils.clip_grad_norm_(self.model.parameters(), gradient_clipping)
|
420 |
+
|
421 |
optimizer.step()
|
422 |
+
scheduler.step()
|
423 |
if (i + 1) % log_interval == 0:
|
424 |
wandb.log({"loss": loss.item()}, step=i + 1)
|
425 |
+
wandb.log({"learning_rate": scheduler.get_last_lr()[0]}, step=i + 1)
|
426 |
if progress_bar:
|
427 |
progress_percentage = (i + 1) * 100 // len(data_loader)
|
428 |
progress_bar.progress(
|