File size: 2,067 Bytes
053730f
 
968f4bc
 
 
8647e3b
968f4bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
053730f
 
 
 
 
 
968f4bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
053730f
 
 
968f4bc
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os

import streamlit as st
from dotenv import load_dotenv

from guardrails_genie.train.train_classifier import train_binary_classifier


def initialize_session_state():
    load_dotenv()
    if "dataset_name" not in st.session_state:
        st.session_state.dataset_name = None
    if "base_model_name" not in st.session_state:
        st.session_state.base_model_name = None
    if "batch_size" not in st.session_state:
        st.session_state.batch_size = 16
    if "should_start_training" not in st.session_state:
        st.session_state.should_start_training = False
    if "training_output" not in st.session_state:
        st.session_state.training_output = None


initialize_session_state()
st.title(":material/fitness_center: Train Classifier")

dataset_name = st.sidebar.text_input("Dataset Name", value="")
st.session_state.dataset_name = dataset_name

base_model_name = st.sidebar.selectbox(
    "Base Model",
    options=[
        "distilbert/distilbert-base-uncased",
        "FacebookAI/roberta-base",
        "microsoft/deberta-v3-base",
    ],
)
st.session_state.base_model_name = base_model_name

batch_size = st.sidebar.slider(
    "Batch Size", min_value=4, max_value=256, value=16, step=4
)
st.session_state.batch_size = batch_size

train_button = st.sidebar.button("Train")
st.session_state.should_start_training = (
    train_button and st.session_state.dataset_name and st.session_state.base_model_name
)

if st.session_state.should_start_training:
    with st.expander("Training", expanded=True):
        training_output = train_binary_classifier(
            project_name=os.getenv("WANDB_PROJECT_NAME"),
            entity_name=os.getenv("WANDB_ENTITY_NAME"),
            run_name=f"{st.session_state.base_model_name}-finetuned",
            dataset_repo=st.session_state.dataset_name,
            model_name=st.session_state.base_model_name,
            batch_size=st.session_state.batch_size,
            streamlit_mode=True,
        )
        st.session_state.training_output = training_output
        st.write(training_output)