ProtoViT: Interpretable Vision Transformer with Adaptive Prototype Learning
This repository contains pretrained ProtoViT models for interpretable image classification, as described in our paper "Interpretable Image Classification with Adaptive Prototype-based Vision Transformers".
Model Description
ProtoViT combines Vision Transformers with prototype-based learning to create models that are both highly accurate and interpretable. Rather than functioning as a black box, ProtoViT learns interpretable prototypes that explain its classification decisions through visual similarities.
Supported Architectures
We provide three variants of ProtoViT:
- ProtoViT-T: Built on DeiT-Tiny backbone
- ProtoViT-S: Built on DeiT-Small backbone
- ProtoViT-CaiT: Built on CaiT-XXS24 backbone
Performance
All models were trained and evaluated on the CUB-200-2011 fine-grained bird species classification dataset.
Model Version | Backbone | Resolution | Top-1 Accuracy | Checkpoint |
---|---|---|---|---|
ProtoViT-T | DeiT-Tiny | 224ร224 | 83.36% | Download |
ProtoViT-S | DeiT-Small | 224ร224 | 85.30% | Download |
ProtoViT-CaiT | CaiT_xxs24 | 224ร224 | 86.02% | Download |
Features
- ๐ Interpretable Decisions: The model performs classification with self-explainatory reasoning based on the inputโs similarity to learned prototypes, the key features for each classes.
- ๐ฏ High Accuracy: Achieves competitive performance on fine-grained classification tasks
- ๐ Multiple Architectures: Supports various Vision Transformer backbones
- ๐ Analysis Tools: Comes with tools for both local and global prototype analysis
Requirements
- Python 3.8+
- PyTorch 1.8+
- timm==0.4.12
- torchvision
- numpy
- pillow
Limitations and Bias
- Data Bias: These models are trained on CUB-200-2011, which may not generalize well to images outside this dataset.
- Resolution Constraints: The models are trained at a resolution of 224ร224; higher or lower resolutions may impact performance.
- Location Misalignment: Same as the CNN based models, these models are not perfectly immune to location misalignment under adversarial attack.
Citation
If you use this model in your research, please cite:
@article{ma2024interpretable,
title={Interpretable Image Classification with Adaptive Prototype-based Vision Transformers},
author={Ma, Chiyu and Donnelly, Jon and Liu, Wenjun and Vosoughi, Soroush and Rudin, Cynthia and Chen, Chaofan},
journal={arXiv preprint arXiv:2410.20722},
year={2024}
}
Acknowledgements
This implementation builds upon the following excellent repositories:
License
This project is released under [MIT] license.
Contact
For any questions or feedback, please:
- Open an issue in the GitHub repository
- Contact [[email protected]]
Model tree for chiyum609/ProtoViT
Base model
timm/cait_xxs24_224.fb_dist_in1k