File size: 1,902 Bytes
803ef9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def get_data_mean_and_stdev(dataset):
    if dataset == 'CIFAR10' or dataset == 'CIFAR100':
        mean = [0.5, 0.5, 0.5]
        std  = [0.5, 0.5, 0.5]
    elif dataset == 'STL-10':
        mean = [0.491, 0.482, 0.447]
        std  = [0.247, 0.244, 0.262]
    elif dataset == 'ImageNet':
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
    elif dataset == 'aircraft':
        mean = [0.486, 0.507, 0.525]
        std  = [0.266, 0.260, 0.276]
    elif dataset == 'cu_birds':
        mean = [0.483, 0.491, 0.424] 
        std  = [0.228, 0.224, 0.259]
    elif dataset == 'dtd':
        mean = [0.533, 0.474, 0.426]
        std  = [0.261, 0.250, 0.259]
    elif dataset == 'fashionmnist':
        mean = [0.348, 0.348, 0.348] 
        std  = [0.347, 0.347, 0.347]
    elif dataset == 'mnist':
        mean = [0.170, 0.170, 0.170]
        std  = [0.320, 0.320, 0.320]
    elif dataset == 'traffic_sign':
        mean = [0.335, 0.291, 0.295]
        std  = [0.267, 0.249, 0.251]
    elif dataset == 'vgg_flower':
        mean = [0.518, 0.410, 0.329]
        std  = [0.296, 0.249, 0.285]
    else:
        raise Exception('Dataset %s not supported.'%dataset)
    return mean, std

def get_data_nclass(dataset):
    if dataset == 'cifar10':
        nclass = 10
    elif dataset == 'cifar100cifar10':
        nclass = 100
    elif dataset == 'stl-10':
        nclass = 10
    elif dataset == 'ImageNet':
        nclass = 1000
    elif dataset == 'aircraft':
        nclass = 102
    elif dataset == 'cu_birds':
        nclass = 200
    elif dataset == 'dtd':
        nclass = 47
    elif dataset == 'fashionmnist':
        nclass = 10
    elif dataset == 'mnist':
        nclass = 10
    elif dataset == 'traffic_sign':
        nclass = 43
    elif dataset == 'vgg_flower':
        nclass = 102
    else:
        raise Exception('Dataset %s not supported.'%dataset)
    return nclass