geekyrakshit commited on
Commit
8647e3b
·
1 Parent(s): a70d6a8

refactor: classifier training

Browse files
application_pages/train_classifier.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  import streamlit as st
4
  from dotenv import load_dotenv
5
 
6
- from guardrails_genie.train_classifier import train_binary_classifier
7
 
8
 
9
  def initialize_session_state():
 
3
  import streamlit as st
4
  from dotenv import load_dotenv
5
 
6
+ from guardrails_genie.train.train_classifier import train_binary_classifier
7
 
8
 
9
  def initialize_session_state():
guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark_weave.py CHANGED
@@ -362,7 +362,7 @@ def main():
362
  preprocess_model_input=preprocess_model_input,
363
  )
364
 
365
- results = asyncio.run(evaluation.evaluate(guardrail))
366
 
367
 
368
  if __name__ == "__main__":
 
362
  preprocess_model_input=preprocess_model_input,
363
  )
364
 
365
+ asyncio.run(evaluation.evaluate(guardrail))
366
 
367
 
368
  if __name__ == "__main__":
guardrails_genie/guardrails/injection/classifier_guardrail.py CHANGED
@@ -1,12 +1,11 @@
1
  from typing import Optional
2
 
3
  import torch
 
4
  import weave
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from transformers.pipelines.base import Pipeline
7
 
8
- import wandb
9
-
10
  from ..base import Guardrail
11
 
12
 
 
1
  from typing import Optional
2
 
3
  import torch
4
+ import wandb
5
  import weave
6
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
7
  from transformers.pipelines.base import Pipeline
8
 
 
 
9
  from ..base import Guardrail
10
 
11
 
guardrails_genie/{train_classifier.py → train/train_classifier.py} RENAMED
@@ -1,54 +1,17 @@
1
  import evaluate
2
  import numpy as np
3
  import streamlit as st
 
4
  from datasets import load_dataset
5
  from transformers import (
6
  AutoModelForSequenceClassification,
7
  AutoTokenizer,
8
  DataCollatorWithPadding,
9
  Trainer,
10
- TrainerCallback,
11
  TrainingArguments,
12
  )
13
- from transformers.trainer_callback import TrainerControl, TrainerState
14
-
15
- import wandb
16
-
17
-
18
- class StreamlitProgressbarCallback(TrainerCallback):
19
- """
20
- StreamlitProgressbarCallback is a custom callback for the Hugging Face Trainer
21
- that integrates a progress bar into a Streamlit application. This class updates
22
- the progress bar at each training step, providing real-time feedback on the
23
- training process within the Streamlit interface.
24
 
25
- Attributes:
26
- progress_bar (streamlit.delta_generator.DeltaGenerator): A Streamlit progress
27
- bar object initialized to 0 with the text "Training".
28
-
29
- Methods:
30
- on_step_begin(args, state, control, **kwargs):
31
- Updates the progress bar at the beginning of each training step. The progress
32
- is calculated as the percentage of completed steps out of the total steps.
33
- The progress bar text is updated to show the current step and the total steps.
34
- """
35
-
36
- def __init__(self, *args, **kwargs):
37
- super().__init__(*args, **kwargs)
38
- self.progress_bar = st.progress(0, text="Training")
39
-
40
- def on_step_begin(
41
- self,
42
- args: TrainingArguments,
43
- state: TrainerState,
44
- control: TrainerControl,
45
- **kwargs,
46
- ):
47
- super().on_step_begin(args, state, control, **kwargs)
48
- self.progress_bar.progress(
49
- (state.global_step * 100 // state.max_steps) + 1,
50
- text=f"Training {state.global_step} / {state.max_steps}",
51
- )
52
 
53
 
54
  def train_binary_classifier(
 
1
  import evaluate
2
  import numpy as np
3
  import streamlit as st
4
+ import wandb
5
  from datasets import load_dataset
6
  from transformers import (
7
  AutoModelForSequenceClassification,
8
  AutoTokenizer,
9
  DataCollatorWithPadding,
10
  Trainer,
 
11
  TrainingArguments,
12
  )
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ from guardrails_genie.utils import StreamlitProgressbarCallback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  def train_binary_classifier(
guardrails_genie/utils.py CHANGED
@@ -1,5 +1,12 @@
1
  import pandas as pd
 
2
  import weave
 
 
 
 
 
 
3
 
4
 
5
  class EvaluationCallManager:
@@ -91,3 +98,39 @@ class EvaluationCallManager:
91
  call["score"]["correct"] for call in guardrail_call["calls"]
92
  ]
93
  return pd.DataFrame(dataframe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
+ import streamlit as st
3
  import weave
4
+ from transformers.trainer_callback import (
5
+ TrainerCallback,
6
+ TrainerControl,
7
+ TrainerState,
8
+ TrainingArguments,
9
+ )
10
 
11
 
12
  class EvaluationCallManager:
 
98
  call["score"]["correct"] for call in guardrail_call["calls"]
99
  ]
100
  return pd.DataFrame(dataframe)
101
+
102
+
103
+ class StreamlitProgressbarCallback(TrainerCallback):
104
+ """
105
+ StreamlitProgressbarCallback is a custom callback for the Hugging Face Trainer
106
+ that integrates a progress bar into a Streamlit application. This class updates
107
+ the progress bar at each training step, providing real-time feedback on the
108
+ training process within the Streamlit interface.
109
+
110
+ Attributes:
111
+ progress_bar (streamlit.delta_generator.DeltaGenerator): A Streamlit progress
112
+ bar object initialized to 0 with the text "Training".
113
+
114
+ Methods:
115
+ on_step_begin(args, state, control, **kwargs):
116
+ Updates the progress bar at the beginning of each training step. The progress
117
+ is calculated as the percentage of completed steps out of the total steps.
118
+ The progress bar text is updated to show the current step and the total steps.
119
+ """
120
+
121
+ def __init__(self, *args, **kwargs):
122
+ super().__init__(*args, **kwargs)
123
+ self.progress_bar = st.progress(0, text="Training")
124
+
125
+ def on_step_begin(
126
+ self,
127
+ args: TrainingArguments,
128
+ state: TrainerState,
129
+ control: TrainerControl,
130
+ **kwargs,
131
+ ):
132
+ super().on_step_begin(args, state, control, **kwargs)
133
+ self.progress_bar.progress(
134
+ (state.global_step * 100 // state.max_steps) + 1,
135
+ text=f"Training {state.global_step} / {state.max_steps}",
136
+ )