geekyrakshit's picture
refactor: classifier training
8647e3b
import pandas as pd
import streamlit as st
import weave
from transformers.trainer_callback import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
class EvaluationCallManager:
"""
Manages the evaluation calls for a specific project and entity in Weave.
This class is responsible for initializing and managing evaluation calls associated with a
specific project and entity. It provides functionality to collect guardrail guard calls
from evaluation predictions and scores, and render these calls into a structured format
suitable for display in Streamlit.
Args:
entity (str): The entity name.
project (str): The project name.
call_id (str): The call id.
max_count (int): The maximum number of guardrail guard calls to collect from the evaluation.
"""
def __init__(self, entity: str, project: str, call_id: str, max_count: int = 10):
self.base_call = weave.init(f"{entity}/{project}").get_call(call_id=call_id)
self.max_count = max_count
self.show_warning_in_app = False
self.call_list = []
def collect_guardrail_guard_calls_from_eval(self):
"""
Collects guardrail guard calls from evaluation predictions and scores.
This function iterates through the children calls of the base evaluation call,
extracting relevant guardrail guard calls and their associated scores. It stops
collecting calls if it encounters an "Evaluation.summarize" operation or if the
maximum count of guardrail guard calls is reached. The collected calls are stored
in a list of dictionaries, each containing the input prompt, outputs, and score.
Returns:
list: A list of dictionaries, each containing:
- input_prompt (str): The input prompt for the guard call.
- outputs (dict): The outputs of the guard call.
- score (dict): The score of the guard call.
"""
guard_calls, count = [], 0
for eval_predict_and_score_call in self.base_call.children():
if "Evaluation.summarize" in eval_predict_and_score_call._op_name:
break
guardrail_predict_call = eval_predict_and_score_call.children()[0]
guard_call = guardrail_predict_call.children()[0]
score_call = eval_predict_and_score_call.children()[1]
guard_calls.append(
{
"input_prompt": str(guard_call.inputs["prompt"]),
"outputs": dict(guard_call.output),
"score": dict(score_call.output),
}
)
count += 1
if count >= self.max_count:
self.show_warning_in_app = True
break
return guard_calls
def render_calls_to_streamlit(self):
"""
Renders the collected guardrail guard calls into a pandas DataFrame suitable for
display in Streamlit.
This function processes the collected guardrail guard calls stored in `self.call_list` and
organizes them into a dictionary format that can be easily converted into a pandas DataFrame.
The DataFrame contains columns for the input prompts, the safety status of the outputs, and
the correctness of the predictions for each guardrail.
The structure of the DataFrame is as follows:
- The first column contains the input prompts.
- Subsequent columns contain the safety status and prediction correctness for each guardrail.
Returns:
pd.DataFrame: A DataFrame containing the input prompts, safety status, and prediction
correctness for each guardrail.
"""
dataframe = {
"input_prompt": [
call["input_prompt"] for call in self.call_list[0]["calls"]
]
}
for guardrail_call in self.call_list:
dataframe[guardrail_call["guardrail_name"] + ".safe"] = [
call["outputs"]["safe"] for call in guardrail_call["calls"]
]
dataframe[guardrail_call["guardrail_name"] + ".prediction_correctness"] = [
call["score"]["correct"] for call in guardrail_call["calls"]
]
return pd.DataFrame(dataframe)
class StreamlitProgressbarCallback(TrainerCallback):
"""
StreamlitProgressbarCallback is a custom callback for the Hugging Face Trainer
that integrates a progress bar into a Streamlit application. This class updates
the progress bar at each training step, providing real-time feedback on the
training process within the Streamlit interface.
Attributes:
progress_bar (streamlit.delta_generator.DeltaGenerator): A Streamlit progress
bar object initialized to 0 with the text "Training".
Methods:
on_step_begin(args, state, control, **kwargs):
Updates the progress bar at the beginning of each training step. The progress
is calculated as the percentage of completed steps out of the total steps.
The progress bar text is updated to show the current step and the total steps.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.progress_bar = st.progress(0, text="Training")
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
super().on_step_begin(args, state, control, **kwargs)
self.progress_bar.progress(
(state.global_step * 100 // state.max_steps) + 1,
text=f"Training {state.global_step} / {state.max_steps}",
)