import time |
import pickle |
import dill |
import os |
import gradio as gr |
import spaces |
from pnpxai.core.experiment.auto_explanation import AutoExplanationForImageClassification |
from pnpxai.core.detector.detector import extract_graph_data, symbolic_trace |
from pnpxai.explainers.utils.baselines import BASELINE_FUNCTIONS_FOR_IMAGE |
from pnpxai.explainers.utils.feature_masks import FEATURE_MASK_FUNCTIONS_FOR_IMAGE |
import matplotlib.pyplot as plt |
import plotly.graph_objects as go |
import plotly.express as px |
import networkx as nx |
import secrets |
DEFAULT_EXPLAINER = ["GradientXInput", "IntegratedGradients", "LRPEpsilonPlus"] |
class App: |
def __init__(self): |
pass |
class Component: |
def __init__(self): |
pass |
class Tab(Component): |
def __init__(self): |
pass |
class OverviewTab(Tab): |
def __init__(self): |
pass |
def show(self): |
with gr.Tab(label="Overview") as tab: |
gr.Label("This is the overview tab.") |
gr.HTML(self.desc()) |
def desc(self): |
with open("static/overview.html", "r") as f: |
desc = f.read() |
return desc |
class DetectionTab(Tab): |
def __init__(self, experiments): |
self.experiments = experiments |
def show(self): |
with gr.Tab(label="Detection") as tab: |
gr.Label("This is the detection tab.") |
for nm, exp_info in self.experiments.items(): |
exp = exp_info['experiment'] |
detector_res = DetectorRes(exp) |
detector_res.show() |
class LocalExpTab(Tab): |
def __init__(self, experiments): |
self.experiments = experiments |
self.experiment_components = [] |
for nm, exp_info in self.experiments.items(): |
self.experiment_components.append(Experiment(exp_info)) |
def description(self): |
return "This tab shows the local explanation." |
def show(self): |
with gr.Tab(label="Local Explanation") as tab: |
gr.Label("This is the local explanation tab.") |
for i, exp in enumerate(self.experiments): |
self.experiment_components[i].show() |
class DetectorRes(Component): |
def __init__(self, experiment): |
self.experiment = experiment |
graph_module = symbolic_trace(experiment.model) |
self.graph_data = extract_graph_data(graph_module) |
def describe(self): |
return "This component shows the detection result." |
def show(self): |
G = nx.DiGraph() |
root = None |
for node in self.graph_data['nodes']: |
if node['op'] == 'placeholder': |
root = node['name'] |
G.add_node(node['name']) |
for edge in self.graph_data['edges']: |
if edge['source'] in G.nodes and edge['target'] in G.nodes: |
G.add_edge(edge['source'], edge['target']) |
def get_pos1(graph): |
graph = graph.copy() |
for layer, nodes in enumerate(reversed(tuple(nx.topological_generations(graph)))): |
for node in nodes: |
graph.nodes[node]["layer"] = layer |
pos = nx.multipartite_layout(graph, subset_key="layer", align='horizontal') |
return pos |
def get_pos2(graph, root, levels=None, width=1., height=1.): |
''' |
G: the graph |
root: the root node |
levels: a dictionary |
key: level number (starting from 0) |
value: number of nodes in this level |
width: horizontal space allocated for drawing |
height: vertical space allocated for drawing |
''' |
TOTAL = "total" |
CURRENT = "current" |
def make_levels(levels, node=root, currentLevel=0, parent=None): |
if not currentLevel in levels: |
levels[currentLevel] = {TOTAL: 0, CURRENT: 0} |
levels[currentLevel][TOTAL] += 1 |
neighbors = graph.neighbors(node) |
for neighbor in neighbors: |
if not neighbor == parent: |
levels = make_levels(levels, neighbor, currentLevel + 1, node) |
return levels |
def make_pos(pos, node=root, currentLevel=0, parent=None, vert_loc=0): |
dx = 1/levels[currentLevel][TOTAL] |
left = dx/2 |
pos[node] = ((left + dx*levels[currentLevel][CURRENT])*width, vert_loc) |
levels[currentLevel][CURRENT] += 1 |
neighbors = graph.neighbors(node) |
for neighbor in neighbors: |
if not neighbor == parent: |
pos = make_pos(pos, neighbor, currentLevel + |
1, node, vert_loc-vert_gap) |
return pos |
if levels is None: |
levels = make_levels({}) |
else: |
levels = {l: {TOTAL: levels[l], CURRENT: 0} for l in levels} |
vert_gap = height / (max([l for l in levels])+1) |
return make_pos({}) |
def plot_graph(graph, pos): |
fig = plt.figure(figsize=(12, 24)) |
ax = fig.gca() |
nx.draw(graph, pos=pos, with_labels=True, node_size=60, font_size=8, ax=ax) |
fig.tight_layout() |
return fig |
pos = get_pos1(G) |
fig = plot_graph(G, pos) |
with gr.Row(): |
gr.Textbox(value="Image Classficiation", label="Task") |
gr.Textbox(value=f"{self.experiment.model.__class__.__name__}", label="Model") |
gr.Plot(value=fig, label=f"Model Architecture of {self.experiment.model.__class__.__name__}", visible=True) |
class ImgGallery(Component): |
def __init__(self, imgs): |
self.imgs = imgs |
self.selected_index = gr.Number(value=0, label="Selected Index", visible=False) |
def on_select(self, evt: gr.SelectData): |
return evt.index |
def show(self): |
self.gallery_obj = gr.Gallery(value=self.imgs, label="Input Data Gallery", columns=6, height=200) |
self.gallery_obj.select(self.on_select, outputs=self.selected_index) |
class Experiment(Component): |
def __init__(self, exp_info): |
self.exp_info = exp_info |
self.experiment = exp_info['experiment'] |
self.input_visualizer = exp_info['input_visualizer'] |
self.target_visualizer = exp_info['target_visualizer'] |
def viz_input(self, input, data_id): |
orig_img_np = self.input_visualizer(input) |
orig_img = px.imshow(orig_img_np) |
orig_img.update_layout( |
title=f"Data ID: {data_id}", |
width=400, |
height=350, |
xaxis=dict( |
showticklabels=False, |
ticks='', |
showgrid=False |
), |
yaxis=dict( |
showticklabels=False, |
ticks='', |
showgrid=False |
), |
) |
return orig_img |
def get_prediction(self, record, topk=3): |
probs = record['output'].softmax(-1).squeeze().detach().numpy() |
text = f"Ground Truth Label: {self.target_visualizer(record['label'])}\n" |
for ind, pred in enumerate(probs.argsort()[-topk:][::-1]): |
label = self.target_visualizer(torch.tensor(pred)) |
prob = probs[pred] |
text += f"Top {ind+1} Prediction: {label} ({prob:.2f})\n" |
return text |
def get_exp_plot(self, data_index, exp_res): |
return ExpRes(data_index, exp_res).show() |
def get_metric_id_by_name(self, metric_name): |
metric_info = self.experiment.manager.get_metrics() |
idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name) |
return metric_info[1][idx] |
def generate_record(self, checkbox_group_info, data_id, metric_names): |
record = {} |
_base = self.experiment.run_batch([data_id], 0, 0, 0) |
record['data_id'] = data_id |
record['input'] = _base['inputs'] |
record['label'] = _base['labels'] |
record['output'] = _base['outputs'] |
record['target'] = _base['targets'] |
record['explanations'] = [] |
metrics_ids = [self.get_metric_id_by_name(metric_nm) for metric_nm in metric_names] |
cnt = 0 |
for info in checkbox_group_info: |
if info['checked']: |
base = self.experiment.run_batch([data_id], info['id'], info['pp_id'], 0) |
record['explanations'].append({ |
'explainer_nm': base['explainer'].__class__.__name__, |
'value': base['postprocessed'], |
'mode' : info['mode'], |
'evaluations': [] |
}) |
for metric_id in metrics_ids: |
res = self.experiment.run_batch([data_id], info['id'], info['pp_id'], metric_id) |
record['explanations'][-1]['evaluations'].append({ |
'metric_nm': res['metric'].__class__.__name__, |
'value' : res['evaluation'] |
}) |
cnt += 1 |
if len(record['explanations'][0]['evaluations']) > 0: |
record['explanations'] = sorted(record['explanations'], key=lambda x: x['evaluations'][0]['value'], reverse=True) |
return record |
def show(self): |
with gr.Row(): |
gr.Textbox(value="Image Classficiation", label="Task") |
gr.Textbox(value=f"{self.experiment.model.__class__.__name__}", label="Model") |
gr.Textbox(value="Heatmap", label="Explanation Type") |
dset = self.experiment.manager._data.dataset |
imgs = [] |
for i in range(len(dset)): |
img = self.input_visualizer(dset[i][0]) |
imgs.append(img) |
gallery = ImgGallery(imgs) |
gallery.show() |
explainers, _ = self.experiment.manager.get_explainers() |
explainer_names = [exp.__class__.__name__ for exp in explainers] |
self.explainer_checkbox_group = ExplainerCheckboxGroup(explainer_names, self.experiment, gallery) |
self.explainer_checkbox_group.show() |
cr_metrics_names = ["AbPC", "MoRF", "LeRF", "MuFidelity"] |
cn_metrics_names = ["Sensitivity"] |
cp_metrics_names = ["Complexity"] |
with gr.Accordion("Evaluators", open=True): |
with gr.Row(): |
cr_metrics = gr.CheckboxGroup(choices=cr_metrics_names, value=[cr_metrics_names[0]], label="Correctness") |
def on_select(metrics): |
if cr_metrics_names[0] not in metrics: |
gr.Warning(f"{cr_metrics_names[0]} is required for the sorting the explanations.") |
return [cr_metrics_names[0]] + metrics |
else: |
return metrics |
cr_metrics.select(on_select, inputs=cr_metrics, outputs=cr_metrics) |
with gr.Row(): |
cn_metrics = gr.CheckboxGroup(choices=cn_metrics_names, label="Continuity") |
with gr.Row(): |
cp_metrics = gr.CheckboxGroup(choices=cp_metrics_names, label="Compactness") |
metric_inputs = [cr_metrics, cn_metrics, cp_metrics] |
data_id = gallery.selected_index |
bttn = gr.Button("Explain", variant="primary") |
buffer_size = 2 * len(explainer_names) |
buffer_n_rows = buffer_size // PLOT_PER_LINE |
buffer_n_rows = buffer_n_rows + 1 if buffer_size % PLOT_PER_LINE != 0 else buffer_n_rows |
plots = [gr.Textbox(label="Prediction result", visible=False)] |
for i in range(buffer_n_rows): |
with gr.Row(): |
for j in range(PLOT_PER_LINE): |
plot = gr.Image(value=None, label="Blank", visible=False) |
plots.append(plot) |
def show_plots(checkbox_group_info): |
_plots = [gr.Textbox(label="Prediction result", visible=False)] |
num_plots = sum([1 for info in checkbox_group_info if info['checked']]) |
n_rows = num_plots // PLOT_PER_LINE |
n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows |
_plots += [gr.Image(value=None, label="Blank", visible=True)] * (n_rows * PLOT_PER_LINE) |
_plots += [gr.Image(value=None, label="Blank", visible=False)] * ((buffer_n_rows - n_rows) * PLOT_PER_LINE) |
return _plots |
@spaces.GPU |
def render_plots(data_id, checkbox_group_info, *metric_inputs): |
cache_dir = f"{os.environ['GRADIO_TEMP_DIR']}/res" |
if not os.path.exists(cache_dir): os.makedirs(cache_dir) |
for f in os.listdir(cache_dir): |
if len(f.split(".")[0]) == 16: |
os.remove(os.path.join(cache_dir, f)) |
metric_input = [] |
for metric in metric_inputs: |
if metric: |
metric_input += metric |
record = self.generate_record(checkbox_group_info, data_id, metric_input) |
pred = self.get_prediction(record) |
plots = [gr.Textbox(label="Prediction result", value=pred, visible=True)] |
num_plots = sum([1 for info in checkbox_group_info if info['checked']]) |
n_rows = num_plots // PLOT_PER_LINE |
n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows |
for i in range(n_rows): |
for j in range(PLOT_PER_LINE): |
if i*PLOT_PER_LINE+j < len(record['explanations']): |
exp_res = record['explanations'][i*PLOT_PER_LINE+j] |
path = self.get_exp_plot(data_id, exp_res) |
plot_obj = gr.Image(value=path, label=f"{exp_res['explainer_nm']} ({exp_res['mode']})", visible=True) |
plots.append(plot_obj) |
else: |
plots.append(gr.Image(value=None, label="Blank", visible=True)) |
plots += [gr.Image(value=None, label="Blank", visible=False)] * ((buffer_n_rows - n_rows) * PLOT_PER_LINE) |
return plots |
bttn.click(show_plots, inputs=[self.explainer_checkbox_group.info], outputs=plots) |
bttn.click(render_plots, inputs=[data_id, self.explainer_checkbox_group.info] + metric_inputs, outputs=plots) |
class ExplainerCheckboxGroup(Component): |
def __init__(self, explainer_names, experiment, gallery): |
super().__init__() |
self.explainer_names = explainer_names |
self.explainer_objs = [] |
self.experiment = experiment |
self.gallery = gallery |
explainers, exp_ids = self.experiment.manager.get_explainers() |
info = [] |
for exp, exp_id in zip(explainers, exp_ids): |
exp_nm = exp.__class__.__name__ |
if exp_nm in DEFAULT_EXPLAINER: |
checked = True |
else: |
checked = False |
info.append({'nm': exp_nm, 'id': exp_id, 'pp_id' : 0, 'mode': 'default', 'checked': checked}) |
self.static_info = sorted(info, key=lambda x: (x['nm'] not in DEFAULT_EXPLAINER, x['nm'])) |
self.info = gr.State(info) |
def update_check(self, checkbox_group_info, exp_id, val=None): |
for info in checkbox_group_info: |
if info['id'] == exp_id: |
if val is not None: |
info['checked'] = val |
else: |
info['checked'] = not info['checked'] |
return checkbox_group_info |
def insert_check(self, checkbox_group_info, exp_nm, exp_id, pp_id): |
if exp_id in [info['id'] for info in checkbox_group_info]: |
return |
checkbox_group_info.append({'nm': exp_nm, 'id': exp_id, 'pp_id' : pp_id, 'mode': 'optimal', 'checked': False}) |
return checkbox_group_info |
def update_gallery_change(self, checkbox_group_info): |
checkboxes = [] |
bttns = [] |
for exp in self.explainer_objs: |
val = exp.explainer_name in DEFAULT_EXPLAINER |
checkboxes.append(gr.Checkbox(label="Default Parameter", value=val, interactive=True)) |
checkboxes += [gr.Checkbox(label="Optimized Parameter (Not Optimal)", value=False, interactive=False)] * len(self.explainer_objs) |
bttns += [gr.Button(value="Optimize", size="sm", variant="primary")] * len(self.explainer_objs) |
for exp in self.explainer_objs: |
val = exp.explainer_name in DEFAULT_EXPLAINER |
checkbox_group_info = self.update_check(checkbox_group_info, exp.default_exp_id, val) |
if hasattr(exp, "optimal_exp_id"): |
checkbox_group_info = self.update_check(checkbox_group_info, exp.optimal_exp_id, False) |
return checkboxes + bttns + [checkbox_group_info] |
def get_checkboxes(self): |
checkboxes = [] |
checkboxes += [exp.default_check for exp in self.explainer_objs] |
checkboxes += [exp.opt_check for exp in self.explainer_objs] |
return checkboxes |
def get_bttns(self): |
return [exp.bttn for exp in self.explainer_objs] |
def show(self): |
cnt = 0 |
with gr.Accordion("Explainers", open=True): |
while cnt * PLOT_PER_LINE < len(self.explainer_names): |
with gr.Row(): |
for info in self.static_info[cnt*PLOT_PER_LINE:(cnt+1)*PLOT_PER_LINE]: |
explainer_obj = ExplainerCheckbox(info['nm'], self, self.experiment, self.gallery) |
self.explainer_objs.append(explainer_obj) |
explainer_obj.show() |
cnt += 1 |
checkboxes = self.get_checkboxes() |
bttns = self.get_bttns() |
self.gallery.gallery_obj.select( |
fn=self.update_gallery_change, |
inputs=self.info, |
outputs=checkboxes + bttns + [self.info], |
) |
class ExplainerCheckbox(Component): |
def __init__(self, explainer_name, groups, experiment, gallery): |
self.explainer_name = explainer_name |
self.groups = groups |
self.experiment = experiment |
self.gallery = gallery |
self.opt_res = gr.State(None) |
self.default_exp_id = self.get_explainer_id_by_name(explainer_name) |
self.obj_metric = self.get_metric_id_by_name(OBJECTIVE_METRIC) |
def get_explainer_id_by_name(self, explainer_name): |
explainer_info = self.experiment.manager.get_explainers() |
idx = [exp.__class__.__name__ for exp in explainer_info[0]].index(explainer_name) |
return explainer_info[1][idx] |
def get_metric_id_by_name(self, metric_name): |
metric_info = self.experiment.manager.get_metrics() |
idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name) |
return metric_info[1][idx] |
def get_str_ppid(self, pp_obj): |
return pp_obj.pooling_fn.__class__.__name__ + pp_obj.normalization_fn.__class__.__name__ |
def default_on_select(self, evt: gr.EventData, checkbox_group_info): |
checkbox_group_info = self.groups.update_check(checkbox_group_info, self.default_exp_id, evt._data['value']) |
return checkbox_group_info |
def optimal_on_select(self, evt: gr.EventData, checkbox_group_info, opt_res): |
if hasattr(self, "optimal_exp_id"): |
checkbox_group_info = self.groups.update_check(checkbox_group_info, self.optimal_exp_id, evt._data['value']) |
else: |
raise ValueError("Optimal result is not found.") |
return checkbox_group_info |
def show(self): |
val = self.explainer_name in DEFAULT_EXPLAINER |
with gr.Accordion(self.explainer_name, open=val): |
checked = next(filter(lambda x: x['nm'] == self.explainer_name, self.groups.static_info))['checked'] |
self.default_check = gr.Checkbox(label="Default Parameter", value=checked, interactive=True) |
self.opt_check = gr.Checkbox(label="Optimized Parameter (Not Optimal)", interactive=False) |
self.default_check.select(self.default_on_select, self.groups.info, self.groups.info) |
self.opt_check.select(self.optimal_on_select, [self.groups.info, self.opt_res], self.groups.info) |
self.bttn = gr.Button(value="Optimize", size="sm", variant="primary") |
@spaces.GPU |
def optimize(checkbox_group_info): |
data_id = self.gallery.selected_index |
opt_output = self.experiment.optimize( |
data_ids=data_id.value, |
explainer_id=self.default_exp_id, |
metric_id=self.obj_metric, |
direction='maximize', |
sampler=SAMPLE_METHOD, |
n_trials=OPT_N_TRIALS, |
) |
str_id = self.get_str_ppid(opt_output.postprocessor) |
for pp_obj, pp_id in zip(*self.experiment.manager.get_postprocessors()): |
if self.get_str_ppid(pp_obj) == str_id: |
opt_postprocessor_id = pp_id |
break |
opt_exp_id = max([x['id'] for x in checkbox_group_info]) + 1 |
opt_res = { |
'id': opt_exp_id, |
'class': opt_output.explainer.__class__, |
'params' : opt_output.study.best_trial.params, |
} |
self.groups.insert_check(checkbox_group_info, self.explainer_name, opt_exp_id, opt_postprocessor_id) |
checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True) |
bttn = gr.update(value="Optimized", variant="secondary") |
return [opt_res, checkbox_group_info, checkbox, bttn] |
def update_exp(exp_res): |
_id = exp_res['id'] |
try: |
kwargs = {} |
has_baseline = False |
has_feature_mask = False |
for k,v in exp_res['params'].items(): |
if "explainer" in k: |
_key = k.split("explainer.")[1] |
kwargs[_key] = v |
if "baseline_fn" in _key: |
has_baseline = True |
if "feature_mask_fn" in _key: |
has_feature_mask = True |
if has_baseline: |
method = kwargs['baseline_fn.method'] |
del kwargs['baseline_fn.method'] |
baseline_kwargs = {} |
keys = list(kwargs.keys()) |
for k in keys: |
v = kwargs[k] |
if "baseline_fn" in k: |
baseline_kwargs[k.split("baseline_fn.")[1]] = v |
del kwargs[k] |
if method == "mean": |
baseline_kwargs['dim'] = 1 |
baseline_fn = BASELINE_FUNCTIONS_FOR_IMAGE[method](**baseline_kwargs) |
kwargs['baseline_fn'] = baseline_fn |
if has_feature_mask: |
method = kwargs['feature_mask_fn.method'] |
del kwargs['feature_mask_fn.method'] |
mask_kwargs = {} |
keys = list(kwargs.keys()) |
for k in keys: |
v = kwargs[k] |
if "feature_mask_fn" in k: |
mask_kwargs[k.split("feature_mask_fn.")[1]] = v |
del kwargs[k] |
mask_fn = FEATURE_MASK_FUNCTIONS_FOR_IMAGE[method](**mask_kwargs) |
kwargs['feature_mask_fn'] = mask_fn |
kwargs['model'] = self.experiment.model |
explainer = exp_res['class'](**kwargs) |
except Exception as e: |
print(f"[Optimizer] Explainer Reconstrcution Error Catched : {e}") |
explainer = self.experiment.manager._explainers[self.default_exp_id] |
self.experiment.manager._explainers.append(explainer) |
self.experiment.manager._explainer_ids.append(_id) |
self.optimal_exp_id = _id |
self.bttn.click(optimize, inputs=[self.groups.info], outputs=[self.opt_res, self.groups.info, self.opt_check, self.bttn], queue=True, concurrency_limit=1) |
self.opt_res.change(update_exp, self.opt_res) |
class ExpRes(Component): |
def __init__(self, data_index, exp_res): |
self.data_index = data_index |
self.exp_res = exp_res |
def show(self): |
value = self.exp_res['value'] |
fig = go.Figure(data=go.Heatmap( |
z=np.flipud(value[0].detach().numpy()), |
colorscale='Reds', |
showscale=False |
)) |
evaluations = self.exp_res['evaluations'] |
metric_values = [f"{eval['metric_nm'][:4]}: {eval['value'].item():.2f}" for eval in evaluations if eval['value'] is not None] |
n = 3 |
cnt = 0 |
while cnt * n < len(metric_values): |
metric_text = ', '.join(metric_values[cnt*n:cnt*n+n]) |
fig.add_annotation( |
x=0, |
y=-0.1 * (cnt+1), |
xref='paper', |
yref='paper', |
text=metric_text, |
showarrow=False, |
font=dict( |
size=18, |
), |
) |
cnt += 1 |
fig = fig.update_layout( |
width=380, |
height=400, |
xaxis=dict( |
showticklabels=False, |
ticks='', |
showgrid=False |
), |
yaxis=dict( |
showticklabels=False, |
ticks='', |
showgrid=False |
), |
margin=dict(t=40, b=40*cnt, l=20, r=20), |
) |
root = f"{os.environ['GRADIO_TEMP_DIR']}/res" |
if not os.path.exists(root): os.makedirs(root) |
key = secrets.token_hex(8) |
path = f"{root}/{key}.png" |
fig.write_image(path) |
return path |
class ImageClsApp(App): |
def __init__(self, experiments, **kwargs): |
self.name = "Image Classification App" |
super().__init__(**kwargs) |
self.experiments = experiments |
self.overview_tab = OverviewTab() |
self.detection_tab = DetectionTab(self.experiments) |
self.local_exp_tab = LocalExpTab(self.experiments) |
def title(self): |
return f""" |
<div style="text-align: center;"> |
<a href="https://openxaiproject.github.io/pnpxai/"> |
<img src="file/static/XAI-Top-PnP.png" width="167" height="100"> |
</a> |
<h1> Plug and Play XAI Platform for Image Classification </h1> |
</div> |
""" |
def launch(self, **kwargs): |
with gr.Blocks( |
title=self.name, |
) as demo: |
file_path = os.path.dirname(os.path.abspath(__file__)) |
gr.set_static_paths(file_path) |
gr.HTML(self.title()) |
self.overview_tab.show() |
self.detection_tab.show() |
self.local_exp_tab.show() |
return demo |
import os |
import torch |
import numpy as np |
from torch.utils.data import DataLoader |
from helpers import get_imagenet_dataset, get_torchvision_model, denormalize_image |
os.environ['GRADIO_TEMP_DIR'] = '.tmp' |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
def target_visualizer(x): return dataset.dataset.idx_to_label(x.item()) |
experiments = {} |
model, transform = get_torchvision_model('resnet18') |
dataset = get_imagenet_dataset(transform) |
loader = DataLoader(dataset, batch_size=4, shuffle=False) |
experiment1 = AutoExplanationForImageClassification( |
model=model.to(device), |
data=loader, |
input_extractor=lambda batch: batch[0].to(device), |
label_extractor=lambda batch: batch[-1].to(device), |
target_extractor=lambda outputs: outputs.argmax(-1).to(device), |
channel_dim=1 |
) |
experiments['experiment1'] = { |
'name': 'ResNet18', |
'experiment': experiment1, |
'input_visualizer': lambda x: denormalize_image(x, transform.mean, transform.std), |
'target_visualizer': target_visualizer, |
} |
model, transform = get_torchvision_model('vit_b_16') |
dataset = get_imagenet_dataset(transform) |
loader = DataLoader(dataset, batch_size=4, shuffle=False) |
experiment2 = AutoExplanationForImageClassification( |
model=model.to(device), |
data=loader, |
input_extractor=lambda batch: batch[0].to(device), |
label_extractor=lambda batch: batch[-1].to(device), |
target_extractor=lambda outputs: outputs.argmax(-1).to(device), |
channel_dim=1 |
) |
experiments['experiment2'] = { |
'name': 'ViT-B_16', |
'experiment': experiment2, |
'input_visualizer': lambda x: denormalize_image(x, transform.mean, transform.std), |
'target_visualizer': target_visualizer, |
} |
app = ImageClsApp(experiments) |
demo = app.launch() |
demo.launch(favicon_path=f"static/XAI-Top-PnP.svg") |