|
import time |
|
|
|
import torch |
|
from torch import nn |
|
import gpytorch |
|
|
|
from .utils import get_batch_to_dataloader |
|
from utils import default_device |
|
from .utils import order_by_y, normalize_data, normalize_by_used_features_f, Binarize |
|
|
|
|
|
|
|
class ExactGPModel(gpytorch.models.ExactGP): |
|
def __init__(self, train_x, train_y, likelihood): |
|
super(ExactGPModel, self).__init__(train_x, train_y, likelihood) |
|
self.mean_module = gpytorch.means.ConstantMean() |
|
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) |
|
|
|
def forward(self, x): |
|
mean_x = self.mean_module(x) |
|
covar_x = self.covar_module(x) |
|
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x) |
|
|
|
|
|
def get_model(x, y, hyperparameters): |
|
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9)) |
|
model = ExactGPModel(x, y, likelihood) |
|
model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters["noise"] |
|
model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters["outputscale"] |
|
model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \ |
|
hyperparameters["lengthscale"] |
|
return model, likelihood |
|
|
|
|
|
@torch.no_grad() |
|
def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=None, equidistant_x=False, fix_x=None): |
|
if isinstance(hyperparameters, (tuple, list)): |
|
hyperparameters = {"noise": hyperparameters[0], "outputscale": hyperparameters[1], "lengthscale": hyperparameters[2]} |
|
elif hyperparameters is None: |
|
hyperparameters = {"noise": .1, "outputscale": .1, "lengthscale": .1} |
|
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))): |
|
start = time.time() |
|
|
|
assert not (equidistant_x and (fix_x is not None)) |
|
if equidistant_x: |
|
assert num_features == 1 |
|
x = torch.linspace(0,1.,seq_len).unsqueeze(0).repeat(batch_size,1).unsqueeze(-1).to(device) |
|
elif fix_x is not None: |
|
assert fix_x.shape == (seq_len, num_features) |
|
x = fix_x.unsqueeze(0).repeat(batch_size, 1, 1).to(device) |
|
else: |
|
x = torch.rand(batch_size, seq_len, num_features, device=device) |
|
model, likelihood = get_model(x, torch.Tensor(), hyperparameters) |
|
model.to(device) |
|
|
|
|
|
with gpytorch.settings.prior_mode(True): |
|
d = model(x) |
|
d = likelihood(d) |
|
sample = d.sample().transpose(0, 1) |
|
|
|
return x.transpose(0, 1), sample, sample |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DataLoader = get_batch_to_dataloader(get_batch) |
|
DataLoader.num_outputs = 1 |
|
|
|
def get_model_on_device(x,y,hyperparameters,device): |
|
model, likelihood = get_model(x, y, hyperparameters) |
|
model.to(device) |
|
return model, likelihood |
|
|
|
|
|
@torch.no_grad() |
|
def evaluate(x, y, y_non_noisy, use_mse=False, hyperparameters={}, get_model_on_device=get_model_on_device, device=default_device, step_size=1, start_pos=0): |
|
start_time = time.time() |
|
losses_after_t = [.0] if start_pos == 0 else [] |
|
all_losses_after_t = [] |
|
|
|
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))), gpytorch.settings.fast_pred_var(False): |
|
for t in range(max(start_pos, 1), len(x), step_size): |
|
loss_sum = 0. |
|
model, likelihood = get_model_on_device(x[:t].transpose(0, 1), y[:t].transpose(0, 1), hyperparameters, device) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
f = model(x[t].unsqueeze(1)) |
|
l = likelihood(f) |
|
means = l.mean.squeeze() |
|
varis = l.covariance_matrix.squeeze() |
|
|
|
|
|
assert len(means.shape) == len(varis.shape) == 1 |
|
assert len(means) == len(varis) == x.shape[1] |
|
|
|
if use_mse: |
|
c = nn.MSELoss(reduction='none') |
|
ls = c(means, y[t]) |
|
else: |
|
ls = -l.log_prob(y[t].unsqueeze(1)) |
|
|
|
losses_after_t.append(ls.mean()) |
|
all_losses_after_t.append(ls.flatten()) |
|
return torch.stack(all_losses_after_t).to('cpu'), torch.tensor(losses_after_t).to('cpu'), time.time() - start_time |
|
|
|
if __name__ == '__main__': |
|
hps = (.1,.1,.1) |
|
for redo_idx in range(1): |
|
print( |
|
evaluate(*get_batch(1000, 10, hyperparameters=hps, num_features=10), use_mse=False, hyperparameters=hps)) |
|
|