import evaluate import numpy as np import streamlit as st from datasets import load_dataset from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments, ) import wandb from guardrails_genie.utils import StreamlitProgressbarCallback def train_binary_classifier( project_name: str, entity_name: str, run_name: str, dataset_repo: str = "geekyrakshit/prompt-injection-dataset", model_name: str = "distilbert/distilbert-base-uncased", prompt_column_name: str = "prompt", id2label: dict[int, str] = {0: "SAFE", 1: "INJECTION"}, label2id: dict[str, int] = {"SAFE": 0, "INJECTION": 1}, learning_rate: float = 1e-5, batch_size: int = 16, num_epochs: int = 2, weight_decay: float = 0.01, save_steps: int = 1000, streamlit_mode: bool = False, ): """ Trains a binary classifier using a specified dataset and model architecture. This function sets up and trains a binary sequence classification model using the Hugging Face Transformers library. It integrates with Weights & Biases for experiment tracking and optionally displays a progress bar in a Streamlit app. Args: project_name (str): The name of the Weights & Biases project. entity_name (str): The Weights & Biases entity (user or team). run_name (str): The name of the Weights & Biases run. dataset_repo (str, optional): The Hugging Face dataset repository to load. model_name (str, optional): The pre-trained model to use. prompt_column_name (str, optional): The column name in the dataset containing the text prompts. id2label (dict[int, str], optional): Mapping from label IDs to label names. label2id (dict[str, int], optional): Mapping from label names to label IDs. learning_rate (float, optional): The learning rate for training. batch_size (int, optional): The batch size for training and evaluation. num_epochs (int, optional): The number of training epochs. weight_decay (float, optional): The weight decay for the optimizer. save_steps (int, optional): The number of steps between model checkpoints. streamlit_mode (bool, optional): If True, integrates with Streamlit to display a progress bar. Returns: dict: The output of the training process, including metrics and model state. Raises: Exception: If an error occurs during training, the exception is raised after ensuring Weights & Biases run is finished. """ wandb.init( project=project_name, entity=entity_name, name=run_name, job_type="train-binary-classifier", ) if streamlit_mode: st.markdown( f"Explore your training logs on [Weights & Biases]({wandb.run.url})" ) dataset = load_dataset(dataset_repo) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenized_datasets = dataset.map( lambda examples: tokenizer(examples[prompt_column_name], truncation=True), batched=True, ) data_collator = DataCollatorWithPadding(tokenizer=tokenizer) accuracy = evaluate.load("accuracy") def compute_metrics(eval_pred): predictions, labels = eval_pred predictions = np.argmax(predictions, axis=1) return accuracy.compute(predictions=predictions, references=labels) model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=2, id2label=id2label, label2id=label2id, ) trainer = Trainer( model=model, args=TrainingArguments( output_dir="binary-classifier", learning_rate=learning_rate, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, num_train_epochs=num_epochs, weight_decay=weight_decay, eval_strategy="epoch", save_strategy="steps", save_steps=save_steps, load_best_model_at_end=True, push_to_hub=False, report_to="wandb", logging_strategy="steps", logging_steps=1, ), train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["test"], processing_class=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics, callbacks=[StreamlitProgressbarCallback()] if streamlit_mode else [], ) try: training_output = trainer.train() except Exception as e: wandb.finish() raise e wandb.finish() return training_output