""" Trainer Class ============= """ import collections import json import logging import math import os import scipy import torch import tqdm import transformers import textattack from textattack.shared.utils import logger from .attack import Attack from .attack_args import AttackArgs from .attack_results import MaximizedAttackResult, SuccessfulAttackResult from .attacker import Attacker from .model_args import HUGGINGFACE_MODELS from .models.helpers import LSTMForClassification, WordCNNForClassification from .models.wrappers import ModelWrapper from .training_args import CommandLineTrainingArgs, TrainingArgs class Trainer: """Trainer is training and eval loop for adversarial training. It is designed to work with PyTorch and Transformers models. Args: model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`): Model wrapper containing both the model and the tokenizer. task_type (:obj:`str`, `optional`, defaults to :obj:`"classification"`): The task that the model is trained to perform. Currently, :class:`~textattack.Trainer` supports two tasks: (1) :obj:`"classification"`, (2) :obj:`"regression"`. attack (:class:`~textattack.Attack`): :class:`~textattack.Attack` used to generate adversarial examples for training. train_dataset (:class:`~textattack.datasets.Dataset`): Dataset for training. eval_dataset (:class:`~textattack.datasets.Dataset`): Dataset for evaluation training_args (:class:`~textattack.TrainingArgs`): Arguments for training. Example:: >>> import textattack >>> import transformers >>> model = transformers.AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") >>> tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased") >>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) >>> # We only use DeepWordBugGao2018 to demonstration purposes. >>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) >>> train_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="train") >>> eval_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") >>> # Train for 3 epochs with 1 initial clean epochs, 1000 adversarial examples per epoch, learning rate of 5e-5, and effective batch size of 32 (8x4). >>> training_args = textattack.TrainingArgs( ... num_epochs=3, ... num_clean_epochs=1, ... num_train_adv_examples=1000, ... learning_rate=5e-5, ... per_device_train_batch_size=8, ... gradient_accumulation_steps=4, ... log_to_tb=True, ... ) >>> trainer = textattack.Trainer( ... model_wrapper, ... "classification", ... attack, ... train_dataset, ... eval_dataset, ... training_args ... ) >>> trainer.train() .. note:: When using :class:`~textattack.Trainer` with `parallel=True` in :class:`~textattack.TrainingArgs`, make sure to protect the “entry point” of the program by using :obj:`if __name__ == '__main__':`. If not, each worker process used for generating adversarial examples will execute the training code again. """ def __init__( self, model_wrapper, task_type="classification", attack=None, train_dataset=None, eval_dataset=None, training_args=None, ): assert isinstance( model_wrapper, ModelWrapper ), f"`model_wrapper` must be of type `textattack.models.wrappers.ModelWrapper`, but got type `{type(model_wrapper)}`." # TODO: Support seq2seq training assert task_type in { "classification", "regression", }, '`task_type` must either be "classification" or "regression"' if attack: assert isinstance( attack, Attack ), f"`attack` argument must be of type `textattack.Attack`, but got type of `{type(attack)}`." if id(model_wrapper) != id(attack.goal_function.model): logger.warn( "`model_wrapper` and the victim model of `attack` are not the same model." ) if train_dataset: assert isinstance( train_dataset, textattack.datasets.Dataset ), f"`train_dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(train_dataset)}`." if eval_dataset: assert isinstance( eval_dataset, textattack.datasets.Dataset ), f"`eval_dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(eval_dataset)}`." if training_args: assert isinstance( training_args, TrainingArgs ), f"`training_args` must be of type `textattack.TrainingArgs`, but got type `{type(training_args)}`." else: training_args = TrainingArgs() if not hasattr(model_wrapper, "model"): raise ValueError("Cannot detect `model` in `model_wrapper`") else: assert isinstance( model_wrapper.model, torch.nn.Module ), f"`model` in `model_wrapper` must be of type `torch.nn.Module`, but got type `{type(model_wrapper.model)}`." if not hasattr(model_wrapper, "tokenizer"): raise ValueError("Cannot detect `tokenizer` in `model_wrapper`") self.model_wrapper = model_wrapper self.task_type = task_type self.attack = attack self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.training_args = training_args self._metric_name = ( "pearson_correlation" if self.task_type == "regression" else "accuracy" ) if self.task_type == "regression": self.loss_fct = torch.nn.MSELoss(reduction="none") else: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") self._global_step = 0 def _generate_adversarial_examples(self, epoch): """Generate adversarial examples using attacker.""" assert ( self.attack is not None ), "`attack` is `None` but attempting to generate adversarial examples." base_file_name = f"attack-train-{epoch}" log_file_name = os.path.join(self.training_args.output_dir, base_file_name) logger.info("Attacking model to generate new adversarial training set...") if isinstance(self.training_args.num_train_adv_examples, float): num_train_adv_examples = math.ceil( len(self.train_dataset) * self.training_args.num_train_adv_examples ) else: num_train_adv_examples = self.training_args.num_train_adv_examples # Use Different AttackArgs based on num_train_adv_examples value. # If num_train_adv_examples >= 0 , num_train_adv_examples is # set as number of successful examples. # If num_train_adv_examples == -1 , num_examples is set to -1 to # generate example for all of training data. if num_train_adv_examples >= 0: attack_args = AttackArgs( num_successful_examples=num_train_adv_examples, num_examples_offset=0, query_budget=self.training_args.query_budget_train, shuffle=True, parallel=self.training_args.parallel, num_workers_per_device=self.training_args.attack_num_workers_per_device, disable_stdout=True, silent=True, log_to_txt=log_file_name + ".txt", log_to_csv=log_file_name + ".csv", ) elif num_train_adv_examples == -1: # set num_examples when num_train_adv_examples = -1 attack_args = AttackArgs( num_examples=num_train_adv_examples, num_examples_offset=0, query_budget=self.training_args.query_budget_train, shuffle=True, parallel=self.training_args.parallel, num_workers_per_device=self.training_args.attack_num_workers_per_device, disable_stdout=True, silent=True, log_to_txt=log_file_name + ".txt", log_to_csv=log_file_name + ".csv", ) else: assert False, "num_train_adv_examples is negative and not equal to -1." attacker = Attacker(self.attack, self.train_dataset, attack_args=attack_args) results = attacker.attack_dataset() attack_types = collections.Counter(r.__class__.__name__ for r in results) total_attacks = ( attack_types["SuccessfulAttackResult"] + attack_types["FailedAttackResult"] ) success_rate = attack_types["SuccessfulAttackResult"] / total_attacks * 100 logger.info(f"Total number of attack results: {len(results)}") logger.info( f"Attack success rate: {success_rate:.2f}% [{attack_types['SuccessfulAttackResult']} / {total_attacks}]" ) # TODO: This will produce a bug if we need to manipulate ground truth output. # To Fix Issue #498 , We need to add the Non Output columns in one tuple to represent input columns # Since adversarial_example won't be an input to the model , we will have to remove it from the input # dictionary in collate_fn adversarial_examples = [ ( tuple(r.perturbed_result.attacked_text._text_input.values()) + ("adversarial_example",), r.perturbed_result.ground_truth_output, ) for r in results if isinstance(r, (SuccessfulAttackResult, MaximizedAttackResult)) ] # Name for column indicating if an example is adversarial is set as "_example_type". adversarial_dataset = textattack.datasets.Dataset( adversarial_examples, input_columns=self.train_dataset.input_columns + ("_example_type",), label_map=self.train_dataset.label_map, label_names=self.train_dataset.label_names, output_scale_factor=self.train_dataset.output_scale_factor, shuffle=False, ) return adversarial_dataset def _print_training_args( self, total_training_steps, train_batch_size, num_clean_epochs ): logger.info("***** Running training *****") logger.info(f" Num examples = {len(self.train_dataset)}") logger.info(f" Num epochs = {self.training_args.num_epochs}") logger.info(f" Num clean epochs = {num_clean_epochs}") logger.info( f" Instantaneous batch size per device = {self.training_args.per_device_train_batch_size}" ) logger.info( f" Total train batch size (w. parallel, distributed & accumulation) = {train_batch_size * self.training_args.gradient_accumulation_steps}" ) logger.info( f" Gradient accumulation steps = {self.training_args.gradient_accumulation_steps}" ) logger.info(f" Total optimization steps = {total_training_steps}") def _save_model_checkpoint( self, model, tokenizer, step=None, epoch=None, best=False, last=False ): # Save model checkpoint if step: dir_name = f"checkpoint-step-{step}" if epoch: dir_name = f"checkpoint-epoch-{epoch}" if best: dir_name = "best_model" if last: dir_name = "last_model" output_dir = os.path.join(self.training_args.output_dir, dir_name) if not os.path.exists(output_dir): os.makedirs(output_dir) if isinstance(model, torch.nn.DataParallel): model = model.module if isinstance(model, (WordCNNForClassification, LSTMForClassification)): model.save_pretrained(output_dir) elif isinstance(model, transformers.PreTrainedModel): model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) else: state_dict = {k: v.cpu() for k, v in model.state_dict().items()} torch.save( state_dict, os.path.join(output_dir, "pytorch_model.bin"), ) def _tb_log(self, log, step): if not hasattr(self, "_tb_writer"): from torch.utils.tensorboard import SummaryWriter self._tb_writer = SummaryWriter(self.training_args.tb_log_dir) self._tb_writer.add_hparams(self.training_args.__dict__, {}) self._tb_writer.flush() for key in log: self._tb_writer.add_scalar(key, log[key], step) def _wandb_log(self, log, step): if not hasattr(self, "_wandb_init"): global wandb import wandb self._wandb_init = True wandb.init( project=self.training_args.wandb_project, config=self.training_args.__dict__, ) wandb.log(log, step=step) def get_optimizer_and_scheduler(self, model, num_training_steps): """Returns optimizer and scheduler to use for training. If you are overriding this method and do not want to use a scheduler, simply return :obj:`None` for scheduler. Args: model (:obj:`torch.nn.Module`): Model to be trained. Pass its parameters to optimizer for training. num_training_steps (:obj:`int`): Number of total training steps. Returns: Tuple of optimizer and scheduler :obj:`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]` """ if isinstance(model, torch.nn.DataParallel): model = model.module if isinstance(model, transformers.PreTrainedModel): # Reference https://huggingface.co/transformers/training.html param_optimizer = list(model.named_parameters()) no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], "weight_decay": self.training_args.weight_decay, }, { "params": [ p for n, p in param_optimizer if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] optimizer = transformers.optimization.AdamW( optimizer_grouped_parameters, lr=self.training_args.learning_rate ) if isinstance(self.training_args.num_warmup_steps, float): num_warmup_steps = math.ceil( self.training_args.num_warmup_steps * num_training_steps ) else: num_warmup_steps = self.training_args.num_warmup_steps scheduler = transformers.optimization.get_linear_schedule_with_warmup( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, ) else: optimizer = torch.optim.Adam( filter(lambda x: x.requires_grad, model.parameters()), lr=self.training_args.learning_rate, ) scheduler = None return optimizer, scheduler def get_train_dataloader(self, dataset, adv_dataset, batch_size): """Returns the :obj:`torch.utils.data.DataLoader` for training. Args: dataset (:class:`~textattack.datasets.Dataset`): Original training dataset. adv_dataset (:class:`~textattack.datasets.Dataset`): Adversarial examples generated from the original training dataset. :obj:`None` if no adversarial attack takes place. batch_size (:obj:`int`): Batch size for training. Returns: :obj:`torch.utils.data.DataLoader` """ # TODO: Add pairing option where we can pair original examples with adversarial examples. # Helper functions for collating data def collate_fn(data): input_texts = [] targets = [] is_adv_sample = [] for item in data: if "_example_type" in item[0].keys(): # Get example type value from OrderedDict and remove it adv = item[0].pop("_example_type") # with _example_type removed from item[0] OrderedDict # all other keys should be part of input _input, label = item if adv != "adversarial_example": raise ValueError( "`item` has length of 3 but last element is not for marking if the item is an `adversarial example`." ) else: is_adv_sample.append(True) else: # else `len(item)` is 2. _input, label = item is_adv_sample.append(False) if isinstance(_input, collections.OrderedDict): _input = tuple(_input.values()) else: _input = tuple(_input) if len(_input) == 1: _input = _input[0] input_texts.append(_input) targets.append(label) return input_texts, torch.tensor(targets), torch.tensor(is_adv_sample) if adv_dataset: dataset = torch.utils.data.ConcatDataset([dataset, adv_dataset]) train_dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True, ) return train_dataloader def get_eval_dataloader(self, dataset, batch_size): """Returns the :obj:`torch.utils.data.DataLoader` for evaluation. Args: dataset (:class:`~textattack.datasets.Dataset`): Dataset to use for evaluation. batch_size (:obj:`int`): Batch size for evaluation. Returns: :obj:`torch.utils.data.DataLoader` """ # Helper functions for collating data def collate_fn(data): input_texts = [] targets = [] for _input, label in data: if isinstance(_input, collections.OrderedDict): _input = tuple(_input.values()) else: _input = tuple(_input) if len(_input) == 1: _input = _input[0] input_texts.append(_input) targets.append(label) return input_texts, torch.tensor(targets) eval_dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True, ) return eval_dataloader def training_step(self, model, tokenizer, batch): """Perform a single training step on a batch of inputs. Args: model (:obj:`torch.nn.Module`): Model to train. tokenizer: Tokenizer used to tokenize input text. batch (:obj:`tuple[list[str], torch.Tensor, torch.Tensor]`): By default, this will be a tuple of input texts, targets, and boolean tensor indicating if the sample is an adversarial example. .. note:: If you override the :meth:`get_train_dataloader` method, then shape/type of :obj:`batch` will depend on how you created your batch. Returns: :obj:`tuple[torch.Tensor, torch.Tensor, torch.Tensor]` where - **loss**: :obj:`torch.FloatTensor` of shape 1 containing the loss. - **preds**: :obj:`torch.FloatTensor` of model's prediction for the batch. - **targets**: :obj:`torch.Tensor` of model's targets (e.g. labels, target values). """ input_texts, targets, is_adv_sample = batch _targets = targets targets = targets.to(textattack.shared.utils.device) if isinstance(model, transformers.PreTrainedModel) or ( isinstance(model, torch.nn.DataParallel) and isinstance(model.module, transformers.PreTrainedModel) ): input_ids = tokenizer( input_texts, padding="max_length", return_tensors="pt", truncation=True, ) input_ids.to(textattack.shared.utils.device) logits = model(**input_ids)[0] else: input_ids = tokenizer(input_texts) if not isinstance(input_ids, torch.Tensor): input_ids = torch.tensor(input_ids) input_ids = input_ids.to(textattack.shared.utils.device) logits = model(input_ids) if self.task_type == "regression": loss = self.loss_fct(logits.squeeze(), targets.squeeze()) preds = logits else: loss = self.loss_fct(logits, targets) preds = logits.argmax(dim=-1) sample_weights = torch.ones( is_adv_sample.size(), device=textattack.shared.utils.device ) sample_weights[is_adv_sample] *= self.training_args.alpha loss = loss * sample_weights loss = torch.mean(loss) preds = preds.cpu() return loss, preds, _targets def evaluate_step(self, model, tokenizer, batch): """Perform a single evaluation step on a batch of inputs. Args: model (:obj:`torch.nn.Module`): Model to train. tokenizer: Tokenizer used to tokenize input text. batch (:obj:`tuple[list[str], torch.Tensor]`): By default, this will be a tuple of input texts and target tensors. .. note:: If you override the :meth:`get_eval_dataloader` method, then shape/type of :obj:`batch` will depend on how you created your batch. Returns: :obj:`tuple[torch.Tensor, torch.Tensor]` where - **preds**: :obj:`torch.FloatTensor` of model's prediction for the batch. - **targets**: :obj:`torch.Tensor` of model's targets (e.g. labels, target values). """ input_texts, targets = batch _targets = targets targets = targets.to(textattack.shared.utils.device) if isinstance(model, transformers.PreTrainedModel): input_ids = tokenizer( input_texts, padding="max_length", return_tensors="pt", truncation=True, ) input_ids.to(textattack.shared.utils.device) logits = model(**input_ids)[0] else: input_ids = tokenizer(input_texts) if not isinstance(input_ids, torch.Tensor): input_ids = torch.tensor(input_ids) input_ids = input_ids.to(textattack.shared.utils.device) logits = model(input_ids) if self.task_type == "regression": preds = logits else: preds = logits.argmax(dim=-1) return preds.cpu(), _targets def train(self): """Train the model on given training dataset.""" if not self.train_dataset: raise ValueError("No `train_dataset` available for training.") textattack.shared.utils.set_seed(self.training_args.random_seed) if not os.path.exists(self.training_args.output_dir): os.makedirs(self.training_args.output_dir) # Save logger writes to file log_txt_path = os.path.join(self.training_args.output_dir, "train_log.txt") fh = logging.FileHandler(log_txt_path) fh.setLevel(logging.DEBUG) logger.addHandler(fh) logger.info(f"Writing logs to {log_txt_path}.") # Save original self.training_args to file args_save_path = os.path.join( self.training_args.output_dir, "training_args.json" ) with open(args_save_path, "w", encoding="utf-8") as f: json.dump(self.training_args.__dict__, f) logger.info(f"Wrote original training args to {args_save_path}.") num_gpus = torch.cuda.device_count() tokenizer = self.model_wrapper.tokenizer model = self.model_wrapper.model if self.training_args.parallel and num_gpus > 1: # TODO: torch.nn.parallel.DistributedDataParallel # Supposedly faster than DataParallel, but requires more work to setup properly. model = torch.nn.DataParallel(model) logger.info(f"Training on {num_gpus} GPUs via `torch.nn.DataParallel`.") train_batch_size = self.training_args.per_device_train_batch_size * num_gpus else: train_batch_size = self.training_args.per_device_train_batch_size if self.attack is None: num_clean_epochs = self.training_args.num_epochs else: num_clean_epochs = self.training_args.num_clean_epochs total_clean_training_steps = ( math.ceil( len(self.train_dataset) / (train_batch_size * self.training_args.gradient_accumulation_steps) ) * num_clean_epochs ) # calculate total_adv_training_data_length based on type of # num_train_adv_examples. # if num_train_adv_examples is float , num_train_adv_examples is a portion of train_dataset. if isinstance(self.training_args.num_train_adv_examples, float): total_adv_training_data_length = ( len(self.train_dataset) * self.training_args.num_train_adv_examples ) # if num_train_adv_examples is int and >=0 then it is taken as value. elif ( isinstance(self.training_args.num_train_adv_examples, int) and self.training_args.num_train_adv_examples >= 0 ): total_adv_training_data_length = self.training_args.num_train_adv_examples # if num_train_adv_examples is = -1 , we generate all possible adv examples. # Max number of all possible adv examples would be equal to train_dataset. else: total_adv_training_data_length = len(self.train_dataset) # Based on total_adv_training_data_length calculation , find total total_adv_training_steps total_adv_training_steps = math.ceil( (len(self.train_dataset) + total_adv_training_data_length) / (train_batch_size * self.training_args.gradient_accumulation_steps) ) * (self.training_args.num_epochs - num_clean_epochs) total_training_steps = total_clean_training_steps + total_adv_training_steps optimizer, scheduler = self.get_optimizer_and_scheduler( model, total_training_steps ) self._print_training_args( total_training_steps, train_batch_size, num_clean_epochs ) model.to(textattack.shared.utils.device) # Variables across epochs self._total_loss = 0.0 self._current_loss = 0.0 self._last_log_step = 0 # `best_score` is used to keep track of the best model across training. # Could be loss, accuracy, or other metrics. best_eval_score = 0.0 best_eval_score_epoch = 0 best_model_path = None epochs_since_best_eval_score = 0 for epoch in range(1, self.training_args.num_epochs + 1): logger.info("==========================================================") logger.info(f"Epoch {epoch}") if self.attack and epoch > num_clean_epochs: if ( epoch - num_clean_epochs - 1 ) % self.training_args.attack_epoch_interval == 0: # only generate a new adversarial training set every self.training_args.attack_period epochs after the clean epochs # adv_dataset is instance of `textattack.datasets.Dataset` model.eval() adv_dataset = self._generate_adversarial_examples(epoch) model.train() model.to(textattack.shared.utils.device) else: adv_dataset = None else: logger.info(f"Running clean epoch {epoch}/{num_clean_epochs}") adv_dataset = None train_dataloader = self.get_train_dataloader( self.train_dataset, adv_dataset, train_batch_size ) model.train() # Epoch variables all_preds = [] all_targets = [] prog_bar = tqdm.tqdm( train_dataloader, desc="Iteration", position=0, leave=True, dynamic_ncols=True, ) for step, batch in enumerate(prog_bar): loss, preds, targets = self.training_step(model, tokenizer, batch) if isinstance(model, torch.nn.DataParallel): loss = loss.mean() loss = loss / self.training_args.gradient_accumulation_steps loss.backward() loss = loss.item() self._total_loss += loss self._current_loss += loss all_preds.append(preds) all_targets.append(targets) if (step + 1) % self.training_args.gradient_accumulation_steps == 0: optimizer.step() if scheduler: scheduler.step() optimizer.zero_grad() self._global_step += 1 if self._global_step > 0: prog_bar.set_description( f"Loss {self._total_loss/self._global_step:.5f}" ) # TODO: Better way to handle TB and Wandb logging if (self._global_step > 0) and ( self._global_step % self.training_args.logging_interval_step == 0 ): lr_to_log = ( scheduler.get_last_lr()[0] if scheduler else self.training_args.learning_rate ) if self._global_step - self._last_log_step >= 1: loss_to_log = round( self._current_loss / (self._global_step - self._last_log_step), 4, ) else: loss_to_log = round(self._current_loss, 4) log = {"train/loss": loss_to_log, "train/learning_rate": lr_to_log} if self.training_args.log_to_tb: self._tb_log(log, self._global_step) if self.training_args.log_to_wandb: self._wandb_log(log, self._global_step) self._current_loss = 0.0 self._last_log_step = self._global_step # Save model checkpoint to file. if self.training_args.checkpoint_interval_steps: if ( self._global_step > 0 and ( self._global_step % self.training_args.checkpoint_interval_steps ) == 0 ): self._save_model_checkpoint( model, tokenizer, step=self._global_step ) preds = torch.cat(all_preds) targets = torch.cat(all_targets) if self._metric_name == "accuracy": correct_predictions = (preds == targets).sum().item() accuracy = correct_predictions / len(targets) metric_log = {"train/train_accuracy": accuracy} logger.info(f"Train accuracy: {accuracy*100:.2f}%") else: pearson_correlation, pearson_pvalue = scipy.stats.pearsonr( preds, targets ) metric_log = { "train/pearson_correlation": pearson_correlation, "train/pearson_pvalue": pearson_pvalue, } logger.info(f"Train Pearson correlation: {pearson_correlation:.4f}%") if len(targets) > 0: if self.training_args.log_to_tb: self._tb_log(metric_log, epoch) if self.training_args.log_to_wandb: metric_log["epoch"] = epoch self._wandb_log(metric_log, self._global_step) # Evaluate after each epoch. eval_score = self.evaluate() if self.training_args.log_to_tb: self._tb_log({f"eval/{self._metric_name}": eval_score}, epoch) if self.training_args.log_to_wandb: self._wandb_log( {f"eval/{self._metric_name}": eval_score, "epoch": epoch}, self._global_step, ) if ( self.training_args.checkpoint_interval_epochs and (epoch % self.training_args.checkpoint_interval_epochs) == 0 ): self._save_model_checkpoint(model, tokenizer, epoch=epoch) if eval_score > best_eval_score: best_eval_score = eval_score best_eval_score_epoch = epoch epochs_since_best_eval_score = 0 self._save_model_checkpoint(model, tokenizer, best=True) logger.info( f"Best score found. Saved model to {self.training_args.output_dir}/best_model/" ) else: epochs_since_best_eval_score += 1 if self.training_args.early_stopping_epochs and ( epochs_since_best_eval_score > self.training_args.early_stopping_epochs ): logger.info( f"Stopping early since it's been {self.training_args.early_stopping_epochs} steps since validation score increased." ) break if self.training_args.log_to_tb: self._tb_writer.flush() # Finish training if isinstance(model, torch.nn.DataParallel): model = model.module if self.training_args.load_best_model_at_end: best_model_path = os.path.join(self.training_args.output_dir, "best_model") if hasattr(model, "from_pretrained"): model = model.__class__.from_pretrained(best_model_path) else: model = model.load_state_dict( torch.load(os.path.join(best_model_path, "pytorch_model.bin")) ) if self.training_args.save_last: self._save_model_checkpoint(model, tokenizer, last=True) self.model_wrapper.model = model self._write_readme(best_eval_score, best_eval_score_epoch, train_batch_size) def evaluate(self): """Evaluate the model on given evaluation dataset.""" if not self.eval_dataset: raise ValueError("No `eval_dataset` available for training.") logging.info("Evaluating model on evaluation dataset.") model = self.model_wrapper.model tokenizer = self.model_wrapper.tokenizer model.eval() all_preds = [] all_targets = [] if isinstance(model, torch.nn.DataParallel): num_gpus = torch.cuda.device_count() eval_batch_size = self.training_args.per_device_eval_batch_size * num_gpus else: eval_batch_size = self.training_args.per_device_eval_batch_size eval_dataloader = self.get_eval_dataloader(self.eval_dataset, eval_batch_size) with torch.no_grad(): for step, batch in enumerate(eval_dataloader): preds, targets = self.evaluate_step(model, tokenizer, batch) all_preds.append(preds) all_targets.append(targets) preds = torch.cat(all_preds) targets = torch.cat(all_targets) if self.task_type == "regression": pearson_correlation, pearson_p_value = scipy.stats.pearsonr(preds, targets) eval_score = pearson_correlation else: correct_predictions = (preds == targets).sum().item() accuracy = correct_predictions / len(targets) eval_score = accuracy if self._metric_name == "accuracy": logger.info(f"Eval {self._metric_name}: {eval_score*100:.2f}%") else: logger.info(f"Eval {self._metric_name}: {eval_score:.4f}%") return eval_score def _write_readme(self, best_eval_score, best_eval_score_epoch, train_batch_size): if isinstance(self.training_args, CommandLineTrainingArgs): model_name = self.training_args.model_name_or_path elif isinstance(self.model_wrapper.model, transformers.PreTrainedModel): if ( hasattr(self.model_wrapper.model.config, "_name_or_path") and self.model_wrapper.model.config._name_or_path in HUGGINGFACE_MODELS ): # TODO Better way than just checking HUGGINGFACE_MODELS ? model_name = self.model_wrapper.model.config._name_or_path elif hasattr(self.model_wrapper.model.config, "model_type"): model_name = self.model_wrapper.model.config.model_type else: model_name = "" else: model_name = "" if model_name: model_name = f"`{model_name}`" if ( isinstance(self.training_args, CommandLineTrainingArgs) and self.training_args.model_max_length ): model_max_length = self.training_args.model_max_length elif isinstance( self.model_wrapper.model, ( transformers.PreTrainedModel, LSTMForClassification, WordCNNForClassification, ), ): model_max_length = self.model_wrapper.tokenizer.model_max_length else: model_max_length = None if model_max_length: model_max_length_str = f" a maximum sequence length of {model_max_length}," else: model_max_length_str = "" if isinstance( self.train_dataset, textattack.datasets.HuggingFaceDataset ) and hasattr(self.train_dataset, "_name"): dataset_name = self.train_dataset._name if hasattr(self.train_dataset, "_subset"): dataset_name += f" ({self.train_dataset._subset})" elif isinstance( self.eval_dataset, textattack.datasets.HuggingFaceDataset ) and hasattr(self.eval_dataset, "_name"): dataset_name = self.eval_dataset._name if hasattr(self.eval_dataset, "_subset"): dataset_name += f" ({self.eval_dataset._subset})" else: dataset_name = None if dataset_name: dataset_str = ( "and the `{dataset_name}` dataset loaded using the `datasets` library" ) else: dataset_str = "" loss_func = ( "mean squared error" if self.task_type == "regression" else "cross-entropy" ) metric_name = ( "pearson correlation" if self.task_type == "regression" else "accuracy" ) epoch_info = f"{best_eval_score_epoch} epoch" + ( "s" if best_eval_score_epoch > 1 else "" ) readme_text = f""" ## TextAttack Model Card This {model_name} model was fine-tuned using TextAttack{dataset_str}. The model was fine-tuned for {self.training_args.num_epochs} epochs with a batch size of {train_batch_size}, {model_max_length_str} and an initial learning rate of {self.training_args.learning_rate}. Since this was a {self.task_type} task, the model was trained with a {loss_func} loss function. The best score the model achieved on this task was {best_eval_score}, as measured by the eval set {metric_name}, found after {epoch_info}. For more information, check out [TextAttack on Github](https://github.com/QData/TextAttack). """ readme_save_path = os.path.join(self.training_args.output_dir, "README.md") with open(readme_save_path, "w", encoding="utf-8") as f: f.write(readme_text.strip() + "\n") logger.info(f"Wrote README to {readme_save_path}.")