param-bharat commited on
Commit
46466a5
·
1 Parent(s): b2417a0

feat: upgrade regex detection model to multi-pattern detection

Browse files
Files changed (1) hide show
  1. guardrails_genie/regex_model.py +37 -26
guardrails_genie/regex_model.py CHANGED
@@ -1,65 +1,76 @@
1
- from typing import List, Dict, Optional
2
- import re
 
3
  import weave
4
  from pydantic import BaseModel
5
 
6
 
7
  class RegexResult(BaseModel):
8
  passed: bool
9
- matched_patterns: Dict[str, List[str]]
10
- failed_patterns: List[str]
11
 
12
 
13
  class RegexModel(weave.Model):
14
- patterns: Dict[str, str]
15
 
16
- def __init__(self, patterns: Dict[str, str]) -> None:
 
 
17
  """
18
  Initialize RegexModel with a dictionary of patterns.
19
-
20
  Args:
21
- patterns: Dictionary where key is pattern name and value is regex pattern
22
  Example: {"email": r"[^@ \t\r\n]+@[^@ \t\r\n]+\.[^@ \t\r\n]+",
23
- "phone": r"\b\d{3}[-.]?\d{3}[-.]?\d{4}\b"}
 
24
  """
25
  super().__init__(patterns=patterns)
 
 
 
26
  self._compiled_patterns = {
27
- name: re.compile(pattern) for name, pattern in patterns.items()
 
28
  }
29
 
30
  @weave.op()
31
- def check(self, prompt: str) -> RegexResult:
32
  """
33
  Check text against all patterns and return detailed results.
34
-
35
  Args:
36
  text: Input text to check against patterns
37
-
38
  Returns:
39
  RegexResult containing pass/fail status and details about matches
40
  """
41
  matched_patterns = {}
42
  failed_patterns = []
43
-
44
- for pattern_name, pattern in self.patterns.items():
45
  matches = []
46
- for match in re.finditer(pattern, prompt):
47
- if match.groups():
48
- # If there are capture groups, join them with a separator
49
- matches.append('-'.join(str(g) for g in match.groups() if g is not None))
50
- else:
51
- # If no capture groups, use the full match
52
- matches.append(match.group(0))
53
-
 
 
 
54
  if matches:
55
  matched_patterns[pattern_name] = matches
56
  else:
57
  failed_patterns.append(pattern_name)
58
-
59
  return RegexResult(
60
  matched_patterns=matched_patterns,
61
  failed_patterns=failed_patterns,
62
- passed=len(matched_patterns) == 0
63
  )
64
 
65
  @weave.op()
@@ -67,4 +78,4 @@ class RegexModel(weave.Model):
67
  """
68
  Alias for check() to maintain consistency with other models.
69
  """
70
- return self.check(text)
 
1
+ from typing import Union, Optional
2
+
3
+ import regex as re
4
  import weave
5
  from pydantic import BaseModel
6
 
7
 
8
  class RegexResult(BaseModel):
9
  passed: bool
10
+ matched_patterns: dict[str, list[str]]
11
+ failed_patterns: list[str]
12
 
13
 
14
  class RegexModel(weave.Model):
15
+ patterns: Optional[Union[dict[str, str], dict[str, list[str]]]] = None
16
 
17
+ def __init__(
18
+ self, patterns: Optional[Union[dict[str, str], dict[str, list[str]]]] = None
19
+ ) -> None:
20
  """
21
  Initialize RegexModel with a dictionary of patterns.
22
+
23
  Args:
24
+ patterns: Dictionary where key is pattern name and value is regex pattern or a list of patterns.
25
  Example: {"email": r"[^@ \t\r\n]+@[^@ \t\r\n]+\.[^@ \t\r\n]+",
26
+ "phone": [r"\b\d{3}[-.]?\d{3}[-.]?\d{4}\b"]}
27
+
28
  """
29
  super().__init__(patterns=patterns)
30
+ normalized_patterns = {}
31
+ for k, v in patterns.items():
32
+ normalized_patterns[k] = v if isinstance(v, list) else [v]
33
  self._compiled_patterns = {
34
+ name: [re.compile(p) for p in pattern]
35
+ for name, pattern in normalized_patterns.items()
36
  }
37
 
38
  @weave.op()
39
+ def check(self, text: str) -> RegexResult:
40
  """
41
  Check text against all patterns and return detailed results.
42
+
43
  Args:
44
  text: Input text to check against patterns
45
+
46
  Returns:
47
  RegexResult containing pass/fail status and details about matches
48
  """
49
  matched_patterns = {}
50
  failed_patterns = []
51
+
52
+ for pattern_name, pats in self._compiled_patterns.items():
53
  matches = []
54
+ for pattern in pats:
55
+ for match in pattern.finditer(text):
56
+ if match.groups():
57
+ # If there are capture groups, join them with a separator
58
+ matches.append(
59
+ "-".join(str(g) for g in match.groups() if g is not None)
60
+ )
61
+ else:
62
+ # If no capture groups, use the full match
63
+ matches.append(match.group(0))
64
+
65
  if matches:
66
  matched_patterns[pattern_name] = matches
67
  else:
68
  failed_patterns.append(pattern_name)
69
+
70
  return RegexResult(
71
  matched_patterns=matched_patterns,
72
  failed_patterns=failed_patterns,
73
+ passed=len(matched_patterns) == 0,
74
  )
75
 
76
  @weave.op()
 
78
  """
79
  Alias for check() to maintain consistency with other models.
80
  """
81
+ return self.check(text)