{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "963a04b2",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "8ebc97aa",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b73f00ce",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running locally at: http://127.0.0.1:7860/\n",
"To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(, 'http://127.0.0.1:7860/', None)"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import gradio as gr\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import gpytorch\n",
"import torch\n",
"import sys\n",
"\n",
"import gpytorch\n",
"\n",
"# We will use the simplest form of GP model, exact inference\n",
"class ExactGPModel(gpytorch.models.ExactGP):\n",
" def __init__(self, train_x, train_y, likelihood):\n",
" super(ExactGPModel, self).__init__(train_x, train_y, likelihood)\n",
" self.mean_module = gpytorch.means.ConstantMean()\n",
" self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
"\n",
" def forward(self, x):\n",
" mean_x = self.mean_module(x)\n",
" covar_x = self.covar_module(x)\n",
" return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
"\n",
"def get_model(x, y, hyperparameters):\n",
" likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9))\n",
" model = ExactGPModel(x, y, likelihood)\n",
" model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters[\"noise\"]\n",
" model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters[\"outputscale\"]\n",
" model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \\\n",
" hyperparameters[\"lengthscale\"]\n",
" return model, likelihood\n",
"\n",
"\n",
"\n",
"excuse = \"Please only specify numbers, x values should be in [0,1] and y values in [-1,1].\"\n",
"excuse_max_examples = \"This model is trained to work with up to 4 input points.\"\n",
"hyperparameters = {'noise': 1e-4, 'outputscale': 1., 'lengthscale': .1, 'fast_computations': (False,False,False)}\n",
"\n",
"\n",
"conf = .5\n",
"\n",
"def mean_and_bounds_for_gp(x,y,test_xs):\n",
" gp_model, likelihood = get_model(x,y,hyperparameters)\n",
" gp_model.eval()\n",
" l = likelihood(gp_model(test_xs))\n",
" means = l.mean.squeeze()\n",
" varis = torch.diagonal(l.covariance_matrix.squeeze())\n",
" stds = varis.sqrt()\n",
" return means, means-stds, means+stds\n",
"\n",
"\n",
"def mean_and_bounds_for_pnf(x,y,test_xs, choice):\n",
" sys.path.append('prior-fitting/')\n",
" model = torch.load(f'onefeature_gp_ls.1_pnf_{choice}.pt')\n",
"\n",
"\n",
" logits = model((torch.cat([x,test_xs],0).unsqueeze(1),y.unsqueeze(1)),single_eval_pos=len(x))\n",
" bounds = model.criterion.quantile(logits,center_prob=.682).squeeze(1)\n",
" return model.criterion.mean(logits).squeeze(1), bounds[:,0], bounds[:,1]\n",
"\n",
"def plot_w_conf_interval(ax_or_plt, x, m, lb, ub, color):\n",
" ax_or_plt.plot(x.squeeze(-1),m, color=color)\n",
" ax_or_plt.fill_between(x.squeeze(-1), lb, ub, alpha=.1, color=color)\n",
"\n",
"\n",
"\n",
"\n",
"@torch.no_grad()\n",
"def infer(table, choice):\n",
" vfunc = np.vectorize(lambda s: len(s))\n",
" non_empty_row_mask = (vfunc(table).sum(1) != 0)\n",
" table = table[non_empty_row_mask]\n",
"\n",
" try:\n",
" table = table.astype(np.float32)\n",
" except ValueError:\n",
" return excuse, None\n",
" x = torch.tensor(table[:,0]).unsqueeze(1)\n",
" y = torch.tensor(table[:,1])\n",
" fig = plt.figure()\n",
"\n",
" if len(x) > 4:\n",
" return excuse_max_examples, None\n",
" if (x<0.).any() or (x>1.).any() or (y<-1).any() or (y>1).any():\n",
" return excuse, None\n",
"\n",
" plt.scatter(x,y)\n",
"\n",
"\n",
" \n",
" test_xs = torch.linspace(0,1,100).unsqueeze(1)\n",
" \n",
" plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_gp(x,y,test_xs), 'green')\n",
" plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_pnf(x,y,test_xs, choice), 'blue')\n",
"\n",
"\n",
" \n",
" return '', plt.gcf()\n",
"\n",
"iface = gr.Interface(fn=infer, \n",
" inputs=[\n",
" gr.inputs.Dataframe(headers=[\"x\", \"y\"], datatype=[\"number\", \"number\"], row_count=2, type='numpy', default=[['.25','.1'],['.75','.4']]),\n",
" gr.inputs.Radio(['160K','800K','4M'], type=\"value\", default='4M', label='Training Costs')\n",
" ], outputs=[\"text\",\"plot\"])\n",
"iface.launch()\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a3a377e3",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "72c0c821",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}