geekyrakshit commited on
Commit
1d117f2
·
1 Parent(s): 3ed3941

add: off-the-shelf prompt injection guardrail

Browse files
guardrails_genie/guardrails/__init__.py CHANGED
@@ -1,4 +1,8 @@
1
- from .injection import PromptInjectionSurveyGuardrail
2
  from .manager import GuardrailManager
3
 
4
- __all__ = ["PromptInjectionSurveyGuardrail", "GuardrailManager"]
 
 
 
 
 
1
+ from .injection import PromptInjectionProtectAIGuardrail, PromptInjectionSurveyGuardrail
2
  from .manager import GuardrailManager
3
 
4
+ __all__ = [
5
+ "PromptInjectionSurveyGuardrail",
6
+ "PromptInjectionProtectAIGuardrail",
7
+ "GuardrailManager",
8
+ ]
guardrails_genie/guardrails/injection/__init__.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from .survey_guardrail import PromptInjectionSurveyGuardrail
2
 
3
- __all__ = ["PromptInjectionSurveyGuardrail"]
 
1
+ from .protectai_guardrail import PromptInjectionProtectAIGuardrail
2
  from .survey_guardrail import PromptInjectionSurveyGuardrail
3
 
4
+ __all__ = ["PromptInjectionSurveyGuardrail", "PromptInjectionProtectAIGuardrail"]
guardrails_genie/guardrails/injection/protectai_guardrail.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import weave
5
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
+ from transformers.pipelines.base import Pipeline
7
+
8
+ from ..base import Guardrail
9
+
10
+
11
+ class PromptInjectionProtectAIGuardrail(Guardrail):
12
+ model_name: str = "ProtectAI/deberta-v3-base-prompt-injection-v2"
13
+ _classifier: Optional[Pipeline] = None
14
+
15
+ def model_post_init(self, __context):
16
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name)
17
+ model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
18
+ self._classifier = pipeline(
19
+ "text-classification",
20
+ model=model,
21
+ tokenizer=tokenizer,
22
+ truncation=True,
23
+ max_length=512,
24
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
25
+ )
26
+
27
+ @weave.op()
28
+ def predict(self, prompt: str):
29
+ return self._classifier(prompt)
30
+
31
+ @weave.op()
32
+ def guard(self, prompt: str):
33
+ response = self.predict(prompt)
34
+ return {"safe": response[0]["label"] != "INJECTION"}