Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,061 Bytes
d65c9b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import torch
from diffusers import DDIMScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline, AutoPipelineForImage2Image
from src.eunms import Model_Type, Scheduler_Type
from src.euler_scheduler import MyEulerAncestralDiscreteScheduler
from src.lcm_scheduler import MyLCMScheduler
from src.ddpm_scheduler import MyDDPMScheduler
from src.sdxl_inversion_pipeline import SDXLDDIMPipeline
from src.sd_inversion_pipeline import SDDDIMPipeline
def scheduler_type_to_class(scheduler_type):
if scheduler_type == Scheduler_Type.DDIM:
return DDIMScheduler
elif scheduler_type == Scheduler_Type.EULER:
return MyEulerAncestralDiscreteScheduler
elif scheduler_type == Scheduler_Type.LCM:
return MyLCMScheduler
elif scheduler_type == Scheduler_Type.DDPM:
return MyDDPMScheduler
else:
raise ValueError("Unknown scheduler type")
def model_type_to_class(model_type):
if model_type == Model_Type.SDXL:
return StableDiffusionXLImg2ImgPipeline, SDXLDDIMPipeline
elif model_type == Model_Type.SDXL_Turbo:
return AutoPipelineForImage2Image, SDXLDDIMPipeline
elif model_type == Model_Type.LCM_SDXL:
return AutoPipelineForImage2Image, SDXLDDIMPipeline
elif model_type == Model_Type.SD15:
return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
elif model_type == Model_Type.SD14:
return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
elif model_type == Model_Type.SD21:
return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
elif model_type == Model_Type.SD21_Turbo:
return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
else:
raise ValueError("Unknown model type")
def model_type_to_model_name(model_type):
if model_type == Model_Type.SDXL:
return "stabilityai/stable-diffusion-xl-base-1.0"
elif model_type == Model_Type.SDXL_Turbo:
return "stabilityai/sdxl-turbo"
elif model_type == Model_Type.LCM_SDXL:
return "stabilityai/stable-diffusion-xl-base-1.0"
elif model_type == Model_Type.SD15:
return "runwayml/stable-diffusion-v1-5"
elif model_type == Model_Type.SD14:
return "CompVis/stable-diffusion-v1-4"
elif model_type == Model_Type.SD21:
return "stabilityai/stable-diffusion-2-1"
elif model_type == Model_Type.SD21_Turbo:
return "stabilityai/sd-turbo"
else:
raise ValueError("Unknown model type")
def model_type_to_size(model_type):
if model_type == Model_Type.SDXL:
return (1024, 1024)
elif model_type == Model_Type.SDXL_Turbo:
return (512, 512)
elif model_type == Model_Type.LCM_SDXL:
return (768, 768) #TODO: check
elif model_type == Model_Type.SD15:
return (512, 512)
elif model_type == Model_Type.SD14:
return (512, 512)
elif model_type == Model_Type.SD21:
return (512, 512)
elif model_type == Model_Type.SD21_Turbo:
return (512, 512)
else:
raise ValueError("Unknown model type")
def is_float16(model_type):
if model_type == Model_Type.SDXL:
return True
elif model_type == Model_Type.SDXL_Turbo:
return True
elif model_type == Model_Type.LCM_SDXL:
return True
elif model_type == Model_Type.SD15:
return False
elif model_type == Model_Type.SD14:
return False
elif model_type == Model_Type.SD21:
return False
elif model_type == Model_Type.SD21_Turbo:
return False
else:
raise ValueError("Unknown model type")
def is_sd(model_type):
if model_type == Model_Type.SDXL:
return False
elif model_type == Model_Type.SDXL_Turbo:
return False
elif model_type == Model_Type.LCM_SDXL:
return False
elif model_type == Model_Type.SD15:
return True
elif model_type == Model_Type.SD14:
return True
elif model_type == Model_Type.SD21:
return True
elif model_type == Model_Type.SD21_Turbo:
return True
else:
raise ValueError("Unknown model type")
def _get_pipes(model_type, device):
model_name = model_type_to_model_name(model_type)
pipeline_inf, pipeline_inv = model_type_to_class(model_type)
if is_float16(model_type):
pipe_inversion = pipeline_inv.from_pretrained(
model_name,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
safety_checker = None
).to(device)
pipe_inference = pipeline_inf.from_pretrained(
model_name,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
safety_checker = None
).to(device)
else:
pipe_inversion = pipeline_inv.from_pretrained(
model_name,
use_safetensors=True,
safety_checker = None
).to(device)
pipe_inference = pipeline_inf.from_pretrained(
model_name,
use_safetensors=True,
safety_checker = None
).to(device)
return pipe_inversion, pipe_inference
def get_pipes(model_type, scheduler_type, device="cuda"):
# model_name = model_type_to_model_name(model_type)
# pipeline_inf, pipeline_inv = model_type_to_class(model_type)
scheduler_class = scheduler_type_to_class(scheduler_type)
pipe_inversion, pipe_inference = _get_pipes(model_type, device)
# pipe_inversion = pipeline_inv.from_pretrained(
# model_name,
# # torch_dtype=torch.float16,
# use_safetensors=True,
# # variant="fp16",
# safety_checker = None
# ).to("cuda")
# pipe_inference = pipeline_inf.from_pretrained(
# model_name,
# # torch_dtype=torch.float16,
# use_safetensors=True,
# # variant="fp16",
# safety_checker = None
# ).to("cuda")
pipe_inference.scheduler = scheduler_class.from_config(pipe_inference.scheduler.config)
pipe_inversion.scheduler = scheduler_class.from_config(pipe_inversion.scheduler.config)
pipe_inversion.scheduler_inference = scheduler_class.from_config(pipe_inference.scheduler.config)
if is_sd(model_type):
pipe_inference.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents
pipe_inversion.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents
pipe_inversion.scheduler_inference.add_noise = lambda init_latents, noise, timestep: init_latents
if model_type == Model_Type.LCM_SDXL:
adapter_id = "latent-consistency/lcm-lora-sdxl"
# load and fuse lcm lora
pipe_inversion.load_lora_weights(adapter_id)
# pipe_inversion.fuse_lora()
pipe_inference.load_lora_weights(adapter_id)
# pipe_inference.fuse_lora()
return pipe_inversion, pipe_inference |