{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "a873fcbb", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.insert(0,'..')" ] }, { "cell_type": "code", "execution_count": 5, "id": "56023c88", "metadata": {}, "outputs": [], "source": [ "import random\n", "\n", "import numpy as np\n", "import torch\n", "from torch import nn\n", "from sklearn.gaussian_process import GaussianProcessRegressor\n", "from sklearn.gaussian_process.kernels import RBF, DotProduct, WhiteKernel\n", "from priors.utils import get_batch_to_dataloader" ] }, { "cell_type": "code", "execution_count": 68, "id": "036c690b", "metadata": {}, "outputs": [], "source": [ "def get_gp():\n", " gp = GaussianProcessRegressor(\n", " kernel=RBF(length_scale=.6, length_scale_bounds='fixed'),\n", " random_state=0, optimizer=None)\n", " return gp" ] }, { "cell_type": "code", "execution_count": 77, "id": "ff8a3cd1", "metadata": {}, "outputs": [], "source": [ "seq_len = 4\n", "num_features = 10\n", "x = torch.rand(seq_len, num_features)\n", "gpr = get_gp()\n", "y = gpr.sample_y(x, random_state=random.randint(0, 2 ** 32)).squeeze()\n" ] }, { "cell_type": "code", "execution_count": 78, "id": "46fe34a9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[-0.29995838] [0.90399136]\n", "[-0.1039504] [0.98874968]\n", "[-0.03414801] [0.99876344]\n", "[-0.01104748] [0.99986603]\n", "[-0.00356252] [0.9999855]\n", "[-0.00114827] [0.99999843]\n", "[-0.00037014] [0.99999983]\n", "[-0.00011934] [0.99999998]\n", "[-3.8486538e-05] [1.]\n", "[-1.24147253e-05] [1.]\n", "[-4.00568455e-06] [1.]\n", "[-1.2927993e-06] [1.]\n", "[-4.17353027e-07] [1.]\n", "[-1.34771328e-07] [1.]\n", "[-4.35327732e-08] [1.]\n", "[-1.40657691e-08] [1.]\n", "[-4.54613576e-09] [1.]\n", "[-1.46979425e-09] [1.]\n", "[-4.75345491e-10] [1.]\n" ] } ], "source": [ "for num_copies in range(1,20):\n", " gp = get_gp()\n", " x_copied = x.tile((1,num_copies))\n", " gp.fit(x_copied[:-1],y[:-1])\n", " m,s = gp.predict(x_copied[-1].reshape(1,-1), return_std=True)\n", " print(m,s)" ] }, { "cell_type": "code", "execution_count": 79, "id": "87752b3d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1. , 0.1047567 , 0.17720387, 0.33463634],\n", " [0.1047567 , 1. , 0.14686013, 0.04858264],\n", " [0.17720387, 0.14686013, 1. , 0.32035965],\n", " [0.33463634, 0.04858264, 0.32035965, 1. ]])" ] }, "execution_count": 79, "metadata": {}, "output_type": "execute_result" } ], "source": [ "k = RBF(length_scale=.6, length_scale_bounds='fixed')\n", "k(x)" ] }, { "cell_type": "code", "execution_count": 80, "id": "6a409ae5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1.00000000e+00, 2.41799081e-19, 5.26006251e-15, 9.26592960e-10],\n", " [2.41799081e-19, 1.00000000e+00, 1.48311381e-16, 1.10443925e-25],\n", " [5.26006251e-15, 1.48311381e-16, 1.00000000e+00, 4.04686299e-10],\n", " [9.26592960e-10, 1.10443925e-25, 4.04686299e-10, 1.00000000e+00]])" ] }, "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "source": [ "k = RBF(length_scale=.6, length_scale_bounds='fixed')\n", "k(x_copied)" ] }, { "cell_type": "code", "execution_count": null, "id": "24141432", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 5 }