import torch from torch.utils.data import DataLoader from torch.utils.data import Dataset from termcolor import colored from transformers.optimization import AdamW from itertools import chain import sys sys.path.append("..") from transformers.optimization import get_linear_schedule_with_warmup import os import math import time from datetime import datetime as dt from torch.utils.data import DataLoader from params import * from utils.logger import get_logger from models.model import ModelWrapper from models.sampler import RandomBatchSampler, BucketBatchSampler from utils.metrics import get_metric_for_tfm from accelerate import Accelerator from dataset.autocorrect_dataset import SpellCorrectDataset from dataset.util import load_epoch_dataset class Trainer(): def __init__(self, model_wrapper: ModelWrapper, data_path, dataset_name, valid_dataset: Dataset): self.model_wrapper = model_wrapper self.model = model_wrapper.model self.model_name = model_wrapper.model_name self.data_path = data_path self.incorrect_file = f'{dataset_name}.train.noise' self.correct_file = f'{dataset_name}.train' self.length_file = f'{dataset_name}.length.train' train_dataset = load_epoch_dataset(data_path, self.correct_file, \ self.incorrect_file, self.length_file, 1, EPOCHS) train_dataset = SpellCorrectDataset(dataset=train_dataset) self.train_dataset = train_dataset self.valid_dataset = valid_dataset if not BUCKET_SAMPLING: self.train_sampler = RandomBatchSampler(train_dataset, TRAIN_BATCH_SIZE) self.valid_sampler = RandomBatchSampler(valid_dataset, VALID_BATCH_SIZE, shuffle = False) else: self.train_sampler = BucketBatchSampler(train_dataset) self.valid_sampler = BucketBatchSampler(valid_dataset, shuffle = False) self.train_data = DataLoader(dataset=train_dataset, batch_sampler=self.train_sampler, collate_fn=model_wrapper.collator.collate, num_workers=2, pin_memory=True) self.valid_data = DataLoader(dataset=valid_dataset, batch_sampler=self.valid_sampler, collate_fn=model_wrapper.collator.collate, num_workers=2, pin_memory=True) self.total_training_steps = len(self.train_dataset) * EPOCHS self.checkpoint_cycle = math.ceil((len(self.train_data) * EPOCHS / CHECKPOINT_FREQ) / PRINT_PER_ITER) * PRINT_PER_ITER self.print_every = PRINT_PER_ITER self.iter = 0 self.scratch_iter = 0 self.start_epoch = 1 self.best_F1 = -1 self.current_epoch = 1 self.progress_epoch = None self.max_epochs = EPOCHS self.learning_rate = MAX_LR self.optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=0.01, correct_bias=False) self.num_warmup_steps = WARMUP_PERCENT * self.total_training_steps self.scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=self.num_warmup_steps, num_training_steps=self.total_training_steps) self.train_losses = [] self.accelerator = Accelerator(cpu= True if DEVICE == "cpu" else False) self.device = self.accelerator.device self.total_fw_time = 0 log_path = LOG + \ f'/pytorch.{self.model_name}.lr.{self.learning_rate}.train.log' if log_path: self.logger = get_logger(log_path) self.logger.log(f'DEVICE is: {self.device}') self.logger.log( f"============TOTAL TRAINING STEPS===========\n{self.total_training_steps}") self.logger.log(f"CHECKPOINT CYCLE: {self.checkpoint_cycle} ITER") def load_lazy_dataset(self, epoch): train_dataset = load_epoch_dataset(self.data_path, self.correct_file,\ self.incorrect_file, self.length_file, epoch, EPOCHS) self.train_dataset = SpellCorrectDataset(dataset=train_dataset) if not BUCKET_SAMPLING: self.train_sampler = RandomBatchSampler(self.train_dataset, TRAIN_BATCH_SIZE) else: self.train_sampler = BucketBatchSampler(self.train_dataset) self.train_data = DataLoader(dataset=self.train_dataset, batch_sampler=self.train_sampler, collate_fn=self.model_wrapper.collator.collate,\ num_workers=2, pin_memory=True) def step(self, batch, training=True): if training: self.model.train() start = time.time() outputs = self.model(batch['batch_src'], batch['attn_masks'], batch['batch_tgt']) # outputs.logits , outputs.loss self.total_fw_time += time.time() - start loss = outputs['loss'] batch_loss = outputs['loss'].cpu().detach().numpy() self.optimizer.zero_grad() self.accelerator.backward(loss) # Gradient clipping is not in AdamW anymore (so you can use amp without issue) torch.nn.utils.clip_grad_norm_( self.model.parameters(), max_norm=1.0) self.optimizer.step() self.scheduler.step(self.iter) return batch_loss else: self.model.eval() outputs = self.model(batch['batch_src'], batch['attn_masks'], batch['batch_tgt']) return outputs['loss'], outputs['preds'], \ batch['batch_tgt'].cpu().detach().numpy(), batch['lengths'] def train(self): self.logger.log("Loading model to device") self.model, self.optimizer, self.scheduler = self.accelerator.prepare( self.model, self.optimizer, self.scheduler) self.logger.log(f"Begin training from epoch: {self.start_epoch}") total_time = 0 total_loss = 0 overall_loss, overall_iter = 0, 0 patience = 0 for epoch_id in range(self.start_epoch, self.max_epochs + 1): self.current_epoch = epoch_id if self.progress_epoch and self.progress_epoch == epoch_id: self.progress_epoch = None elif self.current_epoch != 1: self.load_lazy_dataset(epoch_id) self.logger.log(f"Loaded lazy dataset {epoch_id} / {self.max_epochs}") else: pass self.logger.log(f"START OF EPOCH {epoch_id}") for step, batch in enumerate(self.train_data): start = time.time() self.iter += batch['batch_tgt'].size(0) self.scratch_iter += batch['batch_tgt'].size(0) overall_iter += batch['batch_tgt'].size(0) batch_loss = self.step(batch) total_time += time.time() - start total_loss += batch_loss overall_loss += batch_loss if step % self.print_every == 0: info = '{} - epoch: {} - step: {} - iter: {:08d}/{:08d} - train loss: {:.5f} - lr: {:.5e} - {} time: {:.2f}s'.format( colored(str(dt.now()),"green"), epoch_id, step, self.iter, self.total_training_steps, total_loss / self.print_every, self.optimizer.param_groups[0]['lr'], self.device, total_time) total_loss = 0 total_time = 0 self.logger.log(info) if step % self.checkpoint_cycle == 0: torch.cuda.empty_cache() if step == 0: continue # <---- validate -----> val_loss, val_accu, val_mean_time = self.validate() info = '{} - epoch: {} - valid loss: {:.5f} - valid accuracy: {:.4f}'.format( colored(str(dt.now()),"green"), epoch_id, val_loss, val_accu) self.logger.log(info) if overall_iter != 0 and overall_loss != 0: self.logger.log(f"Overall trainning loss between two checkpoints: {overall_loss / overall_iter}") overall_loss, overall_iter = 0, 0 if val_accu > self.best_F1: self.best_F1 = val_accu info = 'Saving weights to disk......' self.logger.log(info) self.save_weights(self.checkpoint_dir, epoch_id, self.best_F1) info = 'Saving checkpoint to disk......' self.logger.log(info) self.save_checkpoint( self.checkpoint_dir, epoch_id, self.best_F1) patience = 0 else: patience += 1 self.logger.log("Mean forward time: {:.5f}".format( self.total_fw_time / VALID_BATCH_SIZE)) self.total_fw_time = 0 if patience >= PATIENCE: break torch.cuda.empty_cache() ## Validation before next epoch torch.cuda.empty_cache() val_loss, val_accu, val_mean_time = self.validate() info = '{} - epoch: {} - valid loss: {:.5f} - valid accuracy: {:.4f}'.format( colored(str(dt.now()),"green"), epoch_id, val_loss, val_accu) self.logger.log(info) if overall_iter != 0 and overall_loss != 0: self.logger.log(f"Overall trainning loss between two checkpoints: {overall_loss / overall_iter}") overall_loss, overall_iter = 0, 0 if val_accu > self.best_F1: self.best_F1 = val_accu info = 'Saving weights to disk......' self.logger.log(info) self.save_weights(self.checkpoint_dir, epoch_id, self.best_F1) info = 'Saving checkpoint to disk......' self.logger.log(info) self.save_checkpoint( self.checkpoint_dir, epoch_id, self.best_F1) patience = 0 else: patience += 1 self.logger.log("Mean forward time: {:.5f}".format( self.total_fw_time / VALID_BATCH_SIZE)) self.total_fw_time = 0 if patience >= PATIENCE: break torch.cuda.empty_cache() self.scratch_iter = 0 self.logger.log(f"END OF EPOCH {epoch_id}") self.logger.log("Train complete!") def validate(self): total_loss = 0 valid_loss = 0 valid_time = 0 total_time = 0 total_examples = 0 num_correct, num_wrong = 0, 0 with torch.no_grad(): for step, batch in enumerate(self.valid_data): start = time.time() total_examples += batch['batch_tgt'].size(0) batch_loss, batch_predictions, \ batch_label_ids, batch_lengths = self.step( batch, training=False) valid_time += time.time() - start batch_token_lens = batch['lengths'] batch_label_ids = batch['batch_tgt'].cpu().detach().numpy() _num_correct, _num_wrong = get_metric_for_tfm(batch_predictions, batch_label_ids, batch_token_lens) num_correct += _num_correct num_wrong += _num_wrong valid_loss += batch_loss total_loss += batch_loss if step % self.print_every == 0: info = '{} Validation - iter: {:08d}/{:08d} - valid loss: {:.5f} - {} time: {:.2f}s'.format( colored(str(dt.now()),"green"), step, len(self.valid_data), valid_loss / self.print_every, self.device, valid_time / self.print_every) valid_loss = 0 total_time += valid_time valid_time = 0 self.logger.log(info) del batch_loss avg_loss = total_loss / len(self.valid_data) avg_accu = num_correct / (num_correct + num_wrong) avg_time = total_time / total_examples return avg_loss, avg_accu, avg_time def load_checkpoint(self, checkpoint_dir, dataset_name, start_epoch=0): self.checkpoint_dir = checkpoint_dir self.dataset_name = dataset_name checkpoint_path = checkpoint_dir + \ f'/{dataset_name}.model.epoch_{start_epoch - 1}.pth' if start_epoch > 0 and os.path.exists(checkpoint_path): checkpoint = torch.load( checkpoint_path, map_location=torch.device('cpu')) assert EPOCHS == checkpoint['num_epochs'] self.optimizer.load_state_dict(checkpoint['optimizer']) self.scheduler.load_state_dict(checkpoint['scheduler']) self.optimizer.base_lrs = [MAX_LR] self.scheduler.base_lrs = [MAX_LR] self.model.load_state_dict(checkpoint['state_dict']) self.iter = checkpoint['iter'] self.remained_indies = checkpoint['remained_indies'] self.start_epoch = checkpoint['epoch'] self.progress_epoch = self.start_epoch self.scratch_iter = checkpoint['scratch_iter'] train_dataset = load_epoch_dataset(self.data_path, self.correct_file,\ self.incorrect_file, self.length_file, self.start_epoch, EPOCHS) self.train_dataset = SpellCorrectDataset(dataset=train_dataset) if not BUCKET_SAMPLING: assert checkpoint['strategy'] == "random_sampling" self.train_sampler = RandomBatchSampler(self.train_dataset, TRAIN_BATCH_SIZE) self.train_sampler.load_checkpoints(self.scratch_iter) else: assert checkpoint['strategy'] == "bucket_sampling" self.train_sampler = BucketBatchSampler(self.train_dataset) self.train_sampler.load_checkpoints(self.remained_indies) self.train_data = DataLoader(dataset=self.train_dataset, batch_sampler=self.train_sampler, collate_fn=self.model_wrapper.collator.collate,\ num_workers=2, pin_memory=True) self.best_F1 = checkpoint['best_F1'] def save_checkpoint(self, checkpoint_dir, epoch, best_F1): checkpoint_path = checkpoint_dir + \ f'/{self.dataset_name}.model.epoch_{epoch}.pth' flatten_iterator_indies = list(chain.from_iterable(self.train_sampler.seq)) remained_indies = flatten_iterator_indies[self.scratch_iter:None] self.logger.log(f"Traversed iter from beginning: {self.scratch_iter}") state = { 'epoch': epoch, 'iter': self.iter, 'state_dict': self.model.state_dict(), 'scratch_iter': self.scratch_iter, 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict(), 'best_F1': best_F1, 'remained_indies': remained_indies, 'strategy': 'bucket_sampling' if BUCKET_SAMPLING else 'random_sampling', 'num_epochs': EPOCHS } if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir, exist_ok=True) info = f'Saving model checkpoint to: {checkpoint_path}' self.logger.log(info) torch.save(state, checkpoint_path) def save_weights(self, checkpoint_dir, epoch, best_F1): weight_path = checkpoint_dir + \ f'/{self.dataset_name}.weights.pth' if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir, exist_ok=True) state = { 'epoch': epoch, 'state_dict': self.model.state_dict(), 'best_F1': best_F1 } info = f'Saving model weights to: {weight_path}' self.logger.log(info) torch.save(state, weight_path)