fffiloni commited on
Commit
27a9419
·
verified ·
1 Parent(s): af19061

Do not assign gpu directly

Browse files
Files changed (1) hide show
  1. models/utils.py +8 -8
models/utils.py CHANGED
@@ -36,7 +36,7 @@ def get_model(
36
  cache_dir=cache_dir,
37
  memsave=memsave,
38
  )
39
- pipe = pipe.to(device, dtype)
40
  elif model_name == "sdxl-turbo":
41
  vae = AutoencoderKL.from_pretrained(
42
  "madebyollin/sdxl-vae-fp16-fix",
@@ -55,7 +55,7 @@ def get_model(
55
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
56
  pipe.scheduler.config, timestep_spacing="trailing"
57
  )
58
- pipe = pipe.to(device, dtype)
59
  elif model_name == "pixart":
60
  pipe = RewardPixartPipeline.from_pretrained(
61
  "PixArt-alpha/PixArt-XL-2-1024-MS",
@@ -80,7 +80,7 @@ def get_model(
80
  pipe.transformer.eval()
81
  freeze_params(pipe.transformer.parameters())
82
  pipe.transformer.enable_gradient_checkpointing()
83
- pipe = pipe.to(device)
84
  elif model_name == "hyper-sd":
85
  base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
86
  repo_name = "ByteDance/Hyper-SD"
@@ -108,20 +108,20 @@ def get_model(
108
  pipe.scheduler = LCMScheduler.from_config(
109
  pipe.scheduler.config, cache_dir=cache_dir
110
  )
111
- pipe = pipe.to(device, dtype)
112
  # upcast vae
113
  pipe.vae = pipe.vae.to(dtype=torch.float32)
114
  elif model_name == "flux":
115
  pipe = RewardFluxPipeline.from_pretrained(
116
  "black-forest-labs/FLUX.1-schnell",
117
- torch_dtype=torch.bfloat16,
118
  cache_dir=cache_dir,
119
  )
120
- pipe.to(device, dtype)
121
  else:
122
  raise ValueError(f"Unknown model name: {model_name}")
123
- if enable_sequential_cpu_offload:
124
- pipe.enable_sequential_cpu_offload()
125
  return pipe
126
 
127
 
 
36
  cache_dir=cache_dir,
37
  memsave=memsave,
38
  )
39
+ #pipe = pipe.to(device, dtype)
40
  elif model_name == "sdxl-turbo":
41
  vae = AutoencoderKL.from_pretrained(
42
  "madebyollin/sdxl-vae-fp16-fix",
 
55
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
56
  pipe.scheduler.config, timestep_spacing="trailing"
57
  )
58
+ #pipe = pipe.to(device, dtype)
59
  elif model_name == "pixart":
60
  pipe = RewardPixartPipeline.from_pretrained(
61
  "PixArt-alpha/PixArt-XL-2-1024-MS",
 
80
  pipe.transformer.eval()
81
  freeze_params(pipe.transformer.parameters())
82
  pipe.transformer.enable_gradient_checkpointing()
83
+ #pipe = pipe.to(device)
84
  elif model_name == "hyper-sd":
85
  base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
86
  repo_name = "ByteDance/Hyper-SD"
 
108
  pipe.scheduler = LCMScheduler.from_config(
109
  pipe.scheduler.config, cache_dir=cache_dir
110
  )
111
+ #pipe = pipe.to(device, dtype)
112
  # upcast vae
113
  pipe.vae = pipe.vae.to(dtype=torch.float32)
114
  elif model_name == "flux":
115
  pipe = RewardFluxPipeline.from_pretrained(
116
  "black-forest-labs/FLUX.1-schnell",
117
+ torch_dtype=torch.float16,
118
  cache_dir=cache_dir,
119
  )
120
+ #pipe.to(device, dtype)
121
  else:
122
  raise ValueError(f"Unknown model name: {model_name}")
123
+ #if enable_sequential_cpu_offload:
124
+ # pipe.enable_sequential_cpu_offload()
125
  return pipe
126
 
127