AItool's picture
Upload 127 files
a983ebc
raw
history blame
29.9 kB
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/13a_learner.ipynb.
# %% ../nbs/13a_learner.ipynb 2
from __future__ import annotations
from .data.all import *
from .optimizer import *
from .callback.core import *
import pickle,threading
from collections.abc import MutableSequence
# %% auto 0
__all__ = ['replacing_yield', 'mk_metric', 'save_model', 'load_model', 'SkipToEpoch', 'Learner', 'before_batch_cb',
'load_learner', 'Metric', 'AvgMetric', 'AvgLoss', 'AvgSmoothLoss', 'ValueMetric', 'Recorder', 'CastToTensor',
'CancelBackwardException', 'CancelStepException', 'CancelFitException', 'CancelEpochException',
'CancelTrainException', 'CancelValidException', 'CancelBatchException']
# %% ../nbs/13a_learner.ipynb 4
_all_ = ['CancelBackwardException', 'CancelStepException','CancelFitException','CancelEpochException',
'CancelTrainException','CancelValidException','CancelBatchException']
# %% ../nbs/13a_learner.ipynb 10
defaults.lr = 1e-3
# %% ../nbs/13a_learner.ipynb 11
def replacing_yield(o, attr, val):
"Context manager to temporarily replace an attribute"
old = getattr(o,attr)
try: yield setattr(o,attr,val)
finally: setattr(o,attr,old)
# %% ../nbs/13a_learner.ipynb 13
def mk_metric(m):
"Convert `m` to an `AvgMetric`, unless it's already a `Metric`"
if isinstance(m,type): m = m()
return m if isinstance(m, Metric) else AvgMetric(m)
# %% ../nbs/13a_learner.ipynb 15
def save_model(file, model, opt, with_opt=True, pickle_protocol=2, **torch_save_kwargs):
"Save `model` to `file` along with `opt` (if available, and if `with_opt`)"
if rank_distrib(): return # don't save if child proc
if opt is None: with_opt=False
state = get_model(model).state_dict()
if with_opt: state = {'model': state, 'opt':opt.state_dict()}
torch.save(state, file, pickle_protocol=pickle_protocol, **torch_save_kwargs)
# %% ../nbs/13a_learner.ipynb 17
def load_model(file, model, opt, with_opt=True, device=None, strict=True, **torch_load_kwargs):
"Load `model` from `file` along with `opt` (if available, and if `with_opt`)"
if isinstance(device, int): device = torch.device('cuda', device)
elif device is None: device = 'cpu'
state = torch.load(file, map_location=device, **torch_load_kwargs)
hasopt = set(state)=={'model', 'opt'}
model_state = state['model'] if hasopt else state
get_model(model).load_state_dict(model_state, strict=strict)
if hasopt and with_opt:
try: opt.load_state_dict(state['opt'])
except:
if with_opt: warn("Could not load the optimizer state.")
elif with_opt: warn("Saved filed doesn't contain an optimizer state.")
# %% ../nbs/13a_learner.ipynb 19
def _try_concat(o):
try: return torch.cat(o)
except: return sum([L(o_[i,:] for i in range_of(o_)) for o_ in o], L())
# %% ../nbs/13a_learner.ipynb 20
_before_epoch = [event.before_fit, event.before_epoch]
_after_epoch = [event.after_epoch, event.after_fit]
# %% ../nbs/13a_learner.ipynb 21
class _ConstantFunc():
"Returns a function that returns `o`"
def __init__(self, o): self.o = o
def __call__(self, *args, **kwargs): return self.o
# %% ../nbs/13a_learner.ipynb 22
class SkipToEpoch(Callback):
"Skip training up to `epoch`"
order = 70
def __init__(self, epoch:int):
self._skip_to = epoch
def before_epoch(self):
if self.epoch < self._skip_to:
raise CancelEpochException
# %% ../nbs/13a_learner.ipynb 24
_loop = ['Start Fit', 'before_fit', 'Start Epoch Loop', 'before_epoch', 'Start Train', 'before_train',
'Start Batch Loop', 'before_batch', 'after_pred', 'after_loss', 'before_backward', 'before_step',
'after_step', 'after_cancel_batch', 'after_batch','End Batch Loop','End Train',
'after_cancel_train', 'after_train', 'Start Valid', 'before_validate','Start Batch Loop',
'**CBs same as train batch**', 'End Batch Loop', 'End Valid', 'after_cancel_validate',
'after_validate', 'End Epoch Loop', 'after_cancel_epoch', 'after_epoch', 'End Fit',
'after_cancel_fit', 'after_fit']
# %% ../nbs/13a_learner.ipynb 25
class Learner(GetAttr):
_default='model'
def __init__(self,
dls:DataLoaders, # `DataLoaders` containing fastai or PyTorch `DataLoader`s
model:callable, # PyTorch model for training or inference
loss_func:callable|None=None, # Loss function. Defaults to `dls` loss
opt_func:Optimizer|OptimWrapper=Adam, # Optimization function for training
lr:float|slice=defaults.lr, # Default learning rate
splitter:callable=trainable_params, # Split model into parameter groups. Defaults to one parameter group
cbs:Callback|MutableSequence|None=None, # `Callback`s to add to `Learner`
metrics:callable|MutableSequence|None=None, # `Metric`s to calculate on validation set
path:str|Path|None=None, # Parent directory to save, load, and export models. Defaults to `dls` `path`
model_dir:str|Path='models', # Subdirectory to save and load models
wd:float|int|None=None, # Default weight decay
wd_bn_bias:bool=False, # Apply weight decay to normalization and bias parameters
train_bn:bool=True, # Train frozen normalization layers
moms:tuple=(0.95,0.85,0.95), # Default momentum for schedulers
default_cbs:bool=True # Include default `Callback`s
):
path = Path(path) if path is not None else getattr(dls, 'path', Path('.'))
if loss_func is None:
loss_func = getattr(dls.train_ds, 'loss_func', None)
assert loss_func is not None, "Could not infer loss function from the data, please pass a loss function."
self.dls,self.model = dls,model
store_attr(but='dls,model,cbs')
self.training,self.create_mbar,self.logger,self.opt,self.cbs = False,True,print,None,L()
if default_cbs: self.add_cbs(L(defaults.callbacks))
self.add_cbs(cbs)
self.lock = threading.Lock()
self("after_create")
@property
def metrics(self): return self._metrics
@metrics.setter
def metrics(self,v): self._metrics = L(v).map(mk_metric)
def _grab_cbs(self, cb_cls): return L(cb for cb in self.cbs if isinstance(cb, cb_cls))
def add_cbs(self, cbs):
L(cbs).map(self.add_cb)
return self
def remove_cbs(self, cbs):
L(cbs).map(self.remove_cb)
return self
def add_cb(self, cb):
if isinstance(cb, type): cb = cb()
cb.learn = self
setattr(self, cb.name, cb)
self.cbs.append(cb)
return self
def remove_cb(self, cb):
if isinstance(cb, type): self.remove_cbs(self._grab_cbs(cb))
else:
cb.learn = None
if hasattr(self, cb.name): delattr(self, cb.name)
if cb in self.cbs: self.cbs.remove(cb)
return self
@contextmanager
def added_cbs(self, cbs):
self.add_cbs(cbs)
try: yield
finally: self.remove_cbs(cbs)
@contextmanager
def removed_cbs(self, cbs):
self.remove_cbs(cbs)
try: yield self
finally: self.add_cbs(cbs)
def ordered_cbs(self, event): return [cb for cb in self.cbs.sorted('order') if hasattr(cb, event)]
def __call__(self, event_name): L(event_name).map(self._call_one)
def _call_one(self, event_name):
if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
for cb in self.cbs.sorted('order'): cb(event_name)
def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)
def create_opt(self):
if isinstance(self.opt_func, partial):
if 'lr' in self.opt_func.keywords:
self.lr = self.opt_func.keywords['lr']
if isinstance(self.opt_func, OptimWrapper):
self.opt = self.opt_func
self.opt.clear_state()
else:
self.opt = self.opt_func(self.splitter(self.model), lr=self.lr)
if not self.wd_bn_bias:
for p in self._bn_bias_state(True ): p['do_wd'] = False
if self.train_bn:
for p in self._bn_bias_state(False): p['force_train'] = True
def _split(self, b):
i = getattr(self.dls, 'n_inp', 1 if len(b)==1 else len(b)-1)
self.xb,self.yb = b[:i],b[i:]
def _with_events(self, f, event_type, ex, final=noop):
try: self(f'before_{event_type}'); f()
except ex: self(f'after_cancel_{event_type}')
self(f'after_{event_type}'); final()
def all_batches(self):
self.n_iter = len(self.dl)
for o in enumerate(self.dl): self.one_batch(*o)
def _backward(self): self.loss_grad.backward()
def _step(self): self.opt.step()
def _do_grad_opt(self):
self._with_events(self._backward, 'backward', CancelBackwardException)
self._with_events(self._step, 'step', CancelStepException)
self.opt.zero_grad()
def _do_one_batch(self):
self.pred = self.model(*self.xb)
self('after_pred')
if len(self.yb):
self.loss_grad = self.loss_func(self.pred, *self.yb)
self.loss = self.loss_grad.clone()
self('after_loss')
if not self.training or not len(self.yb): return
self._do_grad_opt()
def _set_device(self, b):
model_device = next(self.model.parameters()).device
dls_device = getattr(self.dls, 'device', default_device())
if model_device == dls_device: return to_device(b, dls_device)
else: return to_device(b, model_device)
def one_batch(self, i, b):
self.iter = i
b = self._set_device(b)
self._split(b)
self._with_events(self._do_one_batch, 'batch', CancelBatchException)
def _do_epoch_train(self):
self.dl = self.dls.train
self._with_events(self.all_batches, 'train', CancelTrainException)
def _do_epoch_validate(self, ds_idx=1, dl=None):
if dl is None: dl = self.dls[ds_idx]
self.dl = dl
with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
def _do_epoch(self):
self._do_epoch_train()
self._do_epoch_validate()
def _do_fit(self):
for epoch in range(self.n_epoch):
self.epoch=epoch
self._with_events(self._do_epoch, 'epoch', CancelEpochException)
def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False, start_epoch=0):
if start_epoch != 0:
cbs = L(cbs) + SkipToEpoch(start_epoch)
with self.added_cbs(cbs):
if reset_opt or not self.opt: self.create_opt()
if wd is None: wd = self.wd
if wd is not None: self.opt.set_hypers(wd=wd)
self.opt.set_hypers(lr=self.lr if lr is None else lr)
self.n_epoch = n_epoch
self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None
def __enter__(self): self(_before_epoch); return self
def __exit__(self, exc_type, exc_value, tb): self(_after_epoch)
def validation_context(self, cbs=None, inner=False):
cms = [self.no_logging(),self.no_mbar(), self.lock]
if cbs: cms.append(self.added_cbs(cbs))
if not inner: cms.append(self)
return ContextManagers(cms)
def validate(self, ds_idx=1, dl=None, cbs=None):
if dl is None: dl = self.dls[ds_idx]
with self.validation_context(cbs=cbs): self._do_epoch_validate(ds_idx, dl)
return getattr(self, 'final_record', None)
@delegates(GatherPredsCallback.__init__)
def get_preds(self,
ds_idx:int=1, # `DataLoader` to use for predictions if `dl` is None. 0: train. 1: valid
dl=None, # `DataLoader` to use for predictions, defaults to `ds_idx=1` if None
with_input:bool=False, # Return inputs with predictions
with_decoded:bool=False, # Return decoded predictions
with_loss:bool=False, # Return per item loss with predictions
act=None, # Apply activation to predictions, defaults to `self.loss_func`'s activation
inner:bool=False, # If False, create progress bar, show logger, use temporary `cbs`
reorder:bool=True, # Reorder predictions on dataset indicies, if applicable
cbs:Callback|MutableSequence|None=None, # Temporary `Callback`s to apply during prediction
**kwargs
)-> tuple:
if dl is None: dl = self.dls[ds_idx].new(shuffle=False, drop_last=False)
else:
try: len(dl)
except TypeError as e:
raise TypeError(f"`dl` is {type(dl)} and doesn't have len(dl)")
if isinstance(dl, DataLoader):
if dl.drop_last: dl = dl.new(shuffle=False, drop_last=False)
if reorder and hasattr(dl, 'get_idxs'):
idxs = dl.get_idxs()
dl = dl.new(get_idxs = _ConstantFunc(idxs))
cb = GatherPredsCallback(with_input=with_input, with_loss=with_loss, **kwargs)
ctx_mgrs = self.validation_context(cbs=L(cbs)+[cb], inner=inner)
if with_loss: ctx_mgrs.append(self.loss_not_reduced())
with ContextManagers(ctx_mgrs):
self._do_epoch_validate(dl=dl)
if act is None: act = getcallable(self.loss_func, 'activation')
res = cb.all_tensors()
pred_i = 1 if with_input else 0
if res[pred_i] is not None:
res[pred_i] = act(res[pred_i])
if with_decoded: res.insert(pred_i+2, getcallable(self.loss_func, 'decodes')(res[pred_i]))
if reorder and hasattr(dl, 'get_idxs'): res = nested_reorder(res, tensor(idxs).argsort())
return tuple(res)
self._end_cleanup()
def predict(self, item, rm_type_tfms=None, with_input=False):
dl = self.dls.test_dl([item], rm_type_tfms=rm_type_tfms, num_workers=0)
inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)
i = getattr(self.dls, 'n_inp', -1)
inp = (inp,) if i==1 else tuplify(inp)
dec = self.dls.decode_batch(inp + tuplify(dec_preds))[0]
dec_inp,dec_targ = map(detuplify, [dec[:i],dec[i:]])
res = dec_targ,dec_preds[0],preds[0]
if with_input: res = (dec_inp,) + res
return res
def show_results(self, ds_idx=1, dl=None, max_n=9, shuffle=True, **kwargs):
if dl is None: dl = self.dls[ds_idx].new(shuffle=shuffle)
b = dl.one_batch()
_,_,preds = self.get_preds(dl=[b], with_decoded=True)
dl.show_results(b, preds, max_n=max_n, **kwargs)
def show_training_loop(self):
indent = 0
for s in _loop:
if s.startswith('Start'): print(f'{" "*indent}{s}'); indent += 2
elif s.startswith('End'): indent -= 2; print(f'{" "*indent}{s}')
else: print(f'{" "*indent} - {s:15}:', self.ordered_cbs(s))
@contextmanager
def no_logging(self): return replacing_yield(self, 'logger', noop)
@contextmanager
def no_mbar(self): return replacing_yield(self, 'create_mbar', False)
@contextmanager
def loss_not_reduced(self):
if hasattr(self.loss_func, 'reduction'): return replacing_yield(self.loss_func, 'reduction', 'none')
else: return replacing_yield(self, 'loss_func', partial(self.loss_func, reduction='none'))
def to_detach(self,b,cpu=True,gather=True):
return self.dl.to_detach(b,cpu,gather) if hasattr(getattr(self,'dl',None),'to_detach') else to_detach(b,cpu,gather)
def __getstate__(self): return {k:v for k,v in self.__dict__.items() if k!='lock'}
def __setstate__(self, state):
self.__dict__.update(state)
self.lock = threading.Lock()
Learner.x,Learner.y = add_props(lambda i,x: detuplify((x.xb,x.yb)[i]))
# %% ../nbs/13a_learner.ipynb 26
add_docs(Learner, "Group together a `model`, some `dls` and a `loss_func` to handle training",
add_cbs="Add `cbs` to the list of `Callback` and register `self` as their learner",
add_cb="Add `cb` to the list of `Callback` and register `self` as their learner",
remove_cbs="Remove `cbs` from the list of `Callback` and deregister `self` as their learner",
remove_cb="Add `cb` from the list of `Callback` and deregister `self` as their learner",
added_cbs="Context manage that temporarily adds `cbs`",
removed_cbs="Context manage that temporarily removes `cbs`",
ordered_cbs="List of `Callback`s, in order, for an `event` in the training loop",
create_opt="Create an optimizer with default hyper-parameters",
one_batch="Train or evaluate `self.model` on batch `(xb,yb)`",
all_batches="Train or evaluate `self.model` on all the batches of `self.dl`",
fit="Fit `self.model` for `n_epoch` using `cbs`. Optionally `reset_opt`.",
validate="Validate on `dl` with potential new `cbs`.",
get_preds="Get the predictions and targets on the `ds_idx`-th dbunchset or `dl`, optionally `with_input` and `with_loss`",
predict="Prediction on `item`, fully decoded, loss function decoded and probabilities",
validation_context="A `ContextManagers` suitable for validation, with optional `cbs`",
show_results="Show some predictions on `ds_idx`-th dataset or `dl`",
show_training_loop="Show each step in the training loop",
no_logging="Context manager to temporarily remove `logger`",
no_mbar="Context manager to temporarily prevent the master progress bar from being created",
loss_not_reduced="A context manager to evaluate `loss_func` with reduction set to none.",
to_detach="Calls `to_detach` if `self.dl` provides a `.to_detach` function otherwise calls global `to_detach`",
__call__="Call `event_name` for all `Callback`s in `self.cbs`"
)
# %% ../nbs/13a_learner.ipynb 33
if not hasattr(defaults, 'callbacks'): defaults.callbacks = [TrainEvalCallback]
# %% ../nbs/13a_learner.ipynb 88
def _before_batch_cb(f, self):
xb,yb = f(self, self.xb, self.yb)
self.learn.xb,self.learn.yb = xb,yb
# %% ../nbs/13a_learner.ipynb 89
def before_batch_cb(f):
"Shortcut for creating a Callback on the `before_batch` event, which takes and returns `xb,yb`"
return Callback(before_batch=partial(_before_batch_cb, f))
# %% ../nbs/13a_learner.ipynb 96
@patch
@delegates(save_model)
def save(self:Learner, file, **kwargs):
"Save model and optimizer state (if `with_opt`) to `self.path/self.model_dir/file`"
file = join_path_file(file, self.path/self.model_dir, ext='.pth')
save_model(file, self.model, getattr(self,'opt',None), **kwargs)
return file
# %% ../nbs/13a_learner.ipynb 98
@patch
@delegates(load_model)
def load(self:Learner, file, device=None, **kwargs):
"Load model and optimizer state (if `with_opt`) from `self.path/self.model_dir/file` using `device`"
if device is None and hasattr(self.dls, 'device'): device = self.dls.device
if self.opt is None: self.create_opt()
file = join_path_file(file, self.path/self.model_dir, ext='.pth')
distrib_barrier()
load_model(file, self.model, self.opt, device=device, **kwargs)
return self
# %% ../nbs/13a_learner.ipynb 102
@patch
def export(self:Learner, fname='export.pkl', pickle_module=pickle, pickle_protocol=2):
"Export the content of `self` without the items and the optimizer state for inference"
if rank_distrib(): return # don't export if child proc
self._end_cleanup()
old_dbunch = self.dls
self.dls = self.dls.new_empty()
state = self.opt.state_dict() if self.opt is not None else None
self.opt = None
with warnings.catch_warnings():
#To avoid the warning that come from PyTorch about model not being checked
warnings.simplefilter("ignore")
torch.save(self, self.path/fname, pickle_module=pickle_module, pickle_protocol=pickle_protocol)
self.create_opt()
if state is not None: self.opt.load_state_dict(state)
self.dls = old_dbunch
# %% ../nbs/13a_learner.ipynb 104
def load_learner(fname, cpu=True, pickle_module=pickle):
"Load a `Learner` object in `fname`, by default putting it on the `cpu`"
distrib_barrier()
map_loc = 'cpu' if cpu else default_device()
try: res = torch.load(fname, map_location=map_loc, pickle_module=pickle_module)
except AttributeError as e:
e.args = [f"Custom classes or functions exported with your `Learner` not available in namespace.\Re-declare/import before loading:\n\t{e.args[0]}"]
raise
if cpu:
res.dls.cpu()
if hasattr(res, 'channels_last'): res = res.to_contiguous(to_fp32=True)
elif hasattr(res, 'mixed_precision'): res = res.to_fp32()
elif hasattr(res, 'non_native_mixed_precision'): res = res.to_non_native_fp32()
return res
# %% ../nbs/13a_learner.ipynb 111
@docs
class Metric():
"Blueprint for defining a metric"
def reset(self): pass
def accumulate(self, learn): pass
@property
def value(self): raise NotImplementedError
@property
def name(self): return class2attr(self, 'Metric')
_docs = dict(
reset="Reset inner state to prepare for new computation",
name="Name of the `Metric`, camel-cased and with Metric removed",
accumulate="Use `learn` to update the state with new results",
value="The value of the metric")
# %% ../nbs/13a_learner.ipynb 118
class AvgMetric(Metric):
"Average the values of `func` taking into account potential different batch sizes"
def __init__(self, func): self.func = func
def reset(self): self.total,self.count = 0.,0
def accumulate(self, learn):
bs = find_bs(learn.yb)
self.total += learn.to_detach(self.func(learn.pred, *learn.yb))*bs
self.count += bs
@property
def value(self): return self.total/self.count if self.count != 0 else None
@property
def name(self): return self.func.func.__name__ if hasattr(self.func, 'func') else self.func.__name__
# %% ../nbs/13a_learner.ipynb 122
class AvgLoss(Metric):
"Average the losses taking into account potential different batch sizes"
def reset(self): self.total,self.count = 0.,0
def accumulate(self, learn):
bs = find_bs(learn.yb)
self.total += learn.to_detach(learn.loss.mean())*bs
self.count += bs
@property
def value(self): return self.total/self.count if self.count != 0 else None
@property
def name(self): return "loss"
# %% ../nbs/13a_learner.ipynb 126
class AvgSmoothLoss(Metric):
"Smooth average of the losses (exponentially weighted with `beta`)"
def __init__(self, beta=0.98): self.beta = beta
def reset(self): self.count,self.val = 0,tensor(0.)
def accumulate(self, learn):
self.count += 1
self.val = torch.lerp(to_detach(learn.loss.mean()), self.val, self.beta)
@property
def value(self): return self.val/(1-self.beta**self.count)
# %% ../nbs/13a_learner.ipynb 129
class ValueMetric(Metric):
"Use to include a pre-calculated metric value (for instance calculated in a `Callback`) and returned by `func`"
def __init__(self, func, metric_name=None): store_attr('func, metric_name')
@property
def value(self): return self.func()
@property
def name(self): return self.metric_name if self.metric_name else self.func.__name__
# %% ../nbs/13a_learner.ipynb 133
from fastprogress.fastprogress import format_time
# %% ../nbs/13a_learner.ipynb 134
def _maybe_item(t):
t = t.value
try: return t.item()
except: return t
# %% ../nbs/13a_learner.ipynb 135
class Recorder(Callback):
"Callback that registers statistics (lr, loss and metrics) during training"
_stateattrs=('lrs','iters','losses','values')
remove_on_fetch,order = True,50
def __init__(self, add_time=True, train_metrics=False, valid_metrics=True, beta=0.98):
store_attr('add_time,train_metrics,valid_metrics')
self.loss,self.smooth_loss = AvgLoss(),AvgSmoothLoss(beta=beta)
def before_fit(self):
"Prepare state for training"
self.lrs,self.iters,self.losses,self.values = [],[],[],[]
names = self.metrics.attrgot('name')
if self.train_metrics and self.valid_metrics:
names = L('loss') + names
names = names.map('train_{}') + names.map('valid_{}')
elif self.valid_metrics: names = L('train_loss', 'valid_loss') + names
else: names = L('train_loss') + names
if self.add_time: names.append('time')
self.metric_names = 'epoch'+names
self.smooth_loss.reset()
def after_batch(self):
"Update all metrics and records lr and smooth loss in training"
if len(self.yb) == 0: return
mets = self._train_mets if self.training else self._valid_mets
for met in mets: met.accumulate(self.learn)
if not self.training: return
self.lrs.append(self.opt.hypers[-1]['lr'])
self.losses.append(self.smooth_loss.value)
self.learn.smooth_loss = self.smooth_loss.value
def before_epoch(self):
"Set timer if `self.add_time=True`"
self.cancel_train,self.cancel_valid = False,False
if self.add_time: self.start_epoch = time.time()
self.log = L(getattr(self, 'epoch', 0))
def before_train (self): self._train_mets[1:].map(Self.reset())
def before_validate(self): self._valid_mets.map(Self.reset())
def after_train (self): self.log += self._train_mets.map(_maybe_item)
def after_validate(self): self.log += self._valid_mets.map(_maybe_item)
def after_cancel_train(self): self.cancel_train = True
def after_cancel_validate(self): self.cancel_valid = True
def after_epoch(self):
"Store and log the loss/metric values"
self.learn.final_record = self.log[1:].copy()
self.values.append(self.learn.final_record)
if self.add_time: self.log.append(format_time(time.time() - self.start_epoch))
self.logger(self.log)
self.iters.append(self.smooth_loss.count)
@property
def _train_mets(self):
if getattr(self, 'cancel_train', False): return L()
return L(self.smooth_loss) + (self.metrics if self.train_metrics else L())
@property
def _valid_mets(self):
if getattr(self, 'cancel_valid', False): return L()
return (L(self.loss) + self.metrics if self.valid_metrics else L())
def plot_loss(self, skip_start=5, with_valid=True):
plt.plot(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train')
if with_valid:
idx = (np.array(self.iters)<skip_start).sum()
valid_col = self.metric_names.index('valid_loss') - 1
plt.plot(self.iters[idx:], L(self.values[idx:]).itemgot(valid_col), label='valid')
plt.legend()
# %% ../nbs/13a_learner.ipynb 136
add_docs(Recorder,
before_train = "Reset loss and metrics state",
after_train = "Log loss and metric values on the training set (if `self.training_metrics=True`)",
before_validate = "Reset loss and metrics state",
after_validate = "Log loss and metric values on the validation set",
after_cancel_train = "Ignore training metrics for this epoch",
after_cancel_validate = "Ignore validation metrics for this epoch",
plot_loss = "Plot the losses from `skip_start` and onward")
if Recorder not in defaults.callbacks: defaults.callbacks.append(Recorder)
# %% ../nbs/13a_learner.ipynb 152
def _cast_tensor(x):
if isinstance(x, tuple): return tuple(_cast_tensor(x_) for x_ in x)
else: return cast(x, Tensor) if isinstance(x,torch.Tensor) else x
# %% ../nbs/13a_learner.ipynb 153
class CastToTensor(Callback):
"Cast Subclassed Tensors to `Tensor`"
order=9 # Right before MixedPrecision
def before_batch(self):
self.learn.xb,self.learn.yb = _cast_tensor(self.learn.xb),_cast_tensor(self.learn.yb)
# %% ../nbs/13a_learner.ipynb 155
if CastToTensor not in defaults.callbacks: defaults.callbacks.append(CastToTensor)
# %% ../nbs/13a_learner.ipynb 185
@patch
def freeze_to(self:Learner, n):
if self.opt is None: self.create_opt()
self.opt.freeze_to(n)
self.opt.clear_state()
@patch
def freeze(self:Learner): self.freeze_to(-1)
@patch
def unfreeze(self:Learner): self.freeze_to(0)
add_docs(Learner,
freeze_to="Freeze parameter groups up to `n`",
freeze="Freeze up to last parameter group",
unfreeze="Unfreeze the entire model")
# %% ../nbs/13a_learner.ipynb 189
@patch
def tta(self:Learner, ds_idx=1, dl=None, n=4, item_tfms=None, batch_tfms=None, beta=0.25, use_max=False):
"Return predictions on the `ds_idx` dataset or `dl` using Test Time Augmentation"
if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False)
if item_tfms is not None or batch_tfms is not None: dl = dl.new(after_item=item_tfms, after_batch=batch_tfms)
try:
self(_before_epoch)
with dl.dataset.set_split_idx(0), self.no_mbar():
if hasattr(self,'progress'): self.progress.mbar = master_bar(list(range(n)))
aug_preds = []
for i in self.progress.mbar if hasattr(self,'progress') else range(n):
self.epoch = i #To keep track of progress on mbar since the progress callback will use self.epoch
aug_preds.append(self.get_preds(dl=dl, inner=True)[0][None])
aug_preds = torch.cat(aug_preds)
aug_preds = aug_preds.max(0)[0] if use_max else aug_preds.mean(0)
self.epoch = n
with dl.dataset.set_split_idx(1): preds,targs = self.get_preds(dl=dl, inner=True)
finally: self(event.after_fit)
if use_max: return torch.stack([preds, aug_preds], 0).max(0)[0],targs
preds = (aug_preds,preds) if beta is None else torch.lerp(aug_preds, preds, beta)
return preds,targs