diff --git "a/prior-fitting/notebooks/TabularEvalSimple.ipynb" "b/prior-fitting/notebooks/TabularEvalSimple.ipynb" new file mode 100644--- /dev/null +++ "b/prior-fitting/notebooks/TabularEvalSimple.ipynb" @@ -0,0 +1,2638 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import os\n", + "\n", + "from train import train\n", + "import priors\n", + "import utils\n", + "\n", + "import numpy as np\n", + "\n", + "from datasets import load_openml_list, valid_dids_classification, test_dids_classification\n", + "from tabular import evaluate, get_model, get_default_spec\n", + "\n", + "from tabular import bayes_net_metric, gp_metric, knn_metric, ridge_metric, catboost_metric, xgb_metric, logistic_metric, tabnet_metric" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading Datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading test datasets...\n", + "wine 973\n", + "covertype 1596\n", + "\n", + " Loading valid datasets...\n", + "ionosphere 59\n" + ] + } + ], + "source": [ + "### Loads small list of datasets\n", + "print('Loading test datasets...')\n", + "test_datasets, test_datasets_df = load_openml_list(test_dids_classification[0:2], filter_for_nan=True)\n", + "ds = test_datasets\n", + "\n", + "print('\\n Loading valid datasets...')\n", + "valid_datasets, valid_datasets_df = load_openml_list(valid_dids_classification[0:2], filter_for_nan=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading test datasets...\n", + "kr-vs-kp 3\n", + "credit-g 31\n", + "vehicle 54\n", + "wine 973\n", + "kc1 1067\n", + "airlines 1169\n", + "bank-marketing 1461\n", + "blood-transfusion-service-center 1464\n", + "phoneme 1489\n", + "covertype 1596\n", + "numerai28.6 23517\n", + "connect-4 40668\n", + "car 40975\n", + "Australian 40981\n", + "segment 40984\n", + "jungle_chess_2pcs_raw_endgame_complete 41027\n", + "sylvine 41146\n", + "MiniBooNE 41150\n", + "dionis 41167\n", + "jannis 41168\n", + "helena 41169\n", + "\n", + " Loading valid datasets...\n", + "haberman 43\n", + "ionosphere 59\n", + "sa-heart 1498\n", + "cleve 40710\n" + ] + } + ], + "source": [ + "### Loads all datasets\n", + "print('Loading test datasets...')\n", + "test_datasets, test_datasets_df = load_openml_list(test_dids_classification, filter_for_nan=True)\n", + "ds = test_datasets\n", + "\n", + "print('\\n Loading valid datasets...')\n", + "valid_datasets, valid_datasets_df = load_openml_list(valid_dids_classification, filter_for_nan=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setting params" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "device = 'cpu'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# After how many training samples should evaluatuion be done?\n", + "# Trained models have not been trained to evaluate after 30 samples\n", + "# so performance will drop\n", + "eval_positions = [30]\n", + "\n", + "# What is the maximum number of features?\n", + "# Pretrained models have to use 60\n", + "max_features = 60\n", + "\n", + "# How many samples should be loaded for one dataset?\n", + "# Samples after the training sequence are used for evaluation\n", + "seq_len = 100\n", + "\n", + "# How many subsamples of datasets should be drawn for each dataset\n", + "max_samples = 20" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "gp_model_checkpoint_dir = \"results/tabular_model_gp.ckpt\"\n", + "gp_model_config = {'batch_size': 512,\n", + " 'bptt': 100,\n", + " 'dropout': 0.5,\n", + " 'emsize': 512,\n", + " 'epochs': 100,\n", + " 'eval_positions': [10, 20, 40, 80],\n", + " 'lr': 6.271726842985807e-05,\n", + " 'nhead': 4,\n", + " 'nhid_factor': 2,\n", + " 'nlayers': 5,\n", + " 'num_features': 60,\n", + " 'prior_lengthscale': 0.00014803074521613278,\n", + " 'prior_noise': 0.001,\n", + " 'prior_normalize_by_used_features': True,\n", + " 'prior_num_features_used_sampler': {'uniform_int_sampler_f(1,max_features)': '.. at 0x7f21e832e550>'},\n", + " 'prior_order_y': False,\n", + " 'prior_outputscale': 2.3163584733185836,\n", + " 'prior_type': 'gp'}" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "bnn_model_checkpoint_dir = \"results/tabular_model_bnn.ckpt\"\n", + "bnn_model_config = {'batch_size': 512,\n", + " 'bptt': 50,\n", + " 'dropout': 0.5,\n", + " 'emsize': 512,\n", + " 'epochs': 100,\n", + " 'eval_positions': [10, 20, 40],\n", + " 'lr': 1.6421403128751275e-05,\n", + " 'nhead': 4,\n", + " 'nhid_factor': 2,\n", + " 'nlayers': 5,\n", + " 'num_features': 60,\n", + " 'prior_activations': \"\",\n", + " 'prior_dropout_sampler': {'lambda: 0.0': ' at 0x7f613c1364c0>'},\n", + " 'prior_emsize_sampler': {'scaled_beta_sampler_f(2.0, 4.0, 150, 2)': '.. at 0x7f613c136310>'},\n", + " 'prior_is_causal': False,\n", + " 'prior_nlayers_sampler': {'lambda: 3': ' at 0x7f613c136790>'},\n", + " 'prior_noise_std_gamma_k': 1.8663049257557085,\n", + " 'prior_noise_std_gamma_theta': 0.05275478076173361,\n", + " 'prior_normalize_by_used_features': False,\n", + " 'prior_num_features_used_sampler': {'scaled_beta_sampler_f(1.0, 1.6, max_features, 2)': '.. at 0x7f613c136550>'},\n", + " 'prior_order_y': True,\n", + " 'prior_sigma_gamma_k': 3.6187797729244253,\n", + " 'prior_sigma_gamma_theta': 0.06773738681062867,\n", + " 'prior_type': 'mlp'}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading PFN" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using cuda device\n", + "DataLoader.__dict__ {'num_steps': 100, 'fuse_x_y': False, 'get_batch_kwargs': {'batch_size': 512, 'seq_len': 50, 'num_features': 60, 'hyperparameters': (' at 0x7f613c136790>', '.. at 0x7f613c136310>', \"\", .. at 0x7f56719e5940>, .. at 0x7f56719e5b80>, ' at 0x7f613c1364c0>', True, '.. at 0x7f613c136550>', None, False, None, None, None, True, False, None, 0.0), 'batch_size_per_gp_sample': 8}, 'num_features': 60, 'num_outputs': 1}\n" + ] + } + ], + "source": [ + "model_type = 'bnn'\n", + "if model_type == 'gp':\n", + " raise Exception(\"Not Implemented\")\n", + " config = gp_model_config\n", + " checkpoint_dir = gp_model_checkpoint_dir\n", + "elif model_type == 'bnn':\n", + " config = bnn_model_config\n", + " checkpoint_dir = bnn_model_checkpoint_dir\n", + "\n", + "model = get_model(config, device, eval_positions, should_train=False)\n", + "model_state, _ = torch.load(checkpoint_dir)\n", + "model[2].load_state_dict(model_state)\n", + "model = model[2]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluation of PFN and Baselines on all datasets" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Transformer" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating wine\n", + "\t Eval position 30 done..\n", + "Evaluating covertype\n", + "\t Eval position 30 done..\n" + ] + }, + { + "data": { + "text/plain": [ + "{'metric': 'auc',\n", + " 'wine_mean_metric_at_30': 0.9587367346938775,\n", + " 'wine_time': 0.45734596252441406,\n", + " 'covertype_mean_metric_at_30': 0.9624857142857144,\n", + " 'covertype_time': 0.4606165885925293,\n", + " 'mean_metric_at_30': 0.960611224489796,\n", + " 'mean_metric': 0.960611224489796}" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device = 'cuda'\n", + "result = evaluate(ds, model.to(device), 'transformer'\n", + " , max_features = max_features\n", + " , bptt=seq_len\n", + " , eval_position_range=eval_positions\n", + " , device=device\n", + " , max_samples=20\n", + " , rescale_features=config[\"prior_normalize_by_used_features\"]\n", + " , extend_features=True, plot=False, overwrite=True, save=False)\n", + "result" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "heading_collapsed": true + }, + "source": [ + "### Bayesian NN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "hidden": true + }, + "outputs": [], + "source": [ + "result = evaluate(ds, bayes_net_metric, 'bayes_net'\n", + " , bptt=seq_len\n", + " , evaal_position_range=eval_positions\n", + " , device=device\n", + " , max_samples=20\n", + " , overwrite=True\n", + " , save=False)\n", + "result" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "heading_collapsed": true + }, + "source": [ + "### Gaussian Process" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "hidden": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating wine\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 20%|██ | 1/5 [00:05<00:22, 5.68s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.9857142857142858\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 40%|████ | 2/5 [00:10<00:14, 4.88s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.9714285714285714\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 60%|██████ | 3/5 [00:14<00:09, 4.61s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.5\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 80%|████████ | 4/5 [00:18<00:04, 4.55s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.5\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 5/5 [00:23<00:00, 4.65s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.7571428571428571\n", + "\t Eval position 30 done..\n", + "Evaluating covertype\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 20%|██ | 1/5 [00:06<00:24, 6.12s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.9285714285714286\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 40%|████ | 2/5 [00:11<00:17, 5.75s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.7\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 60%|██████ | 3/5 [00:17<00:11, 5.80s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.7142857142857143\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 80%|████████ | 4/5 [00:22<00:05, 5.44s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.8142857142857143\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 5/5 [00:27<00:00, 5.50s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.8714285714285714\n", + "\t Eval position 30 done..\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "result = evaluate(ds, gp_metric, 'gp'\n", + " , bptt=seq_len\n", + " , evaal_position_range=eval_positions\n", + " , device=device\n", + " , max_samples=20\n", + " , overwrite=True\n", + " , save=False)\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "hidden": true + }, + "outputs": [], + "source": [ + "result" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "heading_collapsed": true + }, + "source": [ + "### Tabnet" + ] + }, + { + "cell_type": "code", + "execution_count": 1403, + "metadata": { + "hidden": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating blood-transfusion-service-center\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/2 [00:00\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'cpu'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mds\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m7\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtabnet_acc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'tabnet'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbptt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_positions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mplot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moverwrite\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/home/hollmann/prior-fitting/results_'\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mselector\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;34m'_tabnet.npy'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'wb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(datasets, model, method, bptt, eval_position_range, device, max_features, plot, extend_features, save, rescale_features, overwrite, max_samples, path_interfix)\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 193\u001b[0;31m ds_result = evaluate_dataset(X.to(device), y.to(device), categorical_feats, model, bptt, eval_position_range,\n\u001b[0m\u001b[1;32m 194\u001b[0m rescale_features=rescale_features_factor, max_samples=max_samples)\n\u001b[1;32m 195\u001b[0m \u001b[0melapsed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_dataset\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position_range, plot, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0meval_position\u001b[0m \u001b[0;32min\u001b[0m \u001b[0meval_position_range\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 224\u001b[0;31m r = evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=rescale_features,\n\u001b[0m\u001b[1;32m 225\u001b[0m max_samples=max_samples)\n\u001b[1;32m 226\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_position\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m 310\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[0;31m# acc_per_position = torch.tensor(batch_pred(model,eval_xs,eval_ys, start=2))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 312\u001b[0;31m \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_pred\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_xs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 313\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 314\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mbatch_pred\u001b[0;34m(acc_function, eval_xs, eval_ys, categorical_feats, start)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[0meval_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mmean\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mstd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 326\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 327\u001b[0;31m \u001b[0macc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0macc_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 328\u001b[0m \u001b[0maccs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0macc\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 329\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mtabnet_acc\u001b[0;34m(x, y, test_x, test_y, cat_features)\u001b[0m\n\u001b[1;32m 565\u001b[0m \u001b[0mclf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTabNetClassifier\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcat_idxs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcat_features\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_a\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'n_d'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 566\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 567\u001b[0;31m clf.fit(\n\u001b[0m\u001b[1;32m 568\u001b[0m \u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 569\u001b[0m \u001b[0;31m#eval_set=[(X_valid, y_valid)], patience=15\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pytorch_tabnet/abstract_model.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X_train, y_train, eval_set, eval_name, eval_metric, loss_fn, weights, max_epochs, patience, batch_size, virtual_batch_size, num_workers, drop_last, callbacks, pin_memory, from_unsupervised)\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callback_container\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_epoch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 223\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_dataloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 224\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 225\u001b[0m \u001b[0;31m# Apply predict epoch to all eval sets\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pytorch_tabnet/abstract_model.py\u001b[0m in \u001b[0;36m_train_epoch\u001b[0;34m(self, train_loader)\u001b[0m\n\u001b[1;32m 432\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callback_container\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 433\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 434\u001b[0;31m \u001b[0mbatch_logs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 435\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 436\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callback_container\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_logs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pytorch_tabnet/abstract_model.py\u001b[0m in \u001b[0;36m_train_batch\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 461\u001b[0m \u001b[0mbatch_logs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m\"batch_size\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 462\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 463\u001b[0;31m \u001b[0mX\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 464\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 465\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "result = evaluate(ds, tabnet_metric, 'tabnet'\n", + " , bptt=seq_len\n", + " , eval_position_range=eval_positions\n", + " , device=device\n", + " , max_samples=20\n", + " , overwrite=True\n", + " , save=False)\n", + "result" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "heading_collapsed": true + }, + "source": [ + "### KNN" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "hidden": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating wine\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 20/20 [00:01<00:00, 11.50it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\t Eval position 30 done..\n", + "Evaluating covertype\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 20/20 [00:01<00:00, 16.24it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\t Eval position 30 done..\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "{'metric': 'auc',\n", + " 'wine_mean_metric_at_30': 0.8925102040816327,\n", + " 'wine_time': 1.7465898990631104,\n", + " 'covertype_mean_metric_at_30': 0.792734693877551,\n", + " 'covertype_time': 1.247727394104004,\n", + " 'mean_metric_at_30': 0.8426224489795919,\n", + " 'mean_metric': 0.8426224489795919}" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result = evaluate(ds, knn_metric, 'knn'\n", + " , bptt=seq_len\n", + " , eval_position_range=eval_positions\n", + " , device=device\n", + " , max_samples=20\n", + " , overwrite=True\n", + " , save=False)\n", + "result" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "heading_collapsed": true + }, + "source": [ + "### Logistic Regression" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "hidden": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating wine\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 5%|▌ | 1/20 [00:01<00:34, 1.83s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'C': 1e-05, 'fit_intercept': True, 'max_iter': 500, 'penalty': 'none', 'solver': 'saga', 'tol': 0.01}\n", + "[4.12326081e-04 9.19611425e-01 3.97752626e-02 9.74932808e-01\n", + " 2.92177993e-03 9.98121211e-01 4.40905019e-03 9.33577461e-01\n", + " 1.50555390e-01 9.94204349e-01 3.87256524e-04 9.75964168e-01\n", + " 6.52074221e-03 5.44187545e-01 2.76466189e-04 7.75604223e-01\n", + " 6.95195131e-04 8.77684487e-01 3.72080751e-04 9.98190223e-01\n", + " 1.30797581e-02 9.98647814e-01 2.64468869e-01 9.99135067e-01\n", + " 2.63580359e-05 6.23341075e-01 1.93786767e-03 9.08927121e-01\n", + " 3.61541354e-03 9.98969131e-01 1.74570882e-02 9.08004442e-01\n", + " 1.55259675e-02 9.57181495e-01 2.78544674e-03 9.99731021e-01\n", + " 7.09553284e-02 2.29517949e-01 3.19418012e-02 6.85265538e-01\n", + " 5.43416196e-03 9.67286609e-01 3.30096232e-02 8.83952036e-01\n", + " 9.05628129e-04 9.25079429e-01 1.30549388e-03 9.87857426e-01\n", + " 1.93246531e-02 7.14752885e-01 6.55393417e-03 9.91341790e-01\n", + " 2.71002767e-01 9.72008984e-01 1.95572800e-02 9.70650138e-01\n", + " 4.02880801e-04 9.36476310e-01 1.71522394e-03 9.87288208e-01\n", + " 5.34800821e-03 9.87928080e-01 2.29681197e-03 9.12132603e-01\n", + " 1.32190621e-01 9.99164477e-01 2.50227955e-04 9.78312822e-01\n", + " 8.01561154e-02 9.71095930e-01]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 10%|█ | 2/20 [00:03<00:35, 1.99s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'C': 1e-05, 'fit_intercept': False, 'max_iter': 500, 'penalty': 'none', 'solver': 'saga', 'tol': 0.01}\n", + "[2.34966376e-03 9.97811502e-01 4.42977805e-03 9.99565614e-01\n", + " 2.59753692e-03 9.61724168e-01 2.61071324e-03 9.99933616e-01\n", + " 1.46586638e-03 9.85141705e-01 1.99387649e-04 9.14871930e-01\n", + " 2.31569875e-02 9.77123157e-01 4.55479338e-04 9.98882992e-01\n", + " 2.01825925e-03 9.12376267e-01 2.19126373e-01 9.95130404e-01\n", + " 5.17808246e-04 9.41544348e-01 5.26481881e-03 5.45981587e-01\n", + " 1.29234306e-04 7.59193071e-01 5.28180762e-04 8.64066925e-01\n", + " 5.49872741e-04 9.99114519e-01 3.82477700e-03 9.97710659e-01\n", + " 3.32963599e-01 9.98419850e-01 2.61277431e-05 5.27080955e-01\n", + " 6.53393899e-03 7.76305159e-01 6.04063168e-03 9.98567996e-01\n", + " 3.01862055e-02 8.02008437e-01 1.81304029e-02 9.54083311e-01\n", + " 7.13274150e-03 9.99326682e-01 1.00578841e-01 2.94002156e-01\n", + " 2.58978601e-02 7.49132883e-01 4.50584482e-03 9.77984332e-01\n", + " 8.14691986e-02 8.71252875e-01 9.75149695e-04 9.42233545e-01\n", + " 9.85537736e-04 9.96837487e-01 1.44663880e-02 7.96780562e-01\n", + " 7.87276898e-03 9.94137661e-01 4.49145912e-01 9.83950313e-01\n", + " 1.61118424e-02 9.59959963e-01 4.32656352e-04 9.19760156e-01\n", + " 1.82114427e-03 9.92688659e-01]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 15%|█▌ | 3/20 [00:05<00:32, 1.92s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'C': 1.0, 'fit_intercept': False, 'max_iter': 500, 'penalty': 'l1', 'solver': 'saga', 'tol': 0.0001}\n", + "[0.06974362 0.42336164 0.01703207 0.97388489 0.22809311 0.864173\n", + " 0.03105408 0.83391439 0.01444774 0.91769675 0.12392376 0.80023415\n", + " 0.09652814 0.85240131 0.08174257 0.74563011 0.32383395 0.87731599\n", + " 0.68395194 0.94467399 0.25368484 0.99224602 0.09027845 0.88636919\n", + " 0.09235925 0.99655498 0.9651366 0.97910275 0.5482745 0.93189756\n", + " 0.01159998 0.96038068 0.0267969 0.9155232 0.05236642 0.96382213\n", + " 0.43594785 0.81424108 0.10509157 0.98534026 0.01147495 0.09514501\n", + " 0.29550364 0.96793223 0.06416272 0.02983371 0.03576493 0.09970959\n", + " 0.06283597 0.06050325 0.04146495 0.10931557 0.00956397 0.03881107\n", + " 0.07353775 0.23500227 0.02755604 0.28342607 0.1337751 0.08567451\n", + " 0.01716663 0.0149318 0.01978608 0.23753092 0.01930092 0.09254563\n", + " 0.04065118 0.13216921 0.0324691 0.09542866]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 20%|██ | 4/20 [00:07<00:29, 1.86s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'C': 1.0, 'fit_intercept': True, 'max_iter': 500, 'penalty': 'l1', 'solver': 'saga', 'tol': 0.01}\n", + "[0.70131459 0.01678711 0.60492504 0.01069737 0.71376805 0.02748887\n", + " 0.98449608 0.05130859 0.97420844 0.40381504 0.96752984 0.02519502\n", + " 0.82707777 0.10708214 0.69000753 0.2040882 0.97747093 0.40330462\n", + " 0.67280445 0.19142666 0.96526066 0.08645192 0.99053094 0.35871556\n", + " 0.65781758 0.44227517 0.68420729 0.08220921 0.89025336 0.45462276\n", + " 0.84425192 0.54211743 0.88263528 0.08692686 0.98430602 0.17645702\n", + " 0.8538925 0.16913417 0.97691351 0.81127546 0.93985646 0.16094987\n", + " 0.96264753 0.04228152 0.90056355 0.19237197 0.93328752 0.07411261\n", + " 0.89027042 0.58041776 0.7955457 0.18435328 0.99249684 0.03125832\n", + " 0.70961402 0.2673135 0.97250892 0.01717529 0.70037093 0.12556494\n", + " 0.70165636 0.09616296 0.2272858 0.08929946 0.32233496 0.00673966\n", + " 0.5958783 0.02302669 0.15701894 0.0500672 ]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 25%|██▌ | 5/20 [00:09<00:27, 1.85s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'C': 1e-05, 'fit_intercept': True, 'max_iter': 500, 'penalty': 'none', 'solver': 'saga', 'tol': 0.01}\n", + "[6.52714586e-01 9.98635531e-01 6.00991816e-03 9.55565635e-01\n", + " 7.09951283e-03 8.14360044e-01 6.81791679e-03 9.99801606e-01\n", + " 4.36909361e-01 8.91309699e-01 3.48875884e-01 9.96234204e-01\n", + " 2.69088658e-02 9.99950326e-01 2.39151361e-01 9.33560167e-01\n", + " 6.36510214e-01 9.11335046e-01 6.19797405e-02 9.89353304e-01\n", + " 1.05815283e-01 9.75894050e-01 9.19643764e-01 9.86137300e-01\n", + " 2.66815067e-03 9.99040711e-01 3.01042737e-02 9.62981963e-01\n", + " 2.01891457e-01 9.99483583e-01 9.72869351e-01 9.92534311e-01\n", + " 3.48164720e-01 9.98274099e-01 9.57299314e-03 9.95362359e-01\n", + " 1.34150611e-01 9.96665681e-01 8.69588961e-03 9.88952714e-01\n", + " 4.11556516e-01 8.06033390e-01 8.89376753e-02 9.99933419e-01\n", + " 1.79285952e-02 2.46006114e-02 4.19293642e-01 9.99020227e-01\n", + " 1.39752377e-03 8.27256545e-02 7.09215549e-04 1.88990783e-02\n", + " 2.49824281e-02 7.11231223e-04 9.35730222e-03 2.34447294e-03\n", + " 9.25927278e-04 2.82376339e-02 2.25198163e-03 3.71604751e-04\n", + " 4.15443651e-03 2.59381661e-01 3.13232724e-02 2.91566784e-02\n", + " 3.94860122e-04 7.75759927e-04 5.23700697e-05 1.50473088e-01\n", + " 3.39148707e-03 5.03053356e-03]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 30%|███ | 6/20 [00:11<00:25, 1.83s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'C': 1e-05, 'fit_intercept': True, 'max_iter': 500, 'penalty': 'none', 'solver': 'saga', 'tol': 0.01}\n", + "[9.83583576e-01 5.31016587e-04 9.29274459e-01 3.85459019e-02\n", + " 9.71639017e-01 1.18775105e-03 9.98134567e-01 5.21240611e-03\n", + " 8.86056657e-01 2.38149719e-01 9.92077944e-01 7.38237983e-04\n", + " 9.58905441e-01 1.24559396e-02 5.60988697e-01 3.11511168e-04\n", + " 7.85096097e-01 9.90275798e-04 8.82864038e-01 5.31628645e-04\n", + " 9.98564605e-01 1.48625315e-02 9.97921985e-01 2.86616976e-01\n", + " 9.98944054e-01 4.72361925e-05 6.31113590e-01 6.26047465e-03\n", + " 8.84982042e-01 9.00147613e-03 9.98913206e-01 3.01658308e-02\n", + " 8.72072354e-01 1.88641325e-02 9.61498458e-01 5.51973166e-03\n", + " 9.99559769e-01 8.98409294e-02 2.73225895e-01 3.36784201e-02\n", + " 7.20221648e-01 7.07506732e-03 9.77283000e-01 1.00082860e-01\n", + " 8.83214682e-01 6.98692395e-04 9.33891209e-01 1.52199602e-03\n", + " 9.93985713e-01 2.50527583e-02 7.55022990e-01 9.41495432e-03\n", + " 9.90193157e-01 3.03874037e-01 9.70953038e-01 2.07865013e-02\n", + " 9.67583165e-01 5.22901980e-04 9.28441639e-01 2.16943826e-03\n", + " 9.90123427e-01 7.18394185e-03 9.90211745e-01 2.42021104e-03\n", + " 9.38244511e-01 1.23170538e-01 9.99344646e-01 3.81413969e-04\n", + " 9.78230408e-01 7.60275466e-02]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 35%|███▌ | 7/20 [00:13<00:24, 1.85s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'C': 1e-05, 'fit_intercept': True, 'max_iter': 500, 'penalty': 'none', 'solver': 'saga', 'tol': 0.0001}\n", + "[3.31313562e-05 9.99999464e-01 7.95159296e-01 9.99997804e-01\n", + " 3.13777454e-04 9.95388564e-01 3.10136131e-04 9.75628399e-01\n", + " 3.62288555e-04 9.99999677e-01 5.60118984e-01 9.94055581e-01\n", + " 2.82618459e-01 9.99898183e-01 2.05644609e-03 9.99999978e-01\n", + " 7.96792726e-02 9.92154817e-01 5.86060898e-01 9.88876051e-01\n", + " 1.15897862e-02 9.99597733e-01 5.01256743e-02 9.98633899e-01\n", + " 9.73034582e-01 9.99549765e-01 5.99983839e-05 9.99995143e-01\n", + " 6.18802761e-03 9.95558768e-01 6.38389742e-02 9.99998705e-01\n", + " 9.99182384e-01 9.99817107e-01 4.12569504e-01 9.99983173e-01\n", + " 4.46397918e-04 9.99892199e-01 3.40415387e-02 9.99945259e-01\n", + " 6.76523004e-04 9.99672836e-01 1.59490335e-01 9.25574947e-01\n", + " 4.39227379e-02 9.99999945e-01 7.88443823e-04 1.14832115e-02\n", + " 3.73727289e-01 9.99995097e-01 1.43936707e-05 6.01971612e-02\n", + " 1.11791732e-05 2.47806691e-03 8.31325116e-04 5.91754942e-06\n", + " 3.79556417e-04 2.44141540e-04 3.43753601e-06 9.67005991e-03\n", + " 5.44429224e-05 6.99912685e-06 1.40831013e-04 1.64769326e-01\n", + " 4.10106657e-03 3.14697787e-03 1.65498495e-06 3.17770130e-05\n", + " 6.87332367e-08 1.85409476e-01]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 40%|████ | 8/20 [00:14<00:22, 1.85s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'C': 1e-05, 'fit_intercept': True, 'max_iter': 500, 'penalty': 'none', 'solver': 'saga', 'tol': 0.01}\n", + "[9.99664268e-01 7.62876099e-01 9.98386486e-01 6.27311096e-03\n", + " 9.37337218e-01 6.21016618e-03 8.18091840e-01 5.71440730e-03\n", + " 9.99733629e-01 4.49018505e-01 9.11202742e-01 4.15045796e-01\n", + " 9.93656045e-01 2.90506359e-02 9.99924234e-01 2.44542240e-01\n", + " 9.26934168e-01 6.76509634e-01 9.09437589e-01 7.14948195e-02\n", + " 9.85587705e-01 7.66412054e-02 9.72836935e-01 8.36812757e-01\n", + " 9.82542331e-01 2.44985077e-03 9.98048341e-01 3.79889220e-02\n", + " 9.58489559e-01 1.95237163e-01 9.99317980e-01 9.59677506e-01\n", + " 9.92652818e-01 3.36662240e-01 9.97738733e-01 1.03584140e-02\n", + " 9.94715794e-01 1.52240798e-01 9.96180878e-01 7.25841492e-03\n", + " 9.86733845e-01 2.56595318e-01 7.50204506e-01 1.22199744e-01\n", + " 9.99891118e-01 2.30040057e-02 3.16352395e-02 4.69433536e-01\n", + " 9.98253315e-01 2.25872669e-03 6.35848289e-02 1.08197788e-03\n", + " 1.72929146e-02 2.31652202e-02 5.00682185e-04 8.54084689e-03\n", + " 3.33725153e-03 1.54900694e-03 3.73832650e-02 3.21548385e-03\n", + " 3.95950758e-04 4.80783210e-03 2.40004379e-01 3.00124757e-02\n", + " 2.95706753e-02 4.96784512e-04 1.01762749e-03 6.84382194e-05\n", + " 1.55359082e-01 3.53251038e-03]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 45%|████▌ | 9/20 [00:16<00:20, 1.86s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'C': 1e-05, 'fit_intercept': True, 'max_iter': 500, 'penalty': 'l2', 'solver': 'saga', 'tol': 1e-10}\n", + "[0.4998356 0.50015137 0.49987409 0.50029966 0.49983398 0.50031188\n", + " 0.50009655 0.50023837 0.49986525 0.5002024 0.49988854 0.50012421\n", + " 0.49990913 0.50033222 0.50001204 0.50011976 0.50011477 0.50028586\n", + " 0.49994748 0.50036626 0.50007451 0.50013711 0.50013507 0.50012796\n", + " 0.49997282 0.50019476 0.49997082 0.50019112 0.5001566 0.50017832\n", + " 0.49986975 0.50027281 0.49993339 0.50016413 0.50001914 0.50030764\n", + " 0.50022348 0.50024722 0.50009243 0.500288 0.49992337 0.50020947\n", + " 0.50000242 0.50023711 0.49991418 0.50019936 0.50010521 0.50013699\n", + " 0.50003493 0.50039762 0.49996634 0.49993259 0.50006249 0.50027367\n", + " 0.49986551 0.49999889 0.49985123 0.4999571 0.50000606 0.49985354\n", + " 0.49994856 0.49987396 0.4999105 0.49993825 0.49989787 0.49983024\n", + " 0.49994391 0.50008364 0.50003337 0.50000305]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 45%|████▌ | 9/20 [00:17<00:21, 1.92s/it]\n", + "ERROR:root:Internal Python error in the inspect module.\n", + "Below is the traceback from this internal error.\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Traceback (most recent call last):\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 3441, in run_code\n", + " exec(code_obj, self.user_global_ns, self.user_ns)\n", + " File \"/tmp/ipykernel_62303/3352290182.py\", line 1, in \n", + " result = evaluate(ds, logistic_acc, 'logistic'\n", + " File \"/home/hollmann/prior-fitting/tabular.py\", line 195, in evaluate\n", + " ds_result = evaluate_dataset(X.to(device), y.to(device), categorical_feats, model, bptt, eval_position_range,\n", + " File \"/home/hollmann/prior-fitting/tabular.py\", line 227, in evaluate_dataset\n", + " r = evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=rescale_features,\n", + " File \"/home/hollmann/prior-fitting/tabular.py\", line 314, in evaluate_position\n", + " acc_eval_pos, outputs = batch_pred(model, eval_xs, eval_ys, categorical_feats, start=eval_position)\n", + " File \"/home/hollmann/prior-fitting/tabular.py\", line 329, in batch_pred\n", + " acc, output = acc_function(eval_x[:start], eval_y[:start], eval_x[start:], eval_y[start:], categorical_feats)\n", + " File \"/home/hollmann/prior-fitting/tabular.py\", line 380, in logistic_acc\n", + " clf.fit(x, y.long())\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/validation.py\", line 63, in inner_f\n", + " return f(*args, **kwargs)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\", line 841, in fit\n", + " self._run_search(evaluate_candidates)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\", line 1296, in _run_search\n", + " evaluate_candidates(ParameterGrid(self.param_grid))\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\", line 795, in evaluate_candidates\n", + " out = parallel(delayed(_fit_and_score)(clone(base_estimator),\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\", line 1044, in __call__\n", + " while self.dispatch_one_batch(iterator):\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\", line 859, in dispatch_one_batch\n", + " self._dispatch(tasks)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\", line 777, in _dispatch\n", + " job = self._backend.apply_async(batch, callback=cb)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/_parallel_backends.py\", line 208, in apply_async\n", + " result = ImmediateResult(func)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/_parallel_backends.py\", line 572, in __init__\n", + " self.results = batch()\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\", line 262, in __call__\n", + " return [func(*args, **kwargs)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\", line 262, in \n", + " return [func(*args, **kwargs)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/fixes.py\", line 222, in __call__\n", + " return self.function(*args, **kwargs)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_validation.py\", line 625, in _fit_and_score\n", + " test_scores = _score(estimator, X_test, y_test, scorer, error_score)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_validation.py\", line 687, in _score\n", + " scores = scorer(estimator, X_test, y_test)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/metrics/_scorer.py\", line 397, in _passthrough_scorer\n", + " return estimator.score(*args, **kwargs)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/base.py\", line 500, in score\n", + " return accuracy_score(y, self.predict(X), sample_weight=sample_weight)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/validation.py\", line 63, in inner_f\n", + " return f(*args, **kwargs)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/metrics/_classification.py\", line 210, in accuracy_score\n", + " return _weighted_sum(score, sample_weight, normalize)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/metrics/_classification.py\", line 133, in _weighted_sum\n", + " return np.average(sample_score, weights=sample_weight)\n", + " File \"<__array_function__ internals>\", line 5, in average\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/numpy/lib/function_base.py\", line 380, in average\n", + " avg = a.mean(axis)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/numpy/core/_methods.py\", line 167, in _mean\n", + " rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/numpy/core/_methods.py\", line 71, in _count_reduce_items\n", + " axis = tuple(range(arr.ndim))\n", + "KeyboardInterrupt\n", + "\n", + "During handling of the above exception, another exception occurred:\n", + "\n", + "Traceback (most recent call last):\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 2061, in showtraceback\n", + " stb = value._render_traceback_()\n", + "AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'\n", + "\n", + "During handling of the above exception, another exception occurred:\n", + "\n", + "Traceback (most recent call last):\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\", line 1101, in get_records\n", + " return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\", line 248, in wrapped\n", + " return f(*args, **kwargs)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\", line 281, in _fixed_getinnerframes\n", + " records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/inspect.py\", line 1541, in getinnerframes\n", + " frameinfo = (tb.tb_frame,) + getframeinfo(tb, context)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/inspect.py\", line 1503, in getframeinfo\n", + " lines, lnum = findsource(frame)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\", line 182, in findsource\n", + " lines = linecache.getlines(file, globals_dict)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/linecache.py\", line 46, in getlines\n", + " return updatecache(filename, module_globals)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/linecache.py\", line 136, in updatecache\n", + " with tokenize.open(fullname) as fp:\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/tokenize.py\", line 394, in open\n", + " encoding, lines = detect_encoding(buffer.readline)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/tokenize.py\", line 363, in detect_encoding\n", + " first = read_or_stop()\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/tokenize.py\", line 321, in read_or_stop\n", + " return readline()\n", + "KeyboardInterrupt\n" + ] + }, + { + "ename": "TypeError", + "evalue": "object of type 'NoneType' has no len()", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "\u001b[0;32m/tmp/ipykernel_62303/3352290182.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m result = evaluate(ds, logistic_acc, 'logistic'\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0mbptt\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mseq_len\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0meval_position_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0meval_positions\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(datasets, model, method, bptt, eval_position_range, device, max_features, plot, extend_features, save, rescale_features, overwrite, max_samples, path_interfix)\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 195\u001b[0;31m ds_result = evaluate_dataset(X.to(device), y.to(device), categorical_feats, model, bptt, eval_position_range,\n\u001b[0m\u001b[1;32m 196\u001b[0m rescale_features=rescale_features_factor, max_samples=max_samples)\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_dataset\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position_range, plot, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0meval_position\u001b[0m \u001b[0;32min\u001b[0m \u001b[0meval_position_range\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 227\u001b[0;31m r = evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=rescale_features,\n\u001b[0m\u001b[1;32m 228\u001b[0m max_samples=max_samples)\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_position\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[0;31m# acc_per_position = torch.tensor(batch_pred(model,eval_xs,eval_ys, start=2))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 314\u001b[0;31m \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_pred\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_xs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 315\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mbatch_pred\u001b[0;34m(acc_function, eval_xs, eval_ys, categorical_feats, start)\u001b[0m\n\u001b[1;32m 328\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 329\u001b[0;31m \u001b[0macc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0macc_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 330\u001b[0m \u001b[0maccs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0macc\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mlogistic_acc\u001b[0;34m(x, y, test_x, test_y, cat_features)\u001b[0m\n\u001b[1;32m 379\u001b[0m \u001b[0;31m# fit model to data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 380\u001b[0;31m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 381\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, groups, **fit_params)\u001b[0m\n\u001b[1;32m 840\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 841\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_search\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 842\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36m_run_search\u001b[0;34m(self, evaluate_candidates)\u001b[0m\n\u001b[1;32m 1295\u001b[0m \u001b[0;34m\"\"\"Search all candidates in param_grid\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1296\u001b[0;31m \u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mParameterGrid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparam_grid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1297\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mevaluate_candidates\u001b[0;34m(candidate_params, cv, more_results)\u001b[0m\n\u001b[1;32m 794\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 795\u001b[0;31m out = parallel(delayed(_fit_and_score)(clone(base_estimator),\n\u001b[0m\u001b[1;32m 796\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 1043\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1044\u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_one_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1045\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36mdispatch_one_batch\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m 858\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 859\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtasks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 860\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m_dispatch\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 776\u001b[0m \u001b[0mjob_idx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 777\u001b[0;31m \u001b[0mjob\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 778\u001b[0m \u001b[0;31m# A job can complete so quickly than its callback is\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36mapply_async\u001b[0;34m(self, func, callback)\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[0;34m\"\"\"Schedule a func to be run\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 208\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImmediateResult\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 209\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 571\u001b[0m \u001b[0;31m# arguments in memory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 572\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 573\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mparallel_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_n_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 262\u001b[0;31m return [func(*args, **kwargs)\n\u001b[0m\u001b[1;32m 263\u001b[0m for func, args, kwargs in self.items]\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mparallel_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_n_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 262\u001b[0;31m return [func(*args, **kwargs)\n\u001b[0m\u001b[1;32m 263\u001b[0m for func, args, kwargs in self.items]\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/fixes.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mconfig_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 222\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_validation.py\u001b[0m in \u001b[0;36m_fit_and_score\u001b[0;34m(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score, return_parameters, return_n_test_samples, return_times, return_estimator, split_progress, candidate_progress, error_score)\u001b[0m\n\u001b[1;32m 624\u001b[0m \u001b[0mfit_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 625\u001b[0;31m \u001b[0mtest_scores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscorer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merror_score\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 626\u001b[0m \u001b[0mscore_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mfit_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_validation.py\u001b[0m in \u001b[0;36m_score\u001b[0;34m(estimator, X_test, y_test, scorer, error_score)\u001b[0m\n\u001b[1;32m 686\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 687\u001b[0;31m \u001b[0mscores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mscorer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 688\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/metrics/_scorer.py\u001b[0m in \u001b[0;36m_passthrough_scorer\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 396\u001b[0m \u001b[0;34m\"\"\"Function that wraps estimator.score\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 397\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscore\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 398\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/base.py\u001b[0m in \u001b[0;36mscore\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m 499\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0maccuracy_score\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 500\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0maccuracy_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 501\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/metrics/_classification.py\u001b[0m in \u001b[0;36maccuracy_score\u001b[0;34m(y_true, y_pred, normalize, sample_weight)\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 210\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_weighted_sum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscore\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnormalize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 211\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/metrics/_classification.py\u001b[0m in \u001b[0;36m_weighted_sum\u001b[0;34m(sample_score, sample_weight, normalize)\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnormalize\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 133\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maverage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msample_score\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweights\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 134\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0msample_weight\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36maverage\u001b[0;34m(*args, **kwargs)\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/numpy/lib/function_base.py\u001b[0m in \u001b[0;36maverage\u001b[0;34m(a, axis, weights, returned)\u001b[0m\n\u001b[1;32m 379\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mweights\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 380\u001b[0;31m \u001b[0mavg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 381\u001b[0m \u001b[0mscl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mavg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mavg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/numpy/core/_methods.py\u001b[0m in \u001b[0;36m_mean\u001b[0;34m(a, axis, dtype, out, keepdims, where)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m \u001b[0mrcount\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_count_reduce_items\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkeepdims\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwhere\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwhere\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 168\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrcount\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mwhere\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mTrue\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mumr_any\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrcount\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/numpy/core/_methods.py\u001b[0m in \u001b[0;36m_count_reduce_items\u001b[0;34m(arr, axis, keepdims, where)\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0maxis\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 71\u001b[0;31m \u001b[0maxis\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 72\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: ", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mshowtraceback\u001b[0;34m(self, exc_tuple, filename, tb_offset, exception_only, running_compiled_code)\u001b[0m\n\u001b[1;32m 2060\u001b[0m \u001b[0;31m# in the engines. This should return a list of strings.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2061\u001b[0;31m \u001b[0mstb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render_traceback_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2062\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'KeyboardInterrupt' object has no attribute '_render_traceback_'", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mshowtraceback\u001b[0;34m(self, exc_tuple, filename, tb_offset, exception_only, running_compiled_code)\u001b[0m\n\u001b[1;32m 2061\u001b[0m \u001b[0mstb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render_traceback_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2062\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2063\u001b[0;31m stb = self.InteractiveTB.structured_traceback(etype,\n\u001b[0m\u001b[1;32m 2064\u001b[0m value, tb, tb_offset=tb_offset)\n\u001b[1;32m 2065\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, value, tb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m 1365\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1366\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1367\u001b[0;31m return FormattedTB.structured_traceback(\n\u001b[0m\u001b[1;32m 1368\u001b[0m self, etype, value, tb, tb_offset, number_of_lines_of_context)\n\u001b[1;32m 1369\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, value, tb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m 1265\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mverbose_modes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1266\u001b[0m \u001b[0;31m# Verbose modes need a full traceback\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1267\u001b[0;31m return VerboseTB.structured_traceback(\n\u001b[0m\u001b[1;32m 1268\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0metype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtb_offset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumber_of_lines_of_context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1269\u001b[0m )\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, evalue, etb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m 1122\u001b[0m \u001b[0;34m\"\"\"Return a nice text document describing the traceback.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1124\u001b[0;31m formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,\n\u001b[0m\u001b[1;32m 1125\u001b[0m tb_offset)\n\u001b[1;32m 1126\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mformat_exception_as_a_whole\u001b[0;34m(self, etype, evalue, etb, number_of_lines_of_context, tb_offset)\u001b[0m\n\u001b[1;32m 1080\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1081\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1082\u001b[0;31m \u001b[0mlast_unique\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecursion_repeat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfind_recursion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0morig_etype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1083\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1084\u001b[0m \u001b[0mframes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat_records\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrecords\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlast_unique\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecursion_repeat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mfind_recursion\u001b[0;34m(etype, value, records)\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[0;31m# first frame (from in to out) that looks different.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 381\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_recursion_error\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0metype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 382\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 383\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 384\u001b[0m \u001b[0;31m# Select filename, lineno, func_name to track frames with\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: object of type 'NoneType' has no len()" + ] + } + ], + "source": [ + "result = evaluate(ds, logistic_metric, 'logistic'\n", + " , bptt=seq_len\n", + " , eval_position_range=eval_positions\n", + " , device=device\n", + " , max_samples=20\n", + " , overwrite=True\n", + " , save=False)\n", + "result" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "heading_collapsed": true + }, + "source": [ + "### Ridge Classification" + ] + }, + { + "cell_type": "code", + "execution_count": 1303, + "metadata": { + "hidden": true, + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating kr-vs-kp\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 5%|▌ | 1/20 [00:00<00:02, 6.91it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.8428571428571429\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 10%|█ | 2/20 [00:00<00:02, 6.81it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 15%|█▌ | 3/20 [00:00<00:02, 6.95it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.7857142857142857\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 20%|██ | 4/20 [00:00<00:02, 6.80it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.9428571428571428\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 25%|██▌ | 5/20 [00:00<00:02, 6.70it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 30%|███ | 6/20 [00:00<00:02, 6.71it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.8\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 35%|███▌ | 7/20 [00:01<00:01, 6.68it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.8714285714285714\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 40%|████ | 8/20 [00:01<00:01, 6.66it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.9571428571428572\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 45%|████▌ | 9/20 [00:01<00:01, 6.79it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.8285714285714286\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=8.92231e-11): result may not be accurate.\n", + " dual_coef = linalg.solve(K, y, sym_pos=True,\n", + " 50%|█████ | 10/20 [00:01<00:01, 6.82it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 55%|█████▌ | 11/20 [00:01<00:01, 6.77it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.6714285714285714\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 60%|██████ | 12/20 [00:01<00:01, 6.76it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.8285714285714286\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 65%|██████▌ | 13/20 [00:01<00:01, 6.71it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.8857142857142857\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 70%|██████��� | 14/20 [00:02<00:00, 6.69it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 75%|███████▌ | 15/20 [00:02<00:00, 6.64it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 75%|███████▌ | 15/20 [00:02<00:00, 6.37it/s]\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_71014/1206641945.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_ds_with_selector\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mselector\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mridge_acc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'ridge'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbptt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_positions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moverwrite\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/home/hollmann/prior-fitting/results/tabular/results_'\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mselector\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;34m'_ridge.npy'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'wb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(datasets, model, method, bptt, eval_position_range, device, max_features, plot, extend_features, save, rescale_features, overwrite, max_samples, path_interfix)\u001b[0m\n\u001b[1;32m 188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 190\u001b[0;31m ds_result = evaluate_dataset(X.to(device), y.to(device), categorical_feats, model, bptt, eval_position_range,\n\u001b[0m\u001b[1;32m 191\u001b[0m rescale_features=rescale_features_factor, max_samples=max_samples)\n\u001b[1;32m 192\u001b[0m \u001b[0melapsed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_dataset\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position_range, plot, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m 219\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0meval_position\u001b[0m \u001b[0;32min\u001b[0m \u001b[0meval_position_range\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 221\u001b[0;31m r = evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=rescale_features,\n\u001b[0m\u001b[1;32m 222\u001b[0m max_samples=max_samples)\n\u001b[1;32m 223\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_position\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 308\u001b[0m \u001b[0;31m# acc_per_position = torch.tensor(batch_pred(model,eval_xs,eval_ys, start=2))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 309\u001b[0;31m \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_pred\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_xs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 310\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mbatch_pred\u001b[0;34m(acc_function, eval_xs, eval_ys, categorical_feats, start)\u001b[0m\n\u001b[1;32m 322\u001b[0m \u001b[0meval_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mmean\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mstd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 323\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 324\u001b[0;31m \u001b[0macc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0macc_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 325\u001b[0m \u001b[0maccs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0macc\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 326\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mridge_acc\u001b[0;34m(x, y, test_x, test_y, cat_features)\u001b[0m\n\u001b[1;32m 348\u001b[0m \u001b[0mclf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGridSearchCV\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_grid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'ridge'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcv\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mCV\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m//\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 349\u001b[0m \u001b[0;31m# fit model to data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 350\u001b[0;31m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 351\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[0mpred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecision_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_x\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;31m# extra_args > 0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, groups, **fit_params)\u001b[0m\n\u001b[1;32m 839\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 840\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 841\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_search\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 842\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 843\u001b[0m \u001b[0;31m# multimetric is determined here because in the case of a callable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36m_run_search\u001b[0;34m(self, evaluate_candidates)\u001b[0m\n\u001b[1;32m 1294\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_run_search\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1295\u001b[0m \u001b[0;34m\"\"\"Search all candidates in param_grid\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1296\u001b[0;31m \u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mParameterGrid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparam_grid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1297\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1298\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mevaluate_candidates\u001b[0;34m(candidate_params, cv, more_results)\u001b[0m\n\u001b[1;32m 793\u001b[0m n_splits, n_candidates, n_candidates * n_splits))\n\u001b[1;32m 794\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 795\u001b[0;31m out = parallel(delayed(_fit_and_score)(clone(base_estimator),\n\u001b[0m\u001b[1;32m 796\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 797\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 1042\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_iterating\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_original_iterator\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1043\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1044\u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_one_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1045\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1046\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36mdispatch_one_batch\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m 857\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 858\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 859\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtasks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 860\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 861\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m_dispatch\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 775\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_lock\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 776\u001b[0m \u001b[0mjob_idx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 777\u001b[0;31m \u001b[0mjob\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 778\u001b[0m \u001b[0;31m# A job can complete so quickly than its callback is\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 779\u001b[0m \u001b[0;31m# called before we get here, causing self._jobs to\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36mapply_async\u001b[0;34m(self, func, callback)\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mapply_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[0;34m\"\"\"Schedule a func to be run\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 208\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImmediateResult\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 209\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 570\u001b[0m \u001b[0;31m# Don't delay the application, to avoid keeping the input\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 571\u001b[0m \u001b[0;31m# arguments in memory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 572\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 573\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 574\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 260\u001b[0m \u001b[0;31m# change the default number of processes to -1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mparallel_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_n_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 262\u001b[0;31m return [func(*args, **kwargs)\n\u001b[0m\u001b[1;32m 263\u001b[0m for func, args, kwargs in self.items]\n\u001b[1;32m 264\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 260\u001b[0m \u001b[0;31m# change the default number of processes to -1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mparallel_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_n_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 262\u001b[0;31m return [func(*args, **kwargs)\n\u001b[0m\u001b[1;32m 263\u001b[0m for func, args, kwargs in self.items]\n\u001b[1;32m 264\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/fixes.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mconfig_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 222\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_validation.py\u001b[0m in \u001b[0;36m_fit_and_score\u001b[0;34m(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score, return_parameters, return_n_test_samples, return_times, return_estimator, split_progress, candidate_progress, error_score)\u001b[0m\n\u001b[1;32m 596\u001b[0m \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 597\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 598\u001b[0;31m \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 599\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 600\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m 943\u001b[0m compute_sample_weight(self.class_weight, y))\n\u001b[1;32m 944\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 945\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mY\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 946\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 947\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m 591\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 592\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 593\u001b[0;31m self.coef_, self.n_iter_ = _ridge_regression(\n\u001b[0m\u001b[1;32m 594\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0malpha\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0malpha\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 595\u001b[0m \u001b[0mmax_iter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_iter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtol\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtol\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msolver\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msolver\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py\u001b[0m in \u001b[0;36m_ridge_regression\u001b[0;34m(X, y, alpha, sample_weight, solver, max_iter, tol, verbose, random_state, return_n_iter, return_intercept, X_scale, X_offset, check_input)\u001b[0m\n\u001b[1;32m 461\u001b[0m \u001b[0mK\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msafe_sparse_dot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdense_output\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 462\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 463\u001b[0;31m \u001b[0mdual_coef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_solve_cholesky_kernel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mK\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0malpha\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 464\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 465\u001b[0m \u001b[0mcoef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msafe_sparse_dot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdual_coef\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdense_output\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py\u001b[0m in \u001b[0;36m_solve_cholesky_kernel\u001b[0;34m(K, y, alpha, sample_weight, copy)\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[0;31m# use the fall-back solution below in case a LinAlgError\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[0;31m# is raised\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m dual_coef = linalg.solve(K, y, sym_pos=True,\n\u001b[0m\u001b[1;32m 188\u001b[0m overwrite_a=False)\n\u001b[1;32m 189\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinAlgError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/scipy/linalg/basic.py\u001b[0m in \u001b[0;36msolve\u001b[0;34m(a, b, sym_pos, lower, overwrite_a, overwrite_b, debug, check_finite, assume_a, transposed)\u001b[0m\n\u001b[1;32m 252\u001b[0m overwrite_b=overwrite_b)\n\u001b[1;32m 253\u001b[0m \u001b[0m_solve_check\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minfo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 254\u001b[0;31m \u001b[0mrcond\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpocon\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0manorm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 255\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 256\u001b[0m \u001b[0m_solve_check\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minfo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlamch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrcond\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "result = evaluate(ds, ridge_metric, 'ridge'\n", + " , bptt=seq_len\n", + " , eval_position_range=eval_positions\n", + " , device=device\n", + " , max_samples=20\n", + " , overwrite=True\n", + " , save=False)\n", + "result" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "heading_collapsed": true + }, + "source": [ + "### XG Boost" + ] + }, + { + "cell_type": "code", + "execution_count": 1346, + "metadata": { + "hidden": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating kr-vs-kp\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/2 [00:00\n", + " result = evaluate(ds, xgb_acc, 'xgb', bptt, eval_positions, device='cpu', max_samples=2, overwrite=True)\n", + " File \"/home/hollmann/prior-fitting/tabular.py\", line 190, in evaluate\n", + " ds_result = evaluate_dataset(X.to(device), y.to(device), categorical_feats, model, bptt, eval_position_range,\n", + " File \"/home/hollmann/prior-fitting/tabular.py\", line 221, in evaluate_dataset\n", + " r = evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=rescale_features,\n", + " File \"/home/hollmann/prior-fitting/tabular.py\", line 309, in evaluate_position\n", + " acc_eval_pos, outputs = batch_pred(model, eval_xs, eval_ys, categorical_feats, start=eval_position)\n", + " File \"/home/hollmann/prior-fitting/tabular.py\", line 324, in batch_pred\n", + " acc, output = acc_function(eval_x[:start], eval_y[:start], eval_x[start:], eval_y[start:], categorical_feats)\n", + " File \"/home/hollmann/prior-fitting/tabular.py\", line 628, in xgb_acc\n", + " clf.fit(x, y.astype(int))\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/validation.py\", line 63, in inner_f\n", + " return f(*args, **kwargs)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\", line 880, in fit\n", + " self.best_estimator_.fit(X, y, **fit_params)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/core.py\", line 433, in inner_f\n", + " return f(**kwargs)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/sklearn.py\", line 1176, in fit\n", + " self._Booster = train(\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/training.py\", line 189, in train\n", + " bst = _train_internal(params, dtrain,\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/training.py\", line 81, in _train_internal\n", + " bst.update(dtrain, i, obj)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/core.py\", line 1496, in update\n", + " _check_call(_LIB.XGBoosterUpdateOneIter(self.handle,\n", + "KeyboardInterrupt\n", + "\n", + "During handling of the above exception, another exception occurred:\n", + "\n", + "Traceback (most recent call last):\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 2061, in showtraceback\n", + " stb = value._render_traceback_()\n", + "AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'\n", + "\n", + "During handling of the above exception, another exception occurred:\n", + "\n", + "Traceback (most recent call last):\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\", line 1101, in get_records\n", + " return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\", line 248, in wrapped\n", + " return f(*args, **kwargs)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\", line 281, in _fixed_getinnerframes\n", + " records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/inspect.py\", line 1541, in getinnerframes\n", + " frameinfo = (tb.tb_frame,) + getframeinfo(tb, context)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/inspect.py\", line 1499, in getframeinfo\n", + " filename = getsourcefile(frame) or getfile(frame)\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/inspect.py\", line 709, in getsourcefile\n", + " if getattr(getmodule(object, filename), '__loader__', None) is not None:\n", + " File \"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/inspect.py\", line 745, in getmodule\n", + " for modname, module in sys.modules.copy().items():\n", + "KeyboardInterrupt\n" + ] + }, + { + "ename": "TypeError", + "evalue": "object of type 'NoneType' has no len()", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "\u001b[0;32m/tmp/ipykernel_71014/762088562.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxgb_acc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'xgb'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbptt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_positions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'cpu'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moverwrite\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(datasets, model, method, bptt, eval_position_range, device, max_features, plot, extend_features, save, rescale_features, overwrite, max_samples, path_interfix)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 190\u001b[0;31m ds_result = evaluate_dataset(X.to(device), y.to(device), categorical_feats, model, bptt, eval_position_range,\n\u001b[0m\u001b[1;32m 191\u001b[0m rescale_features=rescale_features_factor, max_samples=max_samples)\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_dataset\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position_range, plot, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0meval_position\u001b[0m \u001b[0;32min\u001b[0m \u001b[0meval_position_range\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 221\u001b[0;31m r = evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=rescale_features,\n\u001b[0m\u001b[1;32m 222\u001b[0m max_samples=max_samples)\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_position\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m 308\u001b[0m \u001b[0;31m# acc_per_position = torch.tensor(batch_pred(model,eval_xs,eval_ys, start=2))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 309\u001b[0;31m \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_pred\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_xs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 310\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mbatch_pred\u001b[0;34m(acc_function, eval_xs, eval_ys, categorical_feats, start)\u001b[0m\n\u001b[1;32m 323\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 324\u001b[0;31m \u001b[0macc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0macc_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 325\u001b[0m \u001b[0maccs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0macc\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mxgb_acc\u001b[0;34m(x, y, test_x, test_y, cat_features)\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[0;31m# fit model to data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 628\u001b[0;31m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 629\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, groups, **fit_params)\u001b[0m\n\u001b[1;32m 879\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0my\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 880\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbest_estimator_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 881\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/core.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 432\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 433\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 434\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/sklearn.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight, base_margin, eval_set, eval_metric, early_stopping_rounds, verbose, xgb_model, sample_weight_eval_set, base_margin_eval_set, feature_weights, callbacks)\u001b[0m\n\u001b[1;32m 1175\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1176\u001b[0;31m self._Booster = train(\n\u001b[0m\u001b[1;32m 1177\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/training.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(params, dtrain, num_boost_round, evals, obj, feval, maximize, early_stopping_rounds, evals_result, verbose_eval, xgb_model, callbacks)\u001b[0m\n\u001b[1;32m 188\u001b[0m \"\"\"\n\u001b[0;32m--> 189\u001b[0;31m bst = _train_internal(params, dtrain,\n\u001b[0m\u001b[1;32m 190\u001b[0m \u001b[0mnum_boost_round\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_boost_round\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/training.py\u001b[0m in \u001b[0;36m_train_internal\u001b[0;34m(params, dtrain, num_boost_round, evals, obj, feval, xgb_model, callbacks, evals_result, maximize, verbose_eval, early_stopping_rounds)\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 81\u001b[0;31m \u001b[0mbst\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 82\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mafter_iteration\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbst\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/core.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, dtrain, iteration, fobj)\u001b[0m\n\u001b[1;32m 1495\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfobj\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1496\u001b[0;31m _check_call(_LIB.XGBoosterUpdateOneIter(self.handle,\n\u001b[0m\u001b[1;32m 1497\u001b[0m \u001b[0mctypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_int\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miteration\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: ", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mshowtraceback\u001b[0;34m(self, exc_tuple, filename, tb_offset, exception_only, running_compiled_code)\u001b[0m\n\u001b[1;32m 2060\u001b[0m \u001b[0;31m# in the engines. This should return a list of strings.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2061\u001b[0;31m \u001b[0mstb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render_traceback_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2062\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'KeyboardInterrupt' object has no attribute '_render_traceback_'", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mshowtraceback\u001b[0;34m(self, exc_tuple, filename, tb_offset, exception_only, running_compiled_code)\u001b[0m\n\u001b[1;32m 2061\u001b[0m \u001b[0mstb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render_traceback_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2062\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2063\u001b[0;31m stb = self.InteractiveTB.structured_traceback(etype,\n\u001b[0m\u001b[1;32m 2064\u001b[0m value, tb, tb_offset=tb_offset)\n\u001b[1;32m 2065\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, value, tb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m 1365\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1366\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1367\u001b[0;31m return FormattedTB.structured_traceback(\n\u001b[0m\u001b[1;32m 1368\u001b[0m self, etype, value, tb, tb_offset, number_of_lines_of_context)\n\u001b[1;32m 1369\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, value, tb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m 1265\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mverbose_modes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1266\u001b[0m \u001b[0;31m# Verbose modes need a full traceback\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1267\u001b[0;31m return VerboseTB.structured_traceback(\n\u001b[0m\u001b[1;32m 1268\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0metype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtb_offset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumber_of_lines_of_context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1269\u001b[0m )\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, evalue, etb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m 1122\u001b[0m \u001b[0;34m\"\"\"Return a nice text document describing the traceback.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1124\u001b[0;31m formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,\n\u001b[0m\u001b[1;32m 1125\u001b[0m tb_offset)\n\u001b[1;32m 1126\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mformat_exception_as_a_whole\u001b[0;34m(self, etype, evalue, etb, number_of_lines_of_context, tb_offset)\u001b[0m\n\u001b[1;32m 1080\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1081\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1082\u001b[0;31m \u001b[0mlast_unique\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecursion_repeat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfind_recursion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0morig_etype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1083\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1084\u001b[0m \u001b[0mframes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat_records\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrecords\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlast_unique\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecursion_repeat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mfind_recursion\u001b[0;34m(etype, value, records)\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[0;31m# first frame (from in to out) that looks different.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 381\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_recursion_error\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0metype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 382\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 383\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 384\u001b[0m \u001b[0;31m# Select filename, lineno, func_name to track frames with\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: object of type 'NoneType' has no len()" + ] + } + ], + "source": [ + "result = evaluate(ds, ridge_metric, 'ridge'\n", + " , bptt=seq_len\n", + " , eval_position_range=eval_positions\n", + " , device=device\n", + " , max_samples=20\n", + " , overwrite=True\n", + " , save=False)\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "hidden": true + }, + "outputs": [], + "source": [ + "result" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "heading_collapsed": true + }, + "source": [ + "### Catboost" + ] + }, + { + "cell_type": "code", + "execution_count": 1338, + "metadata": { + "hidden": true, + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating kr-vs-kp\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 3.91it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.6428571428571429\n", + "\t Eval position 30 done..\n", + "Evaluating credit-g\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 5.58it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.5571428571428572\n", + "\t Eval position 30 done..\n", + "Evaluating vehicle\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████████| 1/1 [00:00<00:00, 6.29it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.5285714285714286\n", + "\t Eval position 30 done..\n", + "Evaluating wine\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 6.83it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.8714285714285714\n", + "\t Eval position 30 done..\n", + "Evaluating kc1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 6.38it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.6571428571428571\n", + "\t Eval position 30 done..\n", + "Evaluating airlines\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 6.96it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.7714285714285715\n", + "\t Eval position 30 done..\n", + "Evaluating bank-marketing\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 6.15it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.6571428571428571\n", + "\t Eval position 30 done..\n", + "Evaluating blood-transfusion-service-center\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 7.33it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.5285714285714286\n", + "\t Eval position 30 done..\n", + "Evaluating phoneme\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 6.53it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.7428571428571429\n", + "\t Eval position 30 done..\n", + "Evaluating covertype\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 2.03it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.6\n", + "\t Eval position 30 done..\n", + "Evaluating numerai28.6\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 6.47it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.5142857142857142\n", + "\t Eval position 30 done..\n", + "Evaluating connect-4\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 4.79it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.9428571428571428\n", + "\t Eval position 30 done..\n", + "Evaluating car\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 6.46it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.7\n", + "\t Eval position 30 done..\n", + "Evaluating Australian\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 5.96it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.8857142857142857\n", + "\t Eval position 30 done..\n", + "Evaluating segment\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 5.98it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.0\n", + "\t Eval position 30 done..\n", + "Evaluating jungle_chess_2pcs_raw_endgame_complete\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 6.91it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.7142857142857143\n", + "\t Eval position 30 done..\n", + "Evaluating sylvine\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 6.55it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.8142857142857143\n", + "\t Eval position 30 done..\n", + "Evaluating MiniBooNE\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 5.49it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.8142857142857143\n", + "\t Eval position 30 done..\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating dionis\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 5.22it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.9142857142857143\n", + "\t Eval position 30 done..\n", + "Evaluating jannis\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:00<00:00, 4.91it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.5571428571428572\n", + "\t Eval position 30 done..\n", + "Evaluating helena\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/1 [00:00 3629\u001b[0;31m cv_result = self._object._tune_hyperparams(\n\u001b[0m\u001b[1;32m 3630\u001b[0m \u001b[0mparam_grid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_params\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"train_pool\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_iter\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m_catboost.pyx\u001b[0m in \u001b[0;36m_catboost._CatBoost._tune_hyperparams\u001b[0;34m()\u001b[0m\n", + "\u001b[0;32m_catboost.pyx\u001b[0m in \u001b[0;36m_catboost._CatBoost._tune_hyperparams\u001b[0;34m()\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: ", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_71014/2224404849.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalid_datasets\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mselector\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'valid'\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mtest_datasets\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcatboost_acc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'catboost'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbptt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_positions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mplot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'cpu'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moverwrite\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(datasets, model, method, bptt, eval_position_range, device, max_features, plot, extend_features, save, rescale_features, overwrite, max_samples, path_interfix)\u001b[0m\n\u001b[1;32m 188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 190\u001b[0;31m ds_result = evaluate_dataset(X.to(device), y.to(device), categorical_feats, model, bptt, eval_position_range,\n\u001b[0m\u001b[1;32m 191\u001b[0m rescale_features=rescale_features_factor, max_samples=max_samples)\n\u001b[1;32m 192\u001b[0m \u001b[0melapsed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_dataset\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position_range, plot, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m 219\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0meval_position\u001b[0m \u001b[0;32min\u001b[0m \u001b[0meval_position_range\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 221\u001b[0;31m r = evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=rescale_features,\n\u001b[0m\u001b[1;32m 222\u001b[0m max_samples=max_samples)\n\u001b[1;32m 223\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_position\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 308\u001b[0m \u001b[0;31m# acc_per_position = torch.tensor(batch_pred(model,eval_xs,eval_ys, start=2))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 309\u001b[0;31m \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_pred\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_xs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 310\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mbatch_pred\u001b[0;34m(acc_function, eval_xs, eval_ys, categorical_feats, start)\u001b[0m\n\u001b[1;32m 322\u001b[0m \u001b[0meval_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mmean\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mstd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 323\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 324\u001b[0;31m \u001b[0macc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0macc_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 325\u001b[0m \u001b[0maccs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0macc\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 326\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mcatboost_acc\u001b[0;34m(x, y, test_x, test_y, categorical_feats)\u001b[0m\n\u001b[1;32m 593\u001b[0m logging_level='Silent')\n\u001b[1;32m 594\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 595\u001b[0;31m grid_search_result = model.grid_search(param_grid['catboost'],\n\u001b[0m\u001b[1;32m 596\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 597\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/catboost/core.py\u001b[0m in \u001b[0;36mgrid_search\u001b[0;34m(self, param_grid, X, y, cv, partition_random_seed, calc_cv_statistics, search_by_train_test_split, refit, shuffle, stratified, train_size, verbose, plot, log_cout, log_cerr)\u001b[0m\n\u001b[1;32m 3730\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Parameter grid value is not iterable (key={!r}, value={!r})'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3731\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3732\u001b[0;31m return self._tune_hyperparams(\n\u001b[0m\u001b[1;32m 3733\u001b[0m \u001b[0mparam_grid\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparam_grid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcv\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_iter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3734\u001b[0m \u001b[0mpartition_random_seed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpartition_random_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcalc_cv_statistics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcalc_cv_statistics\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/catboost/core.py\u001b[0m in \u001b[0;36m_tune_hyperparams\u001b[0;34m(self, param_grid, X, y, cv, n_iter, partition_random_seed, calc_cv_statistics, search_by_train_test_split, refit, shuffle, stratified, train_size, verbose, plot, log_cout, log_cerr)\u001b[0m\n\u001b[1;32m 3627\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3628\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mlog_fixup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlog_cout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_cerr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mplot_wrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0m_get_train_dir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3629\u001b[0;31m cv_result = self._object._tune_hyperparams(\n\u001b[0m\u001b[1;32m 3630\u001b[0m \u001b[0mparam_grid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_params\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"train_pool\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_iter\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3631\u001b[0m \u001b[0mfold_count\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpartition_random_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstratified\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "result = evaluate(ds, catboost_metric, 'catboost'\n", + " , bptt=seq_len\n", + " , eval_position_range=eval_positions\n", + " , device=device\n", + " , max_samples=20\n", + " , overwrite=True\n", + " , save=False)\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "hidden": true + }, + "outputs": [], + "source": [ + "result" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "heading_collapsed": true + }, + "source": [ + "## Running bayesian inference on one dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": { + "hidden": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['wine',\n", + " tensor([[1.1640e+01, 2.0600e+00, 2.4600e+00, ..., 1.0000e+00, 2.7500e+00,\n", + " 6.8000e+02],\n", + " [1.3860e+01, 1.3500e+00, 2.2700e+00, ..., 1.0100e+00, 3.5500e+00,\n", + " 1.0450e+03],\n", + " [1.2340e+01, 2.4500e+00, 2.4600e+00, ..., 8.0000e-01, 3.3800e+00,\n", + " 4.3800e+02],\n", + " ...,\n", + " [1.3450e+01, 3.7000e+00, 2.6000e+00, ..., 8.5000e-01, 1.5600e+00,\n", + " 6.9500e+02],\n", + " [1.1560e+01, 2.0500e+00, 3.2300e+00, ..., 9.3000e-01, 3.6900e+00,\n", + " 4.6500e+02],\n", + " [1.2820e+01, 3.3700e+00, 2.3000e+00, ..., 7.2000e-01, 1.7500e+00,\n", + " 6.8500e+02]]),\n", + " tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,\n", + " 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,\n", + " 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,\n", + " 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,\n", + " 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,\n", + " 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,\n", + " 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,\n", + " 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.]),\n", + " []]" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# NOTE: Dataset must be ordered 1,0,1,0,...\n", + "one_ds = ds[0]\n", + "one_ds" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": { + "hidden": true, + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.07891185]\n", + " [0.9558054 ]\n", + " [0.05336431]\n", + " [0.98093307]\n", + " [0.07384931]\n", + " [0.8429412 ]\n", + " [0.03347415]\n", + " [0.973169 ]\n", + " [0.04434767]\n", + " [0.9206543 ]\n", + " [0.03534983]\n", + " [0.7528224 ]\n", + " [0.23227222]\n", + " [0.96119374]\n", + " [0.20475379]\n", + " [0.97402555]\n", + " [0.04753578]\n", + " [0.8418446 ]\n", + " [0.20400934]\n", + " [0.9496556 ]]\n" + ] + } + ], + "source": [ + "result_list = []\n", + "eval_position = eval_positions[0]\n", + "rescale_features = max_features / one_ds[1].shape[1] if config['prior_normalize_by_used_features'] else 1\n", + "\n", + "model = model.to(device)\n", + "model = model.eval()\n", + "\n", + "# Data to run inference on\n", + "eval_xs = one_ds[1][0:50]\n", + "# Extending dataset to standardized feature length\n", + "eval_xs = torch.cat([eval_xs, torch.zeros((eval_xs.shape[0], max_features - eval_xs.shape[1]))], -1).to(device)\n", + "eval_ys = one_ds[2][0:50].to(device)\n", + "\n", + "for i, pos in enumerate(range(eval_position, eval_xs.shape[0])):\n", + " eval_x = torch.cat([eval_xs[:eval_position], eval_xs[pos].unsqueeze(0)])\n", + " eval_y = eval_ys[:eval_position]\n", + "\n", + " # Center data using training positions so that it matches priors\n", + " mean = eval_x.mean(0)\n", + " std = eval_x.std(0) + .000001\n", + " eval_x = (eval_x - mean) / std\n", + " eval_x = eval_x / rescale_features\n", + "\n", + " result_list += [torch.sigmoid(model((eval_x.unsqueeze(1), eval_y.unsqueeze(1).float()), single_eval_pos=eval_position)).squeeze(-1).detach().cpu().numpy()[0]]\n", + "\n", + "print(np.array(result_list))" + ] + }, + { + "cell_type": "code", + "execution_count": 1229, + "metadata": { + "hidden": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating kr-vs-kp\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/5 [00:00\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalid_datasets\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mselector\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'valid'\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mtest_datasets\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbayes_net_acc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'bayes_net'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbptt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_positions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'cpu'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moverwrite\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/home/hollmann/prior-fitting/results_'\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mselector\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;34m'_bayes_net.npy'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'wb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(datasets, model, method, bptt, eval_position_range, device, max_features, plot, extend_features, save, rescale_features, overwrite, max_samples, path_interfix)\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 193\u001b[0;31m ds_result = evaluate_dataset(X.to(device), y.to(device), categorical_feats, model, bptt, eval_position_range,\n\u001b[0m\u001b[1;32m 194\u001b[0m rescale_features=rescale_features_factor, max_samples=max_samples)\n\u001b[1;32m 195\u001b[0m \u001b[0melapsed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_dataset\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position_range, plot, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0meval_position\u001b[0m \u001b[0;32min\u001b[0m \u001b[0meval_position_range\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 224\u001b[0;31m r = evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=rescale_features,\n\u001b[0m\u001b[1;32m 225\u001b[0m max_samples=max_samples)\n\u001b[1;32m 226\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_position\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m 310\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[0;31m# acc_per_position = torch.tensor(batch_pred(model,eval_xs,eval_ys, start=2))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 312\u001b[0;31m \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_pred\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_xs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 313\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 314\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mbatch_pred\u001b[0;34m(acc_function, eval_xs, eval_ys, categorical_feats, start)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[0meval_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mmean\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mstd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 326\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 327\u001b[0;31m \u001b[0macc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0macc_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 328\u001b[0m \u001b[0maccs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0macc\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 329\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mbayes_net_acc\u001b[0;34m(x, y, test_x, test_y, cat_features)\u001b[0m\n\u001b[1;32m 514\u001b[0m \u001b[0mclf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGridSearchCV\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_grid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'bayes'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcv\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 515\u001b[0m \u001b[0;31m# fit model to data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 516\u001b[0;31m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 517\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 518\u001b[0m \u001b[0mpred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_x\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;31m# extra_args > 0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, groups, **fit_params)\u001b[0m\n\u001b[1;32m 839\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 840\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 841\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_search\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 842\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 843\u001b[0m \u001b[0;31m# multimetric is determined here because in the case of a callable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36m_run_search\u001b[0;34m(self, evaluate_candidates)\u001b[0m\n\u001b[1;32m 1294\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_run_search\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1295\u001b[0m \u001b[0;34m\"\"\"Search all candidates in param_grid\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1296\u001b[0;31m \u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mParameterGrid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparam_grid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1297\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1298\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mevaluate_candidates\u001b[0;34m(candidate_params, cv, more_results)\u001b[0m\n\u001b[1;32m 793\u001b[0m n_splits, n_candidates, n_candidates * n_splits))\n\u001b[1;32m 794\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 795\u001b[0;31m out = parallel(delayed(_fit_and_score)(clone(base_estimator),\n\u001b[0m\u001b[1;32m 796\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 797\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 1042\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_iterating\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_original_iterator\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1043\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1044\u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_one_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1045\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1046\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36mdispatch_one_batch\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m 857\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 858\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 859\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtasks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 860\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 861\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m_dispatch\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 775\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_lock\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 776\u001b[0m \u001b[0mjob_idx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 777\u001b[0;31m \u001b[0mjob\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 778\u001b[0m \u001b[0;31m# A job can complete so quickly than its callback is\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 779\u001b[0m \u001b[0;31m# called before we get here, causing self._jobs to\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36mapply_async\u001b[0;34m(self, func, callback)\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mapply_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[0;34m\"\"\"Schedule a func to be run\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 208\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImmediateResult\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 209\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 570\u001b[0m \u001b[0;31m# Don't delay the application, to avoid keeping the input\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 571\u001b[0m \u001b[0;31m# arguments in memory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 572\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 573\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 574\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 260\u001b[0m \u001b[0;31m# change the default number of processes to -1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mparallel_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_n_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 262\u001b[0;31m return [func(*args, **kwargs)\n\u001b[0m\u001b[1;32m 263\u001b[0m for func, args, kwargs in self.items]\n\u001b[1;32m 264\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 260\u001b[0m \u001b[0;31m# change the default number of processes to -1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mparallel_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_n_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 262\u001b[0;31m return [func(*args, **kwargs)\n\u001b[0m\u001b[1;32m 263\u001b[0m for func, args, kwargs in self.items]\n\u001b[1;32m 264\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/fixes.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mconfig_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 222\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_validation.py\u001b[0m in \u001b[0;36m_fit_and_score\u001b[0;34m(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score, return_parameters, return_n_test_samples, return_times, return_estimator, split_progress, candidate_progress, error_score)\u001b[0m\n\u001b[1;32m 596\u001b[0m \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 597\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 598\u001b[0;31m \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 599\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 600\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 480\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_steps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 481\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msvi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 482\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 483\u001b[0m \u001b[0;31m# Return the classifier\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/infer/svi.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0;31m# get loss and compute gradients\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mpoutine\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam_only\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mparam_capture\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_and_grads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguide\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m params = set(\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/infer/trace_elbo.py\u001b[0m in \u001b[0;36mloss_and_grads\u001b[0;34m(self, model, guide, *args, **kwargs)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0.0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0;31m# grab a trace from the generator\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 140\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mmodel_trace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mguide_trace\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_traces\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mguide\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 141\u001b[0m loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(\n\u001b[1;32m 142\u001b[0m \u001b[0mmodel_trace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mguide_trace\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/infer/elbo.py\u001b[0m in \u001b[0;36m_get_traces\u001b[0;34m(self, model, guide, args, kwargs)\u001b[0m\n\u001b[1;32m 184\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_particles\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 186\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mguide\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/infer/trace_elbo.py\u001b[0m in \u001b[0;36m_get_trace\u001b[0;34m(self, model, guide, args, kwargs)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0magainst\u001b[0m \u001b[0mit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \"\"\"\n\u001b[0;32m---> 57\u001b[0;31m model_trace, guide_trace = get_importance_trace(\n\u001b[0m\u001b[1;32m 58\u001b[0m \u001b[0;34m\"flat\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_plate_nesting\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mguide\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m )\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/infer/enum.py\u001b[0m in \u001b[0;36mget_importance_trace\u001b[0;34m(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0mmodel_trace\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprune_subsample_sites\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_trace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 64\u001b[0;31m \u001b[0mmodel_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_log_prob\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 65\u001b[0m \u001b[0mguide_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_score_parts\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_validation_enabled\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/poutine/trace_struct.py\u001b[0m in \u001b[0;36mcompute_log_prob\u001b[0;34m(self, site_filter)\u001b[0m\n\u001b[1;32m 248\u001b[0m \u001b[0;34m\"log_prob_sum at site '{}'\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 249\u001b[0m )\n\u001b[0;32m--> 250\u001b[0;31m warn_if_inf(\n\u001b[0m\u001b[1;32m 251\u001b[0m \u001b[0msite\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"log_prob_sum\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 252\u001b[0m \u001b[0;34m\"log_prob_sum at site '{}'\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/util.py\u001b[0m in \u001b[0;36mwarn_if_inf\u001b[0;34m(value, msg, allow_posinf, allow_neginf, filename, lineno)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minf\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumbers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNumber\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m \u001b[0;32melse\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 133\u001b[0m ):\n\u001b[1;32m 134\u001b[0m warnings.warn_explicit(\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "result = evaluate(ds, bayes_net_acc, 'bayes_net'\n", + " , bptt=seq_len\n", + " , eval_position_range=eval_positions\n", + " , device=device\n", + " , max_samples=20\n", + " , overwrite=True\n", + " , save=False)\n", + "result" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:prior-fitting]", + "language": "python", + "name": "conda-env-prior-fitting-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}