|
import torch |
|
import math |
|
import time |
|
import numpy as np |
|
|
|
class ModelWrapper(torch.nn.Module): |
|
def __init__(self, model, feature_dim, num_classes, normalize=False, initial_weights=None, checkpoint_path = None): |
|
super(ModelWrapper, self).__init__() |
|
self.model = model |
|
self.classification_head = torch.nn.Linear(feature_dim, num_classes) |
|
self.normalize = normalize |
|
|
|
if initial_weights is None: |
|
initial_weights = torch.zeros_like(self.classification_head.weight) |
|
torch.nn.init.kaiming_uniform_(initial_weights, a=math.sqrt(5)) |
|
|
|
self.classification_head.weight = torch.nn.Parameter(initial_weights.clone()) |
|
self.classification_head.bias = torch.nn.Parameter(torch.zeros_like(self.classification_head.bias)) |
|
|
|
|
|
if hasattr(self.model, 'transformer'): |
|
delattr(self.model, 'transformer') |
|
|
|
if checkpoint_path: |
|
print("Loading checkpoint", checkpoint_path) |
|
checkpoint = torch.load(checkpoint_path) |
|
checkpoint.pop('classification_head.weight') |
|
checkpoint.pop('classification_head.bias') |
|
model.load_state_dict(checkpoint, strict=False) |
|
|
|
def forward(self, images, return_features=False): |
|
features = self.model.encode_image(images) |
|
if self.normalize: |
|
features = features / features.norm(dim=-1, keepdim=True) |
|
logits = self.classification_head(features) |
|
if return_features: |
|
return logits, features |
|
return logits |
|
|
|
def get_model_from_sd(state_dict, base_model): |
|
feature_dim = state_dict['classification_head.weight'].shape[1] |
|
num_classes = state_dict['classification_head.weight'].shape[0] |
|
model = ModelWrapper(base_model, feature_dim, num_classes, normalize=True) |
|
for p in model.parameters(): |
|
p.data = p.data.float() |
|
model.load_state_dict(state_dict) |
|
model = model.cuda() |
|
devices = [x for x in range(torch.cuda.device_count())] |
|
return torch.nn.DataParallel(model, device_ids=devices) |
|
|
|
def maybe_dictionarize_batch(batch): |
|
if isinstance(batch, dict): |
|
return batch |
|
if len(batch) == 2: |
|
return {'images': batch[0], 'labels': batch[1]} |
|
elif len(batch) == 3: |
|
return {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]} |
|
else: |
|
raise ValueError(f'Unexpected number of elements: {len(batch)}') |
|
|
|
|
|
def test_model_on_dataset(model, dataset): |
|
|
|
model.eval() |
|
device = 'cuda' |
|
with torch.no_grad(): |
|
top1, correct, n = 0., 0., 0. |
|
end = time.time() |
|
loader = dataset.test_loader |
|
if type(dataset).__name__ == 'ImageNet2p': |
|
loader = dataset.train_loader |
|
|
|
|
|
assert dataset.train_dataset.__getitem__(dataset.sampler.indices[1000])['image_paths'].endswith('n01675722_4108.JPEG') |
|
|
|
for i, batch in enumerate(loader): |
|
batch = maybe_dictionarize_batch(batch) |
|
inputs, labels = batch['images'].cuda(), batch['labels'].cuda() |
|
data_time = time.time() - end |
|
y = labels |
|
if 'image_paths' in batch: |
|
image_paths = batch['image_paths'] |
|
|
|
logits = model(inputs) |
|
|
|
projection_fn = getattr(dataset, 'project_logits', None) |
|
if projection_fn is not None: |
|
logits = projection_fn(logits, device) |
|
|
|
if hasattr(dataset, 'project_labels'): |
|
y = dataset.project_labels(y, device) |
|
if isinstance(logits, list): |
|
logits = logits[0] |
|
|
|
|
|
pred = logits.argmax(dim=1, keepdim=True).to(device) |
|
if hasattr(dataset, 'accuracy'): |
|
acc1, num_total = dataset.accuracy(logits, y, image_paths, None) |
|
correct += acc1 |
|
n += num_total |
|
else: |
|
correct += pred.eq(y.view_as(pred)).sum().item() |
|
n += y.size(0) |
|
|
|
batch_time = time.time() - end |
|
end = time.time() |
|
if i % 20 == 0: |
|
percent_complete = 100.0 * i / len(loader) |
|
print( |
|
f"[{percent_complete:.0f}% {i}/{len(loader)}]\t" |
|
f"Acc: {100 * (correct/n):.2f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}" |
|
) |
|
|
|
top1 = correct / n |
|
return top1 |
|
|
|
|
|
def assign_learning_rate(param_group, new_lr): |
|
param_group["lr"] = new_lr |
|
|
|
def _warmup_lr(base_lr, warmup_length, step): |
|
return base_lr * (step + 1) / warmup_length |
|
|
|
def cosine_lr(optimizer, base_lrs, warmup_length, steps): |
|
if not isinstance(base_lrs, list): |
|
base_lrs = [base_lrs for _ in optimizer.param_groups] |
|
assert len(base_lrs) == len(optimizer.param_groups) |
|
def _lr_adjuster(step): |
|
for param_group, base_lr in zip(optimizer.param_groups, base_lrs): |
|
if step < warmup_length: |
|
lr = _warmup_lr(base_lr, warmup_length, step) |
|
else: |
|
e = step - warmup_length |
|
es = steps - warmup_length |
|
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr |
|
assign_learning_rate(param_group, lr) |
|
return _lr_adjuster |
|
|
|
|
|
|
|
|