Spaces:
Running
Running
Ming Li
commited on
Commit
·
c763397
1
Parent(s):
e8acbbf
rm zero gpu on run_pipe func
Browse files
model.py
CHANGED
@@ -113,6 +113,8 @@ class Model:
|
|
113 |
guidance_scale: float,
|
114 |
seed: int,
|
115 |
) -> list[PIL.Image.Image]:
|
|
|
|
|
116 |
generator = torch.Generator().manual_seed(seed)
|
117 |
return self.pipe(
|
118 |
prompt=prompt,
|
@@ -125,7 +127,7 @@ class Model:
|
|
125 |
).images
|
126 |
|
127 |
@torch.inference_mode()
|
128 |
-
@spaces.GPU()
|
129 |
def process_canny(
|
130 |
self,
|
131 |
image: np.ndarray,
|
@@ -170,141 +172,7 @@ class Model:
|
|
170 |
return [control_image] * num_images + results + conditions_of_generated_imgs
|
171 |
|
172 |
@torch.inference_mode()
|
173 |
-
|
174 |
-
self,
|
175 |
-
image: np.ndarray,
|
176 |
-
prompt: str,
|
177 |
-
additional_prompt: str,
|
178 |
-
negative_prompt: str,
|
179 |
-
num_images: int,
|
180 |
-
image_resolution: int,
|
181 |
-
preprocess_resolution: int,
|
182 |
-
num_steps: int,
|
183 |
-
guidance_scale: float,
|
184 |
-
seed: int,
|
185 |
-
value_threshold: float,
|
186 |
-
distance_threshold: float,
|
187 |
-
) -> list[PIL.Image.Image]:
|
188 |
-
if image is None:
|
189 |
-
raise ValueError
|
190 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
191 |
-
raise ValueError
|
192 |
-
if num_images > MAX_NUM_IMAGES:
|
193 |
-
raise ValueError
|
194 |
-
|
195 |
-
self.preprocessor.load("MLSD")
|
196 |
-
control_image = self.preprocessor(
|
197 |
-
image=image,
|
198 |
-
image_resolution=image_resolution,
|
199 |
-
detect_resolution=preprocess_resolution,
|
200 |
-
thr_v=value_threshold,
|
201 |
-
thr_d=distance_threshold,
|
202 |
-
)
|
203 |
-
self.load_controlnet_weight("MLSD")
|
204 |
-
results = self.run_pipe(
|
205 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
206 |
-
negative_prompt=negative_prompt,
|
207 |
-
control_image=control_image,
|
208 |
-
num_images=num_images,
|
209 |
-
num_steps=num_steps,
|
210 |
-
guidance_scale=guidance_scale,
|
211 |
-
seed=seed,
|
212 |
-
)
|
213 |
-
return [control_image] + results
|
214 |
-
|
215 |
-
@torch.inference_mode()
|
216 |
-
def process_scribble(
|
217 |
-
self,
|
218 |
-
image: np.ndarray,
|
219 |
-
prompt: str,
|
220 |
-
additional_prompt: str,
|
221 |
-
negative_prompt: str,
|
222 |
-
num_images: int,
|
223 |
-
image_resolution: int,
|
224 |
-
preprocess_resolution: int,
|
225 |
-
num_steps: int,
|
226 |
-
guidance_scale: float,
|
227 |
-
seed: int,
|
228 |
-
preprocessor_name: str,
|
229 |
-
) -> list[PIL.Image.Image]:
|
230 |
-
if image is None:
|
231 |
-
raise ValueError
|
232 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
233 |
-
raise ValueError
|
234 |
-
if num_images > MAX_NUM_IMAGES:
|
235 |
-
raise ValueError
|
236 |
-
|
237 |
-
if preprocessor_name == "None":
|
238 |
-
image = HWC3(image)
|
239 |
-
image = resize_image(image, resolution=image_resolution)
|
240 |
-
control_image = PIL.Image.fromarray(image)
|
241 |
-
elif preprocessor_name == "HED":
|
242 |
-
self.preprocessor.load(preprocessor_name)
|
243 |
-
control_image = self.preprocessor(
|
244 |
-
image=image,
|
245 |
-
image_resolution=image_resolution,
|
246 |
-
detect_resolution=preprocess_resolution,
|
247 |
-
scribble=False,
|
248 |
-
)
|
249 |
-
elif preprocessor_name == "PidiNet":
|
250 |
-
self.preprocessor.load(preprocessor_name)
|
251 |
-
control_image = self.preprocessor(
|
252 |
-
image=image,
|
253 |
-
image_resolution=image_resolution,
|
254 |
-
detect_resolution=preprocess_resolution,
|
255 |
-
safe=False,
|
256 |
-
)
|
257 |
-
self.load_controlnet_weight("scribble")
|
258 |
-
results = self.run_pipe(
|
259 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
260 |
-
negative_prompt=negative_prompt,
|
261 |
-
control_image=control_image,
|
262 |
-
num_images=num_images,
|
263 |
-
num_steps=num_steps,
|
264 |
-
guidance_scale=guidance_scale,
|
265 |
-
seed=seed,
|
266 |
-
)
|
267 |
-
return [control_image] + results
|
268 |
-
|
269 |
-
@torch.inference_mode()
|
270 |
-
def process_scribble_interactive(
|
271 |
-
self,
|
272 |
-
image_and_mask: dict[str, np.ndarray],
|
273 |
-
prompt: str,
|
274 |
-
additional_prompt: str,
|
275 |
-
negative_prompt: str,
|
276 |
-
num_images: int,
|
277 |
-
image_resolution: int,
|
278 |
-
num_steps: int,
|
279 |
-
guidance_scale: float,
|
280 |
-
seed: int,
|
281 |
-
) -> list[PIL.Image.Image]:
|
282 |
-
if image_and_mask is None:
|
283 |
-
raise ValueError
|
284 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
285 |
-
raise ValueError
|
286 |
-
if num_images > MAX_NUM_IMAGES:
|
287 |
-
raise ValueError
|
288 |
-
|
289 |
-
image = image_and_mask["mask"]
|
290 |
-
image = HWC3(image)
|
291 |
-
image = resize_image(image, resolution=image_resolution)
|
292 |
-
control_image = PIL.Image.fromarray(image)
|
293 |
-
|
294 |
-
self.load_controlnet_weight("scribble")
|
295 |
-
results = self.run_pipe(
|
296 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
297 |
-
negative_prompt=negative_prompt,
|
298 |
-
control_image=control_image,
|
299 |
-
num_images=num_images,
|
300 |
-
num_steps=num_steps,
|
301 |
-
guidance_scale=guidance_scale,
|
302 |
-
seed=seed,
|
303 |
-
)
|
304 |
-
return [control_image] + results
|
305 |
-
|
306 |
-
@torch.inference_mode()
|
307 |
-
@spaces.GPU()
|
308 |
def process_softedge(
|
309 |
self,
|
310 |
image: np.ndarray,
|
@@ -371,53 +239,7 @@ class Model:
|
|
371 |
return [control_image] * num_images + results + conditions_of_generated_imgs
|
372 |
|
373 |
@torch.inference_mode()
|
374 |
-
|
375 |
-
self,
|
376 |
-
image: np.ndarray,
|
377 |
-
prompt: str,
|
378 |
-
additional_prompt: str,
|
379 |
-
negative_prompt: str,
|
380 |
-
num_images: int,
|
381 |
-
image_resolution: int,
|
382 |
-
preprocess_resolution: int,
|
383 |
-
num_steps: int,
|
384 |
-
guidance_scale: float,
|
385 |
-
seed: int,
|
386 |
-
preprocessor_name: str,
|
387 |
-
) -> list[PIL.Image.Image]:
|
388 |
-
if image is None:
|
389 |
-
raise ValueError
|
390 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
391 |
-
raise ValueError
|
392 |
-
if num_images > MAX_NUM_IMAGES:
|
393 |
-
raise ValueError
|
394 |
-
|
395 |
-
if preprocessor_name == "None":
|
396 |
-
image = HWC3(image)
|
397 |
-
image = resize_image(image, resolution=image_resolution)
|
398 |
-
control_image = PIL.Image.fromarray(image)
|
399 |
-
else:
|
400 |
-
self.preprocessor.load("Openpose")
|
401 |
-
control_image = self.preprocessor(
|
402 |
-
image=image,
|
403 |
-
image_resolution=image_resolution,
|
404 |
-
detect_resolution=preprocess_resolution,
|
405 |
-
hand_and_face=True,
|
406 |
-
)
|
407 |
-
self.load_controlnet_weight("Openpose")
|
408 |
-
results = self.run_pipe(
|
409 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
410 |
-
negative_prompt=negative_prompt,
|
411 |
-
control_image=control_image,
|
412 |
-
num_images=num_images,
|
413 |
-
num_steps=num_steps,
|
414 |
-
guidance_scale=guidance_scale,
|
415 |
-
seed=seed,
|
416 |
-
)
|
417 |
-
return [control_image] + results
|
418 |
-
|
419 |
-
@torch.inference_mode()
|
420 |
-
@spaces.GPU()
|
421 |
def process_segmentation(
|
422 |
self,
|
423 |
image: np.ndarray,
|
@@ -471,7 +293,7 @@ class Model:
|
|
471 |
return [control_image] * num_images + results + conditions_of_generated_imgs
|
472 |
|
473 |
@torch.inference_mode()
|
474 |
-
@spaces.GPU()
|
475 |
def process_depth(
|
476 |
self,
|
477 |
image: np.ndarray,
|
@@ -524,52 +346,7 @@ class Model:
|
|
524 |
return [control_image] * num_images + results + conditions_of_generated_imgs
|
525 |
|
526 |
@torch.inference_mode()
|
527 |
-
|
528 |
-
self,
|
529 |
-
image: np.ndarray,
|
530 |
-
prompt: str,
|
531 |
-
additional_prompt: str,
|
532 |
-
negative_prompt: str,
|
533 |
-
num_images: int,
|
534 |
-
image_resolution: int,
|
535 |
-
preprocess_resolution: int,
|
536 |
-
num_steps: int,
|
537 |
-
guidance_scale: float,
|
538 |
-
seed: int,
|
539 |
-
preprocessor_name: str,
|
540 |
-
) -> list[PIL.Image.Image]:
|
541 |
-
if image is None:
|
542 |
-
raise ValueError
|
543 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
544 |
-
raise ValueError
|
545 |
-
if num_images > MAX_NUM_IMAGES:
|
546 |
-
raise ValueError
|
547 |
-
|
548 |
-
if preprocessor_name == "None":
|
549 |
-
image = HWC3(image)
|
550 |
-
image = resize_image(image, resolution=image_resolution)
|
551 |
-
control_image = PIL.Image.fromarray(image)
|
552 |
-
else:
|
553 |
-
self.preprocessor.load("NormalBae")
|
554 |
-
control_image = self.preprocessor(
|
555 |
-
image=image,
|
556 |
-
image_resolution=image_resolution,
|
557 |
-
detect_resolution=preprocess_resolution,
|
558 |
-
)
|
559 |
-
self.load_controlnet_weight("NormalBae")
|
560 |
-
results = self.run_pipe(
|
561 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
562 |
-
negative_prompt=negative_prompt,
|
563 |
-
control_image=control_image,
|
564 |
-
num_images=num_images,
|
565 |
-
num_steps=num_steps,
|
566 |
-
guidance_scale=guidance_scale,
|
567 |
-
seed=seed,
|
568 |
-
)
|
569 |
-
return [control_image] + results
|
570 |
-
|
571 |
-
@torch.inference_mode()
|
572 |
-
@spaces.GPU()
|
573 |
def process_lineart(
|
574 |
self,
|
575 |
image: np.ndarray,
|
@@ -638,81 +415,3 @@ class Model:
|
|
638 |
conditions_of_generated_imgs = [PIL.Image.fromarray((255 - np.array(x)).astype(np.uint8)) for x in conditions_of_generated_imgs]
|
639 |
|
640 |
return [control_image] * num_images + results + conditions_of_generated_imgs
|
641 |
-
|
642 |
-
@torch.inference_mode()
|
643 |
-
def process_shuffle(
|
644 |
-
self,
|
645 |
-
image: np.ndarray,
|
646 |
-
prompt: str,
|
647 |
-
additional_prompt: str,
|
648 |
-
negative_prompt: str,
|
649 |
-
num_images: int,
|
650 |
-
image_resolution: int,
|
651 |
-
num_steps: int,
|
652 |
-
guidance_scale: float,
|
653 |
-
seed: int,
|
654 |
-
preprocessor_name: str,
|
655 |
-
) -> list[PIL.Image.Image]:
|
656 |
-
if image is None:
|
657 |
-
raise ValueError
|
658 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
659 |
-
raise ValueError
|
660 |
-
if num_images > MAX_NUM_IMAGES:
|
661 |
-
raise ValueError
|
662 |
-
|
663 |
-
if preprocessor_name == "None":
|
664 |
-
image = HWC3(image)
|
665 |
-
image = resize_image(image, resolution=image_resolution)
|
666 |
-
control_image = PIL.Image.fromarray(image)
|
667 |
-
else:
|
668 |
-
self.preprocessor.load(preprocessor_name)
|
669 |
-
control_image = self.preprocessor(
|
670 |
-
image=image,
|
671 |
-
image_resolution=image_resolution,
|
672 |
-
)
|
673 |
-
self.load_controlnet_weight("shuffle")
|
674 |
-
results = self.run_pipe(
|
675 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
676 |
-
negative_prompt=negative_prompt,
|
677 |
-
control_image=control_image,
|
678 |
-
num_images=num_images,
|
679 |
-
num_steps=num_steps,
|
680 |
-
guidance_scale=guidance_scale,
|
681 |
-
seed=seed,
|
682 |
-
)
|
683 |
-
return [control_image] + results
|
684 |
-
|
685 |
-
@torch.inference_mode()
|
686 |
-
def process_ip2p(
|
687 |
-
self,
|
688 |
-
image: np.ndarray,
|
689 |
-
prompt: str,
|
690 |
-
additional_prompt: str,
|
691 |
-
negative_prompt: str,
|
692 |
-
num_images: int,
|
693 |
-
image_resolution: int,
|
694 |
-
num_steps: int,
|
695 |
-
guidance_scale: float,
|
696 |
-
seed: int,
|
697 |
-
) -> list[PIL.Image.Image]:
|
698 |
-
if image is None:
|
699 |
-
raise ValueError
|
700 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
701 |
-
raise ValueError
|
702 |
-
if num_images > MAX_NUM_IMAGES:
|
703 |
-
raise ValueError
|
704 |
-
|
705 |
-
image = HWC3(image)
|
706 |
-
image = resize_image(image, resolution=image_resolution)
|
707 |
-
control_image = PIL.Image.fromarray(image)
|
708 |
-
self.load_controlnet_weight("ip2p")
|
709 |
-
results = self.run_pipe(
|
710 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
711 |
-
negative_prompt=negative_prompt,
|
712 |
-
control_image=control_image,
|
713 |
-
num_images=num_images,
|
714 |
-
num_steps=num_steps,
|
715 |
-
guidance_scale=guidance_scale,
|
716 |
-
seed=seed,
|
717 |
-
)
|
718 |
-
return [control_image] + results
|
|
|
113 |
guidance_scale: float,
|
114 |
seed: int,
|
115 |
) -> list[PIL.Image.Image]:
|
116 |
+
self.pipe.to(self.device)
|
117 |
+
self.pipe.controlnet.to(self.device)
|
118 |
generator = torch.Generator().manual_seed(seed)
|
119 |
return self.pipe(
|
120 |
prompt=prompt,
|
|
|
127 |
).images
|
128 |
|
129 |
@torch.inference_mode()
|
130 |
+
@spaces.GPU(enable_queue=True)
|
131 |
def process_canny(
|
132 |
self,
|
133 |
image: np.ndarray,
|
|
|
172 |
return [control_image] * num_images + results + conditions_of_generated_imgs
|
173 |
|
174 |
@torch.inference_mode()
|
175 |
+
@spaces.GPU(enable_queue=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
def process_softedge(
|
177 |
self,
|
178 |
image: np.ndarray,
|
|
|
239 |
return [control_image] * num_images + results + conditions_of_generated_imgs
|
240 |
|
241 |
@torch.inference_mode()
|
242 |
+
@spaces.GPU(enable_queue=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
def process_segmentation(
|
244 |
self,
|
245 |
image: np.ndarray,
|
|
|
293 |
return [control_image] * num_images + results + conditions_of_generated_imgs
|
294 |
|
295 |
@torch.inference_mode()
|
296 |
+
@spaces.GPU(enable_queue=True)
|
297 |
def process_depth(
|
298 |
self,
|
299 |
image: np.ndarray,
|
|
|
346 |
return [control_image] * num_images + results + conditions_of_generated_imgs
|
347 |
|
348 |
@torch.inference_mode()
|
349 |
+
@spaces.GPU(enable_queue=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
def process_lineart(
|
351 |
self,
|
352 |
image: np.ndarray,
|
|
|
415 |
conditions_of_generated_imgs = [PIL.Image.fromarray((255 - np.array(x)).astype(np.uint8)) for x in conditions_of_generated_imgs]
|
416 |
|
417 |
return [control_image] * num_images + results + conditions_of_generated_imgs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|