File size: 5,653 Bytes
f50f696
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import time

import torch
from torch import nn
import gpytorch

from .utils import get_batch_to_dataloader
from utils import default_device
from .utils import order_by_y, normalize_data, normalize_by_used_features_f, Binarize


# We will use the simplest form of GP model, exact inference
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


def get_model(x, y, hyperparameters):
    likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9))
    model = ExactGPModel(x, y, likelihood)
    model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters["noise"]
    model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters["outputscale"]
    model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \
                                                 hyperparameters["lengthscale"]
    return model, likelihood


@torch.no_grad()
def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=None, equidistant_x=False, fix_x=None):
    if isinstance(hyperparameters, (tuple, list)):
        hyperparameters = {"noise": hyperparameters[0], "outputscale": hyperparameters[1], "lengthscale": hyperparameters[2]}
    elif hyperparameters is None:
        hyperparameters = {"noise": .1, "outputscale": .1, "lengthscale": .1}
    with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))):
        start = time.time()

        assert not (equidistant_x and (fix_x is not None))
        if equidistant_x:
            assert num_features == 1
            x = torch.linspace(0,1.,seq_len).unsqueeze(0).repeat(batch_size,1).unsqueeze(-1).to(device)
        elif fix_x is not None:
            assert fix_x.shape == (seq_len, num_features)
            x = fix_x.unsqueeze(0).repeat(batch_size, 1, 1).to(device)
        else:
            x = torch.rand(batch_size, seq_len, num_features, device=device)
        model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
        model.to(device)
        # trained_model = ExactGPModel(train_x, train_y, likelihood).cuda()
        # trained_model.eval()
        with gpytorch.settings.prior_mode(True):
            d = model(x)
            d = likelihood(d)
            sample = d.sample().transpose(0, 1)
        #print(f'took {time.time() - start}')
    return x.transpose(0, 1), sample, sample # x.shape = (T,B,H)

# TODO: Reintegrate this code
# num_features_used = num_features_used_sampler()
# prior_outputscale = prior_outputscale_sampler()
# prior_lengthscale = prior_lengthscale_sampler()
#
# x, sample = normalize_data(x), normalize_data(sample)
#
# if is_binary_classification:
#     sample = (sample > torch.median(sample, dim=0)[0]).float()
#
# if normalize_by_used_features:
#     x = normalize_by_used_features_f(x, num_features_used, num_features)
#
# # # if is_binary_classification and order_y:
# # #     x, sample = order_by_y(x, sample)
# #
# # Append empty features if enabled
# x = torch.cat([x, torch.zeros((x.shape[0], x.shape[1], num_features - num_features_used), device=device)], -1)

DataLoader = get_batch_to_dataloader(get_batch)
DataLoader.num_outputs = 1

def get_model_on_device(x,y,hyperparameters,device):
    model, likelihood = get_model(x, y, hyperparameters)
    model.to(device)
    return model, likelihood


@torch.no_grad()
def evaluate(x, y, y_non_noisy, use_mse=False, hyperparameters={}, get_model_on_device=get_model_on_device, device=default_device, step_size=1, start_pos=0):
    start_time = time.time()
    losses_after_t = [.0] if start_pos == 0 else []
    all_losses_after_t = []

    with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))), gpytorch.settings.fast_pred_var(False):
        for t in range(max(start_pos, 1), len(x), step_size):
            loss_sum = 0.
            model, likelihood = get_model_on_device(x[:t].transpose(0, 1), y[:t].transpose(0, 1), hyperparameters, device)


            model.eval()
            # print([t.shape for t in model.train_inputs])
            # print(x[:t].transpose(0,1).shape, x[t].unsqueeze(1).shape, y[:t].transpose(0,1).shape)
            f = model(x[t].unsqueeze(1))
            l = likelihood(f)
            means = l.mean.squeeze()
            varis = l.covariance_matrix.squeeze()
            # print(l.variance.squeeze(), l.mean.squeeze(), y[t])

            assert len(means.shape) == len(varis.shape) == 1
            assert len(means) == len(varis) == x.shape[1]

            if use_mse:
                c = nn.MSELoss(reduction='none')
                ls = c(means, y[t])
            else:
                ls = -l.log_prob(y[t].unsqueeze(1))

            losses_after_t.append(ls.mean())
            all_losses_after_t.append(ls.flatten())
        return torch.stack(all_losses_after_t).to('cpu'), torch.tensor(losses_after_t).to('cpu'), time.time() - start_time

if __name__ == '__main__':
    hps = (.1,.1,.1)
    for redo_idx in range(1):
        print(
            evaluate(*get_batch(1000, 10, hyperparameters=hps, num_features=10), use_mse=False, hyperparameters=hps))