File size: 13,044 Bytes
708dec4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358

import time
import pickle
import logging
import os
import numpy as np
import torch
import torch.nn as nn


from collections import OrderedDict
from yaml import safe_dump
from yacs.config import load_cfg, CfgNode#, _to_dict
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.engine.inference import _accumulate_predictions_from_multiple_gpus
from maskrcnn_benchmark.modeling.backbone.nas import get_layer_name
from maskrcnn_benchmark.utils.comm import synchronize, get_rank, is_main_process, get_world_size, all_gather
from maskrcnn_benchmark.data.datasets.evaluation import evaluate
from maskrcnn_benchmark.utils.flops import profile


choice = lambda x:x[np.random.randint(len(x))] if isinstance(x,tuple) else choice(tuple(x))


def gather_candidates(all_candidates):
    all_candidates = all_gather(all_candidates)
    all_candidates = [cand for candidates in all_candidates for cand in candidates]
    return list(set(all_candidates))


def gather_stats(all_candidates):
    all_candidates = all_gather(all_candidates)
    reduced_statcs = {}
    for candidates in all_candidates:
        reduced_statcs.update(candidates) # will replace the existing key with last value if more than one exists
    return reduced_statcs


def compute_on_dataset(model, rngs, data_loader, device=cfg.MODEL.DEVICE):
    model.eval()
    results_dict = {}
    cpu_device = torch.device("cpu")
    for _, batch in enumerate(data_loader):
        images, targets, image_ids = batch
        with torch.no_grad():
            output = model(images.to(device), rngs=rngs)
            output = [o.to(cpu_device) for o in output]
        results_dict.update(
            {img_id: result for img_id, result in zip(image_ids, output)}
        )
    return results_dict


def bn_statistic(model, rngs, data_loader, device=cfg.MODEL.DEVICE, max_iter=500):
    for name, param in model.named_buffers():
        if 'running_mean' in name:
            nn.init.constant_(param, 0)
        if 'running_var' in name:
            nn.init.constant_(param, 1)

    model.train()
    for iteration, (images, targets, _) in enumerate(data_loader, 1):
        images = images.to(device)
        targets = [target.to(device) for target in targets]
        with torch.no_grad():
            loss_dict = model(images, targets, rngs)
        if iteration >= max_iter:
            break

    return model


def inference(

        model,

        rngs,

        data_loader,

        iou_types=("bbox",),

        box_only=False,

        device="cuda",

        expected_results=(),

        expected_results_sigma_tol=4,

        output_folder=None,

):

    # convert to a torch.device for efficiency
    device = torch.device(device)
    dataset = data_loader.dataset
    predictions = compute_on_dataset(model, rngs, data_loader, device)
    # wait for all processes to complete before measuring the time
    synchronize()

    predictions = _accumulate_predictions_from_multiple_gpus(predictions)
    if not is_main_process():
        return

    extra_args = dict(
        box_only=box_only,
        iou_types=iou_types,
        expected_results=expected_results,
        expected_results_sigma_tol=expected_results_sigma_tol,
    )

    return evaluate(dataset=dataset,
                    predictions=predictions,
                    output_folder=output_folder,
                    **extra_args)


def fitness(cfg, model, rngs, val_loaders):
    iou_types = ("bbox",)
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm",)
    for data_loader_val in val_loaders:
        results = inference(
            model,
            rngs,
            data_loader_val,
            iou_types=iou_types,
            box_only=False,
            device=cfg.MODEL.DEVICE,
            expected_results=cfg.TEST.EXPECTED_RESULTS,
            expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
        )
        synchronize()

    return results


class EvolutionTrainer(object):
    def __init__(self, cfg, model, flops_limit=None, is_distributed=True):

        self.log_dir = cfg.OUTPUT_DIR
        self.checkpoint_name = os.path.join(self.log_dir,'evolution.pth')
        self.is_distributed = is_distributed

        self.states = model.module.mix_nums if is_distributed else model.mix_nums
        self.supernet_state_dict = pickle.loads(pickle.dumps(model.state_dict()))
        self.flops_limit = flops_limit
        self.model = model

        self.candidates = []
        self.vis_dict = {}

        self.max_epochs = cfg.SEARCH.MAX_EPOCH
        self.select_num = cfg.SEARCH.SELECT_NUM
        self.population_num = cfg.SEARCH.POPULATION_NUM/get_world_size()
        self.mutation_num = cfg.SEARCH.MUTATION_NUM/get_world_size()
        self.crossover_num = cfg.SEARCH.CROSSOVER_NUM/get_world_size()
        self.mutation_prob = cfg.SEARCH.MUTATION_PROB/get_world_size()

        self.keep_top_k = {self.select_num:[], 50:[]}
        self.epoch=0
        self.cfg = cfg

    def save_checkpoint(self):
        if not is_main_process():
            return
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)
        info = {}
        info['candidates'] = self.candidates
        info['vis_dict'] = self.vis_dict
        info['keep_top_k'] = self.keep_top_k
        info['epoch'] = self.epoch
        torch.save(info, self.checkpoint_name)
        print('Save checkpoint to', self.checkpoint_name)

    def load_checkpoint(self):
        if not os.path.exists(self.checkpoint_name):
            return False
        info = torch.load(self.checkpoint_name)
        self.candidates = info['candidates']
        self.vis_dict = info['vis_dict']
        self.keep_top_k = info['keep_top_k']
        self.epoch = info['epoch']
        print('Load checkpoint from', self.checkpoint_name)
        return True

    def legal(self, cand):
        assert isinstance(cand,tuple) and len(cand)==len(self.states)
        if cand in self.vis_dict:
            return False

        if self.flops_limit is not None:
            net = self.model.module.backbone if self.is_distributed else self.model.backbone
            inp = (1, 3, 224, 224)
            flops, params = profile(net, inp, extra_args={'paths': list(cand)})
            flops = flops/1e6
            print('flops:',flops)
            if flops>self.flops_limit:
                return False

        return True

    def update_top_k(self, candidates, *, k, key, reverse=False):
        assert k in self.keep_top_k
        # print('select ......')
        t = self.keep_top_k[k]
        t += candidates
        t.sort(key=key,reverse=reverse)
        self.keep_top_k[k]=t[:k]

    def eval_candidates(self, train_loader, val_loader):
        for cand in self.candidates:
            t0 = time.time()

            # load back supernet state dict
            self.model.load_state_dict(self.supernet_state_dict)
            # bn_statistic
            model = bn_statistic(self.model, list(cand), train_loader)
            # fitness
            evals = fitness(cfg, model, list(cand), val_loader)

            if is_main_process():
                acc = evals[0].results['bbox']['AP']
                self.vis_dict[cand] = acc
                print('candiate ', cand)
                print('time: {}s'.format(time.time() - t0))
                print('acc ', acc)

    def stack_random_cand(self, random_func, *, batchsize=10):
        while True:
            cands = [random_func() for _ in range(batchsize)]
            for cand in cands:
                yield cand

    def random_can(self, num):
        # print('random select ........')
        candidates = []
        cand_iter = self.stack_random_cand(lambda:tuple(np.random.randint(i) for i in self.states))
        while len(candidates)<num:
            cand = next(cand_iter)

            if not self.legal(cand):
                continue
            candidates.append(cand)
            #print('random {}/{}'.format(len(candidates),num))

        # print('random_num = {}'.format(len(candidates)))
        return candidates

    def get_mutation(self, k, mutation_num, m_prob):
        assert k in self.keep_top_k
        # print('mutation ......')
        res = []
        iter = 0
        max_iters = mutation_num*10

        def random_func():
            cand = list(choice(self.keep_top_k[k]))
            for i in range(len(self.states)):
                if np.random.random_sample()<m_prob:
                    cand[i] = np.random.randint(self.states[i])
            return tuple(cand)

        cand_iter = self.stack_random_cand(random_func)
        while len(res)<mutation_num and max_iters>0:
            cand = next(cand_iter)
            if not self.legal(cand):
                continue
            res.append(cand)
            #print('mutation {}/{}'.format(len(res),mutation_num))
            max_iters-=1

        # print('mutation_num = {}'.format(len(res)))
        return res

    def get_crossover(self, k, crossover_num):
        assert k in self.keep_top_k
        # print('crossover ......')
        res = []
        iter = 0
        max_iters = 10 * crossover_num

        def random_func():
            p1=choice(self.keep_top_k[k])
            p2=choice(self.keep_top_k[k])
            return tuple(choice([i,j]) for i,j in zip(p1,p2))

        cand_iter = self.stack_random_cand(random_func)
        while len(res)<crossover_num and max_iters>0:
            cand = next(cand_iter)
            if not self.legal(cand):
                continue
            res.append(cand)
            #print('crossover {}/{}'.format(len(res),crossover_num))
            max_iters-=1

        # print('crossover_num = {}'.format(len(res)))
        return res

    def train(self, train_loader, val_loader):
        logger = logging.getLogger("maskrcnn_benchmark.evolution")

        if not self.load_checkpoint():
            self.candidates = gather_candidates(self.random_can(self.population_num))

        while self.epoch<self.max_epochs:
            self.eval_candidates(train_loader, val_loader)
            self.vis_dict = gather_stats(self.vis_dict)

            self.update_top_k(self.candidates, k=self.select_num, key=lambda x:1-self.vis_dict[x])
            self.update_top_k(self.candidates, k=50, key=lambda x:1-self.vis_dict[x])

            if is_main_process():
                logger.info('Epoch {} : top {} result'.format(self.epoch+1, len(self.keep_top_k[self.select_num])))
                for i,cand in enumerate(self.keep_top_k[self.select_num]):
                    logger.info('     No.{} {} perf = {}'.format(i+1, cand, self.vis_dict[cand]))

            mutation = gather_candidates(self.get_mutation(self.select_num, self.mutation_num, self.mutation_prob))
            crossover = gather_candidates(self.get_crossover(self.select_num, self.crossover_num))
            rand = gather_candidates(self.random_can(self.population_num - len(mutation) - len(crossover)))

            self.candidates = mutation + crossover + rand

            self.epoch+=1
            self.save_checkpoint()

    def save_candidates(self, cand, template):
        paths = self.keep_top_k[self.select_num][cand-1]

        with open(template, "r") as f:
            super_cfg = load_cfg(f)

        search_spaces = {}
        for mix_ops in super_cfg.MODEL.BACKBONE.LAYER_SEARCH:
            search_spaces[mix_ops] = super_cfg.MODEL.BACKBONE.LAYER_SEARCH[mix_ops]
        search_layers = super_cfg.MODEL.BACKBONE.LAYER_SETUP

        layer_setup = []
        for i, layer in enumerate(search_layers):
            name, setup = get_layer_name(layer, search_spaces)
            if not isinstance(name, list):
                name = [name]
            name = name[paths[i]]

            layer_setup.append("('{}', {})".format(name, str(setup)[1:-1]))
        super_cfg.MODEL.BACKBONE.LAYER_SETUP = layer_setup

        cand_cfg = _to_dict(super_cfg)
        del cand_cfg['MODEL']['BACKBONE']['LAYER_SEARCH']
        with open(os.path.join(self.cfg.OUTPUT_DIR, os.path.basename(template)).replace('.yaml','_cand{}.yaml'.format(cand)), 'w') as f:
            f.writelines(safe_dump(cand_cfg))

        super_weight = self.supernet_state_dict
        cand_weight = OrderedDict()
        cand_keys = ['layers.{}.ops.{}'.format(i, c) for i, c in enumerate(paths)]

        for key, val in super_weight.items():
            if 'ops' in key:
                for ck in cand_keys:
                    if ck in key:
                        cand_weight[key.replace(ck,ck.split('.ops.')[0])] = val
            else:
                cand_weight[key] = val

        torch.save({'model':cand_weight}, os.path.join(self.cfg.OUTPUT_DIR, 'init_cand{}.pth'.format(cand)))