Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from diffusers import DDIMScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline, AutoPipelineForImage2Image | |
from src.eunms import Model_Type, Scheduler_Type | |
from src.euler_scheduler import MyEulerAncestralDiscreteScheduler | |
from src.lcm_scheduler import MyLCMScheduler | |
from src.ddpm_scheduler import MyDDPMScheduler | |
from src.sdxl_inversion_pipeline import SDXLDDIMPipeline | |
from src.sd_inversion_pipeline import SDDDIMPipeline | |
def scheduler_type_to_class(scheduler_type): | |
if scheduler_type == Scheduler_Type.DDIM: | |
return DDIMScheduler | |
elif scheduler_type == Scheduler_Type.EULER: | |
return MyEulerAncestralDiscreteScheduler | |
elif scheduler_type == Scheduler_Type.LCM: | |
return MyLCMScheduler | |
elif scheduler_type == Scheduler_Type.DDPM: | |
return MyDDPMScheduler | |
else: | |
raise ValueError("Unknown scheduler type") | |
def model_type_to_class(model_type): | |
if model_type == Model_Type.SDXL: | |
return StableDiffusionXLImg2ImgPipeline, SDXLDDIMPipeline | |
elif model_type == Model_Type.SDXL_Turbo: | |
return AutoPipelineForImage2Image, SDXLDDIMPipeline | |
elif model_type == Model_Type.LCM_SDXL: | |
return AutoPipelineForImage2Image, SDXLDDIMPipeline | |
elif model_type == Model_Type.SD15: | |
return StableDiffusionImg2ImgPipeline, SDDDIMPipeline | |
elif model_type == Model_Type.SD14: | |
return StableDiffusionImg2ImgPipeline, SDDDIMPipeline | |
elif model_type == Model_Type.SD21: | |
return StableDiffusionImg2ImgPipeline, SDDDIMPipeline | |
elif model_type == Model_Type.SD21_Turbo: | |
return StableDiffusionImg2ImgPipeline, SDDDIMPipeline | |
else: | |
raise ValueError("Unknown model type") | |
def model_type_to_model_name(model_type): | |
if model_type == Model_Type.SDXL: | |
return "stabilityai/stable-diffusion-xl-base-1.0" | |
elif model_type == Model_Type.SDXL_Turbo: | |
return "stabilityai/sdxl-turbo" | |
elif model_type == Model_Type.LCM_SDXL: | |
return "stabilityai/stable-diffusion-xl-base-1.0" | |
elif model_type == Model_Type.SD15: | |
return "runwayml/stable-diffusion-v1-5" | |
elif model_type == Model_Type.SD14: | |
return "CompVis/stable-diffusion-v1-4" | |
elif model_type == Model_Type.SD21: | |
return "stabilityai/stable-diffusion-2-1" | |
elif model_type == Model_Type.SD21_Turbo: | |
return "stabilityai/sd-turbo" | |
else: | |
raise ValueError("Unknown model type") | |
def model_type_to_size(model_type): | |
if model_type == Model_Type.SDXL: | |
return (1024, 1024) | |
elif model_type == Model_Type.SDXL_Turbo: | |
return (512, 512) | |
elif model_type == Model_Type.LCM_SDXL: | |
return (768, 768) #TODO: check | |
elif model_type == Model_Type.SD15: | |
return (512, 512) | |
elif model_type == Model_Type.SD14: | |
return (512, 512) | |
elif model_type == Model_Type.SD21: | |
return (512, 512) | |
elif model_type == Model_Type.SD21_Turbo: | |
return (512, 512) | |
else: | |
raise ValueError("Unknown model type") | |
def is_float16(model_type): | |
if model_type == Model_Type.SDXL: | |
return True | |
elif model_type == Model_Type.SDXL_Turbo: | |
return True | |
elif model_type == Model_Type.LCM_SDXL: | |
return True | |
elif model_type == Model_Type.SD15: | |
return False | |
elif model_type == Model_Type.SD14: | |
return False | |
elif model_type == Model_Type.SD21: | |
return False | |
elif model_type == Model_Type.SD21_Turbo: | |
return False | |
else: | |
raise ValueError("Unknown model type") | |
def is_sd(model_type): | |
if model_type == Model_Type.SDXL: | |
return False | |
elif model_type == Model_Type.SDXL_Turbo: | |
return False | |
elif model_type == Model_Type.LCM_SDXL: | |
return False | |
elif model_type == Model_Type.SD15: | |
return True | |
elif model_type == Model_Type.SD14: | |
return True | |
elif model_type == Model_Type.SD21: | |
return True | |
elif model_type == Model_Type.SD21_Turbo: | |
return True | |
else: | |
raise ValueError("Unknown model type") | |
def _get_pipes(model_type, device): | |
model_name = model_type_to_model_name(model_type) | |
pipeline_inf, pipeline_inv = model_type_to_class(model_type) | |
if is_float16(model_type): | |
pipe_inversion = pipeline_inv.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
variant="fp16", | |
safety_checker = None | |
).to(device) | |
pipe_inference = pipeline_inf.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
variant="fp16", | |
safety_checker = None | |
).to(device) | |
else: | |
pipe_inversion = pipeline_inv.from_pretrained( | |
model_name, | |
use_safetensors=True, | |
safety_checker = None | |
).to(device) | |
pipe_inference = pipeline_inf.from_pretrained( | |
model_name, | |
use_safetensors=True, | |
safety_checker = None | |
).to(device) | |
return pipe_inversion, pipe_inference | |
def get_pipes(model_type, scheduler_type, device="cuda"): | |
# model_name = model_type_to_model_name(model_type) | |
# pipeline_inf, pipeline_inv = model_type_to_class(model_type) | |
scheduler_class = scheduler_type_to_class(scheduler_type) | |
pipe_inversion, pipe_inference = _get_pipes(model_type, device) | |
# pipe_inversion = pipeline_inv.from_pretrained( | |
# model_name, | |
# # torch_dtype=torch.float16, | |
# use_safetensors=True, | |
# # variant="fp16", | |
# safety_checker = None | |
# ).to("cuda") | |
# pipe_inference = pipeline_inf.from_pretrained( | |
# model_name, | |
# # torch_dtype=torch.float16, | |
# use_safetensors=True, | |
# # variant="fp16", | |
# safety_checker = None | |
# ).to("cuda") | |
pipe_inference.scheduler = scheduler_class.from_config(pipe_inference.scheduler.config) | |
pipe_inversion.scheduler = scheduler_class.from_config(pipe_inversion.scheduler.config) | |
pipe_inversion.scheduler_inference = scheduler_class.from_config(pipe_inference.scheduler.config) | |
if is_sd(model_type): | |
pipe_inference.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents | |
pipe_inversion.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents | |
pipe_inversion.scheduler_inference.add_noise = lambda init_latents, noise, timestep: init_latents | |
if model_type == Model_Type.LCM_SDXL: | |
adapter_id = "latent-consistency/lcm-lora-sdxl" | |
# load and fuse lcm lora | |
pipe_inversion.load_lora_weights(adapter_id) | |
# pipe_inversion.fuse_lora() | |
pipe_inference.load_lora_weights(adapter_id) | |
# pipe_inference.fuse_lora() | |
return pipe_inversion, pipe_inference |