|
import torch.nn as nn |
|
from torchvision import models |
|
|
|
|
|
def get_head(out_size, cfg): |
|
""" creates projection head g() from config """ |
|
x = [] |
|
in_size = out_size |
|
for _ in range(cfg.head_layers - 1): |
|
x.append(nn.Linear(in_size, cfg.head_size)) |
|
if cfg.add_bn: |
|
x.append(nn.BatchNorm1d(cfg.head_size)) |
|
x.append(nn.ReLU()) |
|
in_size = cfg.head_size |
|
x.append(nn.Linear(in_size, cfg.emb)) |
|
return nn.Sequential(*x) |
|
|
|
|
|
def get_model(arch, dataset): |
|
""" creates encoder E() by name and modifies it for dataset """ |
|
model = getattr(models, arch)(pretrained=False) |
|
if dataset != "imagenet": |
|
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) |
|
if dataset == "cifar10" or dataset == "cifar100": |
|
model.maxpool = nn.Identity() |
|
out_size = model.fc.in_features |
|
model.fc = nn.Identity() |
|
|
|
return nn.DataParallel(model), out_size |
|
|