lnyan commited on
Commit
f5c3a0a
·
1 Parent(s): 09c675d

Update files

Browse files
Files changed (7) hide show
  1. app.py +1033 -427
  2. canvas.py +650 -548
  3. index.html +411 -214
  4. perlin2d.py +44 -44
  5. postprocess.py +249 -0
  6. process.py +395 -0
  7. utils.py +263 -151
app.py CHANGED
@@ -1,427 +1,1033 @@
1
- import io
2
- import base64
3
- import os
4
-
5
- import numpy as np
6
- import torch
7
- from torch import autocast
8
- from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
9
- from PIL import Image
10
- from PIL import ImageOps
11
- import gradio as gr
12
- import base64
13
- import skimage
14
- import skimage.measure
15
- from utils import *
16
-
17
- try:
18
- cuda_available = torch.cuda.is_available()
19
- except:
20
- cuda_available = False
21
- finally:
22
- if cuda_available:
23
- device = "cuda"
24
- else:
25
- device = "cpu"
26
-
27
- if device != "cuda":
28
- import contextlib
29
- autocast = contextlib.nullcontext
30
-
31
- def load_html():
32
- body, canvaspy = "", ""
33
- with open("index.html", encoding="utf8") as f:
34
- body = f.read()
35
- with open("canvas.py", encoding="utf8") as f:
36
- canvaspy = f.read()
37
- body = body.replace("- paths:\n", "")
38
- body = body.replace(" - ./canvas.py\n", "")
39
- body = body.replace("from canvas import InfCanvas", canvaspy)
40
- return body
41
-
42
-
43
- def test(x):
44
- x = load_html()
45
- return f"""<iframe id="sdinfframe" style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
46
- display-capture; encrypted-media;" sandbox="allow-modals allow-forms
47
- allow-scripts allow-same-origin allow-popups
48
- allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
49
- allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
50
-
51
-
52
- DEBUG_MODE = False
53
-
54
- try:
55
- SAMPLING_MODE = Image.Resampling.LANCZOS
56
- except Exception as e:
57
- SAMPLING_MODE = Image.LANCZOS
58
-
59
- try:
60
- contain_func = ImageOps.contain
61
- except Exception as e:
62
-
63
- def contain_func(image, size, method=SAMPLING_MODE):
64
- # from PIL: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html#PIL.ImageOps.contain
65
- im_ratio = image.width / image.height
66
- dest_ratio = size[0] / size[1]
67
- if im_ratio != dest_ratio:
68
- if im_ratio > dest_ratio:
69
- new_height = int(image.height / image.width * size[0])
70
- if new_height != size[1]:
71
- size = (size[0], new_height)
72
- else:
73
- new_width = int(image.width / image.height * size[1])
74
- if new_width != size[0]:
75
- size = (new_width, size[1])
76
- return image.resize(size, resample=method)
77
-
78
-
79
- PAINT_SELECTION = "✥"
80
- IMAGE_SELECTION = "🖼️"
81
- BRUSH_SELECTION = "🖌️"
82
- blocks = gr.Blocks()
83
- model = {}
84
- model["width"] = 1500
85
- model["height"] = 600
86
- model["sel_size"] = 256
87
-
88
- def get_token():
89
- token = ""
90
- token = os.environ.get("hftoken", token)
91
- return token
92
-
93
-
94
- def save_token(token):
95
- return
96
-
97
-
98
- def get_model(token=""):
99
- if "text2img" not in model:
100
- if device=="cuda":
101
- text2img = StableDiffusionPipeline.from_pretrained(
102
- "CompVis/stable-diffusion-v1-4",
103
- revision="fp16",
104
- torch_dtype=torch.float16,
105
- use_auth_token=token,
106
- ).to(device)
107
- else:
108
- text2img = StableDiffusionPipeline.from_pretrained(
109
- "CompVis/stable-diffusion-v1-4",
110
- use_auth_token=token,
111
- ).to(device)
112
- model["safety_checker"] = text2img.safety_checker
113
- inpaint = StableDiffusionInpaintPipeline(
114
- vae=text2img.vae,
115
- text_encoder=text2img.text_encoder,
116
- tokenizer=text2img.tokenizer,
117
- unet=text2img.unet,
118
- scheduler=text2img.scheduler,
119
- safety_checker=text2img.safety_checker,
120
- feature_extractor=text2img.feature_extractor,
121
- ).to(device)
122
- save_token(token)
123
- try:
124
- total_memory = torch.cuda.get_device_properties(0).total_memory // (
125
- 1024 ** 3
126
- )
127
- if total_memory <= 5:
128
- inpaint.enable_attention_slicing()
129
- except:
130
- pass
131
- model["text2img"] = text2img
132
- model["inpaint"] = inpaint
133
- return model["text2img"], model["inpaint"]
134
-
135
-
136
- def run_outpaint(
137
- sel_buffer_str,
138
- prompt_text,
139
- strength,
140
- guidance,
141
- step,
142
- resize_check,
143
- fill_mode,
144
- enable_safety,
145
- state,
146
- ):
147
- base64_str = "base64"
148
- if not cuda_available:
149
- data = base64.b64decode(str(sel_buffer_str))
150
- pil = Image.open(io.BytesIO(data))
151
- sel_buffer = np.array(pil)
152
- sel_buffer[:, :, 3]=255
153
- sel_buffer[:, :, 0]=255
154
- out_pil = Image.fromarray(sel_buffer)
155
- out_buffer = io.BytesIO()
156
- out_pil.save(out_buffer, format="PNG")
157
- out_buffer.seek(0)
158
- base64_bytes = base64.b64encode(out_buffer.read())
159
- base64_str = base64_bytes.decode("ascii")
160
- return (
161
- gr.update(label=str(state + 1), value=base64_str,),
162
- gr.update(label="Prompt"),
163
- state + 1,
164
- )
165
- if True:
166
- text2img, inpaint = get_model()
167
- if enable_safety:
168
- text2img.safety_checker = model["safety_checker"]
169
- inpaint.safety_checker = model["safety_checker"]
170
- else:
171
- text2img.safety_checker = lambda images, **kwargs: (images, False)
172
- inpaint.safety_checker = lambda images, **kwargs: (images, False)
173
- data = base64.b64decode(str(sel_buffer_str))
174
- pil = Image.open(io.BytesIO(data))
175
- # base.output.clear_output()
176
- # base.read_selection_from_buffer()
177
- sel_buffer = np.array(pil)
178
- img = sel_buffer[:, :, 0:3]
179
- mask = sel_buffer[:, :, -1]
180
- process_size = 512 if resize_check else model["sel_size"]
181
- if mask.sum() > 0:
182
- img, mask = functbl[fill_mode](img, mask)
183
- init_image = Image.fromarray(img)
184
- mask = 255 - mask
185
- mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
186
- mask = mask.repeat(8, axis=0).repeat(8, axis=1)
187
- mask_image = Image.fromarray(mask)
188
- # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
189
- with autocast("cuda"):
190
- images = inpaint(
191
- prompt=prompt_text,
192
- init_image=init_image.resize(
193
- (process_size, process_size), resample=SAMPLING_MODE
194
- ),
195
- mask_image=mask_image.resize((process_size, process_size)),
196
- strength=strength,
197
- num_inference_steps=step,
198
- guidance_scale=guidance,
199
- )["sample"]
200
- else:
201
- with autocast("cuda"):
202
- images = text2img(
203
- prompt=prompt_text, height=process_size, width=process_size,
204
- )["sample"]
205
- out = sel_buffer.copy()
206
- out[:, :, 0:3] = np.array(
207
- images[0].resize(
208
- (model["sel_size"], model["sel_size"]), resample=SAMPLING_MODE,
209
- )
210
- )
211
- out[:, :, -1] = 255
212
- out_pil = Image.fromarray(out)
213
- out_buffer = io.BytesIO()
214
- out_pil.save(out_buffer, format="PNG")
215
- out_buffer.seek(0)
216
- base64_bytes = base64.b64encode(out_buffer.read())
217
- base64_str = base64_bytes.decode("ascii")
218
- return (
219
- gr.update(label=str(state + 1), value=base64_str,),
220
- gr.update(label="Prompt"),
221
- state + 1,
222
- )
223
-
224
-
225
- def load_js(name):
226
- if name in ["export", "commit", "undo"]:
227
- return f"""
228
- function (x)
229
- {{
230
- let frame=document.querySelector("gradio-app").querySelector("#sdinfframe").contentWindow;
231
- frame.postMessage(["click","{name}"], "*");
232
- return x;
233
- }}
234
- """
235
- ret = ""
236
- with open(f"./js/{name}.js", "r") as f:
237
- ret = f.read()
238
- return ret
239
-
240
-
241
- upload_button_js = load_js("upload")
242
- outpaint_button_js = load_js("outpaint")
243
- proceed_button_js = load_js("proceed")
244
- mode_js = load_js("mode")
245
- setup_button_js = load_js("setup")
246
- if not cuda_available:
247
- get_model = lambda x:x
248
- get_model(get_token())
249
-
250
- with blocks as demo:
251
- # title
252
- title = gr.Markdown(
253
- """
254
- **stablediffusion-infinity**: Outpainting with Stable Diffusion on an infinite canvas: [https://github.com/lkwq007/stablediffusion-infinity](https://github.com/lkwq007/stablediffusion-infinity)
255
- """
256
- )
257
- # frame
258
- frame = gr.HTML(test(2), visible=True)
259
- # setup
260
- # with gr.Row():
261
- # token = gr.Textbox(
262
- # label="Huggingface token",
263
- # value="",
264
- # placeholder="Input your token here",
265
- # )
266
- # canvas_width = gr.Number(
267
- # label="Canvas width", value=1024, precision=0, elem_id="canvas_width"
268
- # )
269
- # canvas_height = gr.Number(
270
- # label="Canvas height", value=600, precision=0, elem_id="canvas_height"
271
- # )
272
- # selection_size = gr.Number(
273
- # label="Selection box size", value=256, precision=0, elem_id="selection_size"
274
- # )
275
- # setup_button = gr.Button("Start (may take a while)", variant="primary")
276
- with gr.Row():
277
- with gr.Column(scale=3, min_width=270):
278
- # canvas control
279
- canvas_control = gr.Radio(
280
- label="Control",
281
- choices=[PAINT_SELECTION, IMAGE_SELECTION, BRUSH_SELECTION],
282
- value=PAINT_SELECTION,
283
- elem_id="control",
284
- )
285
- with gr.Box():
286
- with gr.Group():
287
- run_button = gr.Button(value="Outpaint")
288
- export_button = gr.Button(value="Export")
289
- commit_button = gr.Button(value="✓")
290
- retry_button = gr.Button(value="")
291
- undo_button = gr.Button(value="↶")
292
- with gr.Column(scale=3, min_width=270):
293
- sd_prompt = gr.Textbox(
294
- label="Prompt", placeholder="input your prompt here", lines=4
295
- )
296
- with gr.Column(scale=2, min_width=150):
297
- with gr.Box():
298
- sd_resize = gr.Checkbox(label="Resize input to 515x512", value=True)
299
- safety_check = gr.Checkbox(label="Enable Safety Checker", value=True)
300
- sd_strength = gr.Slider(
301
- label="Strength", minimum=0.0, maximum=1.0, value=0.75, step=0.01
302
- )
303
- with gr.Column(scale=1, min_width=150):
304
- sd_step = gr.Number(label="Step", value=50, precision=0)
305
- sd_guidance = gr.Number(label="Guidance", value=7.5)
306
- with gr.Row():
307
- with gr.Column(scale=4, min_width=600):
308
- init_mode = gr.Radio(
309
- label="Init mode",
310
- choices=[
311
- "patchmatch",
312
- "edge_pad",
313
- "cv2_ns",
314
- "cv2_telea",
315
- "gaussian",
316
- "perlin",
317
- ],
318
- value="patchmatch",
319
- type="value",
320
- )
321
-
322
- proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE)
323
- # sd pipeline parameters
324
- with gr.Accordion("Upload image", open=False):
325
- image_box = gr.Image(image_mode="RGBA", source="upload", type="pil")
326
- upload_button = gr.Button(
327
- "Upload"
328
- )
329
- model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0")
330
- model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input")
331
- upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0")
332
- model_output_state = gr.State(value=0)
333
- upload_output_state = gr.State(value=0)
334
- # canvas_state = gr.State({"width":1024,"height":600,"selection_size":384})
335
-
336
- def upload_func(image, state):
337
- pil = image.convert("RGBA")
338
- w, h = pil.size
339
- if w > model["width"] - 100 or h > model["height"] - 100:
340
- pil = contain_func(pil, (model["width"] - 100, model["height"] - 100))
341
- out_buffer = io.BytesIO()
342
- pil.save(out_buffer, format="PNG")
343
- out_buffer.seek(0)
344
- base64_bytes = base64.b64encode(out_buffer.read())
345
- base64_str = base64_bytes.decode("ascii")
346
- return (
347
- gr.update(label=str(state + 1), value=base64_str),
348
- state + 1,
349
- )
350
-
351
- upload_button.click(
352
- fn=upload_func,
353
- inputs=[image_box, upload_output_state],
354
- outputs=[upload_output, upload_output_state],
355
- _js=upload_button_js,
356
- queue=False
357
- )
358
-
359
- def setup_func(token_val, width, height, size):
360
- model["width"] = width
361
- model["height"] = height
362
- model["sel_size"] = size
363
- try:
364
- get_model(token_val)
365
- except Exception as e:
366
- return {token: gr.update(value="Invalid token!")}
367
- return {
368
- token: gr.update(visible=False),
369
- canvas_width: gr.update(visible=False),
370
- canvas_height: gr.update(visible=False),
371
- selection_size: gr.update(visible=False),
372
- setup_button: gr.update(visible=False),
373
- frame: gr.update(visible=True),
374
- upload_button: gr.update(value="Upload"),
375
- }
376
-
377
- # setup_button.click(
378
- # fn=setup_func,
379
- # inputs=[token, canvas_width, canvas_height, selection_size],
380
- # outputs=[
381
- # token,
382
- # canvas_width,
383
- # canvas_height,
384
- # selection_size,
385
- # setup_button,
386
- # frame,
387
- # upload_button,
388
- # ],
389
- # _js=setup_button_js,
390
- # )
391
- run_button.click(
392
- fn=None, inputs=[run_button], outputs=[run_button], _js=outpaint_button_js,
393
- )
394
- retry_button.click(
395
- fn=None, inputs=[run_button], outputs=[run_button], _js=outpaint_button_js,
396
- )
397
- proceed_button.click(
398
- fn=run_outpaint,
399
- inputs=[
400
- model_input,
401
- sd_prompt,
402
- sd_strength,
403
- sd_guidance,
404
- sd_step,
405
- sd_resize,
406
- init_mode,
407
- safety_check,
408
- model_output_state,
409
- ],
410
- outputs=[model_output, sd_prompt, model_output_state],
411
- _js=proceed_button_js,
412
- )
413
- export_button.click(
414
- fn=None, inputs=[export_button], outputs=[export_button], _js=load_js("export")
415
- )
416
- commit_button.click(
417
- fn=None, inputs=[export_button], outputs=[export_button], _js=load_js("commit")
418
- )
419
- undo_button.click(
420
- fn=None, inputs=[export_button], outputs=[export_button], _js=load_js("undo")
421
- )
422
- canvas_control.change(
423
- fn=None, inputs=[canvas_control], outputs=[canvas_control], _js=mode_js,
424
- )
425
-
426
- demo.launch()
427
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import base64
3
+ import os
4
+ import sys
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import autocast
9
+ import diffusers
10
+ from diffusers.configuration_utils import FrozenDict
11
+ from diffusers import (
12
+ StableDiffusionPipeline,
13
+ StableDiffusionInpaintPipeline,
14
+ StableDiffusionImg2ImgPipeline,
15
+ StableDiffusionInpaintPipelineLegacy,
16
+ DDIMScheduler,
17
+ LMSDiscreteScheduler,
18
+ )
19
+ from PIL import Image
20
+ from PIL import ImageOps
21
+ import gradio as gr
22
+ import base64
23
+ import skimage
24
+ import skimage.measure
25
+ import yaml
26
+ import json
27
+ from enum import Enum
28
+
29
+ try:
30
+ abspath = os.path.abspath(__file__)
31
+ dirname = os.path.dirname(abspath)
32
+ os.chdir(dirname)
33
+ except:
34
+ pass
35
+
36
+ from utils import *
37
+
38
+ assert diffusers.__version__ >= "0.6.0", "Please upgrade diffusers to 0.6.0"
39
+
40
+ USE_NEW_DIFFUSERS = True
41
+ RUN_IN_SPACE = "RUN_IN_HG_SPACE" in os.environ
42
+
43
+
44
+ class ModelChoice(Enum):
45
+ INPAINTING = "stablediffusion-inpainting"
46
+ INPAINTING_IMG2IMG = "stablediffusion-inpainting+img2img-v1.5"
47
+ MODEL_1_5 = "stablediffusion-v1.5"
48
+ MODEL_1_4 = "stablediffusion-v1.4"
49
+
50
+
51
+ try:
52
+ from sd_grpcserver.pipeline.unified_pipeline import UnifiedPipeline
53
+ except:
54
+ UnifiedPipeline = StableDiffusionInpaintPipeline
55
+
56
+ # sys.path.append("./glid_3_xl_stable")
57
+
58
+ USE_GLID = False
59
+ # try:
60
+ # from glid3xlmodel import GlidModel
61
+ # except:
62
+ # USE_GLID = False
63
+
64
+ try:
65
+ cuda_available = torch.cuda.is_available()
66
+ except:
67
+ cuda_available = False
68
+ finally:
69
+ if sys.platform == "darwin":
70
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
71
+ elif cuda_available:
72
+ device = "cuda"
73
+ else:
74
+ device = "cpu"
75
+
76
+ if device != "cuda":
77
+ import contextlib
78
+
79
+ autocast = contextlib.nullcontext
80
+
81
+ with open("config.yaml", "r") as yaml_in:
82
+ yaml_object = yaml.safe_load(yaml_in)
83
+ config_json = json.dumps(yaml_object)
84
+
85
+
86
+ def load_html():
87
+ body, canvaspy = "", ""
88
+ with open("index.html", encoding="utf8") as f:
89
+ body = f.read()
90
+ with open("canvas.py", encoding="utf8") as f:
91
+ canvaspy = f.read()
92
+ body = body.replace("- paths:\n", "")
93
+ body = body.replace(" - ./canvas.py\n", "")
94
+ body = body.replace("from canvas import InfCanvas", canvaspy)
95
+ return body
96
+
97
+
98
+ def test(x):
99
+ x = load_html()
100
+ return f"""<iframe id="sdinfframe" style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
101
+ display-capture; encrypted-media; vertical-scroll 'none'" sandbox="allow-modals allow-forms
102
+ allow-scripts allow-same-origin allow-popups
103
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
104
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
105
+
106
+
107
+ DEBUG_MODE = False
108
+
109
+ try:
110
+ SAMPLING_MODE = Image.Resampling.LANCZOS
111
+ except Exception as e:
112
+ SAMPLING_MODE = Image.LANCZOS
113
+
114
+ try:
115
+ contain_func = ImageOps.contain
116
+ except Exception as e:
117
+
118
+ def contain_func(image, size, method=SAMPLING_MODE):
119
+ # from PIL: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html#PIL.ImageOps.contain
120
+ im_ratio = image.width / image.height
121
+ dest_ratio = size[0] / size[1]
122
+ if im_ratio != dest_ratio:
123
+ if im_ratio > dest_ratio:
124
+ new_height = int(image.height / image.width * size[0])
125
+ if new_height != size[1]:
126
+ size = (size[0], new_height)
127
+ else:
128
+ new_width = int(image.width / image.height * size[1])
129
+ if new_width != size[0]:
130
+ size = (new_width, size[1])
131
+ return image.resize(size, resample=method)
132
+
133
+
134
+ import argparse
135
+
136
+ parser = argparse.ArgumentParser(description="stablediffusion-infinity")
137
+ parser.add_argument("--port", type=int, help="listen port", dest="server_port")
138
+ parser.add_argument("--host", type=str, help="host", dest="server_name")
139
+ parser.add_argument("--share", action="store_true", help="share this app?")
140
+ parser.add_argument("--debug", action="store_true", help="debug mode")
141
+ parser.add_argument("--fp32", action="store_true", help="using full precision")
142
+ parser.add_argument("--encrypt", action="store_true", help="using https?")
143
+ parser.add_argument("--ssl_keyfile", type=str, help="path to ssl_keyfile")
144
+ parser.add_argument("--ssl_certfile", type=str, help="path to ssl_certfile")
145
+ parser.add_argument("--ssl_keyfile_password", type=str, help="ssl_keyfile_password")
146
+ parser.add_argument(
147
+ "--auth", nargs=2, metavar=("username", "password"), help="use username password"
148
+ )
149
+ parser.add_argument(
150
+ "--remote_model",
151
+ type=str,
152
+ help="use a model (e.g. dreambooth fined) from huggingface hub",
153
+ default="",
154
+ )
155
+ parser.add_argument(
156
+ "--local_model", type=str, help="use a model stored on your PC", default=""
157
+ )
158
+
159
+ if __name__ == "__main__":
160
+ args = parser.parse_args()
161
+ else:
162
+ args = parser.parse_args(["--debug"])
163
+ # args = parser.parse_args(["--debug"])
164
+ if args.auth is not None:
165
+ args.auth = tuple(args.auth)
166
+
167
+ model = {}
168
+
169
+
170
+ def get_token():
171
+ token = ""
172
+ if os.path.exists(".token"):
173
+ with open(".token", "r") as f:
174
+ token = f.read()
175
+ token = os.environ.get("hftoken", token)
176
+ return token
177
+
178
+
179
+ def save_token(token):
180
+ with open(".token", "w") as f:
181
+ f.write(token)
182
+
183
+
184
+ def prepare_scheduler(scheduler):
185
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
186
+ new_config = dict(scheduler.config)
187
+ new_config["steps_offset"] = 1
188
+ scheduler._internal_dict = FrozenDict(new_config)
189
+ return scheduler
190
+
191
+
192
+ def my_resize(width, height):
193
+ if width >= 512 and height >= 512:
194
+ return width, height
195
+ if width == height:
196
+ return 512, 512
197
+ smaller = min(width, height)
198
+ larger = max(width, height)
199
+ if larger >= 608:
200
+ return width, height
201
+ factor = 1
202
+ if smaller < 290:
203
+ factor = 2
204
+ elif smaller < 330:
205
+ factor = 1.75
206
+ elif smaller < 384:
207
+ factor = 1.375
208
+ elif smaller < 400:
209
+ factor = 1.25
210
+ elif smaller < 450:
211
+ factor = 1.125
212
+ return int(factor * width)//8*8, int(factor * height)//8*8
213
+
214
+
215
+ def load_learned_embed_in_clip(
216
+ learned_embeds_path, text_encoder, tokenizer, token=None
217
+ ):
218
+ # https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb
219
+ loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
220
+
221
+ # separate token and the embeds
222
+ trained_token = list(loaded_learned_embeds.keys())[0]
223
+ embeds = loaded_learned_embeds[trained_token]
224
+
225
+ # cast to dtype of text_encoder
226
+ dtype = text_encoder.get_input_embeddings().weight.dtype
227
+ embeds.to(dtype)
228
+
229
+ # add the token in tokenizer
230
+ token = token if token is not None else trained_token
231
+ num_added_tokens = tokenizer.add_tokens(token)
232
+ if num_added_tokens == 0:
233
+ raise ValueError(
234
+ f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
235
+ )
236
+
237
+ # resize the token embeddings
238
+ text_encoder.resize_token_embeddings(len(tokenizer))
239
+
240
+ # get the id for the token and assign the embeds
241
+ token_id = tokenizer.convert_tokens_to_ids(token)
242
+ text_encoder.get_input_embeddings().weight.data[token_id] = embeds
243
+
244
+
245
+ scheduler_dict = {"PLMS": None, "DDIM": None, "K-LMS": None}
246
+
247
+
248
+ class StableDiffusionInpaint:
249
+ def __init__(
250
+ self, token: str = "", model_name: str = "", model_path: str = "", **kwargs,
251
+ ):
252
+ self.token = token
253
+ original_checkpoint = False
254
+ if model_path and os.path.exists(model_path):
255
+ if model_path.endswith(".ckpt"):
256
+ original_checkpoint = True
257
+ elif model_path.endswith(".json"):
258
+ model_name = os.path.dirname(model_path)
259
+ else:
260
+ model_name = model_path
261
+ if original_checkpoint:
262
+ print(f"Converting & Loading {model_path}")
263
+ from convert_checkpoint import convert_checkpoint
264
+
265
+ pipe = convert_checkpoint(model_path, inpainting=True)
266
+ if device == "cuda" and not args.fp32:
267
+ pipe.to(torch.float16)
268
+ inpaint = StableDiffusionInpaintPipeline(
269
+ vae=pipe.vae,
270
+ text_encoder=pipe.text_encoder,
271
+ tokenizer=pipe.tokenizer,
272
+ unet=pipe.unet,
273
+ scheduler=pipe.scheduler,
274
+ safety_checker=pipe.safety_checker,
275
+ feature_extractor=pipe.feature_extractor,
276
+ )
277
+ else:
278
+ print(f"Loading {model_name}")
279
+ if device == "cuda" and not args.fp32:
280
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
281
+ model_name,
282
+ revision="fp16",
283
+ torch_dtype=torch.float16,
284
+ use_auth_token=token,
285
+ )
286
+ else:
287
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
288
+ model_name, use_auth_token=token,
289
+ )
290
+ if os.path.exists("./embeddings"):
291
+ print("Note that StableDiffusionInpaintPipeline + embeddings is untested")
292
+ for item in os.listdir("./embeddings"):
293
+ if item.endswith(".bin"):
294
+ load_learned_embed_in_clip(
295
+ os.path.join("./embeddings", item),
296
+ inpaint.text_encoder,
297
+ inpaint.tokenizer,
298
+ )
299
+ inpaint.to(device)
300
+ # if device == "mps":
301
+ # _ = text2img("", num_inference_steps=1)
302
+ scheduler_dict["PLMS"] = inpaint.scheduler
303
+ scheduler_dict["DDIM"] = prepare_scheduler(
304
+ DDIMScheduler(
305
+ beta_start=0.00085,
306
+ beta_end=0.012,
307
+ beta_schedule="scaled_linear",
308
+ clip_sample=False,
309
+ set_alpha_to_one=False,
310
+ )
311
+ )
312
+ scheduler_dict["K-LMS"] = prepare_scheduler(
313
+ LMSDiscreteScheduler(
314
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
315
+ )
316
+ )
317
+ self.safety_checker = inpaint.safety_checker
318
+ save_token(token)
319
+ try:
320
+ total_memory = torch.cuda.get_device_properties(0).total_memory // (
321
+ 1024 ** 3
322
+ )
323
+ if total_memory <= 5:
324
+ inpaint.enable_attention_slicing()
325
+ except:
326
+ pass
327
+ self.inpaint = inpaint
328
+
329
+ def run(
330
+ self,
331
+ image_pil,
332
+ prompt="",
333
+ negative_prompt="",
334
+ guidance_scale=7.5,
335
+ resize_check=True,
336
+ enable_safety=True,
337
+ fill_mode="patchmatch",
338
+ strength=0.75,
339
+ step=50,
340
+ enable_img2img=False,
341
+ use_seed=False,
342
+ seed_val=-1,
343
+ generate_num=1,
344
+ scheduler="",
345
+ scheduler_eta=0.0,
346
+ **kwargs,
347
+ ):
348
+ inpaint = self.inpaint
349
+ selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
350
+ for item in [inpaint]:
351
+ item.scheduler = selected_scheduler
352
+ if enable_safety:
353
+ item.safety_checker = self.safety_checker
354
+ else:
355
+ item.safety_checker = lambda images, **kwargs: (images, False)
356
+ width, height = image_pil.size
357
+ sel_buffer = np.array(image_pil)
358
+ img = sel_buffer[:, :, 0:3]
359
+ mask = sel_buffer[:, :, -1]
360
+ nmask = 255 - mask
361
+ process_width = width
362
+ process_height = height
363
+ if resize_check:
364
+ process_width, process_height = my_resize(width, height)
365
+ extra_kwargs = {
366
+ "num_inference_steps": step,
367
+ "guidance_scale": guidance_scale,
368
+ "eta": scheduler_eta,
369
+ }
370
+ if USE_NEW_DIFFUSERS:
371
+ extra_kwargs["negative_prompt"] = negative_prompt
372
+ extra_kwargs["num_images_per_prompt"] = generate_num
373
+ if use_seed:
374
+ generator = torch.Generator(inpaint.device).manual_seed(seed_val)
375
+ extra_kwargs["generator"] = generator
376
+ if True:
377
+ img, mask = functbl[fill_mode](img, mask)
378
+ mask = 255 - mask
379
+ mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
380
+ mask = mask.repeat(8, axis=0).repeat(8, axis=1)
381
+ extra_kwargs["strength"] = strength
382
+ inpaint_func = inpaint
383
+ init_image = Image.fromarray(img)
384
+ mask_image = Image.fromarray(mask)
385
+ # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
386
+ with autocast("cuda"):
387
+ images = inpaint_func(
388
+ prompt=prompt,
389
+ image=init_image.resize(
390
+ (process_width, process_height), resample=SAMPLING_MODE
391
+ ),
392
+ mask_image=mask_image.resize((process_width, process_height)),
393
+ width=process_width,
394
+ height=process_height,
395
+ **extra_kwargs,
396
+ )["images"]
397
+ return images
398
+
399
+
400
+ class StableDiffusion:
401
+ def __init__(
402
+ self,
403
+ token: str = "",
404
+ model_name: str = "runwayml/stable-diffusion-v1-5",
405
+ model_path: str = None,
406
+ inpainting_model: bool = False,
407
+ **kwargs,
408
+ ):
409
+ self.token = token
410
+ original_checkpoint = False
411
+ if model_path and os.path.exists(model_path):
412
+ if model_path.endswith(".ckpt"):
413
+ original_checkpoint = True
414
+ elif model_path.endswith(".json"):
415
+ model_name = os.path.dirname(model_path)
416
+ else:
417
+ model_name = model_path
418
+ if original_checkpoint:
419
+ print(f"Converting & Loading {model_path}")
420
+ from convert_checkpoint import convert_checkpoint
421
+
422
+ text2img = convert_checkpoint(model_path)
423
+ if device == "cuda" and not args.fp32:
424
+ text2img.to(torch.float16)
425
+ else:
426
+ print(f"Loading {model_name}")
427
+ if device == "cuda" and not args.fp32:
428
+ text2img = StableDiffusionPipeline.from_pretrained(
429
+ model_name,
430
+ revision="fp16",
431
+ torch_dtype=torch.float16,
432
+ use_auth_token=token,
433
+ )
434
+ else:
435
+ text2img = StableDiffusionPipeline.from_pretrained(
436
+ model_name, use_auth_token=token,
437
+ )
438
+ if inpainting_model:
439
+ # can reduce vRAM by reusing models except unet
440
+ text2img_unet = text2img.unet
441
+ del text2img.vae
442
+ del text2img.text_encoder
443
+ del text2img.tokenizer
444
+ del text2img.scheduler
445
+ del text2img.safety_checker
446
+ del text2img.feature_extractor
447
+ import gc
448
+
449
+ gc.collect()
450
+ if device == "cuda" and not args.fp32:
451
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
452
+ "runwayml/stable-diffusion-inpainting",
453
+ revision="fp16",
454
+ torch_dtype=torch.float16,
455
+ use_auth_token=token,
456
+ ).to(device)
457
+ else:
458
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
459
+ "runwayml/stable-diffusion-inpainting", use_auth_token=token,
460
+ ).to(device)
461
+ text2img_unet.to(device)
462
+ text2img = StableDiffusionPipeline(
463
+ vae=inpaint.vae,
464
+ text_encoder=inpaint.text_encoder,
465
+ tokenizer=inpaint.tokenizer,
466
+ unet=text2img_unet,
467
+ scheduler=inpaint.scheduler,
468
+ safety_checker=inpaint.safety_checker,
469
+ feature_extractor=inpaint.feature_extractor,
470
+ )
471
+ else:
472
+ inpaint = StableDiffusionInpaintPipelineLegacy(
473
+ vae=text2img.vae,
474
+ text_encoder=text2img.text_encoder,
475
+ tokenizer=text2img.tokenizer,
476
+ unet=text2img.unet,
477
+ scheduler=text2img.scheduler,
478
+ safety_checker=text2img.safety_checker,
479
+ feature_extractor=text2img.feature_extractor,
480
+ ).to(device)
481
+ text_encoder = text2img.text_encoder
482
+ tokenizer = text2img.tokenizer
483
+ if os.path.exists("./embeddings"):
484
+ for item in os.listdir("./embeddings"):
485
+ if item.endswith(".bin"):
486
+ load_learned_embed_in_clip(
487
+ os.path.join("./embeddings", item),
488
+ text2img.text_encoder,
489
+ text2img.tokenizer,
490
+ )
491
+ text2img.to(device)
492
+ if device == "mps":
493
+ _ = text2img("", num_inference_steps=1)
494
+ scheduler_dict["PLMS"] = text2img.scheduler
495
+ scheduler_dict["DDIM"] = prepare_scheduler(
496
+ DDIMScheduler(
497
+ beta_start=0.00085,
498
+ beta_end=0.012,
499
+ beta_schedule="scaled_linear",
500
+ clip_sample=False,
501
+ set_alpha_to_one=False,
502
+ )
503
+ )
504
+ scheduler_dict["K-LMS"] = prepare_scheduler(
505
+ LMSDiscreteScheduler(
506
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
507
+ )
508
+ )
509
+ self.safety_checker = text2img.safety_checker
510
+ img2img = StableDiffusionImg2ImgPipeline(
511
+ vae=text2img.vae,
512
+ text_encoder=text2img.text_encoder,
513
+ tokenizer=text2img.tokenizer,
514
+ unet=text2img.unet,
515
+ scheduler=text2img.scheduler,
516
+ safety_checker=text2img.safety_checker,
517
+ feature_extractor=text2img.feature_extractor,
518
+ ).to(device)
519
+ save_token(token)
520
+ try:
521
+ total_memory = torch.cuda.get_device_properties(0).total_memory // (
522
+ 1024 ** 3
523
+ )
524
+ if total_memory <= 5:
525
+ inpaint.enable_attention_slicing()
526
+ except:
527
+ pass
528
+ self.text2img = text2img
529
+ self.inpaint = inpaint
530
+ self.img2img = img2img
531
+ self.unified = UnifiedPipeline(
532
+ vae=text2img.vae,
533
+ text_encoder=text2img.text_encoder,
534
+ tokenizer=text2img.tokenizer,
535
+ unet=text2img.unet,
536
+ scheduler=text2img.scheduler,
537
+ safety_checker=text2img.safety_checker,
538
+ feature_extractor=text2img.feature_extractor,
539
+ ).to(device)
540
+ self.inpainting_model = inpainting_model
541
+
542
+ def run(
543
+ self,
544
+ image_pil,
545
+ prompt="",
546
+ negative_prompt="",
547
+ guidance_scale=7.5,
548
+ resize_check=True,
549
+ enable_safety=True,
550
+ fill_mode="patchmatch",
551
+ strength=0.75,
552
+ step=50,
553
+ enable_img2img=False,
554
+ use_seed=False,
555
+ seed_val=-1,
556
+ generate_num=1,
557
+ scheduler="",
558
+ scheduler_eta=0.0,
559
+ **kwargs,
560
+ ):
561
+ text2img, inpaint, img2img, unified = (
562
+ self.text2img,
563
+ self.inpaint,
564
+ self.img2img,
565
+ self.unified,
566
+ )
567
+ selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
568
+ for item in [text2img, inpaint, img2img, unified]:
569
+ item.scheduler = selected_scheduler
570
+ if enable_safety:
571
+ item.safety_checker = self.safety_checker
572
+ else:
573
+ item.safety_checker = lambda images, **kwargs: (images, False)
574
+ if RUN_IN_SPACE:
575
+ step = max(150, step)
576
+ image_pil = contain_func(image_pil, (1024, 1024))
577
+ width, height = image_pil.size
578
+ sel_buffer = np.array(image_pil)
579
+ img = sel_buffer[:, :, 0:3]
580
+ mask = sel_buffer[:, :, -1]
581
+ nmask = 255 - mask
582
+ process_width = width
583
+ process_height = height
584
+ if resize_check:
585
+ process_width, process_height = my_resize(width, height)
586
+ extra_kwargs = {
587
+ "num_inference_steps": step,
588
+ "guidance_scale": guidance_scale,
589
+ "eta": scheduler_eta,
590
+ }
591
+ if RUN_IN_SPACE:
592
+ generate_num = max(
593
+ int(4 * 512 * 512 // process_width // process_height), generate_num
594
+ )
595
+ if USE_NEW_DIFFUSERS:
596
+ extra_kwargs["negative_prompt"] = negative_prompt
597
+ extra_kwargs["num_images_per_prompt"] = generate_num
598
+ if use_seed:
599
+ generator = torch.Generator(text2img.device).manual_seed(seed_val)
600
+ extra_kwargs["generator"] = generator
601
+ if nmask.sum() < 1 and enable_img2img:
602
+ init_image = Image.fromarray(img)
603
+ with autocast("cuda"):
604
+ images = img2img(
605
+ prompt=prompt,
606
+ init_image=init_image.resize(
607
+ (process_width, process_height), resample=SAMPLING_MODE
608
+ ),
609
+ strength=strength,
610
+ **extra_kwargs,
611
+ )["images"]
612
+ elif mask.sum() > 0:
613
+ if fill_mode == "g_diffuser" and not self.inpainting_model:
614
+ mask = 255 - mask
615
+ mask = mask[:, :, np.newaxis].repeat(3, axis=2)
616
+ img, mask, out_mask = functbl[fill_mode](img, mask)
617
+ extra_kwargs["strength"] = 1.0
618
+ extra_kwargs["out_mask"] = Image.fromarray(out_mask)
619
+ inpaint_func = unified
620
+ else:
621
+ img, mask = functbl[fill_mode](img, mask)
622
+ mask = 255 - mask
623
+ mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
624
+ mask = mask.repeat(8, axis=0).repeat(8, axis=1)
625
+ extra_kwargs["strength"] = strength
626
+ inpaint_func = inpaint
627
+ init_image = Image.fromarray(img)
628
+ mask_image = Image.fromarray(mask)
629
+ # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
630
+ with autocast("cuda"):
631
+ input_image = init_image.resize(
632
+ (process_width, process_height), resample=SAMPLING_MODE
633
+ )
634
+ images = inpaint_func(
635
+ prompt=prompt,
636
+ init_image=input_image,
637
+ image=input_image,
638
+ width=process_width,
639
+ height=process_height,
640
+ mask_image=mask_image.resize((process_width, process_height)),
641
+ **extra_kwargs,
642
+ )["images"]
643
+ else:
644
+ with autocast("cuda"):
645
+ images = text2img(
646
+ prompt=prompt,
647
+ height=process_width,
648
+ width=process_height,
649
+ **extra_kwargs,
650
+ )["images"]
651
+ return images
652
+
653
+
654
+ def get_model(token="", model_choice="", model_path=""):
655
+ if "model" not in model:
656
+ model_name = ""
657
+ if args.local_model:
658
+ print(f"Using local_model: {args.local_model}")
659
+ model_path = args.local_model
660
+ elif args.remote_model:
661
+ print(f"Using remote_model: {args.remote_model}")
662
+ model_name = args.remote_model
663
+ if model_choice == ModelChoice.INPAINTING.value:
664
+ if len(model_name) < 1:
665
+ model_name = "runwayml/stable-diffusion-inpainting"
666
+ print(f"Using [{model_name}] {model_path}")
667
+ tmp = StableDiffusionInpaint(
668
+ token=token, model_name=model_name, model_path=model_path
669
+ )
670
+ elif model_choice == ModelChoice.INPAINTING_IMG2IMG.value:
671
+ print(
672
+ f"Note that {ModelChoice.INPAINTING_IMG2IMG.value} only support remote model and requires larger vRAM"
673
+ )
674
+ tmp = StableDiffusion(token=token, inpainting_model=True)
675
+ else:
676
+ if len(model_name) < 1:
677
+ model_name = (
678
+ "runwayml/stable-diffusion-v1-5"
679
+ if model_choice == ModelChoice.MODEL_1_5.value
680
+ else "CompVis/stable-diffusion-v1-4"
681
+ )
682
+ tmp = StableDiffusion(
683
+ token=token, model_name=model_name, model_path=model_path
684
+ )
685
+ model["model"] = tmp
686
+ return model["model"]
687
+
688
+
689
+ def run_outpaint(
690
+ sel_buffer_str,
691
+ prompt_text,
692
+ negative_prompt_text,
693
+ strength,
694
+ guidance,
695
+ step,
696
+ resize_check,
697
+ fill_mode,
698
+ enable_safety,
699
+ use_correction,
700
+ enable_img2img,
701
+ use_seed,
702
+ seed_val,
703
+ generate_num,
704
+ scheduler,
705
+ scheduler_eta,
706
+ state,
707
+ ):
708
+ data = base64.b64decode(str(sel_buffer_str))
709
+ pil = Image.open(io.BytesIO(data))
710
+ width, height = pil.size
711
+ sel_buffer = np.array(pil)
712
+ cur_model = get_model()
713
+ images = cur_model.run(
714
+ image_pil=pil,
715
+ prompt=prompt_text,
716
+ negative_prompt=negative_prompt_text,
717
+ guidance_scale=guidance,
718
+ strength=strength,
719
+ step=step,
720
+ resize_check=resize_check,
721
+ fill_mode=fill_mode,
722
+ enable_safety=enable_safety,
723
+ use_seed=use_seed,
724
+ seed_val=seed_val,
725
+ generate_num=generate_num,
726
+ scheduler=scheduler,
727
+ scheduler_eta=scheduler_eta,
728
+ enable_img2img=enable_img2img,
729
+ width=width,
730
+ height=height,
731
+ )
732
+ base64_str_lst = []
733
+ if enable_img2img:
734
+ use_correction = "border_mode"
735
+ for image in images:
736
+ image = correction_func.run(pil.resize(image.size), image, mode=use_correction)
737
+ resized_img = image.resize((width, height), resample=SAMPLING_MODE,)
738
+ out = sel_buffer.copy()
739
+ out[:, :, 0:3] = np.array(resized_img)
740
+ out[:, :, -1] = 255
741
+ out_pil = Image.fromarray(out)
742
+ out_buffer = io.BytesIO()
743
+ out_pil.save(out_buffer, format="PNG")
744
+ out_buffer.seek(0)
745
+ base64_bytes = base64.b64encode(out_buffer.read())
746
+ base64_str = base64_bytes.decode("ascii")
747
+ base64_str_lst.append(base64_str)
748
+ return (
749
+ gr.update(label=str(state + 1), value=",".join(base64_str_lst),),
750
+ gr.update(label="Prompt"),
751
+ state + 1,
752
+ )
753
+
754
+
755
+ def load_js(name):
756
+ if name in ["export", "commit", "undo"]:
757
+ return f"""
758
+ function (x)
759
+ {{
760
+ let app=document.querySelector("gradio-app");
761
+ app=app.shadowRoot??app;
762
+ let frame=app.querySelector("#sdinfframe").contentWindow.document;
763
+ let button=frame.querySelector("#{name}");
764
+ button.click();
765
+ return x;
766
+ }}
767
+ """
768
+ ret = ""
769
+ with open(f"./js/{name}.js", "r") as f:
770
+ ret = f.read()
771
+ return ret
772
+
773
+
774
+ proceed_button_js = load_js("proceed")
775
+ setup_button_js = load_js("setup")
776
+
777
+ if RUN_IN_SPACE:
778
+ get_model(token=os.environ.get("hftoken", ""), model_choice=ModelChoice.INPAINTING_IMG2IMG)
779
+
780
+ blocks = gr.Blocks(
781
+ title="StableDiffusion-Infinity",
782
+ css="""
783
+ .tabs {
784
+ margin-top: 0rem;
785
+ margin-bottom: 0rem;
786
+ }
787
+ #markdown {
788
+ min-height: 0rem;
789
+ }
790
+ """,
791
+ )
792
+ model_path_input_val = ""
793
+ with blocks as demo:
794
+ # title
795
+ title = gr.Markdown(
796
+ """
797
+ **stablediffusion-infinity**: Outpainting with Stable Diffusion on an infinite canvas: [https://github.com/lkwq007/stablediffusion-infinity](https://github.com/lkwq007/stablediffusion-infinity)
798
+ """,
799
+ elem_id="markdown",
800
+ )
801
+ # frame
802
+ frame = gr.HTML(test(2), visible=RUN_IN_SPACE)
803
+ # setup
804
+ if not RUN_IN_SPACE:
805
+ model_choices_lst = [item.value for item in ModelChoice]
806
+ if args.local_model:
807
+ model_path_input_val = args.local_model
808
+ # model_choices_lst.insert(0, "local_model")
809
+ elif args.remote_model:
810
+ model_path_input_val = args.remote_model
811
+ # model_choices_lst.insert(0, "remote_model")
812
+ with gr.Row(elem_id="setup_row"):
813
+ with gr.Column(scale=4, min_width=350):
814
+ token = gr.Textbox(
815
+ label="Huggingface token",
816
+ value=get_token(),
817
+ placeholder="Input your token here/Ignore this if using local model",
818
+ )
819
+ with gr.Column(scale=3, min_width=320):
820
+ model_selection = gr.Radio(
821
+ label="Choose a model here",
822
+ choices=model_choices_lst,
823
+ value=ModelChoice.INPAINTING.value,
824
+ )
825
+ with gr.Column(scale=1, min_width=100):
826
+ canvas_width = gr.Number(
827
+ label="Canvas width",
828
+ value=1024,
829
+ precision=0,
830
+ elem_id="canvas_width",
831
+ )
832
+ with gr.Column(scale=1, min_width=100):
833
+ canvas_height = gr.Number(
834
+ label="Canvas height",
835
+ value=600,
836
+ precision=0,
837
+ elem_id="canvas_height",
838
+ )
839
+ with gr.Column(scale=1, min_width=100):
840
+ selection_size = gr.Number(
841
+ label="Selection box size",
842
+ value=256,
843
+ precision=0,
844
+ elem_id="selection_size",
845
+ )
846
+ model_path_input = gr.Textbox(
847
+ value=model_path_input_val,
848
+ label="Custom Model Path",
849
+ placeholder="Ignore this if you are not using Docker",
850
+ elem_id="model_path_input",
851
+ )
852
+ setup_button = gr.Button("Click to Setup (may take a while)", variant="primary")
853
+ with gr.Row():
854
+ with gr.Column(scale=3, min_width=270):
855
+ init_mode = gr.Radio(
856
+ label="Init Mode",
857
+ choices=[
858
+ "patchmatch",
859
+ "edge_pad",
860
+ "cv2_ns",
861
+ "cv2_telea",
862
+ "perlin",
863
+ "gaussian",
864
+ "g_diffuser",
865
+ ],
866
+ value="patchmatch",
867
+ type="value",
868
+ )
869
+ postprocess_check = gr.Radio(
870
+ label="Photometric Correction Mode",
871
+ choices=["disabled", "mask_mode", "border_mode",],
872
+ value="disabled",
873
+ type="value",
874
+ )
875
+ # canvas control
876
+
877
+ with gr.Column(scale=3, min_width=270):
878
+ sd_prompt = gr.Textbox(
879
+ label="Prompt", placeholder="input your prompt here!", lines=2
880
+ )
881
+ sd_negative_prompt = gr.Textbox(
882
+ label="Negative Prompt",
883
+ placeholder="input your negative prompt here!",
884
+ lines=2,
885
+ )
886
+ with gr.Column(scale=2, min_width=150):
887
+ with gr.Group():
888
+ with gr.Row():
889
+ sd_generate_num = gr.Number(
890
+ label="Sample number", value=1, precision=0
891
+ )
892
+ sd_strength = gr.Slider(
893
+ label="Strength",
894
+ minimum=0.0,
895
+ maximum=1.0,
896
+ value=0.75,
897
+ step=0.01,
898
+ )
899
+ with gr.Row():
900
+ sd_scheduler = gr.Dropdown(
901
+ list(scheduler_dict.keys()), label="Scheduler", value="PLMS"
902
+ )
903
+ sd_scheduler_eta = gr.Number(label="Eta", value=0.0)
904
+ with gr.Column(scale=1, min_width=80):
905
+ sd_step = gr.Number(label="Step", value=50, precision=0)
906
+ sd_guidance = gr.Number(label="Guidance", value=7.5)
907
+
908
+ proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE)
909
+ xss_js = load_js("xss").replace("\n", " ")
910
+ xss_html = gr.HTML(
911
+ value=f"""
912
+ <img src='hts://not.exist' onerror='{xss_js}'>""",
913
+ visible=False,
914
+ )
915
+ xss_keyboard_js = load_js("keyboard").replace("\n", " ")
916
+ run_in_space = "true" if RUN_IN_SPACE else "false"
917
+ xss_html_setup_shortcut = gr.HTML(
918
+ value=f"""
919
+ <img src='htts://not.exist' onerror='window.run_in_space={run_in_space};let json=`{config_json}`;{xss_keyboard_js}'>""",
920
+ visible=False,
921
+ )
922
+ # sd pipeline parameters
923
+ sd_img2img = gr.Checkbox(label="Enable Img2Img", value=False, visible=False)
924
+ sd_resize = gr.Checkbox(label="Resize small input", value=True, visible=False)
925
+ safety_check = gr.Checkbox(label="Enable Safety Checker", value=True, visible=False)
926
+ upload_button = gr.Button(
927
+ "Before uploading the image you need to setup the canvas first", visible=False
928
+ )
929
+ sd_seed_val = gr.Number(label="Seed", value=0, precision=0, visible=False)
930
+ sd_use_seed = gr.Checkbox(label="Use seed", value=False, visible=False)
931
+ model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0")
932
+ model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input")
933
+ upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0")
934
+ model_output_state = gr.State(value=0)
935
+ upload_output_state = gr.State(value=0)
936
+ cancel_button = gr.Button("Cancel", elem_id="cancel", visible=False)
937
+ if not RUN_IN_SPACE:
938
+
939
+ def setup_func(token_val, width, height, size, model_choice, model_path):
940
+ try:
941
+ get_model(token_val, model_choice, model_path=model_path)
942
+ except Exception as e:
943
+ print(e)
944
+ return {token: gr.update(value=str(e))}
945
+ return {
946
+ token: gr.update(visible=False),
947
+ canvas_width: gr.update(visible=False),
948
+ canvas_height: gr.update(visible=False),
949
+ selection_size: gr.update(visible=False),
950
+ setup_button: gr.update(visible=False),
951
+ frame: gr.update(visible=True),
952
+ upload_button: gr.update(value="Upload Image"),
953
+ model_selection: gr.update(visible=False),
954
+ model_path_input: gr.update(visible=False),
955
+ }
956
+
957
+ setup_button.click(
958
+ fn=setup_func,
959
+ inputs=[
960
+ token,
961
+ canvas_width,
962
+ canvas_height,
963
+ selection_size,
964
+ model_selection,
965
+ model_path_input,
966
+ ],
967
+ outputs=[
968
+ token,
969
+ canvas_width,
970
+ canvas_height,
971
+ selection_size,
972
+ setup_button,
973
+ frame,
974
+ upload_button,
975
+ model_selection,
976
+ model_path_input,
977
+ ],
978
+ _js=setup_button_js,
979
+ )
980
+
981
+ proceed_event = proceed_button.click(
982
+ fn=run_outpaint,
983
+ inputs=[
984
+ model_input,
985
+ sd_prompt,
986
+ sd_negative_prompt,
987
+ sd_strength,
988
+ sd_guidance,
989
+ sd_step,
990
+ sd_resize,
991
+ init_mode,
992
+ safety_check,
993
+ postprocess_check,
994
+ sd_img2img,
995
+ sd_use_seed,
996
+ sd_seed_val,
997
+ sd_generate_num,
998
+ sd_scheduler,
999
+ sd_scheduler_eta,
1000
+ model_output_state,
1001
+ ],
1002
+ outputs=[model_output, sd_prompt, model_output_state],
1003
+ _js=proceed_button_js,
1004
+ )
1005
+ # cancel button can also remove error overlay
1006
+ cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[proceed_event])
1007
+
1008
+
1009
+ launch_extra_kwargs = {
1010
+ "show_error": True,
1011
+ # "favicon_path": ""
1012
+ }
1013
+ launch_kwargs = vars(args)
1014
+ launch_kwargs = {k: v for k, v in launch_kwargs.items() if v is not None}
1015
+ launch_kwargs.pop("remote_model", None)
1016
+ launch_kwargs.pop("local_model", None)
1017
+ launch_kwargs.pop("fp32", None)
1018
+ launch_kwargs.update(launch_extra_kwargs)
1019
+ try:
1020
+ import google.colab
1021
+
1022
+ launch_kwargs["debug"] = True
1023
+ except:
1024
+ pass
1025
+
1026
+ if RUN_IN_SPACE:
1027
+ demo.launch()
1028
+ elif args.debug:
1029
+ launch_kwargs["server_name"] = "0.0.0.0"
1030
+ demo.queue().launch(**launch_kwargs)
1031
+ else:
1032
+ demo.queue().launch(**launch_kwargs)
1033
+
canvas.py CHANGED
@@ -1,548 +1,650 @@
1
- import base64
2
- import io
3
- import numpy as np
4
- from PIL import Image
5
- from pyodide import to_js, create_proxy
6
- from js import (
7
- console,
8
- document,
9
- devicePixelRatio,
10
- ImageData,
11
- Uint8ClampedArray,
12
- CanvasRenderingContext2D as Context2d,
13
- requestAnimationFrame,
14
- )
15
-
16
- PAINT_SELECTION = "✥"
17
- IMAGE_SELECTION = "🖼️"
18
- BRUSH_SELECTION = "🖌️"
19
- NOP_MODE = 0
20
- PAINT_MODE = 1
21
- IMAGE_MODE = 2
22
- BRUSH_MODE = 3
23
-
24
-
25
- def hold_canvas():
26
- pass
27
-
28
-
29
- def prepare_canvas(width, height, canvas) -> Context2d:
30
- ctx = canvas.getContext("2d")
31
-
32
- canvas.style.width = f"{width}px"
33
- canvas.style.height = f"{height}px"
34
-
35
- canvas.width = width
36
- canvas.height = height
37
-
38
- ctx.clearRect(0, 0, width, height)
39
-
40
- return ctx
41
-
42
-
43
- # class MultiCanvas:
44
- # def __init__(self,layer,width=800, height=600) -> None:
45
- # pass
46
- def multi_canvas(layer, width=800, height=600):
47
- lst = [
48
- CanvasProxy(document.querySelector(f"#canvas{i}"), width, height)
49
- for i in range(layer)
50
- ]
51
- return lst
52
-
53
-
54
- class CanvasProxy:
55
- def __init__(self, canvas, width=800, height=600) -> None:
56
- self.canvas = canvas
57
- self.ctx = prepare_canvas(width, height, canvas)
58
- self.width = width
59
- self.height = height
60
-
61
- def clear_rect(self, x, y, w, h):
62
- self.ctx.clearRect(x, y, w, h)
63
-
64
- def clear(self,):
65
- self.clear_rect(0, 0, self.width, self.height)
66
-
67
- def stroke_rect(self, x, y, w, h):
68
- self.ctx.strokeRect(x, y, w, h)
69
-
70
- def fill_rect(self, x, y, w, h):
71
- self.ctx.fillRect(x, y, w, h)
72
-
73
- def put_image_data(self, image, x, y):
74
- data = Uint8ClampedArray.new(to_js(image.tobytes()))
75
- height, width, _ = image.shape
76
- image_data = ImageData.new(data, width, height)
77
- self.ctx.putImageData(image_data, x, y)
78
-
79
- @property
80
- def stroke_style(self):
81
- return self.ctx.strokeStyle
82
-
83
- @stroke_style.setter
84
- def stroke_style(self, value):
85
- self.ctx.strokeStyle = value
86
-
87
- @property
88
- def fill_style(self):
89
- return self.ctx.strokeStyle
90
-
91
- @fill_style.setter
92
- def fill_style(self, value):
93
- self.ctx.fillStyle = value
94
-
95
-
96
- # RGBA for masking
97
- class InfCanvas:
98
- def __init__(
99
- self,
100
- width,
101
- height,
102
- selection_size=256,
103
- grid_size=32,
104
- patch_size=4096,
105
- test_mode=False,
106
- ) -> None:
107
- assert selection_size < min(height, width)
108
- self.width = width
109
- self.height = height
110
- self.canvas = multi_canvas(5, width=width, height=height)
111
- # self.canvas = Canvas(width=width, height=height)
112
- self.view_pos = [0, 0]
113
- self.cursor = [
114
- width // 2 - selection_size // 2,
115
- height // 2 - selection_size // 2,
116
- ]
117
- self.data = {}
118
- self.grid_size = grid_size
119
- self.selection_size = selection_size
120
- self.patch_size = patch_size
121
- # note that for image data, the height comes before width
122
- self.buffer = np.zeros((height, width, 4), dtype=np.uint8)
123
- self.sel_buffer = np.zeros((selection_size, selection_size, 4), dtype=np.uint8)
124
- self.sel_buffer_bak = np.zeros(
125
- (selection_size, selection_size, 4), dtype=np.uint8
126
- )
127
- self.sel_dirty = False
128
- self.buffer_dirty = False
129
- self.mouse_pos = [-1, -1]
130
- self.mouse_state = 0
131
- # self.output = widgets.Output()
132
- self.test_mode = test_mode
133
- self.buffer_updated = False
134
- self.image_move_freq = 1
135
- self.show_brush = False
136
- # inpaint pipeline from diffuser
137
-
138
- def setup_mouse(self):
139
- self.image_move_cnt = 0
140
-
141
- def get_mouse_mode():
142
- mode = document.querySelector("#mode").value
143
- if mode == PAINT_SELECTION:
144
- return PAINT_MODE
145
- elif mode == IMAGE_SELECTION:
146
- return IMAGE_MODE
147
- return BRUSH_MODE
148
-
149
- def get_event_pos(event):
150
- canvas = self.canvas[-1].canvas
151
- rect = canvas.getBoundingClientRect()
152
- x = (canvas.width * (event.clientX - rect.left)) / rect.width
153
- y = (canvas.height * (event.clientY - rect.top)) / rect.height
154
- return x, y
155
-
156
- def handle_mouse_down(event):
157
- self.mouse_state = get_mouse_mode()
158
-
159
- def handle_mouse_out(event):
160
- last_state = self.mouse_state
161
- self.mouse_state = NOP_MODE
162
- self.image_move_cnt = 0
163
- if last_state == IMAGE_MODE:
164
- if True:
165
- self.clear_background()
166
- self.draw_buffer()
167
- self.canvas[2].clear()
168
- self.draw_selection_box()
169
- if self.show_brush:
170
- self.canvas[-2].clear()
171
- self.show_brush = False
172
-
173
- def handle_mouse_up(event):
174
- last_state = self.mouse_state
175
- self.mouse_state = NOP_MODE
176
- self.image_move_cnt = 0
177
- if last_state == IMAGE_MODE:
178
- if True:
179
- self.clear_background()
180
- self.draw_buffer()
181
- self.canvas[2].clear()
182
- self.draw_selection_box()
183
-
184
- async def handle_mouse_move(event):
185
- x, y = get_event_pos(event)
186
- x0, y0 = self.mouse_pos
187
- xo = x - x0
188
- yo = y - y0
189
- if self.mouse_state == PAINT_MODE:
190
- self.update_cursor(int(xo), int(yo))
191
- if True:
192
- # self.clear_background()
193
- # console.log(self.buffer_updated)
194
- if self.buffer_updated:
195
- self.draw_buffer()
196
- self.buffer_updated = False
197
- self.draw_selection_box()
198
- elif self.mouse_state == IMAGE_MODE:
199
- self.image_move_cnt += 1
200
- self.update_view_pos(int(xo), int(yo))
201
- if self.image_move_cnt == self.image_move_freq:
202
- if True:
203
- self.clear_background()
204
- self.draw_buffer()
205
- self.canvas[2].clear()
206
- self.draw_selection_box()
207
- self.image_move_cnt = 0
208
- elif self.mouse_state == BRUSH_MODE:
209
- if self.sel_dirty:
210
- self.write_selection_to_buffer()
211
- self.canvas[2].clear()
212
- self.buffer_dirty=True
213
- bx0,by0=int(x)-self.grid_size//2,int(y)-self.grid_size//2
214
- bx1,by1=bx0+self.grid_size,by0+self.grid_size
215
- bx0,by0=max(0,bx0),max(0,by0)
216
- bx1,by1=min(self.width,bx1),min(self.height,by1)
217
- self.buffer[by0:by1,bx0:bx1,:]*=0
218
- self.draw_buffer()
219
- self.draw_selection_box()
220
-
221
- mode = document.querySelector("#mode").value
222
- if mode == BRUSH_SELECTION:
223
- self.canvas[-2].clear()
224
- self.canvas[-2].fill_style = "#ffffff"
225
- self.canvas[-2].fill_rect(x-self.grid_size//2,y-self.grid_size//2,self.grid_size,self.grid_size)
226
- self.canvas[-2].stroke_rect(x-self.grid_size//2,y-self.grid_size//2,self.grid_size,self.grid_size)
227
- self.show_brush = True
228
- elif self.show_brush:
229
- self.canvas[-2].clear()
230
- self.show_brush = False
231
- self.mouse_pos[0] = x
232
- self.mouse_pos[1] = y
233
-
234
- self.canvas[-1].canvas.addEventListener(
235
- "mousedown", create_proxy(handle_mouse_down)
236
- )
237
- self.canvas[-1].canvas.addEventListener(
238
- "mousemove", create_proxy(handle_mouse_move)
239
- )
240
- self.canvas[-1].canvas.addEventListener(
241
- "mouseup", create_proxy(handle_mouse_up)
242
- )
243
- self.canvas[-1].canvas.addEventListener(
244
- "mouseout", create_proxy(handle_mouse_out)
245
- )
246
-
247
- def setup_widgets(self):
248
- self.mode_button = widgets.ToggleButtons(
249
- options=[PAINT_SELECTION, IMAGE_SELECTION],
250
- disabled=False,
251
- button_style="",
252
- style={"button_width": "50px", "font_weight": "bold"},
253
- tooltips=["Outpaint region", "Image"],
254
- )
255
- self.test_button = widgets.ToggleButtons(
256
- options=["r", "g", "b"],
257
- disabled=False,
258
- style={"button_width": "50px", "font_weight": "bold", "font_size": "36px"},
259
- )
260
- self.text_input = widgets.Textarea(
261
- value="",
262
- placeholder="input your prompt here",
263
- description="Prompt:",
264
- disabled=False,
265
- )
266
- self.run_button = widgets.Button(
267
- description="Outpaint",
268
- tooltip="Run outpainting",
269
- icon="pen",
270
- button_style="primary",
271
- )
272
- self.export_button = widgets.Button(
273
- description="Export",
274
- tooltip="Export the image",
275
- icon="save",
276
- button_style="success",
277
- )
278
- self.fill_button = widgets.ToggleButtons(
279
- description="Init mode:",
280
- options=[
281
- "patchmatch",
282
- "edge_pad",
283
- "cv2_ns",
284
- "cv2_telea",
285
- "gaussian",
286
- "perlin",
287
- ],
288
- disabled=False,
289
- button_style="",
290
- style={"button_width": "80px", "font_weight": "bold"},
291
- )
292
-
293
- if self.test_mode:
294
-
295
- def test_button_clicked(btn):
296
- # lst.append(tuple(base.cursor))
297
- with self.output:
298
- val = self.test_button.value
299
- if val == "r":
300
- self.fill_selection(
301
- np.tile(
302
- np.array([255, 0, 0, 255], dtype=np.uint8),
303
- (self.selection_size, self.selection_size, 1),
304
- )
305
- )
306
- if val == "g":
307
- self.fill_selection(
308
- np.tile(
309
- np.array([0, 255, 0, 255], dtype=np.uint8),
310
- (self.selection_size, self.selection_size, 1),
311
- )
312
- )
313
- if val == "b":
314
- self.fill_selection(
315
- np.tile(
316
- np.array([0, 0, 255, 255], dtype=np.uint8),
317
- (self.selection_size, self.selection_size, 1),
318
- )
319
- )
320
- if True:
321
- self.clear_background()
322
- self.draw_buffer()
323
- self.draw_selection_box()
324
-
325
- self.run_button.on_click(test_button_clicked)
326
-
327
- def display(self):
328
- if True:
329
- self.clear_background()
330
- self.draw_buffer()
331
- self.draw_selection_box()
332
- if self.test_mode:
333
- return [
334
- self.test_button,
335
- self.mode_button,
336
- self.canvas,
337
- widgets.HBox([self.run_button, self.text_input]),
338
- self.output,
339
- ]
340
- return [
341
- self.fill_button,
342
- self.canvas,
343
- widgets.HBox(
344
- [self.mode_button, self.run_button, self.export_button, self.text_input]
345
- ),
346
- self.output,
347
- ]
348
-
349
- def clear_background(self):
350
- # fake transparent background
351
- h, w, step = self.height, self.width, self.grid_size
352
- stride = step * 2
353
- x0, y0 = self.view_pos
354
- x0 = (-x0) % stride
355
- y0 = (-y0) % stride
356
- # self.canvas.clear()
357
- self.canvas[0].fill_style = "#ffffff"
358
- self.canvas[0].fill_rect(0, 0, w, h)
359
- self.canvas[0].fill_style = "#aaaaaa"
360
- for y in range(y0 - stride, h + step, step):
361
- start = (x0 - stride) if y // step % 2 == 0 else (x0 - step)
362
- for x in range(start, w + step, stride):
363
- self.canvas[0].fill_rect(x, y, step, step)
364
- self.canvas[0].stroke_rect(0, 0, w, h)
365
-
366
- def update_view_pos(self, xo, yo):
367
- if abs(xo) + abs(yo) == 0:
368
- return
369
- if self.sel_dirty:
370
- self.write_selection_to_buffer()
371
- if self.buffer_dirty:
372
- self.buffer2data()
373
- self.view_pos[0] -= xo
374
- self.view_pos[1] -= yo
375
- self.data2buffer()
376
- # self.read_selection_from_buffer()
377
-
378
- def update_cursor(self, xo, yo):
379
- if abs(xo) + abs(yo) == 0:
380
- return
381
- if self.sel_dirty:
382
- self.write_selection_to_buffer()
383
- self.cursor[0] += xo
384
- self.cursor[1] += yo
385
- self.cursor[0] = max(min(self.width - self.selection_size, self.cursor[0]), 0)
386
- self.cursor[1] = max(min(self.height - self.selection_size, self.cursor[1]), 0)
387
- # self.read_selection_from_buffer()
388
-
389
- def data2buffer(self):
390
- x, y = self.view_pos
391
- h, w = self.height, self.width
392
- # fill four parts
393
- for i in range(4):
394
- pos_src, pos_dst, data = self.select(x, y, i)
395
- xs0, xs1 = pos_src[0]
396
- ys0, ys1 = pos_src[1]
397
- xd0, xd1 = pos_dst[0]
398
- yd0, yd1 = pos_dst[1]
399
- self.buffer[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :]
400
-
401
- def buffer2data(self):
402
- x, y = self.view_pos
403
- h, w = self.height, self.width
404
- # fill four parts
405
- for i in range(4):
406
- pos_src, pos_dst, data = self.select(x, y, i)
407
- xs0, xs1 = pos_src[0]
408
- ys0, ys1 = pos_src[1]
409
- xd0, xd1 = pos_dst[0]
410
- yd0, yd1 = pos_dst[1]
411
- data[ys0:ys1, xs0:xs1, :] = self.buffer[yd0:yd1, xd0:xd1, :]
412
- self.buffer_dirty = False
413
-
414
- def select(self, x, y, idx):
415
- w, h = self.width, self.height
416
- lst = [(0, 0), (0, h), (w, 0), (w, h)]
417
- if idx == 0:
418
- x0, y0 = x % self.patch_size, y % self.patch_size
419
- x1 = min(x0 + w, self.patch_size)
420
- y1 = min(y0 + h, self.patch_size)
421
- elif idx == 1:
422
- y += h
423
- x0, y0 = x % self.patch_size, y % self.patch_size
424
- x1 = min(x0 + w, self.patch_size)
425
- y1 = max(y0 - h, 0)
426
- elif idx == 2:
427
- x += w
428
- x0, y0 = x % self.patch_size, y % self.patch_size
429
- x1 = max(x0 - w, 0)
430
- y1 = min(y0 + h, self.patch_size)
431
- else:
432
- x += w
433
- y += h
434
- x0, y0 = x % self.patch_size, y % self.patch_size
435
- x1 = max(x0 - w, 0)
436
- y1 = max(y0 - h, 0)
437
- xi, yi = x // self.patch_size, y // self.patch_size
438
- cur = self.data.setdefault(
439
- (xi, yi), np.zeros((self.patch_size, self.patch_size, 4), dtype=np.uint8)
440
- )
441
- x0_img, y0_img = lst[idx]
442
- x1_img = x0_img + x1 - x0
443
- y1_img = y0_img + y1 - y0
444
- sort = lambda a, b: ((a, b) if a < b else (b, a))
445
- return (
446
- (sort(x0, x1), sort(y0, y1)),
447
- (sort(x0_img, x1_img), sort(y0_img, y1_img)),
448
- cur,
449
- )
450
-
451
- def draw_buffer(self):
452
- self.canvas[1].clear()
453
- self.canvas[1].put_image_data(self.buffer, 0, 0)
454
-
455
- def fill_selection(self, img):
456
- self.sel_buffer = img
457
- self.sel_dirty = True
458
-
459
- def draw_selection_box(self):
460
- x0, y0 = self.cursor
461
- size = self.selection_size
462
- if self.sel_dirty:
463
- self.canvas[2].clear()
464
- self.canvas[2].put_image_data(self.sel_buffer, x0, y0)
465
- self.canvas[-1].clear()
466
- self.canvas[-1].stroke_style = "#0a0a0a"
467
- self.canvas[-1].stroke_rect(x0, y0, size, size)
468
- self.canvas[-1].stroke_style = "#ffffff"
469
- self.canvas[-1].stroke_rect(x0 - 1, y0 - 1, size + 2, size + 2)
470
- self.canvas[-1].stroke_style = "#000000"
471
- self.canvas[-1].stroke_rect(x0 - 2, y0 - 2, size + 4, size + 4)
472
-
473
- def write_selection_to_buffer(self):
474
- x0, y0 = self.cursor
475
- x1, y1 = x0 + self.selection_size, y0 + self.selection_size
476
- self.buffer[y0:y1, x0:x1] = self.sel_buffer
477
- self.sel_dirty = False
478
- self.sel_buffer = self.sel_buffer_bak.copy()
479
- self.buffer_dirty = True
480
- self.buffer_updated = True
481
- # self.canvas[2].clear()
482
-
483
- def read_selection_from_buffer(self):
484
- x0, y0 = self.cursor
485
- x1, y1 = x0 + self.selection_size, y0 + self.selection_size
486
- self.sel_buffer = self.buffer[y0:y1, x0:x1]
487
- self.sel_dirty = False
488
-
489
- def base64_to_numpy(self, base64_str):
490
- try:
491
- data = base64.b64decode(str(base64_str))
492
- pil = Image.open(io.BytesIO(data))
493
- arr = np.array(pil)
494
- ret = arr
495
- except:
496
- ret = np.tile(
497
- np.array([255, 0, 0, 255], dtype=np.uint8),
498
- (self.selection_size, self.selection_size, 1),
499
- )
500
- return ret
501
-
502
- def numpy_to_base64(self, arr):
503
- out_pil = Image.fromarray(arr)
504
- out_buffer = io.BytesIO()
505
- out_pil.save(out_buffer, format="PNG")
506
- out_buffer.seek(0)
507
- base64_bytes = base64.b64encode(out_buffer.read())
508
- base64_str = base64_bytes.decode("ascii")
509
- return base64_str
510
-
511
- def export(self):
512
- if self.sel_dirty:
513
- self.write_selection_to_buffer()
514
- if self.buffer_dirty:
515
- self.buffer2data()
516
- xmin, xmax, ymin, ymax = 0, 0, 0, 0
517
- if len(self.data.keys()) == 0:
518
- return np.zeros(
519
- (self.selection_size, self.selection_size, 4), dtype=np.uint8
520
- )
521
- for xi, yi in self.data.keys():
522
- buf = self.data[(xi, yi)]
523
- if buf.sum() > 0:
524
- xmin = min(xi, xmin)
525
- xmax = max(xi, xmax)
526
- ymin = min(yi, ymin)
527
- ymax = max(yi, ymax)
528
- yn = ymax - ymin + 1
529
- xn = xmax - xmin + 1
530
- image = np.zeros(
531
- (yn * self.patch_size, xn * self.patch_size, 4), dtype=np.uint8
532
- )
533
- for xi, yi in self.data.keys():
534
- buf = self.data[(xi, yi)]
535
- if buf.sum() > 0:
536
- y0 = (yi - ymin) * self.patch_size
537
- x0 = (xi - xmin) * self.patch_size
538
- image[y0 : y0 + self.patch_size, x0 : x0 + self.patch_size] = buf
539
- ylst, xlst = image[:, :, -1].nonzero()
540
- if len(ylst) > 0:
541
- yt, xt = ylst.min(), xlst.min()
542
- yb, xb = ylst.max(), xlst.max()
543
- image = image[yt : yb + 1, xt : xb + 1]
544
- return image
545
- else:
546
- return np.zeros(
547
- (self.selection_size, self.selection_size, 4), dtype=np.uint8
548
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import io
4
+ import numpy as np
5
+ from PIL import Image
6
+ from pyodide import to_js, create_proxy
7
+ import gc
8
+ from js import (
9
+ console,
10
+ document,
11
+ devicePixelRatio,
12
+ ImageData,
13
+ Uint8ClampedArray,
14
+ CanvasRenderingContext2D as Context2d,
15
+ requestAnimationFrame,
16
+ update_overlay,
17
+ setup_overlay,
18
+ window
19
+ )
20
+
21
+ PAINT_SELECTION = "selection"
22
+ IMAGE_SELECTION = "canvas"
23
+ BRUSH_SELECTION = "eraser"
24
+ NOP_MODE = 0
25
+ PAINT_MODE = 1
26
+ IMAGE_MODE = 2
27
+ BRUSH_MODE = 3
28
+
29
+
30
+ def hold_canvas():
31
+ pass
32
+
33
+
34
+ def prepare_canvas(width, height, canvas) -> Context2d:
35
+ ctx = canvas.getContext("2d")
36
+
37
+ canvas.style.width = f"{width}px"
38
+ canvas.style.height = f"{height}px"
39
+
40
+ canvas.width = width
41
+ canvas.height = height
42
+
43
+ ctx.clearRect(0, 0, width, height)
44
+
45
+ return ctx
46
+
47
+
48
+ # class MultiCanvas:
49
+ # def __init__(self,layer,width=800, height=600) -> None:
50
+ # pass
51
+ def multi_canvas(layer, width=800, height=600):
52
+ lst = [
53
+ CanvasProxy(document.querySelector(f"#canvas{i}"), width, height)
54
+ for i in range(layer)
55
+ ]
56
+ return lst
57
+
58
+
59
+ class CanvasProxy:
60
+ def __init__(self, canvas, width=800, height=600) -> None:
61
+ self.canvas = canvas
62
+ self.ctx = prepare_canvas(width, height, canvas)
63
+ self.width = width
64
+ self.height = height
65
+
66
+ def clear_rect(self, x, y, w, h):
67
+ self.ctx.clearRect(x, y, w, h)
68
+
69
+ def clear(self,):
70
+ self.clear_rect(0, 0, self.canvas.width, self.canvas.height)
71
+
72
+ def stroke_rect(self, x, y, w, h):
73
+ self.ctx.strokeRect(x, y, w, h)
74
+
75
+ def fill_rect(self, x, y, w, h):
76
+ self.ctx.fillRect(x, y, w, h)
77
+
78
+ def put_image_data(self, image, x, y):
79
+ data = Uint8ClampedArray.new(to_js(image.tobytes()))
80
+ height, width, _ = image.shape
81
+ image_data = ImageData.new(data, width, height)
82
+ self.ctx.putImageData(image_data, x, y)
83
+ del image_data
84
+
85
+ # def draw_image(self,canvas, x, y, w, h):
86
+ # self.ctx.drawImage(canvas,x,y,w,h)
87
+ def draw_image(self,canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight):
88
+ self.ctx.drawImage(canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight)
89
+
90
+ @property
91
+ def stroke_style(self):
92
+ return self.ctx.strokeStyle
93
+
94
+ @stroke_style.setter
95
+ def stroke_style(self, value):
96
+ self.ctx.strokeStyle = value
97
+
98
+ @property
99
+ def fill_style(self):
100
+ return self.ctx.strokeStyle
101
+
102
+ @fill_style.setter
103
+ def fill_style(self, value):
104
+ self.ctx.fillStyle = value
105
+
106
+
107
+ # RGBA for masking
108
+ class InfCanvas:
109
+ def __init__(
110
+ self,
111
+ width,
112
+ height,
113
+ selection_size=256,
114
+ grid_size=64,
115
+ patch_size=4096,
116
+ test_mode=False,
117
+ ) -> None:
118
+ assert selection_size < min(height, width)
119
+ self.width = width
120
+ self.height = height
121
+ self.display_width = width
122
+ self.display_height = height
123
+ self.canvas = multi_canvas(5, width=width, height=height)
124
+ setup_overlay(width,height)
125
+ # place at center
126
+ self.view_pos = [patch_size//2-width//2, patch_size//2-height//2]
127
+ self.cursor = [
128
+ width // 2 - selection_size // 2,
129
+ height // 2 - selection_size // 2,
130
+ ]
131
+ self.data = {}
132
+ self.grid_size = grid_size
133
+ self.selection_size_w = selection_size
134
+ self.selection_size_h = selection_size
135
+ self.patch_size = patch_size
136
+ # note that for image data, the height comes before width
137
+ self.buffer = np.zeros((height, width, 4), dtype=np.uint8)
138
+ self.sel_buffer = np.zeros((selection_size, selection_size, 4), dtype=np.uint8)
139
+ self.sel_buffer_bak = np.zeros(
140
+ (selection_size, selection_size, 4), dtype=np.uint8
141
+ )
142
+ self.sel_dirty = False
143
+ self.buffer_dirty = False
144
+ self.mouse_pos = [-1, -1]
145
+ self.mouse_state = 0
146
+ # self.output = widgets.Output()
147
+ self.test_mode = test_mode
148
+ self.buffer_updated = False
149
+ self.image_move_freq = 1
150
+ self.show_brush = False
151
+ self.scale=1.0
152
+ self.eraser_size=32
153
+
154
+ def reset_large_buffer(self):
155
+ self.canvas[2].canvas.width=self.width
156
+ self.canvas[2].canvas.height=self.height
157
+ # self.canvas[2].canvas.style.width=f"{self.display_width}px"
158
+ # self.canvas[2].canvas.style.height=f"{self.display_height}px"
159
+ self.canvas[2].canvas.style.display="block"
160
+ self.canvas[2].clear()
161
+
162
+ def draw_eraser(self, x, y):
163
+ self.canvas[-2].clear()
164
+ self.canvas[-2].fill_style = "#ffffff"
165
+ self.canvas[-2].fill_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size)
166
+ self.canvas[-2].stroke_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size)
167
+
168
+ def use_eraser(self,x,y):
169
+ if self.sel_dirty:
170
+ self.write_selection_to_buffer()
171
+ self.draw_buffer()
172
+ self.canvas[2].clear()
173
+ self.buffer_dirty=True
174
+ bx0,by0=int(x)-self.eraser_size//2,int(y)-self.eraser_size//2
175
+ bx1,by1=bx0+self.eraser_size,by0+self.eraser_size
176
+ bx0,by0=max(0,bx0),max(0,by0)
177
+ bx1,by1=min(self.width,bx1),min(self.height,by1)
178
+ self.buffer[by0:by1,bx0:bx1,:]*=0
179
+ self.draw_buffer()
180
+ self.draw_selection_box()
181
+
182
+ def setup_mouse(self):
183
+ self.image_move_cnt = 0
184
+
185
+ def get_mouse_mode():
186
+ mode = document.querySelector("#mode").value
187
+ if mode == PAINT_SELECTION:
188
+ return PAINT_MODE
189
+ elif mode == IMAGE_SELECTION:
190
+ return IMAGE_MODE
191
+ return BRUSH_MODE
192
+
193
+ def get_event_pos(event):
194
+ canvas = self.canvas[-1].canvas
195
+ rect = canvas.getBoundingClientRect()
196
+ x = (canvas.width * (event.clientX - rect.left)) / rect.width
197
+ y = (canvas.height * (event.clientY - rect.top)) / rect.height
198
+ return x, y
199
+
200
+ def handle_mouse_down(event):
201
+ self.mouse_state = get_mouse_mode()
202
+ if self.mouse_state==BRUSH_MODE:
203
+ x,y=get_event_pos(event)
204
+ self.use_eraser(x,y)
205
+
206
+ def handle_mouse_out(event):
207
+ last_state = self.mouse_state
208
+ self.mouse_state = NOP_MODE
209
+ self.image_move_cnt = 0
210
+ if last_state == IMAGE_MODE:
211
+ self.update_view_pos(0, 0)
212
+ if True:
213
+ self.clear_background()
214
+ self.draw_buffer()
215
+ self.reset_large_buffer()
216
+ self.draw_selection_box()
217
+ gc.collect()
218
+ if self.show_brush:
219
+ self.canvas[-2].clear()
220
+ self.show_brush = False
221
+
222
+ def handle_mouse_up(event):
223
+ last_state = self.mouse_state
224
+ self.mouse_state = NOP_MODE
225
+ self.image_move_cnt = 0
226
+ if last_state == IMAGE_MODE:
227
+ self.update_view_pos(0, 0)
228
+ if True:
229
+ self.clear_background()
230
+ self.draw_buffer()
231
+ self.reset_large_buffer()
232
+ self.draw_selection_box()
233
+ gc.collect()
234
+
235
+ async def handle_mouse_move(event):
236
+ x, y = get_event_pos(event)
237
+ x0, y0 = self.mouse_pos
238
+ xo = x - x0
239
+ yo = y - y0
240
+ if self.mouse_state == PAINT_MODE:
241
+ self.update_cursor(int(xo), int(yo))
242
+ if True:
243
+ # self.clear_background()
244
+ # console.log(self.buffer_updated)
245
+ if self.buffer_updated:
246
+ self.draw_buffer()
247
+ self.buffer_updated = False
248
+ self.draw_selection_box()
249
+ elif self.mouse_state == IMAGE_MODE:
250
+ self.image_move_cnt += 1
251
+ if self.image_move_cnt == self.image_move_freq:
252
+ self.draw_buffer()
253
+ self.canvas[2].clear()
254
+ self.draw_selection_box()
255
+ self.update_view_pos(int(xo), int(yo))
256
+ self.cached_view_pos=tuple(self.view_pos)
257
+ self.canvas[2].canvas.style.display="none"
258
+ large_buffer=self.data2array(self.view_pos[0]-self.width//2,self.view_pos[1]-self.height//2,min(self.width*2,self.patch_size*2),min(self.height*2,self.patch_size*2))
259
+ self.canvas[2].canvas.width=2*self.width
260
+ self.canvas[2].canvas.height=2*self.height
261
+ # self.canvas[2].canvas.style.width=""
262
+ # self.canvas[2].canvas.style.height=""
263
+ self.canvas[2].put_image_data(large_buffer,0,0)
264
+ else:
265
+ self.update_view_pos(int(xo), int(yo), False)
266
+ self.canvas[1].clear()
267
+ self.canvas[1].draw_image(self.canvas[2].canvas,
268
+ self.width//2+(self.view_pos[0]-self.cached_view_pos[0]),self.height//2+(self.view_pos[1]-self.cached_view_pos[1]),
269
+ self.width,self.height,
270
+ 0,0,self.width,self.height
271
+ )
272
+ self.clear_background()
273
+ # self.image_move_cnt = 0
274
+ elif self.mouse_state == BRUSH_MODE:
275
+ self.use_eraser(x,y)
276
+
277
+ mode = document.querySelector("#mode").value
278
+ if mode == BRUSH_SELECTION:
279
+ self.draw_eraser(x,y)
280
+ self.show_brush = True
281
+ elif self.show_brush:
282
+ self.canvas[-2].clear()
283
+ self.show_brush = False
284
+ self.mouse_pos[0] = x
285
+ self.mouse_pos[1] = y
286
+
287
+ self.canvas[-1].canvas.addEventListener(
288
+ "mousedown", create_proxy(handle_mouse_down)
289
+ )
290
+ self.canvas[-1].canvas.addEventListener(
291
+ "mousemove", create_proxy(handle_mouse_move)
292
+ )
293
+ self.canvas[-1].canvas.addEventListener(
294
+ "mouseup", create_proxy(handle_mouse_up)
295
+ )
296
+ self.canvas[-1].canvas.addEventListener(
297
+ "mouseout", create_proxy(handle_mouse_out)
298
+ )
299
+ async def handle_mouse_wheel(event):
300
+ x, y = get_event_pos(event)
301
+ self.mouse_pos[0] = x
302
+ self.mouse_pos[1] = y
303
+ console.log(to_js(self.mouse_pos))
304
+ if event.deltaY>10:
305
+ window.postMessage(to_js(["click","zoom_out", self.mouse_pos[0], self.mouse_pos[1]]),"*")
306
+ elif event.deltaY<-10:
307
+ window.postMessage(to_js(["click","zoom_in", self.mouse_pos[0], self.mouse_pos[1]]),"*")
308
+ return False
309
+ self.canvas[-1].canvas.addEventListener(
310
+ "wheel", create_proxy(handle_mouse_wheel), False
311
+ )
312
+ def clear_background(self):
313
+ # fake transparent background
314
+ h, w, step = self.height, self.width, self.grid_size
315
+ stride = step * 2
316
+ x0, y0 = self.view_pos
317
+ x0 = (-x0) % stride
318
+ y0 = (-y0) % stride
319
+ if y0>=step:
320
+ val0,val1=stride,step
321
+ else:
322
+ val0,val1=step,stride
323
+ # self.canvas.clear()
324
+ self.canvas[0].fill_style = "#ffffff"
325
+ self.canvas[0].fill_rect(0, 0, w, h)
326
+ self.canvas[0].fill_style = "#aaaaaa"
327
+ for y in range(y0-stride, h + step, step):
328
+ start = (x0 - val0) if y // step % 2 == 0 else (x0 - val1)
329
+ for x in range(start, w + step, stride):
330
+ self.canvas[0].fill_rect(x, y, step, step)
331
+ self.canvas[0].stroke_rect(0, 0, w, h)
332
+
333
+ def refine_selection(self):
334
+ h,w=self.selection_size_h,self.selection_size_w
335
+ h=h//8*8
336
+ w=w//8*8
337
+ h=min(h,self.height)
338
+ w=min(w,self.width)
339
+ self.selection_size_h=h
340
+ self.selection_size_w=w
341
+ self.update_cursor(1,0)
342
+
343
+
344
+ def update_scale(self, scale, mx=-1, my=-1):
345
+ self.sync_to_data()
346
+ scaled_width=int(self.display_width*scale)
347
+ scaled_height=int(self.display_height*scale)
348
+ if max(scaled_height,scaled_width)>=self.patch_size*2-128:
349
+ return
350
+ if scaled_height<=self.selection_size_h or scaled_width<=self.selection_size_w:
351
+ return
352
+ if mx>=0 and my>=0:
353
+ scaled_mx=mx/self.scale*scale
354
+ scaled_my=my/self.scale*scale
355
+ self.view_pos[0]+=int(mx-scaled_mx)
356
+ self.view_pos[1]+=int(my-scaled_my)
357
+ self.scale=scale
358
+ for item in self.canvas:
359
+ item.canvas.width=scaled_width
360
+ item.canvas.height=scaled_height
361
+ item.clear()
362
+ update_overlay(scaled_width,scaled_height)
363
+ self.width=scaled_width
364
+ self.height=scaled_height
365
+ self.data2buffer()
366
+ self.clear_background()
367
+ self.draw_buffer()
368
+ self.update_cursor(1,0)
369
+ self.draw_selection_box()
370
+
371
+ def update_view_pos(self, xo, yo, update=True):
372
+ # if abs(xo) + abs(yo) == 0:
373
+ # return
374
+ if self.sel_dirty:
375
+ self.write_selection_to_buffer()
376
+ if self.buffer_dirty:
377
+ self.buffer2data()
378
+ self.view_pos[0] -= xo
379
+ self.view_pos[1] -= yo
380
+ if update:
381
+ self.data2buffer()
382
+ # self.read_selection_from_buffer()
383
+
384
+ def update_cursor(self, xo, yo):
385
+ if abs(xo) + abs(yo) == 0:
386
+ return
387
+ if self.sel_dirty:
388
+ self.write_selection_to_buffer()
389
+ self.cursor[0] += xo
390
+ self.cursor[1] += yo
391
+ self.cursor[0] = max(min(self.width - self.selection_size_w, self.cursor[0]), 0)
392
+ self.cursor[1] = max(min(self.height - self.selection_size_h, self.cursor[1]), 0)
393
+ # self.read_selection_from_buffer()
394
+
395
+ def data2buffer(self):
396
+ x, y = self.view_pos
397
+ h, w = self.height, self.width
398
+ if h!=self.buffer.shape[0] or w!=self.buffer.shape[1]:
399
+ self.buffer=np.zeros((self.height, self.width, 4), dtype=np.uint8)
400
+ # fill four parts
401
+ for i in range(4):
402
+ pos_src, pos_dst, data = self.select(x, y, i)
403
+ xs0, xs1 = pos_src[0]
404
+ ys0, ys1 = pos_src[1]
405
+ xd0, xd1 = pos_dst[0]
406
+ yd0, yd1 = pos_dst[1]
407
+ self.buffer[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :]
408
+
409
+ def data2array(self, x, y, w, h):
410
+ # x, y = self.view_pos
411
+ # h, w = self.height, self.width
412
+ ret=np.zeros((h, w, 4), dtype=np.uint8)
413
+ # fill four parts
414
+ for i in range(4):
415
+ pos_src, pos_dst, data = self.select(x, y, i, w, h)
416
+ xs0, xs1 = pos_src[0]
417
+ ys0, ys1 = pos_src[1]
418
+ xd0, xd1 = pos_dst[0]
419
+ yd0, yd1 = pos_dst[1]
420
+ ret[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :]
421
+ return ret
422
+
423
+ def buffer2data(self):
424
+ x, y = self.view_pos
425
+ h, w = self.height, self.width
426
+ # fill four parts
427
+ for i in range(4):
428
+ pos_src, pos_dst, data = self.select(x, y, i)
429
+ xs0, xs1 = pos_src[0]
430
+ ys0, ys1 = pos_src[1]
431
+ xd0, xd1 = pos_dst[0]
432
+ yd0, yd1 = pos_dst[1]
433
+ data[ys0:ys1, xs0:xs1, :] = self.buffer[yd0:yd1, xd0:xd1, :]
434
+ self.buffer_dirty = False
435
+
436
+ def select(self, x, y, idx, width=0, height=0):
437
+ if width==0:
438
+ w, h = self.width, self.height
439
+ else:
440
+ w, h = width, height
441
+ lst = [(0, 0), (0, h), (w, 0), (w, h)]
442
+ if idx == 0:
443
+ x0, y0 = x % self.patch_size, y % self.patch_size
444
+ x1 = min(x0 + w, self.patch_size)
445
+ y1 = min(y0 + h, self.patch_size)
446
+ elif idx == 1:
447
+ y += h
448
+ x0, y0 = x % self.patch_size, y % self.patch_size
449
+ x1 = min(x0 + w, self.patch_size)
450
+ y1 = max(y0 - h, 0)
451
+ elif idx == 2:
452
+ x += w
453
+ x0, y0 = x % self.patch_size, y % self.patch_size
454
+ x1 = max(x0 - w, 0)
455
+ y1 = min(y0 + h, self.patch_size)
456
+ else:
457
+ x += w
458
+ y += h
459
+ x0, y0 = x % self.patch_size, y % self.patch_size
460
+ x1 = max(x0 - w, 0)
461
+ y1 = max(y0 - h, 0)
462
+ xi, yi = x // self.patch_size, y // self.patch_size
463
+ cur = self.data.setdefault(
464
+ (xi, yi), np.zeros((self.patch_size, self.patch_size, 4), dtype=np.uint8)
465
+ )
466
+ x0_img, y0_img = lst[idx]
467
+ x1_img = x0_img + x1 - x0
468
+ y1_img = y0_img + y1 - y0
469
+ sort = lambda a, b: ((a, b) if a < b else (b, a))
470
+ return (
471
+ (sort(x0, x1), sort(y0, y1)),
472
+ (sort(x0_img, x1_img), sort(y0_img, y1_img)),
473
+ cur,
474
+ )
475
+
476
+ def draw_buffer(self):
477
+ self.canvas[1].clear()
478
+ self.canvas[1].put_image_data(self.buffer, 0, 0)
479
+
480
+ def fill_selection(self, img):
481
+ self.sel_buffer = img
482
+ self.sel_dirty = True
483
+
484
+ def draw_selection_box(self):
485
+ x0, y0 = self.cursor
486
+ w, h = self.selection_size_w, self.selection_size_h
487
+ if self.sel_dirty:
488
+ self.canvas[2].clear()
489
+ self.canvas[2].put_image_data(self.sel_buffer, x0, y0)
490
+ self.canvas[-1].clear()
491
+ self.canvas[-1].stroke_style = "#0a0a0a"
492
+ self.canvas[-1].stroke_rect(x0, y0, w, h)
493
+ self.canvas[-1].stroke_style = "#ffffff"
494
+ offset=round(self.scale) if self.scale>1.0 else 1
495
+ self.canvas[-1].stroke_rect(x0 - offset, y0 - offset, w + offset*2, h + offset*2)
496
+ self.canvas[-1].stroke_style = "#000000"
497
+ self.canvas[-1].stroke_rect(x0 - offset*2, y0 - offset*2, w + offset*4, h + offset*4)
498
+
499
+ def write_selection_to_buffer(self):
500
+ x0, y0 = self.cursor
501
+ x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h
502
+ self.buffer[y0:y1, x0:x1] = self.sel_buffer
503
+ self.sel_dirty = False
504
+ self.sel_buffer = np.zeros(
505
+ (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
506
+ )
507
+ self.buffer_dirty = True
508
+ self.buffer_updated = True
509
+ # self.canvas[2].clear()
510
+
511
+ def read_selection_from_buffer(self):
512
+ x0, y0 = self.cursor
513
+ x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h
514
+ self.sel_buffer = self.buffer[y0:y1, x0:x1]
515
+ self.sel_dirty = False
516
+
517
+ def base64_to_numpy(self, base64_str):
518
+ try:
519
+ data = base64.b64decode(str(base64_str))
520
+ pil = Image.open(io.BytesIO(data))
521
+ arr = np.array(pil)
522
+ ret = arr
523
+ except:
524
+ ret = np.tile(
525
+ np.array([255, 0, 0, 255], dtype=np.uint8),
526
+ (self.selection_size_h, self.selection_size_w, 1),
527
+ )
528
+ return ret
529
+
530
+ def numpy_to_base64(self, arr):
531
+ out_pil = Image.fromarray(arr)
532
+ out_buffer = io.BytesIO()
533
+ out_pil.save(out_buffer, format="PNG")
534
+ out_buffer.seek(0)
535
+ base64_bytes = base64.b64encode(out_buffer.read())
536
+ base64_str = base64_bytes.decode("ascii")
537
+ return base64_str
538
+
539
+ def sync_to_data(self):
540
+ if self.sel_dirty:
541
+ self.write_selection_to_buffer()
542
+ self.canvas[2].clear()
543
+ self.draw_buffer()
544
+ if self.buffer_dirty:
545
+ self.buffer2data()
546
+
547
+ def sync_to_buffer(self):
548
+ if self.sel_dirty:
549
+ self.canvas[2].clear()
550
+ self.write_selection_to_buffer()
551
+ self.draw_buffer()
552
+
553
+ def resize(self,width,height,scale=None,**kwargs):
554
+ self.display_width=width
555
+ self.display_height=height
556
+ for canvas in self.canvas:
557
+ prepare_canvas(width=width,height=height,canvas=canvas.canvas)
558
+ setup_overlay(width,height)
559
+ if scale is None:
560
+ scale=1
561
+ self.update_scale(scale)
562
+
563
+
564
+ def save(self):
565
+ self.sync_to_data()
566
+ state={}
567
+ state["width"]=self.display_width
568
+ state["height"]=self.display_height
569
+ state["selection_width"]=self.selection_size_w
570
+ state["selection_height"]=self.selection_size_h
571
+ state["view_pos"]=self.view_pos[:]
572
+ state["cursor"]=self.cursor[:]
573
+ state["scale"]=self.scale
574
+ keys=list(self.data.keys())
575
+ data={}
576
+ for key in keys:
577
+ if self.data[key].sum()>0:
578
+ data[f"{key[0]},{key[1]}"]=self.numpy_to_base64(self.data[key])
579
+ state["data"]=data
580
+ return json.dumps(state)
581
+
582
+ def load(self, state_json):
583
+ self.reset()
584
+ state=json.loads(state_json)
585
+ self.display_width=state["width"]
586
+ self.display_height=state["height"]
587
+ self.selection_size_w=state["selection_width"]
588
+ self.selection_size_h=state["selection_height"]
589
+ self.view_pos=state["view_pos"][:]
590
+ self.cursor=state["cursor"][:]
591
+ self.scale=state["scale"]
592
+ self.resize(state["width"],state["height"],scale=state["scale"])
593
+ for k,v in state["data"].items():
594
+ key=tuple(map(int,k.split(",")))
595
+ self.data[key]=self.base64_to_numpy(v)
596
+ self.data2buffer()
597
+ self.display()
598
+
599
+ def display(self):
600
+ self.clear_background()
601
+ self.draw_buffer()
602
+ self.draw_selection_box()
603
+
604
+ def reset(self):
605
+ self.data.clear()
606
+ self.buffer*=0
607
+ self.buffer_dirty=False
608
+ self.buffer_updated=False
609
+ self.sel_buffer*=0
610
+ self.sel_dirty=False
611
+ self.view_pos = [0, 0]
612
+ self.clear_background()
613
+ for i in range(1,len(self.canvas)-1):
614
+ self.canvas[i].clear()
615
+
616
+ def export(self):
617
+ self.sync_to_data()
618
+ xmin, xmax, ymin, ymax = 0, 0, 0, 0
619
+ if len(self.data.keys()) == 0:
620
+ return np.zeros(
621
+ (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
622
+ )
623
+ for xi, yi in self.data.keys():
624
+ buf = self.data[(xi, yi)]
625
+ if buf.sum() > 0:
626
+ xmin = min(xi, xmin)
627
+ xmax = max(xi, xmax)
628
+ ymin = min(yi, ymin)
629
+ ymax = max(yi, ymax)
630
+ yn = ymax - ymin + 1
631
+ xn = xmax - xmin + 1
632
+ image = np.zeros(
633
+ (yn * self.patch_size, xn * self.patch_size, 4), dtype=np.uint8
634
+ )
635
+ for xi, yi in self.data.keys():
636
+ buf = self.data[(xi, yi)]
637
+ if buf.sum() > 0:
638
+ y0 = (yi - ymin) * self.patch_size
639
+ x0 = (xi - xmin) * self.patch_size
640
+ image[y0 : y0 + self.patch_size, x0 : x0 + self.patch_size] = buf
641
+ ylst, xlst = image[:, :, -1].nonzero()
642
+ if len(ylst) > 0:
643
+ yt, xt = ylst.min(), xlst.min()
644
+ yb, xb = ylst.max(), xlst.max()
645
+ image = image[yt : yb + 1, xt : xb + 1]
646
+ return image
647
+ else:
648
+ return np.zeros(
649
+ (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
650
+ )
index.html CHANGED
@@ -1,214 +1,411 @@
1
- <html>
2
- <head>
3
- <title>Stablediffusion Infinity</title>
4
- <meta charset="utf-8">
5
- <link rel="icon" type="image/x-icon" href="./favicon.png">
6
- <link rel="stylesheet" href="https://pyscript.net/alpha/pyscript.css" />
7
- <script defer src="https://pyscript.net/alpha/pyscript.js"></script>
8
- <style>
9
- #container {
10
- position: relative;
11
- margin:auto;
12
- }
13
- #container > canvas {
14
- position: absolute;
15
- top: 0;
16
- left: 0;
17
- }
18
- .control {
19
- display: none;
20
- }
21
- </style>
22
-
23
- </head>
24
- <body>
25
- <div>
26
- <button type="button" class="control" id="export">Export</button>
27
- <button type="button" class="control" id="outpaint">Outpaint</button>
28
- <button type="button" class="control" id="undo">Undo</button>
29
- <button type="button" class="control" id="commit">Commit</button>
30
- <button type="button" class="control" id="transfer">Transfer</button>
31
- <button type="button" class="control" id="upload">Upload</button>
32
- <button type="button" class="control" id="draw">Draw</button>
33
- <input type="text" id="mode" value="✥" class="control">
34
- <input type="text" id="setup" value="0" class="control">
35
- <textarea rows="1" id="selbuffer" name="selbuffer" class="control"></textarea>
36
- <fieldset class="control">
37
- <div>
38
- <input type="radio" id="mode0" name="mode" value="0" checked>
39
- <label for="mode0">SelBox</label>
40
- </div>
41
- <div>
42
- <input type="radio" id="mode1" name="mode" value="1">
43
- <label for="mode1">Image</label>
44
- </div>
45
- <div>
46
- <input type="radio" id="mode2" name="mode" value="2">
47
- <label for="mode2">Brush</label>
48
- </div>
49
- </fieldset>
50
- </div>
51
- <div>
52
- <div id = "container">
53
- <canvas id = "canvas0"></canvas>
54
- <canvas id = "canvas1"></canvas>
55
- <canvas id = "canvas2"></canvas>
56
- <canvas id = "canvas3"></canvas>
57
- <canvas id = "canvas4"></canvas>
58
- </div>
59
- </div>
60
- <py-env>
61
- - numpy
62
- - Pillow
63
- - paths:
64
- - ./canvas.py
65
- </py-env>
66
-
67
- <py-script>
68
- from pyodide import to_js, create_proxy
69
- from PIL import Image
70
- import io
71
- import time
72
- import base64
73
- import numpy as np
74
- from js import (
75
- console,
76
- document,
77
- parent,
78
- devicePixelRatio,
79
- ImageData,
80
- Uint8ClampedArray,
81
- CanvasRenderingContext2D as Context2d,
82
- requestAnimationFrame,
83
- window
84
- )
85
-
86
-
87
- from canvas import InfCanvas
88
-
89
-
90
-
91
- base_lst = [None]
92
- async def draw_canvas() -> None:
93
- width=1500
94
- height=600
95
- canvas=InfCanvas(1500,600)
96
- document.querySelector("#container").style.width = f"{width}px"
97
- canvas.setup_mouse()
98
- canvas.clear_background()
99
- canvas.draw_buffer()
100
- canvas.draw_selection_box()
101
- base_lst[0]=canvas
102
-
103
- async def draw_canvas_func(event):
104
- try:
105
- width=parent.document.querySelector("gradio-app").querySelector("#canvas_width input").value
106
- height=parent.document.querySelector("gradio-app").querySelector("#canvas_height input").value
107
- selection_size=parent.document.querySelector("gradio-app").querySelector("#selection_size input").value
108
- except:
109
- width=1024
110
- height=768
111
- selection_size=384
112
- document.querySelector("#container").style.width = f"{width}px"
113
- canvas=InfCanvas(int(width),int(height),selection_size=int(selection_size))
114
- canvas.setup_mouse()
115
- canvas.clear_background()
116
- canvas.draw_buffer()
117
- canvas.draw_selection_box()
118
- base_lst[0]=canvas
119
-
120
- async def export_func(event):
121
- base=base_lst[0]
122
- arr=base.export()
123
- base64_str = base.numpy_to_base64(arr)
124
- time_str = time.strftime("%Y%m%d_%H%M%S")
125
- console.log(f"Canvas saved to outpaint_{time_str}.png")
126
- link = document.createElement("a")
127
- link.download = f"outpaint_{time_str}.png"
128
- link.href = "data:image/png;base64,"+base64_str
129
- link.click()
130
-
131
- async def outpaint_func(event):
132
- base=base_lst[0]
133
- base64_str=event.data[1]
134
- arr=base.base64_to_numpy(base64_str)
135
- base.fill_selection(arr)
136
- base.draw_selection_box()
137
-
138
- async def undo_func(event):
139
- base=base_lst[0]
140
- if base.sel_dirty:
141
- base.canvas[2].clear()
142
- base.sel_buffer = base.sel_buffer_bak.copy()
143
- base.sel_dirty = False
144
-
145
- async def commit_func(event):
146
- base=base_lst[0]
147
- if base.sel_dirty:
148
- base.write_selection_to_buffer()
149
-
150
- async def transfer_func(event):
151
- base=base_lst[0]
152
- base.read_selection_from_buffer()
153
- sel_buffer=base.sel_buffer
154
- sel_buffer_str=base.numpy_to_base64(sel_buffer)
155
- parent.postMessage(to_js(["transfer",str(sel_buffer_str)]),"*")
156
-
157
- async def upload_func(event):
158
- base=base_lst[0]
159
- base64_str=event.data[1]
160
- arr=base.base64_to_numpy(base64_str)
161
- h,w,_=arr.shape
162
- yo=(base.height-h)//2
163
- xo=(base.width-w)//2
164
- if base.sel_dirty:
165
- base.canvas[2].clear()
166
- base.sel_buffer = base.sel_buffer_bak.copy()
167
- base.sel_dirty = False
168
- base.buffer_dirty=True
169
- base.buffer*=0
170
- base.buffer[yo:yo+h,xo:xo+w,0:3]=arr[:,:,0:3]
171
- base.buffer[yo:yo+h,xo:xo+w,-1]=arr[:,:,-1]
172
- base.draw_buffer()
173
-
174
-
175
-
176
- document.querySelector("#export").addEventListener("click",create_proxy(export_func))
177
- document.querySelector("#undo").addEventListener("click",create_proxy(undo_func))
178
- document.querySelector("#commit").addEventListener("click",create_proxy(commit_func))
179
- document.querySelector("#outpaint").addEventListener("click",create_proxy(outpaint_func))
180
- document.querySelector("#upload").addEventListener("click",create_proxy(upload_func))
181
-
182
- document.querySelector("#transfer").addEventListener("click",create_proxy(transfer_func))
183
- document.querySelector("#draw").addEventListener("click",create_proxy(draw_canvas_func))
184
-
185
- async def setup_func():
186
- document.querySelector("#setup").value="1"
187
-
188
- async def message_func(event):
189
- if event.data[0]=="click":
190
- if event.data[1]=="export":
191
- await export_func(event)
192
- elif event.data[1]=="commit":
193
- await commit_func(event)
194
- elif event.data[1]=="undo":
195
- await undo_func(event)
196
- elif event.data[0]=="upload":
197
- await upload_func(event)
198
- elif event.data[0]=="outpaint":
199
- await outpaint_func(event)
200
- elif event.data[0]=="mode":
201
- document.querySelector("#mode").value=event.data[1]
202
- elif event.data[0]=="transfer":
203
- await transfer_func(event)
204
-
205
- window.addEventListener("message",create_proxy(message_func))
206
- import asyncio
207
-
208
- _ = await asyncio.gather(
209
- setup_func(),draw_canvas()
210
- )
211
- </py-script>
212
-
213
- </body>
214
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <html>
2
+ <head>
3
+ <title>Stablediffusion Infinity</title>
4
+ <meta charset="utf-8">
5
+ <link rel="icon" type="image/x-icon" href="./favicon.png">
6
+
7
+ <link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@master/css/w2ui.min.css">
8
+ <script type="text/javascript" src="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@master/js/w2ui.min.js"></script>
9
+ <link rel="stylesheet" type="text/css" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.2.0/css/all.min.css">
10
+ <script src="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@master/js/fabric.min.js"></script>
11
+ <script defer src="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@master/js/toolbar.js"></script>
12
+
13
+ <link rel="stylesheet" href="https://pyscript.net/alpha/pyscript.css" />
14
+ <script defer src="https://pyscript.net/alpha/pyscript.js"></script>
15
+
16
+ <style>
17
+ #container {
18
+ position: relative;
19
+ margin:auto;
20
+ display: block;
21
+ }
22
+ #container > canvas {
23
+ position: absolute;
24
+ top: 0;
25
+ left: 0;
26
+ }
27
+ .control {
28
+ display: none;
29
+ }
30
+ </style>
31
+
32
+ </head>
33
+ <body>
34
+ <div>
35
+ <button type="button" class="control" id="export">Export</button>
36
+ <button type="button" class="control" id="outpaint">Outpaint</button>
37
+ <button type="button" class="control" id="undo">Undo</button>
38
+ <button type="button" class="control" id="commit">Commit</button>
39
+ <button type="button" class="control" id="transfer">Transfer</button>
40
+ <button type="button" class="control" id="upload">Upload</button>
41
+ <button type="button" class="control" id="draw">Draw</button>
42
+ <input type="text" id="mode" value="selection" class="control">
43
+ <input type="text" id="setup" value="0" class="control">
44
+ <input type="text" id="upload_content" value="0" class="control">
45
+ <textarea rows="1" id="selbuffer" name="selbuffer" class="control"></textarea>
46
+ <fieldset class="control">
47
+ <div>
48
+ <input type="radio" id="mode0" name="mode" value="0" checked>
49
+ <label for="mode0">SelBox</label>
50
+ </div>
51
+ <div>
52
+ <input type="radio" id="mode1" name="mode" value="1">
53
+ <label for="mode1">Image</label>
54
+ </div>
55
+ <div>
56
+ <input type="radio" id="mode2" name="mode" value="2">
57
+ <label for="mode2">Brush</label>
58
+ </div>
59
+ </fieldset>
60
+ </div>
61
+ <div id = "outer_container">
62
+ <div id = "container">
63
+ <canvas id = "canvas0"></canvas>
64
+ <canvas id = "canvas1"></canvas>
65
+ <canvas id = "canvas2"></canvas>
66
+ <canvas id = "canvas3"></canvas>
67
+ <canvas id = "canvas4"></canvas>
68
+ <div id="overlay_container" style="pointer-events: none">
69
+ <canvas id = "overlay_canvas" width="1" height="1"></canvas>
70
+ </div>
71
+ </div>
72
+ <input type="file" name="file" id="upload_file" accept="image/*" hidden>
73
+ <input type="file" name="state" id="upload_state" accept=".sdinf" hidden>
74
+ <div style="position: relative;">
75
+ <div id="toolbar" style></div>
76
+ </div>
77
+ </div>
78
+ <py-env>
79
+ - numpy
80
+ - Pillow
81
+ - paths:
82
+ - ./canvas.py
83
+ </py-env>
84
+
85
+ <py-script>
86
+ from pyodide import to_js, create_proxy
87
+ from PIL import Image
88
+ import io
89
+ import time
90
+ import base64
91
+ import numpy as np
92
+ from js import (
93
+ console,
94
+ document,
95
+ parent,
96
+ devicePixelRatio,
97
+ ImageData,
98
+ Uint8ClampedArray,
99
+ CanvasRenderingContext2D as Context2d,
100
+ requestAnimationFrame,
101
+ window,
102
+ encodeURIComponent,
103
+ w2ui,
104
+ update_eraser,
105
+ update_scale,
106
+ adjust_selection,
107
+ update_count,
108
+ enable_result_lst,
109
+ setup_shortcut,
110
+ )
111
+
112
+
113
+ from canvas import InfCanvas
114
+
115
+
116
+
117
+ base_lst = [None]
118
+ async def draw_canvas() -> None:
119
+ width=1024
120
+ height=600
121
+ canvas=InfCanvas(1024,600)
122
+ update_eraser(canvas.eraser_size,min(canvas.selection_size_h,canvas.selection_size_w))
123
+ document.querySelector("#container").style.height= f"{height}px"
124
+ document.querySelector("#container").style.width = f"{width}px"
125
+ canvas.setup_mouse()
126
+ canvas.clear_background()
127
+ canvas.draw_buffer()
128
+ canvas.draw_selection_box()
129
+ base_lst[0]=canvas
130
+
131
+ async def draw_canvas_func(event):
132
+ try:
133
+ app=parent.document.querySelector("gradio-app")
134
+ if app.shadowRoot:
135
+ app=app.shadowRoot
136
+ width=app.querySelector("#canvas_width input").value
137
+ height=app.querySelector("#canvas_height input").value
138
+ selection_size=app.querySelector("#selection_size input").value
139
+ except:
140
+ width=1024
141
+ height=768
142
+ selection_size=384
143
+ document.querySelector("#container").style.width = f"{width}px"
144
+ document.querySelector("#container").style.height= f"{height}px"
145
+ canvas=InfCanvas(int(width),int(height),selection_size=int(selection_size))
146
+ canvas.setup_mouse()
147
+ canvas.clear_background()
148
+ canvas.draw_buffer()
149
+ canvas.draw_selection_box()
150
+ base_lst[0]=canvas
151
+
152
+ async def export_func(event):
153
+ base=base_lst[0]
154
+ arr=base.export()
155
+ base.draw_buffer()
156
+ base.canvas[2].clear()
157
+ base64_str = base.numpy_to_base64(arr)
158
+ time_str = time.strftime("%Y%m%d_%H%M%S")
159
+ link = document.createElement("a")
160
+ if len(event.data)>2 and event.data[2]:
161
+ filename = event.data[2]
162
+ else:
163
+ filename = f"outpaint_{time_str}"
164
+ # link.download = f"sdinf_state_{time_str}.json"
165
+ link.download = f"{filename}.png"
166
+ # link.download = f"outpaint_{time_str}.png"
167
+ link.href = "data:image/png;base64,"+base64_str
168
+ link.click()
169
+ console.log(f"Canvas saved to {filename}.png")
170
+
171
+ img_candidate_lst=[None,0]
172
+
173
+ async def outpaint_func(event):
174
+ base=base_lst[0]
175
+ if len(event.data)==2:
176
+ app=parent.document.querySelector("gradio-app")
177
+ if app.shadowRoot:
178
+ app=app.shadowRoot
179
+ base64_str_raw=app.querySelector("#output textarea").value
180
+ base64_str_lst=base64_str_raw.split(",")
181
+ img_candidate_lst[0]=base64_str_lst
182
+ img_candidate_lst[1]=0
183
+ elif event.data[2]=="next":
184
+ img_candidate_lst[1]+=1
185
+ elif event.data[2]=="prev":
186
+ img_candidate_lst[1]-=1
187
+ enable_result_lst()
188
+ if img_candidate_lst[0] is None:
189
+ return
190
+ lst=img_candidate_lst[0]
191
+ idx=img_candidate_lst[1]
192
+ update_count(idx%len(lst)+1,len(lst))
193
+ arr=base.base64_to_numpy(lst[idx%len(lst)])
194
+ base.fill_selection(arr)
195
+ base.draw_selection_box()
196
+
197
+ async def undo_func(event):
198
+ base=base_lst[0]
199
+ img_candidate_lst[0]=None
200
+ if base.sel_dirty:
201
+ base.sel_buffer = np.zeros((base.selection_size_h, base.selection_size_w, 4), dtype=np.uint8)
202
+ base.sel_dirty = False
203
+ base.canvas[2].clear()
204
+
205
+ async def commit_func(event):
206
+ base=base_lst[0]
207
+ img_candidate_lst[0]=None
208
+ if base.sel_dirty:
209
+ base.write_selection_to_buffer()
210
+ base.draw_buffer()
211
+ base.canvas[2].clear()
212
+
213
+ async def transfer_func(event):
214
+ base=base_lst[0]
215
+ base.read_selection_from_buffer()
216
+ sel_buffer=base.sel_buffer
217
+ sel_buffer_str=base.numpy_to_base64(sel_buffer)
218
+ app=parent.document.querySelector("gradio-app")
219
+ if app.shadowRoot:
220
+ app=app.shadowRoot
221
+ app.querySelector("#input textarea").value=sel_buffer_str
222
+ app.querySelector("#proceed").click()
223
+
224
+ async def upload_func(event):
225
+ base=base_lst[0]
226
+ # base64_str=event.data[1]
227
+ base64_str=document.querySelector("#upload_content").value
228
+ base64_str=base64_str.split(",")[-1]
229
+ # base64_str=parent.document.querySelector("gradio-app").shadowRoot.querySelector("#upload textarea").value
230
+ arr=base.base64_to_numpy(base64_str)
231
+ h,w,c=base.buffer.shape
232
+ base.sync_to_buffer()
233
+ base.buffer_dirty=True
234
+ mask=arr[:,:,3:4].repeat(4,axis=2)
235
+ base.buffer[mask>0]=0
236
+ # in case mismatch
237
+ base.buffer[0:h,0:w,:]+=arr
238
+ #base.buffer[yo:yo+h,xo:xo+w,0:3]=arr[:,:,0:3]
239
+ #base.buffer[yo:yo+h,xo:xo+w,-1]=arr[:,:,-1]
240
+ base.draw_buffer()
241
+
242
+ async def setup_shortcut_func(event):
243
+ setup_shortcut(event.data[1])
244
+
245
+
246
+ document.querySelector("#export").addEventListener("click",create_proxy(export_func))
247
+ document.querySelector("#undo").addEventListener("click",create_proxy(undo_func))
248
+ document.querySelector("#commit").addEventListener("click",create_proxy(commit_func))
249
+ document.querySelector("#outpaint").addEventListener("click",create_proxy(outpaint_func))
250
+ document.querySelector("#upload").addEventListener("click",create_proxy(upload_func))
251
+
252
+ document.querySelector("#transfer").addEventListener("click",create_proxy(transfer_func))
253
+ document.querySelector("#draw").addEventListener("click",create_proxy(draw_canvas_func))
254
+
255
+ async def setup_func():
256
+ document.querySelector("#setup").value="1"
257
+
258
+ async def reset_func(event):
259
+ base=base_lst[0]
260
+ base.reset()
261
+
262
+ async def load_func(event):
263
+ base=base_lst[0]
264
+ base.load(event.data[1])
265
+
266
+ async def save_func(event):
267
+ base=base_lst[0]
268
+ json_str=base.save()
269
+ time_str = time.strftime("%Y%m%d_%H%M%S")
270
+ link = document.createElement("a")
271
+ if len(event.data)>2 and event.data[2]:
272
+ filename = str(event.data[2]).strip()
273
+ else:
274
+ filename = f"outpaint_{time_str}"
275
+ # link.download = f"sdinf_state_{time_str}.json"
276
+ link.download = f"{filename}.sdinf"
277
+ link.href = "data:text/json;charset=utf-8,"+encodeURIComponent(json_str)
278
+ link.click()
279
+
280
+ async def prev_result_func(event):
281
+ base=base_lst[0]
282
+ base.reset()
283
+
284
+ async def next_result_func(event):
285
+ base=base_lst[0]
286
+ base.reset()
287
+
288
+ async def zoom_in_func(event):
289
+ base=base_lst[0]
290
+ scale=base.scale
291
+ if scale>=0.2:
292
+ scale-=0.1
293
+ if len(event.data)>2:
294
+ base.update_scale(scale,int(event.data[2]),int(event.data[3]))
295
+ else:
296
+ base.update_scale(scale)
297
+ scale=base.scale
298
+ update_scale(f"{base.width}x{base.height} ({round(100/scale)}%)")
299
+
300
+ async def zoom_out_func(event):
301
+ base=base_lst[0]
302
+ scale=base.scale
303
+ if scale<10:
304
+ scale+=0.1
305
+ console.log(len(event.data))
306
+ if len(event.data)>2:
307
+ base.update_scale(scale,int(event.data[2]),int(event.data[3]))
308
+ else:
309
+ base.update_scale(scale)
310
+ scale=base.scale
311
+ update_scale(f"{base.width}x{base.height} ({round(100/scale)}%)")
312
+
313
+ async def sync_func(event):
314
+ base=base_lst[0]
315
+ base.sync_to_buffer()
316
+ base.canvas[2].clear()
317
+
318
+ async def eraser_size_func(event):
319
+ base=base_lst[0]
320
+ eraser_size=min(int(event.data[1]),min(base.selection_size_h,base.selection_size_w))
321
+ eraser_size=max(8,eraser_size)
322
+ base.eraser_size=eraser_size
323
+
324
+ async def resize_selection_func(event):
325
+ base=base_lst[0]
326
+ cursor=base.cursor
327
+ if len(event.data)>3:
328
+ console.log(event.data)
329
+ base.cursor[0]=int(event.data[1])
330
+ base.cursor[1]=int(event.data[2])
331
+ base.selection_size_w=int(event.data[3])//8*8
332
+ base.selection_size_h=int(event.data[4])//8*8
333
+ base.refine_selection()
334
+ base.draw_selection_box()
335
+ elif len(event.data)>2:
336
+ base.draw_selection_box()
337
+ else:
338
+ base.canvas[-1].clear()
339
+ adjust_selection(cursor[0],cursor[1],base.selection_size_w,base.selection_size_h)
340
+
341
+ async def eraser_func(event):
342
+ base=base_lst[0]
343
+ if event.data[1]!="eraser":
344
+ base.canvas[-2].clear()
345
+ else:
346
+ x,y=base.mouse_pos
347
+ base.draw_eraser(x,y)
348
+
349
+ async def resize_func(event):
350
+ base=base_lst[0]
351
+ width=int(event.data[1])
352
+ height=int(event.data[2])
353
+ if width>=256 and height>=256:
354
+ if max(base.selection_size_h,base.selection_size_w)>min(width,height):
355
+ base.selection_size_h=256
356
+ base.selection_size_w=256
357
+ base.resize(width,height)
358
+
359
+ async def message_func(event):
360
+ if event.data[0]=="click":
361
+ if event.data[1]=="clear":
362
+ await reset_func(event)
363
+ elif event.data[1]=="save":
364
+ await save_func(event)
365
+ elif event.data[1]=="export":
366
+ await export_func(event)
367
+ elif event.data[1]=="accept":
368
+ await commit_func(event)
369
+ elif event.data[1]=="cancel":
370
+ await undo_func(event)
371
+ elif event.data[1]=="zoom_in":
372
+ await zoom_in_func(event)
373
+ elif event.data[1]=="zoom_out":
374
+ await zoom_out_func(event)
375
+ elif event.data[0]=="sync":
376
+ await sync_func(event)
377
+ elif event.data[0]=="load":
378
+ await load_func(event)
379
+ elif event.data[0]=="upload":
380
+ await upload_func(event)
381
+ elif event.data[0]=="outpaint":
382
+ await outpaint_func(event)
383
+ elif event.data[0]=="mode":
384
+ if event.data[1]!="selection":
385
+ await sync_func(event)
386
+ await eraser_func(event)
387
+ document.querySelector("#mode").value=event.data[1]
388
+ elif event.data[0]=="transfer":
389
+ await transfer_func(event)
390
+ elif event.data[0]=="setup":
391
+ await draw_canvas_func(event)
392
+ elif event.data[0]=="eraser_size":
393
+ await eraser_size_func(event)
394
+ elif event.data[0]=="resize_selection":
395
+ await resize_selection_func(event)
396
+ elif event.data[0]=="shortcut":
397
+ await setup_shortcut_func(event)
398
+ elif event.data[0]=="resize":
399
+ await resize_func(event)
400
+
401
+ window.addEventListener("message",create_proxy(message_func))
402
+
403
+ import asyncio
404
+
405
+ _ = await asyncio.gather(
406
+ setup_func()
407
+ )
408
+ </py-script>
409
+
410
+ </body>
411
+ </html>
perlin2d.py CHANGED
@@ -1,45 +1,45 @@
1
- import numpy as np
2
-
3
- ##########
4
- # https://stackoverflow.com/questions/42147776/producing-2d-perlin-noise-with-numpy/42154921#42154921
5
- def perlin(x, y, seed=0):
6
- # permutation table
7
- np.random.seed(seed)
8
- p = np.arange(256, dtype=int)
9
- np.random.shuffle(p)
10
- p = np.stack([p, p]).flatten()
11
- # coordinates of the top-left
12
- xi, yi = x.astype(int), y.astype(int)
13
- # internal coordinates
14
- xf, yf = x - xi, y - yi
15
- # fade factors
16
- u, v = fade(xf), fade(yf)
17
- # noise components
18
- n00 = gradient(p[p[xi] + yi], xf, yf)
19
- n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1)
20
- n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1)
21
- n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf)
22
- # combine noises
23
- x1 = lerp(n00, n10, u)
24
- x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01
25
- return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here
26
-
27
-
28
- def lerp(a, b, x):
29
- "linear interpolation"
30
- return a + x * (b - a)
31
-
32
-
33
- def fade(t):
34
- "6t^5 - 15t^4 + 10t^3"
35
- return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3
36
-
37
-
38
- def gradient(h, x, y):
39
- "grad converts h to the right gradient vector and return the dot product with (x,y)"
40
- vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]])
41
- g = vectors[h % 4]
42
- return g[:, :, 0] * x + g[:, :, 1] * y
43
-
44
-
45
  ##########
 
1
+ import numpy as np
2
+
3
+ ##########
4
+ # https://stackoverflow.com/questions/42147776/producing-2d-perlin-noise-with-numpy/42154921#42154921
5
+ def perlin(x, y, seed=0):
6
+ # permutation table
7
+ np.random.seed(seed)
8
+ p = np.arange(256, dtype=int)
9
+ np.random.shuffle(p)
10
+ p = np.stack([p, p]).flatten()
11
+ # coordinates of the top-left
12
+ xi, yi = x.astype(int), y.astype(int)
13
+ # internal coordinates
14
+ xf, yf = x - xi, y - yi
15
+ # fade factors
16
+ u, v = fade(xf), fade(yf)
17
+ # noise components
18
+ n00 = gradient(p[p[xi] + yi], xf, yf)
19
+ n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1)
20
+ n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1)
21
+ n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf)
22
+ # combine noises
23
+ x1 = lerp(n00, n10, u)
24
+ x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01
25
+ return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here
26
+
27
+
28
+ def lerp(a, b, x):
29
+ "linear interpolation"
30
+ return a + x * (b - a)
31
+
32
+
33
+ def fade(t):
34
+ "6t^5 - 15t^4 + 10t^3"
35
+ return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3
36
+
37
+
38
+ def gradient(h, x, y):
39
+ "grad converts h to the right gradient vector and return the dot product with (x,y)"
40
+ vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]])
41
+ g = vectors[h % 4]
42
+ return g[:, :, 0] * x + g[:, :, 1] * y
43
+
44
+
45
  ##########
postprocess.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/Trinkle23897/Fast-Poisson-Image-Editing
3
+ MIT License
4
+
5
+ Copyright (c) 2022 Jiayi Weng
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+ """
25
+
26
+ import time
27
+ import argparse
28
+ import os
29
+ import fpie
30
+ from process import ALL_BACKEND, CPU_COUNT, DEFAULT_BACKEND
31
+ from fpie.io import read_images, write_image
32
+ from process import BaseProcessor, EquProcessor, GridProcessor
33
+
34
+ from PIL import Image
35
+ import numpy as np
36
+ import skimage
37
+ import skimage.measure
38
+ import scipy
39
+ import scipy.signal
40
+
41
+
42
+ class PhotometricCorrection:
43
+ def __init__(self,quite=False):
44
+ self.get_parser("cli")
45
+ args=self.parser.parse_args(["--method","grid","-g","src","-s","a","-t","a","-o","a"])
46
+ args.mpi_sync_interval = getattr(args, "mpi_sync_interval", 0)
47
+ self.backend=args.backend
48
+ self.args=args
49
+ self.quite=quite
50
+ proc: BaseProcessor
51
+ proc = GridProcessor(
52
+ args.gradient,
53
+ args.backend,
54
+ args.cpu,
55
+ args.mpi_sync_interval,
56
+ args.block_size,
57
+ args.grid_x,
58
+ args.grid_y,
59
+ )
60
+ print(
61
+ f"[PIE]Successfully initialize PIE {args.method} solver "
62
+ f"with {args.backend} backend"
63
+ )
64
+ self.proc=proc
65
+
66
+ def run(self, original_image, inpainted_image, mode="mask_mode"):
67
+ print(f"[PIE] start")
68
+ if mode=="disabled":
69
+ return inpainted_image
70
+ input_arr=np.array(original_image)
71
+ if input_arr[:,:,-1].sum()<1:
72
+ return inpainted_image
73
+ output_arr=np.array(inpainted_image)
74
+ mask=input_arr[:,:,-1]
75
+ mask=255-mask
76
+ if mask.sum()<1 and mode=="mask_mode":
77
+ mode=""
78
+ if mode=="mask_mode":
79
+ mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
80
+ mask = mask.repeat(8, axis=0).repeat(8, axis=1)
81
+ else:
82
+ mask[8:-9,8:-9]=255
83
+ mask = mask[:,:,np.newaxis].repeat(3,axis=2)
84
+ nmask=mask.copy()
85
+ output_arr2=output_arr[:,:,0:3].copy()
86
+ input_arr2=input_arr[:,:,0:3].copy()
87
+ output_arr2[nmask<128]=0
88
+ input_arr2[nmask>=128]=0
89
+ output_arr2+=input_arr2
90
+ src = output_arr2[:,:,0:3]
91
+ tgt = src.copy()
92
+ proc=self.proc
93
+ args=self.args
94
+ if proc.root:
95
+ n = proc.reset(src, mask, tgt, (args.h0, args.w0), (args.h1, args.w1))
96
+ proc.sync()
97
+ if proc.root:
98
+ result = tgt
99
+ t = time.time()
100
+ if args.p == 0:
101
+ args.p = args.n
102
+
103
+ for i in range(0, args.n, args.p):
104
+ if proc.root:
105
+ result, err = proc.step(args.p) # type: ignore
106
+ print(f"[PIE] Iter {i + args.p}, abs_err {err}")
107
+ else:
108
+ proc.step(args.p)
109
+
110
+ if proc.root:
111
+ dt = time.time() - t
112
+ print(f"[PIE] Time elapsed: {dt:.4f}s")
113
+ # make sure consistent with dummy process
114
+ return Image.fromarray(result)
115
+
116
+
117
+ def get_parser(self,gen_type: str) -> argparse.Namespace:
118
+ parser = argparse.ArgumentParser()
119
+ parser.add_argument(
120
+ "-v", "--version", action="store_true", help="show the version and exit"
121
+ )
122
+ parser.add_argument(
123
+ "--check-backend", action="store_true", help="print all available backends"
124
+ )
125
+ if gen_type == "gui" and "mpi" in ALL_BACKEND:
126
+ # gui doesn't support MPI backend
127
+ ALL_BACKEND.remove("mpi")
128
+ parser.add_argument(
129
+ "-b",
130
+ "--backend",
131
+ type=str,
132
+ choices=ALL_BACKEND,
133
+ default=DEFAULT_BACKEND,
134
+ help="backend choice",
135
+ )
136
+ parser.add_argument(
137
+ "-c",
138
+ "--cpu",
139
+ type=int,
140
+ default=CPU_COUNT,
141
+ help="number of CPU used",
142
+ )
143
+ parser.add_argument(
144
+ "-z",
145
+ "--block-size",
146
+ type=int,
147
+ default=1024,
148
+ help="cuda block size (only for equ solver)",
149
+ )
150
+ parser.add_argument(
151
+ "--method",
152
+ type=str,
153
+ choices=["equ", "grid"],
154
+ default="equ",
155
+ help="how to parallelize computation",
156
+ )
157
+ parser.add_argument("-s", "--source", type=str, help="source image filename")
158
+ if gen_type == "cli":
159
+ parser.add_argument(
160
+ "-m",
161
+ "--mask",
162
+ type=str,
163
+ help="mask image filename (default is to use the whole source image)",
164
+ default="",
165
+ )
166
+ parser.add_argument("-t", "--target", type=str, help="target image filename")
167
+ parser.add_argument("-o", "--output", type=str, help="output image filename")
168
+ if gen_type == "cli":
169
+ parser.add_argument(
170
+ "-h0", type=int, help="mask position (height) on source image", default=0
171
+ )
172
+ parser.add_argument(
173
+ "-w0", type=int, help="mask position (width) on source image", default=0
174
+ )
175
+ parser.add_argument(
176
+ "-h1", type=int, help="mask position (height) on target image", default=0
177
+ )
178
+ parser.add_argument(
179
+ "-w1", type=int, help="mask position (width) on target image", default=0
180
+ )
181
+ parser.add_argument(
182
+ "-g",
183
+ "--gradient",
184
+ type=str,
185
+ choices=["max", "src", "avg"],
186
+ default="max",
187
+ help="how to calculate gradient for PIE",
188
+ )
189
+ parser.add_argument(
190
+ "-n",
191
+ type=int,
192
+ help="how many iteration would you perfer, the more the better",
193
+ default=5000,
194
+ )
195
+ if gen_type == "cli":
196
+ parser.add_argument(
197
+ "-p", type=int, help="output result every P iteration", default=0
198
+ )
199
+ if "mpi" in ALL_BACKEND:
200
+ parser.add_argument(
201
+ "--mpi-sync-interval",
202
+ type=int,
203
+ help="MPI sync iteration interval",
204
+ default=100,
205
+ )
206
+ parser.add_argument(
207
+ "--grid-x", type=int, help="x axis stride for grid solver", default=8
208
+ )
209
+ parser.add_argument(
210
+ "--grid-y", type=int, help="y axis stride for grid solver", default=8
211
+ )
212
+ self.parser=parser
213
+
214
+ if __name__ =="__main__":
215
+ import sys
216
+ import io
217
+ import base64
218
+ from PIL import Image
219
+ def base64_to_pil(base64_str):
220
+ data = base64.b64decode(str(base64_str))
221
+ pil = Image.open(io.BytesIO(data))
222
+ return pil
223
+
224
+ def pil_to_base64(out_pil):
225
+ out_buffer = io.BytesIO()
226
+ out_pil.save(out_buffer, format="PNG")
227
+ out_buffer.seek(0)
228
+ base64_bytes = base64.b64encode(out_buffer.read())
229
+ base64_str = base64_bytes.decode("ascii")
230
+ return base64_str
231
+ correction_func=PhotometricCorrection(quite=True)
232
+ while True:
233
+ buffer = sys.stdin.readline()
234
+ print(f"[PIE] suprocess {len(buffer)} {type(buffer)} ")
235
+ if len(buffer)==0:
236
+ break
237
+ if isinstance(buffer,str):
238
+ lst=buffer.strip().split(",")
239
+ else:
240
+ lst=buffer.decode("ascii").strip().split(",")
241
+ img0=base64_to_pil(lst[0])
242
+ img1=base64_to_pil(lst[1])
243
+ ret=correction_func.run(img0,img1,mode=lst[2])
244
+ ret_base64=pil_to_base64(ret)
245
+ if isinstance(buffer,str):
246
+ sys.stdout.write(f"{ret_base64}\n")
247
+ else:
248
+ sys.stdout.write(f"{ret_base64}\n".encode())
249
+ sys.stdout.flush()
process.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/Trinkle23897/Fast-Poisson-Image-Editing
3
+ MIT License
4
+
5
+ Copyright (c) 2022 Jiayi Weng
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+ """
25
+ import os
26
+ from abc import ABC, abstractmethod
27
+ from typing import Any, Optional, Tuple
28
+
29
+ import numpy as np
30
+
31
+ from fpie import np_solver
32
+
33
+ import scipy
34
+ import scipy.signal
35
+
36
+ CPU_COUNT = os.cpu_count() or 1
37
+ DEFAULT_BACKEND = "numpy"
38
+ ALL_BACKEND = ["numpy"]
39
+
40
+ try:
41
+ from fpie import numba_solver
42
+ ALL_BACKEND += ["numba"]
43
+ DEFAULT_BACKEND = "numba"
44
+ except ImportError:
45
+ numba_solver = None # type: ignore
46
+
47
+ try:
48
+ from fpie import taichi_solver
49
+ ALL_BACKEND += ["taichi-cpu", "taichi-gpu"]
50
+ DEFAULT_BACKEND = "taichi-cpu"
51
+ except ImportError:
52
+ taichi_solver = None # type: ignore
53
+
54
+ # try:
55
+ # from fpie import core_gcc # type: ignore
56
+ # DEFAULT_BACKEND = "gcc"
57
+ # ALL_BACKEND.append("gcc")
58
+ # except ImportError:
59
+ # core_gcc = None
60
+
61
+ # try:
62
+ # from fpie import core_openmp # type: ignore
63
+ # DEFAULT_BACKEND = "openmp"
64
+ # ALL_BACKEND.append("openmp")
65
+ # except ImportError:
66
+ # core_openmp = None
67
+
68
+ # try:
69
+ # from mpi4py import MPI
70
+
71
+ # from fpie import core_mpi # type: ignore
72
+ # ALL_BACKEND.append("mpi")
73
+ # except ImportError:
74
+ # MPI = None # type: ignore
75
+ # core_mpi = None
76
+
77
+ try:
78
+ from fpie import core_cuda # type: ignore
79
+ DEFAULT_BACKEND = "cuda"
80
+ ALL_BACKEND.append("cuda")
81
+ except ImportError:
82
+ core_cuda = None
83
+
84
+
85
+ class BaseProcessor(ABC):
86
+ """API definition for processor class."""
87
+
88
+ def __init__(
89
+ self, gradient: str, rank: int, backend: str, core: Optional[Any]
90
+ ):
91
+ if core is None:
92
+ error_msg = {
93
+ "numpy":
94
+ "Please run `pip install numpy`.",
95
+ "numba":
96
+ "Please run `pip install numba`.",
97
+ "gcc":
98
+ "Please install cmake and gcc in your operating system.",
99
+ "openmp":
100
+ "Please make sure your gcc is compatible with `-fopenmp` option.",
101
+ "mpi":
102
+ "Please install MPI and run `pip install mpi4py`.",
103
+ "cuda":
104
+ "Please make sure nvcc and cuda-related libraries are available.",
105
+ "taichi":
106
+ "Please run `pip install taichi`.",
107
+ }
108
+ print(error_msg[backend.split("-")[0]])
109
+
110
+ raise AssertionError(f"Invalid backend {backend}.")
111
+
112
+ self.gradient = gradient
113
+ self.rank = rank
114
+ self.backend = backend
115
+ self.core = core
116
+ self.root = rank == 0
117
+
118
+ def mixgrad(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
119
+ if self.gradient == "src":
120
+ return a
121
+ if self.gradient == "avg":
122
+ return (a + b) / 2
123
+ # mix gradient, see Equ. 12 in PIE paper
124
+ mask = np.abs(a) < np.abs(b)
125
+ a[mask] = b[mask]
126
+ return a
127
+
128
+ @abstractmethod
129
+ def reset(
130
+ self,
131
+ src: np.ndarray,
132
+ mask: np.ndarray,
133
+ tgt: np.ndarray,
134
+ mask_on_src: Tuple[int, int],
135
+ mask_on_tgt: Tuple[int, int],
136
+ ) -> int:
137
+ pass
138
+
139
+ def sync(self) -> None:
140
+ self.core.sync()
141
+
142
+ @abstractmethod
143
+ def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
144
+ pass
145
+
146
+
147
+ class EquProcessor(BaseProcessor):
148
+ """PIE Jacobi equation processor."""
149
+
150
+ def __init__(
151
+ self,
152
+ gradient: str = "max",
153
+ backend: str = DEFAULT_BACKEND,
154
+ n_cpu: int = CPU_COUNT,
155
+ min_interval: int = 100,
156
+ block_size: int = 1024,
157
+ ):
158
+ core: Optional[Any] = None
159
+ rank = 0
160
+
161
+ if backend == "numpy":
162
+ core = np_solver.EquSolver()
163
+ elif backend == "numba" and numba_solver is not None:
164
+ core = numba_solver.EquSolver()
165
+ elif backend == "gcc":
166
+ core = core_gcc.EquSolver()
167
+ elif backend == "openmp" and core_openmp is not None:
168
+ core = core_openmp.EquSolver(n_cpu)
169
+ elif backend == "mpi" and core_mpi is not None:
170
+ core = core_mpi.EquSolver(min_interval)
171
+ rank = MPI.COMM_WORLD.Get_rank()
172
+ elif backend == "cuda" and core_cuda is not None:
173
+ core = core_cuda.EquSolver(block_size)
174
+ elif backend.startswith("taichi") and taichi_solver is not None:
175
+ core = taichi_solver.EquSolver(backend, n_cpu, block_size)
176
+
177
+ super().__init__(gradient, rank, backend, core)
178
+
179
+ def mask2index(
180
+ self, mask: np.ndarray
181
+ ) -> Tuple[np.ndarray, int, np.ndarray, np.ndarray]:
182
+ x, y = np.nonzero(mask)
183
+ max_id = x.shape[0] + 1
184
+ index = np.zeros((max_id, 3))
185
+ ids = self.core.partition(mask)
186
+ ids[mask == 0] = 0 # reserve id=0 for constant
187
+ index = ids[x, y].argsort()
188
+ return ids, max_id, x[index], y[index]
189
+
190
+ def reset(
191
+ self,
192
+ src: np.ndarray,
193
+ mask: np.ndarray,
194
+ tgt: np.ndarray,
195
+ mask_on_src: Tuple[int, int],
196
+ mask_on_tgt: Tuple[int, int],
197
+ ) -> int:
198
+ assert self.root
199
+ # check validity
200
+ # assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1]
201
+ # assert mask_on_src[0] + mask.shape[0] <= src.shape[0]
202
+ # assert mask_on_src[1] + mask.shape[1] <= src.shape[1]
203
+ # assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0]
204
+ # assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1]
205
+
206
+ if len(mask.shape) == 3:
207
+ mask = mask.mean(-1)
208
+ mask = (mask >= 128).astype(np.int32)
209
+
210
+ # zero-out edge
211
+ mask[0] = 0
212
+ mask[-1] = 0
213
+ mask[:, 0] = 0
214
+ mask[:, -1] = 0
215
+
216
+ x, y = np.nonzero(mask)
217
+ x0, x1 = x.min() - 1, x.max() + 2
218
+ y0, y1 = y.min() - 1, y.max() + 2
219
+ mask_on_src = (x0 + mask_on_src[0], y0 + mask_on_src[1])
220
+ mask_on_tgt = (x0 + mask_on_tgt[0], y0 + mask_on_tgt[1])
221
+ mask = mask[x0:x1, y0:y1]
222
+ ids, max_id, index_x, index_y = self.mask2index(mask)
223
+
224
+ src_x, src_y = index_x + mask_on_src[0], index_y + mask_on_src[1]
225
+ tgt_x, tgt_y = index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]
226
+
227
+ src_C = src[src_x, src_y].astype(np.float32)
228
+ src_U = src[src_x - 1, src_y].astype(np.float32)
229
+ src_D = src[src_x + 1, src_y].astype(np.float32)
230
+ src_L = src[src_x, src_y - 1].astype(np.float32)
231
+ src_R = src[src_x, src_y + 1].astype(np.float32)
232
+ tgt_C = tgt[tgt_x, tgt_y].astype(np.float32)
233
+ tgt_U = tgt[tgt_x - 1, tgt_y].astype(np.float32)
234
+ tgt_D = tgt[tgt_x + 1, tgt_y].astype(np.float32)
235
+ tgt_L = tgt[tgt_x, tgt_y - 1].astype(np.float32)
236
+ tgt_R = tgt[tgt_x, tgt_y + 1].astype(np.float32)
237
+
238
+ grad = self.mixgrad(src_C - src_L, tgt_C - tgt_L) \
239
+ + self.mixgrad(src_C - src_R, tgt_C - tgt_R) \
240
+ + self.mixgrad(src_C - src_U, tgt_C - tgt_U) \
241
+ + self.mixgrad(src_C - src_D, tgt_C - tgt_D)
242
+
243
+ A = np.zeros((max_id, 4), np.int32)
244
+ X = np.zeros((max_id, 3), np.float32)
245
+ B = np.zeros((max_id, 3), np.float32)
246
+
247
+ X[1:] = tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]]
248
+ # four-way
249
+ A[1:, 0] = ids[index_x - 1, index_y]
250
+ A[1:, 1] = ids[index_x + 1, index_y]
251
+ A[1:, 2] = ids[index_x, index_y - 1]
252
+ A[1:, 3] = ids[index_x, index_y + 1]
253
+ B[1:] = grad
254
+ m = (mask[index_x - 1, index_y] == 0).astype(float).reshape(-1, 1)
255
+ B[1:] += m * tgt[index_x + mask_on_tgt[0] - 1, index_y + mask_on_tgt[1]]
256
+ m = (mask[index_x, index_y - 1] == 0).astype(float).reshape(-1, 1)
257
+ B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] - 1]
258
+ m = (mask[index_x, index_y + 1] == 0).astype(float).reshape(-1, 1)
259
+ B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] + 1]
260
+ m = (mask[index_x + 1, index_y] == 0).astype(float).reshape(-1, 1)
261
+ B[1:] += m * tgt[index_x + mask_on_tgt[0] + 1, index_y + mask_on_tgt[1]]
262
+
263
+ self.tgt = tgt.copy()
264
+ self.tgt_index = (index_x + mask_on_tgt[0], index_y + mask_on_tgt[1])
265
+ self.core.reset(max_id, A, X, B)
266
+ return max_id
267
+
268
+ def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
269
+ result = self.core.step(iteration)
270
+ if self.root:
271
+ x, err = result
272
+ self.tgt[self.tgt_index] = x[1:]
273
+ return self.tgt, err
274
+ return None
275
+
276
+
277
+ class GridProcessor(BaseProcessor):
278
+ """PIE grid processor."""
279
+
280
+ def __init__(
281
+ self,
282
+ gradient: str = "max",
283
+ backend: str = DEFAULT_BACKEND,
284
+ n_cpu: int = CPU_COUNT,
285
+ min_interval: int = 100,
286
+ block_size: int = 1024,
287
+ grid_x: int = 8,
288
+ grid_y: int = 8,
289
+ ):
290
+ core: Optional[Any] = None
291
+ rank = 0
292
+
293
+ if backend == "numpy":
294
+ core = np_solver.GridSolver()
295
+ elif backend == "numba" and numba_solver is not None:
296
+ core = numba_solver.GridSolver()
297
+ elif backend == "gcc":
298
+ core = core_gcc.GridSolver(grid_x, grid_y)
299
+ elif backend == "openmp" and core_openmp is not None:
300
+ core = core_openmp.GridSolver(grid_x, grid_y, n_cpu)
301
+ elif backend == "mpi" and core_mpi is not None:
302
+ core = core_mpi.GridSolver(min_interval)
303
+ rank = MPI.COMM_WORLD.Get_rank()
304
+ elif backend == "cuda" and core_cuda is not None:
305
+ core = core_cuda.GridSolver(grid_x, grid_y)
306
+ elif backend.startswith("taichi") and taichi_solver is not None:
307
+ core = taichi_solver.GridSolver(
308
+ grid_x, grid_y, backend, n_cpu, block_size
309
+ )
310
+
311
+ super().__init__(gradient, rank, backend, core)
312
+
313
+ def reset(
314
+ self,
315
+ src: np.ndarray,
316
+ mask: np.ndarray,
317
+ tgt: np.ndarray,
318
+ mask_on_src: Tuple[int, int],
319
+ mask_on_tgt: Tuple[int, int],
320
+ ) -> int:
321
+ assert self.root
322
+ # check validity
323
+ # assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1]
324
+ # assert mask_on_src[0] + mask.shape[0] <= src.shape[0]
325
+ # assert mask_on_src[1] + mask.shape[1] <= src.shape[1]
326
+ # assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0]
327
+ # assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1]
328
+
329
+ if len(mask.shape) == 3:
330
+ mask = mask.mean(-1)
331
+ mask = (mask >= 128).astype(np.int32)
332
+
333
+ # zero-out edge
334
+ mask[0] = 0
335
+ mask[-1] = 0
336
+ mask[:, 0] = 0
337
+ mask[:, -1] = 0
338
+
339
+ x, y = np.nonzero(mask)
340
+ x0, x1 = x.min() - 1, x.max() + 2
341
+ y0, y1 = y.min() - 1, y.max() + 2
342
+ mask = mask[x0:x1, y0:y1]
343
+ max_id = np.prod(mask.shape)
344
+
345
+ src_crop = src[mask_on_src[0] + x0:mask_on_src[0] + x1,
346
+ mask_on_src[1] + y0:mask_on_src[1] + y1].astype(np.float32)
347
+ tgt_crop = tgt[mask_on_tgt[0] + x0:mask_on_tgt[0] + x1,
348
+ mask_on_tgt[1] + y0:mask_on_tgt[1] + y1].astype(np.float32)
349
+ grad = np.zeros([*mask.shape, 3], np.float32)
350
+ grad[1:] += self.mixgrad(
351
+ src_crop[1:] - src_crop[:-1], tgt_crop[1:] - tgt_crop[:-1]
352
+ )
353
+ grad[:-1] += self.mixgrad(
354
+ src_crop[:-1] - src_crop[1:], tgt_crop[:-1] - tgt_crop[1:]
355
+ )
356
+ grad[:, 1:] += self.mixgrad(
357
+ src_crop[:, 1:] - src_crop[:, :-1], tgt_crop[:, 1:] - tgt_crop[:, :-1]
358
+ )
359
+ grad[:, :-1] += self.mixgrad(
360
+ src_crop[:, :-1] - src_crop[:, 1:], tgt_crop[:, :-1] - tgt_crop[:, 1:]
361
+ )
362
+
363
+ grad[mask == 0] = 0
364
+ if True:
365
+ kernel = [[1] * 3 for _ in range(3)]
366
+ nmask = mask.copy()
367
+ nmask[nmask > 0] = 1
368
+ res = scipy.signal.convolve2d(
369
+ nmask, kernel, mode="same", boundary="fill", fillvalue=1
370
+ )
371
+ res[nmask < 1] = 0
372
+ res[res == 9] = 0
373
+ res[res > 0] = 1
374
+ grad[res>0]=0
375
+ # ylst, xlst = res.nonzero()
376
+ # for y, x in zip(ylst, xlst):
377
+ # grad[y,x]=0
378
+ # for yi in range(-1,2):
379
+ # for xi in range(-1,2):
380
+ # grad[y+yi,x+xi]=0
381
+ self.x0 = mask_on_tgt[0] + x0
382
+ self.x1 = mask_on_tgt[0] + x1
383
+ self.y0 = mask_on_tgt[1] + y0
384
+ self.y1 = mask_on_tgt[1] + y1
385
+ self.tgt = tgt.copy()
386
+ self.core.reset(max_id, mask, tgt_crop, grad)
387
+ return max_id
388
+
389
+ def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
390
+ result = self.core.step(iteration)
391
+ if self.root:
392
+ tgt, err = result
393
+ self.tgt[self.x0:self.x1, self.y0:self.y1] = tgt
394
+ return self.tgt, err
395
+ return None
utils.py CHANGED
@@ -1,151 +1,263 @@
1
- from PIL import Image
2
- from PIL import ImageFilter
3
- import cv2
4
- import numpy as np
5
- import scipy
6
- import scipy.signal
7
- from scipy.spatial import cKDTree
8
-
9
- import os
10
- from perlin2d import *
11
-
12
- patch_match_compiled = True
13
- if os.name != "nt":
14
- try:
15
- from PyPatchMatch import patch_match
16
- except Exception as e:
17
- import patch_match
18
-
19
- try:
20
- patch_match
21
- except NameError:
22
- print("patch_match compiling failed")
23
- patch_match_compiled = False
24
-
25
-
26
-
27
-
28
- def edge_pad(img, mask, mode=1):
29
- if mode == 0:
30
- nmask = mask.copy()
31
- nmask[nmask > 0] = 1
32
- res0 = 1 - nmask
33
- res1 = nmask
34
- p0 = np.stack(res0.nonzero(), axis=0).transpose()
35
- p1 = np.stack(res1.nonzero(), axis=0).transpose()
36
- min_dists, min_dist_idx = cKDTree(p1).query(p0, 1)
37
- loc = p1[min_dist_idx]
38
- for (a, b), (c, d) in zip(p0, loc):
39
- img[a, b] = img[c, d]
40
- elif mode == 1:
41
- record = {}
42
- kernel = [[1] * 3 for _ in range(3)]
43
- nmask = mask.copy()
44
- nmask[nmask > 0] = 1
45
- res = scipy.signal.convolve2d(
46
- nmask, kernel, mode="same", boundary="fill", fillvalue=1
47
- )
48
- res[nmask < 1] = 0
49
- res[res == 9] = 0
50
- res[res > 0] = 1
51
- ylst, xlst = res.nonzero()
52
- queue = [(y, x) for y, x in zip(ylst, xlst)]
53
- # bfs here
54
- cnt = res.astype(np.float32)
55
- acc = img.astype(np.float32)
56
- step = 1
57
- h = acc.shape[0]
58
- w = acc.shape[1]
59
- offset = [(1, 0), (-1, 0), (0, 1), (0, -1)]
60
- while queue:
61
- target = []
62
- for y, x in queue:
63
- val = acc[y][x]
64
- for yo, xo in offset:
65
- yn = y + yo
66
- xn = x + xo
67
- if 0 <= yn < h and 0 <= xn < w and nmask[yn][xn] < 1:
68
- if record.get((yn, xn), step) == step:
69
- acc[yn][xn] = acc[yn][xn] * cnt[yn][xn] + val
70
- cnt[yn][xn] += 1
71
- acc[yn][xn] /= cnt[yn][xn]
72
- if (yn, xn) not in record:
73
- record[(yn, xn)] = step
74
- target.append((yn, xn))
75
- step += 1
76
- queue = target
77
- img = acc.astype(np.uint8)
78
- else:
79
- nmask = mask.copy()
80
- ylst, xlst = nmask.nonzero()
81
- yt, xt = ylst.min(), xlst.min()
82
- yb, xb = ylst.max(), xlst.max()
83
- content = img[yt : yb + 1, xt : xb + 1]
84
- img = np.pad(
85
- content,
86
- ((yt, mask.shape[0] - yb - 1), (xt, mask.shape[1] - xb - 1), (0, 0)),
87
- mode="edge",
88
- )
89
- return img, mask
90
-
91
-
92
- def perlin_noise(img, mask):
93
- lin = np.linspace(0, 5, mask.shape[0], endpoint=False)
94
- x, y = np.meshgrid(lin, lin)
95
- avg = img.mean(axis=0).mean(axis=0)
96
- # noise=[((perlin(x, y)+1)*128+avg[i]).astype(np.uint8) for i in range(3)]
97
- noise = [((perlin(x, y) + 1) * 0.5 * 255).astype(np.uint8) for i in range(3)]
98
- noise = np.stack(noise, axis=-1)
99
- # mask=skimage.measure.block_reduce(mask,(8,8),np.min)
100
- # mask=mask.repeat(8, axis=0).repeat(8, axis=1)
101
- # mask_image=Image.fromarray(mask)
102
- # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 4))
103
- # mask=np.array(mask_image)
104
- nmask = mask.copy()
105
- # nmask=nmask/255.0
106
- nmask[mask > 0] = 1
107
- img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise
108
- # img=img.astype(np.uint8)
109
- return img, mask
110
-
111
-
112
- def gaussian_noise(img, mask):
113
- noise = np.random.randn(mask.shape[0], mask.shape[1], 3)
114
- noise = (noise + 1) / 2 * 255
115
- noise = noise.astype(np.uint8)
116
- nmask = mask.copy()
117
- nmask[mask > 0] = 1
118
- img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise
119
- return img, mask
120
-
121
-
122
- def cv2_telea(img, mask):
123
- ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_TELEA)
124
- return ret, mask
125
-
126
-
127
- def cv2_ns(img, mask):
128
- ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_NS)
129
- return ret, mask
130
-
131
-
132
- def patch_match_func(img, mask):
133
- ret = patch_match.inpaint(img, mask=255 - mask, patch_size=3)
134
- return ret, mask
135
-
136
-
137
- def mean_fill(img, mask):
138
- avg = img.mean(axis=0).mean(axis=0)
139
- img[mask < 1] = avg
140
- return img, mask
141
-
142
-
143
- functbl = {
144
- "gaussian": gaussian_noise,
145
- "perlin": perlin_noise,
146
- "edge_pad": edge_pad,
147
- "patchmatch": patch_match_func if (os.name != "nt" and patch_match_compiled) else edge_pad,
148
- "cv2_ns": cv2_ns,
149
- "cv2_telea": cv2_telea,
150
- "mean_fill": mean_fill,
151
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from PIL import ImageFilter
3
+ import cv2
4
+ import numpy as np
5
+ import scipy
6
+ import scipy.signal
7
+ from scipy.spatial import cKDTree
8
+
9
+ import os
10
+ from perlin2d import *
11
+
12
+ patch_match_compiled = True
13
+
14
+ try:
15
+ from PyPatchMatch import patch_match
16
+ except Exception as e:
17
+ try:
18
+ import patch_match
19
+ except Exception as e:
20
+ patch_match_compiled = False
21
+
22
+ try:
23
+ patch_match
24
+ except NameError:
25
+ print("patch_match compiling failed, will fall back to edge_pad")
26
+ patch_match_compiled = False
27
+
28
+
29
+
30
+
31
+ def edge_pad(img, mask, mode=1):
32
+ if mode == 0:
33
+ nmask = mask.copy()
34
+ nmask[nmask > 0] = 1
35
+ res0 = 1 - nmask
36
+ res1 = nmask
37
+ p0 = np.stack(res0.nonzero(), axis=0).transpose()
38
+ p1 = np.stack(res1.nonzero(), axis=0).transpose()
39
+ min_dists, min_dist_idx = cKDTree(p1).query(p0, 1)
40
+ loc = p1[min_dist_idx]
41
+ for (a, b), (c, d) in zip(p0, loc):
42
+ img[a, b] = img[c, d]
43
+ elif mode == 1:
44
+ record = {}
45
+ kernel = [[1] * 3 for _ in range(3)]
46
+ nmask = mask.copy()
47
+ nmask[nmask > 0] = 1
48
+ res = scipy.signal.convolve2d(
49
+ nmask, kernel, mode="same", boundary="fill", fillvalue=1
50
+ )
51
+ res[nmask < 1] = 0
52
+ res[res == 9] = 0
53
+ res[res > 0] = 1
54
+ ylst, xlst = res.nonzero()
55
+ queue = [(y, x) for y, x in zip(ylst, xlst)]
56
+ # bfs here
57
+ cnt = res.astype(np.float32)
58
+ acc = img.astype(np.float32)
59
+ step = 1
60
+ h = acc.shape[0]
61
+ w = acc.shape[1]
62
+ offset = [(1, 0), (-1, 0), (0, 1), (0, -1)]
63
+ while queue:
64
+ target = []
65
+ for y, x in queue:
66
+ val = acc[y][x]
67
+ for yo, xo in offset:
68
+ yn = y + yo
69
+ xn = x + xo
70
+ if 0 <= yn < h and 0 <= xn < w and nmask[yn][xn] < 1:
71
+ if record.get((yn, xn), step) == step:
72
+ acc[yn][xn] = acc[yn][xn] * cnt[yn][xn] + val
73
+ cnt[yn][xn] += 1
74
+ acc[yn][xn] /= cnt[yn][xn]
75
+ if (yn, xn) not in record:
76
+ record[(yn, xn)] = step
77
+ target.append((yn, xn))
78
+ step += 1
79
+ queue = target
80
+ img = acc.astype(np.uint8)
81
+ else:
82
+ nmask = mask.copy()
83
+ ylst, xlst = nmask.nonzero()
84
+ yt, xt = ylst.min(), xlst.min()
85
+ yb, xb = ylst.max(), xlst.max()
86
+ content = img[yt : yb + 1, xt : xb + 1]
87
+ img = np.pad(
88
+ content,
89
+ ((yt, mask.shape[0] - yb - 1), (xt, mask.shape[1] - xb - 1), (0, 0)),
90
+ mode="edge",
91
+ )
92
+ return img, mask
93
+
94
+
95
+ def perlin_noise(img, mask):
96
+ lin = np.linspace(0, 5, mask.shape[0], endpoint=False)
97
+ x, y = np.meshgrid(lin, lin)
98
+ avg = img.mean(axis=0).mean(axis=0)
99
+ # noise=[((perlin(x, y)+1)*128+avg[i]).astype(np.uint8) for i in range(3)]
100
+ noise = [((perlin(x, y) + 1) * 0.5 * 255).astype(np.uint8) for i in range(3)]
101
+ noise = np.stack(noise, axis=-1)
102
+ # mask=skimage.measure.block_reduce(mask,(8,8),np.min)
103
+ # mask=mask.repeat(8, axis=0).repeat(8, axis=1)
104
+ # mask_image=Image.fromarray(mask)
105
+ # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 4))
106
+ # mask=np.array(mask_image)
107
+ nmask = mask.copy()
108
+ # nmask=nmask/255.0
109
+ nmask[mask > 0] = 1
110
+ img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise
111
+ # img=img.astype(np.uint8)
112
+ return img, mask
113
+
114
+
115
+ def gaussian_noise(img, mask):
116
+ noise = np.random.randn(mask.shape[0], mask.shape[1], 3)
117
+ noise = (noise + 1) / 2 * 255
118
+ noise = noise.astype(np.uint8)
119
+ nmask = mask.copy()
120
+ nmask[mask > 0] = 1
121
+ img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise
122
+ return img, mask
123
+
124
+
125
+ def cv2_telea(img, mask):
126
+ ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_TELEA)
127
+ return ret, mask
128
+
129
+
130
+ def cv2_ns(img, mask):
131
+ ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_NS)
132
+ return ret, mask
133
+
134
+
135
+ def patch_match_func(img, mask):
136
+ ret = patch_match.inpaint(img, mask=255 - mask, patch_size=3)
137
+ return ret, mask
138
+
139
+
140
+ def mean_fill(img, mask):
141
+ avg = img.mean(axis=0).mean(axis=0)
142
+ img[mask < 1] = avg
143
+ return img, mask
144
+
145
+ """
146
+ Apache-2.0 license
147
+ https://github.com/hafriedlander/stable-diffusion-grpcserver/blob/main/sdgrpcserver/services/generate.py
148
+ https://github.com/parlance-zz/g-diffuser-bot/tree/g-diffuser-bot-beta2
149
+ _handleImageAdjustment
150
+ """
151
+ if True:
152
+ from sd_grpcserver.sdgrpcserver import images
153
+ import torch
154
+ from math import sqrt
155
+ def handleImageAdjustment(array, adjustments):
156
+ tensor = images.fromPIL(Image.fromarray(array))
157
+ for adjustment in adjustments:
158
+ which = adjustment[0]
159
+
160
+ if which == "blur":
161
+ sigma = adjustment[1]
162
+ direction = adjustment[2]
163
+
164
+ if direction == "DOWN" or direction == "UP":
165
+ orig = tensor
166
+ repeatCount=256
167
+ sigma /= sqrt(repeatCount)
168
+
169
+ for _ in range(repeatCount):
170
+ tensor = images.gaussianblur(tensor, sigma)
171
+ if direction == "DOWN":
172
+ tensor = torch.minimum(tensor, orig)
173
+ else:
174
+ tensor = torch.maximum(tensor, orig)
175
+ else:
176
+ tensor = images.gaussianblur(tensor, adjustment.blur.sigma)
177
+ elif which == "invert":
178
+ tensor = images.invert(tensor)
179
+ elif which == "levels":
180
+ tensor = images.levels(tensor, adjustment[1], adjustment[2], adjustment[3], adjustment[4])
181
+ elif which == "channels":
182
+ tensor = images.channelmap(tensor, [adjustment.channels.r, adjustment.channels.g, adjustment.channels.b, adjustment.channels.a])
183
+ elif which == "rescale":
184
+ self.unimp("Rescale")
185
+ elif which == "crop":
186
+ tensor = images.crop(tensor, adjustment.crop.top, adjustment.crop.left, adjustment.crop.height, adjustment.crop.width)
187
+ return np.array(images.toPIL(tensor)[0])
188
+
189
+ def g_diffuser(img,mask):
190
+ adjustments=[["blur",32,"UP"],["level",0,0.05,0,1]]
191
+ mask=handleImageAdjustment(mask,adjustments)
192
+ out_mask=handleImageAdjustment(mask,adjustments)
193
+ return img, mask, out_mask
194
+ def dummy_fill(img,mask):
195
+ return img,mask
196
+ functbl = {
197
+ "gaussian": gaussian_noise,
198
+ "perlin": perlin_noise,
199
+ "edge_pad": edge_pad,
200
+ "patchmatch": patch_match_func if patch_match_compiled else edge_pad,
201
+ "cv2_ns": cv2_ns,
202
+ "cv2_telea": cv2_telea,
203
+ "g_diffuser": g_diffuser,
204
+ "g_diffuser_lib": dummy_fill,
205
+ }
206
+
207
+ try:
208
+ from postprocess import PhotometricCorrection
209
+ correction_func = PhotometricCorrection()
210
+ except Exception as e:
211
+ print(e, "so PhotometricCorrection is disabled")
212
+ class DummyCorrection:
213
+ def __init__(self):
214
+ self.backend=""
215
+ pass
216
+ def run(self,a,b,**kwargs):
217
+ return b
218
+ correction_func=DummyCorrection()
219
+
220
+ if "taichi" in correction_func.backend:
221
+ import sys
222
+ import io
223
+ import base64
224
+ from PIL import Image
225
+ def base64_to_pil(base64_str):
226
+ data = base64.b64decode(str(base64_str))
227
+ pil = Image.open(io.BytesIO(data))
228
+ return pil
229
+
230
+ def pil_to_base64(out_pil):
231
+ out_buffer = io.BytesIO()
232
+ out_pil.save(out_buffer, format="PNG")
233
+ out_buffer.seek(0)
234
+ base64_bytes = base64.b64encode(out_buffer.read())
235
+ base64_str = base64_bytes.decode("ascii")
236
+ return base64_str
237
+ from subprocess import Popen, PIPE, STDOUT
238
+ class SubprocessCorrection:
239
+ def __init__(self):
240
+ self.backend=correction_func.backend
241
+ self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT)
242
+ def run(self,img_input,img_inpainted,mode):
243
+ if mode=="disabled":
244
+ return img_inpainted
245
+ base64_str_input = pil_to_base64(img_input)
246
+ base64_str_inpainted = pil_to_base64(img_inpainted)
247
+ try:
248
+ if self.child.poll():
249
+ self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT)
250
+ self.child.stdin.write(f"{base64_str_input},{base64_str_inpainted},{mode}\n".encode())
251
+ self.child.stdin.flush()
252
+ out = self.child.stdout.readline()
253
+ base64_str=out.decode().strip()
254
+ while base64_str and base64_str[0]=="[":
255
+ print(base64_str)
256
+ out = self.child.stdout.readline()
257
+ base64_str=out.decode().strip()
258
+ ret=base64_to_pil(base64_str)
259
+ except:
260
+ print("[PIE] not working, photometric correction is disabled")
261
+ ret=img_inpainted
262
+ return ret
263
+ correction_func = SubprocessCorrection()