lixiang46 commited on
Commit
6c91ee7
Β·
1 Parent(s): 01bb574
annotator/__init__.py ADDED
File without changes
annotator/canny/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+
4
+ class CannyDetector:
5
+ def __call__(self, img, low_threshold, high_threshold):
6
+ return cv2.Canny(img, low_threshold, high_threshold)
annotator/midas/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
annotator/midas/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Midas Depth Estimation
2
+ # From https://github.com/isl-org/MiDaS
3
+ # MIT LICENSE
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+
9
+ from einops import rearrange
10
+ from .api import MiDaSInference
11
+
12
+
13
+ class MidasDetector:
14
+ def __init__(self):
15
+ self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
16
+ self.rng = np.random.RandomState(0)
17
+
18
+ def __call__(self, input_image):
19
+ assert input_image.ndim == 3
20
+ image_depth = input_image
21
+ with torch.no_grad():
22
+ image_depth = torch.from_numpy(image_depth).float().cuda()
23
+ image_depth = image_depth / 127.5 - 1.0
24
+ image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
25
+ depth = self.model(image_depth)[0]
26
+
27
+ depth -= torch.min(depth)
28
+ depth /= torch.max(depth)
29
+ depth = depth.cpu().numpy()
30
+ depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)
31
+
32
+ return depth_image
33
+
34
+
35
+
annotator/midas/api.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/isl-org/MiDaS
2
+
3
+ import cv2
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision.transforms import Compose
8
+
9
+ from .midas.dpt_depth import DPTDepthModel
10
+ from .midas.midas_net import MidasNet
11
+ from .midas.midas_net_custom import MidasNet_small
12
+ from .midas.transforms import Resize, NormalizeImage, PrepareForNet
13
+ from annotator.util import annotator_ckpts_path
14
+
15
+
16
+ ISL_PATHS = {
17
+ "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
18
+ "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
19
+ "midas_v21": "",
20
+ "midas_v21_small": "",
21
+ }
22
+
23
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/dpt_hybrid-midas-501f0c75.pt"
24
+
25
+
26
+ def disabled_train(self, mode=True):
27
+ """Overwrite model.train with this function to make sure train/eval mode
28
+ does not change anymore."""
29
+ return self
30
+
31
+
32
+ def load_midas_transform(model_type):
33
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
34
+ # load transform only
35
+ if model_type == "dpt_large": # DPT-Large
36
+ net_w, net_h = 384, 384
37
+ resize_mode = "minimal"
38
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
39
+
40
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
41
+ net_w, net_h = 384, 384
42
+ resize_mode = "minimal"
43
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
44
+
45
+ elif model_type == "midas_v21":
46
+ net_w, net_h = 384, 384
47
+ resize_mode = "upper_bound"
48
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
49
+
50
+ elif model_type == "midas_v21_small":
51
+ net_w, net_h = 256, 256
52
+ resize_mode = "upper_bound"
53
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
54
+
55
+ else:
56
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
57
+
58
+ transform = Compose(
59
+ [
60
+ Resize(
61
+ net_w,
62
+ net_h,
63
+ resize_target=None,
64
+ keep_aspect_ratio=True,
65
+ ensure_multiple_of=32,
66
+ resize_method=resize_mode,
67
+ image_interpolation_method=cv2.INTER_CUBIC,
68
+ ),
69
+ normalization,
70
+ PrepareForNet(),
71
+ ]
72
+ )
73
+
74
+ return transform
75
+
76
+
77
+ def load_model(model_type):
78
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
79
+ # load network
80
+ model_path = ISL_PATHS[model_type]
81
+ if model_type == "dpt_large": # DPT-Large
82
+ model = DPTDepthModel(
83
+ path=model_path,
84
+ backbone="vitl16_384",
85
+ non_negative=True,
86
+ )
87
+ net_w, net_h = 384, 384
88
+ resize_mode = "minimal"
89
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
90
+
91
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
92
+ if not os.path.exists(model_path):
93
+ from basicsr.utils.download_util import load_file_from_url
94
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
95
+
96
+ model = DPTDepthModel(
97
+ path=model_path,
98
+ backbone="vitb_rn50_384",
99
+ non_negative=True,
100
+ )
101
+ net_w, net_h = 384, 384
102
+ resize_mode = "minimal"
103
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
104
+
105
+ elif model_type == "midas_v21":
106
+ model = MidasNet(model_path, non_negative=True)
107
+ net_w, net_h = 384, 384
108
+ resize_mode = "upper_bound"
109
+ normalization = NormalizeImage(
110
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
111
+ )
112
+
113
+ elif model_type == "midas_v21_small":
114
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
115
+ non_negative=True, blocks={'expand': True})
116
+ net_w, net_h = 256, 256
117
+ resize_mode = "upper_bound"
118
+ normalization = NormalizeImage(
119
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
120
+ )
121
+
122
+ else:
123
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
124
+ assert False
125
+
126
+ transform = Compose(
127
+ [
128
+ Resize(
129
+ net_w,
130
+ net_h,
131
+ resize_target=None,
132
+ keep_aspect_ratio=True,
133
+ ensure_multiple_of=32,
134
+ resize_method=resize_mode,
135
+ image_interpolation_method=cv2.INTER_CUBIC,
136
+ ),
137
+ normalization,
138
+ PrepareForNet(),
139
+ ]
140
+ )
141
+
142
+ return model.eval(), transform
143
+
144
+
145
+ class MiDaSInference(nn.Module):
146
+ MODEL_TYPES_TORCH_HUB = [
147
+ "DPT_Large",
148
+ "DPT_Hybrid",
149
+ "MiDaS_small"
150
+ ]
151
+ MODEL_TYPES_ISL = [
152
+ "dpt_large",
153
+ "dpt_hybrid",
154
+ "midas_v21",
155
+ "midas_v21_small",
156
+ ]
157
+
158
+ def __init__(self, model_type):
159
+ super().__init__()
160
+ assert (model_type in self.MODEL_TYPES_ISL)
161
+ model, _ = load_model(model_type)
162
+ self.model = model
163
+ self.model.train = disabled_train
164
+
165
+ def forward(self, x):
166
+ with torch.no_grad():
167
+ prediction = self.model(x)
168
+ return prediction
169
+
annotator/midas/midas/__init__.py ADDED
File without changes
annotator/midas/midas/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device('cpu'))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
annotator/midas/midas/blocks.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vit import (
5
+ _make_pretrained_vitb_rn50_384,
6
+ _make_pretrained_vitl16_384,
7
+ _make_pretrained_vitb16_384,
8
+ forward_vit,
9
+ )
10
+
11
+ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12
+ if backbone == "vitl16_384":
13
+ pretrained = _make_pretrained_vitl16_384(
14
+ use_pretrained, hooks=hooks, use_readout=use_readout
15
+ )
16
+ scratch = _make_scratch(
17
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
18
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
19
+ elif backbone == "vitb_rn50_384":
20
+ pretrained = _make_pretrained_vitb_rn50_384(
21
+ use_pretrained,
22
+ hooks=hooks,
23
+ use_vit_only=use_vit_only,
24
+ use_readout=use_readout,
25
+ )
26
+ scratch = _make_scratch(
27
+ [256, 512, 768, 768], features, groups=groups, expand=expand
28
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
29
+ elif backbone == "vitb16_384":
30
+ pretrained = _make_pretrained_vitb16_384(
31
+ use_pretrained, hooks=hooks, use_readout=use_readout
32
+ )
33
+ scratch = _make_scratch(
34
+ [96, 192, 384, 768], features, groups=groups, expand=expand
35
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
36
+ elif backbone == "resnext101_wsl":
37
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39
+ elif backbone == "efficientnet_lite3":
40
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42
+ else:
43
+ print(f"Backbone '{backbone}' not implemented")
44
+ assert False
45
+
46
+ return pretrained, scratch
47
+
48
+
49
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50
+ scratch = nn.Module()
51
+
52
+ out_shape1 = out_shape
53
+ out_shape2 = out_shape
54
+ out_shape3 = out_shape
55
+ out_shape4 = out_shape
56
+ if expand==True:
57
+ out_shape1 = out_shape
58
+ out_shape2 = out_shape*2
59
+ out_shape3 = out_shape*4
60
+ out_shape4 = out_shape*8
61
+
62
+ scratch.layer1_rn = nn.Conv2d(
63
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64
+ )
65
+ scratch.layer2_rn = nn.Conv2d(
66
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67
+ )
68
+ scratch.layer3_rn = nn.Conv2d(
69
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70
+ )
71
+ scratch.layer4_rn = nn.Conv2d(
72
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73
+ )
74
+
75
+ return scratch
76
+
77
+
78
+ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79
+ efficientnet = torch.hub.load(
80
+ "rwightman/gen-efficientnet-pytorch",
81
+ "tf_efficientnet_lite3",
82
+ pretrained=use_pretrained,
83
+ exportable=exportable
84
+ )
85
+ return _make_efficientnet_backbone(efficientnet)
86
+
87
+
88
+ def _make_efficientnet_backbone(effnet):
89
+ pretrained = nn.Module()
90
+
91
+ pretrained.layer1 = nn.Sequential(
92
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93
+ )
94
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97
+
98
+ return pretrained
99
+
100
+
101
+ def _make_resnet_backbone(resnet):
102
+ pretrained = nn.Module()
103
+ pretrained.layer1 = nn.Sequential(
104
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105
+ )
106
+
107
+ pretrained.layer2 = resnet.layer2
108
+ pretrained.layer3 = resnet.layer3
109
+ pretrained.layer4 = resnet.layer4
110
+
111
+ return pretrained
112
+
113
+
114
+ def _make_pretrained_resnext101_wsl(use_pretrained):
115
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116
+ return _make_resnet_backbone(resnet)
117
+
118
+
119
+
120
+ class Interpolate(nn.Module):
121
+ """Interpolation module.
122
+ """
123
+
124
+ def __init__(self, scale_factor, mode, align_corners=False):
125
+ """Init.
126
+
127
+ Args:
128
+ scale_factor (float): scaling
129
+ mode (str): interpolation mode
130
+ """
131
+ super(Interpolate, self).__init__()
132
+
133
+ self.interp = nn.functional.interpolate
134
+ self.scale_factor = scale_factor
135
+ self.mode = mode
136
+ self.align_corners = align_corners
137
+
138
+ def forward(self, x):
139
+ """Forward pass.
140
+
141
+ Args:
142
+ x (tensor): input
143
+
144
+ Returns:
145
+ tensor: interpolated data
146
+ """
147
+
148
+ x = self.interp(
149
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150
+ )
151
+
152
+ return x
153
+
154
+
155
+ class ResidualConvUnit(nn.Module):
156
+ """Residual convolution module.
157
+ """
158
+
159
+ def __init__(self, features):
160
+ """Init.
161
+
162
+ Args:
163
+ features (int): number of features
164
+ """
165
+ super().__init__()
166
+
167
+ self.conv1 = nn.Conv2d(
168
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
169
+ )
170
+
171
+ self.conv2 = nn.Conv2d(
172
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
173
+ )
174
+
175
+ self.relu = nn.ReLU(inplace=True)
176
+
177
+ def forward(self, x):
178
+ """Forward pass.
179
+
180
+ Args:
181
+ x (tensor): input
182
+
183
+ Returns:
184
+ tensor: output
185
+ """
186
+ out = self.relu(x)
187
+ out = self.conv1(out)
188
+ out = self.relu(out)
189
+ out = self.conv2(out)
190
+
191
+ return out + x
192
+
193
+
194
+ class FeatureFusionBlock(nn.Module):
195
+ """Feature fusion block.
196
+ """
197
+
198
+ def __init__(self, features):
199
+ """Init.
200
+
201
+ Args:
202
+ features (int): number of features
203
+ """
204
+ super(FeatureFusionBlock, self).__init__()
205
+
206
+ self.resConfUnit1 = ResidualConvUnit(features)
207
+ self.resConfUnit2 = ResidualConvUnit(features)
208
+
209
+ def forward(self, *xs):
210
+ """Forward pass.
211
+
212
+ Returns:
213
+ tensor: output
214
+ """
215
+ output = xs[0]
216
+
217
+ if len(xs) == 2:
218
+ output += self.resConfUnit1(xs[1])
219
+
220
+ output = self.resConfUnit2(output)
221
+
222
+ output = nn.functional.interpolate(
223
+ output, scale_factor=2, mode="bilinear", align_corners=True
224
+ )
225
+
226
+ return output
227
+
228
+
229
+
230
+
231
+ class ResidualConvUnit_custom(nn.Module):
232
+ """Residual convolution module.
233
+ """
234
+
235
+ def __init__(self, features, activation, bn):
236
+ """Init.
237
+
238
+ Args:
239
+ features (int): number of features
240
+ """
241
+ super().__init__()
242
+
243
+ self.bn = bn
244
+
245
+ self.groups=1
246
+
247
+ self.conv1 = nn.Conv2d(
248
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249
+ )
250
+
251
+ self.conv2 = nn.Conv2d(
252
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253
+ )
254
+
255
+ if self.bn==True:
256
+ self.bn1 = nn.BatchNorm2d(features)
257
+ self.bn2 = nn.BatchNorm2d(features)
258
+
259
+ self.activation = activation
260
+
261
+ self.skip_add = nn.quantized.FloatFunctional()
262
+
263
+ def forward(self, x):
264
+ """Forward pass.
265
+
266
+ Args:
267
+ x (tensor): input
268
+
269
+ Returns:
270
+ tensor: output
271
+ """
272
+
273
+ out = self.activation(x)
274
+ out = self.conv1(out)
275
+ if self.bn==True:
276
+ out = self.bn1(out)
277
+
278
+ out = self.activation(out)
279
+ out = self.conv2(out)
280
+ if self.bn==True:
281
+ out = self.bn2(out)
282
+
283
+ if self.groups > 1:
284
+ out = self.conv_merge(out)
285
+
286
+ return self.skip_add.add(out, x)
287
+
288
+ # return out + x
289
+
290
+
291
+ class FeatureFusionBlock_custom(nn.Module):
292
+ """Feature fusion block.
293
+ """
294
+
295
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296
+ """Init.
297
+
298
+ Args:
299
+ features (int): number of features
300
+ """
301
+ super(FeatureFusionBlock_custom, self).__init__()
302
+
303
+ self.deconv = deconv
304
+ self.align_corners = align_corners
305
+
306
+ self.groups=1
307
+
308
+ self.expand = expand
309
+ out_features = features
310
+ if self.expand==True:
311
+ out_features = features//2
312
+
313
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314
+
315
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317
+
318
+ self.skip_add = nn.quantized.FloatFunctional()
319
+
320
+ def forward(self, *xs):
321
+ """Forward pass.
322
+
323
+ Returns:
324
+ tensor: output
325
+ """
326
+ output = xs[0]
327
+
328
+ if len(xs) == 2:
329
+ res = self.resConfUnit1(xs[1])
330
+ output = self.skip_add.add(output, res)
331
+ # output += res
332
+
333
+ output = self.resConfUnit2(output)
334
+
335
+ output = nn.functional.interpolate(
336
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337
+ )
338
+
339
+ output = self.out_conv(output)
340
+
341
+ return output
342
+
annotator/midas/midas/dpt_depth.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .base_model import BaseModel
6
+ from .blocks import (
7
+ FeatureFusionBlock,
8
+ FeatureFusionBlock_custom,
9
+ Interpolate,
10
+ _make_encoder,
11
+ forward_vit,
12
+ )
13
+
14
+
15
+ def _make_fusion_block(features, use_bn):
16
+ return FeatureFusionBlock_custom(
17
+ features,
18
+ nn.ReLU(False),
19
+ deconv=False,
20
+ bn=use_bn,
21
+ expand=False,
22
+ align_corners=True,
23
+ )
24
+
25
+
26
+ class DPT(BaseModel):
27
+ def __init__(
28
+ self,
29
+ head,
30
+ features=256,
31
+ backbone="vitb_rn50_384",
32
+ readout="project",
33
+ channels_last=False,
34
+ use_bn=False,
35
+ ):
36
+
37
+ super(DPT, self).__init__()
38
+
39
+ self.channels_last = channels_last
40
+
41
+ hooks = {
42
+ "vitb_rn50_384": [0, 1, 8, 11],
43
+ "vitb16_384": [2, 5, 8, 11],
44
+ "vitl16_384": [5, 11, 17, 23],
45
+ }
46
+
47
+ # Instantiate backbone and reassemble blocks
48
+ self.pretrained, self.scratch = _make_encoder(
49
+ backbone,
50
+ features,
51
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
52
+ groups=1,
53
+ expand=False,
54
+ exportable=False,
55
+ hooks=hooks[backbone],
56
+ use_readout=readout,
57
+ )
58
+
59
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63
+
64
+ self.scratch.output_conv = head
65
+
66
+
67
+ def forward(self, x):
68
+ if self.channels_last == True:
69
+ x.contiguous(memory_format=torch.channels_last)
70
+
71
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72
+
73
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
74
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
75
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
76
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
77
+
78
+ path_4 = self.scratch.refinenet4(layer_4_rn)
79
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82
+
83
+ out = self.scratch.output_conv(path_1)
84
+
85
+ return out
86
+
87
+
88
+ class DPTDepthModel(DPT):
89
+ def __init__(self, path=None, non_negative=True, **kwargs):
90
+ features = kwargs["features"] if "features" in kwargs else 256
91
+
92
+ head = nn.Sequential(
93
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96
+ nn.ReLU(True),
97
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98
+ nn.ReLU(True) if non_negative else nn.Identity(),
99
+ nn.Identity(),
100
+ )
101
+
102
+ super().__init__(head, **kwargs)
103
+
104
+ if path is not None:
105
+ self.load(path)
106
+
107
+ def forward(self, x):
108
+ return super().forward(x).squeeze(dim=1)
109
+
annotator/midas/midas/midas_net.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=256, non_negative=True):
17
+ """Init.
18
+
19
+ Args:
20
+ path (str, optional): Path to saved model. Defaults to None.
21
+ features (int, optional): Number of features. Defaults to 256.
22
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23
+ """
24
+ print("Loading weights: ", path)
25
+
26
+ super(MidasNet, self).__init__()
27
+
28
+ use_pretrained = False if path is None else True
29
+
30
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31
+
32
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
33
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
34
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
35
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
36
+
37
+ self.scratch.output_conv = nn.Sequential(
38
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39
+ Interpolate(scale_factor=2, mode="bilinear"),
40
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41
+ nn.ReLU(True),
42
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43
+ nn.ReLU(True) if non_negative else nn.Identity(),
44
+ )
45
+
46
+ if path:
47
+ self.load(path)
48
+
49
+ def forward(self, x):
50
+ """Forward pass.
51
+
52
+ Args:
53
+ x (tensor): input data (image)
54
+
55
+ Returns:
56
+ tensor: depth
57
+ """
58
+
59
+ layer_1 = self.pretrained.layer1(x)
60
+ layer_2 = self.pretrained.layer2(layer_1)
61
+ layer_3 = self.pretrained.layer3(layer_2)
62
+ layer_4 = self.pretrained.layer4(layer_3)
63
+
64
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
65
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
66
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
67
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
68
+
69
+ path_4 = self.scratch.refinenet4(layer_4_rn)
70
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73
+
74
+ out = self.scratch.output_conv(path_1)
75
+
76
+ return torch.squeeze(out, dim=1)
annotator/midas/midas/midas_net_custom.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet_small(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17
+ blocks={'expand': True}):
18
+ """Init.
19
+
20
+ Args:
21
+ path (str, optional): Path to saved model. Defaults to None.
22
+ features (int, optional): Number of features. Defaults to 256.
23
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24
+ """
25
+ print("Loading weights: ", path)
26
+
27
+ super(MidasNet_small, self).__init__()
28
+
29
+ use_pretrained = False if path else True
30
+
31
+ self.channels_last = channels_last
32
+ self.blocks = blocks
33
+ self.backbone = backbone
34
+
35
+ self.groups = 1
36
+
37
+ features1=features
38
+ features2=features
39
+ features3=features
40
+ features4=features
41
+ self.expand = False
42
+ if "expand" in self.blocks and self.blocks['expand'] == True:
43
+ self.expand = True
44
+ features1=features
45
+ features2=features*2
46
+ features3=features*4
47
+ features4=features*8
48
+
49
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50
+
51
+ self.scratch.activation = nn.ReLU(False)
52
+
53
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57
+
58
+
59
+ self.scratch.output_conv = nn.Sequential(
60
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61
+ Interpolate(scale_factor=2, mode="bilinear"),
62
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63
+ self.scratch.activation,
64
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65
+ nn.ReLU(True) if non_negative else nn.Identity(),
66
+ nn.Identity(),
67
+ )
68
+
69
+ if path:
70
+ self.load(path)
71
+
72
+
73
+ def forward(self, x):
74
+ """Forward pass.
75
+
76
+ Args:
77
+ x (tensor): input data (image)
78
+
79
+ Returns:
80
+ tensor: depth
81
+ """
82
+ if self.channels_last==True:
83
+ print("self.channels_last = ", self.channels_last)
84
+ x.contiguous(memory_format=torch.channels_last)
85
+
86
+
87
+ layer_1 = self.pretrained.layer1(x)
88
+ layer_2 = self.pretrained.layer2(layer_1)
89
+ layer_3 = self.pretrained.layer3(layer_2)
90
+ layer_4 = self.pretrained.layer4(layer_3)
91
+
92
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
93
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
94
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
95
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
96
+
97
+
98
+ path_4 = self.scratch.refinenet4(layer_4_rn)
99
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102
+
103
+ out = self.scratch.output_conv(path_1)
104
+
105
+ return torch.squeeze(out, dim=1)
106
+
107
+
108
+
109
+ def fuse_model(m):
110
+ prev_previous_type = nn.Identity()
111
+ prev_previous_name = ''
112
+ previous_type = nn.Identity()
113
+ previous_name = ''
114
+ for name, module in m.named_modules():
115
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116
+ # print("FUSED ", prev_previous_name, previous_name, name)
117
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119
+ # print("FUSED ", prev_previous_name, previous_name)
120
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122
+ # print("FUSED ", previous_name, name)
123
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124
+
125
+ prev_previous_type = previous_type
126
+ prev_previous_name = previous_name
127
+ previous_type = type(module)
128
+ previous_name = name
annotator/midas/midas/transforms.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import math
4
+
5
+
6
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
8
+
9
+ Args:
10
+ sample (dict): sample
11
+ size (tuple): image size
12
+
13
+ Returns:
14
+ tuple: new size
15
+ """
16
+ shape = list(sample["disparity"].shape)
17
+
18
+ if shape[0] >= size[0] and shape[1] >= size[1]:
19
+ return sample
20
+
21
+ scale = [0, 0]
22
+ scale[0] = size[0] / shape[0]
23
+ scale[1] = size[1] / shape[1]
24
+
25
+ scale = max(scale)
26
+
27
+ shape[0] = math.ceil(scale * shape[0])
28
+ shape[1] = math.ceil(scale * shape[1])
29
+
30
+ # resize
31
+ sample["image"] = cv2.resize(
32
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33
+ )
34
+
35
+ sample["disparity"] = cv2.resize(
36
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37
+ )
38
+ sample["mask"] = cv2.resize(
39
+ sample["mask"].astype(np.float32),
40
+ tuple(shape[::-1]),
41
+ interpolation=cv2.INTER_NEAREST,
42
+ )
43
+ sample["mask"] = sample["mask"].astype(bool)
44
+
45
+ return tuple(shape)
46
+
47
+
48
+ class Resize(object):
49
+ """Resize sample to given size (width, height).
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ width,
55
+ height,
56
+ resize_target=True,
57
+ keep_aspect_ratio=False,
58
+ ensure_multiple_of=1,
59
+ resize_method="lower_bound",
60
+ image_interpolation_method=cv2.INTER_AREA,
61
+ ):
62
+ """Init.
63
+
64
+ Args:
65
+ width (int): desired output width
66
+ height (int): desired output height
67
+ resize_target (bool, optional):
68
+ True: Resize the full sample (image, mask, target).
69
+ False: Resize image only.
70
+ Defaults to True.
71
+ keep_aspect_ratio (bool, optional):
72
+ True: Keep the aspect ratio of the input sample.
73
+ Output sample might not have the given width and height, and
74
+ resize behaviour depends on the parameter 'resize_method'.
75
+ Defaults to False.
76
+ ensure_multiple_of (int, optional):
77
+ Output width and height is constrained to be multiple of this parameter.
78
+ Defaults to 1.
79
+ resize_method (str, optional):
80
+ "lower_bound": Output will be at least as large as the given size.
81
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83
+ Defaults to "lower_bound".
84
+ """
85
+ self.__width = width
86
+ self.__height = height
87
+
88
+ self.__resize_target = resize_target
89
+ self.__keep_aspect_ratio = keep_aspect_ratio
90
+ self.__multiple_of = ensure_multiple_of
91
+ self.__resize_method = resize_method
92
+ self.__image_interpolation_method = image_interpolation_method
93
+
94
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96
+
97
+ if max_val is not None and y > max_val:
98
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99
+
100
+ if y < min_val:
101
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102
+
103
+ return y
104
+
105
+ def get_size(self, width, height):
106
+ # determine new height and width
107
+ scale_height = self.__height / height
108
+ scale_width = self.__width / width
109
+
110
+ if self.__keep_aspect_ratio:
111
+ if self.__resize_method == "lower_bound":
112
+ # scale such that output size is lower bound
113
+ if scale_width > scale_height:
114
+ # fit width
115
+ scale_height = scale_width
116
+ else:
117
+ # fit height
118
+ scale_width = scale_height
119
+ elif self.__resize_method == "upper_bound":
120
+ # scale such that output size is upper bound
121
+ if scale_width < scale_height:
122
+ # fit width
123
+ scale_height = scale_width
124
+ else:
125
+ # fit height
126
+ scale_width = scale_height
127
+ elif self.__resize_method == "minimal":
128
+ # scale as least as possbile
129
+ if abs(1 - scale_width) < abs(1 - scale_height):
130
+ # fit width
131
+ scale_height = scale_width
132
+ else:
133
+ # fit height
134
+ scale_width = scale_height
135
+ else:
136
+ raise ValueError(
137
+ f"resize_method {self.__resize_method} not implemented"
138
+ )
139
+
140
+ if self.__resize_method == "lower_bound":
141
+ new_height = self.constrain_to_multiple_of(
142
+ scale_height * height, min_val=self.__height
143
+ )
144
+ new_width = self.constrain_to_multiple_of(
145
+ scale_width * width, min_val=self.__width
146
+ )
147
+ elif self.__resize_method == "upper_bound":
148
+ new_height = self.constrain_to_multiple_of(
149
+ scale_height * height, max_val=self.__height
150
+ )
151
+ new_width = self.constrain_to_multiple_of(
152
+ scale_width * width, max_val=self.__width
153
+ )
154
+ elif self.__resize_method == "minimal":
155
+ new_height = self.constrain_to_multiple_of(scale_height * height)
156
+ new_width = self.constrain_to_multiple_of(scale_width * width)
157
+ else:
158
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
159
+
160
+ return (new_width, new_height)
161
+
162
+ def __call__(self, sample):
163
+ width, height = self.get_size(
164
+ sample["image"].shape[1], sample["image"].shape[0]
165
+ )
166
+
167
+ # resize sample
168
+ sample["image"] = cv2.resize(
169
+ sample["image"],
170
+ (width, height),
171
+ interpolation=self.__image_interpolation_method,
172
+ )
173
+
174
+ if self.__resize_target:
175
+ if "disparity" in sample:
176
+ sample["disparity"] = cv2.resize(
177
+ sample["disparity"],
178
+ (width, height),
179
+ interpolation=cv2.INTER_NEAREST,
180
+ )
181
+
182
+ if "depth" in sample:
183
+ sample["depth"] = cv2.resize(
184
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185
+ )
186
+
187
+ sample["mask"] = cv2.resize(
188
+ sample["mask"].astype(np.float32),
189
+ (width, height),
190
+ interpolation=cv2.INTER_NEAREST,
191
+ )
192
+ sample["mask"] = sample["mask"].astype(bool)
193
+
194
+ return sample
195
+
196
+
197
+ class NormalizeImage(object):
198
+ """Normlize image by given mean and std.
199
+ """
200
+
201
+ def __init__(self, mean, std):
202
+ self.__mean = mean
203
+ self.__std = std
204
+
205
+ def __call__(self, sample):
206
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
207
+
208
+ return sample
209
+
210
+
211
+ class PrepareForNet(object):
212
+ """Prepare sample for usage as network input.
213
+ """
214
+
215
+ def __init__(self):
216
+ pass
217
+
218
+ def __call__(self, sample):
219
+ image = np.transpose(sample["image"], (2, 0, 1))
220
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221
+
222
+ if "mask" in sample:
223
+ sample["mask"] = sample["mask"].astype(np.float32)
224
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
225
+
226
+ if "disparity" in sample:
227
+ disparity = sample["disparity"].astype(np.float32)
228
+ sample["disparity"] = np.ascontiguousarray(disparity)
229
+
230
+ if "depth" in sample:
231
+ depth = sample["depth"].astype(np.float32)
232
+ sample["depth"] = np.ascontiguousarray(depth)
233
+
234
+ return sample
annotator/midas/midas/vit.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class Slice(nn.Module):
10
+ def __init__(self, start_index=1):
11
+ super(Slice, self).__init__()
12
+ self.start_index = start_index
13
+
14
+ def forward(self, x):
15
+ return x[:, self.start_index :]
16
+
17
+
18
+ class AddReadout(nn.Module):
19
+ def __init__(self, start_index=1):
20
+ super(AddReadout, self).__init__()
21
+ self.start_index = start_index
22
+
23
+ def forward(self, x):
24
+ if self.start_index == 2:
25
+ readout = (x[:, 0] + x[:, 1]) / 2
26
+ else:
27
+ readout = x[:, 0]
28
+ return x[:, self.start_index :] + readout.unsqueeze(1)
29
+
30
+
31
+ class ProjectReadout(nn.Module):
32
+ def __init__(self, in_features, start_index=1):
33
+ super(ProjectReadout, self).__init__()
34
+ self.start_index = start_index
35
+
36
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
37
+
38
+ def forward(self, x):
39
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
40
+ features = torch.cat((x[:, self.start_index :], readout), -1)
41
+
42
+ return self.project(features)
43
+
44
+
45
+ class Transpose(nn.Module):
46
+ def __init__(self, dim0, dim1):
47
+ super(Transpose, self).__init__()
48
+ self.dim0 = dim0
49
+ self.dim1 = dim1
50
+
51
+ def forward(self, x):
52
+ x = x.transpose(self.dim0, self.dim1)
53
+ return x
54
+
55
+
56
+ def forward_vit(pretrained, x):
57
+ b, c, h, w = x.shape
58
+
59
+ glob = pretrained.model.forward_flex(x)
60
+
61
+ layer_1 = pretrained.activations["1"]
62
+ layer_2 = pretrained.activations["2"]
63
+ layer_3 = pretrained.activations["3"]
64
+ layer_4 = pretrained.activations["4"]
65
+
66
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
67
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
68
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
69
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
70
+
71
+ unflatten = nn.Sequential(
72
+ nn.Unflatten(
73
+ 2,
74
+ torch.Size(
75
+ [
76
+ h // pretrained.model.patch_size[1],
77
+ w // pretrained.model.patch_size[0],
78
+ ]
79
+ ),
80
+ )
81
+ )
82
+
83
+ if layer_1.ndim == 3:
84
+ layer_1 = unflatten(layer_1)
85
+ if layer_2.ndim == 3:
86
+ layer_2 = unflatten(layer_2)
87
+ if layer_3.ndim == 3:
88
+ layer_3 = unflatten(layer_3)
89
+ if layer_4.ndim == 3:
90
+ layer_4 = unflatten(layer_4)
91
+
92
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
93
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
94
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
95
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
96
+
97
+ return layer_1, layer_2, layer_3, layer_4
98
+
99
+
100
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
101
+ posemb_tok, posemb_grid = (
102
+ posemb[:, : self.start_index],
103
+ posemb[0, self.start_index :],
104
+ )
105
+
106
+ gs_old = int(math.sqrt(len(posemb_grid)))
107
+
108
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
109
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
110
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
111
+
112
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
113
+
114
+ return posemb
115
+
116
+
117
+ def forward_flex(self, x):
118
+ b, c, h, w = x.shape
119
+
120
+ pos_embed = self._resize_pos_embed(
121
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
122
+ )
123
+
124
+ B = x.shape[0]
125
+
126
+ if hasattr(self.patch_embed, "backbone"):
127
+ x = self.patch_embed.backbone(x)
128
+ if isinstance(x, (list, tuple)):
129
+ x = x[-1] # last feature if backbone outputs list/tuple of features
130
+
131
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
132
+
133
+ if getattr(self, "dist_token", None) is not None:
134
+ cls_tokens = self.cls_token.expand(
135
+ B, -1, -1
136
+ ) # stole cls_tokens impl from Phil Wang, thanks
137
+ dist_token = self.dist_token.expand(B, -1, -1)
138
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
139
+ else:
140
+ cls_tokens = self.cls_token.expand(
141
+ B, -1, -1
142
+ ) # stole cls_tokens impl from Phil Wang, thanks
143
+ x = torch.cat((cls_tokens, x), dim=1)
144
+
145
+ x = x + pos_embed
146
+ x = self.pos_drop(x)
147
+
148
+ for blk in self.blocks:
149
+ x = blk(x)
150
+
151
+ x = self.norm(x)
152
+
153
+ return x
154
+
155
+
156
+ activations = {}
157
+
158
+
159
+ def get_activation(name):
160
+ def hook(model, input, output):
161
+ activations[name] = output
162
+
163
+ return hook
164
+
165
+
166
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
167
+ if use_readout == "ignore":
168
+ readout_oper = [Slice(start_index)] * len(features)
169
+ elif use_readout == "add":
170
+ readout_oper = [AddReadout(start_index)] * len(features)
171
+ elif use_readout == "project":
172
+ readout_oper = [
173
+ ProjectReadout(vit_features, start_index) for out_feat in features
174
+ ]
175
+ else:
176
+ assert (
177
+ False
178
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
179
+
180
+ return readout_oper
181
+
182
+
183
+ def _make_vit_b16_backbone(
184
+ model,
185
+ features=[96, 192, 384, 768],
186
+ size=[384, 384],
187
+ hooks=[2, 5, 8, 11],
188
+ vit_features=768,
189
+ use_readout="ignore",
190
+ start_index=1,
191
+ ):
192
+ pretrained = nn.Module()
193
+
194
+ pretrained.model = model
195
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
196
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
197
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
198
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
199
+
200
+ pretrained.activations = activations
201
+
202
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
203
+
204
+ # 32, 48, 136, 384
205
+ pretrained.act_postprocess1 = nn.Sequential(
206
+ readout_oper[0],
207
+ Transpose(1, 2),
208
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
209
+ nn.Conv2d(
210
+ in_channels=vit_features,
211
+ out_channels=features[0],
212
+ kernel_size=1,
213
+ stride=1,
214
+ padding=0,
215
+ ),
216
+ nn.ConvTranspose2d(
217
+ in_channels=features[0],
218
+ out_channels=features[0],
219
+ kernel_size=4,
220
+ stride=4,
221
+ padding=0,
222
+ bias=True,
223
+ dilation=1,
224
+ groups=1,
225
+ ),
226
+ )
227
+
228
+ pretrained.act_postprocess2 = nn.Sequential(
229
+ readout_oper[1],
230
+ Transpose(1, 2),
231
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
232
+ nn.Conv2d(
233
+ in_channels=vit_features,
234
+ out_channels=features[1],
235
+ kernel_size=1,
236
+ stride=1,
237
+ padding=0,
238
+ ),
239
+ nn.ConvTranspose2d(
240
+ in_channels=features[1],
241
+ out_channels=features[1],
242
+ kernel_size=2,
243
+ stride=2,
244
+ padding=0,
245
+ bias=True,
246
+ dilation=1,
247
+ groups=1,
248
+ ),
249
+ )
250
+
251
+ pretrained.act_postprocess3 = nn.Sequential(
252
+ readout_oper[2],
253
+ Transpose(1, 2),
254
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
255
+ nn.Conv2d(
256
+ in_channels=vit_features,
257
+ out_channels=features[2],
258
+ kernel_size=1,
259
+ stride=1,
260
+ padding=0,
261
+ ),
262
+ )
263
+
264
+ pretrained.act_postprocess4 = nn.Sequential(
265
+ readout_oper[3],
266
+ Transpose(1, 2),
267
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
268
+ nn.Conv2d(
269
+ in_channels=vit_features,
270
+ out_channels=features[3],
271
+ kernel_size=1,
272
+ stride=1,
273
+ padding=0,
274
+ ),
275
+ nn.Conv2d(
276
+ in_channels=features[3],
277
+ out_channels=features[3],
278
+ kernel_size=3,
279
+ stride=2,
280
+ padding=1,
281
+ ),
282
+ )
283
+
284
+ pretrained.model.start_index = start_index
285
+ pretrained.model.patch_size = [16, 16]
286
+
287
+ # We inject this function into the VisionTransformer instances so that
288
+ # we can use it with interpolated position embeddings without modifying the library source.
289
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
290
+ pretrained.model._resize_pos_embed = types.MethodType(
291
+ _resize_pos_embed, pretrained.model
292
+ )
293
+
294
+ return pretrained
295
+
296
+
297
+ def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
298
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
299
+
300
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
301
+ return _make_vit_b16_backbone(
302
+ model,
303
+ features=[256, 512, 1024, 1024],
304
+ hooks=hooks,
305
+ vit_features=1024,
306
+ use_readout=use_readout,
307
+ )
308
+
309
+
310
+ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
311
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
312
+
313
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
314
+ return _make_vit_b16_backbone(
315
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
316
+ )
317
+
318
+
319
+ def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
320
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
321
+
322
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
323
+ return _make_vit_b16_backbone(
324
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
325
+ )
326
+
327
+
328
+ def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
329
+ model = timm.create_model(
330
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
331
+ )
332
+
333
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
334
+ return _make_vit_b16_backbone(
335
+ model,
336
+ features=[96, 192, 384, 768],
337
+ hooks=hooks,
338
+ use_readout=use_readout,
339
+ start_index=2,
340
+ )
341
+
342
+
343
+ def _make_vit_b_rn50_backbone(
344
+ model,
345
+ features=[256, 512, 768, 768],
346
+ size=[384, 384],
347
+ hooks=[0, 1, 8, 11],
348
+ vit_features=768,
349
+ use_vit_only=False,
350
+ use_readout="ignore",
351
+ start_index=1,
352
+ ):
353
+ pretrained = nn.Module()
354
+
355
+ pretrained.model = model
356
+
357
+ if use_vit_only == True:
358
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
359
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
360
+ else:
361
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
362
+ get_activation("1")
363
+ )
364
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
365
+ get_activation("2")
366
+ )
367
+
368
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
369
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
370
+
371
+ pretrained.activations = activations
372
+
373
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
374
+
375
+ if use_vit_only == True:
376
+ pretrained.act_postprocess1 = nn.Sequential(
377
+ readout_oper[0],
378
+ Transpose(1, 2),
379
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
380
+ nn.Conv2d(
381
+ in_channels=vit_features,
382
+ out_channels=features[0],
383
+ kernel_size=1,
384
+ stride=1,
385
+ padding=0,
386
+ ),
387
+ nn.ConvTranspose2d(
388
+ in_channels=features[0],
389
+ out_channels=features[0],
390
+ kernel_size=4,
391
+ stride=4,
392
+ padding=0,
393
+ bias=True,
394
+ dilation=1,
395
+ groups=1,
396
+ ),
397
+ )
398
+
399
+ pretrained.act_postprocess2 = nn.Sequential(
400
+ readout_oper[1],
401
+ Transpose(1, 2),
402
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
403
+ nn.Conv2d(
404
+ in_channels=vit_features,
405
+ out_channels=features[1],
406
+ kernel_size=1,
407
+ stride=1,
408
+ padding=0,
409
+ ),
410
+ nn.ConvTranspose2d(
411
+ in_channels=features[1],
412
+ out_channels=features[1],
413
+ kernel_size=2,
414
+ stride=2,
415
+ padding=0,
416
+ bias=True,
417
+ dilation=1,
418
+ groups=1,
419
+ ),
420
+ )
421
+ else:
422
+ pretrained.act_postprocess1 = nn.Sequential(
423
+ nn.Identity(), nn.Identity(), nn.Identity()
424
+ )
425
+ pretrained.act_postprocess2 = nn.Sequential(
426
+ nn.Identity(), nn.Identity(), nn.Identity()
427
+ )
428
+
429
+ pretrained.act_postprocess3 = nn.Sequential(
430
+ readout_oper[2],
431
+ Transpose(1, 2),
432
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
433
+ nn.Conv2d(
434
+ in_channels=vit_features,
435
+ out_channels=features[2],
436
+ kernel_size=1,
437
+ stride=1,
438
+ padding=0,
439
+ ),
440
+ )
441
+
442
+ pretrained.act_postprocess4 = nn.Sequential(
443
+ readout_oper[3],
444
+ Transpose(1, 2),
445
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
446
+ nn.Conv2d(
447
+ in_channels=vit_features,
448
+ out_channels=features[3],
449
+ kernel_size=1,
450
+ stride=1,
451
+ padding=0,
452
+ ),
453
+ nn.Conv2d(
454
+ in_channels=features[3],
455
+ out_channels=features[3],
456
+ kernel_size=3,
457
+ stride=2,
458
+ padding=1,
459
+ ),
460
+ )
461
+
462
+ pretrained.model.start_index = start_index
463
+ pretrained.model.patch_size = [16, 16]
464
+
465
+ # We inject this function into the VisionTransformer instances so that
466
+ # we can use it with interpolated position embeddings without modifying the library source.
467
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
468
+
469
+ # We inject this function into the VisionTransformer instances so that
470
+ # we can use it with interpolated position embeddings without modifying the library source.
471
+ pretrained.model._resize_pos_embed = types.MethodType(
472
+ _resize_pos_embed, pretrained.model
473
+ )
474
+
475
+ return pretrained
476
+
477
+
478
+ def _make_pretrained_vitb_rn50_384(
479
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
480
+ ):
481
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
482
+
483
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
484
+ return _make_vit_b_rn50_backbone(
485
+ model,
486
+ features=[256, 512, 768, 768],
487
+ size=[384, 384],
488
+ hooks=hooks,
489
+ use_vit_only=use_vit_only,
490
+ use_readout=use_readout,
491
+ )
annotator/midas/utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for monoDepth."""
2
+ import sys
3
+ import re
4
+ import numpy as np
5
+ import cv2
6
+ import torch
7
+
8
+
9
+ def read_pfm(path):
10
+ """Read pfm file.
11
+
12
+ Args:
13
+ path (str): path to file
14
+
15
+ Returns:
16
+ tuple: (data, scale)
17
+ """
18
+ with open(path, "rb") as file:
19
+
20
+ color = None
21
+ width = None
22
+ height = None
23
+ scale = None
24
+ endian = None
25
+
26
+ header = file.readline().rstrip()
27
+ if header.decode("ascii") == "PF":
28
+ color = True
29
+ elif header.decode("ascii") == "Pf":
30
+ color = False
31
+ else:
32
+ raise Exception("Not a PFM file: " + path)
33
+
34
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35
+ if dim_match:
36
+ width, height = list(map(int, dim_match.groups()))
37
+ else:
38
+ raise Exception("Malformed PFM header.")
39
+
40
+ scale = float(file.readline().decode("ascii").rstrip())
41
+ if scale < 0:
42
+ # little-endian
43
+ endian = "<"
44
+ scale = -scale
45
+ else:
46
+ # big-endian
47
+ endian = ">"
48
+
49
+ data = np.fromfile(file, endian + "f")
50
+ shape = (height, width, 3) if color else (height, width)
51
+
52
+ data = np.reshape(data, shape)
53
+ data = np.flipud(data)
54
+
55
+ return data, scale
56
+
57
+
58
+ def write_pfm(path, image, scale=1):
59
+ """Write pfm file.
60
+
61
+ Args:
62
+ path (str): pathto file
63
+ image (array): data
64
+ scale (int, optional): Scale. Defaults to 1.
65
+ """
66
+
67
+ with open(path, "wb") as file:
68
+ color = None
69
+
70
+ if image.dtype.name != "float32":
71
+ raise Exception("Image dtype must be float32.")
72
+
73
+ image = np.flipud(image)
74
+
75
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
76
+ color = True
77
+ elif (
78
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79
+ ): # greyscale
80
+ color = False
81
+ else:
82
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83
+
84
+ file.write("PF\n" if color else "Pf\n".encode())
85
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86
+
87
+ endian = image.dtype.byteorder
88
+
89
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
90
+ scale = -scale
91
+
92
+ file.write("%f\n".encode() % scale)
93
+
94
+ image.tofile(file)
95
+
96
+
97
+ def read_image(path):
98
+ """Read image and output RGB image (0-1).
99
+
100
+ Args:
101
+ path (str): path to file
102
+
103
+ Returns:
104
+ array: RGB image (0-1)
105
+ """
106
+ img = cv2.imread(path)
107
+
108
+ if img.ndim == 2:
109
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110
+
111
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112
+
113
+ return img
114
+
115
+
116
+ def resize_image(img):
117
+ """Resize image and make it fit for network.
118
+
119
+ Args:
120
+ img (array): image
121
+
122
+ Returns:
123
+ tensor: data ready for network
124
+ """
125
+ height_orig = img.shape[0]
126
+ width_orig = img.shape[1]
127
+
128
+ if width_orig > height_orig:
129
+ scale = width_orig / 384
130
+ else:
131
+ scale = height_orig / 384
132
+
133
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135
+
136
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137
+
138
+ img_resized = (
139
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140
+ )
141
+ img_resized = img_resized.unsqueeze(0)
142
+
143
+ return img_resized
144
+
145
+
146
+ def resize_depth(depth, width, height):
147
+ """Resize depth map and bring to CPU (numpy).
148
+
149
+ Args:
150
+ depth (tensor): depth
151
+ width (int): image width
152
+ height (int): image height
153
+
154
+ Returns:
155
+ array: processed depth
156
+ """
157
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158
+
159
+ depth_resized = cv2.resize(
160
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161
+ )
162
+
163
+ return depth_resized
164
+
165
+ def write_depth(path, depth, bits=1):
166
+ """Write depth map to pfm and png file.
167
+
168
+ Args:
169
+ path (str): filepath without extension
170
+ depth (array): depth
171
+ """
172
+ write_pfm(path + ".pfm", depth.astype(np.float32))
173
+
174
+ depth_min = depth.min()
175
+ depth_max = depth.max()
176
+
177
+ max_val = (2**(8*bits))-1
178
+
179
+ if depth_max - depth_min > np.finfo("float").eps:
180
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
181
+ else:
182
+ out = np.zeros(depth.shape, dtype=depth.type)
183
+
184
+ if bits == 1:
185
+ cv2.imwrite(path + ".png", out.astype("uint8"))
186
+ elif bits == 2:
187
+ cv2.imwrite(path + ".png", out.astype("uint16"))
188
+
189
+ return
annotator/util.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import cv2
5
+ import os
6
+ import PIL
7
+
8
+ annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
9
+
10
+ def HWC3(x):
11
+ assert x.dtype == np.uint8
12
+ if x.ndim == 2:
13
+ x = x[:, :, None]
14
+ assert x.ndim == 3
15
+ H, W, C = x.shape
16
+ assert C == 1 or C == 3 or C == 4
17
+ if C == 3:
18
+ return x
19
+ if C == 1:
20
+ return np.concatenate([x, x, x], axis=2)
21
+ if C == 4:
22
+ color = x[:, :, 0:3].astype(np.float32)
23
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
24
+ y = color * alpha + 255.0 * (1.0 - alpha)
25
+ y = y.clip(0, 255).astype(np.uint8)
26
+ return y
27
+
28
+
29
+ def resize_image(input_image, resolution, short = False, interpolation=None):
30
+ if isinstance(input_image,PIL.Image.Image):
31
+ mode = 'pil'
32
+ W,H = input_image.size
33
+
34
+ elif isinstance(input_image,np.ndarray):
35
+ mode = 'cv2'
36
+ H, W, _ = input_image.shape
37
+
38
+ H = float(H)
39
+ W = float(W)
40
+ if short:
41
+ k = float(resolution) / min(H, W) # k>1 ζ”Ύε€§οΌŒ k<1 缩小
42
+ else:
43
+ k = float(resolution) / max(H, W) # k>1 ζ”Ύε€§οΌŒ k<1 缩小
44
+ H *= k
45
+ W *= k
46
+ H = int(np.round(H / 64.0)) * 64
47
+ W = int(np.round(W / 64.0)) * 64
48
+
49
+ if mode == 'cv2':
50
+ if interpolation is None:
51
+ interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
52
+ img = cv2.resize(input_image, (W, H), interpolation=interpolation)
53
+
54
+ elif mode == 'pil':
55
+ if interpolation is None:
56
+ interpolation = PIL.Image.LANCZOS if k > 1 else PIL.Image.BILINEAR
57
+ img = input_image.resize((W, H), resample=interpolation)
58
+
59
+ return img
60
+
61
+ # def resize_image(input_image, resolution):
62
+ # H, W, C = input_image.shape
63
+ # H = float(H)
64
+ # W = float(W)
65
+ # k = float(resolution) / min(H, W)
66
+ # H *= k
67
+ # W *= k
68
+ # H = int(np.round(H / 64.0)) * 64
69
+ # W = int(np.round(W / 64.0)) * 64
70
+ # img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
71
+ # return img
72
+
73
+
74
+ def nms(x, t, s):
75
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
76
+
77
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
78
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
79
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
80
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
81
+
82
+ y = np.zeros_like(x)
83
+
84
+ for f in [f1, f2, f3, f4]:
85
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
86
+
87
+ z = np.zeros_like(y, dtype=np.uint8)
88
+ z[y > t] = 255
89
+ return z
90
+
91
+
92
+ def make_noise_disk(H, W, C, F):
93
+ noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
94
+ noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
95
+ noise = noise[F: F + H, F: F + W]
96
+ noise -= np.min(noise)
97
+ noise /= np.max(noise)
98
+ if C == 1:
99
+ noise = noise[:, :, None]
100
+ return noise
101
+
102
+
103
+ def min_max_norm(x):
104
+ x -= np.min(x)
105
+ x /= np.maximum(np.max(x), 1e-5)
106
+ return x
107
+
108
+
109
+ def safe_step(x, step=2):
110
+ y = x.astype(np.float32) * float(step + 1)
111
+ y = y.astype(np.int32).astype(np.float32) / float(step)
112
+ return y
113
+
114
+
115
+ def img2mask(img, H, W, low=10, high=90):
116
+ assert img.ndim == 3 or img.ndim == 2
117
+ assert img.dtype == np.uint8
118
+
119
+ if img.ndim == 3:
120
+ y = img[:, :, random.randrange(0, img.shape[2])]
121
+ else:
122
+ y = img
123
+
124
+ y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
125
+
126
+ if random.uniform(0, 1) < 0.5:
127
+ y = 255 - y
128
+
129
+ return y < np.percentile(y, random.randrange(low, high))
app.py CHANGED
@@ -1,104 +1,118 @@
1
  import spaces
2
  import random
3
  import torch
 
 
 
4
  from huggingface_hub import snapshot_download
5
- from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
6
- from kolors.pipelines import pipeline_stable_diffusion_xl_chatglm_256_ipadapter, pipeline_stable_diffusion_xl_chatglm_256
 
7
  from kolors.models.modeling_chatglm import ChatGLMModel
8
  from kolors.models.tokenization_chatglm import ChatGLMTokenizer
9
- from kolors.models import unet_2d_condition
10
- from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
11
- import gradio as gr
12
- import numpy as np
 
 
 
 
13
 
14
  device = "cuda"
15
  ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
16
- ckpt_IPA_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus")
 
17
 
18
  text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
19
  tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
20
  vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
21
  scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
22
- unet_t2i = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
23
- unet_i2i = unet_2d_condition.UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
24
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(f'{ckpt_IPA_dir}/image_encoder',ignore_mismatched_sizes=True).to(dtype=torch.float16, device=device)
25
- ip_img_size = 336
26
- clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
27
 
28
- pipe_t2i = pipeline_stable_diffusion_xl_chatglm_256.StableDiffusionXLPipeline(
29
  vae=vae,
30
- text_encoder=text_encoder,
31
- tokenizer=tokenizer,
32
- unet=unet_t2i,
33
- scheduler=scheduler,
 
34
  force_zeros_for_empty_prompt=False
35
- ).to(device)
36
 
37
- pipe_i2i = pipeline_stable_diffusion_xl_chatglm_256_ipadapter.StableDiffusionXLPipeline(
38
  vae=vae,
 
39
  text_encoder=text_encoder,
40
  tokenizer=tokenizer,
41
- unet=unet_i2i,
42
  scheduler=scheduler,
43
- image_encoder=image_encoder,
44
- feature_extractor=clip_image_processor,
45
  force_zeros_for_empty_prompt=False
46
- ).to(device)
47
 
48
- if hasattr(pipe_i2i.unet, 'encoder_hid_proj'):
49
- pipe_i2i.unet.text_encoder_hid_proj = pipe_i2i.unet.encoder_hid_proj
50
-
51
- pipe_i2i.load_ip_adapter(f'{ckpt_IPA_dir}' , subfolder="", weight_name=["ip_adapter_plus_general.bin"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  MAX_SEED = np.iinfo(np.int32).max
54
  MAX_IMAGE_SIZE = 1024
55
 
56
  @spaces.GPU
57
  def infer(prompt,
58
- ip_adapter_image = None,
59
- ip_adapter_scale = 0.5,
60
  negative_prompt = "",
61
  seed = 0,
62
- randomize_seed = False,
63
- width = 1024,
64
- height = 1024,
65
- guidance_scale = 5.0,
66
- num_inference_steps = 25
67
- ):
 
68
  if randomize_seed:
69
  seed = random.randint(0, MAX_SEED)
70
  generator = torch.Generator().manual_seed(seed)
71
-
72
- if ip_adapter_image is None:
73
- pipe_t2i.to(device)
74
- image = pipe_t2i(
75
- prompt = prompt,
76
- negative_prompt = negative_prompt,
77
- guidance_scale = guidance_scale,
78
- num_inference_steps = num_inference_steps,
79
- width = width,
80
- height = height,
81
- generator = generator
82
- ).images[0]
83
- return image
84
  else:
85
- pipe_i2i.to(device)
86
- image_encoder.to(device)
87
- pipe_i2i.image_encoder = image_encoder
88
- pipe_i2i.set_ip_adapter_scale([ip_adapter_scale])
89
- image = pipe_i2i(
90
- prompt=prompt ,
91
- ip_adapter_image=[ip_adapter_image],
92
- negative_prompt=negative_prompt,
93
- height=height,
94
- width=width,
95
- num_inference_steps=num_inference_steps,
96
- guidance_scale=guidance_scale,
97
- num_images_per_prompt=1,
98
- generator=generator
99
- ).images[0]
100
- return image
101
-
102
  examples = [
103
 
104
  ]
@@ -130,12 +144,19 @@ with gr.Blocks(css=css) as Kolors:
130
  lines=2
131
  )
132
  with gr.Row():
133
- ip_adapter_image = gr.Image(label="Image Prompt (optional)", type="pil")
 
 
 
 
 
 
134
  with gr.Accordion("Advanced Settings", open=False):
135
  negative_prompt = gr.Textbox(
136
  label="Negative prompt",
137
  placeholder="Enter a negative prompt",
138
  visible=True,
 
139
  )
140
  seed = gr.Slider(
141
  label="Seed",
@@ -145,62 +166,61 @@ with gr.Blocks(css=css) as Kolors:
145
  value=0,
146
  )
147
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
148
- with gr.Row():
149
- width = gr.Slider(
150
- label="Width",
151
- minimum=256,
152
- maximum=MAX_IMAGE_SIZE,
153
- step=32,
154
- value=1024,
155
- )
156
- height = gr.Slider(
157
- label="Height",
158
- minimum=256,
159
- maximum=MAX_IMAGE_SIZE,
160
- step=32,
161
- value=1024,
162
- )
163
  with gr.Row():
164
  guidance_scale = gr.Slider(
165
  label="Guidance scale",
166
  minimum=0.0,
167
  maximum=10.0,
168
  step=0.1,
169
- value=5.0,
170
  )
171
  num_inference_steps = gr.Slider(
172
  label="Number of inference steps",
173
  minimum=10,
174
  maximum=50,
175
  step=1,
176
- value=25,
177
  )
178
  with gr.Row():
179
- ip_adapter_scale = gr.Slider(
180
- label="Image influence scale",
181
- info="Use 1 for creating variations",
182
  minimum=0.0,
183
  maximum=1.0,
184
- step=0.05,
185
- value=0.5,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  )
187
  with gr.Row():
188
  run_button = gr.Button("Run")
189
 
190
  with gr.Column(elem_id="col-right"):
191
- result = gr.Image(label="Result", show_label=False)
192
 
193
  with gr.Row():
194
  gr.Examples(
195
  fn = infer,
196
  examples = examples,
197
- inputs = [prompt, ip_adapter_image, ip_adapter_scale],
198
  outputs = [result]
199
  )
200
 
201
  run_button.click(
202
  fn = infer,
203
- inputs = [prompt, ip_adapter_image, ip_adapter_scale, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
204
  outputs = [result]
205
  )
206
 
 
1
  import spaces
2
  import random
3
  import torch
4
+ import cv2
5
+ import gradio as gr
6
+ import numpy as np
7
  from huggingface_hub import snapshot_download
8
+ from transformers import CLIPVisionModelWithProjection,CLIPImageProcessor
9
+ from diffusers.utils import load_image
10
+ from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import StableDiffusionXLControlNetImg2ImgPipeline
11
  from kolors.models.modeling_chatglm import ChatGLMModel
12
  from kolors.models.tokenization_chatglm import ChatGLMTokenizer
13
+ from kolors.models.controlnet import ControlNetModel
14
+ from diffusers import AutoencoderKL
15
+ from kolors.models.unet_2d_condition import UNet2DConditionModel
16
+ from diffusers import EulerDiscreteScheduler
17
+ from PIL import Image
18
+ from annotator.midas import MidasDetector
19
+ from annotator.util import resize_image, HWC3
20
+
21
 
22
  device = "cuda"
23
  ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
24
+ ckpt_dir_depth = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Depth")
25
+ ckpt_dir_canny = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Canny")
26
 
27
  text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
28
  tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
29
  vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
30
  scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
31
+ unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
32
+ controlnet_depth = ControlNetModel.from_pretrained(f"{ckpt_dir_depth}", revision=None).half().to(device)
33
+ controlnet_canny = ControlNetModel.from_pretrained(f"{ckpt_dir_canny}", revision=None).half().to(device)
 
 
34
 
35
+ pipe_depth = StableDiffusionXLControlNetImg2ImgPipeline(
36
  vae=vae,
37
+ controlnet = controlnet_depth,
38
+ text_encoder=text_encoder,
39
+ tokenizer=tokenizer,
40
+ unet=unet,
41
+ scheduler=scheduler,
42
  force_zeros_for_empty_prompt=False
43
+ )
44
 
45
+ pipe_canny = StableDiffusionXLControlNetImg2ImgPipeline(
46
  vae=vae,
47
+ controlnet = controlnet_canny,
48
  text_encoder=text_encoder,
49
  tokenizer=tokenizer,
50
+ unet=unet,
51
  scheduler=scheduler,
 
 
52
  force_zeros_for_empty_prompt=False
53
+ )
54
 
55
+ @spaces.GPU
56
+ def process_canny_condition(image, canny_threods=[100,200]):
57
+ np_image = image.copy()
58
+ np_image = cv2.Canny(np_image, canny_threods[0], canny_threods[1])
59
+ np_image = np_image[:, :, None]
60
+ np_image = np.concatenate([np_image, np_image, np_image], axis=2)
61
+ np_image = HWC3(np_image)
62
+ return Image.fromarray(np_image)
63
+
64
+ model_midas = MidasDetector()
65
+
66
+ @spaces.GPU
67
+ def process_depth_condition_midas(img, res = 1024):
68
+ h,w,_ = img.shape
69
+ img = resize_image(HWC3(img), res)
70
+ result = HWC3(model_midas(img))
71
+ result = cv2.resize(result, (w,h))
72
+ return Image.fromarray(result)
73
 
74
  MAX_SEED = np.iinfo(np.int32).max
75
  MAX_IMAGE_SIZE = 1024
76
 
77
  @spaces.GPU
78
  def infer(prompt,
79
+ image = None,
80
+ controlnet_type = "Depth",
81
  negative_prompt = "",
82
  seed = 0,
83
+ randomize_seed = False,
84
+ guidance_scale = 6.0,
85
+ num_inference_steps = 50,
86
+ controlnet_conditioning_scale = 0.7,
87
+ control_guidance_end = 0.9,
88
+ strength = 1.0
89
+ ):
90
  if randomize_seed:
91
  seed = random.randint(0, MAX_SEED)
92
  generator = torch.Generator().manual_seed(seed)
93
+ init_image = resize_image(image, MAX_IMAGE_SIZE)
94
+ if controlnet_type == "Depth":
95
+ pipe = pipe_depth.to("cuda")
96
+ condi_img = process_depth_condition_midas( np.array(init_image), MAX_IMAGE_SIZE)
97
+ elif controlnet_type == "Canny":
98
+ pipe = pipe_canny.to("cuda")
99
+ condi_img = process_canny_condition(np.array(init_image))
 
 
 
 
 
 
100
  else:
101
+ return None
102
+ image = pipe(
103
+ prompt= prompt ,
104
+ image = init_image,
105
+ controlnet_conditioning_scale = controlnet_conditioning_scale,
106
+ control_guidance_end = control_guidance_end,
107
+ strength= strength ,
108
+ control_image = condi_img,
109
+ negative_prompt= negative_prompt ,
110
+ num_inference_steps= num_inference_steps,
111
+ guidance_scale= guidance_scale,
112
+ num_images_per_prompt=1,
113
+ generator=generator,
114
+ ).images[0]
115
+ return [condi_img, image]
 
 
116
  examples = [
117
 
118
  ]
 
144
  lines=2
145
  )
146
  with gr.Row():
147
+ controlnet_type = gr.Dropdown(
148
+ ["Depth", "Canny"],
149
+ label = "Controlnet",
150
+ value="Depth"
151
+ )
152
+ with gr.Row():
153
+ image = gr.Image(label="Image", type="pil")
154
  with gr.Accordion("Advanced Settings", open=False):
155
  negative_prompt = gr.Textbox(
156
  label="Negative prompt",
157
  placeholder="Enter a negative prompt",
158
  visible=True,
159
+ value="nsfwοΌŒθ„Έιƒ¨ι˜΄ε½±οΌŒδ½Žεˆ†θΎ¨ηŽ‡οΌŒjpegδΌͺε½±γ€ζ¨‘η³Šγ€η³Ÿη³•οΌŒι»‘θ„ΈοΌŒιœ“θ™Ήη―"
160
  )
161
  seed = gr.Slider(
162
  label="Seed",
 
166
  value=0,
167
  )
168
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  with gr.Row():
170
  guidance_scale = gr.Slider(
171
  label="Guidance scale",
172
  minimum=0.0,
173
  maximum=10.0,
174
  step=0.1,
175
+ value=6.0,
176
  )
177
  num_inference_steps = gr.Slider(
178
  label="Number of inference steps",
179
  minimum=10,
180
  maximum=50,
181
  step=1,
182
+ value=30,
183
  )
184
  with gr.Row():
185
+ controlnet_conditioning_scale = gr.Slider(
186
+ label="Controlnet Conditioning Scale",
 
187
  minimum=0.0,
188
  maximum=1.0,
189
+ step=0.1,
190
+ value=0.7,
191
+ )
192
+ control_guidance_end = gr.Slider(
193
+ label="Control Guidance End",
194
+ minimum=0.0,
195
+ maximum=1.0,
196
+ step=0.1,
197
+ value=0.9,
198
+ )
199
+ with gr.Row():
200
+ strength = gr.Slider(
201
+ label="Strength",
202
+ minimum=0.0,
203
+ maximum=1.0,
204
+ step=0.1,
205
+ value=1.0,
206
  )
207
  with gr.Row():
208
  run_button = gr.Button("Run")
209
 
210
  with gr.Column(elem_id="col-right"):
211
+ result = gr.Gallery(label="Result", show_label=False, columns=2)
212
 
213
  with gr.Row():
214
  gr.Examples(
215
  fn = infer,
216
  examples = examples,
217
+ inputs = [prompt, image, controlnet_type],
218
  outputs = [result]
219
  )
220
 
221
  run_button.click(
222
  fn = infer,
223
+ inputs = [prompt, image, controlnet_type, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength],
224
  outputs = [result]
225
  )
226
 
kolors/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (135 Bytes)
 
kolors/models/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (142 Bytes)
 
kolors/models/__pycache__/configuration_chatglm.cpython-38.pyc DELETED
Binary file (1.6 kB)
 
kolors/models/__pycache__/modeling_chatglm.cpython-38.pyc DELETED
Binary file (33.5 kB)
 
kolors/models/__pycache__/tokenization_chatglm.cpython-38.pyc DELETED
Binary file (11.5 kB)
 
kolors/models/__pycache__/unet_2d_condition.cpython-38.pyc DELETED
Binary file (40.3 kB)
 
kolors/models/controlnet.py ADDED
@@ -0,0 +1,887 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ AttentionProcessor,
28
+ AttnAddedKVProcessor,
29
+ AttnProcessor,
30
+ )
31
+ from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+
34
+ try:
35
+ from diffusers.unets.unet_2d_blocks import (
36
+ CrossAttnDownBlock2D,
37
+ DownBlock2D,
38
+ UNetMidBlock2D,
39
+ UNetMidBlock2DCrossAttn,
40
+ get_down_block,
41
+ )
42
+ from diffusers.unets.unet_2d_condition import UNet2DConditionModel
43
+ except:
44
+ from diffusers.models.unets.unet_2d_blocks import (
45
+ CrossAttnDownBlock2D,
46
+ DownBlock2D,
47
+ UNetMidBlock2D,
48
+ UNetMidBlock2DCrossAttn,
49
+ get_down_block,
50
+ )
51
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
52
+
53
+
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+
58
+ @dataclass
59
+ class ControlNetOutput(BaseOutput):
60
+ """
61
+ The output of [`ControlNetModel`].
62
+
63
+ Args:
64
+ down_block_res_samples (`tuple[torch.Tensor]`):
65
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
66
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
67
+ used to condition the original UNet's downsampling activations.
68
+ mid_down_block_re_sample (`torch.Tensor`):
69
+ The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
70
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
71
+ Output can be used to condition the original UNet's middle block activation.
72
+ """
73
+
74
+ down_block_res_samples: Tuple[torch.Tensor]
75
+ mid_block_res_sample: torch.Tensor
76
+
77
+
78
+ class ControlNetConditioningEmbedding(nn.Module):
79
+ """
80
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
81
+ [11] to convert the entire dataset of 512 Γ— 512 images into smaller 64 Γ— 64 β€œlatent images” for stabilized
82
+ training. This requires ControlNets to convert image-based conditions to 64 Γ— 64 feature space to match the
83
+ convolution size. We use a tiny network E(Β·) of four convolution layers with 4 Γ— 4 kernels and 2 Γ— 2 strides
84
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
85
+ model) to encode image-space conditions ... into feature maps ..."
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ conditioning_embedding_channels: int,
91
+ conditioning_channels: int = 3,
92
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
93
+ ):
94
+ super().__init__()
95
+
96
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
97
+
98
+ self.blocks = nn.ModuleList([])
99
+
100
+ for i in range(len(block_out_channels) - 1):
101
+ channel_in = block_out_channels[i]
102
+ channel_out = block_out_channels[i + 1]
103
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
104
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
105
+
106
+ self.conv_out = zero_module(
107
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
108
+ )
109
+
110
+ def forward(self, conditioning):
111
+ embedding = self.conv_in(conditioning)
112
+ embedding = F.silu(embedding)
113
+
114
+ for block in self.blocks:
115
+ embedding = block(embedding)
116
+ embedding = F.silu(embedding)
117
+
118
+ embedding = self.conv_out(embedding)
119
+
120
+ return embedding
121
+
122
+
123
+ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
124
+ """
125
+ A ControlNet model.
126
+
127
+ Args:
128
+ in_channels (`int`, defaults to 4):
129
+ The number of channels in the input sample.
130
+ flip_sin_to_cos (`bool`, defaults to `True`):
131
+ Whether to flip the sin to cos in the time embedding.
132
+ freq_shift (`int`, defaults to 0):
133
+ The frequency shift to apply to the time embedding.
134
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
135
+ The tuple of downsample blocks to use.
136
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
137
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
138
+ The tuple of output channels for each block.
139
+ layers_per_block (`int`, defaults to 2):
140
+ The number of layers per block.
141
+ downsample_padding (`int`, defaults to 1):
142
+ The padding to use for the downsampling convolution.
143
+ mid_block_scale_factor (`float`, defaults to 1):
144
+ The scale factor to use for the mid block.
145
+ act_fn (`str`, defaults to "silu"):
146
+ The activation function to use.
147
+ norm_num_groups (`int`, *optional*, defaults to 32):
148
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
149
+ in post-processing.
150
+ norm_eps (`float`, defaults to 1e-5):
151
+ The epsilon to use for the normalization.
152
+ cross_attention_dim (`int`, defaults to 1280):
153
+ The dimension of the cross attention features.
154
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
155
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
156
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
157
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
158
+ encoder_hid_dim (`int`, *optional*, defaults to None):
159
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
160
+ dimension to `cross_attention_dim`.
161
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
162
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
163
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
164
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
165
+ The dimension of the attention heads.
166
+ use_linear_projection (`bool`, defaults to `False`):
167
+ class_embed_type (`str`, *optional*, defaults to `None`):
168
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
169
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
170
+ addition_embed_type (`str`, *optional*, defaults to `None`):
171
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
172
+ "text". "text" will use the `TextTimeEmbedding` layer.
173
+ num_class_embeds (`int`, *optional*, defaults to 0):
174
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
175
+ class conditioning with `class_embed_type` equal to `None`.
176
+ upcast_attention (`bool`, defaults to `False`):
177
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
178
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
179
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
180
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
181
+ `class_embed_type="projection"`.
182
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
183
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
184
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
185
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
186
+ global_pool_conditions (`bool`, defaults to `False`):
187
+ TODO(Patrick) - unused parameter.
188
+ addition_embed_type_num_heads (`int`, defaults to 64):
189
+ The number of heads to use for the `TextTimeEmbedding` layer.
190
+ """
191
+
192
+ _supports_gradient_checkpointing = True
193
+
194
+ @register_to_config
195
+ def __init__(
196
+ self,
197
+ in_channels: int = 4,
198
+ conditioning_channels: int = 3,
199
+ flip_sin_to_cos: bool = True,
200
+ freq_shift: int = 0,
201
+ down_block_types: Tuple[str, ...] = (
202
+ "CrossAttnDownBlock2D",
203
+ "CrossAttnDownBlock2D",
204
+ "CrossAttnDownBlock2D",
205
+ "DownBlock2D",
206
+ ),
207
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
208
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
209
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
210
+ layers_per_block: int = 2,
211
+ downsample_padding: int = 1,
212
+ mid_block_scale_factor: float = 1,
213
+ act_fn: str = "silu",
214
+ norm_num_groups: Optional[int] = 32,
215
+ norm_eps: float = 1e-5,
216
+ cross_attention_dim: int = 1280,
217
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
218
+ encoder_hid_dim: Optional[int] = None,
219
+ encoder_hid_dim_type: Optional[str] = None,
220
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
221
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
222
+ use_linear_projection: bool = False,
223
+ class_embed_type: Optional[str] = None,
224
+ addition_embed_type: Optional[str] = None,
225
+ addition_time_embed_dim: Optional[int] = None,
226
+ num_class_embeds: Optional[int] = None,
227
+ upcast_attention: bool = False,
228
+ resnet_time_scale_shift: str = "default",
229
+ projection_class_embeddings_input_dim: Optional[int] = None,
230
+ controlnet_conditioning_channel_order: str = "rgb",
231
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
232
+ global_pool_conditions: bool = False,
233
+ addition_embed_type_num_heads: int = 64,
234
+ ):
235
+ super().__init__()
236
+
237
+ # If `num_attention_heads` is not defined (which is the case for most models)
238
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
239
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
240
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
241
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
242
+ # which is why we correct for the naming here.
243
+ num_attention_heads = num_attention_heads or attention_head_dim
244
+
245
+ # Check inputs
246
+ if len(block_out_channels) != len(down_block_types):
247
+ raise ValueError(
248
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
249
+ )
250
+
251
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
252
+ raise ValueError(
253
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
254
+ )
255
+
256
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
257
+ raise ValueError(
258
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
259
+ )
260
+
261
+ if isinstance(transformer_layers_per_block, int):
262
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
263
+
264
+ # input
265
+ conv_in_kernel = 3
266
+ conv_in_padding = (conv_in_kernel - 1) // 2
267
+ self.conv_in = nn.Conv2d(
268
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
269
+ )
270
+
271
+ # time
272
+ time_embed_dim = block_out_channels[0] * 4
273
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
274
+ timestep_input_dim = block_out_channels[0]
275
+ self.time_embedding = TimestepEmbedding(
276
+ timestep_input_dim,
277
+ time_embed_dim,
278
+ act_fn=act_fn,
279
+ )
280
+
281
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
282
+ encoder_hid_dim_type = "text_proj"
283
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
284
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
285
+
286
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
287
+ raise ValueError(
288
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
289
+ )
290
+
291
+ if encoder_hid_dim_type == "text_proj":
292
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
293
+ elif encoder_hid_dim_type == "text_image_proj":
294
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
295
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
296
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
297
+ self.encoder_hid_proj = TextImageProjection(
298
+ text_embed_dim=encoder_hid_dim,
299
+ image_embed_dim=cross_attention_dim,
300
+ cross_attention_dim=cross_attention_dim,
301
+ )
302
+
303
+ elif encoder_hid_dim_type is not None:
304
+ raise ValueError(
305
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
306
+ )
307
+ else:
308
+ self.encoder_hid_proj = None
309
+
310
+ # class embedding
311
+ if class_embed_type is None and num_class_embeds is not None:
312
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
313
+ elif class_embed_type == "timestep":
314
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
315
+ elif class_embed_type == "identity":
316
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
317
+ elif class_embed_type == "projection":
318
+ if projection_class_embeddings_input_dim is None:
319
+ raise ValueError(
320
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
321
+ )
322
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
323
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
324
+ # 2. it projects from an arbitrary input dimension.
325
+ #
326
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
327
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
328
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
329
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
330
+ else:
331
+ self.class_embedding = None
332
+
333
+ if addition_embed_type == "text":
334
+ if encoder_hid_dim is not None:
335
+ text_time_embedding_from_dim = encoder_hid_dim
336
+ else:
337
+ text_time_embedding_from_dim = cross_attention_dim
338
+
339
+ self.add_embedding = TextTimeEmbedding(
340
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
341
+ )
342
+ elif addition_embed_type == "text_image":
343
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
344
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
345
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
346
+ self.add_embedding = TextImageTimeEmbedding(
347
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
348
+ )
349
+ elif addition_embed_type == "text_time":
350
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
351
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
352
+
353
+ elif addition_embed_type is not None:
354
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
355
+
356
+ # control net conditioning embedding
357
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
358
+ conditioning_embedding_channels=block_out_channels[0],
359
+ block_out_channels=conditioning_embedding_out_channels,
360
+ conditioning_channels=conditioning_channels,
361
+ )
362
+
363
+ self.down_blocks = nn.ModuleList([])
364
+ self.controlnet_down_blocks = nn.ModuleList([])
365
+
366
+ if isinstance(only_cross_attention, bool):
367
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
368
+
369
+ if isinstance(attention_head_dim, int):
370
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
371
+
372
+ if isinstance(num_attention_heads, int):
373
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
374
+
375
+ # down
376
+ output_channel = block_out_channels[0]
377
+
378
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
379
+ controlnet_block = zero_module(controlnet_block)
380
+ self.controlnet_down_blocks.append(controlnet_block)
381
+
382
+ for i, down_block_type in enumerate(down_block_types):
383
+ input_channel = output_channel
384
+ output_channel = block_out_channels[i]
385
+ is_final_block = i == len(block_out_channels) - 1
386
+
387
+ down_block = get_down_block(
388
+ down_block_type,
389
+ num_layers=layers_per_block,
390
+ transformer_layers_per_block=transformer_layers_per_block[i],
391
+ in_channels=input_channel,
392
+ out_channels=output_channel,
393
+ temb_channels=time_embed_dim,
394
+ add_downsample=not is_final_block,
395
+ resnet_eps=norm_eps,
396
+ resnet_act_fn=act_fn,
397
+ resnet_groups=norm_num_groups,
398
+ cross_attention_dim=cross_attention_dim,
399
+ num_attention_heads=num_attention_heads[i],
400
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
401
+ downsample_padding=downsample_padding,
402
+ use_linear_projection=use_linear_projection,
403
+ only_cross_attention=only_cross_attention[i],
404
+ upcast_attention=upcast_attention,
405
+ resnet_time_scale_shift=resnet_time_scale_shift,
406
+ )
407
+ self.down_blocks.append(down_block)
408
+
409
+ for _ in range(layers_per_block):
410
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
411
+ controlnet_block = zero_module(controlnet_block)
412
+ self.controlnet_down_blocks.append(controlnet_block)
413
+
414
+ if not is_final_block:
415
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
416
+ controlnet_block = zero_module(controlnet_block)
417
+ self.controlnet_down_blocks.append(controlnet_block)
418
+
419
+ # mid
420
+ mid_block_channel = block_out_channels[-1]
421
+
422
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
423
+ controlnet_block = zero_module(controlnet_block)
424
+ self.controlnet_mid_block = controlnet_block
425
+
426
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
427
+ self.mid_block = UNetMidBlock2DCrossAttn(
428
+ transformer_layers_per_block=transformer_layers_per_block[-1],
429
+ in_channels=mid_block_channel,
430
+ temb_channels=time_embed_dim,
431
+ resnet_eps=norm_eps,
432
+ resnet_act_fn=act_fn,
433
+ output_scale_factor=mid_block_scale_factor,
434
+ resnet_time_scale_shift=resnet_time_scale_shift,
435
+ cross_attention_dim=cross_attention_dim,
436
+ num_attention_heads=num_attention_heads[-1],
437
+ resnet_groups=norm_num_groups,
438
+ use_linear_projection=use_linear_projection,
439
+ upcast_attention=upcast_attention,
440
+ )
441
+ elif mid_block_type == "UNetMidBlock2D":
442
+ self.mid_block = UNetMidBlock2D(
443
+ in_channels=block_out_channels[-1],
444
+ temb_channels=time_embed_dim,
445
+ num_layers=0,
446
+ resnet_eps=norm_eps,
447
+ resnet_act_fn=act_fn,
448
+ output_scale_factor=mid_block_scale_factor,
449
+ resnet_groups=norm_num_groups,
450
+ resnet_time_scale_shift=resnet_time_scale_shift,
451
+ add_attention=False,
452
+ )
453
+ else:
454
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
455
+
456
+ @classmethod
457
+ def from_unet(
458
+ cls,
459
+ unet: UNet2DConditionModel,
460
+ controlnet_conditioning_channel_order: str = "rgb",
461
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
462
+ load_weights_from_unet: bool = True,
463
+ conditioning_channels: int = 3,
464
+ ):
465
+ r"""
466
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
467
+
468
+ Parameters:
469
+ unet (`UNet2DConditionModel`):
470
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
471
+ where applicable.
472
+ """
473
+ transformer_layers_per_block = (
474
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
475
+ )
476
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
477
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
478
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
479
+ addition_time_embed_dim = (
480
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
481
+ )
482
+
483
+ controlnet = cls(
484
+ encoder_hid_dim=encoder_hid_dim,
485
+ encoder_hid_dim_type=encoder_hid_dim_type,
486
+ addition_embed_type=addition_embed_type,
487
+ addition_time_embed_dim=addition_time_embed_dim,
488
+ transformer_layers_per_block=transformer_layers_per_block,
489
+ in_channels=unet.config.in_channels,
490
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
491
+ freq_shift=unet.config.freq_shift,
492
+ down_block_types=unet.config.down_block_types,
493
+ only_cross_attention=unet.config.only_cross_attention,
494
+ block_out_channels=unet.config.block_out_channels,
495
+ layers_per_block=unet.config.layers_per_block,
496
+ downsample_padding=unet.config.downsample_padding,
497
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
498
+ act_fn=unet.config.act_fn,
499
+ norm_num_groups=unet.config.norm_num_groups,
500
+ norm_eps=unet.config.norm_eps,
501
+ cross_attention_dim=unet.config.cross_attention_dim,
502
+ attention_head_dim=unet.config.attention_head_dim,
503
+ num_attention_heads=unet.config.num_attention_heads,
504
+ use_linear_projection=unet.config.use_linear_projection,
505
+ class_embed_type=unet.config.class_embed_type,
506
+ num_class_embeds=unet.config.num_class_embeds,
507
+ upcast_attention=unet.config.upcast_attention,
508
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
509
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
510
+ mid_block_type=unet.config.mid_block_type,
511
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
512
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
513
+ conditioning_channels=conditioning_channels,
514
+ )
515
+
516
+ if load_weights_from_unet:
517
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
518
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
519
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
520
+
521
+ if controlnet.class_embedding:
522
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
523
+
524
+ if hasattr(controlnet, "add_embedding"):
525
+ controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
526
+
527
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
528
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
529
+
530
+ return controlnet
531
+
532
+ @property
533
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
534
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
535
+ r"""
536
+ Returns:
537
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
538
+ indexed by its weight name.
539
+ """
540
+ # set recursively
541
+ processors = {}
542
+
543
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
544
+ if hasattr(module, "get_processor"):
545
+ processors[f"{name}.processor"] = module.get_processor()
546
+
547
+ for sub_name, child in module.named_children():
548
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
549
+
550
+ return processors
551
+
552
+ for name, module in self.named_children():
553
+ fn_recursive_add_processors(name, module, processors)
554
+
555
+ return processors
556
+
557
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
558
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
559
+ r"""
560
+ Sets the attention processor to use to compute attention.
561
+
562
+ Parameters:
563
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
564
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
565
+ for **all** `Attention` layers.
566
+
567
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
568
+ processor. This is strongly recommended when setting trainable attention processors.
569
+
570
+ """
571
+ count = len(self.attn_processors.keys())
572
+
573
+ if isinstance(processor, dict) and len(processor) != count:
574
+ raise ValueError(
575
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
576
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
577
+ )
578
+
579
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
580
+ if hasattr(module, "set_processor"):
581
+ if not isinstance(processor, dict):
582
+ module.set_processor(processor)
583
+ else:
584
+ module.set_processor(processor.pop(f"{name}.processor"))
585
+
586
+ for sub_name, child in module.named_children():
587
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
588
+
589
+ for name, module in self.named_children():
590
+ fn_recursive_attn_processor(name, module, processor)
591
+
592
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
593
+ def set_default_attn_processor(self):
594
+ """
595
+ Disables custom attention processors and sets the default attention implementation.
596
+ """
597
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
598
+ processor = AttnAddedKVProcessor()
599
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
600
+ processor = AttnProcessor()
601
+ else:
602
+ raise ValueError(
603
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
604
+ )
605
+
606
+ self.set_attn_processor(processor)
607
+
608
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
609
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
610
+ r"""
611
+ Enable sliced attention computation.
612
+
613
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
614
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
615
+
616
+ Args:
617
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
618
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
619
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
620
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
621
+ must be a multiple of `slice_size`.
622
+ """
623
+ sliceable_head_dims = []
624
+
625
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
626
+ if hasattr(module, "set_attention_slice"):
627
+ sliceable_head_dims.append(module.sliceable_head_dim)
628
+
629
+ for child in module.children():
630
+ fn_recursive_retrieve_sliceable_dims(child)
631
+
632
+ # retrieve number of attention layers
633
+ for module in self.children():
634
+ fn_recursive_retrieve_sliceable_dims(module)
635
+
636
+ num_sliceable_layers = len(sliceable_head_dims)
637
+
638
+ if slice_size == "auto":
639
+ # half the attention head size is usually a good trade-off between
640
+ # speed and memory
641
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
642
+ elif slice_size == "max":
643
+ # make smallest slice possible
644
+ slice_size = num_sliceable_layers * [1]
645
+
646
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
647
+
648
+ if len(slice_size) != len(sliceable_head_dims):
649
+ raise ValueError(
650
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
651
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
652
+ )
653
+
654
+ for i in range(len(slice_size)):
655
+ size = slice_size[i]
656
+ dim = sliceable_head_dims[i]
657
+ if size is not None and size > dim:
658
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
659
+
660
+ # Recursively walk through all the children.
661
+ # Any children which exposes the set_attention_slice method
662
+ # gets the message
663
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
664
+ if hasattr(module, "set_attention_slice"):
665
+ module.set_attention_slice(slice_size.pop())
666
+
667
+ for child in module.children():
668
+ fn_recursive_set_attention_slice(child, slice_size)
669
+
670
+ reversed_slice_size = list(reversed(slice_size))
671
+ for module in self.children():
672
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
673
+
674
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
675
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
676
+ module.gradient_checkpointing = value
677
+
678
+ def forward(
679
+ self,
680
+ sample: torch.Tensor,
681
+ timestep: Union[torch.Tensor, float, int],
682
+ encoder_hidden_states: torch.Tensor,
683
+ controlnet_cond: torch.Tensor,
684
+ conditioning_scale: float = 1.0,
685
+ class_labels: Optional[torch.Tensor] = None,
686
+ timestep_cond: Optional[torch.Tensor] = None,
687
+ attention_mask: Optional[torch.Tensor] = None,
688
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
689
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
690
+ guess_mode: bool = False,
691
+ return_dict: bool = True,
692
+ ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
693
+ """
694
+ The [`ControlNetModel`] forward method.
695
+
696
+ Args:
697
+ sample (`torch.Tensor`):
698
+ The noisy input tensor.
699
+ timestep (`Union[torch.Tensor, float, int]`):
700
+ The number of timesteps to denoise an input.
701
+ encoder_hidden_states (`torch.Tensor`):
702
+ The encoder hidden states.
703
+ controlnet_cond (`torch.Tensor`):
704
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
705
+ conditioning_scale (`float`, defaults to `1.0`):
706
+ The scale factor for ControlNet outputs.
707
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
708
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
709
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
710
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
711
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
712
+ embeddings.
713
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
714
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
715
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
716
+ negative values to the attention scores corresponding to "discard" tokens.
717
+ added_cond_kwargs (`dict`):
718
+ Additional conditions for the Stable Diffusion XL UNet.
719
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
720
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
721
+ guess_mode (`bool`, defaults to `False`):
722
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
723
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
724
+ return_dict (`bool`, defaults to `True`):
725
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
726
+
727
+ Returns:
728
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
729
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
730
+ returned where the first element is the sample tensor.
731
+ """
732
+ # check channel order
733
+ channel_order = self.config.controlnet_conditioning_channel_order
734
+
735
+ if channel_order == "rgb":
736
+ # in rgb order by default
737
+ ...
738
+ elif channel_order == "bgr":
739
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
740
+ else:
741
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
742
+
743
+ # prepare attention_mask
744
+ if attention_mask is not None:
745
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
746
+ attention_mask = attention_mask.unsqueeze(1)
747
+
748
+ #Todo
749
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
750
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
751
+
752
+ # 1. time
753
+ timesteps = timestep
754
+ if not torch.is_tensor(timesteps):
755
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
756
+ # This would be a good case for the `match` statement (Python 3.10+)
757
+ is_mps = sample.device.type == "mps"
758
+ if isinstance(timestep, float):
759
+ dtype = torch.float32 if is_mps else torch.float64
760
+ else:
761
+ dtype = torch.int32 if is_mps else torch.int64
762
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
763
+ elif len(timesteps.shape) == 0:
764
+ timesteps = timesteps[None].to(sample.device)
765
+
766
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
767
+ timesteps = timesteps.expand(sample.shape[0])
768
+
769
+ t_emb = self.time_proj(timesteps)
770
+
771
+ # timesteps does not contain any weights and will always return f32 tensors
772
+ # but time_embedding might actually be running in fp16. so we need to cast here.
773
+ # there might be better ways to encapsulate this.
774
+ t_emb = t_emb.to(dtype=sample.dtype)
775
+
776
+ emb = self.time_embedding(t_emb, timestep_cond)
777
+ aug_emb = None
778
+
779
+ if self.class_embedding is not None:
780
+ if class_labels is None:
781
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
782
+
783
+ if self.config.class_embed_type == "timestep":
784
+ class_labels = self.time_proj(class_labels)
785
+
786
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
787
+ emb = emb + class_emb
788
+
789
+ if self.config.addition_embed_type is not None:
790
+ if self.config.addition_embed_type == "text":
791
+ aug_emb = self.add_embedding(encoder_hidden_states)
792
+
793
+ elif self.config.addition_embed_type == "text_time":
794
+ if "text_embeds" not in added_cond_kwargs:
795
+ raise ValueError(
796
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
797
+ )
798
+ text_embeds = added_cond_kwargs.get("text_embeds")
799
+ if "time_ids" not in added_cond_kwargs:
800
+ raise ValueError(
801
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
802
+ )
803
+ time_ids = added_cond_kwargs.get("time_ids")
804
+ time_embeds = self.add_time_proj(time_ids.flatten())
805
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
806
+
807
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
808
+ add_embeds = add_embeds.to(emb.dtype)
809
+ aug_emb = self.add_embedding(add_embeds)
810
+
811
+ emb = emb + aug_emb if aug_emb is not None else emb
812
+
813
+ # 2. pre-process
814
+ sample = self.conv_in(sample)
815
+
816
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
817
+ sample = sample + controlnet_cond
818
+
819
+ # 3. down
820
+ down_block_res_samples = (sample,)
821
+ for downsample_block in self.down_blocks:
822
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
823
+ sample, res_samples = downsample_block(
824
+ hidden_states=sample,
825
+ temb=emb,
826
+ encoder_hidden_states=encoder_hidden_states,
827
+ attention_mask=attention_mask,
828
+ cross_attention_kwargs=cross_attention_kwargs,
829
+ )
830
+ else:
831
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
832
+
833
+ down_block_res_samples += res_samples
834
+
835
+ # 4. mid
836
+ if self.mid_block is not None:
837
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
838
+ sample = self.mid_block(
839
+ sample,
840
+ emb,
841
+ encoder_hidden_states=encoder_hidden_states,
842
+ attention_mask=attention_mask,
843
+ cross_attention_kwargs=cross_attention_kwargs,
844
+ )
845
+ else:
846
+ sample = self.mid_block(sample, emb)
847
+
848
+ # 5. Control net blocks
849
+
850
+ controlnet_down_block_res_samples = ()
851
+
852
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
853
+ down_block_res_sample = controlnet_block(down_block_res_sample)
854
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
855
+
856
+ down_block_res_samples = controlnet_down_block_res_samples
857
+
858
+ mid_block_res_sample = self.controlnet_mid_block(sample)
859
+
860
+ # 6. scaling
861
+ if guess_mode and not self.config.global_pool_conditions:
862
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
863
+ scales = scales * conditioning_scale
864
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
865
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
866
+ else:
867
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
868
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
869
+
870
+ if self.config.global_pool_conditions:
871
+ down_block_res_samples = [
872
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
873
+ ]
874
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
875
+
876
+ if not return_dict:
877
+ return (down_block_res_samples, mid_block_res_sample)
878
+
879
+ return ControlNetOutput(
880
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
881
+ )
882
+
883
+
884
+ def zero_module(module):
885
+ for p in module.parameters():
886
+ nn.init.zeros_(p)
887
+ return module
kolors/pipelines/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (145 Bytes)
 
kolors/pipelines/__pycache__/pipeline_stable_diffusion_xl_chatglm_256.cpython-38.pyc DELETED
Binary file (28.2 kB)
 
kolors/pipelines/__pycache__/pipeline_stable_diffusion_xl_chatglm_256_ipadapter.cpython-38.pyc DELETED
Binary file (30.3 kB)
 
kolors/pipelines/pipeline_controlnet_xl_kolors_img2img.py ADDED
@@ -0,0 +1,1365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from transformers import (
24
+ CLIPImageProcessor,
25
+ CLIPTextModel,
26
+ CLIPTextModelWithProjection,
27
+ CLIPTokenizer,
28
+ CLIPVisionModelWithProjection,
29
+ )
30
+
31
+ from diffusers.utils.import_utils import is_invisible_watermark_available
32
+
33
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
34
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
35
+ from diffusers.loaders import (
36
+ FromSingleFileMixin,
37
+ IPAdapterMixin,
38
+ StableDiffusionXLLoraLoaderMixin,
39
+ TextualInversionLoaderMixin,
40
+ )
41
+ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
42
+ from diffusers.models.attention_processor import (
43
+ AttnProcessor2_0,
44
+ XFormersAttnProcessor,
45
+ )
46
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
47
+ from diffusers.schedulers import KarrasDiffusionSchedulers
48
+ from diffusers.utils import (
49
+ USE_PEFT_BACKEND,
50
+ deprecate,
51
+ logging,
52
+ replace_example_docstring,
53
+ scale_lora_layers,
54
+ unscale_lora_layers,
55
+ )
56
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
57
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
58
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
59
+ from diffusers.pipelines.controlnet import MultiControlNetModel
60
+
61
+ from ..models.controlnet import ControlNetModel
62
+
63
+
64
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
65
+
66
+
67
+
68
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
69
+ def retrieve_latents(
70
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
71
+ ):
72
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
73
+ return encoder_output.latent_dist.sample(generator)
74
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
75
+ return encoder_output.latent_dist.mode()
76
+ elif hasattr(encoder_output, "latents"):
77
+ return encoder_output.latents
78
+ else:
79
+ raise AttributeError("Could not access latents of provided encoder_output")
80
+
81
+
82
+ class StableDiffusionXLControlNetImg2ImgPipeline(
83
+ DiffusionPipeline,
84
+ StableDiffusionMixin,
85
+ TextualInversionLoaderMixin,
86
+ StableDiffusionXLLoraLoaderMixin,
87
+ FromSingleFileMixin,
88
+ IPAdapterMixin,
89
+ ):
90
+ r"""
91
+ Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance.
92
+
93
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
94
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
95
+
96
+ The pipeline also inherits the following loading methods:
97
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
98
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
99
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
100
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
101
+
102
+ Args:
103
+ vae ([`AutoencoderKL`]):
104
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
105
+ text_encoder ([`CLIPTextModel`]):
106
+ Frozen text-encoder. Stable Diffusion uses the text portion of
107
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
108
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
109
+ tokenizer (`CLIPTokenizer`):
110
+ Tokenizer of class
111
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
112
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
113
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
114
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
115
+ as a list, the outputs from each ControlNet are added together to create one combined additional
116
+ conditioning.
117
+ scheduler ([`SchedulerMixin`]):
118
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
119
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
120
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
121
+ Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the
122
+ config of `stabilityai/stable-diffusion-xl-refiner-1-0`.
123
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
124
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
125
+ `stabilityai/stable-diffusion-xl-base-1-0`.
126
+ add_watermarker (`bool`, *optional*):
127
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
128
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
129
+ watermarker will be used.
130
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
131
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
132
+ """
133
+
134
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
135
+ _optional_components = [
136
+ "tokenizer",
137
+ "text_encoder",
138
+ "feature_extractor",
139
+ "image_encoder",
140
+ ]
141
+ _callback_tensor_inputs = [
142
+ "latents",
143
+ "prompt_embeds",
144
+ "negative_prompt_embeds",
145
+ "add_text_embeds",
146
+ "add_time_ids",
147
+ "negative_pooled_prompt_embeds",
148
+ "add_neg_time_ids",
149
+ ]
150
+
151
+ def __init__(
152
+ self,
153
+ vae: AutoencoderKL,
154
+ text_encoder: CLIPTextModel,
155
+ tokenizer: CLIPTokenizer,
156
+ unet: UNet2DConditionModel,
157
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
158
+ scheduler: KarrasDiffusionSchedulers,
159
+ requires_aesthetics_score: bool = False,
160
+ force_zeros_for_empty_prompt: bool = True,
161
+ feature_extractor: CLIPImageProcessor = None,
162
+ image_encoder: CLIPVisionModelWithProjection = None,
163
+ ):
164
+ super().__init__()
165
+
166
+ if isinstance(controlnet, (list, tuple)):
167
+ controlnet = MultiControlNetModel(controlnet)
168
+
169
+ self.register_modules(
170
+ vae=vae,
171
+ text_encoder=text_encoder,
172
+ tokenizer=tokenizer,
173
+ unet=unet,
174
+ controlnet=controlnet,
175
+ scheduler=scheduler,
176
+ feature_extractor=feature_extractor,
177
+ image_encoder=image_encoder,
178
+ )
179
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
180
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
181
+ self.control_image_processor = VaeImageProcessor(
182
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
183
+ )
184
+
185
+ self.watermark = None
186
+
187
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
188
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
189
+
190
+
191
+ def encode_prompt(
192
+ self,
193
+ prompt,
194
+ device: Optional[torch.device] = None,
195
+ num_images_per_prompt: int = 1,
196
+ do_classifier_free_guidance: bool = True,
197
+ negative_prompt=None,
198
+ prompt_embeds: Optional[torch.FloatTensor] = None,
199
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
200
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
201
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
202
+ lora_scale: Optional[float] = None,
203
+ ):
204
+ r"""
205
+ Encodes the prompt into text encoder hidden states.
206
+
207
+ Args:
208
+ prompt (`str` or `List[str]`, *optional*):
209
+ prompt to be encoded
210
+ device: (`torch.device`):
211
+ torch device
212
+ num_images_per_prompt (`int`):
213
+ number of images that should be generated per prompt
214
+ do_classifier_free_guidance (`bool`):
215
+ whether to use classifier free guidance or not
216
+ negative_prompt (`str` or `List[str]`, *optional*):
217
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
218
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
219
+ less than `1`).
220
+ prompt_embeds (`torch.FloatTensor`, *optional*):
221
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
222
+ provided, text embeddings will be generated from `prompt` input argument.
223
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
224
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
225
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
226
+ argument.
227
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
228
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
229
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
230
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
231
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
232
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
233
+ input argument.
234
+ lora_scale (`float`, *optional*):
235
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
236
+ """
237
+ # from IPython import embed; embed(); exit()
238
+ device = device or self._execution_device
239
+
240
+ # set lora scale so that monkey patched LoRA
241
+ # function of text encoder can correctly access it
242
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
243
+ self._lora_scale = lora_scale
244
+
245
+ if prompt is not None and isinstance(prompt, str):
246
+ batch_size = 1
247
+ elif prompt is not None and isinstance(prompt, list):
248
+ batch_size = len(prompt)
249
+ else:
250
+ batch_size = prompt_embeds.shape[0]
251
+
252
+ # Define tokenizers and text encoders
253
+ tokenizers = [self.tokenizer]
254
+ text_encoders = [self.text_encoder]
255
+
256
+ if prompt_embeds is None:
257
+ # textual inversion: procecss multi-vector tokens if necessary
258
+ prompt_embeds_list = []
259
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
260
+ if isinstance(self, TextualInversionLoaderMixin):
261
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
262
+
263
+ text_inputs = tokenizer(
264
+ prompt,
265
+ padding="max_length",
266
+ max_length=256,
267
+ truncation=True,
268
+ return_tensors="pt",
269
+ ).to('cuda')
270
+ output = text_encoder(
271
+ input_ids=text_inputs['input_ids'] ,
272
+ attention_mask=text_inputs['attention_mask'],
273
+ position_ids=text_inputs['position_ids'],
274
+ output_hidden_states=True)
275
+ prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
276
+ pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
277
+ bs_embed, seq_len, _ = prompt_embeds.shape
278
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
279
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
280
+
281
+ prompt_embeds_list.append(prompt_embeds)
282
+
283
+ # prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
284
+ prompt_embeds = prompt_embeds_list[0]
285
+
286
+ # get unconditional embeddings for classifier free guidance
287
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
288
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
289
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
290
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
291
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
292
+ # negative_prompt = negative_prompt or ""
293
+ uncond_tokens: List[str]
294
+ if negative_prompt is None:
295
+ uncond_tokens = [""] * batch_size
296
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
297
+ raise TypeError(
298
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
299
+ f" {type(prompt)}."
300
+ )
301
+ elif isinstance(negative_prompt, str):
302
+ uncond_tokens = [negative_prompt]
303
+ elif batch_size != len(negative_prompt):
304
+ raise ValueError(
305
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
306
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
307
+ " the batch size of `prompt`."
308
+ )
309
+ else:
310
+ uncond_tokens = negative_prompt
311
+
312
+ negative_prompt_embeds_list = []
313
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
314
+ # textual inversion: procecss multi-vector tokens if necessary
315
+ if isinstance(self, TextualInversionLoaderMixin):
316
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
317
+
318
+ max_length = prompt_embeds.shape[1]
319
+ uncond_input = tokenizer(
320
+ uncond_tokens,
321
+ padding="max_length",
322
+ max_length=max_length,
323
+ truncation=True,
324
+ return_tensors="pt",
325
+ ).to('cuda')
326
+ output = text_encoder(
327
+ input_ids=uncond_input['input_ids'] ,
328
+ attention_mask=uncond_input['attention_mask'],
329
+ position_ids=uncond_input['position_ids'],
330
+ output_hidden_states=True)
331
+ negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
332
+ negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
333
+
334
+ if do_classifier_free_guidance:
335
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
336
+ seq_len = negative_prompt_embeds.shape[1]
337
+
338
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
339
+
340
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
341
+ negative_prompt_embeds = negative_prompt_embeds.view(
342
+ batch_size * num_images_per_prompt, seq_len, -1
343
+ )
344
+
345
+ # For classifier free guidance, we need to do two forward passes.
346
+ # Here we concatenate the unconditional and text embeddings into a single batch
347
+ # to avoid doing two forward passes
348
+
349
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
350
+
351
+ # negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
352
+ negative_prompt_embeds = negative_prompt_embeds_list[0]
353
+
354
+ bs_embed = pooled_prompt_embeds.shape[0]
355
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
356
+ bs_embed * num_images_per_prompt, -1
357
+ )
358
+ if do_classifier_free_guidance:
359
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
360
+ bs_embed * num_images_per_prompt, -1
361
+ )
362
+
363
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
364
+
365
+
366
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
367
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
368
+ dtype = next(self.image_encoder.parameters()).dtype
369
+
370
+ if not isinstance(image, torch.Tensor):
371
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
372
+
373
+ image = image.to(device=device, dtype=dtype)
374
+ if output_hidden_states:
375
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
376
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
377
+ uncond_image_enc_hidden_states = self.image_encoder(
378
+ torch.zeros_like(image), output_hidden_states=True
379
+ ).hidden_states[-2]
380
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
381
+ num_images_per_prompt, dim=0
382
+ )
383
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
384
+ else:
385
+ image_embeds = self.image_encoder(image).image_embeds
386
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
387
+ uncond_image_embeds = torch.zeros_like(image_embeds)
388
+
389
+ return image_embeds, uncond_image_embeds
390
+
391
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
392
+ def prepare_ip_adapter_image_embeds(
393
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
394
+ ):
395
+ image_embeds = []
396
+ if do_classifier_free_guidance:
397
+ negative_image_embeds = []
398
+ if ip_adapter_image_embeds is None:
399
+ if not isinstance(ip_adapter_image, list):
400
+ ip_adapter_image = [ip_adapter_image]
401
+
402
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
403
+ raise ValueError(
404
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
405
+ )
406
+
407
+ for single_ip_adapter_image, image_proj_layer in zip(
408
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
409
+ ):
410
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
411
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
412
+ single_ip_adapter_image, device, 1, output_hidden_state
413
+ )
414
+
415
+ image_embeds.append(single_image_embeds[None, :])
416
+ if do_classifier_free_guidance:
417
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
418
+ else:
419
+ for single_image_embeds in ip_adapter_image_embeds:
420
+ if do_classifier_free_guidance:
421
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
422
+ negative_image_embeds.append(single_negative_image_embeds)
423
+ image_embeds.append(single_image_embeds)
424
+
425
+ ip_adapter_image_embeds = []
426
+ for i, single_image_embeds in enumerate(image_embeds):
427
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
428
+ if do_classifier_free_guidance:
429
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
430
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
431
+
432
+ single_image_embeds = single_image_embeds.to(device=device)
433
+ ip_adapter_image_embeds.append(single_image_embeds)
434
+
435
+ return ip_adapter_image_embeds
436
+
437
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
438
+ def prepare_extra_step_kwargs(self, generator, eta):
439
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
440
+ # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
441
+ # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
442
+ # and should be between [0, 1]
443
+
444
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
445
+ extra_step_kwargs = {}
446
+ if accepts_eta:
447
+ extra_step_kwargs["eta"] = eta
448
+
449
+ # check if the scheduler accepts generator
450
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
451
+ if accepts_generator:
452
+ extra_step_kwargs["generator"] = generator
453
+ return extra_step_kwargs
454
+
455
+ def check_inputs(
456
+ self,
457
+ prompt,
458
+ image,
459
+ strength,
460
+ num_inference_steps,
461
+ callback_steps,
462
+ negative_prompt=None,
463
+ prompt_embeds=None,
464
+ negative_prompt_embeds=None,
465
+ pooled_prompt_embeds=None,
466
+ negative_pooled_prompt_embeds=None,
467
+ ip_adapter_image=None,
468
+ ip_adapter_image_embeds=None,
469
+ controlnet_conditioning_scale=1.0,
470
+ control_guidance_start=0.0,
471
+ control_guidance_end=1.0,
472
+ callback_on_step_end_tensor_inputs=None,
473
+ ):
474
+ if strength < 0 or strength > 1:
475
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
476
+ if num_inference_steps is None:
477
+ raise ValueError("`num_inference_steps` cannot be None.")
478
+ elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
479
+ raise ValueError(
480
+ f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
481
+ f" {type(num_inference_steps)}."
482
+ )
483
+
484
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
485
+ raise ValueError(
486
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
487
+ f" {type(callback_steps)}."
488
+ )
489
+
490
+ if callback_on_step_end_tensor_inputs is not None and not all(
491
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
492
+ ):
493
+ raise ValueError(
494
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
495
+ )
496
+
497
+ if prompt is not None and prompt_embeds is not None:
498
+ raise ValueError(
499
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
500
+ " only forward one of the two."
501
+ )
502
+ elif prompt is None and prompt_embeds is None:
503
+ raise ValueError(
504
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
505
+ )
506
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
507
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
508
+
509
+ if negative_prompt is not None and negative_prompt_embeds is not None:
510
+ raise ValueError(
511
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
512
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
513
+ )
514
+
515
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
516
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
517
+ raise ValueError(
518
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
519
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
520
+ f" {negative_prompt_embeds.shape}."
521
+ )
522
+
523
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
524
+ raise ValueError(
525
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
526
+ )
527
+
528
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
529
+ raise ValueError(
530
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
531
+ )
532
+
533
+ # `prompt` needs more sophisticated handling when there are multiple
534
+ # conditionings.
535
+ if isinstance(self.controlnet, MultiControlNetModel):
536
+ if isinstance(prompt, list):
537
+ logger.warning(
538
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
539
+ " prompts. The conditionings will be fixed across the prompts."
540
+ )
541
+
542
+ # Check `image`
543
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
544
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
545
+ )
546
+ if (
547
+ isinstance(self.controlnet, ControlNetModel)
548
+ or is_compiled
549
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
550
+ ):
551
+ self.check_image(image, prompt, prompt_embeds)
552
+ elif (
553
+ isinstance(self.controlnet, MultiControlNetModel)
554
+ or is_compiled
555
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
556
+ ):
557
+ if not isinstance(image, list):
558
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
559
+
560
+ # When `image` is a nested list:
561
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
562
+ elif any(isinstance(i, list) for i in image):
563
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
564
+ elif len(image) != len(self.controlnet.nets):
565
+ raise ValueError(
566
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
567
+ )
568
+
569
+ for image_ in image:
570
+ self.check_image(image_, prompt, prompt_embeds)
571
+ else:
572
+ assert False
573
+
574
+ # Check `controlnet_conditioning_scale`
575
+ if (
576
+ isinstance(self.controlnet, ControlNetModel)
577
+ or is_compiled
578
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
579
+ ):
580
+ if not isinstance(controlnet_conditioning_scale, float):
581
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
582
+ elif (
583
+ isinstance(self.controlnet, MultiControlNetModel)
584
+ or is_compiled
585
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
586
+ ):
587
+ if isinstance(controlnet_conditioning_scale, list):
588
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
589
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
590
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
591
+ self.controlnet.nets
592
+ ):
593
+ raise ValueError(
594
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
595
+ " the same length as the number of controlnets"
596
+ )
597
+ else:
598
+ assert False
599
+
600
+ if not isinstance(control_guidance_start, (tuple, list)):
601
+ control_guidance_start = [control_guidance_start]
602
+
603
+ if not isinstance(control_guidance_end, (tuple, list)):
604
+ control_guidance_end = [control_guidance_end]
605
+
606
+ if len(control_guidance_start) != len(control_guidance_end):
607
+ raise ValueError(
608
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
609
+ )
610
+
611
+ if isinstance(self.controlnet, MultiControlNetModel):
612
+ if len(control_guidance_start) != len(self.controlnet.nets):
613
+ raise ValueError(
614
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
615
+ )
616
+
617
+ for start, end in zip(control_guidance_start, control_guidance_end):
618
+ if start >= end:
619
+ raise ValueError(
620
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
621
+ )
622
+ if start < 0.0:
623
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
624
+ if end > 1.0:
625
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
626
+
627
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
628
+ raise ValueError(
629
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
630
+ )
631
+
632
+ if ip_adapter_image_embeds is not None:
633
+ if not isinstance(ip_adapter_image_embeds, list):
634
+ raise ValueError(
635
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
636
+ )
637
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
638
+ raise ValueError(
639
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
640
+ )
641
+
642
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image
643
+ def check_image(self, image, prompt, prompt_embeds):
644
+ image_is_pil = isinstance(image, PIL.Image.Image)
645
+ image_is_tensor = isinstance(image, torch.Tensor)
646
+ image_is_np = isinstance(image, np.ndarray)
647
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
648
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
649
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
650
+
651
+ if (
652
+ not image_is_pil
653
+ and not image_is_tensor
654
+ and not image_is_np
655
+ and not image_is_pil_list
656
+ and not image_is_tensor_list
657
+ and not image_is_np_list
658
+ ):
659
+ raise TypeError(
660
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
661
+ )
662
+
663
+ if image_is_pil:
664
+ image_batch_size = 1
665
+ else:
666
+ image_batch_size = len(image)
667
+
668
+ if prompt is not None and isinstance(prompt, str):
669
+ prompt_batch_size = 1
670
+ elif prompt is not None and isinstance(prompt, list):
671
+ prompt_batch_size = len(prompt)
672
+ elif prompt_embeds is not None:
673
+ prompt_batch_size = prompt_embeds.shape[0]
674
+
675
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
676
+ raise ValueError(
677
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
678
+ )
679
+
680
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
681
+ def prepare_control_image(
682
+ self,
683
+ image,
684
+ width,
685
+ height,
686
+ batch_size,
687
+ num_images_per_prompt,
688
+ device,
689
+ dtype,
690
+ do_classifier_free_guidance=False,
691
+ guess_mode=False,
692
+ ):
693
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
694
+ image_batch_size = image.shape[0]
695
+
696
+ if image_batch_size == 1:
697
+ repeat_by = batch_size
698
+ else:
699
+ # image batch size is the same as prompt batch size
700
+ repeat_by = num_images_per_prompt
701
+
702
+ image = image.repeat_interleave(repeat_by, dim=0)
703
+
704
+ image = image.to(device=device, dtype=dtype)
705
+
706
+ if do_classifier_free_guidance and not guess_mode:
707
+ image = torch.cat([image] * 2)
708
+
709
+ return image
710
+
711
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
712
+ def get_timesteps(self, num_inference_steps, strength, device):
713
+ # get the original timestep using init_timestep
714
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
715
+
716
+ t_start = max(num_inference_steps - init_timestep, 0)
717
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
718
+ if hasattr(self.scheduler, "set_begin_index"):
719
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
720
+
721
+ return timesteps, num_inference_steps - t_start
722
+
723
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
724
+ def prepare_latents(
725
+ self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
726
+ ):
727
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
728
+ raise ValueError(
729
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
730
+ )
731
+
732
+ # Offload text encoder if `enable_model_cpu_offload` was enabled
733
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
734
+ torch.cuda.empty_cache()
735
+
736
+ image = image.to(device=device, dtype=dtype)
737
+
738
+ batch_size = batch_size * num_images_per_prompt
739
+
740
+ if image.shape[1] == 4:
741
+ init_latents = image
742
+
743
+ else:
744
+ # make sure the VAE is in float32 mode, as it overflows in float16
745
+ if self.vae.config.force_upcast:
746
+ image = image.float()
747
+ self.vae.to(dtype=torch.float32)
748
+
749
+ if isinstance(generator, list) and len(generator) != batch_size:
750
+ raise ValueError(
751
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
752
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
753
+ )
754
+
755
+ elif isinstance(generator, list):
756
+ init_latents = [
757
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
758
+ for i in range(batch_size)
759
+ ]
760
+ init_latents = torch.cat(init_latents, dim=0)
761
+ else:
762
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
763
+
764
+ if self.vae.config.force_upcast:
765
+ self.vae.to(dtype)
766
+
767
+ init_latents = init_latents.to(dtype)
768
+
769
+ init_latents = self.vae.config.scaling_factor * init_latents
770
+
771
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
772
+ # expand init_latents for batch_size
773
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
774
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
775
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
776
+ raise ValueError(
777
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
778
+ )
779
+ else:
780
+ init_latents = torch.cat([init_latents], dim=0)
781
+
782
+ if add_noise:
783
+ shape = init_latents.shape
784
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
785
+ # get latents
786
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
787
+
788
+ latents = init_latents
789
+
790
+ return latents
791
+
792
+
793
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
794
+ def prepare_latents_t2i(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
795
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
796
+ if isinstance(generator, list) and len(generator) != batch_size:
797
+ raise ValueError(
798
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
799
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
800
+ )
801
+
802
+ if latents is None:
803
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
804
+ else:
805
+ latents = latents.to(device)
806
+
807
+ # scale the initial noise by the standard deviation required by the scheduler
808
+ latents = latents * self.scheduler.init_noise_sigma
809
+ return latents
810
+
811
+
812
+
813
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
814
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
815
+
816
+ passed_add_embed_dim = (
817
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + 4096
818
+ )
819
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
820
+
821
+ if expected_add_embed_dim != passed_add_embed_dim:
822
+ raise ValueError(
823
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
824
+ )
825
+
826
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
827
+ return add_time_ids
828
+
829
+
830
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
831
+ def upcast_vae(self):
832
+ dtype = self.vae.dtype
833
+ self.vae.to(dtype=torch.float32)
834
+ use_torch_2_0_or_xformers = isinstance(
835
+ self.vae.decoder.mid_block.attentions[0].processor,
836
+ (
837
+ AttnProcessor2_0,
838
+ XFormersAttnProcessor,
839
+ ),
840
+ )
841
+ # if xformers or torch_2_0 is used attention block does not need
842
+ # to be in float32 which can save lots of memory
843
+ if use_torch_2_0_or_xformers:
844
+ self.vae.post_quant_conv.to(dtype)
845
+ self.vae.decoder.conv_in.to(dtype)
846
+ self.vae.decoder.mid_block.to(dtype)
847
+
848
+ @property
849
+ def guidance_scale(self):
850
+ return self._guidance_scale
851
+
852
+ @property
853
+ def clip_skip(self):
854
+ return self._clip_skip
855
+
856
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
857
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
858
+ # corresponds to doing no classifier free guidance.
859
+ @property
860
+ def do_classifier_free_guidance(self):
861
+ return self._guidance_scale > 1
862
+
863
+ @property
864
+ def cross_attention_kwargs(self):
865
+ return self._cross_attention_kwargs
866
+
867
+ @property
868
+ def num_timesteps(self):
869
+ return self._num_timesteps
870
+
871
+ @torch.no_grad()
872
+ def __call__(
873
+ self,
874
+ prompt: Union[str, List[str]] = None,
875
+ image: PipelineImageInput = None,
876
+ control_image: PipelineImageInput = None,
877
+ height: Optional[int] = None,
878
+ width: Optional[int] = None,
879
+ strength: float = 0.8,
880
+ num_inference_steps: int = 50,
881
+ guidance_scale: float = 5.0,
882
+ negative_prompt: Optional[Union[str, List[str]]] = None,
883
+ num_images_per_prompt: Optional[int] = 1,
884
+ eta: float = 0.0,
885
+ guess_mode: bool = False,
886
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
887
+ latents: Optional[torch.Tensor] = None,
888
+ prompt_embeds: Optional[torch.Tensor] = None,
889
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
890
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
891
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
892
+ ip_adapter_image: Optional[PipelineImageInput] = None,
893
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
894
+ output_type: Optional[str] = "pil",
895
+ return_dict: bool = True,
896
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
897
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
898
+ control_guidance_start: Union[float, List[float]] = 0.0,
899
+ control_guidance_end: Union[float, List[float]] = 1.0,
900
+ original_size: Tuple[int, int] = None,
901
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
902
+ target_size: Tuple[int, int] = None,
903
+ clip_skip: Optional[int] = None,
904
+ callback_on_step_end: Optional[
905
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
906
+ ] = None,
907
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
908
+ **kwargs,
909
+ ):
910
+ r"""
911
+ Function invoked when calling the pipeline for generation.
912
+
913
+ Args:
914
+ prompt (`str` or `List[str]`, *optional*):
915
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
916
+ instead.
917
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
918
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
919
+ The initial image will be used as the starting point for the image generation process. Can also accept
920
+ image latents as `image`, if passing latents directly, it will not be encoded again.
921
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
922
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
923
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
924
+ the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
925
+ be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
926
+ and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in
927
+ init, images must be passed as a list such that each element of the list can be correctly batched for
928
+ input to a single controlnet.
929
+ height (`int`, *optional*, defaults to the size of control_image):
930
+ The height in pixels of the generated image. Anything below 512 pixels won't work well for
931
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
932
+ and checkpoints that are not specifically fine-tuned on low resolutions.
933
+ width (`int`, *optional*, defaults to the size of control_image):
934
+ The width in pixels of the generated image. Anything below 512 pixels won't work well for
935
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
936
+ and checkpoints that are not specifically fine-tuned on low resolutions.
937
+ strength (`float`, *optional*, defaults to 0.8):
938
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
939
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
940
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
941
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
942
+ essentially ignores `image`.
943
+ num_inference_steps (`int`, *optional*, defaults to 50):
944
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
945
+ expense of slower inference.
946
+ guidance_scale (`float`, *optional*, defaults to 7.5):
947
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
948
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
949
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
950
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
951
+ usually at the expense of lower image quality.
952
+ negative_prompt (`str` or `List[str]`, *optional*):
953
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
954
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
955
+ less than `1`).
956
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
957
+ The number of images to generate per prompt.
958
+ eta (`float`, *optional*, defaults to 0.0):
959
+ Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
960
+ [`schedulers.DDIMScheduler`], will be ignored for others.
961
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
962
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
963
+ to make generation deterministic.
964
+ latents (`torch.Tensor`, *optional*):
965
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
966
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
967
+ tensor will ge generated by sampling using the supplied random `generator`.
968
+ prompt_embeds (`torch.Tensor`, *optional*):
969
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
970
+ provided, text embeddings will be generated from `prompt` input argument.
971
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
972
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
973
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
974
+ argument.
975
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
976
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
977
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
978
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
979
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
980
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
981
+ input argument.
982
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
983
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
984
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
985
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
986
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
987
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
988
+ output_type (`str`, *optional*, defaults to `"pil"`):
989
+ The output format of the generate image. Choose between
990
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
991
+ return_dict (`bool`, *optional*, defaults to `True`):
992
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
993
+ plain tuple.
994
+ cross_attention_kwargs (`dict`, *optional*):
995
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
996
+ `self.processor` in
997
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
998
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
999
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
1000
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
1001
+ corresponding scale as a list.
1002
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1003
+ The percentage of total steps at which the controlnet starts applying.
1004
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1005
+ The percentage of total steps at which the controlnet stops applying.
1006
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1007
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1008
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
1009
+ explained in section 2.2 of
1010
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1011
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1012
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1013
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1014
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1015
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1016
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1017
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
1018
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
1019
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1020
+ clip_skip (`int`, *optional*):
1021
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1022
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1023
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1024
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1025
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1026
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1027
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1028
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1029
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1030
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1031
+ `._callback_tensor_inputs` attribute of your pipeline class.
1032
+
1033
+ Examples:
1034
+
1035
+ Returns:
1036
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1037
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`
1038
+ containing the output images.
1039
+ """
1040
+
1041
+ callback = kwargs.pop("callback", None)
1042
+ callback_steps = kwargs.pop("callback_steps", None)
1043
+
1044
+ if callback is not None:
1045
+ deprecate(
1046
+ "callback",
1047
+ "1.0.0",
1048
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1049
+ )
1050
+ if callback_steps is not None:
1051
+ deprecate(
1052
+ "callback_steps",
1053
+ "1.0.0",
1054
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1055
+ )
1056
+
1057
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1058
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1059
+
1060
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1061
+
1062
+ # align format for control guidance
1063
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1064
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1065
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1066
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1067
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1068
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1069
+ control_guidance_start, control_guidance_end = (
1070
+ mult * [control_guidance_start],
1071
+ mult * [control_guidance_end],
1072
+ )
1073
+
1074
+ # from IPython import embed; embed()
1075
+ # 1. Check inputs. Raise error if not correct
1076
+ self.check_inputs(
1077
+ prompt,
1078
+ control_image,
1079
+ strength,
1080
+ num_inference_steps,
1081
+ callback_steps,
1082
+ negative_prompt,
1083
+ prompt_embeds,
1084
+ negative_prompt_embeds,
1085
+ pooled_prompt_embeds,
1086
+ negative_pooled_prompt_embeds,
1087
+ ip_adapter_image,
1088
+ ip_adapter_image_embeds,
1089
+ controlnet_conditioning_scale,
1090
+ control_guidance_start,
1091
+ control_guidance_end,
1092
+ callback_on_step_end_tensor_inputs,
1093
+ )
1094
+
1095
+ self._guidance_scale = guidance_scale
1096
+ self._clip_skip = clip_skip
1097
+ self._cross_attention_kwargs = cross_attention_kwargs
1098
+
1099
+ # 2. Define call parameters
1100
+ if prompt is not None and isinstance(prompt, str):
1101
+ batch_size = 1
1102
+ elif prompt is not None and isinstance(prompt, list):
1103
+ batch_size = len(prompt)
1104
+ else:
1105
+ batch_size = prompt_embeds.shape[0]
1106
+
1107
+ device = self._execution_device
1108
+
1109
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1110
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1111
+
1112
+ # 3.1. Encode input prompt
1113
+ text_encoder_lora_scale = (
1114
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1115
+ )
1116
+ (
1117
+ prompt_embeds,
1118
+ negative_prompt_embeds,
1119
+ pooled_prompt_embeds,
1120
+ negative_pooled_prompt_embeds,
1121
+ ) = self.encode_prompt(
1122
+ prompt,
1123
+ device,
1124
+ num_images_per_prompt,
1125
+ self.do_classifier_free_guidance,
1126
+ negative_prompt,
1127
+ prompt_embeds=prompt_embeds,
1128
+ negative_prompt_embeds=negative_prompt_embeds,
1129
+ pooled_prompt_embeds=pooled_prompt_embeds,
1130
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1131
+ lora_scale=text_encoder_lora_scale,
1132
+ )
1133
+
1134
+ # 3.2 Encode ip_adapter_image
1135
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1136
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1137
+ ip_adapter_image,
1138
+ ip_adapter_image_embeds,
1139
+ device,
1140
+ batch_size * num_images_per_prompt,
1141
+ self.do_classifier_free_guidance,
1142
+ )
1143
+
1144
+ # 4. Prepare image and controlnet_conditioning_image
1145
+ image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
1146
+
1147
+ if isinstance(controlnet, ControlNetModel):
1148
+ control_image = self.prepare_control_image(
1149
+ image=control_image,
1150
+ width=width,
1151
+ height=height,
1152
+ batch_size=batch_size * num_images_per_prompt,
1153
+ num_images_per_prompt=num_images_per_prompt,
1154
+ device=device,
1155
+ dtype=controlnet.dtype,
1156
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1157
+ guess_mode=guess_mode,
1158
+ )
1159
+ height, width = control_image.shape[-2:]
1160
+ elif isinstance(controlnet, MultiControlNetModel):
1161
+ control_images = []
1162
+
1163
+ for control_image_ in control_image:
1164
+ control_image_ = self.prepare_control_image(
1165
+ image=control_image_,
1166
+ width=width,
1167
+ height=height,
1168
+ batch_size=batch_size * num_images_per_prompt,
1169
+ num_images_per_prompt=num_images_per_prompt,
1170
+ device=device,
1171
+ dtype=controlnet.dtype,
1172
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1173
+ guess_mode=guess_mode,
1174
+ )
1175
+
1176
+ control_images.append(control_image_)
1177
+
1178
+ control_image = control_images
1179
+ height, width = control_image[0].shape[-2:]
1180
+ else:
1181
+ assert False
1182
+
1183
+ # 5. Prepare timesteps
1184
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1185
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
1186
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1187
+ self._num_timesteps = len(timesteps)
1188
+
1189
+ # 6. Prepare latent variables
1190
+
1191
+ num_channels_latents = self.unet.config.in_channels
1192
+ if latents is None:
1193
+ if strength >= 1.0:
1194
+ latents = self.prepare_latents_t2i(
1195
+ batch_size * num_images_per_prompt,
1196
+ num_channels_latents,
1197
+ height,
1198
+ width,
1199
+ prompt_embeds.dtype,
1200
+ device,
1201
+ generator,
1202
+ latents,
1203
+ )
1204
+ else:
1205
+ latents = self.prepare_latents(
1206
+ image,
1207
+ latent_timestep,
1208
+ batch_size,
1209
+ num_images_per_prompt,
1210
+ prompt_embeds.dtype,
1211
+ device,
1212
+ generator,
1213
+ True,
1214
+ )
1215
+
1216
+
1217
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1218
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1219
+
1220
+ # 7.1 Create tensor stating which controlnets to keep
1221
+ controlnet_keep = []
1222
+ for i in range(len(timesteps)):
1223
+ keeps = [
1224
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1225
+ for s, e in zip(control_guidance_start, control_guidance_end)
1226
+ ]
1227
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1228
+
1229
+ # 7.2 Prepare added time ids & embeddings
1230
+ if isinstance(control_image, list):
1231
+ original_size = original_size or control_image[0].shape[-2:]
1232
+ else:
1233
+ original_size = original_size or control_image.shape[-2:]
1234
+ target_size = target_size or (height, width)
1235
+
1236
+ # 7. Prepare added time ids & embeddings
1237
+ add_text_embeds = pooled_prompt_embeds
1238
+ add_time_ids = self._get_add_time_ids(
1239
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
1240
+ )
1241
+
1242
+ if self.do_classifier_free_guidance:
1243
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1244
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1245
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
1246
+
1247
+ prompt_embeds = prompt_embeds.to(device)
1248
+ add_text_embeds = add_text_embeds.to(device)
1249
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1250
+
1251
+ # 8. Denoising loop
1252
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1253
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1254
+ for i, t in enumerate(timesteps):
1255
+ # expand the latents if we are doing classifier free guidance
1256
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1257
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1258
+
1259
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1260
+
1261
+ # controlnet(s) inference
1262
+ if guess_mode and self.do_classifier_free_guidance:
1263
+ # Infer ControlNet only for the conditional batch.
1264
+ control_model_input = latents
1265
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1266
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1267
+ controlnet_added_cond_kwargs = {
1268
+ "text_embeds": add_text_embeds.chunk(2)[1],
1269
+ "time_ids": add_time_ids.chunk(2)[1],
1270
+ }
1271
+ else:
1272
+ control_model_input = latent_model_input
1273
+ controlnet_prompt_embeds = prompt_embeds
1274
+ controlnet_added_cond_kwargs = added_cond_kwargs
1275
+
1276
+ if isinstance(controlnet_keep[i], list):
1277
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1278
+ else:
1279
+ controlnet_cond_scale = controlnet_conditioning_scale
1280
+ if isinstance(controlnet_cond_scale, list):
1281
+ controlnet_cond_scale = controlnet_cond_scale[0]
1282
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1283
+
1284
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1285
+ control_model_input,
1286
+ t,
1287
+ encoder_hidden_states=controlnet_prompt_embeds,
1288
+ controlnet_cond=control_image,
1289
+ conditioning_scale=cond_scale,
1290
+ guess_mode=guess_mode,
1291
+ added_cond_kwargs=controlnet_added_cond_kwargs,
1292
+ return_dict=False,
1293
+ )
1294
+
1295
+ if guess_mode and self.do_classifier_free_guidance:
1296
+ # Infered ControlNet only for the conditional batch.
1297
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1298
+ # add 0 to the unconditional batch to keep it unchanged.
1299
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1300
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1301
+
1302
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1303
+ added_cond_kwargs["image_embeds"] = image_embeds
1304
+
1305
+ # predict the noise residual
1306
+ noise_pred = self.unet(
1307
+ latent_model_input,
1308
+ t,
1309
+ encoder_hidden_states=prompt_embeds,
1310
+ cross_attention_kwargs=self.cross_attention_kwargs,
1311
+ down_block_additional_residuals=down_block_res_samples,
1312
+ mid_block_additional_residual=mid_block_res_sample,
1313
+ added_cond_kwargs=added_cond_kwargs,
1314
+ return_dict=False,
1315
+ )[0]
1316
+
1317
+ # perform guidance
1318
+ if self.do_classifier_free_guidance:
1319
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1320
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1321
+
1322
+ # compute the previous noisy sample x_t -> x_t-1
1323
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1324
+
1325
+ # call the callback, if provided
1326
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1327
+ progress_bar.update()
1328
+ if callback is not None and i % callback_steps == 0:
1329
+ step_idx = i // getattr(self.scheduler, "order", 1)
1330
+ callback(step_idx, t, latents)
1331
+
1332
+ # If we do sequential model offloading, let's offload unet and controlnet
1333
+ # manually for max memory savings
1334
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1335
+ self.unet.to("cpu")
1336
+ self.controlnet.to("cpu")
1337
+ torch.cuda.empty_cache()
1338
+
1339
+ if not output_type == "latent":
1340
+ # make sure the VAE is in float32 mode, as it overflows in float16
1341
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1342
+
1343
+ if needs_upcasting:
1344
+ self.upcast_vae()
1345
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1346
+
1347
+ latents = latents / self.vae.config.scaling_factor
1348
+ image = self.vae.decode(latents, return_dict=False)[0]
1349
+
1350
+ # cast back to fp16 if needed
1351
+ if needs_upcasting:
1352
+ self.vae.to(dtype=torch.float16)
1353
+ else:
1354
+ image = latents
1355
+ return StableDiffusionXLPipelineOutput(images=image)
1356
+
1357
+ image = self.image_processor.postprocess(image, output_type=output_type)
1358
+
1359
+ # Offload all models
1360
+ self.maybe_free_model_hooks()
1361
+
1362
+ if not return_dict:
1363
+ return (image,)
1364
+
1365
+ return StableDiffusionXLPipelineOutput(images=image)
kolors/pipelines/pipeline_stable_diffusion_xl_chatglm_256_inpainting.py ADDED
@@ -0,0 +1,1790 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import PIL.Image
20
+ import torch
21
+ from transformers import (
22
+ CLIPImageProcessor,
23
+ CLIPTextModel,
24
+ CLIPTextModelWithProjection,
25
+ CLIPTokenizer,
26
+ CLIPVisionModelWithProjection,
27
+ )
28
+
29
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
30
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
31
+ from diffusers.loaders import (
32
+ FromSingleFileMixin,
33
+ IPAdapterMixin,
34
+ StableDiffusionXLLoraLoaderMixin,
35
+ TextualInversionLoaderMixin,
36
+ )
37
+ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
38
+ from diffusers.models.attention_processor import (
39
+ AttnProcessor2_0,
40
+ LoRAAttnProcessor2_0,
41
+ LoRAXFormersAttnProcessor,
42
+ XFormersAttnProcessor,
43
+ )
44
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
45
+ from diffusers.schedulers import KarrasDiffusionSchedulers
46
+ from diffusers.utils import (
47
+ USE_PEFT_BACKEND,
48
+ deprecate,
49
+ is_invisible_watermark_available,
50
+ is_torch_xla_available,
51
+ logging,
52
+ replace_example_docstring,
53
+ scale_lora_layers,
54
+ unscale_lora_layers,
55
+ )
56
+ from diffusers.utils.torch_utils import randn_tensor
57
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
58
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
59
+
60
+
61
+ if is_invisible_watermark_available():
62
+ from .watermark import StableDiffusionXLWatermarker
63
+
64
+ if is_torch_xla_available():
65
+ import torch_xla.core.xla_model as xm
66
+
67
+ XLA_AVAILABLE = True
68
+ else:
69
+ XLA_AVAILABLE = False
70
+
71
+
72
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
73
+
74
+
75
+ EXAMPLE_DOC_STRING = """
76
+ Examples:
77
+ ```py
78
+ >>> import torch
79
+ >>> from diffusers import StableDiffusionXLInpaintPipeline
80
+ >>> from diffusers.utils import load_image
81
+
82
+ >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
83
+ ... "stabilityai/stable-diffusion-xl-base-1.0",
84
+ ... torch_dtype=torch.float16,
85
+ ... variant="fp16",
86
+ ... use_safetensors=True,
87
+ ... )
88
+ >>> pipe.to("cuda")
89
+
90
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
91
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
92
+
93
+ >>> init_image = load_image(img_url).convert("RGB")
94
+ >>> mask_image = load_image(mask_url).convert("RGB")
95
+
96
+ >>> prompt = "A majestic tiger sitting on a bench"
97
+ >>> image = pipe(
98
+ ... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
99
+ ... ).images[0]
100
+ ```
101
+ """
102
+
103
+
104
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
105
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
106
+ """
107
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
108
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
109
+ """
110
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
111
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
112
+ # rescale the results from guidance (fixes overexposure)
113
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
114
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
115
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
116
+ return noise_cfg
117
+
118
+
119
+ def mask_pil_to_torch(mask, height, width):
120
+ # preprocess mask
121
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
122
+ mask = [mask]
123
+
124
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
125
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
126
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
127
+ mask = mask.astype(np.float32) / 255.0
128
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
129
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
130
+
131
+ mask = torch.from_numpy(mask)
132
+ return mask
133
+
134
+
135
+ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
136
+ """
137
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
138
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
139
+ ``image`` and ``1`` for the ``mask``.
140
+
141
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
142
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
143
+
144
+ Args:
145
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
146
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
147
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
148
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
149
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
150
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
151
+
152
+
153
+ Raises:
154
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
155
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
156
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
157
+ (ot the other way around).
158
+
159
+ Returns:
160
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
161
+ dimensions: ``batch x channels x height x width``.
162
+ """
163
+
164
+ # checkpoint. TOD(Yiyi) - need to clean this up later
165
+ deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
166
+ deprecate(
167
+ "prepare_mask_and_masked_image",
168
+ "0.30.0",
169
+ deprecation_message,
170
+ )
171
+ if image is None:
172
+ raise ValueError("`image` input cannot be undefined.")
173
+
174
+ if mask is None:
175
+ raise ValueError("`mask_image` input cannot be undefined.")
176
+
177
+ if isinstance(image, torch.Tensor):
178
+ if not isinstance(mask, torch.Tensor):
179
+ mask = mask_pil_to_torch(mask, height, width)
180
+
181
+ if image.ndim == 3:
182
+ image = image.unsqueeze(0)
183
+
184
+ # Batch and add channel dim for single mask
185
+ if mask.ndim == 2:
186
+ mask = mask.unsqueeze(0).unsqueeze(0)
187
+
188
+ # Batch single mask or add channel dim
189
+ if mask.ndim == 3:
190
+ # Single batched mask, no channel dim or single mask not batched but channel dim
191
+ if mask.shape[0] == 1:
192
+ mask = mask.unsqueeze(0)
193
+
194
+ # Batched masks no channel dim
195
+ else:
196
+ mask = mask.unsqueeze(1)
197
+
198
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
199
+ # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
200
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
201
+
202
+ # Check image is in [-1, 1]
203
+ # if image.min() < -1 or image.max() > 1:
204
+ # raise ValueError("Image should be in [-1, 1] range")
205
+
206
+ # Check mask is in [0, 1]
207
+ if mask.min() < 0 or mask.max() > 1:
208
+ raise ValueError("Mask should be in [0, 1] range")
209
+
210
+ # Binarize mask
211
+ mask[mask < 0.5] = 0
212
+ mask[mask >= 0.5] = 1
213
+
214
+ # Image as float32
215
+ image = image.to(dtype=torch.float32)
216
+ elif isinstance(mask, torch.Tensor):
217
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
218
+ else:
219
+ # preprocess image
220
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
221
+ image = [image]
222
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
223
+ # resize all images w.r.t passed height an width
224
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
225
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
226
+ image = np.concatenate(image, axis=0)
227
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
228
+ image = np.concatenate([i[None, :] for i in image], axis=0)
229
+
230
+ image = image.transpose(0, 3, 1, 2)
231
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
232
+
233
+ mask = mask_pil_to_torch(mask, height, width)
234
+ mask[mask < 0.5] = 0
235
+ mask[mask >= 0.5] = 1
236
+
237
+ if image.shape[1] == 4:
238
+ # images are in latent space and thus can't
239
+ # be masked set masked_image to None
240
+ # we assume that the checkpoint is not an inpainting
241
+ # checkpoint. TOD(Yiyi) - need to clean this up later
242
+ masked_image = None
243
+ else:
244
+ masked_image = image * (mask < 0.5)
245
+
246
+ # n.b. ensure backwards compatibility as old function does not return image
247
+ if return_image:
248
+ return mask, masked_image, image
249
+
250
+ return mask, masked_image
251
+
252
+
253
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
254
+ def retrieve_latents(
255
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
256
+ ):
257
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
258
+ return encoder_output.latent_dist.sample(generator)
259
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
260
+ return encoder_output.latent_dist.mode()
261
+ elif hasattr(encoder_output, "latents"):
262
+ return encoder_output.latents
263
+ else:
264
+ raise AttributeError("Could not access latents of provided encoder_output")
265
+
266
+
267
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
268
+ def retrieve_timesteps(
269
+ scheduler,
270
+ num_inference_steps: Optional[int] = None,
271
+ device: Optional[Union[str, torch.device]] = None,
272
+ timesteps: Optional[List[int]] = None,
273
+ sigmas: Optional[List[float]] = None,
274
+ **kwargs,
275
+ ):
276
+ """
277
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
278
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
279
+
280
+ Args:
281
+ scheduler (`SchedulerMixin`):
282
+ The scheduler to get timesteps from.
283
+ num_inference_steps (`int`):
284
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
285
+ must be `None`.
286
+ device (`str` or `torch.device`, *optional*):
287
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
288
+ timesteps (`List[int]`, *optional*):
289
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
290
+ `num_inference_steps` and `sigmas` must be `None`.
291
+ sigmas (`List[float]`, *optional*):
292
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
293
+ `num_inference_steps` and `timesteps` must be `None`.
294
+
295
+ Returns:
296
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
297
+ second element is the number of inference steps.
298
+ """
299
+ if timesteps is not None and sigmas is not None:
300
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
301
+ if timesteps is not None:
302
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
303
+ if not accepts_timesteps:
304
+ raise ValueError(
305
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
306
+ f" timestep schedules. Please check whether you are using the correct scheduler."
307
+ )
308
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
309
+ timesteps = scheduler.timesteps
310
+ num_inference_steps = len(timesteps)
311
+ elif sigmas is not None:
312
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
313
+ if not accept_sigmas:
314
+ raise ValueError(
315
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
316
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
317
+ )
318
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
319
+ timesteps = scheduler.timesteps
320
+ num_inference_steps = len(timesteps)
321
+ else:
322
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
323
+ timesteps = scheduler.timesteps
324
+ return timesteps, num_inference_steps
325
+
326
+
327
+ class StableDiffusionXLInpaintPipeline(
328
+ DiffusionPipeline,
329
+ StableDiffusionMixin,
330
+ TextualInversionLoaderMixin,
331
+ StableDiffusionXLLoraLoaderMixin,
332
+ FromSingleFileMixin,
333
+ IPAdapterMixin,
334
+ ):
335
+ r"""
336
+ Pipeline for text-to-image generation using Stable Diffusion XL.
337
+
338
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
339
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
340
+
341
+ The pipeline also inherits the following loading methods:
342
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
343
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
344
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
345
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
346
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
347
+
348
+ Args:
349
+ vae ([`AutoencoderKL`]):
350
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
351
+ text_encoder ([`CLIPTextModel`]):
352
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
353
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
354
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
355
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
356
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
357
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
358
+ specifically the
359
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
360
+ variant.
361
+ tokenizer (`CLIPTokenizer`):
362
+ Tokenizer of class
363
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
364
+ tokenizer_2 (`CLIPTokenizer`):
365
+ Second Tokenizer of class
366
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
367
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
368
+ scheduler ([`SchedulerMixin`]):
369
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
370
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
371
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
372
+ Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
373
+ of `stabilityai/stable-diffusion-xl-refiner-1-0`.
374
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
375
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
376
+ `stabilityai/stable-diffusion-xl-base-1-0`.
377
+ add_watermarker (`bool`, *optional*):
378
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
379
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
380
+ watermarker will be used.
381
+ """
382
+
383
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
384
+
385
+ _optional_components = [
386
+ "tokenizer",
387
+ "tokenizer_2",
388
+ "text_encoder",
389
+ "text_encoder_2",
390
+ "image_encoder",
391
+ "feature_extractor",
392
+ ]
393
+ _callback_tensor_inputs = [
394
+ "latents",
395
+ "prompt_embeds",
396
+ "negative_prompt_embeds",
397
+ "add_text_embeds",
398
+ "add_time_ids",
399
+ "negative_pooled_prompt_embeds",
400
+ "add_neg_time_ids",
401
+ "mask",
402
+ "masked_image_latents",
403
+ ]
404
+
405
+ def __init__(
406
+ self,
407
+ vae: AutoencoderKL,
408
+ text_encoder: CLIPTextModel,
409
+ tokenizer: CLIPTokenizer,
410
+ unet: UNet2DConditionModel,
411
+ scheduler: KarrasDiffusionSchedulers,
412
+ tokenizer_2: CLIPTokenizer = None,
413
+ text_encoder_2: CLIPTextModelWithProjection = None,
414
+ image_encoder: CLIPVisionModelWithProjection = None,
415
+ feature_extractor: CLIPImageProcessor = None,
416
+ requires_aesthetics_score: bool = False,
417
+ force_zeros_for_empty_prompt: bool = True,
418
+ add_watermarker: Optional[bool] = None,
419
+ ):
420
+ super().__init__()
421
+
422
+ self.register_modules(
423
+ vae=vae,
424
+ text_encoder=text_encoder,
425
+ text_encoder_2=text_encoder_2,
426
+ tokenizer=tokenizer,
427
+ tokenizer_2=tokenizer_2,
428
+ unet=unet,
429
+ image_encoder=image_encoder,
430
+ feature_extractor=feature_extractor,
431
+ scheduler=scheduler,
432
+ )
433
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
434
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
435
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
436
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
437
+ self.mask_processor = VaeImageProcessor(
438
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
439
+ )
440
+
441
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
442
+
443
+ if add_watermarker:
444
+ self.watermark = StableDiffusionXLWatermarker()
445
+ else:
446
+ self.watermark = None
447
+
448
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
449
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
450
+ dtype = next(self.image_encoder.parameters()).dtype
451
+
452
+ if not isinstance(image, torch.Tensor):
453
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
454
+
455
+ image = image.to(device=device, dtype=dtype)
456
+ if output_hidden_states:
457
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
458
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
459
+ uncond_image_enc_hidden_states = self.image_encoder(
460
+ torch.zeros_like(image), output_hidden_states=True
461
+ ).hidden_states[-2]
462
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
463
+ num_images_per_prompt, dim=0
464
+ )
465
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
466
+ else:
467
+ image_embeds = self.image_encoder(image).image_embeds
468
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
469
+ uncond_image_embeds = torch.zeros_like(image_embeds)
470
+
471
+ return image_embeds, uncond_image_embeds
472
+
473
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
474
+ def prepare_ip_adapter_image_embeds(
475
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
476
+ ):
477
+ if ip_adapter_image_embeds is None:
478
+ if not isinstance(ip_adapter_image, list):
479
+ ip_adapter_image = [ip_adapter_image]
480
+
481
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
482
+ raise ValueError(
483
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
484
+ )
485
+
486
+ image_embeds = []
487
+ for single_ip_adapter_image, image_proj_layer in zip(
488
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
489
+ ):
490
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
491
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
492
+ single_ip_adapter_image, device, 1, output_hidden_state
493
+ )
494
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
495
+ single_negative_image_embeds = torch.stack(
496
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
497
+ )
498
+
499
+ if do_classifier_free_guidance:
500
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
501
+ single_image_embeds = single_image_embeds.to(device)
502
+
503
+ image_embeds.append(single_image_embeds)
504
+ else:
505
+ repeat_dims = [1]
506
+ image_embeds = []
507
+ for single_image_embeds in ip_adapter_image_embeds:
508
+ if do_classifier_free_guidance:
509
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
510
+ single_image_embeds = single_image_embeds.repeat(
511
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
512
+ )
513
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
514
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
515
+ )
516
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
517
+ else:
518
+ single_image_embeds = single_image_embeds.repeat(
519
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
520
+ )
521
+ image_embeds.append(single_image_embeds)
522
+
523
+ return image_embeds
524
+
525
+ def encode_prompt(
526
+ self,
527
+ prompt,
528
+ device: Optional[torch.device] = None,
529
+ num_images_per_prompt: int = 1,
530
+ do_classifier_free_guidance: bool = True,
531
+ negative_prompt=None,
532
+ prompt_embeds: Optional[torch.FloatTensor] = None,
533
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
534
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
535
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
536
+ lora_scale: Optional[float] = None,
537
+ ):
538
+ r"""
539
+ Encodes the prompt into text encoder hidden states.
540
+
541
+ Args:
542
+ prompt (`str` or `List[str]`, *optional*):
543
+ prompt to be encoded
544
+ device: (`torch.device`):
545
+ torch device
546
+ num_images_per_prompt (`int`):
547
+ number of images that should be generated per prompt
548
+ do_classifier_free_guidance (`bool`):
549
+ whether to use classifier free guidance or not
550
+ negative_prompt (`str` or `List[str]`, *optional*):
551
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
552
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
553
+ less than `1`).
554
+ prompt_embeds (`torch.FloatTensor`, *optional*):
555
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
556
+ provided, text embeddings will be generated from `prompt` input argument.
557
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
558
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
559
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
560
+ argument.
561
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
562
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
563
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
564
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
565
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
566
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
567
+ input argument.
568
+ lora_scale (`float`, *optional*):
569
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
570
+ """
571
+ # from IPython import embed; embed(); exit()
572
+ device = device or self._execution_device
573
+
574
+ # set lora scale so that monkey patched LoRA
575
+ # function of text encoder can correctly access it
576
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
577
+ self._lora_scale = lora_scale
578
+
579
+ if prompt is not None and isinstance(prompt, str):
580
+ batch_size = 1
581
+ elif prompt is not None and isinstance(prompt, list):
582
+ batch_size = len(prompt)
583
+ else:
584
+ batch_size = prompt_embeds.shape[0]
585
+
586
+ # Define tokenizers and text encoders
587
+ tokenizers = [self.tokenizer]
588
+ text_encoders = [self.text_encoder]
589
+
590
+ if prompt_embeds is None:
591
+ # textual inversion: procecss multi-vector tokens if necessary
592
+ prompt_embeds_list = []
593
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
594
+ if isinstance(self, TextualInversionLoaderMixin):
595
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
596
+
597
+ text_inputs = tokenizer(
598
+ prompt,
599
+ padding="max_length",
600
+ max_length=256,
601
+ truncation=True,
602
+ return_tensors="pt",
603
+ ).to('cuda')
604
+ output = text_encoder(
605
+ input_ids=text_inputs['input_ids'] ,
606
+ attention_mask=text_inputs['attention_mask'],
607
+ position_ids=text_inputs['position_ids'],
608
+ output_hidden_states=True)
609
+ prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
610
+ text_proj = output.hidden_states[-1][-1, :, :].clone()
611
+ bs_embed, seq_len, _ = prompt_embeds.shape
612
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
613
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
614
+ prompt_embeds_list.append(prompt_embeds)
615
+
616
+ prompt_embeds = prompt_embeds_list[0]
617
+
618
+ # get unconditional embeddings for classifier free guidance
619
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
620
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
621
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
622
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
623
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
624
+ # negative_prompt = negative_prompt or ""
625
+ uncond_tokens: List[str]
626
+ if negative_prompt is None:
627
+ uncond_tokens = [""] * batch_size
628
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
629
+ raise TypeError(
630
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
631
+ f" {type(prompt)}."
632
+ )
633
+ elif isinstance(negative_prompt, str):
634
+ uncond_tokens = [negative_prompt]
635
+ elif batch_size != len(negative_prompt):
636
+ raise ValueError(
637
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
638
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
639
+ " the batch size of `prompt`."
640
+ )
641
+ else:
642
+ uncond_tokens = negative_prompt
643
+
644
+ negative_prompt_embeds_list = []
645
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
646
+ # textual inversion: procecss multi-vector tokens if necessary
647
+ if isinstance(self, TextualInversionLoaderMixin):
648
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
649
+
650
+ max_length = prompt_embeds.shape[1]
651
+ uncond_input = tokenizer(
652
+ uncond_tokens,
653
+ padding="max_length",
654
+ max_length=max_length,
655
+ truncation=True,
656
+ return_tensors="pt",
657
+ ).to('cuda')
658
+ output = text_encoder(
659
+ input_ids=uncond_input['input_ids'] ,
660
+ attention_mask=uncond_input['attention_mask'],
661
+ position_ids=uncond_input['position_ids'],
662
+ output_hidden_states=True)
663
+ negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
664
+ negative_text_proj = output.hidden_states[-1][-1, :, :].clone()
665
+
666
+ if do_classifier_free_guidance:
667
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
668
+ seq_len = negative_prompt_embeds.shape[1]
669
+
670
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
671
+
672
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
673
+ negative_prompt_embeds = negative_prompt_embeds.view(
674
+ batch_size * num_images_per_prompt, seq_len, -1
675
+ )
676
+
677
+ # For classifier free guidance, we need to do two forward passes.
678
+ # Here we concatenate the unconditional and text embeddings into a single batch
679
+ # to avoid doing two forward passes
680
+
681
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
682
+
683
+ negative_prompt_embeds = negative_prompt_embeds_list[0]
684
+
685
+ bs_embed = text_proj.shape[0]
686
+ text_proj = text_proj.repeat(1, num_images_per_prompt).view(
687
+ bs_embed * num_images_per_prompt, -1
688
+ )
689
+ negative_text_proj = negative_text_proj.repeat(1, num_images_per_prompt).view(
690
+ bs_embed * num_images_per_prompt, -1
691
+ )
692
+
693
+ return prompt_embeds, negative_prompt_embeds, text_proj, negative_text_proj
694
+
695
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
696
+ def prepare_extra_step_kwargs(self, generator, eta):
697
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
698
+ # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
699
+ # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
700
+ # and should be between [0, 1]
701
+
702
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
703
+ extra_step_kwargs = {}
704
+ if accepts_eta:
705
+ extra_step_kwargs["eta"] = eta
706
+
707
+ # check if the scheduler accepts generator
708
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
709
+ if accepts_generator:
710
+ extra_step_kwargs["generator"] = generator
711
+ return extra_step_kwargs
712
+
713
+ def check_inputs(
714
+ self,
715
+ prompt,
716
+ prompt_2,
717
+ image,
718
+ mask_image,
719
+ height,
720
+ width,
721
+ strength,
722
+ callback_steps,
723
+ output_type,
724
+ negative_prompt=None,
725
+ negative_prompt_2=None,
726
+ prompt_embeds=None,
727
+ negative_prompt_embeds=None,
728
+ ip_adapter_image=None,
729
+ ip_adapter_image_embeds=None,
730
+ callback_on_step_end_tensor_inputs=None,
731
+ padding_mask_crop=None,
732
+ ):
733
+ if strength < 0 or strength > 1:
734
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
735
+
736
+ if height % 8 != 0 or width % 8 != 0:
737
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
738
+
739
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
740
+ raise ValueError(
741
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
742
+ f" {type(callback_steps)}."
743
+ )
744
+
745
+ if callback_on_step_end_tensor_inputs is not None and not all(
746
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
747
+ ):
748
+ raise ValueError(
749
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
750
+ )
751
+
752
+ if prompt is not None and prompt_embeds is not None:
753
+ raise ValueError(
754
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
755
+ " only forward one of the two."
756
+ )
757
+ elif prompt_2 is not None and prompt_embeds is not None:
758
+ raise ValueError(
759
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
760
+ " only forward one of the two."
761
+ )
762
+ elif prompt is None and prompt_embeds is None:
763
+ raise ValueError(
764
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
765
+ )
766
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
767
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
768
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
769
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
770
+
771
+ if negative_prompt is not None and negative_prompt_embeds is not None:
772
+ raise ValueError(
773
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
774
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
775
+ )
776
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
777
+ raise ValueError(
778
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
779
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
780
+ )
781
+
782
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
783
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
784
+ raise ValueError(
785
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
786
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
787
+ f" {negative_prompt_embeds.shape}."
788
+ )
789
+ if padding_mask_crop is not None:
790
+ if not isinstance(image, PIL.Image.Image):
791
+ raise ValueError(
792
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
793
+ )
794
+ if not isinstance(mask_image, PIL.Image.Image):
795
+ raise ValueError(
796
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
797
+ f" {type(mask_image)}."
798
+ )
799
+ if output_type != "pil":
800
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
801
+
802
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
803
+ raise ValueError(
804
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
805
+ )
806
+
807
+ if ip_adapter_image_embeds is not None:
808
+ if not isinstance(ip_adapter_image_embeds, list):
809
+ raise ValueError(
810
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
811
+ )
812
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
813
+ raise ValueError(
814
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
815
+ )
816
+
817
+ def prepare_latents(
818
+ self,
819
+ batch_size,
820
+ num_channels_latents,
821
+ height,
822
+ width,
823
+ dtype,
824
+ device,
825
+ generator,
826
+ latents=None,
827
+ image=None,
828
+ timestep=None,
829
+ is_strength_max=True,
830
+ add_noise=True,
831
+ return_noise=False,
832
+ return_image_latents=False,
833
+ ):
834
+ shape = (
835
+ batch_size,
836
+ num_channels_latents,
837
+ int(height) // self.vae_scale_factor,
838
+ int(width) // self.vae_scale_factor,
839
+ )
840
+ if isinstance(generator, list) and len(generator) != batch_size:
841
+ raise ValueError(
842
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
843
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
844
+ )
845
+
846
+ if (image is None or timestep is None) and not is_strength_max:
847
+ raise ValueError(
848
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
849
+ "However, either the image or the noise timestep has not been provided."
850
+ )
851
+
852
+ if image.shape[1] == 4:
853
+ image_latents = image.to(device=device, dtype=dtype)
854
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
855
+ elif return_image_latents or (latents is None and not is_strength_max):
856
+ image = image.to(device=device, dtype=dtype)
857
+ image_latents = self._encode_vae_image(image=image, generator=generator)
858
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
859
+
860
+ if latents is None and add_noise:
861
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
862
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
863
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
864
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
865
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
866
+ elif add_noise:
867
+ noise = latents.to(device)
868
+ latents = noise * self.scheduler.init_noise_sigma
869
+ else:
870
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
871
+ latents = image_latents.to(device)
872
+
873
+ outputs = (latents,)
874
+
875
+ if return_noise:
876
+ outputs += (noise,)
877
+
878
+ if return_image_latents:
879
+ outputs += (image_latents,)
880
+
881
+ return outputs
882
+
883
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
884
+ dtype = image.dtype
885
+ if self.vae.config.force_upcast:
886
+ image = image.float()
887
+ self.vae.to(dtype=torch.float32)
888
+
889
+ if isinstance(generator, list):
890
+ image_latents = [
891
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
892
+ for i in range(image.shape[0])
893
+ ]
894
+ image_latents = torch.cat(image_latents, dim=0)
895
+ else:
896
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
897
+
898
+ if self.vae.config.force_upcast:
899
+ self.vae.to(dtype)
900
+
901
+ image_latents = image_latents.to(dtype)
902
+ image_latents = self.vae.config.scaling_factor * image_latents
903
+
904
+ return image_latents
905
+
906
+ def prepare_mask_latents(
907
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
908
+ ):
909
+ # resize the mask to latents shape as we concatenate the mask to the latents
910
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
911
+ # and half precision
912
+ mask = torch.nn.functional.interpolate(
913
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
914
+ )
915
+ mask = mask.to(device=device, dtype=dtype)
916
+
917
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
918
+ if mask.shape[0] < batch_size:
919
+ if not batch_size % mask.shape[0] == 0:
920
+ raise ValueError(
921
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
922
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
923
+ " of masks that you pass is divisible by the total requested batch size."
924
+ )
925
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
926
+
927
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
928
+
929
+ if masked_image is not None and masked_image.shape[1] == 4:
930
+ masked_image_latents = masked_image
931
+ else:
932
+ masked_image_latents = None
933
+
934
+ if masked_image is not None:
935
+ if masked_image_latents is None:
936
+ masked_image = masked_image.to(device=device, dtype=dtype)
937
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
938
+
939
+ if masked_image_latents.shape[0] < batch_size:
940
+ if not batch_size % masked_image_latents.shape[0] == 0:
941
+ raise ValueError(
942
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
943
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
944
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
945
+ )
946
+ masked_image_latents = masked_image_latents.repeat(
947
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
948
+ )
949
+
950
+ masked_image_latents = (
951
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
952
+ )
953
+
954
+ # aligning device to prevent device errors when concating it with the latent model input
955
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
956
+
957
+ return mask, masked_image_latents
958
+
959
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
960
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
961
+ # get the original timestep using init_timestep
962
+ if denoising_start is None:
963
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
964
+ t_start = max(num_inference_steps - init_timestep, 0)
965
+ else:
966
+ t_start = 0
967
+
968
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
969
+
970
+ # Strength is irrelevant if we directly request a timestep to start at;
971
+ # that is, strength is determined by the denoising_start instead.
972
+ if denoising_start is not None:
973
+ discrete_timestep_cutoff = int(
974
+ round(
975
+ self.scheduler.config.num_train_timesteps
976
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
977
+ )
978
+ )
979
+
980
+ num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
981
+ if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
982
+ # if the scheduler is a 2nd order scheduler we might have to do +1
983
+ # because `num_inference_steps` might be even given that every timestep
984
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
985
+ # mean that we cut the timesteps in the middle of the denoising step
986
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
987
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
988
+ num_inference_steps = num_inference_steps + 1
989
+
990
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
991
+ timesteps = timesteps[-num_inference_steps:]
992
+ return timesteps, num_inference_steps
993
+
994
+ return timesteps, num_inference_steps - t_start
995
+
996
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
997
+ def _get_add_time_ids(
998
+ self,
999
+ original_size,
1000
+ crops_coords_top_left,
1001
+ target_size,
1002
+ aesthetic_score,
1003
+ negative_aesthetic_score,
1004
+ negative_original_size,
1005
+ negative_crops_coords_top_left,
1006
+ negative_target_size,
1007
+ dtype,
1008
+ text_encoder_projection_dim=None,
1009
+ ):
1010
+ if self.config.requires_aesthetics_score:
1011
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
1012
+ add_neg_time_ids = list(
1013
+ negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
1014
+ )
1015
+ else:
1016
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
1017
+ add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
1018
+
1019
+ passed_add_embed_dim = (
1020
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + 4096
1021
+ )
1022
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
1023
+
1024
+ if (
1025
+ expected_add_embed_dim > passed_add_embed_dim
1026
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
1027
+ ):
1028
+ raise ValueError(
1029
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
1030
+ )
1031
+ elif (
1032
+ expected_add_embed_dim < passed_add_embed_dim
1033
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
1034
+ ):
1035
+ raise ValueError(
1036
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
1037
+ )
1038
+ elif expected_add_embed_dim != passed_add_embed_dim:
1039
+ raise ValueError(
1040
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
1041
+ )
1042
+
1043
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
1044
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
1045
+
1046
+ return add_time_ids, add_neg_time_ids
1047
+
1048
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
1049
+ def upcast_vae(self):
1050
+ dtype = self.vae.dtype
1051
+ self.vae.to(dtype=torch.float32)
1052
+ use_torch_2_0_or_xformers = isinstance(
1053
+ self.vae.decoder.mid_block.attentions[0].processor,
1054
+ (
1055
+ AttnProcessor2_0,
1056
+ XFormersAttnProcessor,
1057
+ LoRAXFormersAttnProcessor,
1058
+ LoRAAttnProcessor2_0,
1059
+ ),
1060
+ )
1061
+ # if xformers or torch_2_0 is used attention block does not need
1062
+ # to be in float32 which can save lots of memory
1063
+ if use_torch_2_0_or_xformers:
1064
+ self.vae.post_quant_conv.to(dtype)
1065
+ self.vae.decoder.conv_in.to(dtype)
1066
+ self.vae.decoder.mid_block.to(dtype)
1067
+
1068
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
1069
+ def get_guidance_scale_embedding(
1070
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
1071
+ ) -> torch.Tensor:
1072
+ """
1073
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
1074
+
1075
+ Args:
1076
+ w (`torch.Tensor`):
1077
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
1078
+ embedding_dim (`int`, *optional*, defaults to 512):
1079
+ Dimension of the embeddings to generate.
1080
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
1081
+ Data type of the generated embeddings.
1082
+
1083
+ Returns:
1084
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
1085
+ """
1086
+ assert len(w.shape) == 1
1087
+ w = w * 1000.0
1088
+
1089
+ half_dim = embedding_dim // 2
1090
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
1091
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
1092
+ emb = w.to(dtype)[:, None] * emb[None, :]
1093
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
1094
+ if embedding_dim % 2 == 1: # zero pad
1095
+ emb = torch.nn.functional.pad(emb, (0, 1))
1096
+ assert emb.shape == (w.shape[0], embedding_dim)
1097
+ return emb
1098
+
1099
+ @property
1100
+ def guidance_scale(self):
1101
+ return self._guidance_scale
1102
+
1103
+ @property
1104
+ def guidance_rescale(self):
1105
+ return self._guidance_rescale
1106
+
1107
+ @property
1108
+ def clip_skip(self):
1109
+ return self._clip_skip
1110
+
1111
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1112
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1113
+ # corresponds to doing no classifier free guidance.
1114
+ @property
1115
+ def do_classifier_free_guidance(self):
1116
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
1117
+
1118
+ @property
1119
+ def cross_attention_kwargs(self):
1120
+ return self._cross_attention_kwargs
1121
+
1122
+ @property
1123
+ def denoising_end(self):
1124
+ return self._denoising_end
1125
+
1126
+ @property
1127
+ def denoising_start(self):
1128
+ return self._denoising_start
1129
+
1130
+ @property
1131
+ def num_timesteps(self):
1132
+ return self._num_timesteps
1133
+
1134
+ @property
1135
+ def interrupt(self):
1136
+ return self._interrupt
1137
+
1138
+ @torch.no_grad()
1139
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1140
+ def __call__(
1141
+ self,
1142
+ prompt: Union[str, List[str]] = None,
1143
+ prompt_2: Optional[Union[str, List[str]]] = None,
1144
+ image: PipelineImageInput = None,
1145
+ mask_image: PipelineImageInput = None,
1146
+ masked_image_latents: torch.Tensor = None,
1147
+ height: Optional[int] = None,
1148
+ width: Optional[int] = None,
1149
+ padding_mask_crop: Optional[int] = None,
1150
+ strength: float = 0.9999,
1151
+ num_inference_steps: int = 50,
1152
+ timesteps: List[int] = None,
1153
+ sigmas: List[float] = None,
1154
+ denoising_start: Optional[float] = None,
1155
+ denoising_end: Optional[float] = None,
1156
+ guidance_scale: float = 7.5,
1157
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1158
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
1159
+ num_images_per_prompt: Optional[int] = 1,
1160
+ eta: float = 0.0,
1161
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1162
+ latents: Optional[torch.Tensor] = None,
1163
+ prompt_embeds: Optional[torch.Tensor] = None,
1164
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
1165
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
1166
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
1167
+ ip_adapter_image: Optional[PipelineImageInput] = None,
1168
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1169
+ output_type: Optional[str] = "pil",
1170
+ return_dict: bool = True,
1171
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1172
+ guidance_rescale: float = 0.0,
1173
+ original_size: Tuple[int, int] = None,
1174
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
1175
+ target_size: Tuple[int, int] = None,
1176
+ negative_original_size: Optional[Tuple[int, int]] = None,
1177
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
1178
+ negative_target_size: Optional[Tuple[int, int]] = None,
1179
+ aesthetic_score: float = 6.0,
1180
+ negative_aesthetic_score: float = 2.5,
1181
+ clip_skip: Optional[int] = None,
1182
+ callback_on_step_end: Optional[
1183
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
1184
+ ] = None,
1185
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1186
+ **kwargs,
1187
+ ):
1188
+ r"""
1189
+ Function invoked when calling the pipeline for generation.
1190
+
1191
+ Args:
1192
+ prompt (`str` or `List[str]`, *optional*):
1193
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1194
+ instead.
1195
+ prompt_2 (`str` or `List[str]`, *optional*):
1196
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1197
+ used in both text-encoders
1198
+ image (`PIL.Image.Image`):
1199
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
1200
+ be masked out with `mask_image` and repainted according to `prompt`.
1201
+ mask_image (`PIL.Image.Image`):
1202
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1203
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
1204
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
1205
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
1206
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1207
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
1208
+ Anything below 512 pixels won't work well for
1209
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1210
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1211
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1212
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
1213
+ Anything below 512 pixels won't work well for
1214
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1215
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1216
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
1217
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
1218
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
1219
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
1220
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
1221
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
1222
+ the image is large and contain information irrelevant for inpainting, such as background.
1223
+ strength (`float`, *optional*, defaults to 0.9999):
1224
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
1225
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
1226
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
1227
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
1228
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
1229
+ portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
1230
+ integer, the value of `strength` will be ignored.
1231
+ num_inference_steps (`int`, *optional*, defaults to 50):
1232
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1233
+ expense of slower inference.
1234
+ timesteps (`List[int]`, *optional*):
1235
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1236
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1237
+ passed will be used. Must be in descending order.
1238
+ sigmas (`List[float]`, *optional*):
1239
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
1240
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
1241
+ will be used.
1242
+ denoising_start (`float`, *optional*):
1243
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
1244
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
1245
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
1246
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
1247
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
1248
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
1249
+ denoising_end (`float`, *optional*):
1250
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
1251
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
1252
+ still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
1253
+ denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
1254
+ final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
1255
+ forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
1256
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
1257
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1258
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1259
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1260
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1261
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1262
+ usually at the expense of lower image quality.
1263
+ negative_prompt (`str` or `List[str]`, *optional*):
1264
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1265
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1266
+ less than `1`).
1267
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1268
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1269
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1270
+ prompt_embeds (`torch.Tensor`, *optional*):
1271
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1272
+ provided, text embeddings will be generated from `prompt` input argument.
1273
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
1274
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1275
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1276
+ argument.
1277
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
1278
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1279
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1280
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
1281
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1282
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1283
+ input argument.
1284
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1285
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1286
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1287
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1288
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1289
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1290
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1291
+ The number of images to generate per prompt.
1292
+ eta (`float`, *optional*, defaults to 0.0):
1293
+ Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1294
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1295
+ generator (`torch.Generator`, *optional*):
1296
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1297
+ to make generation deterministic.
1298
+ latents (`torch.Tensor`, *optional*):
1299
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1300
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1301
+ tensor will ge generated by sampling using the supplied random `generator`.
1302
+ output_type (`str`, *optional*, defaults to `"pil"`):
1303
+ The output format of the generate image. Choose between
1304
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1305
+ return_dict (`bool`, *optional*, defaults to `True`):
1306
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1307
+ plain tuple.
1308
+ cross_attention_kwargs (`dict`, *optional*):
1309
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1310
+ `self.processor` in
1311
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1312
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1313
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1314
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
1315
+ explained in section 2.2 of
1316
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1317
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1318
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1319
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1320
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1321
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1322
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1323
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
1324
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
1325
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1326
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1327
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1328
+ micro-conditioning as explained in section 2.2 of
1329
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1330
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1331
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1332
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1333
+ micro-conditioning as explained in section 2.2 of
1334
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1335
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1336
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1337
+ To negatively condition the generation process based on a target image resolution. It should be as same
1338
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1339
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1340
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1341
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
1342
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
1343
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
1344
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1345
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
1346
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
1347
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
1348
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
1349
+ clip_skip (`int`, *optional*):
1350
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1351
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1352
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1353
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1354
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1355
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1356
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1357
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1358
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1359
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1360
+ `._callback_tensor_inputs` attribute of your pipeline class.
1361
+
1362
+ Examples:
1363
+
1364
+ Returns:
1365
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
1366
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
1367
+ `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
1368
+ """
1369
+
1370
+ callback = kwargs.pop("callback", None)
1371
+ callback_steps = kwargs.pop("callback_steps", None)
1372
+
1373
+ if callback is not None:
1374
+ deprecate(
1375
+ "callback",
1376
+ "1.0.0",
1377
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1378
+ )
1379
+ if callback_steps is not None:
1380
+ deprecate(
1381
+ "callback_steps",
1382
+ "1.0.0",
1383
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1384
+ )
1385
+
1386
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1387
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1388
+
1389
+ # 0. Default height and width to unet
1390
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
1391
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
1392
+
1393
+ # 1. Check inputs
1394
+ self.check_inputs(
1395
+ prompt,
1396
+ prompt_2,
1397
+ image,
1398
+ mask_image,
1399
+ height,
1400
+ width,
1401
+ strength,
1402
+ callback_steps,
1403
+ output_type,
1404
+ negative_prompt,
1405
+ negative_prompt_2,
1406
+ prompt_embeds,
1407
+ negative_prompt_embeds,
1408
+ ip_adapter_image,
1409
+ ip_adapter_image_embeds,
1410
+ callback_on_step_end_tensor_inputs,
1411
+ padding_mask_crop,
1412
+ )
1413
+
1414
+ self._guidance_scale = guidance_scale
1415
+ self._guidance_rescale = guidance_rescale
1416
+ self._clip_skip = clip_skip
1417
+ self._cross_attention_kwargs = cross_attention_kwargs
1418
+ self._denoising_end = denoising_end
1419
+ self._denoising_start = denoising_start
1420
+ self._interrupt = False
1421
+
1422
+ # 2. Define call parameters
1423
+ if prompt is not None and isinstance(prompt, str):
1424
+ batch_size = 1
1425
+ elif prompt is not None and isinstance(prompt, list):
1426
+ batch_size = len(prompt)
1427
+ else:
1428
+ batch_size = prompt_embeds.shape[0]
1429
+
1430
+ device = self._execution_device
1431
+
1432
+ # 3. Encode input prompt
1433
+ text_encoder_lora_scale = (
1434
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1435
+ )
1436
+
1437
+ (
1438
+ prompt_embeds,
1439
+ negative_prompt_embeds,
1440
+ pooled_prompt_embeds,
1441
+ negative_pooled_prompt_embeds,
1442
+ ) = self.encode_prompt(
1443
+ prompt=prompt,
1444
+ device=device,
1445
+ num_images_per_prompt=num_images_per_prompt,
1446
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1447
+ negative_prompt=negative_prompt,
1448
+ prompt_embeds=prompt_embeds,
1449
+ negative_prompt_embeds=negative_prompt_embeds,
1450
+ pooled_prompt_embeds=pooled_prompt_embeds,
1451
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1452
+ lora_scale=text_encoder_lora_scale,
1453
+ )
1454
+
1455
+ # 4. set timesteps
1456
+ def denoising_value_valid(dnv):
1457
+ return isinstance(dnv, float) and 0 < dnv < 1
1458
+
1459
+ timesteps, num_inference_steps = retrieve_timesteps(
1460
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1461
+ )
1462
+ timesteps, num_inference_steps = self.get_timesteps(
1463
+ num_inference_steps,
1464
+ strength,
1465
+ device,
1466
+ denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
1467
+ )
1468
+ # check that number of inference steps is not < 1 - as this doesn't make sense
1469
+ if num_inference_steps < 1:
1470
+ raise ValueError(
1471
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
1472
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
1473
+ )
1474
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
1475
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1476
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
1477
+ is_strength_max = strength == 1.0
1478
+
1479
+ # 5. Preprocess mask and image
1480
+ if padding_mask_crop is not None:
1481
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
1482
+ resize_mode = "fill"
1483
+ else:
1484
+ crops_coords = None
1485
+ resize_mode = "default"
1486
+
1487
+ original_image = image
1488
+ init_image = self.image_processor.preprocess(
1489
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
1490
+ )
1491
+ init_image = init_image.to(dtype=torch.float32)
1492
+
1493
+ mask = self.mask_processor.preprocess(
1494
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
1495
+ )
1496
+
1497
+ if masked_image_latents is not None:
1498
+ masked_image = masked_image_latents
1499
+ elif init_image.shape[1] == 4:
1500
+ # if images are in latent space, we can't mask it
1501
+ masked_image = None
1502
+ else:
1503
+ masked_image = init_image * (mask < 0.5)
1504
+
1505
+ # 6. Prepare latent variables
1506
+ num_channels_latents = self.vae.config.latent_channels
1507
+ num_channels_unet = self.unet.config.in_channels
1508
+ return_image_latents = num_channels_unet == 4
1509
+
1510
+ add_noise = True if self.denoising_start is None else False
1511
+ latents_outputs = self.prepare_latents(
1512
+ batch_size * num_images_per_prompt,
1513
+ num_channels_latents,
1514
+ height,
1515
+ width,
1516
+ prompt_embeds.dtype,
1517
+ device,
1518
+ generator,
1519
+ latents,
1520
+ image=init_image,
1521
+ timestep=latent_timestep,
1522
+ is_strength_max=is_strength_max,
1523
+ add_noise=add_noise,
1524
+ return_noise=True,
1525
+ return_image_latents=return_image_latents,
1526
+ )
1527
+
1528
+ if return_image_latents:
1529
+ latents, noise, image_latents = latents_outputs
1530
+ else:
1531
+ latents, noise = latents_outputs
1532
+
1533
+ # 7. Prepare mask latent variables
1534
+ mask, masked_image_latents = self.prepare_mask_latents(
1535
+ mask,
1536
+ masked_image,
1537
+ batch_size * num_images_per_prompt,
1538
+ height,
1539
+ width,
1540
+ prompt_embeds.dtype,
1541
+ device,
1542
+ generator,
1543
+ self.do_classifier_free_guidance,
1544
+ )
1545
+
1546
+ # 8. Check that sizes of mask, masked image and latents match
1547
+ if num_channels_unet == 9:
1548
+ # default case for runwayml/stable-diffusion-inpainting
1549
+ num_channels_mask = mask.shape[1]
1550
+ num_channels_masked_image = masked_image_latents.shape[1]
1551
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
1552
+ raise ValueError(
1553
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
1554
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1555
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1556
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1557
+ " `pipeline.unet` or your `mask_image` or `image` input."
1558
+ )
1559
+ elif num_channels_unet != 4:
1560
+ raise ValueError(
1561
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
1562
+ )
1563
+ # 8.1 Prepare extra step kwargs.
1564
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1565
+
1566
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1567
+ height, width = latents.shape[-2:]
1568
+ height = height * self.vae_scale_factor
1569
+ width = width * self.vae_scale_factor
1570
+
1571
+ original_size = original_size or (height, width)
1572
+ target_size = target_size or (height, width)
1573
+
1574
+ # 10. Prepare added time ids & embeddings
1575
+ if negative_original_size is None:
1576
+ negative_original_size = original_size
1577
+ if negative_target_size is None:
1578
+ negative_target_size = target_size
1579
+
1580
+ add_text_embeds = pooled_prompt_embeds
1581
+ if self.text_encoder_2 is None:
1582
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1583
+ else:
1584
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1585
+
1586
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
1587
+ original_size,
1588
+ crops_coords_top_left,
1589
+ target_size,
1590
+ aesthetic_score,
1591
+ negative_aesthetic_score,
1592
+ negative_original_size,
1593
+ negative_crops_coords_top_left,
1594
+ negative_target_size,
1595
+ dtype=prompt_embeds.dtype,
1596
+ text_encoder_projection_dim=text_encoder_projection_dim,
1597
+ )
1598
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1599
+
1600
+ if self.do_classifier_free_guidance:
1601
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1602
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1603
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1604
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
1605
+
1606
+ prompt_embeds = prompt_embeds.to(device)
1607
+ add_text_embeds = add_text_embeds.to(device)
1608
+ add_time_ids = add_time_ids.to(device)
1609
+
1610
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1611
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1612
+ ip_adapter_image,
1613
+ ip_adapter_image_embeds,
1614
+ device,
1615
+ batch_size * num_images_per_prompt,
1616
+ self.do_classifier_free_guidance,
1617
+ )
1618
+
1619
+
1620
+ # 11. Denoising loop
1621
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1622
+
1623
+ if (
1624
+ self.denoising_end is not None
1625
+ and self.denoising_start is not None
1626
+ and denoising_value_valid(self.denoising_end)
1627
+ and denoising_value_valid(self.denoising_start)
1628
+ and self.denoising_start >= self.denoising_end
1629
+ ):
1630
+ raise ValueError(
1631
+ f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
1632
+ + f" {self.denoising_end} when using type float."
1633
+ )
1634
+ elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
1635
+ discrete_timestep_cutoff = int(
1636
+ round(
1637
+ self.scheduler.config.num_train_timesteps
1638
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1639
+ )
1640
+ )
1641
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1642
+ timesteps = timesteps[:num_inference_steps]
1643
+
1644
+ # 11.1 Optionally get Guidance Scale Embedding
1645
+ timestep_cond = None
1646
+ if self.unet.config.time_cond_proj_dim is not None:
1647
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1648
+ timestep_cond = self.get_guidance_scale_embedding(
1649
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1650
+ ).to(device=device, dtype=latents.dtype)
1651
+
1652
+ self._num_timesteps = len(timesteps)
1653
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1654
+ for i, t in enumerate(timesteps):
1655
+ if self.interrupt:
1656
+ continue
1657
+ # expand the latents if we are doing classifier free guidance
1658
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1659
+
1660
+ # concat latents, mask, masked_image_latents in the channel dimension
1661
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1662
+
1663
+ if num_channels_unet == 9:
1664
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1665
+
1666
+ # predict the noise residual
1667
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1668
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1669
+ added_cond_kwargs["image_embeds"] = image_embeds
1670
+ noise_pred = self.unet(
1671
+ latent_model_input,
1672
+ t,
1673
+ encoder_hidden_states=prompt_embeds,
1674
+ timestep_cond=timestep_cond,
1675
+ cross_attention_kwargs=self.cross_attention_kwargs,
1676
+ added_cond_kwargs=added_cond_kwargs,
1677
+ return_dict=False,
1678
+ )[0]
1679
+
1680
+ # perform guidance
1681
+ if self.do_classifier_free_guidance:
1682
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1683
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1684
+
1685
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1686
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1687
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1688
+
1689
+ # compute the previous noisy sample x_t -> x_t-1
1690
+ latents_dtype = latents.dtype
1691
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1692
+ if latents.dtype != latents_dtype:
1693
+ if torch.backends.mps.is_available():
1694
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1695
+ latents = latents.to(latents_dtype)
1696
+
1697
+ if num_channels_unet == 4:
1698
+ init_latents_proper = image_latents
1699
+ if self.do_classifier_free_guidance:
1700
+ init_mask, _ = mask.chunk(2)
1701
+ else:
1702
+ init_mask = mask
1703
+
1704
+ if i < len(timesteps) - 1:
1705
+ noise_timestep = timesteps[i + 1]
1706
+ init_latents_proper = self.scheduler.add_noise(
1707
+ init_latents_proper, noise, torch.tensor([noise_timestep])
1708
+ )
1709
+
1710
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1711
+
1712
+ if callback_on_step_end is not None:
1713
+ callback_kwargs = {}
1714
+ for k in callback_on_step_end_tensor_inputs:
1715
+ callback_kwargs[k] = locals()[k]
1716
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1717
+
1718
+ latents = callback_outputs.pop("latents", latents)
1719
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1720
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1721
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1722
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1723
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1724
+ )
1725
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1726
+ add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
1727
+ mask = callback_outputs.pop("mask", mask)
1728
+ masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
1729
+
1730
+ # call the callback, if provided
1731
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1732
+ progress_bar.update()
1733
+ if callback is not None and i % callback_steps == 0:
1734
+ step_idx = i // getattr(self.scheduler, "order", 1)
1735
+ callback(step_idx, t, latents)
1736
+
1737
+ if XLA_AVAILABLE:
1738
+ xm.mark_step()
1739
+
1740
+ if not output_type == "latent":
1741
+ # make sure the VAE is in float32 mode, as it overflows in float16
1742
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1743
+
1744
+ if needs_upcasting:
1745
+ self.upcast_vae()
1746
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1747
+ elif latents.dtype != self.vae.dtype:
1748
+ if torch.backends.mps.is_available():
1749
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1750
+ self.vae = self.vae.to(latents.dtype)
1751
+
1752
+ # unscale/denormalize the latents
1753
+ # denormalize with the mean and std if available and not None
1754
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1755
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1756
+ if has_latents_mean and has_latents_std:
1757
+ latents_mean = (
1758
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1759
+ )
1760
+ latents_std = (
1761
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1762
+ )
1763
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1764
+ else:
1765
+ latents = latents / self.vae.config.scaling_factor
1766
+
1767
+ image = self.vae.decode(latents, return_dict=False)[0]
1768
+
1769
+ # cast back to fp16 if needed
1770
+ if needs_upcasting:
1771
+ self.vae.to(dtype=torch.float16)
1772
+ else:
1773
+ return StableDiffusionXLPipelineOutput(images=latents)
1774
+
1775
+ # apply watermark if available
1776
+ if self.watermark is not None:
1777
+ image = self.watermark.apply_watermark(image)
1778
+
1779
+ image = self.image_processor.postprocess(image, output_type=output_type)
1780
+
1781
+ if padding_mask_crop is not None:
1782
+ image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
1783
+
1784
+ # Offload all models
1785
+ self.maybe_free_model_hooks()
1786
+
1787
+ if not return_dict:
1788
+ return (image,)
1789
+
1790
+ return StableDiffusionXLPipelineOutput(images=image)