SaraPieri commited on
Commit
626ec32
·
1 Parent(s): 1f88702

First soup!

Browse files
__pycache__/utils.cpython-36.pyc ADDED
Binary file (4.76 kB). View file
 
__pycache__/zeroshot.cpython-36.pyc ADDED
Binary file (974 Bytes). View file
 
environment.yml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: model_soups
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - astroid=2.6.6=py36h06a4308_0
7
+ - blas=1.0=mkl
8
+ - ca-certificates=2022.4.26=h06a4308_0
9
+ - certifi=2021.5.30=py36h06a4308_0
10
+ - cudatoolkit=11.0.221=h6bb024c_0
11
+ - dataclasses=0.8=pyh4f3eec9_6
12
+ - freetype=2.11.0=h70c0345_0
13
+ - intel-openmp=2022.0.1=h06a4308_3633
14
+ - isort=5.9.3=pyhd3eb1b0_0
15
+ - jpeg=9b=h024ee3a_2
16
+ - lazy-object-proxy=1.6.0=py36h27cfd23_0
17
+ - lcms2=2.12=h3be6417_0
18
+ - ld_impl_linux-64=2.38=h1181459_1
19
+ - libffi=3.3=he6710b0_2
20
+ - libgcc-ng=11.2.0=h1234567_1
21
+ - libpng=1.6.37=hbc83047_0
22
+ - libstdcxx-ng=11.2.0=h1234567_1
23
+ - libtiff=4.2.0=h85742a9_0
24
+ - libuv=1.40.0=h7b6447c_0
25
+ - libwebp-base=1.2.2=h7f8727e_0
26
+ - lz4-c=1.9.3=h295c915_1
27
+ - mccabe=0.6.1=py36_1
28
+ - mkl=2020.2=256
29
+ - mkl-service=2.3.0=py36he8ac12f_0
30
+ - mkl_fft=1.3.0=py36h54f3939_0
31
+ - mkl_random=1.1.1=py36h0573a6f_0
32
+ - ncurses=6.3=h7f8727e_2
33
+ - ninja=1.10.2=h06a4308_5
34
+ - ninja-base=1.10.2=hd09550d_5
35
+ - numpy=1.19.2=py36h54aff64_0
36
+ - numpy-base=1.19.2=py36hfa32c7d_0
37
+ - olefile=0.46=pyhd3eb1b0_0
38
+ - openjpeg=2.4.0=h3ad879b_0
39
+ - openssl=1.1.1o=h7f8727e_0
40
+ - pillow=8.3.1=py36h2c7a002_0
41
+ - pip=21.2.2=py36h06a4308_0
42
+ - pylint=2.9.6=py36h06a4308_1
43
+ - python=3.6.13=h12debd9_1
44
+ - pytorch=1.7.1=py3.6_cuda11.0.221_cudnn8.0.5_0
45
+ - readline=8.1.2=h7f8727e_1
46
+ - setuptools=58.0.4=py36h06a4308_0
47
+ - six=1.16.0=pyhd3eb1b0_1
48
+ - sqlite=3.38.3=hc218d9a_0
49
+ - tk=8.6.11=h1ccaba5_1
50
+ - toml=0.10.2=pyhd3eb1b0_0
51
+ - torchvision=0.8.2=py36_cu110
52
+ - typed-ast=1.4.3=py36h7f8727e_1
53
+ - typing-extensions=4.1.1=hd3eb1b0_0
54
+ - typing_extensions=4.1.1=pyh06a4308_0
55
+ - wheel=0.37.1=pyhd3eb1b0_0
56
+ - wrapt=1.12.1=py36h7b6447c_1
57
+ - xz=5.2.5=h7f8727e_1
58
+ - zlib=1.2.12=h7f8727e_2
59
+ - zstd=1.4.9=haebb681_0
60
+ - pip:
61
+ - charset-normalizer==2.0.12
62
+ - clip==0.1.0
63
+ - cycler==0.11.0
64
+ - ftfy==6.0.3
65
+ - idna==3.3
66
+ - importlib-resources==5.4.0
67
+ - kiwisolver==1.3.1
68
+ - matplotlib==3.3.4
69
+ - pandas==1.1.5
70
+ - pyparsing==3.0.9
71
+ - python-dateutil==2.8.2
72
+ - pytz==2022.1
73
+ - regex==2022.4.24
74
+ - requests==2.27.1
75
+ - tqdm==4.64.0
76
+ - urllib3==1.26.9
77
+ - wcwidth==0.2.5
78
+ - wget==3.2
79
+ - zipp==3.6.0
80
+ prefix: /home/mitchnw/anaconda3/envs/model_soups
81
+
figure.png ADDED
finetune.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import clip
5
+ import os
6
+ from tqdm import tqdm
7
+ import time
8
+ from utils import ModelWrapper, maybe_dictionarize_batch, cosine_lr
9
+ from zeroshot import zeroshot_classifier
10
+ import torch
11
+ from torchvision import transforms, datasets
12
+
13
+ def parse_arguments():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "--data-location",
17
+ type=str,
18
+ default=os.path.expanduser('~/data'),
19
+ help="The root directory for the datasets.",
20
+ )
21
+ parser.add_argument(
22
+ "--model-location",
23
+ type=str,
24
+ default=os.path.expanduser('~/ssd/checkpoints/soups'),
25
+ help="Where to download the models.",
26
+ )
27
+ parser.add_argument(
28
+ "--batch-size",
29
+ type=int,
30
+ default=256,
31
+ )
32
+ parser.add_argument(
33
+ "--workers",
34
+ type=int,
35
+ default=8,
36
+ )
37
+ parser.add_argument(
38
+ "--epochs",
39
+ type=int,
40
+ default=8,
41
+ )
42
+ parser.add_argument(
43
+ "--warmup-length",
44
+ type=int,
45
+ default=500,
46
+ )
47
+ parser.add_argument(
48
+ "--lr",
49
+ type=float,
50
+ default=2e-5,
51
+ )
52
+ parser.add_argument(
53
+ "--wd",
54
+ type=float,
55
+ default=0.1,
56
+ )
57
+ parser.add_argument(
58
+ "--model",
59
+ default='ViT-B/32',
60
+ help='Model to use -- you can try another like ViT-L/14'
61
+ )
62
+ parser.add_argument(
63
+ "--name",
64
+ default='finetune_cp',
65
+ help='Filename for the checkpoints.'
66
+ )
67
+ parser.add_argument(
68
+ "--timm-aug", action="store_true", default=False,
69
+ )
70
+ parser.add_argument(
71
+ "--checkpoint_path",
72
+ default=None,
73
+ help='Checkpoint path to load the model'
74
+ )
75
+
76
+ return parser.parse_args()
77
+
78
+ if __name__ == '__main__':
79
+ args = parse_arguments()
80
+ DEVICE = 'cuda'
81
+
82
+
83
+ template = [lambda x : f"a photo generated by {x}."]
84
+
85
+
86
+ base_model, preprocess = clip.load(args.model, 'cuda', jit=False)
87
+
88
+
89
+ train_transforms = transforms.Compose([transforms.RandomRotation(30),
90
+ transforms.RandomResizedCrop(224),
91
+ transforms.RandomHorizontalFlip(),
92
+ transforms.ToTensor()])
93
+
94
+ test_transforms = transforms.Compose([transforms.RandomRotation(30),
95
+ transforms.RandomResizedCrop(224),
96
+ transforms.ToTensor()])
97
+
98
+
99
+ train_data = datasets.ImageFolder(args.data_location + '/train', transform=train_transforms)
100
+ test_data = datasets.ImageFolder(args.data_location + '/test', transform=test_transforms)
101
+
102
+ train_dset = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, num_workers=args.workers, shuffle = True)
103
+ test_dset = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, num_workers=args.workers)
104
+
105
+ clf = zeroshot_classifier(base_model, ['humans', 'AI'], template, DEVICE)
106
+ NUM_CLASSES = 2
107
+ feature_dim = base_model.visual.output_dim
108
+
109
+ model = ModelWrapper(base_model, feature_dim, NUM_CLASSES, normalize=True, initial_weights=clf, checkpoint_path = args.checkpoint_path)
110
+ for p in model.parameters():
111
+ p.data = p.data.float()
112
+
113
+ model = model.cuda()
114
+ devices = [x for x in range(torch.cuda.device_count())]
115
+ model = torch.nn.DataParallel(model, device_ids=devices)
116
+
117
+ model_parameters = [p for p in model.parameters() if p.requires_grad]
118
+ optimizer = torch.optim.AdamW(model_parameters, lr=args.lr, weight_decay=args.wd)
119
+
120
+ num_batches = len(train_dset)
121
+ scheduler = cosine_lr(optimizer, args.lr, args.warmup_length, args.epochs * num_batches)
122
+
123
+ loss_fn = torch.nn.CrossEntropyLoss()
124
+
125
+ model_path = os.path.join(args.model_location, f'{args.name}.pt')
126
+ print('Saving model to', model_path)
127
+ torch.save(model.module.state_dict(), model_path)
128
+
129
+ for epoch in range(args.epochs):
130
+ # Train
131
+ model.train()
132
+ end = time.time()
133
+ for i, batch in enumerate(train_dset):
134
+ step = i + epoch * num_batches
135
+ scheduler(step)
136
+ optimizer.zero_grad()
137
+ batch = maybe_dictionarize_batch(batch)
138
+ inputs, labels = batch['images'].to(DEVICE), batch['labels'].to(DEVICE)
139
+ data_time = time.time() - end
140
+
141
+ logits = model(inputs)
142
+ loss = loss_fn(logits, labels)
143
+
144
+ loss.backward()
145
+
146
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
147
+
148
+ optimizer.step()
149
+
150
+ batch_time = time.time() - end
151
+ end = time.time()
152
+
153
+ if i % 20 == 0:
154
+ percent_complete = 100.0 * i / len(train_dset)
155
+ print(
156
+ f"Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(train_dset)}]\t"
157
+ f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", flush=True
158
+ )
159
+
160
+ ## Evaluate
161
+ test_loader = test_dset
162
+ model.eval()
163
+
164
+ last_accuracy = 0.0
165
+
166
+ with torch.no_grad():
167
+ print('*'*80)
168
+ print('Starting eval')
169
+ correct, count = 0.0, 0.0
170
+ pbar = tqdm(test_loader)
171
+ for batch in pbar:
172
+ batch = maybe_dictionarize_batch(batch)
173
+ inputs, labels = batch['images'].to(DEVICE), batch['labels'].to(DEVICE)
174
+
175
+ logits = model(inputs)
176
+
177
+ loss = loss_fn(logits, labels)
178
+
179
+ pred = logits.argmax(dim=1, keepdim=True)
180
+ correct += pred.eq(labels.view_as(pred)).sum().item()
181
+ count += len(logits)
182
+ pbar.set_description(
183
+ f"Val loss: {loss.item():.4f} Acc: {100*correct/count:.2f}")
184
+ top1 = correct / count
185
+ print(f'Val acc at epoch {epoch}: {100*top1:.2f}')
186
+
187
+ curr_acc = 100*top1
188
+ if curr_acc > last_accuracy:
189
+ print('Current acc: {}, Last acc: {}'.format(curr_acc, last_accuracy))
190
+ last_accuracy = curr_acc
191
+ model_path = os.path.join(args.model_location, f'{args.name}.pt')
192
+ print('Saving model to', model_path)
193
+ torch.save(model.module.state_dict(), model_path)
194
+ else:
195
+ print('Not saving the model')
196
+
helper.bash ADDED
@@ -0,0 +1 @@
 
 
1
+ python finetune.py --data-location /l/users/u21010238/data/AiorNot --model-location /home/sara.pieri/Documents/model-soups/models --batch-size 56 --name finetune_cp_AiorNot_model_0 --checkpoint_path /home/sara.pieri/Documents/model-soups/models/model_0.pt
utils.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import time
4
+ import numpy as np
5
+
6
+ class ModelWrapper(torch.nn.Module):
7
+ def __init__(self, model, feature_dim, num_classes, normalize=False, initial_weights=None, checkpoint_path = None):
8
+ super(ModelWrapper, self).__init__()
9
+ self.model = model
10
+ self.classification_head = torch.nn.Linear(feature_dim, num_classes)
11
+ self.normalize = normalize
12
+
13
+ if initial_weights is None:
14
+ initial_weights = torch.zeros_like(self.classification_head.weight)
15
+ torch.nn.init.kaiming_uniform_(initial_weights, a=math.sqrt(5))
16
+
17
+ self.classification_head.weight = torch.nn.Parameter(initial_weights.clone())
18
+ self.classification_head.bias = torch.nn.Parameter(torch.zeros_like(self.classification_head.bias))
19
+
20
+ # Note: modified. Get rid of the language part.
21
+ if hasattr(self.model, 'transformer'):
22
+ delattr(self.model, 'transformer')
23
+
24
+ if checkpoint_path:
25
+ print("Loading checkpoint", checkpoint_path)
26
+ checkpoint = torch.load(checkpoint_path)
27
+ checkpoint.pop('classification_head.weight')
28
+ checkpoint.pop('classification_head.bias')
29
+ model.load_state_dict(checkpoint, strict=False)
30
+
31
+ def forward(self, images, return_features=False):
32
+ features = self.model.encode_image(images)
33
+ if self.normalize:
34
+ features = features / features.norm(dim=-1, keepdim=True)
35
+ logits = self.classification_head(features)
36
+ if return_features:
37
+ return logits, features
38
+ return logits
39
+
40
+ def get_model_from_sd(state_dict, base_model):
41
+ feature_dim = state_dict['classification_head.weight'].shape[1]
42
+ num_classes = state_dict['classification_head.weight'].shape[0]
43
+ model = ModelWrapper(base_model, feature_dim, num_classes, normalize=True)
44
+ for p in model.parameters():
45
+ p.data = p.data.float()
46
+ model.load_state_dict(state_dict)
47
+ model = model.cuda()
48
+ devices = [x for x in range(torch.cuda.device_count())]
49
+ return torch.nn.DataParallel(model, device_ids=devices)
50
+
51
+ def maybe_dictionarize_batch(batch):
52
+ if isinstance(batch, dict):
53
+ return batch
54
+ if len(batch) == 2:
55
+ return {'images': batch[0], 'labels': batch[1]}
56
+ elif len(batch) == 3:
57
+ return {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]}
58
+ else:
59
+ raise ValueError(f'Unexpected number of elements: {len(batch)}')
60
+
61
+
62
+ def test_model_on_dataset(model, dataset):
63
+
64
+ model.eval()
65
+ device = 'cuda'
66
+ with torch.no_grad():
67
+ top1, correct, n = 0., 0., 0.
68
+ end = time.time()
69
+ loader = dataset.test_loader
70
+ if type(dataset).__name__ == 'ImageNet2p':
71
+ loader = dataset.train_loader
72
+ # assert to make sure the imagenet held-out minival logic is consistent across machines.
73
+ # tested on a few machines but if this fails for you please submit an issue and we will resolve.
74
+ assert dataset.train_dataset.__getitem__(dataset.sampler.indices[1000])['image_paths'].endswith('n01675722_4108.JPEG')
75
+
76
+ for i, batch in enumerate(loader):
77
+ batch = maybe_dictionarize_batch(batch)
78
+ inputs, labels = batch['images'].cuda(), batch['labels'].cuda()
79
+ data_time = time.time() - end
80
+ y = labels
81
+ if 'image_paths' in batch:
82
+ image_paths = batch['image_paths']
83
+
84
+ logits = model(inputs)
85
+
86
+ projection_fn = getattr(dataset, 'project_logits', None)
87
+ if projection_fn is not None:
88
+ logits = projection_fn(logits, device)
89
+
90
+ if hasattr(dataset, 'project_labels'):
91
+ y = dataset.project_labels(y, device)
92
+ if isinstance(logits, list):
93
+ logits = logits[0]
94
+
95
+
96
+ pred = logits.argmax(dim=1, keepdim=True).to(device)
97
+ if hasattr(dataset, 'accuracy'):
98
+ acc1, num_total = dataset.accuracy(logits, y, image_paths, None)
99
+ correct += acc1
100
+ n += num_total
101
+ else:
102
+ correct += pred.eq(y.view_as(pred)).sum().item()
103
+ n += y.size(0)
104
+
105
+ batch_time = time.time() - end
106
+ end = time.time()
107
+ if i % 20 == 0:
108
+ percent_complete = 100.0 * i / len(loader)
109
+ print(
110
+ f"[{percent_complete:.0f}% {i}/{len(loader)}]\t"
111
+ f"Acc: {100 * (correct/n):.2f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}"
112
+ )
113
+
114
+ top1 = correct / n
115
+ return top1
116
+
117
+
118
+ def assign_learning_rate(param_group, new_lr):
119
+ param_group["lr"] = new_lr
120
+
121
+ def _warmup_lr(base_lr, warmup_length, step):
122
+ return base_lr * (step + 1) / warmup_length
123
+
124
+ def cosine_lr(optimizer, base_lrs, warmup_length, steps):
125
+ if not isinstance(base_lrs, list):
126
+ base_lrs = [base_lrs for _ in optimizer.param_groups]
127
+ assert len(base_lrs) == len(optimizer.param_groups)
128
+ def _lr_adjuster(step):
129
+ for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
130
+ if step < warmup_length:
131
+ lr = _warmup_lr(base_lr, warmup_length, step)
132
+ else:
133
+ e = step - warmup_length
134
+ es = steps - warmup_length
135
+ lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
136
+ assign_learning_rate(param_group, lr)
137
+ return _lr_adjuster
138
+
139
+
140
+
zeroshot.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import clip
5
+ import os
6
+ from tqdm import tqdm
7
+
8
+ def zeroshot_classifier(model, classnames, templates, device):
9
+ print('Building zero-shot classifier.')
10
+ with torch.no_grad():
11
+ zeroshot_weights = []
12
+ for classname in tqdm(classnames):
13
+ texts = [template(classname) for template in templates] #format with class
14
+ texts = clip.tokenize(texts).to(device) #tokenize
15
+ class_embeddings = model.encode_text(texts)
16
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
17
+ class_embedding = class_embeddings.mean(dim=0)
18
+ class_embedding /= class_embedding.norm()
19
+ zeroshot_weights.append(class_embedding)
20
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
21
+ return 100*zeroshot_weights.t()
22
+
23
+