model_soups / utils.py
SaraPieri
First soup!
626ec32
raw
history blame
5.46 kB
import torch
import math
import time
import numpy as np
class ModelWrapper(torch.nn.Module):
def __init__(self, model, feature_dim, num_classes, normalize=False, initial_weights=None, checkpoint_path = None):
super(ModelWrapper, self).__init__()
self.model = model
self.classification_head = torch.nn.Linear(feature_dim, num_classes)
self.normalize = normalize
if initial_weights is None:
initial_weights = torch.zeros_like(self.classification_head.weight)
torch.nn.init.kaiming_uniform_(initial_weights, a=math.sqrt(5))
self.classification_head.weight = torch.nn.Parameter(initial_weights.clone())
self.classification_head.bias = torch.nn.Parameter(torch.zeros_like(self.classification_head.bias))
# Note: modified. Get rid of the language part.
if hasattr(self.model, 'transformer'):
delattr(self.model, 'transformer')
if checkpoint_path:
print("Loading checkpoint", checkpoint_path)
checkpoint = torch.load(checkpoint_path)
checkpoint.pop('classification_head.weight')
checkpoint.pop('classification_head.bias')
model.load_state_dict(checkpoint, strict=False)
def forward(self, images, return_features=False):
features = self.model.encode_image(images)
if self.normalize:
features = features / features.norm(dim=-1, keepdim=True)
logits = self.classification_head(features)
if return_features:
return logits, features
return logits
def get_model_from_sd(state_dict, base_model):
feature_dim = state_dict['classification_head.weight'].shape[1]
num_classes = state_dict['classification_head.weight'].shape[0]
model = ModelWrapper(base_model, feature_dim, num_classes, normalize=True)
for p in model.parameters():
p.data = p.data.float()
model.load_state_dict(state_dict)
model = model.cuda()
devices = [x for x in range(torch.cuda.device_count())]
return torch.nn.DataParallel(model, device_ids=devices)
def maybe_dictionarize_batch(batch):
if isinstance(batch, dict):
return batch
if len(batch) == 2:
return {'images': batch[0], 'labels': batch[1]}
elif len(batch) == 3:
return {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]}
else:
raise ValueError(f'Unexpected number of elements: {len(batch)}')
def test_model_on_dataset(model, dataset):
model.eval()
device = 'cuda'
with torch.no_grad():
top1, correct, n = 0., 0., 0.
end = time.time()
loader = dataset.test_loader
if type(dataset).__name__ == 'ImageNet2p':
loader = dataset.train_loader
# assert to make sure the imagenet held-out minival logic is consistent across machines.
# tested on a few machines but if this fails for you please submit an issue and we will resolve.
assert dataset.train_dataset.__getitem__(dataset.sampler.indices[1000])['image_paths'].endswith('n01675722_4108.JPEG')
for i, batch in enumerate(loader):
batch = maybe_dictionarize_batch(batch)
inputs, labels = batch['images'].cuda(), batch['labels'].cuda()
data_time = time.time() - end
y = labels
if 'image_paths' in batch:
image_paths = batch['image_paths']
logits = model(inputs)
projection_fn = getattr(dataset, 'project_logits', None)
if projection_fn is not None:
logits = projection_fn(logits, device)
if hasattr(dataset, 'project_labels'):
y = dataset.project_labels(y, device)
if isinstance(logits, list):
logits = logits[0]
pred = logits.argmax(dim=1, keepdim=True).to(device)
if hasattr(dataset, 'accuracy'):
acc1, num_total = dataset.accuracy(logits, y, image_paths, None)
correct += acc1
n += num_total
else:
correct += pred.eq(y.view_as(pred)).sum().item()
n += y.size(0)
batch_time = time.time() - end
end = time.time()
if i % 20 == 0:
percent_complete = 100.0 * i / len(loader)
print(
f"[{percent_complete:.0f}% {i}/{len(loader)}]\t"
f"Acc: {100 * (correct/n):.2f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}"
)
top1 = correct / n
return top1
def assign_learning_rate(param_group, new_lr):
param_group["lr"] = new_lr
def _warmup_lr(base_lr, warmup_length, step):
return base_lr * (step + 1) / warmup_length
def cosine_lr(optimizer, base_lrs, warmup_length, steps):
if not isinstance(base_lrs, list):
base_lrs = [base_lrs for _ in optimizer.param_groups]
assert len(base_lrs) == len(optimizer.param_groups)
def _lr_adjuster(step):
for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
e = step - warmup_length
es = steps - warmup_length
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
assign_learning_rate(param_group, lr)
return _lr_adjuster