Noename commited on
Commit
f6f903c
·
1 Parent(s): c941cb1
Files changed (3) hide show
  1. app.py +69 -109
  2. app_ldm.py → app_diff.py +108 -68
  3. requirements.txt +1 -0
app.py CHANGED
@@ -11,137 +11,98 @@ import torch
11
 
12
  from pytorch_lightning import seed_everything
13
  from annotator.util import resize_image, HWC3
14
- from torch.nn.functional import threshold, normalize, interpolate
15
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
 
 
 
 
 
 
 
 
16
  from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
17
- from einops import rearrange, repeat
18
 
19
  import argparse
20
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
- # parse= argparse.ArgumentParser()
24
- # parseadd_argument('--pretrained_model', type=str, default='runwayml/stable-diffusion-v1-5')
25
- # parseadd_argument('--controlnet', type=str, default='controlnet')
26
- # parseadd_argument('--precision', type=str, default='fp32')
27
- # = parseparse_)
28
- # pretrained_model = pretrained_model
29
- pretrained_model = 'runwayml/stable-diffusion-v1-5'
30
- controlnet = 'checkpoint-36000/controlnet'
31
- precision = 'bf16'
32
-
33
- # Check for different hardware architectures
34
- if torch.cuda.is_available():
35
- device = "cuda"
36
- # Check for xformers
37
- try:
38
- import xformers
39
-
40
- enable_xformers = True
41
- except ImportError:
42
- enable_xformers = False
43
- elif torch.backends.mps.is_available():
44
- device = "mps"
45
- else:
46
- device = "cpu"
47
-
48
- print(f"Using device: {device}")
49
-
50
- # Load models
51
- if precision == 'fp32':
52
- torch_dtype = torch.float32
53
- elif precision == 'fp16':
54
- torch_dtype = torch.float16
55
- elif precision == 'bf16':
56
- torch_dtype = torch.bfloat16
57
- else:
58
- raise ValueError(f"Invalid precision: {precision}")
59
-
60
- controlnet = ControlNetModel.from_pretrained(controlnet, torch_dtype=torch_dtype)
61
- pipe = StableDiffusionControlNetPipeline.from_pretrained(
62
- pretrained_model, controlnet=controlnet, torch_dtype=torch_dtype
63
- )
64
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
65
- pipe = pipe.to(device)
66
-
67
- # Apply optimizations based on hardware
68
- if device == "cuda":
69
- pipe = pipe.to(device)
70
- if enable_xformers:
71
- pipe.enable_xformers_memory_efficient_attention()
72
- print("xformers optimization enabled")
73
- elif device == "mps":
74
- pipe = pipe.to(device)
75
- pipe.enable_attention_slicing()
76
- print("Attention slicing enabled for Apple Silicon")
77
- else:
78
- # CPU-specific optimizations
79
- pipe = pipe.to(device)
80
- # pipe.enable_sequential_cpu_offload()
81
- # pipe.enable_attention_slicing()
82
 
83
  feature_extractor = SegformerFeatureExtractor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
84
  segmodel = SegformerForSemanticSegmentation.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
85
 
 
 
 
 
86
 
87
  def LGB_TO_RGB(gray_image, rgb_image):
88
- # gray_image [H, W, 3]
89
  # rgb_image [H, W, 3]
90
 
91
- print("gray_image shape: ", gray_image.shape)
92
- print("rgb_image shape: ", rgb_image.shape)
93
-
94
- gray_image = cv2.cvtColor(gray_image, cv2.COLOR_RGB2GRAY)
95
  lab_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2LAB)
96
- lab_image[:, :, 0] = gray_image[:, :]
97
 
98
  return cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)
99
 
100
 
101
- @torch.inference_mode()
102
- def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, strength,
103
- guidance_scale, seed, eta, threshold, save_memory=False):
 
 
 
 
 
104
  with torch.no_grad():
105
  img = resize_image(input_image, image_resolution)
106
  H, W, C = img.shape
107
  print("img shape: ", img.shape)
108
  if C == 3:
109
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
110
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
111
- control = torch.from_numpy(img).to(device).float()
112
- control = control / 255.0
113
- control = rearrange(control, 'h w c -> 1 c h w')
114
- # control = repeat(control, 'b c h w -> b c h w', b=num_samples)
115
- # control = rearrange(control, 'b h w c -> b c h w')
116
 
117
- if a_prompt:
118
- prompt = prompt + ', ' + a_prompt
 
 
119
 
120
  if seed == -1:
121
  seed = random.randint(0, 65535)
122
  seed_everything(seed)
123
 
124
- generator = torch.Generator(device=device).manual_seed(seed)
125
- # Generate images
126
- output = pipe(
127
- num_images_per_prompt=num_samples,
128
- prompt=prompt,
129
- image=control.to(device),
130
- negative_prompt=n_prompt,
131
- num_inference_steps=ddim_steps,
132
- guidance_scale=guidance_scale,
133
- generator=generator,
134
- eta=eta,
135
- strength=strength,
136
- output_type='np',
 
 
137
 
138
- ).images
 
139
 
140
- # output = einops.rearrange(output, 'b c h w -> b h w c')
141
- output = (output * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
142
 
143
- results = [output[i] for i in range(num_samples)]
144
- results = [LGB_TO_RGB(img, result) for result in results]
145
 
146
  # results의 각 이미지를 mask로 변환
147
  masks = []
@@ -152,7 +113,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
152
  logits = logits.squeeze(0)
153
  thresholded = torch.zeros_like(logits)
154
  thresholded[logits > threshold] = 1
155
- mask = thresholded[1:, :, :].sum(dim=0)
156
  mask = mask.unsqueeze(0).unsqueeze(0)
157
  mask = interpolate(mask, size=(H, W), mode='bilinear')
158
  mask = mask.detach().numpy()
@@ -162,12 +123,13 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
162
 
163
  # results의 각 이미지를 mask를 이용해 mask가 0인 부분은 img 즉 흑백 이미지로 변환.
164
  # img를 channel이 3인 rgb 이미지로 변환
165
- final = [img * (1 - mask[:, :, None]) + result * mask[:, :, None] for result, mask in zip(results, masks)]
 
166
 
167
  # mask to 255 img
168
 
169
  mask_img = [mask * 255 for mask in masks]
170
- return [img] + results + mask_img + final
171
 
172
 
173
  block = gr.Blocks().queue()
@@ -180,15 +142,14 @@ with block:
180
  prompt = gr.Textbox(label="Prompt")
181
  run_button = gr.Button(value="Run")
182
  with gr.Accordion("Advanced options", open=False):
183
- num_samples = gr.Slider(label="Images", minimum=1, maximum=1, value=1, step=1, visible=False)
184
- # num_samples = 1
185
  image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
186
  strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
187
- # guess_mode = gr.Checkbox(label='Guess Mode', value=False)
188
- ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=20, value=20, step=1)
189
  scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
190
- threshold = gr.Slider(label="Segmentation Threshold", minimum=0.1, maximum=0.9, value=0.5, step=0.05)
191
- seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=-1, step=1)
192
  eta = gr.Number(label="eta (DDIM)", value=0.0)
193
  a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
194
  n_prompt = gr.Textbox(label="Negative Prompt",
@@ -196,9 +157,8 @@ with block:
196
  with gr.Column():
197
  # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
198
  result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery")
199
- ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, strength, scale, seed,
200
- eta, threshold]
201
- run_button.click(fn=process, inputs=ips, outputs=[result_gallery], concurrency_limit=4)
202
 
203
  block.queue(max_size=100)
204
  block.launch(share=True)
 
11
 
12
  from pytorch_lightning import seed_everything
13
  from annotator.util import resize_image, HWC3
14
+ from cldm.model import create_model, load_state_dict
15
+ from cldm.ddim_hacked import DDIMSampler
16
+
17
+ import torch.nn as nn
18
+ from torch.nn.functional import threshold, normalize,interpolate
19
+ from torch.utils.data import Dataset
20
+ from torch.optim import Adam
21
+ from torch.utils.data import Dataset
22
+ from torchvision import transforms
23
+ from torch.utils.data import DataLoader
24
  from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
 
25
 
26
  import argparse
27
 
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
+ parseargs = argparse.ArgumentParser()
31
+ parseargs.add_argument('--model', type=str, default='control_sd15_colorize_epoch=156.ckpt')
32
+ args = parseargs.parse_args()
33
+ model_path = args.model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  feature_extractor = SegformerFeatureExtractor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
36
  segmodel = SegformerForSemanticSegmentation.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
37
 
38
+ model = create_model('./models/control_sd15_colorize.yaml').cpu()
39
+ model.load_state_dict(load_state_dict(f"./models/{model_path}", location=device))
40
+ model = model.to(device)
41
+ ddim_sampler = DDIMSampler(model)
42
 
43
  def LGB_TO_RGB(gray_image, rgb_image):
44
+ # gray_image [H, W, 1]
45
  # rgb_image [H, W, 3]
46
 
 
 
 
 
47
  lab_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2LAB)
48
+ lab_image[:, :, 0] = gray_image[:, :, 0]
49
 
50
  return cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)
51
 
52
 
53
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, threshold, save_memory=False):
54
+ # center crop image to square
55
+ # H, W, _ = input_image.shape
56
+ # if H > W:
57
+ # input_image = input_image[(H - W) // 2:(H + W) // 2, :, :]
58
+ # elif W > H:
59
+ # input_image = input_image[:, (W - H) // 2:(H + W) // 2, :]
60
+
61
  with torch.no_grad():
62
  img = resize_image(input_image, image_resolution)
63
  H, W, C = img.shape
64
  print("img shape: ", img.shape)
65
  if C == 3:
66
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
67
+ detected_map = img[:, :, None]
68
+ print("Gray image shape: ", detected_map.shape)
69
+ control = torch.from_numpy(detected_map.copy()).float().to(device)
70
+ # control = einops.rearrange(control, 'h w c -> 1 c h w')
71
+ print("Control shape: ", control.shape)
 
72
 
73
+ control = control / 255.0
74
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
75
+ print("Stacked control shape: ", control.shape)
76
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
77
 
78
  if seed == -1:
79
  seed = random.randint(0, 65535)
80
  seed_everything(seed)
81
 
82
+ if save_memory:
83
+ model.low_vram_shift(is_diffusing=False)
84
+
85
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
86
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
87
+ shape = (4, H // 8, W // 8)
88
+
89
+ if save_memory:
90
+ model.low_vram_shift(is_diffusing=True)
91
+
92
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
93
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
94
+ shape, cond, verbose=False, eta=eta,
95
+ unconditional_guidance_scale=scale,
96
+ unconditional_conditioning=un_cond)
97
 
98
+ if save_memory:
99
+ model.low_vram_shift(is_diffusing=False)
100
 
101
+ x_samples = model.decode_first_stage(samples)
102
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
103
 
104
+ results = [x_samples[i] for i in range(num_samples)]
105
+ results = [LGB_TO_RGB(detected_map, result) for result in results]
106
 
107
  # results의 각 이미지를 mask로 변환
108
  masks = []
 
113
  logits = logits.squeeze(0)
114
  thresholded = torch.zeros_like(logits)
115
  thresholded[logits > threshold] = 1
116
+ mask = thresholded[1: ,:, :].sum(dim=0)
117
  mask = mask.unsqueeze(0).unsqueeze(0)
118
  mask = interpolate(mask, size=(H, W), mode='bilinear')
119
  mask = mask.detach().numpy()
 
123
 
124
  # results의 각 이미지를 mask를 이용해 mask가 0인 부분은 img 즉 흑백 이미지로 변환.
125
  # img를 channel이 3인 rgb 이미지로 변환
126
+ gray_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # [H, W, 3]
127
+ final = [gray_img * (1 - mask[:, :, None]) + result * mask[:, :, None] for result, mask in zip(results, masks)]
128
 
129
  # mask to 255 img
130
 
131
  mask_img = [mask * 255 for mask in masks]
132
+ return [detected_map.squeeze(-1)] + results + mask_img + final
133
 
134
 
135
  block = gr.Blocks().queue()
 
142
  prompt = gr.Textbox(label="Prompt")
143
  run_button = gr.Button(value="Run")
144
  with gr.Accordion("Advanced options", open=False):
145
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
 
146
  image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
147
  strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
148
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
149
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
150
  scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
151
+ threshold = gr.Slider(label="segmentation threshold", minimum=0.1, maximum=0.9, value=0.5, step=0.05)
152
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
153
  eta = gr.Number(label="eta (DDIM)", value=0.0)
154
  a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
155
  n_prompt = gr.Textbox(label="Negative Prompt",
 
157
  with gr.Column():
158
  # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
159
  result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery")
160
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, threshold]
161
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery], concurrency_limit=2)
 
162
 
163
  block.queue(max_size=100)
164
  block.launch(share=True)
app_ldm.py → app_diff.py RENAMED
@@ -11,98 +11,137 @@ import torch
11
 
12
  from pytorch_lightning import seed_everything
13
  from annotator.util import resize_image, HWC3
14
- from cldm.model import create_model, load_state_dict
15
- from cldm.ddim_hacked import DDIMSampler
16
-
17
- import torch.nn as nn
18
- from torch.nn.functional import threshold, normalize,interpolate
19
- from torch.utils.data import Dataset
20
- from torch.optim import Adam
21
- from torch.utils.data import Dataset
22
- from torchvision import transforms
23
- from torch.utils.data import DataLoader
24
  from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
 
25
 
26
  import argparse
27
 
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
- parseargs = argparse.ArgumentParser()
31
- parseargs.add_argument('--model', type=str, default='control_sd15_colorize_epoch=156.ckpt')
32
- args = parseargs.parse_args()
33
- model_path = args.model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  feature_extractor = SegformerFeatureExtractor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
36
  segmodel = SegformerForSemanticSegmentation.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
37
 
38
- model = create_model('./models/control_sd15_colorize.yaml').cpu()
39
- model.load_state_dict(load_state_dict(f"./models/{model_path}", location=device))
40
- model = model.to(device)
41
- ddim_sampler = DDIMSampler(model)
42
 
43
  def LGB_TO_RGB(gray_image, rgb_image):
44
- # gray_image [H, W, 1]
45
  # rgb_image [H, W, 3]
46
 
 
 
 
 
47
  lab_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2LAB)
48
- lab_image[:, :, 0] = gray_image[:, :, 0]
49
 
50
  return cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)
51
 
52
 
53
- def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, threshold, save_memory=False):
54
- # center crop image to square
55
- # H, W, _ = input_image.shape
56
- # if H > W:
57
- # input_image = input_image[(H - W) // 2:(H + W) // 2, :, :]
58
- # elif W > H:
59
- # input_image = input_image[:, (W - H) // 2:(H + W) // 2, :]
60
-
61
  with torch.no_grad():
62
  img = resize_image(input_image, image_resolution)
63
  H, W, C = img.shape
64
  print("img shape: ", img.shape)
65
  if C == 3:
66
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
67
- detected_map = img[:, :, None]
68
- print("Gray image shape: ", detected_map.shape)
69
- control = torch.from_numpy(detected_map.copy()).float().to(device)
70
- # control = einops.rearrange(control, 'h w c -> 1 c h w')
71
- print("Control shape: ", control.shape)
72
-
73
  control = control / 255.0
74
- control = torch.stack([control for _ in range(num_samples)], dim=0)
75
- print("Stacked control shape: ", control.shape)
76
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
 
 
 
77
 
78
  if seed == -1:
79
  seed = random.randint(0, 65535)
80
  seed_everything(seed)
81
 
82
- if save_memory:
83
- model.low_vram_shift(is_diffusing=False)
84
-
85
- cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
86
- un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
87
- shape = (4, H // 8, W // 8)
88
-
89
- if save_memory:
90
- model.low_vram_shift(is_diffusing=True)
91
-
92
- model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
93
- samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
94
- shape, cond, verbose=False, eta=eta,
95
- unconditional_guidance_scale=scale,
96
- unconditional_conditioning=un_cond)
97
 
98
- if save_memory:
99
- model.low_vram_shift(is_diffusing=False)
100
 
101
- x_samples = model.decode_first_stage(samples)
102
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
103
 
104
- results = [x_samples[i] for i in range(num_samples)]
105
- results = [LGB_TO_RGB(detected_map, result) for result in results]
106
 
107
  # results의 각 이미지를 mask로 변환
108
  masks = []
@@ -113,7 +152,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
113
  logits = logits.squeeze(0)
114
  thresholded = torch.zeros_like(logits)
115
  thresholded[logits > threshold] = 1
116
- mask = thresholded[1: ,:, :].sum(dim=0)
117
  mask = mask.unsqueeze(0).unsqueeze(0)
118
  mask = interpolate(mask, size=(H, W), mode='bilinear')
119
  mask = mask.detach().numpy()
@@ -123,13 +162,12 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
123
 
124
  # results의 각 이미지를 mask를 이용해 mask가 0인 부분은 img 즉 흑백 이미지로 변환.
125
  # img를 channel이 3인 rgb 이미지로 변환
126
- gray_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # [H, W, 3]
127
- final = [gray_img * (1 - mask[:, :, None]) + result * mask[:, :, None] for result, mask in zip(results, masks)]
128
 
129
  # mask to 255 img
130
 
131
  mask_img = [mask * 255 for mask in masks]
132
- return [detected_map.squeeze(-1)] + results + mask_img + final
133
 
134
 
135
  block = gr.Blocks().queue()
@@ -142,14 +180,15 @@ with block:
142
  prompt = gr.Textbox(label="Prompt")
143
  run_button = gr.Button(value="Run")
144
  with gr.Accordion("Advanced options", open=False):
145
- num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
 
146
  image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
147
  strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
148
- guess_mode = gr.Checkbox(label='Guess Mode', value=False)
149
- ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
150
  scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
151
- threshold = gr.Slider(label="segmentation threshold", minimum=0.1, maximum=0.9, value=0.5, step=0.05)
152
- seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
153
  eta = gr.Number(label="eta (DDIM)", value=0.0)
154
  a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
155
  n_prompt = gr.Textbox(label="Negative Prompt",
@@ -157,7 +196,8 @@ with block:
157
  with gr.Column():
158
  # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
159
  result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery")
160
- ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, threshold]
 
161
  run_button.click(fn=process, inputs=ips, outputs=[result_gallery], concurrency_limit=4)
162
 
163
  block.queue(max_size=100)
 
11
 
12
  from pytorch_lightning import seed_everything
13
  from annotator.util import resize_image, HWC3
14
+ from torch.nn.functional import threshold, normalize, interpolate
15
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
 
 
 
 
 
 
 
 
16
  from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
17
+ from einops import rearrange, repeat
18
 
19
  import argparse
20
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
+ # parse= argparse.ArgumentParser()
24
+ # parseadd_argument('--pretrained_model', type=str, default='runwayml/stable-diffusion-v1-5')
25
+ # parseadd_argument('--controlnet', type=str, default='controlnet')
26
+ # parseadd_argument('--precision', type=str, default='fp32')
27
+ # = parseparse_)
28
+ # pretrained_model = pretrained_model
29
+ pretrained_model = 'runwayml/stable-diffusion-v1-5'
30
+ controlnet = 'checkpoint-36000/controlnet'
31
+ precision = 'bf16'
32
+
33
+ # Check for different hardware architectures
34
+ if torch.cuda.is_available():
35
+ device = "cuda"
36
+ # Check for xformers
37
+ try:
38
+ import xformers
39
+
40
+ enable_xformers = True
41
+ except ImportError:
42
+ enable_xformers = False
43
+ elif torch.backends.mps.is_available():
44
+ device = "mps"
45
+ else:
46
+ device = "cpu"
47
+
48
+ print(f"Using device: {device}")
49
+
50
+ # Load models
51
+ if precision == 'fp32':
52
+ torch_dtype = torch.float32
53
+ elif precision == 'fp16':
54
+ torch_dtype = torch.float16
55
+ elif precision == 'bf16':
56
+ torch_dtype = torch.bfloat16
57
+ else:
58
+ raise ValueError(f"Invalid precision: {precision}")
59
+
60
+ controlnet = ControlNetModel.from_pretrained(controlnet, torch_dtype=torch_dtype)
61
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
62
+ pretrained_model, controlnet=controlnet, torch_dtype=torch_dtype
63
+ )
64
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
65
+ pipe = pipe.to(device)
66
+
67
+ # Apply optimizations based on hardware
68
+ if device == "cuda":
69
+ pipe = pipe.to(device)
70
+ if enable_xformers:
71
+ pipe.enable_xformers_memory_efficient_attention()
72
+ print("xformers optimization enabled")
73
+ elif device == "mps":
74
+ pipe = pipe.to(device)
75
+ pipe.enable_attention_slicing()
76
+ print("Attention slicing enabled for Apple Silicon")
77
+ else:
78
+ # CPU-specific optimizations
79
+ pipe = pipe.to(device)
80
+ # pipe.enable_sequential_cpu_offload()
81
+ # pipe.enable_attention_slicing()
82
 
83
  feature_extractor = SegformerFeatureExtractor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
84
  segmodel = SegformerForSemanticSegmentation.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
85
 
 
 
 
 
86
 
87
  def LGB_TO_RGB(gray_image, rgb_image):
88
+ # gray_image [H, W, 3]
89
  # rgb_image [H, W, 3]
90
 
91
+ print("gray_image shape: ", gray_image.shape)
92
+ print("rgb_image shape: ", rgb_image.shape)
93
+
94
+ gray_image = cv2.cvtColor(gray_image, cv2.COLOR_RGB2GRAY)
95
  lab_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2LAB)
96
+ lab_image[:, :, 0] = gray_image[:, :]
97
 
98
  return cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)
99
 
100
 
101
+ @torch.inference_mode()
102
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, strength,
103
+ guidance_scale, seed, eta, threshold, save_memory=False):
 
 
 
 
 
104
  with torch.no_grad():
105
  img = resize_image(input_image, image_resolution)
106
  H, W, C = img.shape
107
  print("img shape: ", img.shape)
108
  if C == 3:
109
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
110
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
111
+ control = torch.from_numpy(img).to(device).float()
 
 
 
 
112
  control = control / 255.0
113
+ control = rearrange(control, 'h w c -> 1 c h w')
114
+ # control = repeat(control, 'b c h w -> b c h w', b=num_samples)
115
+ # control = rearrange(control, 'b h w c -> b c h w')
116
+
117
+ if a_prompt:
118
+ prompt = prompt + ', ' + a_prompt
119
 
120
  if seed == -1:
121
  seed = random.randint(0, 65535)
122
  seed_everything(seed)
123
 
124
+ generator = torch.Generator(device=device).manual_seed(seed)
125
+ # Generate images
126
+ output = pipe(
127
+ num_images_per_prompt=num_samples,
128
+ prompt=prompt,
129
+ image=control.to(device),
130
+ negative_prompt=n_prompt,
131
+ num_inference_steps=ddim_steps,
132
+ guidance_scale=guidance_scale,
133
+ generator=generator,
134
+ eta=eta,
135
+ strength=strength,
136
+ output_type='np',
 
 
137
 
138
+ ).images
 
139
 
140
+ # output = einops.rearrange(output, 'b c h w -> b h w c')
141
+ output = (output * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
142
 
143
+ results = [output[i] for i in range(num_samples)]
144
+ results = [LGB_TO_RGB(img, result) for result in results]
145
 
146
  # results의 각 이미지를 mask로 변환
147
  masks = []
 
152
  logits = logits.squeeze(0)
153
  thresholded = torch.zeros_like(logits)
154
  thresholded[logits > threshold] = 1
155
+ mask = thresholded[1:, :, :].sum(dim=0)
156
  mask = mask.unsqueeze(0).unsqueeze(0)
157
  mask = interpolate(mask, size=(H, W), mode='bilinear')
158
  mask = mask.detach().numpy()
 
162
 
163
  # results의 각 이미지를 mask를 이용해 mask가 0인 부분은 img 즉 흑백 이미지로 변환.
164
  # img를 channel이 3인 rgb 이미지로 변환
165
+ final = [img * (1 - mask[:, :, None]) + result * mask[:, :, None] for result, mask in zip(results, masks)]
 
166
 
167
  # mask to 255 img
168
 
169
  mask_img = [mask * 255 for mask in masks]
170
+ return [img] + results + mask_img + final
171
 
172
 
173
  block = gr.Blocks().queue()
 
180
  prompt = gr.Textbox(label="Prompt")
181
  run_button = gr.Button(value="Run")
182
  with gr.Accordion("Advanced options", open=False):
183
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=1, value=1, step=1, visible=False)
184
+ # num_samples = 1
185
  image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
186
  strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
187
+ # guess_mode = gr.Checkbox(label='Guess Mode', value=False)
188
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=20, value=20, step=1)
189
  scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
190
+ threshold = gr.Slider(label="Segmentation Threshold", minimum=0.1, maximum=0.9, value=0.5, step=0.05)
191
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=-1, step=1)
192
  eta = gr.Number(label="eta (DDIM)", value=0.0)
193
  a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
194
  n_prompt = gr.Textbox(label="Negative Prompt",
 
196
  with gr.Column():
197
  # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
198
  result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery")
199
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, strength, scale, seed,
200
+ eta, threshold]
201
  run_button.click(fn=process, inputs=ips, outputs=[result_gallery], concurrency_limit=4)
202
 
203
  block.queue(max_size=100)
requirements.txt CHANGED
@@ -2,6 +2,7 @@ einops
2
  gradio
3
  numpy
4
  torch
 
5
  pytorch-lightning
6
  diffusers
7
  transformers
 
2
  gradio
3
  numpy
4
  torch
5
+ torchvision
6
  pytorch-lightning
7
  diffusers
8
  transformers