SaraPieri
commited on
Commit
·
626ec32
1
Parent(s):
1f88702
First soup!
Browse files- __pycache__/utils.cpython-36.pyc +0 -0
- __pycache__/zeroshot.cpython-36.pyc +0 -0
- environment.yml +81 -0
- figure.png +0 -0
- finetune.py +196 -0
- helper.bash +1 -0
- utils.py +140 -0
- zeroshot.py +23 -0
__pycache__/utils.cpython-36.pyc
ADDED
Binary file (4.76 kB). View file
|
|
__pycache__/zeroshot.cpython-36.pyc
ADDED
Binary file (974 Bytes). View file
|
|
environment.yml
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: model_soups
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- astroid=2.6.6=py36h06a4308_0
|
7 |
+
- blas=1.0=mkl
|
8 |
+
- ca-certificates=2022.4.26=h06a4308_0
|
9 |
+
- certifi=2021.5.30=py36h06a4308_0
|
10 |
+
- cudatoolkit=11.0.221=h6bb024c_0
|
11 |
+
- dataclasses=0.8=pyh4f3eec9_6
|
12 |
+
- freetype=2.11.0=h70c0345_0
|
13 |
+
- intel-openmp=2022.0.1=h06a4308_3633
|
14 |
+
- isort=5.9.3=pyhd3eb1b0_0
|
15 |
+
- jpeg=9b=h024ee3a_2
|
16 |
+
- lazy-object-proxy=1.6.0=py36h27cfd23_0
|
17 |
+
- lcms2=2.12=h3be6417_0
|
18 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
19 |
+
- libffi=3.3=he6710b0_2
|
20 |
+
- libgcc-ng=11.2.0=h1234567_1
|
21 |
+
- libpng=1.6.37=hbc83047_0
|
22 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
23 |
+
- libtiff=4.2.0=h85742a9_0
|
24 |
+
- libuv=1.40.0=h7b6447c_0
|
25 |
+
- libwebp-base=1.2.2=h7f8727e_0
|
26 |
+
- lz4-c=1.9.3=h295c915_1
|
27 |
+
- mccabe=0.6.1=py36_1
|
28 |
+
- mkl=2020.2=256
|
29 |
+
- mkl-service=2.3.0=py36he8ac12f_0
|
30 |
+
- mkl_fft=1.3.0=py36h54f3939_0
|
31 |
+
- mkl_random=1.1.1=py36h0573a6f_0
|
32 |
+
- ncurses=6.3=h7f8727e_2
|
33 |
+
- ninja=1.10.2=h06a4308_5
|
34 |
+
- ninja-base=1.10.2=hd09550d_5
|
35 |
+
- numpy=1.19.2=py36h54aff64_0
|
36 |
+
- numpy-base=1.19.2=py36hfa32c7d_0
|
37 |
+
- olefile=0.46=pyhd3eb1b0_0
|
38 |
+
- openjpeg=2.4.0=h3ad879b_0
|
39 |
+
- openssl=1.1.1o=h7f8727e_0
|
40 |
+
- pillow=8.3.1=py36h2c7a002_0
|
41 |
+
- pip=21.2.2=py36h06a4308_0
|
42 |
+
- pylint=2.9.6=py36h06a4308_1
|
43 |
+
- python=3.6.13=h12debd9_1
|
44 |
+
- pytorch=1.7.1=py3.6_cuda11.0.221_cudnn8.0.5_0
|
45 |
+
- readline=8.1.2=h7f8727e_1
|
46 |
+
- setuptools=58.0.4=py36h06a4308_0
|
47 |
+
- six=1.16.0=pyhd3eb1b0_1
|
48 |
+
- sqlite=3.38.3=hc218d9a_0
|
49 |
+
- tk=8.6.11=h1ccaba5_1
|
50 |
+
- toml=0.10.2=pyhd3eb1b0_0
|
51 |
+
- torchvision=0.8.2=py36_cu110
|
52 |
+
- typed-ast=1.4.3=py36h7f8727e_1
|
53 |
+
- typing-extensions=4.1.1=hd3eb1b0_0
|
54 |
+
- typing_extensions=4.1.1=pyh06a4308_0
|
55 |
+
- wheel=0.37.1=pyhd3eb1b0_0
|
56 |
+
- wrapt=1.12.1=py36h7b6447c_1
|
57 |
+
- xz=5.2.5=h7f8727e_1
|
58 |
+
- zlib=1.2.12=h7f8727e_2
|
59 |
+
- zstd=1.4.9=haebb681_0
|
60 |
+
- pip:
|
61 |
+
- charset-normalizer==2.0.12
|
62 |
+
- clip==0.1.0
|
63 |
+
- cycler==0.11.0
|
64 |
+
- ftfy==6.0.3
|
65 |
+
- idna==3.3
|
66 |
+
- importlib-resources==5.4.0
|
67 |
+
- kiwisolver==1.3.1
|
68 |
+
- matplotlib==3.3.4
|
69 |
+
- pandas==1.1.5
|
70 |
+
- pyparsing==3.0.9
|
71 |
+
- python-dateutil==2.8.2
|
72 |
+
- pytz==2022.1
|
73 |
+
- regex==2022.4.24
|
74 |
+
- requests==2.27.1
|
75 |
+
- tqdm==4.64.0
|
76 |
+
- urllib3==1.26.9
|
77 |
+
- wcwidth==0.2.5
|
78 |
+
- wget==3.2
|
79 |
+
- zipp==3.6.0
|
80 |
+
prefix: /home/mitchnw/anaconda3/envs/model_soups
|
81 |
+
|
figure.png
ADDED
finetune.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import clip
|
5 |
+
import os
|
6 |
+
from tqdm import tqdm
|
7 |
+
import time
|
8 |
+
from utils import ModelWrapper, maybe_dictionarize_batch, cosine_lr
|
9 |
+
from zeroshot import zeroshot_classifier
|
10 |
+
import torch
|
11 |
+
from torchvision import transforms, datasets
|
12 |
+
|
13 |
+
def parse_arguments():
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument(
|
16 |
+
"--data-location",
|
17 |
+
type=str,
|
18 |
+
default=os.path.expanduser('~/data'),
|
19 |
+
help="The root directory for the datasets.",
|
20 |
+
)
|
21 |
+
parser.add_argument(
|
22 |
+
"--model-location",
|
23 |
+
type=str,
|
24 |
+
default=os.path.expanduser('~/ssd/checkpoints/soups'),
|
25 |
+
help="Where to download the models.",
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"--batch-size",
|
29 |
+
type=int,
|
30 |
+
default=256,
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--workers",
|
34 |
+
type=int,
|
35 |
+
default=8,
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--epochs",
|
39 |
+
type=int,
|
40 |
+
default=8,
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--warmup-length",
|
44 |
+
type=int,
|
45 |
+
default=500,
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--lr",
|
49 |
+
type=float,
|
50 |
+
default=2e-5,
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--wd",
|
54 |
+
type=float,
|
55 |
+
default=0.1,
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--model",
|
59 |
+
default='ViT-B/32',
|
60 |
+
help='Model to use -- you can try another like ViT-L/14'
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--name",
|
64 |
+
default='finetune_cp',
|
65 |
+
help='Filename for the checkpoints.'
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--timm-aug", action="store_true", default=False,
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--checkpoint_path",
|
72 |
+
default=None,
|
73 |
+
help='Checkpoint path to load the model'
|
74 |
+
)
|
75 |
+
|
76 |
+
return parser.parse_args()
|
77 |
+
|
78 |
+
if __name__ == '__main__':
|
79 |
+
args = parse_arguments()
|
80 |
+
DEVICE = 'cuda'
|
81 |
+
|
82 |
+
|
83 |
+
template = [lambda x : f"a photo generated by {x}."]
|
84 |
+
|
85 |
+
|
86 |
+
base_model, preprocess = clip.load(args.model, 'cuda', jit=False)
|
87 |
+
|
88 |
+
|
89 |
+
train_transforms = transforms.Compose([transforms.RandomRotation(30),
|
90 |
+
transforms.RandomResizedCrop(224),
|
91 |
+
transforms.RandomHorizontalFlip(),
|
92 |
+
transforms.ToTensor()])
|
93 |
+
|
94 |
+
test_transforms = transforms.Compose([transforms.RandomRotation(30),
|
95 |
+
transforms.RandomResizedCrop(224),
|
96 |
+
transforms.ToTensor()])
|
97 |
+
|
98 |
+
|
99 |
+
train_data = datasets.ImageFolder(args.data_location + '/train', transform=train_transforms)
|
100 |
+
test_data = datasets.ImageFolder(args.data_location + '/test', transform=test_transforms)
|
101 |
+
|
102 |
+
train_dset = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, num_workers=args.workers, shuffle = True)
|
103 |
+
test_dset = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, num_workers=args.workers)
|
104 |
+
|
105 |
+
clf = zeroshot_classifier(base_model, ['humans', 'AI'], template, DEVICE)
|
106 |
+
NUM_CLASSES = 2
|
107 |
+
feature_dim = base_model.visual.output_dim
|
108 |
+
|
109 |
+
model = ModelWrapper(base_model, feature_dim, NUM_CLASSES, normalize=True, initial_weights=clf, checkpoint_path = args.checkpoint_path)
|
110 |
+
for p in model.parameters():
|
111 |
+
p.data = p.data.float()
|
112 |
+
|
113 |
+
model = model.cuda()
|
114 |
+
devices = [x for x in range(torch.cuda.device_count())]
|
115 |
+
model = torch.nn.DataParallel(model, device_ids=devices)
|
116 |
+
|
117 |
+
model_parameters = [p for p in model.parameters() if p.requires_grad]
|
118 |
+
optimizer = torch.optim.AdamW(model_parameters, lr=args.lr, weight_decay=args.wd)
|
119 |
+
|
120 |
+
num_batches = len(train_dset)
|
121 |
+
scheduler = cosine_lr(optimizer, args.lr, args.warmup_length, args.epochs * num_batches)
|
122 |
+
|
123 |
+
loss_fn = torch.nn.CrossEntropyLoss()
|
124 |
+
|
125 |
+
model_path = os.path.join(args.model_location, f'{args.name}.pt')
|
126 |
+
print('Saving model to', model_path)
|
127 |
+
torch.save(model.module.state_dict(), model_path)
|
128 |
+
|
129 |
+
for epoch in range(args.epochs):
|
130 |
+
# Train
|
131 |
+
model.train()
|
132 |
+
end = time.time()
|
133 |
+
for i, batch in enumerate(train_dset):
|
134 |
+
step = i + epoch * num_batches
|
135 |
+
scheduler(step)
|
136 |
+
optimizer.zero_grad()
|
137 |
+
batch = maybe_dictionarize_batch(batch)
|
138 |
+
inputs, labels = batch['images'].to(DEVICE), batch['labels'].to(DEVICE)
|
139 |
+
data_time = time.time() - end
|
140 |
+
|
141 |
+
logits = model(inputs)
|
142 |
+
loss = loss_fn(logits, labels)
|
143 |
+
|
144 |
+
loss.backward()
|
145 |
+
|
146 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
147 |
+
|
148 |
+
optimizer.step()
|
149 |
+
|
150 |
+
batch_time = time.time() - end
|
151 |
+
end = time.time()
|
152 |
+
|
153 |
+
if i % 20 == 0:
|
154 |
+
percent_complete = 100.0 * i / len(train_dset)
|
155 |
+
print(
|
156 |
+
f"Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(train_dset)}]\t"
|
157 |
+
f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", flush=True
|
158 |
+
)
|
159 |
+
|
160 |
+
## Evaluate
|
161 |
+
test_loader = test_dset
|
162 |
+
model.eval()
|
163 |
+
|
164 |
+
last_accuracy = 0.0
|
165 |
+
|
166 |
+
with torch.no_grad():
|
167 |
+
print('*'*80)
|
168 |
+
print('Starting eval')
|
169 |
+
correct, count = 0.0, 0.0
|
170 |
+
pbar = tqdm(test_loader)
|
171 |
+
for batch in pbar:
|
172 |
+
batch = maybe_dictionarize_batch(batch)
|
173 |
+
inputs, labels = batch['images'].to(DEVICE), batch['labels'].to(DEVICE)
|
174 |
+
|
175 |
+
logits = model(inputs)
|
176 |
+
|
177 |
+
loss = loss_fn(logits, labels)
|
178 |
+
|
179 |
+
pred = logits.argmax(dim=1, keepdim=True)
|
180 |
+
correct += pred.eq(labels.view_as(pred)).sum().item()
|
181 |
+
count += len(logits)
|
182 |
+
pbar.set_description(
|
183 |
+
f"Val loss: {loss.item():.4f} Acc: {100*correct/count:.2f}")
|
184 |
+
top1 = correct / count
|
185 |
+
print(f'Val acc at epoch {epoch}: {100*top1:.2f}')
|
186 |
+
|
187 |
+
curr_acc = 100*top1
|
188 |
+
if curr_acc > last_accuracy:
|
189 |
+
print('Current acc: {}, Last acc: {}'.format(curr_acc, last_accuracy))
|
190 |
+
last_accuracy = curr_acc
|
191 |
+
model_path = os.path.join(args.model_location, f'{args.name}.pt')
|
192 |
+
print('Saving model to', model_path)
|
193 |
+
torch.save(model.module.state_dict(), model_path)
|
194 |
+
else:
|
195 |
+
print('Not saving the model')
|
196 |
+
|
helper.bash
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python finetune.py --data-location /l/users/u21010238/data/AiorNot --model-location /home/sara.pieri/Documents/model-soups/models --batch-size 56 --name finetune_cp_AiorNot_model_0 --checkpoint_path /home/sara.pieri/Documents/model-soups/models/model_0.pt
|
utils.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import time
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
class ModelWrapper(torch.nn.Module):
|
7 |
+
def __init__(self, model, feature_dim, num_classes, normalize=False, initial_weights=None, checkpoint_path = None):
|
8 |
+
super(ModelWrapper, self).__init__()
|
9 |
+
self.model = model
|
10 |
+
self.classification_head = torch.nn.Linear(feature_dim, num_classes)
|
11 |
+
self.normalize = normalize
|
12 |
+
|
13 |
+
if initial_weights is None:
|
14 |
+
initial_weights = torch.zeros_like(self.classification_head.weight)
|
15 |
+
torch.nn.init.kaiming_uniform_(initial_weights, a=math.sqrt(5))
|
16 |
+
|
17 |
+
self.classification_head.weight = torch.nn.Parameter(initial_weights.clone())
|
18 |
+
self.classification_head.bias = torch.nn.Parameter(torch.zeros_like(self.classification_head.bias))
|
19 |
+
|
20 |
+
# Note: modified. Get rid of the language part.
|
21 |
+
if hasattr(self.model, 'transformer'):
|
22 |
+
delattr(self.model, 'transformer')
|
23 |
+
|
24 |
+
if checkpoint_path:
|
25 |
+
print("Loading checkpoint", checkpoint_path)
|
26 |
+
checkpoint = torch.load(checkpoint_path)
|
27 |
+
checkpoint.pop('classification_head.weight')
|
28 |
+
checkpoint.pop('classification_head.bias')
|
29 |
+
model.load_state_dict(checkpoint, strict=False)
|
30 |
+
|
31 |
+
def forward(self, images, return_features=False):
|
32 |
+
features = self.model.encode_image(images)
|
33 |
+
if self.normalize:
|
34 |
+
features = features / features.norm(dim=-1, keepdim=True)
|
35 |
+
logits = self.classification_head(features)
|
36 |
+
if return_features:
|
37 |
+
return logits, features
|
38 |
+
return logits
|
39 |
+
|
40 |
+
def get_model_from_sd(state_dict, base_model):
|
41 |
+
feature_dim = state_dict['classification_head.weight'].shape[1]
|
42 |
+
num_classes = state_dict['classification_head.weight'].shape[0]
|
43 |
+
model = ModelWrapper(base_model, feature_dim, num_classes, normalize=True)
|
44 |
+
for p in model.parameters():
|
45 |
+
p.data = p.data.float()
|
46 |
+
model.load_state_dict(state_dict)
|
47 |
+
model = model.cuda()
|
48 |
+
devices = [x for x in range(torch.cuda.device_count())]
|
49 |
+
return torch.nn.DataParallel(model, device_ids=devices)
|
50 |
+
|
51 |
+
def maybe_dictionarize_batch(batch):
|
52 |
+
if isinstance(batch, dict):
|
53 |
+
return batch
|
54 |
+
if len(batch) == 2:
|
55 |
+
return {'images': batch[0], 'labels': batch[1]}
|
56 |
+
elif len(batch) == 3:
|
57 |
+
return {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]}
|
58 |
+
else:
|
59 |
+
raise ValueError(f'Unexpected number of elements: {len(batch)}')
|
60 |
+
|
61 |
+
|
62 |
+
def test_model_on_dataset(model, dataset):
|
63 |
+
|
64 |
+
model.eval()
|
65 |
+
device = 'cuda'
|
66 |
+
with torch.no_grad():
|
67 |
+
top1, correct, n = 0., 0., 0.
|
68 |
+
end = time.time()
|
69 |
+
loader = dataset.test_loader
|
70 |
+
if type(dataset).__name__ == 'ImageNet2p':
|
71 |
+
loader = dataset.train_loader
|
72 |
+
# assert to make sure the imagenet held-out minival logic is consistent across machines.
|
73 |
+
# tested on a few machines but if this fails for you please submit an issue and we will resolve.
|
74 |
+
assert dataset.train_dataset.__getitem__(dataset.sampler.indices[1000])['image_paths'].endswith('n01675722_4108.JPEG')
|
75 |
+
|
76 |
+
for i, batch in enumerate(loader):
|
77 |
+
batch = maybe_dictionarize_batch(batch)
|
78 |
+
inputs, labels = batch['images'].cuda(), batch['labels'].cuda()
|
79 |
+
data_time = time.time() - end
|
80 |
+
y = labels
|
81 |
+
if 'image_paths' in batch:
|
82 |
+
image_paths = batch['image_paths']
|
83 |
+
|
84 |
+
logits = model(inputs)
|
85 |
+
|
86 |
+
projection_fn = getattr(dataset, 'project_logits', None)
|
87 |
+
if projection_fn is not None:
|
88 |
+
logits = projection_fn(logits, device)
|
89 |
+
|
90 |
+
if hasattr(dataset, 'project_labels'):
|
91 |
+
y = dataset.project_labels(y, device)
|
92 |
+
if isinstance(logits, list):
|
93 |
+
logits = logits[0]
|
94 |
+
|
95 |
+
|
96 |
+
pred = logits.argmax(dim=1, keepdim=True).to(device)
|
97 |
+
if hasattr(dataset, 'accuracy'):
|
98 |
+
acc1, num_total = dataset.accuracy(logits, y, image_paths, None)
|
99 |
+
correct += acc1
|
100 |
+
n += num_total
|
101 |
+
else:
|
102 |
+
correct += pred.eq(y.view_as(pred)).sum().item()
|
103 |
+
n += y.size(0)
|
104 |
+
|
105 |
+
batch_time = time.time() - end
|
106 |
+
end = time.time()
|
107 |
+
if i % 20 == 0:
|
108 |
+
percent_complete = 100.0 * i / len(loader)
|
109 |
+
print(
|
110 |
+
f"[{percent_complete:.0f}% {i}/{len(loader)}]\t"
|
111 |
+
f"Acc: {100 * (correct/n):.2f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}"
|
112 |
+
)
|
113 |
+
|
114 |
+
top1 = correct / n
|
115 |
+
return top1
|
116 |
+
|
117 |
+
|
118 |
+
def assign_learning_rate(param_group, new_lr):
|
119 |
+
param_group["lr"] = new_lr
|
120 |
+
|
121 |
+
def _warmup_lr(base_lr, warmup_length, step):
|
122 |
+
return base_lr * (step + 1) / warmup_length
|
123 |
+
|
124 |
+
def cosine_lr(optimizer, base_lrs, warmup_length, steps):
|
125 |
+
if not isinstance(base_lrs, list):
|
126 |
+
base_lrs = [base_lrs for _ in optimizer.param_groups]
|
127 |
+
assert len(base_lrs) == len(optimizer.param_groups)
|
128 |
+
def _lr_adjuster(step):
|
129 |
+
for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
|
130 |
+
if step < warmup_length:
|
131 |
+
lr = _warmup_lr(base_lr, warmup_length, step)
|
132 |
+
else:
|
133 |
+
e = step - warmup_length
|
134 |
+
es = steps - warmup_length
|
135 |
+
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
|
136 |
+
assign_learning_rate(param_group, lr)
|
137 |
+
return _lr_adjuster
|
138 |
+
|
139 |
+
|
140 |
+
|
zeroshot.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import clip
|
5 |
+
import os
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
def zeroshot_classifier(model, classnames, templates, device):
|
9 |
+
print('Building zero-shot classifier.')
|
10 |
+
with torch.no_grad():
|
11 |
+
zeroshot_weights = []
|
12 |
+
for classname in tqdm(classnames):
|
13 |
+
texts = [template(classname) for template in templates] #format with class
|
14 |
+
texts = clip.tokenize(texts).to(device) #tokenize
|
15 |
+
class_embeddings = model.encode_text(texts)
|
16 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
17 |
+
class_embedding = class_embeddings.mean(dim=0)
|
18 |
+
class_embedding /= class_embedding.norm()
|
19 |
+
zeroshot_weights.append(class_embedding)
|
20 |
+
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
|
21 |
+
return 100*zeroshot_weights.t()
|
22 |
+
|
23 |
+
|