File size: 2,431 Bytes
801501a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from __future__ import annotations
import os
import pickle
import salad.spaghetti.constants as const
from salad.spaghetti.custom_types import *


class Options:

    def load(self):
        device = self.device
        if os.path.isfile(self.save_path):
            print(f'loading opitons from {self.save_path}')
            with open(self.save_path, 'rb') as f:
                options = pickle.load(f)
            options.device = device
            return options
        return self

    def save(self):
        if os.path.isdir(self.cp_folder):
            # self.already_saved = True
            with open(self.save_path, 'wb') as f:
                pickle.dump(self, f, pickle.HIGHEST_PROTOCOL)

    @property
    def info(self) -> str:
        return f'{self.model_name}_{self.tag}'

    @property
    def cp_folder(self):
        return f'{const.CHECKPOINTS_ROOT}{self.info}'

    @property
    def save_path(self):
        return f'{const.CHECKPOINTS_ROOT}{self.info}/options.pkl'

    def fill_args(self, args):
        for arg in args:
            if hasattr(self, arg):
                setattr(self, arg, args[arg])

    def __init__(self, **kwargs):
        self.device = CUDA(0)
        self.tag = 'airplanes'
        self.dataset_name = 'shapenet_airplanes_wm_sphere_sym_train'
        self.epochs = 2000
        self.model_name = 'spaghetti'
        self.dim_z = 256
        self.pos_dim = 256 - 3
        self.dim_h = 512
        self.dim_zh = 512
        self.num_gaussians = 16
        self.min_split = 4
        self.max_split = 12
        self.gmm_weight = 1
        self.decomposition_network = 'transformer'
        self.decomposition_num_layers = 4
        self.num_layers = 4
        self.num_heads = 4
        self.num_layers_head = 6
        self.num_heads_head = 8
        self.head_occ_size = 5
        self.head_occ_type = 'skip'
        self.batch_size = 18
        self.num_samples = 2000
        self.dataset_size = -1
        self.symmetric = (True, False, False)
        self.data_symmetric = (True, False, False)
        self.lr_decay = .9
        self.lr_decay_every = 500
        self.warm_up = 2000
        self.reg_weight = 1e-4
        self.disentanglement = True
        self.use_encoder = True
        self.disentanglement_weight = 1
        self.augmentation_rotation = 0.3
        self.augmentation_scale = .2
        self.augmentation_translation = .3
        self.fill_args(kwargs)