yujiepan/FLUX.1-dev-tiny-random

This pipeline is intended for debugging. It is adapted from black-forest-labs/FLUX.1-dev with smaller size and randomly initialized parameters.

Usage

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("yujiepan/FLUX.1-dev-tiny-random", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt,
    height=1024,
    width=1024,
    guidance_scale=3.5,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]
# image.save("flux-dev.png")

Codes

import importlib

import torch
import transformers

import diffusers
import rich


def get_original_model_configs(
    pipeline_cls: type[diffusers.FluxPipeline],
    pipeline_id: str
):
    pipeline_config: dict[str, list[str]] = \
        pipeline_cls.load_config(pipeline_id)
    model_configs = {}

    for subfolder, import_strings in pipeline_config.items():
        if subfolder.startswith("_"):
            continue
        module = importlib.import_module(".".join(import_strings[:-1]))
        cls = getattr(module, import_strings[-1])
        if issubclass(cls, transformers.PreTrainedModel):
            config_class: transformers.PretrainedConfig = cls.config_class
            config = config_class.from_pretrained(
                pipeline_id, subfolder=subfolder)
            model_configs[subfolder] = config
        elif issubclass(cls, diffusers.ModelMixin) and issubclass(cls, diffusers.ConfigMixin):
            config = cls.load_config(pipeline_id, subfolder=subfolder)
            model_configs[subfolder] = config
        elif subfolder in ['scheduler', 'tokenizer', 'tokenizer_2', 'tokenizer_3']:
            pass
        else:
            raise NotImplementedError(f"unknown {subfolder}: {import_strings}")

    return model_configs


def load_pipeline(pipeline_cls: type[diffusers.DiffusionPipeline], pipeline_id: str, model_configs: dict[str, dict]):
    pipeline_config: dict[str, list[str]
                          ] = pipeline_cls.load_config(pipeline_id)
    components = {}
    for subfolder, import_strings in pipeline_config.items():
        if subfolder.startswith("_"):
            continue
        module = importlib.import_module(".".join(import_strings[:-1]))
        cls = getattr(module, import_strings[-1])
        print(f"Loading:", ".".join(import_strings))
        if issubclass(cls, transformers.PreTrainedModel):
            config = model_configs[subfolder]
            component = cls(config)
        elif issubclass(cls, transformers.PreTrainedTokenizerBase):
            component = cls.from_pretrained(pipeline_id, subfolder=subfolder)
        elif issubclass(cls, diffusers.ModelMixin) and issubclass(cls, diffusers.ConfigMixin):
            config = model_configs[subfolder]
            component = cls.from_config(config)
        elif issubclass(cls, diffusers.SchedulerMixin) and issubclass(cls, diffusers.ConfigMixin):
            component = cls.from_pretrained(pipeline_id, subfolder=subfolder)
        else:
            raise (f"unknown {subfolder}: {import_strings}")
        components[subfolder] = component
        if 'transformer' in component.__class__.__name__.lower():
            print(component)
    pipeline = pipeline_cls(**components)
    return pipeline


def get_pipeline():
    torch.manual_seed(42)
    pipeline_id = "black-forest-labs/FLUX.1-dev"
    pipeline_cls = diffusers.FluxPipeline
    model_configs = get_original_model_configs(pipeline_cls, pipeline_id)

    HIDDEN_SIZE = 8
    model_configs["text_encoder"].hidden_size = HIDDEN_SIZE
    model_configs["text_encoder"].intermediate_size = HIDDEN_SIZE * 2
    model_configs["text_encoder"].num_attention_heads = 2
    model_configs["text_encoder"].num_hidden_layers = 2
    model_configs["text_encoder"].projection_dim = HIDDEN_SIZE

    model_configs["text_encoder_2"].d_model = HIDDEN_SIZE
    model_configs["text_encoder_2"].d_ff = HIDDEN_SIZE * 2
    model_configs["text_encoder_2"].d_kv = HIDDEN_SIZE // 2
    model_configs["text_encoder_2"].num_heads = 2
    model_configs["text_encoder_2"].num_layers = 2

    model_configs["transformer"]["num_layers"] = 2
    model_configs["transformer"]["num_single_layers"] = 4
    model_configs["transformer"]["num_attention_heads"] = 2
    model_configs["transformer"]["attention_head_dim"] = HIDDEN_SIZE
    model_configs["transformer"]["pooled_projection_dim"] = HIDDEN_SIZE
    model_configs["transformer"]["joint_attention_dim"] = HIDDEN_SIZE
    model_configs["transformer"]["axes_dims_rope"] = (4, 2, 2)
    # model_configs["transformer"]["caption_projection_dim"] = HIDDEN_SIZE

    model_configs["vae"]["layers_per_block"] = 1
    model_configs["vae"]["block_out_channels"] = [HIDDEN_SIZE] * 4
    model_configs["vae"]["norm_num_groups"] = 2
    model_configs["vae"]["latent_channels"] = 16

    pipeline = load_pipeline(pipeline_cls, pipeline_id, model_configs)
    return pipeline


pipe = get_pipeline()
pipe = pipe.to(torch.bfloat16)

from pathlib import Path
save_folder = '/tmp/yujiepan/FLUX.1-dev-tiny-random'
Path(save_folder).mkdir(parents=True, exist_ok=True)
pipe.save_pretrained(save_folder)

pipe = diffusers.FluxPipeline.from_pretrained(save_folder, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt,
    height=1024,
    width=1024,
    guidance_scale=3.5,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]

configs = get_original_model_configs(diffusers.FluxPipeline, save_folder)
rich.print(configs)

pipe.push_to_hub(save_folder.removeprefix('/tmp/'))
Downloads last month
11
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Collection including yujiepan/FLUX.1-dev-tiny-random