{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import glob\n", "import torch\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "from transformers import AutoModel\n", "from torchvision.transforms.functional import to_pil_image, pil_to_tensor\n", "from torchmetrics.classification import BinaryF1Score, BinaryAveragePrecision\n", "from tqdm.auto import tqdm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = AutoModel.from_pretrained(\"ductai199x/forensic-similarity-graph\", trust_remote_code=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = model.eval().to(device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image_paths = sorted(glob.glob(\"example_images/splicing-??.png\"))\n", "gt_paths = sorted(glob.glob(\"example_images/splicing-??-gt.png\"))\n", "image_vs_gt_paths = list(zip(image_paths, gt_paths))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " imgs = []\n", " gts = []\n", " img_preds = []\n", " loc_preds = []\n", " f1, mAP = BinaryF1Score(), BinaryAveragePrecision()\n", " for image_path, gt_path in tqdm(image_vs_gt_paths):\n", " image = pil_to_tensor(Image.open(image_path).convert(\"RGB\")).float() / 255\n", " gt = ((pil_to_tensor(Image.open(gt_path).convert(\"L\")).float() / 255) < 0.9).int()\n", " img_pred, loc_pred = model(image.unsqueeze(0).to(device))\n", " img_pred, loc_pred = img_pred[0].cpu(), loc_pred[0].cpu()\n", " f1.update(loc_pred[None, ...], gt)\n", " mAP.update(loc_pred[None, ...], gt)\n", " img_preds.append(img_pred)\n", " loc_preds.append(loc_pred)\n", " imgs.append(image)\n", " gts.append(gt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f1.compute().item(), mAP.compute().item()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "col = 4 * 2\n", "row = -(-len(image_vs_gt_paths) // 4)\n", "fig, axs = plt.subplots(row, col)\n", "fig.set_size_inches(3 * col, 3 * row)\n", "for i, (img, gt, img_pred, loc_pred) in enumerate(zip(imgs, gts, img_preds, loc_preds)):\n", " ax = axs[i // 4][(i % 4) * 2]\n", " ax.imshow(to_pil_image(img))\n", " ax = axs[i // 4][(i % 4) * 2 + 1]\n", " ax.imshow(to_pil_image(gt.float()))\n", " ax.imshow(loc_pred, alpha=0.5, cmap=\"coolwarm\")\n", "\n", "for ax in axs.flat:\n", " ax.axis(\"off\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "pyt_tf2", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 2 }