ReNoise-Inversion / src /enums_utils.py
garibida's picture
Upload Files
d65c9b3
raw
history blame
7.06 kB
import torch
from diffusers import DDIMScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline, AutoPipelineForImage2Image
from src.eunms import Model_Type, Scheduler_Type
from src.euler_scheduler import MyEulerAncestralDiscreteScheduler
from src.lcm_scheduler import MyLCMScheduler
from src.ddpm_scheduler import MyDDPMScheduler
from src.sdxl_inversion_pipeline import SDXLDDIMPipeline
from src.sd_inversion_pipeline import SDDDIMPipeline
def scheduler_type_to_class(scheduler_type):
if scheduler_type == Scheduler_Type.DDIM:
return DDIMScheduler
elif scheduler_type == Scheduler_Type.EULER:
return MyEulerAncestralDiscreteScheduler
elif scheduler_type == Scheduler_Type.LCM:
return MyLCMScheduler
elif scheduler_type == Scheduler_Type.DDPM:
return MyDDPMScheduler
else:
raise ValueError("Unknown scheduler type")
def model_type_to_class(model_type):
if model_type == Model_Type.SDXL:
return StableDiffusionXLImg2ImgPipeline, SDXLDDIMPipeline
elif model_type == Model_Type.SDXL_Turbo:
return AutoPipelineForImage2Image, SDXLDDIMPipeline
elif model_type == Model_Type.LCM_SDXL:
return AutoPipelineForImage2Image, SDXLDDIMPipeline
elif model_type == Model_Type.SD15:
return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
elif model_type == Model_Type.SD14:
return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
elif model_type == Model_Type.SD21:
return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
elif model_type == Model_Type.SD21_Turbo:
return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
else:
raise ValueError("Unknown model type")
def model_type_to_model_name(model_type):
if model_type == Model_Type.SDXL:
return "stabilityai/stable-diffusion-xl-base-1.0"
elif model_type == Model_Type.SDXL_Turbo:
return "stabilityai/sdxl-turbo"
elif model_type == Model_Type.LCM_SDXL:
return "stabilityai/stable-diffusion-xl-base-1.0"
elif model_type == Model_Type.SD15:
return "runwayml/stable-diffusion-v1-5"
elif model_type == Model_Type.SD14:
return "CompVis/stable-diffusion-v1-4"
elif model_type == Model_Type.SD21:
return "stabilityai/stable-diffusion-2-1"
elif model_type == Model_Type.SD21_Turbo:
return "stabilityai/sd-turbo"
else:
raise ValueError("Unknown model type")
def model_type_to_size(model_type):
if model_type == Model_Type.SDXL:
return (1024, 1024)
elif model_type == Model_Type.SDXL_Turbo:
return (512, 512)
elif model_type == Model_Type.LCM_SDXL:
return (768, 768) #TODO: check
elif model_type == Model_Type.SD15:
return (512, 512)
elif model_type == Model_Type.SD14:
return (512, 512)
elif model_type == Model_Type.SD21:
return (512, 512)
elif model_type == Model_Type.SD21_Turbo:
return (512, 512)
else:
raise ValueError("Unknown model type")
def is_float16(model_type):
if model_type == Model_Type.SDXL:
return True
elif model_type == Model_Type.SDXL_Turbo:
return True
elif model_type == Model_Type.LCM_SDXL:
return True
elif model_type == Model_Type.SD15:
return False
elif model_type == Model_Type.SD14:
return False
elif model_type == Model_Type.SD21:
return False
elif model_type == Model_Type.SD21_Turbo:
return False
else:
raise ValueError("Unknown model type")
def is_sd(model_type):
if model_type == Model_Type.SDXL:
return False
elif model_type == Model_Type.SDXL_Turbo:
return False
elif model_type == Model_Type.LCM_SDXL:
return False
elif model_type == Model_Type.SD15:
return True
elif model_type == Model_Type.SD14:
return True
elif model_type == Model_Type.SD21:
return True
elif model_type == Model_Type.SD21_Turbo:
return True
else:
raise ValueError("Unknown model type")
def _get_pipes(model_type, device):
model_name = model_type_to_model_name(model_type)
pipeline_inf, pipeline_inv = model_type_to_class(model_type)
if is_float16(model_type):
pipe_inversion = pipeline_inv.from_pretrained(
model_name,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
safety_checker = None
).to(device)
pipe_inference = pipeline_inf.from_pretrained(
model_name,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
safety_checker = None
).to(device)
else:
pipe_inversion = pipeline_inv.from_pretrained(
model_name,
use_safetensors=True,
safety_checker = None
).to(device)
pipe_inference = pipeline_inf.from_pretrained(
model_name,
use_safetensors=True,
safety_checker = None
).to(device)
return pipe_inversion, pipe_inference
def get_pipes(model_type, scheduler_type, device="cuda"):
# model_name = model_type_to_model_name(model_type)
# pipeline_inf, pipeline_inv = model_type_to_class(model_type)
scheduler_class = scheduler_type_to_class(scheduler_type)
pipe_inversion, pipe_inference = _get_pipes(model_type, device)
# pipe_inversion = pipeline_inv.from_pretrained(
# model_name,
# # torch_dtype=torch.float16,
# use_safetensors=True,
# # variant="fp16",
# safety_checker = None
# ).to("cuda")
# pipe_inference = pipeline_inf.from_pretrained(
# model_name,
# # torch_dtype=torch.float16,
# use_safetensors=True,
# # variant="fp16",
# safety_checker = None
# ).to("cuda")
pipe_inference.scheduler = scheduler_class.from_config(pipe_inference.scheduler.config)
pipe_inversion.scheduler = scheduler_class.from_config(pipe_inversion.scheduler.config)
pipe_inversion.scheduler_inference = scheduler_class.from_config(pipe_inference.scheduler.config)
if is_sd(model_type):
pipe_inference.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents
pipe_inversion.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents
pipe_inversion.scheduler_inference.add_noise = lambda init_latents, noise, timestep: init_latents
if model_type == Model_Type.LCM_SDXL:
adapter_id = "latent-consistency/lcm-lora-sdxl"
# load and fuse lcm lora
pipe_inversion.load_lora_weights(adapter_id)
# pipe_inversion.fuse_lora()
pipe_inference.load_lora_weights(adapter_id)
# pipe_inference.fuse_lora()
return pipe_inversion, pipe_inference