Spaces:
Paused
Paused
Andranik Sargsyan
commited on
Commit
·
bfd34e9
1
Parent(s):
919cdba
add demo code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -35
- .gitignore +7 -0
- README.md +9 -7
- app.py +350 -0
- assets/.gitignore +1 -0
- assets/config/ddpm/v1.yaml +14 -0
- assets/config/ddpm/v2-upsample.yaml +24 -0
- assets/config/encoders/clip.yaml +1 -0
- assets/config/encoders/openclip.yaml +4 -0
- assets/config/unet/inpainting/v1.yaml +15 -0
- assets/config/unet/inpainting/v2.yaml +16 -0
- assets/config/unet/upsample/v2.yaml +19 -0
- assets/config/vae-upsample.yaml +16 -0
- assets/config/vae.yaml +17 -0
- assets/examples/images/a19.jpg +3 -0
- assets/examples/images/a2.jpg +3 -0
- assets/examples/images/a4.jpg +3 -0
- assets/examples/images/a40.jpg +3 -0
- assets/examples/images/a46.jpg +3 -0
- assets/examples/images/a51.jpg +3 -0
- assets/examples/images/a54.jpg +3 -0
- assets/examples/images/a65.jpg +3 -0
- assets/examples/masked/a19.png +3 -0
- assets/examples/masked/a2.png +3 -0
- assets/examples/masked/a4.png +3 -0
- assets/examples/masked/a40.png +3 -0
- assets/examples/masked/a46.png +3 -0
- assets/examples/masked/a51.png +3 -0
- assets/examples/masked/a54.png +3 -0
- assets/examples/masked/a65.png +3 -0
- assets/examples/sbs/a19.png +3 -0
- assets/examples/sbs/a2.png +3 -0
- assets/examples/sbs/a4.png +3 -0
- assets/examples/sbs/a40.png +3 -0
- assets/examples/sbs/a46.png +3 -0
- assets/examples/sbs/a51.png +3 -0
- assets/examples/sbs/a54.png +3 -0
- assets/examples/sbs/a65.png +3 -0
- lib/__init__.py +0 -0
- lib/methods/__init__.py +0 -0
- lib/methods/rasg.py +88 -0
- lib/methods/sd.py +74 -0
- lib/methods/sr.py +141 -0
- lib/models/__init__.py +1 -0
- lib/models/common.py +49 -0
- lib/models/ds_inp.py +46 -0
- lib/models/sam.py +20 -0
- lib/models/sd15_inp.py +44 -0
- lib/models/sd2_inp.py +47 -0
- lib/models/sd2_sr.py +204 -0
.gitattributes
CHANGED
@@ -1,35 +1,2 @@
|
|
1 |
-
*.
|
2 |
-
*.
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
|
3 |
+
.gradio/
|
4 |
+
|
5 |
+
outputs/
|
6 |
+
gradio_tmp/
|
7 |
+
__pycache__/
|
README.md
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
---
|
2 |
-
title: HD
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
|
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: HD-Painter
|
3 |
+
emoji: 🧑🎨
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.47.1
|
8 |
+
python_version: 3.9
|
9 |
+
suggested_hardware: a100-large
|
10 |
app_file: app.py
|
11 |
pinned: false
|
12 |
+
pipeline_tag: hd-painter
|
13 |
---
|
14 |
+
Paper: https://arxiv.org/abs/2312.14091
|
|
app.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import OrderedDict
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import shutil
|
6 |
+
import uuid
|
7 |
+
import torch
|
8 |
+
from pathlib import Path
|
9 |
+
from lib.utils.iimage import IImage
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from lib import models
|
13 |
+
from lib.methods import rasg, sd, sr
|
14 |
+
from lib.utils import poisson_blend, image_from_url_text
|
15 |
+
|
16 |
+
|
17 |
+
TMP_DIR = 'gradio_tmp'
|
18 |
+
if Path(TMP_DIR).exists():
|
19 |
+
shutil.rmtree(TMP_DIR)
|
20 |
+
Path(TMP_DIR).mkdir(exist_ok=True, parents=True)
|
21 |
+
|
22 |
+
os.environ['GRADIO_TEMP_DIR'] = TMP_DIR
|
23 |
+
|
24 |
+
on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
|
25 |
+
|
26 |
+
negative_prompt_str = "text, bad anatomy, bad proportions, blurry, cropped, deformed, disfigured, duplicate, error, extra limbs, gross proportions, jpeg artifacts, long neck, low quality, lowres, malformed, morbid, mutated, mutilated, out of frame, ugly, worst quality"
|
27 |
+
positive_prompt_str = "Full HD, 4K, high quality, high resolution"
|
28 |
+
|
29 |
+
example_inputs = [
|
30 |
+
['assets/examples/images/a40.jpg', 'medieval castle'],
|
31 |
+
['assets/examples/images/a4.jpg', 'parrot'],
|
32 |
+
['assets/examples/images/a65.jpg', 'hoodie'],
|
33 |
+
['assets/examples/images/a54.jpg', 'salad'],
|
34 |
+
['assets/examples/images/a51.jpg', 'space helmet'],
|
35 |
+
['assets/examples/images/a46.jpg', 'teddy bear'],
|
36 |
+
['assets/examples/images/a19.jpg', 'antique greek vase'],
|
37 |
+
['assets/examples/images/a2.jpg', 'sunglasses'],
|
38 |
+
]
|
39 |
+
thumbnails = [
|
40 |
+
'https://lh3.googleusercontent.com/fife/AK0iWDxaRlJeZGIBuB3_oGKKhd5buKaL3kJ6moPp7r6svDYFkehrv5XKyF6mj_Pqy3yV-qDQQZj_n8CMpuYH_iDy5717rPL-qpXf-prcIv2pET4LDjFInFQVLoxuurB3_7fugCUogt5ZYIGlTgSbirJkHDqN5min3riiRJLd0ZuGN-ETDDCs5e0wohdX_Wl_Kv5RAdYjZqWFGfKcmzCF-1ny6bjCab-1hcDzIaokggkl3INTG23nhSLWhNB8EdeCdfkQmoF5fROfCPe0Lsvk6RwlAr-ZQ9jaszJ355oXOz4Y-IQLfWvnyfdyxQ02anJF7DiBZbfhH4WcA6faK7Pbjo7RIt-nTN6THwTGlEmxHEzrO-1iuy3j4UfMxPB7r-RjDrn7M9KwLuWSybIJ3dgSx_DF9OqIYkHeG_TSvs1Vi-ugbq9E0K20HNXkFlhe6ty7Ee5xr0nNqhD6lVr-6wbBvI1SJUQg03KoJkduYbTTQ-ibGIwCu7J24kdlo-_d1xBWC01zeW2MfqPjpdNHUtnPAF2IuldEsWMJMHEpWjmXwYfM1D1BmIcOuRLvdNEA6IyPD6VCgLVD27MdLWxdKHgpSvTZ1beTcMM_QuHV2vMJnBT4H7faPmLBe9yMTKlIe7Yf8fGiehOrfEgXZwNyRzmbOyGKfNRKiSVuqeqJt3h1Vze0UKFzAx3rYnibzc--58atwasvdKY_dj8f0_QKR0l2vzWogqh3NdhO9m2r77ni5JXaik25QQzF-BPe70ikVgEomHa4xySlf-Gr3Z1v_HWDa4kYg2KE3P0WjqUD_cmdylbZ45TIOkGAU1qmiEcTCs_wOkOIfCC58Z9P6Lff_BRxc_lhut8hp-Pe8tfRXhITYFFRthXforXyDuqzPmWBBz2EnUHqxa1aYOo4WeQc7KTXK2kF36qAzPwm2QFDreqV3QmS02Gev8MCz2U3bQQa9H8VSB4ZGhrzNWIwG1R8YU_G1Xb4BAgjYEZ5qX6WJaQrjuD6_Zw_pDxRew8t0mjj_tjCrmoZjpxjHsgtudH4IBah5xag0bGdUSThnszYJPM5g1weldimKE63HqaQTG-IN_N51nBken4K_0-liw73ABzUiA6EJzqCQKEoT2pejNTN88N9RFXXB5ZJ9x0NvtuMcy_JsrsVArfA5b7m4OGbwF6b5wN3Ag3XNQ3d58hVJ_Hw99HNIvrjTVCVmU0-DsYIu_njIfyjdqps1cyv6_f23F0X-q4ZsbooPoNg2lc3oeFtC2K58Dgr8JsBjt33Lnmra3YG2nBh3lkIycCvrUOS69xo8Aq9z2ODklCd9soUUQVNa2XKKsMofRi8ESGwuiWYKWdSI7XAXi4dbzFQhWwMRfsSAk3KGMpnUnlOd8Jx68fiGMwTKCCsgIPgo8mFZHSaN29ipzNoACG1bd0ZFX2qxa7UfMnnRobl-AcFLzTtOYD2T2PcIKKjnxy1-gG1Ff0mgGS-BCDl9TtCscAoKcXqTYWCSY5otnfVPcjgEFw_0KEmIbmrf6Rpyr7YGJFPTQfDRah5Ro5LhBIbXhFAkyyvWQZCecGWi3lRAh-pSm2zuGl0nVdykgMqzVwiat8_lhnHMBNRg9xWyLWrgw3bWpdf64Bjgshr0V1XV7kikgRDpyPMJPDNzX1jVASiL5S3dmWDnd49tdnaKBUXGAIWmTyUWfs1bDFZZoOfrmLrnLLUT9R11lMO0EgkZ48_UU7CMJKwgq6hegh-ErsV6S6_SLa2tQYqYlSNUESO7jw1w4_T9KTkM2635QlPH-A81yvgrSsRo0lq3uPtFEHP35-dfGT63yFd=w3385-h1364',
|
41 |
+
'https://lh3.googleusercontent.com/fife/AK0iWDw7CKs6uhv5cJnJ48zW_5UW3IjrXwWqvemhbJGzB_xNhHgNyt4NIP0xPkCK_nyGFEsbb-4iaeWnoJbbeQxU4yBnnuckhD5d-t1qgOuFke8NRVaGCDgtDimbwNY7Jkkc9pj3pMd33P7UhBg7rVueTL7hmAyhtT_wLD9B5a50VfM_T9ErEeYUgsiEYPoE-msuALr2Sl5t0jXttYNt8R2JHdXG8Y6EegiQaEdZjZbJxdwA5S175pCmrOvVVBH334NwarLF_HDr8ESIVEkYxNtLDwb0QPDZSDeci-JS2WZcvEQUmClbxuk3hjtOR0Rm6VX_wXo0-nczybNDdTiZFsBc1Fj1L_08VLYP1OIz1fbcRF7iv2luyGmgUijGaJLr0wtW_C65eXK9aU-NJ5KVt67JT-TncumCqBpvn1-msbzbcL8krOmYqiZthOcvnrIQHeW1jIUVD_zMZADmuRdpjxO-Dn3yC3uev0ve0u4b3vEHR4iiLx-4Jf5DrMvpfdHHIL76verfqhLkiz3gtZENg1jsRTcjH35AWspo3-lMJnZqMP9wzvw7ubBBP1QrblSwfflz4aAIzmh-WisQ5UWMSltQ9DwpT0UggiXe0JbbtloWkbo2_VpaaMJh7VhBQbvRbFUxCm__UPVfglTfdiNv1m2777oGwyDbv682I_qGDW5nG91D22pqo9-enRGLvCnN-STKtWTnJQ5Qod2QYRAEI-1IR9_h-UWtCyBLLpqcGKxkHaLZvjpDmTdiPhVkoG1irracCLbPJGvrrclorr1k0nqTIhvVVH8pNZdx-yCK6KFFGNAz3aSsmxGJURWEt4TQEVDLfet8iFzuLfD2Tg9gJ4kozB88G8PrfLnGopwlPO4y6GkcfJggtshQr7yuo3xnxtci1FlLhOJNkMKCg0xhL-tVDIfWsMCzSIk08XtGgitU7DK8CE7DlkA7chCuq2BHUTeFCEALF-5DLFxa-3ru30gtNsbd71sNH01S8qhTVegWM8yQwwXXhIVqoWM-e-SSUyQldgklvatev1V6iDxLh9u2nbhoxxCcVZvyKx7rUyhSEEQzXF6nYNQn5HaAmp4jW1rCoVVvVMNP_wfvY76vamAPkw22z_nylCpbKW0BQmTBZStyoB4u_lZANVkL0fbyWoxyfZHKOoDgyTKdzDWmbeiIPmr7nvF8QWYsq160pKzIKIhj9ccF2NqyXHqjaDzmqVfZ6EUg0sJ8CeiCUskbSnyxZHVeM5GgG7cj9xMncUUsdvd0WTw3g0aqstb2MAxF3ucO0EXP8dAWqU2XrJ6E0rA51Jdn16CW9KV5HMGEcB5Y6bN-Md8-FT1r_0NQq8ZOJ5nmTzg5JNSTEw-FOBwSvukEMDmbf6ADs1emocozcqk9KiYv1ii3niD58XPIhQWfBcgyYBsqbJQ5x-UqijqcihN7i3RgQMOVJem954MKU8D-DzSet2FUcGWseyzF6Sr4GJQn5g0rOXtFP2HwT2kTXy_pqad1ukQkovfQG30gbRvIRAzMmigVJskadvadasF7Lc5eo5Pm7CPPZQ-ZJUlZuowbagWm0Kz_T-PIreVNJ5WxxQ4w_HH27QqKxe_LZvHJ9y74O8oVCywVJeKQirxDc-yKuUNJ7vagJ5a-paB_cXf6RECpu7caFM6U3g43xVMtxDf-3d0r6G7zs0oPHb4uDbovgLzqooAnjNY3_RGWldXUks0pmCqFbEhLgl8JJd4Lyv5mt34WTDqPggRnraAV5AEtZ8NvGNDWk-ItblCZQ-neXYVj0z1p1rWSJiG1XaDa8ro7G9-7XJtwfJDnFZsf=w3385-h1364',
|
42 |
+
'https://lh3.googleusercontent.com/fife/AK0iWDxK7tdzr1tx6g49uS-pyR-7BdPJfk0_wWoErGJiQPMRbggXquguwth2I35go_GFsW8TUoj17p-jtCKCi0ryGH2gTGNvZEki8xfPn1aroIxOXRt7Ucl64Lu2cyixNSBqyJUBQEOL6LzY6DlwXJ5SVxRRvJkj1Vz-yJYz4mNA12YdDFPj7UkLs_7PEDRanydfEVFhIoiMrYitDW8fGsqn53PHaB8GoxoQHAXFyf82HPfhqqgKNRUR9yz1LAN5q72ERzxg5h4apStr16aLzijuuJF7wmtltf3HdTIMwAAd8Blsgk0d88rNZzItVZMdCeHgAyvoD7JKUCTWWJbu-uXf1PRd6mIg11OzweG04c-uloZFvJdF0pdfcfocibniKt-filYASFN43KYtO7Eyzc-YUHJn1qm73eDr0RkntHy1kQaDPmlRvlZrtbDgZIEau_FsMa-BRNq_jHlTDkyRR_8dyBf__na0I45lcChJWZCOfinbQwzQfiryK13yBB1bQDGk47JM5PTIEBK2JvsOEwezA9geZWhi3oIM6N4b1YXzvWbDYbN-cQvrBd6cuXoecUL-i_Qqda2-xGByMP4BUI0J1qF3mKwzTLvzZAk1f-zCUNHpOrzF2-WsXCUoL4R_t7MOZ2OeT7a5pCAnyiFX8Vdq3x9QzrJjjTqvttS5ElXVUGi5NBVeRV0hbXI5XLMKmiu-lWV3625eosYK0FE6hmFZk9ZAeSkrLDjkh9k1XwpB2x83kcpC_tkk25G9czy6utoNhC2YHc1uVvnUjZ0DjK2d0naLVOCKOCVRcTSvN3kt7PaLASArW5vOzZX0_LcjuhHSzK63mok6XvXXZ5AOiF2fDhUNUl1W8TaDz7aIssbibuLXxjpluone6WJglRrxAkpwtIw2rwk73icupAeVG6Frx2QctrHC0vLG4PKSfrIrZRiJC52Oxpp2dAMnhj76CIwX9wwbj86uSyIuxeIviiKumwTolZjbOKhPrHY-ZOydVDn_ZUsln4mOfDZXwUl9p1CanzpLCUsZTXcq12zTbHJRPwNw6SH9srQc3cTYYsWpBu77VmS5zVkIndUPItXWmUqds_2AI8LCUSWE9NVHiCRSw-B8J8j3SkqOD95-np2cMxDxoVX4nD11CwBr2W0xWS1kqc4mZL4aHdwhKdDcSKk6EGG4kmY1Eq1RsYxc0I08TJK-_nrbWTgA4NDTjh-oJpFFiF11ZHbEksKlSWBhH-MXF-0zxmiar-EIhAQe4hQX_suDty4GXxMwzcF84JthDAvWC0tGLJV5WQGLkLvBkTODythSNoN7l5HjnzoJf2JKI5AF8W3HJJWgntD0F2iFe5K0Ik_huovMbzGjkgPZDvWzrYpA2V6VZyPI2q08axvAfaevPrCd7G8Jy_gliK3hj3qxjIvqJXBUSO2puqum-TlUaSYgbjhWUCLXKE1BH8RRDt0brLklVvNwGpHxG6Eg9vmzp00fV3qhjoJqokTYOAxcAumtfUpyJFVk3cvpgm4826o4kKUM6eXRz8ke1L34YI9Rtf7Ppk2864hIBIy5xA6ajIpI_rIXWG8ogrhEp9GMbXGkLGiLMfd4t7P50JynPc9hYWfachaTlQfQWuz1NtV3ZGdTaEe5y_SpEwWG-UQ-_MglvY8p0EYNjGc6yhu4oEx9D0RN5H4QA2e0YjmVybG05Hfm7LwEBMBvHb8GHnpHGTxL-WDlZtU8DqDmGmhv5l4npcLnuVmSE4JcEtRJ651ZYn88wD2ghTAJtcEuyFW3nzZ2Na3z0LW4Y8Y42YdWb3hmgGsJAYBIa9YC_Cmh=w3385-h1364',
|
43 |
+
'https://lh3.googleusercontent.com/fife/AK0iWDwPG9L-g7hX4NKVZ45PuqnRkQUV3XZlMEWZxXASV2zYVytOhbhds--yBA0ZXUzxeEAbpGa3qCl8sXlu-of8ZjVVNSdSen-zgpoN0BT-R7JGmTqjT9aEdhnysy85Gqr4e4A5KPrInLLvCCkFSHolBEhu-hO1u8kEZ7aDuk8FsPvchP1SvqjnuWtY_OCYa2FpjHH_i24cITjs58nlDxNTFdQQNCnX4KLJQAzh5cKxtG_7pqoBfBzpXBvVDwoxPafbFn0X0u8oJ7V-VOh7faO0JtJrsfcdS06tvw_J29RbDdWojpUS_pRUtA8w9Z6flNlbShj4Ib8X1V6veQLrEYWi3SojO5tfntpl5bKXPGgWhsT8mykvfhg4Tq-Ti9kyPBDPQqEqf0ll9-wFHYoAWSMsmCvZ7kakuDMH1766rOk8QgqYbDr6kMuJw_OFBR2qX7DaQw9XnJqGv7J3guCj-yU5vXes4sOtZ3n1IOheXnlvJL79KhXbsLgznYBjC9uv0fDLuqKaFL34YisLY5xY_2zi_uzTc5BXmtoAFH8otrUVUTVyt6sEPaCtjyPzG2uoSq094QFY_FxagW0E6OYPqskBUwPzg_Wc2eBwazaGy4MXEoDzgvK3PId5N4MWqU232uHsUrEEaCUUKm8-KX43c0C6O2daqjwsh-bxKOIic3pHqDoAOcq-QA83qB7pPyWwGddsaOWRIdzf-QLrB55YuvTeTmOEL84m2YDtxNEHWdYcnXYlZEXAex1xMOfqkbCQGM8jSgC2di_794HKMpsNtKwYH31WoI-Pl73t4THuq9CX1pWYdjhH0ss1j3PUMJ4hELE187E2m6fhQLHNNSHRajfesIwPVP1FgP0W5o-AKLaC7o53R0XrOZ7SqYjOua5TX0RXyJE6ceZATyiZ_2tzUpc4baqRGJb88vr8dhfECh_1J7O_8ufMvYL47HWhVNqltcIGjujtIXp64XEqShlut8TkqtB1-3OqE7gNwfiP4859pPkk0q-kfIdaLCjqB6PeqNFNgcX0dfK_-weXpM-LjrgNM_aVlQEBKwTDjJteUHY3pFTrEU8FoPnPeIjbU_rGIcwA3Lf_NE0CJUhKG1gzKYWG00CCg0srvQRipJpVrW-4SkUrxtW28iq6kVRMER83sN2RnVi7ugEuZ3S-OPdgXkUglk3bIz9ehDcFfQrll0iCabIGQxrN9-7GfKSpi3j2nU0aLeZ6DxgQ7x9f9f-hUgxM97i_90SX-S_M6Lo6bA28RB515HqqXc8FN72Wsp_XlFp3oTb4cJtOI5SSUYkdxGtrn0AQWXuE0Rp1DEaUDEdVbu757FQsct4Us0jLByauEfiZcQS_lGTjzGQxud-2NMKIeIYOWAxBGz08eTWv4b2k_IsnekNnXAHzP5WCYFoNtDiGbgURj7QWz6Qh8HbTsFJCrI3mPYkTN2jMUwviTUKpeElUwqDeA_DhT5tiHa7ldtdzncNzpxH-J75asams5J_O2W5dJMN4PYUxGmVw5mhWEClFo2stMJOfPkrmaga4D1OXc_C3Utf6OWB5CBBHGNjfAekE3QWm7ibEtwC91g1pIujyCUEYVs8YiFi0RWcMsmhWG2yrghA9Hu3kERWuVT0nHHfLRx8L0_PjlQBkNijUjK_gI-C66729qLNsmdwZxym1JCFgV4xRT7Vu9EQGL3tyhbOLRWYlHAjBE6itM-DM5T7idIyWamFBb9Nt6ZFehCpslKzEHy2VtEyiRlZ2z9kH-IEIuZ2qCu3kiGL8m6yQTyboTCI7LNLYbSn497VR8h3WkcqWqlVlWOzGUikL=w3385-h1364',
|
44 |
+
'https://lh3.googleusercontent.com/fife/AK0iWDyoheia4yOtefRVFWnA5EHolnS-xa4pPy0yL4Wb3uIr1-mSHxQ54i0wKr3lk6uhAm5qjJ9UVqwut3UqR54log5XVnaeNu_3y9Gsn44quU_HsGAY0-84HygWkr9Ld0_Dt1JEefD_f82Ijp1TXQf1CS6IbruUOnrljFOraQ2Bu-1To3p2Pk-T3tU8xVCxU3pr6zvFz9UYHFTss8_Xw70ZtLRMT4x4suHtOaSPI42VTq_T8HEnm04Ie_0Yh0Ri2_P-qsaP2ysz-Wnw4Ykbj9nc2VIDqtvCRwti36mlNheyg_8xOLD9sMNWDu3PXoRtn7aBUpw50GMCqeGUMAUMPKhJuZoTDdQK2SVHJ8QwNhYcC8mhAmRFvt85hvuT6NyZ00SeYwyj9_rux77vZThx5ioDoUAH3CQBcwgH82xahatReym7ehL3DXm3JLHDdQLbRM6xvsE0X0MsMXkuNlx8wn6HyteW_yq8fK3wQJ_3XLh-gK5YOFdvd08A7IIK6qi__-o8nvEK_hfHgMMS-O9eT9acfa2Sr0rGGNvoUlpljVOyONyQftP2nGD92R51K2Xcq9oV2sjSu8TDDel2t2DYr5gMB3FsDfgEhQEWE-O9fRLkzIZnOTTAcUDoS-b3R_kB485Ry56FzKFbz3w3tvHhwJ3W-sqvygb8LDoF3qjURWrf7Pau6UMjTSPH6FTjbVzZeWITKsSnA14xA2wj6xi9Bp3JkCOsT0qOfXrPkK7uT3H2U1M8uqFjpNTj6u3tZyF8GgueprmH13rJjYst9d_vevhpXpXIhSDvAbJuA3xG-YNr_SG8BMXpi7N3NC0VHmBXhl_wDBVUnAD6VVqqNtXzB-6NdZzjZKnxApDdi5SGp8C9kDd6bkaXUmwG__BRNVqdbchMw3H1re5t9VxiqWTelfGl6UqAX1W8RzQR21bgu1x7EAGbVsC_UpMxeDeJq9PprMF9cCRC9ziT_H2-ubctN7O9qPpADPT0nqWN28vH-9CqB4jPeBqYwi658twIpLwRFvEukajsrvb-OqmeesnT9QpCPEpL4G0HrjB9rkRX7g-T6q3kqbGfvnqgW4Q8ilOUnkEsFK3qLCIVxwp6B3t2yg8XSOsM3tnzsA1Ua8MofFKqvmwaq7QbIBcMOHa50I1fxRMs4YEVLgu89fSZxRrKSr_8oUXcRWiqSgF-pLLU37GYMrn3yXUJxxO4bXiEifeK8id5H1khO-8ZEBXzwuZBQBYLXbCkCou6enZu98tfPA-prr_NpKfSZM4e0clZWhjo1761-BmJSlIG0JrTo2N2cKhVz-WM5BZjVr1FPYOri5fIjORUL17_RbqMw5MefYN6tLPnpVrOSUmKW9bVgdFdOVpj9Wg5lZxxswAM5qK-wOzEjfdBCW0xjKzxD97zhszCKxc7Rj2uoaJzk9CeauU83LYcihkyMHn5IKhLeAou2yKEwgqXkU0LUObdUxqjavnVgVMcVYRds_j7zpmpKNT_KOV9s2jus8aptJl8sXZ_Gzw0vi7wC_AuAGfmNCsZBEFhn_b_ZgONqdaR9EKwP0hVRDNw4dZkYn9MGqeiX77I40eEbnwVbKaUZvK2Nrt7ukjkgSP6FdvZFfs-aVIUFMc6rBAknAFDFHPzFYcy9ANPDgVAlms8fO5GGuid8kgpxtjSoneUG_A9Az6JY-suY8zFr6mDtJC1LuY9ftIKW37ZaMtutqhbBX5b5w4DLnO32Uv_ZtK1nbV8E4T0KY6qHm2kZiIYuYCsHysbvNW4MODUihrpDh1UfILdMEsae6zGpCTQ7gnKlrB7QJ5Ig28y96Om=w3385-h1364',
|
45 |
+
'https://lh3.googleusercontent.com/fife/AK0iWDxriikbWcV-sJ5xBcy8xcJHsRC9EEBmCimXzsrBiSOy-UnRwSoGhdXNqnB92vMZk1LiOn0h3KbBMgbFr0I0SmQG5TwG__bM8AvtMrXA-DGAaNTktO1JcSPb-wVr_OepLK6P1hHyGYSvcGDdF03pIFNcKxN6QLCZB6rgFLdaWd72z3Dx8eB8QLtU8P_4G4sT2oJ0hAUlz9mKz9lTxWDWl0-1ufKUctmvWfU8EjNuQQzckJA0YwvV2jDl4ZA1r42UHDss6dy6hkjmPoDZN-p2UW_3Ju4vOVAdv0Pf73demNu-L1LALuK0rq4d7OHUaAuP0bubXJAH-wsuVervwPQDEsmBwR-FdW8jfdppKxy4MC2ISf-eyLmsTYR9dLPIlKkOAHh_84vdLeWdGtxs9gES-jhrqOiW-brFtIZoKbH_nR1yLeq9IJ7Z7-GJk3PVi_Ex7gT9WJyIuySNi6s7GH6AnDFf9wfHkzyJ1qLDKMddrNi4GfEyO89S0yScxZFW7hERAH7T2_1YqkeMv48ik9dMA0RcpJK0LYA8GDD5_MQaycXUjeoCO5tvlGQUEE8Dt815Ev3xh9dlKJiKJf_cClK_kL7iBICasjcVSNxP606Zn1fBTc_hF3QymDP8Q_Xl-g9p-ufobo_x2x6nOOdfiq1q37ik_3kPZCPPCnI2OtM3EYZ_yVlFmwWbtqxp2Rz3jaWo6t1TL8PWkFZno-aqJ8YEm4ppZVyX7ne6GTFKzFpW1SKnAnjqu67LjS0DCFhFATAcYhWRmdb5cMXve72eU5DNLgKZOaM0THDa5dnQPaRmu-c-7HlB3WISebcJHj0vIDw7DLwxaCnqqLgybvqW7O-Rt82bat5Lc-jIyMjjkZOutnc3OoYTPRN5PrLZ6HXGFBTq_s5fCimxpXvlw0bzNHqOzovgP4NC6UChXwn9CxSrbLou8vuqeD4YqyjBhh5Do3l7KcGMZtYUUMUhVf7fUrZIJcLb-ZBrg6UdiPRc3h9jcubLjXcPIrzxZeqgQ2ccJRljZqq4CBoX8WU6AiEe51Hbp0C96693G7rVomhzxa8JCMCCL3sy8v8nl1tqfqQ3kE53XnvzuqNMzJCNIfZg-GehqMBmJZp-Vaup9DOLWYWnzRqYsO9pC9r37Ajhh-wkxZyV0XhXD6EecYGcXBX8TOQRF9OMV9BYRDYVnpnGfTBGyxI7lkCU3prZQL0H4JlS48k38oWj9fE1F-PnGysDXLAxYqEERP-AXIwgsPyoKYrq7mG5UzBcEn6qminlmT2wM-EnInAb7F9e2aMxFwWv17RG2fWkO28p4t9hiMEyvAQfYyWKhGui1yqPnHnytNe3BJspV0uekBUBauvvHHLF0_tgHTJwcU03KlOLto7iIOUuwLgBT3z-_diWu_w9WTiOqZppBYKHUVfQlR1JsFn0j0Tg-kk5NbRfRU5BB4nMtQlkpW3vjTrQW88SrSKNOr_vepd1F39EqxBNG4tIyR6lvSCSGefWzUxwHBoHou8MtGPvcdxInB7imHdpwCFtoVy0wGh2kXy36CyxXqV82VehVvmfc8AnNtgVWTPJ8vvs2AQ2WL7xd7lE-DCfy-la9d31t6iLxVAkq_S067Z3iyTiJg508a_BGt_UF41VW2m2dW2v8KRjgXks4Wz7kAvO7rLLhL2CMdh-W7bBP4bfVa9BfcKVKGxaa6kCQB4rXtFs2gEbWx4o6iXf6KIUgl-RAdAgYYjIsavlkyihtgQ28-1--7JjaeO61d3wLYgJ5_t_POL8OjmFJfPeO1m9HTzDJf30C5hVP348dRhrkxSTPlSo=w3385-h1364',
|
46 |
+
'https://lh3.googleusercontent.com/fife/AK0iWDzcrCgKc1IMiriL10KOqvN79VCXca9U5g657RP5HU1zsUNVcZPyefBpRCtevRE0k6FkGxUa4yKW5ELtqaSmVRE8R2jwZPHmpd0xeBvIHYeoySmntJD3wJl_iC9Ma20qubbH3OFIxNFFLCXMJDPY9cJ2D1xFzjD5t6jQi7tJKXc0B5W9vGTyQ3dzHlZuLEfTp4D58WszOFBLPsOq8zeve0ej-2wCEMkCrT8kTwfKTnsi_GXRpK1xFRZjczbA9kSwIxZ_x1iWNWW_XY8aw5Kgn5Du-4r3rBqpHr4_fzv5ehY20T7Cq6Rf15zCBnBad1HGxLLnrJrXujKwxHELaRfqXPItfQAmuoIfL-tgK6giSSNKDDT2Ynn7AXxHz5vzD3m4ALYWql9Qpe1Jw69AxGZt3vyuSC76LPhMmNIJzVIXVMLnw55hHCLH28GVg2WxLTLyUXoX7o-PHhFjGoBbx3X58yQjyrEQPbfmY3gtuijG_vrjoz3SPcoG4eXB6d_NnEM3b84Ml9eqE8BjlACMEJriwtqLEstPBhIsJygDZKB822bY63pG2hYzYi5-bWS19NRwun-jdNraiye8D9AIqxrXaknLvbCnWbgbKDJzlRnt9Nz8GJKU__oZ1wQgzq2DOQqEoJxPwpocJxwYjnzIb-dFqhcLkR9SQVx7rAEyJ_u23WBSp4AVTw6c7sYxd597C5EFBvzOv4qYtJp1G-hBrvHYFvYD13oDtgcFDO_F0nXg7XwmCJ7aQ28Hp79dWewQiSq0nBcgWSOic1Q4feWLFEL21Dw0pFmSVF9f3C67YNA7ZXukOAnv6xUN5RNVoBl58-KDOFSQZYysT8FGYaaY8bFRvYvI-VXcBtMuAw1VYMKLKHrhg2mSvvdPCxsGH_rQWijgGUm9pvqbmyrPzYRd-Tr_i6pKH7EMcBYgcvZNMk9_B2EMj6NUWLVkHEUTDiovgFQJDrYAuiUcKZgp_MXJcTDO79qJEsiP52C83C8Vg91w4t_Q0dPklY8EjjpbMZGvsr5NC-TlQraNFszsrUA_JsNdnMW6Te2GMnXHDaHjEDCdX2kY3XC1Ltzvl3f4tHY12OMOUzMKHFWHWyBNYNKlHnwPdxYNxvHKa-9p_okvxw547oBfsXwUrRpQQVxljmLZJbbGpxkfbEW6Rg70MRKHdbEUg0h0FS7F90kuD7pR1zUFv92fPoI5BUNUn2XjQb-DZZaC1oF5VtRMgO8RSFM_Dolrx_c6ZrPLy7bllakKTDj48CNvafL3UxaD6x9FFu4dNQVHKidcVS__EZ1SYMSZoDkD2sna5Pzpjl0Qz61U4KVAz5lSCJbF1stdiwO4jMwzmAgMNV5-fnJA3kkf9dDaIzqk2diKUh-WGSOwEwwizH54Y-e2EASTqCzGK7FGoVRVr1d1PdN7wcd8MuRMXqBfonrIGmf5cSiuvOL6odbSChO2WFOKkRbBHeV1pO46uaBVeLKGjcbAALymYsv23_veIW-RQdHnaVlvB2HE7wD1afI-3LaRKCjCUD_X5QJuTXff_EQayDwtW46_0hYRI76LGXjQuOc_LUdIi2QHlO4kt2kR8eLBPm09gKl0GJrZoS9HxKS0sERHCjUvugFcvVmgt7idjFKfi0AuUz9XmqYCoSRiMYj9z612Ot5L2D00SLTOcTLV6nlx-PeBpAyiu2ia0ehvVjzLUn9ai0_XOg1bU-ab8fvzrRJIrao5ZxdRp9wF41lJpkanwvurynNJXXu0uk2SoA1_soIeLshOsONxT4DG_PitVFMYjxY8rgx7hfNBJmAJW4GZPsnx1_P5Ojip=w3385-h1364',
|
47 |
+
'https://lh3.googleusercontent.com/fife/AK0iWDziEIX1beA1lUdgcqMcnVcmRuODfH0IHpIkZW5YYZzhQRcRmYCfi9N7-vRWfcfMTuR28ZWDA5EngVjUpwIbVxRBF2DaIb_lJhd9zG3arGxRbi7CwmWdhAXeODEvniYR-IxtWB9lYpNd9hJ8wdleTP_ai10Xscy7iFeXOmFZ76dnr0r690LnZULOd7iyGv9EWZmhKhw4wtEJsqi3e1Yu4CXRsVrM5KLYKG-EWdRW_-m9H3o0G2W7KvOwDvqAewIz5zApPBHvuEE6x-XaUOuF_FuYQVhsKpfI_1Y70SCjOCphpbn51Bv8idg0tgTDn8oL8hkvSl50VqgQqjCLNxmCHlQE88xmjK_4NMI3kbIBLWfiPGCURr95dZt8eniqVi7yu8LNgkaMixAdRBCrQF_z56EsIvkozenXBdS2FiaUEh7LqIHLBcOa6ZaxV6_3t5Q83wgJTaM4cNkvH5_nCeQ9wkwKjf7zcBxFusa5LvhM-qSm4BJz3WzE1zgqTLVnDeh-EFNPMilPhevOdBuNfTY_VF8chvWS5Nwwcxlls8xSdVVqblYGw8YBlzWdi_X5PqynTKn6aWE0IiWOzA_O0hg2q1FtAHRT-PaINo4wjIbBar6fiNNwcZIeTKJrijHcpkIhnI8PHxrUtO3s0c2pfLFvuxCCRMSEfxpcwt0rz-ODEWIZkALajwE2SyFV6Qioc4fH_xWnI-jgRvzHjDf2c14vx3bXjM_gy-25mrECLQYcSWWZVINUbvKf6_YQDwwzAKL9zhMpyGa4EToTBhMSmroGi-NwIPxh8gAfdBCAh8TFAdg2aA7D3N_KpAv4Eh5bkovhCiALFYkGLch6KogZcn7NU3OX8qyn_wJ3oGO2CFmfKkMtLHqmjQpnLtM1U9BPnRELir7pNyG9bTNzs-Vz7-Hzu3vJavGeRhypl0JCoGfO08be-ee_7EnUcKSdepfd3dG39Gc1eulLgIVCRb82Ga5mAkNp_SDSa2BGI24--uOyAUwTBazpQjJ25W0wsHpLRF4obk8Tygl8Fgt2F-VPXYz1-q0x3_KZVWf-PJmKjYD6t3ICuBMoFeJtQUxp88WlSKC7KvhEZYdEaHmEabNNK7j-VTAgi0BeBaw_dTTO0tad9rXbCW9Co3Tc1YXv53oz96VURj-FAKHk_PKPRSV7-NO-BWAk1DOTq3ZDnlKUTA5-x6k4IR5HyNzW9C7rIPGzd_PRA9ddSiRxOjSiBru_P8xS0zQn6p75V48ZkoNsLPWEWCKhANJOaOB7Y01pg3wjjnftuxkp0KpokrlCZVUn2eKPmB0Oee6TP_6DVFhgM6ksqLHO-sNxpehUjWDx84znkN0MihGRgl6TK-6xnWzD9tjvIOsK0mBzk_XY3Vuvb5OEZvLzDJ5POqNHjLcAFaDtX7gsAUtEWk20qmRbpGBnHiZv2kLOUWCy6ICkc3yFv5uUMx7pxgfc_YO95ybO8-FTDG7m1yaoz-WdLV3tHao4_MfFaRXGKtV0_7xnlyXEZ3tMYwKu4hRx2lIOsL4Aff_O8-H0jmJId0llt__iOdVDkuypQWQDOKGGP9B1_gfLkV-ymEP0Bl59jQWNnAqE-jUpTeRRcUB6FkcH8XBPKL7F9N0sq-6XeOmPPpsecmm3SflF6zJ1YV8Uv6H4_9_uQLVBB8wXSvtcQuwgzYnrtpjpMQwFqSvJDhcCPGRfRCR6H7oa-T_ACYAMcICpl8felwVUOQs4O03ywLHNrZBY05hS13cj-_aYw69kw9TdetT-GbvTKC6eY5uwBTq4ytb4eeJQJc4zBlB2Dw1vKmcgIFfZ=w3385-h1364',
|
48 |
+
]
|
49 |
+
|
50 |
+
example_previews = [
|
51 |
+
[thumbnails[0], 'Prompt: medieval castle'],
|
52 |
+
[thumbnails[1], 'Prompt: parrot'],
|
53 |
+
[thumbnails[2], 'Prompt: hoodie'],
|
54 |
+
[thumbnails[3], 'Prompt: salad'],
|
55 |
+
[thumbnails[4], 'Prompt: space helmet'],
|
56 |
+
[thumbnails[5], 'Prompt: laptop'],
|
57 |
+
[thumbnails[6], 'Prompt: antique greek vase'],
|
58 |
+
[thumbnails[7], 'Prompt: sunglasses'],
|
59 |
+
]
|
60 |
+
|
61 |
+
# Load models
|
62 |
+
inpainting_models = OrderedDict([
|
63 |
+
("Dreamshaper Inpainting V8", models.ds_inp.load_model()),
|
64 |
+
("Stable-Inpainting 2.0", models.sd2_inp.load_model()),
|
65 |
+
("Stable-Inpainting 1.5", models.sd15_inp.load_model())
|
66 |
+
])
|
67 |
+
sr_model = models.sd2_sr.load_model()
|
68 |
+
sam_predictor = models.sam.load_model()
|
69 |
+
|
70 |
+
inp_model = None
|
71 |
+
cached_inp_model_name = ''
|
72 |
+
|
73 |
+
def remove_cached_inpainting_model():
|
74 |
+
global inp_model
|
75 |
+
global cached_inp_model_name
|
76 |
+
del inp_model
|
77 |
+
inp_model = None
|
78 |
+
cached_inp_model_name = ''
|
79 |
+
torch.cuda.empty_cache()
|
80 |
+
|
81 |
+
|
82 |
+
def set_model_from_name(inp_model_name):
|
83 |
+
global cached_inp_model_name
|
84 |
+
global inp_model
|
85 |
+
|
86 |
+
if inp_model_name == cached_inp_model_name:
|
87 |
+
print (f"Activating Cached Inpaintng Model: {inp_model_name}")
|
88 |
+
return
|
89 |
+
|
90 |
+
print (f"Activating Inpaintng Model: {inp_model_name}")
|
91 |
+
inp_model = inpainting_models[inp_model_name]
|
92 |
+
cached_inp_model_name = inp_model_name
|
93 |
+
|
94 |
+
|
95 |
+
def rasg_run(use_painta, prompt, input, seed, eta, negative_prompt, positive_prompt, ddim_steps,
|
96 |
+
guidance_scale=7.5, batch_size=4):
|
97 |
+
torch.cuda.empty_cache()
|
98 |
+
|
99 |
+
seed = int(seed)
|
100 |
+
batch_size = max(1, min(int(batch_size), 4))
|
101 |
+
|
102 |
+
image = IImage(input['image']).resize(512)
|
103 |
+
mask = IImage(input['mask']).rgb().resize(512)
|
104 |
+
|
105 |
+
method = ['rasg']
|
106 |
+
if use_painta: method.append('painta')
|
107 |
+
|
108 |
+
inpainted_images = []
|
109 |
+
blended_images = []
|
110 |
+
for i in range(batch_size):
|
111 |
+
inpainted_image = rasg.run(
|
112 |
+
ddim = inp_model,
|
113 |
+
method = '-'.join(method),
|
114 |
+
prompt = prompt,
|
115 |
+
image = image.padx(64),
|
116 |
+
mask = mask.alpha().padx(64),
|
117 |
+
seed = seed+i*1000,
|
118 |
+
eta = eta,
|
119 |
+
prefix = '{}',
|
120 |
+
negative_prompt = negative_prompt,
|
121 |
+
positive_prompt = f', {positive_prompt}',
|
122 |
+
dt = 1000 // ddim_steps,
|
123 |
+
guidance_scale = guidance_scale
|
124 |
+
).crop(image.size)
|
125 |
+
blended_image = poisson_blend(orig_img = image.data[0], fake_img = inpainted_image.data[0],
|
126 |
+
mask = mask.data[0], dilation = 12)
|
127 |
+
|
128 |
+
blended_images.append(blended_image)
|
129 |
+
inpainted_images.append(inpainted_image.numpy()[0])
|
130 |
+
|
131 |
+
return blended_images, inpainted_images
|
132 |
+
|
133 |
+
|
134 |
+
def sd_run(use_painta, prompt, input, seed, eta, negative_prompt, positive_prompt, ddim_steps,
|
135 |
+
guidance_scale=7.5, batch_size=4):
|
136 |
+
torch.cuda.empty_cache()
|
137 |
+
|
138 |
+
seed = int(seed)
|
139 |
+
batch_size = max(1, min(int(batch_size), 4))
|
140 |
+
|
141 |
+
image = IImage(input['image']).resize(512)
|
142 |
+
mask = IImage(input['mask']).rgb().resize(512)
|
143 |
+
|
144 |
+
method = ['default']
|
145 |
+
if use_painta: method.append('painta')
|
146 |
+
|
147 |
+
inpainted_images = []
|
148 |
+
blended_images = []
|
149 |
+
for i in range(batch_size):
|
150 |
+
inpainted_image = sd.run(
|
151 |
+
ddim = inp_model,
|
152 |
+
method = '-'.join(method),
|
153 |
+
prompt = prompt,
|
154 |
+
image = image.padx(64),
|
155 |
+
mask = mask.alpha().padx(64),
|
156 |
+
seed = seed+i*1000,
|
157 |
+
eta = eta,
|
158 |
+
prefix = '{}',
|
159 |
+
negative_prompt = negative_prompt,
|
160 |
+
positive_prompt = f', {positive_prompt}',
|
161 |
+
dt = 1000 // ddim_steps,
|
162 |
+
guidance_scale = guidance_scale
|
163 |
+
).crop(image.size)
|
164 |
+
|
165 |
+
blended_image = poisson_blend(orig_img = image.data[0], fake_img = inpainted_image.data[0],
|
166 |
+
mask = mask.data[0], dilation = 12)
|
167 |
+
|
168 |
+
blended_images.append(blended_image)
|
169 |
+
inpainted_images.append(inpainted_image.numpy()[0])
|
170 |
+
|
171 |
+
return blended_images, inpainted_images
|
172 |
+
|
173 |
+
|
174 |
+
def upscale_run(
|
175 |
+
prompt, input, ddim_steps, seed, use_sam_mask, gallery, img_index,
|
176 |
+
negative_prompt='', positive_prompt=', high resolution professional photo'):
|
177 |
+
torch.cuda.empty_cache()
|
178 |
+
|
179 |
+
# Load SR model and SAM predictor
|
180 |
+
# sr_model = models.sd2_sr.load_model()
|
181 |
+
# sam_predictor = None
|
182 |
+
# if use_sam_mask:
|
183 |
+
# sam_predictor = models.sam.load_model()
|
184 |
+
|
185 |
+
seed = int(seed)
|
186 |
+
img_index = int(img_index)
|
187 |
+
|
188 |
+
img_index = 0 if img_index < 0 else img_index
|
189 |
+
img_index = len(gallery) - 1 if img_index >= len(gallery) else img_index
|
190 |
+
img_info = gallery[img_index if img_index >= 0 else 0]
|
191 |
+
inpainted_image = image_from_url_text(img_info)
|
192 |
+
lr_image = IImage(inpainted_image)
|
193 |
+
hr_image = IImage(input['image']).resize(2048)
|
194 |
+
hr_mask = IImage(input['mask']).resize(2048)
|
195 |
+
output_image = sr.run(sr_model, sam_predictor, lr_image, hr_image, hr_mask, prompt=prompt + positive_prompt,
|
196 |
+
noise_level=0, blend_trick=True, blend_output=True, negative_prompt=negative_prompt,
|
197 |
+
seed=seed, use_sam_mask=use_sam_mask)
|
198 |
+
return output_image.numpy()[0], output_image.numpy()[0]
|
199 |
+
|
200 |
+
|
201 |
+
def switch_run(use_rasg, model_name, *args):
|
202 |
+
set_model_from_name(model_name)
|
203 |
+
if use_rasg:
|
204 |
+
return rasg_run(*args)
|
205 |
+
return sd_run(*args)
|
206 |
+
|
207 |
+
|
208 |
+
with gr.Blocks(css='style.css') as demo:
|
209 |
+
gr.HTML(
|
210 |
+
"""
|
211 |
+
<div style="text-align: center; max-width: 1200px; margin: 20px auto;">
|
212 |
+
<h1 style="font-weight: 900; font-size: 3rem; margin-bottom: 0.5rem">
|
213 |
+
🧑🎨 HD-Painter Demo
|
214 |
+
</h1>
|
215 |
+
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
216 |
+
Hayk Manukyan<sup>1*</sup>, Andranik Sargsyan<sup>1*</sup>, Barsegh Atanyan<sup>1</sup>, Zhangyang Wang<sup>1,2</sup>, Shant Navasardyan<sup>1</sup>
|
217 |
+
and <a href="https://www.humphreyshi.com/home">Humphrey Shi</a><sup>1,3</sup>
|
218 |
+
</h2>
|
219 |
+
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
220 |
+
<sup>1</sup>Picsart AI Resarch (PAIR), <sup>2</sup>UT Austin, <sup>3</sup>Georgia Tech
|
221 |
+
</h2>
|
222 |
+
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
223 |
+
[<a href="https://arxiv.org/abs/2312.14091" style="color:blue;">arXiv</a>]
|
224 |
+
[<a href="https://github.com/Picsart-AI-Research/HD-Painter" style="color:blue;">GitHub</a>]
|
225 |
+
</h2>
|
226 |
+
<h2 style="font-weight: 450; font-size: 1rem; margin: 0.7rem auto; max-width: 1000px">
|
227 |
+
<b>HD-Painter</b> enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method.
|
228 |
+
</h2>
|
229 |
+
</div>
|
230 |
+
""")
|
231 |
+
|
232 |
+
if on_huggingspace:
|
233 |
+
gr.HTML("""
|
234 |
+
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
235 |
+
<br/>
|
236 |
+
<a href="https://huggingface.co/spaces/PAIR/HD-Painter?duplicate=true">
|
237 |
+
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
238 |
+
</p>""")
|
239 |
+
|
240 |
+
with open('script.js', 'r') as f:
|
241 |
+
js_str = f.read()
|
242 |
+
|
243 |
+
demo.load(_js=js_str)
|
244 |
+
|
245 |
+
with gr.Row():
|
246 |
+
with gr.Column():
|
247 |
+
model_picker = gr.Dropdown(
|
248 |
+
list(inpainting_models.keys()),
|
249 |
+
value=0,
|
250 |
+
label = "Please select a model!",
|
251 |
+
)
|
252 |
+
with gr.Column():
|
253 |
+
use_painta = gr.Checkbox(value = True, label = "Use PAIntA")
|
254 |
+
use_rasg = gr.Checkbox(value = True, label = "Use RASG")
|
255 |
+
|
256 |
+
prompt = gr.Textbox(label = "Inpainting Prompt")
|
257 |
+
with gr.Row():
|
258 |
+
with gr.Column():
|
259 |
+
input = gr.ImageMask(label = "Input Image", brush_color='#ff0000', elem_id="inputmask")
|
260 |
+
|
261 |
+
with gr.Row():
|
262 |
+
inpaint_btn = gr.Button("Inpaint", scale = 0)
|
263 |
+
|
264 |
+
with gr.Accordion('Advanced options', open=False):
|
265 |
+
guidance_scale = gr.Slider(minimum = 0, maximum = 30, value = 7.5, label = "Guidance Scale")
|
266 |
+
eta = gr.Slider(minimum = 0, maximum = 1, value = 0.1, label = "eta")
|
267 |
+
ddim_steps = gr.Slider(minimum = 10, maximum = 100, value = 50, step = 1, label = 'Number of diffusion steps')
|
268 |
+
with gr.Row():
|
269 |
+
seed = gr.Number(value = 49123, label = "Seed")
|
270 |
+
batch_size = gr.Number(value = 1, label = "Batch size", minimum=1, maximum=4)
|
271 |
+
negative_prompt = gr.Textbox(value=negative_prompt_str, label = "Negative prompt", lines=3)
|
272 |
+
positive_prompt = gr.Textbox(value=positive_prompt_str, label = "Positive prompt", lines=1)
|
273 |
+
|
274 |
+
with gr.Column():
|
275 |
+
with gr.Row():
|
276 |
+
output_gallery = gr.Gallery(
|
277 |
+
[],
|
278 |
+
columns = 4,
|
279 |
+
preview = True,
|
280 |
+
allow_preview = True,
|
281 |
+
object_fit='scale-down',
|
282 |
+
elem_id='outputgallery'
|
283 |
+
)
|
284 |
+
with gr.Row():
|
285 |
+
upscale_btn = gr.Button("Send to Inpainting-Specialized Super-Resolution (x4)", scale = 1)
|
286 |
+
with gr.Row():
|
287 |
+
use_sam_mask = gr.Checkbox(value = False, label = "Use SAM mask for background preservation (for SR only, experimental feature)")
|
288 |
+
with gr.Row():
|
289 |
+
hires_image = gr.Image(label = "Hi-res Image")
|
290 |
+
|
291 |
+
label = gr.Markdown("## High-Resolution Generation Samples (2048px large side)")
|
292 |
+
|
293 |
+
with gr.Column():
|
294 |
+
example_container = gr.Gallery(
|
295 |
+
example_previews,
|
296 |
+
columns = 4,
|
297 |
+
preview = True,
|
298 |
+
allow_preview = True,
|
299 |
+
object_fit='scale-down'
|
300 |
+
)
|
301 |
+
|
302 |
+
gr.Examples(
|
303 |
+
[
|
304 |
+
example_inputs[i] + [[example_previews[i]]]
|
305 |
+
for i in range(len(example_previews))
|
306 |
+
],
|
307 |
+
[input, prompt, example_container]
|
308 |
+
)
|
309 |
+
|
310 |
+
mock_output_gallery = gr.Gallery([], columns = 4, visible=False)
|
311 |
+
mock_hires = gr.Image(label = "__MHRO__", visible = False)
|
312 |
+
html_info = gr.HTML(elem_id=f'html_info', elem_classes="infotext")
|
313 |
+
|
314 |
+
inpaint_btn.click(
|
315 |
+
fn=switch_run,
|
316 |
+
inputs=[
|
317 |
+
use_rasg,
|
318 |
+
model_picker,
|
319 |
+
use_painta,
|
320 |
+
prompt,
|
321 |
+
input,
|
322 |
+
seed,
|
323 |
+
eta,
|
324 |
+
negative_prompt,
|
325 |
+
positive_prompt,
|
326 |
+
ddim_steps,
|
327 |
+
guidance_scale,
|
328 |
+
batch_size
|
329 |
+
],
|
330 |
+
outputs=[output_gallery, mock_output_gallery],
|
331 |
+
api_name="inpaint"
|
332 |
+
)
|
333 |
+
upscale_btn.click(
|
334 |
+
fn=upscale_run,
|
335 |
+
inputs=[
|
336 |
+
prompt,
|
337 |
+
input,
|
338 |
+
ddim_steps,
|
339 |
+
seed,
|
340 |
+
use_sam_mask,
|
341 |
+
mock_output_gallery,
|
342 |
+
html_info
|
343 |
+
],
|
344 |
+
outputs=[hires_image, mock_hires],
|
345 |
+
api_name="upscale",
|
346 |
+
_js="function(a, b, c, d, e, f, g){ return [a, b, c, d, e, f, selected_gallery_index()] }",
|
347 |
+
)
|
348 |
+
|
349 |
+
demo.queue()
|
350 |
+
demo.launch(share=True, allowed_paths=[TMP_DIR])
|
assets/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
models/
|
assets/config/ddpm/v1.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
linear_start: 0.00085
|
2 |
+
linear_end: 0.0120
|
3 |
+
num_timesteps_cond: 1
|
4 |
+
log_every_t: 200
|
5 |
+
timesteps: 1000
|
6 |
+
first_stage_key: "jpg"
|
7 |
+
cond_stage_key: "txt"
|
8 |
+
image_size: 64
|
9 |
+
channels: 4
|
10 |
+
cond_stage_trainable: false
|
11 |
+
conditioning_key: crossattn
|
12 |
+
monitor: val/loss_simple_ema
|
13 |
+
scale_factor: 0.18215
|
14 |
+
use_ema: False # we set this to false because this is an inference only config
|
assets/config/ddpm/v2-upsample.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
parameterization: "v"
|
2 |
+
low_scale_key: "lr"
|
3 |
+
linear_start: 0.0001
|
4 |
+
linear_end: 0.02
|
5 |
+
num_timesteps_cond: 1
|
6 |
+
log_every_t: 200
|
7 |
+
timesteps: 1000
|
8 |
+
first_stage_key: "jpg"
|
9 |
+
cond_stage_key: "txt"
|
10 |
+
image_size: 128
|
11 |
+
channels: 4
|
12 |
+
cond_stage_trainable: false
|
13 |
+
conditioning_key: "hybrid-adm"
|
14 |
+
monitor: val/loss_simple_ema
|
15 |
+
scale_factor: 0.08333
|
16 |
+
use_ema: False
|
17 |
+
|
18 |
+
low_scale_config:
|
19 |
+
target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation
|
20 |
+
params:
|
21 |
+
noise_schedule_config: # image space
|
22 |
+
linear_start: 0.0001
|
23 |
+
linear_end: 0.02
|
24 |
+
max_noise_level: 350
|
assets/config/encoders/clip.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__class__: smplfusion.models.encoders.clip_embedder.FrozenCLIPEmbedder
|
assets/config/encoders/openclip.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__class__: smplfusion.models.encoders.open_clip_embedder.FrozenOpenCLIPEmbedder
|
2 |
+
__init__:
|
3 |
+
freeze: True
|
4 |
+
layer: "penultimate"
|
assets/config/unet/inpainting/v1.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__class__: smplfusion.models.unet.UNetModel
|
2 |
+
__init__:
|
3 |
+
image_size: 32 # unused
|
4 |
+
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
5 |
+
out_channels: 4
|
6 |
+
model_channels: 320
|
7 |
+
attention_resolutions: [ 4, 2, 1 ]
|
8 |
+
num_res_blocks: 2
|
9 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
10 |
+
num_heads: 8
|
11 |
+
use_spatial_transformer: True
|
12 |
+
transformer_depth: 1
|
13 |
+
context_dim: 768
|
14 |
+
use_checkpoint: False
|
15 |
+
legacy: False
|
assets/config/unet/inpainting/v2.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__class__: smplfusion.models.unet.UNetModel
|
2 |
+
__init__:
|
3 |
+
use_checkpoint: False
|
4 |
+
image_size: 32 # unused
|
5 |
+
in_channels: 9
|
6 |
+
out_channels: 4
|
7 |
+
model_channels: 320
|
8 |
+
attention_resolutions: [ 4, 2, 1 ]
|
9 |
+
num_res_blocks: 2
|
10 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
11 |
+
num_head_channels: 64 # need to fix for flash-attn
|
12 |
+
use_spatial_transformer: True
|
13 |
+
use_linear_in_transformer: True
|
14 |
+
transformer_depth: 1
|
15 |
+
context_dim: 1024
|
16 |
+
legacy: False
|
assets/config/unet/upsample/v2.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__class__: smplfusion.models.unet.UNetModel
|
2 |
+
__init__:
|
3 |
+
use_checkpoint: False
|
4 |
+
num_classes: 1000 # timesteps for noise conditioning (here constant, just need one)
|
5 |
+
image_size: 128
|
6 |
+
in_channels: 7
|
7 |
+
out_channels: 4
|
8 |
+
model_channels: 256
|
9 |
+
attention_resolutions: [ 2,4,8]
|
10 |
+
num_res_blocks: 2
|
11 |
+
channel_mult: [ 1, 2, 2, 4]
|
12 |
+
disable_self_attentions: [True, True, True, False]
|
13 |
+
disable_middle_self_attn: False
|
14 |
+
num_heads: 8
|
15 |
+
use_spatial_transformer: True
|
16 |
+
transformer_depth: 1
|
17 |
+
context_dim: 1024
|
18 |
+
legacy: False
|
19 |
+
use_linear_in_transformer: True
|
assets/config/vae-upsample.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__class__: smplfusion.models.vae.AutoencoderKL
|
2 |
+
__init__:
|
3 |
+
embed_dim: 4
|
4 |
+
ddconfig:
|
5 |
+
double_z: True
|
6 |
+
z_channels: 4
|
7 |
+
resolution: 256
|
8 |
+
in_channels: 3
|
9 |
+
out_ch: 3
|
10 |
+
ch: 128
|
11 |
+
ch_mult: [ 1,2,4 ]
|
12 |
+
num_res_blocks: 2
|
13 |
+
attn_resolutions: [ ]
|
14 |
+
dropout: 0.0
|
15 |
+
lossconfig:
|
16 |
+
target: torch.nn.Identity
|
assets/config/vae.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__class__: smplfusion.models.vae.AutoencoderKL
|
2 |
+
__init__:
|
3 |
+
embed_dim: 4
|
4 |
+
monitor: val/rec_loss
|
5 |
+
ddconfig:
|
6 |
+
double_z: true
|
7 |
+
z_channels: 4
|
8 |
+
resolution: 256
|
9 |
+
in_channels: 3
|
10 |
+
out_ch: 3
|
11 |
+
ch: 128
|
12 |
+
ch_mult: [1,2,4,4]
|
13 |
+
num_res_blocks: 2
|
14 |
+
attn_resolutions: []
|
15 |
+
dropout: 0.0
|
16 |
+
lossconfig:
|
17 |
+
target: torch.nn.Identity
|
assets/examples/images/a19.jpg
ADDED
Git LFS Details
|
assets/examples/images/a2.jpg
ADDED
Git LFS Details
|
assets/examples/images/a4.jpg
ADDED
Git LFS Details
|
assets/examples/images/a40.jpg
ADDED
Git LFS Details
|
assets/examples/images/a46.jpg
ADDED
Git LFS Details
|
assets/examples/images/a51.jpg
ADDED
Git LFS Details
|
assets/examples/images/a54.jpg
ADDED
Git LFS Details
|
assets/examples/images/a65.jpg
ADDED
Git LFS Details
|
assets/examples/masked/a19.png
ADDED
Git LFS Details
|
assets/examples/masked/a2.png
ADDED
Git LFS Details
|
assets/examples/masked/a4.png
ADDED
Git LFS Details
|
assets/examples/masked/a40.png
ADDED
Git LFS Details
|
assets/examples/masked/a46.png
ADDED
Git LFS Details
|
assets/examples/masked/a51.png
ADDED
Git LFS Details
|
assets/examples/masked/a54.png
ADDED
Git LFS Details
|
assets/examples/masked/a65.png
ADDED
Git LFS Details
|
assets/examples/sbs/a19.png
ADDED
Git LFS Details
|
assets/examples/sbs/a2.png
ADDED
Git LFS Details
|
assets/examples/sbs/a4.png
ADDED
Git LFS Details
|
assets/examples/sbs/a40.png
ADDED
Git LFS Details
|
assets/examples/sbs/a46.png
ADDED
Git LFS Details
|
assets/examples/sbs/a51.png
ADDED
Git LFS Details
|
assets/examples/sbs/a54.png
ADDED
Git LFS Details
|
assets/examples/sbs/a65.png
ADDED
Git LFS Details
|
lib/__init__.py
ADDED
File without changes
|
lib/methods/__init__.py
ADDED
File without changes
|
lib/methods/rasg.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from lib.utils.iimage import IImage
|
3 |
+
from pytorch_lightning import seed_everything
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
from lib.smplfusion import share, router, attentionpatch, transformerpatch
|
7 |
+
from lib.smplfusion.patches.attentionpatch import painta
|
8 |
+
from lib.utils import tokenize, scores
|
9 |
+
|
10 |
+
verbose = False
|
11 |
+
|
12 |
+
|
13 |
+
def init_painta(token_idx):
|
14 |
+
# Initialize painta
|
15 |
+
router.attention_forward = attentionpatch.painta.forward
|
16 |
+
router.basic_transformer_forward = transformerpatch.painta.forward
|
17 |
+
painta.painta_on = True
|
18 |
+
painta.painta_res = [16, 32]
|
19 |
+
painta.token_idx = token_idx
|
20 |
+
|
21 |
+
def init_guidance():
|
22 |
+
# Setup model for guidance only!
|
23 |
+
router.attention_forward = attentionpatch.default.forward_and_save
|
24 |
+
router.basic_transformer_forward = transformerpatch.default.forward
|
25 |
+
|
26 |
+
def run(ddim, method, prompt, image, mask, seed, eta, prefix, negative_prompt, positive_prompt, dt, guidance_scale):
|
27 |
+
# Text condition
|
28 |
+
prompt = prefix.format(prompt)
|
29 |
+
context = ddim.encoder.encode([negative_prompt, prompt + positive_prompt])
|
30 |
+
token_idx = list(range(1 + prefix.split(' ').index('{}'), tokenize(prompt).index('<end_of_text>')))
|
31 |
+
token_idx += [tokenize(prompt + positive_prompt).index('<end_of_text>')]
|
32 |
+
|
33 |
+
# Initialize painta
|
34 |
+
if 'painta' in method: init_painta(token_idx)
|
35 |
+
else: init_guidance()
|
36 |
+
|
37 |
+
# Image condition
|
38 |
+
unet_condition = ddim.get_inpainting_condition(image, mask)
|
39 |
+
share.set_mask(mask)
|
40 |
+
|
41 |
+
# Starting latent
|
42 |
+
seed_everything(seed)
|
43 |
+
zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda()
|
44 |
+
|
45 |
+
# Setup unet for guidance
|
46 |
+
ddim.unet.requires_grad_(True)
|
47 |
+
|
48 |
+
pbar = tqdm(range(999, 0, -dt)) if verbose else range(999, 0, -dt)
|
49 |
+
|
50 |
+
for timestep in share.DDIMIterator(pbar):
|
51 |
+
if 'painta' in method and share.timestep <= 500: init_guidance()
|
52 |
+
|
53 |
+
zt = zt.detach()
|
54 |
+
zt.requires_grad = True
|
55 |
+
|
56 |
+
# Reset storage
|
57 |
+
share._crossattn_similarity_res16 = []
|
58 |
+
|
59 |
+
# Run the model
|
60 |
+
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
|
61 |
+
eps_uncond, eps = ddim.unet(
|
62 |
+
torch.cat([_zt, _zt]),
|
63 |
+
timesteps = torch.tensor([timestep, timestep]).cuda(),
|
64 |
+
context = context
|
65 |
+
).detach().chunk(2)
|
66 |
+
|
67 |
+
# Unconditional guidance
|
68 |
+
eps = (eps_uncond + guidance_scale * (eps - eps_uncond))
|
69 |
+
z0 = (zt - share.schedule.sqrt_one_minus_alphas[timestep] * eps) / share.schedule.sqrt_alphas[timestep]
|
70 |
+
|
71 |
+
# Gradient Computation
|
72 |
+
score = scores.bce(share._crossattn_similarity_res16, share.mask16, token_idx = token_idx)
|
73 |
+
score.backward()
|
74 |
+
grad = zt.grad.detach()
|
75 |
+
ddim.unet.zero_grad() # Cleanup already
|
76 |
+
|
77 |
+
# DDIM Step
|
78 |
+
with torch.no_grad():
|
79 |
+
sigma = share.schedule.sigma(share.timestep, dt)
|
80 |
+
# Standartization
|
81 |
+
grad -= grad.mean()
|
82 |
+
grad /= grad.std()
|
83 |
+
|
84 |
+
zt = share.schedule.sqrt_alphas[share.timestep - dt] * z0 + torch.sqrt(1 - share.schedule.alphas[share.timestep - dt] - sigma ** 2) * eps + eta * sigma * grad
|
85 |
+
|
86 |
+
with torch.no_grad():
|
87 |
+
output_image = IImage(ddim.vae.decode(z0 / ddim.config.scale_factor))
|
88 |
+
return output_image
|
lib/methods/sd.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pytorch_lightning import seed_everything
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from lib.utils.iimage import IImage
|
6 |
+
from lib.smplfusion import share, router, attentionpatch, transformerpatch
|
7 |
+
from lib.smplfusion.patches.attentionpatch import painta
|
8 |
+
from lib.utils import tokenize
|
9 |
+
|
10 |
+
verbose = False
|
11 |
+
|
12 |
+
|
13 |
+
def init_painta(token_idx):
|
14 |
+
# Initialize painta
|
15 |
+
router.attention_forward = attentionpatch.painta.forward
|
16 |
+
router.basic_transformer_forward = transformerpatch.painta.forward
|
17 |
+
painta.painta_on = True
|
18 |
+
painta.painta_res = [16, 32]
|
19 |
+
painta.token_idx = token_idx
|
20 |
+
|
21 |
+
def run(
|
22 |
+
ddim,
|
23 |
+
method,
|
24 |
+
prompt,
|
25 |
+
image,
|
26 |
+
mask,
|
27 |
+
seed,
|
28 |
+
eta,
|
29 |
+
prefix,
|
30 |
+
negative_prompt,
|
31 |
+
positive_prompt,
|
32 |
+
dt,
|
33 |
+
guidance_scale
|
34 |
+
):
|
35 |
+
# Text condition
|
36 |
+
context = ddim.encoder.encode([negative_prompt, prompt + positive_prompt])
|
37 |
+
token_idx = list(range(1 + prefix.split(' ').index('{}'), tokenize(prompt).index('<end_of_text>')))
|
38 |
+
token_idx += [tokenize(prompt + positive_prompt).index('<end_of_text>')]
|
39 |
+
|
40 |
+
# Setup painta if needed
|
41 |
+
if 'painta' in method: init_painta(token_idx)
|
42 |
+
else: router.reset()
|
43 |
+
|
44 |
+
# Image condition
|
45 |
+
unet_condition = ddim.get_inpainting_condition(image, mask)
|
46 |
+
share.set_mask(mask)
|
47 |
+
|
48 |
+
# Starting latent
|
49 |
+
seed_everything(seed)
|
50 |
+
zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda()
|
51 |
+
|
52 |
+
# Turn off gradients
|
53 |
+
ddim.unet.requires_grad_(False)
|
54 |
+
|
55 |
+
pbar = tqdm(range(999, 0, -dt)) if verbose else range(999, 0, -dt)
|
56 |
+
|
57 |
+
for timestep in share.DDIMIterator(pbar):
|
58 |
+
if share.timestep <= 500: router.reset()
|
59 |
+
|
60 |
+
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
|
61 |
+
eps_uncond, eps = ddim.unet(
|
62 |
+
torch.cat([_zt, _zt]),
|
63 |
+
timesteps = torch.tensor([timestep, timestep]).cuda(),
|
64 |
+
context = context
|
65 |
+
).chunk(2)
|
66 |
+
|
67 |
+
eps = (eps_uncond + guidance_scale * (eps - eps_uncond))
|
68 |
+
z0 = (zt - share.schedule.sqrt_one_minus_alphas[timestep] * eps) / share.schedule.sqrt_alphas[timestep]
|
69 |
+
zt = share.schedule.sqrt_alphas[timestep - dt] * z0 + share.schedule.sqrt_one_minus_alphas[timestep - dt] * eps
|
70 |
+
|
71 |
+
with torch.no_grad():
|
72 |
+
output_image = IImage(ddim.vae.decode(z0 / ddim.config.scale_factor))
|
73 |
+
|
74 |
+
return output_image
|
lib/methods/sr.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from functools import partial
|
3 |
+
from glob import glob
|
4 |
+
from pathlib import Path as PythonPath
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import torchvision.transforms.functional as TvF
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import numpy as np
|
11 |
+
from inspect import isfunction
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
from lib import smplfusion
|
15 |
+
from lib.smplfusion import share, router, attentionpatch, transformerpatch
|
16 |
+
from lib.utils.iimage import IImage
|
17 |
+
from lib.utils import poisson_blend
|
18 |
+
from lib.models.sd2_sr import predict_eps_from_z_and_v, predict_start_from_z_and_v
|
19 |
+
|
20 |
+
|
21 |
+
def refine_mask(hr_image, hr_mask, lr_image, sam_predictor):
|
22 |
+
lr_mask = hr_mask.resize(512)
|
23 |
+
|
24 |
+
x_min, y_min, rect_w, rect_h = cv2.boundingRect(lr_mask.data[0][:, :, 0])
|
25 |
+
x_min = max(x_min - 1, 0)
|
26 |
+
y_min = max(y_min - 1, 0)
|
27 |
+
x_max = x_min + rect_w + 1
|
28 |
+
y_max = y_min + rect_h + 1
|
29 |
+
|
30 |
+
input_box = np.array([x_min, y_min, x_max, y_max])
|
31 |
+
|
32 |
+
sam_predictor.set_image(hr_image.resize(512).data[0])
|
33 |
+
masks, _, _ = sam_predictor.predict(
|
34 |
+
point_coords=None,
|
35 |
+
point_labels=None,
|
36 |
+
box=input_box[None, :],
|
37 |
+
multimask_output=True,
|
38 |
+
)
|
39 |
+
dilation_kernel = np.ones((13, 13))
|
40 |
+
original_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8)
|
41 |
+
original_object_mask = cv2.dilate(original_object_mask, dilation_kernel)
|
42 |
+
|
43 |
+
sam_predictor.set_image(lr_image.resize(512).data[0])
|
44 |
+
masks, _, _ = sam_predictor.predict(
|
45 |
+
point_coords=None,
|
46 |
+
point_labels=None,
|
47 |
+
box=input_box[None, :],
|
48 |
+
multimask_output=True,
|
49 |
+
)
|
50 |
+
dilation_kernel = np.ones((3, 3))
|
51 |
+
inpainted_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8)
|
52 |
+
inpainted_object_mask = cv2.dilate(inpainted_object_mask, dilation_kernel)
|
53 |
+
|
54 |
+
lr_mask_masking = ((original_object_mask + inpainted_object_mask ) > 0).astype(np.uint8)
|
55 |
+
new_mask = lr_mask.data[0] * lr_mask_masking[:, :, np.newaxis]
|
56 |
+
new_mask = IImage(new_mask).resize(2048, resample = Image.BICUBIC)
|
57 |
+
return new_mask
|
58 |
+
|
59 |
+
|
60 |
+
def run(ddim, sam_predictor, lr_image, hr_image, hr_mask, prompt = 'high resolution professional photo', noise_level=20,
|
61 |
+
blend_output = True, blend_trick = True, no_superres = False,
|
62 |
+
dt = 20, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = False, dtype=torch.bfloat16):
|
63 |
+
torch.manual_seed(seed)
|
64 |
+
|
65 |
+
router.attention_forward = attentionpatch.default.forward_xformers
|
66 |
+
router.basic_transformer_forward = transformerpatch.default.forward
|
67 |
+
|
68 |
+
if use_sam_mask:
|
69 |
+
with torch.no_grad():
|
70 |
+
hr_mask = refine_mask(hr_image, hr_mask, lr_image, sam_predictor)
|
71 |
+
|
72 |
+
orig_h, orig_w = hr_image.torch().shape[2], hr_image.torch().shape[3]
|
73 |
+
hr_image = hr_image.padx(256, padding_mode='reflect')
|
74 |
+
hr_mask = hr_mask.padx(256, padding_mode='reflect').dilate(19)
|
75 |
+
hr_mask_orig = hr_mask
|
76 |
+
lr_image = lr_image.padx(64, padding_mode='reflect')
|
77 |
+
lr_mask = hr_mask.resize((lr_image.torch().shape[2], lr_image.torch().shape[3]), resample = Image.BICUBIC).alpha().torch(vmin=0).cuda()
|
78 |
+
lr_mask = TvF.gaussian_blur(lr_mask, kernel_size=19)
|
79 |
+
|
80 |
+
if no_superres:
|
81 |
+
output_tensor = lr_image.resize((hr_image.torch().shape[2], hr_image.torch().shape[3]), resample = Image.BICUBIC).torch().cuda()
|
82 |
+
output_tensor = (255*((output_tensor.clip(-1, 1) + 1) / 2)).to(torch.uint8)
|
83 |
+
output_tensor = poisson_blend(
|
84 |
+
orig_img=hr_image.data[0][:orig_h, :orig_w, :],
|
85 |
+
fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :],
|
86 |
+
mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :]
|
87 |
+
)
|
88 |
+
return IImage(output_tensor[:orig_h, :orig_w, :])
|
89 |
+
|
90 |
+
# encode hr image
|
91 |
+
with torch.no_grad():
|
92 |
+
hr_z0 = ddim.vae.encode(hr_image.torch().cuda().to(dtype)).mean * ddim.config.scale_factor
|
93 |
+
|
94 |
+
assert hr_z0.shape[2] == lr_image.torch().shape[2]
|
95 |
+
assert hr_z0.shape[3] == lr_image.torch().shape[3]
|
96 |
+
|
97 |
+
unet_condition = lr_image.cuda().torch().to(memory_format=torch.contiguous_format).to(dtype)
|
98 |
+
zT = torch.randn((1,4,unet_condition.shape[2], unet_condition.shape[3])).cuda().to(dtype)
|
99 |
+
|
100 |
+
with torch.no_grad():
|
101 |
+
context = ddim.encoder.encode([negative_prompt, prompt])
|
102 |
+
|
103 |
+
noise_level = torch.Tensor(1 * [noise_level]).to('cuda').long()
|
104 |
+
unet_condition, noise_level = ddim.low_scale_model(unet_condition, noise_level=noise_level)
|
105 |
+
|
106 |
+
with torch.autocast('cuda'), torch.no_grad():
|
107 |
+
zt = zT
|
108 |
+
for index,t in enumerate(range(999, 0, -dt)):
|
109 |
+
|
110 |
+
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
|
111 |
+
|
112 |
+
eps_uncond, eps = ddim.unet(
|
113 |
+
torch.cat([_zt, _zt]).to(dtype),
|
114 |
+
timesteps = torch.tensor([t, t]).cuda(),
|
115 |
+
context = context,
|
116 |
+
y=torch.cat([noise_level]*2)
|
117 |
+
).chunk(2)
|
118 |
+
|
119 |
+
ts = torch.full((zt.shape[0],), t, device='cuda', dtype=torch.long)
|
120 |
+
model_output = (eps_uncond + guidance_scale * (eps - eps_uncond))
|
121 |
+
eps = predict_eps_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
|
122 |
+
z0 = predict_start_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
|
123 |
+
|
124 |
+
if blend_trick:
|
125 |
+
z0 = z0 * lr_mask + hr_z0 * (1-lr_mask)
|
126 |
+
|
127 |
+
zt = ddim.schedule.sqrt_alphas[t - dt] * z0 + ddim.schedule.sqrt_one_minus_alphas[t - dt] * eps
|
128 |
+
|
129 |
+
with torch.no_grad():
|
130 |
+
output_tensor = ddim.vae.decode(z0.to(dtype) / ddim.config.scale_factor)
|
131 |
+
|
132 |
+
if blend_output:
|
133 |
+
output_tensor = (255*((output_tensor + 1) / 2).clip(0, 1)).to(torch.uint8)
|
134 |
+
output_tensor = poisson_blend(
|
135 |
+
orig_img=hr_image.data[0][:orig_h, :orig_w, :],
|
136 |
+
fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :],
|
137 |
+
mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :]
|
138 |
+
)
|
139 |
+
return IImage(output_tensor[:orig_h, :orig_w, :])
|
140 |
+
else:
|
141 |
+
return IImage(output_tensor[:, :, :orig_h, :orig_w])
|
lib/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import sd2_inp, ds_inp, sd15_inp, sd2_sr, sam
|
lib/models/common.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import requests
|
3 |
+
from pathlib import Path
|
4 |
+
from os.path import dirname
|
5 |
+
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
PROJECT_DIR = dirname(dirname(dirname(__file__)))
|
11 |
+
CONFIG_FOLDER = f'{PROJECT_DIR}/assets/config'
|
12 |
+
MODEL_FOLDER = f'{PROJECT_DIR}/assets/models'
|
13 |
+
|
14 |
+
|
15 |
+
def download_file(url, save_path, chunk_size=1024):
|
16 |
+
try:
|
17 |
+
save_path = Path(save_path)
|
18 |
+
if save_path.exists():
|
19 |
+
print(f'{save_path.name} exists')
|
20 |
+
return
|
21 |
+
save_path.parent.mkdir(exist_ok=True, parents=True)
|
22 |
+
resp = requests.get(url, stream=True)
|
23 |
+
total = int(resp.headers.get('content-length', 0))
|
24 |
+
with open(save_path, 'wb') as file, tqdm(
|
25 |
+
desc=save_path.name,
|
26 |
+
total=total,
|
27 |
+
unit='iB',
|
28 |
+
unit_scale=True,
|
29 |
+
unit_divisor=1024,
|
30 |
+
) as bar:
|
31 |
+
for data in resp.iter_content(chunk_size=chunk_size):
|
32 |
+
size = file.write(data)
|
33 |
+
bar.update(size)
|
34 |
+
print(f'{save_path.name} download finished')
|
35 |
+
except Exception as e:
|
36 |
+
raise Exception(f"Download failed: {e}")
|
37 |
+
|
38 |
+
|
39 |
+
def get_obj_from_str(string):
|
40 |
+
module, cls = string.rsplit(".", 1)
|
41 |
+
try:
|
42 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
43 |
+
except:
|
44 |
+
return getattr(importlib.import_module('lib.' + module, package=None), cls)
|
45 |
+
|
46 |
+
|
47 |
+
def load_obj(path):
|
48 |
+
objyaml = OmegaConf.load(path)
|
49 |
+
return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
|
lib/models/ds_inp.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from omegaconf import OmegaConf
|
3 |
+
import torch
|
4 |
+
import safetensors
|
5 |
+
import safetensors.torch
|
6 |
+
|
7 |
+
from lib.smplfusion import DDIM, share, scheduler
|
8 |
+
from .common import *
|
9 |
+
|
10 |
+
|
11 |
+
MODEL_PATH = f'{MODEL_FOLDER}/dreamshaper/dreamshaper_8Inpainting.safetensors'
|
12 |
+
DOWNLOAD_URL = 'https://civitai.com/api/download/models/131004'
|
13 |
+
|
14 |
+
# pre-download
|
15 |
+
download_file(DOWNLOAD_URL, MODEL_PATH)
|
16 |
+
|
17 |
+
|
18 |
+
def load_model():
|
19 |
+
print ("Loading model: Dreamshaper Inpainting V8")
|
20 |
+
|
21 |
+
download_file(DOWNLOAD_URL, MODEL_PATH)
|
22 |
+
|
23 |
+
state_dict = safetensors.torch.load_file(MODEL_PATH)
|
24 |
+
|
25 |
+
config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
|
26 |
+
unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v1.yaml').eval().cuda()
|
27 |
+
vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()
|
28 |
+
encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda()
|
29 |
+
|
30 |
+
extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
|
31 |
+
unet_state = extract(state_dict, 'model.diffusion_model')
|
32 |
+
encoder_state = extract(state_dict, 'cond_stage_model')
|
33 |
+
vae_state = extract(state_dict, 'first_stage_model')
|
34 |
+
|
35 |
+
unet.load_state_dict(unet_state)
|
36 |
+
encoder.load_state_dict(encoder_state)
|
37 |
+
vae.load_state_dict(vae_state)
|
38 |
+
|
39 |
+
unet = unet.requires_grad_(False)
|
40 |
+
encoder = encoder.requires_grad_(False)
|
41 |
+
vae = vae.requires_grad_(False)
|
42 |
+
|
43 |
+
ddim = DDIM(config, vae, encoder, unet)
|
44 |
+
share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
|
45 |
+
|
46 |
+
return ddim
|
lib/models/sam.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from segment_anything import sam_model_registry, SamPredictor
|
2 |
+
from .common import *
|
3 |
+
|
4 |
+
MODEL_PATH = f'{MODEL_FOLDER}/sam/sam_vit_h_4b8939.pth'
|
5 |
+
DOWNLOAD_URL = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
|
6 |
+
|
7 |
+
# pre-download
|
8 |
+
download_file(DOWNLOAD_URL, MODEL_PATH)
|
9 |
+
|
10 |
+
|
11 |
+
def load_model():
|
12 |
+
print ("Loading model: SAM")
|
13 |
+
download_file(DOWNLOAD_URL, MODEL_PATH)
|
14 |
+
model_type = "vit_h"
|
15 |
+
device = "cuda"
|
16 |
+
sam = sam_model_registry[model_type](checkpoint=MODEL_PATH)
|
17 |
+
sam.to(device=device)
|
18 |
+
sam_predictor = SamPredictor(sam)
|
19 |
+
print ("SAM loaded")
|
20 |
+
return sam_predictor
|
lib/models/sd15_inp.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from omegaconf import OmegaConf
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from lib.smplfusion import DDIM, share, scheduler
|
5 |
+
from .common import *
|
6 |
+
|
7 |
+
|
8 |
+
DOWNLOAD_URL = 'https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt?download=true'
|
9 |
+
MODEL_PATH = f'{MODEL_FOLDER}/sd-1-5-inpainting/sd-v1-5-inpainting.ckpt'
|
10 |
+
|
11 |
+
# pre-download
|
12 |
+
download_file(DOWNLOAD_URL, MODEL_PATH)
|
13 |
+
|
14 |
+
|
15 |
+
def load_model():
|
16 |
+
download_file(DOWNLOAD_URL, MODEL_PATH)
|
17 |
+
|
18 |
+
state_dict = torch.load(MODEL_PATH)['state_dict']
|
19 |
+
|
20 |
+
config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
|
21 |
+
|
22 |
+
print ("Loading model: Stable-Inpainting 1.5")
|
23 |
+
|
24 |
+
unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v1.yaml').eval().cuda()
|
25 |
+
vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()
|
26 |
+
encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda()
|
27 |
+
|
28 |
+
extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
|
29 |
+
unet_state = extract(state_dict, 'model.diffusion_model')
|
30 |
+
encoder_state = extract(state_dict, 'cond_stage_model')
|
31 |
+
vae_state = extract(state_dict, 'first_stage_model')
|
32 |
+
|
33 |
+
unet.load_state_dict(unet_state)
|
34 |
+
encoder.load_state_dict(encoder_state)
|
35 |
+
vae.load_state_dict(vae_state)
|
36 |
+
|
37 |
+
unet = unet.requires_grad_(False)
|
38 |
+
encoder = encoder.requires_grad_(False)
|
39 |
+
vae = vae.requires_grad_(False)
|
40 |
+
|
41 |
+
ddim = DDIM(config, vae, encoder, unet)
|
42 |
+
share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
|
43 |
+
|
44 |
+
return ddim
|
lib/models/sd2_inp.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import safetensors
|
2 |
+
import safetensors.torch
|
3 |
+
import torch
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
|
6 |
+
from lib.smplfusion import DDIM, share, scheduler
|
7 |
+
from .common import *
|
8 |
+
|
9 |
+
MODEL_PATH = f'{MODEL_FOLDER}/sd-2-0-inpainting/512-inpainting-ema.safetensors'
|
10 |
+
DOWNLOAD_URL = 'https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/512-inpainting-ema.safetensors?download=true'
|
11 |
+
|
12 |
+
# pre-download
|
13 |
+
download_file(DOWNLOAD_URL, MODEL_PATH)
|
14 |
+
|
15 |
+
|
16 |
+
def load_model():
|
17 |
+
print ("Loading model: Stable-Inpainting 2.0")
|
18 |
+
|
19 |
+
download_file(DOWNLOAD_URL, MODEL_PATH)
|
20 |
+
|
21 |
+
state_dict = safetensors.torch.load_file(MODEL_PATH)
|
22 |
+
|
23 |
+
config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
|
24 |
+
|
25 |
+
unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v2.yaml').eval().cuda()
|
26 |
+
vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()
|
27 |
+
encoder = load_obj(f'{CONFIG_FOLDER}/encoders/openclip.yaml').eval().cuda()
|
28 |
+
ddim = DDIM(config, vae, encoder, unet)
|
29 |
+
|
30 |
+
extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
|
31 |
+
unet_state = extract(state_dict, 'model.diffusion_model')
|
32 |
+
encoder_state = extract(state_dict, 'cond_stage_model')
|
33 |
+
vae_state = extract(state_dict, 'first_stage_model')
|
34 |
+
|
35 |
+
unet.load_state_dict(unet_state)
|
36 |
+
encoder.load_state_dict(encoder_state)
|
37 |
+
vae.load_state_dict(vae_state)
|
38 |
+
|
39 |
+
unet = unet.requires_grad_(False)
|
40 |
+
encoder = encoder.requires_grad_(False)
|
41 |
+
vae = vae.requires_grad_(False)
|
42 |
+
|
43 |
+
ddim = DDIM(config, vae, encoder, unet)
|
44 |
+
share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
|
45 |
+
|
46 |
+
print('Stable-Inpainting 2.0 loaded')
|
47 |
+
return ddim
|
lib/models/sd2_sr.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import safetensors
|
7 |
+
import safetensors.torch
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from inspect import isfunction
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
|
13 |
+
from lib.smplfusion import DDIM, share, scheduler
|
14 |
+
from .common import *
|
15 |
+
|
16 |
+
|
17 |
+
DOWNLOAD_URL = 'https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/resolve/main/x4-upscaler-ema.safetensors?download=true'
|
18 |
+
MODEL_PATH = f'{MODEL_FOLDER}/sd-2-0-upsample/x4-upscaler-ema.safetensors'
|
19 |
+
|
20 |
+
# pre-download
|
21 |
+
download_file(DOWNLOAD_URL, MODEL_PATH)
|
22 |
+
|
23 |
+
|
24 |
+
def exists(x):
|
25 |
+
return x is not None
|
26 |
+
|
27 |
+
|
28 |
+
def default(val, d):
|
29 |
+
if exists(val):
|
30 |
+
return val
|
31 |
+
return d() if isfunction(d) else d
|
32 |
+
|
33 |
+
|
34 |
+
def extract_into_tensor(a, t, x_shape):
|
35 |
+
b, *_ = t.shape
|
36 |
+
out = a.gather(-1, t)
|
37 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
38 |
+
|
39 |
+
|
40 |
+
def predict_eps_from_z_and_v(schedule, x_t, t, v):
|
41 |
+
return (
|
42 |
+
extract_into_tensor(schedule.sqrt_alphas.cuda(), t, x_t.shape) * v +
|
43 |
+
extract_into_tensor(schedule.sqrt_one_minus_alphas.cuda(), t, x_t.shape) * x_t
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
def predict_start_from_z_and_v(schedule, x_t, t, v):
|
48 |
+
return (
|
49 |
+
extract_into_tensor(schedule.sqrt_alphas.cuda(), t, x_t.shape) * x_t -
|
50 |
+
extract_into_tensor(schedule.sqrt_one_minus_alphas.cuda(), t, x_t.shape) * v
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
55 |
+
if schedule == "linear":
|
56 |
+
betas = (
|
57 |
+
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
58 |
+
)
|
59 |
+
|
60 |
+
elif schedule == "cosine":
|
61 |
+
timesteps = (
|
62 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
63 |
+
)
|
64 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
65 |
+
alphas = torch.cos(alphas).pow(2)
|
66 |
+
alphas = alphas / alphas[0]
|
67 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
68 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
69 |
+
|
70 |
+
elif schedule == "sqrt_linear":
|
71 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
72 |
+
elif schedule == "sqrt":
|
73 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
74 |
+
else:
|
75 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
76 |
+
return betas.numpy()
|
77 |
+
|
78 |
+
|
79 |
+
def disabled_train(self, mode=True):
|
80 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
81 |
+
does not change anymore."""
|
82 |
+
return self
|
83 |
+
|
84 |
+
|
85 |
+
class AbstractLowScaleModel(nn.Module):
|
86 |
+
# for concatenating a downsampled image to the latent representation
|
87 |
+
def __init__(self, noise_schedule_config=None):
|
88 |
+
super(AbstractLowScaleModel, self).__init__()
|
89 |
+
if noise_schedule_config is not None:
|
90 |
+
self.register_schedule(**noise_schedule_config)
|
91 |
+
|
92 |
+
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
93 |
+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
94 |
+
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
95 |
+
cosine_s=cosine_s)
|
96 |
+
alphas = 1. - betas
|
97 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
98 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
99 |
+
|
100 |
+
timesteps, = betas.shape
|
101 |
+
self.num_timesteps = int(timesteps)
|
102 |
+
self.linear_start = linear_start
|
103 |
+
self.linear_end = linear_end
|
104 |
+
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
105 |
+
|
106 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
107 |
+
|
108 |
+
self.register_buffer('betas', to_torch(betas))
|
109 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
110 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
111 |
+
|
112 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
113 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
114 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
115 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
116 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
117 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
118 |
+
|
119 |
+
def q_sample(self, x_start, t, noise=None):
|
120 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
121 |
+
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
122 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
return x, None
|
126 |
+
|
127 |
+
def decode(self, x):
|
128 |
+
return x
|
129 |
+
|
130 |
+
|
131 |
+
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
|
132 |
+
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
|
133 |
+
super().__init__(noise_schedule_config=noise_schedule_config)
|
134 |
+
self.max_noise_level = max_noise_level
|
135 |
+
|
136 |
+
def forward(self, x, noise_level=None):
|
137 |
+
if noise_level is None:
|
138 |
+
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
139 |
+
else:
|
140 |
+
assert isinstance(noise_level, torch.Tensor)
|
141 |
+
z = self.q_sample(x, noise_level)
|
142 |
+
return z, noise_level
|
143 |
+
|
144 |
+
|
145 |
+
def get_obj_from_str(string):
|
146 |
+
module, cls = string.rsplit(".", 1)
|
147 |
+
try:
|
148 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
149 |
+
except:
|
150 |
+
return getattr(importlib.import_module('lib.' + module, package=None), cls)
|
151 |
+
def load_obj(path):
|
152 |
+
objyaml = OmegaConf.load(path)
|
153 |
+
return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
|
154 |
+
|
155 |
+
|
156 |
+
def load_model(dtype=torch.bfloat16):
|
157 |
+
print ("Loading model: SD2 superresolution...")
|
158 |
+
|
159 |
+
download_file(DOWNLOAD_URL, MODEL_PATH)
|
160 |
+
|
161 |
+
state_dict = safetensors.torch.load_file(MODEL_PATH)
|
162 |
+
|
163 |
+
config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v2-upsample.yaml')
|
164 |
+
|
165 |
+
unet = load_obj(f'{CONFIG_FOLDER}/unet/upsample/v2.yaml').eval().cuda()
|
166 |
+
vae = load_obj(f'{CONFIG_FOLDER}/vae-upsample.yaml').eval().cuda()
|
167 |
+
encoder = load_obj(f'{CONFIG_FOLDER}/encoders/openclip.yaml').eval().cuda()
|
168 |
+
ddim = DDIM(config, vae, encoder, unet)
|
169 |
+
|
170 |
+
extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
|
171 |
+
unet_state = extract(state_dict, 'model.diffusion_model')
|
172 |
+
encoder_state = extract(state_dict, 'cond_stage_model')
|
173 |
+
vae_state = extract(state_dict, 'first_stage_model')
|
174 |
+
|
175 |
+
unet.load_state_dict(unet_state)
|
176 |
+
encoder.load_state_dict(encoder_state)
|
177 |
+
vae.load_state_dict(vae_state)
|
178 |
+
|
179 |
+
unet = unet.requires_grad_(False)
|
180 |
+
encoder = encoder.requires_grad_(False)
|
181 |
+
vae = vae.requires_grad_(False)
|
182 |
+
|
183 |
+
unet.to(dtype)
|
184 |
+
vae.to(dtype)
|
185 |
+
encoder.to(dtype)
|
186 |
+
|
187 |
+
ddim = DDIM(config, vae, encoder, unet)
|
188 |
+
|
189 |
+
params = {
|
190 |
+
'noise_schedule_config': {
|
191 |
+
'linear_start': 0.0001,
|
192 |
+
'linear_end': 0.02
|
193 |
+
},
|
194 |
+
'max_noise_level': 350
|
195 |
+
}
|
196 |
+
|
197 |
+
low_scale_model = ImageConcatWithNoiseAugmentation(**params).eval().to('cuda')
|
198 |
+
low_scale_model.train = disabled_train
|
199 |
+
for param in low_scale_model.parameters():
|
200 |
+
param.requires_grad = False
|
201 |
+
|
202 |
+
ddim.low_scale_model = low_scale_model
|
203 |
+
print('SD2 superresolution loaded')
|
204 |
+
return ddim
|