Noename commited on
Commit
113edf5
·
1 Parent(s): dca6b80

ldm version

Browse files
Files changed (1) hide show
  1. app_ldm.py +164 -0
app_ldm.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import json
4
+ import random
5
+
6
+ import cv2
7
+ import einops
8
+ import gradio as gr
9
+ import numpy as np
10
+ import torch
11
+
12
+ from pytorch_lightning import seed_everything
13
+ from annotator.util import resize_image, HWC3
14
+ from cldm.model import create_model, load_state_dict
15
+ from cldm.ddim_hacked import DDIMSampler
16
+
17
+ import torch.nn as nn
18
+ from torch.nn.functional import threshold, normalize,interpolate
19
+ from torch.utils.data import Dataset
20
+ from torch.optim import Adam
21
+ from torch.utils.data import Dataset
22
+ from torchvision import transforms
23
+ from torch.utils.data import DataLoader
24
+ from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
25
+
26
+ import argparse
27
+
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+
30
+ parseargs = argparse.ArgumentParser()
31
+ parseargs.add_argument('--model', type=str, default='control_sd15_colorize_epoch=156.ckpt')
32
+ args = parseargs.parse_args()
33
+ model_path = args.model
34
+
35
+ feature_extractor = SegformerFeatureExtractor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
36
+ segmodel = SegformerForSemanticSegmentation.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
37
+
38
+ model = create_model('./models/control_sd15_colorize.yaml').cpu()
39
+ model.load_state_dict(load_state_dict(f"./models/{model_path}", location=device))
40
+ model = model.to(device)
41
+ ddim_sampler = DDIMSampler(model)
42
+
43
+ def LGB_TO_RGB(gray_image, rgb_image):
44
+ # gray_image [H, W, 1]
45
+ # rgb_image [H, W, 3]
46
+
47
+ lab_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2LAB)
48
+ lab_image[:, :, 0] = gray_image[:, :, 0]
49
+
50
+ return cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)
51
+
52
+
53
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, threshold, save_memory=False):
54
+ # center crop image to square
55
+ # H, W, _ = input_image.shape
56
+ # if H > W:
57
+ # input_image = input_image[(H - W) // 2:(H + W) // 2, :, :]
58
+ # elif W > H:
59
+ # input_image = input_image[:, (W - H) // 2:(H + W) // 2, :]
60
+
61
+ with torch.no_grad():
62
+ img = resize_image(input_image, image_resolution)
63
+ H, W, C = img.shape
64
+ print("img shape: ", img.shape)
65
+ if C == 3:
66
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
67
+ detected_map = img[:, :, None]
68
+ print("Gray image shape: ", detected_map.shape)
69
+ control = torch.from_numpy(detected_map.copy()).float().to(device)
70
+ # control = einops.rearrange(control, 'h w c -> 1 c h w')
71
+ print("Control shape: ", control.shape)
72
+
73
+ control = control / 255.0
74
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
75
+ print("Stacked control shape: ", control.shape)
76
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
77
+
78
+ if seed == -1:
79
+ seed = random.randint(0, 65535)
80
+ seed_everything(seed)
81
+
82
+ if save_memory:
83
+ model.low_vram_shift(is_diffusing=False)
84
+
85
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
86
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
87
+ shape = (4, H // 8, W // 8)
88
+
89
+ if save_memory:
90
+ model.low_vram_shift(is_diffusing=True)
91
+
92
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
93
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
94
+ shape, cond, verbose=False, eta=eta,
95
+ unconditional_guidance_scale=scale,
96
+ unconditional_conditioning=un_cond)
97
+
98
+ if save_memory:
99
+ model.low_vram_shift(is_diffusing=False)
100
+
101
+ x_samples = model.decode_first_stage(samples)
102
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
103
+
104
+ results = [x_samples[i] for i in range(num_samples)]
105
+ results = [LGB_TO_RGB(detected_map, result) for result in results]
106
+
107
+ # results의 각 이미지를 mask로 변환
108
+ masks = []
109
+ for result in results:
110
+ inputs = feature_extractor(images=result, return_tensors="pt")
111
+ outputs = segmodel(**inputs)
112
+ logits = outputs.logits
113
+ logits = logits.squeeze(0)
114
+ thresholded = torch.zeros_like(logits)
115
+ thresholded[logits > threshold] = 1
116
+ mask = thresholded[1: ,:, :].sum(dim=0)
117
+ mask = mask.unsqueeze(0).unsqueeze(0)
118
+ mask = interpolate(mask, size=(H, W), mode='bilinear')
119
+ mask = mask.detach().numpy()
120
+ mask = np.squeeze(mask)
121
+ mask = np.where(mask > threshold, 1, 0)
122
+ masks.append(mask)
123
+
124
+ # results의 각 이미지를 mask를 이용해 mask가 0인 부분은 img 즉 흑백 이미지로 변환.
125
+ # img를 channel이 3인 rgb 이미지로 변환
126
+ gray_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # [H, W, 3]
127
+ final = [gray_img * (1 - mask[:, :, None]) + result * mask[:, :, None] for result, mask in zip(results, masks)]
128
+
129
+ # mask to 255 img
130
+
131
+ mask_img = [mask * 255 for mask in masks]
132
+ return [detected_map.squeeze(-1)] + results + mask_img + final
133
+
134
+
135
+ block = gr.Blocks().queue()
136
+ with block:
137
+ with gr.Row():
138
+ gr.Markdown("## Control Stable Diffusion with Gray Image")
139
+ with gr.Row():
140
+ with gr.Column():
141
+ input_image = gr.Image(sources=['upload'], type="numpy")
142
+ prompt = gr.Textbox(label="Prompt")
143
+ run_button = gr.Button(value="Run")
144
+ with gr.Accordion("Advanced options", open=False):
145
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
146
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
147
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
148
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
149
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
150
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
151
+ threshold = gr.Slider(label="segmentation threshold", minimum=0.1, maximum=0.9, value=0.5, step=0.05)
152
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
153
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
154
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
155
+ n_prompt = gr.Textbox(label="Negative Prompt",
156
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
157
+ with gr.Column():
158
+ # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
159
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery")
160
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, threshold]
161
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery], concurrency_limit=4)
162
+
163
+ block.queue(max_size=100)
164
+ block.launch(share=True)