File size: 3,756 Bytes
5ad4668
f50f696
 
 
 
 
5ad4668
f50f696
5ad4668
f50f696
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ad4668
f50f696
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import gpytorch
import torch
import sys

import gpytorch

# We will use the simplest form of GP model, exact inference
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

def get_model(x, y, hyperparameters):
    likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9))
    model = ExactGPModel(x, y, likelihood)
    model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters["noise"]
    model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters["outputscale"]
    model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \
                                                 hyperparameters["lengthscale"]
    return model, likelihood



excuse = "Please only specify numbers, x values should be in [0,1] and y values in [-1,1]."
excuse_max_examples = "This model is trained to work with up to 4 input points."
hyperparameters = {'noise': 1e-4, 'outputscale': 1., 'lengthscale': .1, 'fast_computations': (False,False,False)}


conf = .5

def mean_and_bounds_for_gp(x,y,test_xs):
    gp_model, likelihood = get_model(x,y,hyperparameters)
    gp_model.eval()
    l = likelihood(gp_model(test_xs))
    means = l.mean.squeeze()
    varis = torch.diagonal(l.covariance_matrix.squeeze())
    stds = varis.sqrt()
    return means, means-stds, means+stds


def mean_and_bounds_for_pnf(x,y,test_xs, choice):
    sys.path.append('prior-fitting/')
    model = torch.load(f'onefeature_gp_ls.1_pnf_{choice}.pt')


    logits = model((torch.cat([x,test_xs],0).unsqueeze(1),y.unsqueeze(1)),single_eval_pos=len(x))
    bounds = model.criterion.quantile(logits,center_prob=.682).squeeze(1)
    return model.criterion.mean(logits).squeeze(1), bounds[:,0], bounds[:,1]

def plot_w_conf_interval(ax_or_plt, x, m, lb, ub, color):
    ax_or_plt.plot(x.squeeze(-1),m, color=color)
    ax_or_plt.fill_between(x.squeeze(-1), lb, ub, alpha=.1, color=color)




@torch.no_grad()
def infer(table, choice):
    vfunc = np.vectorize(lambda s: len(s))
    non_empty_row_mask = (vfunc(table).sum(1) != 0)
    table = table[non_empty_row_mask]

    try:
        table = table.astype(np.float32)
    except ValueError:
        return excuse, None
    x = torch.tensor(table[:,0]).unsqueeze(1)
    y = torch.tensor(table[:,1])
    fig = plt.figure()

    if len(x) > 4:
        return excuse_max_examples, None
    if (x<0.).any() or (x>1.).any() or (y<-1).any() or (y>1).any():
        return excuse, None

    plt.scatter(x,y)


    
    test_xs = torch.linspace(0,1,100).unsqueeze(1)
    
    plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_gp(x,y,test_xs), 'green')
    plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_pnf(x,y,test_xs, choice), 'blue')


    
    return '', plt.gcf()

iface = gr.Interface(fn=infer, 
                     inputs=[
                         gr.inputs.Dataframe(headers=["x", "y"], datatype=["number", "number"], row_count=2, type='numpy', default=[['.25','.1'],['.75','.4']]),
                         gr.inputs.Radio(['160K','800K','4M'], type="value", default='4M', label='Training Costs')
                     ], outputs=["text","plot"])
iface.launch()