forensic-similarity-graph / modeling_fsg.py
ductai199x's picture
Update modeling_fsg.py
97c540b verified
import math
import re
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from transformers import PreTrainedModel
from huggingface_hub import PyTorchModelHubMixin
from numba import jit
from .configuration import FsgConfig
from typing import Literal, Type, Union, List
def batch_fn(iterable, n=1):
l = len(iterable)
for ndx in range(0, l, n):
yield iterable[ndx : min(ndx + n, l)]
def gaussian_kernel_1d(sigma: float, num_sigmas: float = 3.0) -> torch.Tensor:
radius = math.ceil(num_sigmas * sigma)
support = torch.arange(-radius, radius + 1, dtype=torch.float)
kernel = Normal(loc=0, scale=sigma).log_prob(support).exp_()
# Ensure kernel weights sum to 1, so that image brightness is not altered
return kernel.mul_(1 / kernel.sum())
def gaussian_filter_2d(img: torch.Tensor, sigma: float) -> torch.Tensor:
kernel_1d = gaussian_kernel_1d(sigma).to(img.device) # Create 1D Gaussian kernel
padding = len(kernel_1d) // 2 # Ensure that image size does not change
img = img[None, None, ...] # Need 4D data for ``conv2d()``
# Convolve along columns and rows
img = F.conv2d(img, weight=kernel_1d.view(1, 1, -1, 1), padding=(padding, 0))
img = F.conv2d(img, weight=kernel_1d.view(1, 1, 1, -1), padding=(0, padding))
return img.squeeze() # Make 2D again
class BaseModel(nn.Module):
def __init__(
self,
patch_size: int,
num_classes: int = 0,
**kwargs,
):
super().__init__()
self.patch_size = patch_size
self.num_classes = num_classes
class ConstrainedConv(nn.Module):
def __init__(self, input_chan=3, num_filters=6, is_constrained=True):
super().__init__()
self.kernel_size = 5
self.input_chan = input_chan
self.num_filters = num_filters
self.is_constrained = is_constrained
weight = torch.empty(num_filters, input_chan, self.kernel_size, self.kernel_size)
nn.init.xavier_normal_(weight, gain=1/3)
self.weight = nn.Parameter(weight, requires_grad=True)
self.one_middle = torch.zeros(self.kernel_size * self.kernel_size)
self.one_middle[12] = 1
self.one_middle = nn.Parameter(self.one_middle, requires_grad=False)
def forward(self, x):
w = self.weight
if self.is_constrained:
w = w.view(-1, self.kernel_size * self.kernel_size)
w = w - w.mean(1)[..., None] + 1 / (self.kernel_size * self.kernel_size - 1)
w = w - (w + 1) * self.one_middle
w = w.view(self.num_filters, self.input_chan, self.kernel_size, self.kernel_size)
x = nn.functional.conv2d(x, w, padding="valid")
x = nn.functional.pad(x, (2, 3, 2, 3))
return x
class ConvBlock(torch.nn.Module):
def __init__(
self,
in_chans,
out_chans,
kernel_size,
stride,
padding,
activation: Literal["tanh", "relu"],
):
super().__init__()
assert activation.lower() in ["tanh", "relu"], "The activation layer must be either Tanh or ReLU"
self.conv = torch.nn.Conv2d(
in_chans,
out_chans,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
self.bn = torch.nn.BatchNorm2d(out_chans)
self.act = torch.nn.Tanh() if activation.lower() == "tanh" else torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(kernel_size=(3, 3), stride=2)
def forward(self, x):
return self.maxpool(self.act(self.bn(self.conv(x))))
class DenseBlock(torch.nn.Module):
def __init__(
self,
in_chans,
out_chans,
activation: Literal["tanh", "relu"],
):
super().__init__()
assert activation.lower() in ["tanh", "relu"], "The activation layer must be either Tanh or ReLU"
self.fc = torch.nn.Linear(in_chans, out_chans)
self.act = torch.nn.Tanh() if activation.lower() == "tanh" else torch.nn.ReLU()
def forward(self, x):
return self.act(self.fc(x))
class MISLNet(BaseModel):
arch = {
"p256": [
("conv1", -1, 96, 7, 2, "valid", "tanh"),
("conv2", 96, 64, 5, 1, "same", "tanh"),
("conv3", 64, 64, 5, 1, "same", "tanh"),
("conv4", 64, 128, 1, 1, "same", "tanh"),
("fc1", 6 * 6 * 128, 200, "tanh"),
("fc2", 200, 200, "tanh"),
],
"p256_3fc_256e": [
("conv1", -1, 96, 7, 2, "valid", "tanh"),
("conv2", 96, 64, 5, 1, "same", "tanh"),
("conv3", 64, 64, 5, 1, "same", "tanh"),
("conv4", 64, 128, 1, 1, "same", "tanh"),
("fc1", 6 * 6 * 128, 1024, "tanh"),
("fc2", 1024, 512, "tanh"),
("fc3", 512, 256, "tanh"),
],
"p128": [
("conv1", -1, 96, 7, 2, "valid", "tanh"),
("conv2", 96, 64, 5, 1, "same", "tanh"),
("conv3", 64, 64, 5, 1, "same", "tanh"),
("conv4", 64, 128, 1, 1, "same", "tanh"),
("fc1", 2 * 2 * 128, 200, "tanh"),
("fc2", 200, 200, "tanh"),
],
"p96": [
("conv1", -1, 96, 7, 2, "valid", "tanh"),
("conv2", 96, 64, 5, 1, "same", "tanh"),
("conv3", 64, 64, 5, 1, "same", "tanh"),
("conv4", 64, 128, 1, 1, "same", "tanh"),
("fc1", 8 * 4 * 64, 200, "tanh"),
("fc2", 200, 200, "tanh"),
],
"p64": [
("conv1", -1, 96, 7, 2, "valid", "tanh"),
("conv2", 96, 64, 5, 1, "same", "tanh"),
("conv3", 64, 64, 5, 1, "same", "tanh"),
("conv4", 64, 128, 1, 1, "same", "tanh"),
("fc1", 2 * 4 * 64, 200, "tanh"),
("fc2", 200, 200, "tanh"),
],
}
def __init__(
self,
patch_size: int,
variant: str,
num_classes=0,
num_filters=6,
is_constrained=True,
**kwargs,
):
super().__init__(patch_size, num_classes)
self.variant = variant
self.chosen_arch = self.arch[variant]
self.num_filters = num_filters
self.constrained_conv = ConstrainedConv(num_filters=num_filters, is_constrained=is_constrained)
self.conv_blocks = []
self.fc_blocks = []
for block in self.chosen_arch:
if block[0].startswith("conv"):
self.conv_blocks.append(
ConvBlock(
in_chans=(num_filters if block[1] == -1 else block[1]),
out_chans=block[2],
kernel_size=block[3],
stride=block[4],
padding=block[5],
activation=block[6],
)
)
elif block[0].startswith("fc"):
self.fc_blocks.append(
DenseBlock(
in_chans=block[1],
out_chans=block[2],
activation=block[3],
)
)
self.conv_blocks = nn.Sequential(*self.conv_blocks)
self.fc_blocks = nn.Sequential(*self.fc_blocks)
self.register_buffer("flatten_index_permutation", torch.tensor([0, 1, 2, 3], dtype=torch.long))
if self.num_classes > 0:
self.output = nn.Linear(self.chosen_arch[-1][2], self.num_classes)
def forward(self, x):
x = self.constrained_conv(x)
x = self.conv_blocks(x)
x = x.permute(*self.flatten_index_permutation)
x = x.flatten(1, -1)
x = self.fc_blocks(x)
if self.num_classes > 0:
x = self.output(x)
return x
def load_state_dict(self, state_dict, strict=True, assign=False):
if "flatten_index_permutation" not in state_dict:
super().load_state_dict(state_dict, False, assign)
else:
super().load_state_dict(state_dict, strict, assign)
class CompareNet(nn.Module):
def __init__(self, input_dim, hidden_dim=2048, output_dim=64):
super().__init__()
self.fc1 = DenseBlock(input_dim, hidden_dim, "relu")
self.fc2 = DenseBlock(hidden_dim * 3, output_dim, "relu")
self.fc3 = nn.Linear(output_dim, 2)
def forward(self, x1, x2):
x1 = self.fc1(x1)
x2 = self.fc1(x2)
x = torch.cat((x1, x1 * x2, x2), dim=1)
x = self.fc2(x)
x = self.fc3(x)
return x
class FSM(nn.Module):
"""
FSM (Forensic Similarity Metric) is a neural network module that computes the similarity between two input images using a feature extraction module and a comparison network module.
Args:
fe_config (dict): Configuration for the feature extraction module.
comparenet_config (dict): Configuration for the comparison network module.
fe_ckpt (str): Path to the checkpoint file for the feature extraction module.
**kwargs: Additional keyword arguments.
"""
def __init__(
self,
fe_config,
comparenet_config,
fe_ckpt=None,
**kwargs,
):
super().__init__()
fe_config["num_classes"] = 0 # to make fe without final classification layer
self.fe: MISLNet = self.load_module_from_ckpt(MISLNet, fe_ckpt, "", **fe_config)
self.patch_size = self.fe.patch_size
comparenet_config["input_dim"] = self.fe.fc_blocks[-1].fc.out_features
self.comparenet = CompareNet(**comparenet_config)
self.fe_freeze = True
def load_module_state_dict(self, module: nn.Module, state_dict, module_name=""):
curr_model_state_dict = module.state_dict()
curr_model_keys_status = {k: False for k in curr_model_state_dict.keys()}
outstanding_keys = []
for ckpt_layer_name, ckpt_layer_weights in state_dict.items():
if module_name not in ckpt_layer_name:
continue
ckpt_matches = re.findall(r"(?=(?:^|\.)((?:\w+\.)*\w+)$)", ckpt_layer_name)[::-1]
model_layer_name_match = list(set(ckpt_matches).intersection(set(curr_model_state_dict.keys())))
# print(ckpt_layer_name, model_layer_name_match)
if len(model_layer_name_match) == 0:
outstanding_keys.append(ckpt_layer_name)
else:
model_layer_name = model_layer_name_match[0]
assert (
curr_model_state_dict[model_layer_name].shape == ckpt_layer_weights.shape
), f"Ckpt layer '{ckpt_layer_name}' shape {ckpt_layer_weights.shape} does not match model layer '{model_layer_name}' shape {curr_model_state_dict[model_layer_name].shape}"
curr_model_state_dict[model_layer_name] = ckpt_layer_weights
curr_model_keys_status[model_layer_name] = True
if all(curr_model_keys_status.values()):
print(f"Success! All necessary keys for module '{module.__class__.__name__}' are loaded!")
else:
not_loaded_keys = [k for k, v in curr_model_keys_status.items() if not v]
print(f"Warning! Some keys are not loaded! Not loaded keys are:\n{not_loaded_keys}")
if len(outstanding_keys) > 0:
print(f"Outstanding keys are: {outstanding_keys}")
module.load_state_dict(curr_model_state_dict, strict=False)
def load_module_from_ckpt(
self,
module_class: Type[nn.Module],
ckpt_path: Union[None, str],
module_name: str,
*args,
**kwargs,
) -> nn.Module:
module = module_class(*args, **kwargs)
if ckpt_path is not None:
ckpt = torch.load(ckpt_path, map_location="cpu")
ckpt_state_dict = ckpt["state_dict"]
self.load_module_state_dict(module, ckpt_state_dict, module_name=module_name)
return module
def load_state_dict(self, state_dict, strict=True, assign=False):
try:
super().load_state_dict(state_dict, strict=strict, assign=assign)
except Exception as e:
print(f"Error loading state dict using normal method: {e}")
print("Trying to load state dict manually...")
# self.load_module_state_dict(self.fe, state_dict, module_name="fe")
# self.load_module_state_dict(self.comparenet, state_dict, module_name="comparenet")
self.load_module_state_dict(self, state_dict, module_name="")
print("State dict loaded successfully!")
def forward_fe(self, x):
if self.freeze_fe:
self.fe.eval()
with torch.no_grad():
return self.fe(x)
else:
self.fe.train()
return self.fe(x)
def forward(self, x1, x2):
x1 = self.forward_fe(x1)
x2 = self.forward_fe(x2)
return self.comparenet(x1, x2)
class FsgModel(
PreTrainedModel,
PyTorchModelHubMixin,
repo_url="ductai199x/forensic-similarity-graph",
pipeline_tag="image-manipulation-detection-localization",
license="cc-by-nc-nd-4.0",
):
"""
Forensic Similarity Graph (FSG) algorithm.
https://ieeexplore.ieee.org/abstract/document/9113265
This class is designed to create a graph-based representation of forensic similarity between different patches of an image, allowing for the detection of manipulated regions.
"""
config_class = FsgConfig
def __init__(self, config: FsgConfig, **kwargs):
super().__init__(config)
self.patch_size = config.fe_config.patch_size
self.stride = int(self.patch_size * config.stride_ratio)
self.fast_sim_mode = config.fast_sim_mode
self.loc_threshold = config.loc_threshold
self.is_high_sim = True
self.need_input_255 = config.need_input_255
self.model = FSM(fe_config=config.fe_config.to_dict(), comparenet_config=config.comparenet_config.to_dict())
warnings.filterwarnings("ignore")
def get_batched_patches(self, x: torch.Tensor):
B, C, H, W = x.shape
# split images into batches of patches: B x C x H x W -> B x (NumPatchHeight x NumPatchWidth) x C x PatchSize x PatchSize
batched_patches = (
x.unfold(2, self.patch_size, self.stride)
.unfold(3, self.patch_size, self.stride)
.permute(0, 2, 3, 1, 4, 5)
)
batched_patches = batched_patches.contiguous().view(B, -1, C, self.patch_size, self.patch_size)
return batched_patches
def get_patches_single(self, x: torch.Tensor):
C, H, W = x.shape
patches = (
x.unfold(1, self.patch_size, self.stride)
.unfold(2, self.patch_size, self.stride)
.permute(1, 2, 0, 3, 4)
)
patches = patches.contiguous().view(-1, C, self.patch_size, self.patch_size)
return patches
@jit(forceobj=True)
def get_features(self, image_patches: torch.Tensor):
patches_features = []
for batch in list(batch_fn(image_patches, 256)):
batch = batch.float()
feats = self.model.fe(batch).detach()
patches_features.append(feats)
patches_features = torch.vstack(patches_features)
return patches_features
@jit(forceobj=True)
def get_sim_scores(self, patch_pairs):
patches_sim_scores = []
for batch in list(batch_fn(patch_pairs, 4096)):
batch = batch.permute(1, 0, 2).float()
scores = self.model.comparenet(*batch).detach()
scores = torch.nn.functional.softmax(scores, dim=1)
patches_sim_scores.append(scores)
patches_sim_scores = torch.vstack(patches_sim_scores)
return patches_sim_scores
def forward_single(self, patches: torch.Tensor):
P, C, H, W = patches.shape
features = self.get_features(patches)
sim_mat = torch.zeros(P, P, device=patches.device)
if self.fast_sim_mode:
upper_tri_idx = torch.triu_indices(P, P, 1).T
patch_pairs = features[upper_tri_idx]
else:
patch_cart_prod = torch.cartesian_prod(torch.arange(P), torch.arange(P))
patch_pairs = features[patch_cart_prod]
sim_scores = self.get_sim_scores(patch_pairs).detach()
if self.fast_sim_mode:
sim_mat[upper_tri_idx[:, 0], upper_tri_idx[:, 1]] = sim_scores[:, 1]
sim_mat += sim_mat.clone().T
else:
sim_mat = sim_scores[:, 1].view(P, P)
sim_mat = 0.5 * (sim_mat + sim_mat.T)
if not self.is_high_sim:
sim_mat = 1 - sim_mat
sim_mat.fill_diagonal_(0.0)
degree_mat = torch.diag(sim_mat.sum(axis=1))
laplacian_mat = degree_mat - sim_mat
degree_sym_mat = torch.diag(sim_mat.sum(axis=1) ** -0.5)
laplacian_sym_mat = (degree_sym_mat @ laplacian_mat) @ degree_sym_mat
eigvals, eigvecs = torch.linalg.eigh(laplacian_sym_mat.cpu())
spectral_gap = eigvals[1] - eigvals[0]
img_pred = 1 - spectral_gap
eigvec = eigvecs[:, 1]
patch_pred = (eigvec > 0).int()
return img_pred.detach(), patch_pred.detach()
def forward(self, x: Union[torch.Tensor, List[torch.Tensor]]):
if isinstance(x, torch.Tensor) and len(x.shape) == 3:
x = [x]
img_preds = []
loc_preds = []
for img in x:
C, H, W = img.shape
if self.need_input_255 and img.max() <= 1:
img = img * 255
# get the (x, y) coordinates of the top left of each patch in the image
x_inds = torch.arange(W).unfold(0, self.patch_size, self.stride)[:, 0]
y_inds = torch.arange(H).unfold(0, self.patch_size, self.stride)[:, 0]
xy_inds = torch.tensor([(ii, jj) for jj in y_inds for ii in x_inds]).to(img.device)
patches = self.get_patches_single(img)
img_pred, patch_pred = self.forward_single(patches)
loc_pred = self.patch_to_pixel_pred(patch_pred, xy_inds)
loc_pred = F.interpolate(loc_pred[None, None, ...], size=(H, W), mode="nearest").squeeze()
img_preds.append(img_pred)
loc_preds.append(loc_pred)
return img_preds, loc_preds
def patch_to_pixel_pred(self, patch_pred, xy_inds):
W, H = torch.max(xy_inds, dim=0).values + self.patch_size
pixel_pred = torch.zeros((H, W)).to(patch_pred.device)
coverage_map = torch.zeros((H, W)).to(patch_pred.device)
for (x, y), pred in zip(xy_inds, patch_pred):
pixel_pred[y : y + self.patch_size, x : x + self.patch_size] += pred
coverage_map[y : y + self.patch_size, x : x + self.patch_size] += 1
# perform gaussian smoothing
pixel_pred = gaussian_filter_2d(pixel_pred, sigma=32)
coverage_map = gaussian_filter_2d(coverage_map, sigma=32)
pixel_pred /= coverage_map + 1e-8
pixel_pred /= pixel_pred.max() + 1e-8
if pixel_pred.sum() > pixel_pred.numel() * 0.5:
pixel_pred = 1 - pixel_pred
pixel_pred = (pixel_pred > self.loc_threshold).float()
return pixel_pred