import os import streamlit as st from dotenv import load_dotenv from guardrails_genie.train.train_classifier import train_binary_classifier def initialize_session_state(): load_dotenv() if "dataset_name" not in st.session_state: st.session_state.dataset_name = None if "base_model_name" not in st.session_state: st.session_state.base_model_name = None if "batch_size" not in st.session_state: st.session_state.batch_size = 16 if "should_start_training" not in st.session_state: st.session_state.should_start_training = False if "training_output" not in st.session_state: st.session_state.training_output = None initialize_session_state() st.title(":material/fitness_center: Train Classifier") dataset_name = st.sidebar.text_input("Dataset Name", value="") st.session_state.dataset_name = dataset_name base_model_name = st.sidebar.selectbox( "Base Model", options=[ "distilbert/distilbert-base-uncased", "FacebookAI/roberta-base", "microsoft/deberta-v3-base", ], ) st.session_state.base_model_name = base_model_name batch_size = st.sidebar.slider( "Batch Size", min_value=4, max_value=256, value=16, step=4 ) st.session_state.batch_size = batch_size train_button = st.sidebar.button("Train") st.session_state.should_start_training = ( train_button and st.session_state.dataset_name and st.session_state.base_model_name ) if st.session_state.should_start_training: with st.expander("Training", expanded=True): training_output = train_binary_classifier( project_name=os.getenv("WANDB_PROJECT_NAME"), entity_name=os.getenv("WANDB_ENTITY_NAME"), run_name=f"{st.session_state.base_model_name}-finetuned", dataset_repo=st.session_state.dataset_name, model_name=st.session_state.base_model_name, batch_size=st.session_state.batch_size, streamlit_mode=True, ) st.session_state.training_output = training_output st.write(training_output)