Spaces:
Running
Running
File size: 7,468 Bytes
9269c10 b850722 9269c10 b850722 c89e6e0 cfcefce 2946856 9269c10 c32f628 c89e6e0 c32f628 c89e6e0 3ad3f59 3caf047 3ad3f59 3caf047 3ad3f59 3caf047 3ad3f59 28d8897 3caf047 28d8897 cfcefce 021c1f9 9269c10 2946856 9269c10 0cca41e 6d0856c 0cca41e 2946856 0cca41e af688eb 0cca41e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import importlib
import streamlit as st
from dotenv import load_dotenv
from guardrails_genie.guardrails import GuardrailManager
from guardrails_genie.llm import OpenAIModel
def initialize_session_state():
load_dotenv()
if "guardrails" not in st.session_state:
st.session_state.guardrails = []
if "guardrail_names" not in st.session_state:
st.session_state.guardrail_names = []
if "guardrails_manager" not in st.session_state:
st.session_state.guardrails_manager = None
if "initialize_guardrails" not in st.session_state:
st.session_state.initialize_guardrails = False
if "system_prompt" not in st.session_state:
st.session_state.system_prompt = ""
if "user_prompt" not in st.session_state:
st.session_state.user_prompt = ""
if "test_guardrails" not in st.session_state:
st.session_state.test_guardrails = False
if "llm_model" not in st.session_state:
st.session_state.llm_model = None
if "llama_guard_checkpoint_name" not in st.session_state:
st.session_state.llama_guard_checkpoint_name = ""
def initialize_guardrails():
st.session_state.guardrails = []
for guardrail_name in st.session_state.guardrail_names:
if guardrail_name == "PromptInjectionSurveyGuardrail":
survey_guardrail_model = st.sidebar.selectbox(
"Survey Guardrail LLM", ["", "gpt-4o-mini", "gpt-4o"]
)
if survey_guardrail_model:
st.session_state.guardrails.append(
getattr(
importlib.import_module("guardrails_genie.guardrails"),
guardrail_name,
)(llm_model=OpenAIModel(model_name=survey_guardrail_model))
)
elif guardrail_name == "PromptInjectionClassifierGuardrail":
classifier_model_name = st.sidebar.selectbox(
"Classifier Guardrail Model",
[
"",
"ProtectAI/deberta-v3-base-prompt-injection-v2",
"wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
],
)
if classifier_model_name != "":
st.session_state.guardrails.append(
getattr(
importlib.import_module("guardrails_genie.guardrails"),
guardrail_name,
)(model_name=classifier_model_name)
)
elif guardrail_name == "PresidioEntityRecognitionGuardrail":
st.session_state.guardrails.append(
getattr(
importlib.import_module("guardrails_genie.guardrails"),
guardrail_name,
)(should_anonymize=True)
)
elif guardrail_name == "RegexEntityRecognitionGuardrail":
st.session_state.guardrails.append(
getattr(
importlib.import_module("guardrails_genie.guardrails"),
guardrail_name,
)(should_anonymize=True)
)
elif guardrail_name == "TransformersEntityRecognitionGuardrail":
st.session_state.guardrails.append(
getattr(
importlib.import_module("guardrails_genie.guardrails"),
guardrail_name,
)(should_anonymize=True)
)
elif guardrail_name == "RestrictedTermsJudge":
st.session_state.guardrails.append(
getattr(
importlib.import_module("guardrails_genie.guardrails"),
guardrail_name,
)(should_anonymize=True)
)
elif guardrail_name == "PromptInjectionLlamaGuardrail":
llama_guard_checkpoint_name = st.sidebar.text_input(
"Checkpoint Name", value=""
)
st.session_state.llama_guard_checkpoint_name = llama_guard_checkpoint_name
st.session_state.guardrails.append(
getattr(
importlib.import_module("guardrails_genie.guardrails"),
guardrail_name,
)(
checkpoint=(
None
if st.session_state.llama_guard_checkpoint_name == ""
else st.session_state.llama_guard_checkpoint_name
)
)
)
else:
st.session_state.guardrails.append(
getattr(
importlib.import_module("guardrails_genie.guardrails"),
guardrail_name,
)()
)
st.session_state.guardrails_manager = GuardrailManager(
guardrails=st.session_state.guardrails
)
if st.session_state.is_authenticated:
initialize_session_state()
st.title(":material/robot: Guardrails Genie Playground")
openai_model = st.sidebar.selectbox(
"OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
)
chat_condition = openai_model != ""
guardrails = []
guardrail_names = st.sidebar.multiselect(
label="Select Guardrails",
options=[
cls_name
for cls_name, cls_obj in vars(
importlib.import_module("guardrails_genie.guardrails")
).items()
if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
],
)
st.session_state.guardrail_names = guardrail_names
if st.sidebar.button("Initialize Guardrails") and chat_condition:
st.session_state.initialize_guardrails = True
if st.session_state.initialize_guardrails:
with st.sidebar.status("Initializing Guardrails..."):
initialize_guardrails()
st.session_state.llm_model = OpenAIModel(model_name=openai_model)
user_prompt = st.text_area("User Prompt", value="")
st.session_state.user_prompt = user_prompt
test_guardrails_button = st.button("Test Guardrails")
st.session_state.test_guardrails = test_guardrails_button
if st.session_state.test_guardrails:
with st.sidebar.status("Running Guardrails..."):
guardrails_response, call = (
st.session_state.guardrails_manager.guard.call(
st.session_state.guardrails_manager,
prompt=st.session_state.user_prompt,
)
)
if guardrails_response["safe"]:
st.markdown(
f"\n\n---\nPrompt is safe! Explore guardrail trace on [Weave]({call.ui_url})\n\n---\n"
)
with st.sidebar.status("Generating response from LLM..."):
response, call = st.session_state.llm_model.predict.call(
st.session_state.llm_model,
user_prompts=st.session_state.user_prompt,
)
st.markdown(
response.choices[0].message.content
+ f"\n\n---\nExplore LLM generation trace on [Weave]({call.ui_url})"
)
else:
st.warning("Prompt is not safe!")
st.markdown(guardrails_response["summary"])
st.markdown(f"Explore prompt trace on [Weave]({call.ui_url})")
else:
st.warning("Please authenticate your WandB account to use this feature.")
|