from diffusers import DiffusionPipeline from diffusers import AutoPipelineForText2Image from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline from diffusers import StableDiffusionPipeline import torch import os def load_huggingface_model(model_name, model_type): if model_name == "SD-turbo": pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16") pipe = pipe.to("cuda") elif model_name == "SDXL-turbo": pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16") pipe = pipe.to("cuda") elif model_name == "Stable-cascade": prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16) decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16) pipe = [prior, decoder] elif model_name == "ReCo": path = '/home/bcy/cache/.cache/huggingface/hub/models--j-min--reco_sd14_coco/snapshots/11a062da5a0a84501047cb19e113f520eb610415' if os.path.isdir('/home/bcy/cache/.cache/huggingface/hub/models--j-min--reco_sd14_coco/snapshots/11a062da5a0a84501047cb19e113f520eb610415') else "CompVis/stable-diffusion-v1-4" pipe = StableDiffusionPipeline.from_pretrained(path ,torch_dtype=torch.float16) pipe = pipe.to("cuda") else: raise NotImplementedError # if model_name == "SD-turbo": # pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo") # elif model_name == "SDXL-turbo": # pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo") # else: # raise NotImplementedError # pipe = pipe.to("cpu") return pipe if __name__ == "__main__": # for name in ["SD-turbo", "SDXL-turbo"]: #"SD-turbo", "SDXL-turbo" # pipe = load_huggingface_model(name, "text2image") # for name in ["IF-I-XL-v1.0"]: # pipe = load_huggingface_model(name, 'text2image') # pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) prompt = 'draw a tiger' pipe = load_huggingface_model('Stable-cascade', "text2image") prior, decoder = pipe prior.enable_model_cpu_offload() prior_output = prior( prompt=prompt, height=512, width=512, negative_prompt='', guidance_scale=4.0, num_images_per_prompt=1, num_inference_steps=20 ) decoder.enable_model_cpu_offload() result = decoder( image_embeddings=prior_output.image_embeddings.to(torch.float16), prompt=prompt, negative_prompt='', guidance_scale=0.0, output_type="pil", num_inference_steps=10 ).images[0]