Spaces:
Running
Running
File size: 4,911 Bytes
5e33295 573a89c 5e33295 573a89c a202ba5 5e33295 573a89c a202ba5 5e33295 a202ba5 5e33295 a202ba5 5e33295 a202ba5 5e33295 573a89c a202ba5 5e33295 a202ba5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import os
import streamlit as st
from guardrails_genie.train.llama_guard import DatasetArgs, LlamaGuardFineTuner
def initialize_session_state():
st.session_state.llama_guard_fine_tuner = LlamaGuardFineTuner(
wandb_project=os.getenv("WANDB_PROJECT_NAME"),
wandb_entity=os.getenv("WANDB_ENTITY_NAME"),
streamlit_mode=True,
)
if "dataset_address" not in st.session_state:
st.session_state.dataset_address = ""
if "train_dataset_range" not in st.session_state:
st.session_state.train_dataset_range = 0
if "test_dataset_range" not in st.session_state:
st.session_state.test_dataset_range = 0
if "load_fine_tuner_button" not in st.session_state:
st.session_state.load_fine_tuner_button = False
if "is_fine_tuner_loaded" not in st.session_state:
st.session_state.is_fine_tuner_loaded = False
if "model_name" not in st.session_state:
st.session_state.model_name = ""
if "preview_dataset" not in st.session_state:
st.session_state.preview_dataset = False
if "evaluate_model" not in st.session_state:
st.session_state.evaluate_model = False
if "evaluation_batch_size" not in st.session_state:
st.session_state.evaluation_batch_size = None
if "evaluation_temperature" not in st.session_state:
st.session_state.evaluation_temperature = None
if "checkpoint" not in st.session_state:
st.session_state.checkpoint = None
if "eval_batch_size" not in st.session_state:
st.session_state.eval_batch_size = 32
if "eval_positive_label" not in st.session_state:
st.session_state.eval_positive_label = 2
if "eval_temperature" not in st.session_state:
st.session_state.eval_temperature = 1.0
initialize_session_state()
st.title(":material/star: Fine-Tune LLama Guard")
dataset_address = st.sidebar.text_input("Dataset Address", value="")
st.session_state.dataset_address = dataset_address
if st.session_state.dataset_address != "":
train_dataset_range = st.sidebar.number_input(
"Train Dataset Range", value=0, min_value=0, max_value=252956
)
test_dataset_range = st.sidebar.number_input(
"Test Dataset Range", value=0, min_value=0, max_value=63240
)
st.session_state.train_dataset_range = train_dataset_range
st.session_state.test_dataset_range = test_dataset_range
model_name = st.sidebar.text_input(
label="Model Name", value="meta-llama/Prompt-Guard-86M"
)
st.session_state.model_name = model_name
checkpoint = st.sidebar.text_input(label="Fine-tuned Model Checkpoint", value="")
st.session_state.checkpoint = checkpoint
preview_dataset = st.sidebar.toggle("Preview Dataset")
st.session_state.preview_dataset = preview_dataset
evaluate_model = st.sidebar.toggle("Evaluate Model")
st.session_state.evaluate_model = evaluate_model
if st.session_state.evaluate_model:
eval_batch_size = st.sidebar.slider(
label="Eval Batch Size", min_value=16, max_value=1024, value=32
)
st.session_state.eval_batch_size = eval_batch_size
eval_positive_label = st.sidebar.number_input("EVal Positive Label", value=2)
st.session_state.eval_positive_label = eval_positive_label
eval_temperature = st.sidebar.slider(
label="Eval Temperature", min_value=0.0, max_value=5.0, value=1.0
)
st.session_state.eval_temperature = eval_temperature
load_fine_tuner_button = st.sidebar.button("Load Fine-Tuner")
st.session_state.load_fine_tuner_button = load_fine_tuner_button
if st.session_state.load_fine_tuner_button:
with st.status("Loading Fine-Tuner"):
st.session_state.llama_guard_fine_tuner.load_dataset(
DatasetArgs(
dataset_address=st.session_state.dataset_address,
train_dataset_range=st.session_state.train_dataset_range,
test_dataset_range=st.session_state.test_dataset_range,
)
)
st.session_state.llama_guard_fine_tuner.load_model(
model_name=st.session_state.model_name,
checkpoint=(
None
if st.session_state.checkpoint == ""
else st.session_state.checkpoint
),
)
if st.session_state.preview_dataset:
st.session_state.llama_guard_fine_tuner.show_dataset_sample()
if st.session_state.evaluate_model:
st.session_state.llama_guard_fine_tuner.evaluate_model(
batch_size=st.session_state.eval_batch_size,
positive_label=st.session_state.eval_positive_label,
temperature=st.session_state.eval_temperature,
)
st.session_state.is_fine_tuner_loaded = True
|