Spaces:
Runtime error
Runtime error
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/13a_learner.ipynb. | |
# %% ../nbs/13a_learner.ipynb 2 | |
from __future__ import annotations | |
from .data.all import * | |
from .optimizer import * | |
from .callback.core import * | |
import pickle,threading | |
from collections.abc import MutableSequence | |
# %% auto 0 | |
__all__ = ['replacing_yield', 'mk_metric', 'save_model', 'load_model', 'SkipToEpoch', 'Learner', 'before_batch_cb', | |
'load_learner', 'Metric', 'AvgMetric', 'AvgLoss', 'AvgSmoothLoss', 'ValueMetric', 'Recorder', 'CastToTensor', | |
'CancelBackwardException', 'CancelStepException', 'CancelFitException', 'CancelEpochException', | |
'CancelTrainException', 'CancelValidException', 'CancelBatchException'] | |
# %% ../nbs/13a_learner.ipynb 4 | |
_all_ = ['CancelBackwardException', 'CancelStepException','CancelFitException','CancelEpochException', | |
'CancelTrainException','CancelValidException','CancelBatchException'] | |
# %% ../nbs/13a_learner.ipynb 10 | |
defaults.lr = 1e-3 | |
# %% ../nbs/13a_learner.ipynb 11 | |
def replacing_yield(o, attr, val): | |
"Context manager to temporarily replace an attribute" | |
old = getattr(o,attr) | |
try: yield setattr(o,attr,val) | |
finally: setattr(o,attr,old) | |
# %% ../nbs/13a_learner.ipynb 13 | |
def mk_metric(m): | |
"Convert `m` to an `AvgMetric`, unless it's already a `Metric`" | |
if isinstance(m,type): m = m() | |
return m if isinstance(m, Metric) else AvgMetric(m) | |
# %% ../nbs/13a_learner.ipynb 15 | |
def save_model(file, model, opt, with_opt=True, pickle_protocol=2, **torch_save_kwargs): | |
"Save `model` to `file` along with `opt` (if available, and if `with_opt`)" | |
if rank_distrib(): return # don't save if child proc | |
if opt is None: with_opt=False | |
state = get_model(model).state_dict() | |
if with_opt: state = {'model': state, 'opt':opt.state_dict()} | |
torch.save(state, file, pickle_protocol=pickle_protocol, **torch_save_kwargs) | |
# %% ../nbs/13a_learner.ipynb 17 | |
def load_model(file, model, opt, with_opt=True, device=None, strict=True, **torch_load_kwargs): | |
"Load `model` from `file` along with `opt` (if available, and if `with_opt`)" | |
if isinstance(device, int): device = torch.device('cuda', device) | |
elif device is None: device = 'cpu' | |
state = torch.load(file, map_location=device, **torch_load_kwargs) | |
hasopt = set(state)=={'model', 'opt'} | |
model_state = state['model'] if hasopt else state | |
get_model(model).load_state_dict(model_state, strict=strict) | |
if hasopt and with_opt: | |
try: opt.load_state_dict(state['opt']) | |
except: | |
if with_opt: warn("Could not load the optimizer state.") | |
elif with_opt: warn("Saved filed doesn't contain an optimizer state.") | |
# %% ../nbs/13a_learner.ipynb 19 | |
def _try_concat(o): | |
try: return torch.cat(o) | |
except: return sum([L(o_[i,:] for i in range_of(o_)) for o_ in o], L()) | |
# %% ../nbs/13a_learner.ipynb 20 | |
_before_epoch = [event.before_fit, event.before_epoch] | |
_after_epoch = [event.after_epoch, event.after_fit] | |
# %% ../nbs/13a_learner.ipynb 21 | |
class _ConstantFunc(): | |
"Returns a function that returns `o`" | |
def __init__(self, o): self.o = o | |
def __call__(self, *args, **kwargs): return self.o | |
# %% ../nbs/13a_learner.ipynb 22 | |
class SkipToEpoch(Callback): | |
"Skip training up to `epoch`" | |
order = 70 | |
def __init__(self, epoch:int): | |
self._skip_to = epoch | |
def before_epoch(self): | |
if self.epoch < self._skip_to: | |
raise CancelEpochException | |
# %% ../nbs/13a_learner.ipynb 24 | |
_loop = ['Start Fit', 'before_fit', 'Start Epoch Loop', 'before_epoch', 'Start Train', 'before_train', | |
'Start Batch Loop', 'before_batch', 'after_pred', 'after_loss', 'before_backward', 'before_step', | |
'after_step', 'after_cancel_batch', 'after_batch','End Batch Loop','End Train', | |
'after_cancel_train', 'after_train', 'Start Valid', 'before_validate','Start Batch Loop', | |
'**CBs same as train batch**', 'End Batch Loop', 'End Valid', 'after_cancel_validate', | |
'after_validate', 'End Epoch Loop', 'after_cancel_epoch', 'after_epoch', 'End Fit', | |
'after_cancel_fit', 'after_fit'] | |
# %% ../nbs/13a_learner.ipynb 25 | |
class Learner(GetAttr): | |
_default='model' | |
def __init__(self, | |
dls:DataLoaders, # `DataLoaders` containing fastai or PyTorch `DataLoader`s | |
model:callable, # PyTorch model for training or inference | |
loss_func:callable|None=None, # Loss function. Defaults to `dls` loss | |
opt_func:Optimizer|OptimWrapper=Adam, # Optimization function for training | |
lr:float|slice=defaults.lr, # Default learning rate | |
splitter:callable=trainable_params, # Split model into parameter groups. Defaults to one parameter group | |
cbs:Callback|MutableSequence|None=None, # `Callback`s to add to `Learner` | |
metrics:callable|MutableSequence|None=None, # `Metric`s to calculate on validation set | |
path:str|Path|None=None, # Parent directory to save, load, and export models. Defaults to `dls` `path` | |
model_dir:str|Path='models', # Subdirectory to save and load models | |
wd:float|int|None=None, # Default weight decay | |
wd_bn_bias:bool=False, # Apply weight decay to normalization and bias parameters | |
train_bn:bool=True, # Train frozen normalization layers | |
moms:tuple=(0.95,0.85,0.95), # Default momentum for schedulers | |
default_cbs:bool=True # Include default `Callback`s | |
): | |
path = Path(path) if path is not None else getattr(dls, 'path', Path('.')) | |
if loss_func is None: | |
loss_func = getattr(dls.train_ds, 'loss_func', None) | |
assert loss_func is not None, "Could not infer loss function from the data, please pass a loss function." | |
self.dls,self.model = dls,model | |
store_attr(but='dls,model,cbs') | |
self.training,self.create_mbar,self.logger,self.opt,self.cbs = False,True,print,None,L() | |
if default_cbs: self.add_cbs(L(defaults.callbacks)) | |
self.add_cbs(cbs) | |
self.lock = threading.Lock() | |
self("after_create") | |
def metrics(self): return self._metrics | |
def metrics(self,v): self._metrics = L(v).map(mk_metric) | |
def _grab_cbs(self, cb_cls): return L(cb for cb in self.cbs if isinstance(cb, cb_cls)) | |
def add_cbs(self, cbs): | |
L(cbs).map(self.add_cb) | |
return self | |
def remove_cbs(self, cbs): | |
L(cbs).map(self.remove_cb) | |
return self | |
def add_cb(self, cb): | |
if isinstance(cb, type): cb = cb() | |
cb.learn = self | |
setattr(self, cb.name, cb) | |
self.cbs.append(cb) | |
return self | |
def remove_cb(self, cb): | |
if isinstance(cb, type): self.remove_cbs(self._grab_cbs(cb)) | |
else: | |
cb.learn = None | |
if hasattr(self, cb.name): delattr(self, cb.name) | |
if cb in self.cbs: self.cbs.remove(cb) | |
return self | |
def added_cbs(self, cbs): | |
self.add_cbs(cbs) | |
try: yield | |
finally: self.remove_cbs(cbs) | |
def removed_cbs(self, cbs): | |
self.remove_cbs(cbs) | |
try: yield self | |
finally: self.add_cbs(cbs) | |
def ordered_cbs(self, event): return [cb for cb in self.cbs.sorted('order') if hasattr(cb, event)] | |
def __call__(self, event_name): L(event_name).map(self._call_one) | |
def _call_one(self, event_name): | |
if not hasattr(event, event_name): raise Exception(f'missing {event_name}') | |
for cb in self.cbs.sorted('order'): cb(event_name) | |
def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state) | |
def create_opt(self): | |
if isinstance(self.opt_func, partial): | |
if 'lr' in self.opt_func.keywords: | |
self.lr = self.opt_func.keywords['lr'] | |
if isinstance(self.opt_func, OptimWrapper): | |
self.opt = self.opt_func | |
self.opt.clear_state() | |
else: | |
self.opt = self.opt_func(self.splitter(self.model), lr=self.lr) | |
if not self.wd_bn_bias: | |
for p in self._bn_bias_state(True ): p['do_wd'] = False | |
if self.train_bn: | |
for p in self._bn_bias_state(False): p['force_train'] = True | |
def _split(self, b): | |
i = getattr(self.dls, 'n_inp', 1 if len(b)==1 else len(b)-1) | |
self.xb,self.yb = b[:i],b[i:] | |
def _with_events(self, f, event_type, ex, final=noop): | |
try: self(f'before_{event_type}'); f() | |
except ex: self(f'after_cancel_{event_type}') | |
self(f'after_{event_type}'); final() | |
def all_batches(self): | |
self.n_iter = len(self.dl) | |
for o in enumerate(self.dl): self.one_batch(*o) | |
def _backward(self): self.loss_grad.backward() | |
def _step(self): self.opt.step() | |
def _do_grad_opt(self): | |
self._with_events(self._backward, 'backward', CancelBackwardException) | |
self._with_events(self._step, 'step', CancelStepException) | |
self.opt.zero_grad() | |
def _do_one_batch(self): | |
self.pred = self.model(*self.xb) | |
self('after_pred') | |
if len(self.yb): | |
self.loss_grad = self.loss_func(self.pred, *self.yb) | |
self.loss = self.loss_grad.clone() | |
self('after_loss') | |
if not self.training or not len(self.yb): return | |
self._do_grad_opt() | |
def _set_device(self, b): | |
model_device = next(self.model.parameters()).device | |
dls_device = getattr(self.dls, 'device', default_device()) | |
if model_device == dls_device: return to_device(b, dls_device) | |
else: return to_device(b, model_device) | |
def one_batch(self, i, b): | |
self.iter = i | |
b = self._set_device(b) | |
self._split(b) | |
self._with_events(self._do_one_batch, 'batch', CancelBatchException) | |
def _do_epoch_train(self): | |
self.dl = self.dls.train | |
self._with_events(self.all_batches, 'train', CancelTrainException) | |
def _do_epoch_validate(self, ds_idx=1, dl=None): | |
if dl is None: dl = self.dls[ds_idx] | |
self.dl = dl | |
with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException) | |
def _do_epoch(self): | |
self._do_epoch_train() | |
self._do_epoch_validate() | |
def _do_fit(self): | |
for epoch in range(self.n_epoch): | |
self.epoch=epoch | |
self._with_events(self._do_epoch, 'epoch', CancelEpochException) | |
def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False, start_epoch=0): | |
if start_epoch != 0: | |
cbs = L(cbs) + SkipToEpoch(start_epoch) | |
with self.added_cbs(cbs): | |
if reset_opt or not self.opt: self.create_opt() | |
if wd is None: wd = self.wd | |
if wd is not None: self.opt.set_hypers(wd=wd) | |
self.opt.set_hypers(lr=self.lr if lr is None else lr) | |
self.n_epoch = n_epoch | |
self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup) | |
def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None | |
def __enter__(self): self(_before_epoch); return self | |
def __exit__(self, exc_type, exc_value, tb): self(_after_epoch) | |
def validation_context(self, cbs=None, inner=False): | |
cms = [self.no_logging(),self.no_mbar(), self.lock] | |
if cbs: cms.append(self.added_cbs(cbs)) | |
if not inner: cms.append(self) | |
return ContextManagers(cms) | |
def validate(self, ds_idx=1, dl=None, cbs=None): | |
if dl is None: dl = self.dls[ds_idx] | |
with self.validation_context(cbs=cbs): self._do_epoch_validate(ds_idx, dl) | |
return getattr(self, 'final_record', None) | |
def get_preds(self, | |
ds_idx:int=1, # `DataLoader` to use for predictions if `dl` is None. 0: train. 1: valid | |
dl=None, # `DataLoader` to use for predictions, defaults to `ds_idx=1` if None | |
with_input:bool=False, # Return inputs with predictions | |
with_decoded:bool=False, # Return decoded predictions | |
with_loss:bool=False, # Return per item loss with predictions | |
act=None, # Apply activation to predictions, defaults to `self.loss_func`'s activation | |
inner:bool=False, # If False, create progress bar, show logger, use temporary `cbs` | |
reorder:bool=True, # Reorder predictions on dataset indicies, if applicable | |
cbs:Callback|MutableSequence|None=None, # Temporary `Callback`s to apply during prediction | |
**kwargs | |
)-> tuple: | |
if dl is None: dl = self.dls[ds_idx].new(shuffle=False, drop_last=False) | |
else: | |
try: len(dl) | |
except TypeError as e: | |
raise TypeError(f"`dl` is {type(dl)} and doesn't have len(dl)") | |
if isinstance(dl, DataLoader): | |
if dl.drop_last: dl = dl.new(shuffle=False, drop_last=False) | |
if reorder and hasattr(dl, 'get_idxs'): | |
idxs = dl.get_idxs() | |
dl = dl.new(get_idxs = _ConstantFunc(idxs)) | |
cb = GatherPredsCallback(with_input=with_input, with_loss=with_loss, **kwargs) | |
ctx_mgrs = self.validation_context(cbs=L(cbs)+[cb], inner=inner) | |
if with_loss: ctx_mgrs.append(self.loss_not_reduced()) | |
with ContextManagers(ctx_mgrs): | |
self._do_epoch_validate(dl=dl) | |
if act is None: act = getcallable(self.loss_func, 'activation') | |
res = cb.all_tensors() | |
pred_i = 1 if with_input else 0 | |
if res[pred_i] is not None: | |
res[pred_i] = act(res[pred_i]) | |
if with_decoded: res.insert(pred_i+2, getcallable(self.loss_func, 'decodes')(res[pred_i])) | |
if reorder and hasattr(dl, 'get_idxs'): res = nested_reorder(res, tensor(idxs).argsort()) | |
return tuple(res) | |
self._end_cleanup() | |
def predict(self, item, rm_type_tfms=None, with_input=False): | |
dl = self.dls.test_dl([item], rm_type_tfms=rm_type_tfms, num_workers=0) | |
inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True) | |
i = getattr(self.dls, 'n_inp', -1) | |
inp = (inp,) if i==1 else tuplify(inp) | |
dec = self.dls.decode_batch(inp + tuplify(dec_preds))[0] | |
dec_inp,dec_targ = map(detuplify, [dec[:i],dec[i:]]) | |
res = dec_targ,dec_preds[0],preds[0] | |
if with_input: res = (dec_inp,) + res | |
return res | |
def show_results(self, ds_idx=1, dl=None, max_n=9, shuffle=True, **kwargs): | |
if dl is None: dl = self.dls[ds_idx].new(shuffle=shuffle) | |
b = dl.one_batch() | |
_,_,preds = self.get_preds(dl=[b], with_decoded=True) | |
dl.show_results(b, preds, max_n=max_n, **kwargs) | |
def show_training_loop(self): | |
indent = 0 | |
for s in _loop: | |
if s.startswith('Start'): print(f'{" "*indent}{s}'); indent += 2 | |
elif s.startswith('End'): indent -= 2; print(f'{" "*indent}{s}') | |
else: print(f'{" "*indent} - {s:15}:', self.ordered_cbs(s)) | |
def no_logging(self): return replacing_yield(self, 'logger', noop) | |
def no_mbar(self): return replacing_yield(self, 'create_mbar', False) | |
def loss_not_reduced(self): | |
if hasattr(self.loss_func, 'reduction'): return replacing_yield(self.loss_func, 'reduction', 'none') | |
else: return replacing_yield(self, 'loss_func', partial(self.loss_func, reduction='none')) | |
def to_detach(self,b,cpu=True,gather=True): | |
return self.dl.to_detach(b,cpu,gather) if hasattr(getattr(self,'dl',None),'to_detach') else to_detach(b,cpu,gather) | |
def __getstate__(self): return {k:v for k,v in self.__dict__.items() if k!='lock'} | |
def __setstate__(self, state): | |
self.__dict__.update(state) | |
self.lock = threading.Lock() | |
Learner.x,Learner.y = add_props(lambda i,x: detuplify((x.xb,x.yb)[i])) | |
# %% ../nbs/13a_learner.ipynb 26 | |
add_docs(Learner, "Group together a `model`, some `dls` and a `loss_func` to handle training", | |
add_cbs="Add `cbs` to the list of `Callback` and register `self` as their learner", | |
add_cb="Add `cb` to the list of `Callback` and register `self` as their learner", | |
remove_cbs="Remove `cbs` from the list of `Callback` and deregister `self` as their learner", | |
remove_cb="Add `cb` from the list of `Callback` and deregister `self` as their learner", | |
added_cbs="Context manage that temporarily adds `cbs`", | |
removed_cbs="Context manage that temporarily removes `cbs`", | |
ordered_cbs="List of `Callback`s, in order, for an `event` in the training loop", | |
create_opt="Create an optimizer with default hyper-parameters", | |
one_batch="Train or evaluate `self.model` on batch `(xb,yb)`", | |
all_batches="Train or evaluate `self.model` on all the batches of `self.dl`", | |
fit="Fit `self.model` for `n_epoch` using `cbs`. Optionally `reset_opt`.", | |
validate="Validate on `dl` with potential new `cbs`.", | |
get_preds="Get the predictions and targets on the `ds_idx`-th dbunchset or `dl`, optionally `with_input` and `with_loss`", | |
predict="Prediction on `item`, fully decoded, loss function decoded and probabilities", | |
validation_context="A `ContextManagers` suitable for validation, with optional `cbs`", | |
show_results="Show some predictions on `ds_idx`-th dataset or `dl`", | |
show_training_loop="Show each step in the training loop", | |
no_logging="Context manager to temporarily remove `logger`", | |
no_mbar="Context manager to temporarily prevent the master progress bar from being created", | |
loss_not_reduced="A context manager to evaluate `loss_func` with reduction set to none.", | |
to_detach="Calls `to_detach` if `self.dl` provides a `.to_detach` function otherwise calls global `to_detach`", | |
__call__="Call `event_name` for all `Callback`s in `self.cbs`" | |
) | |
# %% ../nbs/13a_learner.ipynb 33 | |
if not hasattr(defaults, 'callbacks'): defaults.callbacks = [TrainEvalCallback] | |
# %% ../nbs/13a_learner.ipynb 88 | |
def _before_batch_cb(f, self): | |
xb,yb = f(self, self.xb, self.yb) | |
self.learn.xb,self.learn.yb = xb,yb | |
# %% ../nbs/13a_learner.ipynb 89 | |
def before_batch_cb(f): | |
"Shortcut for creating a Callback on the `before_batch` event, which takes and returns `xb,yb`" | |
return Callback(before_batch=partial(_before_batch_cb, f)) | |
# %% ../nbs/13a_learner.ipynb 96 | |
def save(self:Learner, file, **kwargs): | |
"Save model and optimizer state (if `with_opt`) to `self.path/self.model_dir/file`" | |
file = join_path_file(file, self.path/self.model_dir, ext='.pth') | |
save_model(file, self.model, getattr(self,'opt',None), **kwargs) | |
return file | |
# %% ../nbs/13a_learner.ipynb 98 | |
def load(self:Learner, file, device=None, **kwargs): | |
"Load model and optimizer state (if `with_opt`) from `self.path/self.model_dir/file` using `device`" | |
if device is None and hasattr(self.dls, 'device'): device = self.dls.device | |
if self.opt is None: self.create_opt() | |
file = join_path_file(file, self.path/self.model_dir, ext='.pth') | |
distrib_barrier() | |
load_model(file, self.model, self.opt, device=device, **kwargs) | |
return self | |
# %% ../nbs/13a_learner.ipynb 102 | |
def export(self:Learner, fname='export.pkl', pickle_module=pickle, pickle_protocol=2): | |
"Export the content of `self` without the items and the optimizer state for inference" | |
if rank_distrib(): return # don't export if child proc | |
self._end_cleanup() | |
old_dbunch = self.dls | |
self.dls = self.dls.new_empty() | |
state = self.opt.state_dict() if self.opt is not None else None | |
self.opt = None | |
with warnings.catch_warnings(): | |
#To avoid the warning that come from PyTorch about model not being checked | |
warnings.simplefilter("ignore") | |
torch.save(self, self.path/fname, pickle_module=pickle_module, pickle_protocol=pickle_protocol) | |
self.create_opt() | |
if state is not None: self.opt.load_state_dict(state) | |
self.dls = old_dbunch | |
# %% ../nbs/13a_learner.ipynb 104 | |
def load_learner(fname, cpu=True, pickle_module=pickle): | |
"Load a `Learner` object in `fname`, by default putting it on the `cpu`" | |
distrib_barrier() | |
map_loc = 'cpu' if cpu else default_device() | |
try: res = torch.load(fname, map_location=map_loc, pickle_module=pickle_module) | |
except AttributeError as e: | |
e.args = [f"Custom classes or functions exported with your `Learner` not available in namespace.\Re-declare/import before loading:\n\t{e.args[0]}"] | |
raise | |
if cpu: | |
res.dls.cpu() | |
if hasattr(res, 'channels_last'): res = res.to_contiguous(to_fp32=True) | |
elif hasattr(res, 'mixed_precision'): res = res.to_fp32() | |
elif hasattr(res, 'non_native_mixed_precision'): res = res.to_non_native_fp32() | |
return res | |
# %% ../nbs/13a_learner.ipynb 111 | |
class Metric(): | |
"Blueprint for defining a metric" | |
def reset(self): pass | |
def accumulate(self, learn): pass | |
def value(self): raise NotImplementedError | |
def name(self): return class2attr(self, 'Metric') | |
_docs = dict( | |
reset="Reset inner state to prepare for new computation", | |
name="Name of the `Metric`, camel-cased and with Metric removed", | |
accumulate="Use `learn` to update the state with new results", | |
value="The value of the metric") | |
# %% ../nbs/13a_learner.ipynb 118 | |
class AvgMetric(Metric): | |
"Average the values of `func` taking into account potential different batch sizes" | |
def __init__(self, func): self.func = func | |
def reset(self): self.total,self.count = 0.,0 | |
def accumulate(self, learn): | |
bs = find_bs(learn.yb) | |
self.total += learn.to_detach(self.func(learn.pred, *learn.yb))*bs | |
self.count += bs | |
def value(self): return self.total/self.count if self.count != 0 else None | |
def name(self): return self.func.func.__name__ if hasattr(self.func, 'func') else self.func.__name__ | |
# %% ../nbs/13a_learner.ipynb 122 | |
class AvgLoss(Metric): | |
"Average the losses taking into account potential different batch sizes" | |
def reset(self): self.total,self.count = 0.,0 | |
def accumulate(self, learn): | |
bs = find_bs(learn.yb) | |
self.total += learn.to_detach(learn.loss.mean())*bs | |
self.count += bs | |
def value(self): return self.total/self.count if self.count != 0 else None | |
def name(self): return "loss" | |
# %% ../nbs/13a_learner.ipynb 126 | |
class AvgSmoothLoss(Metric): | |
"Smooth average of the losses (exponentially weighted with `beta`)" | |
def __init__(self, beta=0.98): self.beta = beta | |
def reset(self): self.count,self.val = 0,tensor(0.) | |
def accumulate(self, learn): | |
self.count += 1 | |
self.val = torch.lerp(to_detach(learn.loss.mean()), self.val, self.beta) | |
def value(self): return self.val/(1-self.beta**self.count) | |
# %% ../nbs/13a_learner.ipynb 129 | |
class ValueMetric(Metric): | |
"Use to include a pre-calculated metric value (for instance calculated in a `Callback`) and returned by `func`" | |
def __init__(self, func, metric_name=None): store_attr('func, metric_name') | |
def value(self): return self.func() | |
def name(self): return self.metric_name if self.metric_name else self.func.__name__ | |
# %% ../nbs/13a_learner.ipynb 133 | |
from fastprogress.fastprogress import format_time | |
# %% ../nbs/13a_learner.ipynb 134 | |
def _maybe_item(t): | |
t = t.value | |
try: return t.item() | |
except: return t | |
# %% ../nbs/13a_learner.ipynb 135 | |
class Recorder(Callback): | |
"Callback that registers statistics (lr, loss and metrics) during training" | |
_stateattrs=('lrs','iters','losses','values') | |
remove_on_fetch,order = True,50 | |
def __init__(self, add_time=True, train_metrics=False, valid_metrics=True, beta=0.98): | |
store_attr('add_time,train_metrics,valid_metrics') | |
self.loss,self.smooth_loss = AvgLoss(),AvgSmoothLoss(beta=beta) | |
def before_fit(self): | |
"Prepare state for training" | |
self.lrs,self.iters,self.losses,self.values = [],[],[],[] | |
names = self.metrics.attrgot('name') | |
if self.train_metrics and self.valid_metrics: | |
names = L('loss') + names | |
names = names.map('train_{}') + names.map('valid_{}') | |
elif self.valid_metrics: names = L('train_loss', 'valid_loss') + names | |
else: names = L('train_loss') + names | |
if self.add_time: names.append('time') | |
self.metric_names = 'epoch'+names | |
self.smooth_loss.reset() | |
def after_batch(self): | |
"Update all metrics and records lr and smooth loss in training" | |
if len(self.yb) == 0: return | |
mets = self._train_mets if self.training else self._valid_mets | |
for met in mets: met.accumulate(self.learn) | |
if not self.training: return | |
self.lrs.append(self.opt.hypers[-1]['lr']) | |
self.losses.append(self.smooth_loss.value) | |
self.learn.smooth_loss = self.smooth_loss.value | |
def before_epoch(self): | |
"Set timer if `self.add_time=True`" | |
self.cancel_train,self.cancel_valid = False,False | |
if self.add_time: self.start_epoch = time.time() | |
self.log = L(getattr(self, 'epoch', 0)) | |
def before_train (self): self._train_mets[1:].map(Self.reset()) | |
def before_validate(self): self._valid_mets.map(Self.reset()) | |
def after_train (self): self.log += self._train_mets.map(_maybe_item) | |
def after_validate(self): self.log += self._valid_mets.map(_maybe_item) | |
def after_cancel_train(self): self.cancel_train = True | |
def after_cancel_validate(self): self.cancel_valid = True | |
def after_epoch(self): | |
"Store and log the loss/metric values" | |
self.learn.final_record = self.log[1:].copy() | |
self.values.append(self.learn.final_record) | |
if self.add_time: self.log.append(format_time(time.time() - self.start_epoch)) | |
self.logger(self.log) | |
self.iters.append(self.smooth_loss.count) | |
def _train_mets(self): | |
if getattr(self, 'cancel_train', False): return L() | |
return L(self.smooth_loss) + (self.metrics if self.train_metrics else L()) | |
def _valid_mets(self): | |
if getattr(self, 'cancel_valid', False): return L() | |
return (L(self.loss) + self.metrics if self.valid_metrics else L()) | |
def plot_loss(self, skip_start=5, with_valid=True): | |
plt.plot(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train') | |
if with_valid: | |
idx = (np.array(self.iters)<skip_start).sum() | |
valid_col = self.metric_names.index('valid_loss') - 1 | |
plt.plot(self.iters[idx:], L(self.values[idx:]).itemgot(valid_col), label='valid') | |
plt.legend() | |
# %% ../nbs/13a_learner.ipynb 136 | |
add_docs(Recorder, | |
before_train = "Reset loss and metrics state", | |
after_train = "Log loss and metric values on the training set (if `self.training_metrics=True`)", | |
before_validate = "Reset loss and metrics state", | |
after_validate = "Log loss and metric values on the validation set", | |
after_cancel_train = "Ignore training metrics for this epoch", | |
after_cancel_validate = "Ignore validation metrics for this epoch", | |
plot_loss = "Plot the losses from `skip_start` and onward") | |
if Recorder not in defaults.callbacks: defaults.callbacks.append(Recorder) | |
# %% ../nbs/13a_learner.ipynb 152 | |
def _cast_tensor(x): | |
if isinstance(x, tuple): return tuple(_cast_tensor(x_) for x_ in x) | |
else: return cast(x, Tensor) if isinstance(x,torch.Tensor) else x | |
# %% ../nbs/13a_learner.ipynb 153 | |
class CastToTensor(Callback): | |
"Cast Subclassed Tensors to `Tensor`" | |
order=9 # Right before MixedPrecision | |
def before_batch(self): | |
self.learn.xb,self.learn.yb = _cast_tensor(self.learn.xb),_cast_tensor(self.learn.yb) | |
# %% ../nbs/13a_learner.ipynb 155 | |
if CastToTensor not in defaults.callbacks: defaults.callbacks.append(CastToTensor) | |
# %% ../nbs/13a_learner.ipynb 185 | |
def freeze_to(self:Learner, n): | |
if self.opt is None: self.create_opt() | |
self.opt.freeze_to(n) | |
self.opt.clear_state() | |
def freeze(self:Learner): self.freeze_to(-1) | |
def unfreeze(self:Learner): self.freeze_to(0) | |
add_docs(Learner, | |
freeze_to="Freeze parameter groups up to `n`", | |
freeze="Freeze up to last parameter group", | |
unfreeze="Unfreeze the entire model") | |
# %% ../nbs/13a_learner.ipynb 189 | |
def tta(self:Learner, ds_idx=1, dl=None, n=4, item_tfms=None, batch_tfms=None, beta=0.25, use_max=False): | |
"Return predictions on the `ds_idx` dataset or `dl` using Test Time Augmentation" | |
if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False) | |
if item_tfms is not None or batch_tfms is not None: dl = dl.new(after_item=item_tfms, after_batch=batch_tfms) | |
try: | |
self(_before_epoch) | |
with dl.dataset.set_split_idx(0), self.no_mbar(): | |
if hasattr(self,'progress'): self.progress.mbar = master_bar(list(range(n))) | |
aug_preds = [] | |
for i in self.progress.mbar if hasattr(self,'progress') else range(n): | |
self.epoch = i #To keep track of progress on mbar since the progress callback will use self.epoch | |
aug_preds.append(self.get_preds(dl=dl, inner=True)[0][None]) | |
aug_preds = torch.cat(aug_preds) | |
aug_preds = aug_preds.max(0)[0] if use_max else aug_preds.mean(0) | |
self.epoch = n | |
with dl.dataset.set_split_idx(1): preds,targs = self.get_preds(dl=dl, inner=True) | |
finally: self(event.after_fit) | |
if use_max: return torch.stack([preds, aug_preds], 0).max(0)[0],targs | |
preds = (aug_preds,preds) if beta is None else torch.lerp(aug_preds, preds, beta) | |
return preds,targs | |