Ming Li commited on
Commit
c763397
·
1 Parent(s): e8acbbf

rm zero gpu on run_pipe func

Browse files
Files changed (1) hide show
  1. model.py +7 -308
model.py CHANGED
@@ -113,6 +113,8 @@ class Model:
113
  guidance_scale: float,
114
  seed: int,
115
  ) -> list[PIL.Image.Image]:
 
 
116
  generator = torch.Generator().manual_seed(seed)
117
  return self.pipe(
118
  prompt=prompt,
@@ -125,7 +127,7 @@ class Model:
125
  ).images
126
 
127
  @torch.inference_mode()
128
- @spaces.GPU()
129
  def process_canny(
130
  self,
131
  image: np.ndarray,
@@ -170,141 +172,7 @@ class Model:
170
  return [control_image] * num_images + results + conditions_of_generated_imgs
171
 
172
  @torch.inference_mode()
173
- def process_mlsd(
174
- self,
175
- image: np.ndarray,
176
- prompt: str,
177
- additional_prompt: str,
178
- negative_prompt: str,
179
- num_images: int,
180
- image_resolution: int,
181
- preprocess_resolution: int,
182
- num_steps: int,
183
- guidance_scale: float,
184
- seed: int,
185
- value_threshold: float,
186
- distance_threshold: float,
187
- ) -> list[PIL.Image.Image]:
188
- if image is None:
189
- raise ValueError
190
- if image_resolution > MAX_IMAGE_RESOLUTION:
191
- raise ValueError
192
- if num_images > MAX_NUM_IMAGES:
193
- raise ValueError
194
-
195
- self.preprocessor.load("MLSD")
196
- control_image = self.preprocessor(
197
- image=image,
198
- image_resolution=image_resolution,
199
- detect_resolution=preprocess_resolution,
200
- thr_v=value_threshold,
201
- thr_d=distance_threshold,
202
- )
203
- self.load_controlnet_weight("MLSD")
204
- results = self.run_pipe(
205
- prompt=self.get_prompt(prompt, additional_prompt),
206
- negative_prompt=negative_prompt,
207
- control_image=control_image,
208
- num_images=num_images,
209
- num_steps=num_steps,
210
- guidance_scale=guidance_scale,
211
- seed=seed,
212
- )
213
- return [control_image] + results
214
-
215
- @torch.inference_mode()
216
- def process_scribble(
217
- self,
218
- image: np.ndarray,
219
- prompt: str,
220
- additional_prompt: str,
221
- negative_prompt: str,
222
- num_images: int,
223
- image_resolution: int,
224
- preprocess_resolution: int,
225
- num_steps: int,
226
- guidance_scale: float,
227
- seed: int,
228
- preprocessor_name: str,
229
- ) -> list[PIL.Image.Image]:
230
- if image is None:
231
- raise ValueError
232
- if image_resolution > MAX_IMAGE_RESOLUTION:
233
- raise ValueError
234
- if num_images > MAX_NUM_IMAGES:
235
- raise ValueError
236
-
237
- if preprocessor_name == "None":
238
- image = HWC3(image)
239
- image = resize_image(image, resolution=image_resolution)
240
- control_image = PIL.Image.fromarray(image)
241
- elif preprocessor_name == "HED":
242
- self.preprocessor.load(preprocessor_name)
243
- control_image = self.preprocessor(
244
- image=image,
245
- image_resolution=image_resolution,
246
- detect_resolution=preprocess_resolution,
247
- scribble=False,
248
- )
249
- elif preprocessor_name == "PidiNet":
250
- self.preprocessor.load(preprocessor_name)
251
- control_image = self.preprocessor(
252
- image=image,
253
- image_resolution=image_resolution,
254
- detect_resolution=preprocess_resolution,
255
- safe=False,
256
- )
257
- self.load_controlnet_weight("scribble")
258
- results = self.run_pipe(
259
- prompt=self.get_prompt(prompt, additional_prompt),
260
- negative_prompt=negative_prompt,
261
- control_image=control_image,
262
- num_images=num_images,
263
- num_steps=num_steps,
264
- guidance_scale=guidance_scale,
265
- seed=seed,
266
- )
267
- return [control_image] + results
268
-
269
- @torch.inference_mode()
270
- def process_scribble_interactive(
271
- self,
272
- image_and_mask: dict[str, np.ndarray],
273
- prompt: str,
274
- additional_prompt: str,
275
- negative_prompt: str,
276
- num_images: int,
277
- image_resolution: int,
278
- num_steps: int,
279
- guidance_scale: float,
280
- seed: int,
281
- ) -> list[PIL.Image.Image]:
282
- if image_and_mask is None:
283
- raise ValueError
284
- if image_resolution > MAX_IMAGE_RESOLUTION:
285
- raise ValueError
286
- if num_images > MAX_NUM_IMAGES:
287
- raise ValueError
288
-
289
- image = image_and_mask["mask"]
290
- image = HWC3(image)
291
- image = resize_image(image, resolution=image_resolution)
292
- control_image = PIL.Image.fromarray(image)
293
-
294
- self.load_controlnet_weight("scribble")
295
- results = self.run_pipe(
296
- prompt=self.get_prompt(prompt, additional_prompt),
297
- negative_prompt=negative_prompt,
298
- control_image=control_image,
299
- num_images=num_images,
300
- num_steps=num_steps,
301
- guidance_scale=guidance_scale,
302
- seed=seed,
303
- )
304
- return [control_image] + results
305
-
306
- @torch.inference_mode()
307
- @spaces.GPU()
308
  def process_softedge(
309
  self,
310
  image: np.ndarray,
@@ -371,53 +239,7 @@ class Model:
371
  return [control_image] * num_images + results + conditions_of_generated_imgs
372
 
373
  @torch.inference_mode()
374
- def process_openpose(
375
- self,
376
- image: np.ndarray,
377
- prompt: str,
378
- additional_prompt: str,
379
- negative_prompt: str,
380
- num_images: int,
381
- image_resolution: int,
382
- preprocess_resolution: int,
383
- num_steps: int,
384
- guidance_scale: float,
385
- seed: int,
386
- preprocessor_name: str,
387
- ) -> list[PIL.Image.Image]:
388
- if image is None:
389
- raise ValueError
390
- if image_resolution > MAX_IMAGE_RESOLUTION:
391
- raise ValueError
392
- if num_images > MAX_NUM_IMAGES:
393
- raise ValueError
394
-
395
- if preprocessor_name == "None":
396
- image = HWC3(image)
397
- image = resize_image(image, resolution=image_resolution)
398
- control_image = PIL.Image.fromarray(image)
399
- else:
400
- self.preprocessor.load("Openpose")
401
- control_image = self.preprocessor(
402
- image=image,
403
- image_resolution=image_resolution,
404
- detect_resolution=preprocess_resolution,
405
- hand_and_face=True,
406
- )
407
- self.load_controlnet_weight("Openpose")
408
- results = self.run_pipe(
409
- prompt=self.get_prompt(prompt, additional_prompt),
410
- negative_prompt=negative_prompt,
411
- control_image=control_image,
412
- num_images=num_images,
413
- num_steps=num_steps,
414
- guidance_scale=guidance_scale,
415
- seed=seed,
416
- )
417
- return [control_image] + results
418
-
419
- @torch.inference_mode()
420
- @spaces.GPU()
421
  def process_segmentation(
422
  self,
423
  image: np.ndarray,
@@ -471,7 +293,7 @@ class Model:
471
  return [control_image] * num_images + results + conditions_of_generated_imgs
472
 
473
  @torch.inference_mode()
474
- @spaces.GPU()
475
  def process_depth(
476
  self,
477
  image: np.ndarray,
@@ -524,52 +346,7 @@ class Model:
524
  return [control_image] * num_images + results + conditions_of_generated_imgs
525
 
526
  @torch.inference_mode()
527
- def process_normal(
528
- self,
529
- image: np.ndarray,
530
- prompt: str,
531
- additional_prompt: str,
532
- negative_prompt: str,
533
- num_images: int,
534
- image_resolution: int,
535
- preprocess_resolution: int,
536
- num_steps: int,
537
- guidance_scale: float,
538
- seed: int,
539
- preprocessor_name: str,
540
- ) -> list[PIL.Image.Image]:
541
- if image is None:
542
- raise ValueError
543
- if image_resolution > MAX_IMAGE_RESOLUTION:
544
- raise ValueError
545
- if num_images > MAX_NUM_IMAGES:
546
- raise ValueError
547
-
548
- if preprocessor_name == "None":
549
- image = HWC3(image)
550
- image = resize_image(image, resolution=image_resolution)
551
- control_image = PIL.Image.fromarray(image)
552
- else:
553
- self.preprocessor.load("NormalBae")
554
- control_image = self.preprocessor(
555
- image=image,
556
- image_resolution=image_resolution,
557
- detect_resolution=preprocess_resolution,
558
- )
559
- self.load_controlnet_weight("NormalBae")
560
- results = self.run_pipe(
561
- prompt=self.get_prompt(prompt, additional_prompt),
562
- negative_prompt=negative_prompt,
563
- control_image=control_image,
564
- num_images=num_images,
565
- num_steps=num_steps,
566
- guidance_scale=guidance_scale,
567
- seed=seed,
568
- )
569
- return [control_image] + results
570
-
571
- @torch.inference_mode()
572
- @spaces.GPU()
573
  def process_lineart(
574
  self,
575
  image: np.ndarray,
@@ -638,81 +415,3 @@ class Model:
638
  conditions_of_generated_imgs = [PIL.Image.fromarray((255 - np.array(x)).astype(np.uint8)) for x in conditions_of_generated_imgs]
639
 
640
  return [control_image] * num_images + results + conditions_of_generated_imgs
641
-
642
- @torch.inference_mode()
643
- def process_shuffle(
644
- self,
645
- image: np.ndarray,
646
- prompt: str,
647
- additional_prompt: str,
648
- negative_prompt: str,
649
- num_images: int,
650
- image_resolution: int,
651
- num_steps: int,
652
- guidance_scale: float,
653
- seed: int,
654
- preprocessor_name: str,
655
- ) -> list[PIL.Image.Image]:
656
- if image is None:
657
- raise ValueError
658
- if image_resolution > MAX_IMAGE_RESOLUTION:
659
- raise ValueError
660
- if num_images > MAX_NUM_IMAGES:
661
- raise ValueError
662
-
663
- if preprocessor_name == "None":
664
- image = HWC3(image)
665
- image = resize_image(image, resolution=image_resolution)
666
- control_image = PIL.Image.fromarray(image)
667
- else:
668
- self.preprocessor.load(preprocessor_name)
669
- control_image = self.preprocessor(
670
- image=image,
671
- image_resolution=image_resolution,
672
- )
673
- self.load_controlnet_weight("shuffle")
674
- results = self.run_pipe(
675
- prompt=self.get_prompt(prompt, additional_prompt),
676
- negative_prompt=negative_prompt,
677
- control_image=control_image,
678
- num_images=num_images,
679
- num_steps=num_steps,
680
- guidance_scale=guidance_scale,
681
- seed=seed,
682
- )
683
- return [control_image] + results
684
-
685
- @torch.inference_mode()
686
- def process_ip2p(
687
- self,
688
- image: np.ndarray,
689
- prompt: str,
690
- additional_prompt: str,
691
- negative_prompt: str,
692
- num_images: int,
693
- image_resolution: int,
694
- num_steps: int,
695
- guidance_scale: float,
696
- seed: int,
697
- ) -> list[PIL.Image.Image]:
698
- if image is None:
699
- raise ValueError
700
- if image_resolution > MAX_IMAGE_RESOLUTION:
701
- raise ValueError
702
- if num_images > MAX_NUM_IMAGES:
703
- raise ValueError
704
-
705
- image = HWC3(image)
706
- image = resize_image(image, resolution=image_resolution)
707
- control_image = PIL.Image.fromarray(image)
708
- self.load_controlnet_weight("ip2p")
709
- results = self.run_pipe(
710
- prompt=self.get_prompt(prompt, additional_prompt),
711
- negative_prompt=negative_prompt,
712
- control_image=control_image,
713
- num_images=num_images,
714
- num_steps=num_steps,
715
- guidance_scale=guidance_scale,
716
- seed=seed,
717
- )
718
- return [control_image] + results
 
113
  guidance_scale: float,
114
  seed: int,
115
  ) -> list[PIL.Image.Image]:
116
+ self.pipe.to(self.device)
117
+ self.pipe.controlnet.to(self.device)
118
  generator = torch.Generator().manual_seed(seed)
119
  return self.pipe(
120
  prompt=prompt,
 
127
  ).images
128
 
129
  @torch.inference_mode()
130
+ @spaces.GPU(enable_queue=True)
131
  def process_canny(
132
  self,
133
  image: np.ndarray,
 
172
  return [control_image] * num_images + results + conditions_of_generated_imgs
173
 
174
  @torch.inference_mode()
175
+ @spaces.GPU(enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  def process_softedge(
177
  self,
178
  image: np.ndarray,
 
239
  return [control_image] * num_images + results + conditions_of_generated_imgs
240
 
241
  @torch.inference_mode()
242
+ @spaces.GPU(enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  def process_segmentation(
244
  self,
245
  image: np.ndarray,
 
293
  return [control_image] * num_images + results + conditions_of_generated_imgs
294
 
295
  @torch.inference_mode()
296
+ @spaces.GPU(enable_queue=True)
297
  def process_depth(
298
  self,
299
  image: np.ndarray,
 
346
  return [control_image] * num_images + results + conditions_of_generated_imgs
347
 
348
  @torch.inference_mode()
349
+ @spaces.GPU(enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  def process_lineart(
351
  self,
352
  image: np.ndarray,
 
415
  conditions_of_generated_imgs = [PIL.Image.fromarray((255 - np.array(x)).astype(np.uint8)) for x in conditions_of_generated_imgs]
416
 
417
  return [control_image] * num_images + results + conditions_of_generated_imgs