cocktailpeanut commited on
Commit
ddbdec3
·
1 Parent(s): c63fe57
Files changed (1) hide show
  1. demo_gradio.py +3 -2
demo_gradio.py CHANGED
@@ -48,14 +48,15 @@ controlnet_path = "diffusers/controlnet-depth-sdxl-1.0"
48
  #torch.cuda.empty_cache()
49
  device = devicetorch.get(torch)
50
  devicetorch.empty_cache(torch)
 
51
 
52
  # load SDXL pipeline
53
- controlnet = ControlNetModel.from_pretrained(controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=torch.float16).to(device)
54
  pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
55
  base_model_path,
56
  controlnet=controlnet,
57
  use_safetensors=True,
58
- torch_dtype=torch.float16,
59
  add_watermarker=False,
60
  ).to(device)
61
  pipe.unet = register_cross_attention_hook(pipe.unet)
 
48
  #torch.cuda.empty_cache()
49
  device = devicetorch.get(torch)
50
  devicetorch.empty_cache(torch)
51
+ dtype = devicetorch.dtype(torch, "float16")
52
 
53
  # load SDXL pipeline
54
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=dtype).to(device)
55
  pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
56
  base_model_path,
57
  controlnet=controlnet,
58
  use_safetensors=True,
59
+ torch_dtype=dtype,
60
  add_watermarker=False,
61
  ).to(device)
62
  pipe.unet = register_cross_attention_hook(pipe.unet)