{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "963a04b2", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "8ebc97aa", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 1, "id": "b73f00ce", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running locally at: http://127.0.0.1:7860/\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "(, 'http://127.0.0.1:7860/', None)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import gradio as gr\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import gpytorch\n", "import torch\n", "import sys\n", "\n", "import gpytorch\n", "\n", "# We will use the simplest form of GP model, exact inference\n", "class ExactGPModel(gpytorch.models.ExactGP):\n", " def __init__(self, train_x, train_y, likelihood):\n", " super(ExactGPModel, self).__init__(train_x, train_y, likelihood)\n", " self.mean_module = gpytorch.means.ConstantMean()\n", " self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n", "\n", " def forward(self, x):\n", " mean_x = self.mean_module(x)\n", " covar_x = self.covar_module(x)\n", " return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n", "\n", "def get_model(x, y, hyperparameters):\n", " likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9))\n", " model = ExactGPModel(x, y, likelihood)\n", " model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters[\"noise\"]\n", " model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters[\"outputscale\"]\n", " model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \\\n", " hyperparameters[\"lengthscale\"]\n", " return model, likelihood\n", "\n", "\n", "\n", "excuse = \"Please only specify numbers, x values should be in [0,1] and y values in [-1,1].\"\n", "excuse_max_examples = \"This model is trained to work with up to 4 input points.\"\n", "hyperparameters = {'noise': 1e-4, 'outputscale': 1., 'lengthscale': .1, 'fast_computations': (False,False,False)}\n", "\n", "\n", "conf = .5\n", "\n", "def mean_and_bounds_for_gp(x,y,test_xs):\n", " gp_model, likelihood = get_model(x,y,hyperparameters)\n", " gp_model.eval()\n", " l = likelihood(gp_model(test_xs))\n", " means = l.mean.squeeze()\n", " varis = torch.diagonal(l.covariance_matrix.squeeze())\n", " stds = varis.sqrt()\n", " return means, means-stds, means+stds\n", "\n", "\n", "def mean_and_bounds_for_pnf(x,y,test_xs, choice):\n", " sys.path.append('prior-fitting/')\n", " model = torch.load(f'onefeature_gp_ls.1_pnf_{choice}.pt')\n", "\n", "\n", " logits = model((torch.cat([x,test_xs],0).unsqueeze(1),y.unsqueeze(1)),single_eval_pos=len(x))\n", " bounds = model.criterion.quantile(logits,center_prob=.682).squeeze(1)\n", " return model.criterion.mean(logits).squeeze(1), bounds[:,0], bounds[:,1]\n", "\n", "def plot_w_conf_interval(ax_or_plt, x, m, lb, ub, color):\n", " ax_or_plt.plot(x.squeeze(-1),m, color=color)\n", " ax_or_plt.fill_between(x.squeeze(-1), lb, ub, alpha=.1, color=color)\n", "\n", "\n", "\n", "\n", "@torch.no_grad()\n", "def infer(table, choice):\n", " vfunc = np.vectorize(lambda s: len(s))\n", " non_empty_row_mask = (vfunc(table).sum(1) != 0)\n", " table = table[non_empty_row_mask]\n", "\n", " try:\n", " table = table.astype(np.float32)\n", " except ValueError:\n", " return excuse, None\n", " x = torch.tensor(table[:,0]).unsqueeze(1)\n", " y = torch.tensor(table[:,1])\n", " fig = plt.figure()\n", "\n", " if len(x) > 4:\n", " return excuse_max_examples, None\n", " if (x<0.).any() or (x>1.).any() or (y<-1).any() or (y>1).any():\n", " return excuse, None\n", "\n", " plt.scatter(x,y)\n", "\n", "\n", " \n", " test_xs = torch.linspace(0,1,100).unsqueeze(1)\n", " \n", " plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_gp(x,y,test_xs), 'green')\n", " plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_pnf(x,y,test_xs, choice), 'blue')\n", "\n", "\n", " \n", " return '', plt.gcf()\n", "\n", "iface = gr.Interface(fn=infer, \n", " inputs=[\n", " gr.inputs.Dataframe(headers=[\"x\", \"y\"], datatype=[\"number\", \"number\"], row_count=2, type='numpy', default=[['.25','.1'],['.75','.4']]),\n", " gr.inputs.Radio(['160K','800K','4M'], type=\"value\", default='4M', label='Training Costs')\n", " ], outputs=[\"text\",\"plot\"])\n", "iface.launch()\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a3a377e3", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "72c0c821", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 5 }