diffusion_pruning / README.md
rezashkv's picture
Update README.md
10037d0 verified
---
license: mit
language:
- en
tags:
- text-to-image
- stable-diffusion
- diffusers
---
# APTP: Adaptive Prompt-Tailored Pruning of T2I Diffusion Models
[![arXiv](https://img.shields.io/badge/Paper-arXiv-red?style=for-the-badge)](https://arxiv.org/abs/2406.12042)
[![Github](https://img.shields.io/badge/Gihub-Code-succees?style=for-the-badge&logo=GitHub)](https://github.com/rezashkv/diffusion_pruning)
The implementation of the paper ["Not All Prompts Are Made Equal: Prompt-based Pruning of Text-to-Image Diffusion Models"](https://arxiv.org/abs/2406.12042)
## Abstract
Text-to-image (T2I) diffusion models have demonstrated impressive image generation capabilities. Still, their computational intensity prohibits
resource-constrained organizations from deploying T2I models after fine-tuning them on their internal target data. While pruning
techniques offer a potential solution to reduce the computational burden of T2I models, static pruning methods use the same pruned
model for all input prompts, overlooking the varying capacity requirements of different prompts. Dynamic pruning addresses this issue by utilizing
a separate sub-network for each prompt, but it prevents batch parallelism on GPUs. To overcome these limitations, we introduce
Adaptive Prompt-Tailored Pruning (APTP), a novel prompt-based pruning method designed for T2I diffusion models. Central to our approach is a
prompt router model, which learns to determine the required capacity for an input text prompt and routes it to an architecture code, given a
total desired compute budget for prompts. Each architecture code represents a specialized model tailored to the prompts assigned to it, and the
number of codes is a hyperparameter. We train the prompt router and architecture codes using contrastive learning, ensuring that similar prompts
are mapped to nearby codes. Further, we employ optimal transport to prevent the codes from collapsing into a single one. We demonstrate APTP's
effectiveness by pruning Stable Diffusion (SD) V2.1 using CC3M and COCO as target datasets. APTP outperforms the
single-model pruning baselines in terms of FID, CLIP, and CMMD scores. Our analysis of the clusters learned by APTP reveals they
are semantically meaningful. We also show that APTP can automatically discover previously empirically found challenging prompts for SD, e.g., prompts for generating text images, assigning them to higher capacity codes.
<p align="center">
<img src="assets/fig_1.gif" alt="APTP Overview" width="600" />
</p>
<p align="left">
<em>APTP: We prune a text-to-image diffusion model like Stable Diffusion (left) into a mixture of efficient experts (right) in a prompt-based manner. Our prompt router routes distinct types of prompts to different experts, allowing experts' architectures to be separately specialized by removing layers or channels.</em>
</p>
<p align="center">
<img src="assets/fig_2.gif" alt="APTP Pruning Scheme" width="600" />
</p>
<p align="left">
<em>APTP pruning scheme. We train the prompt router and the set of architecture codes to prune a T2I diffusion model into a mixture of experts. The prompt router consists of three modules. We use a Sentence Transformer as the prompt encoder to encode the input prompt into a representation z. Then, the architecture predictor transforms z into the architecture embedding e that has the same dimensionality as architecture codes. Finally, the router routes the embedding e into an architecture code a(i). We use optimal transport to evenly distribute the prompts in a training batch among the architecture codes. The architecture code a(i) = (u(i), v(i)) determines pruning the model’s width and depth. We train the prompt router’s parameters and architecture codes in an end-to-end manner using the denoising objective of the pruned model L<sub>DDPM</sub>, distillation loss between the pruned and original models L<sub>distill</sub>, average resource usage for the samples in the batch R, and contrastive objective L<sub>cont</sub>, encouraging embeddings e preserving semantic similarity of the representations z.</em>
</p>
### Model Description
- **Developed by:** UMD Efficiency Group
- **Model type:** Text-to-Image Diffusion Model
- **Model Description:** APTP is a pruning scheme for text-to-image diffusion models like Stable Diffusion, resulting in a mixture of efficient experts specialized for different prompt types.
### License
APTP is released under the MIT License. Please see the [LICENSE](LICENSE) file for details.
## Training Dataset
We used Conceptual Captions and MS-COCO 2014 datasets for training the models. Details for downloading and preparing these datasets are provided in the [Github Repository](https://github.com/rezashkv/diffusion_pruning).
## File Structure
```
APTP
β”œβ”€β”€ APTP-Base-CC3M
β”‚ β”œβ”€β”€ arch0
β”‚ β”œβ”€β”€ ...
β”‚ └── arch15
β”œβ”€β”€ APTP-Small-CC3M
β”‚ β”œβ”€β”€ arch0
β”‚ β”œβ”€β”€ ...
β”‚ └── arch7
β”œβ”€β”€ APTP-Base-COCO
β”‚ β”œβ”€β”€ arch0
β”‚ β”œβ”€β”€ ...
β”‚ └── arch7
└── APTP-Small-COCO
β”œβ”€β”€ arch0
β”œβ”€β”€ ...
└── arch7
```
## Simple Inference Example
Make sure you follow the [provided instructions](https://github.com/rezashkv/diffusion_pruning?tab=readme-ov-file#installation) to install pdm from source.
```python
from diffusers import StableDiffusionPipeline, PNDMScheduler
from pdm.models import HyperStructure, StructureVectorQuantizer, UNet2DConditionModelPruned
from pdm.utils.data_utils import get_mpnet_embeddings
from transformers import AutoTokenizer, AutoModel
import torch
prompt_encoder_model_name_or_path = "sentence-transformers/all-mpnet-base-v2"
aptp_model_name_or_path = f"rezashkv/APTP"
aptp_variant = "APTP-Base-CC3M"
sd_model_name_or_path = "stabilityai/stable-diffusion-2-1"
prompt_encoder = AutoModel.from_pretrained(prompt_encoder_model_name_or_path)
prompt_encoder_tokenizer = AutoTokenizer.from_pretrained(prompt_encoder_model_name_or_path)
hyper_net = HyperStructure.from_pretrained(aptp_model_name_or_path, subfolder=f"{aptp_variant}/hypernet")
quantizer = StructureVectorQuantizer.from_pretrained(aptp_model_name_or_path, subfolder=f"{aptp_variant}/quantizer")
prompts = ["a woman on a white background looks down and away from the camera the a forlorn look on her face"]
prompt_embedding = get_mpnet_embeddings(prompts, prompt_encoder, prompt_encoder_tokenizer)
arch_embedding = hyper_net(prompt_embedding)
expert_id = quantizer.get_cosine_sim_min_encoding_indices(arch_embedding)[0].item()
unet = UNet2DConditionModelPruned.from_pretrained(aptp_model_name_or_path,
subfolder=f"{aptp_variant}/arch{expert_id}/checkpoint-30000/unet")
noise_scheduler = PNDMScheduler.from_pretrained(sd_model_name_or_path, subfolder="scheduler")
pipeline = StableDiffusionPipeline.from_pretrained(sd_model_name_or_path, unet=unet, scheduler=noise_scheduler)
pipeline.to('cuda')
generator = torch.Generator(device='cuda').manual_seed(43)
image = pipeline(
prompt=prompts[0],
guidance_scale=7.5,
generator=generator,
output_type='pil',
).images[0]
image.save("image.png")
```
## Uses
This model is designed for academic and research purposes, specifically for exploring the efficiency of text-to-image diffusion models through prompt-based pruning. Potential applications include:
1. **Research:** Researchers can use the model to study prompt-based pruning techniques and their impact on the performance and efficiency of text-to-image generation models.
2. **Education:** Educators and students can use this model as a learning tool for understanding advanced concepts in neural network pruning, diffusion models, and prompt engineering.
3. **Benchmarking:** The model can be used for benchmarking against other text-to-image generation models to assess the trade-offs between computational efficiency and output quality.
## Safety
When using these models, it is important to consider the following safety and ethical guidelines:
1. **Content Generation:** The model can generate a wide range of images based on text prompts. Users should ensure that the generated content adheres to ethical guidelines and does not produce harmful, offensive, or inappropriate images.
2. **Bias and Fairness:** Like other AI models, APTP may exhibit biases present in the training data. Users should be aware of these potential biases and take steps to mitigate their impact, particularly when the model is used in sensitive or critical applications.
3. **Data Privacy:** Ensure that any data used with the model complies with data privacy regulations. Avoid using personally identifiable information (PII) or sensitive data without proper consent.
4. **Responsible Use:** Users are encouraged to use the model responsibly, considering the potential social and ethical implications of their work. This includes avoiding the generation of misleading or false information and respecting the rights and dignity of individuals depicted in generated images.
By adhering to these guidelines, users can help ensure the responsible and ethical use of the APTP model.
## Contact
In case of any questions or issues, please contact the authors of the paper:
* [Reza Shirkavand](mailto:[email protected])
* [Alireza Ganjdanesh](mailto:[email protected])