ProtoViT: Interpretable Vision Transformer with Adaptive Prototype Learning

This repository contains pretrained ProtoViT models for interpretable image classification, as described in our paper "Interpretable Image Classification with Adaptive Prototype-based Vision Transformers".

Model Description

ProtoViT combines Vision Transformers with prototype-based learning to create models that are both highly accurate and interpretable. Rather than functioning as a black box, ProtoViT learns interpretable prototypes that explain its classification decisions through visual similarities.

Supported Architectures

We provide three variants of ProtoViT:

  • ProtoViT-T: Built on DeiT-Tiny backbone
  • ProtoViT-S: Built on DeiT-Small backbone
  • ProtoViT-CaiT: Built on CaiT-XXS24 backbone

Performance

All models were trained and evaluated on the CUB-200-2011 fine-grained bird species classification dataset.

Model Version Backbone Resolution Top-1 Accuracy Checkpoint
ProtoViT-T DeiT-Tiny 224ร—224 83.36% Download
ProtoViT-S DeiT-Small 224ร—224 85.30% Download
ProtoViT-CaiT CaiT_xxs24 224ร—224 86.02% Download

Features

  • ๐Ÿ” Interpretable Decisions: The model performs classification with self-explainatory reasoning based on the inputโ€™s similarity to learned prototypes, the key features for each classes.
  • ๐ŸŽฏ High Accuracy: Achieves competitive performance on fine-grained classification tasks
  • ๐Ÿš€ Multiple Architectures: Supports various Vision Transformer backbones
  • ๐Ÿ“Š Analysis Tools: Comes with tools for both local and global prototype analysis

Requirements

  • Python 3.8+
  • PyTorch 1.8+
  • timm==0.4.12
  • torchvision
  • numpy
  • pillow

Limitations and Bias

  • Data Bias: These models are trained on CUB-200-2011, which may not generalize well to images outside this dataset.
  • Resolution Constraints: The models are trained at a resolution of 224ร—224; higher or lower resolutions may impact performance.
  • Location Misalignment: Same as the CNN based models, these models are not perfectly immune to location misalignment under adversarial attack.

Citation

If you use this model in your research, please cite:

@article{ma2024interpretable,
  title={Interpretable Image Classification with Adaptive Prototype-based Vision Transformers},
  author={Ma, Chiyu and Donnelly, Jon and Liu, Wenjun and Vosoughi, Soroush and Rudin, Cynthia and Chen, Chaofan},
  journal={arXiv preprint arXiv:2410.20722},
  year={2024}
}

Acknowledgements

This implementation builds upon the following excellent repositories:

License

This project is released under [MIT] license.

Contact

For any questions or feedback, please:

  1. Open an issue in the GitHub repository
  2. Contact [[email protected]]
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .

Model tree for chiyum609/ProtoViT

Finetuned
(1)
this model