In [None]:
import glob
import torch
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoModel
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
from torchmetrics.classification import BinaryF1Score, BinaryAveragePrecision
from tqdm.auto import tqdm

In [None]:
model = AutoModel.from_pretrained("ductai199x/forensic-similarity-graph", trust_remote_code=True)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model = model.eval().to(device)

In [None]:
image_paths = sorted(glob.glob("example_images/splicing-??.png"))
gt_paths = sorted(glob.glob("example_images/splicing-??-gt.png"))
image_vs_gt_paths = list(zip(image_paths, gt_paths))

In [None]:
with torch.no_grad():
 imgs = []
 gts = []
 img_preds = []
 loc_preds = []
 f1, mAP = BinaryF1Score(), BinaryAveragePrecision()
 for image_path, gt_path in tqdm(image_vs_gt_paths):
 image = pil_to_tensor(Image.open(image_path).convert("RGB")).float() / 255
 gt = ((pil_to_tensor(Image.open(gt_path).convert("L")).float() / 255) < 0.9).int()
 img_pred, loc_pred = model(image.unsqueeze(0).to(device))
 img_pred, loc_pred = img_pred[0].cpu(), loc_pred[0].cpu()
 f1.update(loc_pred[None, ...], gt)
 mAP.update(loc_pred[None, ...], gt)
 img_preds.append(img_pred)
 loc_preds.append(loc_pred)
 imgs.append(image)
 gts.append(gt)

In [None]:
f1.compute().item(), mAP.compute().item()

In [None]:
col = 4 * 2
row = -(-len(image_vs_gt_paths) // 4)
fig, axs = plt.subplots(row, col)
fig.set_size_inches(3 * col, 3 * row)
for i, (img, gt, img_pred, loc_pred) in enumerate(zip(imgs, gts, img_preds, loc_preds)):
 ax = axs[i // 4][(i % 4) * 2]
 ax.imshow(to_pil_image(img))
 ax = axs[i // 4][(i % 4) * 2 + 1]
 ax.imshow(to_pil_image(gt.float()))
 ax.imshow(loc_pred, alpha=0.5, cmap="coolwarm")

for ax in axs.flat:
 ax.axis("off")