How to build a community pipeline
Note: this page was built from the GitHub Issue on Community Pipelines #841.
Let’s make an example! Say you want to define a pipeline that just does a single forward pass to a U-Net and then calls a scheduler only once (Note, this doesn’t make any sense from a scientific point of view, but only represents an example of how things work under the hood).
Cool! So you open your favorite IDE and start creating your pipeline 💻.
First, what model weights and configurations do we need?
We have a U-Net and a scheduler, so our pipeline should take a U-Net and a scheduler as an argument.
Also, as stated above, you’d like to be able to load weights and the scheduler config for Hub and share your code with others, so we’ll inherit from DiffusionPipeline
:
from diffusers import DiffusionPipeline
import torch
class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler):
super().__init__()
Now, we must save the unet
and scheduler
in a config file so that you can save your pipeline with save_pretrained
.
Therefore, make sure you add every component that is save-able to the register_modules
function:
from diffusers import DiffusionPipeline
import torch
class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
Cool, the init is done! 🔥 Now, let’s go into the forward pass, which we recommend defining as __call__
. Here you’re given all the creative freedom there is. For our amazing “one-step” pipeline, we simply create a random image and call the unet once and the scheduler once:
from diffusers import DiffusionPipeline
import torch
class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
def __call__(self):
image = torch.randn(
(1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
)
timestep = 1
model_output = self.unet(image, timestep).sample
scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
return scheduler_output
Cool, that’s it! 🚀 You can now run this pipeline by passing a unet
and a scheduler
to the init:
from diffusers import DDPMScheduler, Unet2DModel
scheduler = DDPMScheduler()
unet = UNet2DModel()
pipeline = UnetSchedulerOneForwardPipeline(unet=unet, scheduler=scheduler)
output = pipeline()
But what’s even better is that you can load pre-existing weights into the pipeline if they match exactly your pipeline structure. This is e.g. the case for https://huggingface.co/google/ddpm-cifar10-32 so that we can do the following:
pipeline = UnetSchedulerOneForwardPipeline.from_pretrained("google/ddpm-cifar10-32")
output = pipeline()
We want to share this amazing pipeline with the community, so we would open a PR request to add the following code under one_step_unet.py
to https://github.com/huggingface/diffusers/tree/main/examples/community .
from diffusers import DiffusionPipeline
import torch
class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
def __call__(self):
image = torch.randn(
(1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
)
timestep = 1
model_output = self.unet(image, timestep).sample
scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
return scheduler_output
Our amazing pipeline got merged here: #840.
Now everybody that has diffusers >= 0.4.0
installed can use our pipeline magically 🪄 as follows:
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="one_step_unet")
pipe()
Another way to upload your custom_pipeline, besides sending a PR, is uploading the code that contains it to the Hugging Face Hub, as exemplified here.
Try it out now - it works!
In general, you will want to create much more sophisticated pipelines, so we recommend looking at existing pipelines here: https://github.com/huggingface/diffusers/tree/main/examples/community.
IMPORTANT:
You can use whatever package you want in your community pipeline file - as long as the user has it installed, everything will work fine. Make sure you have one and only one pipeline class that inherits from DiffusionPipeline
as this will be automatically detected.
How do community pipelines work?
A community pipeline is a class that has to inherit from ['DiffusionPipeline']: and that has been added to `examples/community` [files](https://github.com/huggingface/diffusers/tree/main/examples/community). The community can load the pipeline code via the custom_pipeline argument from DiffusionPipeline. See docs [here](https://huggingface.co/docs/diffusers/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.custom_pipeline):This means:
The model weights and configs of the pipeline should be loaded from the pretrained_model_name_or_path
argument:
whereas the code that powers the community pipeline is defined in a file added in examples/community
.
Now, it might very well be that only some of your pipeline components weights can be downloaded from an official repo. The other components should then be passed directly to init as is the case for the ClIP guidance notebook here.
The magic behind all of this is that we load the code directly from GitHub. You can check it out in more detail if you follow the functionality defined here:
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
if custom_pipeline is not None:
pipeline_class = get_class_from_dynamic_module(
custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline
)
elif cls != DiffusionPipeline:
pipeline_class = cls
else:
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
This is why a community pipeline merged to GitHub will be directly available to all diffusers
packages.