geekyrakshit commited on
Commit
c89e6e0
·
1 Parent(s): 159baa9

update: app

Browse files
.gitignore CHANGED
@@ -167,4 +167,5 @@ test.py
167
  temp.txt
168
  **.csv
169
  binary-classifier/
170
- wandb/
 
 
167
  temp.txt
168
  **.csv
169
  binary-classifier/
170
+ wandb/
171
+ artifacts/
application_pages/chat_app.py CHANGED
@@ -1,4 +1,5 @@
1
  import importlib
 
2
 
3
  import streamlit as st
4
  import weave
@@ -7,27 +8,27 @@ from dotenv import load_dotenv
7
  from guardrails_genie.guardrails import GuardrailManager
8
  from guardrails_genie.llm import OpenAIModel
9
 
10
- st.title(":material/robot: Guardrails Genie Playground")
11
-
12
- load_dotenv()
13
- weave.init(project_name="guardrails-genie")
14
 
15
- if "guardrails" not in st.session_state:
16
- st.session_state.guardrails = []
17
- if "guardrail_names" not in st.session_state:
18
- st.session_state.guardrail_names = []
19
- if "guardrails_manager" not in st.session_state:
20
- st.session_state.guardrails_manager = None
21
- if "initialize_guardrails" not in st.session_state:
22
- st.session_state.initialize_guardrails = False
23
- if "system_prompt" not in st.session_state:
24
- st.session_state.system_prompt = ""
25
- if "user_prompt" not in st.session_state:
26
- st.session_state.user_prompt = ""
27
- if "test_guardrails" not in st.session_state:
28
- st.session_state.test_guardrails = False
29
- if "llm_model" not in st.session_state:
30
- st.session_state.llm_model = None
 
 
 
 
31
 
32
 
33
  def initialize_guardrails():
@@ -44,18 +45,30 @@ def initialize_guardrails():
44
  guardrail_name,
45
  )(llm_model=OpenAIModel(model_name=survey_guardrail_model))
46
  )
47
- else:
48
- st.session_state.guardrails.append(
49
- getattr(
50
- importlib.import_module("guardrails_genie.guardrails"),
51
- guardrail_name,
52
- )()
 
 
53
  )
 
 
 
 
 
 
 
54
  st.session_state.guardrails_manager = GuardrailManager(
55
  guardrails=st.session_state.guardrails
56
  )
57
 
58
 
 
 
 
59
  openai_model = st.sidebar.selectbox(
60
  "OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
61
  )
@@ -97,7 +110,7 @@ if st.session_state.initialize_guardrails:
97
 
98
  if guardrails_response["safe"]:
99
  st.markdown(
100
- f"\n\n---\nPrompt is safe! Explore prompt trace on [Weave]({call.ui_url})\n\n---\n"
101
  )
102
 
103
  with st.sidebar.status("Generating response from LLM..."):
 
1
  import importlib
2
+ import os
3
 
4
  import streamlit as st
5
  import weave
 
8
  from guardrails_genie.guardrails import GuardrailManager
9
  from guardrails_genie.llm import OpenAIModel
10
 
 
 
 
 
11
 
12
+ def initialize_session_state():
13
+ load_dotenv()
14
+ weave.init(project_name=os.getenv("WEAVE_PROJECT"))
15
+
16
+ if "guardrails" not in st.session_state:
17
+ st.session_state.guardrails = []
18
+ if "guardrail_names" not in st.session_state:
19
+ st.session_state.guardrail_names = []
20
+ if "guardrails_manager" not in st.session_state:
21
+ st.session_state.guardrails_manager = None
22
+ if "initialize_guardrails" not in st.session_state:
23
+ st.session_state.initialize_guardrails = False
24
+ if "system_prompt" not in st.session_state:
25
+ st.session_state.system_prompt = ""
26
+ if "user_prompt" not in st.session_state:
27
+ st.session_state.user_prompt = ""
28
+ if "test_guardrails" not in st.session_state:
29
+ st.session_state.test_guardrails = False
30
+ if "llm_model" not in st.session_state:
31
+ st.session_state.llm_model = None
32
 
33
 
34
  def initialize_guardrails():
 
45
  guardrail_name,
46
  )(llm_model=OpenAIModel(model_name=survey_guardrail_model))
47
  )
48
+ elif guardrail_name == "PromptInjectionClassifierGuardrail":
49
+ classifier_model_name = st.sidebar.selectbox(
50
+ "Classifier Guardrail Model",
51
+ [
52
+ "",
53
+ "ProtectAI/deberta-v3-base-prompt-injection-v2",
54
+ "wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
55
+ ],
56
  )
57
+ if classifier_model_name != "":
58
+ st.session_state.guardrails.append(
59
+ getattr(
60
+ importlib.import_module("guardrails_genie.guardrails"),
61
+ guardrail_name,
62
+ )(model_name=classifier_model_name)
63
+ )
64
  st.session_state.guardrails_manager = GuardrailManager(
65
  guardrails=st.session_state.guardrails
66
  )
67
 
68
 
69
+ initialize_session_state()
70
+ st.title(":material/robot: Guardrails Genie Playground")
71
+
72
  openai_model = st.sidebar.selectbox(
73
  "OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
74
  )
 
110
 
111
  if guardrails_response["safe"]:
112
  st.markdown(
113
+ f"\n\n---\nPrompt is safe! Explore guardrail trace on [Weave]({call.ui_url})\n\n---\n"
114
  )
115
 
116
  with st.sidebar.status("Generating response from LLM..."):
application_pages/evaluation_app.py CHANGED
@@ -64,10 +64,22 @@ def initialize_guardrail():
64
  guardrail_name,
65
  )(llm_model=OpenAIModel(model_name=survey_guardrail_model))
66
  )
67
- else:
68
- guardrails.append(
69
- getattr(import_module("guardrails_genie.guardrails"), guardrail_name)()
 
 
 
 
 
70
  )
 
 
 
 
 
 
 
71
  st.session_state.guardrails = guardrails
72
  st.session_state.guardrail_manager = GuardrailManager(guardrails=guardrails)
73
 
 
64
  guardrail_name,
65
  )(llm_model=OpenAIModel(model_name=survey_guardrail_model))
66
  )
67
+ elif guardrail_name == "PromptInjectionClassifierGuardrail":
68
+ classifier_model_name = st.sidebar.selectbox(
69
+ "Classifier Guardrail Model",
70
+ [
71
+ "",
72
+ "ProtectAI/deberta-v3-base-prompt-injection-v2",
73
+ "wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
74
+ ],
75
  )
76
+ if classifier_model_name:
77
+ st.session_state.guardrails.append(
78
+ getattr(
79
+ import_module("guardrails_genie.guardrails"),
80
+ guardrail_name,
81
+ )(model_name=classifier_model_name)
82
+ )
83
  st.session_state.guardrails = guardrails
84
  st.session_state.guardrail_manager = GuardrailManager(guardrails=guardrails)
85
 
guardrails_genie/guardrails/__init__.py CHANGED
@@ -1,8 +1,11 @@
1
- from .injection import PromptInjectionProtectAIGuardrail, PromptInjectionSurveyGuardrail
 
 
 
2
  from .manager import GuardrailManager
3
 
4
  __all__ = [
5
  "PromptInjectionSurveyGuardrail",
6
- "PromptInjectionProtectAIGuardrail",
7
  "GuardrailManager",
8
  ]
 
1
+ from .injection import (
2
+ PromptInjectionClassifierGuardrail,
3
+ PromptInjectionSurveyGuardrail,
4
+ )
5
  from .manager import GuardrailManager
6
 
7
  __all__ = [
8
  "PromptInjectionSurveyGuardrail",
9
+ "PromptInjectionClassifierGuardrail",
10
  "GuardrailManager",
11
  ]
guardrails_genie/guardrails/injection/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .protectai_guardrail import PromptInjectionProtectAIGuardrail
2
  from .survey_guardrail import PromptInjectionSurveyGuardrail
3
 
4
- __all__ = ["PromptInjectionSurveyGuardrail", "PromptInjectionProtectAIGuardrail"]
 
1
+ from .protectai_guardrail import PromptInjectionClassifierGuardrail
2
  from .survey_guardrail import PromptInjectionSurveyGuardrail
3
 
4
+ __all__ = ["PromptInjectionSurveyGuardrail", "PromptInjectionClassifierGuardrail"]
guardrails_genie/guardrails/injection/protectai_guardrail.py CHANGED
@@ -5,16 +5,25 @@ import weave
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from transformers.pipelines.base import Pipeline
7
 
 
 
8
  from ..base import Guardrail
9
 
10
 
11
- class PromptInjectionProtectAIGuardrail(Guardrail):
12
  model_name: str = "ProtectAI/deberta-v3-base-prompt-injection-v2"
13
  _classifier: Optional[Pipeline] = None
14
 
15
  def model_post_init(self, __context):
16
- tokenizer = AutoTokenizer.from_pretrained(self.model_name)
17
- model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
 
 
 
 
 
 
 
18
  self._classifier = pipeline(
19
  "text-classification",
20
  model=model,
 
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from transformers.pipelines.base import Pipeline
7
 
8
+ import wandb
9
+
10
  from ..base import Guardrail
11
 
12
 
13
+ class PromptInjectionClassifierGuardrail(Guardrail):
14
  model_name: str = "ProtectAI/deberta-v3-base-prompt-injection-v2"
15
  _classifier: Optional[Pipeline] = None
16
 
17
  def model_post_init(self, __context):
18
+ if self.model_name.startswith("wandb://"):
19
+ api = wandb.Api()
20
+ artifact = api.artifact(self.model_name.removeprefix("wandb://"))
21
+ artifact_dir = artifact.download()
22
+ tokenizer = AutoTokenizer.from_pretrained(artifact_dir)
23
+ model = AutoModelForSequenceClassification.from_pretrained(artifact_dir)
24
+ else:
25
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name)
26
+ model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
27
  self._classifier = pipeline(
28
  "text-classification",
29
  model=model,