Samuel Mueller
commited on
Commit
·
f50f696
1
Parent(s):
ab8ac48
working locally
Browse files- .gitignore +3 -0
- SettingUpTheWebiste.ipynb +209 -0
- app.py +102 -3
- prior-fitting/.gitignore +129 -0
- prior-fitting/README.md +1 -0
- prior-fitting/acquisition_functions.py +18 -0
- prior-fitting/bar_distribution.py +147 -0
- prior-fitting/decoders.py +30 -0
- prior-fitting/encoders.py +95 -0
- prior-fitting/losses.py +12 -0
- prior-fitting/mcmc_svi_transformer_on_bayesian.py +443 -0
- prior-fitting/notebooks/BayesianModels_And_Custom_Pyro_Modules.ipynb +524 -0
- prior-fitting/notebooks/FewShotOmniglot.ipynb +168 -0
- prior-fitting/notebooks/SetupForGPFittingExperiments.ipynb +270 -0
- prior-fitting/notebooks/TabularEvalSimple.ipynb +0 -0
- prior-fitting/notebooks/Untitled.ipynb +180 -0
- prior-fitting/positional_encodings.py +70 -0
- prior-fitting/presentation/heatmap_bardistribution.py +97 -0
- prior-fitting/priors/__init__.py +4 -0
- prior-fitting/priors/binarized_regression.py +21 -0
- prior-fitting/priors/fast_gp.py +130 -0
- prior-fitting/priors/fast_gp_mix.py +307 -0
- prior-fitting/priors/gp.py +70 -0
- prior-fitting/priors/mlp.py +208 -0
- prior-fitting/priors/omniglot.py +98 -0
- prior-fitting/priors/prior.py +12 -0
- prior-fitting/priors/pyro.py +39 -0
- prior-fitting/priors/ridge.py +38 -0
- prior-fitting/priors/stroke.py +143 -0
- prior-fitting/priors/utils.py +102 -0
- prior-fitting/requirements.txt +13 -0
- prior-fitting/tabular.py +725 -0
- prior-fitting/train.py +288 -0
- prior-fitting/transformer.py +91 -0
- prior-fitting/utils.py +115 -0
- requirements.txt +13 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.ipynb_checkpoints
|
2 |
+
flagged
|
3 |
+
.idea
|
SettingUpTheWebiste.ipynb
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "963a04b2",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": []
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "code",
|
13 |
+
"execution_count": null,
|
14 |
+
"id": "8ebc97aa",
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": []
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"cell_type": "code",
|
21 |
+
"execution_count": 1,
|
22 |
+
"id": "b73f00ce",
|
23 |
+
"metadata": {},
|
24 |
+
"outputs": [
|
25 |
+
{
|
26 |
+
"name": "stdout",
|
27 |
+
"output_type": "stream",
|
28 |
+
"text": [
|
29 |
+
"Running locally at: http://127.0.0.1:7860/\n",
|
30 |
+
"To create a public link, set `share=True` in `launch()`.\n"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"data": {
|
35 |
+
"text/html": [
|
36 |
+
"\n",
|
37 |
+
" <iframe\n",
|
38 |
+
" width=\"900\"\n",
|
39 |
+
" height=\"500\"\n",
|
40 |
+
" src=\"http://127.0.0.1:7860/\"\n",
|
41 |
+
" frameborder=\"0\"\n",
|
42 |
+
" allowfullscreen\n",
|
43 |
+
" ></iframe>\n",
|
44 |
+
" "
|
45 |
+
],
|
46 |
+
"text/plain": [
|
47 |
+
"<IPython.lib.display.IFrame at 0x7f8f67cba520>"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
"metadata": {},
|
51 |
+
"output_type": "display_data"
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"data": {
|
55 |
+
"text/plain": [
|
56 |
+
"(<Flask 'gradio.networking'>, 'http://127.0.0.1:7860/', None)"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
"execution_count": 1,
|
60 |
+
"metadata": {},
|
61 |
+
"output_type": "execute_result"
|
62 |
+
}
|
63 |
+
],
|
64 |
+
"source": [
|
65 |
+
"import gradio as gr\n",
|
66 |
+
"import numpy as np\n",
|
67 |
+
"import matplotlib.pyplot as plt\n",
|
68 |
+
"import gpytorch\n",
|
69 |
+
"import torch\n",
|
70 |
+
"import sys\n",
|
71 |
+
"\n",
|
72 |
+
"import gpytorch\n",
|
73 |
+
"\n",
|
74 |
+
"# We will use the simplest form of GP model, exact inference\n",
|
75 |
+
"class ExactGPModel(gpytorch.models.ExactGP):\n",
|
76 |
+
" def __init__(self, train_x, train_y, likelihood):\n",
|
77 |
+
" super(ExactGPModel, self).__init__(train_x, train_y, likelihood)\n",
|
78 |
+
" self.mean_module = gpytorch.means.ConstantMean()\n",
|
79 |
+
" self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
|
80 |
+
"\n",
|
81 |
+
" def forward(self, x):\n",
|
82 |
+
" mean_x = self.mean_module(x)\n",
|
83 |
+
" covar_x = self.covar_module(x)\n",
|
84 |
+
" return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
|
85 |
+
"\n",
|
86 |
+
"def get_model(x, y, hyperparameters):\n",
|
87 |
+
" likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9))\n",
|
88 |
+
" model = ExactGPModel(x, y, likelihood)\n",
|
89 |
+
" model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters[\"noise\"]\n",
|
90 |
+
" model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters[\"outputscale\"]\n",
|
91 |
+
" model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \\\n",
|
92 |
+
" hyperparameters[\"lengthscale\"]\n",
|
93 |
+
" return model, likelihood\n",
|
94 |
+
"\n",
|
95 |
+
"\n",
|
96 |
+
"\n",
|
97 |
+
"excuse = \"Please only specify numbers, x values should be in [0,1] and y values in [-1,1].\"\n",
|
98 |
+
"excuse_max_examples = \"This model is trained to work with up to 4 input points.\"\n",
|
99 |
+
"hyperparameters = {'noise': 1e-4, 'outputscale': 1., 'lengthscale': .1, 'fast_computations': (False,False,False)}\n",
|
100 |
+
"\n",
|
101 |
+
"\n",
|
102 |
+
"conf = .5\n",
|
103 |
+
"\n",
|
104 |
+
"def mean_and_bounds_for_gp(x,y,test_xs):\n",
|
105 |
+
" gp_model, likelihood = get_model(x,y,hyperparameters)\n",
|
106 |
+
" gp_model.eval()\n",
|
107 |
+
" l = likelihood(gp_model(test_xs))\n",
|
108 |
+
" means = l.mean.squeeze()\n",
|
109 |
+
" varis = torch.diagonal(l.covariance_matrix.squeeze())\n",
|
110 |
+
" stds = varis.sqrt()\n",
|
111 |
+
" return means, means-stds, means+stds\n",
|
112 |
+
"\n",
|
113 |
+
"\n",
|
114 |
+
"def mean_and_bounds_for_pnf(x,y,test_xs, choice):\n",
|
115 |
+
" sys.path.append('prior-fitting/')\n",
|
116 |
+
" model = torch.load(f'onefeature_gp_ls.1_pnf_{choice}.pt')\n",
|
117 |
+
"\n",
|
118 |
+
"\n",
|
119 |
+
" logits = model((torch.cat([x,test_xs],0).unsqueeze(1),y.unsqueeze(1)),single_eval_pos=len(x))\n",
|
120 |
+
" bounds = model.criterion.quantile(logits,center_prob=.682).squeeze(1)\n",
|
121 |
+
" return model.criterion.mean(logits).squeeze(1), bounds[:,0], bounds[:,1]\n",
|
122 |
+
"\n",
|
123 |
+
"def plot_w_conf_interval(ax_or_plt, x, m, lb, ub, color):\n",
|
124 |
+
" ax_or_plt.plot(x.squeeze(-1),m, color=color)\n",
|
125 |
+
" ax_or_plt.fill_between(x.squeeze(-1), lb, ub, alpha=.1, color=color)\n",
|
126 |
+
"\n",
|
127 |
+
"\n",
|
128 |
+
"\n",
|
129 |
+
"\n",
|
130 |
+
"@torch.no_grad()\n",
|
131 |
+
"def infer(table, choice):\n",
|
132 |
+
" vfunc = np.vectorize(lambda s: len(s))\n",
|
133 |
+
" non_empty_row_mask = (vfunc(table).sum(1) != 0)\n",
|
134 |
+
" table = table[non_empty_row_mask]\n",
|
135 |
+
"\n",
|
136 |
+
" try:\n",
|
137 |
+
" table = table.astype(np.float32)\n",
|
138 |
+
" except ValueError:\n",
|
139 |
+
" return excuse, None\n",
|
140 |
+
" x = torch.tensor(table[:,0]).unsqueeze(1)\n",
|
141 |
+
" y = torch.tensor(table[:,1])\n",
|
142 |
+
" fig = plt.figure()\n",
|
143 |
+
"\n",
|
144 |
+
" if len(x) > 4:\n",
|
145 |
+
" return excuse_max_examples, None\n",
|
146 |
+
" if (x<0.).any() or (x>1.).any() or (y<-1).any() or (y>1).any():\n",
|
147 |
+
" return excuse, None\n",
|
148 |
+
"\n",
|
149 |
+
" plt.scatter(x,y)\n",
|
150 |
+
"\n",
|
151 |
+
"\n",
|
152 |
+
" \n",
|
153 |
+
" test_xs = torch.linspace(0,1,100).unsqueeze(1)\n",
|
154 |
+
" \n",
|
155 |
+
" plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_gp(x,y,test_xs), 'green')\n",
|
156 |
+
" plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_pnf(x,y,test_xs, choice), 'blue')\n",
|
157 |
+
"\n",
|
158 |
+
"\n",
|
159 |
+
" \n",
|
160 |
+
" return '', plt.gcf()\n",
|
161 |
+
"\n",
|
162 |
+
"iface = gr.Interface(fn=infer, \n",
|
163 |
+
" inputs=[\n",
|
164 |
+
" gr.inputs.Dataframe(headers=[\"x\", \"y\"], datatype=[\"number\", \"number\"], row_count=2, type='numpy', default=[['.25','.1'],['.75','.4']]),\n",
|
165 |
+
" gr.inputs.Radio(['160K','800K','4M'], type=\"value\", default='4M', label='Training Costs')\n",
|
166 |
+
" ], outputs=[\"text\",\"plot\"])\n",
|
167 |
+
"iface.launch()\n",
|
168 |
+
"\n"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"execution_count": null,
|
174 |
+
"id": "a3a377e3",
|
175 |
+
"metadata": {},
|
176 |
+
"outputs": [],
|
177 |
+
"source": []
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "code",
|
181 |
+
"execution_count": null,
|
182 |
+
"id": "72c0c821",
|
183 |
+
"metadata": {},
|
184 |
+
"outputs": [],
|
185 |
+
"source": []
|
186 |
+
}
|
187 |
+
],
|
188 |
+
"metadata": {
|
189 |
+
"kernelspec": {
|
190 |
+
"display_name": "Python 3 (ipykernel)",
|
191 |
+
"language": "python",
|
192 |
+
"name": "python3"
|
193 |
+
},
|
194 |
+
"language_info": {
|
195 |
+
"codemirror_mode": {
|
196 |
+
"name": "ipython",
|
197 |
+
"version": 3
|
198 |
+
},
|
199 |
+
"file_extension": ".py",
|
200 |
+
"mimetype": "text/x-python",
|
201 |
+
"name": "python",
|
202 |
+
"nbconvert_exporter": "python",
|
203 |
+
"pygments_lexer": "ipython3",
|
204 |
+
"version": "3.9.5"
|
205 |
+
}
|
206 |
+
},
|
207 |
+
"nbformat": 4,
|
208 |
+
"nbformat_minor": 5
|
209 |
+
}
|
app.py
CHANGED
@@ -1,7 +1,106 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
iface.launch()
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import gpytorch
|
5 |
+
import torch
|
6 |
+
import sys
|
7 |
|
8 |
+
import gpytorch
|
|
|
9 |
|
10 |
+
# We will use the simplest form of GP model, exact inference
|
11 |
+
class ExactGPModel(gpytorch.models.ExactGP):
|
12 |
+
def __init__(self, train_x, train_y, likelihood):
|
13 |
+
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
|
14 |
+
self.mean_module = gpytorch.means.ConstantMean()
|
15 |
+
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
mean_x = self.mean_module(x)
|
19 |
+
covar_x = self.covar_module(x)
|
20 |
+
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
|
21 |
+
|
22 |
+
def get_model(x, y, hyperparameters):
|
23 |
+
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9))
|
24 |
+
model = ExactGPModel(x, y, likelihood)
|
25 |
+
model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters["noise"]
|
26 |
+
model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters["outputscale"]
|
27 |
+
model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \
|
28 |
+
hyperparameters["lengthscale"]
|
29 |
+
return model, likelihood
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
excuse = "Please only specify numbers, x values should be in [0,1] and y values in [-1,1]."
|
34 |
+
excuse_max_examples = "This model is trained to work with up to 4 input points."
|
35 |
+
hyperparameters = {'noise': 1e-4, 'outputscale': 1., 'lengthscale': .1, 'fast_computations': (False,False,False)}
|
36 |
+
|
37 |
+
|
38 |
+
conf = .5
|
39 |
+
|
40 |
+
def mean_and_bounds_for_gp(x,y,test_xs):
|
41 |
+
gp_model, likelihood = get_model(x,y,hyperparameters)
|
42 |
+
gp_model.eval()
|
43 |
+
l = likelihood(gp_model(test_xs))
|
44 |
+
means = l.mean.squeeze()
|
45 |
+
varis = torch.diagonal(l.covariance_matrix.squeeze())
|
46 |
+
stds = varis.sqrt()
|
47 |
+
return means, means-stds, means+stds
|
48 |
+
|
49 |
+
|
50 |
+
def mean_and_bounds_for_pnf(x,y,test_xs, choice):
|
51 |
+
sys.path.append('prior-fitting/')
|
52 |
+
model = torch.load(f'onefeature_gp_ls.1_pnf_{choice}.pt')
|
53 |
+
|
54 |
+
|
55 |
+
logits = model((torch.cat([x,test_xs],0).unsqueeze(1),y.unsqueeze(1)),single_eval_pos=len(x))
|
56 |
+
bounds = model.criterion.quantile(logits,center_prob=.682).squeeze(1)
|
57 |
+
return model.criterion.mean(logits).squeeze(1), bounds[:,0], bounds[:,1]
|
58 |
+
|
59 |
+
def plot_w_conf_interval(ax_or_plt, x, m, lb, ub, color):
|
60 |
+
ax_or_plt.plot(x.squeeze(-1),m, color=color)
|
61 |
+
ax_or_plt.fill_between(x.squeeze(-1), lb, ub, alpha=.1, color=color)
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def infer(table, choice):
|
68 |
+
vfunc = np.vectorize(lambda s: len(s))
|
69 |
+
non_empty_row_mask = (vfunc(table).sum(1) != 0)
|
70 |
+
table = table[non_empty_row_mask]
|
71 |
+
|
72 |
+
try:
|
73 |
+
table = table.astype(np.float32)
|
74 |
+
except ValueError:
|
75 |
+
return excuse, None
|
76 |
+
x = torch.tensor(table[:,0]).unsqueeze(1)
|
77 |
+
y = torch.tensor(table[:,1])
|
78 |
+
fig = plt.figure()
|
79 |
+
|
80 |
+
if len(x) > 4:
|
81 |
+
return excuse_max_examples, None
|
82 |
+
if (x<0.).any() or (x>1.).any() or (y<-1).any() or (y>1).any():
|
83 |
+
return excuse, None
|
84 |
+
|
85 |
+
plt.scatter(x,y)
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
test_xs = torch.linspace(0,1,100).unsqueeze(1)
|
90 |
+
|
91 |
+
plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_gp(x,y,test_xs), 'green')
|
92 |
+
plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_pnf(x,y,test_xs, choice), 'blue')
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
return '', plt.gcf()
|
97 |
+
|
98 |
+
iface = gr.Interface(fn=infer,
|
99 |
+
inputs=[
|
100 |
+
gr.inputs.Dataframe(headers=["x", "y"], datatype=["number", "number"], row_count=2, type='numpy', default=[['.25','.1'],['.75','.4']]),
|
101 |
+
gr.inputs.Radio(['160K','800K','4M'], type="value", default='4M', label='Training Costs')
|
102 |
+
], outputs=["text","plot"])
|
103 |
iface.launch()
|
104 |
+
|
105 |
+
|
106 |
+
|
prior-fitting/.gitignore
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
prior-fitting/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# prior-fitting
|
prior-fitting/acquisition_functions.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from botorch.acquisition import AcquisitionFunction
|
2 |
+
from torch import Tensor
|
3 |
+
|
4 |
+
|
5 |
+
class ExpectedImprovement(AcquisitionFunction):
|
6 |
+
def forward(self, X: Tensor, best_f, maximize=True) -> Tensor: # X: evaluation_points x feature_dim
|
7 |
+
assert len(X.shape) == 2
|
8 |
+
|
9 |
+
model = self.get_submodule('model')
|
10 |
+
|
11 |
+
y = model(X)
|
12 |
+
|
13 |
+
full_range = model.full_range
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
prior-fitting/bar_distribution.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
class BarDistribution(nn.Module):
|
6 |
+
def __init__(self, borders: torch.Tensor): # here borders should start with min and end with max, where all values lie in (min,max) and are sorted
|
7 |
+
# sorted list of borders
|
8 |
+
super().__init__()
|
9 |
+
assert len(borders.shape) == 1
|
10 |
+
#self.borders = borders
|
11 |
+
self.register_buffer('borders', borders)
|
12 |
+
#self.bucket_widths = self.borders[1:] - self.borders[:-1]
|
13 |
+
self.register_buffer('bucket_widths', self.borders[1:] - self.borders[:-1])
|
14 |
+
full_width = self.bucket_widths.sum()
|
15 |
+
assert (full_width - (self.borders[-1] - self.borders[0])).abs() < 1e-4, f'diff: {full_width - (self.borders[-1] - self.borders[0])}'
|
16 |
+
assert (torch.argsort(borders) == torch.arange(len(borders))).all(), "Please provide sorted borders!"
|
17 |
+
self.num_bars = len(borders) - 1
|
18 |
+
|
19 |
+
def map_to_bucket_idx(self, y):
|
20 |
+
target_sample = torch.searchsorted(self.borders, y) - 1
|
21 |
+
target_sample[y == self.borders[0]] = 0
|
22 |
+
target_sample[y == self.borders[-1]] = self.num_bars - 1
|
23 |
+
return target_sample
|
24 |
+
|
25 |
+
def forward(self, logits, y): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
|
26 |
+
target_sample = self.map_to_bucket_idx(y)
|
27 |
+
assert (target_sample >= 0).all() and (target_sample < self.num_bars).all(), f'y {y} not in support set for borders (min_y, max_y) {self.borders}'
|
28 |
+
assert logits.shape[-1] == self.num_bars, f'{logits.shape[-1]} vs {self.num_bars}'
|
29 |
+
|
30 |
+
bucket_log_probs = torch.log_softmax(logits, -1)
|
31 |
+
scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
|
32 |
+
|
33 |
+
return -scaled_bucket_log_probs.gather(-1,target_sample.unsqueeze(-1)).squeeze(-1)
|
34 |
+
|
35 |
+
def mean(self, logits):
|
36 |
+
bucket_means = self.borders[:-1] + self.bucket_widths/2
|
37 |
+
p = torch.softmax(logits, -1)
|
38 |
+
return p @ bucket_means
|
39 |
+
|
40 |
+
def quantile(self, logits, center_prob=.682):
|
41 |
+
logits_shape = logits.shape
|
42 |
+
logits = logits.view(-1, logits.shape[-1])
|
43 |
+
side_prob = (1-center_prob)/2
|
44 |
+
probs = logits.softmax(-1)
|
45 |
+
flipped_probs = probs.flip(-1)
|
46 |
+
cumprobs = torch.cumsum(probs, -1)
|
47 |
+
flipped_cumprobs = torch.cumsum(flipped_probs, -1)
|
48 |
+
|
49 |
+
def find_lower_quantile(probs, cumprobs, side_prob, borders):
|
50 |
+
idx = (torch.searchsorted(cumprobs, side_prob)).clamp(0, len(cumprobs)-1) # this might not do the right for outliers
|
51 |
+
|
52 |
+
left_prob = cumprobs[idx-1]
|
53 |
+
rest_prob = side_prob - left_prob
|
54 |
+
left_border, right_border = borders[idx:idx+2]
|
55 |
+
return left_border + (right_border-left_border)*rest_prob/probs[idx]
|
56 |
+
|
57 |
+
results = []
|
58 |
+
for p,cp,f_p,f_cp in zip(probs, cumprobs, flipped_probs, flipped_cumprobs):
|
59 |
+
r = find_lower_quantile(p, cp, side_prob, self.borders), find_lower_quantile(f_p, f_cp, side_prob, self.borders.flip(0))
|
60 |
+
results.append(r)
|
61 |
+
|
62 |
+
return torch.tensor(results).reshape(*logits_shape[:-1],2)
|
63 |
+
|
64 |
+
def mode(self, logits):
|
65 |
+
mode_inds = logits.argmax(-1)
|
66 |
+
bucket_means = self.borders[:-1] + self.bucket_widths/2
|
67 |
+
return bucket_means[mode_inds]
|
68 |
+
|
69 |
+
def ei(self, logits, best_f, maximize=True): # logits: evaluation_points x batch x feature_dim
|
70 |
+
bucket_means = self.borders[:-1] + self.bucket_widths/2
|
71 |
+
if maximize:
|
72 |
+
bucket_contributions = torch.tensor(
|
73 |
+
[max((bucket_max + max(bucket_min, best_f)) / 2 - best_f,0) for
|
74 |
+
bucket_min, bucket_max, bucket_mean in zip(self.borders[:-1], self.borders[1:], bucket_means)], dtype=logits.dtype, device=logits.device)
|
75 |
+
else:
|
76 |
+
bucket_contributions = torch.tensor(
|
77 |
+
[-min((min(bucket_max,best_f) + bucket_min) / 2 - best_f,0) for # min on max instead of max on min, and compare min < instead of max >
|
78 |
+
bucket_min, bucket_max, bucket_mean in zip(self.borders[:-1], self.borders[1:], bucket_means)], dtype=logits.dtype, device=logits.device)
|
79 |
+
p = torch.softmax(logits, -1)
|
80 |
+
return p @ bucket_contributions
|
81 |
+
|
82 |
+
|
83 |
+
class FullSupportBarDistribution(BarDistribution):
|
84 |
+
@staticmethod
|
85 |
+
def halfnormal_with_p_weight_before(range_max,p=.5):
|
86 |
+
s = range_max / torch.distributions.HalfNormal(torch.tensor(1.)).icdf(torch.tensor(p))
|
87 |
+
return torch.distributions.HalfNormal(s)
|
88 |
+
|
89 |
+
def forward(self, logits, y): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
|
90 |
+
assert self.num_bars > 1
|
91 |
+
target_sample = self.map_to_bucket_idx(y)
|
92 |
+
target_sample.clamp_(0,self.num_bars-1)
|
93 |
+
assert logits.shape[-1] == self.num_bars
|
94 |
+
|
95 |
+
bucket_log_probs = torch.log_softmax(logits, -1)
|
96 |
+
scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
|
97 |
+
#print(bucket_log_probs, logits.shape)
|
98 |
+
log_probs = scaled_bucket_log_probs.gather(-1,target_sample.unsqueeze(-1)).squeeze(-1)
|
99 |
+
|
100 |
+
side_normals = (self.halfnormal_with_p_weight_before(self.bucket_widths[0]), self.halfnormal_with_p_weight_before(self.bucket_widths[-1]))
|
101 |
+
|
102 |
+
|
103 |
+
# TODO look over it again
|
104 |
+
log_probs[target_sample == 0] += side_normals[0].log_prob((self.borders[1]-y[target_sample == 0]).clamp(min=.00000001)) + torch.log(self.bucket_widths[0])
|
105 |
+
log_probs[target_sample == self.num_bars-1] += side_normals[1].log_prob(y[target_sample == self.num_bars-1]-self.borders[-2]) + torch.log(self.bucket_widths[-1])
|
106 |
+
|
107 |
+
|
108 |
+
return -log_probs
|
109 |
+
|
110 |
+
def mean(self, logits):
|
111 |
+
bucket_means = self.borders[:-1] + self.bucket_widths / 2
|
112 |
+
p = torch.softmax(logits, -1)
|
113 |
+
side_normals = (self.halfnormal_with_p_weight_before(self.bucket_widths[0]),
|
114 |
+
self.halfnormal_with_p_weight_before(self.bucket_widths[-1]))
|
115 |
+
bucket_means[0] = -side_normals[0].mean + self.borders[1]
|
116 |
+
bucket_means[-1] = side_normals[1].mean + self.borders[-2]
|
117 |
+
return p @ bucket_means
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
def get_bucket_limits(num_outputs:int, full_range:tuple=None, ys:torch.Tensor=None):
|
122 |
+
assert (ys is not None) or (full_range is not None)
|
123 |
+
if ys is not None:
|
124 |
+
ys = ys.flatten()
|
125 |
+
if len(ys) % num_outputs: ys = ys[:-(len(ys) % num_outputs)]
|
126 |
+
print(f'Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys.')
|
127 |
+
ys_per_bucket = len(ys) // num_outputs
|
128 |
+
if full_range is None:
|
129 |
+
full_range = (ys.min(), ys.max())
|
130 |
+
else:
|
131 |
+
assert full_range[0] <= ys.min() and full_range[1] >= ys.max()
|
132 |
+
full_range = torch.tensor(full_range)
|
133 |
+
ys_sorted, ys_order = ys.sort(0)
|
134 |
+
bucket_limits = (ys_sorted[ys_per_bucket-1::ys_per_bucket][:-1]+ys_sorted[ys_per_bucket::ys_per_bucket])/2
|
135 |
+
print(full_range)
|
136 |
+
bucket_limits = torch.cat([full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)],0)
|
137 |
+
|
138 |
+
else:
|
139 |
+
class_width = (full_range[1] - full_range[0]) / num_outputs
|
140 |
+
bucket_limits = torch.cat([full_range[0] + torch.arange(num_outputs).float()*class_width, torch.tensor(full_range[1]).unsqueeze(0)], 0)
|
141 |
+
|
142 |
+
assert len(bucket_limits) - 1 == num_outputs and full_range[0] == bucket_limits[0] and full_range[-1] == bucket_limits[-1]
|
143 |
+
return bucket_limits
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
|
prior-fitting/decoders.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import random
|
4 |
+
|
5 |
+
|
6 |
+
class ScaledDecoder(nn.Module):
|
7 |
+
def __init__(self, ninp, nhid, nout):
|
8 |
+
super().__init__()
|
9 |
+
self.linear = nn.Linear(ninp, nhid)
|
10 |
+
self.linear1 = nn.Linear(nhid, nout)
|
11 |
+
self.linear2 = nn.Linear(nhid, 10)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
#return torch.cat([self.linear1(x), self.linear2(x)], -1)
|
15 |
+
x = self.linear(x)
|
16 |
+
x = nn.GELU()(x)
|
17 |
+
temps = self.linear2(x).softmax(-1) @ torch.tensor([1.,1.4,1.7,2.,5.,10.,20.,40.,80.,160.], device=x.device)
|
18 |
+
if random.random() > .99:
|
19 |
+
print(temps.shape,temps[:,:2])
|
20 |
+
return self.linear1(x) / temps.unsqueeze(-1)
|
21 |
+
|
22 |
+
class FixedScaledDecoder(nn.Module):
|
23 |
+
def __init__(self, ninp, nhid, nout):
|
24 |
+
super().__init__()
|
25 |
+
self.mapper = nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, nout))
|
26 |
+
self.T = nn.Parameter(torch.ones(10000)/10000)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
return self.mapper(x)/self.T.sum()
|
30 |
+
|
prior-fitting/encoders.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
7 |
+
|
8 |
+
class _PositionalEncoding(nn.Module):
|
9 |
+
def __init__(self, d_model, dropout=0.):
|
10 |
+
super().__init__()
|
11 |
+
self.dropout = nn.Dropout(p=dropout)
|
12 |
+
self.d_model = d_model
|
13 |
+
self.device_test_tensor = nn.Parameter(torch.tensor(1.))
|
14 |
+
|
15 |
+
def forward(self, x):# T x B x num_features
|
16 |
+
assert self.d_model % x.shape[-1]*2 == 0
|
17 |
+
d_per_feature = self.d_model // x.shape[-1]
|
18 |
+
pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device)
|
19 |
+
#position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
20 |
+
interval_size = 10
|
21 |
+
div_term = (1./interval_size) * 2*math.pi*torch.exp(torch.arange(0, d_per_feature, 2, device=self.device_test_tensor.device).float()*math.log(math.sqrt(2)))
|
22 |
+
#print(div_term/2/math.pi)
|
23 |
+
pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term)
|
24 |
+
pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term)
|
25 |
+
return self.dropout(pe).view(x.shape[0],x.shape[1],self.d_model)
|
26 |
+
|
27 |
+
|
28 |
+
class EmbeddingEncoder(nn.Module):
|
29 |
+
def __init__(self, num_features, em_size, num_embs=100):
|
30 |
+
super().__init__()
|
31 |
+
self.num_embs = num_embs
|
32 |
+
self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True)
|
33 |
+
self.init_weights(.1)
|
34 |
+
self.min_max = (-2,+2)
|
35 |
+
|
36 |
+
@property
|
37 |
+
def width(self):
|
38 |
+
return self.min_max[1] - self.min_max[0]
|
39 |
+
|
40 |
+
def init_weights(self, initrange):
|
41 |
+
self.embeddings.weight.data.uniform_(-initrange, initrange)
|
42 |
+
|
43 |
+
def discretize(self, x):
|
44 |
+
split_size = self.width / self.num_embs
|
45 |
+
return (x - self.min_max[0] // split_size).int().clamp(0, self.num_embs - 1)
|
46 |
+
|
47 |
+
def forward(self, x): # T x B x num_features
|
48 |
+
x_idxs = self.discretize(x)
|
49 |
+
x_idxs += torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs
|
50 |
+
# print(x_idxs,self.embeddings.weight.shape)
|
51 |
+
return self.embeddings(x_idxs).mean(-2)
|
52 |
+
|
53 |
+
Linear = nn.Linear
|
54 |
+
MLP = lambda num_features, emsize: nn.Sequential(nn.Linear(num_features+1,emsize*2),
|
55 |
+
nn.ReLU(),
|
56 |
+
nn.Linear(emsize*2,emsize))
|
57 |
+
|
58 |
+
class Conv(nn.Module):
|
59 |
+
def __init__(self, input_size, emsize):
|
60 |
+
super().__init__()
|
61 |
+
self.convs = torch.nn.ModuleList([nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)])
|
62 |
+
self.linear = nn.Linear(64,emsize)
|
63 |
+
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
size = math.isqrt(x.shape[-1])
|
67 |
+
assert size*size == x.shape[-1]
|
68 |
+
x = x.reshape(*x.shape[:-1], 1, size, size)
|
69 |
+
for conv in self.convs:
|
70 |
+
if x.shape[-1] < 4:
|
71 |
+
break
|
72 |
+
x = conv(x)
|
73 |
+
x.relu_()
|
74 |
+
x = nn.AdaptiveAvgPool2d((1,1))(x).squeeze(-1).squeeze(-1)
|
75 |
+
return self.linear(x)
|
76 |
+
|
77 |
+
|
78 |
+
Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize)
|
79 |
+
|
80 |
+
|
81 |
+
class CanEmb(nn.Embedding):
|
82 |
+
def __init__(self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs):
|
83 |
+
assert embedding_dim % num_features == 0
|
84 |
+
embedding_dim = embedding_dim // num_features
|
85 |
+
super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
x = super().forward(x)
|
89 |
+
return x.view(*x.shape[:-2], -1)
|
90 |
+
|
91 |
+
def get_Canonical(num_classes):
|
92 |
+
return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize)
|
93 |
+
|
94 |
+
def get_Embedding(num_embs_per_feature=100):
|
95 |
+
return lambda num_features, emsize: EmbeddingEncoder(num_features, emsize, num_embs=num_embs_per_feature)
|
prior-fitting/losses.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class ScaledSoftmaxCE(nn.Module):
|
6 |
+
def forward(self, x, label):
|
7 |
+
logits = x[..., :-10]
|
8 |
+
temp_scales = x[..., -10:]
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
logprobs = logits.softmax(-1)
|
prior-fitting/mcmc_svi_transformer_on_bayesian.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import scipy.stats as st
|
2 |
+
from train import Losses
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
import time
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
|
14 |
+
import pyro
|
15 |
+
import pyro.distributions as dist
|
16 |
+
from pyro.nn import PyroModule, PyroSample
|
17 |
+
import torch.nn as nn
|
18 |
+
from pyro.infer.autoguide import AutoDiagonalNormal
|
19 |
+
from pyro.infer import SVI, Trace_ELBO, Predictive, MCMC, NUTS
|
20 |
+
from pyro import infer
|
21 |
+
import matplotlib.gridspec as gridspec
|
22 |
+
import os.path
|
23 |
+
import glob
|
24 |
+
from train import train, get_weighted_single_eval_pos_sampler
|
25 |
+
import priors
|
26 |
+
import encoders
|
27 |
+
from pyro.infer import SVGD, RBFSteinKernel
|
28 |
+
|
29 |
+
class CausalModel(PyroModule):
|
30 |
+
def __init__(self, model_spec, device='cuda'):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
self.device = device
|
34 |
+
self.num_features = model_spec['num_features']
|
35 |
+
|
36 |
+
mu, sigma = torch.tensor([0.0]).to(self.device), torch.tensor([1.0]).to(self.device)
|
37 |
+
|
38 |
+
self.fc1 = PyroModule[nn.Linear](self.num_features, model_spec['embed'])
|
39 |
+
self.drop = pyro.sample('drop', dist.Categorical(probs=torch.tensor([0.5, 0.5]).expand([model_spec['embed'], self.num_features, 2]))).float()
|
40 |
+
self.fc1.weight = PyroSample(dist.Normal(mu, 0.0000001+self.drop).expand([model_spec['embed'], self.num_features]).to_event(2))
|
41 |
+
self.fc1.bias = PyroSample(dist.Normal(mu, sigma).expand([model_spec['embed']]).to_event(1))
|
42 |
+
|
43 |
+
self.fc2 = PyroModule[nn.Linear](model_spec['embed'], 2)
|
44 |
+
self.fc2.weight = PyroSample(dist.Normal(mu, sigma).expand([2, model_spec['embed']]).to_event(2))
|
45 |
+
self.fc2.bias = PyroSample(dist.Normal(mu, sigma).expand([2]).to_event(1))
|
46 |
+
|
47 |
+
self.model = torch.nn.Sequential(self.fc1, self.fc2)
|
48 |
+
|
49 |
+
self.to(self.device)
|
50 |
+
|
51 |
+
def forward(self, x=None, y=None, seq_len=1):
|
52 |
+
if x is None:
|
53 |
+
with pyro.plate("x_plate", seq_len):
|
54 |
+
d_ = dist.Normal(torch.tensor([0.0]).to(self.device), torch.tensor([1.0]).to(self.device)).expand(
|
55 |
+
[self.num_features]).to_event(1)
|
56 |
+
x = pyro.sample("x", d_)
|
57 |
+
|
58 |
+
out = self.model(x)
|
59 |
+
mu = out.squeeze()
|
60 |
+
softmax = torch.nn.Softmax(dim=1)
|
61 |
+
# sigma = pyro.sample("sigma", dist.Uniform(torch.tensor([0.0]).to(self.device), torch.tensor([1.0]).to(self.device)))
|
62 |
+
with pyro.plate("data", out.shape[0]):
|
63 |
+
# d_ = dist.Normal(mu, sigma)
|
64 |
+
# obs = pyro.sample("obs", d_, obs=y)
|
65 |
+
s = softmax(mu)
|
66 |
+
obs = pyro.sample('obs', dist.Categorical(probs=s), obs=y).float()
|
67 |
+
|
68 |
+
return x, obs
|
69 |
+
|
70 |
+
class BayesianModel(PyroModule):
|
71 |
+
def __init__(self, model_spec, device='cuda'):
|
72 |
+
super().__init__()
|
73 |
+
|
74 |
+
self.device = device
|
75 |
+
self.num_features = model_spec['num_features']
|
76 |
+
|
77 |
+
mu, sigma = torch.tensor([0.0]).to(self.device), torch.tensor([1.0]).to(self.device)
|
78 |
+
|
79 |
+
self.fc1 = PyroModule[nn.Linear](self.num_features, model_spec['embed'])
|
80 |
+
self.fc1.weight = PyroSample(
|
81 |
+
dist.Normal(mu, sigma).expand([model_spec['embed'], self.num_features]).to_event(2))
|
82 |
+
self.fc1.bias = PyroSample(dist.Normal(mu, sigma).expand([model_spec['embed']]).to_event(1))
|
83 |
+
|
84 |
+
self.fc2 = PyroModule[nn.Linear](model_spec['embed'], 2)
|
85 |
+
self.fc2.weight = PyroSample(dist.Normal(mu, sigma).expand([2, model_spec['embed']]).to_event(2))
|
86 |
+
self.fc2.bias = PyroSample(dist.Normal(mu, sigma).expand([2]).to_event(1))
|
87 |
+
|
88 |
+
self.model = torch.nn.Sequential(self.fc1, self.fc2)
|
89 |
+
|
90 |
+
self.to(self.device)
|
91 |
+
|
92 |
+
def forward(self, x=None, y=None, seq_len=1):
|
93 |
+
if x is None:
|
94 |
+
with pyro.plate("x_plate", seq_len):
|
95 |
+
d_ = dist.Normal(torch.tensor([0.0]).to(self.device), torch.tensor([1.0]).to(self.device)).expand(
|
96 |
+
[self.num_features]).to_event(1)
|
97 |
+
x = pyro.sample("x", d_)
|
98 |
+
|
99 |
+
out = self.model(x)
|
100 |
+
mu = out.squeeze()
|
101 |
+
softmax = torch.nn.Softmax(dim=1)
|
102 |
+
# sigma = pyro.sample("sigma", dist.Uniform(torch.tensor([0.0]).to(self.device), torch.tensor([1.0]).to(self.device)))
|
103 |
+
with pyro.plate("data", out.shape[0]):
|
104 |
+
# d_ = dist.Normal(mu, sigma)
|
105 |
+
# obs = pyro.sample("obs", d_, obs=y)
|
106 |
+
s = softmax(mu)
|
107 |
+
obs = pyro.sample('obs', dist.Categorical(probs=s), obs=y).float()
|
108 |
+
|
109 |
+
return x, obs
|
110 |
+
|
111 |
+
|
112 |
+
def get_transformer_config(model_spec):
|
113 |
+
return {'lr': 2.006434218345026e-05
|
114 |
+
, 'epochs': 400
|
115 |
+
, 'dropout': 0.0
|
116 |
+
, 'emsize': 256
|
117 |
+
, 'batch_size': 256
|
118 |
+
, 'nlayers': 5
|
119 |
+
, 'num_outputs': 1
|
120 |
+
, 'num_features': model_spec['num_features']
|
121 |
+
, 'steps_per_epoch': 100
|
122 |
+
, 'nhead': 4
|
123 |
+
, 'dropout': 0.0
|
124 |
+
, 'seq_len': model_spec['seq_len']
|
125 |
+
, 'nhid_factor': 2}
|
126 |
+
|
127 |
+
|
128 |
+
def get_model(model_generator, config, should_train=True, device='cuda'):
|
129 |
+
epochs = 0 if not should_train else config['epochs']
|
130 |
+
|
131 |
+
model = train(priors.pyro.DataLoader
|
132 |
+
, Losses.bce
|
133 |
+
, encoders.Linear
|
134 |
+
, emsize=config['emsize']
|
135 |
+
, nhead=config['nhead']
|
136 |
+
, y_encoder_generator=encoders.Linear
|
137 |
+
, pos_encoder_generator=None
|
138 |
+
, batch_size=config['batch_size']
|
139 |
+
, nlayers=config['nlayers']
|
140 |
+
, nhid=config['emsize'] * config['nhid_factor']
|
141 |
+
, epochs=epochs
|
142 |
+
, warmup_epochs=config['epochs'] // 4
|
143 |
+
, bptt=config['seq_len']
|
144 |
+
, gpu_device=device
|
145 |
+
, dropout=config['dropout']
|
146 |
+
, steps_per_epoch=config['steps_per_epoch']
|
147 |
+
, single_eval_pos_gen=get_weighted_single_eval_pos_sampler(100)
|
148 |
+
, extra_prior_kwargs_dict={
|
149 |
+
'num_outputs': config['num_outputs']
|
150 |
+
, 'num_features': config['num_features']
|
151 |
+
, 'canonical_args': None
|
152 |
+
, 'fuse_x_y': False
|
153 |
+
, 'model': model_generator
|
154 |
+
}
|
155 |
+
, lr=config['lr']
|
156 |
+
, verbose=True)
|
157 |
+
|
158 |
+
return model
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
def plot_features(data, targets):
|
163 |
+
fig2 = plt.figure(constrained_layout=True, figsize=(12, 12))
|
164 |
+
spec2 = gridspec.GridSpec(ncols=data.shape[1], nrows=data.shape[1], figure=fig2)
|
165 |
+
for d in range(0, data.shape[1]):
|
166 |
+
for d2 in range(0, data.shape[1]):
|
167 |
+
sub_ax = fig2.add_subplot(spec2[d, d2])
|
168 |
+
sub_ax.scatter(data[:, d].detach().cpu().numpy(), data[:, d2].detach().cpu().numpy(),
|
169 |
+
c=targets[:].detach().cpu().numpy())
|
170 |
+
|
171 |
+
|
172 |
+
def evaluate_preds(preds, y_test):
|
173 |
+
preds_hard = preds['obs'] > 0.5 # TODO: 0.5 or 0
|
174 |
+
acc = (preds_hard == y_test).float().mean()
|
175 |
+
means = preds_hard.float().mean(axis=0)
|
176 |
+
|
177 |
+
# var = preds['obs'].var(axis=0)
|
178 |
+
nll = nn.BCELoss()(means.float(), y_test.float())
|
179 |
+
mse = Losses.mse(means, y_test).mean()
|
180 |
+
|
181 |
+
return acc, nll, mse
|
182 |
+
|
183 |
+
|
184 |
+
def load_results(path, task='steps'):
|
185 |
+
results_nll = []
|
186 |
+
results_acc = []
|
187 |
+
times = []
|
188 |
+
samples_list = []
|
189 |
+
|
190 |
+
files = glob.glob(f'/home/hollmann/prior-fitting/{path}_*.npy')
|
191 |
+
for file in files:
|
192 |
+
print(file)
|
193 |
+
with open(file, 'rb') as f:
|
194 |
+
if task == 'steps':
|
195 |
+
nll, acc, elapsed = np.load(f, allow_pickle=True)
|
196 |
+
samples_list += [file]
|
197 |
+
else:
|
198 |
+
samples, nll, acc, elapsed = np.load(f, allow_pickle=True)
|
199 |
+
samples_list += [samples]
|
200 |
+
times += [elapsed]
|
201 |
+
results_nll += [nll]
|
202 |
+
results_acc += [acc]
|
203 |
+
results_acc = np.array(results_acc)
|
204 |
+
results_nll = np.array(results_nll)
|
205 |
+
times = np.array(times)
|
206 |
+
files = np.array(files)
|
207 |
+
samples = np.array(samples_list)
|
208 |
+
means = np.array([compute_mean_and_conf_interval(results_nll[n, :])[0] for n in range(0, results_nll.shape[0])])
|
209 |
+
conf = np.array([compute_mean_and_conf_interval(results_nll[n, :])[1] for n in range(0, results_nll.shape[0])])
|
210 |
+
|
211 |
+
if task == 'steps':
|
212 |
+
sorter = np.argsort(times, axis=0)
|
213 |
+
else:
|
214 |
+
sorter = np.argsort(samples, axis=0)
|
215 |
+
|
216 |
+
results_nll, results_acc, times, files, samples, means, conf = results_nll[sorter], results_acc[sorter], times[sorter], files[sorter], samples[sorter], means[sorter], conf[sorter]
|
217 |
+
|
218 |
+
return files, times, samples, means, conf
|
219 |
+
|
220 |
+
def plot_with_confidence_intervals(ax_or_pyplot, x, mean, confidence, **common_kwargs):
|
221 |
+
ax_or_pyplot.plot(x,mean,**common_kwargs)
|
222 |
+
if 'label' in common_kwargs:
|
223 |
+
common_kwargs.pop('label')
|
224 |
+
if 'marker' in common_kwargs:
|
225 |
+
common_kwargs.pop('marker')
|
226 |
+
ax_or_pyplot.fill_between(x, (mean-confidence), (mean+confidence), alpha=.1, **common_kwargs)
|
227 |
+
|
228 |
+
|
229 |
+
def compute_mean_and_conf_interval(accuracies, confidence=.95):
|
230 |
+
accuracies = np.array(accuracies)
|
231 |
+
n = len(accuracies)
|
232 |
+
m, se = np.mean(accuracies), st.sem(accuracies)
|
233 |
+
h = se * st.t.ppf((1 + confidence) / 2., n - 1)
|
234 |
+
return m, h
|
235 |
+
|
236 |
+
|
237 |
+
def generate_toy_data(model, bptt, device='cpu'):
|
238 |
+
n_samples = 100
|
239 |
+
X_list, y_list = [], []
|
240 |
+
torch.manual_seed(0)
|
241 |
+
for _ in range(0, n_samples):
|
242 |
+
X_sample, y_sample = model(seq_len=bptt)
|
243 |
+
X_list += [X_sample]
|
244 |
+
y_list += [y_sample]
|
245 |
+
X = torch.stack(X_list, 0)
|
246 |
+
y = torch.stack(y_list, 0)
|
247 |
+
# y = (y > 0).float()
|
248 |
+
|
249 |
+
return X.to(device), y.to(device)
|
250 |
+
|
251 |
+
|
252 |
+
|
253 |
+
def eval_svi(X, y, device, model_sampler, training_samples_n, num_train_steps, num_pred_samples, lr=1e-3, num_particles=1, svgd=False):
|
254 |
+
X_test, y_test = X[:, training_samples_n:], y[:, training_samples_n:]
|
255 |
+
X_train, y_train = X[:, 0:training_samples_n], y[:, 0:training_samples_n]
|
256 |
+
|
257 |
+
nll_list = []
|
258 |
+
acc_list = []
|
259 |
+
for sample_id in tqdm(list(range(0, X_test.shape[0]))):
|
260 |
+
model = model_sampler()
|
261 |
+
guide = AutoDiagonalNormal(model).to(device)
|
262 |
+
adam = pyro.optim.Adam({"lr": lr})
|
263 |
+
svi = SVI(model, guide, adam, loss=Trace_ELBO(num_particles=num_particles))
|
264 |
+
|
265 |
+
if svgd:
|
266 |
+
kernel = RBFSteinKernel()
|
267 |
+
svi = SVGD(model, kernel, adam, num_particles=50, max_plate_nesting=0)
|
268 |
+
|
269 |
+
pyro.clear_param_store()
|
270 |
+
|
271 |
+
X_test_sample, y_test_sample, X_train_sample, y_train_sample = X_test[sample_id], y_test[sample_id], X_train[
|
272 |
+
sample_id], y_train[sample_id]
|
273 |
+
|
274 |
+
acc, nll, mse = 0.0, 0.0, 0.0
|
275 |
+
# bar = tqdm(list(range(num_train_steps)))
|
276 |
+
bar = list(range(num_train_steps))
|
277 |
+
for epoch in bar:
|
278 |
+
loss = svi.step(X_train_sample, y_train_sample)
|
279 |
+
# if epoch % 100 == 1:
|
280 |
+
# bar.set_postfix(loss=f'{loss / X_train_sample.shape[0]:.3f}', test_nll=f'{nll:.3f}', test_acc=f'{acc:.3f}')
|
281 |
+
|
282 |
+
predictive = Predictive(model, guide=guide, num_samples=num_pred_samples)
|
283 |
+
preds = predictive(X_test_sample)
|
284 |
+
acc, nll, mse = evaluate_preds(preds, y_test_sample)
|
285 |
+
nll_list += [nll.detach().cpu().numpy()]
|
286 |
+
acc_list += [acc.detach().cpu().numpy()]
|
287 |
+
|
288 |
+
return np.array(nll_list), np.array(acc_list)
|
289 |
+
|
290 |
+
|
291 |
+
def eval_mcmc(X, y, device, model_sampler, training_samples_n, warmup_steps, num_pred_samples):
|
292 |
+
X_test, y_test = X[:, training_samples_n:].to(device), y[:, training_samples_n:].to(device)
|
293 |
+
X_train, y_train = X[:, 0:training_samples_n].to(device), y[:, 0:training_samples_n].to(device)
|
294 |
+
|
295 |
+
acc_list, nll_list = [], []
|
296 |
+
for sample_id in tqdm(list(range(0, X_test.shape[0]))):
|
297 |
+
X_test_sample, y_test_sample, X_train_sample, y_train_sample = X_test[sample_id], y_test[sample_id], X_train[
|
298 |
+
sample_id], y_train[sample_id]
|
299 |
+
|
300 |
+
model = model_sampler()
|
301 |
+
mcmc = MCMC(NUTS(model), num_samples=num_pred_samples, num_chains=1, disable_progbar=True,
|
302 |
+
warmup_steps=warmup_steps, mp_context="fork")
|
303 |
+
mcmc.run(X_train_sample, y_train_sample)
|
304 |
+
preds = infer.mcmc.util.predictive(model, mcmc.get_samples(), X_test_sample, None)
|
305 |
+
acc, nll, mse = evaluate_preds(preds, y_test_sample)
|
306 |
+
nll_list += [nll.detach().cpu().numpy()]
|
307 |
+
acc_list += [acc.detach().cpu().numpy()]
|
308 |
+
|
309 |
+
return np.array(nll_list), np.array(acc_list)
|
310 |
+
|
311 |
+
|
312 |
+
def eval_transformer(X, y, device, model, training_samples_n):
|
313 |
+
X_sample, y_sample = X.transpose(0, 1), y.transpose(0, 1).float()
|
314 |
+
bs = 1
|
315 |
+
samples = []
|
316 |
+
for i in range(0, X_sample.shape[1] // bs):
|
317 |
+
samples += [(X_sample[:, bs * i:bs * (i + 1)], y_sample[:, bs * i:bs * (i + 1)])]
|
318 |
+
|
319 |
+
mean = X_sample[:training_samples_n].mean(0)
|
320 |
+
std = X_sample[:training_samples_n].std(0) + .000001
|
321 |
+
X_sample = (X_sample - mean) / std
|
322 |
+
|
323 |
+
start = time.time()
|
324 |
+
output = torch.cat(
|
325 |
+
[model.to(device)((X_sample_chunk, y_sample_chunk), single_eval_pos=training_samples_n).squeeze(-1) for
|
326 |
+
(X_sample_chunk, y_sample_chunk) in samples], 1)
|
327 |
+
elapsed = time.time() - start
|
328 |
+
|
329 |
+
output = output.detach().cpu()
|
330 |
+
acc = ((torch.sigmoid(output) > 0.5) == y_sample[training_samples_n:].cpu().bool()).float().mean(axis=0)
|
331 |
+
nll = nn.BCELoss(reduction='none')(torch.sigmoid(output.float()), y_sample[training_samples_n:].cpu().float()).mean(
|
332 |
+
axis=0)
|
333 |
+
return acc, nll, elapsed
|
334 |
+
|
335 |
+
|
336 |
+
def training_steps(method, X, y, model_spec, device='cpu', path_interfix='', overwrite=False):
|
337 |
+
training_samples_n = 100
|
338 |
+
for s in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
339 |
+
path = f'/home/hollmann/prior-fitting/{path_interfix}/results_{method}_training_steps_{s}.npy'
|
340 |
+
if (os.path.isfile(path)) and not overwrite:
|
341 |
+
print(f'already done {s}')
|
342 |
+
continue
|
343 |
+
|
344 |
+
start = time.time()
|
345 |
+
if method == 'svi':
|
346 |
+
nll, acc = eval_svi(X, y, device, model_spec, training_samples_n, num_train_steps=s, num_pred_samples=s, svgd=False)
|
347 |
+
elif method == 'svgd':
|
348 |
+
nll, acc = eval_svi(X, y, device, model_spec, training_samples_n, num_train_steps=s, num_pred_samples=s, svgd=True)
|
349 |
+
elif method == 'mcmc':
|
350 |
+
nll, acc = eval_mcmc(X, y, device, model_spec, training_samples_n, warmup_steps=s, num_pred_samples=s)
|
351 |
+
elapsed = time.time() - start
|
352 |
+
|
353 |
+
print(s)
|
354 |
+
print('NLL ', compute_mean_and_conf_interval(nll))
|
355 |
+
print('ACC ', compute_mean_and_conf_interval(acc))
|
356 |
+
print('TIME ', elapsed)
|
357 |
+
|
358 |
+
with open(path, 'wb') as f:
|
359 |
+
np.save(f, (np.array(nll), np.array(acc), elapsed))
|
360 |
+
|
361 |
+
print(f'Saved results at {path}')
|
362 |
+
|
363 |
+
|
364 |
+
def training_samples(method, X, y, model_spec, evaluation_points, steps = None, device='cpu', path_interfix='', overwrite=False):
|
365 |
+
num_pred_samples_mcmc = steps if steps else 512
|
366 |
+
warmup_steps = steps if steps else 512
|
367 |
+
|
368 |
+
num_pred_samples_svi = steps if steps else 1024
|
369 |
+
num_train_steps = steps if steps else 1024
|
370 |
+
|
371 |
+
num_pred_samples = num_pred_samples_svi if method == 'svi' else num_pred_samples_mcmc
|
372 |
+
|
373 |
+
for training_samples_n in evaluation_points:
|
374 |
+
path = f'/home/hollmann/prior-fitting/{path_interfix}/results_{method}_{num_pred_samples}_training_samples_{training_samples_n}.npy'
|
375 |
+
if (os.path.isfile(path)) and not overwrite:
|
376 |
+
print(f'already done {training_samples_n}')
|
377 |
+
continue
|
378 |
+
|
379 |
+
start = time.time()
|
380 |
+
if method == 'svi':
|
381 |
+
nll, acc = eval_svi(X, y, device, model_spec, training_samples_n, num_train_steps=num_train_steps, num_pred_samples=num_pred_samples)
|
382 |
+
elif method == 'svgd':
|
383 |
+
nll, acc = eval_svi(X, y, device, model_spec, training_samples_n, num_train_steps=num_train_steps, num_pred_samples=num_pred_samples, svgd=True)
|
384 |
+
elif method == 'mcmc':
|
385 |
+
nll, acc = eval_mcmc(X, y, device, model_spec, training_samples_n, warmup_steps=warmup_steps, num_pred_samples=num_pred_samples)
|
386 |
+
elapsed = time.time() - start
|
387 |
+
|
388 |
+
print('NLL ', compute_mean_and_conf_interval(nll))
|
389 |
+
print('ACC ', compute_mean_and_conf_interval(acc))
|
390 |
+
print('TIME ', elapsed)
|
391 |
+
|
392 |
+
with open(path, 'wb') as f:
|
393 |
+
np.save(f, (training_samples_n, np.array(nll), np.array(acc), elapsed))
|
394 |
+
|
395 |
+
### MAIN
|
396 |
+
def get_default_model_spec(size):
|
397 |
+
bptt = 300
|
398 |
+
|
399 |
+
if size == 'big':
|
400 |
+
num_features = 8
|
401 |
+
embed = 64
|
402 |
+
nlayers = 2
|
403 |
+
elif size == 'small':
|
404 |
+
num_features = 3
|
405 |
+
embed = 5
|
406 |
+
nlayers = 2
|
407 |
+
else:
|
408 |
+
num_features = int(size.split("_")[0])
|
409 |
+
embed = int(size.split("_")[1])
|
410 |
+
nlayers = int(size.split("_")[2])
|
411 |
+
|
412 |
+
return {'nlayers': nlayers, 'embed': embed, 'num_features': num_features, "seq_len": bptt}
|
413 |
+
|
414 |
+
def get_default_evaluation_points():
|
415 |
+
return list(range(2, 100, 5))
|
416 |
+
|
417 |
+
if __name__ == '__main__':
|
418 |
+
parser = argparse.ArgumentParser()
|
419 |
+
parser.add_argument('--solver', default='svi', type=str)
|
420 |
+
parser.add_argument('--task', default='steps', type=str)
|
421 |
+
parser.add_argument('--model_size', default='small', type=str)
|
422 |
+
|
423 |
+
args = parser.parse_args()
|
424 |
+
|
425 |
+
model_spec = get_default_model_spec(args.model_size)
|
426 |
+
evaluation_points = get_default_evaluation_points()
|
427 |
+
device = 'cuda:0' if args.solver == 'svi' else 'cpu'
|
428 |
+
|
429 |
+
torch.manual_seed(0)
|
430 |
+
test_model = BayesianModel(model_spec, device=device)
|
431 |
+
|
432 |
+
X, y = generate_toy_data(test_model, model_spec['seq_len'])
|
433 |
+
model_sampler = lambda: BayesianModel(model_spec, device=device)
|
434 |
+
|
435 |
+
if args.task == 'steps':
|
436 |
+
training_steps(args.solver, X, y, model_sampler, device=device,
|
437 |
+
path_interfix=f'results/timing_{args.model_size}_model', svgd=args.svgd)
|
438 |
+
elif args.task == 'samples':
|
439 |
+
training_samples(args.solver, X, y, model_sampler, evaluation_points, device=device,
|
440 |
+
path_interfix=f'results/timing_{args.model_size}_model', svgd=args.svgd)
|
441 |
+
|
442 |
+
|
443 |
+
|
prior-fitting/notebooks/BayesianModels_And_Custom_Pyro_Modules.ipynb
ADDED
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 56,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import torch\n",
|
10 |
+
"import numpy as np\n",
|
11 |
+
"\n",
|
12 |
+
"import priors\n",
|
13 |
+
"from train import train, get_weighted_single_eval_pos_sampler\n",
|
14 |
+
"import encoders\n",
|
15 |
+
"import positional_encodings\n",
|
16 |
+
"import utils\n",
|
17 |
+
"import bar_distribution\n",
|
18 |
+
"import decoders\n",
|
19 |
+
"from datasets import *\n",
|
20 |
+
"import os\n",
|
21 |
+
"\n",
|
22 |
+
"from tqdm import tqdm\n",
|
23 |
+
"import time\n",
|
24 |
+
"\n",
|
25 |
+
"import torch\n",
|
26 |
+
"import pandas as pd\n",
|
27 |
+
"import numpy as np\n",
|
28 |
+
"import matplotlib.pyplot as plt\n",
|
29 |
+
"\n",
|
30 |
+
"import torch.nn as nn\n",
|
31 |
+
"import os.path\n",
|
32 |
+
"import glob\n",
|
33 |
+
"\n",
|
34 |
+
"from mcmc_svi_transformer_on_bayesian import get_model, get_default_model_spec, generate_toy_data, load_results, plot_with_confidence_intervals, training_steps, training_samples, get_default_evaluation_points, compute_mean_and_conf_interval, eval_transformer\n"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": 4,
|
40 |
+
"metadata": {},
|
41 |
+
"outputs": [],
|
42 |
+
"source": [
|
43 |
+
"# %load_ext autoreload\n",
|
44 |
+
"\n",
|
45 |
+
"# %autoreload 2"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "code",
|
50 |
+
"execution_count": 3,
|
51 |
+
"metadata": {},
|
52 |
+
"outputs": [],
|
53 |
+
"source": [
|
54 |
+
"## DEFINE A PRIOR MODEL ##\n",
|
55 |
+
"# We define a Bayesian Model as a prior for all methods\n",
|
56 |
+
"# This can be replaced by other models that inherit from PyroModule.\n",
|
57 |
+
"class BayesianModel(PyroModule):\n",
|
58 |
+
" def __init__(self, model_spec, device='cuda'):\n",
|
59 |
+
" super().__init__()\n",
|
60 |
+
"\n",
|
61 |
+
" self.device = device\n",
|
62 |
+
" self.num_features = model_spec['num_features']\n",
|
63 |
+
"\n",
|
64 |
+
" mu, sigma = torch.tensor([0.0]).to(self.device), torch.tensor([1.0]).to(self.device)\n",
|
65 |
+
"\n",
|
66 |
+
" self.fc1 = PyroModule[nn.Linear](self.num_features, model_spec['embed'])\n",
|
67 |
+
" self.fc1.weight = PyroSample(\n",
|
68 |
+
" dist.Normal(mu, sigma).expand([model_spec['embed'], self.num_features]).to_event(2))\n",
|
69 |
+
" self.fc1.bias = PyroSample(dist.Normal(mu, sigma).expand([model_spec['embed']]).to_event(1))\n",
|
70 |
+
"\n",
|
71 |
+
" self.fc2 = PyroModule[nn.Linear](model_spec['embed'], 2)\n",
|
72 |
+
" self.fc2.weight = PyroSample(dist.Normal(mu, sigma).expand([2, model_spec['embed']]).to_event(2))\n",
|
73 |
+
" self.fc2.bias = PyroSample(dist.Normal(mu, sigma).expand([2]).to_event(1))\n",
|
74 |
+
"\n",
|
75 |
+
" self.model = torch.nn.Sequential(self.fc1, self.fc2)\n",
|
76 |
+
"\n",
|
77 |
+
" self.to(self.device)\n",
|
78 |
+
"\n",
|
79 |
+
" def forward(self, x=None, y=None, seq_len=1):\n",
|
80 |
+
" if x is None:\n",
|
81 |
+
" with pyro.plate(\"x_plate\", seq_len):\n",
|
82 |
+
" d_ = dist.Normal(torch.tensor([0.0]).to(self.device), torch.tensor([1.0]).to(self.device)).expand(\n",
|
83 |
+
" [self.num_features]).to_event(1)\n",
|
84 |
+
" x = pyro.sample(\"x\", d_)\n",
|
85 |
+
"\n",
|
86 |
+
" out = self.model(x)\n",
|
87 |
+
" mu = out.squeeze()\n",
|
88 |
+
" softmax = torch.nn.Softmax(dim=1)\n",
|
89 |
+
" with pyro.plate(\"data\", out.shape[0]):\n",
|
90 |
+
" s = softmax(mu)\n",
|
91 |
+
" obs = pyro.sample('obs', dist.Categorical(probs=s), obs=y).float()\n",
|
92 |
+
"\n",
|
93 |
+
" return x, obs"
|
94 |
+
]
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"cell_type": "code",
|
98 |
+
"execution_count": 69,
|
99 |
+
"metadata": {},
|
100 |
+
"outputs": [],
|
101 |
+
"source": [
|
102 |
+
"results_directory = 'results' # Where to save results\n",
|
103 |
+
"model_spec_size = 'small' # Size of the BNN model to evaluate, also try big\n",
|
104 |
+
"bptt = 100 # Number of samples in each dataset\n",
|
105 |
+
"\n",
|
106 |
+
"# Training samples seen after which to evaluate the methods\n",
|
107 |
+
"evaluation_points = [2, 7, 12, 17, 22, 27, 32, 37, 42, 47, 52, 57, 62, 67, 72, 77, 82, 87, 92]\n",
|
108 |
+
"\n",
|
109 |
+
"# Function which generates a model from the prior\n",
|
110 |
+
"model_sampler = lambda : BayesianModel(get_default_model_spec(model_spec_size), device = device)\n",
|
111 |
+
"\n",
|
112 |
+
"global_results = {} # Dict in which to save results\n",
|
113 |
+
"task = 'samples' # Task to evaluate, only option is samples, keep fixed"
|
114 |
+
]
|
115 |
+
},
|
116 |
+
{
|
117 |
+
"cell_type": "code",
|
118 |
+
"execution_count": null,
|
119 |
+
"metadata": {},
|
120 |
+
"outputs": [],
|
121 |
+
"source": [
|
122 |
+
"!mkdir {results_directory}"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "markdown",
|
127 |
+
"metadata": {
|
128 |
+
"heading_collapsed": true
|
129 |
+
},
|
130 |
+
"source": [
|
131 |
+
"### Evaluate SVI and MCMC"
|
132 |
+
]
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "code",
|
136 |
+
"execution_count": 25,
|
137 |
+
"metadata": {
|
138 |
+
"hidden": true
|
139 |
+
},
|
140 |
+
"outputs": [],
|
141 |
+
"source": [
|
142 |
+
"method = 'svi'\n",
|
143 |
+
"steps = 1\n",
|
144 |
+
"device = 'cuda'\n",
|
145 |
+
"path_interfix = f'{results_directory}/timing_{model_spec_size}_model_test'"
|
146 |
+
]
|
147 |
+
},
|
148 |
+
{
|
149 |
+
"cell_type": "code",
|
150 |
+
"execution_count": 26,
|
151 |
+
"metadata": {
|
152 |
+
"hidden": true
|
153 |
+
},
|
154 |
+
"outputs": [],
|
155 |
+
"source": [
|
156 |
+
"!mkdir {path_interfix}"
|
157 |
+
]
|
158 |
+
},
|
159 |
+
{
|
160 |
+
"cell_type": "code",
|
161 |
+
"execution_count": 27,
|
162 |
+
"metadata": {
|
163 |
+
"hidden": true
|
164 |
+
},
|
165 |
+
"outputs": [
|
166 |
+
{
|
167 |
+
"name": "stderr",
|
168 |
+
"output_type": "stream",
|
169 |
+
"text": [
|
170 |
+
"100%|██████████| 100/100 [00:02<00:00, 37.13it/s]\n",
|
171 |
+
"/home/hollmann/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/numpy/lib/npyio.py:528: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
|
172 |
+
" arr = np.asanyarray(arr)\n"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"name": "stdout",
|
177 |
+
"output_type": "stream",
|
178 |
+
"text": [
|
179 |
+
"NLL (51.540817, 1.832436208065078)\n",
|
180 |
+
"ACC (0.48459178, 0.01832436154844232)\n",
|
181 |
+
"TIME 2.6950523853302\n"
|
182 |
+
]
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"name": "stderr",
|
186 |
+
"output_type": "stream",
|
187 |
+
"text": [
|
188 |
+
"100%|██████████| 100/100 [00:02<00:00, 35.89it/s]\n"
|
189 |
+
]
|
190 |
+
},
|
191 |
+
{
|
192 |
+
"name": "stdout",
|
193 |
+
"output_type": "stream",
|
194 |
+
"text": [
|
195 |
+
"NLL (48.569893, 1.8696300575034437)\n",
|
196 |
+
"ACC (0.51430106, 0.01869630134377999)\n",
|
197 |
+
"TIME 2.788970708847046\n"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"name": "stderr",
|
202 |
+
"output_type": "stream",
|
203 |
+
"text": [
|
204 |
+
"100%|██████████| 100/100 [00:03<00:00, 31.80it/s]\n"
|
205 |
+
]
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"name": "stdout",
|
209 |
+
"output_type": "stream",
|
210 |
+
"text": [
|
211 |
+
"NLL (51.034092, 1.807273770560027)\n",
|
212 |
+
"ACC (0.48965907, 0.018072737823868815)\n",
|
213 |
+
"TIME 3.1472866535186768\n"
|
214 |
+
]
|
215 |
+
},
|
216 |
+
{
|
217 |
+
"name": "stderr",
|
218 |
+
"output_type": "stream",
|
219 |
+
"text": [
|
220 |
+
"100%|██████████| 100/100 [00:02<00:00, 38.48it/s]\n"
|
221 |
+
]
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"name": "stdout",
|
225 |
+
"output_type": "stream",
|
226 |
+
"text": [
|
227 |
+
"NLL (50.216866, 2.0121896389094833)\n",
|
228 |
+
"ACC (0.4978313, 0.02012189562034928)\n",
|
229 |
+
"TIME 2.600956439971924\n"
|
230 |
+
]
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"name": "stderr",
|
234 |
+
"output_type": "stream",
|
235 |
+
"text": [
|
236 |
+
" 55%|█████▌ | 55/100 [00:01<00:01, 38.41it/s]\n"
|
237 |
+
]
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"ename": "KeyboardInterrupt",
|
241 |
+
"evalue": "",
|
242 |
+
"output_type": "error",
|
243 |
+
"traceback": [
|
244 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
245 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
246 |
+
"\u001b[0;32m/tmp/ipykernel_9449/1948451174.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgenerate_toy_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbptt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m training_samples(method\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
247 |
+
"\u001b[0;32m~/prior-fitting/mcmc_svi_transformer_on_bayesian.py\u001b[0m in \u001b[0;36mtraining_samples\u001b[0;34m(method, X, y, model_spec, evaluation_points, steps, device, path_interfix, overwrite)\u001b[0m\n\u001b[1;32m 379\u001b[0m \u001b[0mstart\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmethod\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'svi'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 381\u001b[0;31m \u001b[0mnll\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0macc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0meval_svi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_spec\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining_samples_n\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_train_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_train_steps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_pred_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_pred_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 382\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mmethod\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'svgd'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 383\u001b[0m \u001b[0mnll\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0macc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0meval_svi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_spec\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining_samples_n\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_train_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_train_steps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_pred_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_pred_samples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msvgd\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
248 |
+
"\u001b[0;32m~/prior-fitting/mcmc_svi_transformer_on_bayesian.py\u001b[0m in \u001b[0;36meval_svi\u001b[0;34m(X, y, device, model_sampler, training_samples_n, num_train_steps, num_pred_samples, lr, num_particles, svgd)\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 282\u001b[0m \u001b[0mpredictive\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mPredictive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mguide\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mguide\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_pred_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 283\u001b[0;31m \u001b[0mpreds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredictive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_test_sample\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 284\u001b[0m \u001b[0macc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnll\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluate_preds\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpreds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test_sample\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0mnll_list\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mnll\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
249 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1052\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
250 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/infer/predictive.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 271\u001b[0m \u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 272\u001b[0m )\n\u001b[0;32m--> 273\u001b[0;31m return _predictive(\n\u001b[0m\u001b[1;32m 274\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 275\u001b[0m \u001b[0mposterior_samples\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
251 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/infer/predictive.py\u001b[0m in \u001b[0;36m_predictive\u001b[0;34m(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mparallel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m return _predictive_sequential(\n\u001b[0m\u001b[1;32m 128\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0mposterior_samples\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
252 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/infer/predictive.py\u001b[0m in \u001b[0;36m_predictive_sequential\u001b[0;34m(model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes, return_trace)\u001b[0m\n\u001b[1;32m 46\u001b[0m ]\n\u001b[1;32m 47\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m trace = poutine.trace(poutine.condition(model, samples[i])).get_trace(\n\u001b[0m\u001b[1;32m 49\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mmodel_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m )\n",
|
253 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py\u001b[0m in \u001b[0;36mget_trace\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[0mCalls\u001b[0m \u001b[0mthis\u001b[0m \u001b[0mpoutine\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mreturns\u001b[0m \u001b[0mits\u001b[0m \u001b[0mtrace\u001b[0m \u001b[0minstead\u001b[0m \u001b[0mof\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;31m'\u001b[0m\u001b[0ms\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 197\u001b[0m \"\"\"\n\u001b[0;32m--> 198\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 199\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmsngr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
254 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 172\u001b[0m )\n\u001b[1;32m 173\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 174\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 175\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mValueError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 176\u001b[0m \u001b[0mexc_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexc_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraceback\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
255 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_context_wrap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
256 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/torch/autograd/grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
257 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_context_wrap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
258 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/nn/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 424\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 425\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pyro_context\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 426\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 427\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 428\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__getattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
259 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1052\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
260 |
+
"\u001b[0;32m/tmp/ipykernel_9449/3309204952.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, y, seq_len)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"x\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0mmu\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0msoftmax\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
261 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1052\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
262 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 140\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
263 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/nn/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 424\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 425\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pyro_context\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 426\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 427\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 428\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__getattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
264 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1052\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
265 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 96\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 97\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
266 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/nn/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 477\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprior\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"sample\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# if not a distribution\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 478\u001b[0m \u001b[0mprior\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprior\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 479\u001b[0;31m \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfullname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprior\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 480\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfullname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 481\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
267 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/primitives.py\u001b[0m in \u001b[0;36msample\u001b[0;34m(name, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 162\u001b[0m }\n\u001b[1;32m 163\u001b[0m \u001b[0;31m# apply the stack and return its return value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 164\u001b[0;31m \u001b[0mapply_stack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 165\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"value\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
268 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/poutine/runtime.py\u001b[0m in \u001b[0;36mapply_stack\u001b[0;34m(initial_msg)\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[0mpointer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpointer\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 212\u001b[0;31m \u001b[0mframe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_process_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 213\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"stop\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
269 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_process_message\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0mon\u001b[0m \u001b[0mmessage\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mThe\u001b[0m \u001b[0mmessage\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mupdated\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mplace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 140\u001b[0m \"\"\"\n\u001b[0;32m--> 141\u001b[0;31m \u001b[0mmethod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"_pyro_{}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"type\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 142\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmethod\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
270 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
271 |
+
]
|
272 |
+
}
|
273 |
+
],
|
274 |
+
"source": [
|
275 |
+
"X, y = generate_toy_data(test_model, bptt, device)\n",
|
276 |
+
"\n",
|
277 |
+
"training_samples(method\n",
|
278 |
+
" , X\n",
|
279 |
+
" , y\n",
|
280 |
+
" , model_sampler\n",
|
281 |
+
" , evaluation_points\n",
|
282 |
+
" , steps=steps\n",
|
283 |
+
" , device=device\n",
|
284 |
+
" , path_interfix=path_interfix)"
|
285 |
+
]
|
286 |
+
},
|
287 |
+
{
|
288 |
+
"cell_type": "markdown",
|
289 |
+
"metadata": {
|
290 |
+
"heading_collapsed": true
|
291 |
+
},
|
292 |
+
"source": [
|
293 |
+
"### Training Transformer on Prior (Skip this step to reuse results)"
|
294 |
+
]
|
295 |
+
},
|
296 |
+
{
|
297 |
+
"cell_type": "code",
|
298 |
+
"execution_count": 41,
|
299 |
+
"metadata": {
|
300 |
+
"hidden": true
|
301 |
+
},
|
302 |
+
"outputs": [],
|
303 |
+
"source": [
|
304 |
+
"device = 'cuda'"
|
305 |
+
]
|
306 |
+
},
|
307 |
+
{
|
308 |
+
"cell_type": "code",
|
309 |
+
"execution_count": 49,
|
310 |
+
"metadata": {
|
311 |
+
"hidden": true
|
312 |
+
},
|
313 |
+
"outputs": [],
|
314 |
+
"source": [
|
315 |
+
"config = {'lr': 2.006434218345026e-05\n",
|
316 |
+
" , 'epochs': 160\n",
|
317 |
+
" , 'dropout': 0.0\n",
|
318 |
+
" , 'emsize': 256\n",
|
319 |
+
" , 'batch_size': 256\n",
|
320 |
+
" , 'nlayers': 5\n",
|
321 |
+
" , 'num_outputs': 1\n",
|
322 |
+
" , 'num_features': model_spec['num_features']\n",
|
323 |
+
" , 'steps_per_epoch': 100\n",
|
324 |
+
" , 'nhead': 4\n",
|
325 |
+
" , 'dropout': 0.0\n",
|
326 |
+
" , 'seq_len': model_spec['seq_len']\n",
|
327 |
+
" , 'nhid_factor': 2}"
|
328 |
+
]
|
329 |
+
},
|
330 |
+
{
|
331 |
+
"cell_type": "code",
|
332 |
+
"execution_count": 51,
|
333 |
+
"metadata": {
|
334 |
+
"hidden": true
|
335 |
+
},
|
336 |
+
"outputs": [
|
337 |
+
{
|
338 |
+
"name": "stdout",
|
339 |
+
"output_type": "stream",
|
340 |
+
"text": [
|
341 |
+
"Using cuda device\n",
|
342 |
+
"DataLoader.__dict__ {'num_steps': 100, 'fuse_x_y': False, 'get_batch_kwargs': {'batch_size': 256, 'seq_len': 300, 'num_outputs': 1, 'num_features': 3, 'canonical_args': None, 'model': <function <lambda> at 0x7f6f42f49f70>}, 'num_features': 3, 'num_outputs': 1}\n"
|
343 |
+
]
|
344 |
+
},
|
345 |
+
{
|
346 |
+
"ename": "KeyboardInterrupt",
|
347 |
+
"evalue": "",
|
348 |
+
"output_type": "error",
|
349 |
+
"traceback": [
|
350 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
351 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
352 |
+
"\u001b[0;32m/tmp/ipykernel_9449/1283571267.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtransformer_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_sampler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshould_train\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mmodel_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults_directory\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34mf'bayesian_models_transformer_checkpoint_{model_spec_size}_epochs_'\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'epochs'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;34m'.cpkt'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtransformer_model\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
353 |
+
"\u001b[0;32m~/prior-fitting/mcmc_svi_transformer_on_bayesian.py\u001b[0m in \u001b[0;36mget_model\u001b[0;34m(model_generator, config, should_train, device)\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0mepochs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mshould_train\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'epochs'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m model = train(priors.pyro.DataLoader\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0mLosses\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbce\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0mencoders\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinear\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
354 |
+
"\u001b[0;32m~/prior-fitting/train.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(priordataloader_class, criterion, encoder_generator, emsize, nhid, nlayers, nhead, dropout, epochs, steps_per_epoch, batch_size, bptt, lr, warmup_epochs, input_normalization, y_encoder_generator, pos_encoder_generator, decoder, extra_prior_kwargs_dict, scheduler, load_weights_from_this_state_dict, validation_period, single_eval_pos_gen, gpu_device, aggregate_k_gradients, verbose)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepochs\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0mepoch_start_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 118\u001b[0;31m \u001b[0mtotal_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtotal_positional_losses\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime_to_get_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mforward_time\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstep_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 119\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'validate'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mvalidation_period\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
355 |
+
"\u001b[0;32m~/prior-fitting/train.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0maggregate_k_gradients\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0maggregate_k_gradients\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 95\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip_grad_norm_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1.\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 96\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
356 |
+
"\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/torch/nn/utils/clip_grad.py\u001b[0m in \u001b[0;36mclip_grad_norm_\u001b[0;34m(parameters, max_norm, norm_type, error_if_nonfinite)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0mtotal_norm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnorm_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mp\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mparameters\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnorm_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mtotal_norm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mtotal_norm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misinf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0merror_if_nonfinite\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m raise RuntimeError(\n",
|
357 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
358 |
+
]
|
359 |
+
}
|
360 |
+
],
|
361 |
+
"source": [
|
362 |
+
"transformer_model = get_model(model_sampler, config, should_train = True)\n",
|
363 |
+
"model_path = os.path.join(results_directory, f'bayesian_models_transformer_checkpoint_{model_spec_size}_epochs_'+config['epochs']+'.cpkt')\n",
|
364 |
+
"torch.save((transformer_model[2].state_dict(), None), model_path)\n"
|
365 |
+
]
|
366 |
+
},
|
367 |
+
{
|
368 |
+
"cell_type": "markdown",
|
369 |
+
"metadata": {},
|
370 |
+
"source": [
|
371 |
+
"### Evaluating Transformer"
|
372 |
+
]
|
373 |
+
},
|
374 |
+
{
|
375 |
+
"cell_type": "code",
|
376 |
+
"execution_count": 52,
|
377 |
+
"metadata": {},
|
378 |
+
"outputs": [
|
379 |
+
{
|
380 |
+
"name": "stdout",
|
381 |
+
"output_type": "stream",
|
382 |
+
"text": [
|
383 |
+
"Using cuda device\n",
|
384 |
+
"DataLoader.__dict__ {'num_steps': 100, 'fuse_x_y': False, 'get_batch_kwargs': {'batch_size': 256, 'seq_len': 300, 'num_outputs': 1, 'num_features': 3, 'canonical_args': None, 'model': <function <lambda> at 0x7f6f42f49f70>}, 'num_features': 3, 'num_outputs': 1}\n"
|
385 |
+
]
|
386 |
+
},
|
387 |
+
{
|
388 |
+
"data": {
|
389 |
+
"text/plain": [
|
390 |
+
"<All keys matched successfully>"
|
391 |
+
]
|
392 |
+
},
|
393 |
+
"execution_count": 52,
|
394 |
+
"metadata": {},
|
395 |
+
"output_type": "execute_result"
|
396 |
+
}
|
397 |
+
],
|
398 |
+
"source": [
|
399 |
+
"loaded_epoch = config['epochs']\n",
|
400 |
+
"transformer_model = get_model(model_sampler, config, should_train = False)\n",
|
401 |
+
"path = os.path.join(results_directory, F'bayesian_models_transformer_checkpoint_{model_spec_size}_epochs_{loaded_epoch}.cpkt')\n",
|
402 |
+
"model_state, optimizer_state = torch.load(path)\n",
|
403 |
+
"transformer_model[2].load_state_dict(model_state)"
|
404 |
+
]
|
405 |
+
},
|
406 |
+
{
|
407 |
+
"cell_type": "code",
|
408 |
+
"execution_count": 57,
|
409 |
+
"metadata": {},
|
410 |
+
"outputs": [],
|
411 |
+
"source": [
|
412 |
+
"X, y = generate_toy_data(test_model, bptt, device)"
|
413 |
+
]
|
414 |
+
},
|
415 |
+
{
|
416 |
+
"cell_type": "code",
|
417 |
+
"execution_count": 73,
|
418 |
+
"metadata": {},
|
419 |
+
"outputs": [],
|
420 |
+
"source": [
|
421 |
+
"results_acc = []\n",
|
422 |
+
"results_nll = []\n",
|
423 |
+
"transformer_model[2].eval()\n",
|
424 |
+
"for training_samples_n in evaluation_points:\n",
|
425 |
+
" acc, nll, elapsed = eval_transformer(X, y, model=transformer_model[2], training_samples_n=training_samples_n, device=device)\n",
|
426 |
+
" results_acc.append(acc)\n",
|
427 |
+
" results_nll.append(nll)\n",
|
428 |
+
"mean = np.array([compute_mean_and_conf_interval(nll)[0] for nll in results_nll])\n",
|
429 |
+
"conf = np.array([compute_mean_and_conf_interval(nll)[1] for nll in results_nll])\n",
|
430 |
+
"\n",
|
431 |
+
"global_results['transformer'] = (None, np.array(evaluation_points), mean, conf)\n"
|
432 |
+
]
|
433 |
+
},
|
434 |
+
{
|
435 |
+
"cell_type": "markdown",
|
436 |
+
"metadata": {},
|
437 |
+
"source": [
|
438 |
+
"## Plotting results"
|
439 |
+
]
|
440 |
+
},
|
441 |
+
{
|
442 |
+
"cell_type": "code",
|
443 |
+
"execution_count": 71,
|
444 |
+
"metadata": {},
|
445 |
+
"outputs": [
|
446 |
+
{
|
447 |
+
"name": "stdout",
|
448 |
+
"output_type": "stream",
|
449 |
+
"text": [
|
450 |
+
]
|
451 |
+
}
|
452 |
+
],
|
453 |
+
"source": [
|
454 |
+
"files, times, samples, mean, conf = load_results(f'{results_directory}/timing_{model_size}_model/results_svi_training_{task}', task=task)\n",
|
455 |
+
"global_results['svi'] = (times/100, samples, mean, conf)\n",
|
456 |
+
"files, times, samples, mean, conf = load_results(f'{results_directory}/timing_{model_size}_model/results_mcmc_training_{task}', task=task)\n",
|
457 |
+
"global_results['mcmc'] = (times/100, samples,mean, conf)\n"
|
458 |
+
]
|
459 |
+
},
|
460 |
+
{
|
461 |
+
"cell_type": "code",
|
462 |
+
"execution_count": 74,
|
463 |
+
"metadata": {},
|
464 |
+
"outputs": [
|
465 |
+
{
|
466 |
+
"data": {
|
467 |
+
"text/plain": [
|
468 |
+
"<matplotlib.legend.Legend at 0x7f6f1c2ba7c0>"
|
469 |
+
]
|
470 |
+
},
|
471 |
+
"execution_count": 74,
|
472 |
+
"metadata": {},
|
473 |
+
"output_type": "execute_result"
|
474 |
+
},
|
475 |
+
{
|
476 |
+
"data": {
|
477 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgAAAAEoCAYAAAAub0k8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABkaUlEQVR4nO3dd5xcV33//9e5ZdrubFXfVe9dtiwXDLaMwYWOIfQaCCGE/L4hgQAB0wkQQgKh+QsJhFACGPgGgw0GG1dwkWTLlizJ6mVX0vY69Zbz++POzM5WrVaz2vZ5Phhm5s6dmbvj1d73nPM55yitNUIIIYSYWYyJPgAhhBBCXHwSAIQQQogZSAKAEEIIMQNJABBCCCFmIAkAQgghxAwkAUAIIYSYgayJeuNZs2bpJUuWTNTbCyGEEFParl27WrXWs8f6/AkLAEuWLGHnzp0T9fZCCCHElKaUOnEhz5cuACGEEGIGkgAghBBCzEASAIQQQogZaMJqAIQQQkwdjuPQ0NBAOp2e6EOZcSKRCPX19di2XdLXlQAghBDinBoaGojH4yxZsgSl1EQfzoyhtaatrY2GhgaWLl1a0teWLgAhhBDnlE6nqa2tlZP/RaaUora2dlxaXiQACCGEGBU5+U+M8frcJQAIIYSYdl70ohfR2dk50YcxqUkNgBBCiGnnrrvumuhDmPSkBUAIIcSUkEgkePGLX8zmzZvZsGED3/ve93jNa15TePz+++/npS99KRDMNtva2jpRhzolTFgLgJ6oNxZCCHFBPvmrZ9h3urukr7luQQUff+n6Eff57W9/y4IFC7jzzjsB6Orq4tZbbyWRSFBWVsZPfvITXvva15b0uKazCWsB0FoigBBCiNHbuHEj99xzDx/84Ad56KGHqKys5KabbuJXv/oVruty55138vKXv3yiD3PKmLAWAF8CgBBCTEnn+qY+XlatWsWuXbu46667+PCHP8wNN9zAa1/7Wr7+9a9TU1PDtm3biMfjE3JsU9GEtgC4nj9Rby+EEGKKOX36NLFYjDe96U28//3v54knnmD79u088cQTfPvb35bm//M0gTUAmqznY5lShyiEEOLc9uzZwwc+8AEMw8C2bb75zW9imiYveclL+K//+i++973vTfQhTilqovriN27ZrB/84+NUl4Un5P2FEEKM3v79+1m7du1EH8aMNdTnr5TapbW+bKyvOaFfv5NOdiLfXgghhJixJjQAOL5D1pU6ACGEEOJim/gAIIWAQgghxEU3oQHAlRYAIYQQYkJMaADwtSbjZGVSICGEEOIim/AxeFnflW4AIYQQ4iKb8AAg3QBCCCHExTdhAcD1gmZ/R0sAEEIIIS62CQsAx9vS+Frj+R5Zz8P3pQ5ACCHE8I4fP86aNWt45zvfyYYNG3jjG9/IPffcw9VXX83KlSt5/PHH6e3t5e1vfzsbN25k06ZN/PznPwegvLycD37wg2zdupUXvOAFPP7442zfvp1ly5Zxxx13AOB5Hu9///sLz/3qV786kT/uuJvQxYCONqdYMTcWdAN4YSKGOVGHI4QQYrR+8yE4u6e0rzlvI9z8+XPudvjwYW6//Xa+9a1vsW3bNn70ox/x8MMPc8cdd/BP//RPrF69msrKSvbsCY6vo6MDgEQiwfbt2/nCF77AK1/5Sj760Y/y+9//nn379vHWt76Vl73sZXzrW9/i2LFjPPnkk1iWRXt7e2l/xklmwgKAMhM8dao3FwBcMq5PxJYAIIQQYnhLly5l48aNAKxfv57rr78epRQbN27k+PHjnDp1ih//+MeF/aurqwEIhULcdNNNQLCscDgcxrbtwvMA7rnnHt797ndjWcGpsaam5iL+ZBffxAUAu5MnTnbwqsvmyIyAQggxlYzim/p4CYf71o8xDKNw3zAMXNfFMAyUUoOeZ9t2YftQz4NgldqhnjtdTeAoAM2BtoNkXR9Xu7l6AKkDEEIIMXY33HADX/va1wr3810Ao33ubbfdVggE070LYEKHAXqhE+w/k0Rrjeu70goghBDignz0ox+lo6ODDRs2sHnzZu67775RP/ed73wnixYtYtOmTWzevJkf/ehH43ikE2/ClgOuWF6h6973Il664O9569XzKLPLqY6UUxmzJ+R4hBBCDE+WA55Y02o54IgZJVR2iidPdgPBhEAZz5uowxFCCCFmlAkMAGF8o5uj7afpSbu4voPW4Mi0wEIIIcS4m8AAEAFARU6ypyGBp3087UkdgBBCCHERjCoAKKVuUko9q5Q6rJT60DD7bFdK7VZKPaOUeuBcrxk2bMJmhFDZKXaf7AVkXQAhhBDiYjnnPABKKRP4OvBCoAHYoZS6Q2u9r2ifKuAbwE1a65NKqTmjefPlFSs44pziqVwAcHwXx/Nn3FhMIYQQ4mIbTQvA5cBhrfVRrXUW+DHw8gH7vAH4hdb6JIDWunk0b74qvpys0UhDZw8tPdmgDgBkeWAhhBBinI0mANQBp4ruN+S2FVsFVCul7ldK7VJKvWU0b746vgyNhxlpYPfJXlzfRWuN48mEQEIIIfrr7OzkG9/4xri/z+tf/3o2bdrEv/3bv437e02k0QSAodriB56hLWAr8GLgRuBWpdSqQS+k1LuUUjuVUjvb2ztYGV8GQFllQ6EOQKYFFkIIMZThAoBXwiHkZ8+e5U9/+hNPP/0073vf+0b1nPzMgaVSyp9nJKMJAA3AwqL79cDpIfb5rdY6obVuBR4ENg98Ia31t7TWl2mtL6upqaHCjjM/Op+KqiAABDMCOoU6ACGEECLvQx/6EEeOHGHLli1s27aN6667jje84Q2FxYFe8YpXsHXrVtavX8+3vvWtwvPKy8v5yEc+wubNm7nyyitpamoC4Pbbby/MGHjNNdcAwXTAzc3NbNmyhYceeojdu3dz5ZVXsmnTJl75ylcWphbevn07//iP/8i1117LV77yFbZv38773vc+rrnmGtauXcuOHTu45ZZbWLlyJR/96EcLx/KDH/yAyy+/nC1btvCXf/mXhZN9eXk5H/vYx7jiiit45JFHLsrnOZrFgHYAK5VSS4FG4HUEff7Ffgl8TSllASHgCmDEtpN8jd+qiuU8lnmazqTDibYMK+eEAGR1QCGEmKS+8PgXONB+oKSvuaZmDR+8/IMj7vP5z3+evXv3snv3bu6//35e/OIXs3fvXpYuXQrAd77zHWpqakilUmzbto1XvepV1NbWkkgkuPLKK/nsZz/LP/zDP/Dtb3+bj370o3zqU5/i7rvvpq6ujs7OTgDuuOMOXvKSl7B7924ANm3axFe/+lWuvfZaPvaxj/HJT36SL3/5y0DQIvHAA8Ggt1/96leEQiEefPBBvvKVr/Dyl7+cXbt2UVNTw/Lly3nf+95Hc3MzP/nJT/jjH/+Ibdu85z3v4Yc//CFvectbSCQSbNiwgU996lMl/VxHcs4WAK21C7wXuBvYD/xUa/2MUurdSql35/bZD/wWeBp4HPgPrfXekV63EADiy0n73Si7nd0ne/B00JQihYBCCCFGcvnllxdO/gD//u//XviWf+rUKQ4dOgQESwG/5CUvAWDr1q2F5X+vvvpq3va2t/Htb397yGb3rq4uOjs7ufbaawF461vfyoMPPlh4/LWvfW2//V/2spcBwXLD69evZ/78+YTDYZYtW8apU6e499572bVrF9u2bWPLli3ce++9HD16FADTNHnVq15Vok9mdEa1HLDW+i7grgHbbhtw/4vAF0f7xgqFAlZVrABg1qxGdp9cwisunY3nu2RdGQYohBCT0bm+qV8sZWVlhdv3338/99xzD4888gixWIzt27eTTqeB/ksBm6ZZ6LO/7bbbeOyxx7jzzjvZsmVL4Vv/WN4f6LfE8MBli103KHJ/61vfyuc+97lBrxWJRDDNi9vqPWEzASo0lmmwMFZH1IxSXXOavQ0JXE/j+C6eL8sDCyGE6BOPx+np6Rnysa6uLqqrq4nFYhw4cIBHH330nK935MgRrrjiCj71qU8xa9YsTp061e/xyspKqqureeihhwD4/ve/X2gNGIvrr7+en/3sZzQ3ByPl29vbOXHixJhf70KNqgVgPJhtR7ANcJTBivgyziSPkXJ8Dp5NcumiKBAh6/pEQ1IHIIQQAmpra7n66qvZsGED0WiUuXPnFh676aabuO2229i0aROrV6/myiuvPOfrfeADH+DQoUNorbn++uvZvHnzoBPy9773Pd797neTTCZZtmwZ3/3ud8d8/OvWreMzn/kMN9xwA77vY9s2X//611m8ePGYX/NCTNhywJctMPWdv7ud9thifnL8F/zi1K/oefYTvOHyRbz56gVUh2uI2CaVUVkeWAghJposBzyxptVywABlDY+hVFAHoNEsmt8cFAL6Hr72ZT4AIYQQYpxMXACwo4RPPIJtGKyMLwegdtZpDpxNksx6uL6DrzWujAYQQgghSm7iAkA4jnV2DyGvl3K7jLroAnz7GJ4PexsSOL4DyHBAIYQQYjxMXAAIxVHap6xxJxBMCNSUOUzIpLAuACDdAEIIMUnIDK0TY7w+9wkMAFF0pIroqUcxlGJ1xQp63V5W1CeCAKCDMZPSAiCEEBMvEonQ1tYmIeAi01rT1tZGJBIp+WtP2DBAUKglzyN8/CGsqzWr4sGEQHNmN3L/iTLaerNUhFxsbZN1fULWhNYrCiHEjFZfX09DQwMtLS0TfSgzTiQSob6+vuSvO4EBAFh6DcaBX1HecZAF8RWUWTFU+BiwiqdO9rKwqgrbsMl6EgCEEGIi2bbdb9pdMfVN7Fl1yXPRyqDs1GMYuQmBmtKHiEdMnjrVWygEdKQOQAghhCipiQ0A0Wr0vE2ETz6CaShWx1fSkGxk/UKD3Sd7cbwsgCwPLIQQQpTYhLerq6XXYrccIJztYFXFcjSaBbPP0NLj0NCRwdMeGhkOKIQQQpTSxAeAZcHCCvHTO1gRX45CYUaOAPDkyR7c/HwA0g0ghBBClMyEBwBmr8Uvm0Xs1KOUWVEWxuo4kz7MnAo76AaQ+QCEEEKIkpvAAKByVwp/yfMIndqBqXxWVazgUPdhNtdHePpULxk3qANwfY0vywMLIYQQJTFxAUCpvptLr8XI9lLe+gyrKpaT9FIsmdtOIuNz4GxPoQBQ6gCEEEKI0pjAAND31ubi56INi/KGxwoTAtmRowC5boCgDiAj3QBCCCFESUxsDYBpB9fhcrwFW4ieepQF0XmUW2U0pA6zdFaE3VIIKIQQQpTcBAeAUOGmv+R5WO1HiaVbWVWxgoPdR9iyMMK+00l6MplgH63xpA5ACCGEuGATGwCscOGmWrYdgPjpx1gVX05j6jRr5mZxPc3TjV2F/aQVQAghhLhwE98CkCsGtGpW4sUXEGt4jFUVKwEIR45hGYonT/TI8sBCCCFECU1sAFCq0A2gDANnyXMIN+5iVVk9CsXx5BHWzo/0LwT0vIk8YiGEEGJamPiJgKy+NY79Jdeg3DSz255lcdlCDvYcZnNdmKPNKdp6kwBoHawNIIQQQoixmwQBoK8OwFh0BdoMBXUAFSs41HOULfMMNLDrZGdhP+kGEEIIIS7MxAcAwwTDAsCyy8nWbSV68lFWxZeT9tJEy84SDRk8caIbTwfN/xIAhBBCiAsz8QEACq0AlmHiLHkOVncjG81yAI70HmFzXZTdJ3sL8wHI8sBCCCHEhZlUAQDAX/JcAFa2HqLCjvNs92G21Nmc7cpysr0XAA04ngQAIYQQYqwmRwAoGg5oVi7GrV5CvPHxXB3AYS6ZH3QR7DreWXiKrAsghBBCjN3kCABFwwFtwyKz6CrCZ3azOraYM6kmKsuS1JaZ7DrZ1bcwkNQBCCGEEGM2OQIAFLoBbMMmu/gqlO9yiRMU/R3qOcKW+mBdgKzUAQghhBAXbBIFgGA+AKUU/oIt+HaMS9uOYSqTgz1HuGSBTXfK49mzfdMCy+qAQgghxNhMngBQNBzQtmJkF15OTeMOlpQt5FD3YbYsCFYO3HG8vfAUqQMQQgghxmbyBAAAq68OILvoKsxEC2tDszncc5SqGCysttl1XBYGEkIIIS7UJAsAQTeAbdhkF10JwCWZDBk/y4nEKS6pC7G3sZdUNguA58vywEIIIcRYTK4AkBsOaBkmlM/GmbWaba3HADjYfZhLFthkXM3uho7CU2RdACGEEOL8Ta4AUDQc0FIW2cVXsrhpP9V2BQe7j7BxQQhD9a8DkEJAIYQQ4vxNrgAARcMBLTKLnoOhfdZZVRzsOUxZyGD1nBC7TnQWdpc6ACGEEOL8Tb4AYObXBbBw56zFj1RySTpFc7qFzmwXW+pCHDybpDsV1AH4WuNKN4AQQghxXiZhALDAsLANCwyT7MIruKylqA6gLoSvYceJ1sJTZDigEEIIcX4mXwAAsEIYysA0DDKLrmJDbzsWBgd7jrBmjk3YUjx+rGg+AOkGEEIIIc7L5AwA+W4AZZFddAUhDFaaZRzsPoxtKjbOD7HrRNF8ANICIIQQQpyXyRkArDAohW1Y6Eglztz1XJJKcqT3GK7vckldiFPtaZq60wBoLa0AQgghxPmYnAEgNxzQNoLpf7OLruLSjrM4vsPxxEm21AVDBR871lcHIPMBCCGEEKM3OQMA5AKAhVKQXXwVmzNB1f/B7sMsqbGojBhSByCEEEKM0eQNALlpgS1l4dauZHa4irna5GD3YQyl2FIXYteJjsKSwLI8sBBCCDF6kzcAmMEwQNuwQCmyi65icyrJwZ7DAFxSF6I94XKsNQGARooBhRBCiNEaVQBQSt2klHpWKXVYKfWhIR7frpTqUkrtzl0+VpKjs8KFOoDM4uewJZWkNdNOe6aDLXXBSIHHjrUVdpduACGEEGJ0zhkAlFIm8HXgZmAd8Hql1Lohdn1Ia70ld/lUSY7ODActAIBTfxmbsi4Q1AHMjZssqDB5XAKAEEIIcd5G0wJwOXBYa31Ua50Ffgy8fHwPK8cKYxgmhjLQoTJW1K4lpCl0A2ypC7H7VFdhKmDX1/iyPLAQQghxTqMJAHXAqaL7DbltA12llHpKKfUbpdT6khxdYThg0ArgL76a9Zk0hzr3A3BJXZhk1mffme7CU6QOQAghhDi30QQANcS2gV+znwAWa603A18F/nfIF1LqXUqpnUqpnS0tLaM7wqIAkF10FZvTWY4mGnB8h80LQijoPxxQAoAQQghxTqMJAA3AwqL79cDp4h201t1a697c7bsAWyk1a+ALaa2/pbW+TGt92ezZs0d3hFakEAC8qkVsMspw8DnWe4J4xGDFLIvHiyYEkjoAIYQQ4txGEwB2ACuVUkuVUiHgdcAdxTsopeYppVTu9uW5120b9EpjYVpYZhilAKVYM3crAAe7ngVgS12Yvad7SOYKBD1f40kdgBBCCDGicwYArbULvBe4G9gP/FRr/YxS6t1KqXfndns1sFcp9RTw78DrdAln5VFWGEsFrQDxJddQ57gcaXkCCOYD8Hx48mRnYX9pBRBCCCFGZo1mp1yz/l0Dtt1WdPtrwNdKe2hFrAiWYeH4Ltm6S9m002FH4iRaa9bNC2GbsPN4O1evCHodsq5PNGSO2+EIIYQQU93knQmwmBXGNu3C7Q3R+bTi0JZuI2wp1s8L9asDyHieTAsshBBCjGBqBAClsK1o4e7aeZcBcLT5cSDoBjjckqKtNwMEywP3ZNyLf5xCCCHEFDE1AgBg2jEMFRzuwmU3EvF9juQCQH5a4J3H+4YDprIeGde7+AcqhBBCTAFTJgAE6wIE/fpGVT3rfYMDyQYAltdaxMOKHcf6DzzoTrnSFSCEEEIMYeoEANPGzi0RDLA+Vsch5ZDNdGEaik0LQjx+vL3fCd/Xmu60dAUIIYQQA02dAADYVqxwe838bbhKcebE74FgWuCmHodTHal+z0k7HmlHugKEEEKIYlMqAFihssLt1UteAMCRlp1AsDAQwI6jrYOe1512ZJEgIYQQosiUCgDKimDl6gCqorUsxOLZZANozYIKkznl5qA6AAhGBXSnnYt9uEIIIcSkNaUCAIaBbfd1A6wvW8jTliLUfgSlFFvqQuw62TXkVMAZ15euACGEECJnagUAwLb7ugHWzNtGm2XSc+p+IJgPoDvjcfBs15DP7U47sk6AEEIIwRQMAMV1AOvmXgrA0ZZdAGxeENQBPH506KWGtYbulHQFCCGEEFMvANhRjFwdwJL4QmKY7M+2YGS6qY6ZLKu12XG8Y9jnZz2/sHKgEEIIMVNNuQAAfd0ApjJYU76Qp8Ihyk/3jQZ4qrGXdHb4b/q9aRfXkxUDhRBCzFxTMgBYob5CwHVzNnMwZGM0PALAlgUhsp5mz9HGoM1/CBpkgiAhhBAz2pQMALZdXri9rno1nlKcaH0SfI8N820sAx4/0YnKdA/7Go7nk5AFg4QQQsxQUzMAWGHILQ+8tnIlAHsNl0jbQaK2wWULo/z86Q4a27rASQz7OomMiyNdAUIIIWagKRkAlFKYdrAuQEWonIXRuTwZCVPe+BgAf/3cCpSCj93dgJ/qBjcz5OtoglEBsmCQEEKImWZKBgDo3w2wvmYNT0djlDUGdQCzyhT/cN1c9pxN8R+Pt2BkusAfurnf9TW90hUghBBihpm6ASAUA6UAWFe1kk6laek6hplqB+CKJYqXrq3iuztaeKKhB5XuAD10c38y65F1pStACCHEzDF1A4BhgxVM/LOuahUAT0XClDc+DoCvff7q6grqK0PcencDPcksKt057Ot1p6UrQAghxMwxZQOAZVgYVlAHsLi8jjIryhNlFZQ3PlrYRxlZPv7CebQlXf7pD6fBzaAyPUO+nudreqQrQAghxAwxZQMAgJWbEMhQBmsrV/JUWZyy07v69fcvrPX4qyvncO/hbu7Y14lyEuAkh3y9VNYj48qCQUIIIaa/KR0AguGAFhDUARzVGVJukljzM4V9PN/jls0xti0s418eOMPxjgxGpge87JCv2Z1y8WXBICGEENPclA4AlmEV1QGsxEezJxrt1w0AkPHS3Hr9PMKWwUd/20DW9TDSneAP/rbva02PzBIohBBimpvSAcA2bDDCAKytWoFCsat2EWW5+QCKRSMZbr1+Ac+2pPnmo82g/dzIgMHf9tOuR9qRrgAhhBDT15QOAIYyggmBlKLcLmNxeR1PxcqJdB7H6j3bb1/Xd9m22OLVG6v5wRNtPHqiF+W7w44M6E470hUghBBi2prSAQDANu1+3QDPeD34UJgVsFjKS/Peq2eztCbMx3/fSHvSRXkZVLZ30L5aByFACCGEmI6mfACwlAVGEADWV62ix0txqGoB1QfvRA0o9NNa45LkszfV05vx+PQ9jWitgwDgpAa9dsb1SWWlK0AIIcT0M+UDQNACENQBrKsKFgZ6bPXziXQcZs7Obw7a3/Fd6qs0/99z5/Lw8V5ufzqYOdDIdIM3+Bt/T9rBk64AIYQQ08zUDwCGjTJMMC3qy+YTt8vYYxu0rXsNNc/+koqj9w56TtJL8aoNlTx3STlfebiJw61pQGMMMV1wfsEgIYQQYjqZ8gEAcsMBTbswIdD+rkN0XfEuknM2Mv/RLxHqPN5vf601CS/Bx15QRzxs8pHfNpB2/WFHBmQ9n2RWhgYKIYSYPqZFALANG8xgWuD11Ss50duINj3OXvdxfCtK/QOfRA3o43d8l0jY5RMvrONoe4Z/f7gJAOU5qEz3oPfoTbu4niwYJIQQYnqYHgHAtMGyQanCwkAHOo8QrZ7H6Ws+Qqj7FPMf+dKgb/ZJN8m2RVHeeEkttz/dzgNHgxO/clOobKLfvhrokq4AIYQQ08T0CACGHdywQqypXI6BYl/nIUylUEsup2Xz26g8/geqDt7R73laa3qdBO+5ag6rZ0f49D2naekNTvIq2wNuut/+rq/plQWDhBBCTAPTIgAYysBQBhghYlaUFRVLuPfMH8l4WcKmSeLSN9Fbdzlzd3yDSOuBfs91fAdfZfnMTfVkXJ+P/74RP9dSYKS7+i0sBJDIuDjSFSCEEGKKmxYBAHKtALkJgd656vU0Js/yX4dvB6AsHKbpmo/gRaupe+CTwZC/Ikk3xaIqm7+/Zj47TiX44RNtuUd0riiw/wm/M+mQdSUECCGEmLqmVwDIDQfcOmsjL66/np8du5N9nYcAiFXW0njtx7FTbSx4+PP9Tuq+9ul1E7x8fRXPX17B1x9pYl9TUDSofG/QdMG+1nQks/RmXPQQawkIIYQQk930CgAAZnD9l2vewKxIDV/ccxtZL4upFNbCjTRd9h7ijY9Su/fH/Z6f9bJk/Sz/eP18amMWt97dQDI3C6DyskOODEhkXNoTWRkdIIQQYsqZNgHAMiwUCsxgVsAyK8bfrf8LTiZO89+Hfw5AyDBJb7yFriXXMXv3d4idfbLfayTcJPGwwadvrOdUZ5YvPdi3oJBykuD0HxkAQWFgeyJLQooDhRBCTCHTJgAopQoTAqGCbdtmb+bm+uv4ybFfcaDrCACxkE3rcz9AtqKeugc/g5VsLbyGr30SbpJL68p4+7ZZ3LGvk98f7Co8bmR6wc0Mem8N9OZaA2TaYCGEEFPBtAkAkJsRUCkwQ4Vt7179Jmoj1UFXgB8M8Ssrr+D0tZ/AcFPUPfhp8PsW/Ml4GbJelr+4fA4b50X5pz+c5kx3flEhjZHp6rd/McfzaevNyAJCQgghJr1pFQD66gDChW3ldoz3rf8Ljvc28IPDvwDAUIrQ/JWcufLviDXvYfaT/9nvdXrdBIah+dSN9WgNt97diJv/Zq99jFRbUBPgJIMFhIoKATXBMsKdySy+tAYIIYSYpKZnALBC/bZfMXsLN9Zdy/8cu4ODXUdz+xq4a26kY9VLmfXMjyk/+cfC/r72SbpJ6itDfPC6+Tx1Jsl3d7T0vaD2UU4SI9ONkWrDSDSjkq2odFdQJ+BlyTgerYkMaUdaA4QQQkw+0yoAmIaZmxDIhHCs32N/teZNVIcq+Oc9t+HkJveJ2hbtV7yXVO0qFvzx89g9pwv7p70MWd/h5jVV3Ly6kv94vIWnTieHeWeN8l2Um8LI9GCk2jESTZBoobujha7ODvxsGnwZLSCEEGJymFYBAIpaAULlhSGBAHG7nL9d/06O9Z7ih0f+X2F7WVkZZ679BChF/QOfQHnZwmO9Ti9aa/5h+3zmxW1uvbuB3szov9Er30O5abLJLjramsh0nYbeZki2Q6YnKCiUUCCEEGICTN8AABCpCIoCc54zZysvXPA8fnT0lxzuPg4E9QDh2Qs5c/WHiLQfZu7jXyvsnx8VUB42+cyN9TT3OnzuvtNjmvzH15rulEtPKo1205DpDYJAb1P/UOBIS4EQQojxN70DgGFCON7v8feseQuVoTj/vOc23FxXgG0YeMufR+uG11N96NdUHvldYf+0l8bxHTbOj/GXV87hdwe7ufNAF2OVdnw6kg7Z4smDfC9oDcj0QqoDEs2Q7h60eqEQQghRKqMKAEqpm5RSzyqlDiulPjTCftuUUp5S6tWlO8TzYxlW/w12BOy+UQEVoXL+dt07ONJzgh8d/WVhe9S26Nr6DhJzNzPv0X8j3HG08Fivm0BrzVu2zuLSuhj/fP8ZDrb0XynwfHi+pivlkMi6aIY4yWsN2QQkWsBJjfl9hBBCiOGcMwAopUzg68DNwDrg9UqpdcPs9wXg7lIf5PkoTAhULFwBRt+PevXcy7h+/tX84Mj/40jPicL2smiEM9d+DD9URt0Dn8BwgqI/z/dIeilMQ/HJG+qJWoq3/eQo393RguuN/Vt6MuvRkXRwh2vy9z1IdUKiLRhuKIQQQpTIaFoALgcOa62Paq2zwI+Blw+x398APweaS3h8Y9KvGwCCOoBwZb9Nf732rcTtMr645/8WugIMpYhVz6HxebcS6jnN/D/9S6EZPuWmcH2XeXGbH71xBdcui/ONR5p520+PXnBrQGfSIZkdYSphLwuJ1iAMSH2AEEKIEhhNAKgDThXdb8htK1BK1QGvBG4r3aGN3aAAAGDZ/YYGVobi/O26d3Co+xg/Ofarvt0MA7X4UloueQcVJ+6n+kDfiIF8V0BtzOJzL1rIF160kJZeh7f85Aj/99FmnDEuCqSBRNajI+XgjdTv76SCboHs4DUJhBBCiPMxmgCghtg28Cz1ZeCDWusRx8gppd6llNqplNrZ0tIy0q4XZFAXQN6AoYHPm3c52+ddxX8f/jnHevoyTsSy6Nn0enrqr2LurtuItOwDwPVdUl5fn/zzV1Twkzet4IaVwTwBb/7xUZ5pGnufvev5dCSypJxhagMgWMY43Q29LUOuSyCEEEKMxmgCQAOwsOh+PXB6wD6XAT9WSh0HXg18Qyn1ioEvpLX+ltb6Mq31ZbNnzx7bEY+CZViEi6YD7mfA0MC/Wfc2yu0yvrjnNryiOf7LoyGanvdhnNgs6h/4FGY6qPxPeel++1VFLT51Yz3/+tJFdKc9/vynR/nqH5vIuGNvDejNeLQlsvRkRqoPcIOhg8n2YdcmEEIIIYYzmgCwA1iplFqqlAoBrwPuKN5Ba71Ua71Ea70E+BnwHq31/5b6YM9HPBQPZgUcaMDQwKpQBf/furfzbPdRfnr814XtCkWsoprGaz+Bme5gwcP/BNpHa02P24uv+5+Yn7c0zk/etIKXrq3iv3e18sb/OcJTZ4abOfDctO4bMtiZcki7w5zk3UzQLZDpkWGDQgghRu2cAUBr7QLvJaju3w/8VGv9jFLq3Uqpd4/3AY6VoQziofjQDw4YGnjtvCu5Zu4VfO/Qzzje21DYbhkG5oJ1NF3+XspP72DWnh8CQVdAl9ONN6DHIx42+egL6vjqyxeTcTV/cfsx/vXBM6ScCyvcczyfnrRLWyJLMuviDzzRax3MISDDBoUQQoySGsusdqVw2WWX6Z07d477+3Rlush4Q/SVaw3JtkJVfUemiz9/+P0siM3l36/4JKZhFnbtSWWZ9eBnqTh6Lydf8M8kF2wFgpBRYceHrDlIZD2+9scmfrang7oKm1tfUMfW+rKS/EwKCFkG0ZCJbQyR4cxQ0NVhDlEMKYQQYlpQSu3SWl821udPu5kABxq2K2DA0MDqcCV/s+7tHOg6ws9O3NVv1/KoTfNVf0emajF1D30GKxEUMPrap8vpJusPHqNfFjL54HULuO2WJSilePcvjvP5+06TyF54f70GMq5PZ9KhI9c90K9oMD9sMN0lwwaFEEIMadoHgBG7AgYMDbxu3lVcPecyvnvodk72Nha2KxRl5XFOX/sJDC9D3YOfCorwIKgJcHqGbmUAttaX8T9vWM4bttTyiz0dvO6HR3j0RG/Jfj63qHsgkXX7DyPMJmXYoBBCiCFN+wAAEDbDRKzI0A8WDQ1USvG3699BxAzxxb3/F6+o0M8yDKw5yzhz1fuJtTzD8v99KzV7f4yZ7sqFgF5S7tD97xHb4H3XzOM//mwpEUvxN788wafvaaTnPFYWPBetg5kF2xNZulNFaw3khw0mWmXYoBBCiIIZEQAA4vYwXQHQb2hgTbiK9659G/s6D/GL47/pv5tlkl75Ak5t/xRO2WzmPvEtVvzsNcx/+HNEW54h4STodYb/tr1pfowfvH45b906izv3d/LaHxzmoWM9JfsZ8zKeT1fKoT1ZNKeA5wRDBlMd0i0ghBBi+hcBFst6WToznUM/6KSDb8oEzfq3PvEv7Grbw7ev/gL1ZfMLu2mChXw8XxPuOEbVwTuoPPp7TCdJunoFHatfRnrFzZRFZ6PUUHMoBfY1pfj0PY0cbstw0+pK/v6aeVRFh5nA6AIpBWHLIGqbWIYBhgWx2n7rIwghhJhaLrQIcEYFAIDubDdpd5i5+9PdQRAAWtPtvOPhf2BJeT3/esXHMItaDzytSWTcwtS/hpOk4ti9VD/7SyIdR/HsMnqW34i74Q3ompXDHovj+Xx3Ryvf2dlCZTgoGnz+iorS/bBDsE2DiG0QDoVRZbP6TYokhBBi6pAAcJ601rSl2wZN5JN7sN/Mer9rfJAv7Pkm71nzFl615OZBu2c8j0TG7Zt/R2uiLc9QffAO4scfwPAdsvO3kl73GjJLrh92WN7BljSfvqeRAy1prl9RwQe2z6c2Nj6tAXlKQTgcIVIxG9syz/0EIYQQk4oEgDEYsSvAdYJ+coKw8JEn/pndbfv49tVfoK5s3qDdfa1JZF2yA6b+NdOdVB7+LdUHf0Wo9wx+tJbUmltIr30Vfvn8Qa/j+prv72rl24+1ELUNbtlYzS0bqplfEbrgn3ck2gxhxGqIhiwitolpSIuAEEJMBRIAxqgn2zNs1T7ZXsgE0/i2pNt5x8MfYHl8EV+6/NZhCwkd3yeRcfH8gbP0+cRP72LWoV8TPvkwKEV20TWk1r0Gp/4qGPB6x9ozfP1PTYXiwKuXxPmzTTVcsagMY5ya67UZRkerAQiZwQRDYcsYsYZBCCHExJIAMEYjdgUAJDuCynngtw3388W9/5e/Wft2XrH4huFfE03K8UhnvUFr+SmlqEr1ED94B9Fnf4GRaserWEhq7atJr345OlLdb/+zPVl+saeD/32mg46UR31liFdtrOal66qojJS+e0BbEXSkqu94gbBtBvUC0kUghBCTjgSAC+B4Dh2ZjqEf9L2gHkBrtNZ8eNfn2dPxLN+++gssiM0d8XVd3yeZ9QpFgsXK7DKimISP3UNk308JnX0CbYbILLuB1LrX4s7Z2K8wL+v63Hekh5/taWf36SRhU3HD6kpevbGGdXOjF/TzD6TtGDo8uAjRUIqInRtFYMrIASGEmAwkAFygEbsCioYGNqfaeMfDH2BV5VK+uO0jw88pUCTteiSz7qBF+qJWlDIrmIHQbD9EdN/thA/9GsNJ4NSuJr3utaRXvWxQ0eDBljQ/39POb57tIuX4rJ8b5dWbanjBygoiVmlOzNouQ4eHmTmRvlEEEcvEkHoBIYSYMBIALpDWmvZ0+6CV/QqKhgbeeeoP/Osz3+Zv1/05L130wlG9vp8bMpgd0BoQMcOU2+WF+yqbIHz4LqL7forVfhC3chGJK/+e7KJrBw3V68143Hmgk5893c7xjiyVEZOXravilo011FdeeNGgDpWjQ+Uj7pNfkChiS72AEEJMBAkAJTBiV0DR0ECtNR/c+Tn2dR7iW1d//pxdAcWyvkcy4/UrEgyZIeJWef+Tp9aETj1E2SNfwuo6TrbuSnqvej/eEPMJaK3Z2ZDgZ0938MDRbnwNVy0u58821XDV4vILquj3w3GwR7d6oVIQsU2itoktXQRCCHFRSAAokd5sL0k3OfSDRUMDm1It/MUfP0TIsPnkJX/H+upVo36PoYoELcOiYqhpin2H6L6fEtt1GyrbS3rtq0ls/St0tGbI127udfh/ezv4f3s7aEu6LKiwuWVDNS9fXz3mGQb9cAXYsXPvWMQ0FGHLIGQZhExpGRBCiPEiAaBEztkVUDQ08HhvA7c+8S+0pNr4uw1/wQ1115zXe7m+TyLr4ea6BUzDpMKOY6rB1fYq3UnZrm8S2Xc72o6RvPRdpNa/fthJhVxPc//Rbm5/up0nGpOETMULVlbw6o01bJgXPe8Tsh+uBHtsxYaKoGYglAsE0joghBClIwGghEbsCoB+QwO7sj18aveX2d2+j9csfQnvXPX6ftMFj0ZxkaChDCrsOJYx9Ld1s+MI5Y/8C6GGP41YH1DsSFuan+/p4K79nSQcnzWzI7x6Uw03rqokYo/2WBV+pBKGW03xPCgFYdMsBAKZdEgIIcZOAkCJjdgVUDQ0EMD1Xb6+/3vcceoerph9CR/Z/N5Cdf9oFRcJKqWI23FCxtDf7gFCJx+m7NF/weo8NmJ9QLFE1uM3B7q4/el2jrZniNkGz11azvOXV/CcJXGi5wwDCj9SBVb4vH62czENVegqkEJCIYQ4PxIASuycXQFFQwPzfnnyd3xt//dYWDafz1z6gfMqDszL+h6JdFAbEDOjhM3w8EMNfYfIvtsp2/XNoD5gzatIXPaeYesDin+2J08nuetAJw8c6aEz7RG2FFctDsLA85bGKQ8PN+mPwo9Wgzk+UxMrwDL7agdCJRrWKIQQ05UEgHHg+A4d6RG6AoqGBuY92baXT+7+MqD4xJb3saV23Xm/r0aTzLqknaA1wDZsIkYY27CH/Has0p3Edt1GdN9Pc/UBf0Fq/RuGrQ8o5vqa3Y1J/nCkm/uOdNOacLENxRWLyrhuRQXXLI0PLh5UBn6kelSvf6GUohAEQqYhExAJIcQAEgDGyYhdAQDZRHAp+vgaE2f56BNfpDHZxN+sfRsvXfSCMb236/tkXI+M6xfqA0JGiLAZwh6ie8DsOErZo18ifOph3IpcfcDikesDivlas+dsij8c7ua+w92c6XEwFWytL+P5KyrYvryib3VCZeBHa2CYWoXxYqiguyDoKujbrui703970e0Bn0P/x4bfTwghJjMJAONEa01HpgPXd4ffyfcg0wNutrCp10ny2ae+yuOtu3nFoht4z5q3YBpjm0tfo8m4PhnX7zdiIGyECBvhQa/bvz7gCnqv+sA56wMGvafW7G9OB2HgSDcnO7MoYMuCGNetqOC65RXMqwjjR2thjD/XZFbcFWGbSoYyCiEmLQkA4+icXQGFHdNBEMh9lp72+fazP+L243dyae0Gbt38f6g4x8x65+L6PinHw3H9QqODbViEzDBhI9RXL+A7RPb9LFcf0EN6zS25+oDa835PrTVH2jL84XA3fzjSzZG2DADr50Z5/spKrtuwhLqa0U0WNJVZhsLOdUXYpoxeEEJMDhIAxlnCSZBwEufe0feDuQKKagN+23A///bMfzA3OpvPXPp+FpXXXfDx+FqT9TzSjl+YVTBfLxA2QoSMEEopVLqL2BO3EX3mJ2grSnLru0ZdHzCcEx2ZQsvA/ubg51w1t5zrVs/hujVzWDpr+ocByHVHFLUSSH2CEGIiSAAYZ6PqCijmOpDpDroHgL0dz/LxJ/8Vx3f56Oa/4fLZW0p2bI7vk86tOpj/rxjUC9iEzDAhw87VB/wr4VMP5eoD/o7s4u2jrg8YzunuLPcd7uYPR3p4+kxQK7GkNsZ1a+ZwzcrZrJ4XnzHflPMFi3bhoqTbQAgx7iQAXASO79CZ7kQzys9Ka3ASkE2CDqYPvvWJL3Gs5yTvXvMmbll8c0lPEL7WZNygVcAv+u9pKIOwGSZihIk0Pkr5I/+C1XmUbP1z6L7us+ccNjhazSm474TDfc+28OTJDnwNFRGLrYuruXxpDZcvraG++vzmR5jKpI5ACHExSAC4SEbdFVDMc4PaAM8h5ab5/J5v8HDTDm6q287/Wf/nI074M1YZzyPr+INWH7QMizAGVc/+kvjjX8GPzqLrxi/j1a4uyftqM4SOVNORdHj8eDuPHwsuzT1B3cCCqgiXLwnCwGWLa6iMjf9QwskkX0dgG8EoBqWCEQyGCrpw8teTmu+DId0dQkwWEgAuovZ0++i7AoplU5Dtxfc9vnf45/zgyC/YUL2aT2x5H9XhytIfKOBpTdpxC0MJi1W0H2XeHz6Mke2he/tnyC4b3dLG56KtCDpS1Xdfa062J4MwcLydXSc6SGQ8FLB6XjxoHVhSw6aFlYSt6Tei4Hyp3P8ZKhjcaCiVCwvB9eDtQ+870Gj+jQ+7i9bgpsFJgZfBsMIQjoM1PhNCCSFGTwLAReT6Lh3pjtF3BRTzfcj2gJPhD2f+xBf33EZ1qJJPb30/y+OLS3+wORpN1vNJO31DCQGsVDsL7/8EkZa9JC59F8mtfwXnuZbBkO9nRdGh8iGHCLq+z/7TPTx2rI0dxzvY09iF52vClsHmhVWFQLBybjnGZP82PN15DspNodw06L7fG6XAMgzMUBg7WoFlh6UIUogJIgHgIks6SXqd3rG/gJuFTDfPdhzi1ie+RMJN8uFNf81z524r3UEO99a+T9r1yOZaBZSXZd5j/07V4bvILN5Oz3WfDU7epaIMtGGCCi7aMHK3g+uEo3nyVCePH2tnx7F2jrYGXSxVUZvLllRzxdJati2tZn7l2FYjFOdJ++CmUE4KNcqWLm2GIVyOZYexTYVtGliGjIwQF4HWweJsXjb4wmFFLri4eaqRADABOtIdOL4z9hfQGrK9tHY38rEn/5Vnu47w5ytfyxuWvfyi9QM7vo/j+TiuT/yZnzN35zdwKxfTfdNX8SsWXpRjgKAdOx8SWhIej5/q4fET3ew40UVrbzDB0sKaaKF+YOviauKRmVU/MO7cTO7bfgbG0rpFEAR0OF6YITJfCGmZCtuQ4ZKiBHw/ONl72eDE7zv9+66UCkKAHS35wmWTlQSACeD5Hl3ZrrHVA/R7IZdMsoUvPfVN7j3zR54//zm8f8NfEh6nBXeG42uNcWoHtfd8HLSm4/mfQy963kU9hoG01hxtz/D4qQSPnUrwREOClBM0Rdum6uvzZnAfuSpcD+hHZ3D/ORQ/D8KWyep5cTbVV7KpvpJFNbHJX5w3Fr4XnPSdZL8m/gulrUiuC2jwVNEDQ4GVazG44PfUGq2D6OIXbuvCuUFr+o2Ogb4visNNJT3UtqH2LX6Kacjwz5LKf7svnPCHWaBtKMoAOwJWdFrXq0gAmCBaa3qcHtJu+tw7n+u1Mgn+58AP+c+DP2F15XI+dcnfMStSmiF658PobqTqNx/C7DhOzxXvJbvp7YAxxu+EpeV4PnvPpniiMUnS8YM/+oCvAQ0+oFH4AFrho3KPK7RSuZNC/vHciYK+k0f+BNGbcdl3upvudBDuqqI2G3NhYFN9FWvnx0cuWNR+7luJ7rutc0eXew9tmMEJUpkXt8kyV9Cn3BTKy557/wt5qxFqQYoVhwLLUIUTef6/r/b7TubF2/3chsnwu1nMUMHPYeZ+HtNQQc3EDJkTY8y07n+y97IjVKaeJ8PqCwPmxV3DZLxJAJhgKTdFb7Z3bIWBxXyPP564l3968svErCivXvIiblhwzbiNEhiOcpLE7/00kWMPklp9M9nrPoJlV+D7iqyrB32TmnKUgVaqXy1CsM0ITlbKwPc9TrQlebqhi6cau9nT2MPJjiDoWYZizdwYm+bH2LygjI3zo8yOWX0n/bEcj2GBYeW6QqzgOEq52NIwBX3jT6HtKNoum5brRpwPBX1hoF84mKGtBr434IR/AV2q58O0+7oJpsHvpASAScD1XboyXXj6PJqohnGk9QD/vvur7O04gKVMrp57GS+uv55Latf3zfc/3rRPbOd/Ub7zP3HmrKP75n8iUrGImBXF0xrH88m4GrdoBsLpriPp8vTZJE+fSfL0mRT7mlJkveCnr6uwc4Egxub5MZbWhEvwjU/liiaDQNAXDqzR/eEaQ0Hf+FFoO4YOlZVkpMl0U9xqYCqVaw2Z+q0GWmscT+P6Pp6TGfDt3h+w71AvMKpNQBCwDKOvG8Y0gm5Cc6RwZYWDMGBFpuz8FhIAJgmtNd3ZbjJephQvxvG2/dx17Df8rvFBepxe5kfn8KL667ipfjs14aoLf49RCB19gIp7P40Oxei66XP48zYQs6LErKAqPz/E0HF9sp4urE0wE2RdnwMtaZ4+k+SpM0mePp2kPRUEwLKQwcZ5MTbNj7J5QYz1c6OUhUr5baM4HFh9Iy0MC3z3ggv6xo9Ch8rQdkyCwCgEJ7XBXQmWoTAmWTgoPtk7XvDlwPV18PuY6R73Lqfh5GuCDENh0j8kGLmgYBoGmKFc8eDUGkkgAWCSSTpJEk7iwrsEAHyPbKqDhxoe4s6GP/BU+z5MZfKcOVt5Uf11bJ21CXOc/5CabUeo+s2HMBIt9Fz7D6TXvAjTMCi3ygYVK3pak3VzgcAfPAHRdKa1prHL6QsEZ5IcacugAUPBylkRNs2PsXp2hKqISWXUpDJiURkxqYiYWJPsD/q4UkbQImCXTak/tpNJYT4Go393wsUYaeH7Gsf3cT2N6wW3B4V/rVFOApVNMPmC6GBGLiQYClQohmlHMUORQkjI/5YWij8nye+tBIBJyPEdujJd+KXqb/VcyPZyqvM4dzXcx92ND9Dl9DA3Moubc60Cs8exaFClu6j83a2EGneR3PQaeq/6azAsbMOi3I5hDzOlseP3tQ7MpO6CvJ6Mx96zKZ46k+Sp00meaUoVRjIMVB4yqIyYVEaDUDDwUjXE9qg9xdcYUAbazrcITOGfYxIpdavBqE72A3lZVKYLdT5V+5ORMoJRLVYkaCEYbrcBN/JxYeBIkfy/1YFhIhoyxzwTqgSAScrXPj3ZntJ0CeQ5Gcj2knXT/KlpJ3c2/IEn2vZioLhyzqW8uP75bJu9ZXxaBXyX8j99jdie28nWX0bXCz+NjlQAEDZDlFkxrBH6pmdyd0Ge62uaex260h5dKS+4Tru5a4/Owva+bYns8CHSNlSuJaHvErIG/Lcfsh918MaBfwaG+q+jgEXVYdbNibJuboRZZSWaj0EZ+KEysCQIjKfhWg3yTeK+H/wbdf0gsDveeRb9ah+V6UG5qfH7ISaINkywomir9MWDFRGb6Bi7CCUATHIl7RLIyyYhmwCtOZ1s4q5Tf+C3jQ/Qke1idqSGm+uCVoG50Vmle8+cyIE7iT/wRfzy2XTe/AW8mmVA8MclYkYos6KjKlb0dfDHJuuO4Q/NDOJ6mq6MR1eqLxT0XdyiIBEECNcb/DmO9pw6cL+BT3P9YBnofHabXWaxbm6UtXOirJ0TYd3cKFXRCxm9UDQxFEZulIZRNHLDGHwRJaHUBY66c1IY2Z6LPMpkYmjTDoKAFS1JYJUAMM05nkNXtoRdAlCYTRAnBToYifBI8xPc2XAvO1v3oIBtszfz4vrruXL2JZglTK3W2b1U3v2PKCdJ9/UfJ7u0b9KgoPI2eK+BTWH5U0r/CVWC236ufiDfzIgunoGlr+lMEby+qcyp3fw9RaUdn2db0uxrTrG/KcW+5hQnOvoKvBZU2KydE80Fgwhr50QpD4/jcKvicEDfUE6tFANDxLn/WF/g79NM/H30vaC5f4KK/CaWQlvhoIvAioz5VSQAzAC+9unOdJP1S/wPxfdyQaCvq+FssoXfNN7Hbxruoy3TSW24mpvqruVF9c9nXmx2Sd7W6G2h8u4PYzfvp3fbO0lufWtJv5G5+amKh6kfULmgYSkLy8hfT69JPqaK3ozHgeY0+5uDQLCvKcXp7r5x3YuqQqydG2VdrpVg9ewoUXt8v71nXJ+ejEdvxqc749GT8ci4OrfsMpi5GSEHXhsq6EM3oFAUZuRmjex/PXhbZcTCDlkDhm5O0KRPF4HKJlDZXqZCkd+4U0Yw+ZUVCeYaOA8SAGaQhJMg4SRK/8KeC5mefhNqeL7HY61P8utTf2BHy240cNmsjbx04Qu5as6lFz6vgJsh/sA/Ez34W9LLttPz/I8EBV0lptGFgkLXH35528GhwB6xLmHI9xpi2lhx/jpTLgeai1oKmlI0J4L5CAwFS2vCha6D9XOjrJgVIVxUv+D5mt6sR08mOJEXTuZpj56sR0862NaT9ftuZ4L9ezMemSG6QsaboaC+MsTSmjBLqsMsrQnnboeIhXItE0WTPWll9k36NJW6M7xsMLRvwueXmJy0YeXqBSKjqheQADDDZL0s3dnu0nYJ5OUKBQfOm92cauO3jfdxV8N9tKTbWVq+kLeseBXPnbvtwoKA1kSf/gnlj3wdt2ZZMF9AxYIL/CFKRykKrQO2YRXCwUjy9QnBWGZm5AiG8dCacNjfnGZfUypoLWhK0ZGbO8E0gpNnxtX0ZEYufoTgZBsPm7mLQXnYpCJsUh42iBduB4/l9wtbqjBttO9rfE3uogdcD96mdRBKNLlrHQx71bn9PR0UeB5vz3KsPcPJrgxFq28zL27nQkGoKBiEi2omVL8Wg+JJoCZN64H2UdneYP0IMSraDOXqBYafX0ACwAw0bl0CEPy1clKFQsFinva5/8wjfP/ILziVOM3S8oW8ecWreN4FBoHQyceo+P3HwTBIbPsLnNmrcWuWBpNrTDL5UGAbQTAYTSgIhkLluiT8mTmKodS01jT1OuxrCkLByc4MMdsgHjGJh4ITeEXEJB4y+m7nTuixST4E0vU0p7qCMHC8I8PR9gzHc7czbt/vTk3UZElNX2vB0powS6vDzCqz+v98yghOJqYNRui8m5kv2BiL/LrTHvubU5zpdoJ/Nzr4t+P5waiYYFvweQ18zPM1ru677fka16dw2yt6rCxksqDCZl7cZkFFiHkVNgviNrVlVmHRr4ml+oYUDlipUALADNab7SXpjlOi9n1wEoVCwWL5IPCDI7/gZCEI3MLz5l4+5iBgdp6i8rcfxuo4BgSL73iV9bi1y3OXFXg1y/AqFky6Jk9Dqb6ug1GEAmklEGPha82ZbofjHRmOtfddjndk6Mn0nVzLQ0bQSpALBPnb8+N2MEVwfg0JM4Q2Q2DY49NK4Hu5mfzOPZy5N+P1KxDd35ymoWt0X3CC4YjBtWUoTEVhiGIwXDGYwS8/VXK/xxV0ZzxOdwdDbIvZhmJe3C4EguA6CAjz4zazy+0RJ+FyfRdPe4TNEi4vnK8XsKNgWBIAZrpx7RKAIQsF8zzt80CuReBk4jRLyut58/JbuGbeFWMLAtrH7D6N2X4Uq+0wVtsRrLYjmF0NufX5wLeieLXLcGtXBMGgJggIOhy/0J+0pM63+0BaCcRYaa1pS7oDQkGWo+0Z2pN9fe0hU7GoKsSS6jCLc/UFS6rDLKoKE42Gwci1EpihCw7ZIxX5DRwNsj83GiS/5/x4MBpk7dxgJMji6lDfyb3oJJ8/wZdKMutxpsfhbI/D6W6Hsz1ZznQ7waXHoS3Zv27BVDCn3GZ+LhDMrwgxL24xu1xRHfOpimnKQxHidnnJjrGYNizilbVEo2NrKb0oAUApdRPwFcAE/kNr/fkBj78c+DRB95oL/K3W+uGRXlMCQH+e79Gd7cbxx3FVLNcJgsAQK2952ueBs4/y/cO/4GSikSXl9bwpFwRKMrGQk8LqOF4UCoJrI9Pddwzlc4sCQRAOvKqFpV0Z7wIVhwJLmYXroZqj860Enq+HHc0wJbhp7KansFr24lWvIFt35aBmTDF+utMexzoynMi1FBzvyHKiI0NDV9+cDABzy+0gENSEWVwdZnFtlKWz4tTEYygrPPoJbAYU+WVdn0NtmULtxv6mFEfbM4X3nlVmsa7oZL92TpSa2OT5N1ss4/qc7ckHgmwhGOTvt/S6/f6NKqCmzKTMNjFV0EphFrU8FLdCWLnFnIr3sVT//c2iAGQZweiTaCzGn1+ziqrY8LMNDmfcA4BSygQOAi8EGoAdwOu11vuK9ikHElprrZTaBPxUa71mpNeVADCY1ppep5fUeM+k5aRzhYKDWxw87fPg2cf478M/52SikcXldbx5+atKFwSKaY2RbC20EuRDgdl5vDCNqDZs3JolhVDg1S7DK5+HH6vNrS438f17wTCyvtEH+daCoUKB6/uF4jJdXGjmB+k52KYndh0FN4Xd9DT2mZ3Yp3diN+9BFQVT346RXfg8skuvJ7PoeTAOIz/EuWVdn1NdWU50ZINgkAsIJzqyJIumnC4LGSypDroQFtfGWFxbzuJZFdTPimMXrx2gNV6qi6NNHbmTfZr9TSkOt2WChX2AqogZDOnMzfOwbk6U2eUXuR6hxLK+Q9pLk/WyOJ6mNeHR1OPR3BtctydB+1auJiGoUfD8vlY+Txffp18tg6eLthfVPBTXMQA8+IHrWFR7/v+OLkYAuAr4hNb6xtz9DwNorT83wv7f0VqvHel1JQAML+tlSTrJ8SkQzNM6qA/IJoccxpsPAt8/8nNO9DayqKyON6+4hWvnXTnuCxDhOZidJ/t1IVjthzETrf1/BCuMH63Fj9XgxWrwY7PwYzX4sdrcJXc7WgPmxf9GYhpGodhwpFAwFE1fOMhXnefDQT4s+KUKC/kT/ukdhM7sxGreg/JdtDJxZ63FWXAZzvxtOHPWY7XsI3zsXsIn7sNItaPNENn655BZej3ZRdeiI5UXeDDiQmmtaUm4hVBwoiNbaEHID8WEoPm7rirMkpoY1WU2R5p7ONiSLixzXR4ycvM39J3w58XtSV18OVpaa9J+hrSXxjvHmgVhMzx+XQBaE62opbw8PqbP9WIEgFcDN2mt35m7/2bgCq31ewfs90rgc8Ac4MVa60dGel0JAOfm+i4pN0XaTZd2KuFivheMFnDSQz9cCAK/4HhvA4vKFvDm5bdw7fyrxj8IDKDSXVjtRzF6WzBS7RiJVsxUO0ayDSPRFmxLdw39c0SqCoHAKw4HhdtBeNCh8nFtVTCUgamM3GyGRt99wyzMoDgWvtaFYWnn5KRQZ3ZjNO5ANTyOOvt04YSv56zDr9+Grr8cf/4lEO77w6d1MHzO9TSO62KefZLQsXsJH7sXM3EWrUycBduCMLDkOvwSTTolSieR9QotBic6MhxvD263JV2W1YRZOzfK+tzJvr4yNLlP9r6L1X4Iq3kP2orizLsEP1434r9fz/dIeWkyfmbY+UQGGs8AAFBeWUu0bGz1TxcjAPwZcOOAAHC51vpvhtn/GuBjWusXDPHYu4B3ASxatGjriRMnxnrcM4qvfdJumqSbHL9CwSEmEhp4DA81Pc5/H/55IQi8afktbJ+AIDAizwmCQDIXDAqXdsyi20aybcjpS7VhgmH3jcnOT96Sm6ylb6y2iVZGYZx28fbi/fptN0y0GcaPVuNHq3LXfRcdqcAwQ6UPCE4KTj8JDY/Dqcfh7B7wneC45q6H+sth4eWw4NJ+J/xzKdQ4eD7+2Wcwj/ye0LF7sLpOoFG4czeTWXo9mSXPx6+oH9uxC5GjMt3YTU9jNe3Gbnoq6Joa0F3qRWfhztuCM+8SnLlbcGetBsMm4wXf9p0xTF400wPAeXUB5PY5BmzTWrcOt4+0AIxN2k2TclPjVyw4Qn0A5IPADr5/+Occ6z3FwrIFvGn5K7lu/nMmVxA4F61R2d7BQSHVCb6L0l4w5tn3gnoE7QW3c9fBbb9v+1D7aH/wdieNke4Kbg88JBQ6UjkgHPTd1tFqdKwWFa1BxWZhRioxDHNwQHCScHr3MCf8DcHJvv5yqLsEQqX9w+Z6Hl7rYTj0O8yj92K1HggOqXZNUDOw9Hq86uUlfU8xDWmN2XUCq+mp4GTftBur40jwkDJwa1bhzt2MM28LzpxNKDeFffbJ4NK0G7PnNAC+FSE1ay2p2etJztlIavZa/PP8nZ/pAcAiKAK8HmgkKAJ8g9b6maJ9VgBHckWAlwK/Aur1CC8uAeDCOL5Dyk2RcTOl7x44R30ABEHg4aYd/HdREHjj8lfw/HnPKenCQ9NSbtlUI9XR/5LsQKU6MNKd/bdneoZ+GcPqHxAiVZg9Z7Ga9xWa9L05a/HrtgbN+gsuwQjHMZRCYeTmsB/f0KY7T+If+j0c/j3mmd0AuFVLySy5nuzS63FnrS1Zl4simA5aDTFnf377iMc64j+jkf+N5aeQzk9UI9NJnyc3jd2yr+/bfdNTGOkOAPxQHGfu5uCEP3czzpyNIxaeOr6D092A1fQE0aa9RJv3EOk4gtI+GkWmehnJORtIzdlAcvYG3PK5Ix7ajA4AuTd5EfBlgmGA39Faf1Yp9W4ArfVtSqkPAm8BHCAFfECGAV4cvvZJuSlSbqr03QPnqA/Iv//DTTv4/pFfcLTnJPWx+Txn7lYiRpiQGSJshAibuYsRImSGiBTdHvi4bUyPIqOS8px+oUCligNC/7DgR6vJ1l2Ks+ASnHkbg5ES55A/QRq5UFAcDhTBdfF2hRrbf6PeJjh8L/rw7+DUDpT28Mvn57oJrseZu7nfULXiE3q/xXqKTuiGCoZVGar/KpMTzdd9VeKu7+dmt/MndnTHJGIkmoNv92d3B9/uWw+gdNA871Yuzp3st+DM3YxXveyccxrki/oyXgZ3iGZ+5aSItu4n1ryHaPNeoi37MHPdB05sdl8gmLORTNXSfr+HMz4AjAcJAKWltSbjZUi6ySH/AVyQc9QHQBAE/ti0kx8d/V+O9zaQHWMXhUIRNkOEDDt3nQ8HYcKmTcSMUF82j1UVy1hZsYS6svlTq+thGlEKFEbuWuXCQn7R5nyo6AsL+TChCAo6jaMPYBy5B078CeVl0ZFKdNUSqKxDVdajKuqgog4q6yG+AKzzHyc92fQFAz83FS540z0Y+B5W+0Gss33f7s3eoIlem2GcORuKvt1vgmhN8LyiUJfPm33Lh3t4voenPTztkvEdPN8f/eRbvke44yixlr1Em/cSa96LnWwBwLNjQbfBnI0k52zAn7uZ8tic0n0eA0gAECXleA5JN0lmFNN3nt8Lj1wfUMzXPlnfIeNlyfpZ0l6WrJcl42fJeFkyXoaM7xRtywTbc9vSfiZ4LLct2D9Lyk1zKnG6EDAiZpgVFUtYWbGUlRVLWFWxjEVlC6QbYgpR2QShk48SatiB1X0ao+csRu/ZwlwQENRHqLLZQRioqIPKOqioz13XQXxeMPvdFBWMF/fxc5NGeRq8qTpxFIDWhDqeJXLoTkKHf4NKBCdXXTYHveAS9PxLghqU2WtQZqjvBD+gFUdrjauDE73ru7i+h6vdEbtYPK3xdPBZen4wbNbTGt8fubPU6m0i1ryXaMteYs17CHccQ6GD0S11V5Be8SKyS64LRgqVkAQAMS483wuGEXrp0nUPaB0UmWUTE7YMuOd7nEg0cqj7GIe6j3Oo6xiHe46TzgWekGGzPL6YlZVLC8FgSflC7Ek0o6A4B9/DSLRg9pzF7DmD0X0Gq/dMcL/7DKq3uV8hpVYGqnxuEAYKAaG4BWFe34ySWoPvgpsBLwNuOnc7G4RcLxPcLzyev6SDfdyi5xTvE58Pi64MCizPYzTFSDytcbzgZFaYRGYSBgNDBbPbhRKnsQ/dhfnsr1HtR4I1CpZeC6tuhLqtwWc0TPeRr/3gJK+9wone015JW0fyw2Xz82t4uVUhh2qFMbK9RFv2U9G8h/jxP2D2nEabYTKLryWz4mayC59bktApAUCMK601aS9N0kniDVF9PiajqA+4mDzt05g4w8HuY4VgcLj7GIlcP5+tLJbGF+YCwVJWVi5lWflCQlP4W+OM5rm5gHAmuHSfwew9g9kdBAbV21xYewJAKxMVLgc3i/YywQiOMdKoYCpkKwxmGG2FwbRRXY0oLxO817yNsOiqIBDMv6Tk3RfFLQb5meYuVo2BAizTwDKCa9vpxjx0N+z/FTTuCnaquwzWvgRW3gTRqkGvUTjB564d3xu/Ic6jpMnPzhe0Fni5SbdsI0S5WYbZ9BThw3cRPnI3RroDPxQnu/QFpFfcjDP/stFPtTyABABx0WS9bDB6oFTdA6OoD5govvY5k2zmUPexfsGgx+kFgul7F5fXsSofCiqWsiy+iKgVmeAjFxfMczASzUEwyIUElenJnbBDaDMcrKJnhXNrtufuF9/O72eF0WY4dz+UmxNiiG+xbgb77F5CjTsJNe7Eaj4QVJ5bYfwFl+IvvAIWXYk5Zz3GOC3nW1x8GNQZXNioBAUYhsI2FZZhFK5xM3D0vuCkf+zBYJhpzXJY+1JY8xKorEfnm+K1l2vGz3/Dd6dUzUPYDFEZKjpB+y6cfAQO/BoO/T5oES2bDatfhF7zkmBuDaX6/YzFP27+nJvfZkarMMLnLtYdigQAMSYln2XwPOoDJpLWmqZUay4M9AWDzmywKJGBYlF5HSsrlrK6chmrKpaxvGIxkVIuBypmBJXpxT79ZBAIGnZidRwHwA/Hceq24izchl9/Oap6KaZh5iZ96psAqpQ0xcEgmM1xqH7xYAEbhW0pzNwJv9Avr/1gXon9v4JDd0O2F102G3/VzWRX34w3ayUeGl/7uRP/5P5bMFqDAkAxJw3H7of9v4bjDwRfhKoWByFo7Uugeum53yBSCaGxrachAUBckJIOI8zXBzipSR8Eimmtac10BIGg62ghFLRnOoEgFCwury+EgpUVSyUUiPNmJFqxG3cRatxFqGEnZm8TAF7ZHLL1l5Gtvwynbit+2azcEEcDS5kYuVCQH1mRPyWrohEXFzI80/V9PK2xDSM3iiNQ+AbfvB914FeYB3+L0duEtmNklm0nteoGnAWXjrnpe6oYMQAUS3cFLQLP3gknHwV00Bqw5iWw6kUQH2a+AQkAYqLl6wRSbqo0wwh9L0jDXja4PseCG5NNXyg4ysGuYxzsPsqzXUf7WgqUweKyOlblWglWVS5leXwxYakpEKOhNWZXA3bjTkINQSjIL43tVi8hW7eVbP02nAVb0OHz6x9WheFzqt8QTVW4b+SCBP1CQ7653tM+dJ8hdOh3RA7+Dqv9CNowyS68kvSqG8gsfi7YM6ebbNQBoFhvExz8bdAy0LQHUMEMnGteAitvCE76eRIAxGQyLqsR+n4QBnwnFwxKPFfBRRCEgvZCIMhfF4eCJeX1hUCwqmIZy+OLpNBQnJv2sVoPEWrYGbQSnHkK5aaDaW9nryFbdyl+2axcTUIkKDy0+m5rM4y2+4oStRUOKvDPo0VAZXoIH7mPyKHfYZ/ejULjzN1AetWNpJc/Hz1EMd9MMKYAUKzjOBy4Ew78Krht2LD0miAMLNsejFCRACAmG9d3g/kExmO6Yd/PtQxkp2wggNzSq+n2XCAIug8Odh2lywmm7zWVGYSCymWsqghaCRbE5lIVqpAZD8XwvCx20zOFQGA37Rty/YiRaGXkihf7QoHOhYbioKCtMEamh9DJx1C+g1u5kPSqG8msfCFepSzidMEBIE9raH4mKB48cBckmoMpjVe/CF7wCahaeN4vKQFAjLtxnW44T2tws7lA4E7KUQWjpbWmOd3WLxAc7D5Kd270AUDMjLIgNof5sbksyF3qctezIrUyu6Hoz8uiskmUl0E5aZSXATeLctMoN1PYjpcJ7hcuwXwFaqjtRdtQJpklV5NedSPu7DXjuiT2VFOyAFDM96BhR1AvcOQ+eO8OiNWc98tIABAXjda6EARKNp/A8G/W10LgOkHXwRQaOjRQfvTBsd5TnE42cSbZxOlUE6eTzZxJNuEWfZ62spgbnZ0LBnNYEJtXCAnzo7OlS0GIi2hcAkCxUDlEJmYeAJkaTYyaUoqYHSNmx8h4GVJOqrR1Av3fLDd5SghC5AKBC34m6D7QAy/jcxilopRiXmw282KzBz3maZ/WdBunk025S3NwnWpib8ezJL2+Nc8VilmRahZE57KgbB4LonMK4WBBbC7l9tjGEwshJsgEjqKQACDGJGyGCZthHN8JCga9bOnrBIopBZYNDDOBypChwAuCwyQPC6YymBudzdzobC6p3dDvMa01XU5PUThoyrUgNPNYy5OFoYp5ETNMVaiCylAFlXacylBwqQpVUGEH11WhCipCcapCccqs2LgvCSzEVJX1shzoOsy6qpXMipx/E/1kJwFAXBDbsKkMV47PugPnwzCAUZ7IBoUFL3c9+cKCUqpw0l5XtXLQ4yk3zelUEAhOJ5toy3TQle2hM9tNV7abE70NdDk9hXUSBjKUQWUuGORDQRAcKnLhoYKqcCWV4Soqw5VUWGWEDHNKzfMgxGgl3CTPdBxkT8ezPN2xn2c7j+Bol49u/hteu+ylE314JScBQJSEaZiUh8op02W5qT41OjcrWL7OxNd+YVvx/fy++etxN6awUBQOKG5Z0EFBT/72RRa1IiyPL2Z5fPGI+6W9DF3Zntylm06nm+5sD51OD11ObyE0HOttpDPbTU+2Z9j/FiEjRDwUJx4qJ27HidtlxO0yyu0yKqwyyq0YFVaMcruMuF3e97hVJqs0ikmlI9PFno4DucuzHOk+jo/GVCYrK5bwysU3cemsDTx37raJPtRxIQFAlJRSCluNfZ7zfsEBDRp8ghCRDw6udnNLg16EoYPnExZg6K4I/L7tkAsKuigwDLhddHVBFKBMMEwidoRIrJq5BPdRRu566GpvT3v0ZHvozHTSnemmM9NJV6aLnmxPcHF6CrfPplo43HWM7mw3aW/kxaHKrFhfWLDLKbf6QkJFqJzacDWzwjXMilQzK1Ijsy2KkskX4uZP+E93HOBU4jQAYSPE2qoVvHH5K9lUs5a1lSsKa4KMexHgBJIAICaV/FSno+mXDtYK7wsD+ctFaUUYzvkGhpHkQ0E+MAzaxoDtOjixn+PkPhqmMqkKV1EVrjqv5zm+Q0+2h95s76CgMPDSm+2hNXE6t08vrh4c6OJ2WSEQ1IZrmB2pYVakhlnhICDMjtRQYcdlTgUxiNaaE4lG9rTvZ0/Hs+zpOEBzug2AcquMDdWruanuWjZWr2FV5bIZuVz4zPuJxbSRb22wjf4tDgMDgavdqbkwSf6kNoVObrZhUxOpoeY8C6a01iTdJG2pNlqTLbQkm2lNNtOaaqE11Uprqo0jPSfpyHQNCni2YTMrXE1tJGg9GCok1ISrZ+Qf+GnLMMG0g1n1cisrep7Doc5D7Gndw9Nte9nbtp/u3IRcteEqNlav5bXVa9hYs5ql5Qul+BUJAGIasgwLa8Af+/y64/lQ4PjO1AwF05RSijK7jDK7jEUVi4beSWtcJ017qoWWZBOtyVZaU820pttpS3fQkm7jYPdR/tS8k6w/eCKpMiuGbVjYhoWlLGzDDm4bFraysAyz8JjVb7/gfv524TUGPF5hlzMnUsuc6CwqpVWidBRghnDw6XRSdLkJOrPddGQ66Mx00pHu4HDnYZ5pfabQBVVXXsdz6q5m4+yNbJq1ifmxeaihCn/94utJNjzoIpAAIGYE0zAxMQkX9Sn72sfzPRzfKYQDz/cmtgtBDE8prFCUOaFFzKksCgl+7g+474D20K5DT6aLlnQrremOICBk2unO9uLofAAMWoYc3yu0Ejm+S8ZJ4vhObj+vaD+3337nEjJs5kRmMSdaG4SCwu1ZzInOYk6kdsYvHOVpnx6nl85MFx3Z7sLIlQ4nKE7tzBWmdma66Mx00ls0k2YxS1ksqljETUtvYuOsjWyctZHaaO0Qe57jG39xQa/vAbmwoKygFW4aBgQJAGLGMpSBYRrYZl8Xgta60ELgahfHc8Z/1kNxYQwjV3sR/HdUYagoq6HCX8TyfDDI/2FXKqiRgNx1rmYiv734ehjBMrm54Jj7XXF8h65MF83JZpp7z9KcOE1zoonmVCs7Wp+ifYiui0o7XggDheui2zXhqinVTO3nT+i5kSYduRN6Z+G6h85sV3BSzwajUPwhwraBQUW4gupwNVWRKlZUrwhuh6uoilQValPy98usstK0tigF5hCnRDMMoQpw08FS51522oQBCQBCFFFKYZv2sKEgf5HugynAyI14oLTftJVShW6CYnNjc1lVvapvg9bgZsBN42STtKbbaU630pxqozndVrjdmDzLk23P9JvxEcBSJrMiNcyJ1DI7Uku5XUbEDBcuUTNMxIz0bbPyj0Vyj0eImuExD730fI8up6ffCXzgCT0/pLQr2zPsCR2CoruqUAVV4QoWltexMbKBqkgtVZHq4FJ0co+H4phqkg0XVQrsaHDRetqEAQkAQpzDUKHA135fS4HUFIihKAV2BOwIdqic+ZE48915w06i1Oskc6GgNQgI+et0G890HiThpkh7GZwh6htGYitrUDgYeLENi26nN9fk3k2X002Pkxj6x0IRt8sKM04uKqujsrpvlsnKcCVVuYmjqsJVVEaqsIxwUKw31DfsqWYahYFp8F9DiIvPUAYhM9RvYZ58oaHjOVN79IEoPdMCMw7hODgZcFPB6pdFyu0Y5fYilsWHKYLM8XyPtJ8h7WZIe32XlJfuux7wWDr3WPB4cN3t9NKcbiXtZnC1S9wupzJUwfKKRcEMkHZFbibIOFW5JvjKcCUV4UpM0yLoPjH7ukwucOjplFQcBnw/CANuesqEAQkAQpTIUIWGxcMR860FUmQ4w9nh4OJ7QRBw0uc1tbJpmJQZMcqs2IUfi2GAYRWdvPMn9FxdxTnqIUQRw4BQLLgUhwF36Gm4JwMJAEKMo+GGJHo6qD73tV+4LcWGM4xhBkvBhsqD1gA3FZwsxiMf9jvRW8Ft05KT+3iZImFAAoAQF1m+pSA0xDCwfBDIdyfkw4J0JUxzVii45E8WTio3FO085QsfDTOYJKfft3sxIQaFgVzQmwRhQAKAEJOIZVhYWDCgCDo/7XFx60E+KEiXwjRSfLJwndzJIt2/VSC/xoM54Bu9nOgnP8OAUFlwyYeBCVwgSwKAEFPAcNMeQ9Cl4Gu/UHQ48FK8AqOYQiw7uOh47ttirtBuOlTSi74wMIHkN0mIKS7fpWAz8iqMnu/1W1mxcBmwTQLDJJMfTihEiUkAEGKGyAeF0SgEghECQ742Ib+EsxBiapEAIIQYRCmFqUYfGAqtB1rj4xdqE4brkhBCTDwJAEKIC2Yoo/+89SPkhuLWheLgUNyiUPyYEGJ8SAAQQlxUY21dGKqwUYodhRg7CQBCiEltUOvCCAbWLnjaG7KeIR8WpH5BzGQSAIQQ08b5ti4AhRBQ3P0wKCigB42UkPAgpjoJAEKIGU0phUKNupWhWHF4KL7k6xmKbwsx2UgAEEKIMTqf8FAcBvKjJDzt4ft926VFYeIZykChUEphYAw5+dZ0IQFACCEugn61DMP0UBSPhhjYgtCvqyIXFCZjN0T+5Jm/XRyOio974LZi53r8XO+df9/CiTx3O38s+f8Wg/afYVMpSwAQQohJQimFpc7/z3I+CORrGNDBibO4VmHQdf527mRbOGnnTojFJ8fi7fkTZ/C//vuN9wk0f6zFoUBrXTjJi/MjAUAIIaa4wolacV4FkFNNcctC38YJOphpQCKTEEIIMQNJABBCCCFmIAkAQgghxAwkAUAIIYSYgSQACCGEEDOQBAAhhBBiBhpVAFBK3aSUelYpdVgp9aEhHn+jUurp3OVPSqnNpT9UIYQQQpTKOQOAUsoEvg7cDKwDXq+UWjdgt2PAtVrrTcCngW+V+kCFEEIIUTqjaQG4HDistT6qtc4CPwZeXryD1vpPWuuO3N1HgfrSHqYQQgghSmk0AaAOOFV0vyG3bTjvAH5zIQclhBBCiPE1mqmAh5poccjVGZRS1xEEgOcO8/i7gHcBLFq0aJSHKIQQQohSG00AaAAWFt2vB04P3EkptQn4D+BmrXXbUC+ktf4WufoApVSLUurEeR+xOJdZQOtEH8QMJZ/9xJHPfuLIZz9xVl/Ik0cTAHYAK5VSS4FG4HXAG4p3UEotAn4BvFlrfXA0b6y1nn2exypGQSm1U2t92UQfx0wkn/3Ekc9+4shnP3GUUjsv5PnnDABaa1cp9V7gboJVrL+jtX5GKfXu3OO3AR8DaoFv5FZrcuUXQgghhJi8RrUcsNb6LuCuAdtuK7r9TuCdpT00IYQQQowXmQlw+pE5GCaOfPYTRz77iSOf/cS5oM9eaT1kQb8QQgghpjFpARBCCCFmIAkAU5RSaqFS6j6l1H6l1DNKqf+T216jlPq9UupQ7rp6oo91ulJKmUqpJ5VSv87dl8/+IlBKVSmlfqaUOpD7/b9KPvuLQyn1vtzfm71Kqf9RSkXksx8fSqnvKKWalVJ7i7YN+1krpT6cW6/nWaXUjaN5DwkAU5cL/L3Wei1wJfDXuTUaPgTcq7VeCdybuy/Gx/8B9hfdl8/+4vgK8Fut9RpgM8F/A/nsx5lSqg74/4DLtNYbCEaFvQ757MfLfwE3Ddg25Ged+9v/OmB97jnfyK3jMyIJAFOU1vqM1vqJ3O0egj+CdQTrNHwvt9v3gFdMyAFOc0qpeuDFBJNf5clnP86UUhXANcB/Amits1rrTuSzv1gsIKqUsoAYwaRw8tmPA631g0D7gM3DfdYvB36stc5orY8BhwnW8RmRBIBpQCm1BLgEeAyYq7U+A0FIAOZM4KFNZ18G/gHwi7bJZz/+lgEtwHdz3S//oZQqQz77cae1bgT+BTgJnAG6tNa/Qz77i2m4z/p81+wBJABMeUqpcuDnwN9qrbsn+nhmAqXUS4BmrfWuiT6WGcgCLgW+qbW+BEggTc4XRa6/+eXAUmABUKaUetPEHpXIGfWaPcUkAExhSimb4OT/Q631L3Kbm5RS83OPzweaJ+r4prGrgZcppY4TLI/9fKXUD5DP/mJoABq01o/l7v+MIBDIZz/+XgAc01q3aK0dgunfn4N89hfTcJ/1qNbsGUgCwBSlgjmX/xPYr7X+16KH7gDemrv9VuCXF/vYpjut9Ye11vVa6yUEhTd/0Fq/Cfnsx53W+ixwSimVXwTlemAf8tlfDCeBK5VSsdzfn+sJao/ks794hvus7wBep5QK59btWQk8fq4Xk4mApiil1HOBh4A99PVD/yNBHcBPgUUE/2D/TGs9sJBElIhSajvwfq31S5RStchnP+6UUlsIii9DwFHg7QRfZuSzH2dKqU8CryUYhfQkwRTw5chnX3JKqf8BthOsttgEfBz4X4b5rJVSHwH+nOC/zd9qrX9zzveQACCEEELMPNIFIIQQQsxAEgCEEEKIGUgCgBBCCDEDSQAQQgghZiAJAEIIIcQMJAFACCGEmIEkAAghhBAzkAQAIYQQYgb6/wHsvdpSnMDsJAAAAABJRU5ErkJggg==\n",
|
478 |
+
"text/plain": [
|
479 |
+
"<Figure size 504x288 with 1 Axes>"
|
480 |
+
]
|
481 |
+
},
|
482 |
+
"metadata": {
|
483 |
+
"needs_background": "light"
|
484 |
+
},
|
485 |
+
"output_type": "display_data"
|
486 |
+
}
|
487 |
+
],
|
488 |
+
"source": [
|
489 |
+
"y_min = min([global_results[k][2].min() for k in global_results])\n",
|
490 |
+
"y_max = max([global_results[k][2].max() for k in global_results])\n",
|
491 |
+
"\n",
|
492 |
+
"fig2 = plt.figure(constrained_layout=True, figsize=(7, 4))\n",
|
493 |
+
"axes = plt.axes()\n",
|
494 |
+
"axes.set_xlim(2, 100)\n",
|
495 |
+
"#axes.set_ylim(y_min, y_max)\n",
|
496 |
+
"for k in global_results:\n",
|
497 |
+
" plot_with_confidence_intervals(plt, global_results[k][1], global_results[k][2], global_results[k][3], label=k)\n",
|
498 |
+
" #plt.plot(global_results_train_steps[k][1], global_results_train_steps[k][0], label=k)\n",
|
499 |
+
"plt.legend(loc=\"upper right\")"
|
500 |
+
]
|
501 |
+
}
|
502 |
+
],
|
503 |
+
"metadata": {
|
504 |
+
"kernelspec": {
|
505 |
+
"display_name": "Python [conda env:prior-fitting]",
|
506 |
+
"language": "python",
|
507 |
+
"name": "conda-env-prior-fitting-py"
|
508 |
+
},
|
509 |
+
"language_info": {
|
510 |
+
"codemirror_mode": {
|
511 |
+
"name": "ipython",
|
512 |
+
"version": 3
|
513 |
+
},
|
514 |
+
"file_extension": ".py",
|
515 |
+
"mimetype": "text/x-python",
|
516 |
+
"name": "python",
|
517 |
+
"nbconvert_exporter": "python",
|
518 |
+
"pygments_lexer": "ipython3",
|
519 |
+
"version": "3.9.6"
|
520 |
+
}
|
521 |
+
},
|
522 |
+
"nbformat": 4,
|
523 |
+
"nbformat_minor": 2
|
524 |
+
}
|
prior-fitting/notebooks/FewShotOmniglot.ipynb
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 3,
|
6 |
+
"id": "976fbfea",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import sys\n",
|
11 |
+
"sys.path.insert(0,'..')"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": 4,
|
17 |
+
"id": "4b164f6b",
|
18 |
+
"metadata": {},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"import torch\n",
|
22 |
+
"from torch import nn\n",
|
23 |
+
"\n",
|
24 |
+
"from tqdm import tqdm\n",
|
25 |
+
"\n",
|
26 |
+
"\n",
|
27 |
+
"from train import train\n",
|
28 |
+
"import priors\n",
|
29 |
+
"import encoders\n",
|
30 |
+
"import positional_encodings\n",
|
31 |
+
"import utils\n",
|
32 |
+
"import bar_distribution\n",
|
33 |
+
"\n",
|
34 |
+
"\n",
|
35 |
+
"from samlib.utils import chunker"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "code",
|
40 |
+
"execution_count": null,
|
41 |
+
"id": "29d423b4",
|
42 |
+
"metadata": {},
|
43 |
+
"outputs": [],
|
44 |
+
"source": [
|
45 |
+
"mykwargs = \\\n",
|
46 |
+
"{\n",
|
47 |
+
" 'bptt': 5*5+1,\n",
|
48 |
+
"'nlayers': 6,\n",
|
49 |
+
" 'dropout': 0.0, 'steps_per_epoch': 100,\n",
|
50 |
+
" 'batch_size': 100}\n",
|
51 |
+
"mnist_jobs_5shot_pi_prior_search = [\n",
|
52 |
+
" pretrain_and_eval( {'num_features': 28 * 28, 'fuse_x_y': False, 'num_outputs': 5,\n",
|
53 |
+
" 'translations': False, 'jonas_style': True}, priors.stroke.DataLoader, Losses.ce, enc, emsize=emsize, nhead=nhead, warmup_epochs=warmup_epochs, nhid=nhid, y_encoder_generator=encoders.get_Canonical(5), lr=lr, epochs=epochs, single_eval_pos_gen=mykwargs['bptt']-1,\n",
|
54 |
+
" extra_prior_kwargs_dict={'num_features': 28*28, 'fuse_x_y': False, 'num_outputs':5, 'only_train_for_last_idx': True,\n",
|
55 |
+
" 'min_max_strokes': (1,max_strokes), 'min_max_len': (min_len, max_len), 'min_max_width': (min_width, max_width), 'max_offset': max_offset, 'max_target_offset': max_target_offset},\n",
|
56 |
+
" **mykwargs)\n",
|
57 |
+
" for max_strokes, min_len, max_len, min_width, max_width, max_offset, max_target_offset in random_hypers\n",
|
58 |
+
" for enc in [encoders.Linear] for emsize in [1024] for nhead in [4] for nhid in [emsize*2] for warmup_epochs in [5] for lr in [.00001] for epochs in [128,1024] for _ in range(1)]\n",
|
59 |
+
"\n",
|
60 |
+
"\n",
|
61 |
+
"\n"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"cell_type": "code",
|
66 |
+
"execution_count": null,
|
67 |
+
"id": "deb93d1e",
|
68 |
+
"metadata": {},
|
69 |
+
"outputs": [],
|
70 |
+
"source": [
|
71 |
+
"\n",
|
72 |
+
"@torch.inference_mode()\n",
|
73 |
+
"def get_acc(finetuned_model, eval_pos, device='cpu', steps=100, train_mode=False, **mykwargs):\n",
|
74 |
+
" finetuned_model.to(device)\n",
|
75 |
+
" finetuned_model.eval()\n",
|
76 |
+
"\n",
|
77 |
+
" t_dl = priors.omniglot.DataLoader(steps, batch_size=1000, seq_len=mykwargs['bptt'], train=train_mode,\n",
|
78 |
+
" **mykwargs['extra_prior_kwargs_dict'])\n",
|
79 |
+
"\n",
|
80 |
+
" ps = []\n",
|
81 |
+
" ys = []\n",
|
82 |
+
" for x, y in tqdm(t_dl):\n",
|
83 |
+
" p = finetuned_model(tuple(e.to(device) for e in x), single_eval_pos=eval_pos)\n",
|
84 |
+
" ps.append(p)\n",
|
85 |
+
" ys.append(y)\n",
|
86 |
+
"\n",
|
87 |
+
" ps = torch.cat(ps, 1)\n",
|
88 |
+
" ys = torch.cat(ys, 1)\n",
|
89 |
+
"\n",
|
90 |
+
" def acc(ps, ys):\n",
|
91 |
+
" return (ps.argmax(-1) == ys.to(ps.device)).float().mean()\n",
|
92 |
+
"\n",
|
93 |
+
" a = acc(ps[eval_pos], ys[eval_pos]).cpu()\n",
|
94 |
+
" print(a.item())\n",
|
95 |
+
" return a\n",
|
96 |
+
"\n",
|
97 |
+
"\n",
|
98 |
+
"def train_and_eval(*args, **kwargs):\n",
|
99 |
+
" r = train(*args, **kwargs)\n",
|
100 |
+
" model = r[-1]\n",
|
101 |
+
" acc = get_acc(model, -1, device='cuda:0', **kwargs).cpu()\n",
|
102 |
+
" model.to('cpu')\n",
|
103 |
+
" return [acc]\n",
|
104 |
+
"\n",
|
105 |
+
"def pretrain_and_eval(extra_prior_kwargs_dict_eval,*args, **kwargs):\n",
|
106 |
+
" r = train(*args, **kwargs)\n",
|
107 |
+
" model = r[-1]\n",
|
108 |
+
" kwargs['extra_prior_kwargs_dict'] = extra_prior_kwargs_dict_eval\n",
|
109 |
+
" acc = get_acc(model, -1, device='cuda:0', **kwargs).cpu()\n",
|
110 |
+
" model.to('cpu')\n",
|
111 |
+
" return r, acc"
|
112 |
+
]
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"cell_type": "code",
|
116 |
+
"execution_count": null,
|
117 |
+
"id": "706ecbb7",
|
118 |
+
"metadata": {},
|
119 |
+
"outputs": [],
|
120 |
+
"source": [
|
121 |
+
"\n",
|
122 |
+
"emsize = 1024\n",
|
123 |
+
"# mnist_jobs_5shot_pi[20].result()[-1].state_dict()\n",
|
124 |
+
"mykwargs = \\\n",
|
125 |
+
" {'bptt': 5 * 5 + 1,\n",
|
126 |
+
" 'nlayers': 6,\n",
|
127 |
+
" 'nhead': 4, 'emsize': emsize,\n",
|
128 |
+
" 'encoder_generator': encoders.Linear, 'nhid': emsize * 2}\n",
|
129 |
+
"results = train_and_eval(priors.omniglot.DataLoader, Losses.ce, y_encoder_generator=encoders.get_Canonical(5),\n",
|
130 |
+
" load_weights_from_this_state_dict=mnist_jobs_5shot_pi_prior_search[67][0][-1].state_dict(), epochs=32, lr=.00001, dropout=dropout,\n",
|
131 |
+
" single_eval_pos_gen=mykwargs['bptt'] - 1,\n",
|
132 |
+
" extra_prior_kwargs_dict={'num_features': 28 * 28, 'fuse_x_y': False, 'num_outputs': 5,\n",
|
133 |
+
" 'translations': True, 'jonas_style': True},\n",
|
134 |
+
" batch_size=100, steps_per_epoch=200, **mykwargs)\n",
|
135 |
+
"\n"
|
136 |
+
]
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"cell_type": "code",
|
140 |
+
"execution_count": null,
|
141 |
+
"id": "611554b2",
|
142 |
+
"metadata": {},
|
143 |
+
"outputs": [],
|
144 |
+
"source": []
|
145 |
+
}
|
146 |
+
],
|
147 |
+
"metadata": {
|
148 |
+
"kernelspec": {
|
149 |
+
"display_name": "Python 3 (ipykernel)",
|
150 |
+
"language": "python",
|
151 |
+
"name": "python3"
|
152 |
+
},
|
153 |
+
"language_info": {
|
154 |
+
"codemirror_mode": {
|
155 |
+
"name": "ipython",
|
156 |
+
"version": 3
|
157 |
+
},
|
158 |
+
"file_extension": ".py",
|
159 |
+
"mimetype": "text/x-python",
|
160 |
+
"name": "python",
|
161 |
+
"nbconvert_exporter": "python",
|
162 |
+
"pygments_lexer": "ipython3",
|
163 |
+
"version": "3.9.5"
|
164 |
+
}
|
165 |
+
},
|
166 |
+
"nbformat": 4,
|
167 |
+
"nbformat_minor": 5
|
168 |
+
}
|
prior-fitting/notebooks/SetupForGPFittingExperiments.ipynb
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 4,
|
6 |
+
"id": "111c502f",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import sys\n",
|
11 |
+
"sys.path.insert(0,'..')"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": 10,
|
17 |
+
"id": "e6b59ce3",
|
18 |
+
"metadata": {},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"import torch\n",
|
22 |
+
"from torch import nn\n",
|
23 |
+
"\n",
|
24 |
+
"\n",
|
25 |
+
"from train import train\n",
|
26 |
+
"import priors\n",
|
27 |
+
"import encoders\n",
|
28 |
+
"import positional_encodings\n",
|
29 |
+
"import utils\n",
|
30 |
+
"import bar_distribution\n",
|
31 |
+
"import transformer\n",
|
32 |
+
"\n",
|
33 |
+
"from samlib.utils import chunker"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"cell_type": "code",
|
38 |
+
"execution_count": 12,
|
39 |
+
"id": "acf7423d",
|
40 |
+
"metadata": {},
|
41 |
+
"outputs": [],
|
42 |
+
"source": [
|
43 |
+
"kwargs = \\\n",
|
44 |
+
"{\n",
|
45 |
+
" 'nlayers': 6, \n",
|
46 |
+
" 'dropout': 0.0, 'steps_per_epoch': 100, \n",
|
47 |
+
"}\n",
|
48 |
+
" \n",
|
49 |
+
" \n",
|
50 |
+
"def train_and_compare_fast_gp_mix(*args, **kwargs):\n",
|
51 |
+
" hps = kwargs['extra_prior_kwargs_dict']['hyperparameters']\n",
|
52 |
+
" num_features = kwargs['extra_prior_kwargs_dict']['num_features']\n",
|
53 |
+
" baseline_res = priors.fast_gp_mix.evaluate(\n",
|
54 |
+
" *args[0].get_batch_method(10000,kwargs['bptt'],num_features, hyperparameters=hps),\n",
|
55 |
+
" hyperparameters=hps, \n",
|
56 |
+
" use_mse=Losses.mse == args[2])\n",
|
57 |
+
" print(baseline_res, 'with fast_gp_mix')\n",
|
58 |
+
" \n",
|
59 |
+
" res = train(*args, **kwargs)\n",
|
60 |
+
" return res, baseline_res\n",
|
61 |
+
"\n",
|
62 |
+
"def train_and_compare_fast_gp(*args, num_evals=1000, **kwargs):\n",
|
63 |
+
" hps = kwargs['extra_prior_kwargs_dict']['hyperparameters']\n",
|
64 |
+
" num_features = kwargs['extra_prior_kwargs_dict']['num_features']\n",
|
65 |
+
" baseline_res = priors.fast_gp.evaluate(\n",
|
66 |
+
" *args[0].get_batch_method(num_evals,kwargs['bptt'],num_features, hyperparameters=hps, device='cpu'),\n",
|
67 |
+
" hyperparameters=hps, \n",
|
68 |
+
" use_mse=Losses.mse == args[2], device='cpu')\n",
|
69 |
+
" print(baseline_res, 'with fast_gp')\n",
|
70 |
+
" \n",
|
71 |
+
" res = train(*args, **kwargs)\n",
|
72 |
+
" return res, baseline_res\n",
|
73 |
+
"\n",
|
74 |
+
"def train_and_compare_gp(*args, num_evals=10000, **kwargs):\n",
|
75 |
+
" num_features = kwargs['extra_prior_kwargs_dict']['num_features']\n",
|
76 |
+
" baseline_res = priors.gp.evaluate(\n",
|
77 |
+
" *args[0].get_batch_method(num_evals,kwargs['bptt'],num_features),\n",
|
78 |
+
" use_mse=Losses.mse == args[2])\n",
|
79 |
+
" print(baseline_res, 'with fast_gp')\n",
|
80 |
+
" \n",
|
81 |
+
" res = train(*args, **kwargs)\n",
|
82 |
+
" return res, baseline_res\n",
|
83 |
+
"\n"
|
84 |
+
]
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"cell_type": "code",
|
88 |
+
"execution_count": 13,
|
89 |
+
"id": "da083e24",
|
90 |
+
"metadata": {},
|
91 |
+
"outputs": [],
|
92 |
+
"source": [
|
93 |
+
"import gpytorch\n",
|
94 |
+
"hps = {'noise': 1e-4, 'outputscale': 1., 'lengthscale': .6, 'fast_computations': (False,False,False)}\n",
|
95 |
+
"\n",
|
96 |
+
"import numpy as np, scipy.stats as st\n",
|
97 |
+
"\n",
|
98 |
+
"def compute_mean_and_conf_interval(accuracies, confidence=.95):\n",
|
99 |
+
" accuracies = np.array(accuracies)\n",
|
100 |
+
" n = len(accuracies)\n",
|
101 |
+
" m, se = np.mean(accuracies, -1), st.sem(accuracies, -1)\n",
|
102 |
+
" h = se * st.t.ppf((1 + confidence) / 2., n-1)\n",
|
103 |
+
" return m, h\n",
|
104 |
+
"\n",
|
105 |
+
"\n",
|
106 |
+
"def bl(hps,bptt, num_evals=100, num_features=1, step_size=1, evals_per_batch=None, speedups=(False,False,False,False)):\n",
|
107 |
+
" if evals_per_batch is None:\n",
|
108 |
+
" evals_per_batch = num_evals\n",
|
109 |
+
" else:\n",
|
110 |
+
" assert num_evals%evals_per_batch == 0\n",
|
111 |
+
" results = []\n",
|
112 |
+
" for batch_i in range(num_evals//evals_per_batch):\n",
|
113 |
+
" with gpytorch.settings.fast_computations(False,False,False):\n",
|
114 |
+
" batch = priors.fast_gp.get_batch(evals_per_batch,bptt,num_features, hyperparameters=hps)\n",
|
115 |
+
" with gpytorch.settings.fast_pred_var(speedups[0]), gpytorch.settings.fast_computations(*speedups[1:]):\n",
|
116 |
+
" all_res, baseline_res,_ = priors.fast_gp.evaluate(\n",
|
117 |
+
" *batch,\n",
|
118 |
+
" hyperparameters=hps, step_size=step_size\n",
|
119 |
+
" )\n",
|
120 |
+
" print(baseline_res, 'with fast_gp')\n",
|
121 |
+
" \n",
|
122 |
+
" results.append(all_res)\n",
|
123 |
+
" all_results = torch.cat(results,1) # seq x batch_size\n",
|
124 |
+
" return compute_mean_and_conf_interval(all_results) # mean array, var array\n",
|
125 |
+
" \n",
|
126 |
+
" \n",
|
127 |
+
"#settings = [{'num_evals':n,} for n in [100,1000]]\n",
|
128 |
+
" \n",
|
129 |
+
"#js = [ex.submit(bl, hps, 2000, step_size=100, evals_per_batch=2, num_features=5, **kwargs) for kwargs in settings]\n",
|
130 |
+
"\n",
|
131 |
+
"\n"
|
132 |
+
]
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "code",
|
136 |
+
"execution_count": null,
|
137 |
+
"id": "8088aa12",
|
138 |
+
"metadata": {},
|
139 |
+
"outputs": [],
|
140 |
+
"source": [
|
141 |
+
"# below you can simply replace the prior to priors.fast_gp_mix to do experiments over mixtures of GPs"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"cell_type": "code",
|
146 |
+
"execution_count": null,
|
147 |
+
"id": "165e683c",
|
148 |
+
"metadata": {},
|
149 |
+
"outputs": [],
|
150 |
+
"source": [
|
151 |
+
"num_features = 5\n",
|
152 |
+
"hps = {'noise': 1e-4, 'outputscale': 1., 'lengthscale': .6, 'fast_computations': (False,False,False)}\n",
|
153 |
+
"ys = priors.fast_gp.get_batch(100000,20,num_features, hyperparameters=hps)[1]\n",
|
154 |
+
"fivefeature_jobs = [\n",
|
155 |
+
" train(priors.fast_gp.DataLoader, bar_distribution.FullSupportBarDistribution(bar_distribution.get_bucket_limits(num_borders, ys=ys)), enc, emsize=emsize, nhead=nhead, warmup_epochs=warmup_epochs, y_encoder_generator=y_enc, pos_encoder_generator=pos_enc,\n",
|
156 |
+
" batch_size=batch_size, scheduler=decay, extra_prior_kwargs_dict={'num_features': num_features, 'fuse_x_y': False, 'hyperparameters': hps},\n",
|
157 |
+
" epochs=epochs, lr=lr, input_normalization=input_norm, bptt=2010, single_eval_pos_gen=single_eval_pos,aggregate_k_gradients=step_every, **kwargs) \n",
|
158 |
+
" for enc in [encoders.Linear] for y_enc in [encoders.Linear] for emsize in [512] for nhead in [4] for nhid in [emsize*2] for epochs in [50*25,100*25,200*25,400*25] \n",
|
159 |
+
" for warmup_epochs in [epochs//4] for input_norm in [False]\n",
|
160 |
+
" for batch_size in [4] for step_every in [100//batch_size] for lr in [.0001,.0003,.001] for decay in [utils.get_cosine_schedule_with_warmup] for num_borders in [1000,10000] \n",
|
161 |
+
" for single_eval_pos in [utils.get_weighted_single_eval_pos_sampler(2000)]\n",
|
162 |
+
" for pos_enc in [positional_encodings.PositionalEncoding if single_eval_pos is None else positional_encodings.NoPositionalEncoding] \n",
|
163 |
+
" for redo in range(1)\n",
|
164 |
+
"]\n",
|
165 |
+
"\n",
|
166 |
+
"\n",
|
167 |
+
"\n"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"cell_type": "code",
|
172 |
+
"execution_count": 14,
|
173 |
+
"id": "15d01f3b",
|
174 |
+
"metadata": {},
|
175 |
+
"outputs": [],
|
176 |
+
"source": [
|
177 |
+
"import numpy as np, scipy.stats as st\n",
|
178 |
+
"\n",
|
179 |
+
"def compute_mean_and_conf_interval(accuracies, confidence=.95):\n",
|
180 |
+
" accuracies = np.array(accuracies)\n",
|
181 |
+
" n = len(accuracies)\n",
|
182 |
+
" m, se = np.mean(accuracies), st.sem(accuracies)\n",
|
183 |
+
" h = se * st.t.ppf((1 + confidence) / 2., n-1)\n",
|
184 |
+
" return m, h\n",
|
185 |
+
"hps = {'noise': 1e-4, 'outputscale': 1., 'lengthscale': .6, 'fast_computations': (False,False,False)}\n",
|
186 |
+
"\n",
|
187 |
+
"@torch.inference_mode()\n",
|
188 |
+
"def run_test(model,device='cuda:0',step_size=100, start_pos=1, batch_size=1000, sub_batch_size=10, seq_len=2000):\n",
|
189 |
+
" assert batch_size % sub_batch_size == 0\n",
|
190 |
+
" model.to(device)\n",
|
191 |
+
"\n",
|
192 |
+
" model.eval()\n",
|
193 |
+
" nlls = []\n",
|
194 |
+
" nll_confidences = []\n",
|
195 |
+
" mses = []\n",
|
196 |
+
" max_mses = []\n",
|
197 |
+
" eval_positions = []\n",
|
198 |
+
" \n",
|
199 |
+
" def get_metrics(model, eval_pos, batch_size):\n",
|
200 |
+
" x,y, target_y = priors.fast_gp.get_batch(batch_size=batch_size, seq_len=eval_pos+1, num_features=5,hyperparameters=hps, device=device)\n",
|
201 |
+
" logits = model((x,y), single_eval_pos=eval_pos)\n",
|
202 |
+
" if isinstance(model.criterion,nn.GaussianNLLLoss):\n",
|
203 |
+
" nll = model.criterion(logits[0][...,0], target_y[eval_pos], var=logits[0][...,1].abs())\n",
|
204 |
+
" return nll, 0., 0.\n",
|
205 |
+
" means = model.criterion.mean(logits) # num_evals x batch_size\n",
|
206 |
+
" maxs = (model.criterion.borders[logits.argmax(-1)] + model.criterion.borders[logits.argmax(-1)+1])/2\n",
|
207 |
+
" mse = nn.MSELoss()\n",
|
208 |
+
" nll = model.criterion(logits[0], target_y[eval_pos])\n",
|
209 |
+
" return nll, mse(means[0], target_y[eval_pos]), mse(maxs[0], target_y[eval_pos])\n",
|
210 |
+
" \n",
|
211 |
+
" \n",
|
212 |
+
" for eval_pos in range(start_pos, seq_len, step_size):\n",
|
213 |
+
" eval_positions.append(eval_pos)\n",
|
214 |
+
" print(eval_pos)\n",
|
215 |
+
" \n",
|
216 |
+
" nll = []\n",
|
217 |
+
" mean_mse = []\n",
|
218 |
+
" max_mse = []\n",
|
219 |
+
" for i in range(batch_size//sub_batch_size):\n",
|
220 |
+
" batch_nll, batch_mean_mse, batch_max_mse = get_metrics(model, eval_pos, sub_batch_size)\n",
|
221 |
+
" nll.append(batch_nll)\n",
|
222 |
+
" mean_mse.append(batch_mean_mse)\n",
|
223 |
+
" max_mse.append(batch_max_mse)\n",
|
224 |
+
" \n",
|
225 |
+
" nll = torch.cat(nll)\n",
|
226 |
+
" mean_mse = torch.tensor(mean_mse).mean()\n",
|
227 |
+
" max_mse = torch.tensor(max_mse).mean()\n",
|
228 |
+
" \n",
|
229 |
+
" \n",
|
230 |
+
" mses.append(mean_mse)\n",
|
231 |
+
" max_mses.append(max_mse)\n",
|
232 |
+
" nlls.append(nll.mean())\n",
|
233 |
+
" nll_confidences.append(compute_mean_and_conf_interval(nll.to('cpu'))[1])\n",
|
234 |
+
" return eval_positions, torch.stack(mses).to('cpu'), torch.stack(max_mses).to('cpu'), torch.stack(nlls).to('cpu'), torch.tensor(nll_confidences).to('cpu')\n",
|
235 |
+
"\n",
|
236 |
+
"\n",
|
237 |
+
"\n"
|
238 |
+
]
|
239 |
+
},
|
240 |
+
{
|
241 |
+
"cell_type": "code",
|
242 |
+
"execution_count": null,
|
243 |
+
"id": "755e88e4",
|
244 |
+
"metadata": {},
|
245 |
+
"outputs": [],
|
246 |
+
"source": []
|
247 |
+
}
|
248 |
+
],
|
249 |
+
"metadata": {
|
250 |
+
"kernelspec": {
|
251 |
+
"display_name": "Python 3 (ipykernel)",
|
252 |
+
"language": "python",
|
253 |
+
"name": "python3"
|
254 |
+
},
|
255 |
+
"language_info": {
|
256 |
+
"codemirror_mode": {
|
257 |
+
"name": "ipython",
|
258 |
+
"version": 3
|
259 |
+
},
|
260 |
+
"file_extension": ".py",
|
261 |
+
"mimetype": "text/x-python",
|
262 |
+
"name": "python",
|
263 |
+
"nbconvert_exporter": "python",
|
264 |
+
"pygments_lexer": "ipython3",
|
265 |
+
"version": "3.9.5"
|
266 |
+
}
|
267 |
+
},
|
268 |
+
"nbformat": 4,
|
269 |
+
"nbformat_minor": 5
|
270 |
+
}
|
prior-fitting/notebooks/TabularEvalSimple.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
prior-fitting/notebooks/Untitled.ipynb
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 2,
|
6 |
+
"id": "a873fcbb",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import sys\n",
|
11 |
+
"sys.path.insert(0,'..')"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": 5,
|
17 |
+
"id": "56023c88",
|
18 |
+
"metadata": {},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"import random\n",
|
22 |
+
"\n",
|
23 |
+
"import numpy as np\n",
|
24 |
+
"import torch\n",
|
25 |
+
"from torch import nn\n",
|
26 |
+
"from sklearn.gaussian_process import GaussianProcessRegressor\n",
|
27 |
+
"from sklearn.gaussian_process.kernels import RBF, DotProduct, WhiteKernel\n",
|
28 |
+
"from priors.utils import get_batch_to_dataloader"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": 68,
|
34 |
+
"id": "036c690b",
|
35 |
+
"metadata": {},
|
36 |
+
"outputs": [],
|
37 |
+
"source": [
|
38 |
+
"def get_gp():\n",
|
39 |
+
" gp = GaussianProcessRegressor(\n",
|
40 |
+
" kernel=RBF(length_scale=.6, length_scale_bounds='fixed'),\n",
|
41 |
+
" random_state=0, optimizer=None)\n",
|
42 |
+
" return gp"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": 77,
|
48 |
+
"id": "ff8a3cd1",
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [],
|
51 |
+
"source": [
|
52 |
+
"seq_len = 4\n",
|
53 |
+
"num_features = 10\n",
|
54 |
+
"x = torch.rand(seq_len, num_features)\n",
|
55 |
+
"gpr = get_gp()\n",
|
56 |
+
"y = gpr.sample_y(x, random_state=random.randint(0, 2 ** 32)).squeeze()\n"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": 78,
|
62 |
+
"id": "46fe34a9",
|
63 |
+
"metadata": {},
|
64 |
+
"outputs": [
|
65 |
+
{
|
66 |
+
"name": "stdout",
|
67 |
+
"output_type": "stream",
|
68 |
+
"text": [
|
69 |
+
"[-0.29995838] [0.90399136]\n",
|
70 |
+
"[-0.1039504] [0.98874968]\n",
|
71 |
+
"[-0.03414801] [0.99876344]\n",
|
72 |
+
"[-0.01104748] [0.99986603]\n",
|
73 |
+
"[-0.00356252] [0.9999855]\n",
|
74 |
+
"[-0.00114827] [0.99999843]\n",
|
75 |
+
"[-0.00037014] [0.99999983]\n",
|
76 |
+
"[-0.00011934] [0.99999998]\n",
|
77 |
+
"[-3.8486538e-05] [1.]\n",
|
78 |
+
"[-1.24147253e-05] [1.]\n",
|
79 |
+
"[-4.00568455e-06] [1.]\n",
|
80 |
+
"[-1.2927993e-06] [1.]\n",
|
81 |
+
"[-4.17353027e-07] [1.]\n",
|
82 |
+
"[-1.34771328e-07] [1.]\n",
|
83 |
+
"[-4.35327732e-08] [1.]\n",
|
84 |
+
"[-1.40657691e-08] [1.]\n",
|
85 |
+
"[-4.54613576e-09] [1.]\n",
|
86 |
+
"[-1.46979425e-09] [1.]\n",
|
87 |
+
"[-4.75345491e-10] [1.]\n"
|
88 |
+
]
|
89 |
+
}
|
90 |
+
],
|
91 |
+
"source": [
|
92 |
+
"for num_copies in range(1,20):\n",
|
93 |
+
" gp = get_gp()\n",
|
94 |
+
" x_copied = x.tile((1,num_copies))\n",
|
95 |
+
" gp.fit(x_copied[:-1],y[:-1])\n",
|
96 |
+
" m,s = gp.predict(x_copied[-1].reshape(1,-1), return_std=True)\n",
|
97 |
+
" print(m,s)"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "code",
|
102 |
+
"execution_count": 79,
|
103 |
+
"id": "87752b3d",
|
104 |
+
"metadata": {},
|
105 |
+
"outputs": [
|
106 |
+
{
|
107 |
+
"data": {
|
108 |
+
"text/plain": [
|
109 |
+
"array([[1. , 0.1047567 , 0.17720387, 0.33463634],\n",
|
110 |
+
" [0.1047567 , 1. , 0.14686013, 0.04858264],\n",
|
111 |
+
" [0.17720387, 0.14686013, 1. , 0.32035965],\n",
|
112 |
+
" [0.33463634, 0.04858264, 0.32035965, 1. ]])"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
"execution_count": 79,
|
116 |
+
"metadata": {},
|
117 |
+
"output_type": "execute_result"
|
118 |
+
}
|
119 |
+
],
|
120 |
+
"source": [
|
121 |
+
"k = RBF(length_scale=.6, length_scale_bounds='fixed')\n",
|
122 |
+
"k(x)"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "code",
|
127 |
+
"execution_count": 80,
|
128 |
+
"id": "6a409ae5",
|
129 |
+
"metadata": {},
|
130 |
+
"outputs": [
|
131 |
+
{
|
132 |
+
"data": {
|
133 |
+
"text/plain": [
|
134 |
+
"array([[1.00000000e+00, 2.41799081e-19, 5.26006251e-15, 9.26592960e-10],\n",
|
135 |
+
" [2.41799081e-19, 1.00000000e+00, 1.48311381e-16, 1.10443925e-25],\n",
|
136 |
+
" [5.26006251e-15, 1.48311381e-16, 1.00000000e+00, 4.04686299e-10],\n",
|
137 |
+
" [9.26592960e-10, 1.10443925e-25, 4.04686299e-10, 1.00000000e+00]])"
|
138 |
+
]
|
139 |
+
},
|
140 |
+
"execution_count": 80,
|
141 |
+
"metadata": {},
|
142 |
+
"output_type": "execute_result"
|
143 |
+
}
|
144 |
+
],
|
145 |
+
"source": [
|
146 |
+
"k = RBF(length_scale=.6, length_scale_bounds='fixed')\n",
|
147 |
+
"k(x_copied)"
|
148 |
+
]
|
149 |
+
},
|
150 |
+
{
|
151 |
+
"cell_type": "code",
|
152 |
+
"execution_count": null,
|
153 |
+
"id": "24141432",
|
154 |
+
"metadata": {},
|
155 |
+
"outputs": [],
|
156 |
+
"source": []
|
157 |
+
}
|
158 |
+
],
|
159 |
+
"metadata": {
|
160 |
+
"kernelspec": {
|
161 |
+
"display_name": "Python 3 (ipykernel)",
|
162 |
+
"language": "python",
|
163 |
+
"name": "python3"
|
164 |
+
},
|
165 |
+
"language_info": {
|
166 |
+
"codemirror_mode": {
|
167 |
+
"name": "ipython",
|
168 |
+
"version": 3
|
169 |
+
},
|
170 |
+
"file_extension": ".py",
|
171 |
+
"mimetype": "text/x-python",
|
172 |
+
"name": "python",
|
173 |
+
"nbconvert_exporter": "python",
|
174 |
+
"pygments_lexer": "ipython3",
|
175 |
+
"version": "3.9.5"
|
176 |
+
}
|
177 |
+
},
|
178 |
+
"nbformat": 4,
|
179 |
+
"nbformat_minor": 5
|
180 |
+
}
|
prior-fitting/positional_encodings.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
|
7 |
+
# Protocol for positonal encodings.
|
8 |
+
# __init__(d_model, max_len=..[, more optionals])
|
9 |
+
# forward(x: (seq_len, bs, d_model)) -> Tensor of shape (*x.shape[:2],d_model) containing pos. embeddings
|
10 |
+
|
11 |
+
|
12 |
+
class NoPositionalEncoding(nn.Module):
|
13 |
+
def __init__(self, d_model, max_len=None):
|
14 |
+
super(NoPositionalEncoding, self).__init__()
|
15 |
+
pass
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return x #* math.sqrt(x.shape[-1])
|
19 |
+
|
20 |
+
|
21 |
+
class PositionalEncoding(nn.Module):
|
22 |
+
def __init__(self, d_model, max_len=5000):
|
23 |
+
super(PositionalEncoding, self).__init__()
|
24 |
+
pe = torch.zeros(max_len, d_model)
|
25 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
26 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
27 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
28 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
29 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
30 |
+
self.register_buffer('pe', pe)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x = self.pe[:x.size(0), :] + x # * math.sqrt(x.shape[-1])
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
class LearnedPositionalEncoding(nn.Module):
|
38 |
+
def __init__(self, d_model, max_len=5000):
|
39 |
+
super(LearnedPositionalEncoding, self).__init__()
|
40 |
+
self.max_seq_len = max_len
|
41 |
+
#self.positional_embeddings = nn.Embedding(max_len, d_model)
|
42 |
+
self.positional_embeddings = nn.Parameter(torch.empty(max_len, d_model))
|
43 |
+
nn.init.normal_(self.positional_embeddings, mean=0, std=d_model ** -0.5)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
seq_len, bs, d_model = x.shape
|
47 |
+
assert seq_len <= len(self.positional_embeddings), 'seq_len can be at most max_len.'
|
48 |
+
pos_emb = self.positional_embeddings[:seq_len]
|
49 |
+
return pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x #* math.sqrt(x.shape[-1])
|
50 |
+
|
51 |
+
|
52 |
+
class PairedScrambledPositionalEncodings(LearnedPositionalEncoding):
|
53 |
+
# TODO check whether it is a problem to use the same perm. for full batch
|
54 |
+
def forward(self, x):
|
55 |
+
seq_len, bs, d_model = x.shape
|
56 |
+
assert seq_len <= len(self.positional_embeddings), 'seq_len can be at most max_len.'
|
57 |
+
assert len(self.positional_embeddings) % 2 == 0, 'Please specify an even max_len.'
|
58 |
+
|
59 |
+
paired_embs = self.positional_embeddings.view(len(self.positional_embeddings), -1, 2)
|
60 |
+
pos_emb = paired_embs[torch.randperm(len(paired_embs))].view(*self.positional_embeddings.shape)[:seq_len]
|
61 |
+
|
62 |
+
return pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x #* math.sqrt(x.shape[-1])
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
prior-fitting/presentation/heatmap_bardistribution.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
An example of how to use this:
|
3 |
+
x ,y , y_target = priors.fast_gp.get_batch(1,100,num_features, hyperparameters=(1e-4,1.,.6), equidistant_x=True)
|
4 |
+
fig, ax = pyplot.subplots(figsize=[10,10])
|
5 |
+
plot_model_and_orig_curve(ax, SOME_MODEL, x, y, given_indices[10,40,60])
|
6 |
+
|
7 |
+
Don't worry it is normal to be slow...
|
8 |
+
"""
|
9 |
+
import matplotlib.patches as patches
|
10 |
+
import seaborn as sns
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
def add_rect(ax, coord, height, width, color):
|
15 |
+
rect = patches.Rectangle(coord, height, width, linewidth=1, edgecolor='none', facecolor=color)
|
16 |
+
|
17 |
+
# Add the patch to the Axes
|
18 |
+
ax.add_patch(rect)
|
19 |
+
|
20 |
+
|
21 |
+
def heatmap_with_box_sizes(ax, data: torch.Tensor, x_starts, x_ends, y_starts, y_ends,
|
22 |
+
palette=sns.color_palette("rocket", as_cmap=True), set_lims=True):
|
23 |
+
"""
|
24 |
+
Beware all x and y arrays should be sorted from small to large and the data will appear in that same order: Small indexes map to lower x/y-axis values.
|
25 |
+
"""
|
26 |
+
if set_lims:
|
27 |
+
ax.set_xlim(x_starts[0], x_ends[-1])
|
28 |
+
ax.set_ylim(y_starts[0], y_ends[-1])
|
29 |
+
|
30 |
+
data = (data - data.min()) / (data.max() - data.min())
|
31 |
+
|
32 |
+
for col_i, (col_start, col_end) in enumerate(zip(x_starts, x_ends)):
|
33 |
+
for row_i, (row_start, row_end) in enumerate(zip(y_starts, y_ends)):
|
34 |
+
add_rect(ax, (col_start, row_start), col_end - col_start, row_end - row_start,
|
35 |
+
palette(data[row_i, col_i].item()))
|
36 |
+
|
37 |
+
|
38 |
+
print(ax.get_ylim())
|
39 |
+
|
40 |
+
|
41 |
+
def plot_bar_distribution(ax, x: torch.Tensor, bar_borders: torch.Tensor, predictions: torch.Tensor, **kwargs):
|
42 |
+
x = x.squeeze()
|
43 |
+
predictions = predictions.squeeze()
|
44 |
+
assert len(x.shape) == 1 and len(predictions.shape) == 2 and len(predictions) == len(x) and len(
|
45 |
+
bar_borders.shape) == 1 and len(bar_borders) - 1 == predictions.shape[1]
|
46 |
+
|
47 |
+
y_starts = bar_borders[:-1]
|
48 |
+
y_ends = bar_borders[1:]
|
49 |
+
|
50 |
+
x, order = x.sort(0)
|
51 |
+
print(x.shape, predictions.shape, order.shape)
|
52 |
+
|
53 |
+
predictions = predictions[order] / (bar_borders[1:] - bar_borders[:-1])
|
54 |
+
print(predictions.shape)
|
55 |
+
|
56 |
+
# assume x is sorted
|
57 |
+
x_starts = torch.cat([x[0].unsqueeze(0), (x[1:] + x[:-1]) / 2])
|
58 |
+
x_ends = torch.cat([(x[1:] + x[:-1]) / 2, x[-1].unsqueeze(0), ])
|
59 |
+
|
60 |
+
heatmap_with_box_sizes(ax, predictions.T, x_starts, x_ends, y_starts, y_ends, **kwargs)
|
61 |
+
|
62 |
+
|
63 |
+
def plot_model_w_eval_pos(ax, model, x, y, single_eval_pos, softmax=False, min_max_y=None, **kwargs):
|
64 |
+
with torch.no_grad():
|
65 |
+
model.eval()
|
66 |
+
y_pred = model((x, y), single_eval_pos=single_eval_pos)
|
67 |
+
if softmax:
|
68 |
+
y_pred = y_pred.softmax(-1)
|
69 |
+
if min_max_y:
|
70 |
+
lowest_bar = torch.searchsorted(model.criterion.borders, min_max_y[0])
|
71 |
+
highest_bar = min(torch.searchsorted(model.criterion.borders, min_max_y[1]), len(model.criterion.borders))
|
72 |
+
borders = model.criterion.borders[lowest_bar:highest_bar]
|
73 |
+
y_pred = y_pred[..., lowest_bar:highest_bar - 1]
|
74 |
+
else:
|
75 |
+
borders = model.criterion.borders
|
76 |
+
plot_bar_distribution(ax, x[single_eval_pos:], borders, y_pred, **kwargs)
|
77 |
+
|
78 |
+
|
79 |
+
def plot_model_and_orig_curve(ax, model, x, y, given_indices=[0]):
|
80 |
+
"""
|
81 |
+
:param ax: A standard pyplot ax
|
82 |
+
:param model: A Transformer Model with `single_eval_pos`
|
83 |
+
:param x: A three-dimensional input tensor with x.shape[0]=1 and x.shape[2]=1
|
84 |
+
:param y: A two-dimensional tensor with y.shape[1]=0
|
85 |
+
:param given_indices: The indexes in y which should be given to the model (the training points)
|
86 |
+
:return:
|
87 |
+
"""
|
88 |
+
x_winput = torch.cat([x[given_indices], x], 0)
|
89 |
+
y_winput = torch.cat([y[given_indices], y], 0)
|
90 |
+
|
91 |
+
ax.plot(x.squeeze(), y.squeeze(), color='grey')
|
92 |
+
ax.plot(x.squeeze()[given_indices], y.squeeze()[given_indices], 'o', color='black')
|
93 |
+
plot_model_w_eval_pos(ax, model, x_winput, y_winput, len(given_indices),
|
94 |
+
min_max_y=(y.min() - .3, y.max() + .3), softmax=True,
|
95 |
+
palette=sns.cubehelix_palette(start=2, rot=0, dark=0.4, light=1, as_cmap=True))
|
96 |
+
|
97 |
+
|
prior-fitting/priors/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import fast_gp, gp, ridge, stroke, fast_gp_mix, mlp, omniglot, binarized_regression, pyro
|
2 |
+
|
3 |
+
|
4 |
+
|
prior-fitting/priors/binarized_regression.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import fast_gp, fast_gp_mix
|
2 |
+
from .utils import get_batch_to_dataloader
|
3 |
+
|
4 |
+
def regression_prior_to_binary(get_batch_function):
|
5 |
+
|
6 |
+
def binarized_get_batch_function(*args, assert_on=False, **kwargs):
|
7 |
+
x, y, target_y = get_batch_function(*args, **kwargs)
|
8 |
+
if assert_on:
|
9 |
+
assert y is target_y, "y == target_y is assumed by this function"
|
10 |
+
y = y.sigmoid().bernoulli()
|
11 |
+
return x, y, y
|
12 |
+
|
13 |
+
return binarized_get_batch_function
|
14 |
+
|
15 |
+
|
16 |
+
Binarized_fast_gp_dataloader = get_batch_to_dataloader(regression_prior_to_binary(fast_gp.get_batch))
|
17 |
+
Binarized_fast_gp_dataloader.num_outputs = 1
|
18 |
+
|
19 |
+
|
20 |
+
Binarized_fast_gp_mix_dataloader = get_batch_to_dataloader(regression_prior_to_binary(fast_gp_mix.get_batch))
|
21 |
+
Binarized_fast_gp_mix_dataloader.num_outputs = 1
|
prior-fitting/priors/fast_gp.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import gpytorch
|
6 |
+
|
7 |
+
from .utils import get_batch_to_dataloader
|
8 |
+
from utils import default_device
|
9 |
+
from .utils import order_by_y, normalize_data, normalize_by_used_features_f, Binarize
|
10 |
+
|
11 |
+
|
12 |
+
# We will use the simplest form of GP model, exact inference
|
13 |
+
class ExactGPModel(gpytorch.models.ExactGP):
|
14 |
+
def __init__(self, train_x, train_y, likelihood):
|
15 |
+
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
|
16 |
+
self.mean_module = gpytorch.means.ConstantMean()
|
17 |
+
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
mean_x = self.mean_module(x)
|
21 |
+
covar_x = self.covar_module(x)
|
22 |
+
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
|
23 |
+
|
24 |
+
|
25 |
+
def get_model(x, y, hyperparameters):
|
26 |
+
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9))
|
27 |
+
model = ExactGPModel(x, y, likelihood)
|
28 |
+
model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters["noise"]
|
29 |
+
model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters["outputscale"]
|
30 |
+
model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \
|
31 |
+
hyperparameters["lengthscale"]
|
32 |
+
return model, likelihood
|
33 |
+
|
34 |
+
|
35 |
+
@torch.no_grad()
|
36 |
+
def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=None, equidistant_x=False, fix_x=None):
|
37 |
+
if isinstance(hyperparameters, (tuple, list)):
|
38 |
+
hyperparameters = {"noise": hyperparameters[0], "outputscale": hyperparameters[1], "lengthscale": hyperparameters[2]}
|
39 |
+
elif hyperparameters is None:
|
40 |
+
hyperparameters = {"noise": .1, "outputscale": .1, "lengthscale": .1}
|
41 |
+
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))):
|
42 |
+
start = time.time()
|
43 |
+
|
44 |
+
assert not (equidistant_x and (fix_x is not None))
|
45 |
+
if equidistant_x:
|
46 |
+
assert num_features == 1
|
47 |
+
x = torch.linspace(0,1.,seq_len).unsqueeze(0).repeat(batch_size,1).unsqueeze(-1).to(device)
|
48 |
+
elif fix_x is not None:
|
49 |
+
assert fix_x.shape == (seq_len, num_features)
|
50 |
+
x = fix_x.unsqueeze(0).repeat(batch_size, 1, 1).to(device)
|
51 |
+
else:
|
52 |
+
x = torch.rand(batch_size, seq_len, num_features, device=device)
|
53 |
+
model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
|
54 |
+
model.to(device)
|
55 |
+
# trained_model = ExactGPModel(train_x, train_y, likelihood).cuda()
|
56 |
+
# trained_model.eval()
|
57 |
+
with gpytorch.settings.prior_mode(True):
|
58 |
+
d = model(x)
|
59 |
+
d = likelihood(d)
|
60 |
+
sample = d.sample().transpose(0, 1)
|
61 |
+
#print(f'took {time.time() - start}')
|
62 |
+
return x.transpose(0, 1), sample, sample # x.shape = (T,B,H)
|
63 |
+
|
64 |
+
# TODO: Reintegrate this code
|
65 |
+
# num_features_used = num_features_used_sampler()
|
66 |
+
# prior_outputscale = prior_outputscale_sampler()
|
67 |
+
# prior_lengthscale = prior_lengthscale_sampler()
|
68 |
+
#
|
69 |
+
# x, sample = normalize_data(x), normalize_data(sample)
|
70 |
+
#
|
71 |
+
# if is_binary_classification:
|
72 |
+
# sample = (sample > torch.median(sample, dim=0)[0]).float()
|
73 |
+
#
|
74 |
+
# if normalize_by_used_features:
|
75 |
+
# x = normalize_by_used_features_f(x, num_features_used, num_features)
|
76 |
+
#
|
77 |
+
# # # if is_binary_classification and order_y:
|
78 |
+
# # # x, sample = order_by_y(x, sample)
|
79 |
+
# #
|
80 |
+
# # Append empty features if enabled
|
81 |
+
# x = torch.cat([x, torch.zeros((x.shape[0], x.shape[1], num_features - num_features_used), device=device)], -1)
|
82 |
+
|
83 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
84 |
+
DataLoader.num_outputs = 1
|
85 |
+
|
86 |
+
def get_model_on_device(x,y,hyperparameters,device):
|
87 |
+
model, likelihood = get_model(x, y, hyperparameters)
|
88 |
+
model.to(device)
|
89 |
+
return model, likelihood
|
90 |
+
|
91 |
+
|
92 |
+
@torch.no_grad()
|
93 |
+
def evaluate(x, y, y_non_noisy, use_mse=False, hyperparameters={}, get_model_on_device=get_model_on_device, device=default_device, step_size=1, start_pos=0):
|
94 |
+
start_time = time.time()
|
95 |
+
losses_after_t = [.0] if start_pos == 0 else []
|
96 |
+
all_losses_after_t = []
|
97 |
+
|
98 |
+
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))), gpytorch.settings.fast_pred_var(False):
|
99 |
+
for t in range(max(start_pos, 1), len(x), step_size):
|
100 |
+
loss_sum = 0.
|
101 |
+
model, likelihood = get_model_on_device(x[:t].transpose(0, 1), y[:t].transpose(0, 1), hyperparameters, device)
|
102 |
+
|
103 |
+
|
104 |
+
model.eval()
|
105 |
+
# print([t.shape for t in model.train_inputs])
|
106 |
+
# print(x[:t].transpose(0,1).shape, x[t].unsqueeze(1).shape, y[:t].transpose(0,1).shape)
|
107 |
+
f = model(x[t].unsqueeze(1))
|
108 |
+
l = likelihood(f)
|
109 |
+
means = l.mean.squeeze()
|
110 |
+
varis = l.covariance_matrix.squeeze()
|
111 |
+
# print(l.variance.squeeze(), l.mean.squeeze(), y[t])
|
112 |
+
|
113 |
+
assert len(means.shape) == len(varis.shape) == 1
|
114 |
+
assert len(means) == len(varis) == x.shape[1]
|
115 |
+
|
116 |
+
if use_mse:
|
117 |
+
c = nn.MSELoss(reduction='none')
|
118 |
+
ls = c(means, y[t])
|
119 |
+
else:
|
120 |
+
ls = -l.log_prob(y[t].unsqueeze(1))
|
121 |
+
|
122 |
+
losses_after_t.append(ls.mean())
|
123 |
+
all_losses_after_t.append(ls.flatten())
|
124 |
+
return torch.stack(all_losses_after_t).to('cpu'), torch.tensor(losses_after_t).to('cpu'), time.time() - start_time
|
125 |
+
|
126 |
+
if __name__ == '__main__':
|
127 |
+
hps = (.1,.1,.1)
|
128 |
+
for redo_idx in range(1):
|
129 |
+
print(
|
130 |
+
evaluate(*get_batch(1000, 10, hyperparameters=hps, num_features=10), use_mse=False, hyperparameters=hps))
|
prior-fitting/priors/fast_gp_mix.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import functools
|
3 |
+
import random
|
4 |
+
import math
|
5 |
+
import traceback
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
import gpytorch
|
10 |
+
from botorch.models import SingleTaskGP
|
11 |
+
from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL
|
12 |
+
from botorch.fit import fit_gpytorch_model
|
13 |
+
from gpytorch.mlls import ExactMarginalLogLikelihood
|
14 |
+
from gpytorch.likelihoods import GaussianLikelihood
|
15 |
+
from gpytorch.priors.torch_priors import GammaPrior
|
16 |
+
from gpytorch.constraints import GreaterThan
|
17 |
+
|
18 |
+
|
19 |
+
from bar_distribution import BarDistribution
|
20 |
+
from utils import default_device
|
21 |
+
from .utils import get_batch_to_dataloader
|
22 |
+
from . import fast_gp
|
23 |
+
|
24 |
+
def get_model(x, y, hyperparameters: dict, sample=True):
|
25 |
+
aug_batch_shape = SingleTaskGP(x,y.unsqueeze(-1))._aug_batch_shape
|
26 |
+
noise_prior = GammaPrior(hyperparameters.get('noise_concentration',1.1), hyperparameters.get('noise_rate',0.05))
|
27 |
+
noise_prior_mode = (noise_prior.concentration - 1) / noise_prior.rate
|
28 |
+
likelihood = GaussianLikelihood(
|
29 |
+
noise_prior=noise_prior,
|
30 |
+
batch_shape=aug_batch_shape,
|
31 |
+
noise_constraint=GreaterThan(
|
32 |
+
MIN_INFERRED_NOISE_LEVEL,
|
33 |
+
transform=None,
|
34 |
+
initial_value=noise_prior_mode,
|
35 |
+
),
|
36 |
+
)
|
37 |
+
model = SingleTaskGP(x, y.unsqueeze(-1),
|
38 |
+
covar_module=gpytorch.kernels.ScaleKernel(
|
39 |
+
gpytorch.kernels.MaternKernel(
|
40 |
+
nu=hyperparameters.get('nu',2.5),
|
41 |
+
ard_num_dims=x.shape[-1],
|
42 |
+
batch_shape=aug_batch_shape,
|
43 |
+
lengthscale_prior=gpytorch.priors.GammaPrior(hyperparameters.get('lengthscale_concentration',3.0), hyperparameters.get('lengthscale_rate',6.0)),
|
44 |
+
),
|
45 |
+
batch_shape=aug_batch_shape,
|
46 |
+
outputscale_prior=gpytorch.priors.GammaPrior(hyperparameters.get('outputscale_concentration',.5), hyperparameters.get('outputscale_rate',0.15)),
|
47 |
+
), likelihood=likelihood)
|
48 |
+
|
49 |
+
likelihood = model.likelihood
|
50 |
+
if sample:
|
51 |
+
sampled_model = model.pyro_sample_from_prior()
|
52 |
+
return sampled_model, sampled_model.likelihood
|
53 |
+
else:
|
54 |
+
assert not(hyperparameters.get('sigmoid', False)) and not(hyperparameters.get('y_minmax_norm', False)), "Sigmoid and y_minmax_norm can only be used to sample models..."
|
55 |
+
return model, likelihood
|
56 |
+
|
57 |
+
|
58 |
+
@torch.no_grad()
|
59 |
+
def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=None,
|
60 |
+
batch_size_per_gp_sample=None, num_outputs=1,
|
61 |
+
fix_to_range=None, equidistant_x=False):
|
62 |
+
'''
|
63 |
+
This function is very similar to the equivalent in .fast_gp. The only difference is that this function operates over
|
64 |
+
a mixture of GP priors.
|
65 |
+
:param batch_size:
|
66 |
+
:param seq_len:
|
67 |
+
:param num_features:
|
68 |
+
:param device:
|
69 |
+
:param hyperparameters:
|
70 |
+
:param for_regression:
|
71 |
+
:return:
|
72 |
+
'''
|
73 |
+
assert num_outputs == 1
|
74 |
+
hyperparameters = hyperparameters or {}
|
75 |
+
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))):
|
76 |
+
batch_size_per_gp_sample = (batch_size_per_gp_sample or max(batch_size // 10,1))
|
77 |
+
assert batch_size % batch_size_per_gp_sample == 0
|
78 |
+
|
79 |
+
total_num_candidates = batch_size*(2**(fix_to_range is not None))
|
80 |
+
num_candidates = batch_size_per_gp_sample * (2**(fix_to_range is not None))
|
81 |
+
if equidistant_x:
|
82 |
+
assert num_features == 1
|
83 |
+
x = torch.linspace(0,1.,seq_len).unsqueeze(0).repeat(total_num_candidates,1).unsqueeze(-1)
|
84 |
+
else:
|
85 |
+
x = torch.rand(total_num_candidates, seq_len, num_features, device=device)
|
86 |
+
samples = []
|
87 |
+
for i in range(0,total_num_candidates,num_candidates):
|
88 |
+
num_of_dims ~ uniform
|
89 |
+
model, likelihood = get_model(x[i:i+num_candidates,...,:num_of_dims], torch.zeros(num_candidates,x.shape[1]), hyperparameters)
|
90 |
+
x[i:i + num_candidates, ..., num_of_dims:] = 0
|
91 |
+
x[i:i + num_candidates, ..., :num_of_dims] *= total_dims/num_of_dims
|
92 |
+
#print(model.covar_module.base_kernel.lengthscale)
|
93 |
+
model.to(device)
|
94 |
+
# trained_model = ExactGPModel(train_x, train_y, likelihood).cuda()
|
95 |
+
# trained_model.eval()
|
96 |
+
successful_sample = 0
|
97 |
+
throwaway_share = 0.
|
98 |
+
while successful_sample < 1:
|
99 |
+
with gpytorch.settings.prior_mode(True):
|
100 |
+
d = model(x[i:i+num_candidates])
|
101 |
+
d = likelihood(d)
|
102 |
+
sample = d.sample() # bs_per_gp_s x T
|
103 |
+
if hyperparameters.get('y_minmax_norm'):
|
104 |
+
sample = ((sample - sample.min(1)[0]) / (sample.max(1)[0] - sample.min(1)[0]))
|
105 |
+
if hyperparameters.get('sigmoid'):
|
106 |
+
sample = sample.sigmoid()
|
107 |
+
if fix_to_range is None:
|
108 |
+
samples.append(sample.transpose(0, 1))
|
109 |
+
successful_sample = True
|
110 |
+
continue
|
111 |
+
smaller_mask = sample < fix_to_range[0]
|
112 |
+
larger_mask = sample >= fix_to_range[1]
|
113 |
+
in_range_mask = ~ (smaller_mask | larger_mask).any(1)
|
114 |
+
throwaway_share += (~in_range_mask[:batch_size_per_gp_sample]).sum()/batch_size_per_gp_sample
|
115 |
+
if in_range_mask.sum() < batch_size_per_gp_sample:
|
116 |
+
successful_sample -= 1
|
117 |
+
if successful_sample < 100:
|
118 |
+
print("Please change hyper-parameters (e.g. decrease outputscale_mean) it"
|
119 |
+
"seems like the range is set to tight for your hyper-parameters.")
|
120 |
+
continue
|
121 |
+
|
122 |
+
x[i:i+batch_size_per_gp_sample] = x[i:i+num_candidates][in_range_mask][:batch_size_per_gp_sample]
|
123 |
+
sample = sample[in_range_mask][:batch_size_per_gp_sample]
|
124 |
+
samples.append(sample.transpose(0, 1))
|
125 |
+
successful_sample = True
|
126 |
+
if random.random() < .01:
|
127 |
+
print('throwaway share', throwaway_share/(batch_size//batch_size_per_gp_sample))
|
128 |
+
|
129 |
+
#print(f'took {time.time() - start}')
|
130 |
+
sample = torch.cat(samples, 1)
|
131 |
+
x = x.view(-1,batch_size,seq_len,num_features)[0]
|
132 |
+
# TODO think about enabling the line below
|
133 |
+
#sample = sample - sample[0, :].unsqueeze(0).expand(*sample.shape)
|
134 |
+
x = x.transpose(0,1)
|
135 |
+
assert x.shape[:2] == sample.shape[:2]
|
136 |
+
target_sample = sample
|
137 |
+
return x, sample, target_sample # x.shape = (T,B,H)
|
138 |
+
|
139 |
+
|
140 |
+
class DataLoader(get_batch_to_dataloader(get_batch)):
|
141 |
+
num_outputs = 1
|
142 |
+
@torch.no_grad()
|
143 |
+
def validate(self, model, step_size=1, start_pos=0):
|
144 |
+
if isinstance(model.criterion, BarDistribution):
|
145 |
+
(x,y), target_y = self.gbm(**self.get_batch_kwargs, fuse_x_y=self.fuse_x_y)
|
146 |
+
model.eval()
|
147 |
+
losses = []
|
148 |
+
for eval_pos in range(start_pos, len(x), step_size):
|
149 |
+
logits = model((x,y), single_eval_pos=eval_pos)
|
150 |
+
means = model.criterion.mean(logits) # num_evals x batch_size
|
151 |
+
mse = nn.MSELoss()
|
152 |
+
losses.append(mse(means[0], target_y[eval_pos]))
|
153 |
+
model.train()
|
154 |
+
return torch.stack(losses)
|
155 |
+
else:
|
156 |
+
return 123.
|
157 |
+
|
158 |
+
|
159 |
+
@torch.enable_grad()
|
160 |
+
def get_fitted_model(x, y, hyperparameters, device):
|
161 |
+
# fit the gaussian process
|
162 |
+
model, likelihood = get_model(x,y,hyperparameters,sample=False)
|
163 |
+
#print(model.covar_module.base_kernel.lengthscale)
|
164 |
+
model.to(device)
|
165 |
+
mll = ExactMarginalLogLikelihood(likelihood, model)
|
166 |
+
model.train()
|
167 |
+
fit_gpytorch_model(mll)
|
168 |
+
#print(model.covar_module.base_kernel.lengthscale)
|
169 |
+
return model, likelihood
|
170 |
+
|
171 |
+
|
172 |
+
evaluate = functools.partial(fast_gp.evaluate, get_model_on_device=get_fitted_model)
|
173 |
+
|
174 |
+
def get_mcmc_model(x, y, hyperparameters, device, num_samples, warmup_steps):
|
175 |
+
from pyro.infer.mcmc import NUTS, MCMC
|
176 |
+
import pyro
|
177 |
+
x = x.to(device)
|
178 |
+
y = y.to(device)
|
179 |
+
model, likelihood = get_model(x, y, hyperparameters, sample=False)
|
180 |
+
model.to(device)
|
181 |
+
|
182 |
+
|
183 |
+
def pyro_model(x, y):
|
184 |
+
sampled_model = model.pyro_sample_from_prior()
|
185 |
+
_ = sampled_model.likelihood(sampled_model(x))
|
186 |
+
return y
|
187 |
+
|
188 |
+
nuts_kernel = NUTS(pyro_model, adapt_step_size=True)
|
189 |
+
mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps)
|
190 |
+
#print(x.shape)
|
191 |
+
mcmc_run.run(x, y)
|
192 |
+
model.pyro_load_from_samples(mcmc_run.get_samples())
|
193 |
+
model.eval()
|
194 |
+
# test_x = torch.linspace(0, 1, 101).unsqueeze(-1)
|
195 |
+
# test_y = torch.sin(test_x * (2 * math.pi))
|
196 |
+
# expanded_test_x = test_x.unsqueeze(0).repeat(num_samples, 1, 1)
|
197 |
+
# output = model(expanded_test_x)
|
198 |
+
#print(x.shape)
|
199 |
+
return model, likelihood
|
200 |
+
# output = model(x[-1].unsqueeze(1).repeat(1, num_samples 1))
|
201 |
+
# return output.mean
|
202 |
+
|
203 |
+
|
204 |
+
|
205 |
+
|
206 |
+
def get_mean_logdensity(dists, x: torch.Tensor, full_range=None):
|
207 |
+
means = torch.cat([d.mean.squeeze() for d in dists], 0)
|
208 |
+
vars = torch.cat([d.variance.squeeze() for d in dists], 0)
|
209 |
+
assert len(means.shape) == 1 and len(vars.shape) == 1
|
210 |
+
dist = torch.distributions.Normal(means, vars.sqrt())
|
211 |
+
#logprobs = torch.cat([d.log_prob(x) for d in dists], 0)
|
212 |
+
logprobs = dist.log_prob(x)
|
213 |
+
if full_range is not None:
|
214 |
+
used_weight = 1. - (dist.cdf(torch.tensor(full_range[0])) + (1.-dist.cdf(torch.tensor(full_range[1]))))
|
215 |
+
if torch.isinf(-torch.log(used_weight)).any() or torch.isinf(torch.log(used_weight)).any():
|
216 |
+
print('factor is inf', -torch.log(used_weight))
|
217 |
+
logprobs -= torch.log(used_weight)
|
218 |
+
assert len(logprobs.shape) == 1
|
219 |
+
#print(logprobs)
|
220 |
+
return torch.logsumexp(logprobs, 0) - math.log(len(logprobs))
|
221 |
+
|
222 |
+
|
223 |
+
def evaluate_(x, y, y_non_noisy, hyperparameters=None, device=default_device, num_samples=100, warmup_steps=300,
|
224 |
+
full_range=None, min_seq_len=0, use_likelihood=False):
|
225 |
+
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))), gpytorch.settings.fast_pred_var(False):
|
226 |
+
x = x.to(device)
|
227 |
+
y = y.to(device)
|
228 |
+
start_time = time.time()
|
229 |
+
losses_after_t = [.0] if min_seq_len == 0 else []
|
230 |
+
all_losses = []
|
231 |
+
|
232 |
+
for t in range(max(min_seq_len,1), len(x)):
|
233 |
+
#print('Timestep', t)
|
234 |
+
loss_sum = 0.
|
235 |
+
step_losses = []
|
236 |
+
start_step = time.time()
|
237 |
+
for b_i in range(x.shape[1]):
|
238 |
+
done = 0
|
239 |
+
while done < 1:
|
240 |
+
try:
|
241 |
+
model, likelihood = get_mcmc_model(x[:t, b_i], y[:t, b_i], hyperparameters, device, num_samples=num_samples, warmup_steps=warmup_steps)
|
242 |
+
model.eval()
|
243 |
+
|
244 |
+
with torch.no_grad():
|
245 |
+
dists = model(x[t, b_i, :].unsqueeze(
|
246 |
+
0)) # TODO check what is going on here! Does the GP interpret the input wrong?
|
247 |
+
if use_likelihood:
|
248 |
+
dists = likelihood(dists)
|
249 |
+
l = -get_mean_logdensity([dists], y[t, b_i], full_range)
|
250 |
+
done = 1
|
251 |
+
except Exception as e:
|
252 |
+
done -= 1
|
253 |
+
print('Trying again..')
|
254 |
+
print(traceback.format_exc())
|
255 |
+
print(e)
|
256 |
+
finally:
|
257 |
+
if done < -10:
|
258 |
+
print('Too many retries...')
|
259 |
+
exit()
|
260 |
+
|
261 |
+
step_losses.append(l.item())
|
262 |
+
#print('loss',l.item())
|
263 |
+
print(f'current average loss at step {t} is {sum(step_losses)/len(step_losses)} with {(time.time()-start_step)/len(step_losses)} s per eval.')
|
264 |
+
loss_sum += l
|
265 |
+
|
266 |
+
loss_sum /= x.shape[1]
|
267 |
+
all_losses.append(step_losses)
|
268 |
+
print(f'loss after step {t} is {loss_sum}')
|
269 |
+
losses_after_t.append(loss_sum)
|
270 |
+
print(f'losses so far {torch.tensor(losses_after_t)}')
|
271 |
+
return torch.tensor(losses_after_t), time.time() - start_time, all_losses
|
272 |
+
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
if __name__ == '__main__':
|
278 |
+
import argparse
|
279 |
+
|
280 |
+
parser = argparse.ArgumentParser()
|
281 |
+
parser.add_argument('--batch_size', type=int)
|
282 |
+
parser.add_argument('--seq_len', type=int)
|
283 |
+
parser.add_argument('--min_seq_len', type=int, default=0)
|
284 |
+
parser.add_argument('--warmup_steps', type=int)
|
285 |
+
parser.add_argument('--num_samples', type=int)
|
286 |
+
parser.add_argument('--min_y', type=int)
|
287 |
+
parser.add_argument('--max_y', type=int)
|
288 |
+
parser.add_argument('--dim', type=int, default=1)
|
289 |
+
parser.add_argument('--use_likelihood', default=True, type=bool)
|
290 |
+
parser.add_argument('--device', default='cpu')
|
291 |
+
parser.add_argument('--outputscale_concentraion', default=2., type=float)
|
292 |
+
parser.add_argument('--noise_concentration', default=1.1, type=float)
|
293 |
+
parser.add_argument('--noise_rate', default=.05, type=float)
|
294 |
+
|
295 |
+
args = parser.parse_args()
|
296 |
+
|
297 |
+
print('min_y:', args.min_y)
|
298 |
+
full_range = (None if args.min_y is None else (args.min_y,args.max_y))
|
299 |
+
|
300 |
+
hps = {'outputscale_concentration': args.outputscale_concentraion, 'noise_concentration': args.noise_concentration,
|
301 |
+
'noise_rate': args.noise_rate, 'fast_computations': (False,False,False)}
|
302 |
+
x, y, _ = get_batch(args.batch_size, args.seq_len, args.dim, fix_to_range=full_range, hyperparameters=hps)
|
303 |
+
print('RESULT:', evaluate_(x, y, y, device=args.device, warmup_steps=args.warmup_steps,
|
304 |
+
num_samples=args.num_samples, full_range=full_range, min_seq_len=args.min_seq_len,
|
305 |
+
hyperparameters=hps, use_likelihood=args.use_likelihood))
|
306 |
+
|
307 |
+
|
prior-fitting/priors/gp.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import random
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from sklearn.gaussian_process import GaussianProcessRegressor
|
8 |
+
from sklearn.gaussian_process.kernels import RBF, DotProduct, WhiteKernel
|
9 |
+
from .utils import get_batch_to_dataloader
|
10 |
+
|
11 |
+
|
12 |
+
length_scale_sampling_gp = .6
|
13 |
+
|
14 |
+
def get_gp(length_scale=None):
|
15 |
+
return GaussianProcessRegressor(
|
16 |
+
kernel=RBF(length_scale=length_scale or length_scale_sampling_gp, length_scale_bounds='fixed'),
|
17 |
+
random_state=0, optimizer=None)
|
18 |
+
|
19 |
+
|
20 |
+
def get_batch(batch_size, seq_len, num_features, noisy_std=None):
|
21 |
+
# m = torch.normal(0.,.1,size=(batch_size,num_features))
|
22 |
+
# m2 = torch.rand(batch_size,num_features)
|
23 |
+
# b = 0 # torch.rand(batch_size)
|
24 |
+
x_t = torch.rand(batch_size, seq_len, num_features)
|
25 |
+
# gp_b = TensorGP(kernel=TensorRBF(noisy_std))
|
26 |
+
# y_t = gp_b.sample_from_GP_prior(x_t).detach()
|
27 |
+
|
28 |
+
gpr = get_gp(noisy_std)
|
29 |
+
y_t = torch.zeros(batch_size, seq_len)
|
30 |
+
|
31 |
+
for i in range(len(y_t)):
|
32 |
+
y_t[i] += gpr.sample_y(x_t[i], random_state=random.randint(0, 2 ** 32)).squeeze()
|
33 |
+
x, y = x_t.transpose(0, 1), y_t.transpose(0, 1)
|
34 |
+
# x, _ = torch.sort(x,dim=0)
|
35 |
+
return x, y, y
|
36 |
+
|
37 |
+
|
38 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
39 |
+
DataLoader.num_outputs = 1
|
40 |
+
|
41 |
+
def evaluate(x, y, y_non_noisy, use_mse=False, length_scale=length_scale_sampling_gp):
|
42 |
+
start_time = time.time()
|
43 |
+
losses_after_t = [.0]
|
44 |
+
for t in range(1, len(x)):
|
45 |
+
loss_sum = 0.
|
46 |
+
for b_i in range(x.shape[1]):
|
47 |
+
gpr = get_gp(length_scale).fit(x[:t, b_i], y[:t, b_i])
|
48 |
+
means, stds = gpr.predict(x[t, b_i].unsqueeze(0), return_std=True)
|
49 |
+
assert len(means) == 1 == len(stds)
|
50 |
+
if use_mse:
|
51 |
+
c = nn.MSELoss()
|
52 |
+
l = c(torch.tensor(means), y[t, b_i].unsqueeze(-1))
|
53 |
+
else:
|
54 |
+
c = nn.GaussianNLLLoss(full=True)
|
55 |
+
l = c(torch.tensor(means), y[t, b_i].unsqueeze(-1),
|
56 |
+
var=torch.tensor(stds) ** 2)
|
57 |
+
loss_sum += l
|
58 |
+
|
59 |
+
|
60 |
+
losses_after_t.append(loss_sum / x.shape[1])
|
61 |
+
|
62 |
+
return torch.tensor(losses_after_t), time.time()-start_time
|
63 |
+
|
64 |
+
if __name__ == '__main__':
|
65 |
+
ls = .1
|
66 |
+
for alpha in set([ls, ls * 1.1, ls * .9]):
|
67 |
+
print(alpha)
|
68 |
+
for redo_idx in range(1):
|
69 |
+
print(
|
70 |
+
evaluate(*get_batch(1000, 10, noisy_std=ls, num_features=10), use_mse=False, length_scale=alpha))
|
prior-fitting/priors/mlp.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from utils import default_device
|
9 |
+
from .utils import get_batch_to_dataloader
|
10 |
+
from .utils import order_by_y, normalize_data, normalize_by_used_features_f, Binarize
|
11 |
+
from .utils import trunc_norm_sampler_f, beta_sampler_f, gamma_sampler_f, uniform_sampler_f, zipf_sampler_f, scaled_beta_sampler_f, uniform_int_sampler_f
|
12 |
+
|
13 |
+
|
14 |
+
def canonical_pre_processing(x, canonical_args):
|
15 |
+
assert x.shape[2] == len(canonical_args)
|
16 |
+
ranges = [torch.arange(num_classes).float() if num_classes is not None else None for num_classes in canonical_args]
|
17 |
+
for feature_dim, rang in enumerate(ranges):
|
18 |
+
if rang is not None:
|
19 |
+
x[:, :, feature_dim] = (x[:, :, feature_dim] - rang.mean()) / rang.std()
|
20 |
+
return x
|
21 |
+
|
22 |
+
|
23 |
+
DEFAULT_NUM_LAYERS = 2
|
24 |
+
DEFAULT_HIDDEN_DIM = 100
|
25 |
+
DEFAULT_ACTIVATION_MODULE = torch.nn.ReLU
|
26 |
+
DEFAULT_INIT_STD = .1
|
27 |
+
DEFAULT_HIDDEN_NOISE_STD = .1
|
28 |
+
DEFAULT_FIXED_DROPOUT = 0.
|
29 |
+
DEFAULT_IS_BINARY_CLASSIFICATION = False
|
30 |
+
|
31 |
+
|
32 |
+
class GaussianNoise(nn.Module):
|
33 |
+
def __init__(self, std):
|
34 |
+
super().__init__()
|
35 |
+
self.std = std
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return x + torch.normal(torch.zeros_like(x), self.std)
|
39 |
+
|
40 |
+
|
41 |
+
def causes_sampler_f(num_causes_sampler):
|
42 |
+
num_causes = num_causes_sampler()
|
43 |
+
means = np.random.normal(0, 1, (num_causes))
|
44 |
+
std = np.abs(np.random.normal(0, 1, (num_causes)) * means)
|
45 |
+
return means, std
|
46 |
+
|
47 |
+
def categorical_features_sampler(max_features):
|
48 |
+
features = []
|
49 |
+
ordinal = []
|
50 |
+
num_categorical_features_sampler = scaled_beta_sampler_f(0.5, .8, max_features, 0)
|
51 |
+
is_ordinal_sampler = lambda : random.choice([True, False])
|
52 |
+
classes_per_feature_sampler = scaled_beta_sampler_f(0.1, 2.0, 10, 1)
|
53 |
+
classes_per_feature_sampler_ordinal = scaled_beta_sampler_f(0.1, 2.0, 200, 1)
|
54 |
+
for i in range(0, num_categorical_features_sampler()):
|
55 |
+
ordinal_s = is_ordinal_sampler()
|
56 |
+
ordinal.append(ordinal_s)
|
57 |
+
classes = classes_per_feature_sampler_ordinal() if ordinal_s else classes_per_feature_sampler()
|
58 |
+
features.append(np.random.rand(classes))
|
59 |
+
return features, ordinal
|
60 |
+
|
61 |
+
|
62 |
+
def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=(DEFAULT_NUM_LAYERS, DEFAULT_HIDDEN_DIM, DEFAULT_ACTIVATION_MODULE, DEFAULT_INIT_STD, DEFAULT_HIDDEN_NOISE_STD, DEFAULT_FIXED_DROPOUT, DEFAULT_IS_BINARY_CLASSIFICATION),
|
63 |
+
batch_size_per_gp_sample=None, num_outputs=1, canonical_args=None, sampling='normal'):
|
64 |
+
assert num_outputs == 1
|
65 |
+
num_layers_sampler, hidden_dim_sampler, activation_module, init_std_sampler, noise_std_sampler, dropout_prob_sampler, is_binary_classification, num_features_used_sampler, causes_sampler, is_causal, pre_sample_causes, pre_sample_weights, y_is_effect, order_y, normalize_by_used_features, categorical_features_sampler, nan_prob = hyperparameters
|
66 |
+
|
67 |
+
# if is_binary_classification:
|
68 |
+
# sample_batch_size = 100*batch_size
|
69 |
+
# else:
|
70 |
+
sample_batch_size = batch_size
|
71 |
+
|
72 |
+
# if canonical_args is not None:
|
73 |
+
# assert len(canonical_args) == num_causes
|
74 |
+
# # should be list of [None, 2, 4] meaning scalar parameter, 2 classes, 4 classes
|
75 |
+
#
|
76 |
+
# for feature_idx, num_classes in enumerate(canonical_args):
|
77 |
+
# if num_classes is not None:
|
78 |
+
# causes[:,:,feature_idx] = torch.randint(num_classes, (seq_len, sample_batch_size))
|
79 |
+
#
|
80 |
+
# causes = canonical_pre_processing(causes, canonical_args)
|
81 |
+
|
82 |
+
batch_size_per_gp_sample = batch_size_per_gp_sample or sample_batch_size // 8
|
83 |
+
assert sample_batch_size % batch_size_per_gp_sample == 0, 'Please choose a batch_size divisible by batch_size_per_gp_sample.'
|
84 |
+
num_models = sample_batch_size // batch_size_per_gp_sample
|
85 |
+
# standard kaiming uniform init currently...
|
86 |
+
|
87 |
+
def get_model():
|
88 |
+
class MLP(torch.nn.Module):
|
89 |
+
def __init__(self):
|
90 |
+
super(MLP, self).__init__()
|
91 |
+
|
92 |
+
self.dropout_prob = dropout_prob_sampler()
|
93 |
+
self.noise_std = noise_std_sampler()
|
94 |
+
self.init_std = init_std_sampler()
|
95 |
+
self.num_features_used = num_features_used_sampler()
|
96 |
+
self.categorical_features, self.categorical_features_is_ordinal = categorical_features_sampler(self.num_features_used)
|
97 |
+
if is_causal:
|
98 |
+
self.causes = causes_sampler() if is_causal else self.num_features_used
|
99 |
+
self.causes = (torch.tensor(self.causes[0], device=device).unsqueeze(0).unsqueeze(0).tile((seq_len,1,1)), torch.tensor(self.causes[1], device=device).unsqueeze(0).unsqueeze(0).tile((seq_len,1,1)))
|
100 |
+
self.num_causes = self.causes[0].shape[2]
|
101 |
+
else:
|
102 |
+
self.num_causes = self.num_features_used
|
103 |
+
self.num_layers = num_layers_sampler()
|
104 |
+
self.hidden_dim = hidden_dim_sampler()
|
105 |
+
|
106 |
+
if is_causal:
|
107 |
+
self.hidden_dim = max(self.hidden_dim, 2 * self.num_features_used+1)
|
108 |
+
|
109 |
+
#print('cat', self.categorical_features, self.categorical_features_is_ordinal, self.num_features_used)
|
110 |
+
|
111 |
+
assert(self.num_layers > 2)
|
112 |
+
|
113 |
+
self.layers = [nn.Linear(self.num_causes, self.hidden_dim, device=device)]
|
114 |
+
self.layers += [module for layer_idx in range(self.num_layers-1) for module in [
|
115 |
+
nn.Sequential(*[
|
116 |
+
activation_module()
|
117 |
+
, nn.Linear(self.hidden_dim, num_outputs if layer_idx == self.num_layers - 2 else self.hidden_dim, device=device)
|
118 |
+
, GaussianNoise(torch.abs(torch.normal(torch.zeros((num_outputs if layer_idx == self.num_layers - 2 else self.hidden_dim),device=device), self.noise_std))) if pre_sample_weights else GaussianNoise(self.noise_std)
|
119 |
+
])
|
120 |
+
]]
|
121 |
+
self.layers = nn.Sequential(*self.layers)
|
122 |
+
|
123 |
+
self.binarizer = Binarize() if is_binary_classification else lambda x : x
|
124 |
+
|
125 |
+
# Initialize Model parameters
|
126 |
+
for i, p in enumerate(self.layers.parameters()):
|
127 |
+
dropout_prob = self.dropout_prob if i > 0 else 0.0
|
128 |
+
nn.init.normal_(p, std=self.init_std / (1. - dropout_prob))
|
129 |
+
with torch.no_grad():
|
130 |
+
p *= torch.bernoulli(torch.zeros_like(p) + 1. - dropout_prob)
|
131 |
+
|
132 |
+
def forward(self):
|
133 |
+
if sampling == 'normal':
|
134 |
+
if is_causal and pre_sample_causes:
|
135 |
+
causes = torch.normal(self.causes[0], self.causes[1].abs()).float()
|
136 |
+
else:
|
137 |
+
causes = torch.normal(0., 1., (seq_len, 1, self.num_causes), device=device).float()
|
138 |
+
elif sampling == 'uniform':
|
139 |
+
causes = torch.rand((seq_len, 1, self.num_causes), device=device)
|
140 |
+
else:
|
141 |
+
raise ValueError(f'Sampling is set to invalid setting: {sampling}.')
|
142 |
+
|
143 |
+
outputs = [causes]
|
144 |
+
for layer in self.layers:
|
145 |
+
outputs.append(layer(outputs[-1]))
|
146 |
+
outputs = outputs[2:]
|
147 |
+
|
148 |
+
if is_causal:
|
149 |
+
outputs_flat = torch.cat(outputs, -1)
|
150 |
+
random_perm = torch.randperm(outputs_flat.shape[-1]-1, device=device)
|
151 |
+
random_idx_y = [-1] if y_is_effect else random_perm[0:num_outputs]
|
152 |
+
y = outputs_flat[:, :, random_idx_y]
|
153 |
+
|
154 |
+
random_idx = random_perm[num_outputs:num_outputs + self.num_features_used]
|
155 |
+
x = outputs_flat[:, :, random_idx]
|
156 |
+
else:
|
157 |
+
y = outputs[-1][:, :, :]
|
158 |
+
x = causes
|
159 |
+
|
160 |
+
if len(self.categorical_features) > 0:
|
161 |
+
random_perm = torch.randperm(x.shape[-1], device=device)
|
162 |
+
for i, (categorical_feature, is_ordinal) in enumerate(zip(self.categorical_features, self.categorical_features_is_ordinal)):
|
163 |
+
idx = random_perm[i]
|
164 |
+
temp = normalize_data(x[:, :, idx])
|
165 |
+
if is_ordinal:
|
166 |
+
x[:, :, idx] = (temp > (torch.tensor(categorical_feature, device=device, dtype=torch.float32).unsqueeze(-1).unsqueeze(-1) - 0.5)).sum(axis=0)
|
167 |
+
else:
|
168 |
+
x[:, :, idx] = (temp > (torch.tensor(categorical_feature, device=device,
|
169 |
+
dtype=torch.float32).unsqueeze(-1).unsqueeze(-1) - 0.5)).sum(
|
170 |
+
axis=0) * (127 * len(categorical_feature) + 1) % len(categorical_feature)
|
171 |
+
|
172 |
+
|
173 |
+
# if nan_prob > 0:
|
174 |
+
# nan_value = random.choice([-999,-1,0, -10])
|
175 |
+
# x[torch.rand(x.shape, device=device) > (1-nan_prob)] = nan_value
|
176 |
+
|
177 |
+
x, y = normalize_data(x), normalize_data(y)
|
178 |
+
|
179 |
+
# Binarize output if enabled
|
180 |
+
y = self.binarizer(y)
|
181 |
+
|
182 |
+
if normalize_by_used_features:
|
183 |
+
x = normalize_by_used_features_f(x, self.num_features_used, num_features)
|
184 |
+
|
185 |
+
if is_binary_classification and order_y:
|
186 |
+
x, y = order_by_y(x,y)
|
187 |
+
|
188 |
+
# Append empty features if enabled
|
189 |
+
x = torch.cat([x, torch.zeros((x.shape[0], x.shape[1], num_features - self.num_features_used), device=device)], -1)
|
190 |
+
|
191 |
+
return x, y
|
192 |
+
|
193 |
+
return MLP()
|
194 |
+
|
195 |
+
models = [get_model() for _ in range(num_models)]
|
196 |
+
|
197 |
+
sample = sum([[model() for _ in range(0,batch_size_per_gp_sample)] for model in models],[])
|
198 |
+
|
199 |
+
x, y = zip(*sample)
|
200 |
+
y = torch.cat(y, 1).squeeze(-1).detach()
|
201 |
+
x = torch.cat(x, 1).detach()
|
202 |
+
|
203 |
+
return x, y, y
|
204 |
+
|
205 |
+
|
206 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
207 |
+
DataLoader.num_outputs = 1
|
208 |
+
|
prior-fitting/priors/omniglot.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from torch.utils import data
|
5 |
+
from torchvision import transforms
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from datasets import omniglotNshot
|
9 |
+
import utils
|
10 |
+
|
11 |
+
|
12 |
+
def _compute_maxtranslations(single_image_tensor, dim, background):
|
13 |
+
assert len(single_image_tensor.shape) == 2
|
14 |
+
content_rows = ((single_image_tensor == background).all(dim=1 - dim) == False).nonzero()
|
15 |
+
begin, end = content_rows[0], content_rows[-1]
|
16 |
+
return torch.cat([-begin, single_image_tensor.shape[dim] - end - 1]).cpu().tolist()
|
17 |
+
|
18 |
+
|
19 |
+
def compute_maxtranslations_x_y(single_image_tensor, background):
|
20 |
+
return _compute_maxtranslations(single_image_tensor, 1, background), _compute_maxtranslations(single_image_tensor,
|
21 |
+
0, background)
|
22 |
+
|
23 |
+
|
24 |
+
def translate(img, trans_x, trans_y):
|
25 |
+
return transforms.functional.affine(img.unsqueeze(0), angle=0.0, translate=[trans_x, trans_y], scale=1.0,
|
26 |
+
interpolation=transforms.InterpolationMode.NEAREST, shear=[0.0, 0.0],
|
27 |
+
fill=0.).squeeze(0)
|
28 |
+
|
29 |
+
def translate_omniglot(image_tensor, background=0.):
|
30 |
+
flat_image_tensor = image_tensor.view(-1, *image_tensor.shape[-2:])
|
31 |
+
for i, image in enumerate(flat_image_tensor):
|
32 |
+
max_x, max_y = compute_maxtranslations_x_y(image, background)
|
33 |
+
flat_image_tensor[i] = translate(image, random.randint(*max_x), random.randint(*max_y))
|
34 |
+
return flat_image_tensor.view(*image_tensor.shape)
|
35 |
+
|
36 |
+
|
37 |
+
class DataLoader(data.DataLoader):
|
38 |
+
def __init__(self, num_steps, batch_size, seq_len, num_features, num_outputs, num_classes_used=1200, fuse_x_y=False, train=True, translations=True, jonas_style=False):
|
39 |
+
# TODO position before last is predictable by counting..
|
40 |
+
utils.set_locals_in_self(locals())
|
41 |
+
assert not fuse_x_y, 'So far don\' support fusing.'
|
42 |
+
imgsz = math.isqrt(num_features)
|
43 |
+
assert imgsz * imgsz == num_features
|
44 |
+
assert ((seq_len-1) // num_outputs) * num_outputs == seq_len - 1
|
45 |
+
if jonas_style:
|
46 |
+
self.d = omniglotNshot.OmniglotNShotJonas('omniglot', batchsz=batch_size, n_way=num_outputs,
|
47 |
+
k_shot=((seq_len - 1) // num_outputs),
|
48 |
+
k_query=1, imgsz=imgsz)
|
49 |
+
else:
|
50 |
+
self.d = omniglotNshot.OmniglotNShot('omniglot', batchsz=batch_size, n_way=num_outputs,
|
51 |
+
k_shot=((seq_len - 1) // num_outputs),
|
52 |
+
k_query=1, imgsz=imgsz, num_train_classes_used=num_classes_used)
|
53 |
+
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return self.num_steps
|
57 |
+
|
58 |
+
def __iter__(self):
|
59 |
+
# Eval at pos
|
60 |
+
def t(x, y, x_q, y_q):
|
61 |
+
x = np.concatenate([x,x_q[:,:1]], 1)
|
62 |
+
y = np.concatenate([y,y_q[:,:1]], 1)
|
63 |
+
y = torch.from_numpy(y).transpose(0, 1)
|
64 |
+
target_y = y.clone().detach()
|
65 |
+
target_y[:-1] = -100
|
66 |
+
x = torch.from_numpy(x)
|
67 |
+
if self.translations and self.train:
|
68 |
+
x = translate_omniglot(x)
|
69 |
+
image_tensor = x.view(*x.shape[:2], -1).transpose(0, 1), y
|
70 |
+
return image_tensor, target_y
|
71 |
+
|
72 |
+
return (t(*self.d.next(mode='train' if self.train else 'test')) for _ in range(self.num_steps))
|
73 |
+
|
74 |
+
@torch.no_grad()
|
75 |
+
def validate(self, finetuned_model, eval_pos=-1):
|
76 |
+
finetuned_model.eval()
|
77 |
+
device = next(iter(finetuned_model.parameters())).device
|
78 |
+
|
79 |
+
if not hasattr(self, 't_dl'):
|
80 |
+
self.t_dl = DataLoader(num_steps=self.num_steps, batch_size=self.batch_size, seq_len=self.seq_len,
|
81 |
+
num_features=self.num_features, num_outputs=self.num_outputs, fuse_x_y=self.fuse_x_y,
|
82 |
+
train=False)
|
83 |
+
|
84 |
+
ps = []
|
85 |
+
ys = []
|
86 |
+
for x,y in self.t_dl:
|
87 |
+
p = finetuned_model(tuple(e.to(device) for e in x), single_eval_pos=eval_pos)
|
88 |
+
ps.append(p)
|
89 |
+
ys.append(y)
|
90 |
+
|
91 |
+
ps = torch.cat(ps,1)
|
92 |
+
ys = torch.cat(ys,1)
|
93 |
+
|
94 |
+
def acc(ps,ys):
|
95 |
+
return (ps.argmax(-1)==ys.to(ps.device)).float().mean()
|
96 |
+
|
97 |
+
a = acc(ps[eval_pos], ys[eval_pos]).cpu()
|
98 |
+
return a
|
prior-fitting/priors/prior.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
|
3 |
+
|
4 |
+
class PriorDataLoader(DataLoader):
|
5 |
+
pass
|
6 |
+
# init accepts num_steps as first argument
|
7 |
+
|
8 |
+
# has two attributes set on class or object level:
|
9 |
+
# num_features: int and
|
10 |
+
# num_outputs: int
|
11 |
+
# fuse_x_y: bool
|
12 |
+
# Optional: validate function that accepts a transformer model
|
prior-fitting/priors/pyro.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from utils import default_device
|
7 |
+
from .utils import get_batch_to_dataloader
|
8 |
+
|
9 |
+
|
10 |
+
def get_batch(batch_size, seq_len, batch_size_per_gp_sample=None, **config):
|
11 |
+
batch_size_per_gp_sample = batch_size_per_gp_sample or batch_size // 16
|
12 |
+
assert batch_size % batch_size_per_gp_sample == 0, 'Please choose a batch_size divisible by batch_size_per_gp_sample.'
|
13 |
+
num_models = batch_size // batch_size_per_gp_sample
|
14 |
+
# standard kaiming uniform init currently...
|
15 |
+
|
16 |
+
models = [config['model']() for _ in range(num_models)]
|
17 |
+
|
18 |
+
sample = sum([[model(seq_len=seq_len) for _ in range(0,batch_size_per_gp_sample)] for model in models],[])
|
19 |
+
|
20 |
+
def normalize_data(data):
|
21 |
+
mean = data.mean(0)
|
22 |
+
std = data.std(0) + .000001
|
23 |
+
eval_xs = (data - mean) / std
|
24 |
+
|
25 |
+
return eval_xs
|
26 |
+
|
27 |
+
x, y = zip(*sample)
|
28 |
+
|
29 |
+
y = torch.stack(y, 1).squeeze(-1).detach()
|
30 |
+
x = torch.stack(x, 1).detach()
|
31 |
+
|
32 |
+
x, y = normalize_data(x), y
|
33 |
+
|
34 |
+
return x, y, y
|
35 |
+
|
36 |
+
|
37 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
38 |
+
DataLoader.num_outputs = 1
|
39 |
+
|
prior-fitting/priors/ridge.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import time
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from sklearn.linear_model import Ridge
|
8 |
+
from .utils import get_batch_to_dataloader
|
9 |
+
|
10 |
+
def get_batch(batch_size, seq_len, num_features, noisy_std = .1):
|
11 |
+
m = torch.normal(0., .1, size=(batch_size,num_features))
|
12 |
+
b = 0 # torch.rand(batch_size)
|
13 |
+
x = torch.rand(seq_len, batch_size,num_features)
|
14 |
+
y_non_noisy = torch.einsum('bf,tbf->tb',m,x)
|
15 |
+
y = y_non_noisy + torch.normal(torch.zeros_like(y_non_noisy),noisy_std) # noisy_std is alpha
|
16 |
+
return x, y, y_non_noisy
|
17 |
+
|
18 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
19 |
+
DataLoader.num_outputs = 1
|
20 |
+
|
21 |
+
|
22 |
+
def evaluate(x,y,y_non_noisy, alpha=0.):
|
23 |
+
start_time = time.time()
|
24 |
+
losses_after_t = [.0]
|
25 |
+
for t in range(1,len(x)):
|
26 |
+
loss_sum = 0.
|
27 |
+
for b_i in range(x.shape[1]):
|
28 |
+
clf = Ridge(alpha=alpha)
|
29 |
+
clf.fit(x[:t,b_i],y[:t,b_i])
|
30 |
+
y_ = clf.predict(x[t,b_i].unsqueeze(0))
|
31 |
+
l = nn.MSELoss()(y_non_noisy[t,b_i].unsqueeze(0),torch.tensor(y_))
|
32 |
+
loss_sum += l
|
33 |
+
losses_after_t.append(loss_sum/x.shape[1])
|
34 |
+
return torch.tensor(losses_after_t), time.time()-start_time
|
35 |
+
|
36 |
+
if __name__ == '__main__':
|
37 |
+
for alpha in [.001,.01,.5,1.]:
|
38 |
+
print(alpha, evaluate(*get_batch(1000,10,noisy_std=.01),alpha=alpha))
|
prior-fitting/priors/stroke.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageDraw, ImageFilter
|
2 |
+
import random
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from .utils import get_batch_to_dataloader
|
8 |
+
|
9 |
+
def mnist_prior(num_classes=2, size=28, min_max_strokes=(1,3), min_max_len=(5/28,20/28), min_max_start=(2/28,25/28),
|
10 |
+
min_max_width=(1/28,4/28), max_offset=4/28, max_target_offset=2/28):
|
11 |
+
classes = []
|
12 |
+
for i in range(num_classes):
|
13 |
+
num_strokes = random.randint(*min_max_strokes)
|
14 |
+
len_strokes = [random.randint(int(size * min_max_len[0]), int(size * min_max_len[1])) for i in range(num_strokes)]
|
15 |
+
stroke_start_points = [
|
16 |
+
(random.randint(int(size * min_max_start[0]), int(size * min_max_start[1])), random.randint(int(size * min_max_start[0]), int(size * min_max_start[1]))) for i in
|
17 |
+
range(num_strokes)]
|
18 |
+
stroke_directions = []
|
19 |
+
# i = Image.fromarray(np.zeros((28,28),dtype=np.uint8))
|
20 |
+
# draw = ImageDraw.Draw(i)
|
21 |
+
for i in range(num_strokes):
|
22 |
+
sp, length = stroke_start_points[i], len_strokes[i]
|
23 |
+
counter = 0
|
24 |
+
while True:
|
25 |
+
if counter % 3 == 0:
|
26 |
+
length = random.randint(int(size * min_max_len[0]), int(size * min_max_len[1]))
|
27 |
+
sp = (
|
28 |
+
random.randint(int(size * min_max_start[0]), int(size * min_max_start[1])), random.randint(int(size * min_max_start[0]), int(size * min_max_start[1])))
|
29 |
+
stroke_start_points[i], len_strokes[i] = sp, length
|
30 |
+
radians = random.random() * 2 * math.pi
|
31 |
+
x_vel = math.cos(radians) * length
|
32 |
+
y_vel = math.sin(radians) * length
|
33 |
+
new_p = (sp[0] + x_vel, sp[1] + y_vel)
|
34 |
+
# print(math.degrees(radians),sp,new_p)
|
35 |
+
if not any(n > size - 1 or n < 0 for n in new_p):
|
36 |
+
break
|
37 |
+
counter += 1
|
38 |
+
stroke_directions.append(radians)
|
39 |
+
# print([round(x) for x in sp+new_p])
|
40 |
+
# draw.line([round(x) for x in sp+new_p], fill=128, width=3)
|
41 |
+
classes.append((len_strokes, stroke_start_points, stroke_directions))
|
42 |
+
|
43 |
+
generator_functions = []
|
44 |
+
for c in classes:
|
45 |
+
def g(c=c):
|
46 |
+
len_strokes, stroke_start_points, stroke_directions = c
|
47 |
+
i = Image.fromarray(np.zeros((size, size), dtype=np.uint8))
|
48 |
+
draw = ImageDraw.Draw(i)
|
49 |
+
width = random.randint(int(size * min_max_width[0]), int(size * min_max_width[1]))
|
50 |
+
offset = random.randint(int(-size * max_offset), int(size * max_offset)), random.randint(int(- size * max_offset), int(size * max_offset))
|
51 |
+
for sp, length, radians in zip(stroke_start_points, len_strokes, stroke_directions):
|
52 |
+
sp = (sp[0] + offset[0], sp[1] + offset[1])
|
53 |
+
x_vel = math.cos(radians) * length + random.randint(int(-size * max_target_offset), int(size * max_target_offset))
|
54 |
+
y_vel = math.sin(radians) * length + random.randint(int(-size * max_target_offset), int(size * max_target_offset))
|
55 |
+
new_p = (sp[0] + x_vel, sp[1] + y_vel)
|
56 |
+
stroke_directions.append(radians)
|
57 |
+
draw.line([round(x) for x in sp + new_p], fill=128, width=width)
|
58 |
+
a_i = np.array(i)
|
59 |
+
a_i[a_i == 128] = np.random.randint(200, 255, size=a_i.shape)[a_i == 128]
|
60 |
+
return Image.fromarray(a_i).filter(ImageFilter.GaussianBlur(.2))
|
61 |
+
|
62 |
+
generator_functions.append(g)
|
63 |
+
return generator_functions
|
64 |
+
|
65 |
+
|
66 |
+
# g1,g2 = mnist_prior(2)
|
67 |
+
|
68 |
+
# for i in [g1() for _ in range(10)]:
|
69 |
+
# display(i.resize((200,200)))
|
70 |
+
|
71 |
+
from torchvision.transforms import ToTensor, ToPILImage
|
72 |
+
|
73 |
+
|
74 |
+
def normalize(x):
|
75 |
+
return (x-x.mean())/(x.std()+.000001)
|
76 |
+
|
77 |
+
from os import path, listdir
|
78 |
+
import random
|
79 |
+
|
80 |
+
def get_batch(batch_size, seq_len, num_features=None, noisy_std=None, only_train_for_last_idx=False, normalize_x=False, num_outputs=2, use_saved_from=None, **kwargs): # num_features = 28*28=784
|
81 |
+
if use_saved_from is not None:
|
82 |
+
directory = path.join(use_saved_from, f'len_{seq_len}_out_{num_outputs}_features_{num_features}_bs_{batch_size}')
|
83 |
+
filename = random.choice(listdir(directory))
|
84 |
+
return torch.load(path.join(directory,filename))
|
85 |
+
|
86 |
+
size = math.isqrt(num_features)
|
87 |
+
assert size * size == num_features, 'num_features needs to be the square of an integer.'
|
88 |
+
if only_train_for_last_idx:
|
89 |
+
assert (seq_len-1) % num_outputs == 0
|
90 |
+
|
91 |
+
# assert seq_len % 2 == 0, "assert seq_len % 2 == 0"
|
92 |
+
batch = []
|
93 |
+
y = []
|
94 |
+
target_y = []
|
95 |
+
for b_i in range(batch_size):
|
96 |
+
gs = mnist_prior(num_outputs, size, **kwargs)
|
97 |
+
if only_train_for_last_idx:
|
98 |
+
generators = [i for i in range(len(gs)) for _ in range((seq_len-1) // num_outputs)]
|
99 |
+
random.shuffle(generators)
|
100 |
+
generators += [random.randint(0, len(gs) - 1)]
|
101 |
+
target = [-100 for _ in generators]
|
102 |
+
target[-1] = generators[-1]
|
103 |
+
else:
|
104 |
+
generators = [random.randint(0, len(gs) - 1) for _ in range(seq_len)]
|
105 |
+
target = generators
|
106 |
+
normalize_or_not = lambda x: normalize(x) if normalize_x else x
|
107 |
+
s = torch.cat([normalize_or_not(ToTensor()(gs[f_i]())) for f_i in generators], 0)
|
108 |
+
batch.append(s)
|
109 |
+
y.append(torch.tensor(generators))
|
110 |
+
target_y.append(torch.tensor(target))
|
111 |
+
x = torch.stack(batch, 1).view(seq_len, batch_size, -1)
|
112 |
+
y = torch.stack(y, 1)
|
113 |
+
target_y = torch.stack(target_y, 1)
|
114 |
+
return x,y,target_y
|
115 |
+
|
116 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
117 |
+
DataLoader.num_outputs = 2
|
118 |
+
|
119 |
+
if __name__ == '__main__':
|
120 |
+
g1, g2 = mnist_prior(2, size=3)
|
121 |
+
|
122 |
+
# for i in range(10):
|
123 |
+
# print(PILToTensor()(g1()))
|
124 |
+
# display(ToPILImage()(PILToTensor()(g1())).resize((200,200)))
|
125 |
+
# display(g2().resize((200,200)))
|
126 |
+
|
127 |
+
size = 10
|
128 |
+
x, y = get_batch(1, 10, num_features=size * size)
|
129 |
+
|
130 |
+
x_ = x[..., :-1].squeeze(1)
|
131 |
+
last_y = x[..., -1].squeeze(1)
|
132 |
+
y = y.squeeze(1)
|
133 |
+
|
134 |
+
# print(y)
|
135 |
+
|
136 |
+
for i, y_, last_y_, x__ in zip(x_, y, last_y, x.squeeze(1)):
|
137 |
+
# print(y_)
|
138 |
+
# print(i.shape)
|
139 |
+
# print(x__)
|
140 |
+
img = ToPILImage()(i.view(size, size))
|
141 |
+
# display(img.resize((200,200)))
|
142 |
+
|
143 |
+
print(y, last_y)
|
prior-fitting/priors/utils.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from utils import set_locals_in_self
|
6 |
+
from itertools import repeat
|
7 |
+
from .prior import PriorDataLoader
|
8 |
+
from torch import nn
|
9 |
+
import numpy as np
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import matplotlib.gridspec as gridspec
|
12 |
+
import scipy.stats as stats
|
13 |
+
|
14 |
+
def get_batch_to_dataloader(get_batch_method_):
|
15 |
+
class DL(PriorDataLoader):
|
16 |
+
get_batch_method = get_batch_method_
|
17 |
+
|
18 |
+
# Caution, you might need to set self.num_features manually if it is not part of the args.
|
19 |
+
def __init__(self, num_steps, fuse_x_y=False, **get_batch_kwargs):
|
20 |
+
set_locals_in_self(locals())
|
21 |
+
# The stuff outside the or is set as class attribute before instantiation.
|
22 |
+
self.num_features = get_batch_kwargs.get('num_features') or self.num_features
|
23 |
+
self.num_outputs = get_batch_kwargs.get('num_outputs') or self.num_outputs
|
24 |
+
print('DataLoader.__dict__', self.__dict__)
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def gbm(*args, fuse_x_y=True, **kwargs):
|
28 |
+
x, y, target_y = get_batch_method_(*args, **kwargs)
|
29 |
+
if fuse_x_y:
|
30 |
+
return torch.cat([x, torch.cat([torch.zeros_like(y[:1]), y[:-1]], 0).unsqueeze(-1).float()],
|
31 |
+
-1), target_y
|
32 |
+
else:
|
33 |
+
return (x, y), target_y
|
34 |
+
|
35 |
+
def __len__(self):
|
36 |
+
return self.num_steps
|
37 |
+
|
38 |
+
def __iter__(self):
|
39 |
+
return iter(self.gbm(**self.get_batch_kwargs, fuse_x_y=self.fuse_x_y) for _ in range(self.num_steps))
|
40 |
+
|
41 |
+
|
42 |
+
return DL
|
43 |
+
|
44 |
+
|
45 |
+
def plot_features(data, targets):
|
46 |
+
if torch.is_tensor(data):
|
47 |
+
data = data.detach().cpu().numpy()
|
48 |
+
targets = targets.detach().cpu().numpy()
|
49 |
+
fig2 = plt.figure(constrained_layout=True, figsize=(12, 12))
|
50 |
+
spec2 = gridspec.GridSpec(ncols=data.shape[1], nrows=data.shape[1], figure=fig2)
|
51 |
+
for d in range(0, data.shape[1]):
|
52 |
+
for d2 in range(0, data.shape[1]):
|
53 |
+
sub_ax = fig2.add_subplot(spec2[d, d2])
|
54 |
+
sub_ax.scatter(data[:, d], data[:, d2],
|
55 |
+
c=targets[:])
|
56 |
+
|
57 |
+
|
58 |
+
def plot_prior(prior):
|
59 |
+
s = np.array([prior() for _ in range(0, 10000)])
|
60 |
+
count, bins, ignored = plt.hist(s, 50, density=True)
|
61 |
+
print(s.min())
|
62 |
+
plt.show()
|
63 |
+
|
64 |
+
trunc_norm_sampler_f = lambda mu, sigma : lambda: stats.truncnorm((0 - mu) / sigma, (1 - mu) / sigma, loc=mu, scale=sigma).rvs(1)[0]
|
65 |
+
beta_sampler_f = lambda a, b : lambda : np.random.beta(a, b)
|
66 |
+
gamma_sampler_f = lambda a, b : lambda : np.random.gamma(a, b)
|
67 |
+
uniform_sampler_f = lambda a, b : lambda : np.random.uniform(a, b)
|
68 |
+
uniform_int_sampler_f = lambda a, b : lambda : np.random.randint(a, b)
|
69 |
+
zipf_sampler_f = lambda a, b, c : lambda : min(b + np.random.zipf(a), c)
|
70 |
+
scaled_beta_sampler_f = lambda a, b, scale, minimum : lambda : minimum + round(beta_sampler_f(a, b)() * (scale - minimum + 1) - 0.5)
|
71 |
+
|
72 |
+
|
73 |
+
def normalize_data(data):
|
74 |
+
mean = data.mean(0)
|
75 |
+
std = data.std(0) + .000001
|
76 |
+
data = (data - mean) / std
|
77 |
+
|
78 |
+
return data
|
79 |
+
|
80 |
+
|
81 |
+
def normalize_by_used_features_f(x, num_features_used, num_features):
|
82 |
+
return x / (num_features_used / num_features)
|
83 |
+
|
84 |
+
|
85 |
+
class Binarize(nn.Module):
|
86 |
+
def __init__(self, p=0.5):
|
87 |
+
super().__init__()
|
88 |
+
self.p = p
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
return (x > torch.median(x)).float()
|
92 |
+
|
93 |
+
|
94 |
+
def order_by_y(x, y):
|
95 |
+
order = torch.argsort(y if random.randint(0, 1) else -y, dim=0)[:, 0, 0]
|
96 |
+
order = order.reshape(2, -1).transpose(0, 1).reshape(-1)#.reshape(seq_len)
|
97 |
+
x = x[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).flip([0]).reshape(seq_len, 1, -1)
|
98 |
+
y = y[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).reshape(seq_len, 1, -1)
|
99 |
+
|
100 |
+
return x, y
|
101 |
+
|
102 |
+
|
prior-fitting/requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Recommend to use python >= 3.9
|
2 |
+
gpytorch==1.5.0
|
3 |
+
pyro-ppl==1.7.0
|
4 |
+
torch==1.9.0
|
5 |
+
scikit-learn==0.24.2
|
6 |
+
pyyaml==5.4.1
|
7 |
+
blitz-bayesian-pytorch==0.2.7
|
8 |
+
seaborn==0.11.2
|
9 |
+
xgboost==1.4.0
|
10 |
+
tqdm==4.62.1
|
11 |
+
numpy==1.21.2
|
12 |
+
openml==0.12.2
|
13 |
+
catboost==0.26.1
|
prior-fitting/tabular.py
ADDED
@@ -0,0 +1,725 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from catboost import CatBoostClassifier, Pool
|
2 |
+
from sklearn.model_selection import GridSearchCV
|
3 |
+
from sklearn.model_selection import KFold
|
4 |
+
from sklearn.model_selection import ParameterGrid
|
5 |
+
|
6 |
+
import pyro
|
7 |
+
import pyro.distributions as dist
|
8 |
+
from pyro.nn import PyroModule, PyroSample
|
9 |
+
from pyro.infer.autoguide import AutoDiagonalNormal
|
10 |
+
from pyro.infer import SVI, Trace_ELBO, Predictive, MCMC, NUTS
|
11 |
+
from pytorch_tabnet.tab_model import TabNetClassifier, TabNetRegressor
|
12 |
+
from sklearn.metrics import accuracy_score, roc_auc_score
|
13 |
+
import argparse
|
14 |
+
import itertools
|
15 |
+
|
16 |
+
from train import train, get_weighted_single_eval_pos_sampler, Losses
|
17 |
+
import priors
|
18 |
+
import encoders
|
19 |
+
from sklearn import preprocessing
|
20 |
+
|
21 |
+
from sklearn.base import BaseEstimator, ClassifierMixin
|
22 |
+
|
23 |
+
from torch import nn
|
24 |
+
|
25 |
+
from datasets import *
|
26 |
+
import xgboost as xgb
|
27 |
+
import matplotlib.pyplot as plt
|
28 |
+
import numpy as np
|
29 |
+
|
30 |
+
import torch
|
31 |
+
from sklearn import neighbors, datasets
|
32 |
+
from sklearn.gaussian_process import GaussianProcessClassifier
|
33 |
+
from sklearn.gaussian_process.kernels import RBF
|
34 |
+
|
35 |
+
from priors.utils import trunc_norm_sampler_f, beta_sampler_f, gamma_sampler_f, uniform_sampler_f, zipf_sampler_f, scaled_beta_sampler_f, uniform_int_sampler_f
|
36 |
+
|
37 |
+
from tqdm import tqdm
|
38 |
+
import time
|
39 |
+
import random
|
40 |
+
|
41 |
+
import os
|
42 |
+
|
43 |
+
CV = 5
|
44 |
+
param_grid = {}
|
45 |
+
metric_used = roc_auc_score
|
46 |
+
|
47 |
+
def get_uniform_single_eval_pos_sampler(max_len):
|
48 |
+
"""
|
49 |
+
Just sample any evaluation position with the same weight
|
50 |
+
:return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
|
51 |
+
"""
|
52 |
+
return lambda: random.choices(range(max_len))[0]
|
53 |
+
|
54 |
+
|
55 |
+
def get_mlp_prior_hyperparameters(config):
|
56 |
+
sigma_sampler = gamma_sampler_f(config["prior_sigma_gamma_k"], config["prior_sigma_gamma_theta"])
|
57 |
+
noise_std_sampler = gamma_sampler_f(config["prior_noise_std_gamma_k"], config["prior_noise_std_gamma_theta"])
|
58 |
+
|
59 |
+
mlp_prior_hyperparameters = (list(config["prior_nlayers_sampler"].values())[0]
|
60 |
+
, list(config["prior_emsize_sampler"].values())[0]
|
61 |
+
, config["prior_activations"]
|
62 |
+
, sigma_sampler
|
63 |
+
, noise_std_sampler
|
64 |
+
, list(config["prior_dropout_sampler"].values())[0]
|
65 |
+
, True
|
66 |
+
, list(config["prior_num_features_used_sampler"].values())[0]
|
67 |
+
, list(config["prior_causes_sampler"].values())[0] if config['prior_is_causal'] else None
|
68 |
+
, config["prior_is_causal"]
|
69 |
+
, config["prior_pre_sample_causes"] if config['prior_is_causal'] else None
|
70 |
+
, config["prior_pre_sample_weights"] if config['prior_is_causal'] else None
|
71 |
+
, config["prior_y_is_effect"] if config['prior_is_causal'] else None
|
72 |
+
, config["prior_order_y"]
|
73 |
+
, config["prior_normalize_by_used_features"]
|
74 |
+
, list(config["prior_categorical_feats"].values())[0] if config['prior_is_causal'] else None
|
75 |
+
, 0.0
|
76 |
+
)
|
77 |
+
|
78 |
+
return mlp_prior_hyperparameters
|
79 |
+
|
80 |
+
|
81 |
+
def get_gp_mix_prior_hyperparameters(config):
|
82 |
+
return {'lengthscale_concentration': config["prior_lengthscale_concentration"],
|
83 |
+
'nu': config["prior_nu"],
|
84 |
+
'outputscale_concentration': config["prior_outputscale_concentration"],
|
85 |
+
'categorical_data': config["prior_y_minmax_norm"],
|
86 |
+
'y_minmax_norm': config["prior_lengthscale_concentration"],
|
87 |
+
'noise_concentration': config["prior_noise_concentration"],
|
88 |
+
'noise_rate': config["prior_noise_rate"]}
|
89 |
+
|
90 |
+
|
91 |
+
def get_gp_prior_hyperparameters(config):
|
92 |
+
|
93 |
+
|
94 |
+
return (config['prior_noise']
|
95 |
+
, lambda : config['prior_outputscale']
|
96 |
+
, lambda : config['prior_lengthscale'] # lengthscale, Höher mehr sep
|
97 |
+
, True
|
98 |
+
, list(config['prior_num_features_used_sampler'].values())[0]
|
99 |
+
, config['prior_normalize_by_used_features']
|
100 |
+
, config['prior_order_y'])
|
101 |
+
|
102 |
+
|
103 |
+
def get_meta_gp_prior_hyperparameters(config):
|
104 |
+
lengthscale_sampler = trunc_norm_sampler_f(config["prior_lengthscale_mean"], config["prior_lengthscale_mean"] * config["prior_lengthscale_std_f"])
|
105 |
+
outputscale_sampler = trunc_norm_sampler_f(config["prior_outputscale_mean"], config["prior_outputscale_mean"] * config["prior_outputscale_std_f"])
|
106 |
+
|
107 |
+
return (config['prior_noise']
|
108 |
+
, outputscale_sampler
|
109 |
+
, lengthscale_sampler # lengthscale, Höher mehr sep
|
110 |
+
, True
|
111 |
+
, list(config['prior_num_features_used_sampler'].values())[0]
|
112 |
+
, config['prior_normalize_by_used_features']
|
113 |
+
, config['prior_order_y'])
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
def get_model(config, device, eval_positions, should_train=True, verbose=False):
|
118 |
+
extra_kwargs = {}
|
119 |
+
if config['prior_type'] == 'mlp':
|
120 |
+
prior_hyperparameters = get_mlp_prior_hyperparameters(config)
|
121 |
+
model_proto = priors.mlp.DataLoader
|
122 |
+
extra_kwargs['batch_size_per_gp_sample'] = 8
|
123 |
+
elif config['prior_type'] == 'gp':
|
124 |
+
prior_hyperparameters = get_gp_prior_hyperparameters(config)
|
125 |
+
model_proto = priors.fast_gp.DataLoader
|
126 |
+
elif config['prior_type'] == 'custom_gp_mix':
|
127 |
+
prior_hyperparameters = get_meta_gp_prior_hyperparameters(config)
|
128 |
+
model_proto = priors.fast_gp.DataLoader
|
129 |
+
elif config['prior_type'] == 'gp_mix':
|
130 |
+
prior_hyperparameters = get_gp_mix_prior_hyperparameters(config)
|
131 |
+
model_proto = priors.fast_gp_mix.DataLoader
|
132 |
+
else:
|
133 |
+
raise Exception()
|
134 |
+
|
135 |
+
epochs = 0 if not should_train else config['epochs']
|
136 |
+
model = train(model_proto
|
137 |
+
, Losses.bce
|
138 |
+
, encoders.Linear
|
139 |
+
, emsize=config['emsize']
|
140 |
+
, nhead=config['nhead']
|
141 |
+
, y_encoder_generator=encoders.Linear
|
142 |
+
, pos_encoder_generator=None
|
143 |
+
, batch_size=config['batch_size']
|
144 |
+
, nlayers=config['nlayers']
|
145 |
+
, nhid=config['emsize'] * config['nhid_factor']
|
146 |
+
, epochs=epochs
|
147 |
+
, warmup_epochs=epochs // 4
|
148 |
+
, bptt=config['bptt']
|
149 |
+
, gpu_device=device
|
150 |
+
, dropout=config['dropout']
|
151 |
+
, steps_per_epoch=100
|
152 |
+
, single_eval_pos_gen=get_uniform_single_eval_pos_sampler(max(eval_positions) + 1)
|
153 |
+
, extra_prior_kwargs_dict={
|
154 |
+
'num_features': config['num_features']
|
155 |
+
# , 'canonical_args': None
|
156 |
+
, 'fuse_x_y': False
|
157 |
+
, 'hyperparameters': prior_hyperparameters
|
158 |
+
, **extra_kwargs
|
159 |
+
}
|
160 |
+
, lr=config['lr']
|
161 |
+
, verbose=verbose)
|
162 |
+
|
163 |
+
return model
|
164 |
+
|
165 |
+
|
166 |
+
## General eval
|
167 |
+
|
168 |
+
def evaluate(datasets, model, method, bptt, eval_position_range, device, max_features=0, plot=False, extend_features=False, save=True, rescale_features=False, overwrite=False,
|
169 |
+
max_samples=40, path_interfix=''):
|
170 |
+
# eval_position_range: last entry is the one used to calculate metricuracy; up to index is used for training
|
171 |
+
result = {'metric': 'auc'}
|
172 |
+
|
173 |
+
metric_sum = 0
|
174 |
+
for [name, X, y, categorical_feats] in datasets:
|
175 |
+
result_ds = {}
|
176 |
+
path = f'/home/hollmann/prior-fitting/results/tabular/{path_interfix}/results_{method}_{name}.npy'
|
177 |
+
if (os.path.isfile(path)) and not overwrite:
|
178 |
+
with open(path, 'rb') as f:
|
179 |
+
result_ds = np.load(f, allow_pickle=True).tolist()
|
180 |
+
if 'time' in result_ds:
|
181 |
+
result_ds[name+'_time'] = result_ds['time']
|
182 |
+
del result_ds['time']
|
183 |
+
result.update(result_ds)
|
184 |
+
mean_metric = float(result[name + '_mean_metric_at_' + str(eval_position_range[-1])])
|
185 |
+
metric_sum += mean_metric
|
186 |
+
print(f'Loaded saved result for {name} (mean metric {mean_metric})')
|
187 |
+
continue
|
188 |
+
|
189 |
+
print('Evaluating ' + str(name))
|
190 |
+
rescale_features_factor = X.shape[1] / max_features if rescale_features and extend_features else 1.0
|
191 |
+
if extend_features:
|
192 |
+
X = torch.cat([X, torch.zeros((X.shape[0], max_features - X.shape[1]))], -1)
|
193 |
+
|
194 |
+
start_time = time.time()
|
195 |
+
ds_result = evaluate_dataset(X.to(device), y.to(device), categorical_feats, model, bptt, eval_position_range,
|
196 |
+
rescale_features=rescale_features_factor, max_samples=max_samples)
|
197 |
+
elapsed = time.time() - start_time
|
198 |
+
|
199 |
+
for i, r in enumerate(ds_result):
|
200 |
+
metric, outputs, ys = r
|
201 |
+
if save:
|
202 |
+
result_ds[name + '_per_ds_metric_at_' + str(eval_position_range[i])] = metric
|
203 |
+
result_ds[name + '_outputs_at_' + str(eval_position_range[i])] = outputs
|
204 |
+
result_ds[name + '_ys_at_' + str(eval_position_range[i])] = ys
|
205 |
+
|
206 |
+
result_ds[name + '_mean_metric_at_' + str(eval_position_range[i])] = metric_used(ys.detach().cpu().flatten(), outputs.flatten())
|
207 |
+
result_ds[name + '_time'] = elapsed
|
208 |
+
|
209 |
+
if save:
|
210 |
+
with open(path, 'wb') as f:
|
211 |
+
np.save(f, result_ds)
|
212 |
+
|
213 |
+
result.update(result_ds)
|
214 |
+
metric_sum += float(metric[-1].mean())
|
215 |
+
|
216 |
+
for pos in eval_position_range:
|
217 |
+
result[f'mean_metric_at_{pos}'] = np.array([result[d[0] + '_mean_metric_at_' + str(pos)] for d in datasets]).mean()
|
218 |
+
|
219 |
+
result['mean_metric'] = np.array([result['mean_metric_at_' + str(pos)] for pos in eval_position_range]).mean()
|
220 |
+
|
221 |
+
return result
|
222 |
+
|
223 |
+
|
224 |
+
def evaluate_dataset(X, y, categorical_feats, model, bptt, eval_position_range, plot=False, rescale_features=1.0,
|
225 |
+
max_samples=40):
|
226 |
+
result = []
|
227 |
+
for eval_position in eval_position_range:
|
228 |
+
r = evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=rescale_features,
|
229 |
+
max_samples=max_samples)
|
230 |
+
result.append(r)
|
231 |
+
print('\t Eval position ' + str(eval_position) + ' done..')
|
232 |
+
|
233 |
+
if plot:
|
234 |
+
plt.plot(np.array(list(eval_position_range)), np.array([r.mean() for r in result]))
|
235 |
+
|
236 |
+
return result
|
237 |
+
|
238 |
+
|
239 |
+
def evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=1.0, max_samples=40):
|
240 |
+
# right now permutation style is to test performance on one before the last element
|
241 |
+
# eval_position = bptt - eval_positions
|
242 |
+
|
243 |
+
# TODO: Make sure that no bias exists
|
244 |
+
# assert(eval_position % 2 == 0)
|
245 |
+
|
246 |
+
eval_xs = []
|
247 |
+
eval_ys = []
|
248 |
+
num_evals = len(X) - bptt # len(X)-bptt-(bptt-eval_position)+1
|
249 |
+
|
250 |
+
# Generate permutations of evaluation data
|
251 |
+
# with torch.random.fork_rng():
|
252 |
+
# torch.random.manual_seed(13)
|
253 |
+
# ps = [torch.randperm(2*(bptt - eval_position)) for _ in range(num_evals)]
|
254 |
+
|
255 |
+
for i in range(num_evals):
|
256 |
+
# Select chunk of data with extra evaluation positions that can be discarded
|
257 |
+
# x_ = X[i:i+bptt+(bptt-eval_position)].clone()
|
258 |
+
# y_ = y[i:i+bptt+(bptt-eval_position)].clone()
|
259 |
+
|
260 |
+
# # Permutate evaluation positions
|
261 |
+
# perm_range = slice(eval_position,bptt+(bptt - eval_position))
|
262 |
+
# x_[perm_range] = x_[perm_range][ps[i]]
|
263 |
+
# y_[perm_range] = y_[perm_range][ps[i]]
|
264 |
+
|
265 |
+
# # Discard extra evaluation positions
|
266 |
+
# x_ = x_[0:bptt]
|
267 |
+
# y_ = y_[0:bptt]
|
268 |
+
|
269 |
+
x_ = X[i:i + bptt].clone()
|
270 |
+
y_ = y[i:i + bptt].clone()
|
271 |
+
|
272 |
+
eval_xs.append(x_)
|
273 |
+
eval_ys.append(y_)
|
274 |
+
|
275 |
+
# eval data will be ordered in training range and
|
276 |
+
# will be a random subset of 2*eval_position data points in eval positions
|
277 |
+
eval_xs = torch.stack(eval_xs, 1)
|
278 |
+
eval_ys = torch.stack(eval_ys, 1)
|
279 |
+
|
280 |
+
# Limit to N samples per dataset
|
281 |
+
with torch.random.fork_rng():
|
282 |
+
torch.random.manual_seed(13)
|
283 |
+
sel = torch.randperm(eval_xs.shape[1])
|
284 |
+
eval_xs = eval_xs[:, sel[0:max_samples], :]
|
285 |
+
eval_ys = eval_ys[:, sel[0:max_samples]]
|
286 |
+
#
|
287 |
+
# if quantile_transform:
|
288 |
+
# for sample in range(0, eval_xs.shape[1]):
|
289 |
+
# quantile_transformer = preprocessing.QuantileTransformer(random_state=0, n_quantiles=eval_xs.shape[0])
|
290 |
+
# quantile_transformer.fit(eval_xs[:eval_position, sample].cpu())
|
291 |
+
# eval_xs[:, sample] = torch.tensor(quantile_transformer.transform(eval_xs[:, sample].cpu()))
|
292 |
+
|
293 |
+
if isinstance(model, nn.Module):
|
294 |
+
model.eval()
|
295 |
+
outputs = np.zeros(shape=(len(list(range(eval_position, eval_xs.shape[0]))), eval_xs.shape[1]))
|
296 |
+
for i, pos in enumerate(range(eval_position, eval_xs.shape[0])):
|
297 |
+
eval_x = torch.cat([eval_xs[:eval_position], eval_xs[pos].unsqueeze(0)])
|
298 |
+
eval_y = eval_ys[:eval_position]
|
299 |
+
|
300 |
+
# Center data using training positions so that it matches priors
|
301 |
+
mean = eval_x.mean(0)
|
302 |
+
std = eval_x.std(0) + .000001
|
303 |
+
eval_x = (eval_x - mean) / std
|
304 |
+
eval_x = eval_x / rescale_features
|
305 |
+
|
306 |
+
output = torch.sigmoid(model((eval_x, eval_y.float()), single_eval_pos=eval_position)).squeeze(-1)
|
307 |
+
outputs[i, :] = output.detach().cpu().numpy()
|
308 |
+
|
309 |
+
metric_per_t = np.array([metric_used(eval_ys[eval_position:, i].cpu(), outputs[:, i]) for i in range(eval_xs.shape[1])])
|
310 |
+
return metric_per_t, outputs, eval_ys[eval_position:]
|
311 |
+
else:
|
312 |
+
metric_eval_pos, outputs = batch_pred(model, eval_xs, eval_ys, categorical_feats, start=eval_position)
|
313 |
+
|
314 |
+
return metric_eval_pos, outputs, eval_ys[eval_position:]
|
315 |
+
|
316 |
+
|
317 |
+
def batch_pred(metric_function, eval_xs, eval_ys, categorical_feats, start=2):
|
318 |
+
metrics = []
|
319 |
+
outputs = []
|
320 |
+
# for i in tqdm(list(range(start,len(eval_xs)))):
|
321 |
+
eval_splits = list(zip(eval_xs.transpose(0, 1), eval_ys.transpose(0, 1)))
|
322 |
+
for eval_x, eval_y in tqdm(eval_splits): # eval x is One sample i.e. bptt x num_features
|
323 |
+
mean = eval_x[:start].mean(0)
|
324 |
+
std = eval_x[:start].std(0) + .000001
|
325 |
+
eval_x = (eval_x - mean) / std
|
326 |
+
|
327 |
+
metric, output = metric_function(eval_x[:start], eval_y[:start], eval_x[start:], eval_y[start:], categorical_feats)
|
328 |
+
metrics += [metric]
|
329 |
+
outputs += [output]
|
330 |
+
# metrics_per_t.append(metric_sum/eval_xs.shape[1])
|
331 |
+
return np.array(metrics), np.array(outputs).T
|
332 |
+
|
333 |
+
## Ridge
|
334 |
+
|
335 |
+
|
336 |
+
from sklearn.linear_model import RidgeClassifier
|
337 |
+
# param_grid['ridge'] = {'alpha': [0, 0.1, .5, 1.0, 2.0], 'fit_intercept': [True, False]} # 'normalize': [False],
|
338 |
+
def ridge_metric(x, y, test_x, test_y, cat_features):
|
339 |
+
import warnings
|
340 |
+
def warn(*args, **kwargs):
|
341 |
+
pass
|
342 |
+
|
343 |
+
warnings.warn = warn
|
344 |
+
|
345 |
+
x, y, test_x, test_y = x.cpu(), y.cpu(), test_x.cpu(), test_y.cpu()
|
346 |
+
|
347 |
+
clf = RidgeClassifier()
|
348 |
+
|
349 |
+
# create a dictionary of all values we want to test for n_neighbors
|
350 |
+
# use gridsearch to test all values for n_neighbors
|
351 |
+
clf = GridSearchCV(clf, param_grid['ridge'], cv=min(CV, x.shape[0]//2))
|
352 |
+
# fit model to data
|
353 |
+
clf.fit(x, y.long())
|
354 |
+
|
355 |
+
pred = clf.decision_function(test_x)
|
356 |
+
metric = metric_used(test_y.cpu().numpy(), pred)
|
357 |
+
|
358 |
+
return metric, pred
|
359 |
+
|
360 |
+
|
361 |
+
from sklearn.linear_model import LogisticRegression
|
362 |
+
param_grid['logistic'] = {'solver': ['saga'], 'penalty': ['l1', 'l2', 'none'], 'tol': [1e-2, 1e-4, 1e-10], 'max_iter': [500], 'fit_intercept': [True, False], 'C': [1e-5, 0.001, 0.01, 0.1, 1.0, 2.0]} # 'normalize': [False],
|
363 |
+
def logistic_metric(x, y, test_x, test_y, cat_features):
|
364 |
+
import warnings
|
365 |
+
def warn(*args, **kwargs):
|
366 |
+
pass
|
367 |
+
|
368 |
+
warnings.warn = warn
|
369 |
+
|
370 |
+
x, y, test_x, test_y = x.cpu(), y.cpu(), test_x.cpu(), test_y.cpu()
|
371 |
+
|
372 |
+
clf = LogisticRegression()
|
373 |
+
|
374 |
+
# create a dictionary of all values we want to test for n_neighbors
|
375 |
+
# use gridsearch to test all values for n_neighbors
|
376 |
+
clf = GridSearchCV(clf, param_grid['logistic'], cv=min(CV, x.shape[0]//2))
|
377 |
+
# fit model to data
|
378 |
+
clf.fit(x, y.long())
|
379 |
+
|
380 |
+
pred = clf.predict_proba(test_x)[:, 1]
|
381 |
+
metric = metric_used(test_y.cpu().numpy(), pred)
|
382 |
+
|
383 |
+
return metric, pred
|
384 |
+
|
385 |
+
|
386 |
+
## KNN
|
387 |
+
param_grid['knn'] = {'n_neighbors (max number of samples)': np.arange(1, 6)}
|
388 |
+
def knn_metric(x, y, test_x, test_y, cat_features):
|
389 |
+
x, y, test_x, test_y = x.cpu(), y.cpu(), test_x.cpu(), test_y.cpu()
|
390 |
+
|
391 |
+
clf = neighbors.KNeighborsClassifier() # min(param['n_neighbors'],len(y)))
|
392 |
+
param_grid_knn = {'n_neighbors': np.arange(1, min(6, len(y) - 1))}
|
393 |
+
# create a dictionary of all values we want to test for n_neighbors
|
394 |
+
# use gridsearch to test all values for n_neighbors
|
395 |
+
clf = GridSearchCV(clf, param_grid_knn, cv=min(CV, x.shape[0]//2))
|
396 |
+
# fit model to data
|
397 |
+
clf.fit(x, y.long())
|
398 |
+
|
399 |
+
# print(clf.best_params_)
|
400 |
+
|
401 |
+
# clf.fit(x, y.long())
|
402 |
+
pred = clf.predict_proba(test_x)[:, 1]
|
403 |
+
|
404 |
+
metric = metric_used(test_y.cpu().numpy(), pred)
|
405 |
+
|
406 |
+
return metric, pred
|
407 |
+
|
408 |
+
|
409 |
+
## Bayesian NN
|
410 |
+
class BayesianModel(PyroModule):
|
411 |
+
def __init__(self, model_spec, device='cuda'):
|
412 |
+
super().__init__()
|
413 |
+
|
414 |
+
self.device = device
|
415 |
+
self.num_features = model_spec['num_features']
|
416 |
+
|
417 |
+
mu, sigma = torch.tensor([0.0]).to(self.device), torch.tensor([1.0]).to(self.device)
|
418 |
+
|
419 |
+
self.fc1 = PyroModule[nn.Linear](self.num_features, model_spec['embed'])
|
420 |
+
self.fc1.weight = PyroSample(
|
421 |
+
dist.Normal(mu, sigma).expand([model_spec['embed'], self.num_features]).to_event(2))
|
422 |
+
self.fc1.bias = PyroSample(dist.Normal(mu, sigma).expand([model_spec['embed']]).to_event(1))
|
423 |
+
self.fc2 = PyroModule[nn.Linear](model_spec['embed'], 2)
|
424 |
+
self.fc2.weight = PyroSample(dist.Normal(mu, sigma).expand([2, model_spec['embed']]).to_event(2))
|
425 |
+
self.fc2.bias = PyroSample(dist.Normal(mu, sigma).expand([2]).to_event(1))
|
426 |
+
|
427 |
+
self.model = torch.nn.Sequential(self.fc1, self.fc2)
|
428 |
+
|
429 |
+
self.to(self.device)
|
430 |
+
|
431 |
+
def forward(self, x=None, y=None, seq_len=1):
|
432 |
+
if x is None:
|
433 |
+
with pyro.plate("x_plate", seq_len):
|
434 |
+
d_ = dist.Normal(torch.tensor([0.0]).to(self.device), torch.tensor([1.0]).to(self.device)).expand(
|
435 |
+
[self.num_features]).to_event(1)
|
436 |
+
x = pyro.sample("x", d_)
|
437 |
+
|
438 |
+
out = self.model(x)
|
439 |
+
mu = out.squeeze()
|
440 |
+
softmax = torch.nn.Softmax(dim=1)
|
441 |
+
# sigma = pyro.sample("sigma", dist.Uniform(torch.tensor([0.0]).to(self.device), torch.tensor([1.0]).to(self.device)))
|
442 |
+
with pyro.plate("data", out.shape[0]):
|
443 |
+
# d_ = dist.Normal(mu, sigma)
|
444 |
+
# obs = pyro.sample("obs", d_, obs=y)
|
445 |
+
s = softmax(mu)
|
446 |
+
obs = pyro.sample('obs', dist.Categorical(probs=s), obs=y).float()
|
447 |
+
|
448 |
+
return x, obs
|
449 |
+
|
450 |
+
|
451 |
+
class BayesianNNClassifier(BaseEstimator, ClassifierMixin):
|
452 |
+
|
453 |
+
def __init__(self, num_features, n_layers, embed, lr, device):
|
454 |
+
self.num_pred_samples = 400
|
455 |
+
self.num_steps = 400
|
456 |
+
self.embed = embed
|
457 |
+
self.n_layers = n_layers
|
458 |
+
self.lr = lr
|
459 |
+
self.num_features = num_features
|
460 |
+
self.device = device
|
461 |
+
|
462 |
+
def fit(self, X, y):
|
463 |
+
model_spec = {'nlayers': 2, 'embed': self.embed, 'num_features': self.num_features}
|
464 |
+
|
465 |
+
self.model = BayesianModel(model_spec, device=self.device)
|
466 |
+
self.guide = AutoDiagonalNormal(self.model).to(self.device)
|
467 |
+
self.adam = pyro.optim.Adam({"lr": self.lr})
|
468 |
+
self.svi = SVI(self.model, self.guide, self.adam, loss=Trace_ELBO())
|
469 |
+
|
470 |
+
pyro.clear_param_store()
|
471 |
+
|
472 |
+
X = X.to(self.device)
|
473 |
+
y = y.to(self.device)
|
474 |
+
|
475 |
+
for epoch in tqdm(range(0, self.num_steps)):
|
476 |
+
loss = self.svi.step(X, y)
|
477 |
+
|
478 |
+
# Return the classifier
|
479 |
+
return self
|
480 |
+
|
481 |
+
def predict(self, X):
|
482 |
+
X = X.to(self.device)
|
483 |
+
predictive = Predictive(self.model, guide=self.guide, num_samples=self.num_pred_samples)
|
484 |
+
preds = predictive(X)['obs']
|
485 |
+
preds_means = preds.float().mean(axis=0).detach().cpu()
|
486 |
+
preds_hard = preds_means > 0.5
|
487 |
+
|
488 |
+
return preds_hard.long()
|
489 |
+
|
490 |
+
def predict_proba(self, X):
|
491 |
+
X = X.to(self.device)
|
492 |
+
predictive = Predictive(self.model, guide=self.guide, num_samples=self.num_pred_samples)
|
493 |
+
preds = predictive(X)['obs']
|
494 |
+
preds_means = preds.float().mean(axis=0).detach().cpu()
|
495 |
+
|
496 |
+
return preds_means
|
497 |
+
|
498 |
+
def score(self, X, y):
|
499 |
+
return super().score(X, y)
|
500 |
+
|
501 |
+
param_grid['bayes'] = {'embed': [5, 10, 30, 64], 'lr': [1e-3, 1e-4], 'num_training_steps': [400], 'num_samples_for_prediction': [400]}
|
502 |
+
def bayes_net_metric(x, y, test_x, test_y, cat_features):
|
503 |
+
device = x.device
|
504 |
+
|
505 |
+
clf = BayesianNNClassifier(x.shape[1], 2, 1, 1e-3, device)
|
506 |
+
# create a dictionary of all values we want to test for n_neighbors
|
507 |
+
# use gridsearch to test all values for n_neighbors
|
508 |
+
clf = GridSearchCV(clf, param_grid['bayes'], cv=5)
|
509 |
+
# fit model to data
|
510 |
+
clf.fit(x.cpu(), y.long().cpu())
|
511 |
+
|
512 |
+
pred = clf.predict_proba(test_x)
|
513 |
+
metric = metric_used(test_y.cpu().numpy(), pred.cpu().numpy())
|
514 |
+
|
515 |
+
return metric, pred
|
516 |
+
|
517 |
+
## GP
|
518 |
+
param_grid['gp'] = {'params_y_scale': [0.05, 0.1, 0.5, 1.0, 5.0, 10.0],
|
519 |
+
'params_length_scale': [0.1, 0.5, 1.0, 2.0]}
|
520 |
+
def gp_metric(x, y, test_x, test_y, cat_features):
|
521 |
+
import warnings
|
522 |
+
def warn(*args, **kwargs):
|
523 |
+
pass
|
524 |
+
warnings.warn = warn
|
525 |
+
|
526 |
+
x, y, test_x, test_y = x.cpu(), y.cpu(), test_x.cpu(), test_y.cpu()
|
527 |
+
|
528 |
+
clf = GaussianProcessClassifier()
|
529 |
+
# create a dictionary of all values we want to test for n_neighbors
|
530 |
+
params_y_scale = [0.05, 0.1, 0.5, 1.0, 5.0, 10.0]# 0.000001, 0.00001,
|
531 |
+
params_length_scale = [0.1, 0.5, 1.0, 2.0] # 0.01,
|
532 |
+
param_grid = {'kernel': [y * RBF(l) for (y, l) in list(itertools.product(params_y_scale, params_length_scale))]}
|
533 |
+
# use gridsearch to test all values for n_neighbors
|
534 |
+
clf = GridSearchCV(clf, param_grid, cv=min(CV, x.shape[0]//2))
|
535 |
+
# fit model to data
|
536 |
+
clf.fit(x, y.long())
|
537 |
+
pred = clf.predict_proba(test_x)[:, 1]
|
538 |
+
metric = metric_used(test_y.cpu().numpy(), pred)
|
539 |
+
|
540 |
+
return metric, pred
|
541 |
+
|
542 |
+
|
543 |
+
## Tabnet
|
544 |
+
# https://github.com/dreamquark-ai/tabnet
|
545 |
+
param_grid['tabnet'] = {'n_d': [2, 4], 'n_steps': [2,4,6], 'gamma': [1.3], 'optimizer_params': [{'lr': 2e-2}, {'lr': 2e-1}]}
|
546 |
+
#param_grid['tabnet'] = {'n_d': [2], 'n_steps': [2], 'optimizer_params': [{'lr': 2e-2}, {'lr': 2e-1}]}
|
547 |
+
def tabnet_metric(x, y, test_x, test_y, cat_features):
|
548 |
+
x, y, test_x, test_y = x.cpu().numpy(), y.cpu().numpy(), test_x.cpu().numpy(), test_y.cpu().numpy()
|
549 |
+
|
550 |
+
mean_metrics = []
|
551 |
+
mean_best_epochs = []
|
552 |
+
|
553 |
+
for params in list(ParameterGrid(param_grid['tabnet'])):
|
554 |
+
kf = KFold(n_splits=min(5, x.shape[0]//2), random_state=None, shuffle=False)
|
555 |
+
metrics = []
|
556 |
+
best_epochs = []
|
557 |
+
for train_index, test_index in kf.split(x):
|
558 |
+
X_train, X_valid, y_train, y_valid = x[train_index], x[test_index], y[train_index], y[test_index]
|
559 |
+
|
560 |
+
clf = TabNetClassifier(verbose=True, cat_idxs=cat_features, n_a=params['n_d'], **params)
|
561 |
+
|
562 |
+
clf.fit(
|
563 |
+
X_train, y_train,
|
564 |
+
#eval_set=[(X_valid, y_valid)], patience=15
|
565 |
+
)
|
566 |
+
|
567 |
+
metric = metric_used(test_y.cpu().numpy(), clf.predict(X_valid))
|
568 |
+
metrics += [metric]
|
569 |
+
#best_epochs += [clf.best_epoch]
|
570 |
+
mean_metrics += [np.array(metrics).mean()]
|
571 |
+
#mean_best_epochs += [np.array(best_epochs).mean().astype(int)]
|
572 |
+
|
573 |
+
mean_metrics = np.array(mean_metrics)
|
574 |
+
#mean_best_epochs = np.array(mean_best_epochs)
|
575 |
+
params_used = np.array(list(ParameterGrid(param_grid['tabnet'])))
|
576 |
+
|
577 |
+
best_idx = np.argmax(mean_metrics)
|
578 |
+
#print(params_used[best_idx])
|
579 |
+
clf = TabNetClassifier(cat_idxs=cat_features, **params_used[best_idx])
|
580 |
+
|
581 |
+
clf.fit(
|
582 |
+
x, y#, max_epochs=mean_best_epochs[best_idx]
|
583 |
+
)
|
584 |
+
|
585 |
+
pred = 1 - clf.predict_proba(test_x)[:,0]
|
586 |
+
metric = metric_used(test_y, pred)
|
587 |
+
|
588 |
+
#print(metric, clf.predict(test_x), pred)
|
589 |
+
|
590 |
+
return metric, pred
|
591 |
+
|
592 |
+
|
593 |
+
# Catboost
|
594 |
+
param_grid['catboost'] = {'learning_rate': [0.1, 0.5, 1.0],
|
595 |
+
'depth': [2, 4, 7],
|
596 |
+
'l2_leaf_reg': [0.0, 0.5, 1],
|
597 |
+
'iterations': [10, 40, 70],
|
598 |
+
'loss_function': ['Logloss']}
|
599 |
+
def catboost_metric(x, y, test_x, test_y, categorical_feats):
|
600 |
+
import warnings
|
601 |
+
def warn(*args, **kwargs):
|
602 |
+
pass
|
603 |
+
|
604 |
+
warnings.warn = warn
|
605 |
+
|
606 |
+
x, y, test_x, test_y = x.numpy(), y.numpy(), test_x.numpy(), test_y.numpy()
|
607 |
+
|
608 |
+
def make_pd_from_np(x):
|
609 |
+
data = pd.DataFrame(x)
|
610 |
+
for c in categorical_feats:
|
611 |
+
data.iloc[:, c] = data.iloc[:, c].astype('int')
|
612 |
+
return data
|
613 |
+
|
614 |
+
x = make_pd_from_np(x)
|
615 |
+
test_x = make_pd_from_np(test_x)
|
616 |
+
|
617 |
+
model = CatBoostClassifier(iterations=2,
|
618 |
+
depth=2,
|
619 |
+
learning_rate=1,
|
620 |
+
loss_function='Logloss',
|
621 |
+
logging_level='Silent')
|
622 |
+
|
623 |
+
grid_search_result = model.grid_search(param_grid['catboost'],
|
624 |
+
X=x,
|
625 |
+
y=y,
|
626 |
+
cv=5,
|
627 |
+
plot=False,
|
628 |
+
verbose=False) # randomized_search with n_iter
|
629 |
+
|
630 |
+
# model.fit(x, y)
|
631 |
+
pred = model.predict_proba(test_x)[:, 1]
|
632 |
+
metric = metric_used(test_y.cpu().numpy(), pred)
|
633 |
+
|
634 |
+
return metric, pred
|
635 |
+
|
636 |
+
|
637 |
+
# XGBoost
|
638 |
+
param_grid['xgb'] = {
|
639 |
+
'min_child_weight': [0.5, 1.0],
|
640 |
+
'learning_rate': [0.02, 0.2],
|
641 |
+
#'gamma': [0.1, 0.2, 0.5, 1, 2],
|
642 |
+
'subsample': [0.5, 0.8],
|
643 |
+
'max_depth': [1, 2],
|
644 |
+
'colsample_bytree': [0.8], #0.5,
|
645 |
+
'eval_metric': ['logloss'],
|
646 |
+
'n_estimators': [100]
|
647 |
+
}
|
648 |
+
def xgb_metric(x, y, test_x, test_y, cat_features):
|
649 |
+
x, y, test_x, test_y = x.numpy(), y.numpy().astype(int), test_x.numpy(), test_y.numpy().astype(int)
|
650 |
+
|
651 |
+
clf = xgb.XGBClassifier(use_label_encoder=False)
|
652 |
+
|
653 |
+
# {'num_round': [2,5,10,20], 'max_depth': [1, 2,4,6,8], 'eta': [.1, .01, .001], 'eval_metric': 'logloss'}
|
654 |
+
# use gridsearch to test all values for n_neighbors
|
655 |
+
clf = GridSearchCV(clf, param_grid['xgb'], cv=5, n_jobs=4, verbose=2)
|
656 |
+
# fit model to data
|
657 |
+
clf.fit(x, y.astype(int))
|
658 |
+
|
659 |
+
print(clf.best_params_)
|
660 |
+
|
661 |
+
# clf.fit(x, y.long())
|
662 |
+
pred = clf.predict_proba(test_x)[:, 1]
|
663 |
+
metrics = ((pred > 0.5) == test_y).astype(float).mean()
|
664 |
+
return metrics, pred
|
665 |
+
|
666 |
+
def get_default_spec(test_datasets, valid_datasets):
|
667 |
+
bptt = 100
|
668 |
+
eval_positions = [30] #list(range(6, 42, 2)) # list(range(10, bptt-10, 20)) + [bptt-10]
|
669 |
+
max_features = max([X.shape[1] for (_, X, _, _) in test_datasets] + [X.shape[1] for (_, X, _, _) in valid_datasets])
|
670 |
+
max_samples = 20
|
671 |
+
|
672 |
+
return bptt, eval_positions, max_features, max_samples
|
673 |
+
|
674 |
+
if __name__ == '__main__':
|
675 |
+
parser = argparse.ArgumentParser()
|
676 |
+
parser.add_argument('--method', default='ridge', type=str)
|
677 |
+
parser.add_argument('--did', default=-1, type=int)
|
678 |
+
parser.add_argument('--overwrite', default=False, type=bool)
|
679 |
+
args = parser.parse_args()
|
680 |
+
|
681 |
+
test_datasets, _ = load_openml_list(test_dids_classification)
|
682 |
+
valid_datasets, _ = load_openml_list(valid_dids_classification)
|
683 |
+
|
684 |
+
selector = 'test'
|
685 |
+
ds = valid_datasets if selector == 'valid' else test_datasets
|
686 |
+
if args.did > -1:
|
687 |
+
ds = ds[args.did:args.did+1]
|
688 |
+
|
689 |
+
bptt, eval_positions, max_features, max_samples = get_default_spec(test_datasets, valid_datasets)
|
690 |
+
|
691 |
+
if args.method == 'bayes':
|
692 |
+
clf = bayes_net_metric
|
693 |
+
device = 'cpu'
|
694 |
+
elif args.method == 'gp':
|
695 |
+
clf = gp_metric
|
696 |
+
device = 'cpu'
|
697 |
+
elif args.method == 'ridge':
|
698 |
+
clf = ridge_metric
|
699 |
+
device = 'cpu'
|
700 |
+
elif args.method == 'knn':
|
701 |
+
clf = knn_metric
|
702 |
+
device = 'cpu'
|
703 |
+
elif args.method == 'catboost':
|
704 |
+
clf = catboost_metric
|
705 |
+
device = 'cpu'
|
706 |
+
elif args.method == 'tabnet':
|
707 |
+
clf = tabnet_metric
|
708 |
+
device = 'cpu'
|
709 |
+
elif args.method == 'xgb':
|
710 |
+
# Uses lots of cpu so difficult to time
|
711 |
+
clf = xgb_metric
|
712 |
+
device = 'cpu'
|
713 |
+
elif args.method == 'logistic':
|
714 |
+
clf = logistic_metric
|
715 |
+
device = 'cpu'
|
716 |
+
else:
|
717 |
+
clf = None
|
718 |
+
device = 'cpu'
|
719 |
+
|
720 |
+
start_time = time.time()
|
721 |
+
result = evaluate(ds, clf, args.method, bptt, eval_positions, device=device, max_samples=max_samples, overwrite=args.overwrite, save=True)
|
722 |
+
result['time_spent'] = time.time() - start_time
|
723 |
+
|
724 |
+
with open(f'/home/hollmann/prior-fitting/results/tabular/results_{selector}_{args.method}.npy', 'wb') as f:
|
725 |
+
np.save(f, result)
|
prior-fitting/train.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import time
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from transformer import TransformerModel
|
8 |
+
from bar_distribution import BarDistribution, FullSupportBarDistribution, get_bucket_limits
|
9 |
+
from utils import get_cosine_schedule_with_warmup, get_openai_lr, StoreDictKeyPair, get_weighted_single_eval_pos_sampler, get_uniform_single_eval_pos_sampler
|
10 |
+
import priors
|
11 |
+
import encoders
|
12 |
+
import positional_encodings
|
13 |
+
|
14 |
+
class Losses():
|
15 |
+
gaussian = nn.GaussianNLLLoss(full=True, reduction='none')
|
16 |
+
mse = nn.MSELoss(reduction='none')
|
17 |
+
ce = nn.CrossEntropyLoss(reduction='none')
|
18 |
+
bce = nn.BCEWithLogitsLoss(reduction='none')
|
19 |
+
get_BarDistribution = BarDistribution
|
20 |
+
|
21 |
+
|
22 |
+
def train(priordataloader_class, criterion, encoder_generator, emsize=200, nhid=200, nlayers=6, nhead=2, dropout=0.2,
|
23 |
+
epochs=10, steps_per_epoch=100, batch_size=200, bptt=10, lr=None, warmup_epochs=10, input_normalization=False,
|
24 |
+
y_encoder_generator=None, pos_encoder_generator=None, decoder=None, extra_prior_kwargs_dict={}, scheduler=get_cosine_schedule_with_warmup,
|
25 |
+
load_weights_from_this_state_dict=None, validation_period=10, single_eval_pos_gen=None, gpu_device='cuda:0',
|
26 |
+
aggregate_k_gradients=1, verbose=True
|
27 |
+
):
|
28 |
+
|
29 |
+
device = gpu_device if torch.cuda.is_available() else 'cpu:0'
|
30 |
+
print(f'Using {device} device')
|
31 |
+
dl = priordataloader_class(num_steps=steps_per_epoch, batch_size=batch_size, seq_len=bptt, **extra_prior_kwargs_dict)
|
32 |
+
|
33 |
+
encoder = encoder_generator(dl.num_features+1 if dl.fuse_x_y else dl.num_features,emsize)
|
34 |
+
n_out = dl.num_outputs
|
35 |
+
if isinstance(criterion, nn.GaussianNLLLoss):
|
36 |
+
n_out *= 2
|
37 |
+
elif isinstance(criterion, BarDistribution) or "BarDistribution" in criterion.__class__.__name__: # TODO remove this fix (only for dev)
|
38 |
+
assert n_out == 1
|
39 |
+
n_out = criterion.num_bars
|
40 |
+
model = TransformerModel(encoder, n_out, emsize, nhead, nhid, nlayers, dropout,
|
41 |
+
y_encoder=y_encoder_generator(1, emsize), input_normalization=input_normalization,
|
42 |
+
pos_encoder=(pos_encoder_generator or positional_encodings.NoPositionalEncoding)(emsize, bptt*2),
|
43 |
+
decoder=decoder
|
44 |
+
)
|
45 |
+
model.criterion = criterion
|
46 |
+
if load_weights_from_this_state_dict is not None:
|
47 |
+
model.load_state_dict(load_weights_from_this_state_dict)
|
48 |
+
model.to(device)
|
49 |
+
|
50 |
+
|
51 |
+
# learning rate
|
52 |
+
if lr is None:
|
53 |
+
lr = get_openai_lr(model)
|
54 |
+
print(f"Using OpenAI max lr of {lr}.")
|
55 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
56 |
+
scheduler = scheduler(optimizer, warmup_epochs, epochs)
|
57 |
+
|
58 |
+
def train():
|
59 |
+
model.train() # Turn on the train mode
|
60 |
+
total_loss = 0.
|
61 |
+
total_positional_losses = 0.
|
62 |
+
total_positional_losses_recorded = 0
|
63 |
+
start_time = time.time()
|
64 |
+
before_get_batch = time.time()
|
65 |
+
assert len(dl) % aggregate_k_gradients == 0, 'Please set the number of steps per epoch s.t. `aggregate_k_gradients` divides it.'
|
66 |
+
for batch, (data, targets) in enumerate(dl):
|
67 |
+
time_to_get_batch = time.time() - before_get_batch
|
68 |
+
before_forward = time.time()
|
69 |
+
single_eval_pos = single_eval_pos_gen() if callable(single_eval_pos_gen) else single_eval_pos_gen
|
70 |
+
output = model(tuple(e.to(device) for e in data) if isinstance(data, tuple) else data.to(device)
|
71 |
+
, single_eval_pos=single_eval_pos)
|
72 |
+
|
73 |
+
forward_time = time.time() - before_forward
|
74 |
+
|
75 |
+
if single_eval_pos is not None:
|
76 |
+
targets = targets[single_eval_pos:]
|
77 |
+
|
78 |
+
if isinstance(criterion, nn.GaussianNLLLoss):
|
79 |
+
assert output.shape[-1] == 2, \
|
80 |
+
'need to write a little bit of code to handle multiple regression targets at once'
|
81 |
+
|
82 |
+
mean_pred = output[..., 0]
|
83 |
+
var_pred = output[..., 1].abs()
|
84 |
+
losses = criterion(mean_pred.flatten(), targets.to(device).flatten(), var=var_pred.flatten())
|
85 |
+
elif isinstance(criterion, (nn.MSELoss, nn.BCEWithLogitsLoss)):
|
86 |
+
losses = criterion(output.flatten(), targets.to(device).flatten())
|
87 |
+
else:
|
88 |
+
losses = criterion(output.reshape(-1, n_out), targets.to(device).flatten())
|
89 |
+
losses = losses.view(*output.shape[0:2]).squeeze(-1)
|
90 |
+
|
91 |
+
|
92 |
+
loss = losses.mean()
|
93 |
+
loss.backward()
|
94 |
+
if batch % aggregate_k_gradients == aggregate_k_gradients - 1:
|
95 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
|
96 |
+
optimizer.step()
|
97 |
+
optimizer.zero_grad()
|
98 |
+
|
99 |
+
step_time = time.time() - before_forward
|
100 |
+
|
101 |
+
total_loss += loss.item()
|
102 |
+
total_positional_losses += losses.mean(1).cpu().detach() if single_eval_pos is None else \
|
103 |
+
nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)*loss.cpu().detach()
|
104 |
+
|
105 |
+
total_positional_losses_recorded += torch.ones(bptt) if single_eval_pos is None else \
|
106 |
+
nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)
|
107 |
+
|
108 |
+
before_get_batch = time.time()
|
109 |
+
return total_loss / steps_per_epoch, (
|
110 |
+
total_positional_losses / total_positional_losses_recorded).tolist(), time_to_get_batch, forward_time, step_time
|
111 |
+
|
112 |
+
best_val_loss = float("inf")
|
113 |
+
best_model = None
|
114 |
+
total_loss = float('inf')
|
115 |
+
total_positional_losses = float('inf')
|
116 |
+
for epoch in range(1, epochs + 1):
|
117 |
+
epoch_start_time = time.time()
|
118 |
+
total_loss, total_positional_losses, time_to_get_batch, forward_time, step_time = train()
|
119 |
+
if hasattr(dl, 'validate') and epoch % validation_period == 0:
|
120 |
+
with torch.no_grad():
|
121 |
+
val_score = dl.validate(model)
|
122 |
+
else:
|
123 |
+
val_score = None
|
124 |
+
|
125 |
+
if verbose:
|
126 |
+
print('-' * 89)
|
127 |
+
print(
|
128 |
+
f'| end of epoch {epoch:3d} | time: {(time.time() - epoch_start_time):5.2f}s | mean loss {total_loss:5.2f} | '
|
129 |
+
f"pos losses {','.join([f'{l:5.2f}' for l in total_positional_losses])}, lr {scheduler.get_last_lr()[0]}"
|
130 |
+
f' data time {time_to_get_batch:5.2f} step time {step_time:5.2f}'
|
131 |
+
f' forward time {forward_time:5.2f}' + (f'val score {val_score}' if val_score is not None else ''))
|
132 |
+
print('-' * 89)
|
133 |
+
|
134 |
+
scheduler.step()
|
135 |
+
return total_loss, total_positional_losses, model.to('cpu')
|
136 |
+
|
137 |
+
def _parse_args(config_parser, parser):
|
138 |
+
# Do we have a config file to parse?
|
139 |
+
args_config, remaining = config_parser.parse_known_args()
|
140 |
+
if args_config.config:
|
141 |
+
with open(args_config.config, 'r') as f:
|
142 |
+
cfg = yaml.safe_load(f)
|
143 |
+
parser.set_defaults(**cfg)
|
144 |
+
|
145 |
+
# The main arg parser parses the rest of the args, the usual
|
146 |
+
# defaults will have been overridden if config file specified.
|
147 |
+
args = parser.parse_args(remaining)
|
148 |
+
|
149 |
+
# Cache the args as a text string to save them in the output dir later
|
150 |
+
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
|
151 |
+
return args, args_text
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == '__main__':
|
155 |
+
config_parser = argparse.ArgumentParser(description='Only used as a first parser for the config file path.')
|
156 |
+
config_parser.add_argument('--config')
|
157 |
+
parser = argparse.ArgumentParser()
|
158 |
+
parser.add_argument('prior')
|
159 |
+
parser.add_argument('--loss_function', default='barnll')
|
160 |
+
# Optional Arg's for `--loss_function barnll`
|
161 |
+
parser.add_argument('--min_y', type=float, help='barnll can only model y in strict ranges, this is the minimum y can take.')
|
162 |
+
parser.add_argument('--max_y', type=float, help='barnll can only model y in strict ranges, this is the maximum y can take.')
|
163 |
+
parser.add_argument('--num_buckets', default=100, type=int)
|
164 |
+
#parser.add_argument('--num_features', default=None, type=int, help='Specify depending on the prior.')
|
165 |
+
parser.add_argument("--extra_prior_kwargs_dict", default={'fuse_x_y': False}, dest="extra_prior_kwargs_dict", action=StoreDictKeyPair, nargs="+", metavar="KEY=VAL", help='Specify depending on the prior.')
|
166 |
+
parser.add_argument('--encoder', default='linear', type=str, help='Specify depending on the prior.')
|
167 |
+
parser.add_argument('--y_encoder', default='linear', type=str, help='Specify depending on the prior. You should specify this if you do not fuse x and y.')
|
168 |
+
parser.add_argument('--pos_encoder', default='sinus', type=str, help='Specify depending on the prior.')
|
169 |
+
parser.add_argument('--bptt', default=10, type=int)
|
170 |
+
parser.add_argument('--epochs', default=200, type=int)
|
171 |
+
parser.add_argument('--warmup_epochs', default=50, type=int)
|
172 |
+
parser.add_argument('--validation_period', default=10, type=int)
|
173 |
+
parser.add_argument('--permutation_invariant_max_eval_pos', default=None, type=int, help='Set this to an int to ')
|
174 |
+
parser.add_argument('--permutation_invariant_sampling', default='weighted', help="Only relevant if --permutation_invariant_max_eval_pos is set.")
|
175 |
+
|
176 |
+
# these can likely be mostly left at defaults
|
177 |
+
parser.add_argument('--emsize', default=512, type=int) # sometimes even larger is better e.g. 1024
|
178 |
+
parser.add_argument('--nlayers', default=6, type=int)
|
179 |
+
parser.add_argument('--nhid', default=None, type=int) # 2*emsize is the default
|
180 |
+
parser.add_argument('--nhead', default=4, type=int) # nhead = emsize / 64 in the original paper
|
181 |
+
parser.add_argument('--dropout', default=.0, type=float)
|
182 |
+
parser.add_argument('--steps_per_epoch', default=10, type=int)
|
183 |
+
parser.add_argument('--batch_size', default=1000, type=int)
|
184 |
+
parser.add_argument('--lr', '--learning_rate', default=.001, type=float) # try also .0003, .0001, go lower with lower batch size
|
185 |
+
|
186 |
+
args, _ = _parse_args(config_parser, parser)
|
187 |
+
|
188 |
+
if args.nhid is None:
|
189 |
+
args.nhid = 2*args.emsize
|
190 |
+
|
191 |
+
prior = args.__dict__.pop('prior')
|
192 |
+
|
193 |
+
if prior == 'gp':
|
194 |
+
prior = priors.fast_gp.DataLoader
|
195 |
+
elif prior == 'ridge':
|
196 |
+
prior = priors.ridge.DataLoader
|
197 |
+
elif prior == 'stroke':
|
198 |
+
prior = priors.stroke.DataLoader
|
199 |
+
elif prior == 'mix_gp':
|
200 |
+
prior = priors.fast_gp_mix.DataLoader
|
201 |
+
else:
|
202 |
+
raise NotImplementedError(f'Prior == {prior}.')
|
203 |
+
|
204 |
+
|
205 |
+
loss_function = args.__dict__.pop('loss_function')
|
206 |
+
|
207 |
+
criterion = nn.GaussianNLLLoss(reduction='none', full=True)
|
208 |
+
classificiation_criterion = nn.CrossEntropyLoss(reduction='none')
|
209 |
+
num_buckets = args.__dict__.pop('num_buckets')
|
210 |
+
max_y = args.__dict__.pop('max_y')
|
211 |
+
min_y = args.__dict__.pop('min_y')
|
212 |
+
# criterion = nn.MSELoss(reduction='none')
|
213 |
+
|
214 |
+
def get_y_sample():
|
215 |
+
dl = prior(num_steps=1, batch_size=args.batch_size * args.steps_per_epoch, seq_len=args.bptt,
|
216 |
+
**args.extra_prior_kwargs_dict)
|
217 |
+
y_sample = next(iter(dl))[-1]
|
218 |
+
print(f'Creating Bar distribution with borders from y sample of size {y_sample.numel()}')
|
219 |
+
return y_sample
|
220 |
+
|
221 |
+
if loss_function == 'ce':
|
222 |
+
criterion = nn.CrossEntropyLoss(reduction='none')
|
223 |
+
elif loss_function == 'gaussnll':
|
224 |
+
criterion = nn.GaussianNLLLoss(reduction='none', full=True)
|
225 |
+
elif loss_function == 'mse':
|
226 |
+
criterion = nn.MSELoss(reduction='none')
|
227 |
+
elif loss_function == 'barnll':
|
228 |
+
criterion = BarDistribution(borders=get_bucket_limits(num_buckets, full_range=(min_y,max_y)))
|
229 |
+
elif loss_function == 'adaptivebarnll':
|
230 |
+
borders = get_bucket_limits(num_buckets, ys=get_y_sample(), full_range=(min_y,max_y))
|
231 |
+
criterion = BarDistribution(borders=borders)
|
232 |
+
elif loss_function == 'adaptivefullsupportbarnll':
|
233 |
+
assert min_y is None and max_y is None, "Please do not specify `min_y` and `max_y` with `unboundedadaptivebarnll`."
|
234 |
+
borders = get_bucket_limits(num_buckets, ys=get_y_sample())
|
235 |
+
criterion = FullSupportBarDistribution(borders=borders)
|
236 |
+
else:
|
237 |
+
raise NotImplementedError(f'loss_function == {loss_function}.')
|
238 |
+
|
239 |
+
|
240 |
+
|
241 |
+
encoder = args.__dict__.pop('encoder')
|
242 |
+
y_encoder = args.__dict__.pop('y_encoder')
|
243 |
+
|
244 |
+
def get_encoder_generator(encoder):
|
245 |
+
if encoder == 'linear':
|
246 |
+
encoder_generator = encoders.Linear
|
247 |
+
elif encoder == 'mlp':
|
248 |
+
encoder_generator = encoders.MLP
|
249 |
+
elif encoder == 'positional':
|
250 |
+
encoder_generator = encoders.Positional
|
251 |
+
else:
|
252 |
+
raise NotImplementedError(f'A {encoder} encoder is not valid.')
|
253 |
+
return encoder_generator
|
254 |
+
|
255 |
+
encoder_generator = get_encoder_generator(encoder)
|
256 |
+
y_encoder_generator = get_encoder_generator(y_encoder)
|
257 |
+
|
258 |
+
pos_encoder = args.__dict__.pop('pos_encoder')
|
259 |
+
|
260 |
+
if pos_encoder == 'none':
|
261 |
+
pos_encoder_generator = None
|
262 |
+
elif pos_encoder == 'sinus':
|
263 |
+
pos_encoder_generator = positional_encodings.PositionalEncoding
|
264 |
+
elif pos_encoder == 'learned':
|
265 |
+
pos_encoder_generator = positional_encodings.LearnedPositionalEncoding
|
266 |
+
elif pos_encoder == 'paired_scrambled_learned':
|
267 |
+
pos_encoder_generator = positional_encodings.PairedScrambledPositionalEncodings
|
268 |
+
else:
|
269 |
+
raise NotImplementedError(f'pos_encoer == {pos_encoder} is not valid.')
|
270 |
+
|
271 |
+
permutation_invariant_max_eval_pos = args.__dict__.pop('permutation_invariant_max_eval_pos')
|
272 |
+
permutation_invariant_sampling = args.__dict__.pop('permutation_invariant_sampling')
|
273 |
+
if permutation_invariant_max_eval_pos is not None:
|
274 |
+
if permutation_invariant_sampling == 'weighted':
|
275 |
+
get_sampler = get_weighted_single_eval_pos_sampler
|
276 |
+
elif permutation_invariant_sampling == 'uniform':
|
277 |
+
get_sampler = get_uniform_single_eval_pos_sampler
|
278 |
+
else:
|
279 |
+
raise ValueError()
|
280 |
+
args.__dict__['single_eval_pos_gen'] = get_sampler(permutation_invariant_max_eval_pos)
|
281 |
+
|
282 |
+
|
283 |
+
print("ARGS for `train`:", args.__dict__)
|
284 |
+
|
285 |
+
train(prior, criterion, encoder_generator,
|
286 |
+
y_encoder_generator=y_encoder_generator,pos_encoder_generator=pos_encoder_generator,
|
287 |
+
**args.__dict__)
|
288 |
+
|
prior-fitting/transformer.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
8 |
+
from torch.nn.modules.transformer import MultiheadAttention, _get_activation_fn
|
9 |
+
|
10 |
+
from utils import SeqBN
|
11 |
+
|
12 |
+
|
13 |
+
class TransformerModel(nn.Module):
|
14 |
+
def __init__(self, encoder, n_out, ninp, nhead, nhid, nlayers, dropout=0.0, y_encoder=None, pos_encoder=None, decoder=None, input_normalization=False):
|
15 |
+
super().__init__()
|
16 |
+
self.model_type = 'Transformer'
|
17 |
+
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout, activation='gelu')
|
18 |
+
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
|
19 |
+
self.ninp = ninp
|
20 |
+
self.encoder = encoder
|
21 |
+
self.y_encoder = y_encoder
|
22 |
+
self.pos_encoder = pos_encoder
|
23 |
+
self.decoder = decoder(ninp, nhid, n_out) if decoder is not None else nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, n_out))
|
24 |
+
self.input_ln = SeqBN(ninp) if input_normalization else None
|
25 |
+
|
26 |
+
self.init_weights()
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def generate_square_subsequent_mask(sz):
|
30 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
31 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
32 |
+
return mask
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def generate_D_q_matrix(sz, query_size):
|
36 |
+
train_size = sz-query_size
|
37 |
+
mask = torch.zeros(sz,sz) == 0
|
38 |
+
mask[:,train_size:].zero_()
|
39 |
+
mask |= torch.eye(sz) == 1
|
40 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
41 |
+
return mask
|
42 |
+
|
43 |
+
def init_weights(self):
|
44 |
+
initrange = 1.
|
45 |
+
# if isinstance(self.encoder,EmbeddingEncoder):
|
46 |
+
# self.encoder.weight.data.uniform_(-initrange, initrange)
|
47 |
+
# self.decoder.bias.data.zero_()
|
48 |
+
# self.decoder.weight.data.uniform_(-initrange, initrange)
|
49 |
+
for layer in self.transformer_encoder.layers:
|
50 |
+
nn.init.zeros_(layer.linear2.weight)
|
51 |
+
nn.init.zeros_(layer.linear2.bias)
|
52 |
+
nn.init.zeros_(layer.self_attn.out_proj.weight)
|
53 |
+
nn.init.zeros_(layer.self_attn.out_proj.bias)
|
54 |
+
|
55 |
+
def forward(self, src, src_mask=None, single_eval_pos=None):
|
56 |
+
assert single_eval_pos is not None, 'Single eval pos is required now.'
|
57 |
+
fuse_x_y = not isinstance(src, tuple)
|
58 |
+
assert not(fuse_x_y and single_eval_pos is not None), \
|
59 |
+
'Don\'t use both fuxe_x_y and single_eval_pos (permutation equivariant setup) at the same time.'
|
60 |
+
if src_mask is None:
|
61 |
+
x_src = src if fuse_x_y else src[0]
|
62 |
+
if single_eval_pos is None:
|
63 |
+
src_mask = self.generate_square_subsequent_mask(len(x_src) if fuse_x_y else 2*len(x_src)).to(x_src.device)
|
64 |
+
else:
|
65 |
+
src_mask = self.generate_D_q_matrix(len(x_src), len(x_src)-single_eval_pos).to(x_src.device)
|
66 |
+
if not fuse_x_y:
|
67 |
+
x_src, y_src = src
|
68 |
+
x_src = self.encoder(x_src)
|
69 |
+
y_src = self.y_encoder(y_src.unsqueeze(-1))
|
70 |
+
if single_eval_pos is None:
|
71 |
+
src = torch.stack([x_src, y_src], 1).view(-1, *x_src.shape[1:])
|
72 |
+
else:
|
73 |
+
train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
|
74 |
+
src = torch.cat([train_x, x_src[single_eval_pos:]], 0)
|
75 |
+
else:
|
76 |
+
src = self.encoder(src)
|
77 |
+
|
78 |
+
if self.input_ln is not None:
|
79 |
+
src = self.input_ln(src)
|
80 |
+
|
81 |
+
if self.pos_encoder is not None:
|
82 |
+
src = self.pos_encoder(src)
|
83 |
+
|
84 |
+
output = self.transformer_encoder(src, src_mask)
|
85 |
+
output = self.decoder(output)
|
86 |
+
if fuse_x_y:
|
87 |
+
return output
|
88 |
+
elif single_eval_pos is None:
|
89 |
+
return output[0::2]
|
90 |
+
else:
|
91 |
+
return output[single_eval_pos:]
|
prior-fitting/utils.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import argparse
|
3 |
+
import random
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.optim.lr_scheduler import LambdaLR
|
8 |
+
|
9 |
+
# copied from huggingface
|
10 |
+
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
|
11 |
+
""" Create a schedule with a learning rate that decreases following the
|
12 |
+
values of the cosine function between 0 and `pi * cycles` after a warmup
|
13 |
+
period during which it increases linearly between 0 and 1.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def lr_lambda(current_step):
|
17 |
+
if current_step < num_warmup_steps:
|
18 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
19 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
20 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
21 |
+
|
22 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
23 |
+
|
24 |
+
# copied from huggingface
|
25 |
+
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
26 |
+
"""
|
27 |
+
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
28 |
+
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
optimizer (:class:`~torch.optim.Optimizer`):
|
32 |
+
The optimizer for which to schedule the learning rate.
|
33 |
+
num_warmup_steps (:obj:`int`):
|
34 |
+
The number of steps for the warmup phase.
|
35 |
+
num_training_steps (:obj:`int`):
|
36 |
+
The total number of training steps.
|
37 |
+
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
38 |
+
The index of the last epoch when resuming training.
|
39 |
+
|
40 |
+
Return:
|
41 |
+
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def lr_lambda(current_step: int):
|
45 |
+
if current_step < num_warmup_steps:
|
46 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
47 |
+
return max(
|
48 |
+
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
|
49 |
+
)
|
50 |
+
|
51 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
52 |
+
|
53 |
+
|
54 |
+
def get_openai_lr(transformer_model):
|
55 |
+
num_params = sum(p.numel() for p in transformer_model.parameters())
|
56 |
+
return 0.003239 - 0.0001395 * math.log(num_params)
|
57 |
+
|
58 |
+
|
59 |
+
def get_weighted_single_eval_pos_sampler(max_len):
|
60 |
+
"""
|
61 |
+
This gives a sampler that can be used for `single_eval_pos` which yields good performance for all positions p,
|
62 |
+
where p <= `max_len`. At most `max_len` - 1 examples are shown to the Transformer.
|
63 |
+
:return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
|
64 |
+
"""
|
65 |
+
return lambda: random.choices(range(max_len), [1 / (max_len - i) for i in range(max_len)])[0]
|
66 |
+
|
67 |
+
|
68 |
+
def get_uniform_single_eval_pos_sampler(max_len):
|
69 |
+
"""
|
70 |
+
Just sample any evaluation position with the same weight
|
71 |
+
:return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
|
72 |
+
"""
|
73 |
+
return lambda: random.choices(range(max_len))[0]
|
74 |
+
|
75 |
+
|
76 |
+
class SeqBN(nn.Module):
|
77 |
+
def __init__(self, d_model):
|
78 |
+
super().__init__()
|
79 |
+
self.bn = nn.BatchNorm1d(d_model)
|
80 |
+
self.d_model = d_model
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
assert self.d_model == x.shape[-1]
|
84 |
+
flat_x = x.view(-1, self.d_model)
|
85 |
+
flat_x = self.bn(flat_x)
|
86 |
+
return flat_x.view(*x.shape)
|
87 |
+
|
88 |
+
|
89 |
+
def set_locals_in_self(locals):
|
90 |
+
self = locals['self']
|
91 |
+
for var_name, val in locals.items():
|
92 |
+
if var_name != 'self': setattr(self, var_name, val)
|
93 |
+
|
94 |
+
|
95 |
+
default_device = 'cuda:0' if torch.cuda.is_available() else 'cpu:0'
|
96 |
+
|
97 |
+
|
98 |
+
# Copied from StackOverflow, but we do an eval on the values additionally
|
99 |
+
class StoreDictKeyPair(argparse.Action):
|
100 |
+
def __init__(self, option_strings, dest, nargs=None, **kwargs):
|
101 |
+
self._nargs = nargs
|
102 |
+
super(StoreDictKeyPair, self).__init__(option_strings, dest, nargs=nargs, **kwargs)
|
103 |
+
|
104 |
+
def __call__(self, parser, namespace, values, option_string=None):
|
105 |
+
my_dict = {}
|
106 |
+
for kv in values:
|
107 |
+
k, v = kv.split("=")
|
108 |
+
try:
|
109 |
+
my_dict[k] = eval(v)
|
110 |
+
except NameError:
|
111 |
+
my_dict[k] = v
|
112 |
+
setattr(namespace, self.dest, my_dict)
|
113 |
+
print("dict values: {}".format(my_dict))
|
114 |
+
|
115 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Recommend to use python >= 3.9
|
2 |
+
gpytorch==1.5.0
|
3 |
+
pyro-ppl==1.7.0
|
4 |
+
torch==1.9.0
|
5 |
+
scikit-learn==0.24.2
|
6 |
+
pyyaml==5.4.1
|
7 |
+
blitz-bayesian-pytorch==0.2.7
|
8 |
+
seaborn==0.11.2
|
9 |
+
xgboost==1.4.0
|
10 |
+
tqdm==4.62.1
|
11 |
+
numpy==1.21.2
|
12 |
+
openml==0.12.2
|
13 |
+
catboost==0.26.1
|