geekyrakshit commited on
Commit
78a1bf0
·
1 Parent(s): 36c3c0f

add: docs for AccuracyMetric

Browse files
guardrails_genie/guardrails/__init__.py CHANGED
@@ -1,12 +1,12 @@
1
- from .injection import (
2
- PromptInjectionClassifierGuardrail,
3
- PromptInjectionSurveyGuardrail,
4
- )
5
  from .entity_recognition import (
6
  PresidioEntityRecognitionGuardrail,
7
  RegexEntityRecognitionGuardrail,
8
- TransformersEntityRecognitionGuardrail,
9
  RestrictedTermsJudge,
 
 
 
 
 
10
  )
11
  from .manager import GuardrailManager
12
 
 
 
 
 
 
1
  from .entity_recognition import (
2
  PresidioEntityRecognitionGuardrail,
3
  RegexEntityRecognitionGuardrail,
 
4
  RestrictedTermsJudge,
5
+ TransformersEntityRecognitionGuardrail,
6
+ )
7
+ from .injection import (
8
+ PromptInjectionClassifierGuardrail,
9
+ PromptInjectionSurveyGuardrail,
10
  )
11
  from .manager import GuardrailManager
12
 
guardrails_genie/guardrails/entity_recognition/__init__.py CHANGED
@@ -1,10 +1,13 @@
1
- from .presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
2
- from .regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
3
- from .transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
4
  from .llm_judge_entity_recognition_guardrail import RestrictedTermsJudge
 
 
 
 
 
 
5
  __all__ = [
6
  "PresidioEntityRecognitionGuardrail",
7
  "RegexEntityRecognitionGuardrail",
8
  "TransformersEntityRecognitionGuardrail",
9
- "RestrictedTermsJudge"
10
  ]
 
 
 
 
1
  from .llm_judge_entity_recognition_guardrail import RestrictedTermsJudge
2
+ from .presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
3
+ from .regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
4
+ from .transformers_entity_recognition_guardrail import (
5
+ TransformersEntityRecognitionGuardrail,
6
+ )
7
+
8
  __all__ = [
9
  "PresidioEntityRecognitionGuardrail",
10
  "RegexEntityRecognitionGuardrail",
11
  "TransformersEntityRecognitionGuardrail",
12
+ "RestrictedTermsJudge",
13
  ]
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/banned_term_examples.py CHANGED
@@ -11,15 +11,22 @@ I think we should implement features similar to Salesforce's Einstein AI
11
  and Oracle's Cloud Infrastructure. Maybe we could also look at how
12
  AWS handles their lambda functions.
13
  """,
14
- "custom_terms": ["Salesforce", "Oracle", "AWS", "Einstein AI", "Cloud Infrastructure", "lambda"],
 
 
 
 
 
 
 
15
  "expected_entities": {
16
  "Salesforce": ["Salesforce"],
17
  "Oracle": ["Oracle"],
18
  "AWS": ["AWS"],
19
  "Einstein AI": ["Einstein AI"],
20
  "Cloud Infrastructure": ["Cloud Infrastructure"],
21
- "lambda": ["lambda"]
22
- }
23
  },
24
  {
25
  "description": "Inappropriate Language in Support Ticket",
@@ -32,8 +39,8 @@ stupid service? I've wasted so much freaking time on this crap.
32
  "damn": ["damn"],
33
  "hell": ["hell"],
34
  "stupid": ["stupid"],
35
- "crap": ["crap"]
36
- }
37
  },
38
  {
39
  "description": "Confidential Project Names",
@@ -45,9 +52,9 @@ with Project Phoenix team and the Blue Dragon initiative for resource allocation
45
  "expected_entities": {
46
  "Project Titan": ["Project Titan"],
47
  "Project Phoenix": ["Project Phoenix"],
48
- "Blue Dragon": ["Blue Dragon"]
49
- }
50
- }
51
  ]
52
 
53
  # Edge cases and special formats
@@ -59,15 +66,22 @@ MSFT's Azure and O365 platform is gaining market share.
59
  Have you seen what GOOGL/GOOG and FB/META are doing with their AI?
60
  CRM (Salesforce) and ORCL (Oracle) have interesting features too.
61
  """,
62
- "custom_terms": ["Microsoft", "Google", "Meta", "Facebook", "Salesforce", "Oracle"],
 
 
 
 
 
 
 
63
  "expected_entities": {
64
  "Microsoft": ["MSFT"],
65
  "Google": ["GOOGL", "GOOG"],
66
  "Meta": ["META"],
67
  "Facebook": ["FB"],
68
  "Salesforce": ["CRM", "Salesforce"],
69
- "Oracle": ["ORCL"]
70
- }
71
  },
72
  {
73
  "description": "L33t Speak and Intentional Obfuscation",
@@ -76,15 +90,22 @@ S4l3sf0rc3 is better than 0r4cl3!
76
  M1cr0$oft and G00gl3 are the main competitors.
77
  Let's check F8book and Met@ too.
78
  """,
79
- "custom_terms": ["Salesforce", "Oracle", "Microsoft", "Google", "Facebook", "Meta"],
 
 
 
 
 
 
 
80
  "expected_entities": {
81
  "Salesforce": ["S4l3sf0rc3"],
82
  "Oracle": ["0r4cl3"],
83
  "Microsoft": ["M1cr0$oft"],
84
  "Google": ["G00gl3"],
85
  "Facebook": ["F8book"],
86
- "Meta": ["Met@"]
87
- }
88
  },
89
  {
90
  "description": "Case Variations and Partial Matches",
@@ -98,8 +119,8 @@ Have you tried micro-soft or Google_Cloud?
98
  "Microsoft": ["MicroSoft", "micro-soft"],
99
  "Google": ["google", "Google_Cloud"],
100
  "Salesforce": ["salesFORCE"],
101
- "Oracle": ["ORACLE"]
102
- }
103
  },
104
  {
105
  "description": "Common Misspellings and Typos",
@@ -113,8 +134,8 @@ Salezforce and Oracel need checking too.
113
  "Microsoft": ["Microsft", "Microsooft"],
114
  "Google": ["Goggle", "Googel", "Gooogle"],
115
  "Salesforce": ["Salezforce"],
116
- "Oracle": ["Oracel"]
117
- }
118
  },
119
  {
120
  "description": "Mixed Variations and Context",
@@ -123,7 +144,15 @@ The M$ cloud competes with AWS (Amazon Web Services).
123
  FB/Meta's social platform and GOOGL's search dominate.
124
  SF.com and Oracle-DB are industry standards.
125
  """,
126
- "custom_terms": ["Microsoft", "Amazon Web Services", "Facebook", "Meta", "Google", "Salesforce", "Oracle"],
 
 
 
 
 
 
 
 
127
  "expected_entities": {
128
  "Microsoft": ["M$"],
129
  "Amazon Web Services": ["AWS"],
@@ -131,37 +160,40 @@ SF.com and Oracle-DB are industry standards.
131
  "Meta": ["Meta"],
132
  "Google": ["GOOGL"],
133
  "Salesforce": ["SF.com"],
134
- "Oracle": ["Oracle-DB"]
135
- }
136
- }
137
  ]
138
 
 
139
  def validate_entities(detected: dict, expected: dict) -> bool:
140
  """Compare detected entities with expected entities"""
141
  if set(detected.keys()) != set(expected.keys()):
142
  return False
143
  return all(set(detected[k]) == set(expected[k]) for k in expected.keys())
144
 
 
145
  def run_test_case(guardrail, test_case, test_type="Main"):
146
  """Run a single test case and print results"""
147
  print(f"\n{test_type} Test Case: {test_case['description']}")
148
  print("-" * 50)
149
-
150
  result = guardrail.guard(
151
- test_case['input_text'],
152
- custom_terms=test_case['custom_terms']
153
  )
154
- expected = test_case['expected_entities']
155
-
156
  # Validate results
157
  matches = validate_entities(result.detected_entities, expected)
158
-
159
  print(f"Test Status: {'✓ PASS' if matches else '✗ FAIL'}")
160
  print(f"Contains Restricted Terms: {result.contains_entities}")
161
-
162
  if not matches:
163
  print("\nEntity Comparison:")
164
- all_entity_types = set(list(result.detected_entities.keys()) + list(expected.keys()))
 
 
165
  for entity_type in all_entity_types:
166
  detected = set(result.detected_entities.get(entity_type, []))
167
  expected_set = set(expected.get(entity_type, []))
@@ -171,8 +203,8 @@ def run_test_case(guardrail, test_case, test_type="Main"):
171
  if detected != expected_set:
172
  print(f" Missing: {sorted(expected_set - detected)}")
173
  print(f" Extra: {sorted(detected - expected_set)}")
174
-
175
  if result.anonymized_text:
176
  print(f"\nAnonymized Text:\n{result.anonymized_text}")
177
-
178
  return matches
 
11
  and Oracle's Cloud Infrastructure. Maybe we could also look at how
12
  AWS handles their lambda functions.
13
  """,
14
+ "custom_terms": [
15
+ "Salesforce",
16
+ "Oracle",
17
+ "AWS",
18
+ "Einstein AI",
19
+ "Cloud Infrastructure",
20
+ "lambda",
21
+ ],
22
  "expected_entities": {
23
  "Salesforce": ["Salesforce"],
24
  "Oracle": ["Oracle"],
25
  "AWS": ["AWS"],
26
  "Einstein AI": ["Einstein AI"],
27
  "Cloud Infrastructure": ["Cloud Infrastructure"],
28
+ "lambda": ["lambda"],
29
+ },
30
  },
31
  {
32
  "description": "Inappropriate Language in Support Ticket",
 
39
  "damn": ["damn"],
40
  "hell": ["hell"],
41
  "stupid": ["stupid"],
42
+ "crap": ["crap"],
43
+ },
44
  },
45
  {
46
  "description": "Confidential Project Names",
 
52
  "expected_entities": {
53
  "Project Titan": ["Project Titan"],
54
  "Project Phoenix": ["Project Phoenix"],
55
+ "Blue Dragon": ["Blue Dragon"],
56
+ },
57
+ },
58
  ]
59
 
60
  # Edge cases and special formats
 
66
  Have you seen what GOOGL/GOOG and FB/META are doing with their AI?
67
  CRM (Salesforce) and ORCL (Oracle) have interesting features too.
68
  """,
69
+ "custom_terms": [
70
+ "Microsoft",
71
+ "Google",
72
+ "Meta",
73
+ "Facebook",
74
+ "Salesforce",
75
+ "Oracle",
76
+ ],
77
  "expected_entities": {
78
  "Microsoft": ["MSFT"],
79
  "Google": ["GOOGL", "GOOG"],
80
  "Meta": ["META"],
81
  "Facebook": ["FB"],
82
  "Salesforce": ["CRM", "Salesforce"],
83
+ "Oracle": ["ORCL"],
84
+ },
85
  },
86
  {
87
  "description": "L33t Speak and Intentional Obfuscation",
 
90
  M1cr0$oft and G00gl3 are the main competitors.
91
  Let's check F8book and Met@ too.
92
  """,
93
+ "custom_terms": [
94
+ "Salesforce",
95
+ "Oracle",
96
+ "Microsoft",
97
+ "Google",
98
+ "Facebook",
99
+ "Meta",
100
+ ],
101
  "expected_entities": {
102
  "Salesforce": ["S4l3sf0rc3"],
103
  "Oracle": ["0r4cl3"],
104
  "Microsoft": ["M1cr0$oft"],
105
  "Google": ["G00gl3"],
106
  "Facebook": ["F8book"],
107
+ "Meta": ["Met@"],
108
+ },
109
  },
110
  {
111
  "description": "Case Variations and Partial Matches",
 
119
  "Microsoft": ["MicroSoft", "micro-soft"],
120
  "Google": ["google", "Google_Cloud"],
121
  "Salesforce": ["salesFORCE"],
122
+ "Oracle": ["ORACLE"],
123
+ },
124
  },
125
  {
126
  "description": "Common Misspellings and Typos",
 
134
  "Microsoft": ["Microsft", "Microsooft"],
135
  "Google": ["Goggle", "Googel", "Gooogle"],
136
  "Salesforce": ["Salezforce"],
137
+ "Oracle": ["Oracel"],
138
+ },
139
  },
140
  {
141
  "description": "Mixed Variations and Context",
 
144
  FB/Meta's social platform and GOOGL's search dominate.
145
  SF.com and Oracle-DB are industry standards.
146
  """,
147
+ "custom_terms": [
148
+ "Microsoft",
149
+ "Amazon Web Services",
150
+ "Facebook",
151
+ "Meta",
152
+ "Google",
153
+ "Salesforce",
154
+ "Oracle",
155
+ ],
156
  "expected_entities": {
157
  "Microsoft": ["M$"],
158
  "Amazon Web Services": ["AWS"],
 
160
  "Meta": ["Meta"],
161
  "Google": ["GOOGL"],
162
  "Salesforce": ["SF.com"],
163
+ "Oracle": ["Oracle-DB"],
164
+ },
165
+ },
166
  ]
167
 
168
+
169
  def validate_entities(detected: dict, expected: dict) -> bool:
170
  """Compare detected entities with expected entities"""
171
  if set(detected.keys()) != set(expected.keys()):
172
  return False
173
  return all(set(detected[k]) == set(expected[k]) for k in expected.keys())
174
 
175
+
176
  def run_test_case(guardrail, test_case, test_type="Main"):
177
  """Run a single test case and print results"""
178
  print(f"\n{test_type} Test Case: {test_case['description']}")
179
  print("-" * 50)
180
+
181
  result = guardrail.guard(
182
+ test_case["input_text"], custom_terms=test_case["custom_terms"]
 
183
  )
184
+ expected = test_case["expected_entities"]
185
+
186
  # Validate results
187
  matches = validate_entities(result.detected_entities, expected)
188
+
189
  print(f"Test Status: {'✓ PASS' if matches else '✗ FAIL'}")
190
  print(f"Contains Restricted Terms: {result.contains_entities}")
191
+
192
  if not matches:
193
  print("\nEntity Comparison:")
194
+ all_entity_types = set(
195
+ list(result.detected_entities.keys()) + list(expected.keys())
196
+ )
197
  for entity_type in all_entity_types:
198
  detected = set(result.detected_entities.get(entity_type, []))
199
  expected_set = set(expected.get(entity_type, []))
 
203
  if detected != expected_set:
204
  print(f" Missing: {sorted(expected_set - detected)}")
205
  print(f" Extra: {sorted(detected - expected_set)}")
206
+
207
  if result.anonymized_text:
208
  print(f"\nAnonymized Text:\n{result.anonymized_text}")
209
+
210
  return matches
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_llm_judge.py CHANGED
@@ -1,21 +1,22 @@
1
- from guardrails_genie.guardrails.entity_recognition.llm_judge_entity_recognition_guardrail import RestrictedTermsJudge
 
2
  from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_examples import (
3
- RESTRICTED_TERMS_EXAMPLES,
4
- EDGE_CASE_EXAMPLES,
5
- run_test_case
 
 
 
6
  )
7
  from guardrails_genie.llm import OpenAIModel
8
- import weave
9
 
10
  def test_restricted_terms_detection():
11
  """Test restricted terms detection scenarios using predefined test cases"""
12
  weave.init("guardrails-genie-restricted-terms-llm-judge")
13
-
14
  # Create the guardrail with OpenAI model
15
- llm_judge = RestrictedTermsJudge(
16
- should_anonymize=True,
17
- llm_model=OpenAIModel()
18
- )
19
 
20
  # Test statistics
21
  total_tests = len(RESTRICTED_TERMS_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
@@ -43,5 +44,6 @@ def test_restricted_terms_detection():
43
  print(f"Failed: {total_tests - passed_tests}")
44
  print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
45
 
 
46
  if __name__ == "__main__":
47
  test_restricted_terms_detection()
 
1
+ import weave
2
+
3
  from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_examples import (
4
+ EDGE_CASE_EXAMPLES,
5
+ RESTRICTED_TERMS_EXAMPLES,
6
+ run_test_case,
7
+ )
8
+ from guardrails_genie.guardrails.entity_recognition.llm_judge_entity_recognition_guardrail import (
9
+ RestrictedTermsJudge,
10
  )
11
  from guardrails_genie.llm import OpenAIModel
12
+
13
 
14
  def test_restricted_terms_detection():
15
  """Test restricted terms detection scenarios using predefined test cases"""
16
  weave.init("guardrails-genie-restricted-terms-llm-judge")
17
+
18
  # Create the guardrail with OpenAI model
19
+ llm_judge = RestrictedTermsJudge(should_anonymize=True, llm_model=OpenAIModel())
 
 
 
20
 
21
  # Test statistics
22
  total_tests = len(RESTRICTED_TERMS_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
 
44
  print(f"Failed: {total_tests - passed_tests}")
45
  print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
46
 
47
+
48
  if __name__ == "__main__":
49
  test_restricted_terms_detection()
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_regex_model.py CHANGED
@@ -1,19 +1,22 @@
1
- from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
 
2
  from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_examples import (
3
- RESTRICTED_TERMS_EXAMPLES,
4
- EDGE_CASE_EXAMPLES,
5
- run_test_case
 
 
 
6
  )
7
- import weave
8
 
9
  def test_restricted_terms_detection():
10
  """Test restricted terms detection scenarios using predefined test cases"""
11
  weave.init("guardrails-genie-restricted-terms-regex-model")
12
-
13
  # Create the guardrail with anonymization enabled
14
  regex_guardrail = RegexEntityRecognitionGuardrail(
15
- use_defaults=False, # Don't use default PII patterns
16
- should_anonymize=True
17
  )
18
 
19
  # Test statistics
@@ -42,5 +45,6 @@ def test_restricted_terms_detection():
42
  print(f"Failed: {total_tests - passed_tests}")
43
  print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
44
 
 
45
  if __name__ == "__main__":
46
  test_restricted_terms_detection()
 
1
+ import weave
2
+
3
  from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_examples import (
4
+ EDGE_CASE_EXAMPLES,
5
+ RESTRICTED_TERMS_EXAMPLES,
6
+ run_test_case,
7
+ )
8
+ from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import (
9
+ RegexEntityRecognitionGuardrail,
10
  )
11
+
12
 
13
  def test_restricted_terms_detection():
14
  """Test restricted terms detection scenarios using predefined test cases"""
15
  weave.init("guardrails-genie-restricted-terms-regex-model")
16
+
17
  # Create the guardrail with anonymization enabled
18
  regex_guardrail = RegexEntityRecognitionGuardrail(
19
+ use_defaults=False, should_anonymize=True # Don't use default PII patterns
 
20
  )
21
 
22
  # Test statistics
 
45
  print(f"Failed: {total_tests - passed_tests}")
46
  print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
47
 
48
+
49
  if __name__ == "__main__":
50
  test_restricted_terms_detection()
guardrails_genie/guardrails/entity_recognition/llm_judge_entity_recognition_guardrail.py CHANGED
@@ -1,15 +1,16 @@
1
  from typing import Dict, List, Optional
 
 
2
  import weave
3
  from pydantic import BaseModel, Field
4
- from typing_extensions import Annotated
5
 
6
  from ...llm import OpenAIModel
7
  from ..base import Guardrail
8
- import instructor
9
 
10
 
11
  class TermMatch(BaseModel):
12
  """Represents a matched term and its variations"""
 
13
  original_term: str
14
  matched_text: str
15
  match_type: str = Field(
@@ -22,19 +23,18 @@ class TermMatch(BaseModel):
22
 
23
  class RestrictedTermsAnalysis(BaseModel):
24
  """Analysis result for restricted terms detection"""
 
25
  contains_restricted_terms: bool = Field(
26
  description="Whether any restricted terms were detected"
27
  )
28
  detected_matches: List[TermMatch] = Field(
29
  default_factory=list,
30
- description="List of detected term matches with their variations"
31
- )
32
- explanation: str = Field(
33
- description="Detailed explanation of the analysis"
34
  )
 
35
  anonymized_text: Optional[str] = Field(
36
  default=None,
37
- description="Text with restricted terms replaced with category tags"
38
  )
39
 
40
  @property
@@ -106,39 +106,57 @@ Return your analysis in the structured format specified by the RestrictedTermsAn
106
  return user_prompt, system_prompt
107
 
108
  @weave.op()
109
- def predict(self, text: str, custom_terms: List[str], **kwargs) -> RestrictedTermsAnalysis:
 
 
110
  user_prompt, system_prompt = self.format_prompts(text, custom_terms)
111
-
112
  response = self.llm_model.predict(
113
  user_prompts=user_prompt,
114
  system_prompt=system_prompt,
115
  response_format=RestrictedTermsAnalysis,
116
  temperature=0.1, # Lower temperature for more consistent analysis
117
- **kwargs
118
  )
119
-
120
  return response.choices[0].message.parsed
121
 
122
- #TODO: Remove default custom_terms
123
  @weave.op()
124
- def guard(self, text: str, custom_terms: List[str] = ["Microsoft", "Amazon Web Services", "Facebook", "Meta", "Google", "Salesforce", "Oracle"], aggregate_redaction: bool = True, **kwargs) -> RestrictedTermsRecognitionResponse:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  """
126
  Guard against restricted terms and their variations.
127
-
128
  Args:
129
  text: Text to analyze
130
  custom_terms: List of restricted terms to check for
131
-
132
  Returns:
133
  RestrictedTermsRecognitionResponse containing safety assessment and detailed analysis
134
  """
135
  analysis = self.predict(text, custom_terms, **kwargs)
136
-
137
  # Create a summary of findings
138
  if analysis.contains_restricted_terms:
139
  summary_parts = ["Restricted terms detected:"]
140
  for match in analysis.detected_matches:
141
- summary_parts.append(f"\n- {match.original_term}: {match.matched_text} ({match.match_type})")
 
 
142
  summary = "\n".join(summary_parts)
143
  else:
144
  summary = "No restricted terms detected."
@@ -148,8 +166,14 @@ Return your analysis in the structured format specified by the RestrictedTermsAn
148
  if self.should_anonymize and analysis.contains_restricted_terms:
149
  anonymized_text = text
150
  for match in analysis.detected_matches:
151
- replacement = "[redacted]" if aggregate_redaction else f"[{match.match_type.upper()}]"
152
- anonymized_text = anonymized_text.replace(match.matched_text, replacement)
 
 
 
 
 
 
153
 
154
  # Convert detected_matches to a dictionary format
155
  detected_entities = {}
@@ -162,5 +186,5 @@ Return your analysis in the structured format specified by the RestrictedTermsAn
162
  contains_entities=analysis.contains_restricted_terms,
163
  detected_entities=detected_entities,
164
  explanation=summary,
165
- anonymized_text=anonymized_text
166
- )
 
1
  from typing import Dict, List, Optional
2
+
3
+ import instructor
4
  import weave
5
  from pydantic import BaseModel, Field
 
6
 
7
  from ...llm import OpenAIModel
8
  from ..base import Guardrail
 
9
 
10
 
11
  class TermMatch(BaseModel):
12
  """Represents a matched term and its variations"""
13
+
14
  original_term: str
15
  matched_text: str
16
  match_type: str = Field(
 
23
 
24
  class RestrictedTermsAnalysis(BaseModel):
25
  """Analysis result for restricted terms detection"""
26
+
27
  contains_restricted_terms: bool = Field(
28
  description="Whether any restricted terms were detected"
29
  )
30
  detected_matches: List[TermMatch] = Field(
31
  default_factory=list,
32
+ description="List of detected term matches with their variations",
 
 
 
33
  )
34
+ explanation: str = Field(description="Detailed explanation of the analysis")
35
  anonymized_text: Optional[str] = Field(
36
  default=None,
37
+ description="Text with restricted terms replaced with category tags",
38
  )
39
 
40
  @property
 
106
  return user_prompt, system_prompt
107
 
108
  @weave.op()
109
+ def predict(
110
+ self, text: str, custom_terms: List[str], **kwargs
111
+ ) -> RestrictedTermsAnalysis:
112
  user_prompt, system_prompt = self.format_prompts(text, custom_terms)
113
+
114
  response = self.llm_model.predict(
115
  user_prompts=user_prompt,
116
  system_prompt=system_prompt,
117
  response_format=RestrictedTermsAnalysis,
118
  temperature=0.1, # Lower temperature for more consistent analysis
119
+ **kwargs,
120
  )
121
+
122
  return response.choices[0].message.parsed
123
 
124
+ # TODO: Remove default custom_terms
125
  @weave.op()
126
+ def guard(
127
+ self,
128
+ text: str,
129
+ custom_terms: List[str] = [
130
+ "Microsoft",
131
+ "Amazon Web Services",
132
+ "Facebook",
133
+ "Meta",
134
+ "Google",
135
+ "Salesforce",
136
+ "Oracle",
137
+ ],
138
+ aggregate_redaction: bool = True,
139
+ **kwargs,
140
+ ) -> RestrictedTermsRecognitionResponse:
141
  """
142
  Guard against restricted terms and their variations.
143
+
144
  Args:
145
  text: Text to analyze
146
  custom_terms: List of restricted terms to check for
147
+
148
  Returns:
149
  RestrictedTermsRecognitionResponse containing safety assessment and detailed analysis
150
  """
151
  analysis = self.predict(text, custom_terms, **kwargs)
152
+
153
  # Create a summary of findings
154
  if analysis.contains_restricted_terms:
155
  summary_parts = ["Restricted terms detected:"]
156
  for match in analysis.detected_matches:
157
+ summary_parts.append(
158
+ f"\n- {match.original_term}: {match.matched_text} ({match.match_type})"
159
+ )
160
  summary = "\n".join(summary_parts)
161
  else:
162
  summary = "No restricted terms detected."
 
166
  if self.should_anonymize and analysis.contains_restricted_terms:
167
  anonymized_text = text
168
  for match in analysis.detected_matches:
169
+ replacement = (
170
+ "[redacted]"
171
+ if aggregate_redaction
172
+ else f"[{match.match_type.upper()}]"
173
+ )
174
+ anonymized_text = anonymized_text.replace(
175
+ match.matched_text, replacement
176
+ )
177
 
178
  # Convert detected_matches to a dictionary format
179
  detected_entities = {}
 
186
  contains_entities=analysis.contains_restricted_terms,
187
  detected_entities=detected_entities,
188
  explanation=summary,
189
+ anonymized_text=anonymized_text,
190
+ )
guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark.py CHANGED
@@ -1,10 +1,11 @@
1
- from datasets import load_dataset
2
- from typing import Dict, List, Tuple
3
- import random
4
- from tqdm import tqdm
5
  import json
 
6
  from pathlib import Path
 
 
7
  import weave
 
 
8
 
9
  # Add this mapping dictionary near the top of the file
10
  PRESIDIO_TO_TRANSFORMER_MAPPING = {
@@ -32,66 +33,70 @@ PRESIDIO_TO_TRANSFORMER_MAPPING = {
32
  "CRYPTO": "ACCOUNTNUM", # Cryptocurrency addresses
33
  "IBAN_CODE": "ACCOUNTNUM",
34
  "MEDICAL_LICENSE": "IDCARDNUM",
35
- "IN_VEHICLE_REGISTRATION": "IDCARDNUM"
36
  }
37
 
38
- def load_ai4privacy_dataset(num_samples: int = 100, split: str = "validation") -> List[Dict]:
 
 
 
39
  """
40
  Load and prepare samples from the ai4privacy dataset.
41
-
42
  Args:
43
  num_samples: Number of samples to evaluate
44
  split: Dataset split to use ("train" or "validation")
45
-
46
  Returns:
47
  List of prepared test cases
48
  """
49
  # Load the dataset
50
  dataset = load_dataset("ai4privacy/pii-masking-400k")
51
-
52
  # Get the specified split
53
  data_split = dataset[split]
54
-
55
  # Randomly sample entries if num_samples is less than total
56
  if num_samples < len(data_split):
57
  indices = random.sample(range(len(data_split)), num_samples)
58
  samples = [data_split[i] for i in indices]
59
  else:
60
  samples = data_split
61
-
62
  # Convert to test case format
63
  test_cases = []
64
  for sample in samples:
65
  # Extract entities from privacy_mask
66
  entities: Dict[str, List[str]] = {}
67
- for entity in sample['privacy_mask']:
68
- label = entity['label']
69
- value = entity['value']
70
  if label not in entities:
71
  entities[label] = []
72
  entities[label].append(value)
73
-
74
  test_case = {
75
  "description": f"AI4Privacy Sample (ID: {sample['uid']})",
76
- "input_text": sample['source_text'],
77
  "expected_entities": entities,
78
- "masked_text": sample['masked_text'],
79
- "language": sample['language'],
80
- "locale": sample['locale']
81
  }
82
  test_cases.append(test_case)
83
-
84
  return test_cases
85
 
 
86
  @weave.op()
87
  def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]:
88
  """
89
  Evaluate a model on the test cases.
90
-
91
  Args:
92
  guardrail: Entity recognition guardrail to evaluate
93
  test_cases: List of test cases
94
-
95
  Returns:
96
  Tuple of (metrics dict, detailed results list)
97
  """
@@ -99,17 +104,17 @@ def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]
99
  "total": len(test_cases),
100
  "passed": 0,
101
  "failed": 0,
102
- "entity_metrics": {} # Will store precision/recall per entity type
103
  }
104
-
105
  detailed_results = []
106
-
107
  for test_case in tqdm(test_cases, desc="Evaluating samples"):
108
  # Run detection
109
- result = guardrail.guard(test_case['input_text'])
110
  detected = result.detected_entities
111
- expected = test_case['expected_entities']
112
-
113
  # Map Presidio entities if this is the Presidio guardrail
114
  if isinstance(guardrail, PresidioEntityRecognitionGuardrail):
115
  mapped_detected = {}
@@ -120,44 +125,62 @@ def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]
120
  mapped_detected[mapped_type] = []
121
  mapped_detected[mapped_type].extend(values)
122
  detected = mapped_detected
123
-
124
  # Track entity-level metrics
125
  all_entity_types = set(list(detected.keys()) + list(expected.keys()))
126
  entity_results = {}
127
-
128
  for entity_type in all_entity_types:
129
  detected_set = set(detected.get(entity_type, []))
130
  expected_set = set(expected.get(entity_type, []))
131
-
132
  # Calculate metrics
133
  true_positives = len(detected_set & expected_set)
134
  false_positives = len(detected_set - expected_set)
135
  false_negatives = len(expected_set - detected_set)
136
-
137
- precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
138
- recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
139
- f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
140
-
 
 
 
 
 
 
 
 
 
 
 
 
141
  entity_results[entity_type] = {
142
  "precision": precision,
143
  "recall": recall,
144
  "f1": f1,
145
  "true_positives": true_positives,
146
  "false_positives": false_positives,
147
- "false_negatives": false_negatives
148
  }
149
-
150
  # Aggregate metrics
151
  if entity_type not in metrics["entity_metrics"]:
152
  metrics["entity_metrics"][entity_type] = {
153
  "total_true_positives": 0,
154
  "total_false_positives": 0,
155
- "total_false_negatives": 0
156
  }
157
- metrics["entity_metrics"][entity_type]["total_true_positives"] += true_positives
158
- metrics["entity_metrics"][entity_type]["total_false_positives"] += false_positives
159
- metrics["entity_metrics"][entity_type]["total_false_negatives"] += false_negatives
160
-
 
 
 
 
 
 
161
  # Store detailed result
162
  detailed_result = {
163
  "id": test_case.get("description", ""),
@@ -167,69 +190,88 @@ def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]
167
  "expected_entities": expected,
168
  "detected_entities": detected,
169
  "entity_metrics": entity_results,
170
- "anonymized_text": result.anonymized_text if result.anonymized_text else None
 
 
171
  }
172
  detailed_results.append(detailed_result)
173
-
174
  # Update pass/fail counts
175
  if all(entity_results[et]["f1"] == 1.0 for et in entity_results):
176
  metrics["passed"] += 1
177
  else:
178
  metrics["failed"] += 1
179
-
180
  # Calculate final entity metrics and track totals for overall metrics
181
  total_tp = 0
182
  total_fp = 0
183
  total_fn = 0
184
-
185
  for entity_type, counts in metrics["entity_metrics"].items():
186
  tp = counts["total_true_positives"]
187
  fp = counts["total_false_positives"]
188
  fn = counts["total_false_negatives"]
189
-
190
  total_tp += tp
191
  total_fp += fp
192
  total_fn += fn
193
-
194
  precision = tp / (tp + fp) if (tp + fp) > 0 else 0
195
  recall = tp / (tp + fn) if (tp + fn) > 0 else 0
196
- f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
197
-
198
- metrics["entity_metrics"][entity_type].update({
199
- "precision": precision,
200
- "recall": recall,
201
- "f1": f1
202
- })
203
-
 
 
204
  # Calculate overall metrics
205
- overall_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
206
- overall_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
207
- overall_f1 = 2 * (overall_precision * overall_recall) / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0
208
-
 
 
 
 
 
 
 
 
209
  metrics["overall"] = {
210
  "precision": overall_precision,
211
  "recall": overall_recall,
212
  "f1": overall_f1,
213
  "total_true_positives": total_tp,
214
  "total_false_positives": total_fp,
215
- "total_false_negatives": total_fn
216
  }
217
-
218
  return metrics, detailed_results
219
 
220
- def save_results(metrics: Dict, detailed_results: List[Dict], model_name: str, output_dir: str = "evaluation_results"):
 
 
 
 
 
 
221
  """Save evaluation results to files"""
222
  output_dir = Path(output_dir)
223
  output_dir.mkdir(exist_ok=True)
224
-
225
  # Save metrics summary
226
  with open(output_dir / f"{model_name}_metrics.json", "w") as f:
227
  json.dump(metrics, f, indent=2)
228
-
229
  # Save detailed results
230
  with open(output_dir / f"{model_name}_detailed_results.json", "w") as f:
231
  json.dump(detailed_results, f, indent=2)
232
 
 
233
  def print_metrics_summary(metrics: Dict):
234
  """Print a summary of the evaluation metrics"""
235
  print("\nEvaluation Summary")
@@ -238,7 +280,7 @@ def print_metrics_summary(metrics: Dict):
238
  print(f"Passed: {metrics['passed']}")
239
  print(f"Failed: {metrics['failed']}")
240
  print(f"Success Rate: {(metrics['passed']/metrics['total'])*100:.1f}%")
241
-
242
  # Print overall metrics
243
  print("\nOverall Metrics:")
244
  print("-" * 80)
@@ -247,40 +289,56 @@ def print_metrics_summary(metrics: Dict):
247
  print(f"{'Precision':<20} {metrics['overall']['precision']:>10.2f}")
248
  print(f"{'Recall':<20} {metrics['overall']['recall']:>10.2f}")
249
  print(f"{'F1':<20} {metrics['overall']['f1']:>10.2f}")
250
-
251
  print("\nEntity-level Metrics:")
252
  print("-" * 80)
253
  print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
254
  print("-" * 80)
255
  for entity_type, entity_metrics in metrics["entity_metrics"].items():
256
- print(f"{entity_type:<20} {entity_metrics['precision']:>10.2f} {entity_metrics['recall']:>10.2f} {entity_metrics['f1']:>10.2f}")
 
 
 
257
 
258
  def main():
259
  """Main evaluation function"""
260
  weave.init("guardrails-genie-pii-evaluation-demo")
261
-
262
  # Load test cases
263
  test_cases = load_ai4privacy_dataset(num_samples=100)
264
-
265
  # Initialize models to evaluate
266
  models = {
267
- "regex": RegexEntityRecognitionGuardrail(should_anonymize=True, show_available_entities=True),
268
- "presidio": PresidioEntityRecognitionGuardrail(should_anonymize=True, show_available_entities=True),
269
- "transformers": TransformersEntityRecognitionGuardrail(should_anonymize=True, show_available_entities=True)
 
 
 
 
 
 
270
  }
271
-
272
  # Evaluate each model
273
  for model_name, guardrail in models.items():
274
  print(f"\nEvaluating {model_name} model...")
275
  metrics, detailed_results = evaluate_model(guardrail, test_cases)
276
-
277
  # Print and save results
278
  print_metrics_summary(metrics)
279
  save_results(metrics, detailed_results, model_name)
280
 
 
281
  if __name__ == "__main__":
282
- from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
283
- from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
284
- from guardrails_genie.guardrails.entity_recognition.transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
285
-
286
- main()
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+ import random
3
  from pathlib import Path
4
+ from typing import Dict, List, Tuple
5
+
6
  import weave
7
+ from datasets import load_dataset
8
+ from tqdm import tqdm
9
 
10
  # Add this mapping dictionary near the top of the file
11
  PRESIDIO_TO_TRANSFORMER_MAPPING = {
 
33
  "CRYPTO": "ACCOUNTNUM", # Cryptocurrency addresses
34
  "IBAN_CODE": "ACCOUNTNUM",
35
  "MEDICAL_LICENSE": "IDCARDNUM",
36
+ "IN_VEHICLE_REGISTRATION": "IDCARDNUM",
37
  }
38
 
39
+
40
+ def load_ai4privacy_dataset(
41
+ num_samples: int = 100, split: str = "validation"
42
+ ) -> List[Dict]:
43
  """
44
  Load and prepare samples from the ai4privacy dataset.
45
+
46
  Args:
47
  num_samples: Number of samples to evaluate
48
  split: Dataset split to use ("train" or "validation")
49
+
50
  Returns:
51
  List of prepared test cases
52
  """
53
  # Load the dataset
54
  dataset = load_dataset("ai4privacy/pii-masking-400k")
55
+
56
  # Get the specified split
57
  data_split = dataset[split]
58
+
59
  # Randomly sample entries if num_samples is less than total
60
  if num_samples < len(data_split):
61
  indices = random.sample(range(len(data_split)), num_samples)
62
  samples = [data_split[i] for i in indices]
63
  else:
64
  samples = data_split
65
+
66
  # Convert to test case format
67
  test_cases = []
68
  for sample in samples:
69
  # Extract entities from privacy_mask
70
  entities: Dict[str, List[str]] = {}
71
+ for entity in sample["privacy_mask"]:
72
+ label = entity["label"]
73
+ value = entity["value"]
74
  if label not in entities:
75
  entities[label] = []
76
  entities[label].append(value)
77
+
78
  test_case = {
79
  "description": f"AI4Privacy Sample (ID: {sample['uid']})",
80
+ "input_text": sample["source_text"],
81
  "expected_entities": entities,
82
+ "masked_text": sample["masked_text"],
83
+ "language": sample["language"],
84
+ "locale": sample["locale"],
85
  }
86
  test_cases.append(test_case)
87
+
88
  return test_cases
89
 
90
+
91
  @weave.op()
92
  def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]:
93
  """
94
  Evaluate a model on the test cases.
95
+
96
  Args:
97
  guardrail: Entity recognition guardrail to evaluate
98
  test_cases: List of test cases
99
+
100
  Returns:
101
  Tuple of (metrics dict, detailed results list)
102
  """
 
104
  "total": len(test_cases),
105
  "passed": 0,
106
  "failed": 0,
107
+ "entity_metrics": {}, # Will store precision/recall per entity type
108
  }
109
+
110
  detailed_results = []
111
+
112
  for test_case in tqdm(test_cases, desc="Evaluating samples"):
113
  # Run detection
114
+ result = guardrail.guard(test_case["input_text"])
115
  detected = result.detected_entities
116
+ expected = test_case["expected_entities"]
117
+
118
  # Map Presidio entities if this is the Presidio guardrail
119
  if isinstance(guardrail, PresidioEntityRecognitionGuardrail):
120
  mapped_detected = {}
 
125
  mapped_detected[mapped_type] = []
126
  mapped_detected[mapped_type].extend(values)
127
  detected = mapped_detected
128
+
129
  # Track entity-level metrics
130
  all_entity_types = set(list(detected.keys()) + list(expected.keys()))
131
  entity_results = {}
132
+
133
  for entity_type in all_entity_types:
134
  detected_set = set(detected.get(entity_type, []))
135
  expected_set = set(expected.get(entity_type, []))
136
+
137
  # Calculate metrics
138
  true_positives = len(detected_set & expected_set)
139
  false_positives = len(detected_set - expected_set)
140
  false_negatives = len(expected_set - detected_set)
141
+
142
+ precision = (
143
+ true_positives / (true_positives + false_positives)
144
+ if (true_positives + false_positives) > 0
145
+ else 0
146
+ )
147
+ recall = (
148
+ true_positives / (true_positives + false_negatives)
149
+ if (true_positives + false_negatives) > 0
150
+ else 0
151
+ )
152
+ f1 = (
153
+ 2 * (precision * recall) / (precision + recall)
154
+ if (precision + recall) > 0
155
+ else 0
156
+ )
157
+
158
  entity_results[entity_type] = {
159
  "precision": precision,
160
  "recall": recall,
161
  "f1": f1,
162
  "true_positives": true_positives,
163
  "false_positives": false_positives,
164
+ "false_negatives": false_negatives,
165
  }
166
+
167
  # Aggregate metrics
168
  if entity_type not in metrics["entity_metrics"]:
169
  metrics["entity_metrics"][entity_type] = {
170
  "total_true_positives": 0,
171
  "total_false_positives": 0,
172
+ "total_false_negatives": 0,
173
  }
174
+ metrics["entity_metrics"][entity_type][
175
+ "total_true_positives"
176
+ ] += true_positives
177
+ metrics["entity_metrics"][entity_type][
178
+ "total_false_positives"
179
+ ] += false_positives
180
+ metrics["entity_metrics"][entity_type][
181
+ "total_false_negatives"
182
+ ] += false_negatives
183
+
184
  # Store detailed result
185
  detailed_result = {
186
  "id": test_case.get("description", ""),
 
190
  "expected_entities": expected,
191
  "detected_entities": detected,
192
  "entity_metrics": entity_results,
193
+ "anonymized_text": (
194
+ result.anonymized_text if result.anonymized_text else None
195
+ ),
196
  }
197
  detailed_results.append(detailed_result)
198
+
199
  # Update pass/fail counts
200
  if all(entity_results[et]["f1"] == 1.0 for et in entity_results):
201
  metrics["passed"] += 1
202
  else:
203
  metrics["failed"] += 1
204
+
205
  # Calculate final entity metrics and track totals for overall metrics
206
  total_tp = 0
207
  total_fp = 0
208
  total_fn = 0
209
+
210
  for entity_type, counts in metrics["entity_metrics"].items():
211
  tp = counts["total_true_positives"]
212
  fp = counts["total_false_positives"]
213
  fn = counts["total_false_negatives"]
214
+
215
  total_tp += tp
216
  total_fp += fp
217
  total_fn += fn
218
+
219
  precision = tp / (tp + fp) if (tp + fp) > 0 else 0
220
  recall = tp / (tp + fn) if (tp + fn) > 0 else 0
221
+ f1 = (
222
+ 2 * (precision * recall) / (precision + recall)
223
+ if (precision + recall) > 0
224
+ else 0
225
+ )
226
+
227
+ metrics["entity_metrics"][entity_type].update(
228
+ {"precision": precision, "recall": recall, "f1": f1}
229
+ )
230
+
231
  # Calculate overall metrics
232
+ overall_precision = (
233
+ total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
234
+ )
235
+ overall_recall = (
236
+ total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
237
+ )
238
+ overall_f1 = (
239
+ 2 * (overall_precision * overall_recall) / (overall_precision + overall_recall)
240
+ if (overall_precision + overall_recall) > 0
241
+ else 0
242
+ )
243
+
244
  metrics["overall"] = {
245
  "precision": overall_precision,
246
  "recall": overall_recall,
247
  "f1": overall_f1,
248
  "total_true_positives": total_tp,
249
  "total_false_positives": total_fp,
250
+ "total_false_negatives": total_fn,
251
  }
252
+
253
  return metrics, detailed_results
254
 
255
+
256
+ def save_results(
257
+ metrics: Dict,
258
+ detailed_results: List[Dict],
259
+ model_name: str,
260
+ output_dir: str = "evaluation_results",
261
+ ):
262
  """Save evaluation results to files"""
263
  output_dir = Path(output_dir)
264
  output_dir.mkdir(exist_ok=True)
265
+
266
  # Save metrics summary
267
  with open(output_dir / f"{model_name}_metrics.json", "w") as f:
268
  json.dump(metrics, f, indent=2)
269
+
270
  # Save detailed results
271
  with open(output_dir / f"{model_name}_detailed_results.json", "w") as f:
272
  json.dump(detailed_results, f, indent=2)
273
 
274
+
275
  def print_metrics_summary(metrics: Dict):
276
  """Print a summary of the evaluation metrics"""
277
  print("\nEvaluation Summary")
 
280
  print(f"Passed: {metrics['passed']}")
281
  print(f"Failed: {metrics['failed']}")
282
  print(f"Success Rate: {(metrics['passed']/metrics['total'])*100:.1f}%")
283
+
284
  # Print overall metrics
285
  print("\nOverall Metrics:")
286
  print("-" * 80)
 
289
  print(f"{'Precision':<20} {metrics['overall']['precision']:>10.2f}")
290
  print(f"{'Recall':<20} {metrics['overall']['recall']:>10.2f}")
291
  print(f"{'F1':<20} {metrics['overall']['f1']:>10.2f}")
292
+
293
  print("\nEntity-level Metrics:")
294
  print("-" * 80)
295
  print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
296
  print("-" * 80)
297
  for entity_type, entity_metrics in metrics["entity_metrics"].items():
298
+ print(
299
+ f"{entity_type:<20} {entity_metrics['precision']:>10.2f} {entity_metrics['recall']:>10.2f} {entity_metrics['f1']:>10.2f}"
300
+ )
301
+
302
 
303
  def main():
304
  """Main evaluation function"""
305
  weave.init("guardrails-genie-pii-evaluation-demo")
306
+
307
  # Load test cases
308
  test_cases = load_ai4privacy_dataset(num_samples=100)
309
+
310
  # Initialize models to evaluate
311
  models = {
312
+ "regex": RegexEntityRecognitionGuardrail(
313
+ should_anonymize=True, show_available_entities=True
314
+ ),
315
+ "presidio": PresidioEntityRecognitionGuardrail(
316
+ should_anonymize=True, show_available_entities=True
317
+ ),
318
+ "transformers": TransformersEntityRecognitionGuardrail(
319
+ should_anonymize=True, show_available_entities=True
320
+ ),
321
  }
322
+
323
  # Evaluate each model
324
  for model_name, guardrail in models.items():
325
  print(f"\nEvaluating {model_name} model...")
326
  metrics, detailed_results = evaluate_model(guardrail, test_cases)
327
+
328
  # Print and save results
329
  print_metrics_summary(metrics)
330
  save_results(metrics, detailed_results, model_name)
331
 
332
+
333
  if __name__ == "__main__":
334
+ from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import (
335
+ PresidioEntityRecognitionGuardrail,
336
+ )
337
+ from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import (
338
+ RegexEntityRecognitionGuardrail,
339
+ )
340
+ from guardrails_genie.guardrails.entity_recognition.transformers_entity_recognition_guardrail import (
341
+ TransformersEntityRecognitionGuardrail,
342
+ )
343
+
344
+ main()
guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark_weave.py CHANGED
@@ -1,13 +1,13 @@
1
- from datasets import load_dataset
2
- from typing import Dict, List, Tuple, Optional
3
- import random
4
- from tqdm import tqdm
5
  import json
 
6
  from pathlib import Path
 
 
7
  import weave
8
- from weave.scorers import Scorer
9
  from weave import Evaluation
10
- import asyncio
11
 
12
  # Add this mapping dictionary near the top of the file
13
  PRESIDIO_TO_TRANSFORMER_MAPPING = {
@@ -35,26 +35,29 @@ PRESIDIO_TO_TRANSFORMER_MAPPING = {
35
  "CRYPTO": "ACCOUNTNUM", # Cryptocurrency addresses
36
  "IBAN_CODE": "ACCOUNTNUM",
37
  "MEDICAL_LICENSE": "IDCARDNUM",
38
- "IN_VEHICLE_REGISTRATION": "IDCARDNUM"
39
  }
40
 
 
41
  class EntityRecognitionScorer(Scorer):
42
  """Scorer for evaluating entity recognition performance"""
43
-
44
  @weave.op()
45
- async def score(self, model_output: Optional[dict], input_text: str, expected_entities: Dict) -> Dict:
 
 
46
  """Score entity recognition results"""
47
  if not model_output:
48
  return {"f1": 0.0}
49
-
50
  # Convert Pydantic model to dict if necessary
51
  if hasattr(model_output, "model_dump"):
52
  model_output = model_output.model_dump()
53
  elif hasattr(model_output, "dict"):
54
  model_output = model_output.dict()
55
-
56
  detected = model_output.get("detected_entities", {})
57
-
58
  # Map Presidio entities if needed
59
  if model_output.get("model_type") == "presidio":
60
  mapped_detected = {}
@@ -65,191 +68,234 @@ class EntityRecognitionScorer(Scorer):
65
  mapped_detected[mapped_type] = []
66
  mapped_detected[mapped_type].extend(values)
67
  detected = mapped_detected
68
-
69
  # Track entity-level metrics
70
  all_entity_types = set(list(detected.keys()) + list(expected_entities.keys()))
71
  entity_metrics = {}
72
-
73
  for entity_type in all_entity_types:
74
  detected_set = set(detected.get(entity_type, []))
75
  expected_set = set(expected_entities.get(entity_type, []))
76
-
77
  # Calculate metrics
78
  true_positives = len(detected_set & expected_set)
79
  false_positives = len(detected_set - expected_set)
80
  false_negatives = len(expected_set - detected_set)
81
-
82
  if entity_type not in entity_metrics:
83
  entity_metrics[entity_type] = {
84
  "total_true_positives": 0,
85
  "total_false_positives": 0,
86
- "total_false_negatives": 0
87
  }
88
-
89
  entity_metrics[entity_type]["total_true_positives"] += true_positives
90
  entity_metrics[entity_type]["total_false_positives"] += false_positives
91
  entity_metrics[entity_type]["total_false_negatives"] += false_negatives
92
-
93
  # Calculate per-entity metrics
94
- precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
95
- recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
96
- f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
97
-
98
- entity_metrics[entity_type].update({
99
- "precision": precision,
100
- "recall": recall,
101
- "f1": f1
102
- })
103
-
 
 
 
 
 
 
 
 
 
 
104
  # Calculate overall metrics
105
- total_tp = sum(metrics["total_true_positives"] for metrics in entity_metrics.values())
106
- total_fp = sum(metrics["total_false_positives"] for metrics in entity_metrics.values())
107
- total_fn = sum(metrics["total_false_negatives"] for metrics in entity_metrics.values())
108
-
109
- overall_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
110
- overall_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
111
- overall_f1 = 2 * (overall_precision * overall_recall) / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0
112
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  entity_metrics["overall"] = {
114
  "precision": overall_precision,
115
  "recall": overall_recall,
116
  "f1": overall_f1,
117
  "total_true_positives": total_tp,
118
  "total_false_positives": total_fp,
119
- "total_false_negatives": total_fn
120
  }
121
-
122
  return entity_metrics
123
 
124
- def load_ai4privacy_dataset(num_samples: int = 100, split: str = "validation") -> List[Dict]:
 
 
 
125
  """
126
  Load and prepare samples from the ai4privacy dataset.
127
-
128
  Args:
129
  num_samples: Number of samples to evaluate
130
  split: Dataset split to use ("train" or "validation")
131
-
132
  Returns:
133
  List of prepared test cases
134
  """
135
  # Load the dataset
136
  dataset = load_dataset("ai4privacy/pii-masking-400k")
137
-
138
  # Get the specified split
139
  data_split = dataset[split]
140
-
141
  # Randomly sample entries if num_samples is less than total
142
  if num_samples < len(data_split):
143
  indices = random.sample(range(len(data_split)), num_samples)
144
  samples = [data_split[i] for i in indices]
145
  else:
146
  samples = data_split
147
-
148
  # Convert to test case format
149
  test_cases = []
150
  for sample in samples:
151
  # Extract entities from privacy_mask
152
  entities: Dict[str, List[str]] = {}
153
- for entity in sample['privacy_mask']:
154
- label = entity['label']
155
- value = entity['value']
156
  if label not in entities:
157
  entities[label] = []
158
  entities[label].append(value)
159
-
160
  test_case = {
161
  "description": f"AI4Privacy Sample (ID: {sample['uid']})",
162
- "input_text": sample['source_text'],
163
  "expected_entities": entities,
164
- "masked_text": sample['masked_text'],
165
- "language": sample['language'],
166
- "locale": sample['locale']
167
  }
168
  test_cases.append(test_case)
169
-
170
  return test_cases
171
 
172
- def save_results(weave_results: Dict, model_name: str, output_dir: str = "evaluation_results"):
 
 
 
173
  """Save evaluation results to files"""
174
  output_dir = Path(output_dir)
175
  output_dir.mkdir(exist_ok=True)
176
-
177
  # Extract and process results
178
  scorer_results = weave_results.get("EntityRecognitionScorer", [])
179
  if not scorer_results or all(r is None for r in scorer_results):
180
  print(f"No valid results to save for {model_name}")
181
  return
182
-
183
  # Calculate summary metrics
184
  total_samples = len(scorer_results)
185
  passed = sum(1 for r in scorer_results if r is not None and not isinstance(r, str))
186
-
187
  # Aggregate entity-level metrics
188
  entity_metrics = {}
189
  for result in scorer_results:
190
  try:
191
  if isinstance(result, str) or not result:
192
  continue
193
-
194
  for entity_type, metrics in result.items():
195
  if entity_type not in entity_metrics:
196
  entity_metrics[entity_type] = {
197
  "precision": [],
198
  "recall": [],
199
- "f1": []
200
  }
201
  entity_metrics[entity_type]["precision"].append(metrics["precision"])
202
  entity_metrics[entity_type]["recall"].append(metrics["recall"])
203
  entity_metrics[entity_type]["f1"].append(metrics["f1"])
204
  except (AttributeError, TypeError, KeyError):
205
  continue
206
-
207
  # Calculate averages
208
  summary_metrics = {
209
  "total": total_samples,
210
  "passed": passed,
211
  "failed": total_samples - passed,
212
- "success_rate": (passed/total_samples) if total_samples > 0 else 0,
213
  "entity_metrics": {
214
  entity_type: {
215
- "precision": sum(metrics["precision"]) / len(metrics["precision"]) if metrics["precision"] else 0,
216
- "recall": sum(metrics["recall"]) / len(metrics["recall"]) if metrics["recall"] else 0,
217
- "f1": sum(metrics["f1"]) / len(metrics["f1"]) if metrics["f1"] else 0
 
 
 
 
 
 
 
 
218
  }
219
  for entity_type, metrics in entity_metrics.items()
220
- }
221
  }
222
-
223
  # Save files
224
  with open(output_dir / f"{model_name}_metrics.json", "w") as f:
225
  json.dump(summary_metrics, f, indent=2)
226
-
227
  # Save detailed results, filtering out string results
228
- detailed_results = [r for r in scorer_results if not isinstance(r, str) and r is not None]
 
 
229
  with open(output_dir / f"{model_name}_detailed_results.json", "w") as f:
230
  json.dump(detailed_results, f, indent=2)
231
 
 
232
  def print_metrics_summary(weave_results: Dict):
233
  """Print a summary of the evaluation metrics"""
234
  print("\nEvaluation Summary")
235
  print("=" * 80)
236
-
237
  # Extract results from Weave's evaluation format
238
  scorer_results = weave_results.get("EntityRecognitionScorer", {})
239
  if not scorer_results:
240
  print("No valid results available")
241
  return
242
-
243
  # Calculate overall metrics
244
  total_samples = int(weave_results.get("model_latency", {}).get("count", 0))
245
  passed = total_samples # Since we have results, all samples passed
246
  failed = 0
247
-
248
  print(f"Total Samples: {total_samples}")
249
  print(f"Passed: {passed}")
250
  print(f"Failed: {failed}")
251
  print(f"Success Rate: {(passed/total_samples)*100:.2f}%")
252
-
253
  # Print overall metrics
254
  if "overall" in scorer_results:
255
  overall = scorer_results["overall"]
@@ -260,63 +306,68 @@ def print_metrics_summary(weave_results: Dict):
260
  print(f"{'Precision':<20} {overall['precision']['mean']:>10.2f}")
261
  print(f"{'Recall':<20} {overall['recall']['mean']:>10.2f}")
262
  print(f"{'F1':<20} {overall['f1']['mean']:>10.2f}")
263
-
264
  # Print entity-level metrics
265
  print("\nEntity-Level Metrics:")
266
  print("-" * 80)
267
  print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
268
  print("-" * 80)
269
-
270
  for entity_type, metrics in scorer_results.items():
271
  if entity_type == "overall":
272
  continue
273
-
274
  precision = metrics.get("precision", {}).get("mean", 0)
275
  recall = metrics.get("recall", {}).get("mean", 0)
276
  f1 = metrics.get("f1", {}).get("mean", 0)
277
-
278
  print(f"{entity_type:<20} {precision:>10.2f} {recall:>10.2f} {f1:>10.2f}")
279
 
 
280
  def preprocess_model_input(example: Dict) -> Dict:
281
  """Preprocess dataset example to match model input format."""
282
  return {
283
  "prompt": example["input_text"],
284
- "model_type": example.get("model_type", "unknown") # Add model type for Presidio mapping
 
 
285
  }
286
 
 
287
  def main():
288
  """Main evaluation function"""
289
  weave.init("guardrails-genie-pii-evaluation")
290
-
291
  # Load test cases
292
  test_cases = load_ai4privacy_dataset(num_samples=100)
293
-
294
  # Add model type to test cases for Presidio mapping
295
  models = {
296
  # "regex": RegexEntityRecognitionGuardrail(should_anonymize=True),
297
  "presidio": PresidioEntityRecognitionGuardrail(should_anonymize=True),
298
  # "transformers": TransformersEntityRecognitionGuardrail(should_anonymize=True)
299
  }
300
-
301
  scorer = EntityRecognitionScorer()
302
-
303
  # Evaluate each model
304
  for model_name, guardrail in models.items():
305
  print(f"\nEvaluating {model_name} model...")
306
  # Add model type to test cases
307
  model_test_cases = [{**case, "model_type": model_name} for case in test_cases]
308
-
309
  evaluation = Evaluation(
310
  dataset=model_test_cases,
311
  scorers=[scorer],
312
- preprocess_model_input=preprocess_model_input
313
  )
314
-
315
  results = asyncio.run(evaluation.evaluate(guardrail))
316
 
 
317
  if __name__ == "__main__":
318
- from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
319
- from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
320
- from guardrails_genie.guardrails.entity_recognition.transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
321
-
322
- main()
 
1
+ import asyncio
 
 
 
2
  import json
3
+ import random
4
  from pathlib import Path
5
+ from typing import Dict, List, Optional
6
+
7
  import weave
8
+ from datasets import load_dataset
9
  from weave import Evaluation
10
+ from weave.scorers import Scorer
11
 
12
  # Add this mapping dictionary near the top of the file
13
  PRESIDIO_TO_TRANSFORMER_MAPPING = {
 
35
  "CRYPTO": "ACCOUNTNUM", # Cryptocurrency addresses
36
  "IBAN_CODE": "ACCOUNTNUM",
37
  "MEDICAL_LICENSE": "IDCARDNUM",
38
+ "IN_VEHICLE_REGISTRATION": "IDCARDNUM",
39
  }
40
 
41
+
42
  class EntityRecognitionScorer(Scorer):
43
  """Scorer for evaluating entity recognition performance"""
44
+
45
  @weave.op()
46
+ async def score(
47
+ self, model_output: Optional[dict], input_text: str, expected_entities: Dict
48
+ ) -> Dict:
49
  """Score entity recognition results"""
50
  if not model_output:
51
  return {"f1": 0.0}
52
+
53
  # Convert Pydantic model to dict if necessary
54
  if hasattr(model_output, "model_dump"):
55
  model_output = model_output.model_dump()
56
  elif hasattr(model_output, "dict"):
57
  model_output = model_output.dict()
58
+
59
  detected = model_output.get("detected_entities", {})
60
+
61
  # Map Presidio entities if needed
62
  if model_output.get("model_type") == "presidio":
63
  mapped_detected = {}
 
68
  mapped_detected[mapped_type] = []
69
  mapped_detected[mapped_type].extend(values)
70
  detected = mapped_detected
71
+
72
  # Track entity-level metrics
73
  all_entity_types = set(list(detected.keys()) + list(expected_entities.keys()))
74
  entity_metrics = {}
75
+
76
  for entity_type in all_entity_types:
77
  detected_set = set(detected.get(entity_type, []))
78
  expected_set = set(expected_entities.get(entity_type, []))
79
+
80
  # Calculate metrics
81
  true_positives = len(detected_set & expected_set)
82
  false_positives = len(detected_set - expected_set)
83
  false_negatives = len(expected_set - detected_set)
84
+
85
  if entity_type not in entity_metrics:
86
  entity_metrics[entity_type] = {
87
  "total_true_positives": 0,
88
  "total_false_positives": 0,
89
+ "total_false_negatives": 0,
90
  }
91
+
92
  entity_metrics[entity_type]["total_true_positives"] += true_positives
93
  entity_metrics[entity_type]["total_false_positives"] += false_positives
94
  entity_metrics[entity_type]["total_false_negatives"] += false_negatives
95
+
96
  # Calculate per-entity metrics
97
+ precision = (
98
+ true_positives / (true_positives + false_positives)
99
+ if (true_positives + false_positives) > 0
100
+ else 0
101
+ )
102
+ recall = (
103
+ true_positives / (true_positives + false_negatives)
104
+ if (true_positives + false_negatives) > 0
105
+ else 0
106
+ )
107
+ f1 = (
108
+ 2 * (precision * recall) / (precision + recall)
109
+ if (precision + recall) > 0
110
+ else 0
111
+ )
112
+
113
+ entity_metrics[entity_type].update(
114
+ {"precision": precision, "recall": recall, "f1": f1}
115
+ )
116
+
117
  # Calculate overall metrics
118
+ total_tp = sum(
119
+ metrics["total_true_positives"] for metrics in entity_metrics.values()
120
+ )
121
+ total_fp = sum(
122
+ metrics["total_false_positives"] for metrics in entity_metrics.values()
123
+ )
124
+ total_fn = sum(
125
+ metrics["total_false_negatives"] for metrics in entity_metrics.values()
126
+ )
127
+
128
+ overall_precision = (
129
+ total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
130
+ )
131
+ overall_recall = (
132
+ total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
133
+ )
134
+ overall_f1 = (
135
+ 2
136
+ * (overall_precision * overall_recall)
137
+ / (overall_precision + overall_recall)
138
+ if (overall_precision + overall_recall) > 0
139
+ else 0
140
+ )
141
+
142
  entity_metrics["overall"] = {
143
  "precision": overall_precision,
144
  "recall": overall_recall,
145
  "f1": overall_f1,
146
  "total_true_positives": total_tp,
147
  "total_false_positives": total_fp,
148
+ "total_false_negatives": total_fn,
149
  }
150
+
151
  return entity_metrics
152
 
153
+
154
+ def load_ai4privacy_dataset(
155
+ num_samples: int = 100, split: str = "validation"
156
+ ) -> List[Dict]:
157
  """
158
  Load and prepare samples from the ai4privacy dataset.
159
+
160
  Args:
161
  num_samples: Number of samples to evaluate
162
  split: Dataset split to use ("train" or "validation")
163
+
164
  Returns:
165
  List of prepared test cases
166
  """
167
  # Load the dataset
168
  dataset = load_dataset("ai4privacy/pii-masking-400k")
169
+
170
  # Get the specified split
171
  data_split = dataset[split]
172
+
173
  # Randomly sample entries if num_samples is less than total
174
  if num_samples < len(data_split):
175
  indices = random.sample(range(len(data_split)), num_samples)
176
  samples = [data_split[i] for i in indices]
177
  else:
178
  samples = data_split
179
+
180
  # Convert to test case format
181
  test_cases = []
182
  for sample in samples:
183
  # Extract entities from privacy_mask
184
  entities: Dict[str, List[str]] = {}
185
+ for entity in sample["privacy_mask"]:
186
+ label = entity["label"]
187
+ value = entity["value"]
188
  if label not in entities:
189
  entities[label] = []
190
  entities[label].append(value)
191
+
192
  test_case = {
193
  "description": f"AI4Privacy Sample (ID: {sample['uid']})",
194
+ "input_text": sample["source_text"],
195
  "expected_entities": entities,
196
+ "masked_text": sample["masked_text"],
197
+ "language": sample["language"],
198
+ "locale": sample["locale"],
199
  }
200
  test_cases.append(test_case)
201
+
202
  return test_cases
203
 
204
+
205
+ def save_results(
206
+ weave_results: Dict, model_name: str, output_dir: str = "evaluation_results"
207
+ ):
208
  """Save evaluation results to files"""
209
  output_dir = Path(output_dir)
210
  output_dir.mkdir(exist_ok=True)
211
+
212
  # Extract and process results
213
  scorer_results = weave_results.get("EntityRecognitionScorer", [])
214
  if not scorer_results or all(r is None for r in scorer_results):
215
  print(f"No valid results to save for {model_name}")
216
  return
217
+
218
  # Calculate summary metrics
219
  total_samples = len(scorer_results)
220
  passed = sum(1 for r in scorer_results if r is not None and not isinstance(r, str))
221
+
222
  # Aggregate entity-level metrics
223
  entity_metrics = {}
224
  for result in scorer_results:
225
  try:
226
  if isinstance(result, str) or not result:
227
  continue
228
+
229
  for entity_type, metrics in result.items():
230
  if entity_type not in entity_metrics:
231
  entity_metrics[entity_type] = {
232
  "precision": [],
233
  "recall": [],
234
+ "f1": [],
235
  }
236
  entity_metrics[entity_type]["precision"].append(metrics["precision"])
237
  entity_metrics[entity_type]["recall"].append(metrics["recall"])
238
  entity_metrics[entity_type]["f1"].append(metrics["f1"])
239
  except (AttributeError, TypeError, KeyError):
240
  continue
241
+
242
  # Calculate averages
243
  summary_metrics = {
244
  "total": total_samples,
245
  "passed": passed,
246
  "failed": total_samples - passed,
247
+ "success_rate": (passed / total_samples) if total_samples > 0 else 0,
248
  "entity_metrics": {
249
  entity_type: {
250
+ "precision": (
251
+ sum(metrics["precision"]) / len(metrics["precision"])
252
+ if metrics["precision"]
253
+ else 0
254
+ ),
255
+ "recall": (
256
+ sum(metrics["recall"]) / len(metrics["recall"])
257
+ if metrics["recall"]
258
+ else 0
259
+ ),
260
+ "f1": sum(metrics["f1"]) / len(metrics["f1"]) if metrics["f1"] else 0,
261
  }
262
  for entity_type, metrics in entity_metrics.items()
263
+ },
264
  }
265
+
266
  # Save files
267
  with open(output_dir / f"{model_name}_metrics.json", "w") as f:
268
  json.dump(summary_metrics, f, indent=2)
269
+
270
  # Save detailed results, filtering out string results
271
+ detailed_results = [
272
+ r for r in scorer_results if not isinstance(r, str) and r is not None
273
+ ]
274
  with open(output_dir / f"{model_name}_detailed_results.json", "w") as f:
275
  json.dump(detailed_results, f, indent=2)
276
 
277
+
278
  def print_metrics_summary(weave_results: Dict):
279
  """Print a summary of the evaluation metrics"""
280
  print("\nEvaluation Summary")
281
  print("=" * 80)
282
+
283
  # Extract results from Weave's evaluation format
284
  scorer_results = weave_results.get("EntityRecognitionScorer", {})
285
  if not scorer_results:
286
  print("No valid results available")
287
  return
288
+
289
  # Calculate overall metrics
290
  total_samples = int(weave_results.get("model_latency", {}).get("count", 0))
291
  passed = total_samples # Since we have results, all samples passed
292
  failed = 0
293
+
294
  print(f"Total Samples: {total_samples}")
295
  print(f"Passed: {passed}")
296
  print(f"Failed: {failed}")
297
  print(f"Success Rate: {(passed/total_samples)*100:.2f}%")
298
+
299
  # Print overall metrics
300
  if "overall" in scorer_results:
301
  overall = scorer_results["overall"]
 
306
  print(f"{'Precision':<20} {overall['precision']['mean']:>10.2f}")
307
  print(f"{'Recall':<20} {overall['recall']['mean']:>10.2f}")
308
  print(f"{'F1':<20} {overall['f1']['mean']:>10.2f}")
309
+
310
  # Print entity-level metrics
311
  print("\nEntity-Level Metrics:")
312
  print("-" * 80)
313
  print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
314
  print("-" * 80)
315
+
316
  for entity_type, metrics in scorer_results.items():
317
  if entity_type == "overall":
318
  continue
319
+
320
  precision = metrics.get("precision", {}).get("mean", 0)
321
  recall = metrics.get("recall", {}).get("mean", 0)
322
  f1 = metrics.get("f1", {}).get("mean", 0)
323
+
324
  print(f"{entity_type:<20} {precision:>10.2f} {recall:>10.2f} {f1:>10.2f}")
325
 
326
+
327
  def preprocess_model_input(example: Dict) -> Dict:
328
  """Preprocess dataset example to match model input format."""
329
  return {
330
  "prompt": example["input_text"],
331
+ "model_type": example.get(
332
+ "model_type", "unknown"
333
+ ), # Add model type for Presidio mapping
334
  }
335
 
336
+
337
  def main():
338
  """Main evaluation function"""
339
  weave.init("guardrails-genie-pii-evaluation")
340
+
341
  # Load test cases
342
  test_cases = load_ai4privacy_dataset(num_samples=100)
343
+
344
  # Add model type to test cases for Presidio mapping
345
  models = {
346
  # "regex": RegexEntityRecognitionGuardrail(should_anonymize=True),
347
  "presidio": PresidioEntityRecognitionGuardrail(should_anonymize=True),
348
  # "transformers": TransformersEntityRecognitionGuardrail(should_anonymize=True)
349
  }
350
+
351
  scorer = EntityRecognitionScorer()
352
+
353
  # Evaluate each model
354
  for model_name, guardrail in models.items():
355
  print(f"\nEvaluating {model_name} model...")
356
  # Add model type to test cases
357
  model_test_cases = [{**case, "model_type": model_name} for case in test_cases]
358
+
359
  evaluation = Evaluation(
360
  dataset=model_test_cases,
361
  scorers=[scorer],
362
+ preprocess_model_input=preprocess_model_input,
363
  )
364
+
365
  results = asyncio.run(evaluation.evaluate(guardrail))
366
 
367
+
368
  if __name__ == "__main__":
369
+ from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import (
370
+ PresidioEntityRecognitionGuardrail,
371
+ )
372
+
373
+ main()
guardrails_genie/guardrails/entity_recognition/pii_examples/pii_test_examples.py CHANGED
@@ -18,8 +18,8 @@ Emergency Contact: Mary Johnson (Tel: 098-765-4321)
18
  "SURNAME": ["Smith", "Johnson"],
19
  "EMAIL": ["[email protected]"],
20
  "PHONE_NUMBER": ["123-456-7890", "098-765-4321"],
21
- "SOCIALNUM": ["123-45-6789"]
22
- }
23
  },
24
  {
25
  "description": "Meeting Notes with Attendees",
@@ -39,8 +39,8 @@ Action Items:
39
  "GIVENNAME": ["Sarah", "Robert", "Tom", "Bob"],
40
  "SURNAME": ["Williams", "Brown", "Wilson"],
41
42
- "PHONE_NUMBER": ["555-0123-4567", "777-888-9999"]
43
- }
44
  },
45
  {
46
  "description": "Medical Record",
@@ -57,8 +57,8 @@ Emergency Contact: Michael Thompson (555-123-4567)
57
  "GIVENNAME": ["Emma", "James", "Michael"],
58
  "SURNAME": ["Thompson", "Wilson", "Thompson"],
59
  "EMAIL": ["[email protected]"],
60
- "PHONE_NUMBER": ["555-123-4567"]
61
- }
62
  },
63
  {
64
  "description": "No PII Content",
@@ -68,7 +68,7 @@ Project Status Update:
68
  - Budget is within limits
69
  - Next review scheduled for next week
70
  """,
71
- "expected_entities": {}
72
  },
73
  {
74
  "description": "Mixed Format Phone Numbers",
@@ -84,10 +84,10 @@ Emergency: 555 444 3333
84
  "(555) 123-4567",
85
  "555.987.6543",
86
  "+1-555-321-7890",
87
- "555 444 3333"
88
  ]
89
- }
90
- }
91
  ]
92
 
93
  # Additional examples can be added to test specific edge cases or formats
@@ -103,37 +103,41 @@ [email protected]
103
  "EMAIL": [
104
105
106
107
  ],
108
  "GIVENNAME": ["John", "Jane", "Bob"],
109
- "SURNAME": ["Doe", "Smith", "Jones"]
110
- }
111
  }
112
  ]
113
 
 
114
  def validate_entities(detected: dict, expected: dict) -> bool:
115
  """Compare detected entities with expected entities"""
116
  if set(detected.keys()) != set(expected.keys()):
117
  return False
118
  return all(set(detected[k]) == set(expected[k]) for k in expected.keys())
119
 
 
120
  def run_test_case(guardrail, test_case, test_type="Main"):
121
  """Run a single test case and print results"""
122
  print(f"\n{test_type} Test Case: {test_case['description']}")
123
  print("-" * 50)
124
-
125
- result = guardrail.guard(test_case['input_text'])
126
- expected = test_case['expected_entities']
127
-
128
  # Validate results
129
  matches = validate_entities(result.detected_entities, expected)
130
-
131
  print(f"Test Status: {'✓ PASS' if matches else '✗ FAIL'}")
132
  print(f"Contains PII: {result.contains_entities}")
133
-
134
  if not matches:
135
  print("\nEntity Comparison:")
136
- all_entity_types = set(list(result.detected_entities.keys()) + list(expected.keys()))
 
 
137
  for entity_type in all_entity_types:
138
  detected = set(result.detected_entities.get(entity_type, []))
139
  expected_set = set(expected.get(entity_type, []))
@@ -143,8 +147,8 @@ def run_test_case(guardrail, test_case, test_type="Main"):
143
  if detected != expected_set:
144
  print(f" Missing: {sorted(expected_set - detected)}")
145
  print(f" Extra: {sorted(detected - expected_set)}")
146
-
147
  if result.anonymized_text:
148
  print(f"\nAnonymized Text:\n{result.anonymized_text}")
149
-
150
- return matches
 
18
  "SURNAME": ["Smith", "Johnson"],
19
  "EMAIL": ["[email protected]"],
20
  "PHONE_NUMBER": ["123-456-7890", "098-765-4321"],
21
+ "SOCIALNUM": ["123-45-6789"],
22
+ },
23
  },
24
  {
25
  "description": "Meeting Notes with Attendees",
 
39
  "GIVENNAME": ["Sarah", "Robert", "Tom", "Bob"],
40
  "SURNAME": ["Williams", "Brown", "Wilson"],
41
42
+ "PHONE_NUMBER": ["555-0123-4567", "777-888-9999"],
43
+ },
44
  },
45
  {
46
  "description": "Medical Record",
 
57
  "GIVENNAME": ["Emma", "James", "Michael"],
58
  "SURNAME": ["Thompson", "Wilson", "Thompson"],
59
  "EMAIL": ["[email protected]"],
60
+ "PHONE_NUMBER": ["555-123-4567"],
61
+ },
62
  },
63
  {
64
  "description": "No PII Content",
 
68
  - Budget is within limits
69
  - Next review scheduled for next week
70
  """,
71
+ "expected_entities": {},
72
  },
73
  {
74
  "description": "Mixed Format Phone Numbers",
 
84
  "(555) 123-4567",
85
  "555.987.6543",
86
  "+1-555-321-7890",
87
+ "555 444 3333",
88
  ]
89
+ },
90
+ },
91
  ]
92
 
93
  # Additional examples can be added to test specific edge cases or formats
 
103
  "EMAIL": [
104
105
106
107
  ],
108
  "GIVENNAME": ["John", "Jane", "Bob"],
109
+ "SURNAME": ["Doe", "Smith", "Jones"],
110
+ },
111
  }
112
  ]
113
 
114
+
115
  def validate_entities(detected: dict, expected: dict) -> bool:
116
  """Compare detected entities with expected entities"""
117
  if set(detected.keys()) != set(expected.keys()):
118
  return False
119
  return all(set(detected[k]) == set(expected[k]) for k in expected.keys())
120
 
121
+
122
  def run_test_case(guardrail, test_case, test_type="Main"):
123
  """Run a single test case and print results"""
124
  print(f"\n{test_type} Test Case: {test_case['description']}")
125
  print("-" * 50)
126
+
127
+ result = guardrail.guard(test_case["input_text"])
128
+ expected = test_case["expected_entities"]
129
+
130
  # Validate results
131
  matches = validate_entities(result.detected_entities, expected)
132
+
133
  print(f"Test Status: {'✓ PASS' if matches else '✗ FAIL'}")
134
  print(f"Contains PII: {result.contains_entities}")
135
+
136
  if not matches:
137
  print("\nEntity Comparison:")
138
+ all_entity_types = set(
139
+ list(result.detected_entities.keys()) + list(expected.keys())
140
+ )
141
  for entity_type in all_entity_types:
142
  detected = set(result.detected_entities.get(entity_type, []))
143
  expected_set = set(expected.get(entity_type, []))
 
147
  if detected != expected_set:
148
  print(f" Missing: {sorted(expected_set - detected)}")
149
  print(f" Extra: {sorted(detected - expected_set)}")
150
+
151
  if result.anonymized_text:
152
  print(f"\nAnonymized Text:\n{result.anonymized_text}")
153
+
154
+ return matches
guardrails_genie/guardrails/entity_recognition/pii_examples/run_presidio_model.py CHANGED
@@ -1,15 +1,22 @@
1
- from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
2
- from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import PII_TEST_EXAMPLES, EDGE_CASE_EXAMPLES, run_test_case, validate_entities
3
  import weave
4
 
 
 
 
 
 
 
 
 
 
 
5
  def test_pii_detection():
6
  """Test PII detection scenarios using predefined test cases"""
7
  weave.init("guardrails-genie-pii-presidio-model")
8
-
9
  # Create the guardrail with default entities and anonymization enabled
10
  pii_guardrail = PresidioEntityRecognitionGuardrail(
11
- should_anonymize=True,
12
- show_available_entities=True
13
  )
14
 
15
  # Test statistics
@@ -38,5 +45,6 @@ def test_pii_detection():
38
  print(f"Failed: {total_tests - passed_tests}")
39
  print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
40
 
 
41
  if __name__ == "__main__":
42
  test_pii_detection()
 
 
 
1
  import weave
2
 
3
+ from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import (
4
+ EDGE_CASE_EXAMPLES,
5
+ PII_TEST_EXAMPLES,
6
+ run_test_case,
7
+ )
8
+ from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import (
9
+ PresidioEntityRecognitionGuardrail,
10
+ )
11
+
12
+
13
  def test_pii_detection():
14
  """Test PII detection scenarios using predefined test cases"""
15
  weave.init("guardrails-genie-pii-presidio-model")
16
+
17
  # Create the guardrail with default entities and anonymization enabled
18
  pii_guardrail = PresidioEntityRecognitionGuardrail(
19
+ should_anonymize=True, show_available_entities=True
 
20
  )
21
 
22
  # Test statistics
 
45
  print(f"Failed: {total_tests - passed_tests}")
46
  print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
47
 
48
+
49
  if __name__ == "__main__":
50
  test_pii_detection()
guardrails_genie/guardrails/entity_recognition/pii_examples/run_regex_model.py CHANGED
@@ -1,15 +1,22 @@
1
- from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
2
- from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import PII_TEST_EXAMPLES, EDGE_CASE_EXAMPLES, run_test_case, validate_entities
3
  import weave
4
 
 
 
 
 
 
 
 
 
 
 
5
  def test_pii_detection():
6
  """Test PII detection scenarios using predefined test cases"""
7
  weave.init("guardrails-genie-pii-regex-model")
8
-
9
  # Create the guardrail with default entities and anonymization enabled
10
  pii_guardrail = RegexEntityRecognitionGuardrail(
11
- should_anonymize=True,
12
- show_available_entities=True
13
  )
14
 
15
  # Test statistics
@@ -38,5 +45,6 @@ def test_pii_detection():
38
  print(f"Failed: {total_tests - passed_tests}")
39
  print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
40
 
 
41
  if __name__ == "__main__":
42
  test_pii_detection()
 
 
 
1
  import weave
2
 
3
+ from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import (
4
+ EDGE_CASE_EXAMPLES,
5
+ PII_TEST_EXAMPLES,
6
+ run_test_case,
7
+ )
8
+ from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import (
9
+ RegexEntityRecognitionGuardrail,
10
+ )
11
+
12
+
13
  def test_pii_detection():
14
  """Test PII detection scenarios using predefined test cases"""
15
  weave.init("guardrails-genie-pii-regex-model")
16
+
17
  # Create the guardrail with default entities and anonymization enabled
18
  pii_guardrail = RegexEntityRecognitionGuardrail(
19
+ should_anonymize=True, show_available_entities=True
 
20
  )
21
 
22
  # Test statistics
 
45
  print(f"Failed: {total_tests - passed_tests}")
46
  print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
47
 
48
+
49
  if __name__ == "__main__":
50
  test_pii_detection()
guardrails_genie/guardrails/entity_recognition/pii_examples/run_transformers.py CHANGED
@@ -1,16 +1,30 @@
1
- from guardrails_genie.guardrails.entity_recognition.transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
2
- from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import PII_TEST_EXAMPLES, EDGE_CASE_EXAMPLES, run_test_case, validate_entities
3
  import weave
4
 
 
 
 
 
 
 
 
 
 
 
5
  def test_pii_detection():
6
  """Test PII detection scenarios using predefined test cases"""
7
  weave.init("guardrails-genie-pii-transformers-pipeline-model")
8
-
9
  # Create the guardrail with default entities and anonymization enabled
10
  pii_guardrail = TransformersEntityRecognitionGuardrail(
11
- selected_entities=["GIVENNAME", "SURNAME", "EMAIL", "TELEPHONENUM", "SOCIALNUM"],
 
 
 
 
 
 
12
  should_anonymize=True,
13
- show_available_entities=True
14
  )
15
 
16
  # Test statistics
@@ -39,5 +53,6 @@ def test_pii_detection():
39
  print(f"Failed: {total_tests - passed_tests}")
40
  print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
41
 
 
42
  if __name__ == "__main__":
43
  test_pii_detection()
 
 
 
1
  import weave
2
 
3
+ from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import (
4
+ EDGE_CASE_EXAMPLES,
5
+ PII_TEST_EXAMPLES,
6
+ run_test_case,
7
+ )
8
+ from guardrails_genie.guardrails.entity_recognition.transformers_entity_recognition_guardrail import (
9
+ TransformersEntityRecognitionGuardrail,
10
+ )
11
+
12
+
13
  def test_pii_detection():
14
  """Test PII detection scenarios using predefined test cases"""
15
  weave.init("guardrails-genie-pii-transformers-pipeline-model")
16
+
17
  # Create the guardrail with default entities and anonymization enabled
18
  pii_guardrail = TransformersEntityRecognitionGuardrail(
19
+ selected_entities=[
20
+ "GIVENNAME",
21
+ "SURNAME",
22
+ "EMAIL",
23
+ "TELEPHONENUM",
24
+ "SOCIALNUM",
25
+ ],
26
  should_anonymize=True,
27
+ show_available_entities=True,
28
  )
29
 
30
  # Test statistics
 
53
  print(f"Failed: {total_tests - passed_tests}")
54
  print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
55
 
56
+
57
  if __name__ == "__main__":
58
  test_pii_detection()
guardrails_genie/guardrails/entity_recognition/presidio_entity_recognition_guardrail.py CHANGED
@@ -1,12 +1,18 @@
1
- from typing import List, Dict, Optional, ClassVar, Any
2
- import weave
3
- from pydantic import BaseModel
4
 
5
- from presidio_analyzer import AnalyzerEngine, RecognizerRegistry, Pattern, PatternRecognizer
 
 
 
 
 
 
6
  from presidio_anonymizer import AnonymizerEngine
 
7
 
8
  from ..base import Guardrail
9
 
 
10
  class PresidioEntityRecognitionResponse(BaseModel):
11
  contains_entities: bool
12
  detected_entities: Dict[str, List[str]]
@@ -17,6 +23,7 @@ class PresidioEntityRecognitionResponse(BaseModel):
17
  def safe(self) -> bool:
18
  return not self.contains_entities
19
 
 
20
  class PresidioEntityRecognitionSimpleResponse(BaseModel):
21
  contains_entities: bool
22
  explanation: str
@@ -26,21 +33,24 @@ class PresidioEntityRecognitionSimpleResponse(BaseModel):
26
  def safe(self) -> bool:
27
  return not self.contains_entities
28
 
29
- #TODO: Add support for transformers workflow and not just Spacy
 
30
  class PresidioEntityRecognitionGuardrail(Guardrail):
31
  @staticmethod
32
  def get_available_entities() -> List[str]:
33
  registry = RecognizerRegistry()
34
  analyzer = AnalyzerEngine(registry=registry)
35
- return [recognizer.supported_entities[0]
36
- for recognizer in analyzer.registry.recognizers]
37
-
 
 
38
  analyzer: AnalyzerEngine
39
  anonymizer: AnonymizerEngine
40
  selected_entities: List[str]
41
  should_anonymize: bool
42
  language: str
43
-
44
  def __init__(
45
  self,
46
  selected_entities: Optional[List[str]] = None,
@@ -49,7 +59,7 @@ class PresidioEntityRecognitionGuardrail(Guardrail):
49
  deny_lists: Optional[Dict[str, List[str]]] = None,
50
  regex_patterns: Optional[Dict[str, List[Dict[str, str]]]] = None,
51
  custom_recognizers: Optional[List[Any]] = None,
52
- show_available_entities: bool = False
53
  ):
54
  # If show_available_entities is True, print available entities
55
  if show_available_entities:
@@ -63,36 +73,37 @@ class PresidioEntityRecognitionGuardrail(Guardrail):
63
  # Initialize default values to all available entities
64
  if selected_entities is None:
65
  selected_entities = self.get_available_entities()
66
-
67
  # Get available entities dynamically
68
  available_entities = self.get_available_entities()
69
-
70
  # Filter out invalid entities and warn user
71
  invalid_entities = [e for e in selected_entities if e not in available_entities]
72
  valid_entities = [e for e in selected_entities if e in available_entities]
73
-
74
  if invalid_entities:
75
- print(f"\nWarning: The following entities are not available and will be ignored: {invalid_entities}")
 
 
76
  print(f"Continuing with valid entities: {valid_entities}")
77
  selected_entities = valid_entities
78
-
79
  # Initialize analyzer with default recognizers
80
  analyzer = AnalyzerEngine()
81
-
82
  # Add custom recognizers if provided
83
  if custom_recognizers:
84
  for recognizer in custom_recognizers:
85
  analyzer.registry.add_recognizer(recognizer)
86
-
87
  # Add deny list recognizers if provided
88
  if deny_lists:
89
  for entity_type, tokens in deny_lists.items():
90
  deny_list_recognizer = PatternRecognizer(
91
- supported_entity=entity_type,
92
- deny_list=tokens
93
  )
94
  analyzer.registry.add_recognizer(deny_list_recognizer)
95
-
96
  # Add regex pattern recognizers if provided
97
  if regex_patterns:
98
  for entity_type, patterns in regex_patterns.items():
@@ -100,89 +111,92 @@ class PresidioEntityRecognitionGuardrail(Guardrail):
100
  Pattern(
101
  name=pattern.get("name", f"pattern_{i}"),
102
  regex=pattern["regex"],
103
- score=pattern.get("score", 0.5)
104
- ) for i, pattern in enumerate(patterns)
 
105
  ]
106
  regex_recognizer = PatternRecognizer(
107
- supported_entity=entity_type,
108
- patterns=presidio_patterns
109
  )
110
  analyzer.registry.add_recognizer(regex_recognizer)
111
-
112
  # Initialize Presidio engines
113
  anonymizer = AnonymizerEngine()
114
-
115
  # Call parent class constructor with all fields
116
  super().__init__(
117
  analyzer=analyzer,
118
  anonymizer=anonymizer,
119
  selected_entities=selected_entities,
120
  should_anonymize=should_anonymize,
121
- language=language
122
  )
123
 
124
  @weave.op()
125
- def guard(self, prompt: str, return_detected_types: bool = True, **kwargs) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse:
 
 
126
  """
127
  Check if the input prompt contains any entities using Presidio.
128
-
129
  Args:
130
  prompt: The text to analyze
131
  return_detected_types: If True, returns detailed entity type information
132
  """
133
  # Analyze text for entities
134
  analyzer_results = self.analyzer.analyze(
135
- text=str(prompt),
136
- entities=self.selected_entities,
137
- language=self.language
138
  )
139
-
140
  # Group results by entity type
141
  detected_entities = {}
142
  for result in analyzer_results:
143
  entity_type = result.entity_type
144
- text_slice = prompt[result.start:result.end]
145
  if entity_type not in detected_entities:
146
  detected_entities[entity_type] = []
147
  detected_entities[entity_type].append(text_slice)
148
-
149
  # Create explanation
150
  explanation_parts = []
151
  if detected_entities:
152
  explanation_parts.append("Found the following entities in the text:")
153
  for entity_type, instances in detected_entities.items():
154
- explanation_parts.append(f"- {entity_type}: {len(instances)} instance(s)")
 
 
155
  else:
156
  explanation_parts.append("No entities detected in the text.")
157
-
158
  # Add information about what was checked
159
  explanation_parts.append("\nChecked for these entity types:")
160
  for entity in self.selected_entities:
161
  explanation_parts.append(f"- {entity}")
162
-
163
  # Anonymize if requested
164
  anonymized_text = None
165
  if self.should_anonymize and detected_entities:
166
  anonymized_result = self.anonymizer.anonymize(
167
- text=prompt,
168
- analyzer_results=analyzer_results
169
  )
170
  anonymized_text = anonymized_result.text
171
-
172
  if return_detected_types:
173
  return PresidioEntityRecognitionResponse(
174
  contains_entities=bool(detected_entities),
175
  detected_entities=detected_entities,
176
  explanation="\n".join(explanation_parts),
177
- anonymized_text=anonymized_text
178
  )
179
  else:
180
  return PresidioEntityRecognitionSimpleResponse(
181
  contains_entities=bool(detected_entities),
182
  explanation="\n".join(explanation_parts),
183
- anonymized_text=anonymized_text
184
  )
185
-
186
  @weave.op()
187
- def predict(self, prompt: str, return_detected_types: bool = True, **kwargs) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse:
188
- return self.guard(prompt, return_detected_types=return_detected_types, **kwargs)
 
 
 
1
+ from typing import Any, Dict, List, Optional
 
 
2
 
3
+ import weave
4
+ from presidio_analyzer import (
5
+ AnalyzerEngine,
6
+ Pattern,
7
+ PatternRecognizer,
8
+ RecognizerRegistry,
9
+ )
10
  from presidio_anonymizer import AnonymizerEngine
11
+ from pydantic import BaseModel
12
 
13
  from ..base import Guardrail
14
 
15
+
16
  class PresidioEntityRecognitionResponse(BaseModel):
17
  contains_entities: bool
18
  detected_entities: Dict[str, List[str]]
 
23
  def safe(self) -> bool:
24
  return not self.contains_entities
25
 
26
+
27
  class PresidioEntityRecognitionSimpleResponse(BaseModel):
28
  contains_entities: bool
29
  explanation: str
 
33
  def safe(self) -> bool:
34
  return not self.contains_entities
35
 
36
+
37
+ # TODO: Add support for transformers workflow and not just Spacy
38
  class PresidioEntityRecognitionGuardrail(Guardrail):
39
  @staticmethod
40
  def get_available_entities() -> List[str]:
41
  registry = RecognizerRegistry()
42
  analyzer = AnalyzerEngine(registry=registry)
43
+ return [
44
+ recognizer.supported_entities[0]
45
+ for recognizer in analyzer.registry.recognizers
46
+ ]
47
+
48
  analyzer: AnalyzerEngine
49
  anonymizer: AnonymizerEngine
50
  selected_entities: List[str]
51
  should_anonymize: bool
52
  language: str
53
+
54
  def __init__(
55
  self,
56
  selected_entities: Optional[List[str]] = None,
 
59
  deny_lists: Optional[Dict[str, List[str]]] = None,
60
  regex_patterns: Optional[Dict[str, List[Dict[str, str]]]] = None,
61
  custom_recognizers: Optional[List[Any]] = None,
62
+ show_available_entities: bool = False,
63
  ):
64
  # If show_available_entities is True, print available entities
65
  if show_available_entities:
 
73
  # Initialize default values to all available entities
74
  if selected_entities is None:
75
  selected_entities = self.get_available_entities()
76
+
77
  # Get available entities dynamically
78
  available_entities = self.get_available_entities()
79
+
80
  # Filter out invalid entities and warn user
81
  invalid_entities = [e for e in selected_entities if e not in available_entities]
82
  valid_entities = [e for e in selected_entities if e in available_entities]
83
+
84
  if invalid_entities:
85
+ print(
86
+ f"\nWarning: The following entities are not available and will be ignored: {invalid_entities}"
87
+ )
88
  print(f"Continuing with valid entities: {valid_entities}")
89
  selected_entities = valid_entities
90
+
91
  # Initialize analyzer with default recognizers
92
  analyzer = AnalyzerEngine()
93
+
94
  # Add custom recognizers if provided
95
  if custom_recognizers:
96
  for recognizer in custom_recognizers:
97
  analyzer.registry.add_recognizer(recognizer)
98
+
99
  # Add deny list recognizers if provided
100
  if deny_lists:
101
  for entity_type, tokens in deny_lists.items():
102
  deny_list_recognizer = PatternRecognizer(
103
+ supported_entity=entity_type, deny_list=tokens
 
104
  )
105
  analyzer.registry.add_recognizer(deny_list_recognizer)
106
+
107
  # Add regex pattern recognizers if provided
108
  if regex_patterns:
109
  for entity_type, patterns in regex_patterns.items():
 
111
  Pattern(
112
  name=pattern.get("name", f"pattern_{i}"),
113
  regex=pattern["regex"],
114
+ score=pattern.get("score", 0.5),
115
+ )
116
+ for i, pattern in enumerate(patterns)
117
  ]
118
  regex_recognizer = PatternRecognizer(
119
+ supported_entity=entity_type, patterns=presidio_patterns
 
120
  )
121
  analyzer.registry.add_recognizer(regex_recognizer)
122
+
123
  # Initialize Presidio engines
124
  anonymizer = AnonymizerEngine()
125
+
126
  # Call parent class constructor with all fields
127
  super().__init__(
128
  analyzer=analyzer,
129
  anonymizer=anonymizer,
130
  selected_entities=selected_entities,
131
  should_anonymize=should_anonymize,
132
+ language=language,
133
  )
134
 
135
  @weave.op()
136
+ def guard(
137
+ self, prompt: str, return_detected_types: bool = True, **kwargs
138
+ ) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse:
139
  """
140
  Check if the input prompt contains any entities using Presidio.
141
+
142
  Args:
143
  prompt: The text to analyze
144
  return_detected_types: If True, returns detailed entity type information
145
  """
146
  # Analyze text for entities
147
  analyzer_results = self.analyzer.analyze(
148
+ text=str(prompt), entities=self.selected_entities, language=self.language
 
 
149
  )
150
+
151
  # Group results by entity type
152
  detected_entities = {}
153
  for result in analyzer_results:
154
  entity_type = result.entity_type
155
+ text_slice = prompt[result.start : result.end]
156
  if entity_type not in detected_entities:
157
  detected_entities[entity_type] = []
158
  detected_entities[entity_type].append(text_slice)
159
+
160
  # Create explanation
161
  explanation_parts = []
162
  if detected_entities:
163
  explanation_parts.append("Found the following entities in the text:")
164
  for entity_type, instances in detected_entities.items():
165
+ explanation_parts.append(
166
+ f"- {entity_type}: {len(instances)} instance(s)"
167
+ )
168
  else:
169
  explanation_parts.append("No entities detected in the text.")
170
+
171
  # Add information about what was checked
172
  explanation_parts.append("\nChecked for these entity types:")
173
  for entity in self.selected_entities:
174
  explanation_parts.append(f"- {entity}")
175
+
176
  # Anonymize if requested
177
  anonymized_text = None
178
  if self.should_anonymize and detected_entities:
179
  anonymized_result = self.anonymizer.anonymize(
180
+ text=prompt, analyzer_results=analyzer_results
 
181
  )
182
  anonymized_text = anonymized_result.text
183
+
184
  if return_detected_types:
185
  return PresidioEntityRecognitionResponse(
186
  contains_entities=bool(detected_entities),
187
  detected_entities=detected_entities,
188
  explanation="\n".join(explanation_parts),
189
+ anonymized_text=anonymized_text,
190
  )
191
  else:
192
  return PresidioEntityRecognitionSimpleResponse(
193
  contains_entities=bool(detected_entities),
194
  explanation="\n".join(explanation_parts),
195
+ anonymized_text=anonymized_text,
196
  )
197
+
198
  @weave.op()
199
+ def predict(
200
+ self, prompt: str, return_detected_types: bool = True, **kwargs
201
+ ) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse:
202
+ return self.guard(prompt, return_detected_types=return_detected_types, **kwargs)
guardrails_genie/guardrails/entity_recognition/regex_entity_recognition_guardrail.py CHANGED
@@ -1,11 +1,11 @@
1
- from typing import Dict, Optional, ClassVar, List
 
2
 
3
  import weave
4
  from pydantic import BaseModel
5
 
6
  from ...regex_model import RegexModel
7
  from ..base import Guardrail
8
- import re
9
 
10
 
11
  class RegexEntityRecognitionResponse(BaseModel):
@@ -33,28 +33,34 @@ class RegexEntityRecognitionGuardrail(Guardrail):
33
  regex_model: RegexModel
34
  patterns: Dict[str, str] = {}
35
  should_anonymize: bool = False
36
-
37
  DEFAULT_PATTERNS: ClassVar[Dict[str, str]] = {
38
- "EMAIL": r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
39
- "TELEPHONENUM": r'\b(\+\d{1,3}[-.]?)?\(?\d{3}\)?[-.]?\d{3}[-.]?\d{4}\b',
40
- "SOCIALNUM": r'\b\d{3}[-]?\d{2}[-]?\d{4}\b',
41
- "CREDITCARDNUMBER": r'\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b',
42
- "DATEOFBIRTH": r'\b(0[1-9]|1[0-2])[-/](0[1-9]|[12]\d|3[01])[-/](19|20)\d{2}\b',
43
- "DRIVERLICENSENUM": r'[A-Z]\d{7}', # Example pattern, adjust for your needs
44
- "ACCOUNTNUM": r'\b\d{10,12}\b', # Example pattern for bank accounts
45
- "ZIPCODE": r'\b\d{5}(?:-\d{4})?\b',
46
- "GIVENNAME": r'\b[A-Z][a-z]+\b', # Basic pattern for first names
47
- "SURNAME": r'\b[A-Z][a-z]+\b', # Basic pattern for last names
48
- "CITY": r'\b[A-Z][a-z]+(?:[\s-][A-Z][a-z]+)*\b',
49
- "STREET": r'\b\d+\s+[A-Z][a-z]+\s+(?:Street|St|Avenue|Ave|Road|Rd|Boulevard|Blvd|Lane|Ln|Drive|Dr)\b',
50
- "IDCARDNUM": r'[A-Z]\d{7,8}', # Generic pattern for ID cards
51
- "USERNAME": r'@[A-Za-z]\w{3,}', # Basic username pattern
52
- "PASSWORD": r'[A-Za-z0-9@#$%^&+=]{8,}', # Basic password pattern
53
- "TAXNUM": r'\b\d{2}[-]\d{7}\b', # Example tax number pattern
54
- "BUILDINGNUM": r'\b\d+[A-Za-z]?\b' # Basic building number pattern
55
  }
56
-
57
- def __init__(self, use_defaults: bool = True, should_anonymize: bool = False, show_available_entities: bool = False, **kwargs):
 
 
 
 
 
 
58
  patterns = {}
59
  if use_defaults:
60
  patterns = self.DEFAULT_PATTERNS.copy()
@@ -63,15 +69,15 @@ class RegexEntityRecognitionGuardrail(Guardrail):
63
 
64
  if show_available_entities:
65
  self._print_available_entities(patterns.keys())
66
-
67
  # Create the RegexModel instance
68
  regex_model = RegexModel(patterns=patterns)
69
-
70
  # Initialize the base class with both the regex_model and patterns
71
  super().__init__(
72
- regex_model=regex_model,
73
  patterns=patterns,
74
- should_anonymize=should_anonymize
75
  )
76
 
77
  def text_to_pattern(self, text: str) -> str:
@@ -82,7 +88,7 @@ class RegexEntityRecognitionGuardrail(Guardrail):
82
  escaped_text = re.escape(text)
83
  # Create a pattern that matches the exact text, case-insensitive
84
  return rf"\b{escaped_text}\b"
85
-
86
  def _print_available_entities(self, entities: List[str]):
87
  """Print available entities"""
88
  print("\nAvailable entity types:")
@@ -92,16 +98,23 @@ class RegexEntityRecognitionGuardrail(Guardrail):
92
  print("=" * 25 + "\n")
93
 
94
  @weave.op()
95
- def guard(self, prompt: str, custom_terms: Optional[list[str]] = None, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
 
 
 
 
 
 
 
96
  """
97
  Check if the input prompt contains any entities based on the regex patterns.
98
-
99
  Args:
100
  prompt: Input text to check for entities
101
- custom_terms: List of custom terms to be converted into regex patterns. If provided,
102
  only these terms will be checked, ignoring default patterns.
103
  return_detected_types: If True, returns detailed entity type information
104
-
105
  Returns:
106
  RegexEntityRecognitionResponse or RegexEntityRecognitionSimpleResponse containing detection results
107
  """
@@ -113,7 +126,7 @@ class RegexEntityRecognitionGuardrail(Guardrail):
113
  else:
114
  # Use the original regex_model if no custom terms provided
115
  result = self.regex_model.check(prompt)
116
-
117
  # Create detailed explanation
118
  explanation_parts = []
119
  if result.matched_patterns:
@@ -122,35 +135,50 @@ class RegexEntityRecognitionGuardrail(Guardrail):
122
  explanation_parts.append(f"- {entity_type}: {len(matches)} instance(s)")
123
  else:
124
  explanation_parts.append("No entities detected in the text.")
125
-
126
  if result.failed_patterns:
127
  explanation_parts.append("\nChecked but did not find these entity types:")
128
  for pattern in result.failed_patterns:
129
  explanation_parts.append(f"- {pattern}")
130
-
131
  # Updated anonymization logic
132
  anonymized_text = None
133
- if getattr(self, 'should_anonymize', False) and result.matched_patterns:
134
  anonymized_text = prompt
135
  for entity_type, matches in result.matched_patterns.items():
136
  for match in matches:
137
- replacement = "[redacted]" if aggregate_redaction else f"[{entity_type.upper()}]"
 
 
 
 
138
  anonymized_text = anonymized_text.replace(match, replacement)
139
-
140
  if return_detected_types:
141
  return RegexEntityRecognitionResponse(
142
  contains_entities=not result.passed,
143
  detected_entities=result.matched_patterns,
144
  explanation="\n".join(explanation_parts),
145
- anonymized_text=anonymized_text
146
  )
147
  else:
148
  return RegexEntityRecognitionSimpleResponse(
149
  contains_entities=not result.passed,
150
  explanation="\n".join(explanation_parts),
151
- anonymized_text=anonymized_text
152
  )
153
 
154
  @weave.op()
155
- def predict(self, prompt: str, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
156
- return self.guard(prompt, return_detected_types=return_detected_types, aggregate_redaction=aggregate_redaction, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import ClassVar, Dict, List, Optional
3
 
4
  import weave
5
  from pydantic import BaseModel
6
 
7
  from ...regex_model import RegexModel
8
  from ..base import Guardrail
 
9
 
10
 
11
  class RegexEntityRecognitionResponse(BaseModel):
 
33
  regex_model: RegexModel
34
  patterns: Dict[str, str] = {}
35
  should_anonymize: bool = False
36
+
37
  DEFAULT_PATTERNS: ClassVar[Dict[str, str]] = {
38
+ "EMAIL": r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
39
+ "TELEPHONENUM": r"\b(\+\d{1,3}[-.]?)?\(?\d{3}\)?[-.]?\d{3}[-.]?\d{4}\b",
40
+ "SOCIALNUM": r"\b\d{3}[-]?\d{2}[-]?\d{4}\b",
41
+ "CREDITCARDNUMBER": r"\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b",
42
+ "DATEOFBIRTH": r"\b(0[1-9]|1[0-2])[-/](0[1-9]|[12]\d|3[01])[-/](19|20)\d{2}\b",
43
+ "DRIVERLICENSENUM": r"[A-Z]\d{7}", # Example pattern, adjust for your needs
44
+ "ACCOUNTNUM": r"\b\d{10,12}\b", # Example pattern for bank accounts
45
+ "ZIPCODE": r"\b\d{5}(?:-\d{4})?\b",
46
+ "GIVENNAME": r"\b[A-Z][a-z]+\b", # Basic pattern for first names
47
+ "SURNAME": r"\b[A-Z][a-z]+\b", # Basic pattern for last names
48
+ "CITY": r"\b[A-Z][a-z]+(?:[\s-][A-Z][a-z]+)*\b",
49
+ "STREET": r"\b\d+\s+[A-Z][a-z]+\s+(?:Street|St|Avenue|Ave|Road|Rd|Boulevard|Blvd|Lane|Ln|Drive|Dr)\b",
50
+ "IDCARDNUM": r"[A-Z]\d{7,8}", # Generic pattern for ID cards
51
+ "USERNAME": r"@[A-Za-z]\w{3,}", # Basic username pattern
52
+ "PASSWORD": r"[A-Za-z0-9@#$%^&+=]{8,}", # Basic password pattern
53
+ "TAXNUM": r"\b\d{2}[-]\d{7}\b", # Example tax number pattern
54
+ "BUILDINGNUM": r"\b\d+[A-Za-z]?\b", # Basic building number pattern
55
  }
56
+
57
+ def __init__(
58
+ self,
59
+ use_defaults: bool = True,
60
+ should_anonymize: bool = False,
61
+ show_available_entities: bool = False,
62
+ **kwargs,
63
+ ):
64
  patterns = {}
65
  if use_defaults:
66
  patterns = self.DEFAULT_PATTERNS.copy()
 
69
 
70
  if show_available_entities:
71
  self._print_available_entities(patterns.keys())
72
+
73
  # Create the RegexModel instance
74
  regex_model = RegexModel(patterns=patterns)
75
+
76
  # Initialize the base class with both the regex_model and patterns
77
  super().__init__(
78
+ regex_model=regex_model,
79
  patterns=patterns,
80
+ should_anonymize=should_anonymize,
81
  )
82
 
83
  def text_to_pattern(self, text: str) -> str:
 
88
  escaped_text = re.escape(text)
89
  # Create a pattern that matches the exact text, case-insensitive
90
  return rf"\b{escaped_text}\b"
91
+
92
  def _print_available_entities(self, entities: List[str]):
93
  """Print available entities"""
94
  print("\nAvailable entity types:")
 
98
  print("=" * 25 + "\n")
99
 
100
  @weave.op()
101
+ def guard(
102
+ self,
103
+ prompt: str,
104
+ custom_terms: Optional[list[str]] = None,
105
+ return_detected_types: bool = True,
106
+ aggregate_redaction: bool = True,
107
+ **kwargs,
108
+ ) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
109
  """
110
  Check if the input prompt contains any entities based on the regex patterns.
111
+
112
  Args:
113
  prompt: Input text to check for entities
114
+ custom_terms: List of custom terms to be converted into regex patterns. If provided,
115
  only these terms will be checked, ignoring default patterns.
116
  return_detected_types: If True, returns detailed entity type information
117
+
118
  Returns:
119
  RegexEntityRecognitionResponse or RegexEntityRecognitionSimpleResponse containing detection results
120
  """
 
126
  else:
127
  # Use the original regex_model if no custom terms provided
128
  result = self.regex_model.check(prompt)
129
+
130
  # Create detailed explanation
131
  explanation_parts = []
132
  if result.matched_patterns:
 
135
  explanation_parts.append(f"- {entity_type}: {len(matches)} instance(s)")
136
  else:
137
  explanation_parts.append("No entities detected in the text.")
138
+
139
  if result.failed_patterns:
140
  explanation_parts.append("\nChecked but did not find these entity types:")
141
  for pattern in result.failed_patterns:
142
  explanation_parts.append(f"- {pattern}")
143
+
144
  # Updated anonymization logic
145
  anonymized_text = None
146
+ if getattr(self, "should_anonymize", False) and result.matched_patterns:
147
  anonymized_text = prompt
148
  for entity_type, matches in result.matched_patterns.items():
149
  for match in matches:
150
+ replacement = (
151
+ "[redacted]"
152
+ if aggregate_redaction
153
+ else f"[{entity_type.upper()}]"
154
+ )
155
  anonymized_text = anonymized_text.replace(match, replacement)
156
+
157
  if return_detected_types:
158
  return RegexEntityRecognitionResponse(
159
  contains_entities=not result.passed,
160
  detected_entities=result.matched_patterns,
161
  explanation="\n".join(explanation_parts),
162
+ anonymized_text=anonymized_text,
163
  )
164
  else:
165
  return RegexEntityRecognitionSimpleResponse(
166
  contains_entities=not result.passed,
167
  explanation="\n".join(explanation_parts),
168
+ anonymized_text=anonymized_text,
169
  )
170
 
171
  @weave.op()
172
+ def predict(
173
+ self,
174
+ prompt: str,
175
+ return_detected_types: bool = True,
176
+ aggregate_redaction: bool = True,
177
+ **kwargs,
178
+ ) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
179
+ return self.guard(
180
+ prompt,
181
+ return_detected_types=return_detected_types,
182
+ aggregate_redaction=aggregate_redaction,
183
+ **kwargs,
184
+ )
guardrails_genie/guardrails/entity_recognition/transformers_entity_recognition_guardrail.py CHANGED
@@ -1,9 +1,11 @@
1
- from typing import List, Dict, Optional, ClassVar
2
- from transformers import pipeline, AutoConfig
3
- import json
4
  from pydantic import BaseModel
 
 
5
  from ..base import Guardrail
6
- import weave
7
 
8
  class TransformersEntityRecognitionResponse(BaseModel):
9
  contains_entities: bool
@@ -15,6 +17,7 @@ class TransformersEntityRecognitionResponse(BaseModel):
15
  def safe(self) -> bool:
16
  return not self.contains_entities
17
 
 
18
  class TransformersEntityRecognitionSimpleResponse(BaseModel):
19
  contains_entities: bool
20
  explanation: str
@@ -24,14 +27,15 @@ class TransformersEntityRecognitionSimpleResponse(BaseModel):
24
  def safe(self) -> bool:
25
  return not self.contains_entities
26
 
 
27
  class TransformersEntityRecognitionGuardrail(Guardrail):
28
  """Generic guardrail for detecting entities using any token classification model."""
29
-
30
  _pipeline: Optional[object] = None
31
  selected_entities: List[str]
32
  should_anonymize: bool
33
  available_entities: List[str]
34
-
35
  def __init__(
36
  self,
37
  model_name: str = "iiiorg/piiranha-v1-detect-personal-information",
@@ -42,50 +46,52 @@ class TransformersEntityRecognitionGuardrail(Guardrail):
42
  # Load model config and extract available entities
43
  config = AutoConfig.from_pretrained(model_name)
44
  entities = self._extract_entities_from_config(config)
45
-
46
  if show_available_entities:
47
  self._print_available_entities(entities)
48
-
49
  # Initialize default values if needed
50
  if selected_entities is None:
51
  selected_entities = entities # Use all available entities by default
52
-
53
  # Filter out invalid entities and warn user
54
  invalid_entities = [e for e in selected_entities if e not in entities]
55
  valid_entities = [e for e in selected_entities if e in entities]
56
-
57
  if invalid_entities:
58
- print(f"\nWarning: The following entities are not available and will be ignored: {invalid_entities}")
 
 
59
  print(f"Continuing with valid entities: {valid_entities}")
60
  selected_entities = valid_entities
61
-
62
  # Call parent class constructor
63
  super().__init__(
64
  selected_entities=selected_entities,
65
  should_anonymize=should_anonymize,
66
- available_entities=entities
67
  )
68
-
69
  # Initialize pipeline
70
  self._pipeline = pipeline(
71
  task="token-classification",
72
  model=model_name,
73
- aggregation_strategy="simple" # Merge same entities
74
  )
75
 
76
  def _extract_entities_from_config(self, config) -> List[str]:
77
  """Extract unique entity types from the model config."""
78
  # Get id2label mapping from config
79
  id2label = config.id2label
80
-
81
  # Extract unique entity types (removing B- and I- prefixes)
82
  entities = set()
83
  for label in id2label.values():
84
- if label.startswith(('B-', 'I-')):
85
  entities.add(label[2:]) # Remove prefix
86
- elif label != 'O': # Skip the 'O' (Outside) label
87
  entities.add(label)
88
-
89
  return sorted(list(entities))
90
 
91
  def _print_available_entities(self, entities: List[str]):
@@ -103,48 +109,60 @@ class TransformersEntityRecognitionGuardrail(Guardrail):
103
  def _detect_entities(self, text: str) -> Dict[str, List[str]]:
104
  """Detect entities in the text using the pipeline."""
105
  results = self._pipeline(text)
106
-
107
  # Group findings by entity type
108
  detected_entities = {}
109
  for entity in results:
110
- entity_type = entity['entity_group']
111
  if entity_type in self.selected_entities:
112
  if entity_type not in detected_entities:
113
  detected_entities[entity_type] = []
114
- detected_entities[entity_type].append(entity['word'])
115
-
116
  return detected_entities
117
 
118
  def _anonymize_text(self, text: str, aggregate_redaction: bool = True) -> str:
119
  """Anonymize detected entities in text using the pipeline."""
120
  results = self._pipeline(text)
121
-
122
  # Sort entities by start position in reverse order to avoid offset issues
123
- entities = sorted(results, key=lambda x: x['start'], reverse=True)
124
-
125
  # Create a mutable list of characters
126
  chars = list(text)
127
-
128
  # Apply redactions
129
  for entity in entities:
130
- if entity['entity_group'] in self.selected_entities:
131
- start, end = entity['start'], entity['end']
132
- replacement = ' [redacted] ' if aggregate_redaction else f" [{entity['entity_group']}] "
133
-
 
 
 
 
134
  # Replace the entity with the redaction marker
135
  chars[start:end] = replacement
136
-
137
  # Join characters and clean up only consecutive spaces (preserving newlines)
138
- result = ''.join(chars)
139
  # Replace multiple spaces with single space, but preserve newlines
140
- lines = result.split('\n')
141
- cleaned_lines = [' '.join(line.split()) for line in lines]
142
- return '\n'.join(cleaned_lines)
143
 
144
  @weave.op()
145
- def guard(self, prompt: str, return_detected_types: bool = True, aggregate_redaction: bool = True) -> TransformersEntityRecognitionResponse | TransformersEntityRecognitionSimpleResponse:
 
 
 
 
 
 
 
 
146
  """Check if the input prompt contains any entities using the transformer pipeline.
147
-
148
  Args:
149
  prompt: The text to analyze
150
  return_detected_types: If True, returns detailed entity type information
@@ -152,39 +170,55 @@ class TransformersEntityRecognitionGuardrail(Guardrail):
152
  """
153
  # Detect entities
154
  detected_entities = self._detect_entities(prompt)
155
-
156
  # Create explanation
157
  explanation_parts = []
158
  if detected_entities:
159
  explanation_parts.append("Found the following entities in the text:")
160
  for entity_type, instances in detected_entities.items():
161
- explanation_parts.append(f"- {entity_type}: {len(instances)} instance(s)")
 
 
162
  else:
163
  explanation_parts.append("No entities detected in the text.")
164
-
165
  explanation_parts.append("\nChecked for these entities:")
166
  for entity in self.selected_entities:
167
  explanation_parts.append(f"- {entity}")
168
-
169
  # Anonymize if requested
170
  anonymized_text = None
171
  if self.should_anonymize and detected_entities:
172
  anonymized_text = self._anonymize_text(prompt, aggregate_redaction)
173
-
174
  if return_detected_types:
175
  return TransformersEntityRecognitionResponse(
176
  contains_entities=bool(detected_entities),
177
  detected_entities=detected_entities,
178
  explanation="\n".join(explanation_parts),
179
- anonymized_text=anonymized_text
180
  )
181
  else:
182
  return TransformersEntityRecognitionSimpleResponse(
183
  contains_entities=bool(detected_entities),
184
  explanation="\n".join(explanation_parts),
185
- anonymized_text=anonymized_text
186
  )
187
 
188
  @weave.op()
189
- def predict(self, prompt: str, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> TransformersEntityRecognitionResponse | TransformersEntityRecognitionSimpleResponse:
190
- return self.guard(prompt, return_detected_types=return_detected_types, aggregate_redaction=aggregate_redaction, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+
3
+ import weave
4
  from pydantic import BaseModel
5
+ from transformers import AutoConfig, pipeline
6
+
7
  from ..base import Guardrail
8
+
9
 
10
  class TransformersEntityRecognitionResponse(BaseModel):
11
  contains_entities: bool
 
17
  def safe(self) -> bool:
18
  return not self.contains_entities
19
 
20
+
21
  class TransformersEntityRecognitionSimpleResponse(BaseModel):
22
  contains_entities: bool
23
  explanation: str
 
27
  def safe(self) -> bool:
28
  return not self.contains_entities
29
 
30
+
31
  class TransformersEntityRecognitionGuardrail(Guardrail):
32
  """Generic guardrail for detecting entities using any token classification model."""
33
+
34
  _pipeline: Optional[object] = None
35
  selected_entities: List[str]
36
  should_anonymize: bool
37
  available_entities: List[str]
38
+
39
  def __init__(
40
  self,
41
  model_name: str = "iiiorg/piiranha-v1-detect-personal-information",
 
46
  # Load model config and extract available entities
47
  config = AutoConfig.from_pretrained(model_name)
48
  entities = self._extract_entities_from_config(config)
49
+
50
  if show_available_entities:
51
  self._print_available_entities(entities)
52
+
53
  # Initialize default values if needed
54
  if selected_entities is None:
55
  selected_entities = entities # Use all available entities by default
56
+
57
  # Filter out invalid entities and warn user
58
  invalid_entities = [e for e in selected_entities if e not in entities]
59
  valid_entities = [e for e in selected_entities if e in entities]
60
+
61
  if invalid_entities:
62
+ print(
63
+ f"\nWarning: The following entities are not available and will be ignored: {invalid_entities}"
64
+ )
65
  print(f"Continuing with valid entities: {valid_entities}")
66
  selected_entities = valid_entities
67
+
68
  # Call parent class constructor
69
  super().__init__(
70
  selected_entities=selected_entities,
71
  should_anonymize=should_anonymize,
72
+ available_entities=entities,
73
  )
74
+
75
  # Initialize pipeline
76
  self._pipeline = pipeline(
77
  task="token-classification",
78
  model=model_name,
79
+ aggregation_strategy="simple", # Merge same entities
80
  )
81
 
82
  def _extract_entities_from_config(self, config) -> List[str]:
83
  """Extract unique entity types from the model config."""
84
  # Get id2label mapping from config
85
  id2label = config.id2label
86
+
87
  # Extract unique entity types (removing B- and I- prefixes)
88
  entities = set()
89
  for label in id2label.values():
90
+ if label.startswith(("B-", "I-")):
91
  entities.add(label[2:]) # Remove prefix
92
+ elif label != "O": # Skip the 'O' (Outside) label
93
  entities.add(label)
94
+
95
  return sorted(list(entities))
96
 
97
  def _print_available_entities(self, entities: List[str]):
 
109
  def _detect_entities(self, text: str) -> Dict[str, List[str]]:
110
  """Detect entities in the text using the pipeline."""
111
  results = self._pipeline(text)
112
+
113
  # Group findings by entity type
114
  detected_entities = {}
115
  for entity in results:
116
+ entity_type = entity["entity_group"]
117
  if entity_type in self.selected_entities:
118
  if entity_type not in detected_entities:
119
  detected_entities[entity_type] = []
120
+ detected_entities[entity_type].append(entity["word"])
121
+
122
  return detected_entities
123
 
124
  def _anonymize_text(self, text: str, aggregate_redaction: bool = True) -> str:
125
  """Anonymize detected entities in text using the pipeline."""
126
  results = self._pipeline(text)
127
+
128
  # Sort entities by start position in reverse order to avoid offset issues
129
+ entities = sorted(results, key=lambda x: x["start"], reverse=True)
130
+
131
  # Create a mutable list of characters
132
  chars = list(text)
133
+
134
  # Apply redactions
135
  for entity in entities:
136
+ if entity["entity_group"] in self.selected_entities:
137
+ start, end = entity["start"], entity["end"]
138
+ replacement = (
139
+ " [redacted] "
140
+ if aggregate_redaction
141
+ else f" [{entity['entity_group']}] "
142
+ )
143
+
144
  # Replace the entity with the redaction marker
145
  chars[start:end] = replacement
146
+
147
  # Join characters and clean up only consecutive spaces (preserving newlines)
148
+ result = "".join(chars)
149
  # Replace multiple spaces with single space, but preserve newlines
150
+ lines = result.split("\n")
151
+ cleaned_lines = [" ".join(line.split()) for line in lines]
152
+ return "\n".join(cleaned_lines)
153
 
154
  @weave.op()
155
+ def guard(
156
+ self,
157
+ prompt: str,
158
+ return_detected_types: bool = True,
159
+ aggregate_redaction: bool = True,
160
+ ) -> (
161
+ TransformersEntityRecognitionResponse
162
+ | TransformersEntityRecognitionSimpleResponse
163
+ ):
164
  """Check if the input prompt contains any entities using the transformer pipeline.
165
+
166
  Args:
167
  prompt: The text to analyze
168
  return_detected_types: If True, returns detailed entity type information
 
170
  """
171
  # Detect entities
172
  detected_entities = self._detect_entities(prompt)
173
+
174
  # Create explanation
175
  explanation_parts = []
176
  if detected_entities:
177
  explanation_parts.append("Found the following entities in the text:")
178
  for entity_type, instances in detected_entities.items():
179
+ explanation_parts.append(
180
+ f"- {entity_type}: {len(instances)} instance(s)"
181
+ )
182
  else:
183
  explanation_parts.append("No entities detected in the text.")
184
+
185
  explanation_parts.append("\nChecked for these entities:")
186
  for entity in self.selected_entities:
187
  explanation_parts.append(f"- {entity}")
188
+
189
  # Anonymize if requested
190
  anonymized_text = None
191
  if self.should_anonymize and detected_entities:
192
  anonymized_text = self._anonymize_text(prompt, aggregate_redaction)
193
+
194
  if return_detected_types:
195
  return TransformersEntityRecognitionResponse(
196
  contains_entities=bool(detected_entities),
197
  detected_entities=detected_entities,
198
  explanation="\n".join(explanation_parts),
199
+ anonymized_text=anonymized_text,
200
  )
201
  else:
202
  return TransformersEntityRecognitionSimpleResponse(
203
  contains_entities=bool(detected_entities),
204
  explanation="\n".join(explanation_parts),
205
+ anonymized_text=anonymized_text,
206
  )
207
 
208
  @weave.op()
209
+ def predict(
210
+ self,
211
+ prompt: str,
212
+ return_detected_types: bool = True,
213
+ aggregate_redaction: bool = True,
214
+ **kwargs,
215
+ ) -> (
216
+ TransformersEntityRecognitionResponse
217
+ | TransformersEntityRecognitionSimpleResponse
218
+ ):
219
+ return self.guard(
220
+ prompt,
221
+ return_detected_types=return_detected_types,
222
+ aggregate_redaction=aggregate_redaction,
223
+ **kwargs,
224
+ )
guardrails_genie/guardrails/manager.py CHANGED
@@ -1,6 +1,6 @@
1
  import weave
2
- from rich.progress import track
3
  from pydantic import BaseModel
 
4
 
5
  from .base import Guardrail
6
 
 
1
  import weave
 
2
  from pydantic import BaseModel
3
+ from rich.progress import track
4
 
5
  from .base import Guardrail
6
 
guardrails_genie/metrics.py CHANGED
@@ -5,12 +5,55 @@ import weave
5
 
6
 
7
  class AccuracyMetric(weave.Scorer):
 
 
 
 
 
 
 
 
 
8
  @weave.op()
9
  def score(self, output: dict, label: int):
10
- return {"correct": bool(label) == output["safe"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  @weave.op()
13
  def summarize(self, score_rows: list) -> Optional[dict]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  valid_data = [
15
  x.get("correct") for x in score_rows if x.get("correct") is not None
16
  ]
 
5
 
6
 
7
  class AccuracyMetric(weave.Scorer):
8
+ """
9
+ A class to compute and summarize accuracy-related metrics for model outputs.
10
+
11
+ This class extends the `weave.Scorer` and provides operations to score
12
+ individual predictions and summarize the results across multiple predictions.
13
+ It calculates the accuracy, precision, recall, and F1 score based on the
14
+ comparison between predicted outputs and true labels.
15
+ """
16
+
17
  @weave.op()
18
  def score(self, output: dict, label: int):
19
+ """
20
+ Evaluate the correctness of a single prediction.
21
+
22
+ This method compares a model's predicted output with the true label
23
+ to determine if the prediction is correct. It checks if the 'safe'
24
+ field in the output dictionary, when converted to an integer, matches
25
+ the provided label.
26
+
27
+ Args:
28
+ output (dict): A dictionary containing the model's prediction,
29
+ specifically the 'safe' key which holds the predicted value.
30
+ label (int): The true label against which the prediction is compared.
31
+
32
+ Returns:
33
+ dict: A dictionary with a single key 'correct', which is True if the
34
+ prediction matches the label, otherwise False.
35
+ """
36
+ return {"correct": label == int(output["safe"])}
37
 
38
  @weave.op()
39
  def summarize(self, score_rows: list) -> Optional[dict]:
40
+ """
41
+ Summarize the accuracy-related metrics from a list of prediction scores.
42
+
43
+ This method processes a list of score dictionaries, each containing a
44
+ 'correct' key indicating whether a prediction was correct. It calculates
45
+ several metrics: accuracy, precision, recall, and F1 score, based on the
46
+ number of true positives, false positives, and false negatives.
47
+
48
+ Args:
49
+ score_rows (list): A list of dictionaries, each with a 'correct' key
50
+ indicating the correctness of individual predictions.
51
+
52
+ Returns:
53
+ Optional[dict]: A dictionary containing the calculated metrics:
54
+ 'accuracy', 'precision', 'recall', and 'f1_score'. If no valid data
55
+ is present, all metrics default to 0.
56
+ """
57
  valid_data = [
58
  x.get("correct") for x in score_rows if x.get("correct") is not None
59
  ]
guardrails_genie/regex_model.py CHANGED
@@ -1,5 +1,6 @@
1
- from typing import List, Dict, Optional
2
  import re
 
 
3
  import weave
4
  from pydantic import BaseModel
5
 
@@ -16,7 +17,7 @@ class RegexModel(weave.Model):
16
  def __init__(self, patterns: Dict[str, str]) -> None:
17
  """
18
  Initialize RegexModel with a dictionary of patterns.
19
-
20
  Args:
21
  patterns: Dictionary where key is pattern name and value is regex pattern
22
  Example: {"email": r"[^@ \t\r\n]+@[^@ \t\r\n]+\.[^@ \t\r\n]+",
@@ -31,35 +32,37 @@ class RegexModel(weave.Model):
31
  def check(self, prompt: str) -> RegexResult:
32
  """
33
  Check text against all patterns and return detailed results.
34
-
35
  Args:
36
  text: Input text to check against patterns
37
-
38
  Returns:
39
  RegexResult containing pass/fail status and details about matches
40
  """
41
  matched_patterns = {}
42
  failed_patterns = []
43
-
44
  for pattern_name, pattern in self.patterns.items():
45
  matches = []
46
  for match in re.finditer(pattern, prompt):
47
  if match.groups():
48
  # If there are capture groups, join them with a separator
49
- matches.append('-'.join(str(g) for g in match.groups() if g is not None))
 
 
50
  else:
51
  # If no capture groups, use the full match
52
  matches.append(match.group(0))
53
-
54
  if matches:
55
  matched_patterns[pattern_name] = matches
56
  else:
57
  failed_patterns.append(pattern_name)
58
-
59
  return RegexResult(
60
  matched_patterns=matched_patterns,
61
  failed_patterns=failed_patterns,
62
- passed=len(matched_patterns) == 0
63
  )
64
 
65
  @weave.op()
@@ -67,4 +70,4 @@ class RegexModel(weave.Model):
67
  """
68
  Alias for check() to maintain consistency with other models.
69
  """
70
- return self.check(text)
 
 
1
  import re
2
+ from typing import Dict, List
3
+
4
  import weave
5
  from pydantic import BaseModel
6
 
 
17
  def __init__(self, patterns: Dict[str, str]) -> None:
18
  """
19
  Initialize RegexModel with a dictionary of patterns.
20
+
21
  Args:
22
  patterns: Dictionary where key is pattern name and value is regex pattern
23
  Example: {"email": r"[^@ \t\r\n]+@[^@ \t\r\n]+\.[^@ \t\r\n]+",
 
32
  def check(self, prompt: str) -> RegexResult:
33
  """
34
  Check text against all patterns and return detailed results.
35
+
36
  Args:
37
  text: Input text to check against patterns
38
+
39
  Returns:
40
  RegexResult containing pass/fail status and details about matches
41
  """
42
  matched_patterns = {}
43
  failed_patterns = []
44
+
45
  for pattern_name, pattern in self.patterns.items():
46
  matches = []
47
  for match in re.finditer(pattern, prompt):
48
  if match.groups():
49
  # If there are capture groups, join them with a separator
50
+ matches.append(
51
+ "-".join(str(g) for g in match.groups() if g is not None)
52
+ )
53
  else:
54
  # If no capture groups, use the full match
55
  matches.append(match.group(0))
56
+
57
  if matches:
58
  matched_patterns[pattern_name] = matches
59
  else:
60
  failed_patterns.append(pattern_name)
61
+
62
  return RegexResult(
63
  matched_patterns=matched_patterns,
64
  failed_patterns=failed_patterns,
65
+ passed=len(matched_patterns) == 0,
66
  )
67
 
68
  @weave.op()
 
70
  """
71
  Alias for check() to maintain consistency with other models.
72
  """
73
+ return self.check(text)
guardrails_genie/utils.py CHANGED
@@ -19,9 +19,9 @@ class EvaluationCallManager:
19
  """
20
  Manages the evaluation calls for a specific project and entity in Weave.
21
 
22
- This class is responsible for initializing and managing evaluation calls associated with a
23
- specific project and entity. It provides functionality to collect guardrail guard calls
24
- from evaluation predictions and scores, and render these calls into a structured format
25
  suitable for display in Streamlit.
26
 
27
  Args:
@@ -30,6 +30,7 @@ class EvaluationCallManager:
30
  call_id (str): The call id.
31
  max_count (int): The maximum number of guardrail guard calls to collect from the evaluation.
32
  """
 
33
  def __init__(self, entity: str, project: str, call_id: str, max_count: int = 10):
34
  self.base_call = weave.init(f"{entity}/{project}").get_call(call_id=call_id)
35
  self.max_count = max_count
@@ -40,10 +41,10 @@ class EvaluationCallManager:
40
  """
41
  Collects guardrail guard calls from evaluation predictions and scores.
42
 
43
- This function iterates through the children calls of the base evaluation call,
44
- extracting relevant guardrail guard calls and their associated scores. It stops
45
- collecting calls if it encounters an "Evaluation.summarize" operation or if the
46
- maximum count of guardrail guard calls is reached. The collected calls are stored
47
  in a list of dictionaries, each containing the input prompt, outputs, and score.
48
 
49
  Returns:
@@ -77,9 +78,9 @@ class EvaluationCallManager:
77
  Renders the collected guardrail guard calls into a pandas DataFrame suitable for
78
  display in Streamlit.
79
 
80
- This function processes the collected guardrail guard calls stored in `self.call_list` and
81
- organizes them into a dictionary format that can be easily converted into a pandas DataFrame.
82
- The DataFrame contains columns for the input prompts, the safety status of the outputs, and
83
  the correctness of the predictions for each guardrail.
84
 
85
  The structure of the DataFrame is as follows:
@@ -87,7 +88,7 @@ class EvaluationCallManager:
87
  - Subsequent columns contain the safety status and prediction correctness for each guardrail.
88
 
89
  Returns:
90
- pd.DataFrame: A DataFrame containing the input prompts, safety status, and prediction
91
  correctness for each guardrail.
92
  """
93
  dataframe = {
 
19
  """
20
  Manages the evaluation calls for a specific project and entity in Weave.
21
 
22
+ This class is responsible for initializing and managing evaluation calls associated with a
23
+ specific project and entity. It provides functionality to collect guardrail guard calls
24
+ from evaluation predictions and scores, and render these calls into a structured format
25
  suitable for display in Streamlit.
26
 
27
  Args:
 
30
  call_id (str): The call id.
31
  max_count (int): The maximum number of guardrail guard calls to collect from the evaluation.
32
  """
33
+
34
  def __init__(self, entity: str, project: str, call_id: str, max_count: int = 10):
35
  self.base_call = weave.init(f"{entity}/{project}").get_call(call_id=call_id)
36
  self.max_count = max_count
 
41
  """
42
  Collects guardrail guard calls from evaluation predictions and scores.
43
 
44
+ This function iterates through the children calls of the base evaluation call,
45
+ extracting relevant guardrail guard calls and their associated scores. It stops
46
+ collecting calls if it encounters an "Evaluation.summarize" operation or if the
47
+ maximum count of guardrail guard calls is reached. The collected calls are stored
48
  in a list of dictionaries, each containing the input prompt, outputs, and score.
49
 
50
  Returns:
 
78
  Renders the collected guardrail guard calls into a pandas DataFrame suitable for
79
  display in Streamlit.
80
 
81
+ This function processes the collected guardrail guard calls stored in `self.call_list` and
82
+ organizes them into a dictionary format that can be easily converted into a pandas DataFrame.
83
+ The DataFrame contains columns for the input prompts, the safety status of the outputs, and
84
  the correctness of the predictions for each guardrail.
85
 
86
  The structure of the DataFrame is as follows:
 
88
  - Subsequent columns contain the safety status and prediction correctness for each guardrail.
89
 
90
  Returns:
91
+ pd.DataFrame: A DataFrame containing the input prompts, safety status, and prediction
92
  correctness for each guardrail.
93
  """
94
  dataframe = {