Spaces:
Running
Running
File size: 2,067 Bytes
053730f 968f4bc 8647e3b 968f4bc 053730f 968f4bc 053730f 968f4bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import os
import streamlit as st
from dotenv import load_dotenv
from guardrails_genie.train.train_classifier import train_binary_classifier
def initialize_session_state():
load_dotenv()
if "dataset_name" not in st.session_state:
st.session_state.dataset_name = None
if "base_model_name" not in st.session_state:
st.session_state.base_model_name = None
if "batch_size" not in st.session_state:
st.session_state.batch_size = 16
if "should_start_training" not in st.session_state:
st.session_state.should_start_training = False
if "training_output" not in st.session_state:
st.session_state.training_output = None
initialize_session_state()
st.title(":material/fitness_center: Train Classifier")
dataset_name = st.sidebar.text_input("Dataset Name", value="")
st.session_state.dataset_name = dataset_name
base_model_name = st.sidebar.selectbox(
"Base Model",
options=[
"distilbert/distilbert-base-uncased",
"FacebookAI/roberta-base",
"microsoft/deberta-v3-base",
],
)
st.session_state.base_model_name = base_model_name
batch_size = st.sidebar.slider(
"Batch Size", min_value=4, max_value=256, value=16, step=4
)
st.session_state.batch_size = batch_size
train_button = st.sidebar.button("Train")
st.session_state.should_start_training = (
train_button and st.session_state.dataset_name and st.session_state.base_model_name
)
if st.session_state.should_start_training:
with st.expander("Training", expanded=True):
training_output = train_binary_classifier(
project_name=os.getenv("WANDB_PROJECT_NAME"),
entity_name=os.getenv("WANDB_ENTITY_NAME"),
run_name=f"{st.session_state.base_model_name}-finetuned",
dataset_repo=st.session_state.dataset_name,
model_name=st.session_state.base_model_name,
batch_size=st.session_state.batch_size,
streamlit_mode=True,
)
st.session_state.training_output = training_output
st.write(training_output)
|