hysts HF staff commited on
Commit
701868f
·
1 Parent(s): 1c9931d
Files changed (1) hide show
  1. model.py +60 -95
model.py CHANGED
@@ -111,31 +111,6 @@ class Model:
111
  generator=generator,
112
  image=control_image).images
113
 
114
- def process(
115
- self,
116
- task_name: str,
117
- prompt: str,
118
- additional_prompt: str,
119
- negative_prompt: str,
120
- control_image: PIL.Image.Image,
121
- vis_control_image: PIL.Image.Image,
122
- num_samples: int,
123
- num_steps: int,
124
- guidance_scale: float,
125
- seed: int,
126
- ) -> list[PIL.Image.Image]:
127
- self.load_controlnet_weight(task_name)
128
- results = self.run_pipe(
129
- prompt=self.get_prompt(prompt, additional_prompt),
130
- negative_prompt=negative_prompt,
131
- control_image=control_image,
132
- num_images=num_samples,
133
- num_steps=num_steps,
134
- guidance_scale=guidance_scale,
135
- seed=seed,
136
- )
137
- return [vis_control_image] + results
138
-
139
  @staticmethod
140
  def preprocess_canny(
141
  input_image: np.ndarray,
@@ -157,7 +132,7 @@ class Model:
157
  prompt: str,
158
  additional_prompt: str,
159
  negative_prompt: str,
160
- num_samples: int,
161
  image_resolution: int,
162
  num_steps: int,
163
  guidance_scale: float,
@@ -171,18 +146,17 @@ class Model:
171
  low_threshold=low_threshold,
172
  high_threshold=high_threshold,
173
  )
174
- return self.process(
175
- task_name='canny',
176
- prompt=prompt,
177
- additional_prompt=additional_prompt,
178
  negative_prompt=negative_prompt,
179
  control_image=control_image,
180
- vis_control_image=vis_control_image,
181
- num_samples=num_samples,
182
  num_steps=num_steps,
183
  guidance_scale=guidance_scale,
184
  seed=seed,
185
  )
 
186
 
187
  @staticmethod
188
  def preprocess_hough(
@@ -215,7 +189,7 @@ class Model:
215
  prompt: str,
216
  additional_prompt: str,
217
  negative_prompt: str,
218
- num_samples: int,
219
  image_resolution: int,
220
  detect_resolution: int,
221
  num_steps: int,
@@ -231,18 +205,17 @@ class Model:
231
  value_threshold=value_threshold,
232
  distance_threshold=distance_threshold,
233
  )
234
- return self.process(
235
- task_name='hough',
236
- prompt=prompt,
237
- additional_prompt=additional_prompt,
238
  negative_prompt=negative_prompt,
239
  control_image=control_image,
240
- vis_control_image=vis_control_image,
241
- num_samples=num_samples,
242
  num_steps=num_steps,
243
  guidance_scale=guidance_scale,
244
  seed=seed,
245
  )
 
246
 
247
  @staticmethod
248
  def preprocess_hed(
@@ -267,7 +240,7 @@ class Model:
267
  prompt: str,
268
  additional_prompt: str,
269
  negative_prompt: str,
270
- num_samples: int,
271
  image_resolution: int,
272
  detect_resolution: int,
273
  num_steps: int,
@@ -279,18 +252,17 @@ class Model:
279
  image_resolution=image_resolution,
280
  detect_resolution=detect_resolution,
281
  )
282
- return self.process(
283
- task_name='hed',
284
- prompt=prompt,
285
- additional_prompt=additional_prompt,
286
  negative_prompt=negative_prompt,
287
  control_image=control_image,
288
- vis_control_image=vis_control_image,
289
- num_samples=num_samples,
290
  num_steps=num_steps,
291
  guidance_scale=guidance_scale,
292
  seed=seed,
293
  )
 
294
 
295
  @staticmethod
296
  def preprocess_scribble(
@@ -311,7 +283,7 @@ class Model:
311
  prompt: str,
312
  additional_prompt: str,
313
  negative_prompt: str,
314
- num_samples: int,
315
  image_resolution: int,
316
  num_steps: int,
317
  guidance_scale: float,
@@ -321,18 +293,17 @@ class Model:
321
  input_image=input_image,
322
  image_resolution=image_resolution,
323
  )
324
- return self.process(
325
- task_name='scribble',
326
- prompt=prompt,
327
- additional_prompt=additional_prompt,
328
  negative_prompt=negative_prompt,
329
  control_image=control_image,
330
- vis_control_image=vis_control_image,
331
- num_samples=num_samples,
332
  num_steps=num_steps,
333
  guidance_scale=guidance_scale,
334
  seed=seed,
335
  )
 
336
 
337
  @staticmethod
338
  def preprocess_scribble_interactive(
@@ -354,7 +325,7 @@ class Model:
354
  prompt: str,
355
  additional_prompt: str,
356
  negative_prompt: str,
357
- num_samples: int,
358
  image_resolution: int,
359
  num_steps: int,
360
  guidance_scale: float,
@@ -364,18 +335,17 @@ class Model:
364
  input_image=input_image,
365
  image_resolution=image_resolution,
366
  )
367
- return self.process(
368
- task_name='scribble',
369
- prompt=prompt,
370
- additional_prompt=additional_prompt,
371
  negative_prompt=negative_prompt,
372
  control_image=control_image,
373
- vis_control_image=vis_control_image,
374
- num_samples=num_samples,
375
  num_steps=num_steps,
376
  guidance_scale=guidance_scale,
377
  seed=seed,
378
  )
 
379
 
380
  @staticmethod
381
  def preprocess_fake_scribble(
@@ -408,7 +378,7 @@ class Model:
408
  prompt: str,
409
  additional_prompt: str,
410
  negative_prompt: str,
411
- num_samples: int,
412
  image_resolution: int,
413
  detect_resolution: int,
414
  num_steps: int,
@@ -420,18 +390,17 @@ class Model:
420
  image_resolution=image_resolution,
421
  detect_resolution=detect_resolution,
422
  )
423
- return self.process(
424
- task_name='scribble',
425
- prompt=prompt,
426
- additional_prompt=additional_prompt,
427
  negative_prompt=negative_prompt,
428
  control_image=control_image,
429
- vis_control_image=vis_control_image,
430
- num_samples=num_samples,
431
  num_steps=num_steps,
432
  guidance_scale=guidance_scale,
433
  seed=seed,
434
  )
 
435
 
436
  @staticmethod
437
  def preprocess_pose(
@@ -462,7 +431,7 @@ class Model:
462
  prompt: str,
463
  additional_prompt: str,
464
  negative_prompt: str,
465
- num_samples: int,
466
  image_resolution: int,
467
  detect_resolution: int,
468
  num_steps: int,
@@ -476,18 +445,17 @@ class Model:
476
  detect_resolution=detect_resolution,
477
  is_pose_image=is_pose_image,
478
  )
479
- return self.process(
480
- task_name='pose',
481
- prompt=prompt,
482
- additional_prompt=additional_prompt,
483
  negative_prompt=negative_prompt,
484
  control_image=control_image,
485
- vis_control_image=vis_control_image,
486
- num_samples=num_samples,
487
  num_steps=num_steps,
488
  guidance_scale=guidance_scale,
489
  seed=seed,
490
  )
 
491
 
492
  @staticmethod
493
  def preprocess_seg(
@@ -516,7 +484,7 @@ class Model:
516
  prompt: str,
517
  additional_prompt: str,
518
  negative_prompt: str,
519
- num_samples: int,
520
  image_resolution: int,
521
  detect_resolution: int,
522
  num_steps: int,
@@ -530,18 +498,17 @@ class Model:
530
  detect_resolution=detect_resolution,
531
  is_segmentation_map=is_segmentation_map,
532
  )
533
- return self.process(
534
- task_name='seg',
535
- prompt=prompt,
536
- additional_prompt=additional_prompt,
537
  negative_prompt=negative_prompt,
538
  control_image=control_image,
539
- vis_control_image=vis_control_image,
540
- num_samples=num_samples,
541
  num_steps=num_steps,
542
  guidance_scale=guidance_scale,
543
  seed=seed,
544
  )
 
545
 
546
  @staticmethod
547
  def preprocess_depth(
@@ -571,7 +538,7 @@ class Model:
571
  prompt: str,
572
  additional_prompt: str,
573
  negative_prompt: str,
574
- num_samples: int,
575
  image_resolution: int,
576
  detect_resolution: int,
577
  num_steps: int,
@@ -585,18 +552,17 @@ class Model:
585
  detect_resolution=detect_resolution,
586
  is_depth_image=is_depth_image,
587
  )
588
- return self.process(
589
- task_name='depth',
590
- prompt=prompt,
591
- additional_prompt=additional_prompt,
592
  negative_prompt=negative_prompt,
593
  control_image=control_image,
594
- vis_control_image=vis_control_image,
595
- num_samples=num_samples,
596
  num_steps=num_steps,
597
  guidance_scale=guidance_scale,
598
  seed=seed,
599
  )
 
600
 
601
  @staticmethod
602
  def preprocess_normal(
@@ -628,7 +594,7 @@ class Model:
628
  prompt: str,
629
  additional_prompt: str,
630
  negative_prompt: str,
631
- num_samples: int,
632
  image_resolution: int,
633
  detect_resolution: int,
634
  num_steps: int,
@@ -644,15 +610,14 @@ class Model:
644
  bg_threshold=bg_threshold,
645
  is_normal_image=is_normal_image,
646
  )
647
- return self.process(
648
- task_name='normal',
649
- prompt=prompt,
650
- additional_prompt=additional_prompt,
651
  negative_prompt=negative_prompt,
652
  control_image=control_image,
653
- vis_control_image=vis_control_image,
654
- num_samples=num_samples,
655
  num_steps=num_steps,
656
  guidance_scale=guidance_scale,
657
  seed=seed,
658
  )
 
 
111
  generator=generator,
112
  image=control_image).images
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  @staticmethod
115
  def preprocess_canny(
116
  input_image: np.ndarray,
 
132
  prompt: str,
133
  additional_prompt: str,
134
  negative_prompt: str,
135
+ num_images: int,
136
  image_resolution: int,
137
  num_steps: int,
138
  guidance_scale: float,
 
146
  low_threshold=low_threshold,
147
  high_threshold=high_threshold,
148
  )
149
+ self.load_controlnet_weight('canny')
150
+ results = self.run_pipe(
151
+ prompt=self.get_prompt(prompt, additional_prompt),
 
152
  negative_prompt=negative_prompt,
153
  control_image=control_image,
154
+ num_images=num_images,
 
155
  num_steps=num_steps,
156
  guidance_scale=guidance_scale,
157
  seed=seed,
158
  )
159
+ return [vis_control_image] + results
160
 
161
  @staticmethod
162
  def preprocess_hough(
 
189
  prompt: str,
190
  additional_prompt: str,
191
  negative_prompt: str,
192
+ num_images: int,
193
  image_resolution: int,
194
  detect_resolution: int,
195
  num_steps: int,
 
205
  value_threshold=value_threshold,
206
  distance_threshold=distance_threshold,
207
  )
208
+ self.load_controlnet_weight('hough')
209
+ results = self.run_pipe(
210
+ prompt=self.get_prompt(prompt, additional_prompt),
 
211
  negative_prompt=negative_prompt,
212
  control_image=control_image,
213
+ num_images=num_images,
 
214
  num_steps=num_steps,
215
  guidance_scale=guidance_scale,
216
  seed=seed,
217
  )
218
+ return [vis_control_image] + results
219
 
220
  @staticmethod
221
  def preprocess_hed(
 
240
  prompt: str,
241
  additional_prompt: str,
242
  negative_prompt: str,
243
+ num_images: int,
244
  image_resolution: int,
245
  detect_resolution: int,
246
  num_steps: int,
 
252
  image_resolution=image_resolution,
253
  detect_resolution=detect_resolution,
254
  )
255
+ self.load_controlnet_weight('hed')
256
+ results = self.run_pipe(
257
+ prompt=self.get_prompt(prompt, additional_prompt),
 
258
  negative_prompt=negative_prompt,
259
  control_image=control_image,
260
+ num_images=num_images,
 
261
  num_steps=num_steps,
262
  guidance_scale=guidance_scale,
263
  seed=seed,
264
  )
265
+ return [vis_control_image] + results
266
 
267
  @staticmethod
268
  def preprocess_scribble(
 
283
  prompt: str,
284
  additional_prompt: str,
285
  negative_prompt: str,
286
+ num_images: int,
287
  image_resolution: int,
288
  num_steps: int,
289
  guidance_scale: float,
 
293
  input_image=input_image,
294
  image_resolution=image_resolution,
295
  )
296
+ self.load_controlnet_weight('scribble')
297
+ results = self.run_pipe(
298
+ prompt=self.get_prompt(prompt, additional_prompt),
 
299
  negative_prompt=negative_prompt,
300
  control_image=control_image,
301
+ num_images=num_images,
 
302
  num_steps=num_steps,
303
  guidance_scale=guidance_scale,
304
  seed=seed,
305
  )
306
+ return [vis_control_image] + results
307
 
308
  @staticmethod
309
  def preprocess_scribble_interactive(
 
325
  prompt: str,
326
  additional_prompt: str,
327
  negative_prompt: str,
328
+ num_images: int,
329
  image_resolution: int,
330
  num_steps: int,
331
  guidance_scale: float,
 
335
  input_image=input_image,
336
  image_resolution=image_resolution,
337
  )
338
+ self.load_controlnet_weight('scribble')
339
+ results = self.run_pipe(
340
+ prompt=self.get_prompt(prompt, additional_prompt),
 
341
  negative_prompt=negative_prompt,
342
  control_image=control_image,
343
+ num_images=num_images,
 
344
  num_steps=num_steps,
345
  guidance_scale=guidance_scale,
346
  seed=seed,
347
  )
348
+ return [vis_control_image] + results
349
 
350
  @staticmethod
351
  def preprocess_fake_scribble(
 
378
  prompt: str,
379
  additional_prompt: str,
380
  negative_prompt: str,
381
+ num_images: int,
382
  image_resolution: int,
383
  detect_resolution: int,
384
  num_steps: int,
 
390
  image_resolution=image_resolution,
391
  detect_resolution=detect_resolution,
392
  )
393
+ self.load_controlnet_weight('scribble')
394
+ results = self.run_pipe(
395
+ prompt=self.get_prompt(prompt, additional_prompt),
 
396
  negative_prompt=negative_prompt,
397
  control_image=control_image,
398
+ num_images=num_images,
 
399
  num_steps=num_steps,
400
  guidance_scale=guidance_scale,
401
  seed=seed,
402
  )
403
+ return [vis_control_image] + results
404
 
405
  @staticmethod
406
  def preprocess_pose(
 
431
  prompt: str,
432
  additional_prompt: str,
433
  negative_prompt: str,
434
+ num_images: int,
435
  image_resolution: int,
436
  detect_resolution: int,
437
  num_steps: int,
 
445
  detect_resolution=detect_resolution,
446
  is_pose_image=is_pose_image,
447
  )
448
+ self.load_controlnet_weight('pose')
449
+ results = self.run_pipe(
450
+ prompt=self.get_prompt(prompt, additional_prompt),
 
451
  negative_prompt=negative_prompt,
452
  control_image=control_image,
453
+ num_images=num_images,
 
454
  num_steps=num_steps,
455
  guidance_scale=guidance_scale,
456
  seed=seed,
457
  )
458
+ return [vis_control_image] + results
459
 
460
  @staticmethod
461
  def preprocess_seg(
 
484
  prompt: str,
485
  additional_prompt: str,
486
  negative_prompt: str,
487
+ num_images: int,
488
  image_resolution: int,
489
  detect_resolution: int,
490
  num_steps: int,
 
498
  detect_resolution=detect_resolution,
499
  is_segmentation_map=is_segmentation_map,
500
  )
501
+ self.load_controlnet_weight('seg')
502
+ results = self.run_pipe(
503
+ prompt=self.get_prompt(prompt, additional_prompt),
 
504
  negative_prompt=negative_prompt,
505
  control_image=control_image,
506
+ num_images=num_images,
 
507
  num_steps=num_steps,
508
  guidance_scale=guidance_scale,
509
  seed=seed,
510
  )
511
+ return [vis_control_image] + results
512
 
513
  @staticmethod
514
  def preprocess_depth(
 
538
  prompt: str,
539
  additional_prompt: str,
540
  negative_prompt: str,
541
+ num_images: int,
542
  image_resolution: int,
543
  detect_resolution: int,
544
  num_steps: int,
 
552
  detect_resolution=detect_resolution,
553
  is_depth_image=is_depth_image,
554
  )
555
+ self.load_controlnet_weight('depth')
556
+ results = self.run_pipe(
557
+ prompt=self.get_prompt(prompt, additional_prompt),
 
558
  negative_prompt=negative_prompt,
559
  control_image=control_image,
560
+ num_images=num_images,
 
561
  num_steps=num_steps,
562
  guidance_scale=guidance_scale,
563
  seed=seed,
564
  )
565
+ return [vis_control_image] + results
566
 
567
  @staticmethod
568
  def preprocess_normal(
 
594
  prompt: str,
595
  additional_prompt: str,
596
  negative_prompt: str,
597
+ num_images: int,
598
  image_resolution: int,
599
  detect_resolution: int,
600
  num_steps: int,
 
610
  bg_threshold=bg_threshold,
611
  is_normal_image=is_normal_image,
612
  )
613
+ self.load_controlnet_weight('normal')
614
+ results = self.run_pipe(
615
+ prompt=self.get_prompt(prompt, additional_prompt),
 
616
  negative_prompt=negative_prompt,
617
  control_image=control_image,
618
+ num_images=num_images,
 
619
  num_steps=num_steps,
620
  guidance_scale=guidance_scale,
621
  seed=seed,
622
  )
623
+ return [vis_control_image] + results