File size: 4,911 Bytes
5e33295
 
573a89c
 
 
 
 
 
5e33295
 
 
 
 
573a89c
 
 
 
 
 
a202ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
5e33295
 
 
 
 
 
 
 
573a89c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a202ba5
5e33295
 
a202ba5
 
 
5e33295
 
 
a202ba5
 
 
 
 
 
5e33295
 
 
 
 
 
 
 
 
 
 
 
 
 
a202ba5
 
 
 
 
 
 
 
 
 
 
 
 
5e33295
 
 
 
 
 
573a89c
a202ba5
 
 
 
5e33295
 
 
a202ba5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os

import streamlit as st

from guardrails_genie.train.llama_guard import DatasetArgs, LlamaGuardFineTuner


def initialize_session_state():
    st.session_state.llama_guard_fine_tuner = LlamaGuardFineTuner(
        wandb_project=os.getenv("WANDB_PROJECT_NAME"),
        wandb_entity=os.getenv("WANDB_ENTITY_NAME"),
        streamlit_mode=True,
    )
    if "dataset_address" not in st.session_state:
        st.session_state.dataset_address = ""
    if "train_dataset_range" not in st.session_state:
        st.session_state.train_dataset_range = 0
    if "test_dataset_range" not in st.session_state:
        st.session_state.test_dataset_range = 0
    if "load_fine_tuner_button" not in st.session_state:
        st.session_state.load_fine_tuner_button = False
    if "is_fine_tuner_loaded" not in st.session_state:
        st.session_state.is_fine_tuner_loaded = False
    if "model_name" not in st.session_state:
        st.session_state.model_name = ""
    if "preview_dataset" not in st.session_state:
        st.session_state.preview_dataset = False
    if "evaluate_model" not in st.session_state:
        st.session_state.evaluate_model = False
    if "evaluation_batch_size" not in st.session_state:
        st.session_state.evaluation_batch_size = None
    if "evaluation_temperature" not in st.session_state:
        st.session_state.evaluation_temperature = None
    if "checkpoint" not in st.session_state:
        st.session_state.checkpoint = None
    if "eval_batch_size" not in st.session_state:
        st.session_state.eval_batch_size = 32
    if "eval_positive_label" not in st.session_state:
        st.session_state.eval_positive_label = 2
    if "eval_temperature" not in st.session_state:
        st.session_state.eval_temperature = 1.0


initialize_session_state()
st.title(":material/star: Fine-Tune LLama Guard")

dataset_address = st.sidebar.text_input("Dataset Address", value="")
st.session_state.dataset_address = dataset_address

if st.session_state.dataset_address != "":
    train_dataset_range = st.sidebar.number_input(
        "Train Dataset Range", value=0, min_value=0, max_value=252956
    )
    test_dataset_range = st.sidebar.number_input(
        "Test Dataset Range", value=0, min_value=0, max_value=63240
    )
    st.session_state.train_dataset_range = train_dataset_range
    st.session_state.test_dataset_range = test_dataset_range

    model_name = st.sidebar.text_input(
        label="Model Name", value="meta-llama/Prompt-Guard-86M"
    )
    st.session_state.model_name = model_name

    checkpoint = st.sidebar.text_input(label="Fine-tuned Model Checkpoint", value="")
    st.session_state.checkpoint = checkpoint

    preview_dataset = st.sidebar.toggle("Preview Dataset")
    st.session_state.preview_dataset = preview_dataset

    evaluate_model = st.sidebar.toggle("Evaluate Model")
    st.session_state.evaluate_model = evaluate_model

    if st.session_state.evaluate_model:
        eval_batch_size = st.sidebar.slider(
            label="Eval Batch Size", min_value=16, max_value=1024, value=32
        )
        st.session_state.eval_batch_size = eval_batch_size

        eval_positive_label = st.sidebar.number_input("EVal Positive Label", value=2)
        st.session_state.eval_positive_label = eval_positive_label

        eval_temperature = st.sidebar.slider(
            label="Eval Temperature", min_value=0.0, max_value=5.0, value=1.0
        )
        st.session_state.eval_temperature = eval_temperature

    load_fine_tuner_button = st.sidebar.button("Load Fine-Tuner")
    st.session_state.load_fine_tuner_button = load_fine_tuner_button

    if st.session_state.load_fine_tuner_button:
        with st.status("Loading Fine-Tuner"):
            st.session_state.llama_guard_fine_tuner.load_dataset(
                DatasetArgs(
                    dataset_address=st.session_state.dataset_address,
                    train_dataset_range=st.session_state.train_dataset_range,
                    test_dataset_range=st.session_state.test_dataset_range,
                )
            )
            st.session_state.llama_guard_fine_tuner.load_model(
                model_name=st.session_state.model_name,
                checkpoint=(
                    None
                    if st.session_state.checkpoint == ""
                    else st.session_state.checkpoint
                ),
            )
            if st.session_state.preview_dataset:
                st.session_state.llama_guard_fine_tuner.show_dataset_sample()
            if st.session_state.evaluate_model:
                st.session_state.llama_guard_fine_tuner.evaluate_model(
                    batch_size=st.session_state.eval_batch_size,
                    positive_label=st.session_state.eval_positive_label,
                    temperature=st.session_state.eval_temperature,
                )
            st.session_state.is_fine_tuner_loaded = True