geekyrakshit commited on
Commit
306b50d
·
1 Parent(s): 98a3259

add: AccuracyMetric

Browse files
guardrails_genie/guardrails/injection/protectai_guardrail.py CHANGED
@@ -26,9 +26,9 @@ class PromptInjectionProtectAIGuardrail(Guardrail):
26
 
27
  @weave.op()
28
  def predict(self, prompt: str):
29
- return self._classifier(prompt)
 
30
 
31
  @weave.op()
32
  def guard(self, prompt: str):
33
- response = self.predict(prompt)
34
- return {"safe": response[0]["label"] != "INJECTION"}
 
26
 
27
  @weave.op()
28
  def predict(self, prompt: str):
29
+ response = weave.op()(self._classifier)(prompt)
30
+ return {"safe": response[0]["label"] != "INJECTION"}
31
 
32
  @weave.op()
33
  def guard(self, prompt: str):
34
+ return self.predict(prompt)
 
guardrails_genie/guardrails/injection/survey_guardrail.py CHANGED
@@ -69,9 +69,9 @@ Here are some strict instructions that you must follow:
69
  response_format=SurveyGuardrailResponse,
70
  **kwargs,
71
  )
72
- return chat_completion.choices[0].message.parsed
 
73
 
74
  @weave.op()
75
  def guard(self, prompt: str, **kwargs) -> list[str]:
76
- response = self.predict(prompt, **kwargs)
77
- return {"safe": not response.injection_prompt}
 
69
  response_format=SurveyGuardrailResponse,
70
  **kwargs,
71
  )
72
+ response = chat_completion.choices[0].message.parsed
73
+ return {"safe": not response.injection_prompt}
74
 
75
  @weave.op()
76
  def guard(self, prompt: str, **kwargs) -> list[str]:
77
+ return self.predict(prompt, **kwargs)
 
guardrails_genie/metrics.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ import weave
5
+
6
+
7
+ class AccuracyMetric(weave.Scorer):
8
+ @weave.op()
9
+ def score(self, output: dict, label: int):
10
+ return {"correct": bool(label) == output["safe"]}
11
+
12
+ @weave.op()
13
+ def summarize(self, score_rows: list) -> Optional[dict]:
14
+ valid_data = [
15
+ x.get("correct") for x in score_rows if x.get("correct") is not None
16
+ ]
17
+ count_true = list(valid_data).count(True)
18
+ int_data = [int(x) for x in valid_data]
19
+
20
+ sample_mean = np.mean(int_data) if int_data else 0
21
+ sample_variance = np.var(int_data) if int_data else 0
22
+ sample_error = np.sqrt(sample_variance / len(int_data)) if int_data else 0
23
+
24
+ # Calculate precision, recall, and F1 score
25
+ true_positives = count_true
26
+ false_positives = len(valid_data) - count_true
27
+ false_negatives = len(score_rows) - len(valid_data)
28
+
29
+ precision = (
30
+ true_positives / (true_positives + false_positives)
31
+ if (true_positives + false_positives) > 0
32
+ else 0
33
+ )
34
+ recall = (
35
+ true_positives / (true_positives + false_negatives)
36
+ if (true_positives + false_negatives) > 0
37
+ else 0
38
+ )
39
+ f1_score = (
40
+ (2 * precision * recall) / (precision + recall)
41
+ if (precision + recall) > 0
42
+ else 0
43
+ )
44
+
45
+ return {
46
+ "correct": {
47
+ "true_count": count_true,
48
+ "false_count": len(score_rows) - count_true,
49
+ "true_fraction": float(sample_mean),
50
+ "false_fraction": 1.0 - float(sample_mean),
51
+ "stderr": float(sample_error),
52
+ "precision": precision,
53
+ "recall": recall,
54
+ "f1_score": f1_score,
55
+ }
56
+ }