In [1]:
from datasets import load_dataset

data = load_dataset("dataset")

  from .autonotebook import tqdm as notebook_tqdm
Resolving data files: 100%|██████████| 25/25 [00:00<00:00, 203606.99it/s]
Resolving data files: 100%|██████████| 26/26 [00:00<00:00, 203076.17it/s]


In [2]:
# !pip install Pillow

In [3]:
data['train'][0]

{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=1024x2048>,
 'label': 0}

In [4]:
labels = data["train"].features["label"].names
labels

['ai_gen', 'human']

In [5]:
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [6]:
label2id

{'ai_gen': '0', 'human': '1'}

In [7]:
from transformers import AutoImageProcessor

checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

In [8]:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

In [9]:
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

In [10]:
data = data.with_transform(transforms)

In [11]:
data

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 18000
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 20715
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 13354
    })
})

In [12]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

In [13]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    "umm-maybe/AI-image-detector",
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

  return self.fget.__get__(instance, owner)()


In [14]:
import evaluate

accuracy = evaluate.load("accuracy")

import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [15]:
training_args = TrainingArguments(
    output_dir="ai_detector",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    # push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=data["train"],
    eval_dataset=data["test"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
0,0.0347,0.013469,0.99648
1,0.0009,0.02113,0.994234
2,0.0241,0.010735,0.997529


Non-default generation parameters: {'max_length': 128}
Non-default generation parameters: {'max_length': 128}
Non-default generation parameters: {'max_length': 128}


TrainOutput(global_step=843, training_loss=0.034790725539037115, metrics={'train_runtime': 2594.7053, 'train_samples_per_second': 20.812, 'train_steps_per_second': 0.325, 'total_flos': 4.2268994172435825e+18, 'train_loss': 0.034790725539037115, 'epoch': 3.0})