{ "cells": [ { "cell_type": "code", "execution_count": 4, "id": "111c502f", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.insert(0,'..')" ] }, { "cell_type": "code", "execution_count": 10, "id": "e6b59ce3", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "\n", "\n", "from train import train\n", "import priors\n", "import encoders\n", "import positional_encodings\n", "import utils\n", "import bar_distribution\n", "import transformer\n", "\n", "from samlib.utils import chunker" ] }, { "cell_type": "code", "execution_count": 12, "id": "acf7423d", "metadata": {}, "outputs": [], "source": [ "kwargs = \\\n", "{\n", " 'nlayers': 6, \n", " 'dropout': 0.0, 'steps_per_epoch': 100, \n", "}\n", " \n", " \n", "def train_and_compare_fast_gp_mix(*args, **kwargs):\n", " hps = kwargs['extra_prior_kwargs_dict']['hyperparameters']\n", " num_features = kwargs['extra_prior_kwargs_dict']['num_features']\n", " baseline_res = priors.fast_gp_mix.evaluate(\n", " *args[0].get_batch_method(10000,kwargs['bptt'],num_features, hyperparameters=hps),\n", " hyperparameters=hps, \n", " use_mse=Losses.mse == args[2])\n", " print(baseline_res, 'with fast_gp_mix')\n", " \n", " res = train(*args, **kwargs)\n", " return res, baseline_res\n", "\n", "def train_and_compare_fast_gp(*args, num_evals=1000, **kwargs):\n", " hps = kwargs['extra_prior_kwargs_dict']['hyperparameters']\n", " num_features = kwargs['extra_prior_kwargs_dict']['num_features']\n", " baseline_res = priors.fast_gp.evaluate(\n", " *args[0].get_batch_method(num_evals,kwargs['bptt'],num_features, hyperparameters=hps, device='cpu'),\n", " hyperparameters=hps, \n", " use_mse=Losses.mse == args[2], device='cpu')\n", " print(baseline_res, 'with fast_gp')\n", " \n", " res = train(*args, **kwargs)\n", " return res, baseline_res\n", "\n", "def train_and_compare_gp(*args, num_evals=10000, **kwargs):\n", " num_features = kwargs['extra_prior_kwargs_dict']['num_features']\n", " baseline_res = priors.gp.evaluate(\n", " *args[0].get_batch_method(num_evals,kwargs['bptt'],num_features),\n", " use_mse=Losses.mse == args[2])\n", " print(baseline_res, 'with fast_gp')\n", " \n", " res = train(*args, **kwargs)\n", " return res, baseline_res\n", "\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "da083e24", "metadata": {}, "outputs": [], "source": [ "import gpytorch\n", "hps = {'noise': 1e-4, 'outputscale': 1., 'lengthscale': .6, 'fast_computations': (False,False,False)}\n", "\n", "import numpy as np, scipy.stats as st\n", "\n", "def compute_mean_and_conf_interval(accuracies, confidence=.95):\n", " accuracies = np.array(accuracies)\n", " n = len(accuracies)\n", " m, se = np.mean(accuracies, -1), st.sem(accuracies, -1)\n", " h = se * st.t.ppf((1 + confidence) / 2., n-1)\n", " return m, h\n", "\n", "\n", "def bl(hps,bptt, num_evals=100, num_features=1, step_size=1, evals_per_batch=None, speedups=(False,False,False,False)):\n", " if evals_per_batch is None:\n", " evals_per_batch = num_evals\n", " else:\n", " assert num_evals%evals_per_batch == 0\n", " results = []\n", " for batch_i in range(num_evals//evals_per_batch):\n", " with gpytorch.settings.fast_computations(False,False,False):\n", " batch = priors.fast_gp.get_batch(evals_per_batch,bptt,num_features, hyperparameters=hps)\n", " with gpytorch.settings.fast_pred_var(speedups[0]), gpytorch.settings.fast_computations(*speedups[1:]):\n", " all_res, baseline_res,_ = priors.fast_gp.evaluate(\n", " *batch,\n", " hyperparameters=hps, step_size=step_size\n", " )\n", " print(baseline_res, 'with fast_gp')\n", " \n", " results.append(all_res)\n", " all_results = torch.cat(results,1) # seq x batch_size\n", " return compute_mean_and_conf_interval(all_results) # mean array, var array\n", " \n", " \n", "#settings = [{'num_evals':n,} for n in [100,1000]]\n", " \n", "#js = [ex.submit(bl, hps, 2000, step_size=100, evals_per_batch=2, num_features=5, **kwargs) for kwargs in settings]\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "8088aa12", "metadata": {}, "outputs": [], "source": [ "# below you can simply replace the prior to priors.fast_gp_mix to do experiments over mixtures of GPs" ] }, { "cell_type": "code", "execution_count": null, "id": "165e683c", "metadata": {}, "outputs": [], "source": [ "num_features = 5\n", "hps = {'noise': 1e-4, 'outputscale': 1., 'lengthscale': .6, 'fast_computations': (False,False,False)}\n", "ys = priors.fast_gp.get_batch(100000,20,num_features, hyperparameters=hps)[1]\n", "fivefeature_jobs = [\n", " train(priors.fast_gp.DataLoader, bar_distribution.FullSupportBarDistribution(bar_distribution.get_bucket_limits(num_borders, ys=ys)), enc, emsize=emsize, nhead=nhead, warmup_epochs=warmup_epochs, y_encoder_generator=y_enc, pos_encoder_generator=pos_enc,\n", " batch_size=batch_size, scheduler=decay, extra_prior_kwargs_dict={'num_features': num_features, 'fuse_x_y': False, 'hyperparameters': hps},\n", " epochs=epochs, lr=lr, input_normalization=input_norm, bptt=2010, single_eval_pos_gen=single_eval_pos,aggregate_k_gradients=step_every, **kwargs) \n", " for enc in [encoders.Linear] for y_enc in [encoders.Linear] for emsize in [512] for nhead in [4] for nhid in [emsize*2] for epochs in [50*25,100*25,200*25,400*25] \n", " for warmup_epochs in [epochs//4] for input_norm in [False]\n", " for batch_size in [4] for step_every in [100//batch_size] for lr in [.0001,.0003,.001] for decay in [utils.get_cosine_schedule_with_warmup] for num_borders in [1000,10000] \n", " for single_eval_pos in [utils.get_weighted_single_eval_pos_sampler(2000)]\n", " for pos_enc in [positional_encodings.PositionalEncoding if single_eval_pos is None else positional_encodings.NoPositionalEncoding] \n", " for redo in range(1)\n", "]\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "15d01f3b", "metadata": {}, "outputs": [], "source": [ "import numpy as np, scipy.stats as st\n", "\n", "def compute_mean_and_conf_interval(accuracies, confidence=.95):\n", " accuracies = np.array(accuracies)\n", " n = len(accuracies)\n", " m, se = np.mean(accuracies), st.sem(accuracies)\n", " h = se * st.t.ppf((1 + confidence) / 2., n-1)\n", " return m, h\n", "hps = {'noise': 1e-4, 'outputscale': 1., 'lengthscale': .6, 'fast_computations': (False,False,False)}\n", "\n", "@torch.inference_mode()\n", "def run_test(model,device='cuda:0',step_size=100, start_pos=1, batch_size=1000, sub_batch_size=10, seq_len=2000):\n", " assert batch_size % sub_batch_size == 0\n", " model.to(device)\n", "\n", " model.eval()\n", " nlls = []\n", " nll_confidences = []\n", " mses = []\n", " max_mses = []\n", " eval_positions = []\n", " \n", " def get_metrics(model, eval_pos, batch_size):\n", " x,y, target_y = priors.fast_gp.get_batch(batch_size=batch_size, seq_len=eval_pos+1, num_features=5,hyperparameters=hps, device=device)\n", " logits = model((x,y), single_eval_pos=eval_pos)\n", " if isinstance(model.criterion,nn.GaussianNLLLoss):\n", " nll = model.criterion(logits[0][...,0], target_y[eval_pos], var=logits[0][...,1].abs())\n", " return nll, 0., 0.\n", " means = model.criterion.mean(logits) # num_evals x batch_size\n", " maxs = (model.criterion.borders[logits.argmax(-1)] + model.criterion.borders[logits.argmax(-1)+1])/2\n", " mse = nn.MSELoss()\n", " nll = model.criterion(logits[0], target_y[eval_pos])\n", " return nll, mse(means[0], target_y[eval_pos]), mse(maxs[0], target_y[eval_pos])\n", " \n", " \n", " for eval_pos in range(start_pos, seq_len, step_size):\n", " eval_positions.append(eval_pos)\n", " print(eval_pos)\n", " \n", " nll = []\n", " mean_mse = []\n", " max_mse = []\n", " for i in range(batch_size//sub_batch_size):\n", " batch_nll, batch_mean_mse, batch_max_mse = get_metrics(model, eval_pos, sub_batch_size)\n", " nll.append(batch_nll)\n", " mean_mse.append(batch_mean_mse)\n", " max_mse.append(batch_max_mse)\n", " \n", " nll = torch.cat(nll)\n", " mean_mse = torch.tensor(mean_mse).mean()\n", " max_mse = torch.tensor(max_mse).mean()\n", " \n", " \n", " mses.append(mean_mse)\n", " max_mses.append(max_mse)\n", " nlls.append(nll.mean())\n", " nll_confidences.append(compute_mean_and_conf_interval(nll.to('cpu'))[1])\n", " return eval_positions, torch.stack(mses).to('cpu'), torch.stack(max_mses).to('cpu'), torch.stack(nlls).to('cpu'), torch.tensor(nll_confidences).to('cpu')\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "755e88e4", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 5 }