geekyrakshit commited on
Commit
b207b4c
·
1 Parent(s): 096a26c

add: docs for prompt injection guardrails

Browse files
docs/guardrails/base.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Guardrail Base Class
2
+
3
+ ::: guardrails_genie.guardrails.base
docs/guardrails/manager.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Guardrail Manager
2
+
3
+ ::: guardrails_genie.guardrails.manager
docs/guardrails/prompt_injection/classifier.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Prompt Injection Classifier Guardrail
2
+
3
+ ::: guardrails_genie.guardrails.injection.classifier_guardrail
docs/guardrails/prompt_injection/llm_survey.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Survey Guardrail
2
+
3
+ ::: guardrails_genie.guardrails.injection.survey_guardrail
guardrails_genie/guardrails/base.py CHANGED
@@ -4,6 +4,25 @@ import weave
4
 
5
 
6
  class Guardrail(weave.Model):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def __init__(self, *args, **kwargs):
8
  super().__init__(*args, **kwargs)
9
 
 
4
 
5
 
6
  class Guardrail(weave.Model):
7
+ """
8
+ The Guardrail class is an abstract base class that extends the weave.Model.
9
+
10
+ This class is designed to provide a framework for implementing guardrails
11
+ in the form of the `guard` method. The `guard` method is an abstract method
12
+ that must be implemented by any subclass. It takes a prompt string and
13
+ additional keyword arguments, and returns a list of strings. The specific
14
+ implementation of the `guard` method will define the behavior of the guardrail.
15
+
16
+ Attributes:
17
+ None
18
+
19
+ Methods:
20
+ guard(prompt: str, **kwargs) -> list[str]:
21
+ Abstract method that must be implemented by subclasses. It takes a
22
+ prompt string and additional keyword arguments, and returns a list
23
+ of strings.
24
+ """
25
+
26
  def __init__(self, *args, **kwargs):
27
  super().__init__(*args, **kwargs)
28
 
guardrails_genie/guardrails/injection/classifier_guardrail.py CHANGED
@@ -11,6 +11,15 @@ from ..base import Guardrail
11
 
12
 
13
  class PromptInjectionClassifierGuardrail(Guardrail):
 
 
 
 
 
 
 
 
 
14
  model_name: str = "ProtectAI/deberta-v3-base-prompt-injection-v2"
15
  _classifier: Optional[Pipeline] = None
16
 
@@ -39,6 +48,24 @@ class PromptInjectionClassifierGuardrail(Guardrail):
39
 
40
  @weave.op()
41
  def guard(self, prompt: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  response = self.classify(prompt)
43
  confidence_percentage = round(response[0]["score"] * 100, 2)
44
  return {
 
11
 
12
 
13
  class PromptInjectionClassifierGuardrail(Guardrail):
14
+ """
15
+ A guardrail that uses a pre-trained text-classification model to classify prompts
16
+ for potential injection attacks.
17
+
18
+ Args:
19
+ model_name (str): The name of the HuggingFace model or a WandB
20
+ checkpoint artifact path to use for classification.
21
+ """
22
+
23
  model_name: str = "ProtectAI/deberta-v3-base-prompt-injection-v2"
24
  _classifier: Optional[Pipeline] = None
25
 
 
48
 
49
  @weave.op()
50
  def guard(self, prompt: str):
51
+ """
52
+ Analyzes the given prompt to determine if it is safe or potentially an injection attack.
53
+
54
+ This function uses a pre-trained text-classification model to classify the prompt.
55
+ It calls the `classify` method to get the classification result, which includes a label
56
+ and a confidence score. The function then calculates the confidence percentage and
57
+ returns a dictionary with two keys:
58
+
59
+ - "safe": A boolean indicating whether the prompt is safe (True) or an injection (False).
60
+ - "summary": A string summarizing the classification result, including the label and the
61
+ confidence percentage.
62
+
63
+ Args:
64
+ prompt (str): The input prompt to be classified.
65
+
66
+ Returns:
67
+ dict: A dictionary containing the safety status and a summary of the classification result.
68
+ """
69
  response = self.classify(prompt)
70
  confidence_percentage = round(response[0]["score"] * 100, 2)
71
  return {
guardrails_genie/guardrails/injection/survey_guardrail.py CHANGED
@@ -16,10 +16,32 @@ class SurveyGuardrailResponse(BaseModel):
16
 
17
 
18
  class PromptInjectionSurveyGuardrail(Guardrail):
 
 
 
 
 
 
 
 
 
19
  llm_model: OpenAIModel
20
 
21
  @weave.op()
22
  def load_prompt_injection_survey(self) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  prompt_injection_survey_path = os.path.join(
24
  os.getcwd(), "prompts", "injection_paper_1.md"
25
  )
@@ -30,6 +52,30 @@ class PromptInjectionSurveyGuardrail(Guardrail):
30
 
31
  @weave.op()
32
  def format_prompts(self, prompt: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  markdown_text = self.load_prompt_injection_survey()
34
  user_prompt = f"""You are given the following research papers as reference:\n\n{markdown_text}"""
35
  user_prompt += f"""
@@ -62,6 +108,21 @@ Here are some strict instructions that you must follow:
62
 
63
  @weave.op()
64
  def predict(self, prompt: str, **kwargs) -> list[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  user_prompt, system_prompt = self.format_prompts(prompt)
66
  chat_completion = self.llm_model.predict(
67
  user_prompts=user_prompt,
@@ -74,6 +135,22 @@ Here are some strict instructions that you must follow:
74
 
75
  @weave.op()
76
  def guard(self, prompt: str, **kwargs) -> list[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  response = self.predict(prompt, **kwargs)
78
  summary = (
79
  f"Prompt is deemed safe. {response.explanation}"
 
16
 
17
 
18
  class PromptInjectionSurveyGuardrail(Guardrail):
19
+ """
20
+ A guardrail that uses a summarized version of the research paper
21
+ [An Early Categorization of Prompt Injection Attacks on Large Language Models](https://arxiv.org/abs/2402.00898)
22
+ to assess whether a prompt is a prompt injection attack or not.
23
+
24
+ Args:
25
+ llm_model (OpenAIModel): The LLM model to use for the guardrail.
26
+ """
27
+
28
  llm_model: OpenAIModel
29
 
30
  @weave.op()
31
  def load_prompt_injection_survey(self) -> str:
32
+ """
33
+ Loads the prompt injection survey content from a markdown file, wraps it in
34
+ `<research_paper>...</research_paper>` tags, and returns it as a string.
35
+
36
+ This function constructs the file path to the markdown file containing the
37
+ summarized research paper on prompt injection attacks. It reads the content
38
+ of the file, wraps it in <research_paper> tags, and returns the formatted
39
+ string. This formatted content is used as a reference in the prompt
40
+ assessment process.
41
+
42
+ Returns:
43
+ str: The content of the prompt injection survey wrapped in <research_paper> tags.
44
+ """
45
  prompt_injection_survey_path = os.path.join(
46
  os.getcwd(), "prompts", "injection_paper_1.md"
47
  )
 
52
 
53
  @weave.op()
54
  def format_prompts(self, prompt: str) -> str:
55
+ """
56
+ Formats the user and system prompts for assessing potential prompt injection attacks.
57
+
58
+ This function constructs two types of prompts: a user prompt and a system prompt.
59
+ The user prompt includes the content of a research paper on prompt injection attacks,
60
+ which is loaded using the `load_prompt_injection_survey` method. This content is
61
+ wrapped in a specific format to serve as a reference for the assessment process.
62
+ The user prompt also includes the input prompt that needs to be evaluated for
63
+ potential injection attacks, enclosed within <input_prompt> tags.
64
+
65
+ The system prompt provides detailed instructions to an expert system on how to
66
+ analyze the input prompt. It specifies that the system should use the research
67
+ papers as a reference to determine if the input prompt is a prompt injection attack,
68
+ and if so, classify it as a direct or indirect attack and identify the specific type.
69
+ The system is instructed to provide a detailed explanation of its assessment,
70
+ citing specific parts of the research papers, and to follow strict guidelines
71
+ to ensure accuracy and clarity.
72
+
73
+ Args:
74
+ prompt (str): The input prompt to be assessed for potential injection attacks.
75
+
76
+ Returns:
77
+ tuple: A tuple containing the formatted user prompt and system prompt.
78
+ """
79
  markdown_text = self.load_prompt_injection_survey()
80
  user_prompt = f"""You are given the following research papers as reference:\n\n{markdown_text}"""
81
  user_prompt += f"""
 
108
 
109
  @weave.op()
110
  def predict(self, prompt: str, **kwargs) -> list[str]:
111
+ """
112
+ Predicts whether the given input prompt is a prompt injection attack.
113
+
114
+ This function formats the user and system prompts using the `format_prompts` method,
115
+ which includes the content of research papers and the input prompt to be assessed.
116
+ It then uses the `llm_model` to predict the nature of the input prompt by providing
117
+ the formatted prompts and expecting a response in the `SurveyGuardrailResponse` format.
118
+
119
+ Args:
120
+ prompt (str): The input prompt to be assessed for potential injection attacks.
121
+ **kwargs: Additional keyword arguments to be passed to the `llm_model.predict` method.
122
+
123
+ Returns:
124
+ list[str]: The parsed response from the model, indicating the assessment of the input prompt.
125
+ """
126
  user_prompt, system_prompt = self.format_prompts(prompt)
127
  chat_completion = self.llm_model.predict(
128
  user_prompts=user_prompt,
 
135
 
136
  @weave.op()
137
  def guard(self, prompt: str, **kwargs) -> list[str]:
138
+ """
139
+ Assesses the given input prompt for potential prompt injection attacks and provides a summary.
140
+
141
+ This function uses the `predict` method to determine whether the input prompt is a prompt injection attack.
142
+ It then constructs a summary based on the prediction, indicating whether the prompt is safe or an attack.
143
+ If the prompt is deemed an attack, the summary specifies whether it is a direct or indirect attack and the type of attack.
144
+
145
+ Args:
146
+ prompt (str): The input prompt to be assessed for potential injection attacks.
147
+ **kwargs: Additional keyword arguments to be passed to the `predict` method.
148
+
149
+ Returns:
150
+ dict: A dictionary containing:
151
+ - "safe" (bool): Indicates whether the prompt is safe (True) or an injection attack (False).
152
+ - "summary" (str): A summary of the assessment, including the type of attack and explanation if applicable.
153
+ """
154
  response = self.predict(prompt, **kwargs)
155
  summary = (
156
  f"Prompt is deemed safe. {response.explanation}"
guardrails_genie/guardrails/manager.py CHANGED
@@ -6,10 +6,44 @@ from .base import Guardrail
6
 
7
 
8
  class GuardrailManager(weave.Model):
 
 
 
 
 
 
 
 
 
 
9
  guardrails: list[Guardrail]
10
 
11
  @weave.op()
12
  def guard(self, prompt: str, progress_bar: bool = True, **kwargs) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  alerts, summaries, safe = [], "", True
14
  iterable = (
15
  track(self.guardrails, description="Running guardrails")
@@ -31,4 +65,25 @@ class GuardrailManager(weave.Model):
31
 
32
  @weave.op()
33
  def predict(self, prompt: str, **kwargs) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  return self.guard(prompt, progress_bar=False, **kwargs)
 
6
 
7
 
8
  class GuardrailManager(weave.Model):
9
+ """
10
+ GuardrailManager is responsible for managing and executing a series of guardrails
11
+ on a given prompt. It utilizes the `weave` framework to define operations that
12
+ can be applied to the guardrails.
13
+
14
+ Attributes:
15
+ guardrails (list[Guardrail]): A list of Guardrail objects that define the
16
+ rules and checks to be applied to the input prompt.
17
+ """
18
+
19
  guardrails: list[Guardrail]
20
 
21
  @weave.op()
22
  def guard(self, prompt: str, progress_bar: bool = True, **kwargs) -> dict:
23
+ """
24
+ Execute a series of guardrails on a given prompt and return the results.
25
+
26
+ This method iterates over a list of Guardrail objects, applying each guardrail's
27
+ `guard` method to the provided prompt. It collects responses from each guardrail
28
+ and compiles them into a summary report. The function also determines the overall
29
+ safety of the prompt based on the responses from the guardrails.
30
+
31
+ Args:
32
+ prompt (str): The input prompt to be evaluated by the guardrails.
33
+ progress_bar (bool, optional): If True, displays a progress bar while
34
+ processing the guardrails. Defaults to True.
35
+ **kwargs: Additional keyword arguments to be passed to each guardrail's
36
+ `guard` method.
37
+
38
+ Returns:
39
+ dict: A dictionary containing:
40
+ - "safe" (bool): Indicates whether the prompt is considered safe
41
+ based on the guardrails' evaluations.
42
+ - "alerts" (list): A list of dictionaries, each containing the name
43
+ of the guardrail and its response.
44
+ - "summary" (str): A formatted string summarizing the results of
45
+ each guardrail's evaluation.
46
+ """
47
  alerts, summaries, safe = [], "", True
48
  iterable = (
49
  track(self.guardrails, description="Running guardrails")
 
65
 
66
  @weave.op()
67
  def predict(self, prompt: str, **kwargs) -> dict:
68
+ """
69
+ Predicts the safety and potential issues of a given input prompt using the guardrails.
70
+
71
+ This function serves as a wrapper around the `guard` method, providing a simplified
72
+ interface for evaluating the input prompt without displaying a progress bar. It
73
+ applies a series of guardrails to the prompt and returns a detailed assessment.
74
+
75
+ Args:
76
+ prompt (str): The input prompt to be evaluated by the guardrails.
77
+ **kwargs: Additional keyword arguments to be passed to each guardrail's
78
+ `guard` method.
79
+
80
+ Returns:
81
+ dict: A dictionary containing:
82
+ - "safe" (bool): Indicates whether the prompt is considered safe
83
+ based on the guardrails' evaluations.
84
+ - "alerts" (list): A list of dictionaries, each containing the name
85
+ of the guardrail and its response.
86
+ - "summary" (str): A formatted string summarizing the results of
87
+ each guardrail's evaluation.
88
+ """
89
  return self.guard(prompt, progress_bar=False, **kwargs)
guardrails_genie/llm.py CHANGED
@@ -10,13 +10,14 @@ class OpenAIModel(weave.Model):
10
  A class to interface with OpenAI's language models using the Weave framework.
11
 
12
  This class provides methods to create structured messages and generate predictions
13
- using OpenAI's chat completion API. It is designed to work with both single and
14
- multiple user prompts, and optionally includes a system prompt to guide the model's
15
  responses.
16
 
17
  Args:
18
  model_name (str): The name of the OpenAI model to be used for predictions.
19
  """
 
20
  model_name: str
21
  _openai_client: OpenAI
22
 
@@ -34,19 +35,19 @@ class OpenAIModel(weave.Model):
34
  """
35
  Create a list of messages for the OpenAI chat completion API.
36
 
37
- This function constructs a list of messages in the format required by the
38
- OpenAI chat completion API. It takes user prompts, an optional system prompt,
39
- and an optional list of existing messages, and combines them into a single
40
  list of messages.
41
 
42
  Args:
43
- user_prompts (Union[str, list[str]]): A single user prompt or a list of
44
  user prompts to be included in the messages.
45
- system_prompt (Optional[str]): An optional system prompt to guide the
46
- model's responses. If provided, it will be added at the beginning
47
  of the messages list.
48
- messages (Optional[list[dict]]): An optional list of existing messages
49
- to which the new prompts will be appended. If not provided, a new
50
  list will be created.
51
 
52
  Returns:
@@ -71,21 +72,21 @@ class OpenAIModel(weave.Model):
71
  """
72
  Generate a chat completion response using the OpenAI API.
73
 
74
- This function takes user prompts, an optional system prompt, and an optional
75
- list of existing messages to create a list of messages formatted for the
76
- OpenAI chat completion API. It then sends these messages to the OpenAI API
77
  to generate a chat completion response.
78
 
79
  Args:
80
- user_prompts (Union[str, list[str]]): A single user prompt or a list of
81
  user prompts to be included in the messages.
82
- system_prompt (Optional[str]): An optional system prompt to guide the
83
- model's responses. If provided, it will be added at the beginning
84
  of the messages list.
85
- messages (Optional[list[dict]]): An optional list of existing messages
86
- to which the new prompts will be appended. If not provided, a new
87
  list will be created.
88
- **kwargs: Additional keyword arguments to be passed to the OpenAI API
89
  for chat completion.
90
 
91
  Returns:
 
10
  A class to interface with OpenAI's language models using the Weave framework.
11
 
12
  This class provides methods to create structured messages and generate predictions
13
+ using OpenAI's chat completion API. It is designed to work with both single and
14
+ multiple user prompts, and optionally includes a system prompt to guide the model's
15
  responses.
16
 
17
  Args:
18
  model_name (str): The name of the OpenAI model to be used for predictions.
19
  """
20
+
21
  model_name: str
22
  _openai_client: OpenAI
23
 
 
35
  """
36
  Create a list of messages for the OpenAI chat completion API.
37
 
38
+ This function constructs a list of messages in the format required by the
39
+ OpenAI chat completion API. It takes user prompts, an optional system prompt,
40
+ and an optional list of existing messages, and combines them into a single
41
  list of messages.
42
 
43
  Args:
44
+ user_prompts (Union[str, list[str]]): A single user prompt or a list of
45
  user prompts to be included in the messages.
46
+ system_prompt (Optional[str]): An optional system prompt to guide the
47
+ model's responses. If provided, it will be added at the beginning
48
  of the messages list.
49
+ messages (Optional[list[dict]]): An optional list of existing messages
50
+ to which the new prompts will be appended. If not provided, a new
51
  list will be created.
52
 
53
  Returns:
 
72
  """
73
  Generate a chat completion response using the OpenAI API.
74
 
75
+ This function takes user prompts, an optional system prompt, and an optional
76
+ list of existing messages to create a list of messages formatted for the
77
+ OpenAI chat completion API. It then sends these messages to the OpenAI API
78
  to generate a chat completion response.
79
 
80
  Args:
81
+ user_prompts (Union[str, list[str]]): A single user prompt or a list of
82
  user prompts to be included in the messages.
83
+ system_prompt (Optional[str]): An optional system prompt to guide the
84
+ model's responses. If provided, it will be added at the beginning
85
  of the messages list.
86
+ messages (Optional[list[dict]]): An optional list of existing messages
87
+ to which the new prompts will be appended. If not provided, a new
88
  list will be created.
89
+ **kwargs: Additional keyword arguments to be passed to the OpenAI API
90
  for chat completion.
91
 
92
  Returns:
mkdocs.yml CHANGED
@@ -59,6 +59,12 @@ extra_javascript:
59
 
60
  nav:
61
  - Home: 'index.md'
 
 
 
 
 
 
62
  - LLM: 'llm.md'
63
  - Metrics: 'metrics.md'
64
  - RegexModel: 'regex_model.md'
 
59
 
60
  nav:
61
  - Home: 'index.md'
62
+ - Guardrails:
63
+ - Guardrail Base Class: 'guardrails/base.md'
64
+ - Guardrail Manager: 'guardrails/manager.md'
65
+ - Prompt Injection Guardrails:
66
+ - Classifier Guardrail: 'guardrails/prompt_injection/classifier.md'
67
+ - Survey Guardrail: 'guardrails/prompt_injection/llm_survey.md'
68
  - LLM: 'llm.md'
69
  - Metrics: 'metrics.md'
70
  - RegexModel: 'regex_model.md'