Control_Ability_Arena / model /models /huggingface_models.py
Bbmyy
first commit
c92c0ec
raw
history blame
2.9 kB
from diffusers import DiffusionPipeline
from diffusers import AutoPipelineForText2Image
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
from diffusers import StableDiffusionPipeline
import torch
import os
def load_huggingface_model(model_name, model_type):
if model_name == "SD-turbo":
pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
pipe = pipe.to("cuda")
elif model_name == "SDXL-turbo":
pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
pipe = pipe.to("cuda")
elif model_name == "Stable-cascade":
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16)
pipe = [prior, decoder]
elif model_name == "ReCo":
path = '/home/bcy/cache/.cache/huggingface/hub/models--j-min--reco_sd14_coco/snapshots/11a062da5a0a84501047cb19e113f520eb610415' if os.path.isdir('/home/bcy/cache/.cache/huggingface/hub/models--j-min--reco_sd14_coco/snapshots/11a062da5a0a84501047cb19e113f520eb610415') else "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(path ,torch_dtype=torch.float16)
pipe = pipe.to("cuda")
else:
raise NotImplementedError
# if model_name == "SD-turbo":
# pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo")
# elif model_name == "SDXL-turbo":
# pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
# else:
# raise NotImplementedError
# pipe = pipe.to("cpu")
return pipe
if __name__ == "__main__":
# for name in ["SD-turbo", "SDXL-turbo"]: #"SD-turbo", "SDXL-turbo"
# pipe = load_huggingface_model(name, "text2image")
# for name in ["IF-I-XL-v1.0"]:
# pipe = load_huggingface_model(name, 'text2image')
# pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
prompt = 'draw a tiger'
pipe = load_huggingface_model('Stable-cascade', "text2image")
prior, decoder = pipe
prior.enable_model_cpu_offload()
prior_output = prior(
prompt=prompt,
height=512,
width=512,
negative_prompt='',
guidance_scale=4.0,
num_images_per_prompt=1,
num_inference_steps=20
)
decoder.enable_model_cpu_offload()
result = decoder(
image_embeddings=prior_output.image_embeddings.to(torch.float16),
prompt=prompt,
negative_prompt='',
guidance_scale=0.0,
output_type="pil",
num_inference_steps=10
).images[0]