ash0ts commited on
Commit
3a97187
·
2 Parent(s): 28e7022 b8c0bf9

Merge branch 'main' into feat/pii-banned-words

Browse files
.gitignore CHANGED
@@ -165,4 +165,7 @@ cursor_prompts/
165
  uv.lock
166
  test.py
167
  temp.txt
168
- **.csv
 
 
 
 
165
  uv.lock
166
  test.py
167
  temp.txt
168
+ **.csv
169
+ binary-classifier/
170
+ wandb/
171
+ artifacts/
README.md CHANGED
@@ -18,7 +18,12 @@ source .venv/bin/activate
18
  ## Run the App
19
 
20
  ```bash
21
- OPENAI_API_KEY="YOUR_OPENAI_API_KEY" streamlit run app.py
 
 
 
 
 
22
  ```
23
 
24
  ## Use the Library
 
18
  ## Run the App
19
 
20
  ```bash
21
+ export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
22
+ export WEAVE_PROJECT="YOUR_WEAVE_PROJECT"
23
+ export WANDB_PROJECT_NAME="YOUR_WANDB_PROJECT_NAME"
24
+ export WANDB_ENTITY_NAME="YOUR_WANDB_ENTITY_NAME"
25
+ export WANDB_LOG_MODEL="checkpoint"
26
+ streamlit run app.py
27
  ```
28
 
29
  ## Use the Library
app.py CHANGED
@@ -4,13 +4,22 @@ intro_page = st.Page(
4
  "application_pages/intro_page.py", title="Introduction", icon=":material/guardian:"
5
  )
6
  chat_page = st.Page(
7
- "application_pages/chat_app.py", title="Chat", icon=":material/robot:"
 
 
8
  )
9
  evaluation_page = st.Page(
10
  "application_pages/evaluation_app.py",
11
  title="Evaluation",
12
  icon=":material/monitoring:",
13
  )
14
- page_navigation = st.navigation([intro_page, chat_page, evaluation_page])
 
 
 
 
 
 
 
15
  st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
16
  page_navigation.run()
 
4
  "application_pages/intro_page.py", title="Introduction", icon=":material/guardian:"
5
  )
6
  chat_page = st.Page(
7
+ "application_pages/chat_app.py",
8
+ title="Playground",
9
+ icon=":material/sports_esports:",
10
  )
11
  evaluation_page = st.Page(
12
  "application_pages/evaluation_app.py",
13
  title="Evaluation",
14
  icon=":material/monitoring:",
15
  )
16
+ train_classifier_page = st.Page(
17
+ "application_pages/train_classifier.py",
18
+ title="Train Classifier",
19
+ icon=":material/fitness_center:",
20
+ )
21
+ page_navigation = st.navigation(
22
+ [intro_page, chat_page, evaluation_page, train_classifier_page]
23
+ )
24
  st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
25
  page_navigation.run()
application_pages/chat_app.py CHANGED
@@ -1,4 +1,5 @@
1
  import importlib
 
2
 
3
  import streamlit as st
4
  import weave
@@ -7,19 +8,27 @@ from dotenv import load_dotenv
7
  from guardrails_genie.guardrails import GuardrailManager
8
  from guardrails_genie.llm import OpenAIModel
9
 
10
- load_dotenv()
11
- weave.init(project_name="guardrails-genie")
12
 
13
- st.title(":material/robot: Guardrails Genie")
14
-
15
- if "guardrails" not in st.session_state:
16
- st.session_state.guardrails = []
17
- if "guardrail_names" not in st.session_state:
18
- st.session_state.guardrail_names = []
19
- if "guardrails_manager" not in st.session_state:
20
- st.session_state.guardrails_manager = None
21
- if "chat_started" not in st.session_state:
22
- st.session_state.chat_started = False
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  def initialize_guardrails():
@@ -36,18 +45,30 @@ def initialize_guardrails():
36
  guardrail_name,
37
  )(llm_model=OpenAIModel(model_name=survey_guardrail_model))
38
  )
39
- else:
40
- st.session_state.guardrails.append(
41
- getattr(
42
- importlib.import_module("guardrails_genie.guardrails"),
43
- guardrail_name,
44
- )()
 
 
45
  )
 
 
 
 
 
 
 
46
  st.session_state.guardrails_manager = GuardrailManager(
47
  guardrails=st.session_state.guardrails
48
  )
49
 
50
 
 
 
 
51
  openai_model = st.sidebar.selectbox(
52
  "OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
53
  )
@@ -67,48 +88,41 @@ guardrail_names = st.sidebar.multiselect(
67
  )
68
  st.session_state.guardrail_names = guardrail_names
69
 
70
- if st.sidebar.button("Start Chat") and chat_condition:
71
- st.session_state.chat_started = True
72
 
73
- if st.session_state.chat_started:
74
  with st.sidebar.status("Initializing Guardrails..."):
75
  initialize_guardrails()
 
76
 
77
- # Initialize chat history
78
- if "messages" not in st.session_state:
79
- st.session_state.messages = []
80
-
81
- llm_model = OpenAIModel(model_name=openai_model)
82
-
83
- # Display chat messages from history on app rerun
84
- for message in st.session_state.messages:
85
- with st.chat_message(message["role"]):
86
- st.markdown(message["content"])
87
 
88
- # React to user input
89
- if prompt := st.chat_input("What is up?"):
90
- # Display user message in chat message container
91
- st.chat_message("user").markdown(prompt)
92
- # Add user message to chat history
93
- st.session_state.messages.append({"role": "user", "content": prompt})
94
 
95
- guardrails_response, call = st.session_state.guardrails_manager.guard.call(
96
- st.session_state.guardrails_manager, prompt=prompt
97
- )
 
 
98
 
99
  if guardrails_response["safe"]:
100
- response, call = llm_model.predict.call(
101
- llm_model, user_prompts=prompt, messages=st.session_state.messages
102
  )
103
- response = response.choices[0].message.content
104
 
105
- # Display assistant response in chat message container
106
- with st.chat_message("assistant"):
107
- st.markdown(response + f"\n\n---\n[Explore in Weave]({call.ui_url})")
108
- # Add assistant response to chat history
109
- st.session_state.messages.append({"role": "assistant", "content": response})
 
 
 
 
110
  else:
111
- st.error("Guardrails detected an issue with the prompt.")
112
- for alert in guardrails_response["alerts"]:
113
- st.error(f"{alert['guardrail_name']}: {alert['response']}")
114
- st.error(f"For details, explore in Weave at {call.ui_url}")
 
1
  import importlib
2
+ import os
3
 
4
  import streamlit as st
5
  import weave
 
8
  from guardrails_genie.guardrails import GuardrailManager
9
  from guardrails_genie.llm import OpenAIModel
10
 
 
 
11
 
12
+ def initialize_session_state():
13
+ load_dotenv()
14
+ weave.init(project_name=os.getenv("WEAVE_PROJECT"))
15
+
16
+ if "guardrails" not in st.session_state:
17
+ st.session_state.guardrails = []
18
+ if "guardrail_names" not in st.session_state:
19
+ st.session_state.guardrail_names = []
20
+ if "guardrails_manager" not in st.session_state:
21
+ st.session_state.guardrails_manager = None
22
+ if "initialize_guardrails" not in st.session_state:
23
+ st.session_state.initialize_guardrails = False
24
+ if "system_prompt" not in st.session_state:
25
+ st.session_state.system_prompt = ""
26
+ if "user_prompt" not in st.session_state:
27
+ st.session_state.user_prompt = ""
28
+ if "test_guardrails" not in st.session_state:
29
+ st.session_state.test_guardrails = False
30
+ if "llm_model" not in st.session_state:
31
+ st.session_state.llm_model = None
32
 
33
 
34
  def initialize_guardrails():
 
45
  guardrail_name,
46
  )(llm_model=OpenAIModel(model_name=survey_guardrail_model))
47
  )
48
+ elif guardrail_name == "PromptInjectionClassifierGuardrail":
49
+ classifier_model_name = st.sidebar.selectbox(
50
+ "Classifier Guardrail Model",
51
+ [
52
+ "",
53
+ "ProtectAI/deberta-v3-base-prompt-injection-v2",
54
+ "wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
55
+ ],
56
  )
57
+ if classifier_model_name != "":
58
+ st.session_state.guardrails.append(
59
+ getattr(
60
+ importlib.import_module("guardrails_genie.guardrails"),
61
+ guardrail_name,
62
+ )(model_name=classifier_model_name)
63
+ )
64
  st.session_state.guardrails_manager = GuardrailManager(
65
  guardrails=st.session_state.guardrails
66
  )
67
 
68
 
69
+ initialize_session_state()
70
+ st.title(":material/robot: Guardrails Genie Playground")
71
+
72
  openai_model = st.sidebar.selectbox(
73
  "OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
74
  )
 
88
  )
89
  st.session_state.guardrail_names = guardrail_names
90
 
91
+ if st.sidebar.button("Initialize Guardrails") and chat_condition:
92
+ st.session_state.initialize_guardrails = True
93
 
94
+ if st.session_state.initialize_guardrails:
95
  with st.sidebar.status("Initializing Guardrails..."):
96
  initialize_guardrails()
97
+ st.session_state.llm_model = OpenAIModel(model_name=openai_model)
98
 
99
+ user_prompt = st.text_area("User Prompt", value="")
100
+ st.session_state.user_prompt = user_prompt
 
 
 
 
 
 
 
 
101
 
102
+ test_guardrails_button = st.button("Test Guardrails")
103
+ st.session_state.test_guardrails = test_guardrails_button
 
 
 
 
104
 
105
+ if st.session_state.test_guardrails:
106
+ with st.sidebar.status("Running Guardrails..."):
107
+ guardrails_response, call = st.session_state.guardrails_manager.guard.call(
108
+ st.session_state.guardrails_manager, prompt=st.session_state.user_prompt
109
+ )
110
 
111
  if guardrails_response["safe"]:
112
+ st.markdown(
113
+ f"\n\n---\nPrompt is safe! Explore guardrail trace on [Weave]({call.ui_url})\n\n---\n"
114
  )
 
115
 
116
+ with st.sidebar.status("Generating response from LLM..."):
117
+ response, call = st.session_state.llm_model.predict.call(
118
+ st.session_state.llm_model,
119
+ user_prompts=st.session_state.user_prompt,
120
+ )
121
+ st.markdown(
122
+ response.choices[0].message.content
123
+ + f"\n\n---\nExplore LLM generation trace on [Weave]({call.ui_url})"
124
+ )
125
  else:
126
+ st.warning("Prompt is not safe!")
127
+ st.markdown(guardrails_response["summary"])
128
+ st.markdown(f"Explore prompt trace on [Weave]({call.ui_url})")
 
application_pages/evaluation_app.py CHANGED
@@ -1,7 +1,10 @@
1
  import asyncio
 
 
2
  from importlib import import_module
3
 
4
  import pandas as pd
 
5
  import streamlit as st
6
  import weave
7
  from dotenv import load_dotenv
@@ -9,12 +12,11 @@ from dotenv import load_dotenv
9
  from guardrails_genie.guardrails import GuardrailManager
10
  from guardrails_genie.llm import OpenAIModel
11
  from guardrails_genie.metrics import AccuracyMetric
12
-
13
- load_dotenv()
14
- weave.init(project_name="guardrails-genie")
15
 
16
 
17
  def initialize_session_state():
 
18
  if "uploaded_file" not in st.session_state:
19
  st.session_state.uploaded_file = None
20
  if "dataset_name" not in st.session_state:
@@ -35,6 +37,18 @@ def initialize_session_state():
35
  st.session_state.evaluation_summary = None
36
  if "guardrail_manager" not in st.session_state:
37
  st.session_state.guardrail_manager = None
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  def initialize_guardrail():
@@ -51,10 +65,22 @@ def initialize_guardrail():
51
  guardrail_name,
52
  )(llm_model=OpenAIModel(model_name=survey_guardrail_model))
53
  )
54
- else:
55
- guardrails.append(
56
- getattr(import_module("guardrails_genie.guardrails"), guardrail_name)()
 
 
 
 
 
57
  )
 
 
 
 
 
 
 
58
  st.session_state.guardrails = guardrails
59
  st.session_state.guardrail_manager = GuardrailManager(guardrails=guardrails)
60
 
@@ -107,6 +133,8 @@ if st.session_state.dataset_previewed:
107
 
108
  if st.session_state.guardrail_names != []:
109
  initialize_guardrail()
 
 
110
  if st.session_state.guardrail_manager is not None:
111
  if st.sidebar.button("Start Evaluation"):
112
  st.session_state.start_evaluation = True
@@ -119,10 +147,55 @@ if st.session_state.dataset_previewed:
119
  with st.expander("Evaluation Results", expanded=True):
120
  evaluation_summary, call = asyncio.run(
121
  evaluation.evaluate.call(
122
- evaluation, st.session_state.guardrail_manager
 
 
 
 
 
123
  )
124
  )
125
- st.markdown(f"[Explore evaluation in Weave]({call.ui_url})")
126
- st.write(evaluation_summary)
127
- st.session_state.evaluation_summary = evaluation_summary
128
- st.session_state.start_evaluation = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
+ import os
3
+ import time
4
  from importlib import import_module
5
 
6
  import pandas as pd
7
+ import rich
8
  import streamlit as st
9
  import weave
10
  from dotenv import load_dotenv
 
12
  from guardrails_genie.guardrails import GuardrailManager
13
  from guardrails_genie.llm import OpenAIModel
14
  from guardrails_genie.metrics import AccuracyMetric
15
+ from guardrails_genie.utils import EvaluationCallManager
 
 
16
 
17
 
18
  def initialize_session_state():
19
+ load_dotenv()
20
  if "uploaded_file" not in st.session_state:
21
  st.session_state.uploaded_file = None
22
  if "dataset_name" not in st.session_state:
 
37
  st.session_state.evaluation_summary = None
38
  if "guardrail_manager" not in st.session_state:
39
  st.session_state.guardrail_manager = None
40
+ if "evaluation_name" not in st.session_state:
41
+ st.session_state.evaluation_name = ""
42
+ if "show_result_table" not in st.session_state:
43
+ st.session_state.show_result_table = False
44
+ if "weave_client" not in st.session_state:
45
+ st.session_state.weave_client = weave.init(
46
+ project_name=os.getenv("WEAVE_PROJECT")
47
+ )
48
+ if "evaluation_call_manager" not in st.session_state:
49
+ st.session_state.evaluation_call_manager = None
50
+ if "call_id" not in st.session_state:
51
+ st.session_state.call_id = None
52
 
53
 
54
  def initialize_guardrail():
 
65
  guardrail_name,
66
  )(llm_model=OpenAIModel(model_name=survey_guardrail_model))
67
  )
68
+ elif guardrail_name == "PromptInjectionClassifierGuardrail":
69
+ classifier_model_name = st.sidebar.selectbox(
70
+ "Classifier Guardrail Model",
71
+ [
72
+ "",
73
+ "ProtectAI/deberta-v3-base-prompt-injection-v2",
74
+ "wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
75
+ ],
76
  )
77
+ if classifier_model_name:
78
+ st.session_state.guardrails.append(
79
+ getattr(
80
+ import_module("guardrails_genie.guardrails"),
81
+ guardrail_name,
82
+ )(model_name=classifier_model_name)
83
+ )
84
  st.session_state.guardrails = guardrails
85
  st.session_state.guardrail_manager = GuardrailManager(guardrails=guardrails)
86
 
 
133
 
134
  if st.session_state.guardrail_names != []:
135
  initialize_guardrail()
136
+ evaluation_name = st.sidebar.text_input("Evaluation name", value="")
137
+ st.session_state.evaluation_name = evaluation_name
138
  if st.session_state.guardrail_manager is not None:
139
  if st.sidebar.button("Start Evaluation"):
140
  st.session_state.start_evaluation = True
 
147
  with st.expander("Evaluation Results", expanded=True):
148
  evaluation_summary, call = asyncio.run(
149
  evaluation.evaluate.call(
150
+ evaluation,
151
+ st.session_state.guardrail_manager,
152
+ __weave={
153
+ "display_name": "Evaluation.evaluate:"
154
+ + st.session_state.evaluation_name
155
+ },
156
  )
157
  )
158
+ x_axis = list(evaluation_summary["AccuracyMetric"].keys())
159
+ y_axis = [
160
+ evaluation_summary["AccuracyMetric"][x_axis_item]
161
+ for x_axis_item in x_axis
162
+ ]
163
+ st.bar_chart(
164
+ pd.DataFrame({"Metric": x_axis, "Score": y_axis}),
165
+ x="Metric",
166
+ y="Score",
167
+ )
168
+ st.session_state.evaluation_summary = evaluation_summary
169
+ st.session_state.call_id = call.id
170
+ st.session_state.start_evaluation = False
171
+
172
+ if not st.session_state.start_evaluation:
173
+ time.sleep(5)
174
+ st.session_state.evaluation_call_manager = (
175
+ EvaluationCallManager(
176
+ entity="geekyrakshit",
177
+ project="guardrails-genie",
178
+ call_id=st.session_state.call_id,
179
+ )
180
+ )
181
+ for guardrail_name in st.session_state.guardrail_names:
182
+ st.session_state.evaluation_call_manager.call_list.append(
183
+ {
184
+ "guardrail_name": guardrail_name,
185
+ "calls": st.session_state.evaluation_call_manager.collect_guardrail_guard_calls_from_eval(),
186
+ }
187
+ )
188
+ rich.print(
189
+ st.session_state.evaluation_call_manager.call_list
190
+ )
191
+ st.dataframe(
192
+ st.session_state.evaluation_call_manager.render_calls_to_streamlit()
193
+ )
194
+ if st.session_state.evaluation_call_manager.show_warning_in_app:
195
+ st.warning(
196
+ f"Only {st.session_state.evaluation_call_manager.max_count} calls can be shown in the app."
197
+ )
198
+ st.markdown(
199
+ f"Explore the entire evaluation trace table in [Weave]({call.ui_url})"
200
+ )
201
+ st.session_state.evaluation_call_manager = None
application_pages/train_classifier.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import streamlit as st
4
+ from dotenv import load_dotenv
5
+
6
+ from guardrails_genie.train_classifier import train_binary_classifier
7
+
8
+
9
+ def initialize_session_state():
10
+ load_dotenv()
11
+ if "dataset_name" not in st.session_state:
12
+ st.session_state.dataset_name = None
13
+ if "base_model_name" not in st.session_state:
14
+ st.session_state.base_model_name = None
15
+ if "batch_size" not in st.session_state:
16
+ st.session_state.batch_size = 16
17
+ if "should_start_training" not in st.session_state:
18
+ st.session_state.should_start_training = False
19
+ if "training_output" not in st.session_state:
20
+ st.session_state.training_output = None
21
+
22
+
23
+ initialize_session_state()
24
+ st.title(":material/fitness_center: Train Classifier")
25
+
26
+ dataset_name = st.sidebar.text_input("Dataset Name", value="")
27
+ st.session_state.dataset_name = dataset_name
28
+
29
+ base_model_name = st.sidebar.selectbox(
30
+ "Base Model",
31
+ options=[
32
+ "distilbert/distilbert-base-uncased",
33
+ "FacebookAI/roberta-base",
34
+ "microsoft/deberta-v3-base",
35
+ ],
36
+ )
37
+ st.session_state.base_model_name = base_model_name
38
+
39
+ batch_size = st.sidebar.slider(
40
+ "Batch Size", min_value=4, max_value=256, value=16, step=4
41
+ )
42
+ st.session_state.batch_size = batch_size
43
+
44
+ train_button = st.sidebar.button("Train")
45
+ st.session_state.should_start_training = (
46
+ train_button and st.session_state.dataset_name and st.session_state.base_model_name
47
+ )
48
+
49
+ if st.session_state.should_start_training:
50
+ with st.expander("Training", expanded=True):
51
+ training_output = train_binary_classifier(
52
+ project_name=os.getenv("WANDB_PROJECT_NAME"),
53
+ entity_name=os.getenv("WANDB_ENTITY_NAME"),
54
+ run_name=f"{st.session_state.base_model_name}-finetuned",
55
+ dataset_repo=st.session_state.dataset_name,
56
+ model_name=st.session_state.base_model_name,
57
+ batch_size=st.session_state.batch_size,
58
+ streamlit_mode=True,
59
+ )
60
+ st.session_state.training_output = training_output
61
+ st.write(training_output)
guardrails_genie/guardrails/__init__.py CHANGED
@@ -1,8 +1,11 @@
1
- from .injection import PromptInjectionProtectAIGuardrail, PromptInjectionSurveyGuardrail
 
 
 
2
  from .manager import GuardrailManager
3
 
4
  __all__ = [
5
  "PromptInjectionSurveyGuardrail",
6
- "PromptInjectionProtectAIGuardrail",
7
  "GuardrailManager",
8
  ]
 
1
+ from .injection import (
2
+ PromptInjectionClassifierGuardrail,
3
+ PromptInjectionSurveyGuardrail,
4
+ )
5
  from .manager import GuardrailManager
6
 
7
  __all__ = [
8
  "PromptInjectionSurveyGuardrail",
9
+ "PromptInjectionClassifierGuardrail",
10
  "GuardrailManager",
11
  ]
guardrails_genie/guardrails/injection/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .protectai_guardrail import PromptInjectionProtectAIGuardrail
2
  from .survey_guardrail import PromptInjectionSurveyGuardrail
3
 
4
- __all__ = ["PromptInjectionSurveyGuardrail", "PromptInjectionProtectAIGuardrail"]
 
1
+ from .classifier_guardrail import PromptInjectionClassifierGuardrail
2
  from .survey_guardrail import PromptInjectionSurveyGuardrail
3
 
4
+ __all__ = ["PromptInjectionSurveyGuardrail", "PromptInjectionClassifierGuardrail"]
guardrails_genie/guardrails/injection/{protectai_guardrail.py → classifier_guardrail.py} RENAMED
@@ -5,16 +5,25 @@ import weave
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from transformers.pipelines.base import Pipeline
7
 
 
 
8
  from ..base import Guardrail
9
 
10
 
11
- class PromptInjectionProtectAIGuardrail(Guardrail):
12
  model_name: str = "ProtectAI/deberta-v3-base-prompt-injection-v2"
13
  _classifier: Optional[Pipeline] = None
14
 
15
  def model_post_init(self, __context):
16
- tokenizer = AutoTokenizer.from_pretrained(self.model_name)
17
- model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
 
 
 
 
 
 
 
18
  self._classifier = pipeline(
19
  "text-classification",
20
  model=model,
@@ -29,10 +38,14 @@ class PromptInjectionProtectAIGuardrail(Guardrail):
29
  return self._classifier(prompt)
30
 
31
  @weave.op()
32
- def predict(self, prompt: str):
33
  response = self.classify(prompt)
34
- return {"safe": response[0]["label"] != "INJECTION"}
 
 
 
 
35
 
36
  @weave.op()
37
- def guard(self, prompt: str):
38
- return self.predict(prompt)
 
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from transformers.pipelines.base import Pipeline
7
 
8
+ import wandb
9
+
10
  from ..base import Guardrail
11
 
12
 
13
+ class PromptInjectionClassifierGuardrail(Guardrail):
14
  model_name: str = "ProtectAI/deberta-v3-base-prompt-injection-v2"
15
  _classifier: Optional[Pipeline] = None
16
 
17
  def model_post_init(self, __context):
18
+ if self.model_name.startswith("wandb://"):
19
+ api = wandb.Api()
20
+ artifact = api.artifact(self.model_name.removeprefix("wandb://"))
21
+ artifact_dir = artifact.download()
22
+ tokenizer = AutoTokenizer.from_pretrained(artifact_dir)
23
+ model = AutoModelForSequenceClassification.from_pretrained(artifact_dir)
24
+ else:
25
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name)
26
+ model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
27
  self._classifier = pipeline(
28
  "text-classification",
29
  model=model,
 
38
  return self._classifier(prompt)
39
 
40
  @weave.op()
41
+ def guard(self, prompt: str):
42
  response = self.classify(prompt)
43
+ confidence_percentage = round(response[0]["score"] * 100, 2)
44
+ return {
45
+ "safe": response[0]["label"] != "INJECTION",
46
+ "summary": f"Prompt is deemed {response[0]['label']} with {confidence_percentage}% confidence.",
47
+ }
48
 
49
  @weave.op()
50
+ def predict(self, prompt: str):
51
+ return self.guard(prompt)
guardrails_genie/guardrails/injection/survey_guardrail.py CHANGED
@@ -70,8 +70,17 @@ Here are some strict instructions that you must follow:
70
  **kwargs,
71
  )
72
  response = chat_completion.choices[0].message.parsed
73
- return {"safe": not response.injection_prompt}
74
 
75
  @weave.op()
76
  def guard(self, prompt: str, **kwargs) -> list[str]:
77
- return self.predict(prompt, **kwargs)
 
 
 
 
 
 
 
 
 
 
70
  **kwargs,
71
  )
72
  response = chat_completion.choices[0].message.parsed
73
+ return response
74
 
75
  @weave.op()
76
  def guard(self, prompt: str, **kwargs) -> list[str]:
77
+ response = self.predict(prompt, **kwargs)
78
+ summary = (
79
+ f"Prompt is deemed safe. {response.explanation}"
80
+ if not response.injection_prompt
81
+ else f"Prompt is deemed a {'direct attack' if response.is_direct_attack else 'indirect attack'} of type {response.attack_type}. {response.explanation}"
82
+ )
83
+ return {
84
+ "safe": not response.injection_prompt,
85
+ "summary": summary,
86
+ }
guardrails_genie/guardrails/manager.py CHANGED
@@ -9,7 +9,7 @@ class GuardrailManager(weave.Model):
9
 
10
  @weave.op()
11
  def guard(self, prompt: str, progress_bar: bool = True, **kwargs) -> dict:
12
- alerts, safe = [], True
13
  iterable = (
14
  track(self.guardrails, description="Running guardrails")
15
  if progress_bar
@@ -21,7 +21,10 @@ class GuardrailManager(weave.Model):
21
  {"guardrail_name": guardrail.__class__.__name__, "response": response}
22
  )
23
  safe = safe and response["safe"]
24
- return {"safe": safe, "alerts": alerts}
 
 
 
25
 
26
  @weave.op()
27
  def predict(self, prompt: str, **kwargs) -> dict:
 
9
 
10
  @weave.op()
11
  def guard(self, prompt: str, progress_bar: bool = True, **kwargs) -> dict:
12
+ alerts, summaries, safe = [], "", True
13
  iterable = (
14
  track(self.guardrails, description="Running guardrails")
15
  if progress_bar
 
21
  {"guardrail_name": guardrail.__class__.__name__, "response": response}
22
  )
23
  safe = safe and response["safe"]
24
+ summaries += (
25
+ f"**{guardrail.__class__.__name__}**: {response['summary']}\n\n---\n\n"
26
+ )
27
+ return {"safe": safe, "alerts": alerts, "summary": summaries}
28
 
29
  @weave.op()
30
  def predict(self, prompt: str, **kwargs) -> dict:
guardrails_genie/train_classifier.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ import numpy as np
3
+ import streamlit as st
4
+ from datasets import load_dataset
5
+ from transformers import (
6
+ AutoModelForSequenceClassification,
7
+ AutoTokenizer,
8
+ DataCollatorWithPadding,
9
+ Trainer,
10
+ TrainerCallback,
11
+ TrainingArguments,
12
+ )
13
+ from transformers.trainer_callback import TrainerControl, TrainerState
14
+
15
+ import wandb
16
+
17
+
18
+ class StreamlitProgressbarCallback(TrainerCallback):
19
+
20
+ def __init__(self, *args, **kwargs):
21
+ super().__init__(*args, **kwargs)
22
+ self.progress_bar = st.progress(0, text="Training")
23
+
24
+ def on_step_begin(
25
+ self,
26
+ args: TrainingArguments,
27
+ state: TrainerState,
28
+ control: TrainerControl,
29
+ **kwargs,
30
+ ):
31
+ super().on_step_begin(args, state, control, **kwargs)
32
+ self.progress_bar.progress(
33
+ (state.global_step * 100 // state.max_steps) + 1,
34
+ text=f"Training {state.global_step} / {state.max_steps}",
35
+ )
36
+
37
+
38
+ def train_binary_classifier(
39
+ project_name: str,
40
+ entity_name: str,
41
+ run_name: str,
42
+ dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
43
+ model_name: str = "distilbert/distilbert-base-uncased",
44
+ learning_rate: float = 2e-5,
45
+ batch_size: int = 16,
46
+ num_epochs: int = 2,
47
+ weight_decay: float = 0.01,
48
+ streamlit_mode: bool = False,
49
+ ):
50
+ wandb.init(project=project_name, entity=entity_name, name=run_name)
51
+ if streamlit_mode:
52
+ st.markdown(
53
+ f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
54
+ )
55
+ dataset = load_dataset(dataset_repo)
56
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
57
+
58
+ def preprocess_function(examples):
59
+ return tokenizer(examples["prompt"], truncation=True)
60
+
61
+ tokenized_datasets = dataset.map(preprocess_function, batched=True)
62
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
63
+ accuracy = evaluate.load("accuracy")
64
+
65
+ def compute_metrics(eval_pred):
66
+ predictions, labels = eval_pred
67
+ predictions = np.argmax(predictions, axis=1)
68
+ return accuracy.compute(predictions=predictions, references=labels)
69
+
70
+ id2label = {0: "SAFE", 1: "INJECTION"}
71
+ label2id = {"SAFE": 0, "INJECTION": 1}
72
+
73
+ model = AutoModelForSequenceClassification.from_pretrained(
74
+ model_name,
75
+ num_labels=2,
76
+ id2label=id2label,
77
+ label2id=label2id,
78
+ )
79
+
80
+ trainer = Trainer(
81
+ model=model,
82
+ args=TrainingArguments(
83
+ output_dir="binary-classifier",
84
+ learning_rate=learning_rate,
85
+ per_device_train_batch_size=batch_size,
86
+ per_device_eval_batch_size=batch_size,
87
+ num_train_epochs=num_epochs,
88
+ weight_decay=weight_decay,
89
+ eval_strategy="epoch",
90
+ save_strategy="epoch",
91
+ load_best_model_at_end=True,
92
+ push_to_hub=True,
93
+ report_to="wandb",
94
+ logging_strategy="steps",
95
+ logging_steps=1,
96
+ ),
97
+ train_dataset=tokenized_datasets["train"],
98
+ eval_dataset=tokenized_datasets["test"],
99
+ processing_class=tokenizer,
100
+ data_collator=data_collator,
101
+ compute_metrics=compute_metrics,
102
+ callbacks=[StreamlitProgressbarCallback()] if streamlit_mode else [],
103
+ )
104
+ try:
105
+ training_output = trainer.train()
106
+ except Exception as e:
107
+ wandb.finish()
108
+ raise e
109
+ wandb.finish()
110
+ return training_output
guardrails_genie/utils.py CHANGED
@@ -1,7 +1,9 @@
1
  import os
2
 
 
3
  import pymupdf4llm
4
  import weave
 
5
  from firerequests import FireRequests
6
 
7
 
@@ -11,3 +13,47 @@ def get_markdown_from_pdf_url(url: str) -> str:
11
  markdown = pymupdf4llm.to_markdown("temp.pdf", show_progress=False)
12
  os.remove("temp.pdf")
13
  return markdown
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
 
3
+ import pandas as pd
4
  import pymupdf4llm
5
  import weave
6
+ import weave.trace
7
  from firerequests import FireRequests
8
 
9
 
 
13
  markdown = pymupdf4llm.to_markdown("temp.pdf", show_progress=False)
14
  os.remove("temp.pdf")
15
  return markdown
16
+
17
+
18
+ class EvaluationCallManager:
19
+ def __init__(self, entity: str, project: str, call_id: str, max_count: int = 10):
20
+ self.base_call = weave.init(f"{entity}/{project}").get_call(call_id=call_id)
21
+ self.max_count = max_count
22
+ self.show_warning_in_app = False
23
+ self.call_list = []
24
+
25
+ def collect_guardrail_guard_calls_from_eval(self):
26
+ guard_calls, count = [], 0
27
+ for eval_predict_and_score_call in self.base_call.children():
28
+ if "Evaluation.summarize" in eval_predict_and_score_call._op_name:
29
+ break
30
+ guardrail_predict_call = eval_predict_and_score_call.children()[0]
31
+ guard_call = guardrail_predict_call.children()[0]
32
+ score_call = eval_predict_and_score_call.children()[1]
33
+ guard_calls.append(
34
+ {
35
+ "input_prompt": str(guard_call.inputs["prompt"]),
36
+ "outputs": dict(guard_call.output),
37
+ "score": dict(score_call.output),
38
+ }
39
+ )
40
+ count += 1
41
+ if count >= self.max_count:
42
+ self.show_warning_in_app = True
43
+ break
44
+ return guard_calls
45
+
46
+ def render_calls_to_streamlit(self):
47
+ dataframe = {
48
+ "input_prompt": [
49
+ call["input_prompt"] for call in self.call_list[0]["calls"]
50
+ ]
51
+ }
52
+ for guardrail_call in self.call_list:
53
+ dataframe[guardrail_call["guardrail_name"] + ".safe"] = [
54
+ call["outputs"]["safe"] for call in guardrail_call["calls"]
55
+ ]
56
+ dataframe[guardrail_call["guardrail_name"] + ".prediction_correctness"] = [
57
+ call["score"]["correct"] for call in guardrail_call["calls"]
58
+ ]
59
+ return pd.DataFrame(dataframe)
pyproject.toml CHANGED
@@ -12,7 +12,7 @@ dependencies = [
12
  "ruff>=0.6.9",
13
  "pip>=24.2",
14
  "uv>=0.4.20",
15
- "git+https://github.com/wandb/weave@feat/eval-progressbar",
16
  "streamlit>=1.40.1",
17
  "python-dotenv>=1.0.1",
18
  "watchdog>=6.0.0",
@@ -23,4 +23,4 @@ dependencies = [
23
  ]
24
 
25
  [tool.setuptools]
26
- py-modules = ["guardrails_genie"]
 
12
  "ruff>=0.6.9",
13
  "pip>=24.2",
14
  "uv>=0.4.20",
15
+ "weave @ git+https://github.com/wandb/weave@feat/eval-progressbar",
16
  "streamlit>=1.40.1",
17
  "python-dotenv>=1.0.1",
18
  "watchdog>=6.0.0",
 
23
  ]
24
 
25
  [tool.setuptools]
26
+ py-modules = ["guardrails_genie"]