Samuel Mueller
add link
04eb228
raw
history blame
4.6 kB
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import gpytorch
import torch
import sys
import gpytorch
# We will use the simplest form of GP model, exact inference
class ExactGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
def get_model(x, y, hyperparameters):
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9))
model = ExactGPModel(x, y, likelihood)
model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters["noise"]
model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters["outputscale"]
model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \
hyperparameters["lengthscale"]
return model, likelihood
excuse = "Please only specify numbers, x values should be in [0,1] and y values in [-1,1]."
excuse_max_examples = "This model is trained to work with up to 4 input points."
hyperparameters = {'noise': 1e-4, 'outputscale': 1., 'lengthscale': .1, 'fast_computations': (False,False,False)}
conf = .5
def mean_and_bounds_for_gp(x,y,test_xs):
gp_model, likelihood = get_model(x,y,hyperparameters)
gp_model.eval()
l = likelihood(gp_model(test_xs))
means = l.mean.squeeze()
varis = torch.diagonal(l.covariance_matrix.squeeze())
stds = varis.sqrt()
return means, means-stds, means+stds
def mean_and_bounds_for_pnf(x,y,test_xs, choice):
sys.path.append('prior-fitting/')
model = torch.load(f'onefeature_gp_ls.1_pnf_{choice}.pt')
logits = model((torch.cat([x,test_xs],0).unsqueeze(1),y.unsqueeze(1)),single_eval_pos=len(x))
bounds = model.criterion.quantile(logits,center_prob=.682).squeeze(1)
return model.criterion.mean(logits).squeeze(1), bounds[:,0], bounds[:,1]
def plot_w_conf_interval(ax_or_plt, x, m, lb, ub, color):
ax_or_plt.plot(x.squeeze(-1),m, color=color)
ax_or_plt.fill_between(x.squeeze(-1), lb, ub, alpha=.1, color=color)
@torch.no_grad()
def infer(table, choice):
vfunc = np.vectorize(lambda s: len(s))
non_empty_row_mask = (vfunc(table).sum(1) != 0)
table = table[non_empty_row_mask]
try:
table = table.astype(np.float32)
except ValueError:
return excuse, None
x = torch.tensor(table[:,0]).unsqueeze(1)
y = torch.tensor(table[:,1])
fig = plt.figure(figsize=(4,2),dpi=1000)
if len(x) > 4:
return excuse_max_examples, None
if (x<0.).any() or (x>1.).any() or (y<-1).any() or (y>1).any():
return excuse, None
plt.scatter(x,y)
test_xs = torch.linspace(0,1,100).unsqueeze(1)
plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_gp(x,y,test_xs), 'green')
plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_pnf(x,y,test_xs, choice), 'blue')
return '', plt.gcf()
iface = gr.Interface(fn=infer,
title='GP Posterior Approximation with Transformers',
description='''This is a demo of PFNs as we describe them in our recent paper (https://openreview.net/forum?id=KSugKcbNf9).
Lines represent means and shaded areas are the confidence interval (68.2% quantile). In green, we have the ground truth GP posterior and in blue we have our approximation.
We provide three models that are architecturally the same, but with different training budgets.
''',
article="<p style='text-align: center'><a href='https://openreview.net/forum?id=KSugKcbNf9'>Paper: Transformers Can Do Bayesian Inference</a></p>",
inputs=[
gr.inputs.Dataframe(headers=["x", "y"], datatype=["number", "number"], row_count=2, type='numpy', default=[['.25','.1'],['.75','.4']], label='The data: you can change this and increase the number of data points using the `enter` key.'),
gr.inputs.Radio(['160K','800K','4M'], type="value", default='160K', label='Number of Sampled Datasets in Training (Training Costs)')
], outputs=["text","plot"])
iface.launch()