ductai199x
commited on
Upload model
Browse files- config.json +17 -0
- configuration.py +72 -0
- model.safetensors +3 -0
- modeling_fsg.py +510 -0
config.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"comparenet_config": {
|
3 |
+
"hidden_dim": 2048,
|
4 |
+
"output_dim": 64
|
5 |
+
},
|
6 |
+
"fast_sim_mode": true,
|
7 |
+
"fe_config": {
|
8 |
+
"is_constrained": false,
|
9 |
+
"num_classes": 0,
|
10 |
+
"num_filters": 6,
|
11 |
+
"patch_size": 128,
|
12 |
+
"variant": "p128"
|
13 |
+
},
|
14 |
+
"loc_threshold": 0.3,
|
15 |
+
"need_input_255": true,
|
16 |
+
"stride_ratio": 0.5
|
17 |
+
}
|
configuration.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.configuration_utils import PretrainedConfig
|
2 |
+
|
3 |
+
|
4 |
+
class FeConfig:
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
patch_size: int = 128,
|
8 |
+
variant: str = "p128",
|
9 |
+
num_classes: int = 0,
|
10 |
+
num_filters: int = 6,
|
11 |
+
is_constrained: bool = False,
|
12 |
+
):
|
13 |
+
self.patch_size = patch_size
|
14 |
+
self.variant = variant
|
15 |
+
self.num_classes = num_classes
|
16 |
+
self.num_filters = num_filters
|
17 |
+
self.is_constrained = is_constrained
|
18 |
+
|
19 |
+
def to_dict(self):
|
20 |
+
return {
|
21 |
+
"patch_size": self.patch_size,
|
22 |
+
"variant": self.variant,
|
23 |
+
"num_classes": self.num_classes,
|
24 |
+
"num_filters": self.num_filters,
|
25 |
+
"is_constrained": self.is_constrained,
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
class CompareNetConfig:
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
hidden_dim: int = 2048,
|
33 |
+
output_dim: int = 64,
|
34 |
+
):
|
35 |
+
self.hidden_dim = hidden_dim
|
36 |
+
self.output_dim = output_dim
|
37 |
+
|
38 |
+
def to_dict(self):
|
39 |
+
return {
|
40 |
+
"hidden_dim": self.hidden_dim,
|
41 |
+
"output_dim": self.output_dim,
|
42 |
+
}
|
43 |
+
|
44 |
+
|
45 |
+
class FsgConfig(PretrainedConfig):
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
fe_config=None,
|
49 |
+
comparenet_config=None,
|
50 |
+
fast_sim_mode: bool = True,
|
51 |
+
loc_threshold: float = 0.3,
|
52 |
+
stride_ratio: float = 0.5,
|
53 |
+
need_input_255: bool = True,
|
54 |
+
**kwargs,
|
55 |
+
):
|
56 |
+
super().__init__(**kwargs)
|
57 |
+
self.fe_config = FeConfig() if fe_config is None else FeConfig(**fe_config)
|
58 |
+
self.comparenet_config = CompareNetConfig() if comparenet_config is None else CompareNetConfig(**comparenet_config)
|
59 |
+
self.fast_sim_mode = fast_sim_mode
|
60 |
+
self.loc_threshold = loc_threshold
|
61 |
+
self.stride_ratio = stride_ratio
|
62 |
+
self.need_input_255 = need_input_255
|
63 |
+
|
64 |
+
def to_dict(self):
|
65 |
+
return {
|
66 |
+
"fe_config": self.fe_config.to_dict(),
|
67 |
+
"comparenet_config": self.comparenet_config.to_dict(),
|
68 |
+
"fast_sim_mode": self.fast_sim_mode,
|
69 |
+
"loc_threshold": self.loc_threshold,
|
70 |
+
"stride_ratio": self.stride_ratio,
|
71 |
+
"need_input_255": self.need_input_255,
|
72 |
+
}
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0b281e2f076cf288875a51cee670c58a080a5695453a42e297bb17528c1ed99e
|
3 |
+
size 4974180
|
modeling_fsg.py
ADDED
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import re
|
3 |
+
import warnings
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.distributions import Normal
|
8 |
+
from transformers import PreTrainedModel
|
9 |
+
from huggingface_hub import PyTorchModelHubMixin
|
10 |
+
from numba import jit
|
11 |
+
from .configuration import FsgConfig
|
12 |
+
from typing import Literal, Type, Union, List
|
13 |
+
|
14 |
+
|
15 |
+
def batch_fn(iterable, n=1):
|
16 |
+
l = len(iterable)
|
17 |
+
for ndx in range(0, l, n):
|
18 |
+
yield iterable[ndx : min(ndx + n, l)]
|
19 |
+
|
20 |
+
|
21 |
+
def gaussian_kernel_1d(sigma: float, num_sigmas: float = 3.0) -> torch.Tensor:
|
22 |
+
radius = math.ceil(num_sigmas * sigma)
|
23 |
+
support = torch.arange(-radius, radius + 1, dtype=torch.float)
|
24 |
+
kernel = Normal(loc=0, scale=sigma).log_prob(support).exp_()
|
25 |
+
# Ensure kernel weights sum to 1, so that image brightness is not altered
|
26 |
+
return kernel.mul_(1 / kernel.sum())
|
27 |
+
|
28 |
+
|
29 |
+
def gaussian_filter_2d(img: torch.Tensor, sigma: float) -> torch.Tensor:
|
30 |
+
kernel_1d = gaussian_kernel_1d(sigma).to(img.device) # Create 1D Gaussian kernel
|
31 |
+
padding = len(kernel_1d) // 2 # Ensure that image size does not change
|
32 |
+
img = img[None, None, ...] # Need 4D data for ``conv2d()``
|
33 |
+
# Convolve along columns and rows
|
34 |
+
img = F.conv2d(img, weight=kernel_1d.view(1, 1, -1, 1), padding=(padding, 0))
|
35 |
+
img = F.conv2d(img, weight=kernel_1d.view(1, 1, 1, -1), padding=(0, padding))
|
36 |
+
return img.squeeze() # Make 2D again
|
37 |
+
|
38 |
+
|
39 |
+
class BaseModel(nn.Module):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
patch_size: int,
|
43 |
+
num_classes: int = 0,
|
44 |
+
**kwargs,
|
45 |
+
):
|
46 |
+
super().__init__()
|
47 |
+
self.patch_size = patch_size
|
48 |
+
self.num_classes = num_classes
|
49 |
+
|
50 |
+
|
51 |
+
class ConstrainedConv(nn.Module):
|
52 |
+
def __init__(self, input_chan=3, num_filters=6, is_constrained=True):
|
53 |
+
super().__init__()
|
54 |
+
self.kernel_size = 5
|
55 |
+
self.input_chan = input_chan
|
56 |
+
self.num_filters = num_filters
|
57 |
+
self.is_constrained = is_constrained
|
58 |
+
weight = torch.empty(num_filters, input_chan, self.kernel_size, self.kernel_size)
|
59 |
+
nn.init.xavier_normal_(weight, gain=1/3)
|
60 |
+
self.weight = nn.Parameter(weight, requires_grad=True)
|
61 |
+
self.one_middle = torch.zeros(self.kernel_size * self.kernel_size)
|
62 |
+
self.one_middle[12] = 1
|
63 |
+
self.one_middle = nn.Parameter(self.one_middle, requires_grad=False)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
w = self.weight
|
67 |
+
if self.is_constrained:
|
68 |
+
w = w.view(-1, self.kernel_size * self.kernel_size)
|
69 |
+
w = w - w.mean(1)[..., None] + 1 / (self.kernel_size * self.kernel_size - 1)
|
70 |
+
w = w - (w + 1) * self.one_middle
|
71 |
+
w = w.view(self.num_filters, self.input_chan, self.kernel_size, self.kernel_size)
|
72 |
+
x = nn.functional.conv2d(x, w, padding="valid")
|
73 |
+
x = nn.functional.pad(x, (2, 3, 2, 3))
|
74 |
+
return x
|
75 |
+
|
76 |
+
|
77 |
+
class ConvBlock(torch.nn.Module):
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
in_chans,
|
81 |
+
out_chans,
|
82 |
+
kernel_size,
|
83 |
+
stride,
|
84 |
+
padding,
|
85 |
+
activation: Literal["tanh", "relu"],
|
86 |
+
):
|
87 |
+
super().__init__()
|
88 |
+
assert activation.lower() in ["tanh", "relu"], "The activation layer must be either Tanh or ReLU"
|
89 |
+
self.conv = torch.nn.Conv2d(
|
90 |
+
in_chans,
|
91 |
+
out_chans,
|
92 |
+
kernel_size=kernel_size,
|
93 |
+
stride=stride,
|
94 |
+
padding=padding,
|
95 |
+
)
|
96 |
+
self.bn = torch.nn.BatchNorm2d(out_chans)
|
97 |
+
self.act = torch.nn.Tanh() if activation.lower() == "tanh" else torch.nn.ReLU()
|
98 |
+
self.maxpool = torch.nn.MaxPool2d(kernel_size=(3, 3), stride=2)
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
return self.maxpool(self.act(self.bn(self.conv(x))))
|
102 |
+
|
103 |
+
|
104 |
+
class DenseBlock(torch.nn.Module):
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
in_chans,
|
108 |
+
out_chans,
|
109 |
+
activation: Literal["tanh", "relu"],
|
110 |
+
):
|
111 |
+
super().__init__()
|
112 |
+
assert activation.lower() in ["tanh", "relu"], "The activation layer must be either Tanh or ReLU"
|
113 |
+
self.fc = torch.nn.Linear(in_chans, out_chans)
|
114 |
+
self.act = torch.nn.Tanh() if activation.lower() == "tanh" else torch.nn.ReLU()
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
return self.act(self.fc(x))
|
118 |
+
|
119 |
+
|
120 |
+
class MISLNet(BaseModel):
|
121 |
+
arch = {
|
122 |
+
"p256": [
|
123 |
+
("conv1", -1, 96, 7, 2, "valid", "tanh"),
|
124 |
+
("conv2", 96, 64, 5, 1, "same", "tanh"),
|
125 |
+
("conv3", 64, 64, 5, 1, "same", "tanh"),
|
126 |
+
("conv4", 64, 128, 1, 1, "same", "tanh"),
|
127 |
+
("fc1", 6 * 6 * 128, 200, "tanh"),
|
128 |
+
("fc2", 200, 200, "tanh"),
|
129 |
+
],
|
130 |
+
"p256_3fc_256e": [
|
131 |
+
("conv1", -1, 96, 7, 2, "valid", "tanh"),
|
132 |
+
("conv2", 96, 64, 5, 1, "same", "tanh"),
|
133 |
+
("conv3", 64, 64, 5, 1, "same", "tanh"),
|
134 |
+
("conv4", 64, 128, 1, 1, "same", "tanh"),
|
135 |
+
("fc1", 6 * 6 * 128, 1024, "tanh"),
|
136 |
+
("fc2", 1024, 512, "tanh"),
|
137 |
+
("fc3", 512, 256, "tanh"),
|
138 |
+
],
|
139 |
+
"p128": [
|
140 |
+
("conv1", -1, 96, 7, 2, "valid", "tanh"),
|
141 |
+
("conv2", 96, 64, 5, 1, "same", "tanh"),
|
142 |
+
("conv3", 64, 64, 5, 1, "same", "tanh"),
|
143 |
+
("conv4", 64, 128, 1, 1, "same", "tanh"),
|
144 |
+
("fc1", 2 * 2 * 128, 200, "tanh"),
|
145 |
+
("fc2", 200, 200, "tanh"),
|
146 |
+
],
|
147 |
+
"p96": [
|
148 |
+
("conv1", -1, 96, 7, 2, "valid", "tanh"),
|
149 |
+
("conv2", 96, 64, 5, 1, "same", "tanh"),
|
150 |
+
("conv3", 64, 64, 5, 1, "same", "tanh"),
|
151 |
+
("conv4", 64, 128, 1, 1, "same", "tanh"),
|
152 |
+
("fc1", 8 * 4 * 64, 200, "tanh"),
|
153 |
+
("fc2", 200, 200, "tanh"),
|
154 |
+
],
|
155 |
+
"p64": [
|
156 |
+
("conv1", -1, 96, 7, 2, "valid", "tanh"),
|
157 |
+
("conv2", 96, 64, 5, 1, "same", "tanh"),
|
158 |
+
("conv3", 64, 64, 5, 1, "same", "tanh"),
|
159 |
+
("conv4", 64, 128, 1, 1, "same", "tanh"),
|
160 |
+
("fc1", 2 * 4 * 64, 200, "tanh"),
|
161 |
+
("fc2", 200, 200, "tanh"),
|
162 |
+
],
|
163 |
+
}
|
164 |
+
|
165 |
+
def __init__(
|
166 |
+
self,
|
167 |
+
patch_size: int,
|
168 |
+
variant: str,
|
169 |
+
num_classes=0,
|
170 |
+
num_filters=6,
|
171 |
+
is_constrained=True,
|
172 |
+
**kwargs,
|
173 |
+
):
|
174 |
+
super().__init__(patch_size, num_classes)
|
175 |
+
self.variant = variant
|
176 |
+
self.chosen_arch = self.arch[variant]
|
177 |
+
self.num_filters = num_filters
|
178 |
+
|
179 |
+
self.constrained_conv = ConstrainedConv(num_filters=num_filters, is_constrained=is_constrained)
|
180 |
+
|
181 |
+
self.conv_blocks = []
|
182 |
+
self.fc_blocks = []
|
183 |
+
for block in self.chosen_arch:
|
184 |
+
if block[0].startswith("conv"):
|
185 |
+
self.conv_blocks.append(
|
186 |
+
ConvBlock(
|
187 |
+
in_chans=(num_filters if block[1] == -1 else block[1]),
|
188 |
+
out_chans=block[2],
|
189 |
+
kernel_size=block[3],
|
190 |
+
stride=block[4],
|
191 |
+
padding=block[5],
|
192 |
+
activation=block[6],
|
193 |
+
)
|
194 |
+
)
|
195 |
+
elif block[0].startswith("fc"):
|
196 |
+
self.fc_blocks.append(
|
197 |
+
DenseBlock(
|
198 |
+
in_chans=block[1],
|
199 |
+
out_chans=block[2],
|
200 |
+
activation=block[3],
|
201 |
+
)
|
202 |
+
)
|
203 |
+
|
204 |
+
self.conv_blocks = nn.Sequential(*self.conv_blocks)
|
205 |
+
self.fc_blocks = nn.Sequential(*self.fc_blocks)
|
206 |
+
|
207 |
+
self.register_buffer("flatten_index_permutation", torch.tensor([0, 1, 2, 3], dtype=torch.long))
|
208 |
+
|
209 |
+
if self.num_classes > 0:
|
210 |
+
self.output = nn.Linear(self.chosen_arch[-1][2], self.num_classes)
|
211 |
+
|
212 |
+
def forward(self, x):
|
213 |
+
x = self.constrained_conv(x)
|
214 |
+
x = self.conv_blocks(x)
|
215 |
+
x = x.permute(*self.flatten_index_permutation)
|
216 |
+
x = x.flatten(1, -1)
|
217 |
+
x = self.fc_blocks(x)
|
218 |
+
if self.num_classes > 0:
|
219 |
+
x = self.output(x)
|
220 |
+
return x
|
221 |
+
|
222 |
+
def load_state_dict(self, state_dict, strict=True, assign=False):
|
223 |
+
if "flatten_index_permutation" not in state_dict:
|
224 |
+
super().load_state_dict(state_dict, False, assign)
|
225 |
+
else:
|
226 |
+
super().load_state_dict(state_dict, strict, assign)
|
227 |
+
|
228 |
+
|
229 |
+
class CompareNet(nn.Module):
|
230 |
+
def __init__(self, input_dim, hidden_dim=2048, output_dim=64):
|
231 |
+
super().__init__()
|
232 |
+
self.fc1 = DenseBlock(input_dim, hidden_dim, "relu")
|
233 |
+
self.fc2 = DenseBlock(hidden_dim * 3, output_dim, "relu")
|
234 |
+
self.fc3 = nn.Linear(output_dim, 2)
|
235 |
+
|
236 |
+
def forward(self, x1, x2):
|
237 |
+
x1 = self.fc1(x1)
|
238 |
+
x2 = self.fc1(x2)
|
239 |
+
x = torch.cat((x1, x1 * x2, x2), dim=1)
|
240 |
+
x = self.fc2(x)
|
241 |
+
x = self.fc3(x)
|
242 |
+
return x
|
243 |
+
|
244 |
+
|
245 |
+
class FSM(nn.Module):
|
246 |
+
"""
|
247 |
+
FSM (Forensic Similarity Metric) is a neural network module that computes the similarity between two input images using a feature extraction module and a comparison network module.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
fe_config (dict): Configuration for the feature extraction module.
|
251 |
+
comparenet_config (dict): Configuration for the comparison network module.
|
252 |
+
fe_ckpt (str): Path to the checkpoint file for the feature extraction module.
|
253 |
+
**kwargs: Additional keyword arguments.
|
254 |
+
"""
|
255 |
+
|
256 |
+
def __init__(
|
257 |
+
self,
|
258 |
+
fe_config,
|
259 |
+
comparenet_config,
|
260 |
+
fe_ckpt=None,
|
261 |
+
**kwargs,
|
262 |
+
):
|
263 |
+
super().__init__()
|
264 |
+
fe_config["num_classes"] = 0 # to make fe without final classification layer
|
265 |
+
self.fe: MISLNet = self.load_module_from_ckpt(MISLNet, fe_ckpt, "", **fe_config)
|
266 |
+
self.patch_size = self.fe.patch_size
|
267 |
+
comparenet_config["input_dim"] = self.fe.fc_blocks[-1].fc.out_features
|
268 |
+
self.comparenet = CompareNet(**comparenet_config)
|
269 |
+
self.fe_freeze = True
|
270 |
+
|
271 |
+
def load_module_state_dict(self, module: nn.Module, state_dict, module_name=""):
|
272 |
+
curr_model_state_dict = module.state_dict()
|
273 |
+
curr_model_keys_status = {k: False for k in curr_model_state_dict.keys()}
|
274 |
+
outstanding_keys = []
|
275 |
+
for ckpt_layer_name, ckpt_layer_weights in state_dict.items():
|
276 |
+
if module_name not in ckpt_layer_name:
|
277 |
+
continue
|
278 |
+
ckpt_matches = re.findall(r"(?=(?:^|\.)((?:\w+\.)*\w+)$)", ckpt_layer_name)[::-1]
|
279 |
+
model_layer_name_match = list(set(ckpt_matches).intersection(set(curr_model_state_dict.keys())))
|
280 |
+
# print(ckpt_layer_name, model_layer_name_match)
|
281 |
+
if len(model_layer_name_match) == 0:
|
282 |
+
outstanding_keys.append(ckpt_layer_name)
|
283 |
+
else:
|
284 |
+
model_layer_name = model_layer_name_match[0]
|
285 |
+
assert (
|
286 |
+
curr_model_state_dict[model_layer_name].shape == ckpt_layer_weights.shape
|
287 |
+
), f"Ckpt layer '{ckpt_layer_name}' shape {ckpt_layer_weights.shape} does not match model layer '{model_layer_name}' shape {curr_model_state_dict[model_layer_name].shape}"
|
288 |
+
curr_model_state_dict[model_layer_name] = ckpt_layer_weights
|
289 |
+
curr_model_keys_status[model_layer_name] = True
|
290 |
+
|
291 |
+
if all(curr_model_keys_status.values()):
|
292 |
+
print(f"Success! All necessary keys for module '{module.__class__.__name__}' are loaded!")
|
293 |
+
else:
|
294 |
+
not_loaded_keys = [k for k, v in curr_model_keys_status.items() if not v]
|
295 |
+
print(f"Warning! Some keys are not loaded! Not loaded keys are:\n{not_loaded_keys}")
|
296 |
+
if len(outstanding_keys) > 0:
|
297 |
+
print(f"Outstanding keys are: {outstanding_keys}")
|
298 |
+
module.load_state_dict(curr_model_state_dict, strict=False)
|
299 |
+
|
300 |
+
def load_module_from_ckpt(
|
301 |
+
self,
|
302 |
+
module_class: Type[nn.Module],
|
303 |
+
ckpt_path: Union[None, str],
|
304 |
+
module_name: str,
|
305 |
+
*args,
|
306 |
+
**kwargs,
|
307 |
+
) -> nn.Module:
|
308 |
+
module = module_class(*args, **kwargs)
|
309 |
+
|
310 |
+
if ckpt_path is not None:
|
311 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
312 |
+
ckpt_state_dict = ckpt["state_dict"]
|
313 |
+
self.load_module_state_dict(module, ckpt_state_dict, module_name=module_name)
|
314 |
+
return module
|
315 |
+
|
316 |
+
def load_state_dict(self, state_dict, strict=True, assign=False):
|
317 |
+
try:
|
318 |
+
super().load_state_dict(state_dict, strict=strict, assign=assign)
|
319 |
+
except Exception as e:
|
320 |
+
print(f"Error loading state dict using normal method: {e}")
|
321 |
+
print("Trying to load state dict manually...")
|
322 |
+
# self.load_module_state_dict(self.fe, state_dict, module_name="fe")
|
323 |
+
# self.load_module_state_dict(self.comparenet, state_dict, module_name="comparenet")
|
324 |
+
self.load_module_state_dict(self, state_dict, module_name="")
|
325 |
+
print("State dict loaded successfully!")
|
326 |
+
|
327 |
+
def forward_fe(self, x):
|
328 |
+
if self.freeze_fe:
|
329 |
+
self.fe.eval()
|
330 |
+
with torch.no_grad():
|
331 |
+
return self.fe(x)
|
332 |
+
else:
|
333 |
+
self.fe.train()
|
334 |
+
return self.fe(x)
|
335 |
+
|
336 |
+
def forward(self, x1, x2):
|
337 |
+
x1 = self.forward_fe(x1)
|
338 |
+
x2 = self.forward_fe(x2)
|
339 |
+
return self.comparenet(x1, x2)
|
340 |
+
|
341 |
+
|
342 |
+
class FsgModel(
|
343 |
+
PreTrainedModel,
|
344 |
+
PyTorchModelHubMixin,
|
345 |
+
repo_url="ductai199x/forensic-similarity-graph",
|
346 |
+
pipeline_tag="image-manipulation-detection-localization",
|
347 |
+
license="cc-by-nc-nd-4.0",
|
348 |
+
):
|
349 |
+
"""
|
350 |
+
Forensic Similarity Graph (FSG) algorithm.
|
351 |
+
https://ieeexplore.ieee.org/abstract/document/9113265
|
352 |
+
|
353 |
+
This class is designed to create a graph-based representation of forensic similarity between different patches of an image, allowing for the detection of manipulated regions.
|
354 |
+
|
355 |
+
Args:
|
356 |
+
stride_ratio (float): The ratio of the stride to the patch size, determining the overlap between patches. The lower the value, the higher the overlap.
|
357 |
+
fast_sim_mode (bool): If True, the algorithm uses a faster method to compute similarity scores, potentially at the cost of accuracy.
|
358 |
+
loc_threshold (float): The threshold for determining the location of interest in the similarity graph. Values above this threshold are considered significant.
|
359 |
+
is_high_sim (bool): If True, higher similarity scores indicate higher similarity. If False, lower scores indicate higher similarity.
|
360 |
+
need_input_255 (bool): If True, input images are expected to be scaled to [0, 255]. If False, images are expected to be in [0, 1].
|
361 |
+
**kwargs: Additional keyword arguments passed to the superclass initializer.
|
362 |
+
|
363 |
+
Example Usage:
|
364 |
+
```python
|
365 |
+
import torch
|
366 |
+
import matplotlib.pyplot as plt
|
367 |
+
from torchvision.io import read_image, ImageReadMode
|
368 |
+
from model import FSG
|
369 |
+
|
370 |
+
ckpt_path = "path/to/ckpt.pth"
|
371 |
+
model = FSG.load_from_checkpoint(ckpt_path, map_location="cpu", stride_ratio=0.5, fast_sim_mode=False, loc_threshold=0.37, is_high_sim=False, need_input_255=False)
|
372 |
+
model.eval()
|
373 |
+
|
374 |
+
img_path = "path/to/image.jpg"
|
375 |
+
image = read_image(img_path, mode=ImageReadMode.RGB).float() / 255
|
376 |
+
|
377 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
378 |
+
with torch.no_grad():
|
379 |
+
img_preds, loc_preds = model(image[None, ...].to(device))
|
380 |
+
|
381 |
+
plt.imshow(loc_preds.cpu()[0])
|
382 |
+
plt.colorbar()
|
383 |
+
plt.show()
|
384 |
+
```
|
385 |
+
"""
|
386 |
+
|
387 |
+
def __init__(self, config: FsgConfig, **kwargs):
|
388 |
+
super().__init__(config)
|
389 |
+
self.patch_size = config.fe_config.patch_size
|
390 |
+
self.stride = int(self.patch_size * config.stride_ratio)
|
391 |
+
self.fast_sim_mode = config.fast_sim_mode
|
392 |
+
self.loc_threshold = config.loc_threshold
|
393 |
+
self.is_high_sim = True
|
394 |
+
self.need_input_255 = config.need_input_255
|
395 |
+
self.model = FSM(fe_config=config.fe_config.to_dict(), comparenet_config=config.comparenet_config.to_dict())
|
396 |
+
|
397 |
+
warnings.filterwarnings("ignore")
|
398 |
+
|
399 |
+
def get_batched_patches(self, x: torch.Tensor):
|
400 |
+
B, C, H, W = x.shape
|
401 |
+
# split images into batches of patches: B x C x H x W -> B x (NumPatchHeight x NumPatchWidth) x C x PatchSize x PatchSize
|
402 |
+
batched_patches = (
|
403 |
+
x.unfold(2, self.patch_size, self.stride)
|
404 |
+
.unfold(3, self.patch_size, self.stride)
|
405 |
+
.permute(0, 2, 3, 1, 4, 5)
|
406 |
+
)
|
407 |
+
batched_patches = batched_patches.contiguous().view(B, -1, C, self.patch_size, self.patch_size)
|
408 |
+
return batched_patches
|
409 |
+
|
410 |
+
def get_patches_single(self, x: torch.Tensor):
|
411 |
+
C, H, W = x.shape
|
412 |
+
patches = (
|
413 |
+
x.unfold(1, self.patch_size, self.stride)
|
414 |
+
.unfold(2, self.patch_size, self.stride)
|
415 |
+
.permute(1, 2, 0, 3, 4)
|
416 |
+
)
|
417 |
+
patches = patches.contiguous().view(-1, C, self.patch_size, self.patch_size)
|
418 |
+
return patches
|
419 |
+
|
420 |
+
@jit(forceobj=True)
|
421 |
+
def get_features(self, image_patches: torch.Tensor):
|
422 |
+
patches_features = []
|
423 |
+
for batch in list(batch_fn(image_patches, 256)):
|
424 |
+
batch = batch.float()
|
425 |
+
feats = self.model.fe(batch).detach()
|
426 |
+
patches_features.append(feats)
|
427 |
+
patches_features = torch.vstack(patches_features)
|
428 |
+
return patches_features
|
429 |
+
|
430 |
+
@jit(forceobj=True)
|
431 |
+
def get_sim_scores(self, patch_pairs):
|
432 |
+
patches_sim_scores = []
|
433 |
+
for batch in list(batch_fn(patch_pairs, 4096)):
|
434 |
+
batch = batch.permute(1, 0, 2).float()
|
435 |
+
scores = self.model.comparenet(*batch).detach()
|
436 |
+
scores = torch.nn.functional.softmax(scores, dim=1)
|
437 |
+
patches_sim_scores.append(scores)
|
438 |
+
patches_sim_scores = torch.vstack(patches_sim_scores)
|
439 |
+
return patches_sim_scores
|
440 |
+
|
441 |
+
def forward_single(self, patches: torch.Tensor):
|
442 |
+
P, C, H, W = patches.shape
|
443 |
+
features = self.get_features(patches)
|
444 |
+
sim_mat = torch.zeros(P, P, device=patches.device)
|
445 |
+
if self.fast_sim_mode:
|
446 |
+
upper_tri_idx = torch.triu_indices(P, P, 1).T
|
447 |
+
patch_pairs = features[upper_tri_idx]
|
448 |
+
else:
|
449 |
+
patch_cart_prod = torch.cartesian_prod(torch.arange(P), torch.arange(P))
|
450 |
+
patch_pairs = features[patch_cart_prod]
|
451 |
+
sim_scores = self.get_sim_scores(patch_pairs).detach()
|
452 |
+
if self.fast_sim_mode:
|
453 |
+
sim_mat[upper_tri_idx[:, 0], upper_tri_idx[:, 1]] = sim_scores[:, 1]
|
454 |
+
sim_mat += sim_mat.clone().T
|
455 |
+
else:
|
456 |
+
sim_mat = sim_scores[:, 1].view(P, P)
|
457 |
+
sim_mat = 0.5 * (sim_mat + sim_mat.T)
|
458 |
+
if not self.is_high_sim:
|
459 |
+
sim_mat = 1 - sim_mat
|
460 |
+
sim_mat.fill_diagonal_(0.0)
|
461 |
+
degree_mat = torch.diag(sim_mat.sum(axis=1))
|
462 |
+
laplacian_mat = degree_mat - sim_mat
|
463 |
+
degree_sym_mat = torch.diag(sim_mat.sum(axis=1) ** -0.5)
|
464 |
+
laplacian_sym_mat = (degree_sym_mat @ laplacian_mat) @ degree_sym_mat
|
465 |
+
eigvals, eigvecs = torch.linalg.eigh(laplacian_sym_mat.cpu())
|
466 |
+
spectral_gap = eigvals[1] - eigvals[0]
|
467 |
+
img_pred = 1 - spectral_gap
|
468 |
+
eigvec = eigvecs[:, 1]
|
469 |
+
patch_pred = (eigvec > 0).int()
|
470 |
+
return img_pred.detach(), patch_pred.detach()
|
471 |
+
|
472 |
+
def forward(self, x: Union[torch.Tensor, List[torch.Tensor]]):
|
473 |
+
if isinstance(x, torch.Tensor) and len(x.shape) == 3:
|
474 |
+
x = [x]
|
475 |
+
|
476 |
+
img_preds = []
|
477 |
+
loc_preds = []
|
478 |
+
for img in x:
|
479 |
+
C, H, W = img.shape
|
480 |
+
if self.need_input_255 and img.max() <= 1:
|
481 |
+
img = img * 255
|
482 |
+
# get the (x, y) coordinates of the top left of each patch in the image
|
483 |
+
x_inds = torch.arange(W).unfold(0, self.patch_size, self.stride)[:, 0]
|
484 |
+
y_inds = torch.arange(H).unfold(0, self.patch_size, self.stride)[:, 0]
|
485 |
+
xy_inds = torch.tensor([(ii, jj) for jj in y_inds for ii in x_inds]).to(img.device)
|
486 |
+
|
487 |
+
patches = self.get_patches_single(img)
|
488 |
+
img_pred, patch_pred = self.forward_single(patches)
|
489 |
+
loc_pred = self.patch_to_pixel_pred(patch_pred, xy_inds)
|
490 |
+
loc_pred = F.interpolate(loc_pred[None, None, ...], size=(H, W), mode="nearest").squeeze()
|
491 |
+
img_preds.append(img_pred)
|
492 |
+
loc_preds.append(loc_pred)
|
493 |
+
return img_preds, loc_preds
|
494 |
+
|
495 |
+
def patch_to_pixel_pred(self, patch_pred, xy_inds):
|
496 |
+
W, H = torch.max(xy_inds, dim=0).values + self.patch_size
|
497 |
+
pixel_pred = torch.zeros((H, W)).to(patch_pred.device)
|
498 |
+
coverage_map = torch.zeros((H, W)).to(patch_pred.device)
|
499 |
+
for (x, y), pred in zip(xy_inds, patch_pred):
|
500 |
+
pixel_pred[y : y + self.patch_size, x : x + self.patch_size] += pred
|
501 |
+
coverage_map[y : y + self.patch_size, x : x + self.patch_size] += 1
|
502 |
+
# perform gaussian smoothing
|
503 |
+
pixel_pred = gaussian_filter_2d(pixel_pred, sigma=32)
|
504 |
+
coverage_map = gaussian_filter_2d(coverage_map, sigma=32)
|
505 |
+
pixel_pred /= coverage_map + 1e-8
|
506 |
+
pixel_pred /= pixel_pred.max() + 1e-8
|
507 |
+
if pixel_pred.sum() > pixel_pred.numel() * 0.5:
|
508 |
+
pixel_pred = 1 - pixel_pred
|
509 |
+
pixel_pred = (pixel_pred > self.loc_threshold).float()
|
510 |
+
return pixel_pred
|