is_virtualstaging / fastai /interpret.py
AItool's picture
Upload 127 files
a983ebc
raw
history blame
7.77 kB
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/20_interpret.ipynb.
# %% ../nbs/20_interpret.ipynb 2
from __future__ import annotations
from .data.all import *
from .optimizer import *
from .learner import *
from .tabular.core import *
import sklearn.metrics as skm
# %% auto 0
__all__ = ['plot_top_losses', 'Interpretation', 'ClassificationInterpretation', 'SegmentationInterpretation']
# %% ../nbs/20_interpret.ipynb 7
@typedispatch
def plot_top_losses(x, y, *args, **kwargs):
raise Exception(f"plot_top_losses is not implemented for {type(x)},{type(y)}")
# %% ../nbs/20_interpret.ipynb 8
_all_ = ["plot_top_losses"]
# %% ../nbs/20_interpret.ipynb 9
class Interpretation():
"Interpretation base class, can be inherited for task specific Interpretation classes"
def __init__(self,
learn:Learner,
dl:DataLoader, # `DataLoader` to run inference over
losses:TensorBase, # Losses calculated from `dl`
act=None # Activation function for prediction
):
store_attr()
def __getitem__(self, idxs):
"Return inputs, preds, targs, decoded outputs, and losses at `idxs`"
if isinstance(idxs, Tensor): idxs = idxs.tolist()
if not is_listy(idxs): idxs = [idxs]
items = getattr(self.dl.items, 'iloc', L(self.dl.items))[idxs]
tmp_dl = self.learn.dls.test_dl(items, with_labels=True, process=not isinstance(self.dl, TabDataLoader))
inps,preds,targs,decoded = self.learn.get_preds(dl=tmp_dl, with_input=True, with_loss=False,
with_decoded=True, act=self.act, reorder=False)
return inps, preds, targs, decoded, self.losses[idxs]
@classmethod
def from_learner(cls,
learn, # Model used to create interpretation
ds_idx:int=1, # Index of `learn.dls` when `dl` is None
dl:DataLoader=None, # `Dataloader` used to make predictions
act=None # Override default or set prediction activation function
):
"Construct interpretation object from a learner"
if dl is None: dl = learn.dls[ds_idx].new(shuffle=False, drop_last=False)
_,_,losses = learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=False,
with_preds=False, with_targs=False, act=act)
return cls(learn, dl, losses, act)
def top_losses(self,
k:int|None=None, # Return `k` losses, defaults to all
largest:bool=True, # Sort losses by largest or smallest
items:bool=False # Whether to return input items
):
"`k` largest(/smallest) losses and indexes, defaulting to all losses."
losses, idx = self.losses.topk(ifnone(k, len(self.losses)), largest=largest)
if items: return losses, idx, getattr(self.dl.items, 'iloc', L(self.dl.items))[idx]
else: return losses, idx
def plot_top_losses(self,
k:int|MutableSequence, # Number of losses to plot
largest:bool=True, # Sort losses by largest or smallest
**kwargs
):
"Show `k` largest(/smallest) preds and losses. Implementation based on type dispatch"
if is_listy(k) or isinstance(k, range):
losses, idx = (o[k] for o in self.top_losses(None, largest))
else:
losses, idx = self.top_losses(k, largest)
inps, preds, targs, decoded, _ = self[idx]
inps, targs, decoded = tuplify(inps), tuplify(targs), tuplify(decoded)
x, y, its = self.dl._pre_show_batch(inps+targs, max_n=len(idx))
x1, y1, outs = self.dl._pre_show_batch(inps+decoded, max_n=len(idx))
if its is not None:
plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), preds, losses, **kwargs)
#TODO: figure out if this is needed
#its None means that a batch knows how to show itself as a whole, so we pass x, x1
#else: show_results(x, x1, its, ctxs=ctxs, max_n=max_n, **kwargs)
def show_results(self,
idxs:list, # Indices of predictions and targets
**kwargs
):
"Show predictions and targets of `idxs`"
if isinstance(idxs, Tensor): idxs = idxs.tolist()
if not is_listy(idxs): idxs = [idxs]
inps, _, targs, decoded, _ = self[idxs]
b = tuplify(inps)+tuplify(targs)
self.dl.show_results(b, tuplify(decoded), max_n=len(idxs), **kwargs)
# %% ../nbs/20_interpret.ipynb 22
class ClassificationInterpretation(Interpretation):
"Interpretation methods for classification models."
def __init__(self,
learn:Learner,
dl:DataLoader, # `DataLoader` to run inference over
losses:TensorBase, # Losses calculated from `dl`
act=None # Activation function for prediction
):
super().__init__(learn, dl, losses, act)
self.vocab = self.dl.vocab
if is_listy(self.vocab): self.vocab = self.vocab[-1]
def confusion_matrix(self):
"Confusion matrix as an `np.ndarray`."
x = torch.arange(0, len(self.vocab))
_,targs,decoded = self.learn.get_preds(dl=self.dl, with_decoded=True, with_preds=True,
with_targs=True, act=self.act)
d,t = flatten_check(decoded, targs)
cm = ((d==x[:,None]) & (t==x[:,None,None])).long().sum(2)
return to_np(cm)
def plot_confusion_matrix(self,
normalize:bool=False, # Whether to normalize occurrences
title:str='Confusion matrix', # Title of plot
cmap:str="Blues", # Colormap from matplotlib
norm_dec:int=2, # Decimal places for normalized occurrences
plot_txt:bool=True, # Display occurrence in matrix
**kwargs
):
"Plot the confusion matrix, with `title` and using `cmap`."
# This function is mainly copied from the sklearn docs
cm = self.confusion_matrix()
if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fig = plt.figure(**kwargs)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
tick_marks = np.arange(len(self.vocab))
plt.xticks(tick_marks, self.vocab, rotation=90)
plt.yticks(tick_marks, self.vocab, rotation=0)
if plot_txt:
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
coeff = f'{cm[i, j]:.{norm_dec}f}' if normalize else f'{cm[i, j]}'
plt.text(j, i, coeff, horizontalalignment="center", verticalalignment="center", color="white"
if cm[i, j] > thresh else "black")
ax = fig.gca()
ax.set_ylim(len(self.vocab)-.5,-.5)
plt.tight_layout()
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.grid(False)
def most_confused(self, min_val=1):
"Sorted descending largest non-diagonal entries of confusion matrix (actual, predicted, # occurrences"
cm = self.confusion_matrix()
np.fill_diagonal(cm, 0)
res = [(self.vocab[i],self.vocab[j],cm[i,j]) for i,j in zip(*np.where(cm>=min_val))]
return sorted(res, key=itemgetter(2), reverse=True)
def print_classification_report(self):
"Print scikit-learn classification report"
_,targs,decoded = self.learn.get_preds(dl=self.dl, with_decoded=True, with_preds=True,
with_targs=True, act=self.act)
d,t = flatten_check(decoded, targs)
names = [str(v) for v in self.vocab]
print(skm.classification_report(t, d, labels=list(self.vocab.o2i.values()), target_names=names))
# %% ../nbs/20_interpret.ipynb 27
class SegmentationInterpretation(Interpretation):
"Interpretation methods for segmentation models."
pass