ductai199x commited on
Commit
3afc463
·
verified ·
1 Parent(s): db390dc

Upload model

Browse files
Files changed (4) hide show
  1. config.json +17 -0
  2. configuration.py +72 -0
  3. model.safetensors +3 -0
  4. modeling_fsg.py +510 -0
config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "comparenet_config": {
3
+ "hidden_dim": 2048,
4
+ "output_dim": 64
5
+ },
6
+ "fast_sim_mode": true,
7
+ "fe_config": {
8
+ "is_constrained": false,
9
+ "num_classes": 0,
10
+ "num_filters": 6,
11
+ "patch_size": 128,
12
+ "variant": "p128"
13
+ },
14
+ "loc_threshold": 0.3,
15
+ "need_input_255": true,
16
+ "stride_ratio": 0.5
17
+ }
configuration.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class FeConfig:
5
+ def __init__(
6
+ self,
7
+ patch_size: int = 128,
8
+ variant: str = "p128",
9
+ num_classes: int = 0,
10
+ num_filters: int = 6,
11
+ is_constrained: bool = False,
12
+ ):
13
+ self.patch_size = patch_size
14
+ self.variant = variant
15
+ self.num_classes = num_classes
16
+ self.num_filters = num_filters
17
+ self.is_constrained = is_constrained
18
+
19
+ def to_dict(self):
20
+ return {
21
+ "patch_size": self.patch_size,
22
+ "variant": self.variant,
23
+ "num_classes": self.num_classes,
24
+ "num_filters": self.num_filters,
25
+ "is_constrained": self.is_constrained,
26
+ }
27
+
28
+
29
+ class CompareNetConfig:
30
+ def __init__(
31
+ self,
32
+ hidden_dim: int = 2048,
33
+ output_dim: int = 64,
34
+ ):
35
+ self.hidden_dim = hidden_dim
36
+ self.output_dim = output_dim
37
+
38
+ def to_dict(self):
39
+ return {
40
+ "hidden_dim": self.hidden_dim,
41
+ "output_dim": self.output_dim,
42
+ }
43
+
44
+
45
+ class FsgConfig(PretrainedConfig):
46
+ def __init__(
47
+ self,
48
+ fe_config=None,
49
+ comparenet_config=None,
50
+ fast_sim_mode: bool = True,
51
+ loc_threshold: float = 0.3,
52
+ stride_ratio: float = 0.5,
53
+ need_input_255: bool = True,
54
+ **kwargs,
55
+ ):
56
+ super().__init__(**kwargs)
57
+ self.fe_config = FeConfig() if fe_config is None else FeConfig(**fe_config)
58
+ self.comparenet_config = CompareNetConfig() if comparenet_config is None else CompareNetConfig(**comparenet_config)
59
+ self.fast_sim_mode = fast_sim_mode
60
+ self.loc_threshold = loc_threshold
61
+ self.stride_ratio = stride_ratio
62
+ self.need_input_255 = need_input_255
63
+
64
+ def to_dict(self):
65
+ return {
66
+ "fe_config": self.fe_config.to_dict(),
67
+ "comparenet_config": self.comparenet_config.to_dict(),
68
+ "fast_sim_mode": self.fast_sim_mode,
69
+ "loc_threshold": self.loc_threshold,
70
+ "stride_ratio": self.stride_ratio,
71
+ "need_input_255": self.need_input_255,
72
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b281e2f076cf288875a51cee670c58a080a5695453a42e297bb17528c1ed99e
3
+ size 4974180
modeling_fsg.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import re
3
+ import warnings
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.distributions import Normal
8
+ from transformers import PreTrainedModel
9
+ from huggingface_hub import PyTorchModelHubMixin
10
+ from numba import jit
11
+ from .configuration import FsgConfig
12
+ from typing import Literal, Type, Union, List
13
+
14
+
15
+ def batch_fn(iterable, n=1):
16
+ l = len(iterable)
17
+ for ndx in range(0, l, n):
18
+ yield iterable[ndx : min(ndx + n, l)]
19
+
20
+
21
+ def gaussian_kernel_1d(sigma: float, num_sigmas: float = 3.0) -> torch.Tensor:
22
+ radius = math.ceil(num_sigmas * sigma)
23
+ support = torch.arange(-radius, radius + 1, dtype=torch.float)
24
+ kernel = Normal(loc=0, scale=sigma).log_prob(support).exp_()
25
+ # Ensure kernel weights sum to 1, so that image brightness is not altered
26
+ return kernel.mul_(1 / kernel.sum())
27
+
28
+
29
+ def gaussian_filter_2d(img: torch.Tensor, sigma: float) -> torch.Tensor:
30
+ kernel_1d = gaussian_kernel_1d(sigma).to(img.device) # Create 1D Gaussian kernel
31
+ padding = len(kernel_1d) // 2 # Ensure that image size does not change
32
+ img = img[None, None, ...] # Need 4D data for ``conv2d()``
33
+ # Convolve along columns and rows
34
+ img = F.conv2d(img, weight=kernel_1d.view(1, 1, -1, 1), padding=(padding, 0))
35
+ img = F.conv2d(img, weight=kernel_1d.view(1, 1, 1, -1), padding=(0, padding))
36
+ return img.squeeze() # Make 2D again
37
+
38
+
39
+ class BaseModel(nn.Module):
40
+ def __init__(
41
+ self,
42
+ patch_size: int,
43
+ num_classes: int = 0,
44
+ **kwargs,
45
+ ):
46
+ super().__init__()
47
+ self.patch_size = patch_size
48
+ self.num_classes = num_classes
49
+
50
+
51
+ class ConstrainedConv(nn.Module):
52
+ def __init__(self, input_chan=3, num_filters=6, is_constrained=True):
53
+ super().__init__()
54
+ self.kernel_size = 5
55
+ self.input_chan = input_chan
56
+ self.num_filters = num_filters
57
+ self.is_constrained = is_constrained
58
+ weight = torch.empty(num_filters, input_chan, self.kernel_size, self.kernel_size)
59
+ nn.init.xavier_normal_(weight, gain=1/3)
60
+ self.weight = nn.Parameter(weight, requires_grad=True)
61
+ self.one_middle = torch.zeros(self.kernel_size * self.kernel_size)
62
+ self.one_middle[12] = 1
63
+ self.one_middle = nn.Parameter(self.one_middle, requires_grad=False)
64
+
65
+ def forward(self, x):
66
+ w = self.weight
67
+ if self.is_constrained:
68
+ w = w.view(-1, self.kernel_size * self.kernel_size)
69
+ w = w - w.mean(1)[..., None] + 1 / (self.kernel_size * self.kernel_size - 1)
70
+ w = w - (w + 1) * self.one_middle
71
+ w = w.view(self.num_filters, self.input_chan, self.kernel_size, self.kernel_size)
72
+ x = nn.functional.conv2d(x, w, padding="valid")
73
+ x = nn.functional.pad(x, (2, 3, 2, 3))
74
+ return x
75
+
76
+
77
+ class ConvBlock(torch.nn.Module):
78
+ def __init__(
79
+ self,
80
+ in_chans,
81
+ out_chans,
82
+ kernel_size,
83
+ stride,
84
+ padding,
85
+ activation: Literal["tanh", "relu"],
86
+ ):
87
+ super().__init__()
88
+ assert activation.lower() in ["tanh", "relu"], "The activation layer must be either Tanh or ReLU"
89
+ self.conv = torch.nn.Conv2d(
90
+ in_chans,
91
+ out_chans,
92
+ kernel_size=kernel_size,
93
+ stride=stride,
94
+ padding=padding,
95
+ )
96
+ self.bn = torch.nn.BatchNorm2d(out_chans)
97
+ self.act = torch.nn.Tanh() if activation.lower() == "tanh" else torch.nn.ReLU()
98
+ self.maxpool = torch.nn.MaxPool2d(kernel_size=(3, 3), stride=2)
99
+
100
+ def forward(self, x):
101
+ return self.maxpool(self.act(self.bn(self.conv(x))))
102
+
103
+
104
+ class DenseBlock(torch.nn.Module):
105
+ def __init__(
106
+ self,
107
+ in_chans,
108
+ out_chans,
109
+ activation: Literal["tanh", "relu"],
110
+ ):
111
+ super().__init__()
112
+ assert activation.lower() in ["tanh", "relu"], "The activation layer must be either Tanh or ReLU"
113
+ self.fc = torch.nn.Linear(in_chans, out_chans)
114
+ self.act = torch.nn.Tanh() if activation.lower() == "tanh" else torch.nn.ReLU()
115
+
116
+ def forward(self, x):
117
+ return self.act(self.fc(x))
118
+
119
+
120
+ class MISLNet(BaseModel):
121
+ arch = {
122
+ "p256": [
123
+ ("conv1", -1, 96, 7, 2, "valid", "tanh"),
124
+ ("conv2", 96, 64, 5, 1, "same", "tanh"),
125
+ ("conv3", 64, 64, 5, 1, "same", "tanh"),
126
+ ("conv4", 64, 128, 1, 1, "same", "tanh"),
127
+ ("fc1", 6 * 6 * 128, 200, "tanh"),
128
+ ("fc2", 200, 200, "tanh"),
129
+ ],
130
+ "p256_3fc_256e": [
131
+ ("conv1", -1, 96, 7, 2, "valid", "tanh"),
132
+ ("conv2", 96, 64, 5, 1, "same", "tanh"),
133
+ ("conv3", 64, 64, 5, 1, "same", "tanh"),
134
+ ("conv4", 64, 128, 1, 1, "same", "tanh"),
135
+ ("fc1", 6 * 6 * 128, 1024, "tanh"),
136
+ ("fc2", 1024, 512, "tanh"),
137
+ ("fc3", 512, 256, "tanh"),
138
+ ],
139
+ "p128": [
140
+ ("conv1", -1, 96, 7, 2, "valid", "tanh"),
141
+ ("conv2", 96, 64, 5, 1, "same", "tanh"),
142
+ ("conv3", 64, 64, 5, 1, "same", "tanh"),
143
+ ("conv4", 64, 128, 1, 1, "same", "tanh"),
144
+ ("fc1", 2 * 2 * 128, 200, "tanh"),
145
+ ("fc2", 200, 200, "tanh"),
146
+ ],
147
+ "p96": [
148
+ ("conv1", -1, 96, 7, 2, "valid", "tanh"),
149
+ ("conv2", 96, 64, 5, 1, "same", "tanh"),
150
+ ("conv3", 64, 64, 5, 1, "same", "tanh"),
151
+ ("conv4", 64, 128, 1, 1, "same", "tanh"),
152
+ ("fc1", 8 * 4 * 64, 200, "tanh"),
153
+ ("fc2", 200, 200, "tanh"),
154
+ ],
155
+ "p64": [
156
+ ("conv1", -1, 96, 7, 2, "valid", "tanh"),
157
+ ("conv2", 96, 64, 5, 1, "same", "tanh"),
158
+ ("conv3", 64, 64, 5, 1, "same", "tanh"),
159
+ ("conv4", 64, 128, 1, 1, "same", "tanh"),
160
+ ("fc1", 2 * 4 * 64, 200, "tanh"),
161
+ ("fc2", 200, 200, "tanh"),
162
+ ],
163
+ }
164
+
165
+ def __init__(
166
+ self,
167
+ patch_size: int,
168
+ variant: str,
169
+ num_classes=0,
170
+ num_filters=6,
171
+ is_constrained=True,
172
+ **kwargs,
173
+ ):
174
+ super().__init__(patch_size, num_classes)
175
+ self.variant = variant
176
+ self.chosen_arch = self.arch[variant]
177
+ self.num_filters = num_filters
178
+
179
+ self.constrained_conv = ConstrainedConv(num_filters=num_filters, is_constrained=is_constrained)
180
+
181
+ self.conv_blocks = []
182
+ self.fc_blocks = []
183
+ for block in self.chosen_arch:
184
+ if block[0].startswith("conv"):
185
+ self.conv_blocks.append(
186
+ ConvBlock(
187
+ in_chans=(num_filters if block[1] == -1 else block[1]),
188
+ out_chans=block[2],
189
+ kernel_size=block[3],
190
+ stride=block[4],
191
+ padding=block[5],
192
+ activation=block[6],
193
+ )
194
+ )
195
+ elif block[0].startswith("fc"):
196
+ self.fc_blocks.append(
197
+ DenseBlock(
198
+ in_chans=block[1],
199
+ out_chans=block[2],
200
+ activation=block[3],
201
+ )
202
+ )
203
+
204
+ self.conv_blocks = nn.Sequential(*self.conv_blocks)
205
+ self.fc_blocks = nn.Sequential(*self.fc_blocks)
206
+
207
+ self.register_buffer("flatten_index_permutation", torch.tensor([0, 1, 2, 3], dtype=torch.long))
208
+
209
+ if self.num_classes > 0:
210
+ self.output = nn.Linear(self.chosen_arch[-1][2], self.num_classes)
211
+
212
+ def forward(self, x):
213
+ x = self.constrained_conv(x)
214
+ x = self.conv_blocks(x)
215
+ x = x.permute(*self.flatten_index_permutation)
216
+ x = x.flatten(1, -1)
217
+ x = self.fc_blocks(x)
218
+ if self.num_classes > 0:
219
+ x = self.output(x)
220
+ return x
221
+
222
+ def load_state_dict(self, state_dict, strict=True, assign=False):
223
+ if "flatten_index_permutation" not in state_dict:
224
+ super().load_state_dict(state_dict, False, assign)
225
+ else:
226
+ super().load_state_dict(state_dict, strict, assign)
227
+
228
+
229
+ class CompareNet(nn.Module):
230
+ def __init__(self, input_dim, hidden_dim=2048, output_dim=64):
231
+ super().__init__()
232
+ self.fc1 = DenseBlock(input_dim, hidden_dim, "relu")
233
+ self.fc2 = DenseBlock(hidden_dim * 3, output_dim, "relu")
234
+ self.fc3 = nn.Linear(output_dim, 2)
235
+
236
+ def forward(self, x1, x2):
237
+ x1 = self.fc1(x1)
238
+ x2 = self.fc1(x2)
239
+ x = torch.cat((x1, x1 * x2, x2), dim=1)
240
+ x = self.fc2(x)
241
+ x = self.fc3(x)
242
+ return x
243
+
244
+
245
+ class FSM(nn.Module):
246
+ """
247
+ FSM (Forensic Similarity Metric) is a neural network module that computes the similarity between two input images using a feature extraction module and a comparison network module.
248
+
249
+ Args:
250
+ fe_config (dict): Configuration for the feature extraction module.
251
+ comparenet_config (dict): Configuration for the comparison network module.
252
+ fe_ckpt (str): Path to the checkpoint file for the feature extraction module.
253
+ **kwargs: Additional keyword arguments.
254
+ """
255
+
256
+ def __init__(
257
+ self,
258
+ fe_config,
259
+ comparenet_config,
260
+ fe_ckpt=None,
261
+ **kwargs,
262
+ ):
263
+ super().__init__()
264
+ fe_config["num_classes"] = 0 # to make fe without final classification layer
265
+ self.fe: MISLNet = self.load_module_from_ckpt(MISLNet, fe_ckpt, "", **fe_config)
266
+ self.patch_size = self.fe.patch_size
267
+ comparenet_config["input_dim"] = self.fe.fc_blocks[-1].fc.out_features
268
+ self.comparenet = CompareNet(**comparenet_config)
269
+ self.fe_freeze = True
270
+
271
+ def load_module_state_dict(self, module: nn.Module, state_dict, module_name=""):
272
+ curr_model_state_dict = module.state_dict()
273
+ curr_model_keys_status = {k: False for k in curr_model_state_dict.keys()}
274
+ outstanding_keys = []
275
+ for ckpt_layer_name, ckpt_layer_weights in state_dict.items():
276
+ if module_name not in ckpt_layer_name:
277
+ continue
278
+ ckpt_matches = re.findall(r"(?=(?:^|\.)((?:\w+\.)*\w+)$)", ckpt_layer_name)[::-1]
279
+ model_layer_name_match = list(set(ckpt_matches).intersection(set(curr_model_state_dict.keys())))
280
+ # print(ckpt_layer_name, model_layer_name_match)
281
+ if len(model_layer_name_match) == 0:
282
+ outstanding_keys.append(ckpt_layer_name)
283
+ else:
284
+ model_layer_name = model_layer_name_match[0]
285
+ assert (
286
+ curr_model_state_dict[model_layer_name].shape == ckpt_layer_weights.shape
287
+ ), f"Ckpt layer '{ckpt_layer_name}' shape {ckpt_layer_weights.shape} does not match model layer '{model_layer_name}' shape {curr_model_state_dict[model_layer_name].shape}"
288
+ curr_model_state_dict[model_layer_name] = ckpt_layer_weights
289
+ curr_model_keys_status[model_layer_name] = True
290
+
291
+ if all(curr_model_keys_status.values()):
292
+ print(f"Success! All necessary keys for module '{module.__class__.__name__}' are loaded!")
293
+ else:
294
+ not_loaded_keys = [k for k, v in curr_model_keys_status.items() if not v]
295
+ print(f"Warning! Some keys are not loaded! Not loaded keys are:\n{not_loaded_keys}")
296
+ if len(outstanding_keys) > 0:
297
+ print(f"Outstanding keys are: {outstanding_keys}")
298
+ module.load_state_dict(curr_model_state_dict, strict=False)
299
+
300
+ def load_module_from_ckpt(
301
+ self,
302
+ module_class: Type[nn.Module],
303
+ ckpt_path: Union[None, str],
304
+ module_name: str,
305
+ *args,
306
+ **kwargs,
307
+ ) -> nn.Module:
308
+ module = module_class(*args, **kwargs)
309
+
310
+ if ckpt_path is not None:
311
+ ckpt = torch.load(ckpt_path, map_location="cpu")
312
+ ckpt_state_dict = ckpt["state_dict"]
313
+ self.load_module_state_dict(module, ckpt_state_dict, module_name=module_name)
314
+ return module
315
+
316
+ def load_state_dict(self, state_dict, strict=True, assign=False):
317
+ try:
318
+ super().load_state_dict(state_dict, strict=strict, assign=assign)
319
+ except Exception as e:
320
+ print(f"Error loading state dict using normal method: {e}")
321
+ print("Trying to load state dict manually...")
322
+ # self.load_module_state_dict(self.fe, state_dict, module_name="fe")
323
+ # self.load_module_state_dict(self.comparenet, state_dict, module_name="comparenet")
324
+ self.load_module_state_dict(self, state_dict, module_name="")
325
+ print("State dict loaded successfully!")
326
+
327
+ def forward_fe(self, x):
328
+ if self.freeze_fe:
329
+ self.fe.eval()
330
+ with torch.no_grad():
331
+ return self.fe(x)
332
+ else:
333
+ self.fe.train()
334
+ return self.fe(x)
335
+
336
+ def forward(self, x1, x2):
337
+ x1 = self.forward_fe(x1)
338
+ x2 = self.forward_fe(x2)
339
+ return self.comparenet(x1, x2)
340
+
341
+
342
+ class FsgModel(
343
+ PreTrainedModel,
344
+ PyTorchModelHubMixin,
345
+ repo_url="ductai199x/forensic-similarity-graph",
346
+ pipeline_tag="image-manipulation-detection-localization",
347
+ license="cc-by-nc-nd-4.0",
348
+ ):
349
+ """
350
+ Forensic Similarity Graph (FSG) algorithm.
351
+ https://ieeexplore.ieee.org/abstract/document/9113265
352
+
353
+ This class is designed to create a graph-based representation of forensic similarity between different patches of an image, allowing for the detection of manipulated regions.
354
+
355
+ Args:
356
+ stride_ratio (float): The ratio of the stride to the patch size, determining the overlap between patches. The lower the value, the higher the overlap.
357
+ fast_sim_mode (bool): If True, the algorithm uses a faster method to compute similarity scores, potentially at the cost of accuracy.
358
+ loc_threshold (float): The threshold for determining the location of interest in the similarity graph. Values above this threshold are considered significant.
359
+ is_high_sim (bool): If True, higher similarity scores indicate higher similarity. If False, lower scores indicate higher similarity.
360
+ need_input_255 (bool): If True, input images are expected to be scaled to [0, 255]. If False, images are expected to be in [0, 1].
361
+ **kwargs: Additional keyword arguments passed to the superclass initializer.
362
+
363
+ Example Usage:
364
+ ```python
365
+ import torch
366
+ import matplotlib.pyplot as plt
367
+ from torchvision.io import read_image, ImageReadMode
368
+ from model import FSG
369
+
370
+ ckpt_path = "path/to/ckpt.pth"
371
+ model = FSG.load_from_checkpoint(ckpt_path, map_location="cpu", stride_ratio=0.5, fast_sim_mode=False, loc_threshold=0.37, is_high_sim=False, need_input_255=False)
372
+ model.eval()
373
+
374
+ img_path = "path/to/image.jpg"
375
+ image = read_image(img_path, mode=ImageReadMode.RGB).float() / 255
376
+
377
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
378
+ with torch.no_grad():
379
+ img_preds, loc_preds = model(image[None, ...].to(device))
380
+
381
+ plt.imshow(loc_preds.cpu()[0])
382
+ plt.colorbar()
383
+ plt.show()
384
+ ```
385
+ """
386
+
387
+ def __init__(self, config: FsgConfig, **kwargs):
388
+ super().__init__(config)
389
+ self.patch_size = config.fe_config.patch_size
390
+ self.stride = int(self.patch_size * config.stride_ratio)
391
+ self.fast_sim_mode = config.fast_sim_mode
392
+ self.loc_threshold = config.loc_threshold
393
+ self.is_high_sim = True
394
+ self.need_input_255 = config.need_input_255
395
+ self.model = FSM(fe_config=config.fe_config.to_dict(), comparenet_config=config.comparenet_config.to_dict())
396
+
397
+ warnings.filterwarnings("ignore")
398
+
399
+ def get_batched_patches(self, x: torch.Tensor):
400
+ B, C, H, W = x.shape
401
+ # split images into batches of patches: B x C x H x W -> B x (NumPatchHeight x NumPatchWidth) x C x PatchSize x PatchSize
402
+ batched_patches = (
403
+ x.unfold(2, self.patch_size, self.stride)
404
+ .unfold(3, self.patch_size, self.stride)
405
+ .permute(0, 2, 3, 1, 4, 5)
406
+ )
407
+ batched_patches = batched_patches.contiguous().view(B, -1, C, self.patch_size, self.patch_size)
408
+ return batched_patches
409
+
410
+ def get_patches_single(self, x: torch.Tensor):
411
+ C, H, W = x.shape
412
+ patches = (
413
+ x.unfold(1, self.patch_size, self.stride)
414
+ .unfold(2, self.patch_size, self.stride)
415
+ .permute(1, 2, 0, 3, 4)
416
+ )
417
+ patches = patches.contiguous().view(-1, C, self.patch_size, self.patch_size)
418
+ return patches
419
+
420
+ @jit(forceobj=True)
421
+ def get_features(self, image_patches: torch.Tensor):
422
+ patches_features = []
423
+ for batch in list(batch_fn(image_patches, 256)):
424
+ batch = batch.float()
425
+ feats = self.model.fe(batch).detach()
426
+ patches_features.append(feats)
427
+ patches_features = torch.vstack(patches_features)
428
+ return patches_features
429
+
430
+ @jit(forceobj=True)
431
+ def get_sim_scores(self, patch_pairs):
432
+ patches_sim_scores = []
433
+ for batch in list(batch_fn(patch_pairs, 4096)):
434
+ batch = batch.permute(1, 0, 2).float()
435
+ scores = self.model.comparenet(*batch).detach()
436
+ scores = torch.nn.functional.softmax(scores, dim=1)
437
+ patches_sim_scores.append(scores)
438
+ patches_sim_scores = torch.vstack(patches_sim_scores)
439
+ return patches_sim_scores
440
+
441
+ def forward_single(self, patches: torch.Tensor):
442
+ P, C, H, W = patches.shape
443
+ features = self.get_features(patches)
444
+ sim_mat = torch.zeros(P, P, device=patches.device)
445
+ if self.fast_sim_mode:
446
+ upper_tri_idx = torch.triu_indices(P, P, 1).T
447
+ patch_pairs = features[upper_tri_idx]
448
+ else:
449
+ patch_cart_prod = torch.cartesian_prod(torch.arange(P), torch.arange(P))
450
+ patch_pairs = features[patch_cart_prod]
451
+ sim_scores = self.get_sim_scores(patch_pairs).detach()
452
+ if self.fast_sim_mode:
453
+ sim_mat[upper_tri_idx[:, 0], upper_tri_idx[:, 1]] = sim_scores[:, 1]
454
+ sim_mat += sim_mat.clone().T
455
+ else:
456
+ sim_mat = sim_scores[:, 1].view(P, P)
457
+ sim_mat = 0.5 * (sim_mat + sim_mat.T)
458
+ if not self.is_high_sim:
459
+ sim_mat = 1 - sim_mat
460
+ sim_mat.fill_diagonal_(0.0)
461
+ degree_mat = torch.diag(sim_mat.sum(axis=1))
462
+ laplacian_mat = degree_mat - sim_mat
463
+ degree_sym_mat = torch.diag(sim_mat.sum(axis=1) ** -0.5)
464
+ laplacian_sym_mat = (degree_sym_mat @ laplacian_mat) @ degree_sym_mat
465
+ eigvals, eigvecs = torch.linalg.eigh(laplacian_sym_mat.cpu())
466
+ spectral_gap = eigvals[1] - eigvals[0]
467
+ img_pred = 1 - spectral_gap
468
+ eigvec = eigvecs[:, 1]
469
+ patch_pred = (eigvec > 0).int()
470
+ return img_pred.detach(), patch_pred.detach()
471
+
472
+ def forward(self, x: Union[torch.Tensor, List[torch.Tensor]]):
473
+ if isinstance(x, torch.Tensor) and len(x.shape) == 3:
474
+ x = [x]
475
+
476
+ img_preds = []
477
+ loc_preds = []
478
+ for img in x:
479
+ C, H, W = img.shape
480
+ if self.need_input_255 and img.max() <= 1:
481
+ img = img * 255
482
+ # get the (x, y) coordinates of the top left of each patch in the image
483
+ x_inds = torch.arange(W).unfold(0, self.patch_size, self.stride)[:, 0]
484
+ y_inds = torch.arange(H).unfold(0, self.patch_size, self.stride)[:, 0]
485
+ xy_inds = torch.tensor([(ii, jj) for jj in y_inds for ii in x_inds]).to(img.device)
486
+
487
+ patches = self.get_patches_single(img)
488
+ img_pred, patch_pred = self.forward_single(patches)
489
+ loc_pred = self.patch_to_pixel_pred(patch_pred, xy_inds)
490
+ loc_pred = F.interpolate(loc_pred[None, None, ...], size=(H, W), mode="nearest").squeeze()
491
+ img_preds.append(img_pred)
492
+ loc_preds.append(loc_pred)
493
+ return img_preds, loc_preds
494
+
495
+ def patch_to_pixel_pred(self, patch_pred, xy_inds):
496
+ W, H = torch.max(xy_inds, dim=0).values + self.patch_size
497
+ pixel_pred = torch.zeros((H, W)).to(patch_pred.device)
498
+ coverage_map = torch.zeros((H, W)).to(patch_pred.device)
499
+ for (x, y), pred in zip(xy_inds, patch_pred):
500
+ pixel_pred[y : y + self.patch_size, x : x + self.patch_size] += pred
501
+ coverage_map[y : y + self.patch_size, x : x + self.patch_size] += 1
502
+ # perform gaussian smoothing
503
+ pixel_pred = gaussian_filter_2d(pixel_pred, sigma=32)
504
+ coverage_map = gaussian_filter_2d(coverage_map, sigma=32)
505
+ pixel_pred /= coverage_map + 1e-8
506
+ pixel_pred /= pixel_pred.max() + 1e-8
507
+ if pixel_pred.sum() > pixel_pred.numel() * 0.5:
508
+ pixel_pred = 1 - pixel_pred
509
+ pixel_pred = (pixel_pred > self.loc_threshold).float()
510
+ return pixel_pred