|
import argparse |
|
import os |
|
import torch |
|
import clip |
|
import os |
|
from tqdm import tqdm |
|
import time |
|
from utils import ModelWrapper, maybe_dictionarize_batch, cosine_lr |
|
from zeroshot import zeroshot_classifier |
|
import torch |
|
from torchvision import transforms, datasets |
|
|
|
def parse_arguments(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--data-location", |
|
type=str, |
|
default=os.path.expanduser('~/data'), |
|
help="The root directory for the datasets.", |
|
) |
|
parser.add_argument( |
|
"--model-location", |
|
type=str, |
|
default=os.path.expanduser('~/ssd/checkpoints/soups'), |
|
help="Where to download the models.", |
|
) |
|
parser.add_argument( |
|
"--batch-size", |
|
type=int, |
|
default=256, |
|
) |
|
parser.add_argument( |
|
"--workers", |
|
type=int, |
|
default=8, |
|
) |
|
parser.add_argument( |
|
"--epochs", |
|
type=int, |
|
default=8, |
|
) |
|
parser.add_argument( |
|
"--warmup-length", |
|
type=int, |
|
default=500, |
|
) |
|
parser.add_argument( |
|
"--lr", |
|
type=float, |
|
default=2e-5, |
|
) |
|
parser.add_argument( |
|
"--wd", |
|
type=float, |
|
default=0.1, |
|
) |
|
parser.add_argument( |
|
"--model", |
|
default='ViT-B/32', |
|
help='Model to use -- you can try another like ViT-L/14' |
|
) |
|
parser.add_argument( |
|
"--name", |
|
default='finetune_cp', |
|
help='Filename for the checkpoints.' |
|
) |
|
parser.add_argument( |
|
"--timm-aug", action="store_true", default=False, |
|
) |
|
parser.add_argument( |
|
"--checkpoint_path", |
|
default=None, |
|
help='Checkpoint path to load the model' |
|
) |
|
|
|
return parser.parse_args() |
|
|
|
if __name__ == '__main__': |
|
args = parse_arguments() |
|
DEVICE = 'cuda' |
|
|
|
|
|
template = [lambda x : f"a photo generated by {x}."] |
|
|
|
|
|
base_model, preprocess = clip.load(args.model, 'cuda', jit=False) |
|
|
|
|
|
train_transforms = transforms.Compose([transforms.RandomRotation(30), |
|
transforms.RandomResizedCrop(224), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.ToTensor()]) |
|
|
|
test_transforms = transforms.Compose([transforms.RandomRotation(30), |
|
transforms.RandomResizedCrop(224), |
|
transforms.ToTensor()]) |
|
|
|
|
|
train_data = datasets.ImageFolder(args.data_location + '/train', transform=train_transforms) |
|
test_data = datasets.ImageFolder(args.data_location + '/test', transform=test_transforms) |
|
|
|
train_dset = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, num_workers=args.workers, shuffle = True) |
|
test_dset = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, num_workers=args.workers) |
|
|
|
clf = zeroshot_classifier(base_model, ['humans', 'AI'], template, DEVICE) |
|
NUM_CLASSES = 2 |
|
feature_dim = base_model.visual.output_dim |
|
|
|
model = ModelWrapper(base_model, feature_dim, NUM_CLASSES, normalize=True, initial_weights=clf, checkpoint_path = args.checkpoint_path) |
|
for p in model.parameters(): |
|
p.data = p.data.float() |
|
|
|
model = model.cuda() |
|
devices = [x for x in range(torch.cuda.device_count())] |
|
model = torch.nn.DataParallel(model, device_ids=devices) |
|
|
|
model_parameters = [p for p in model.parameters() if p.requires_grad] |
|
optimizer = torch.optim.AdamW(model_parameters, lr=args.lr, weight_decay=args.wd) |
|
|
|
num_batches = len(train_dset) |
|
scheduler = cosine_lr(optimizer, args.lr, args.warmup_length, args.epochs * num_batches) |
|
|
|
loss_fn = torch.nn.CrossEntropyLoss() |
|
|
|
model_path = os.path.join(args.model_location, f'{args.name}.pt') |
|
print('Saving model to', model_path) |
|
torch.save(model.module.state_dict(), model_path) |
|
|
|
last_accuracy = 0.0 |
|
|
|
for epoch in range(args.epochs): |
|
|
|
model.train() |
|
end = time.time() |
|
for i, batch in enumerate(train_dset): |
|
step = i + epoch * num_batches |
|
scheduler(step) |
|
optimizer.zero_grad() |
|
batch = maybe_dictionarize_batch(batch) |
|
inputs, labels = batch['images'].to(DEVICE), batch['labels'].to(DEVICE) |
|
data_time = time.time() - end |
|
|
|
logits = model(inputs) |
|
loss = loss_fn(logits, labels) |
|
|
|
loss.backward() |
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
|
optimizer.step() |
|
|
|
batch_time = time.time() - end |
|
end = time.time() |
|
|
|
if i % 20 == 0: |
|
percent_complete = 100.0 * i / len(train_dset) |
|
print( |
|
f"Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(train_dset)}]\t" |
|
f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", flush=True |
|
) |
|
|
|
|
|
test_loader = test_dset |
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
print('*'*80) |
|
print('Starting eval') |
|
correct, count = 0.0, 0.0 |
|
pbar = tqdm(test_loader) |
|
for batch in pbar: |
|
batch = maybe_dictionarize_batch(batch) |
|
inputs, labels = batch['images'].to(DEVICE), batch['labels'].to(DEVICE) |
|
|
|
logits = model(inputs) |
|
|
|
loss = loss_fn(logits, labels) |
|
|
|
pred = logits.argmax(dim=1, keepdim=True) |
|
correct += pred.eq(labels.view_as(pred)).sum().item() |
|
count += len(logits) |
|
pbar.set_description( |
|
f"Val loss: {loss.item():.4f} Acc: {100*correct/count:.2f}") |
|
top1 = correct / count |
|
print(f'Val acc at epoch {epoch}: {100*top1:.2f}') |
|
|
|
curr_acc = 100*top1 |
|
if curr_acc > last_accuracy: |
|
print('Current acc: {}, Last acc: {}'.format(curr_acc, last_accuracy)) |
|
last_accuracy = curr_acc |
|
model_path = os.path.join(args.model_location, f'{args.name}.pt') |
|
print('Saving model to', model_path) |
|
torch.save(model.module.state_dict(), model_path) |
|
else: |
|
print('Not saving the model') |
|
|
|
|