K-Sort-Arena / model /models /huggingface_models.py
ksort's picture
Add Video
b6dc501
from diffusers import DiffusionPipeline
from diffusers import AutoPipelineForText2Image
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
import torch
def load_huggingface_model(model_name, model_type):
if model_name == "SD-turbo":
pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
pipe = pipe.to("cuda")
elif model_name == "SDXL-turbo":
pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
pipe = pipe.to("cuda")
elif model_name == "Stable-cascade":
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16)
pipe = [prior, decoder]
else:
raise NotImplementedError
# if model_name == "SD-turbo":
# pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo")
# elif model_name == "SDXL-turbo":
# pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
# else:
# raise NotImplementedError
# pipe = pipe.to("cpu")
return pipe
if __name__ == "__main__":
# for name in ["SD-turbo", "SDXL-turbo"]: #"SD-turbo", "SDXL-turbo"
# pipe = load_huggingface_model(name, "text2image")
# for name in ["IF-I-XL-v1.0"]:
# pipe = load_huggingface_model(name, 'text2image')
# pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
prompt = 'draw a tiger'
pipe = load_huggingface_model('Stable-cascade', "text2image")
prior, decoder = pipe
prior.enable_model_cpu_offload()
prior_output = prior(
prompt=prompt,
height=512,
width=512,
negative_prompt='',
guidance_scale=4.0,
num_images_per_prompt=1,
num_inference_steps=20
)
decoder.enable_model_cpu_offload()
result = decoder(
image_embeddings=prior_output.image_embeddings.to(torch.float16),
prompt=prompt,
negative_prompt='',
guidance_scale=0.0,
output_type="pil",
num_inference_steps=10
).images[0]