Spaces:
Build error
Build error
import gradio as gr | |
from convert import run_conversion | |
from hub_utils import push_to_hub, save_model_card | |
PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4" | |
DESCRIPTION = """ | |
This Space lets you convert KerasCV Stable Diffusion weights to a format compatible with [Diffusers](https://github.com/huggingface/diffusers) π§¨. This allows users to fine-tune using KerasCV and use the fine-tuned weights in Diffusers taking advantage of its nifty features (like [schedulers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers), [fast attention](https://huggingface.co/docs/diffusers/optimization/fp16), etc.). Specifically, the Keras weights are first converted to PyTorch and then they are wrapped into a [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview). This pipeline is then pushed to the Hugging Face Hub given you have provided `your_hf_token`. | |
## Notes (important) | |
* The Space downloads a couple of pre-trained weights and runs a dummy inference. Depending, on the machine type, the enture process can take anywhere between 2 - 5 minutes. | |
* Only Stable Diffusion (v1) is supported as of now. In particular this checkpoint: [`"CompVis/stable-diffusion-v1-4"`](https://huggingface.co/CompVis/stable-diffusion-v1-4). | |
* [This Colab Notebook](https://colab.research.google.com/drive/1RYY077IQbAJldg8FkK8HSEpNILKHEwLb?usp=sharing) was used to develop the conversion utilities initially. | |
* Providing both `text_encoder_weights` and `unet_weights` is dependent on the fine-tuning task. Here are some _typical_ scenarios: | |
* [DreamBooth](https://dreambooth.github.io/): Both text encoder and UNet | |
* [Textual Inversion](https://textual-inversion.github.io/): Text encoder | |
* [Traditional text2image fine-tuning](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image): UNet | |
**In case none of the `text_encoder_weights` and `unet_weights` is provided, nothing will be done.** | |
* For Textual Inversion, you MUST provide a valid `placeholder_token` i.e., the text concept used for conducting Textual Inversion. | |
* When providing the weights' links, ensure they're directly downloadable. Internally, the Space uses [`tf.keras.utils.get_file()`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/get_file) to retrieve the weights locally. | |
* If you don't provide `your_hf_token` the converted pipeline won't be pushed. | |
Check [here](https://github.com/huggingface/diffusers/blob/31be42209ddfdb69d9640a777b32e9b5c6259bf0/examples/dreambooth/train_dreambooth_lora.py#L975) for an example on how you can change the scheduler of an already initialized `StableDiffusionPipeline`. | |
""" | |
def run(hf_token, text_encoder_weights, unet_weights, placeholder_token, repo_prefix): | |
if text_encoder_weights == "": | |
text_encoder_weights = None | |
if unet_weights == "": | |
unet_weights = None | |
if text_encoder_weights is None and unet_weights is None: | |
return "β No fine-tuned weights provided, nothing to do." | |
if placeholder_token == "": | |
placeholder_token = None | |
if placeholder_token is not None and text_encoder_weights is None: | |
return "β Placeholder token provided but no text encoder weights were provided. Cannot proceed." | |
pipeline = run_conversion(text_encoder_weights, unet_weights, placeholder_token) | |
output_path = "kerascv_sd_diffusers_pipeline" | |
pipeline.save_pretrained(output_path) | |
weight_paths = [] | |
if text_encoder_weights is not None: | |
weight_paths.append(text_encoder_weights) | |
if unet_weights is not None: | |
weight_paths.append(unet_weights) | |
save_model_card( | |
base_model=PRETRAINED_CKPT, | |
repo_folder=output_path, | |
weight_paths=weight_paths, | |
placeholder_token=placeholder_token, | |
) | |
push_str = push_to_hub(hf_token, output_path, repo_prefix) | |
return push_str | |
demo = gr.Interface( | |
title="KerasCV Stable Diffusion to Diffusers Stable Diffusion Pipelines π§¨π€", | |
description=DESCRIPTION, | |
allow_flagging="never", | |
inputs=[ | |
gr.Text(max_lines=1, label="your_hf_token"), | |
gr.Text(max_lines=1, label="text_encoder_weights"), | |
gr.Text(max_lines=1, label="unet_weights"), | |
gr.Text(max_lines=1, label="placeholder_token"), | |
gr.Text(max_lines=1, label="output_repo_prefix"), | |
], | |
outputs=[gr.Markdown(label="output")], | |
fn=run, | |
) | |
demo.launch() | |