geekyrakshit commited on
Commit
ec05364
1 Parent(s): 9e04c4b

update: eval ui

Browse files
application_pages/evaluation_app.py CHANGED
@@ -1,6 +1,5 @@
1
  import asyncio
2
  import os
3
- import time
4
  from importlib import import_module
5
 
6
  import pandas as pd
@@ -12,212 +11,81 @@ from dotenv import load_dotenv
12
  from guardrails_genie.guardrails import GuardrailManager
13
  from guardrails_genie.llm import OpenAIModel
14
  from guardrails_genie.metrics import AccuracyMetric
15
- from guardrails_genie.utils import EvaluationCallManager
16
 
17
 
18
  def initialize_session_state():
19
  load_dotenv()
 
 
20
  if "uploaded_file" not in st.session_state:
21
  st.session_state.uploaded_file = None
22
  if "dataset_name" not in st.session_state:
23
- st.session_state.dataset_name = ""
24
  if "preview_in_app" not in st.session_state:
25
  st.session_state.preview_in_app = False
 
 
 
 
26
  if "dataset_ref" not in st.session_state:
27
  st.session_state.dataset_ref = None
28
- if "dataset_previewed" not in st.session_state:
29
- st.session_state.dataset_previewed = False
30
- if "guardrail_names" not in st.session_state:
31
- st.session_state.guardrail_names = []
32
- if "guardrails" not in st.session_state:
33
- st.session_state.guardrails = []
34
- if "start_evaluation" not in st.session_state:
35
- st.session_state.start_evaluation = False
36
- if "evaluation_summary" not in st.session_state:
37
- st.session_state.evaluation_summary = None
38
- if "guardrail_manager" not in st.session_state:
39
- st.session_state.guardrail_manager = None
40
- if "evaluation_name" not in st.session_state:
41
- st.session_state.evaluation_name = ""
42
- if "show_result_table" not in st.session_state:
43
- st.session_state.show_result_table = False
44
- if "weave_client" not in st.session_state:
45
- st.session_state.weave_client = weave.init(
46
- project_name=os.getenv("WEAVE_PROJECT")
47
- )
48
- if "evaluation_call_manager" not in st.session_state:
49
- st.session_state.evaluation_call_manager = None
50
- if "call_id" not in st.session_state:
51
- st.session_state.call_id = None
52
- if "llama_guardrail_checkpoint" not in st.session_state:
53
- st.session_state.llama_guardrail_checkpoint = None
54
-
55
-
56
- def initialize_guardrail():
57
- guardrails = []
58
- for guardrail_name in st.session_state.guardrail_names:
59
- if guardrail_name == "PromptInjectionSurveyGuardrail":
60
- survey_guardrail_model = st.sidebar.selectbox(
61
- "Survey Guardrail LLM", ["", "gpt-4o-mini", "gpt-4o"]
62
- )
63
- if survey_guardrail_model:
64
- guardrails.append(
65
- getattr(
66
- import_module("guardrails_genie.guardrails"),
67
- guardrail_name,
68
- )(llm_model=OpenAIModel(model_name=survey_guardrail_model))
69
- )
70
- elif guardrail_name == "PromptInjectionClassifierGuardrail":
71
- classifier_model_name = st.sidebar.selectbox(
72
- "Classifier Guardrail Model",
73
- [
74
- "",
75
- "ProtectAI/deberta-v3-base-prompt-injection-v2",
76
- "wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
77
- ],
78
- )
79
- if classifier_model_name:
80
- st.session_state.guardrails.append(
81
- getattr(
82
- import_module("guardrails_genie.guardrails"),
83
- guardrail_name,
84
- )(model_name=classifier_model_name)
85
- )
86
- elif guardrail_name == "PromptInjectionLlamaGuardrail":
87
- llama_guardrail_checkpoint = st.sidebar.text_input(
88
- "Llama Guardrail Checkpoint",
89
- value=None,
90
- )
91
- st.session_state.llama_guardrail_checkpoint = llama_guardrail_checkpoint
92
- if st.session_state.llama_guardrail_checkpoint is not None:
93
- st.session_state.guardrails.append(
94
- getattr(
95
- import_module("guardrails_genie.guardrails"),
96
- guardrail_name,
97
- )(checkpoint=st.session_state.llama_guardrail_checkpoint)
98
- )
99
- else:
100
- st.session_state.guardrails.append(
101
- getattr(
102
- import_module("guardrails_genie.guardrails"),
103
- guardrail_name,
104
- )()
105
- )
106
- st.session_state.guardrails = guardrails
107
- st.session_state.guardrail_manager = GuardrailManager(guardrails=guardrails)
108
 
109
 
110
  initialize_session_state()
111
  st.title(":material/monitoring: Evaluation")
112
 
 
 
 
 
 
 
 
113
  uploaded_file = st.sidebar.file_uploader(
114
  "Upload the evaluation dataset as a CSV file", type="csv"
115
  )
116
  st.session_state.uploaded_file = uploaded_file
117
- dataset_name = st.sidebar.text_input("Evaluation dataset name", value="")
118
- st.session_state.dataset_name = dataset_name
119
- preview_in_app = st.sidebar.toggle("Preview in app", value=False)
120
- st.session_state.preview_in_app = preview_in_app
121
-
122
- if st.session_state.uploaded_file is not None and st.session_state.dataset_name != "":
123
- with st.expander("Evaluation Dataset Preview", expanded=True):
124
- dataframe = pd.read_csv(st.session_state.uploaded_file)
125
- data_list = dataframe.to_dict(orient="records")
126
 
127
- dataset = weave.Dataset(name=st.session_state.dataset_name, rows=data_list)
128
- st.session_state.dataset_ref = weave.publish(dataset)
129
-
130
- entity = st.session_state.dataset_ref.entity
131
- project = st.session_state.dataset_ref.project
132
- dataset_name = st.session_state.dataset_name
133
- digest = st.session_state.dataset_ref._digest
134
- st.markdown(
135
- f"Dataset published to [**Weave**](https://wandb.ai/{entity}/{project}/weave/objects/{dataset_name}/versions/{digest})"
 
 
 
 
136
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- if preview_in_app:
139
- st.dataframe(dataframe)
140
-
141
- st.session_state.dataset_previewed = True
142
-
143
- if st.session_state.dataset_previewed:
144
- guardrail_names = st.sidebar.multiselect(
145
- "Select Guardrails",
146
- options=[
147
- cls_name
148
- for cls_name, cls_obj in vars(
149
- import_module("guardrails_genie.guardrails")
150
- ).items()
151
- if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
152
- ],
153
- )
154
- st.session_state.guardrail_names = guardrail_names
155
-
156
- if st.session_state.guardrail_names != []:
157
- initialize_guardrail()
158
- evaluation_name = st.sidebar.text_input("Evaluation name", value="")
159
- st.session_state.evaluation_name = evaluation_name
160
- if st.session_state.guardrail_manager is not None:
161
- if st.sidebar.button("Start Evaluation"):
162
- st.session_state.start_evaluation = True
163
- if st.session_state.start_evaluation:
164
- evaluation = weave.Evaluation(
165
- dataset=st.session_state.dataset_ref,
166
- scorers=[AccuracyMetric()],
167
- streamlit_mode=True,
168
- )
169
- with st.expander("Evaluation Results", expanded=True):
170
- evaluation_summary, call = asyncio.run(
171
- evaluation.evaluate.call(
172
- evaluation,
173
- st.session_state.guardrail_manager,
174
- __weave={
175
- "display_name": "Evaluation.evaluate:"
176
- + st.session_state.evaluation_name
177
- },
178
- )
179
- )
180
- x_axis = list(evaluation_summary["AccuracyMetric"].keys())
181
- y_axis = [
182
- evaluation_summary["AccuracyMetric"][x_axis_item]
183
- for x_axis_item in x_axis
184
- ]
185
- st.bar_chart(
186
- pd.DataFrame({"Metric": x_axis, "Score": y_axis}),
187
- x="Metric",
188
- y="Score",
189
  )
190
- st.session_state.evaluation_summary = evaluation_summary
191
- st.session_state.call_id = call.id
192
- st.session_state.start_evaluation = False
193
 
194
- if not st.session_state.start_evaluation:
195
- time.sleep(5)
196
- st.session_state.evaluation_call_manager = (
197
- EvaluationCallManager(
198
- entity="geekyrakshit",
199
- project="guardrails-genie",
200
- call_id=st.session_state.call_id,
201
- )
202
- )
203
- for guardrail_name in st.session_state.guardrail_names:
204
- st.session_state.evaluation_call_manager.call_list.append(
205
- {
206
- "guardrail_name": guardrail_name,
207
- "calls": st.session_state.evaluation_call_manager.collect_guardrail_guard_calls_from_eval(),
208
- }
209
- )
210
- rich.print(
211
- st.session_state.evaluation_call_manager.call_list
212
- )
213
- st.dataframe(
214
- st.session_state.evaluation_call_manager.render_calls_to_streamlit()
215
- )
216
- if st.session_state.evaluation_call_manager.show_warning_in_app:
217
- st.warning(
218
- f"Only {st.session_state.evaluation_call_manager.max_count} calls can be shown in the app."
219
- )
220
- st.markdown(
221
- f"Explore the entire evaluation trace table in [Weave]({call.ui_url})"
222
- )
223
- st.session_state.evaluation_call_manager = None
 
1
  import asyncio
2
  import os
 
3
  from importlib import import_module
4
 
5
  import pandas as pd
 
11
  from guardrails_genie.guardrails import GuardrailManager
12
  from guardrails_genie.llm import OpenAIModel
13
  from guardrails_genie.metrics import AccuracyMetric
 
14
 
15
 
16
  def initialize_session_state():
17
  load_dotenv()
18
+ if "weave_project_name" not in st.session_state:
19
+ st.session_state.weave_project_name = "guardrails-genie"
20
  if "uploaded_file" not in st.session_state:
21
  st.session_state.uploaded_file = None
22
  if "dataset_name" not in st.session_state:
23
+ st.session_state.dataset_name = None
24
  if "preview_in_app" not in st.session_state:
25
  st.session_state.preview_in_app = False
26
+ if "is_dataset_published" not in st.session_state:
27
+ st.session_state.is_dataset_published = False
28
+ if "publish_dataset_button" not in st.session_state:
29
+ st.session_state.publish_dataset_button = False
30
  if "dataset_ref" not in st.session_state:
31
  st.session_state.dataset_ref = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  initialize_session_state()
35
  st.title(":material/monitoring: Evaluation")
36
 
37
+ weave_project_name = st.sidebar.text_input(
38
+ "Weave project name", value=st.session_state.weave_project_name
39
+ )
40
+ st.session_state.weave_project_name = weave_project_name
41
+ if st.session_state.weave_project_name != "":
42
+ weave.init(project_name=st.session_state.weave_project_name)
43
+
44
  uploaded_file = st.sidebar.file_uploader(
45
  "Upload the evaluation dataset as a CSV file", type="csv"
46
  )
47
  st.session_state.uploaded_file = uploaded_file
 
 
 
 
 
 
 
 
 
48
 
49
+ if st.session_state.uploaded_file is not None:
50
+ dataset_name = st.sidebar.text_input("Evaluation dataset name", value=None)
51
+ st.session_state.dataset_name = dataset_name
52
+ preview_in_app = st.sidebar.toggle("Preview in app", value=False)
53
+ st.session_state.preview_in_app = preview_in_app
54
+ publish_dataset_button = st.sidebar.button("Publish dataset")
55
+ st.session_state.publish_dataset_button = publish_dataset_button
56
+
57
+ if (
58
+ st.session_state.publish_dataset_button
59
+ and (
60
+ st.session_state.dataset_name is not None
61
+ and st.session_state.dataset_name != ""
62
  )
63
+ ):
64
+
65
+ with st.expander("Evaluation Dataset Preview", expanded=True):
66
+ dataframe = pd.read_csv(st.session_state.uploaded_file)
67
+ data_list = dataframe.to_dict(orient="records")
68
+
69
+ dataset = weave.Dataset(name=st.session_state.dataset_name, rows=data_list)
70
+ st.session_state.dataset_ref = weave.publish(dataset)
71
+
72
+ entity = st.session_state.dataset_ref.entity
73
+ project = st.session_state.dataset_ref.project
74
+ dataset_name = st.session_state.dataset_name
75
+ digest = st.session_state.dataset_ref._digest
76
+ dataset_url = f"https://wandb.ai/{entity}/{project}/weave/objects/{dataset_name}/versions/{digest}"
77
+ st.markdown(
78
+ f"Dataset published to [**Weave**]({dataset_url})"
79
+ )
80
 
81
+ if preview_in_app:
82
+ st.dataframe(dataframe.head(20))
83
+ if len(dataframe) > 20:
84
+ st.markdown(
85
+ f"⚠️ Dataset is too large to preview in app, please explore in the [**Weave UI**]({dataset_url})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  )
 
 
 
87
 
88
+ st.session_state.is_dataset_published = True
89
+
90
+ if st.session_state.is_dataset_published:
91
+ st.write("Maza Ayega")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
guardrails_genie/train/llama_guard.py CHANGED
@@ -3,12 +3,13 @@ import shutil
3
  from glob import glob
4
  from typing import Optional
5
 
 
 
6
  import plotly.graph_objects as go
7
  import streamlit as st
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
- import torch.optim as optim
12
  from datasets import load_dataset
13
  from pydantic import BaseModel
14
  from rich.progress import track
@@ -335,8 +336,8 @@ class LlamaGuardFineTuner:
335
 
336
  def train(
337
  self,
338
- batch_size: int = 32,
339
- lr: float = 5e-6,
340
  num_classes: int = 2,
341
  log_interval: int = 1,
342
  save_interval: int = 50,
@@ -358,7 +359,7 @@ class LlamaGuardFineTuner:
358
 
359
  Args:
360
  batch_size (int, optional): The number of samples per batch during training.
361
- lr (float, optional): The learning rate for the optimizer.
362
  num_classes (int, optional): The number of output classes for the classifier.
363
  log_interval (int, optional): The interval (in batches) at which to log the loss.
364
  save_interval (int, optional): The interval (in batches) at which to save model checkpoints.
@@ -377,7 +378,7 @@ class LlamaGuardFineTuner:
377
  wandb.config.dataset_args = self.dataset_args.model_dump()
378
  wandb.config.model_name = self.model_name
379
  wandb.config.batch_size = batch_size
380
- wandb.config.lr = lr
381
  wandb.config.num_classes = num_classes
382
  wandb.config.log_interval = log_interval
383
  wandb.config.save_interval = save_interval
@@ -387,7 +388,16 @@ class LlamaGuardFineTuner:
387
  self.model.num_labels = num_classes
388
  self.model = self.model.to(self.device)
389
  self.model.train()
390
- optimizer = optim.AdamW(self.model.parameters(), lr=lr)
 
 
 
 
 
 
 
 
 
391
  data_loader = DataLoader(
392
  self.train_dataset,
393
  batch_size=batch_size,
@@ -405,9 +415,14 @@ class LlamaGuardFineTuner:
405
  loss = outputs.loss
406
  optimizer.zero_grad()
407
  loss.backward()
 
 
 
408
  optimizer.step()
 
409
  if (i + 1) % log_interval == 0:
410
  wandb.log({"loss": loss.item()}, step=i + 1)
 
411
  if progress_bar:
412
  progress_percentage = (i + 1) * 100 // len(data_loader)
413
  progress_bar.progress(
 
3
  from glob import glob
4
  from typing import Optional
5
 
6
+ # import torch.optim as optim
7
+ import bitsandbytes.optim as optim
8
  import plotly.graph_objects as go
9
  import streamlit as st
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
 
13
  from datasets import load_dataset
14
  from pydantic import BaseModel
15
  from rich.progress import track
 
336
 
337
  def train(
338
  self,
339
+ batch_size: int = 16,
340
+ starting_lr: float = 1e-7,
341
  num_classes: int = 2,
342
  log_interval: int = 1,
343
  save_interval: int = 50,
 
359
 
360
  Args:
361
  batch_size (int, optional): The number of samples per batch during training.
362
+ starting_lr (float, optional): The starting learning rate for the optimizer.
363
  num_classes (int, optional): The number of output classes for the classifier.
364
  log_interval (int, optional): The interval (in batches) at which to log the loss.
365
  save_interval (int, optional): The interval (in batches) at which to save model checkpoints.
 
378
  wandb.config.dataset_args = self.dataset_args.model_dump()
379
  wandb.config.model_name = self.model_name
380
  wandb.config.batch_size = batch_size
381
+ wandb.config.starting_lr = starting_lr
382
  wandb.config.num_classes = num_classes
383
  wandb.config.log_interval = log_interval
384
  wandb.config.save_interval = save_interval
 
388
  self.model.num_labels = num_classes
389
  self.model = self.model.to(self.device)
390
  self.model.train()
391
+ # optimizer = optim.AdamW(self.model.parameters(), lr=starting_lr)
392
+ optimizer = optim.Lion(
393
+ self.model.parameters(), lr=starting_lr, weight_decay=0.01
394
+ )
395
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
396
+ optimizer,
397
+ max_lr=starting_lr,
398
+ steps_per_epoch=len(self.train_dataset) // batch_size + 1,
399
+ epochs=1,
400
+ )
401
  data_loader = DataLoader(
402
  self.train_dataset,
403
  batch_size=batch_size,
 
415
  loss = outputs.loss
416
  optimizer.zero_grad()
417
  loss.backward()
418
+
419
+ # torch.nn.utils.clip_grad_norm_(self.model.parameters(), gradient_clipping)
420
+
421
  optimizer.step()
422
+ scheduler.step()
423
  if (i + 1) % log_interval == 0:
424
  wandb.log({"loss": loss.item()}, step=i + 1)
425
+ wandb.log({"learning_rate": scheduler.get_last_lr()[0]}, step=i + 1)
426
  if progress_bar:
427
  progress_percentage = (i + 1) * 100 // len(data_loader)
428
  progress_bar.progress(