Adapter commited on
Commit
2aa90ca
·
1 Parent(s): 08102b2

new functions

Browse files
app.py CHANGED
@@ -8,14 +8,14 @@ os.system('mim install mmcv-full==1.7.0')
8
 
9
  from demo.model import Model_all
10
  import gradio as gr
11
- from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg, create_demo_depth, create_demo_depth_keypose
12
  import torch
13
  import subprocess
14
  import shlex
15
  from huggingface_hub import hf_hub_url
16
 
17
  urls = {
18
- 'TencentARC/T2I-Adapter':['models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth', 'models/t2iadapter_sketch_sd14v1.pth', 'models/t2iadapter_depth_sd14v1.pth'],
19
  'CompVis/stable-diffusion-v-1-4-original':['sd-v1-4.ckpt'],
20
  'andite/anything-v4.0':['anything-v4.0-pruned.ckpt', 'anything-v4.0.vae.pt'],
21
  }
@@ -44,37 +44,43 @@ for url in urls_mmpose:
44
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
45
  model = Model_all(device)
46
 
47
- DESCRIPTION = '''# T2I-Adapter (Sketch & Keypose & Segmentation & Depth)
48
- [Paper](https://arxiv.org/abs/2302.08453) [GitHub](https://github.com/TencentARC/T2I-Adapter)
49
 
50
- This gradio demo is for a simple experience of T2I-Adapter:
51
- - Keypose/Sketch to Image Generation
52
- - Image to Image Generation
53
- - Support the base model of Stable Diffusion v1.4 and Anything 4.0
 
54
  '''
55
 
56
  with gr.Blocks(css='style.css') as demo:
57
  gr.Markdown(DESCRIPTION)
58
 
59
- gr.HTML("""
60
- <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
61
  <br/>
62
  <a href="https://huggingface.co/spaces/Adapter/T2I-Adapter?duplicate=true">
63
  <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
64
  <p/>""")
65
 
66
  with gr.Tabs():
 
 
67
  with gr.TabItem('Keypose'):
68
  create_demo_keypose(model.process_keypose)
69
  with gr.TabItem('Sketch'):
70
  create_demo_sketch(model.process_sketch)
71
  with gr.TabItem('Draw'):
72
  create_demo_draw(model.process_draw)
73
- with gr.TabItem('Segmentation'):
74
- create_demo_seg(model.process_seg)
75
  with gr.TabItem('Depth'):
76
  create_demo_depth(model.process_depth)
77
- with gr.TabItem('Multi-adapters (Depth & Keypose)'):
78
  create_demo_depth_keypose(model.process_depth_keypose)
79
-
 
 
 
 
 
 
 
80
  demo.queue().launch(debug=True, server_name='0.0.0.0')
 
8
 
9
  from demo.model import Model_all
10
  import gradio as gr
11
+ from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg, create_demo_depth, create_demo_depth_keypose, create_demo_color, create_demo_color_sketch, create_demo_openpose, create_demo_style_sketch
12
  import torch
13
  import subprocess
14
  import shlex
15
  from huggingface_hub import hf_hub_url
16
 
17
  urls = {
18
+ 'TencentARC/T2I-Adapter':['models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_color_sd14v1.pth', 'models/t2iadapter_openpose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth', 'models/t2iadapter_sketch_sd14v1.pth', 'models/t2iadapter_depth_sd14v1.pth','third-party-models/body_pose_model.pth', "models/t2iadapter_style_sd14v1.pth"],
19
  'CompVis/stable-diffusion-v-1-4-original':['sd-v1-4.ckpt'],
20
  'andite/anything-v4.0':['anything-v4.0-pruned.ckpt', 'anything-v4.0.vae.pt'],
21
  }
 
44
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
45
  model = Model_all(device)
46
 
47
+ DESCRIPTION = '''# T2I-Adapter
 
48
 
49
+ Gradio demo for **T2I-Adapter**: [[GitHub]](https://github.com/TencentARC/T2I-Adapter), [[Paper]](https://arxiv.org/abs/2302.08453).
50
+
51
+ It also supports **multiple adapters** in the follwing tabs showing **"A adapter + B adapter"**.
52
+
53
+ If T2I-Adapter is helpful, please help to ⭐ the [Github Repo](https://github.com/TencentARC/T2I-Adapter) and recommend it to your friends 😊
54
  '''
55
 
56
  with gr.Blocks(css='style.css') as demo:
57
  gr.Markdown(DESCRIPTION)
58
 
59
+ gr.HTML("""<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
 
60
  <br/>
61
  <a href="https://huggingface.co/spaces/Adapter/T2I-Adapter?duplicate=true">
62
  <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
63
  <p/>""")
64
 
65
  with gr.Tabs():
66
+ with gr.TabItem('Openpose'):
67
+ create_demo_openpose(model.process_openpose)
68
  with gr.TabItem('Keypose'):
69
  create_demo_keypose(model.process_keypose)
70
  with gr.TabItem('Sketch'):
71
  create_demo_sketch(model.process_sketch)
72
  with gr.TabItem('Draw'):
73
  create_demo_draw(model.process_draw)
 
 
74
  with gr.TabItem('Depth'):
75
  create_demo_depth(model.process_depth)
76
+ with gr.TabItem('Depth + Keypose'):
77
  create_demo_depth_keypose(model.process_depth_keypose)
78
+ with gr.TabItem('Color'):
79
+ create_demo_color(model.process_color)
80
+ with gr.TabItem('Color + Sketch'):
81
+ create_demo_color_sketch(model.process_color_sketch)
82
+ with gr.TabItem('Style + Sketch'):
83
+ create_demo_style_sketch(model.process_style_sketch)
84
+ with gr.TabItem('Segmentation'):
85
+ create_demo_seg(model.process_seg)
86
  demo.queue().launch(debug=True, server_name='0.0.0.0')
demo/demos.py CHANGED
@@ -18,11 +18,6 @@ def create_demo_keypose(process):
18
  with gr.Blocks() as demo:
19
  with gr.Row():
20
  gr.Markdown('## T2I-Adapter (Keypose)')
21
- # with gr.Row():
22
- # with gr.Column():
23
- # gr.Textbox(value="Hello Memory")
24
- # with gr.Column():
25
- # gr.JSON(get_system_memory, every=1)
26
  with gr.Row():
27
  with gr.Column():
28
  input_img = gr.Image(source='upload', type="numpy")
@@ -44,6 +39,31 @@ def create_demo_keypose(process):
44
  run_button.click(fn=process, inputs=ips, outputs=[result])
45
  return demo
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def create_demo_sketch(process):
48
  with gr.Blocks() as demo:
49
  with gr.Row():
@@ -70,6 +90,91 @@ def create_demo_sketch(process):
70
  run_button.click(fn=process, inputs=ips, outputs=[result])
71
  return demo
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def create_demo_seg(process):
74
  with gr.Blocks() as demo:
75
  with gr.Row():
 
18
  with gr.Blocks() as demo:
19
  with gr.Row():
20
  gr.Markdown('## T2I-Adapter (Keypose)')
 
 
 
 
 
21
  with gr.Row():
22
  with gr.Column():
23
  input_img = gr.Image(source='upload', type="numpy")
 
39
  run_button.click(fn=process, inputs=ips, outputs=[result])
40
  return demo
41
 
42
+ def create_demo_openpose(process):
43
+ with gr.Blocks() as demo:
44
+ with gr.Row():
45
+ gr.Markdown('## T2I-Adapter (Openpose)')
46
+ with gr.Row():
47
+ with gr.Column():
48
+ input_img = gr.Image(source='upload', type="numpy")
49
+ prompt = gr.Textbox(label="Prompt")
50
+ neg_prompt = gr.Textbox(label="Negative Prompt",
51
+ value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
52
+ pos_prompt = gr.Textbox(label="Positive Prompt",
53
+ value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
54
+ with gr.Row():
55
+ type_in = gr.inputs.Radio(['Openpose', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a openpose map)')
56
+ fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed to produce a fixed output)')
57
+ run_button = gr.Button(label="Run")
58
+ con_strength = gr.Slider(label="Controling Strength (The guidance strength of the openpose to the result)", minimum=0, maximum=1, value=1, step=0.1)
59
+ scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
60
+ base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
61
+ with gr.Column():
62
+ result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
63
+ ips = [input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
64
+ run_button.click(fn=process, inputs=ips, outputs=[result])
65
+ return demo
66
+
67
  def create_demo_sketch(process):
68
  with gr.Blocks() as demo:
69
  with gr.Row():
 
90
  run_button.click(fn=process, inputs=ips, outputs=[result])
91
  return demo
92
 
93
+ def create_demo_color_sketch(process):
94
+ with gr.Blocks() as demo:
95
+ with gr.Row():
96
+ gr.Markdown('## T2I-Adapter (Color + Sketch)')
97
+ with gr.Row():
98
+ with gr.Column():
99
+ with gr.Row():
100
+ input_img_sketch = gr.Image(source='upload', type="numpy", label='Sketch guidance')
101
+ input_img_color = gr.Image(source='upload', type="numpy", label='Color guidance')
102
+ prompt = gr.Textbox(label="Prompt")
103
+ neg_prompt = gr.Textbox(label="Negative Prompt",
104
+ value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
105
+ pos_prompt = gr.Textbox(label="Positive Prompt",
106
+ value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
107
+ type_in_color = gr.inputs.Radio(['ColorMap', 'Image'], type="value", default='Image', label='Input Types of Color\n (You can input an image or a color map)')
108
+ with gr.Row():
109
+ type_in = gr.inputs.Radio(['Sketch', 'Image'], type="value", default='Image', label='Input Types of Sketch\n (You can input an image or a sketch)')
110
+ color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the sketch background\n (Only work for sketch input)')
111
+ with gr.Row():
112
+ w_sketch = gr.Slider(label="Depth guidance weight", minimum=0, maximum=2, value=1.0, step=0.1)
113
+ w_color = gr.Slider(label="Color guidance weight", minimum=0, maximum=2, value=1.2, step=0.1)
114
+ run_button = gr.Button(label="Run")
115
+ con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=0.4, step=0.1)
116
+ scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
117
+ fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
118
+ base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
119
+ with gr.Column():
120
+ result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=3, height='auto')
121
+ ips = [input_img_sketch, input_img_color, type_in, type_in_color, w_sketch, w_color, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
122
+ run_button.click(fn=process, inputs=ips, outputs=[result])
123
+ return demo
124
+
125
+ def create_demo_style_sketch(process):
126
+ with gr.Blocks() as demo:
127
+ with gr.Row():
128
+ gr.Markdown('## T2I-Adapter (Style + Sketch)')
129
+ with gr.Row():
130
+ with gr.Column():
131
+ with gr.Row():
132
+ input_img_sketch = gr.Image(source='upload', type="numpy", label='Sketch guidance')
133
+ input_img_style = gr.Image(source='upload', type="numpy", label='Style guidance')
134
+ prompt = gr.Textbox(label="Prompt")
135
+ neg_prompt = gr.Textbox(label="Negative Prompt",
136
+ value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
137
+ pos_prompt = gr.Textbox(label="Positive Prompt",
138
+ value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
139
+ with gr.Row():
140
+ type_in = gr.inputs.Radio(['Sketch', 'Image'], type="value", default='Image', label='Input Types of Sketch\n (You can input an image or a sketch)')
141
+ color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the sketch background\n (Only work for sketch input)')
142
+ run_button = gr.Button(label="Run")
143
+ con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=1, step=0.1)
144
+ scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
145
+ fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
146
+ base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
147
+ with gr.Column():
148
+ result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
149
+ ips = [input_img_sketch, input_img_style, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
150
+ run_button.click(fn=process, inputs=ips, outputs=[result])
151
+ return demo
152
+
153
+ def create_demo_color(process):
154
+ with gr.Blocks() as demo:
155
+ with gr.Row():
156
+ gr.Markdown('## T2I-Adapter (Color)')
157
+ with gr.Row():
158
+ with gr.Column():
159
+ input_img = gr.Image(source='upload', type="numpy", label='Color guidance')
160
+ prompt = gr.Textbox(label="Prompt")
161
+ neg_prompt = gr.Textbox(label="Negative Prompt",
162
+ value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
163
+ pos_prompt = gr.Textbox(label="Positive Prompt",
164
+ value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
165
+ type_in_color = gr.inputs.Radio(['ColorMap', 'Image'], type="value", default='Image', label='Input Types of Color\n (You can input an image or a color map)')
166
+ w_color = gr.Slider(label="Color guidance weight", minimum=0, maximum=2, value=1, step=0.1)
167
+ run_button = gr.Button(label="Run")
168
+ con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=1, step=0.1)
169
+ scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
170
+ fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
171
+ base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
172
+ with gr.Column():
173
+ result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
174
+ ips = [input_img, prompt, neg_prompt, pos_prompt, w_color, type_in_color, fix_sample, scale, con_strength, base_model]
175
+ run_button.click(fn=process, inputs=ips, outputs=[result])
176
+ return demo
177
+
178
  def create_demo_seg(process):
179
  with gr.Blocks() as demo:
180
  with gr.Row():
demo/model.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  from basicsr.utils import img2tensor, tensor2img
3
  from pytorch_lightning import seed_everything
4
  from ldm.models.diffusion.plms import PLMSSampler
5
- from ldm.modules.encoders.adapter import Adapter
6
  from ldm.util import instantiate_from_config
7
  from ldm.modules.structure_condition.model_edge import pidinet
8
  from ldm.modules.structure_condition.model_seg import seger, Colorize
@@ -16,6 +16,8 @@ import os
16
  import cv2
17
  import numpy as np
18
  import torch.nn.functional as F
 
 
19
 
20
 
21
  def preprocessing(image, device):
@@ -151,9 +153,9 @@ class Model_all:
151
  self.model_seg = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
152
  use_conv=False).to(device)
153
  self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
154
- self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
155
 
156
  # depth part
 
157
  self.model_depth = Adapter(cin=3 * 64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
158
  use_conv=False).to(device)
159
  self.model_depth.load_state_dict(torch.load("models/t2iadapter_depth_sd14v1.pth", map_location=device))
@@ -162,6 +164,23 @@ class Model_all:
162
  self.model_pose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
163
  use_conv=False).to(device)
164
  self.model_pose.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth", map_location=device))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  ## mmpose
166
  det_config = 'models/faster_rcnn_r50_fpn_coco.py'
167
  det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
@@ -257,7 +276,202 @@ class Model_all:
257
  x_samples_ddim = x_samples_ddim.astype(np.uint8)
258
 
259
  return [im_edge, x_samples_ddim]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  @torch.no_grad()
262
  def process_depth(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
263
  con_strength, base_model):
@@ -638,6 +852,67 @@ class Model_all:
638
  x_samples_ddim = x_samples_ddim.astype(np.uint8)
639
 
640
  return [im_pose[:, :, ::-1].astype(np.uint8), x_samples_ddim]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641
 
642
 
643
  if __name__ == '__main__':
 
2
  from basicsr.utils import img2tensor, tensor2img
3
  from pytorch_lightning import seed_everything
4
  from ldm.models.diffusion.plms import PLMSSampler
5
+ from ldm.modules.encoders.adapter import Adapter, Adapter_light, StyleAdapter
6
  from ldm.util import instantiate_from_config
7
  from ldm.modules.structure_condition.model_edge import pidinet
8
  from ldm.modules.structure_condition.model_seg import seger, Colorize
 
16
  import cv2
17
  import numpy as np
18
  import torch.nn.functional as F
19
+ from transformers import CLIPProcessor, CLIPVisionModel
20
+ from PIL import Image
21
 
22
 
23
  def preprocessing(image, device):
 
153
  self.model_seg = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
154
  use_conv=False).to(device)
155
  self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
 
156
 
157
  # depth part
158
+ self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
159
  self.model_depth = Adapter(cin=3 * 64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
160
  use_conv=False).to(device)
161
  self.model_depth.load_state_dict(torch.load("models/t2iadapter_depth_sd14v1.pth", map_location=device))
 
164
  self.model_pose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
165
  use_conv=False).to(device)
166
  self.model_pose.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth", map_location=device))
167
+
168
+ # openpose part
169
+ self.model_openpose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
170
+ use_conv=False).to(device)
171
+ self.model_openpose.load_state_dict(torch.load("models/t2iadapter_openpose_sd14v1.pth", map_location=device))
172
+
173
+ # color part
174
+ self.model_color = Adapter_light(cin=int(3 * 64), channels=[320, 640, 1280, 1280], nums_rb=4).to(device)
175
+ self.model_color.load_state_dict(torch.load("models/t2iadapter_color_sd14v1_small.pth", map_location=device))
176
+
177
+ # style part
178
+ self.model_style = StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8).to(device)
179
+ self.model_style.load_state_dict(torch.load("models/t2iadapter_style_sd14v1.pth", map_location=device))
180
+ self.clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14')
181
+ self.clip_vision_model = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14').to(device)
182
+
183
+ device = 'cpu'
184
  ## mmpose
185
  det_config = 'models/faster_rcnn_r50_fpn_coco.py'
186
  det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
 
276
  x_samples_ddim = x_samples_ddim.astype(np.uint8)
277
 
278
  return [im_edge, x_samples_ddim]
279
+
280
+ @torch.no_grad()
281
+ def process_color_sketch(self, input_img_sketch, input_img_color, type_in, type_in_color, w_sketch, w_color, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
282
+ if self.current_base != base_model:
283
+ ckpt = os.path.join("models", base_model)
284
+ pl_sd = torch.load(ckpt, map_location="cuda")
285
+ if "state_dict" in pl_sd:
286
+ sd = pl_sd["state_dict"]
287
+ else:
288
+ sd = pl_sd
289
+ self.base_model.load_state_dict(sd, strict=False)
290
+ self.current_base = base_model
291
+ if 'anything' in base_model.lower():
292
+ self.load_vae()
293
+
294
+ con_strength = int((1 - con_strength) * 50)
295
+ if fix_sample == 'True':
296
+ seed_everything(42)
297
+ im = cv2.resize(input_img_sketch, (512, 512))
298
+
299
+ if type_in == 'Sketch':
300
+ if color_back == 'White':
301
+ im = 255 - im
302
+ im_edge = im.copy()
303
+ im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
304
+ im = im > 0.5
305
+ im = im.float()
306
+ elif type_in == 'Image':
307
+ im = img2tensor(im).unsqueeze(0) / 255.
308
+ im = self.model_edge(im.to(self.device))[-1]#.cuda()
309
+ im = im > 0.5
310
+ im = im.float()
311
+ im_edge = tensor2img(im)
312
+ if type_in_color == 'Image':
313
+ input_img_color = cv2.resize(input_img_color,(512//64, 512//64), interpolation=cv2.INTER_CUBIC)
314
+ input_img_color = cv2.resize(input_img_color,(512,512), interpolation=cv2.INTER_NEAREST)
315
+ else:
316
+ input_img_color = cv2.resize(input_img_color, (512, 512))
317
+ im_color = input_img_color.copy()
318
+ im_color_tensor = img2tensor(input_img_color, bgr2rgb=False).unsqueeze(0) / 255.
319
+
320
+ # extract condition features
321
+ c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
322
+ nc = self.base_model.get_learned_conditioning([neg_prompt])
323
+ features_adapter_sketch = self.model_sketch(im.to(self.device))
324
+ features_adapter_color = self.model_color(im_color_tensor.to(self.device))
325
+ features_adapter = [fs*w_sketch+fc*w_color for fs, fc in zip(features_adapter_sketch,features_adapter_color)]
326
+ shape = [4, 64, 64]
327
+
328
+ # sampling
329
+ samples_ddim, _ = self.sampler.sample(S=50,
330
+ conditioning=c,
331
+ batch_size=1,
332
+ shape=shape,
333
+ verbose=False,
334
+ unconditional_guidance_scale=scale,
335
+ unconditional_conditioning=nc,
336
+ eta=0.0,
337
+ x_T=None,
338
+ features_adapter1=features_adapter,
339
+ mode='sketch',
340
+ con_strength=con_strength)
341
+
342
+ x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
343
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
344
+ x_samples_ddim = x_samples_ddim.to('cpu')
345
+ x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
346
+ x_samples_ddim = 255. * x_samples_ddim
347
+ x_samples_ddim = x_samples_ddim.astype(np.uint8)
348
 
349
+ return [im_edge, im_color, x_samples_ddim]
350
+
351
+ @torch.no_grad()
352
+ def process_style_sketch(self, input_img_sketch, input_img_style, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
353
+ if self.current_base != base_model:
354
+ ckpt = os.path.join("models", base_model)
355
+ pl_sd = torch.load(ckpt, map_location="cuda")
356
+ if "state_dict" in pl_sd:
357
+ sd = pl_sd["state_dict"]
358
+ else:
359
+ sd = pl_sd
360
+ self.base_model.load_state_dict(sd, strict=False)
361
+ self.current_base = base_model
362
+ if 'anything' in base_model.lower():
363
+ self.load_vae()
364
+
365
+ con_strength = int((1 - con_strength) * 50)
366
+ if fix_sample == 'True':
367
+ seed_everything(42)
368
+ im = cv2.resize(input_img_sketch, (512, 512))
369
+
370
+ if type_in == 'Sketch':
371
+ if color_back == 'White':
372
+ im = 255 - im
373
+ im_edge = im.copy()
374
+ im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
375
+ im = im > 0.5
376
+ im = im.float()
377
+ elif type_in == 'Image':
378
+ im = img2tensor(im).unsqueeze(0) / 255.
379
+ im = self.model_edge(im.to(self.device))[-1]#.cuda()
380
+ im = im > 0.5
381
+ im = im.float()
382
+ im_edge = tensor2img(im)
383
+
384
+ style = Image.fromarray(input_img_style)
385
+ style_for_clip = self.clip_processor(images=style, return_tensors="pt")['pixel_values']
386
+ style_feat = self.clip_vision_model(style_for_clip.to(self.device))['last_hidden_state']
387
+ style_feat = self.model_style(style_feat)
388
+
389
+ # extract condition features
390
+ c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
391
+ nc = self.base_model.get_learned_conditioning([neg_prompt])
392
+ features_adapter = self.model_sketch(im.to(self.device))
393
+ shape = [4, 64, 64]
394
+
395
+ # sampling
396
+ samples_ddim, _ = self.sampler.sample(S=50,
397
+ conditioning=c,
398
+ batch_size=1,
399
+ shape=shape,
400
+ verbose=False,
401
+ unconditional_guidance_scale=scale,
402
+ unconditional_conditioning=nc,
403
+ eta=0.0,
404
+ x_T=None,
405
+ features_adapter1=features_adapter,
406
+ mode='style',
407
+ con_strength=con_strength,
408
+ style_feature=style_feat)
409
+
410
+ x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
411
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
412
+ x_samples_ddim = x_samples_ddim.to('cpu')
413
+ x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
414
+ x_samples_ddim = 255. * x_samples_ddim
415
+ x_samples_ddim = x_samples_ddim.astype(np.uint8)
416
+
417
+ return [im_edge, x_samples_ddim]
418
+
419
+ @torch.no_grad()
420
+ def process_color(self, input_img, prompt, neg_prompt, pos_prompt, w_color, type_in_color, fix_sample, scale, con_strength, base_model):
421
+ if self.current_base != base_model:
422
+ ckpt = os.path.join("models", base_model)
423
+ pl_sd = torch.load(ckpt, map_location="cuda")
424
+ if "state_dict" in pl_sd:
425
+ sd = pl_sd["state_dict"]
426
+ else:
427
+ sd = pl_sd
428
+ self.base_model.load_state_dict(sd, strict=False)
429
+ self.current_base = base_model
430
+ if 'anything' in base_model.lower():
431
+ self.load_vae()
432
+
433
+ con_strength = int((1 - con_strength) * 50)
434
+ if fix_sample == 'True':
435
+ seed_everything(42)
436
+ if type_in_color == 'Image':
437
+ input_img = cv2.resize(input_img,(512//64, 512//64), interpolation=cv2.INTER_CUBIC)
438
+ input_img = cv2.resize(input_img,(512,512), interpolation=cv2.INTER_NEAREST)
439
+ else:
440
+ input_img = cv2.resize(input_img, (512, 512))
441
+
442
+ im_color = input_img.copy()
443
+ im = img2tensor(input_img, bgr2rgb=False).unsqueeze(0) / 255.
444
+
445
+ # extract condition features
446
+ c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
447
+ nc = self.base_model.get_learned_conditioning([neg_prompt])
448
+ features_adapter = self.model_color(im.to(self.device))
449
+ features_adapter = [fi*w_color for fi in features_adapter]
450
+ shape = [4, 64, 64]
451
+
452
+ # sampling
453
+ samples_ddim, _ = self.sampler.sample(S=50,
454
+ conditioning=c,
455
+ batch_size=1,
456
+ shape=shape,
457
+ verbose=False,
458
+ unconditional_guidance_scale=scale,
459
+ unconditional_conditioning=nc,
460
+ eta=0.0,
461
+ x_T=None,
462
+ features_adapter1=features_adapter,
463
+ mode='sketch',
464
+ con_strength=con_strength)
465
+
466
+ x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
467
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
468
+ x_samples_ddim = x_samples_ddim.to('cpu')
469
+ x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
470
+ x_samples_ddim = 255. * x_samples_ddim
471
+ x_samples_ddim = x_samples_ddim.astype(np.uint8)
472
+
473
+ return [im_color, x_samples_ddim]
474
+
475
  @torch.no_grad()
476
  def process_depth(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
477
  con_strength, base_model):
 
852
  x_samples_ddim = x_samples_ddim.astype(np.uint8)
853
 
854
  return [im_pose[:, :, ::-1].astype(np.uint8), x_samples_ddim]
855
+
856
+ @torch.no_grad()
857
+ def process_openpose(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength,
858
+ base_model):
859
+ if self.current_base != base_model:
860
+ ckpt = os.path.join("models", base_model)
861
+ pl_sd = torch.load(ckpt, map_location="cuda")
862
+ if "state_dict" in pl_sd:
863
+ sd = pl_sd["state_dict"]
864
+ else:
865
+ sd = pl_sd
866
+ self.base_model.load_state_dict(sd, strict=False)
867
+ self.current_base = base_model
868
+ if 'anything' in base_model.lower():
869
+ self.load_vae()
870
+
871
+ con_strength = int((1 - con_strength) * 50)
872
+ if fix_sample == 'True':
873
+ seed_everything(42)
874
+ im = cv2.resize(input_img, (512, 512))
875
+
876
+ if type_in == 'Openpose':
877
+ im_pose = im.copy()[:,:,::-1]
878
+ elif type_in == 'Image':
879
+ from ldm.modules.structure_condition.openpose.api import OpenposeInference
880
+ model = OpenposeInference()
881
+ keypose = model(im)
882
+ im_pose = keypose.copy()[:,:,::-1]
883
+ # keypose = img2tensor(keypose).unsqueeze(0) / 255.
884
+
885
+ # extract condition features
886
+ c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
887
+ nc = self.base_model.get_learned_conditioning([neg_prompt])
888
+ pose = img2tensor(im_pose, bgr2rgb=True, float32=True) / 255.
889
+ pose = pose.unsqueeze(0)
890
+ features_adapter = self.model_openpose(pose.to(self.device))
891
+
892
+ shape = [4, 64, 64]
893
+
894
+ # sampling
895
+ samples_ddim, _ = self.sampler.sample(S=50,
896
+ conditioning=c,
897
+ batch_size=1,
898
+ shape=shape,
899
+ verbose=False,
900
+ unconditional_guidance_scale=scale,
901
+ unconditional_conditioning=nc,
902
+ eta=0.0,
903
+ x_T=None,
904
+ features_adapter1=features_adapter,
905
+ mode='sketch',
906
+ con_strength=con_strength)
907
+
908
+ x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
909
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
910
+ x_samples_ddim = x_samples_ddim.to('cpu')
911
+ x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
912
+ x_samples_ddim = 255. * x_samples_ddim
913
+ x_samples_ddim = x_samples_ddim.astype(np.uint8)
914
+
915
+ return [im_pose[:, :, ::-1].astype(np.uint8), x_samples_ddim]
916
 
917
 
918
  if __name__ == '__main__':
ldm/models/diffusion/plms.py CHANGED
@@ -79,6 +79,7 @@ class PLMSSampler(object):
79
  features_adapter2=None,
80
  mode = 'sketch',
81
  con_strength=30,
 
82
  # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
83
  **kwargs
84
  ):
@@ -115,7 +116,8 @@ class PLMSSampler(object):
115
  features_adapter1=copy.deepcopy(features_adapter1),
116
  features_adapter2=copy.deepcopy(features_adapter2),
117
  mode = mode,
118
- con_strength = con_strength
 
119
  )
120
  return samples, intermediates
121
 
@@ -125,7 +127,7 @@ class PLMSSampler(object):
125
  callback=None, timesteps=None, quantize_denoised=False,
126
  mask=None, x0=None, img_callback=None, log_every_t=100,
127
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
128
- unconditional_guidance_scale=1., unconditional_conditioning=None,features_adapter1=None, features_adapter2=None, mode='sketch', con_strength=30):
129
  device = self.model.betas.device
130
  b = shape[0]
131
  if x_T is None:
@@ -161,6 +163,16 @@ class PLMSSampler(object):
161
  features_adapter = None
162
  else:
163
  features_adapter = features_adapter1
 
 
 
 
 
 
 
 
 
 
164
  elif mode == 'mul':
165
  features_adapter = [a1i*0.5 + a2i for a1i, a2i in zip(features_adapter1, features_adapter2)]
166
  else:
 
79
  features_adapter2=None,
80
  mode = 'sketch',
81
  con_strength=30,
82
+ style_feature=None,
83
  # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
84
  **kwargs
85
  ):
 
116
  features_adapter1=copy.deepcopy(features_adapter1),
117
  features_adapter2=copy.deepcopy(features_adapter2),
118
  mode = mode,
119
+ con_strength = con_strength,
120
+ style_feature=style_feature
121
  )
122
  return samples, intermediates
123
 
 
127
  callback=None, timesteps=None, quantize_denoised=False,
128
  mask=None, x0=None, img_callback=None, log_every_t=100,
129
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
130
+ unconditional_guidance_scale=1., unconditional_conditioning=None,features_adapter1=None, features_adapter2=None, mode='sketch', con_strength=30, style_feature=None):
131
  device = self.model.betas.device
132
  b = shape[0]
133
  if x_T is None:
 
163
  features_adapter = None
164
  else:
165
  features_adapter = features_adapter1
166
+ elif mode == 'style':
167
+ if index<con_strength:
168
+ features_adapter = None
169
+ else:
170
+ features_adapter = features_adapter1
171
+
172
+ if index>25:
173
+ cond = torch.cat([cond, style_feature], dim=1)
174
+ unconditional_conditioning = torch.cat(
175
+ [unconditional_conditioning, unconditional_conditioning[:, -8:, :]], dim=1)
176
  elif mode == 'mul':
177
  features_adapter = [a1i*0.5 + a2i for a1i, a2i in zip(features_adapter1, features_adapter2)]
178
  else:
ldm/modules/attention.py CHANGED
@@ -4,10 +4,22 @@ import torch
4
  import torch.nn.functional as F
5
  from torch import nn, einsum
6
  from einops import rearrange, repeat
 
7
 
8
  from ldm.modules.diffusionmodules.util import checkpoint
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
11
  def exists(val):
12
  return val is not None
13
 
@@ -77,25 +89,6 @@ def Normalize(in_channels):
77
  return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78
 
79
 
80
- class LinearAttention(nn.Module):
81
- def __init__(self, dim, heads=4, dim_head=32):
82
- super().__init__()
83
- self.heads = heads
84
- hidden_dim = dim_head * heads
85
- self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86
- self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87
-
88
- def forward(self, x):
89
- b, c, h, w = x.shape
90
- qkv = self.to_qkv(x)
91
- q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92
- k = k.softmax(dim=-1)
93
- context = torch.einsum('bhdn,bhen->bhde', k, v)
94
- out = torch.einsum('bhde,bhdn->bhen', context, q)
95
- out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96
- return self.to_out(out)
97
-
98
-
99
  class SpatialSelfAttention(nn.Module):
100
  def __init__(self, in_channels):
101
  super().__init__()
@@ -177,7 +170,15 @@ class CrossAttention(nn.Module):
177
 
178
  q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
179
 
180
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
 
 
 
 
 
 
 
 
181
 
182
  if exists(mask):
183
  mask = rearrange(mask, 'b ... -> b (...)')
@@ -186,20 +187,79 @@ class CrossAttention(nn.Module):
186
  sim.masked_fill_(~mask, max_neg_value)
187
 
188
  # attention, what we cannot get enough of
189
- attn = sim.softmax(dim=-1)
190
 
191
- out = einsum('b i j, b j d -> b i d', attn, v)
192
  out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
193
  return self.to_out(out)
194
 
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  class BasicTransformerBlock(nn.Module):
197
- def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
 
 
 
 
 
198
  super().__init__()
199
- self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
 
 
 
 
 
200
  self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
201
- self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
202
- heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
203
  self.norm1 = nn.LayerNorm(dim)
204
  self.norm2 = nn.LayerNorm(dim)
205
  self.norm3 = nn.LayerNorm(dim)
@@ -209,7 +269,7 @@ class BasicTransformerBlock(nn.Module):
209
  return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
210
 
211
  def _forward(self, x, context=None):
212
- x = self.attn1(self.norm1(x)) + x
213
  x = self.attn2(self.norm2(x), context=context) + x
214
  x = self.ff(self.norm3(x)) + x
215
  return x
@@ -222,40 +282,59 @@ class SpatialTransformer(nn.Module):
222
  and reshape to b, t, d.
223
  Then apply standard transformer action.
224
  Finally, reshape to image
 
225
  """
226
  def __init__(self, in_channels, n_heads, d_head,
227
- depth=1, dropout=0., context_dim=None):
 
 
228
  super().__init__()
 
 
229
  self.in_channels = in_channels
230
  inner_dim = n_heads * d_head
231
  self.norm = Normalize(in_channels)
232
-
233
- self.proj_in = nn.Conv2d(in_channels,
234
- inner_dim,
235
- kernel_size=1,
236
- stride=1,
237
- padding=0)
 
 
238
 
239
  self.transformer_blocks = nn.ModuleList(
240
- [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
 
241
  for d in range(depth)]
242
  )
243
-
244
- self.proj_out = zero_module(nn.Conv2d(inner_dim,
245
- in_channels,
246
- kernel_size=1,
247
- stride=1,
248
- padding=0))
 
 
 
249
 
250
  def forward(self, x, context=None):
251
  # note: if no context is given, cross-attention defaults to self-attention
 
 
252
  b, c, h, w = x.shape
253
  x_in = x
254
  x = self.norm(x)
255
- x = self.proj_in(x)
256
- x = rearrange(x, 'b c h w -> b (h w) c')
257
- for block in self.transformer_blocks:
258
- x = block(x, context=context)
259
- x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
260
- x = self.proj_out(x)
 
 
 
 
 
 
261
  return x + x_in
 
4
  import torch.nn.functional as F
5
  from torch import nn, einsum
6
  from einops import rearrange, repeat
7
+ from typing import Optional, Any
8
 
9
  from ldm.modules.diffusionmodules.util import checkpoint
10
 
11
 
12
+ try:
13
+ import xformers
14
+ import xformers.ops
15
+ XFORMERS_IS_AVAILBLE = True
16
+ except:
17
+ XFORMERS_IS_AVAILBLE = False
18
+
19
+ # CrossAttn precision handling
20
+ import os
21
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
22
+
23
  def exists(val):
24
  return val is not None
25
 
 
89
  return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  class SpatialSelfAttention(nn.Module):
93
  def __init__(self, in_channels):
94
  super().__init__()
 
170
 
171
  q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
172
 
173
+ # force cast to fp32 to avoid overflowing
174
+ if _ATTN_PRECISION =="fp32":
175
+ with torch.autocast(enabled=False, device_type = 'cuda'):
176
+ q, k = q.float(), k.float()
177
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
178
+ else:
179
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
180
+
181
+ del q, k
182
 
183
  if exists(mask):
184
  mask = rearrange(mask, 'b ... -> b (...)')
 
187
  sim.masked_fill_(~mask, max_neg_value)
188
 
189
  # attention, what we cannot get enough of
190
+ sim = sim.softmax(dim=-1)
191
 
192
+ out = einsum('b i j, b j d -> b i d', sim, v)
193
  out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
194
  return self.to_out(out)
195
 
196
 
197
+ class MemoryEfficientCrossAttention(nn.Module):
198
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
199
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
200
+ super().__init__()
201
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
202
+ f"{heads} heads.")
203
+ inner_dim = dim_head * heads
204
+ context_dim = default(context_dim, query_dim)
205
+
206
+ self.heads = heads
207
+ self.dim_head = dim_head
208
+
209
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
210
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
211
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
212
+
213
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
214
+ self.attention_op: Optional[Any] = None
215
+
216
+ def forward(self, x, context=None, mask=None):
217
+ q = self.to_q(x)
218
+ context = default(context, x)
219
+ k = self.to_k(context)
220
+ v = self.to_v(context)
221
+
222
+ b, _, _ = q.shape
223
+ q, k, v = map(
224
+ lambda t: t.unsqueeze(3)
225
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
226
+ .permute(0, 2, 1, 3)
227
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
228
+ .contiguous(),
229
+ (q, k, v),
230
+ )
231
+
232
+ # actually compute the attention, what we cannot get enough of
233
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
234
+
235
+ if exists(mask):
236
+ raise NotImplementedError
237
+ out = (
238
+ out.unsqueeze(0)
239
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
240
+ .permute(0, 2, 1, 3)
241
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
242
+ )
243
+ return self.to_out(out)
244
+
245
+
246
  class BasicTransformerBlock(nn.Module):
247
+ ATTENTION_MODES = {
248
+ "softmax": CrossAttention, # vanilla attention
249
+ "softmax-xformers": MemoryEfficientCrossAttention
250
+ }
251
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
252
+ disable_self_attn=False):
253
  super().__init__()
254
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
255
+ assert attn_mode in self.ATTENTION_MODES
256
+ attn_cls = self.ATTENTION_MODES[attn_mode]
257
+ self.disable_self_attn = disable_self_attn
258
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
259
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
260
  self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
261
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
262
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
263
  self.norm1 = nn.LayerNorm(dim)
264
  self.norm2 = nn.LayerNorm(dim)
265
  self.norm3 = nn.LayerNorm(dim)
 
269
  return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
270
 
271
  def _forward(self, x, context=None):
272
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
273
  x = self.attn2(self.norm2(x), context=context) + x
274
  x = self.ff(self.norm3(x)) + x
275
  return x
 
282
  and reshape to b, t, d.
283
  Then apply standard transformer action.
284
  Finally, reshape to image
285
+ NEW: use_linear for more efficiency instead of the 1x1 convs
286
  """
287
  def __init__(self, in_channels, n_heads, d_head,
288
+ depth=1, dropout=0., context_dim=None,
289
+ disable_self_attn=False, use_linear=False,
290
+ use_checkpoint=True):
291
  super().__init__()
292
+ if exists(context_dim) and not isinstance(context_dim, list):
293
+ context_dim = [context_dim]
294
  self.in_channels = in_channels
295
  inner_dim = n_heads * d_head
296
  self.norm = Normalize(in_channels)
297
+ if not use_linear:
298
+ self.proj_in = nn.Conv2d(in_channels,
299
+ inner_dim,
300
+ kernel_size=1,
301
+ stride=1,
302
+ padding=0)
303
+ else:
304
+ self.proj_in = nn.Linear(in_channels, inner_dim)
305
 
306
  self.transformer_blocks = nn.ModuleList(
307
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
308
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
309
  for d in range(depth)]
310
  )
311
+ if not use_linear:
312
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
313
+ in_channels,
314
+ kernel_size=1,
315
+ stride=1,
316
+ padding=0))
317
+ else:
318
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
319
+ self.use_linear = use_linear
320
 
321
  def forward(self, x, context=None):
322
  # note: if no context is given, cross-attention defaults to self-attention
323
+ if not isinstance(context, list):
324
+ context = [context]
325
  b, c, h, w = x.shape
326
  x_in = x
327
  x = self.norm(x)
328
+ if not self.use_linear:
329
+ x = self.proj_in(x)
330
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
331
+ if self.use_linear:
332
+ x = self.proj_in(x)
333
+ for i, block in enumerate(self.transformer_blocks):
334
+ x = block(x, context=context[i])
335
+ if self.use_linear:
336
+ x = self.proj_out(x)
337
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
338
+ if not self.use_linear:
339
+ x = self.proj_out(x)
340
  return x + x_in
ldm/modules/diffusionmodules/model.py CHANGED
@@ -4,9 +4,17 @@ import torch
4
  import torch.nn as nn
5
  import numpy as np
6
  from einops import rearrange
 
7
 
8
- from ldm.util import instantiate_from_config
9
- from ldm.modules.attention import LinearAttention
 
 
 
 
 
 
 
10
 
11
 
12
  def get_timestep_embedding(timesteps, embedding_dim):
@@ -141,12 +149,6 @@ class ResnetBlock(nn.Module):
141
  return x+h
142
 
143
 
144
- class LinAttnBlock(LinearAttention):
145
- """to match AttnBlock usage"""
146
- def __init__(self, in_channels):
147
- super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
148
-
149
-
150
  class AttnBlock(nn.Module):
151
  def __init__(self, in_channels):
152
  super().__init__()
@@ -174,7 +176,6 @@ class AttnBlock(nn.Module):
174
  stride=1,
175
  padding=0)
176
 
177
-
178
  def forward(self, x):
179
  h_ = x
180
  h_ = self.norm(h_)
@@ -201,16 +202,99 @@ class AttnBlock(nn.Module):
201
 
202
  return x+h_
203
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- def make_attn(in_channels, attn_type="vanilla"):
206
- assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
208
  if attn_type == "vanilla":
 
209
  return AttnBlock(in_channels)
 
 
 
 
 
 
210
  elif attn_type == "none":
211
  return nn.Identity(in_channels)
212
  else:
213
- return LinAttnBlock(in_channels)
214
 
215
 
216
  class Model(nn.Module):
@@ -766,70 +850,3 @@ class Resize(nn.Module):
766
  else:
767
  x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
768
  return x
769
-
770
- class FirstStagePostProcessor(nn.Module):
771
-
772
- def __init__(self, ch_mult:list, in_channels,
773
- pretrained_model:nn.Module=None,
774
- reshape=False,
775
- n_channels=None,
776
- dropout=0.,
777
- pretrained_config=None):
778
- super().__init__()
779
- if pretrained_config is None:
780
- assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
781
- self.pretrained_model = pretrained_model
782
- else:
783
- assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
784
- self.instantiate_pretrained(pretrained_config)
785
-
786
- self.do_reshape = reshape
787
-
788
- if n_channels is None:
789
- n_channels = self.pretrained_model.encoder.ch
790
-
791
- self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
792
- self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
793
- stride=1,padding=1)
794
-
795
- blocks = []
796
- downs = []
797
- ch_in = n_channels
798
- for m in ch_mult:
799
- blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
800
- ch_in = m * n_channels
801
- downs.append(Downsample(ch_in, with_conv=False))
802
-
803
- self.model = nn.ModuleList(blocks)
804
- self.downsampler = nn.ModuleList(downs)
805
-
806
-
807
- def instantiate_pretrained(self, config):
808
- model = instantiate_from_config(config)
809
- self.pretrained_model = model.eval()
810
- # self.pretrained_model.train = False
811
- for param in self.pretrained_model.parameters():
812
- param.requires_grad = False
813
-
814
-
815
- @torch.no_grad()
816
- def encode_with_pretrained(self,x):
817
- c = self.pretrained_model.encode(x)
818
- if isinstance(c, DiagonalGaussianDistribution):
819
- c = c.mode()
820
- return c
821
-
822
- def forward(self,x):
823
- z_fs = self.encode_with_pretrained(x)
824
- z = self.proj_norm(z_fs)
825
- z = self.proj(z)
826
- z = nonlinearity(z)
827
-
828
- for submodel, downmodel in zip(self.model,self.downsampler):
829
- z = submodel(z,temb=None)
830
- z = downmodel(z)
831
-
832
- if self.do_reshape:
833
- z = rearrange(z,'b c h w -> b (h w) c')
834
- return z
835
-
 
4
  import torch.nn as nn
5
  import numpy as np
6
  from einops import rearrange
7
+ from typing import Optional, Any
8
 
9
+ from ldm.modules.attention import MemoryEfficientCrossAttention
10
+
11
+ try:
12
+ import xformers
13
+ import xformers.ops
14
+ XFORMERS_IS_AVAILBLE = True
15
+ except:
16
+ XFORMERS_IS_AVAILBLE = False
17
+ print("No module 'xformers'. Proceeding without it.")
18
 
19
 
20
  def get_timestep_embedding(timesteps, embedding_dim):
 
149
  return x+h
150
 
151
 
 
 
 
 
 
 
152
  class AttnBlock(nn.Module):
153
  def __init__(self, in_channels):
154
  super().__init__()
 
176
  stride=1,
177
  padding=0)
178
 
 
179
  def forward(self, x):
180
  h_ = x
181
  h_ = self.norm(h_)
 
202
 
203
  return x+h_
204
 
205
+ class MemoryEfficientAttnBlock(nn.Module):
206
+ """
207
+ Uses xformers efficient implementation,
208
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
209
+ Note: this is a single-head self-attention operation
210
+ """
211
+ #
212
+ def __init__(self, in_channels):
213
+ super().__init__()
214
+ self.in_channels = in_channels
215
 
216
+ self.norm = Normalize(in_channels)
217
+ self.q = torch.nn.Conv2d(in_channels,
218
+ in_channels,
219
+ kernel_size=1,
220
+ stride=1,
221
+ padding=0)
222
+ self.k = torch.nn.Conv2d(in_channels,
223
+ in_channels,
224
+ kernel_size=1,
225
+ stride=1,
226
+ padding=0)
227
+ self.v = torch.nn.Conv2d(in_channels,
228
+ in_channels,
229
+ kernel_size=1,
230
+ stride=1,
231
+ padding=0)
232
+ self.proj_out = torch.nn.Conv2d(in_channels,
233
+ in_channels,
234
+ kernel_size=1,
235
+ stride=1,
236
+ padding=0)
237
+ self.attention_op: Optional[Any] = None
238
+
239
+ def forward(self, x):
240
+ h_ = x
241
+ h_ = self.norm(h_)
242
+ q = self.q(h_)
243
+ k = self.k(h_)
244
+ v = self.v(h_)
245
+
246
+ # compute attention
247
+ B, C, H, W = q.shape
248
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
249
+
250
+ q, k, v = map(
251
+ lambda t: t.unsqueeze(3)
252
+ .reshape(B, t.shape[1], 1, C)
253
+ .permute(0, 2, 1, 3)
254
+ .reshape(B * 1, t.shape[1], C)
255
+ .contiguous(),
256
+ (q, k, v),
257
+ )
258
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
259
+
260
+ out = (
261
+ out.unsqueeze(0)
262
+ .reshape(B, 1, out.shape[1], C)
263
+ .permute(0, 2, 1, 3)
264
+ .reshape(B, out.shape[1], C)
265
+ )
266
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
267
+ out = self.proj_out(out)
268
+ return x+out
269
+
270
+
271
+ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
272
+ def forward(self, x, context=None, mask=None):
273
+ b, c, h, w = x.shape
274
+ x = rearrange(x, 'b c h w -> b (h w) c')
275
+ out = super().forward(x, context=context, mask=mask)
276
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
277
+ return x + out
278
+
279
+
280
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
281
+ assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
282
+ if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
283
+ attn_type = "vanilla-xformers"
284
  print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
285
  if attn_type == "vanilla":
286
+ assert attn_kwargs is None
287
  return AttnBlock(in_channels)
288
+ elif attn_type == "vanilla-xformers":
289
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
290
+ return MemoryEfficientAttnBlock(in_channels)
291
+ elif type == "memory-efficient-cross-attn":
292
+ attn_kwargs["query_dim"] = in_channels
293
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
294
  elif attn_type == "none":
295
  return nn.Identity(in_channels)
296
  else:
297
+ raise NotImplementedError()
298
 
299
 
300
  class Model(nn.Module):
 
850
  else:
851
  x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
852
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/modules/encoders/adapter.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  from ldm.modules.attention import SpatialTransformer, BasicTransformerBlock
 
5
 
6
  def conv_nd(dims, *args, **kwargs):
7
  """
@@ -121,3 +122,130 @@ class Adapter(nn.Module):
121
  features.append(x)
122
 
123
  return features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  from ldm.modules.attention import SpatialTransformer, BasicTransformerBlock
5
+ from collections import OrderedDict
6
 
7
  def conv_nd(dims, *args, **kwargs):
8
  """
 
122
  features.append(x)
123
 
124
  return features
125
+
126
+
127
+ class ResnetBlock_light(nn.Module):
128
+ def __init__(self, in_c):
129
+ super().__init__()
130
+ self.block1 = nn.Conv2d(in_c, in_c, 3, 1, 1)
131
+ self.act = nn.ReLU()
132
+ self.block2 = nn.Conv2d(in_c, in_c, 3, 1, 1)
133
+
134
+ def forward(self, x):
135
+ h = self.block1(x)
136
+ h = self.act(h)
137
+ h = self.block2(h)
138
+
139
+ return h + x
140
+
141
+
142
+ class extractor(nn.Module):
143
+ def __init__(self, in_c, inter_c, out_c, nums_rb, down=False):
144
+ super().__init__()
145
+ self.in_conv = nn.Conv2d(in_c, inter_c, 1, 1, 0)
146
+ self.body = []
147
+ for _ in range(nums_rb):
148
+ self.body.append(ResnetBlock_light(inter_c))
149
+ self.body = nn.Sequential(*self.body)
150
+ self.out_conv = nn.Conv2d(inter_c, out_c, 1, 1, 0)
151
+ self.down = down
152
+ if self.down == True:
153
+ self.down_opt = Downsample(in_c, use_conv=False)
154
+
155
+ def forward(self, x):
156
+ if self.down == True:
157
+ x = self.down_opt(x)
158
+ x = self.in_conv(x)
159
+ x = self.body(x)
160
+ x = self.out_conv(x)
161
+
162
+ return x
163
+
164
+
165
+ class Adapter_light(nn.Module):
166
+ def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
167
+ super(Adapter_light, self).__init__()
168
+ self.unshuffle = nn.PixelUnshuffle(8)
169
+ self.channels = channels
170
+ self.nums_rb = nums_rb
171
+ self.body = []
172
+ for i in range(len(channels)):
173
+ if i == 0:
174
+ self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False))
175
+ else:
176
+ self.body.append(extractor(in_c=channels[i-1], inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=True))
177
+ self.body = nn.ModuleList(self.body)
178
+
179
+ def forward(self, x):
180
+ # unshuffle
181
+ x = self.unshuffle(x)
182
+ # extract features
183
+ features = []
184
+ for i in range(len(self.channels)):
185
+ x = self.body[i](x)
186
+ features.append(x)
187
+
188
+ return features
189
+
190
+ class QuickGELU(nn.Module):
191
+
192
+ def forward(self, x: torch.Tensor):
193
+ return x * torch.sigmoid(1.702 * x)
194
+
195
+ class ResidualAttentionBlock(nn.Module):
196
+
197
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
198
+ super().__init__()
199
+
200
+ self.attn = nn.MultiheadAttention(d_model, n_head)
201
+ self.ln_1 = LayerNorm(d_model)
202
+ self.mlp = nn.Sequential(
203
+ OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
204
+ ("c_proj", nn.Linear(d_model * 4, d_model))]))
205
+ self.ln_2 = LayerNorm(d_model)
206
+ self.attn_mask = attn_mask
207
+
208
+ def attention(self, x: torch.Tensor):
209
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
210
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
211
+
212
+ def forward(self, x: torch.Tensor):
213
+ x = x + self.attention(self.ln_1(x))
214
+ x = x + self.mlp(self.ln_2(x))
215
+ return x
216
+
217
+ class LayerNorm(nn.LayerNorm):
218
+ """Subclass torch's LayerNorm to handle fp16."""
219
+
220
+ def forward(self, x: torch.Tensor):
221
+ orig_type = x.dtype
222
+ ret = super().forward(x.type(torch.float32))
223
+ return ret.type(orig_type)
224
+
225
+ class StyleAdapter(nn.Module):
226
+
227
+ def __init__(self, width=1024, context_dim=768, num_head=8, n_layes=3, num_token=4):
228
+ super().__init__()
229
+
230
+ scale = width ** -0.5
231
+ self.transformer_layes = nn.Sequential(*[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)])
232
+ self.num_token = num_token
233
+ self.style_embedding = nn.Parameter(torch.randn(1, num_token, width) * scale)
234
+ self.ln_post = LayerNorm(width)
235
+ self.ln_pre = LayerNorm(width)
236
+ self.proj = nn.Parameter(scale * torch.randn(width, context_dim))
237
+
238
+ def forward(self, x):
239
+ # x shape [N, HW+1, C]
240
+ style_embedding = self.style_embedding + torch.zeros(
241
+ (x.shape[0], self.num_token, self.style_embedding.shape[-1]), device=x.device)
242
+ x = torch.cat([x, style_embedding], dim=1)
243
+ x = self.ln_pre(x)
244
+ x = x.permute(1, 0, 2) # NLD -> LND
245
+ x = self.transformer_layes(x)
246
+ x = x.permute(1, 0, 2) # LND -> NLD
247
+
248
+ x = self.ln_post(x[:, -self.num_token:, :])
249
+ x = x @ self.proj
250
+
251
+ return x
ldm/modules/structure_condition/openpose/__init__.py ADDED
File without changes
ldm/modules/structure_condition/openpose/api.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch.nn as nn
5
+
6
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
7
+
8
+ import cv2
9
+ import torch
10
+
11
+ from . import util
12
+ from .body import Body
13
+
14
+ remote_model_path = "https://drive.google.com/file/d/1EULkcH_hhSU28qVc1jSJpCh2hGOrzpjK/view?usp=share_link"
15
+
16
+
17
+ class OpenposeInference(nn.Module):
18
+
19
+ def __init__(self):
20
+ super().__init__()
21
+ body_modelpath = os.path.join('models', "body_pose_model.pth")
22
+
23
+ if not os.path.exists(body_modelpath):
24
+ from basicsr.utils.download_util import load_file_from_url
25
+ load_file_from_url(remote_model_path, model_dir='models')
26
+
27
+ self.body_estimation = Body(body_modelpath)
28
+
29
+ def forward(self, x):
30
+ x = x[:, :, ::-1].copy()
31
+ with torch.no_grad():
32
+ candidate, subset = self.body_estimation(x)
33
+ canvas = np.zeros_like(x)
34
+ canvas = util.draw_bodypose(canvas, candidate, subset)
35
+ canvas = cv2.cvtColor(canvas, cv2.COLOR_RGB2BGR)
36
+ return canvas
ldm/modules/structure_condition/openpose/body.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import time
3
+
4
+ import cv2
5
+ import matplotlib
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import torch
9
+ from scipy.ndimage.filters import gaussian_filter
10
+ from torchvision import transforms
11
+
12
+ from . import util
13
+ from .model import bodypose_model
14
+
15
+
16
+ class Body(object):
17
+
18
+ def __init__(self, model_path):
19
+ self.model = bodypose_model()
20
+ if torch.cuda.is_available():
21
+ self.model = self.model.cuda()
22
+ print('cuda')
23
+ model_dict = util.transfer(self.model, torch.load(model_path))
24
+ self.model.load_state_dict(model_dict)
25
+ self.model.eval()
26
+
27
+ def __call__(self, oriImg):
28
+ # scale_search = [0.5, 1.0, 1.5, 2.0]
29
+ scale_search = [0.5]
30
+ boxsize = 368
31
+ stride = 8
32
+ padValue = 128
33
+ thre1 = 0.1
34
+ thre2 = 0.05
35
+ multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
36
+ heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
37
+ paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
38
+
39
+ for m in range(len(multiplier)):
40
+ scale = multiplier[m]
41
+ imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
42
+ imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
43
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
44
+ im = np.ascontiguousarray(im)
45
+
46
+ data = torch.from_numpy(im).float()
47
+ if torch.cuda.is_available():
48
+ data = data.cuda()
49
+ # data = data.permute([2, 0, 1]).unsqueeze(0).float()
50
+ with torch.no_grad():
51
+ Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
52
+ Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
53
+ Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
54
+
55
+ # extract outputs, resize, and remove padding
56
+ # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
57
+ heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
58
+ heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
59
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
60
+ heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
61
+
62
+ # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
63
+ paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
64
+ paf = cv2.resize(paf, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
65
+ paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
66
+ paf = cv2.resize(paf, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
67
+
68
+ heatmap_avg += heatmap_avg + heatmap / len(multiplier)
69
+ paf_avg += +paf / len(multiplier)
70
+
71
+ all_peaks = []
72
+ peak_counter = 0
73
+
74
+ for part in range(18):
75
+ map_ori = heatmap_avg[:, :, part]
76
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
77
+
78
+ map_left = np.zeros(one_heatmap.shape)
79
+ map_left[1:, :] = one_heatmap[:-1, :]
80
+ map_right = np.zeros(one_heatmap.shape)
81
+ map_right[:-1, :] = one_heatmap[1:, :]
82
+ map_up = np.zeros(one_heatmap.shape)
83
+ map_up[:, 1:] = one_heatmap[:, :-1]
84
+ map_down = np.zeros(one_heatmap.shape)
85
+ map_down[:, :-1] = one_heatmap[:, 1:]
86
+
87
+ peaks_binary = np.logical_and.reduce((one_heatmap >= map_left, one_heatmap >= map_right,
88
+ one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1))
89
+ peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse
90
+ peaks_with_score = [x + (map_ori[x[1], x[0]], ) for x in peaks]
91
+ peak_id = range(peak_counter, peak_counter + len(peaks))
92
+ peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i], ) for i in range(len(peak_id))]
93
+
94
+ all_peaks.append(peaks_with_score_and_id)
95
+ peak_counter += len(peaks)
96
+
97
+ # find connection in the specified sequence, center 29 is in the position 15
98
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
99
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
100
+ [1, 16], [16, 18], [3, 17], [6, 18]]
101
+ # the middle joints heatmap correpondence
102
+ mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
103
+ [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
104
+ [55, 56], [37, 38], [45, 46]]
105
+
106
+ connection_all = []
107
+ special_k = []
108
+ mid_num = 10
109
+
110
+ for k in range(len(mapIdx)):
111
+ score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
112
+ candA = all_peaks[limbSeq[k][0] - 1]
113
+ candB = all_peaks[limbSeq[k][1] - 1]
114
+ nA = len(candA)
115
+ nB = len(candB)
116
+ indexA, indexB = limbSeq[k]
117
+ if (nA != 0 and nB != 0):
118
+ connection_candidate = []
119
+ for i in range(nA):
120
+ for j in range(nB):
121
+ vec = np.subtract(candB[j][:2], candA[i][:2])
122
+ norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
123
+ norm = max(0.001, norm)
124
+ vec = np.divide(vec, norm)
125
+
126
+ startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
127
+ np.linspace(candA[i][1], candB[j][1], num=mid_num)))
128
+
129
+ vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \
130
+ for I in range(len(startend))])
131
+ vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \
132
+ for I in range(len(startend))])
133
+
134
+ score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
135
+ score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
136
+ 0.5 * oriImg.shape[0] / norm - 1, 0)
137
+ criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
138
+ criterion2 = score_with_dist_prior > 0
139
+ if criterion1 and criterion2:
140
+ connection_candidate.append(
141
+ [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
142
+
143
+ connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
144
+ connection = np.zeros((0, 5))
145
+ for c in range(len(connection_candidate)):
146
+ i, j, s = connection_candidate[c][0:3]
147
+ if (i not in connection[:, 3] and j not in connection[:, 4]):
148
+ connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
149
+ if (len(connection) >= min(nA, nB)):
150
+ break
151
+
152
+ connection_all.append(connection)
153
+ else:
154
+ special_k.append(k)
155
+ connection_all.append([])
156
+
157
+ # last number in each row is the total parts number of that person
158
+ # the second last number in each row is the score of the overall configuration
159
+ subset = -1 * np.ones((0, 20))
160
+ candidate = np.array([item for sublist in all_peaks for item in sublist])
161
+
162
+ for k in range(len(mapIdx)):
163
+ if k not in special_k:
164
+ partAs = connection_all[k][:, 0]
165
+ partBs = connection_all[k][:, 1]
166
+ indexA, indexB = np.array(limbSeq[k]) - 1
167
+
168
+ for i in range(len(connection_all[k])): # = 1:size(temp,1)
169
+ found = 0
170
+ subset_idx = [-1, -1]
171
+ for j in range(len(subset)): # 1:size(subset,1):
172
+ if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
173
+ subset_idx[found] = j
174
+ found += 1
175
+
176
+ if found == 1:
177
+ j = subset_idx[0]
178
+ if subset[j][indexB] != partBs[i]:
179
+ subset[j][indexB] = partBs[i]
180
+ subset[j][-1] += 1
181
+ subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
182
+ elif found == 2: # if found 2 and disjoint, merge them
183
+ j1, j2 = subset_idx
184
+ membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
185
+ if len(np.nonzero(membership == 2)[0]) == 0: # merge
186
+ subset[j1][:-2] += (subset[j2][:-2] + 1)
187
+ subset[j1][-2:] += subset[j2][-2:]
188
+ subset[j1][-2] += connection_all[k][i][2]
189
+ subset = np.delete(subset, j2, 0)
190
+ else: # as like found == 1
191
+ subset[j1][indexB] = partBs[i]
192
+ subset[j1][-1] += 1
193
+ subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
194
+
195
+ # if find no partA in the subset, create a new subset
196
+ elif not found and k < 17:
197
+ row = -1 * np.ones(20)
198
+ row[indexA] = partAs[i]
199
+ row[indexB] = partBs[i]
200
+ row[-1] = 2
201
+ row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
202
+ subset = np.vstack([subset, row])
203
+ # delete some rows of subset which has few parts occur
204
+ deleteIdx = []
205
+ for i in range(len(subset)):
206
+ if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
207
+ deleteIdx.append(i)
208
+ subset = np.delete(subset, deleteIdx, axis=0)
209
+
210
+ # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
211
+ # candidate: x, y, score, id
212
+ return candidate, subset
213
+
214
+
215
+ if __name__ == "__main__":
216
+ body_estimation = Body('../model/body_pose_model.pth')
217
+
218
+ test_image = '/group/30042/liangbinxie/Projects/mmpose/test_data/twitter/1.png'
219
+ oriImg = cv2.imread(test_image) # B,G,R order
220
+ candidate, subset = body_estimation(oriImg)
221
+ print(candidate, subset)
222
+ canvas = util.draw_bodypose(oriImg, candidate, subset)
223
+ plt.imshow(canvas[:, :, [2, 1, 0]])
224
+ plt.show()
ldm/modules/structure_condition/openpose/hand.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import json
3
+ import numpy as np
4
+ import math
5
+ import time
6
+ from scipy.ndimage.filters import gaussian_filter
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib
9
+ import torch
10
+ from skimage.measure import label
11
+
12
+ from .model import handpose_model
13
+ from . import util
14
+
15
+ class Hand(object):
16
+ def __init__(self, model_path):
17
+ self.model = handpose_model()
18
+ if torch.cuda.is_available():
19
+ self.model = self.model.cuda()
20
+ print('cuda')
21
+ model_dict = util.transfer(self.model, torch.load(model_path))
22
+ self.model.load_state_dict(model_dict)
23
+ self.model.eval()
24
+
25
+ def __call__(self, oriImg):
26
+ scale_search = [0.5, 1.0, 1.5, 2.0]
27
+ # scale_search = [0.5]
28
+ boxsize = 368
29
+ stride = 8
30
+ padValue = 128
31
+ thre = 0.05
32
+ multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
33
+ heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22))
34
+ # paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
35
+
36
+ for m in range(len(multiplier)):
37
+ scale = multiplier[m]
38
+ imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
39
+ imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
40
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
41
+ im = np.ascontiguousarray(im)
42
+
43
+ data = torch.from_numpy(im).float()
44
+ if torch.cuda.is_available():
45
+ data = data.cuda()
46
+ # data = data.permute([2, 0, 1]).unsqueeze(0).float()
47
+ with torch.no_grad():
48
+ output = self.model(data).cpu().numpy()
49
+ # output = self.model(data).numpy()q
50
+
51
+ # extract outputs, resize, and remove padding
52
+ heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps
53
+ heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
54
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
55
+ heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
56
+
57
+ heatmap_avg += heatmap / len(multiplier)
58
+
59
+ all_peaks = []
60
+ for part in range(21):
61
+ map_ori = heatmap_avg[:, :, part]
62
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
63
+ binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
64
+ # 全部小于阈值
65
+ if np.sum(binary) == 0:
66
+ all_peaks.append([0, 0])
67
+ continue
68
+ label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
69
+ max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
70
+ label_img[label_img != max_index] = 0
71
+ map_ori[label_img == 0] = 0
72
+
73
+ y, x = util.npmax(map_ori)
74
+ all_peaks.append([x, y])
75
+ return np.array(all_peaks)
76
+
77
+ if __name__ == "__main__":
78
+ hand_estimation = Hand('../model/hand_pose_model.pth')
79
+
80
+ # test_image = '../images/hand.jpg'
81
+ test_image = '../images/hand.jpg'
82
+ oriImg = cv2.imread(test_image) # B,G,R order
83
+ peaks = hand_estimation(oriImg)
84
+ canvas = util.draw_handpose(oriImg, peaks, True)
85
+ cv2.imshow('', canvas)
86
+ cv2.waitKey(0)
ldm/modules/structure_condition/openpose/model.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from collections import OrderedDict
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ def make_layers(block, no_relu_layers):
8
+ layers = []
9
+ for layer_name, v in block.items():
10
+ if 'pool' in layer_name:
11
+ layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
12
+ padding=v[2])
13
+ layers.append((layer_name, layer))
14
+ else:
15
+ conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
16
+ kernel_size=v[2], stride=v[3],
17
+ padding=v[4])
18
+ layers.append((layer_name, conv2d))
19
+ if layer_name not in no_relu_layers:
20
+ layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
21
+
22
+ return nn.Sequential(OrderedDict(layers))
23
+
24
+ class bodypose_model(nn.Module):
25
+ def __init__(self):
26
+ super(bodypose_model, self).__init__()
27
+
28
+ # these layers have no relu layer
29
+ no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
30
+ 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
31
+ 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
32
+ 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
33
+ blocks = {}
34
+ block0 = OrderedDict([
35
+ ('conv1_1', [3, 64, 3, 1, 1]),
36
+ ('conv1_2', [64, 64, 3, 1, 1]),
37
+ ('pool1_stage1', [2, 2, 0]),
38
+ ('conv2_1', [64, 128, 3, 1, 1]),
39
+ ('conv2_2', [128, 128, 3, 1, 1]),
40
+ ('pool2_stage1', [2, 2, 0]),
41
+ ('conv3_1', [128, 256, 3, 1, 1]),
42
+ ('conv3_2', [256, 256, 3, 1, 1]),
43
+ ('conv3_3', [256, 256, 3, 1, 1]),
44
+ ('conv3_4', [256, 256, 3, 1, 1]),
45
+ ('pool3_stage1', [2, 2, 0]),
46
+ ('conv4_1', [256, 512, 3, 1, 1]),
47
+ ('conv4_2', [512, 512, 3, 1, 1]),
48
+ ('conv4_3_CPM', [512, 256, 3, 1, 1]),
49
+ ('conv4_4_CPM', [256, 128, 3, 1, 1])
50
+ ])
51
+
52
+
53
+ # Stage 1
54
+ block1_1 = OrderedDict([
55
+ ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
56
+ ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
57
+ ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
58
+ ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
59
+ ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
60
+ ])
61
+
62
+ block1_2 = OrderedDict([
63
+ ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
64
+ ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
65
+ ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
66
+ ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
67
+ ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
68
+ ])
69
+ blocks['block1_1'] = block1_1
70
+ blocks['block1_2'] = block1_2
71
+
72
+ self.model0 = make_layers(block0, no_relu_layers)
73
+
74
+ # Stages 2 - 6
75
+ for i in range(2, 7):
76
+ blocks['block%d_1' % i] = OrderedDict([
77
+ ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
78
+ ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
79
+ ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
80
+ ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
81
+ ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
82
+ ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
83
+ ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
84
+ ])
85
+
86
+ blocks['block%d_2' % i] = OrderedDict([
87
+ ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
88
+ ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
89
+ ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
90
+ ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
91
+ ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
92
+ ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
93
+ ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
94
+ ])
95
+
96
+ for k in blocks.keys():
97
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
98
+
99
+ self.model1_1 = blocks['block1_1']
100
+ self.model2_1 = blocks['block2_1']
101
+ self.model3_1 = blocks['block3_1']
102
+ self.model4_1 = blocks['block4_1']
103
+ self.model5_1 = blocks['block5_1']
104
+ self.model6_1 = blocks['block6_1']
105
+
106
+ self.model1_2 = blocks['block1_2']
107
+ self.model2_2 = blocks['block2_2']
108
+ self.model3_2 = blocks['block3_2']
109
+ self.model4_2 = blocks['block4_2']
110
+ self.model5_2 = blocks['block5_2']
111
+ self.model6_2 = blocks['block6_2']
112
+
113
+
114
+ def forward(self, x):
115
+
116
+ out1 = self.model0(x)
117
+
118
+ out1_1 = self.model1_1(out1)
119
+ out1_2 = self.model1_2(out1)
120
+ out2 = torch.cat([out1_1, out1_2, out1], 1)
121
+
122
+ out2_1 = self.model2_1(out2)
123
+ out2_2 = self.model2_2(out2)
124
+ out3 = torch.cat([out2_1, out2_2, out1], 1)
125
+
126
+ out3_1 = self.model3_1(out3)
127
+ out3_2 = self.model3_2(out3)
128
+ out4 = torch.cat([out3_1, out3_2, out1], 1)
129
+
130
+ out4_1 = self.model4_1(out4)
131
+ out4_2 = self.model4_2(out4)
132
+ out5 = torch.cat([out4_1, out4_2, out1], 1)
133
+
134
+ out5_1 = self.model5_1(out5)
135
+ out5_2 = self.model5_2(out5)
136
+ out6 = torch.cat([out5_1, out5_2, out1], 1)
137
+
138
+ out6_1 = self.model6_1(out6)
139
+ out6_2 = self.model6_2(out6)
140
+
141
+ return out6_1, out6_2
142
+
143
+ class handpose_model(nn.Module):
144
+ def __init__(self):
145
+ super(handpose_model, self).__init__()
146
+
147
+ # these layers have no relu layer
148
+ no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
149
+ 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
150
+ # stage 1
151
+ block1_0 = OrderedDict([
152
+ ('conv1_1', [3, 64, 3, 1, 1]),
153
+ ('conv1_2', [64, 64, 3, 1, 1]),
154
+ ('pool1_stage1', [2, 2, 0]),
155
+ ('conv2_1', [64, 128, 3, 1, 1]),
156
+ ('conv2_2', [128, 128, 3, 1, 1]),
157
+ ('pool2_stage1', [2, 2, 0]),
158
+ ('conv3_1', [128, 256, 3, 1, 1]),
159
+ ('conv3_2', [256, 256, 3, 1, 1]),
160
+ ('conv3_3', [256, 256, 3, 1, 1]),
161
+ ('conv3_4', [256, 256, 3, 1, 1]),
162
+ ('pool3_stage1', [2, 2, 0]),
163
+ ('conv4_1', [256, 512, 3, 1, 1]),
164
+ ('conv4_2', [512, 512, 3, 1, 1]),
165
+ ('conv4_3', [512, 512, 3, 1, 1]),
166
+ ('conv4_4', [512, 512, 3, 1, 1]),
167
+ ('conv5_1', [512, 512, 3, 1, 1]),
168
+ ('conv5_2', [512, 512, 3, 1, 1]),
169
+ ('conv5_3_CPM', [512, 128, 3, 1, 1])
170
+ ])
171
+
172
+ block1_1 = OrderedDict([
173
+ ('conv6_1_CPM', [128, 512, 1, 1, 0]),
174
+ ('conv6_2_CPM', [512, 22, 1, 1, 0])
175
+ ])
176
+
177
+ blocks = {}
178
+ blocks['block1_0'] = block1_0
179
+ blocks['block1_1'] = block1_1
180
+
181
+ # stage 2-6
182
+ for i in range(2, 7):
183
+ blocks['block%d' % i] = OrderedDict([
184
+ ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
185
+ ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
186
+ ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
187
+ ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
188
+ ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
189
+ ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
190
+ ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
191
+ ])
192
+
193
+ for k in blocks.keys():
194
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
195
+
196
+ self.model1_0 = blocks['block1_0']
197
+ self.model1_1 = blocks['block1_1']
198
+ self.model2 = blocks['block2']
199
+ self.model3 = blocks['block3']
200
+ self.model4 = blocks['block4']
201
+ self.model5 = blocks['block5']
202
+ self.model6 = blocks['block6']
203
+
204
+ def forward(self, x):
205
+ out1_0 = self.model1_0(x)
206
+ out1_1 = self.model1_1(out1_0)
207
+ concat_stage2 = torch.cat([out1_1, out1_0], 1)
208
+ out_stage2 = self.model2(concat_stage2)
209
+ concat_stage3 = torch.cat([out_stage2, out1_0], 1)
210
+ out_stage3 = self.model3(concat_stage3)
211
+ concat_stage4 = torch.cat([out_stage3, out1_0], 1)
212
+ out_stage4 = self.model4(concat_stage4)
213
+ concat_stage5 = torch.cat([out_stage4, out1_0], 1)
214
+ out_stage5 = self.model5(concat_stage5)
215
+ concat_stage6 = torch.cat([out_stage5, out1_0], 1)
216
+ out_stage6 = self.model6(concat_stage6)
217
+ return out_stage6
218
+
219
+
ldm/modules/structure_condition/openpose/util.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import cv2
4
+ import matplotlib
5
+ import numpy as np
6
+
7
+
8
+ def padRightDownCorner(img, stride, padValue):
9
+ h = img.shape[0]
10
+ w = img.shape[1]
11
+
12
+ pad = 4 * [None]
13
+ pad[0] = 0 # up
14
+ pad[1] = 0 # left
15
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
16
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
17
+
18
+ img_padded = img
19
+ pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
20
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
21
+ pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
22
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
23
+ pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
24
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
25
+ pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
26
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
27
+
28
+ return img_padded, pad
29
+
30
+
31
+ # transfer caffe model to pytorch which will match the layer name
32
+ def transfer(model, model_weights):
33
+ transfered_model_weights = {}
34
+ for weights_name in model.state_dict().keys():
35
+ transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
36
+ return transfered_model_weights
37
+
38
+
39
+ # draw the body keypoint and lims
40
+ def draw_bodypose(canvas, candidate, subset):
41
+ stickwidth = 4
42
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
43
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
44
+ [1, 16], [16, 18], [3, 17], [6, 18]]
45
+
46
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
47
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
48
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
49
+ for i in range(18):
50
+ for n in range(len(subset)):
51
+ index = int(subset[n][i])
52
+ if index == -1:
53
+ continue
54
+ x, y = candidate[index][0:2]
55
+ cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
56
+ for i in range(17):
57
+ for n in range(len(subset)):
58
+ index = subset[n][np.array(limbSeq[i]) - 1]
59
+ if -1 in index:
60
+ continue
61
+ cur_canvas = canvas.copy()
62
+ Y = candidate[index.astype(int), 0]
63
+ X = candidate[index.astype(int), 1]
64
+ mX = np.mean(X)
65
+ mY = np.mean(Y)
66
+ length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5
67
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
68
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
69
+ cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
70
+ canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
71
+ # plt.imsave("preview.jpg", canvas[:, :, [2, 1, 0]])
72
+ # plt.imshow(canvas[:, :, [2, 1, 0]])
73
+ return canvas
74
+
75
+
76
+ # image drawed by opencv is not good.
77
+ def draw_handpose(canvas, all_hand_peaks, show_number=False):
78
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
79
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
80
+
81
+ for peaks in all_hand_peaks:
82
+ for ie, e in enumerate(edges):
83
+ if np.sum(np.all(peaks[e], axis=1) == 0) == 0:
84
+ x1, y1 = peaks[e[0]]
85
+ x2, y2 = peaks[e[1]]
86
+ cv2.line(
87
+ canvas, (x1, y1), (x2, y2),
88
+ matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
89
+ thickness=2)
90
+
91
+ for i, keyponit in enumerate(peaks):
92
+ x, y = keyponit
93
+ cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
94
+ if show_number:
95
+ cv2.putText(canvas, str(i), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 0, 0), lineType=cv2.LINE_AA)
96
+ return canvas
97
+
98
+
99
+ # detect hand according to body pose keypoints
100
+ # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
101
+ def handDetect(candidate, subset, oriImg):
102
+ # right hand: wrist 4, elbow 3, shoulder 2
103
+ # left hand: wrist 7, elbow 6, shoulder 5
104
+ ratioWristElbow = 0.33
105
+ detect_result = []
106
+ image_height, image_width = oriImg.shape[0:2]
107
+ for person in subset.astype(int):
108
+ # if any of three not detected
109
+ has_left = np.sum(person[[5, 6, 7]] == -1) == 0
110
+ has_right = np.sum(person[[2, 3, 4]] == -1) == 0
111
+ if not (has_left or has_right):
112
+ continue
113
+ hands = []
114
+ #left hand
115
+ if has_left:
116
+ left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
117
+ x1, y1 = candidate[left_shoulder_index][:2]
118
+ x2, y2 = candidate[left_elbow_index][:2]
119
+ x3, y3 = candidate[left_wrist_index][:2]
120
+ hands.append([x1, y1, x2, y2, x3, y3, True])
121
+ # right hand
122
+ if has_right:
123
+ right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
124
+ x1, y1 = candidate[right_shoulder_index][:2]
125
+ x2, y2 = candidate[right_elbow_index][:2]
126
+ x3, y3 = candidate[right_wrist_index][:2]
127
+ hands.append([x1, y1, x2, y2, x3, y3, False])
128
+
129
+ for x1, y1, x2, y2, x3, y3, is_left in hands:
130
+ # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
131
+ # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
132
+ # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
133
+ # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
134
+ # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
135
+ # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
136
+ x = x3 + ratioWristElbow * (x3 - x2)
137
+ y = y3 + ratioWristElbow * (y3 - y2)
138
+ distanceWristElbow = math.sqrt((x3 - x2)**2 + (y3 - y2)**2)
139
+ distanceElbowShoulder = math.sqrt((x2 - x1)**2 + (y2 - y1)**2)
140
+ width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
141
+ # x-y refers to the center --> offset to topLeft point
142
+ # handRectangle.x -= handRectangle.width / 2.f;
143
+ # handRectangle.y -= handRectangle.height / 2.f;
144
+ x -= width / 2
145
+ y -= width / 2 # width = height
146
+ # overflow the image
147
+ if x < 0: x = 0
148
+ if y < 0: y = 0
149
+ width1 = width
150
+ width2 = width
151
+ if x + width > image_width: width1 = image_width - x
152
+ if y + width > image_height: width2 = image_height - y
153
+ width = min(width1, width2)
154
+ # the max hand box value is 20 pixels
155
+ if width >= 20:
156
+ detect_result.append([int(x), int(y), int(width), is_left])
157
+ '''
158
+ return value: [[x, y, w, True if left hand else False]].
159
+ width=height since the network require squared input.
160
+ x, y is the coordinate of top left
161
+ '''
162
+ return detect_result
163
+
164
+
165
+ # get max index of 2d array
166
+ def npmax(array):
167
+ arrayindex = array.argmax(1)
168
+ arrayvalue = array.max(1)
169
+ i = arrayvalue.argmax()
170
+ j = arrayindex[i]
171
+ return i, j
172
+
173
+
174
+ def HWC3(x):
175
+ assert x.dtype == np.uint8
176
+ if x.ndim == 2:
177
+ x = x[:, :, None]
178
+ assert x.ndim == 3
179
+ H, W, C = x.shape
180
+ assert C == 1 or C == 3 or C == 4
181
+ if C == 3:
182
+ return x
183
+ if C == 1:
184
+ return np.concatenate([x, x, x], axis=2)
185
+ if C == 4:
186
+ color = x[:, :, 0:3].astype(np.float32)
187
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
188
+ y = color * alpha + 255.0 * (1.0 - alpha)
189
+ y = y.clip(0, 255).astype(np.uint8)
190
+ return y
191
+
192
+
193
+ def resize_image(input_image, resolution):
194
+ H, W, C = input_image.shape
195
+ H = float(H)
196
+ W = float(W)
197
+ k = float(resolution) / min(H, W)
198
+ H *= k
199
+ W *= k
200
+ H = int(np.round(H / 64.0)) * 64
201
+ W = int(np.round(W / 64.0)) * 64
202
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
203
+ return img