Spaces:
Running
Running
geekyrakshit
commited on
Commit
•
ec05364
1
Parent(s):
9e04c4b
update: eval ui
Browse files
application_pages/evaluation_app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import asyncio
|
2 |
import os
|
3 |
-
import time
|
4 |
from importlib import import_module
|
5 |
|
6 |
import pandas as pd
|
@@ -12,212 +11,81 @@ from dotenv import load_dotenv
|
|
12 |
from guardrails_genie.guardrails import GuardrailManager
|
13 |
from guardrails_genie.llm import OpenAIModel
|
14 |
from guardrails_genie.metrics import AccuracyMetric
|
15 |
-
from guardrails_genie.utils import EvaluationCallManager
|
16 |
|
17 |
|
18 |
def initialize_session_state():
|
19 |
load_dotenv()
|
|
|
|
|
20 |
if "uploaded_file" not in st.session_state:
|
21 |
st.session_state.uploaded_file = None
|
22 |
if "dataset_name" not in st.session_state:
|
23 |
-
st.session_state.dataset_name =
|
24 |
if "preview_in_app" not in st.session_state:
|
25 |
st.session_state.preview_in_app = False
|
|
|
|
|
|
|
|
|
26 |
if "dataset_ref" not in st.session_state:
|
27 |
st.session_state.dataset_ref = None
|
28 |
-
if "dataset_previewed" not in st.session_state:
|
29 |
-
st.session_state.dataset_previewed = False
|
30 |
-
if "guardrail_names" not in st.session_state:
|
31 |
-
st.session_state.guardrail_names = []
|
32 |
-
if "guardrails" not in st.session_state:
|
33 |
-
st.session_state.guardrails = []
|
34 |
-
if "start_evaluation" not in st.session_state:
|
35 |
-
st.session_state.start_evaluation = False
|
36 |
-
if "evaluation_summary" not in st.session_state:
|
37 |
-
st.session_state.evaluation_summary = None
|
38 |
-
if "guardrail_manager" not in st.session_state:
|
39 |
-
st.session_state.guardrail_manager = None
|
40 |
-
if "evaluation_name" not in st.session_state:
|
41 |
-
st.session_state.evaluation_name = ""
|
42 |
-
if "show_result_table" not in st.session_state:
|
43 |
-
st.session_state.show_result_table = False
|
44 |
-
if "weave_client" not in st.session_state:
|
45 |
-
st.session_state.weave_client = weave.init(
|
46 |
-
project_name=os.getenv("WEAVE_PROJECT")
|
47 |
-
)
|
48 |
-
if "evaluation_call_manager" not in st.session_state:
|
49 |
-
st.session_state.evaluation_call_manager = None
|
50 |
-
if "call_id" not in st.session_state:
|
51 |
-
st.session_state.call_id = None
|
52 |
-
if "llama_guardrail_checkpoint" not in st.session_state:
|
53 |
-
st.session_state.llama_guardrail_checkpoint = None
|
54 |
-
|
55 |
-
|
56 |
-
def initialize_guardrail():
|
57 |
-
guardrails = []
|
58 |
-
for guardrail_name in st.session_state.guardrail_names:
|
59 |
-
if guardrail_name == "PromptInjectionSurveyGuardrail":
|
60 |
-
survey_guardrail_model = st.sidebar.selectbox(
|
61 |
-
"Survey Guardrail LLM", ["", "gpt-4o-mini", "gpt-4o"]
|
62 |
-
)
|
63 |
-
if survey_guardrail_model:
|
64 |
-
guardrails.append(
|
65 |
-
getattr(
|
66 |
-
import_module("guardrails_genie.guardrails"),
|
67 |
-
guardrail_name,
|
68 |
-
)(llm_model=OpenAIModel(model_name=survey_guardrail_model))
|
69 |
-
)
|
70 |
-
elif guardrail_name == "PromptInjectionClassifierGuardrail":
|
71 |
-
classifier_model_name = st.sidebar.selectbox(
|
72 |
-
"Classifier Guardrail Model",
|
73 |
-
[
|
74 |
-
"",
|
75 |
-
"ProtectAI/deberta-v3-base-prompt-injection-v2",
|
76 |
-
"wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
|
77 |
-
],
|
78 |
-
)
|
79 |
-
if classifier_model_name:
|
80 |
-
st.session_state.guardrails.append(
|
81 |
-
getattr(
|
82 |
-
import_module("guardrails_genie.guardrails"),
|
83 |
-
guardrail_name,
|
84 |
-
)(model_name=classifier_model_name)
|
85 |
-
)
|
86 |
-
elif guardrail_name == "PromptInjectionLlamaGuardrail":
|
87 |
-
llama_guardrail_checkpoint = st.sidebar.text_input(
|
88 |
-
"Llama Guardrail Checkpoint",
|
89 |
-
value=None,
|
90 |
-
)
|
91 |
-
st.session_state.llama_guardrail_checkpoint = llama_guardrail_checkpoint
|
92 |
-
if st.session_state.llama_guardrail_checkpoint is not None:
|
93 |
-
st.session_state.guardrails.append(
|
94 |
-
getattr(
|
95 |
-
import_module("guardrails_genie.guardrails"),
|
96 |
-
guardrail_name,
|
97 |
-
)(checkpoint=st.session_state.llama_guardrail_checkpoint)
|
98 |
-
)
|
99 |
-
else:
|
100 |
-
st.session_state.guardrails.append(
|
101 |
-
getattr(
|
102 |
-
import_module("guardrails_genie.guardrails"),
|
103 |
-
guardrail_name,
|
104 |
-
)()
|
105 |
-
)
|
106 |
-
st.session_state.guardrails = guardrails
|
107 |
-
st.session_state.guardrail_manager = GuardrailManager(guardrails=guardrails)
|
108 |
|
109 |
|
110 |
initialize_session_state()
|
111 |
st.title(":material/monitoring: Evaluation")
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
uploaded_file = st.sidebar.file_uploader(
|
114 |
"Upload the evaluation dataset as a CSV file", type="csv"
|
115 |
)
|
116 |
st.session_state.uploaded_file = uploaded_file
|
117 |
-
dataset_name = st.sidebar.text_input("Evaluation dataset name", value="")
|
118 |
-
st.session_state.dataset_name = dataset_name
|
119 |
-
preview_in_app = st.sidebar.toggle("Preview in app", value=False)
|
120 |
-
st.session_state.preview_in_app = preview_in_app
|
121 |
-
|
122 |
-
if st.session_state.uploaded_file is not None and st.session_state.dataset_name != "":
|
123 |
-
with st.expander("Evaluation Dataset Preview", expanded=True):
|
124 |
-
dataframe = pd.read_csv(st.session_state.uploaded_file)
|
125 |
-
data_list = dataframe.to_dict(orient="records")
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
136 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
if st.session_state.dataset_previewed:
|
144 |
-
guardrail_names = st.sidebar.multiselect(
|
145 |
-
"Select Guardrails",
|
146 |
-
options=[
|
147 |
-
cls_name
|
148 |
-
for cls_name, cls_obj in vars(
|
149 |
-
import_module("guardrails_genie.guardrails")
|
150 |
-
).items()
|
151 |
-
if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
|
152 |
-
],
|
153 |
-
)
|
154 |
-
st.session_state.guardrail_names = guardrail_names
|
155 |
-
|
156 |
-
if st.session_state.guardrail_names != []:
|
157 |
-
initialize_guardrail()
|
158 |
-
evaluation_name = st.sidebar.text_input("Evaluation name", value="")
|
159 |
-
st.session_state.evaluation_name = evaluation_name
|
160 |
-
if st.session_state.guardrail_manager is not None:
|
161 |
-
if st.sidebar.button("Start Evaluation"):
|
162 |
-
st.session_state.start_evaluation = True
|
163 |
-
if st.session_state.start_evaluation:
|
164 |
-
evaluation = weave.Evaluation(
|
165 |
-
dataset=st.session_state.dataset_ref,
|
166 |
-
scorers=[AccuracyMetric()],
|
167 |
-
streamlit_mode=True,
|
168 |
-
)
|
169 |
-
with st.expander("Evaluation Results", expanded=True):
|
170 |
-
evaluation_summary, call = asyncio.run(
|
171 |
-
evaluation.evaluate.call(
|
172 |
-
evaluation,
|
173 |
-
st.session_state.guardrail_manager,
|
174 |
-
__weave={
|
175 |
-
"display_name": "Evaluation.evaluate:"
|
176 |
-
+ st.session_state.evaluation_name
|
177 |
-
},
|
178 |
-
)
|
179 |
-
)
|
180 |
-
x_axis = list(evaluation_summary["AccuracyMetric"].keys())
|
181 |
-
y_axis = [
|
182 |
-
evaluation_summary["AccuracyMetric"][x_axis_item]
|
183 |
-
for x_axis_item in x_axis
|
184 |
-
]
|
185 |
-
st.bar_chart(
|
186 |
-
pd.DataFrame({"Metric": x_axis, "Score": y_axis}),
|
187 |
-
x="Metric",
|
188 |
-
y="Score",
|
189 |
)
|
190 |
-
st.session_state.evaluation_summary = evaluation_summary
|
191 |
-
st.session_state.call_id = call.id
|
192 |
-
st.session_state.start_evaluation = False
|
193 |
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
entity="geekyrakshit",
|
199 |
-
project="guardrails-genie",
|
200 |
-
call_id=st.session_state.call_id,
|
201 |
-
)
|
202 |
-
)
|
203 |
-
for guardrail_name in st.session_state.guardrail_names:
|
204 |
-
st.session_state.evaluation_call_manager.call_list.append(
|
205 |
-
{
|
206 |
-
"guardrail_name": guardrail_name,
|
207 |
-
"calls": st.session_state.evaluation_call_manager.collect_guardrail_guard_calls_from_eval(),
|
208 |
-
}
|
209 |
-
)
|
210 |
-
rich.print(
|
211 |
-
st.session_state.evaluation_call_manager.call_list
|
212 |
-
)
|
213 |
-
st.dataframe(
|
214 |
-
st.session_state.evaluation_call_manager.render_calls_to_streamlit()
|
215 |
-
)
|
216 |
-
if st.session_state.evaluation_call_manager.show_warning_in_app:
|
217 |
-
st.warning(
|
218 |
-
f"Only {st.session_state.evaluation_call_manager.max_count} calls can be shown in the app."
|
219 |
-
)
|
220 |
-
st.markdown(
|
221 |
-
f"Explore the entire evaluation trace table in [Weave]({call.ui_url})"
|
222 |
-
)
|
223 |
-
st.session_state.evaluation_call_manager = None
|
|
|
1 |
import asyncio
|
2 |
import os
|
|
|
3 |
from importlib import import_module
|
4 |
|
5 |
import pandas as pd
|
|
|
11 |
from guardrails_genie.guardrails import GuardrailManager
|
12 |
from guardrails_genie.llm import OpenAIModel
|
13 |
from guardrails_genie.metrics import AccuracyMetric
|
|
|
14 |
|
15 |
|
16 |
def initialize_session_state():
|
17 |
load_dotenv()
|
18 |
+
if "weave_project_name" not in st.session_state:
|
19 |
+
st.session_state.weave_project_name = "guardrails-genie"
|
20 |
if "uploaded_file" not in st.session_state:
|
21 |
st.session_state.uploaded_file = None
|
22 |
if "dataset_name" not in st.session_state:
|
23 |
+
st.session_state.dataset_name = None
|
24 |
if "preview_in_app" not in st.session_state:
|
25 |
st.session_state.preview_in_app = False
|
26 |
+
if "is_dataset_published" not in st.session_state:
|
27 |
+
st.session_state.is_dataset_published = False
|
28 |
+
if "publish_dataset_button" not in st.session_state:
|
29 |
+
st.session_state.publish_dataset_button = False
|
30 |
if "dataset_ref" not in st.session_state:
|
31 |
st.session_state.dataset_ref = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
|
34 |
initialize_session_state()
|
35 |
st.title(":material/monitoring: Evaluation")
|
36 |
|
37 |
+
weave_project_name = st.sidebar.text_input(
|
38 |
+
"Weave project name", value=st.session_state.weave_project_name
|
39 |
+
)
|
40 |
+
st.session_state.weave_project_name = weave_project_name
|
41 |
+
if st.session_state.weave_project_name != "":
|
42 |
+
weave.init(project_name=st.session_state.weave_project_name)
|
43 |
+
|
44 |
uploaded_file = st.sidebar.file_uploader(
|
45 |
"Upload the evaluation dataset as a CSV file", type="csv"
|
46 |
)
|
47 |
st.session_state.uploaded_file = uploaded_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
+
if st.session_state.uploaded_file is not None:
|
50 |
+
dataset_name = st.sidebar.text_input("Evaluation dataset name", value=None)
|
51 |
+
st.session_state.dataset_name = dataset_name
|
52 |
+
preview_in_app = st.sidebar.toggle("Preview in app", value=False)
|
53 |
+
st.session_state.preview_in_app = preview_in_app
|
54 |
+
publish_dataset_button = st.sidebar.button("Publish dataset")
|
55 |
+
st.session_state.publish_dataset_button = publish_dataset_button
|
56 |
+
|
57 |
+
if (
|
58 |
+
st.session_state.publish_dataset_button
|
59 |
+
and (
|
60 |
+
st.session_state.dataset_name is not None
|
61 |
+
and st.session_state.dataset_name != ""
|
62 |
)
|
63 |
+
):
|
64 |
+
|
65 |
+
with st.expander("Evaluation Dataset Preview", expanded=True):
|
66 |
+
dataframe = pd.read_csv(st.session_state.uploaded_file)
|
67 |
+
data_list = dataframe.to_dict(orient="records")
|
68 |
+
|
69 |
+
dataset = weave.Dataset(name=st.session_state.dataset_name, rows=data_list)
|
70 |
+
st.session_state.dataset_ref = weave.publish(dataset)
|
71 |
+
|
72 |
+
entity = st.session_state.dataset_ref.entity
|
73 |
+
project = st.session_state.dataset_ref.project
|
74 |
+
dataset_name = st.session_state.dataset_name
|
75 |
+
digest = st.session_state.dataset_ref._digest
|
76 |
+
dataset_url = f"https://wandb.ai/{entity}/{project}/weave/objects/{dataset_name}/versions/{digest}"
|
77 |
+
st.markdown(
|
78 |
+
f"Dataset published to [**Weave**]({dataset_url})"
|
79 |
+
)
|
80 |
|
81 |
+
if preview_in_app:
|
82 |
+
st.dataframe(dataframe.head(20))
|
83 |
+
if len(dataframe) > 20:
|
84 |
+
st.markdown(
|
85 |
+
f"⚠️ Dataset is too large to preview in app, please explore in the [**Weave UI**]({dataset_url})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
)
|
|
|
|
|
|
|
87 |
|
88 |
+
st.session_state.is_dataset_published = True
|
89 |
+
|
90 |
+
if st.session_state.is_dataset_published:
|
91 |
+
st.write("Maza Ayega")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
guardrails_genie/train/llama_guard.py
CHANGED
@@ -3,12 +3,13 @@ import shutil
|
|
3 |
from glob import glob
|
4 |
from typing import Optional
|
5 |
|
|
|
|
|
6 |
import plotly.graph_objects as go
|
7 |
import streamlit as st
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
11 |
-
import torch.optim as optim
|
12 |
from datasets import load_dataset
|
13 |
from pydantic import BaseModel
|
14 |
from rich.progress import track
|
@@ -335,8 +336,8 @@ class LlamaGuardFineTuner:
|
|
335 |
|
336 |
def train(
|
337 |
self,
|
338 |
-
batch_size: int =
|
339 |
-
|
340 |
num_classes: int = 2,
|
341 |
log_interval: int = 1,
|
342 |
save_interval: int = 50,
|
@@ -358,7 +359,7 @@ class LlamaGuardFineTuner:
|
|
358 |
|
359 |
Args:
|
360 |
batch_size (int, optional): The number of samples per batch during training.
|
361 |
-
|
362 |
num_classes (int, optional): The number of output classes for the classifier.
|
363 |
log_interval (int, optional): The interval (in batches) at which to log the loss.
|
364 |
save_interval (int, optional): The interval (in batches) at which to save model checkpoints.
|
@@ -377,7 +378,7 @@ class LlamaGuardFineTuner:
|
|
377 |
wandb.config.dataset_args = self.dataset_args.model_dump()
|
378 |
wandb.config.model_name = self.model_name
|
379 |
wandb.config.batch_size = batch_size
|
380 |
-
wandb.config.
|
381 |
wandb.config.num_classes = num_classes
|
382 |
wandb.config.log_interval = log_interval
|
383 |
wandb.config.save_interval = save_interval
|
@@ -387,7 +388,16 @@ class LlamaGuardFineTuner:
|
|
387 |
self.model.num_labels = num_classes
|
388 |
self.model = self.model.to(self.device)
|
389 |
self.model.train()
|
390 |
-
optimizer = optim.AdamW(self.model.parameters(), lr=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
data_loader = DataLoader(
|
392 |
self.train_dataset,
|
393 |
batch_size=batch_size,
|
@@ -405,9 +415,14 @@ class LlamaGuardFineTuner:
|
|
405 |
loss = outputs.loss
|
406 |
optimizer.zero_grad()
|
407 |
loss.backward()
|
|
|
|
|
|
|
408 |
optimizer.step()
|
|
|
409 |
if (i + 1) % log_interval == 0:
|
410 |
wandb.log({"loss": loss.item()}, step=i + 1)
|
|
|
411 |
if progress_bar:
|
412 |
progress_percentage = (i + 1) * 100 // len(data_loader)
|
413 |
progress_bar.progress(
|
|
|
3 |
from glob import glob
|
4 |
from typing import Optional
|
5 |
|
6 |
+
# import torch.optim as optim
|
7 |
+
import bitsandbytes.optim as optim
|
8 |
import plotly.graph_objects as go
|
9 |
import streamlit as st
|
10 |
import torch
|
11 |
import torch.nn as nn
|
12 |
import torch.nn.functional as F
|
|
|
13 |
from datasets import load_dataset
|
14 |
from pydantic import BaseModel
|
15 |
from rich.progress import track
|
|
|
336 |
|
337 |
def train(
|
338 |
self,
|
339 |
+
batch_size: int = 16,
|
340 |
+
starting_lr: float = 1e-7,
|
341 |
num_classes: int = 2,
|
342 |
log_interval: int = 1,
|
343 |
save_interval: int = 50,
|
|
|
359 |
|
360 |
Args:
|
361 |
batch_size (int, optional): The number of samples per batch during training.
|
362 |
+
starting_lr (float, optional): The starting learning rate for the optimizer.
|
363 |
num_classes (int, optional): The number of output classes for the classifier.
|
364 |
log_interval (int, optional): The interval (in batches) at which to log the loss.
|
365 |
save_interval (int, optional): The interval (in batches) at which to save model checkpoints.
|
|
|
378 |
wandb.config.dataset_args = self.dataset_args.model_dump()
|
379 |
wandb.config.model_name = self.model_name
|
380 |
wandb.config.batch_size = batch_size
|
381 |
+
wandb.config.starting_lr = starting_lr
|
382 |
wandb.config.num_classes = num_classes
|
383 |
wandb.config.log_interval = log_interval
|
384 |
wandb.config.save_interval = save_interval
|
|
|
388 |
self.model.num_labels = num_classes
|
389 |
self.model = self.model.to(self.device)
|
390 |
self.model.train()
|
391 |
+
# optimizer = optim.AdamW(self.model.parameters(), lr=starting_lr)
|
392 |
+
optimizer = optim.Lion(
|
393 |
+
self.model.parameters(), lr=starting_lr, weight_decay=0.01
|
394 |
+
)
|
395 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
396 |
+
optimizer,
|
397 |
+
max_lr=starting_lr,
|
398 |
+
steps_per_epoch=len(self.train_dataset) // batch_size + 1,
|
399 |
+
epochs=1,
|
400 |
+
)
|
401 |
data_loader = DataLoader(
|
402 |
self.train_dataset,
|
403 |
batch_size=batch_size,
|
|
|
415 |
loss = outputs.loss
|
416 |
optimizer.zero_grad()
|
417 |
loss.backward()
|
418 |
+
|
419 |
+
# torch.nn.utils.clip_grad_norm_(self.model.parameters(), gradient_clipping)
|
420 |
+
|
421 |
optimizer.step()
|
422 |
+
scheduler.step()
|
423 |
if (i + 1) % log_interval == 0:
|
424 |
wandb.log({"loss": loss.item()}, step=i + 1)
|
425 |
+
wandb.log({"learning_rate": scheduler.get_last_lr()[0]}, step=i + 1)
|
426 |
if progress_bar:
|
427 |
progress_percentage = (i + 1) * 100 // len(data_loader)
|
428 |
progress_bar.progress(
|