Spaces:
Running
Running
geekyrakshit
commited on
Commit
·
78a1bf0
1
Parent(s):
36c3c0f
add: docs for AccuracyMetric
Browse files- guardrails_genie/guardrails/__init__.py +5 -5
- guardrails_genie/guardrails/entity_recognition/__init__.py +7 -4
- guardrails_genie/guardrails/entity_recognition/banned_terms_examples/banned_term_examples.py +64 -32
- guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_llm_judge.py +12 -10
- guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_regex_model.py +12 -8
- guardrails_genie/guardrails/entity_recognition/llm_judge_entity_recognition_guardrail.py +45 -21
- guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark.py +139 -81
- guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark_weave.py +141 -90
- guardrails_genie/guardrails/entity_recognition/pii_examples/pii_test_examples.py +27 -23
- guardrails_genie/guardrails/entity_recognition/pii_examples/run_presidio_model.py +13 -5
- guardrails_genie/guardrails/entity_recognition/pii_examples/run_regex_model.py +13 -5
- guardrails_genie/guardrails/entity_recognition/pii_examples/run_transformers.py +20 -5
- guardrails_genie/guardrails/entity_recognition/presidio_entity_recognition_guardrail.py +60 -46
- guardrails_genie/guardrails/entity_recognition/regex_entity_recognition_guardrail.py +69 -41
- guardrails_genie/guardrails/entity_recognition/transformers_entity_recognition_guardrail.py +81 -47
- guardrails_genie/guardrails/manager.py +1 -1
- guardrails_genie/metrics.py +44 -1
- guardrails_genie/regex_model.py +13 -10
- guardrails_genie/utils.py +12 -11
guardrails_genie/guardrails/__init__.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
from .injection import (
|
2 |
-
PromptInjectionClassifierGuardrail,
|
3 |
-
PromptInjectionSurveyGuardrail,
|
4 |
-
)
|
5 |
from .entity_recognition import (
|
6 |
PresidioEntityRecognitionGuardrail,
|
7 |
RegexEntityRecognitionGuardrail,
|
8 |
-
TransformersEntityRecognitionGuardrail,
|
9 |
RestrictedTermsJudge,
|
|
|
|
|
|
|
|
|
|
|
10 |
)
|
11 |
from .manager import GuardrailManager
|
12 |
|
|
|
|
|
|
|
|
|
|
|
1 |
from .entity_recognition import (
|
2 |
PresidioEntityRecognitionGuardrail,
|
3 |
RegexEntityRecognitionGuardrail,
|
|
|
4 |
RestrictedTermsJudge,
|
5 |
+
TransformersEntityRecognitionGuardrail,
|
6 |
+
)
|
7 |
+
from .injection import (
|
8 |
+
PromptInjectionClassifierGuardrail,
|
9 |
+
PromptInjectionSurveyGuardrail,
|
10 |
)
|
11 |
from .manager import GuardrailManager
|
12 |
|
guardrails_genie/guardrails/entity_recognition/__init__.py
CHANGED
@@ -1,10 +1,13 @@
|
|
1 |
-
from .presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
|
2 |
-
from .regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
|
3 |
-
from .transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
|
4 |
from .llm_judge_entity_recognition_guardrail import RestrictedTermsJudge
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
__all__ = [
|
6 |
"PresidioEntityRecognitionGuardrail",
|
7 |
"RegexEntityRecognitionGuardrail",
|
8 |
"TransformersEntityRecognitionGuardrail",
|
9 |
-
"RestrictedTermsJudge"
|
10 |
]
|
|
|
|
|
|
|
|
|
1 |
from .llm_judge_entity_recognition_guardrail import RestrictedTermsJudge
|
2 |
+
from .presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
|
3 |
+
from .regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
|
4 |
+
from .transformers_entity_recognition_guardrail import (
|
5 |
+
TransformersEntityRecognitionGuardrail,
|
6 |
+
)
|
7 |
+
|
8 |
__all__ = [
|
9 |
"PresidioEntityRecognitionGuardrail",
|
10 |
"RegexEntityRecognitionGuardrail",
|
11 |
"TransformersEntityRecognitionGuardrail",
|
12 |
+
"RestrictedTermsJudge",
|
13 |
]
|
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/banned_term_examples.py
CHANGED
@@ -11,15 +11,22 @@ I think we should implement features similar to Salesforce's Einstein AI
|
|
11 |
and Oracle's Cloud Infrastructure. Maybe we could also look at how
|
12 |
AWS handles their lambda functions.
|
13 |
""",
|
14 |
-
"custom_terms": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
"expected_entities": {
|
16 |
"Salesforce": ["Salesforce"],
|
17 |
"Oracle": ["Oracle"],
|
18 |
"AWS": ["AWS"],
|
19 |
"Einstein AI": ["Einstein AI"],
|
20 |
"Cloud Infrastructure": ["Cloud Infrastructure"],
|
21 |
-
"lambda": ["lambda"]
|
22 |
-
}
|
23 |
},
|
24 |
{
|
25 |
"description": "Inappropriate Language in Support Ticket",
|
@@ -32,8 +39,8 @@ stupid service? I've wasted so much freaking time on this crap.
|
|
32 |
"damn": ["damn"],
|
33 |
"hell": ["hell"],
|
34 |
"stupid": ["stupid"],
|
35 |
-
"crap": ["crap"]
|
36 |
-
}
|
37 |
},
|
38 |
{
|
39 |
"description": "Confidential Project Names",
|
@@ -45,9 +52,9 @@ with Project Phoenix team and the Blue Dragon initiative for resource allocation
|
|
45 |
"expected_entities": {
|
46 |
"Project Titan": ["Project Titan"],
|
47 |
"Project Phoenix": ["Project Phoenix"],
|
48 |
-
"Blue Dragon": ["Blue Dragon"]
|
49 |
-
}
|
50 |
-
}
|
51 |
]
|
52 |
|
53 |
# Edge cases and special formats
|
@@ -59,15 +66,22 @@ MSFT's Azure and O365 platform is gaining market share.
|
|
59 |
Have you seen what GOOGL/GOOG and FB/META are doing with their AI?
|
60 |
CRM (Salesforce) and ORCL (Oracle) have interesting features too.
|
61 |
""",
|
62 |
-
"custom_terms": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
"expected_entities": {
|
64 |
"Microsoft": ["MSFT"],
|
65 |
"Google": ["GOOGL", "GOOG"],
|
66 |
"Meta": ["META"],
|
67 |
"Facebook": ["FB"],
|
68 |
"Salesforce": ["CRM", "Salesforce"],
|
69 |
-
"Oracle": ["ORCL"]
|
70 |
-
}
|
71 |
},
|
72 |
{
|
73 |
"description": "L33t Speak and Intentional Obfuscation",
|
@@ -76,15 +90,22 @@ S4l3sf0rc3 is better than 0r4cl3!
|
|
76 |
M1cr0$oft and G00gl3 are the main competitors.
|
77 |
Let's check F8book and Met@ too.
|
78 |
""",
|
79 |
-
"custom_terms": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
"expected_entities": {
|
81 |
"Salesforce": ["S4l3sf0rc3"],
|
82 |
"Oracle": ["0r4cl3"],
|
83 |
"Microsoft": ["M1cr0$oft"],
|
84 |
"Google": ["G00gl3"],
|
85 |
"Facebook": ["F8book"],
|
86 |
-
"Meta": ["Met@"]
|
87 |
-
}
|
88 |
},
|
89 |
{
|
90 |
"description": "Case Variations and Partial Matches",
|
@@ -98,8 +119,8 @@ Have you tried micro-soft or Google_Cloud?
|
|
98 |
"Microsoft": ["MicroSoft", "micro-soft"],
|
99 |
"Google": ["google", "Google_Cloud"],
|
100 |
"Salesforce": ["salesFORCE"],
|
101 |
-
"Oracle": ["ORACLE"]
|
102 |
-
}
|
103 |
},
|
104 |
{
|
105 |
"description": "Common Misspellings and Typos",
|
@@ -113,8 +134,8 @@ Salezforce and Oracel need checking too.
|
|
113 |
"Microsoft": ["Microsft", "Microsooft"],
|
114 |
"Google": ["Goggle", "Googel", "Gooogle"],
|
115 |
"Salesforce": ["Salezforce"],
|
116 |
-
"Oracle": ["Oracel"]
|
117 |
-
}
|
118 |
},
|
119 |
{
|
120 |
"description": "Mixed Variations and Context",
|
@@ -123,7 +144,15 @@ The M$ cloud competes with AWS (Amazon Web Services).
|
|
123 |
FB/Meta's social platform and GOOGL's search dominate.
|
124 |
SF.com and Oracle-DB are industry standards.
|
125 |
""",
|
126 |
-
"custom_terms": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
"expected_entities": {
|
128 |
"Microsoft": ["M$"],
|
129 |
"Amazon Web Services": ["AWS"],
|
@@ -131,37 +160,40 @@ SF.com and Oracle-DB are industry standards.
|
|
131 |
"Meta": ["Meta"],
|
132 |
"Google": ["GOOGL"],
|
133 |
"Salesforce": ["SF.com"],
|
134 |
-
"Oracle": ["Oracle-DB"]
|
135 |
-
}
|
136 |
-
}
|
137 |
]
|
138 |
|
|
|
139 |
def validate_entities(detected: dict, expected: dict) -> bool:
|
140 |
"""Compare detected entities with expected entities"""
|
141 |
if set(detected.keys()) != set(expected.keys()):
|
142 |
return False
|
143 |
return all(set(detected[k]) == set(expected[k]) for k in expected.keys())
|
144 |
|
|
|
145 |
def run_test_case(guardrail, test_case, test_type="Main"):
|
146 |
"""Run a single test case and print results"""
|
147 |
print(f"\n{test_type} Test Case: {test_case['description']}")
|
148 |
print("-" * 50)
|
149 |
-
|
150 |
result = guardrail.guard(
|
151 |
-
test_case[
|
152 |
-
custom_terms=test_case['custom_terms']
|
153 |
)
|
154 |
-
expected = test_case[
|
155 |
-
|
156 |
# Validate results
|
157 |
matches = validate_entities(result.detected_entities, expected)
|
158 |
-
|
159 |
print(f"Test Status: {'✓ PASS' if matches else '✗ FAIL'}")
|
160 |
print(f"Contains Restricted Terms: {result.contains_entities}")
|
161 |
-
|
162 |
if not matches:
|
163 |
print("\nEntity Comparison:")
|
164 |
-
all_entity_types = set(
|
|
|
|
|
165 |
for entity_type in all_entity_types:
|
166 |
detected = set(result.detected_entities.get(entity_type, []))
|
167 |
expected_set = set(expected.get(entity_type, []))
|
@@ -171,8 +203,8 @@ def run_test_case(guardrail, test_case, test_type="Main"):
|
|
171 |
if detected != expected_set:
|
172 |
print(f" Missing: {sorted(expected_set - detected)}")
|
173 |
print(f" Extra: {sorted(detected - expected_set)}")
|
174 |
-
|
175 |
if result.anonymized_text:
|
176 |
print(f"\nAnonymized Text:\n{result.anonymized_text}")
|
177 |
-
|
178 |
return matches
|
|
|
11 |
and Oracle's Cloud Infrastructure. Maybe we could also look at how
|
12 |
AWS handles their lambda functions.
|
13 |
""",
|
14 |
+
"custom_terms": [
|
15 |
+
"Salesforce",
|
16 |
+
"Oracle",
|
17 |
+
"AWS",
|
18 |
+
"Einstein AI",
|
19 |
+
"Cloud Infrastructure",
|
20 |
+
"lambda",
|
21 |
+
],
|
22 |
"expected_entities": {
|
23 |
"Salesforce": ["Salesforce"],
|
24 |
"Oracle": ["Oracle"],
|
25 |
"AWS": ["AWS"],
|
26 |
"Einstein AI": ["Einstein AI"],
|
27 |
"Cloud Infrastructure": ["Cloud Infrastructure"],
|
28 |
+
"lambda": ["lambda"],
|
29 |
+
},
|
30 |
},
|
31 |
{
|
32 |
"description": "Inappropriate Language in Support Ticket",
|
|
|
39 |
"damn": ["damn"],
|
40 |
"hell": ["hell"],
|
41 |
"stupid": ["stupid"],
|
42 |
+
"crap": ["crap"],
|
43 |
+
},
|
44 |
},
|
45 |
{
|
46 |
"description": "Confidential Project Names",
|
|
|
52 |
"expected_entities": {
|
53 |
"Project Titan": ["Project Titan"],
|
54 |
"Project Phoenix": ["Project Phoenix"],
|
55 |
+
"Blue Dragon": ["Blue Dragon"],
|
56 |
+
},
|
57 |
+
},
|
58 |
]
|
59 |
|
60 |
# Edge cases and special formats
|
|
|
66 |
Have you seen what GOOGL/GOOG and FB/META are doing with their AI?
|
67 |
CRM (Salesforce) and ORCL (Oracle) have interesting features too.
|
68 |
""",
|
69 |
+
"custom_terms": [
|
70 |
+
"Microsoft",
|
71 |
+
"Google",
|
72 |
+
"Meta",
|
73 |
+
"Facebook",
|
74 |
+
"Salesforce",
|
75 |
+
"Oracle",
|
76 |
+
],
|
77 |
"expected_entities": {
|
78 |
"Microsoft": ["MSFT"],
|
79 |
"Google": ["GOOGL", "GOOG"],
|
80 |
"Meta": ["META"],
|
81 |
"Facebook": ["FB"],
|
82 |
"Salesforce": ["CRM", "Salesforce"],
|
83 |
+
"Oracle": ["ORCL"],
|
84 |
+
},
|
85 |
},
|
86 |
{
|
87 |
"description": "L33t Speak and Intentional Obfuscation",
|
|
|
90 |
M1cr0$oft and G00gl3 are the main competitors.
|
91 |
Let's check F8book and Met@ too.
|
92 |
""",
|
93 |
+
"custom_terms": [
|
94 |
+
"Salesforce",
|
95 |
+
"Oracle",
|
96 |
+
"Microsoft",
|
97 |
+
"Google",
|
98 |
+
"Facebook",
|
99 |
+
"Meta",
|
100 |
+
],
|
101 |
"expected_entities": {
|
102 |
"Salesforce": ["S4l3sf0rc3"],
|
103 |
"Oracle": ["0r4cl3"],
|
104 |
"Microsoft": ["M1cr0$oft"],
|
105 |
"Google": ["G00gl3"],
|
106 |
"Facebook": ["F8book"],
|
107 |
+
"Meta": ["Met@"],
|
108 |
+
},
|
109 |
},
|
110 |
{
|
111 |
"description": "Case Variations and Partial Matches",
|
|
|
119 |
"Microsoft": ["MicroSoft", "micro-soft"],
|
120 |
"Google": ["google", "Google_Cloud"],
|
121 |
"Salesforce": ["salesFORCE"],
|
122 |
+
"Oracle": ["ORACLE"],
|
123 |
+
},
|
124 |
},
|
125 |
{
|
126 |
"description": "Common Misspellings and Typos",
|
|
|
134 |
"Microsoft": ["Microsft", "Microsooft"],
|
135 |
"Google": ["Goggle", "Googel", "Gooogle"],
|
136 |
"Salesforce": ["Salezforce"],
|
137 |
+
"Oracle": ["Oracel"],
|
138 |
+
},
|
139 |
},
|
140 |
{
|
141 |
"description": "Mixed Variations and Context",
|
|
|
144 |
FB/Meta's social platform and GOOGL's search dominate.
|
145 |
SF.com and Oracle-DB are industry standards.
|
146 |
""",
|
147 |
+
"custom_terms": [
|
148 |
+
"Microsoft",
|
149 |
+
"Amazon Web Services",
|
150 |
+
"Facebook",
|
151 |
+
"Meta",
|
152 |
+
"Google",
|
153 |
+
"Salesforce",
|
154 |
+
"Oracle",
|
155 |
+
],
|
156 |
"expected_entities": {
|
157 |
"Microsoft": ["M$"],
|
158 |
"Amazon Web Services": ["AWS"],
|
|
|
160 |
"Meta": ["Meta"],
|
161 |
"Google": ["GOOGL"],
|
162 |
"Salesforce": ["SF.com"],
|
163 |
+
"Oracle": ["Oracle-DB"],
|
164 |
+
},
|
165 |
+
},
|
166 |
]
|
167 |
|
168 |
+
|
169 |
def validate_entities(detected: dict, expected: dict) -> bool:
|
170 |
"""Compare detected entities with expected entities"""
|
171 |
if set(detected.keys()) != set(expected.keys()):
|
172 |
return False
|
173 |
return all(set(detected[k]) == set(expected[k]) for k in expected.keys())
|
174 |
|
175 |
+
|
176 |
def run_test_case(guardrail, test_case, test_type="Main"):
|
177 |
"""Run a single test case and print results"""
|
178 |
print(f"\n{test_type} Test Case: {test_case['description']}")
|
179 |
print("-" * 50)
|
180 |
+
|
181 |
result = guardrail.guard(
|
182 |
+
test_case["input_text"], custom_terms=test_case["custom_terms"]
|
|
|
183 |
)
|
184 |
+
expected = test_case["expected_entities"]
|
185 |
+
|
186 |
# Validate results
|
187 |
matches = validate_entities(result.detected_entities, expected)
|
188 |
+
|
189 |
print(f"Test Status: {'✓ PASS' if matches else '✗ FAIL'}")
|
190 |
print(f"Contains Restricted Terms: {result.contains_entities}")
|
191 |
+
|
192 |
if not matches:
|
193 |
print("\nEntity Comparison:")
|
194 |
+
all_entity_types = set(
|
195 |
+
list(result.detected_entities.keys()) + list(expected.keys())
|
196 |
+
)
|
197 |
for entity_type in all_entity_types:
|
198 |
detected = set(result.detected_entities.get(entity_type, []))
|
199 |
expected_set = set(expected.get(entity_type, []))
|
|
|
203 |
if detected != expected_set:
|
204 |
print(f" Missing: {sorted(expected_set - detected)}")
|
205 |
print(f" Extra: {sorted(detected - expected_set)}")
|
206 |
+
|
207 |
if result.anonymized_text:
|
208 |
print(f"\nAnonymized Text:\n{result.anonymized_text}")
|
209 |
+
|
210 |
return matches
|
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_llm_judge.py
CHANGED
@@ -1,21 +1,22 @@
|
|
1 |
-
|
|
|
2 |
from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_examples import (
|
3 |
-
|
4 |
-
|
5 |
-
run_test_case
|
|
|
|
|
|
|
6 |
)
|
7 |
from guardrails_genie.llm import OpenAIModel
|
8 |
-
|
9 |
|
10 |
def test_restricted_terms_detection():
|
11 |
"""Test restricted terms detection scenarios using predefined test cases"""
|
12 |
weave.init("guardrails-genie-restricted-terms-llm-judge")
|
13 |
-
|
14 |
# Create the guardrail with OpenAI model
|
15 |
-
llm_judge = RestrictedTermsJudge(
|
16 |
-
should_anonymize=True,
|
17 |
-
llm_model=OpenAIModel()
|
18 |
-
)
|
19 |
|
20 |
# Test statistics
|
21 |
total_tests = len(RESTRICTED_TERMS_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
|
@@ -43,5 +44,6 @@ def test_restricted_terms_detection():
|
|
43 |
print(f"Failed: {total_tests - passed_tests}")
|
44 |
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
45 |
|
|
|
46 |
if __name__ == "__main__":
|
47 |
test_restricted_terms_detection()
|
|
|
1 |
+
import weave
|
2 |
+
|
3 |
from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_examples import (
|
4 |
+
EDGE_CASE_EXAMPLES,
|
5 |
+
RESTRICTED_TERMS_EXAMPLES,
|
6 |
+
run_test_case,
|
7 |
+
)
|
8 |
+
from guardrails_genie.guardrails.entity_recognition.llm_judge_entity_recognition_guardrail import (
|
9 |
+
RestrictedTermsJudge,
|
10 |
)
|
11 |
from guardrails_genie.llm import OpenAIModel
|
12 |
+
|
13 |
|
14 |
def test_restricted_terms_detection():
|
15 |
"""Test restricted terms detection scenarios using predefined test cases"""
|
16 |
weave.init("guardrails-genie-restricted-terms-llm-judge")
|
17 |
+
|
18 |
# Create the guardrail with OpenAI model
|
19 |
+
llm_judge = RestrictedTermsJudge(should_anonymize=True, llm_model=OpenAIModel())
|
|
|
|
|
|
|
20 |
|
21 |
# Test statistics
|
22 |
total_tests = len(RESTRICTED_TERMS_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
|
|
|
44 |
print(f"Failed: {total_tests - passed_tests}")
|
45 |
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
46 |
|
47 |
+
|
48 |
if __name__ == "__main__":
|
49 |
test_restricted_terms_detection()
|
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_regex_model.py
CHANGED
@@ -1,19 +1,22 @@
|
|
1 |
-
|
|
|
2 |
from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_examples import (
|
3 |
-
|
4 |
-
|
5 |
-
run_test_case
|
|
|
|
|
|
|
6 |
)
|
7 |
-
|
8 |
|
9 |
def test_restricted_terms_detection():
|
10 |
"""Test restricted terms detection scenarios using predefined test cases"""
|
11 |
weave.init("guardrails-genie-restricted-terms-regex-model")
|
12 |
-
|
13 |
# Create the guardrail with anonymization enabled
|
14 |
regex_guardrail = RegexEntityRecognitionGuardrail(
|
15 |
-
use_defaults=False, # Don't use default PII patterns
|
16 |
-
should_anonymize=True
|
17 |
)
|
18 |
|
19 |
# Test statistics
|
@@ -42,5 +45,6 @@ def test_restricted_terms_detection():
|
|
42 |
print(f"Failed: {total_tests - passed_tests}")
|
43 |
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
44 |
|
|
|
45 |
if __name__ == "__main__":
|
46 |
test_restricted_terms_detection()
|
|
|
1 |
+
import weave
|
2 |
+
|
3 |
from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_examples import (
|
4 |
+
EDGE_CASE_EXAMPLES,
|
5 |
+
RESTRICTED_TERMS_EXAMPLES,
|
6 |
+
run_test_case,
|
7 |
+
)
|
8 |
+
from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import (
|
9 |
+
RegexEntityRecognitionGuardrail,
|
10 |
)
|
11 |
+
|
12 |
|
13 |
def test_restricted_terms_detection():
|
14 |
"""Test restricted terms detection scenarios using predefined test cases"""
|
15 |
weave.init("guardrails-genie-restricted-terms-regex-model")
|
16 |
+
|
17 |
# Create the guardrail with anonymization enabled
|
18 |
regex_guardrail = RegexEntityRecognitionGuardrail(
|
19 |
+
use_defaults=False, should_anonymize=True # Don't use default PII patterns
|
|
|
20 |
)
|
21 |
|
22 |
# Test statistics
|
|
|
45 |
print(f"Failed: {total_tests - passed_tests}")
|
46 |
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
47 |
|
48 |
+
|
49 |
if __name__ == "__main__":
|
50 |
test_restricted_terms_detection()
|
guardrails_genie/guardrails/entity_recognition/llm_judge_entity_recognition_guardrail.py
CHANGED
@@ -1,15 +1,16 @@
|
|
1 |
from typing import Dict, List, Optional
|
|
|
|
|
2 |
import weave
|
3 |
from pydantic import BaseModel, Field
|
4 |
-
from typing_extensions import Annotated
|
5 |
|
6 |
from ...llm import OpenAIModel
|
7 |
from ..base import Guardrail
|
8 |
-
import instructor
|
9 |
|
10 |
|
11 |
class TermMatch(BaseModel):
|
12 |
"""Represents a matched term and its variations"""
|
|
|
13 |
original_term: str
|
14 |
matched_text: str
|
15 |
match_type: str = Field(
|
@@ -22,19 +23,18 @@ class TermMatch(BaseModel):
|
|
22 |
|
23 |
class RestrictedTermsAnalysis(BaseModel):
|
24 |
"""Analysis result for restricted terms detection"""
|
|
|
25 |
contains_restricted_terms: bool = Field(
|
26 |
description="Whether any restricted terms were detected"
|
27 |
)
|
28 |
detected_matches: List[TermMatch] = Field(
|
29 |
default_factory=list,
|
30 |
-
description="List of detected term matches with their variations"
|
31 |
-
)
|
32 |
-
explanation: str = Field(
|
33 |
-
description="Detailed explanation of the analysis"
|
34 |
)
|
|
|
35 |
anonymized_text: Optional[str] = Field(
|
36 |
default=None,
|
37 |
-
description="Text with restricted terms replaced with category tags"
|
38 |
)
|
39 |
|
40 |
@property
|
@@ -106,39 +106,57 @@ Return your analysis in the structured format specified by the RestrictedTermsAn
|
|
106 |
return user_prompt, system_prompt
|
107 |
|
108 |
@weave.op()
|
109 |
-
def predict(
|
|
|
|
|
110 |
user_prompt, system_prompt = self.format_prompts(text, custom_terms)
|
111 |
-
|
112 |
response = self.llm_model.predict(
|
113 |
user_prompts=user_prompt,
|
114 |
system_prompt=system_prompt,
|
115 |
response_format=RestrictedTermsAnalysis,
|
116 |
temperature=0.1, # Lower temperature for more consistent analysis
|
117 |
-
**kwargs
|
118 |
)
|
119 |
-
|
120 |
return response.choices[0].message.parsed
|
121 |
|
122 |
-
#TODO: Remove default custom_terms
|
123 |
@weave.op()
|
124 |
-
def guard(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
"""
|
126 |
Guard against restricted terms and their variations.
|
127 |
-
|
128 |
Args:
|
129 |
text: Text to analyze
|
130 |
custom_terms: List of restricted terms to check for
|
131 |
-
|
132 |
Returns:
|
133 |
RestrictedTermsRecognitionResponse containing safety assessment and detailed analysis
|
134 |
"""
|
135 |
analysis = self.predict(text, custom_terms, **kwargs)
|
136 |
-
|
137 |
# Create a summary of findings
|
138 |
if analysis.contains_restricted_terms:
|
139 |
summary_parts = ["Restricted terms detected:"]
|
140 |
for match in analysis.detected_matches:
|
141 |
-
summary_parts.append(
|
|
|
|
|
142 |
summary = "\n".join(summary_parts)
|
143 |
else:
|
144 |
summary = "No restricted terms detected."
|
@@ -148,8 +166,14 @@ Return your analysis in the structured format specified by the RestrictedTermsAn
|
|
148 |
if self.should_anonymize and analysis.contains_restricted_terms:
|
149 |
anonymized_text = text
|
150 |
for match in analysis.detected_matches:
|
151 |
-
replacement =
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
# Convert detected_matches to a dictionary format
|
155 |
detected_entities = {}
|
@@ -162,5 +186,5 @@ Return your analysis in the structured format specified by the RestrictedTermsAn
|
|
162 |
contains_entities=analysis.contains_restricted_terms,
|
163 |
detected_entities=detected_entities,
|
164 |
explanation=summary,
|
165 |
-
anonymized_text=anonymized_text
|
166 |
-
)
|
|
|
1 |
from typing import Dict, List, Optional
|
2 |
+
|
3 |
+
import instructor
|
4 |
import weave
|
5 |
from pydantic import BaseModel, Field
|
|
|
6 |
|
7 |
from ...llm import OpenAIModel
|
8 |
from ..base import Guardrail
|
|
|
9 |
|
10 |
|
11 |
class TermMatch(BaseModel):
|
12 |
"""Represents a matched term and its variations"""
|
13 |
+
|
14 |
original_term: str
|
15 |
matched_text: str
|
16 |
match_type: str = Field(
|
|
|
23 |
|
24 |
class RestrictedTermsAnalysis(BaseModel):
|
25 |
"""Analysis result for restricted terms detection"""
|
26 |
+
|
27 |
contains_restricted_terms: bool = Field(
|
28 |
description="Whether any restricted terms were detected"
|
29 |
)
|
30 |
detected_matches: List[TermMatch] = Field(
|
31 |
default_factory=list,
|
32 |
+
description="List of detected term matches with their variations",
|
|
|
|
|
|
|
33 |
)
|
34 |
+
explanation: str = Field(description="Detailed explanation of the analysis")
|
35 |
anonymized_text: Optional[str] = Field(
|
36 |
default=None,
|
37 |
+
description="Text with restricted terms replaced with category tags",
|
38 |
)
|
39 |
|
40 |
@property
|
|
|
106 |
return user_prompt, system_prompt
|
107 |
|
108 |
@weave.op()
|
109 |
+
def predict(
|
110 |
+
self, text: str, custom_terms: List[str], **kwargs
|
111 |
+
) -> RestrictedTermsAnalysis:
|
112 |
user_prompt, system_prompt = self.format_prompts(text, custom_terms)
|
113 |
+
|
114 |
response = self.llm_model.predict(
|
115 |
user_prompts=user_prompt,
|
116 |
system_prompt=system_prompt,
|
117 |
response_format=RestrictedTermsAnalysis,
|
118 |
temperature=0.1, # Lower temperature for more consistent analysis
|
119 |
+
**kwargs,
|
120 |
)
|
121 |
+
|
122 |
return response.choices[0].message.parsed
|
123 |
|
124 |
+
# TODO: Remove default custom_terms
|
125 |
@weave.op()
|
126 |
+
def guard(
|
127 |
+
self,
|
128 |
+
text: str,
|
129 |
+
custom_terms: List[str] = [
|
130 |
+
"Microsoft",
|
131 |
+
"Amazon Web Services",
|
132 |
+
"Facebook",
|
133 |
+
"Meta",
|
134 |
+
"Google",
|
135 |
+
"Salesforce",
|
136 |
+
"Oracle",
|
137 |
+
],
|
138 |
+
aggregate_redaction: bool = True,
|
139 |
+
**kwargs,
|
140 |
+
) -> RestrictedTermsRecognitionResponse:
|
141 |
"""
|
142 |
Guard against restricted terms and their variations.
|
143 |
+
|
144 |
Args:
|
145 |
text: Text to analyze
|
146 |
custom_terms: List of restricted terms to check for
|
147 |
+
|
148 |
Returns:
|
149 |
RestrictedTermsRecognitionResponse containing safety assessment and detailed analysis
|
150 |
"""
|
151 |
analysis = self.predict(text, custom_terms, **kwargs)
|
152 |
+
|
153 |
# Create a summary of findings
|
154 |
if analysis.contains_restricted_terms:
|
155 |
summary_parts = ["Restricted terms detected:"]
|
156 |
for match in analysis.detected_matches:
|
157 |
+
summary_parts.append(
|
158 |
+
f"\n- {match.original_term}: {match.matched_text} ({match.match_type})"
|
159 |
+
)
|
160 |
summary = "\n".join(summary_parts)
|
161 |
else:
|
162 |
summary = "No restricted terms detected."
|
|
|
166 |
if self.should_anonymize and analysis.contains_restricted_terms:
|
167 |
anonymized_text = text
|
168 |
for match in analysis.detected_matches:
|
169 |
+
replacement = (
|
170 |
+
"[redacted]"
|
171 |
+
if aggregate_redaction
|
172 |
+
else f"[{match.match_type.upper()}]"
|
173 |
+
)
|
174 |
+
anonymized_text = anonymized_text.replace(
|
175 |
+
match.matched_text, replacement
|
176 |
+
)
|
177 |
|
178 |
# Convert detected_matches to a dictionary format
|
179 |
detected_entities = {}
|
|
|
186 |
contains_entities=analysis.contains_restricted_terms,
|
187 |
detected_entities=detected_entities,
|
188 |
explanation=summary,
|
189 |
+
anonymized_text=anonymized_text,
|
190 |
+
)
|
guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
-
from datasets import load_dataset
|
2 |
-
from typing import Dict, List, Tuple
|
3 |
-
import random
|
4 |
-
from tqdm import tqdm
|
5 |
import json
|
|
|
6 |
from pathlib import Path
|
|
|
|
|
7 |
import weave
|
|
|
|
|
8 |
|
9 |
# Add this mapping dictionary near the top of the file
|
10 |
PRESIDIO_TO_TRANSFORMER_MAPPING = {
|
@@ -32,66 +33,70 @@ PRESIDIO_TO_TRANSFORMER_MAPPING = {
|
|
32 |
"CRYPTO": "ACCOUNTNUM", # Cryptocurrency addresses
|
33 |
"IBAN_CODE": "ACCOUNTNUM",
|
34 |
"MEDICAL_LICENSE": "IDCARDNUM",
|
35 |
-
"IN_VEHICLE_REGISTRATION": "IDCARDNUM"
|
36 |
}
|
37 |
|
38 |
-
|
|
|
|
|
|
|
39 |
"""
|
40 |
Load and prepare samples from the ai4privacy dataset.
|
41 |
-
|
42 |
Args:
|
43 |
num_samples: Number of samples to evaluate
|
44 |
split: Dataset split to use ("train" or "validation")
|
45 |
-
|
46 |
Returns:
|
47 |
List of prepared test cases
|
48 |
"""
|
49 |
# Load the dataset
|
50 |
dataset = load_dataset("ai4privacy/pii-masking-400k")
|
51 |
-
|
52 |
# Get the specified split
|
53 |
data_split = dataset[split]
|
54 |
-
|
55 |
# Randomly sample entries if num_samples is less than total
|
56 |
if num_samples < len(data_split):
|
57 |
indices = random.sample(range(len(data_split)), num_samples)
|
58 |
samples = [data_split[i] for i in indices]
|
59 |
else:
|
60 |
samples = data_split
|
61 |
-
|
62 |
# Convert to test case format
|
63 |
test_cases = []
|
64 |
for sample in samples:
|
65 |
# Extract entities from privacy_mask
|
66 |
entities: Dict[str, List[str]] = {}
|
67 |
-
for entity in sample[
|
68 |
-
label = entity[
|
69 |
-
value = entity[
|
70 |
if label not in entities:
|
71 |
entities[label] = []
|
72 |
entities[label].append(value)
|
73 |
-
|
74 |
test_case = {
|
75 |
"description": f"AI4Privacy Sample (ID: {sample['uid']})",
|
76 |
-
"input_text": sample[
|
77 |
"expected_entities": entities,
|
78 |
-
"masked_text": sample[
|
79 |
-
"language": sample[
|
80 |
-
"locale": sample[
|
81 |
}
|
82 |
test_cases.append(test_case)
|
83 |
-
|
84 |
return test_cases
|
85 |
|
|
|
86 |
@weave.op()
|
87 |
def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]:
|
88 |
"""
|
89 |
Evaluate a model on the test cases.
|
90 |
-
|
91 |
Args:
|
92 |
guardrail: Entity recognition guardrail to evaluate
|
93 |
test_cases: List of test cases
|
94 |
-
|
95 |
Returns:
|
96 |
Tuple of (metrics dict, detailed results list)
|
97 |
"""
|
@@ -99,17 +104,17 @@ def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]
|
|
99 |
"total": len(test_cases),
|
100 |
"passed": 0,
|
101 |
"failed": 0,
|
102 |
-
"entity_metrics": {} # Will store precision/recall per entity type
|
103 |
}
|
104 |
-
|
105 |
detailed_results = []
|
106 |
-
|
107 |
for test_case in tqdm(test_cases, desc="Evaluating samples"):
|
108 |
# Run detection
|
109 |
-
result = guardrail.guard(test_case[
|
110 |
detected = result.detected_entities
|
111 |
-
expected = test_case[
|
112 |
-
|
113 |
# Map Presidio entities if this is the Presidio guardrail
|
114 |
if isinstance(guardrail, PresidioEntityRecognitionGuardrail):
|
115 |
mapped_detected = {}
|
@@ -120,44 +125,62 @@ def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]
|
|
120 |
mapped_detected[mapped_type] = []
|
121 |
mapped_detected[mapped_type].extend(values)
|
122 |
detected = mapped_detected
|
123 |
-
|
124 |
# Track entity-level metrics
|
125 |
all_entity_types = set(list(detected.keys()) + list(expected.keys()))
|
126 |
entity_results = {}
|
127 |
-
|
128 |
for entity_type in all_entity_types:
|
129 |
detected_set = set(detected.get(entity_type, []))
|
130 |
expected_set = set(expected.get(entity_type, []))
|
131 |
-
|
132 |
# Calculate metrics
|
133 |
true_positives = len(detected_set & expected_set)
|
134 |
false_positives = len(detected_set - expected_set)
|
135 |
false_negatives = len(expected_set - detected_set)
|
136 |
-
|
137 |
-
precision =
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
entity_results[entity_type] = {
|
142 |
"precision": precision,
|
143 |
"recall": recall,
|
144 |
"f1": f1,
|
145 |
"true_positives": true_positives,
|
146 |
"false_positives": false_positives,
|
147 |
-
"false_negatives": false_negatives
|
148 |
}
|
149 |
-
|
150 |
# Aggregate metrics
|
151 |
if entity_type not in metrics["entity_metrics"]:
|
152 |
metrics["entity_metrics"][entity_type] = {
|
153 |
"total_true_positives": 0,
|
154 |
"total_false_positives": 0,
|
155 |
-
"total_false_negatives": 0
|
156 |
}
|
157 |
-
metrics["entity_metrics"][entity_type][
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
# Store detailed result
|
162 |
detailed_result = {
|
163 |
"id": test_case.get("description", ""),
|
@@ -167,69 +190,88 @@ def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]
|
|
167 |
"expected_entities": expected,
|
168 |
"detected_entities": detected,
|
169 |
"entity_metrics": entity_results,
|
170 |
-
"anonymized_text":
|
|
|
|
|
171 |
}
|
172 |
detailed_results.append(detailed_result)
|
173 |
-
|
174 |
# Update pass/fail counts
|
175 |
if all(entity_results[et]["f1"] == 1.0 for et in entity_results):
|
176 |
metrics["passed"] += 1
|
177 |
else:
|
178 |
metrics["failed"] += 1
|
179 |
-
|
180 |
# Calculate final entity metrics and track totals for overall metrics
|
181 |
total_tp = 0
|
182 |
total_fp = 0
|
183 |
total_fn = 0
|
184 |
-
|
185 |
for entity_type, counts in metrics["entity_metrics"].items():
|
186 |
tp = counts["total_true_positives"]
|
187 |
fp = counts["total_false_positives"]
|
188 |
fn = counts["total_false_negatives"]
|
189 |
-
|
190 |
total_tp += tp
|
191 |
total_fp += fp
|
192 |
total_fn += fn
|
193 |
-
|
194 |
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
195 |
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
196 |
-
f1 =
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
204 |
# Calculate overall metrics
|
205 |
-
overall_precision =
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
metrics["overall"] = {
|
210 |
"precision": overall_precision,
|
211 |
"recall": overall_recall,
|
212 |
"f1": overall_f1,
|
213 |
"total_true_positives": total_tp,
|
214 |
"total_false_positives": total_fp,
|
215 |
-
"total_false_negatives": total_fn
|
216 |
}
|
217 |
-
|
218 |
return metrics, detailed_results
|
219 |
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
"""Save evaluation results to files"""
|
222 |
output_dir = Path(output_dir)
|
223 |
output_dir.mkdir(exist_ok=True)
|
224 |
-
|
225 |
# Save metrics summary
|
226 |
with open(output_dir / f"{model_name}_metrics.json", "w") as f:
|
227 |
json.dump(metrics, f, indent=2)
|
228 |
-
|
229 |
# Save detailed results
|
230 |
with open(output_dir / f"{model_name}_detailed_results.json", "w") as f:
|
231 |
json.dump(detailed_results, f, indent=2)
|
232 |
|
|
|
233 |
def print_metrics_summary(metrics: Dict):
|
234 |
"""Print a summary of the evaluation metrics"""
|
235 |
print("\nEvaluation Summary")
|
@@ -238,7 +280,7 @@ def print_metrics_summary(metrics: Dict):
|
|
238 |
print(f"Passed: {metrics['passed']}")
|
239 |
print(f"Failed: {metrics['failed']}")
|
240 |
print(f"Success Rate: {(metrics['passed']/metrics['total'])*100:.1f}%")
|
241 |
-
|
242 |
# Print overall metrics
|
243 |
print("\nOverall Metrics:")
|
244 |
print("-" * 80)
|
@@ -247,40 +289,56 @@ def print_metrics_summary(metrics: Dict):
|
|
247 |
print(f"{'Precision':<20} {metrics['overall']['precision']:>10.2f}")
|
248 |
print(f"{'Recall':<20} {metrics['overall']['recall']:>10.2f}")
|
249 |
print(f"{'F1':<20} {metrics['overall']['f1']:>10.2f}")
|
250 |
-
|
251 |
print("\nEntity-level Metrics:")
|
252 |
print("-" * 80)
|
253 |
print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
|
254 |
print("-" * 80)
|
255 |
for entity_type, entity_metrics in metrics["entity_metrics"].items():
|
256 |
-
print(
|
|
|
|
|
|
|
257 |
|
258 |
def main():
|
259 |
"""Main evaluation function"""
|
260 |
weave.init("guardrails-genie-pii-evaluation-demo")
|
261 |
-
|
262 |
# Load test cases
|
263 |
test_cases = load_ai4privacy_dataset(num_samples=100)
|
264 |
-
|
265 |
# Initialize models to evaluate
|
266 |
models = {
|
267 |
-
"regex": RegexEntityRecognitionGuardrail(
|
268 |
-
|
269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
}
|
271 |
-
|
272 |
# Evaluate each model
|
273 |
for model_name, guardrail in models.items():
|
274 |
print(f"\nEvaluating {model_name} model...")
|
275 |
metrics, detailed_results = evaluate_model(guardrail, test_cases)
|
276 |
-
|
277 |
# Print and save results
|
278 |
print_metrics_summary(metrics)
|
279 |
save_results(metrics, detailed_results, model_name)
|
280 |
|
|
|
281 |
if __name__ == "__main__":
|
282 |
-
from guardrails_genie.guardrails.entity_recognition.
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
+
import random
|
3 |
from pathlib import Path
|
4 |
+
from typing import Dict, List, Tuple
|
5 |
+
|
6 |
import weave
|
7 |
+
from datasets import load_dataset
|
8 |
+
from tqdm import tqdm
|
9 |
|
10 |
# Add this mapping dictionary near the top of the file
|
11 |
PRESIDIO_TO_TRANSFORMER_MAPPING = {
|
|
|
33 |
"CRYPTO": "ACCOUNTNUM", # Cryptocurrency addresses
|
34 |
"IBAN_CODE": "ACCOUNTNUM",
|
35 |
"MEDICAL_LICENSE": "IDCARDNUM",
|
36 |
+
"IN_VEHICLE_REGISTRATION": "IDCARDNUM",
|
37 |
}
|
38 |
|
39 |
+
|
40 |
+
def load_ai4privacy_dataset(
|
41 |
+
num_samples: int = 100, split: str = "validation"
|
42 |
+
) -> List[Dict]:
|
43 |
"""
|
44 |
Load and prepare samples from the ai4privacy dataset.
|
45 |
+
|
46 |
Args:
|
47 |
num_samples: Number of samples to evaluate
|
48 |
split: Dataset split to use ("train" or "validation")
|
49 |
+
|
50 |
Returns:
|
51 |
List of prepared test cases
|
52 |
"""
|
53 |
# Load the dataset
|
54 |
dataset = load_dataset("ai4privacy/pii-masking-400k")
|
55 |
+
|
56 |
# Get the specified split
|
57 |
data_split = dataset[split]
|
58 |
+
|
59 |
# Randomly sample entries if num_samples is less than total
|
60 |
if num_samples < len(data_split):
|
61 |
indices = random.sample(range(len(data_split)), num_samples)
|
62 |
samples = [data_split[i] for i in indices]
|
63 |
else:
|
64 |
samples = data_split
|
65 |
+
|
66 |
# Convert to test case format
|
67 |
test_cases = []
|
68 |
for sample in samples:
|
69 |
# Extract entities from privacy_mask
|
70 |
entities: Dict[str, List[str]] = {}
|
71 |
+
for entity in sample["privacy_mask"]:
|
72 |
+
label = entity["label"]
|
73 |
+
value = entity["value"]
|
74 |
if label not in entities:
|
75 |
entities[label] = []
|
76 |
entities[label].append(value)
|
77 |
+
|
78 |
test_case = {
|
79 |
"description": f"AI4Privacy Sample (ID: {sample['uid']})",
|
80 |
+
"input_text": sample["source_text"],
|
81 |
"expected_entities": entities,
|
82 |
+
"masked_text": sample["masked_text"],
|
83 |
+
"language": sample["language"],
|
84 |
+
"locale": sample["locale"],
|
85 |
}
|
86 |
test_cases.append(test_case)
|
87 |
+
|
88 |
return test_cases
|
89 |
|
90 |
+
|
91 |
@weave.op()
|
92 |
def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]:
|
93 |
"""
|
94 |
Evaluate a model on the test cases.
|
95 |
+
|
96 |
Args:
|
97 |
guardrail: Entity recognition guardrail to evaluate
|
98 |
test_cases: List of test cases
|
99 |
+
|
100 |
Returns:
|
101 |
Tuple of (metrics dict, detailed results list)
|
102 |
"""
|
|
|
104 |
"total": len(test_cases),
|
105 |
"passed": 0,
|
106 |
"failed": 0,
|
107 |
+
"entity_metrics": {}, # Will store precision/recall per entity type
|
108 |
}
|
109 |
+
|
110 |
detailed_results = []
|
111 |
+
|
112 |
for test_case in tqdm(test_cases, desc="Evaluating samples"):
|
113 |
# Run detection
|
114 |
+
result = guardrail.guard(test_case["input_text"])
|
115 |
detected = result.detected_entities
|
116 |
+
expected = test_case["expected_entities"]
|
117 |
+
|
118 |
# Map Presidio entities if this is the Presidio guardrail
|
119 |
if isinstance(guardrail, PresidioEntityRecognitionGuardrail):
|
120 |
mapped_detected = {}
|
|
|
125 |
mapped_detected[mapped_type] = []
|
126 |
mapped_detected[mapped_type].extend(values)
|
127 |
detected = mapped_detected
|
128 |
+
|
129 |
# Track entity-level metrics
|
130 |
all_entity_types = set(list(detected.keys()) + list(expected.keys()))
|
131 |
entity_results = {}
|
132 |
+
|
133 |
for entity_type in all_entity_types:
|
134 |
detected_set = set(detected.get(entity_type, []))
|
135 |
expected_set = set(expected.get(entity_type, []))
|
136 |
+
|
137 |
# Calculate metrics
|
138 |
true_positives = len(detected_set & expected_set)
|
139 |
false_positives = len(detected_set - expected_set)
|
140 |
false_negatives = len(expected_set - detected_set)
|
141 |
+
|
142 |
+
precision = (
|
143 |
+
true_positives / (true_positives + false_positives)
|
144 |
+
if (true_positives + false_positives) > 0
|
145 |
+
else 0
|
146 |
+
)
|
147 |
+
recall = (
|
148 |
+
true_positives / (true_positives + false_negatives)
|
149 |
+
if (true_positives + false_negatives) > 0
|
150 |
+
else 0
|
151 |
+
)
|
152 |
+
f1 = (
|
153 |
+
2 * (precision * recall) / (precision + recall)
|
154 |
+
if (precision + recall) > 0
|
155 |
+
else 0
|
156 |
+
)
|
157 |
+
|
158 |
entity_results[entity_type] = {
|
159 |
"precision": precision,
|
160 |
"recall": recall,
|
161 |
"f1": f1,
|
162 |
"true_positives": true_positives,
|
163 |
"false_positives": false_positives,
|
164 |
+
"false_negatives": false_negatives,
|
165 |
}
|
166 |
+
|
167 |
# Aggregate metrics
|
168 |
if entity_type not in metrics["entity_metrics"]:
|
169 |
metrics["entity_metrics"][entity_type] = {
|
170 |
"total_true_positives": 0,
|
171 |
"total_false_positives": 0,
|
172 |
+
"total_false_negatives": 0,
|
173 |
}
|
174 |
+
metrics["entity_metrics"][entity_type][
|
175 |
+
"total_true_positives"
|
176 |
+
] += true_positives
|
177 |
+
metrics["entity_metrics"][entity_type][
|
178 |
+
"total_false_positives"
|
179 |
+
] += false_positives
|
180 |
+
metrics["entity_metrics"][entity_type][
|
181 |
+
"total_false_negatives"
|
182 |
+
] += false_negatives
|
183 |
+
|
184 |
# Store detailed result
|
185 |
detailed_result = {
|
186 |
"id": test_case.get("description", ""),
|
|
|
190 |
"expected_entities": expected,
|
191 |
"detected_entities": detected,
|
192 |
"entity_metrics": entity_results,
|
193 |
+
"anonymized_text": (
|
194 |
+
result.anonymized_text if result.anonymized_text else None
|
195 |
+
),
|
196 |
}
|
197 |
detailed_results.append(detailed_result)
|
198 |
+
|
199 |
# Update pass/fail counts
|
200 |
if all(entity_results[et]["f1"] == 1.0 for et in entity_results):
|
201 |
metrics["passed"] += 1
|
202 |
else:
|
203 |
metrics["failed"] += 1
|
204 |
+
|
205 |
# Calculate final entity metrics and track totals for overall metrics
|
206 |
total_tp = 0
|
207 |
total_fp = 0
|
208 |
total_fn = 0
|
209 |
+
|
210 |
for entity_type, counts in metrics["entity_metrics"].items():
|
211 |
tp = counts["total_true_positives"]
|
212 |
fp = counts["total_false_positives"]
|
213 |
fn = counts["total_false_negatives"]
|
214 |
+
|
215 |
total_tp += tp
|
216 |
total_fp += fp
|
217 |
total_fn += fn
|
218 |
+
|
219 |
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
220 |
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
221 |
+
f1 = (
|
222 |
+
2 * (precision * recall) / (precision + recall)
|
223 |
+
if (precision + recall) > 0
|
224 |
+
else 0
|
225 |
+
)
|
226 |
+
|
227 |
+
metrics["entity_metrics"][entity_type].update(
|
228 |
+
{"precision": precision, "recall": recall, "f1": f1}
|
229 |
+
)
|
230 |
+
|
231 |
# Calculate overall metrics
|
232 |
+
overall_precision = (
|
233 |
+
total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
|
234 |
+
)
|
235 |
+
overall_recall = (
|
236 |
+
total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
|
237 |
+
)
|
238 |
+
overall_f1 = (
|
239 |
+
2 * (overall_precision * overall_recall) / (overall_precision + overall_recall)
|
240 |
+
if (overall_precision + overall_recall) > 0
|
241 |
+
else 0
|
242 |
+
)
|
243 |
+
|
244 |
metrics["overall"] = {
|
245 |
"precision": overall_precision,
|
246 |
"recall": overall_recall,
|
247 |
"f1": overall_f1,
|
248 |
"total_true_positives": total_tp,
|
249 |
"total_false_positives": total_fp,
|
250 |
+
"total_false_negatives": total_fn,
|
251 |
}
|
252 |
+
|
253 |
return metrics, detailed_results
|
254 |
|
255 |
+
|
256 |
+
def save_results(
|
257 |
+
metrics: Dict,
|
258 |
+
detailed_results: List[Dict],
|
259 |
+
model_name: str,
|
260 |
+
output_dir: str = "evaluation_results",
|
261 |
+
):
|
262 |
"""Save evaluation results to files"""
|
263 |
output_dir = Path(output_dir)
|
264 |
output_dir.mkdir(exist_ok=True)
|
265 |
+
|
266 |
# Save metrics summary
|
267 |
with open(output_dir / f"{model_name}_metrics.json", "w") as f:
|
268 |
json.dump(metrics, f, indent=2)
|
269 |
+
|
270 |
# Save detailed results
|
271 |
with open(output_dir / f"{model_name}_detailed_results.json", "w") as f:
|
272 |
json.dump(detailed_results, f, indent=2)
|
273 |
|
274 |
+
|
275 |
def print_metrics_summary(metrics: Dict):
|
276 |
"""Print a summary of the evaluation metrics"""
|
277 |
print("\nEvaluation Summary")
|
|
|
280 |
print(f"Passed: {metrics['passed']}")
|
281 |
print(f"Failed: {metrics['failed']}")
|
282 |
print(f"Success Rate: {(metrics['passed']/metrics['total'])*100:.1f}%")
|
283 |
+
|
284 |
# Print overall metrics
|
285 |
print("\nOverall Metrics:")
|
286 |
print("-" * 80)
|
|
|
289 |
print(f"{'Precision':<20} {metrics['overall']['precision']:>10.2f}")
|
290 |
print(f"{'Recall':<20} {metrics['overall']['recall']:>10.2f}")
|
291 |
print(f"{'F1':<20} {metrics['overall']['f1']:>10.2f}")
|
292 |
+
|
293 |
print("\nEntity-level Metrics:")
|
294 |
print("-" * 80)
|
295 |
print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
|
296 |
print("-" * 80)
|
297 |
for entity_type, entity_metrics in metrics["entity_metrics"].items():
|
298 |
+
print(
|
299 |
+
f"{entity_type:<20} {entity_metrics['precision']:>10.2f} {entity_metrics['recall']:>10.2f} {entity_metrics['f1']:>10.2f}"
|
300 |
+
)
|
301 |
+
|
302 |
|
303 |
def main():
|
304 |
"""Main evaluation function"""
|
305 |
weave.init("guardrails-genie-pii-evaluation-demo")
|
306 |
+
|
307 |
# Load test cases
|
308 |
test_cases = load_ai4privacy_dataset(num_samples=100)
|
309 |
+
|
310 |
# Initialize models to evaluate
|
311 |
models = {
|
312 |
+
"regex": RegexEntityRecognitionGuardrail(
|
313 |
+
should_anonymize=True, show_available_entities=True
|
314 |
+
),
|
315 |
+
"presidio": PresidioEntityRecognitionGuardrail(
|
316 |
+
should_anonymize=True, show_available_entities=True
|
317 |
+
),
|
318 |
+
"transformers": TransformersEntityRecognitionGuardrail(
|
319 |
+
should_anonymize=True, show_available_entities=True
|
320 |
+
),
|
321 |
}
|
322 |
+
|
323 |
# Evaluate each model
|
324 |
for model_name, guardrail in models.items():
|
325 |
print(f"\nEvaluating {model_name} model...")
|
326 |
metrics, detailed_results = evaluate_model(guardrail, test_cases)
|
327 |
+
|
328 |
# Print and save results
|
329 |
print_metrics_summary(metrics)
|
330 |
save_results(metrics, detailed_results, model_name)
|
331 |
|
332 |
+
|
333 |
if __name__ == "__main__":
|
334 |
+
from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import (
|
335 |
+
PresidioEntityRecognitionGuardrail,
|
336 |
+
)
|
337 |
+
from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import (
|
338 |
+
RegexEntityRecognitionGuardrail,
|
339 |
+
)
|
340 |
+
from guardrails_genie.guardrails.entity_recognition.transformers_entity_recognition_guardrail import (
|
341 |
+
TransformersEntityRecognitionGuardrail,
|
342 |
+
)
|
343 |
+
|
344 |
+
main()
|
guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark_weave.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
|
2 |
-
from typing import Dict, List, Tuple, Optional
|
3 |
-
import random
|
4 |
-
from tqdm import tqdm
|
5 |
import json
|
|
|
6 |
from pathlib import Path
|
|
|
|
|
7 |
import weave
|
8 |
-
from
|
9 |
from weave import Evaluation
|
10 |
-
import
|
11 |
|
12 |
# Add this mapping dictionary near the top of the file
|
13 |
PRESIDIO_TO_TRANSFORMER_MAPPING = {
|
@@ -35,26 +35,29 @@ PRESIDIO_TO_TRANSFORMER_MAPPING = {
|
|
35 |
"CRYPTO": "ACCOUNTNUM", # Cryptocurrency addresses
|
36 |
"IBAN_CODE": "ACCOUNTNUM",
|
37 |
"MEDICAL_LICENSE": "IDCARDNUM",
|
38 |
-
"IN_VEHICLE_REGISTRATION": "IDCARDNUM"
|
39 |
}
|
40 |
|
|
|
41 |
class EntityRecognitionScorer(Scorer):
|
42 |
"""Scorer for evaluating entity recognition performance"""
|
43 |
-
|
44 |
@weave.op()
|
45 |
-
async def score(
|
|
|
|
|
46 |
"""Score entity recognition results"""
|
47 |
if not model_output:
|
48 |
return {"f1": 0.0}
|
49 |
-
|
50 |
# Convert Pydantic model to dict if necessary
|
51 |
if hasattr(model_output, "model_dump"):
|
52 |
model_output = model_output.model_dump()
|
53 |
elif hasattr(model_output, "dict"):
|
54 |
model_output = model_output.dict()
|
55 |
-
|
56 |
detected = model_output.get("detected_entities", {})
|
57 |
-
|
58 |
# Map Presidio entities if needed
|
59 |
if model_output.get("model_type") == "presidio":
|
60 |
mapped_detected = {}
|
@@ -65,191 +68,234 @@ class EntityRecognitionScorer(Scorer):
|
|
65 |
mapped_detected[mapped_type] = []
|
66 |
mapped_detected[mapped_type].extend(values)
|
67 |
detected = mapped_detected
|
68 |
-
|
69 |
# Track entity-level metrics
|
70 |
all_entity_types = set(list(detected.keys()) + list(expected_entities.keys()))
|
71 |
entity_metrics = {}
|
72 |
-
|
73 |
for entity_type in all_entity_types:
|
74 |
detected_set = set(detected.get(entity_type, []))
|
75 |
expected_set = set(expected_entities.get(entity_type, []))
|
76 |
-
|
77 |
# Calculate metrics
|
78 |
true_positives = len(detected_set & expected_set)
|
79 |
false_positives = len(detected_set - expected_set)
|
80 |
false_negatives = len(expected_set - detected_set)
|
81 |
-
|
82 |
if entity_type not in entity_metrics:
|
83 |
entity_metrics[entity_type] = {
|
84 |
"total_true_positives": 0,
|
85 |
"total_false_positives": 0,
|
86 |
-
"total_false_negatives": 0
|
87 |
}
|
88 |
-
|
89 |
entity_metrics[entity_type]["total_true_positives"] += true_positives
|
90 |
entity_metrics[entity_type]["total_false_positives"] += false_positives
|
91 |
entity_metrics[entity_type]["total_false_negatives"] += false_negatives
|
92 |
-
|
93 |
# Calculate per-entity metrics
|
94 |
-
precision =
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
# Calculate overall metrics
|
105 |
-
total_tp = sum(
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
entity_metrics["overall"] = {
|
114 |
"precision": overall_precision,
|
115 |
"recall": overall_recall,
|
116 |
"f1": overall_f1,
|
117 |
"total_true_positives": total_tp,
|
118 |
"total_false_positives": total_fp,
|
119 |
-
"total_false_negatives": total_fn
|
120 |
}
|
121 |
-
|
122 |
return entity_metrics
|
123 |
|
124 |
-
|
|
|
|
|
|
|
125 |
"""
|
126 |
Load and prepare samples from the ai4privacy dataset.
|
127 |
-
|
128 |
Args:
|
129 |
num_samples: Number of samples to evaluate
|
130 |
split: Dataset split to use ("train" or "validation")
|
131 |
-
|
132 |
Returns:
|
133 |
List of prepared test cases
|
134 |
"""
|
135 |
# Load the dataset
|
136 |
dataset = load_dataset("ai4privacy/pii-masking-400k")
|
137 |
-
|
138 |
# Get the specified split
|
139 |
data_split = dataset[split]
|
140 |
-
|
141 |
# Randomly sample entries if num_samples is less than total
|
142 |
if num_samples < len(data_split):
|
143 |
indices = random.sample(range(len(data_split)), num_samples)
|
144 |
samples = [data_split[i] for i in indices]
|
145 |
else:
|
146 |
samples = data_split
|
147 |
-
|
148 |
# Convert to test case format
|
149 |
test_cases = []
|
150 |
for sample in samples:
|
151 |
# Extract entities from privacy_mask
|
152 |
entities: Dict[str, List[str]] = {}
|
153 |
-
for entity in sample[
|
154 |
-
label = entity[
|
155 |
-
value = entity[
|
156 |
if label not in entities:
|
157 |
entities[label] = []
|
158 |
entities[label].append(value)
|
159 |
-
|
160 |
test_case = {
|
161 |
"description": f"AI4Privacy Sample (ID: {sample['uid']})",
|
162 |
-
"input_text": sample[
|
163 |
"expected_entities": entities,
|
164 |
-
"masked_text": sample[
|
165 |
-
"language": sample[
|
166 |
-
"locale": sample[
|
167 |
}
|
168 |
test_cases.append(test_case)
|
169 |
-
|
170 |
return test_cases
|
171 |
|
172 |
-
|
|
|
|
|
|
|
173 |
"""Save evaluation results to files"""
|
174 |
output_dir = Path(output_dir)
|
175 |
output_dir.mkdir(exist_ok=True)
|
176 |
-
|
177 |
# Extract and process results
|
178 |
scorer_results = weave_results.get("EntityRecognitionScorer", [])
|
179 |
if not scorer_results or all(r is None for r in scorer_results):
|
180 |
print(f"No valid results to save for {model_name}")
|
181 |
return
|
182 |
-
|
183 |
# Calculate summary metrics
|
184 |
total_samples = len(scorer_results)
|
185 |
passed = sum(1 for r in scorer_results if r is not None and not isinstance(r, str))
|
186 |
-
|
187 |
# Aggregate entity-level metrics
|
188 |
entity_metrics = {}
|
189 |
for result in scorer_results:
|
190 |
try:
|
191 |
if isinstance(result, str) or not result:
|
192 |
continue
|
193 |
-
|
194 |
for entity_type, metrics in result.items():
|
195 |
if entity_type not in entity_metrics:
|
196 |
entity_metrics[entity_type] = {
|
197 |
"precision": [],
|
198 |
"recall": [],
|
199 |
-
"f1": []
|
200 |
}
|
201 |
entity_metrics[entity_type]["precision"].append(metrics["precision"])
|
202 |
entity_metrics[entity_type]["recall"].append(metrics["recall"])
|
203 |
entity_metrics[entity_type]["f1"].append(metrics["f1"])
|
204 |
except (AttributeError, TypeError, KeyError):
|
205 |
continue
|
206 |
-
|
207 |
# Calculate averages
|
208 |
summary_metrics = {
|
209 |
"total": total_samples,
|
210 |
"passed": passed,
|
211 |
"failed": total_samples - passed,
|
212 |
-
"success_rate": (passed/total_samples) if total_samples > 0 else 0,
|
213 |
"entity_metrics": {
|
214 |
entity_type: {
|
215 |
-
"precision":
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
}
|
219 |
for entity_type, metrics in entity_metrics.items()
|
220 |
-
}
|
221 |
}
|
222 |
-
|
223 |
# Save files
|
224 |
with open(output_dir / f"{model_name}_metrics.json", "w") as f:
|
225 |
json.dump(summary_metrics, f, indent=2)
|
226 |
-
|
227 |
# Save detailed results, filtering out string results
|
228 |
-
detailed_results = [
|
|
|
|
|
229 |
with open(output_dir / f"{model_name}_detailed_results.json", "w") as f:
|
230 |
json.dump(detailed_results, f, indent=2)
|
231 |
|
|
|
232 |
def print_metrics_summary(weave_results: Dict):
|
233 |
"""Print a summary of the evaluation metrics"""
|
234 |
print("\nEvaluation Summary")
|
235 |
print("=" * 80)
|
236 |
-
|
237 |
# Extract results from Weave's evaluation format
|
238 |
scorer_results = weave_results.get("EntityRecognitionScorer", {})
|
239 |
if not scorer_results:
|
240 |
print("No valid results available")
|
241 |
return
|
242 |
-
|
243 |
# Calculate overall metrics
|
244 |
total_samples = int(weave_results.get("model_latency", {}).get("count", 0))
|
245 |
passed = total_samples # Since we have results, all samples passed
|
246 |
failed = 0
|
247 |
-
|
248 |
print(f"Total Samples: {total_samples}")
|
249 |
print(f"Passed: {passed}")
|
250 |
print(f"Failed: {failed}")
|
251 |
print(f"Success Rate: {(passed/total_samples)*100:.2f}%")
|
252 |
-
|
253 |
# Print overall metrics
|
254 |
if "overall" in scorer_results:
|
255 |
overall = scorer_results["overall"]
|
@@ -260,63 +306,68 @@ def print_metrics_summary(weave_results: Dict):
|
|
260 |
print(f"{'Precision':<20} {overall['precision']['mean']:>10.2f}")
|
261 |
print(f"{'Recall':<20} {overall['recall']['mean']:>10.2f}")
|
262 |
print(f"{'F1':<20} {overall['f1']['mean']:>10.2f}")
|
263 |
-
|
264 |
# Print entity-level metrics
|
265 |
print("\nEntity-Level Metrics:")
|
266 |
print("-" * 80)
|
267 |
print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
|
268 |
print("-" * 80)
|
269 |
-
|
270 |
for entity_type, metrics in scorer_results.items():
|
271 |
if entity_type == "overall":
|
272 |
continue
|
273 |
-
|
274 |
precision = metrics.get("precision", {}).get("mean", 0)
|
275 |
recall = metrics.get("recall", {}).get("mean", 0)
|
276 |
f1 = metrics.get("f1", {}).get("mean", 0)
|
277 |
-
|
278 |
print(f"{entity_type:<20} {precision:>10.2f} {recall:>10.2f} {f1:>10.2f}")
|
279 |
|
|
|
280 |
def preprocess_model_input(example: Dict) -> Dict:
|
281 |
"""Preprocess dataset example to match model input format."""
|
282 |
return {
|
283 |
"prompt": example["input_text"],
|
284 |
-
"model_type": example.get(
|
|
|
|
|
285 |
}
|
286 |
|
|
|
287 |
def main():
|
288 |
"""Main evaluation function"""
|
289 |
weave.init("guardrails-genie-pii-evaluation")
|
290 |
-
|
291 |
# Load test cases
|
292 |
test_cases = load_ai4privacy_dataset(num_samples=100)
|
293 |
-
|
294 |
# Add model type to test cases for Presidio mapping
|
295 |
models = {
|
296 |
# "regex": RegexEntityRecognitionGuardrail(should_anonymize=True),
|
297 |
"presidio": PresidioEntityRecognitionGuardrail(should_anonymize=True),
|
298 |
# "transformers": TransformersEntityRecognitionGuardrail(should_anonymize=True)
|
299 |
}
|
300 |
-
|
301 |
scorer = EntityRecognitionScorer()
|
302 |
-
|
303 |
# Evaluate each model
|
304 |
for model_name, guardrail in models.items():
|
305 |
print(f"\nEvaluating {model_name} model...")
|
306 |
# Add model type to test cases
|
307 |
model_test_cases = [{**case, "model_type": model_name} for case in test_cases]
|
308 |
-
|
309 |
evaluation = Evaluation(
|
310 |
dataset=model_test_cases,
|
311 |
scorers=[scorer],
|
312 |
-
preprocess_model_input=preprocess_model_input
|
313 |
)
|
314 |
-
|
315 |
results = asyncio.run(evaluation.evaluate(guardrail))
|
316 |
|
|
|
317 |
if __name__ == "__main__":
|
318 |
-
from guardrails_genie.guardrails.entity_recognition.
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
main()
|
|
|
1 |
+
import asyncio
|
|
|
|
|
|
|
2 |
import json
|
3 |
+
import random
|
4 |
from pathlib import Path
|
5 |
+
from typing import Dict, List, Optional
|
6 |
+
|
7 |
import weave
|
8 |
+
from datasets import load_dataset
|
9 |
from weave import Evaluation
|
10 |
+
from weave.scorers import Scorer
|
11 |
|
12 |
# Add this mapping dictionary near the top of the file
|
13 |
PRESIDIO_TO_TRANSFORMER_MAPPING = {
|
|
|
35 |
"CRYPTO": "ACCOUNTNUM", # Cryptocurrency addresses
|
36 |
"IBAN_CODE": "ACCOUNTNUM",
|
37 |
"MEDICAL_LICENSE": "IDCARDNUM",
|
38 |
+
"IN_VEHICLE_REGISTRATION": "IDCARDNUM",
|
39 |
}
|
40 |
|
41 |
+
|
42 |
class EntityRecognitionScorer(Scorer):
|
43 |
"""Scorer for evaluating entity recognition performance"""
|
44 |
+
|
45 |
@weave.op()
|
46 |
+
async def score(
|
47 |
+
self, model_output: Optional[dict], input_text: str, expected_entities: Dict
|
48 |
+
) -> Dict:
|
49 |
"""Score entity recognition results"""
|
50 |
if not model_output:
|
51 |
return {"f1": 0.0}
|
52 |
+
|
53 |
# Convert Pydantic model to dict if necessary
|
54 |
if hasattr(model_output, "model_dump"):
|
55 |
model_output = model_output.model_dump()
|
56 |
elif hasattr(model_output, "dict"):
|
57 |
model_output = model_output.dict()
|
58 |
+
|
59 |
detected = model_output.get("detected_entities", {})
|
60 |
+
|
61 |
# Map Presidio entities if needed
|
62 |
if model_output.get("model_type") == "presidio":
|
63 |
mapped_detected = {}
|
|
|
68 |
mapped_detected[mapped_type] = []
|
69 |
mapped_detected[mapped_type].extend(values)
|
70 |
detected = mapped_detected
|
71 |
+
|
72 |
# Track entity-level metrics
|
73 |
all_entity_types = set(list(detected.keys()) + list(expected_entities.keys()))
|
74 |
entity_metrics = {}
|
75 |
+
|
76 |
for entity_type in all_entity_types:
|
77 |
detected_set = set(detected.get(entity_type, []))
|
78 |
expected_set = set(expected_entities.get(entity_type, []))
|
79 |
+
|
80 |
# Calculate metrics
|
81 |
true_positives = len(detected_set & expected_set)
|
82 |
false_positives = len(detected_set - expected_set)
|
83 |
false_negatives = len(expected_set - detected_set)
|
84 |
+
|
85 |
if entity_type not in entity_metrics:
|
86 |
entity_metrics[entity_type] = {
|
87 |
"total_true_positives": 0,
|
88 |
"total_false_positives": 0,
|
89 |
+
"total_false_negatives": 0,
|
90 |
}
|
91 |
+
|
92 |
entity_metrics[entity_type]["total_true_positives"] += true_positives
|
93 |
entity_metrics[entity_type]["total_false_positives"] += false_positives
|
94 |
entity_metrics[entity_type]["total_false_negatives"] += false_negatives
|
95 |
+
|
96 |
# Calculate per-entity metrics
|
97 |
+
precision = (
|
98 |
+
true_positives / (true_positives + false_positives)
|
99 |
+
if (true_positives + false_positives) > 0
|
100 |
+
else 0
|
101 |
+
)
|
102 |
+
recall = (
|
103 |
+
true_positives / (true_positives + false_negatives)
|
104 |
+
if (true_positives + false_negatives) > 0
|
105 |
+
else 0
|
106 |
+
)
|
107 |
+
f1 = (
|
108 |
+
2 * (precision * recall) / (precision + recall)
|
109 |
+
if (precision + recall) > 0
|
110 |
+
else 0
|
111 |
+
)
|
112 |
+
|
113 |
+
entity_metrics[entity_type].update(
|
114 |
+
{"precision": precision, "recall": recall, "f1": f1}
|
115 |
+
)
|
116 |
+
|
117 |
# Calculate overall metrics
|
118 |
+
total_tp = sum(
|
119 |
+
metrics["total_true_positives"] for metrics in entity_metrics.values()
|
120 |
+
)
|
121 |
+
total_fp = sum(
|
122 |
+
metrics["total_false_positives"] for metrics in entity_metrics.values()
|
123 |
+
)
|
124 |
+
total_fn = sum(
|
125 |
+
metrics["total_false_negatives"] for metrics in entity_metrics.values()
|
126 |
+
)
|
127 |
+
|
128 |
+
overall_precision = (
|
129 |
+
total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
|
130 |
+
)
|
131 |
+
overall_recall = (
|
132 |
+
total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
|
133 |
+
)
|
134 |
+
overall_f1 = (
|
135 |
+
2
|
136 |
+
* (overall_precision * overall_recall)
|
137 |
+
/ (overall_precision + overall_recall)
|
138 |
+
if (overall_precision + overall_recall) > 0
|
139 |
+
else 0
|
140 |
+
)
|
141 |
+
|
142 |
entity_metrics["overall"] = {
|
143 |
"precision": overall_precision,
|
144 |
"recall": overall_recall,
|
145 |
"f1": overall_f1,
|
146 |
"total_true_positives": total_tp,
|
147 |
"total_false_positives": total_fp,
|
148 |
+
"total_false_negatives": total_fn,
|
149 |
}
|
150 |
+
|
151 |
return entity_metrics
|
152 |
|
153 |
+
|
154 |
+
def load_ai4privacy_dataset(
|
155 |
+
num_samples: int = 100, split: str = "validation"
|
156 |
+
) -> List[Dict]:
|
157 |
"""
|
158 |
Load and prepare samples from the ai4privacy dataset.
|
159 |
+
|
160 |
Args:
|
161 |
num_samples: Number of samples to evaluate
|
162 |
split: Dataset split to use ("train" or "validation")
|
163 |
+
|
164 |
Returns:
|
165 |
List of prepared test cases
|
166 |
"""
|
167 |
# Load the dataset
|
168 |
dataset = load_dataset("ai4privacy/pii-masking-400k")
|
169 |
+
|
170 |
# Get the specified split
|
171 |
data_split = dataset[split]
|
172 |
+
|
173 |
# Randomly sample entries if num_samples is less than total
|
174 |
if num_samples < len(data_split):
|
175 |
indices = random.sample(range(len(data_split)), num_samples)
|
176 |
samples = [data_split[i] for i in indices]
|
177 |
else:
|
178 |
samples = data_split
|
179 |
+
|
180 |
# Convert to test case format
|
181 |
test_cases = []
|
182 |
for sample in samples:
|
183 |
# Extract entities from privacy_mask
|
184 |
entities: Dict[str, List[str]] = {}
|
185 |
+
for entity in sample["privacy_mask"]:
|
186 |
+
label = entity["label"]
|
187 |
+
value = entity["value"]
|
188 |
if label not in entities:
|
189 |
entities[label] = []
|
190 |
entities[label].append(value)
|
191 |
+
|
192 |
test_case = {
|
193 |
"description": f"AI4Privacy Sample (ID: {sample['uid']})",
|
194 |
+
"input_text": sample["source_text"],
|
195 |
"expected_entities": entities,
|
196 |
+
"masked_text": sample["masked_text"],
|
197 |
+
"language": sample["language"],
|
198 |
+
"locale": sample["locale"],
|
199 |
}
|
200 |
test_cases.append(test_case)
|
201 |
+
|
202 |
return test_cases
|
203 |
|
204 |
+
|
205 |
+
def save_results(
|
206 |
+
weave_results: Dict, model_name: str, output_dir: str = "evaluation_results"
|
207 |
+
):
|
208 |
"""Save evaluation results to files"""
|
209 |
output_dir = Path(output_dir)
|
210 |
output_dir.mkdir(exist_ok=True)
|
211 |
+
|
212 |
# Extract and process results
|
213 |
scorer_results = weave_results.get("EntityRecognitionScorer", [])
|
214 |
if not scorer_results or all(r is None for r in scorer_results):
|
215 |
print(f"No valid results to save for {model_name}")
|
216 |
return
|
217 |
+
|
218 |
# Calculate summary metrics
|
219 |
total_samples = len(scorer_results)
|
220 |
passed = sum(1 for r in scorer_results if r is not None and not isinstance(r, str))
|
221 |
+
|
222 |
# Aggregate entity-level metrics
|
223 |
entity_metrics = {}
|
224 |
for result in scorer_results:
|
225 |
try:
|
226 |
if isinstance(result, str) or not result:
|
227 |
continue
|
228 |
+
|
229 |
for entity_type, metrics in result.items():
|
230 |
if entity_type not in entity_metrics:
|
231 |
entity_metrics[entity_type] = {
|
232 |
"precision": [],
|
233 |
"recall": [],
|
234 |
+
"f1": [],
|
235 |
}
|
236 |
entity_metrics[entity_type]["precision"].append(metrics["precision"])
|
237 |
entity_metrics[entity_type]["recall"].append(metrics["recall"])
|
238 |
entity_metrics[entity_type]["f1"].append(metrics["f1"])
|
239 |
except (AttributeError, TypeError, KeyError):
|
240 |
continue
|
241 |
+
|
242 |
# Calculate averages
|
243 |
summary_metrics = {
|
244 |
"total": total_samples,
|
245 |
"passed": passed,
|
246 |
"failed": total_samples - passed,
|
247 |
+
"success_rate": (passed / total_samples) if total_samples > 0 else 0,
|
248 |
"entity_metrics": {
|
249 |
entity_type: {
|
250 |
+
"precision": (
|
251 |
+
sum(metrics["precision"]) / len(metrics["precision"])
|
252 |
+
if metrics["precision"]
|
253 |
+
else 0
|
254 |
+
),
|
255 |
+
"recall": (
|
256 |
+
sum(metrics["recall"]) / len(metrics["recall"])
|
257 |
+
if metrics["recall"]
|
258 |
+
else 0
|
259 |
+
),
|
260 |
+
"f1": sum(metrics["f1"]) / len(metrics["f1"]) if metrics["f1"] else 0,
|
261 |
}
|
262 |
for entity_type, metrics in entity_metrics.items()
|
263 |
+
},
|
264 |
}
|
265 |
+
|
266 |
# Save files
|
267 |
with open(output_dir / f"{model_name}_metrics.json", "w") as f:
|
268 |
json.dump(summary_metrics, f, indent=2)
|
269 |
+
|
270 |
# Save detailed results, filtering out string results
|
271 |
+
detailed_results = [
|
272 |
+
r for r in scorer_results if not isinstance(r, str) and r is not None
|
273 |
+
]
|
274 |
with open(output_dir / f"{model_name}_detailed_results.json", "w") as f:
|
275 |
json.dump(detailed_results, f, indent=2)
|
276 |
|
277 |
+
|
278 |
def print_metrics_summary(weave_results: Dict):
|
279 |
"""Print a summary of the evaluation metrics"""
|
280 |
print("\nEvaluation Summary")
|
281 |
print("=" * 80)
|
282 |
+
|
283 |
# Extract results from Weave's evaluation format
|
284 |
scorer_results = weave_results.get("EntityRecognitionScorer", {})
|
285 |
if not scorer_results:
|
286 |
print("No valid results available")
|
287 |
return
|
288 |
+
|
289 |
# Calculate overall metrics
|
290 |
total_samples = int(weave_results.get("model_latency", {}).get("count", 0))
|
291 |
passed = total_samples # Since we have results, all samples passed
|
292 |
failed = 0
|
293 |
+
|
294 |
print(f"Total Samples: {total_samples}")
|
295 |
print(f"Passed: {passed}")
|
296 |
print(f"Failed: {failed}")
|
297 |
print(f"Success Rate: {(passed/total_samples)*100:.2f}%")
|
298 |
+
|
299 |
# Print overall metrics
|
300 |
if "overall" in scorer_results:
|
301 |
overall = scorer_results["overall"]
|
|
|
306 |
print(f"{'Precision':<20} {overall['precision']['mean']:>10.2f}")
|
307 |
print(f"{'Recall':<20} {overall['recall']['mean']:>10.2f}")
|
308 |
print(f"{'F1':<20} {overall['f1']['mean']:>10.2f}")
|
309 |
+
|
310 |
# Print entity-level metrics
|
311 |
print("\nEntity-Level Metrics:")
|
312 |
print("-" * 80)
|
313 |
print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
|
314 |
print("-" * 80)
|
315 |
+
|
316 |
for entity_type, metrics in scorer_results.items():
|
317 |
if entity_type == "overall":
|
318 |
continue
|
319 |
+
|
320 |
precision = metrics.get("precision", {}).get("mean", 0)
|
321 |
recall = metrics.get("recall", {}).get("mean", 0)
|
322 |
f1 = metrics.get("f1", {}).get("mean", 0)
|
323 |
+
|
324 |
print(f"{entity_type:<20} {precision:>10.2f} {recall:>10.2f} {f1:>10.2f}")
|
325 |
|
326 |
+
|
327 |
def preprocess_model_input(example: Dict) -> Dict:
|
328 |
"""Preprocess dataset example to match model input format."""
|
329 |
return {
|
330 |
"prompt": example["input_text"],
|
331 |
+
"model_type": example.get(
|
332 |
+
"model_type", "unknown"
|
333 |
+
), # Add model type for Presidio mapping
|
334 |
}
|
335 |
|
336 |
+
|
337 |
def main():
|
338 |
"""Main evaluation function"""
|
339 |
weave.init("guardrails-genie-pii-evaluation")
|
340 |
+
|
341 |
# Load test cases
|
342 |
test_cases = load_ai4privacy_dataset(num_samples=100)
|
343 |
+
|
344 |
# Add model type to test cases for Presidio mapping
|
345 |
models = {
|
346 |
# "regex": RegexEntityRecognitionGuardrail(should_anonymize=True),
|
347 |
"presidio": PresidioEntityRecognitionGuardrail(should_anonymize=True),
|
348 |
# "transformers": TransformersEntityRecognitionGuardrail(should_anonymize=True)
|
349 |
}
|
350 |
+
|
351 |
scorer = EntityRecognitionScorer()
|
352 |
+
|
353 |
# Evaluate each model
|
354 |
for model_name, guardrail in models.items():
|
355 |
print(f"\nEvaluating {model_name} model...")
|
356 |
# Add model type to test cases
|
357 |
model_test_cases = [{**case, "model_type": model_name} for case in test_cases]
|
358 |
+
|
359 |
evaluation = Evaluation(
|
360 |
dataset=model_test_cases,
|
361 |
scorers=[scorer],
|
362 |
+
preprocess_model_input=preprocess_model_input,
|
363 |
)
|
364 |
+
|
365 |
results = asyncio.run(evaluation.evaluate(guardrail))
|
366 |
|
367 |
+
|
368 |
if __name__ == "__main__":
|
369 |
+
from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import (
|
370 |
+
PresidioEntityRecognitionGuardrail,
|
371 |
+
)
|
372 |
+
|
373 |
+
main()
|
guardrails_genie/guardrails/entity_recognition/pii_examples/pii_test_examples.py
CHANGED
@@ -18,8 +18,8 @@ Emergency Contact: Mary Johnson (Tel: 098-765-4321)
|
|
18 |
"SURNAME": ["Smith", "Johnson"],
|
19 |
"EMAIL": ["[email protected]"],
|
20 |
"PHONE_NUMBER": ["123-456-7890", "098-765-4321"],
|
21 |
-
"SOCIALNUM": ["123-45-6789"]
|
22 |
-
}
|
23 |
},
|
24 |
{
|
25 |
"description": "Meeting Notes with Attendees",
|
@@ -39,8 +39,8 @@ Action Items:
|
|
39 |
"GIVENNAME": ["Sarah", "Robert", "Tom", "Bob"],
|
40 |
"SURNAME": ["Williams", "Brown", "Wilson"],
|
41 |
"EMAIL": ["[email protected]", "[email protected]"],
|
42 |
-
"PHONE_NUMBER": ["555-0123-4567", "777-888-9999"]
|
43 |
-
}
|
44 |
},
|
45 |
{
|
46 |
"description": "Medical Record",
|
@@ -57,8 +57,8 @@ Emergency Contact: Michael Thompson (555-123-4567)
|
|
57 |
"GIVENNAME": ["Emma", "James", "Michael"],
|
58 |
"SURNAME": ["Thompson", "Wilson", "Thompson"],
|
59 |
"EMAIL": ["[email protected]"],
|
60 |
-
"PHONE_NUMBER": ["555-123-4567"]
|
61 |
-
}
|
62 |
},
|
63 |
{
|
64 |
"description": "No PII Content",
|
@@ -68,7 +68,7 @@ Project Status Update:
|
|
68 |
- Budget is within limits
|
69 |
- Next review scheduled for next week
|
70 |
""",
|
71 |
-
"expected_entities": {}
|
72 |
},
|
73 |
{
|
74 |
"description": "Mixed Format Phone Numbers",
|
@@ -84,10 +84,10 @@ Emergency: 555 444 3333
|
|
84 |
"(555) 123-4567",
|
85 |
"555.987.6543",
|
86 |
"+1-555-321-7890",
|
87 |
-
"555 444 3333"
|
88 |
]
|
89 |
-
}
|
90 |
-
}
|
91 |
]
|
92 |
|
93 |
# Additional examples can be added to test specific edge cases or formats
|
@@ -103,37 +103,41 @@ [email protected]
|
|
103 |
"EMAIL": [
|
104 | |
105 | |
106 | |
107 |
],
|
108 |
"GIVENNAME": ["John", "Jane", "Bob"],
|
109 |
-
"SURNAME": ["Doe", "Smith", "Jones"]
|
110 |
-
}
|
111 |
}
|
112 |
]
|
113 |
|
|
|
114 |
def validate_entities(detected: dict, expected: dict) -> bool:
|
115 |
"""Compare detected entities with expected entities"""
|
116 |
if set(detected.keys()) != set(expected.keys()):
|
117 |
return False
|
118 |
return all(set(detected[k]) == set(expected[k]) for k in expected.keys())
|
119 |
|
|
|
120 |
def run_test_case(guardrail, test_case, test_type="Main"):
|
121 |
"""Run a single test case and print results"""
|
122 |
print(f"\n{test_type} Test Case: {test_case['description']}")
|
123 |
print("-" * 50)
|
124 |
-
|
125 |
-
result = guardrail.guard(test_case[
|
126 |
-
expected = test_case[
|
127 |
-
|
128 |
# Validate results
|
129 |
matches = validate_entities(result.detected_entities, expected)
|
130 |
-
|
131 |
print(f"Test Status: {'✓ PASS' if matches else '✗ FAIL'}")
|
132 |
print(f"Contains PII: {result.contains_entities}")
|
133 |
-
|
134 |
if not matches:
|
135 |
print("\nEntity Comparison:")
|
136 |
-
all_entity_types = set(
|
|
|
|
|
137 |
for entity_type in all_entity_types:
|
138 |
detected = set(result.detected_entities.get(entity_type, []))
|
139 |
expected_set = set(expected.get(entity_type, []))
|
@@ -143,8 +147,8 @@ def run_test_case(guardrail, test_case, test_type="Main"):
|
|
143 |
if detected != expected_set:
|
144 |
print(f" Missing: {sorted(expected_set - detected)}")
|
145 |
print(f" Extra: {sorted(detected - expected_set)}")
|
146 |
-
|
147 |
if result.anonymized_text:
|
148 |
print(f"\nAnonymized Text:\n{result.anonymized_text}")
|
149 |
-
|
150 |
-
return matches
|
|
|
18 |
"SURNAME": ["Smith", "Johnson"],
|
19 |
"EMAIL": ["[email protected]"],
|
20 |
"PHONE_NUMBER": ["123-456-7890", "098-765-4321"],
|
21 |
+
"SOCIALNUM": ["123-45-6789"],
|
22 |
+
},
|
23 |
},
|
24 |
{
|
25 |
"description": "Meeting Notes with Attendees",
|
|
|
39 |
"GIVENNAME": ["Sarah", "Robert", "Tom", "Bob"],
|
40 |
"SURNAME": ["Williams", "Brown", "Wilson"],
|
41 |
"EMAIL": ["[email protected]", "[email protected]"],
|
42 |
+
"PHONE_NUMBER": ["555-0123-4567", "777-888-9999"],
|
43 |
+
},
|
44 |
},
|
45 |
{
|
46 |
"description": "Medical Record",
|
|
|
57 |
"GIVENNAME": ["Emma", "James", "Michael"],
|
58 |
"SURNAME": ["Thompson", "Wilson", "Thompson"],
|
59 |
"EMAIL": ["[email protected]"],
|
60 |
+
"PHONE_NUMBER": ["555-123-4567"],
|
61 |
+
},
|
62 |
},
|
63 |
{
|
64 |
"description": "No PII Content",
|
|
|
68 |
- Budget is within limits
|
69 |
- Next review scheduled for next week
|
70 |
""",
|
71 |
+
"expected_entities": {},
|
72 |
},
|
73 |
{
|
74 |
"description": "Mixed Format Phone Numbers",
|
|
|
84 |
"(555) 123-4567",
|
85 |
"555.987.6543",
|
86 |
"+1-555-321-7890",
|
87 |
+
"555 444 3333",
|
88 |
]
|
89 |
+
},
|
90 |
+
},
|
91 |
]
|
92 |
|
93 |
# Additional examples can be added to test specific edge cases or formats
|
|
|
103 |
"EMAIL": [
|
104 | |
105 | |
106 |
+
"[email protected]",
|
107 |
],
|
108 |
"GIVENNAME": ["John", "Jane", "Bob"],
|
109 |
+
"SURNAME": ["Doe", "Smith", "Jones"],
|
110 |
+
},
|
111 |
}
|
112 |
]
|
113 |
|
114 |
+
|
115 |
def validate_entities(detected: dict, expected: dict) -> bool:
|
116 |
"""Compare detected entities with expected entities"""
|
117 |
if set(detected.keys()) != set(expected.keys()):
|
118 |
return False
|
119 |
return all(set(detected[k]) == set(expected[k]) for k in expected.keys())
|
120 |
|
121 |
+
|
122 |
def run_test_case(guardrail, test_case, test_type="Main"):
|
123 |
"""Run a single test case and print results"""
|
124 |
print(f"\n{test_type} Test Case: {test_case['description']}")
|
125 |
print("-" * 50)
|
126 |
+
|
127 |
+
result = guardrail.guard(test_case["input_text"])
|
128 |
+
expected = test_case["expected_entities"]
|
129 |
+
|
130 |
# Validate results
|
131 |
matches = validate_entities(result.detected_entities, expected)
|
132 |
+
|
133 |
print(f"Test Status: {'✓ PASS' if matches else '✗ FAIL'}")
|
134 |
print(f"Contains PII: {result.contains_entities}")
|
135 |
+
|
136 |
if not matches:
|
137 |
print("\nEntity Comparison:")
|
138 |
+
all_entity_types = set(
|
139 |
+
list(result.detected_entities.keys()) + list(expected.keys())
|
140 |
+
)
|
141 |
for entity_type in all_entity_types:
|
142 |
detected = set(result.detected_entities.get(entity_type, []))
|
143 |
expected_set = set(expected.get(entity_type, []))
|
|
|
147 |
if detected != expected_set:
|
148 |
print(f" Missing: {sorted(expected_set - detected)}")
|
149 |
print(f" Extra: {sorted(detected - expected_set)}")
|
150 |
+
|
151 |
if result.anonymized_text:
|
152 |
print(f"\nAnonymized Text:\n{result.anonymized_text}")
|
153 |
+
|
154 |
+
return matches
|
guardrails_genie/guardrails/entity_recognition/pii_examples/run_presidio_model.py
CHANGED
@@ -1,15 +1,22 @@
|
|
1 |
-
from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
|
2 |
-
from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import PII_TEST_EXAMPLES, EDGE_CASE_EXAMPLES, run_test_case, validate_entities
|
3 |
import weave
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
def test_pii_detection():
|
6 |
"""Test PII detection scenarios using predefined test cases"""
|
7 |
weave.init("guardrails-genie-pii-presidio-model")
|
8 |
-
|
9 |
# Create the guardrail with default entities and anonymization enabled
|
10 |
pii_guardrail = PresidioEntityRecognitionGuardrail(
|
11 |
-
should_anonymize=True,
|
12 |
-
show_available_entities=True
|
13 |
)
|
14 |
|
15 |
# Test statistics
|
@@ -38,5 +45,6 @@ def test_pii_detection():
|
|
38 |
print(f"Failed: {total_tests - passed_tests}")
|
39 |
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
40 |
|
|
|
41 |
if __name__ == "__main__":
|
42 |
test_pii_detection()
|
|
|
|
|
|
|
1 |
import weave
|
2 |
|
3 |
+
from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import (
|
4 |
+
EDGE_CASE_EXAMPLES,
|
5 |
+
PII_TEST_EXAMPLES,
|
6 |
+
run_test_case,
|
7 |
+
)
|
8 |
+
from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import (
|
9 |
+
PresidioEntityRecognitionGuardrail,
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
def test_pii_detection():
|
14 |
"""Test PII detection scenarios using predefined test cases"""
|
15 |
weave.init("guardrails-genie-pii-presidio-model")
|
16 |
+
|
17 |
# Create the guardrail with default entities and anonymization enabled
|
18 |
pii_guardrail = PresidioEntityRecognitionGuardrail(
|
19 |
+
should_anonymize=True, show_available_entities=True
|
|
|
20 |
)
|
21 |
|
22 |
# Test statistics
|
|
|
45 |
print(f"Failed: {total_tests - passed_tests}")
|
46 |
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
47 |
|
48 |
+
|
49 |
if __name__ == "__main__":
|
50 |
test_pii_detection()
|
guardrails_genie/guardrails/entity_recognition/pii_examples/run_regex_model.py
CHANGED
@@ -1,15 +1,22 @@
|
|
1 |
-
from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
|
2 |
-
from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import PII_TEST_EXAMPLES, EDGE_CASE_EXAMPLES, run_test_case, validate_entities
|
3 |
import weave
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
def test_pii_detection():
|
6 |
"""Test PII detection scenarios using predefined test cases"""
|
7 |
weave.init("guardrails-genie-pii-regex-model")
|
8 |
-
|
9 |
# Create the guardrail with default entities and anonymization enabled
|
10 |
pii_guardrail = RegexEntityRecognitionGuardrail(
|
11 |
-
should_anonymize=True,
|
12 |
-
show_available_entities=True
|
13 |
)
|
14 |
|
15 |
# Test statistics
|
@@ -38,5 +45,6 @@ def test_pii_detection():
|
|
38 |
print(f"Failed: {total_tests - passed_tests}")
|
39 |
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
40 |
|
|
|
41 |
if __name__ == "__main__":
|
42 |
test_pii_detection()
|
|
|
|
|
|
|
1 |
import weave
|
2 |
|
3 |
+
from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import (
|
4 |
+
EDGE_CASE_EXAMPLES,
|
5 |
+
PII_TEST_EXAMPLES,
|
6 |
+
run_test_case,
|
7 |
+
)
|
8 |
+
from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import (
|
9 |
+
RegexEntityRecognitionGuardrail,
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
def test_pii_detection():
|
14 |
"""Test PII detection scenarios using predefined test cases"""
|
15 |
weave.init("guardrails-genie-pii-regex-model")
|
16 |
+
|
17 |
# Create the guardrail with default entities and anonymization enabled
|
18 |
pii_guardrail = RegexEntityRecognitionGuardrail(
|
19 |
+
should_anonymize=True, show_available_entities=True
|
|
|
20 |
)
|
21 |
|
22 |
# Test statistics
|
|
|
45 |
print(f"Failed: {total_tests - passed_tests}")
|
46 |
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
47 |
|
48 |
+
|
49 |
if __name__ == "__main__":
|
50 |
test_pii_detection()
|
guardrails_genie/guardrails/entity_recognition/pii_examples/run_transformers.py
CHANGED
@@ -1,16 +1,30 @@
|
|
1 |
-
from guardrails_genie.guardrails.entity_recognition.transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
|
2 |
-
from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import PII_TEST_EXAMPLES, EDGE_CASE_EXAMPLES, run_test_case, validate_entities
|
3 |
import weave
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
def test_pii_detection():
|
6 |
"""Test PII detection scenarios using predefined test cases"""
|
7 |
weave.init("guardrails-genie-pii-transformers-pipeline-model")
|
8 |
-
|
9 |
# Create the guardrail with default entities and anonymization enabled
|
10 |
pii_guardrail = TransformersEntityRecognitionGuardrail(
|
11 |
-
selected_entities=[
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
should_anonymize=True,
|
13 |
-
show_available_entities=True
|
14 |
)
|
15 |
|
16 |
# Test statistics
|
@@ -39,5 +53,6 @@ def test_pii_detection():
|
|
39 |
print(f"Failed: {total_tests - passed_tests}")
|
40 |
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
41 |
|
|
|
42 |
if __name__ == "__main__":
|
43 |
test_pii_detection()
|
|
|
|
|
|
|
1 |
import weave
|
2 |
|
3 |
+
from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import (
|
4 |
+
EDGE_CASE_EXAMPLES,
|
5 |
+
PII_TEST_EXAMPLES,
|
6 |
+
run_test_case,
|
7 |
+
)
|
8 |
+
from guardrails_genie.guardrails.entity_recognition.transformers_entity_recognition_guardrail import (
|
9 |
+
TransformersEntityRecognitionGuardrail,
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
def test_pii_detection():
|
14 |
"""Test PII detection scenarios using predefined test cases"""
|
15 |
weave.init("guardrails-genie-pii-transformers-pipeline-model")
|
16 |
+
|
17 |
# Create the guardrail with default entities and anonymization enabled
|
18 |
pii_guardrail = TransformersEntityRecognitionGuardrail(
|
19 |
+
selected_entities=[
|
20 |
+
"GIVENNAME",
|
21 |
+
"SURNAME",
|
22 |
+
"EMAIL",
|
23 |
+
"TELEPHONENUM",
|
24 |
+
"SOCIALNUM",
|
25 |
+
],
|
26 |
should_anonymize=True,
|
27 |
+
show_available_entities=True,
|
28 |
)
|
29 |
|
30 |
# Test statistics
|
|
|
53 |
print(f"Failed: {total_tests - passed_tests}")
|
54 |
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
55 |
|
56 |
+
|
57 |
if __name__ == "__main__":
|
58 |
test_pii_detection()
|
guardrails_genie/guardrails/entity_recognition/presidio_entity_recognition_guardrail.py
CHANGED
@@ -1,12 +1,18 @@
|
|
1 |
-
from typing import
|
2 |
-
import weave
|
3 |
-
from pydantic import BaseModel
|
4 |
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from presidio_anonymizer import AnonymizerEngine
|
|
|
7 |
|
8 |
from ..base import Guardrail
|
9 |
|
|
|
10 |
class PresidioEntityRecognitionResponse(BaseModel):
|
11 |
contains_entities: bool
|
12 |
detected_entities: Dict[str, List[str]]
|
@@ -17,6 +23,7 @@ class PresidioEntityRecognitionResponse(BaseModel):
|
|
17 |
def safe(self) -> bool:
|
18 |
return not self.contains_entities
|
19 |
|
|
|
20 |
class PresidioEntityRecognitionSimpleResponse(BaseModel):
|
21 |
contains_entities: bool
|
22 |
explanation: str
|
@@ -26,21 +33,24 @@ class PresidioEntityRecognitionSimpleResponse(BaseModel):
|
|
26 |
def safe(self) -> bool:
|
27 |
return not self.contains_entities
|
28 |
|
29 |
-
|
|
|
30 |
class PresidioEntityRecognitionGuardrail(Guardrail):
|
31 |
@staticmethod
|
32 |
def get_available_entities() -> List[str]:
|
33 |
registry = RecognizerRegistry()
|
34 |
analyzer = AnalyzerEngine(registry=registry)
|
35 |
-
return [
|
36 |
-
|
37 |
-
|
|
|
|
|
38 |
analyzer: AnalyzerEngine
|
39 |
anonymizer: AnonymizerEngine
|
40 |
selected_entities: List[str]
|
41 |
should_anonymize: bool
|
42 |
language: str
|
43 |
-
|
44 |
def __init__(
|
45 |
self,
|
46 |
selected_entities: Optional[List[str]] = None,
|
@@ -49,7 +59,7 @@ class PresidioEntityRecognitionGuardrail(Guardrail):
|
|
49 |
deny_lists: Optional[Dict[str, List[str]]] = None,
|
50 |
regex_patterns: Optional[Dict[str, List[Dict[str, str]]]] = None,
|
51 |
custom_recognizers: Optional[List[Any]] = None,
|
52 |
-
show_available_entities: bool = False
|
53 |
):
|
54 |
# If show_available_entities is True, print available entities
|
55 |
if show_available_entities:
|
@@ -63,36 +73,37 @@ class PresidioEntityRecognitionGuardrail(Guardrail):
|
|
63 |
# Initialize default values to all available entities
|
64 |
if selected_entities is None:
|
65 |
selected_entities = self.get_available_entities()
|
66 |
-
|
67 |
# Get available entities dynamically
|
68 |
available_entities = self.get_available_entities()
|
69 |
-
|
70 |
# Filter out invalid entities and warn user
|
71 |
invalid_entities = [e for e in selected_entities if e not in available_entities]
|
72 |
valid_entities = [e for e in selected_entities if e in available_entities]
|
73 |
-
|
74 |
if invalid_entities:
|
75 |
-
print(
|
|
|
|
|
76 |
print(f"Continuing with valid entities: {valid_entities}")
|
77 |
selected_entities = valid_entities
|
78 |
-
|
79 |
# Initialize analyzer with default recognizers
|
80 |
analyzer = AnalyzerEngine()
|
81 |
-
|
82 |
# Add custom recognizers if provided
|
83 |
if custom_recognizers:
|
84 |
for recognizer in custom_recognizers:
|
85 |
analyzer.registry.add_recognizer(recognizer)
|
86 |
-
|
87 |
# Add deny list recognizers if provided
|
88 |
if deny_lists:
|
89 |
for entity_type, tokens in deny_lists.items():
|
90 |
deny_list_recognizer = PatternRecognizer(
|
91 |
-
supported_entity=entity_type,
|
92 |
-
deny_list=tokens
|
93 |
)
|
94 |
analyzer.registry.add_recognizer(deny_list_recognizer)
|
95 |
-
|
96 |
# Add regex pattern recognizers if provided
|
97 |
if regex_patterns:
|
98 |
for entity_type, patterns in regex_patterns.items():
|
@@ -100,89 +111,92 @@ class PresidioEntityRecognitionGuardrail(Guardrail):
|
|
100 |
Pattern(
|
101 |
name=pattern.get("name", f"pattern_{i}"),
|
102 |
regex=pattern["regex"],
|
103 |
-
score=pattern.get("score", 0.5)
|
104 |
-
)
|
|
|
105 |
]
|
106 |
regex_recognizer = PatternRecognizer(
|
107 |
-
supported_entity=entity_type,
|
108 |
-
patterns=presidio_patterns
|
109 |
)
|
110 |
analyzer.registry.add_recognizer(regex_recognizer)
|
111 |
-
|
112 |
# Initialize Presidio engines
|
113 |
anonymizer = AnonymizerEngine()
|
114 |
-
|
115 |
# Call parent class constructor with all fields
|
116 |
super().__init__(
|
117 |
analyzer=analyzer,
|
118 |
anonymizer=anonymizer,
|
119 |
selected_entities=selected_entities,
|
120 |
should_anonymize=should_anonymize,
|
121 |
-
language=language
|
122 |
)
|
123 |
|
124 |
@weave.op()
|
125 |
-
def guard(
|
|
|
|
|
126 |
"""
|
127 |
Check if the input prompt contains any entities using Presidio.
|
128 |
-
|
129 |
Args:
|
130 |
prompt: The text to analyze
|
131 |
return_detected_types: If True, returns detailed entity type information
|
132 |
"""
|
133 |
# Analyze text for entities
|
134 |
analyzer_results = self.analyzer.analyze(
|
135 |
-
text=str(prompt),
|
136 |
-
entities=self.selected_entities,
|
137 |
-
language=self.language
|
138 |
)
|
139 |
-
|
140 |
# Group results by entity type
|
141 |
detected_entities = {}
|
142 |
for result in analyzer_results:
|
143 |
entity_type = result.entity_type
|
144 |
-
text_slice = prompt[result.start:result.end]
|
145 |
if entity_type not in detected_entities:
|
146 |
detected_entities[entity_type] = []
|
147 |
detected_entities[entity_type].append(text_slice)
|
148 |
-
|
149 |
# Create explanation
|
150 |
explanation_parts = []
|
151 |
if detected_entities:
|
152 |
explanation_parts.append("Found the following entities in the text:")
|
153 |
for entity_type, instances in detected_entities.items():
|
154 |
-
explanation_parts.append(
|
|
|
|
|
155 |
else:
|
156 |
explanation_parts.append("No entities detected in the text.")
|
157 |
-
|
158 |
# Add information about what was checked
|
159 |
explanation_parts.append("\nChecked for these entity types:")
|
160 |
for entity in self.selected_entities:
|
161 |
explanation_parts.append(f"- {entity}")
|
162 |
-
|
163 |
# Anonymize if requested
|
164 |
anonymized_text = None
|
165 |
if self.should_anonymize and detected_entities:
|
166 |
anonymized_result = self.anonymizer.anonymize(
|
167 |
-
text=prompt,
|
168 |
-
analyzer_results=analyzer_results
|
169 |
)
|
170 |
anonymized_text = anonymized_result.text
|
171 |
-
|
172 |
if return_detected_types:
|
173 |
return PresidioEntityRecognitionResponse(
|
174 |
contains_entities=bool(detected_entities),
|
175 |
detected_entities=detected_entities,
|
176 |
explanation="\n".join(explanation_parts),
|
177 |
-
anonymized_text=anonymized_text
|
178 |
)
|
179 |
else:
|
180 |
return PresidioEntityRecognitionSimpleResponse(
|
181 |
contains_entities=bool(detected_entities),
|
182 |
explanation="\n".join(explanation_parts),
|
183 |
-
anonymized_text=anonymized_text
|
184 |
)
|
185 |
-
|
186 |
@weave.op()
|
187 |
-
def predict(
|
188 |
-
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional
|
|
|
|
|
2 |
|
3 |
+
import weave
|
4 |
+
from presidio_analyzer import (
|
5 |
+
AnalyzerEngine,
|
6 |
+
Pattern,
|
7 |
+
PatternRecognizer,
|
8 |
+
RecognizerRegistry,
|
9 |
+
)
|
10 |
from presidio_anonymizer import AnonymizerEngine
|
11 |
+
from pydantic import BaseModel
|
12 |
|
13 |
from ..base import Guardrail
|
14 |
|
15 |
+
|
16 |
class PresidioEntityRecognitionResponse(BaseModel):
|
17 |
contains_entities: bool
|
18 |
detected_entities: Dict[str, List[str]]
|
|
|
23 |
def safe(self) -> bool:
|
24 |
return not self.contains_entities
|
25 |
|
26 |
+
|
27 |
class PresidioEntityRecognitionSimpleResponse(BaseModel):
|
28 |
contains_entities: bool
|
29 |
explanation: str
|
|
|
33 |
def safe(self) -> bool:
|
34 |
return not self.contains_entities
|
35 |
|
36 |
+
|
37 |
+
# TODO: Add support for transformers workflow and not just Spacy
|
38 |
class PresidioEntityRecognitionGuardrail(Guardrail):
|
39 |
@staticmethod
|
40 |
def get_available_entities() -> List[str]:
|
41 |
registry = RecognizerRegistry()
|
42 |
analyzer = AnalyzerEngine(registry=registry)
|
43 |
+
return [
|
44 |
+
recognizer.supported_entities[0]
|
45 |
+
for recognizer in analyzer.registry.recognizers
|
46 |
+
]
|
47 |
+
|
48 |
analyzer: AnalyzerEngine
|
49 |
anonymizer: AnonymizerEngine
|
50 |
selected_entities: List[str]
|
51 |
should_anonymize: bool
|
52 |
language: str
|
53 |
+
|
54 |
def __init__(
|
55 |
self,
|
56 |
selected_entities: Optional[List[str]] = None,
|
|
|
59 |
deny_lists: Optional[Dict[str, List[str]]] = None,
|
60 |
regex_patterns: Optional[Dict[str, List[Dict[str, str]]]] = None,
|
61 |
custom_recognizers: Optional[List[Any]] = None,
|
62 |
+
show_available_entities: bool = False,
|
63 |
):
|
64 |
# If show_available_entities is True, print available entities
|
65 |
if show_available_entities:
|
|
|
73 |
# Initialize default values to all available entities
|
74 |
if selected_entities is None:
|
75 |
selected_entities = self.get_available_entities()
|
76 |
+
|
77 |
# Get available entities dynamically
|
78 |
available_entities = self.get_available_entities()
|
79 |
+
|
80 |
# Filter out invalid entities and warn user
|
81 |
invalid_entities = [e for e in selected_entities if e not in available_entities]
|
82 |
valid_entities = [e for e in selected_entities if e in available_entities]
|
83 |
+
|
84 |
if invalid_entities:
|
85 |
+
print(
|
86 |
+
f"\nWarning: The following entities are not available and will be ignored: {invalid_entities}"
|
87 |
+
)
|
88 |
print(f"Continuing with valid entities: {valid_entities}")
|
89 |
selected_entities = valid_entities
|
90 |
+
|
91 |
# Initialize analyzer with default recognizers
|
92 |
analyzer = AnalyzerEngine()
|
93 |
+
|
94 |
# Add custom recognizers if provided
|
95 |
if custom_recognizers:
|
96 |
for recognizer in custom_recognizers:
|
97 |
analyzer.registry.add_recognizer(recognizer)
|
98 |
+
|
99 |
# Add deny list recognizers if provided
|
100 |
if deny_lists:
|
101 |
for entity_type, tokens in deny_lists.items():
|
102 |
deny_list_recognizer = PatternRecognizer(
|
103 |
+
supported_entity=entity_type, deny_list=tokens
|
|
|
104 |
)
|
105 |
analyzer.registry.add_recognizer(deny_list_recognizer)
|
106 |
+
|
107 |
# Add regex pattern recognizers if provided
|
108 |
if regex_patterns:
|
109 |
for entity_type, patterns in regex_patterns.items():
|
|
|
111 |
Pattern(
|
112 |
name=pattern.get("name", f"pattern_{i}"),
|
113 |
regex=pattern["regex"],
|
114 |
+
score=pattern.get("score", 0.5),
|
115 |
+
)
|
116 |
+
for i, pattern in enumerate(patterns)
|
117 |
]
|
118 |
regex_recognizer = PatternRecognizer(
|
119 |
+
supported_entity=entity_type, patterns=presidio_patterns
|
|
|
120 |
)
|
121 |
analyzer.registry.add_recognizer(regex_recognizer)
|
122 |
+
|
123 |
# Initialize Presidio engines
|
124 |
anonymizer = AnonymizerEngine()
|
125 |
+
|
126 |
# Call parent class constructor with all fields
|
127 |
super().__init__(
|
128 |
analyzer=analyzer,
|
129 |
anonymizer=anonymizer,
|
130 |
selected_entities=selected_entities,
|
131 |
should_anonymize=should_anonymize,
|
132 |
+
language=language,
|
133 |
)
|
134 |
|
135 |
@weave.op()
|
136 |
+
def guard(
|
137 |
+
self, prompt: str, return_detected_types: bool = True, **kwargs
|
138 |
+
) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse:
|
139 |
"""
|
140 |
Check if the input prompt contains any entities using Presidio.
|
141 |
+
|
142 |
Args:
|
143 |
prompt: The text to analyze
|
144 |
return_detected_types: If True, returns detailed entity type information
|
145 |
"""
|
146 |
# Analyze text for entities
|
147 |
analyzer_results = self.analyzer.analyze(
|
148 |
+
text=str(prompt), entities=self.selected_entities, language=self.language
|
|
|
|
|
149 |
)
|
150 |
+
|
151 |
# Group results by entity type
|
152 |
detected_entities = {}
|
153 |
for result in analyzer_results:
|
154 |
entity_type = result.entity_type
|
155 |
+
text_slice = prompt[result.start : result.end]
|
156 |
if entity_type not in detected_entities:
|
157 |
detected_entities[entity_type] = []
|
158 |
detected_entities[entity_type].append(text_slice)
|
159 |
+
|
160 |
# Create explanation
|
161 |
explanation_parts = []
|
162 |
if detected_entities:
|
163 |
explanation_parts.append("Found the following entities in the text:")
|
164 |
for entity_type, instances in detected_entities.items():
|
165 |
+
explanation_parts.append(
|
166 |
+
f"- {entity_type}: {len(instances)} instance(s)"
|
167 |
+
)
|
168 |
else:
|
169 |
explanation_parts.append("No entities detected in the text.")
|
170 |
+
|
171 |
# Add information about what was checked
|
172 |
explanation_parts.append("\nChecked for these entity types:")
|
173 |
for entity in self.selected_entities:
|
174 |
explanation_parts.append(f"- {entity}")
|
175 |
+
|
176 |
# Anonymize if requested
|
177 |
anonymized_text = None
|
178 |
if self.should_anonymize and detected_entities:
|
179 |
anonymized_result = self.anonymizer.anonymize(
|
180 |
+
text=prompt, analyzer_results=analyzer_results
|
|
|
181 |
)
|
182 |
anonymized_text = anonymized_result.text
|
183 |
+
|
184 |
if return_detected_types:
|
185 |
return PresidioEntityRecognitionResponse(
|
186 |
contains_entities=bool(detected_entities),
|
187 |
detected_entities=detected_entities,
|
188 |
explanation="\n".join(explanation_parts),
|
189 |
+
anonymized_text=anonymized_text,
|
190 |
)
|
191 |
else:
|
192 |
return PresidioEntityRecognitionSimpleResponse(
|
193 |
contains_entities=bool(detected_entities),
|
194 |
explanation="\n".join(explanation_parts),
|
195 |
+
anonymized_text=anonymized_text,
|
196 |
)
|
197 |
+
|
198 |
@weave.op()
|
199 |
+
def predict(
|
200 |
+
self, prompt: str, return_detected_types: bool = True, **kwargs
|
201 |
+
) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse:
|
202 |
+
return self.guard(prompt, return_detected_types=return_detected_types, **kwargs)
|
guardrails_genie/guardrails/entity_recognition/regex_entity_recognition_guardrail.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
-
|
|
|
2 |
|
3 |
import weave
|
4 |
from pydantic import BaseModel
|
5 |
|
6 |
from ...regex_model import RegexModel
|
7 |
from ..base import Guardrail
|
8 |
-
import re
|
9 |
|
10 |
|
11 |
class RegexEntityRecognitionResponse(BaseModel):
|
@@ -33,28 +33,34 @@ class RegexEntityRecognitionGuardrail(Guardrail):
|
|
33 |
regex_model: RegexModel
|
34 |
patterns: Dict[str, str] = {}
|
35 |
should_anonymize: bool = False
|
36 |
-
|
37 |
DEFAULT_PATTERNS: ClassVar[Dict[str, str]] = {
|
38 |
-
"EMAIL": r
|
39 |
-
"TELEPHONENUM": r
|
40 |
-
"SOCIALNUM": r
|
41 |
-
"CREDITCARDNUMBER": r
|
42 |
-
"DATEOFBIRTH": r
|
43 |
-
"DRIVERLICENSENUM": r
|
44 |
-
"ACCOUNTNUM": r
|
45 |
-
"ZIPCODE": r
|
46 |
-
"GIVENNAME": r
|
47 |
-
"SURNAME": r
|
48 |
-
"CITY": r
|
49 |
-
"STREET": r
|
50 |
-
"IDCARDNUM": r
|
51 |
-
"USERNAME": r
|
52 |
-
"PASSWORD": r
|
53 |
-
"TAXNUM": r
|
54 |
-
"BUILDINGNUM": r
|
55 |
}
|
56 |
-
|
57 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
patterns = {}
|
59 |
if use_defaults:
|
60 |
patterns = self.DEFAULT_PATTERNS.copy()
|
@@ -63,15 +69,15 @@ class RegexEntityRecognitionGuardrail(Guardrail):
|
|
63 |
|
64 |
if show_available_entities:
|
65 |
self._print_available_entities(patterns.keys())
|
66 |
-
|
67 |
# Create the RegexModel instance
|
68 |
regex_model = RegexModel(patterns=patterns)
|
69 |
-
|
70 |
# Initialize the base class with both the regex_model and patterns
|
71 |
super().__init__(
|
72 |
-
regex_model=regex_model,
|
73 |
patterns=patterns,
|
74 |
-
should_anonymize=should_anonymize
|
75 |
)
|
76 |
|
77 |
def text_to_pattern(self, text: str) -> str:
|
@@ -82,7 +88,7 @@ class RegexEntityRecognitionGuardrail(Guardrail):
|
|
82 |
escaped_text = re.escape(text)
|
83 |
# Create a pattern that matches the exact text, case-insensitive
|
84 |
return rf"\b{escaped_text}\b"
|
85 |
-
|
86 |
def _print_available_entities(self, entities: List[str]):
|
87 |
"""Print available entities"""
|
88 |
print("\nAvailable entity types:")
|
@@ -92,16 +98,23 @@ class RegexEntityRecognitionGuardrail(Guardrail):
|
|
92 |
print("=" * 25 + "\n")
|
93 |
|
94 |
@weave.op()
|
95 |
-
def guard(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
"""
|
97 |
Check if the input prompt contains any entities based on the regex patterns.
|
98 |
-
|
99 |
Args:
|
100 |
prompt: Input text to check for entities
|
101 |
-
custom_terms: List of custom terms to be converted into regex patterns. If provided,
|
102 |
only these terms will be checked, ignoring default patterns.
|
103 |
return_detected_types: If True, returns detailed entity type information
|
104 |
-
|
105 |
Returns:
|
106 |
RegexEntityRecognitionResponse or RegexEntityRecognitionSimpleResponse containing detection results
|
107 |
"""
|
@@ -113,7 +126,7 @@ class RegexEntityRecognitionGuardrail(Guardrail):
|
|
113 |
else:
|
114 |
# Use the original regex_model if no custom terms provided
|
115 |
result = self.regex_model.check(prompt)
|
116 |
-
|
117 |
# Create detailed explanation
|
118 |
explanation_parts = []
|
119 |
if result.matched_patterns:
|
@@ -122,35 +135,50 @@ class RegexEntityRecognitionGuardrail(Guardrail):
|
|
122 |
explanation_parts.append(f"- {entity_type}: {len(matches)} instance(s)")
|
123 |
else:
|
124 |
explanation_parts.append("No entities detected in the text.")
|
125 |
-
|
126 |
if result.failed_patterns:
|
127 |
explanation_parts.append("\nChecked but did not find these entity types:")
|
128 |
for pattern in result.failed_patterns:
|
129 |
explanation_parts.append(f"- {pattern}")
|
130 |
-
|
131 |
# Updated anonymization logic
|
132 |
anonymized_text = None
|
133 |
-
if getattr(self,
|
134 |
anonymized_text = prompt
|
135 |
for entity_type, matches in result.matched_patterns.items():
|
136 |
for match in matches:
|
137 |
-
replacement =
|
|
|
|
|
|
|
|
|
138 |
anonymized_text = anonymized_text.replace(match, replacement)
|
139 |
-
|
140 |
if return_detected_types:
|
141 |
return RegexEntityRecognitionResponse(
|
142 |
contains_entities=not result.passed,
|
143 |
detected_entities=result.matched_patterns,
|
144 |
explanation="\n".join(explanation_parts),
|
145 |
-
anonymized_text=anonymized_text
|
146 |
)
|
147 |
else:
|
148 |
return RegexEntityRecognitionSimpleResponse(
|
149 |
contains_entities=not result.passed,
|
150 |
explanation="\n".join(explanation_parts),
|
151 |
-
anonymized_text=anonymized_text
|
152 |
)
|
153 |
|
154 |
@weave.op()
|
155 |
-
def predict(
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import ClassVar, Dict, List, Optional
|
3 |
|
4 |
import weave
|
5 |
from pydantic import BaseModel
|
6 |
|
7 |
from ...regex_model import RegexModel
|
8 |
from ..base import Guardrail
|
|
|
9 |
|
10 |
|
11 |
class RegexEntityRecognitionResponse(BaseModel):
|
|
|
33 |
regex_model: RegexModel
|
34 |
patterns: Dict[str, str] = {}
|
35 |
should_anonymize: bool = False
|
36 |
+
|
37 |
DEFAULT_PATTERNS: ClassVar[Dict[str, str]] = {
|
38 |
+
"EMAIL": r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
|
39 |
+
"TELEPHONENUM": r"\b(\+\d{1,3}[-.]?)?\(?\d{3}\)?[-.]?\d{3}[-.]?\d{4}\b",
|
40 |
+
"SOCIALNUM": r"\b\d{3}[-]?\d{2}[-]?\d{4}\b",
|
41 |
+
"CREDITCARDNUMBER": r"\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b",
|
42 |
+
"DATEOFBIRTH": r"\b(0[1-9]|1[0-2])[-/](0[1-9]|[12]\d|3[01])[-/](19|20)\d{2}\b",
|
43 |
+
"DRIVERLICENSENUM": r"[A-Z]\d{7}", # Example pattern, adjust for your needs
|
44 |
+
"ACCOUNTNUM": r"\b\d{10,12}\b", # Example pattern for bank accounts
|
45 |
+
"ZIPCODE": r"\b\d{5}(?:-\d{4})?\b",
|
46 |
+
"GIVENNAME": r"\b[A-Z][a-z]+\b", # Basic pattern for first names
|
47 |
+
"SURNAME": r"\b[A-Z][a-z]+\b", # Basic pattern for last names
|
48 |
+
"CITY": r"\b[A-Z][a-z]+(?:[\s-][A-Z][a-z]+)*\b",
|
49 |
+
"STREET": r"\b\d+\s+[A-Z][a-z]+\s+(?:Street|St|Avenue|Ave|Road|Rd|Boulevard|Blvd|Lane|Ln|Drive|Dr)\b",
|
50 |
+
"IDCARDNUM": r"[A-Z]\d{7,8}", # Generic pattern for ID cards
|
51 |
+
"USERNAME": r"@[A-Za-z]\w{3,}", # Basic username pattern
|
52 |
+
"PASSWORD": r"[A-Za-z0-9@#$%^&+=]{8,}", # Basic password pattern
|
53 |
+
"TAXNUM": r"\b\d{2}[-]\d{7}\b", # Example tax number pattern
|
54 |
+
"BUILDINGNUM": r"\b\d+[A-Za-z]?\b", # Basic building number pattern
|
55 |
}
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
use_defaults: bool = True,
|
60 |
+
should_anonymize: bool = False,
|
61 |
+
show_available_entities: bool = False,
|
62 |
+
**kwargs,
|
63 |
+
):
|
64 |
patterns = {}
|
65 |
if use_defaults:
|
66 |
patterns = self.DEFAULT_PATTERNS.copy()
|
|
|
69 |
|
70 |
if show_available_entities:
|
71 |
self._print_available_entities(patterns.keys())
|
72 |
+
|
73 |
# Create the RegexModel instance
|
74 |
regex_model = RegexModel(patterns=patterns)
|
75 |
+
|
76 |
# Initialize the base class with both the regex_model and patterns
|
77 |
super().__init__(
|
78 |
+
regex_model=regex_model,
|
79 |
patterns=patterns,
|
80 |
+
should_anonymize=should_anonymize,
|
81 |
)
|
82 |
|
83 |
def text_to_pattern(self, text: str) -> str:
|
|
|
88 |
escaped_text = re.escape(text)
|
89 |
# Create a pattern that matches the exact text, case-insensitive
|
90 |
return rf"\b{escaped_text}\b"
|
91 |
+
|
92 |
def _print_available_entities(self, entities: List[str]):
|
93 |
"""Print available entities"""
|
94 |
print("\nAvailable entity types:")
|
|
|
98 |
print("=" * 25 + "\n")
|
99 |
|
100 |
@weave.op()
|
101 |
+
def guard(
|
102 |
+
self,
|
103 |
+
prompt: str,
|
104 |
+
custom_terms: Optional[list[str]] = None,
|
105 |
+
return_detected_types: bool = True,
|
106 |
+
aggregate_redaction: bool = True,
|
107 |
+
**kwargs,
|
108 |
+
) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
|
109 |
"""
|
110 |
Check if the input prompt contains any entities based on the regex patterns.
|
111 |
+
|
112 |
Args:
|
113 |
prompt: Input text to check for entities
|
114 |
+
custom_terms: List of custom terms to be converted into regex patterns. If provided,
|
115 |
only these terms will be checked, ignoring default patterns.
|
116 |
return_detected_types: If True, returns detailed entity type information
|
117 |
+
|
118 |
Returns:
|
119 |
RegexEntityRecognitionResponse or RegexEntityRecognitionSimpleResponse containing detection results
|
120 |
"""
|
|
|
126 |
else:
|
127 |
# Use the original regex_model if no custom terms provided
|
128 |
result = self.regex_model.check(prompt)
|
129 |
+
|
130 |
# Create detailed explanation
|
131 |
explanation_parts = []
|
132 |
if result.matched_patterns:
|
|
|
135 |
explanation_parts.append(f"- {entity_type}: {len(matches)} instance(s)")
|
136 |
else:
|
137 |
explanation_parts.append("No entities detected in the text.")
|
138 |
+
|
139 |
if result.failed_patterns:
|
140 |
explanation_parts.append("\nChecked but did not find these entity types:")
|
141 |
for pattern in result.failed_patterns:
|
142 |
explanation_parts.append(f"- {pattern}")
|
143 |
+
|
144 |
# Updated anonymization logic
|
145 |
anonymized_text = None
|
146 |
+
if getattr(self, "should_anonymize", False) and result.matched_patterns:
|
147 |
anonymized_text = prompt
|
148 |
for entity_type, matches in result.matched_patterns.items():
|
149 |
for match in matches:
|
150 |
+
replacement = (
|
151 |
+
"[redacted]"
|
152 |
+
if aggregate_redaction
|
153 |
+
else f"[{entity_type.upper()}]"
|
154 |
+
)
|
155 |
anonymized_text = anonymized_text.replace(match, replacement)
|
156 |
+
|
157 |
if return_detected_types:
|
158 |
return RegexEntityRecognitionResponse(
|
159 |
contains_entities=not result.passed,
|
160 |
detected_entities=result.matched_patterns,
|
161 |
explanation="\n".join(explanation_parts),
|
162 |
+
anonymized_text=anonymized_text,
|
163 |
)
|
164 |
else:
|
165 |
return RegexEntityRecognitionSimpleResponse(
|
166 |
contains_entities=not result.passed,
|
167 |
explanation="\n".join(explanation_parts),
|
168 |
+
anonymized_text=anonymized_text,
|
169 |
)
|
170 |
|
171 |
@weave.op()
|
172 |
+
def predict(
|
173 |
+
self,
|
174 |
+
prompt: str,
|
175 |
+
return_detected_types: bool = True,
|
176 |
+
aggregate_redaction: bool = True,
|
177 |
+
**kwargs,
|
178 |
+
) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
|
179 |
+
return self.guard(
|
180 |
+
prompt,
|
181 |
+
return_detected_types=return_detected_types,
|
182 |
+
aggregate_redaction=aggregate_redaction,
|
183 |
+
**kwargs,
|
184 |
+
)
|
guardrails_genie/guardrails/entity_recognition/transformers_entity_recognition_guardrail.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
-
from typing import
|
2 |
-
|
3 |
-
import
|
4 |
from pydantic import BaseModel
|
|
|
|
|
5 |
from ..base import Guardrail
|
6 |
-
|
7 |
|
8 |
class TransformersEntityRecognitionResponse(BaseModel):
|
9 |
contains_entities: bool
|
@@ -15,6 +17,7 @@ class TransformersEntityRecognitionResponse(BaseModel):
|
|
15 |
def safe(self) -> bool:
|
16 |
return not self.contains_entities
|
17 |
|
|
|
18 |
class TransformersEntityRecognitionSimpleResponse(BaseModel):
|
19 |
contains_entities: bool
|
20 |
explanation: str
|
@@ -24,14 +27,15 @@ class TransformersEntityRecognitionSimpleResponse(BaseModel):
|
|
24 |
def safe(self) -> bool:
|
25 |
return not self.contains_entities
|
26 |
|
|
|
27 |
class TransformersEntityRecognitionGuardrail(Guardrail):
|
28 |
"""Generic guardrail for detecting entities using any token classification model."""
|
29 |
-
|
30 |
_pipeline: Optional[object] = None
|
31 |
selected_entities: List[str]
|
32 |
should_anonymize: bool
|
33 |
available_entities: List[str]
|
34 |
-
|
35 |
def __init__(
|
36 |
self,
|
37 |
model_name: str = "iiiorg/piiranha-v1-detect-personal-information",
|
@@ -42,50 +46,52 @@ class TransformersEntityRecognitionGuardrail(Guardrail):
|
|
42 |
# Load model config and extract available entities
|
43 |
config = AutoConfig.from_pretrained(model_name)
|
44 |
entities = self._extract_entities_from_config(config)
|
45 |
-
|
46 |
if show_available_entities:
|
47 |
self._print_available_entities(entities)
|
48 |
-
|
49 |
# Initialize default values if needed
|
50 |
if selected_entities is None:
|
51 |
selected_entities = entities # Use all available entities by default
|
52 |
-
|
53 |
# Filter out invalid entities and warn user
|
54 |
invalid_entities = [e for e in selected_entities if e not in entities]
|
55 |
valid_entities = [e for e in selected_entities if e in entities]
|
56 |
-
|
57 |
if invalid_entities:
|
58 |
-
print(
|
|
|
|
|
59 |
print(f"Continuing with valid entities: {valid_entities}")
|
60 |
selected_entities = valid_entities
|
61 |
-
|
62 |
# Call parent class constructor
|
63 |
super().__init__(
|
64 |
selected_entities=selected_entities,
|
65 |
should_anonymize=should_anonymize,
|
66 |
-
available_entities=entities
|
67 |
)
|
68 |
-
|
69 |
# Initialize pipeline
|
70 |
self._pipeline = pipeline(
|
71 |
task="token-classification",
|
72 |
model=model_name,
|
73 |
-
aggregation_strategy="simple" # Merge same entities
|
74 |
)
|
75 |
|
76 |
def _extract_entities_from_config(self, config) -> List[str]:
|
77 |
"""Extract unique entity types from the model config."""
|
78 |
# Get id2label mapping from config
|
79 |
id2label = config.id2label
|
80 |
-
|
81 |
# Extract unique entity types (removing B- and I- prefixes)
|
82 |
entities = set()
|
83 |
for label in id2label.values():
|
84 |
-
if label.startswith((
|
85 |
entities.add(label[2:]) # Remove prefix
|
86 |
-
elif label !=
|
87 |
entities.add(label)
|
88 |
-
|
89 |
return sorted(list(entities))
|
90 |
|
91 |
def _print_available_entities(self, entities: List[str]):
|
@@ -103,48 +109,60 @@ class TransformersEntityRecognitionGuardrail(Guardrail):
|
|
103 |
def _detect_entities(self, text: str) -> Dict[str, List[str]]:
|
104 |
"""Detect entities in the text using the pipeline."""
|
105 |
results = self._pipeline(text)
|
106 |
-
|
107 |
# Group findings by entity type
|
108 |
detected_entities = {}
|
109 |
for entity in results:
|
110 |
-
entity_type = entity[
|
111 |
if entity_type in self.selected_entities:
|
112 |
if entity_type not in detected_entities:
|
113 |
detected_entities[entity_type] = []
|
114 |
-
detected_entities[entity_type].append(entity[
|
115 |
-
|
116 |
return detected_entities
|
117 |
|
118 |
def _anonymize_text(self, text: str, aggregate_redaction: bool = True) -> str:
|
119 |
"""Anonymize detected entities in text using the pipeline."""
|
120 |
results = self._pipeline(text)
|
121 |
-
|
122 |
# Sort entities by start position in reverse order to avoid offset issues
|
123 |
-
entities = sorted(results, key=lambda x: x[
|
124 |
-
|
125 |
# Create a mutable list of characters
|
126 |
chars = list(text)
|
127 |
-
|
128 |
# Apply redactions
|
129 |
for entity in entities:
|
130 |
-
if entity[
|
131 |
-
start, end = entity[
|
132 |
-
replacement =
|
133 |
-
|
|
|
|
|
|
|
|
|
134 |
# Replace the entity with the redaction marker
|
135 |
chars[start:end] = replacement
|
136 |
-
|
137 |
# Join characters and clean up only consecutive spaces (preserving newlines)
|
138 |
-
result =
|
139 |
# Replace multiple spaces with single space, but preserve newlines
|
140 |
-
lines = result.split(
|
141 |
-
cleaned_lines = [
|
142 |
-
return
|
143 |
|
144 |
@weave.op()
|
145 |
-
def guard(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
"""Check if the input prompt contains any entities using the transformer pipeline.
|
147 |
-
|
148 |
Args:
|
149 |
prompt: The text to analyze
|
150 |
return_detected_types: If True, returns detailed entity type information
|
@@ -152,39 +170,55 @@ class TransformersEntityRecognitionGuardrail(Guardrail):
|
|
152 |
"""
|
153 |
# Detect entities
|
154 |
detected_entities = self._detect_entities(prompt)
|
155 |
-
|
156 |
# Create explanation
|
157 |
explanation_parts = []
|
158 |
if detected_entities:
|
159 |
explanation_parts.append("Found the following entities in the text:")
|
160 |
for entity_type, instances in detected_entities.items():
|
161 |
-
explanation_parts.append(
|
|
|
|
|
162 |
else:
|
163 |
explanation_parts.append("No entities detected in the text.")
|
164 |
-
|
165 |
explanation_parts.append("\nChecked for these entities:")
|
166 |
for entity in self.selected_entities:
|
167 |
explanation_parts.append(f"- {entity}")
|
168 |
-
|
169 |
# Anonymize if requested
|
170 |
anonymized_text = None
|
171 |
if self.should_anonymize and detected_entities:
|
172 |
anonymized_text = self._anonymize_text(prompt, aggregate_redaction)
|
173 |
-
|
174 |
if return_detected_types:
|
175 |
return TransformersEntityRecognitionResponse(
|
176 |
contains_entities=bool(detected_entities),
|
177 |
detected_entities=detected_entities,
|
178 |
explanation="\n".join(explanation_parts),
|
179 |
-
anonymized_text=anonymized_text
|
180 |
)
|
181 |
else:
|
182 |
return TransformersEntityRecognitionSimpleResponse(
|
183 |
contains_entities=bool(detected_entities),
|
184 |
explanation="\n".join(explanation_parts),
|
185 |
-
anonymized_text=anonymized_text
|
186 |
)
|
187 |
|
188 |
@weave.op()
|
189 |
-
def predict(
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
|
3 |
+
import weave
|
4 |
from pydantic import BaseModel
|
5 |
+
from transformers import AutoConfig, pipeline
|
6 |
+
|
7 |
from ..base import Guardrail
|
8 |
+
|
9 |
|
10 |
class TransformersEntityRecognitionResponse(BaseModel):
|
11 |
contains_entities: bool
|
|
|
17 |
def safe(self) -> bool:
|
18 |
return not self.contains_entities
|
19 |
|
20 |
+
|
21 |
class TransformersEntityRecognitionSimpleResponse(BaseModel):
|
22 |
contains_entities: bool
|
23 |
explanation: str
|
|
|
27 |
def safe(self) -> bool:
|
28 |
return not self.contains_entities
|
29 |
|
30 |
+
|
31 |
class TransformersEntityRecognitionGuardrail(Guardrail):
|
32 |
"""Generic guardrail for detecting entities using any token classification model."""
|
33 |
+
|
34 |
_pipeline: Optional[object] = None
|
35 |
selected_entities: List[str]
|
36 |
should_anonymize: bool
|
37 |
available_entities: List[str]
|
38 |
+
|
39 |
def __init__(
|
40 |
self,
|
41 |
model_name: str = "iiiorg/piiranha-v1-detect-personal-information",
|
|
|
46 |
# Load model config and extract available entities
|
47 |
config = AutoConfig.from_pretrained(model_name)
|
48 |
entities = self._extract_entities_from_config(config)
|
49 |
+
|
50 |
if show_available_entities:
|
51 |
self._print_available_entities(entities)
|
52 |
+
|
53 |
# Initialize default values if needed
|
54 |
if selected_entities is None:
|
55 |
selected_entities = entities # Use all available entities by default
|
56 |
+
|
57 |
# Filter out invalid entities and warn user
|
58 |
invalid_entities = [e for e in selected_entities if e not in entities]
|
59 |
valid_entities = [e for e in selected_entities if e in entities]
|
60 |
+
|
61 |
if invalid_entities:
|
62 |
+
print(
|
63 |
+
f"\nWarning: The following entities are not available and will be ignored: {invalid_entities}"
|
64 |
+
)
|
65 |
print(f"Continuing with valid entities: {valid_entities}")
|
66 |
selected_entities = valid_entities
|
67 |
+
|
68 |
# Call parent class constructor
|
69 |
super().__init__(
|
70 |
selected_entities=selected_entities,
|
71 |
should_anonymize=should_anonymize,
|
72 |
+
available_entities=entities,
|
73 |
)
|
74 |
+
|
75 |
# Initialize pipeline
|
76 |
self._pipeline = pipeline(
|
77 |
task="token-classification",
|
78 |
model=model_name,
|
79 |
+
aggregation_strategy="simple", # Merge same entities
|
80 |
)
|
81 |
|
82 |
def _extract_entities_from_config(self, config) -> List[str]:
|
83 |
"""Extract unique entity types from the model config."""
|
84 |
# Get id2label mapping from config
|
85 |
id2label = config.id2label
|
86 |
+
|
87 |
# Extract unique entity types (removing B- and I- prefixes)
|
88 |
entities = set()
|
89 |
for label in id2label.values():
|
90 |
+
if label.startswith(("B-", "I-")):
|
91 |
entities.add(label[2:]) # Remove prefix
|
92 |
+
elif label != "O": # Skip the 'O' (Outside) label
|
93 |
entities.add(label)
|
94 |
+
|
95 |
return sorted(list(entities))
|
96 |
|
97 |
def _print_available_entities(self, entities: List[str]):
|
|
|
109 |
def _detect_entities(self, text: str) -> Dict[str, List[str]]:
|
110 |
"""Detect entities in the text using the pipeline."""
|
111 |
results = self._pipeline(text)
|
112 |
+
|
113 |
# Group findings by entity type
|
114 |
detected_entities = {}
|
115 |
for entity in results:
|
116 |
+
entity_type = entity["entity_group"]
|
117 |
if entity_type in self.selected_entities:
|
118 |
if entity_type not in detected_entities:
|
119 |
detected_entities[entity_type] = []
|
120 |
+
detected_entities[entity_type].append(entity["word"])
|
121 |
+
|
122 |
return detected_entities
|
123 |
|
124 |
def _anonymize_text(self, text: str, aggregate_redaction: bool = True) -> str:
|
125 |
"""Anonymize detected entities in text using the pipeline."""
|
126 |
results = self._pipeline(text)
|
127 |
+
|
128 |
# Sort entities by start position in reverse order to avoid offset issues
|
129 |
+
entities = sorted(results, key=lambda x: x["start"], reverse=True)
|
130 |
+
|
131 |
# Create a mutable list of characters
|
132 |
chars = list(text)
|
133 |
+
|
134 |
# Apply redactions
|
135 |
for entity in entities:
|
136 |
+
if entity["entity_group"] in self.selected_entities:
|
137 |
+
start, end = entity["start"], entity["end"]
|
138 |
+
replacement = (
|
139 |
+
" [redacted] "
|
140 |
+
if aggregate_redaction
|
141 |
+
else f" [{entity['entity_group']}] "
|
142 |
+
)
|
143 |
+
|
144 |
# Replace the entity with the redaction marker
|
145 |
chars[start:end] = replacement
|
146 |
+
|
147 |
# Join characters and clean up only consecutive spaces (preserving newlines)
|
148 |
+
result = "".join(chars)
|
149 |
# Replace multiple spaces with single space, but preserve newlines
|
150 |
+
lines = result.split("\n")
|
151 |
+
cleaned_lines = [" ".join(line.split()) for line in lines]
|
152 |
+
return "\n".join(cleaned_lines)
|
153 |
|
154 |
@weave.op()
|
155 |
+
def guard(
|
156 |
+
self,
|
157 |
+
prompt: str,
|
158 |
+
return_detected_types: bool = True,
|
159 |
+
aggregate_redaction: bool = True,
|
160 |
+
) -> (
|
161 |
+
TransformersEntityRecognitionResponse
|
162 |
+
| TransformersEntityRecognitionSimpleResponse
|
163 |
+
):
|
164 |
"""Check if the input prompt contains any entities using the transformer pipeline.
|
165 |
+
|
166 |
Args:
|
167 |
prompt: The text to analyze
|
168 |
return_detected_types: If True, returns detailed entity type information
|
|
|
170 |
"""
|
171 |
# Detect entities
|
172 |
detected_entities = self._detect_entities(prompt)
|
173 |
+
|
174 |
# Create explanation
|
175 |
explanation_parts = []
|
176 |
if detected_entities:
|
177 |
explanation_parts.append("Found the following entities in the text:")
|
178 |
for entity_type, instances in detected_entities.items():
|
179 |
+
explanation_parts.append(
|
180 |
+
f"- {entity_type}: {len(instances)} instance(s)"
|
181 |
+
)
|
182 |
else:
|
183 |
explanation_parts.append("No entities detected in the text.")
|
184 |
+
|
185 |
explanation_parts.append("\nChecked for these entities:")
|
186 |
for entity in self.selected_entities:
|
187 |
explanation_parts.append(f"- {entity}")
|
188 |
+
|
189 |
# Anonymize if requested
|
190 |
anonymized_text = None
|
191 |
if self.should_anonymize and detected_entities:
|
192 |
anonymized_text = self._anonymize_text(prompt, aggregate_redaction)
|
193 |
+
|
194 |
if return_detected_types:
|
195 |
return TransformersEntityRecognitionResponse(
|
196 |
contains_entities=bool(detected_entities),
|
197 |
detected_entities=detected_entities,
|
198 |
explanation="\n".join(explanation_parts),
|
199 |
+
anonymized_text=anonymized_text,
|
200 |
)
|
201 |
else:
|
202 |
return TransformersEntityRecognitionSimpleResponse(
|
203 |
contains_entities=bool(detected_entities),
|
204 |
explanation="\n".join(explanation_parts),
|
205 |
+
anonymized_text=anonymized_text,
|
206 |
)
|
207 |
|
208 |
@weave.op()
|
209 |
+
def predict(
|
210 |
+
self,
|
211 |
+
prompt: str,
|
212 |
+
return_detected_types: bool = True,
|
213 |
+
aggregate_redaction: bool = True,
|
214 |
+
**kwargs,
|
215 |
+
) -> (
|
216 |
+
TransformersEntityRecognitionResponse
|
217 |
+
| TransformersEntityRecognitionSimpleResponse
|
218 |
+
):
|
219 |
+
return self.guard(
|
220 |
+
prompt,
|
221 |
+
return_detected_types=return_detected_types,
|
222 |
+
aggregate_redaction=aggregate_redaction,
|
223 |
+
**kwargs,
|
224 |
+
)
|
guardrails_genie/guardrails/manager.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import weave
|
2 |
-
from rich.progress import track
|
3 |
from pydantic import BaseModel
|
|
|
4 |
|
5 |
from .base import Guardrail
|
6 |
|
|
|
1 |
import weave
|
|
|
2 |
from pydantic import BaseModel
|
3 |
+
from rich.progress import track
|
4 |
|
5 |
from .base import Guardrail
|
6 |
|
guardrails_genie/metrics.py
CHANGED
@@ -5,12 +5,55 @@ import weave
|
|
5 |
|
6 |
|
7 |
class AccuracyMetric(weave.Scorer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
@weave.op()
|
9 |
def score(self, output: dict, label: int):
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
@weave.op()
|
13 |
def summarize(self, score_rows: list) -> Optional[dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
valid_data = [
|
15 |
x.get("correct") for x in score_rows if x.get("correct") is not None
|
16 |
]
|
|
|
5 |
|
6 |
|
7 |
class AccuracyMetric(weave.Scorer):
|
8 |
+
"""
|
9 |
+
A class to compute and summarize accuracy-related metrics for model outputs.
|
10 |
+
|
11 |
+
This class extends the `weave.Scorer` and provides operations to score
|
12 |
+
individual predictions and summarize the results across multiple predictions.
|
13 |
+
It calculates the accuracy, precision, recall, and F1 score based on the
|
14 |
+
comparison between predicted outputs and true labels.
|
15 |
+
"""
|
16 |
+
|
17 |
@weave.op()
|
18 |
def score(self, output: dict, label: int):
|
19 |
+
"""
|
20 |
+
Evaluate the correctness of a single prediction.
|
21 |
+
|
22 |
+
This method compares a model's predicted output with the true label
|
23 |
+
to determine if the prediction is correct. It checks if the 'safe'
|
24 |
+
field in the output dictionary, when converted to an integer, matches
|
25 |
+
the provided label.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
output (dict): A dictionary containing the model's prediction,
|
29 |
+
specifically the 'safe' key which holds the predicted value.
|
30 |
+
label (int): The true label against which the prediction is compared.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
dict: A dictionary with a single key 'correct', which is True if the
|
34 |
+
prediction matches the label, otherwise False.
|
35 |
+
"""
|
36 |
+
return {"correct": label == int(output["safe"])}
|
37 |
|
38 |
@weave.op()
|
39 |
def summarize(self, score_rows: list) -> Optional[dict]:
|
40 |
+
"""
|
41 |
+
Summarize the accuracy-related metrics from a list of prediction scores.
|
42 |
+
|
43 |
+
This method processes a list of score dictionaries, each containing a
|
44 |
+
'correct' key indicating whether a prediction was correct. It calculates
|
45 |
+
several metrics: accuracy, precision, recall, and F1 score, based on the
|
46 |
+
number of true positives, false positives, and false negatives.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
score_rows (list): A list of dictionaries, each with a 'correct' key
|
50 |
+
indicating the correctness of individual predictions.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Optional[dict]: A dictionary containing the calculated metrics:
|
54 |
+
'accuracy', 'precision', 'recall', and 'f1_score'. If no valid data
|
55 |
+
is present, all metrics default to 0.
|
56 |
+
"""
|
57 |
valid_data = [
|
58 |
x.get("correct") for x in score_rows if x.get("correct") is not None
|
59 |
]
|
guardrails_genie/regex_model.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
-
from typing import List, Dict, Optional
|
2 |
import re
|
|
|
|
|
3 |
import weave
|
4 |
from pydantic import BaseModel
|
5 |
|
@@ -16,7 +17,7 @@ class RegexModel(weave.Model):
|
|
16 |
def __init__(self, patterns: Dict[str, str]) -> None:
|
17 |
"""
|
18 |
Initialize RegexModel with a dictionary of patterns.
|
19 |
-
|
20 |
Args:
|
21 |
patterns: Dictionary where key is pattern name and value is regex pattern
|
22 |
Example: {"email": r"[^@ \t\r\n]+@[^@ \t\r\n]+\.[^@ \t\r\n]+",
|
@@ -31,35 +32,37 @@ class RegexModel(weave.Model):
|
|
31 |
def check(self, prompt: str) -> RegexResult:
|
32 |
"""
|
33 |
Check text against all patterns and return detailed results.
|
34 |
-
|
35 |
Args:
|
36 |
text: Input text to check against patterns
|
37 |
-
|
38 |
Returns:
|
39 |
RegexResult containing pass/fail status and details about matches
|
40 |
"""
|
41 |
matched_patterns = {}
|
42 |
failed_patterns = []
|
43 |
-
|
44 |
for pattern_name, pattern in self.patterns.items():
|
45 |
matches = []
|
46 |
for match in re.finditer(pattern, prompt):
|
47 |
if match.groups():
|
48 |
# If there are capture groups, join them with a separator
|
49 |
-
matches.append(
|
|
|
|
|
50 |
else:
|
51 |
# If no capture groups, use the full match
|
52 |
matches.append(match.group(0))
|
53 |
-
|
54 |
if matches:
|
55 |
matched_patterns[pattern_name] = matches
|
56 |
else:
|
57 |
failed_patterns.append(pattern_name)
|
58 |
-
|
59 |
return RegexResult(
|
60 |
matched_patterns=matched_patterns,
|
61 |
failed_patterns=failed_patterns,
|
62 |
-
passed=len(matched_patterns) == 0
|
63 |
)
|
64 |
|
65 |
@weave.op()
|
@@ -67,4 +70,4 @@ class RegexModel(weave.Model):
|
|
67 |
"""
|
68 |
Alias for check() to maintain consistency with other models.
|
69 |
"""
|
70 |
-
return self.check(text)
|
|
|
|
|
1 |
import re
|
2 |
+
from typing import Dict, List
|
3 |
+
|
4 |
import weave
|
5 |
from pydantic import BaseModel
|
6 |
|
|
|
17 |
def __init__(self, patterns: Dict[str, str]) -> None:
|
18 |
"""
|
19 |
Initialize RegexModel with a dictionary of patterns.
|
20 |
+
|
21 |
Args:
|
22 |
patterns: Dictionary where key is pattern name and value is regex pattern
|
23 |
Example: {"email": r"[^@ \t\r\n]+@[^@ \t\r\n]+\.[^@ \t\r\n]+",
|
|
|
32 |
def check(self, prompt: str) -> RegexResult:
|
33 |
"""
|
34 |
Check text against all patterns and return detailed results.
|
35 |
+
|
36 |
Args:
|
37 |
text: Input text to check against patterns
|
38 |
+
|
39 |
Returns:
|
40 |
RegexResult containing pass/fail status and details about matches
|
41 |
"""
|
42 |
matched_patterns = {}
|
43 |
failed_patterns = []
|
44 |
+
|
45 |
for pattern_name, pattern in self.patterns.items():
|
46 |
matches = []
|
47 |
for match in re.finditer(pattern, prompt):
|
48 |
if match.groups():
|
49 |
# If there are capture groups, join them with a separator
|
50 |
+
matches.append(
|
51 |
+
"-".join(str(g) for g in match.groups() if g is not None)
|
52 |
+
)
|
53 |
else:
|
54 |
# If no capture groups, use the full match
|
55 |
matches.append(match.group(0))
|
56 |
+
|
57 |
if matches:
|
58 |
matched_patterns[pattern_name] = matches
|
59 |
else:
|
60 |
failed_patterns.append(pattern_name)
|
61 |
+
|
62 |
return RegexResult(
|
63 |
matched_patterns=matched_patterns,
|
64 |
failed_patterns=failed_patterns,
|
65 |
+
passed=len(matched_patterns) == 0,
|
66 |
)
|
67 |
|
68 |
@weave.op()
|
|
|
70 |
"""
|
71 |
Alias for check() to maintain consistency with other models.
|
72 |
"""
|
73 |
+
return self.check(text)
|
guardrails_genie/utils.py
CHANGED
@@ -19,9 +19,9 @@ class EvaluationCallManager:
|
|
19 |
"""
|
20 |
Manages the evaluation calls for a specific project and entity in Weave.
|
21 |
|
22 |
-
This class is responsible for initializing and managing evaluation calls associated with a
|
23 |
-
specific project and entity. It provides functionality to collect guardrail guard calls
|
24 |
-
from evaluation predictions and scores, and render these calls into a structured format
|
25 |
suitable for display in Streamlit.
|
26 |
|
27 |
Args:
|
@@ -30,6 +30,7 @@ class EvaluationCallManager:
|
|
30 |
call_id (str): The call id.
|
31 |
max_count (int): The maximum number of guardrail guard calls to collect from the evaluation.
|
32 |
"""
|
|
|
33 |
def __init__(self, entity: str, project: str, call_id: str, max_count: int = 10):
|
34 |
self.base_call = weave.init(f"{entity}/{project}").get_call(call_id=call_id)
|
35 |
self.max_count = max_count
|
@@ -40,10 +41,10 @@ class EvaluationCallManager:
|
|
40 |
"""
|
41 |
Collects guardrail guard calls from evaluation predictions and scores.
|
42 |
|
43 |
-
This function iterates through the children calls of the base evaluation call,
|
44 |
-
extracting relevant guardrail guard calls and their associated scores. It stops
|
45 |
-
collecting calls if it encounters an "Evaluation.summarize" operation or if the
|
46 |
-
maximum count of guardrail guard calls is reached. The collected calls are stored
|
47 |
in a list of dictionaries, each containing the input prompt, outputs, and score.
|
48 |
|
49 |
Returns:
|
@@ -77,9 +78,9 @@ class EvaluationCallManager:
|
|
77 |
Renders the collected guardrail guard calls into a pandas DataFrame suitable for
|
78 |
display in Streamlit.
|
79 |
|
80 |
-
This function processes the collected guardrail guard calls stored in `self.call_list` and
|
81 |
-
organizes them into a dictionary format that can be easily converted into a pandas DataFrame.
|
82 |
-
The DataFrame contains columns for the input prompts, the safety status of the outputs, and
|
83 |
the correctness of the predictions for each guardrail.
|
84 |
|
85 |
The structure of the DataFrame is as follows:
|
@@ -87,7 +88,7 @@ class EvaluationCallManager:
|
|
87 |
- Subsequent columns contain the safety status and prediction correctness for each guardrail.
|
88 |
|
89 |
Returns:
|
90 |
-
pd.DataFrame: A DataFrame containing the input prompts, safety status, and prediction
|
91 |
correctness for each guardrail.
|
92 |
"""
|
93 |
dataframe = {
|
|
|
19 |
"""
|
20 |
Manages the evaluation calls for a specific project and entity in Weave.
|
21 |
|
22 |
+
This class is responsible for initializing and managing evaluation calls associated with a
|
23 |
+
specific project and entity. It provides functionality to collect guardrail guard calls
|
24 |
+
from evaluation predictions and scores, and render these calls into a structured format
|
25 |
suitable for display in Streamlit.
|
26 |
|
27 |
Args:
|
|
|
30 |
call_id (str): The call id.
|
31 |
max_count (int): The maximum number of guardrail guard calls to collect from the evaluation.
|
32 |
"""
|
33 |
+
|
34 |
def __init__(self, entity: str, project: str, call_id: str, max_count: int = 10):
|
35 |
self.base_call = weave.init(f"{entity}/{project}").get_call(call_id=call_id)
|
36 |
self.max_count = max_count
|
|
|
41 |
"""
|
42 |
Collects guardrail guard calls from evaluation predictions and scores.
|
43 |
|
44 |
+
This function iterates through the children calls of the base evaluation call,
|
45 |
+
extracting relevant guardrail guard calls and their associated scores. It stops
|
46 |
+
collecting calls if it encounters an "Evaluation.summarize" operation or if the
|
47 |
+
maximum count of guardrail guard calls is reached. The collected calls are stored
|
48 |
in a list of dictionaries, each containing the input prompt, outputs, and score.
|
49 |
|
50 |
Returns:
|
|
|
78 |
Renders the collected guardrail guard calls into a pandas DataFrame suitable for
|
79 |
display in Streamlit.
|
80 |
|
81 |
+
This function processes the collected guardrail guard calls stored in `self.call_list` and
|
82 |
+
organizes them into a dictionary format that can be easily converted into a pandas DataFrame.
|
83 |
+
The DataFrame contains columns for the input prompts, the safety status of the outputs, and
|
84 |
the correctness of the predictions for each guardrail.
|
85 |
|
86 |
The structure of the DataFrame is as follows:
|
|
|
88 |
- Subsequent columns contain the safety status and prediction correctness for each guardrail.
|
89 |
|
90 |
Returns:
|
91 |
+
pd.DataFrame: A DataFrame containing the input prompts, safety status, and prediction
|
92 |
correctness for each guardrail.
|
93 |
"""
|
94 |
dataframe = {
|