geekyrakshit's picture
add: off-the-shelf prompt injection guardrail
1d117f2
raw
history blame
1.1 kB
from typing import Optional
import torch
import weave
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
from transformers.pipelines.base import Pipeline
from ..base import Guardrail
class PromptInjectionProtectAIGuardrail(Guardrail):
model_name: str = "ProtectAI/deberta-v3-base-prompt-injection-v2"
_classifier: Optional[Pipeline] = None
def model_post_init(self, __context):
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
self._classifier = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
truncation=True,
max_length=512,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
@weave.op()
def predict(self, prompt: str):
return self._classifier(prompt)
@weave.op()
def guard(self, prompt: str):
response = self.predict(prompt)
return {"safe": response[0]["label"] != "INJECTION"}