Samuel Mueller
commited on
Commit
·
981d877
1
Parent(s):
04eb228
legend added
Browse files
app.py
CHANGED
@@ -55,9 +55,9 @@ def mean_and_bounds_for_pnf(x,y,test_xs, choice):
|
|
55 |
bounds = model.criterion.quantile(logits,center_prob=.682).squeeze(1)
|
56 |
return model.criterion.mean(logits).squeeze(1), bounds[:,0], bounds[:,1]
|
57 |
|
58 |
-
def plot_w_conf_interval(ax_or_plt, x, m, lb, ub, color):
|
59 |
-
ax_or_plt.plot(x.squeeze(-1),m, color=color)
|
60 |
-
ax_or_plt.fill_between(x.squeeze(-1), lb, ub, alpha=.1, color=color)
|
61 |
|
62 |
|
63 |
|
@@ -74,22 +74,24 @@ def infer(table, choice):
|
|
74 |
return excuse, None
|
75 |
x = torch.tensor(table[:,0]).unsqueeze(1)
|
76 |
y = torch.tensor(table[:,1])
|
77 |
-
fig = plt.figure(figsize=(4
|
78 |
|
79 |
if len(x) > 4:
|
80 |
return excuse_max_examples, None
|
81 |
if (x<0.).any() or (x>1.).any() or (y<-1).any() or (y>1).any():
|
82 |
return excuse, None
|
83 |
|
84 |
-
plt.scatter(x,y)
|
85 |
|
86 |
|
87 |
|
88 |
test_xs = torch.linspace(0,1,100).unsqueeze(1)
|
89 |
|
90 |
-
plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_gp(x,y,test_xs), 'green')
|
91 |
-
plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_pnf(x,y,test_xs, choice), 'blue')
|
92 |
-
|
|
|
|
|
93 |
|
94 |
|
95 |
return '', plt.gcf()
|
|
|
55 |
bounds = model.criterion.quantile(logits,center_prob=.682).squeeze(1)
|
56 |
return model.criterion.mean(logits).squeeze(1), bounds[:,0], bounds[:,1]
|
57 |
|
58 |
+
def plot_w_conf_interval(ax_or_plt, x, m, lb, ub, color, label_prefix):
|
59 |
+
ax_or_plt.plot(x.squeeze(-1),m, color=color, label=label_prefix+' mean')
|
60 |
+
ax_or_plt.fill_between(x.squeeze(-1), lb, ub, alpha=.1, color=color, label=label_prefix+' conf. interval')
|
61 |
|
62 |
|
63 |
|
|
|
74 |
return excuse, None
|
75 |
x = torch.tensor(table[:,0]).unsqueeze(1)
|
76 |
y = torch.tensor(table[:,1])
|
77 |
+
fig = plt.figure(figsize=(8,4),dpi=1000)
|
78 |
|
79 |
if len(x) > 4:
|
80 |
return excuse_max_examples, None
|
81 |
if (x<0.).any() or (x>1.).any() or (y<-1).any() or (y>1).any():
|
82 |
return excuse, None
|
83 |
|
84 |
+
plt.scatter(x,y, color='black', label='Examples in given dataset')
|
85 |
|
86 |
|
87 |
|
88 |
test_xs = torch.linspace(0,1,100).unsqueeze(1)
|
89 |
|
90 |
+
plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_gp(x,y,test_xs), 'green', 'GP')
|
91 |
+
plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_pnf(x,y,test_xs, choice), 'blue', 'PFN')
|
92 |
+
|
93 |
+
plt.legend(ncol=2,bbox_to_anchor=[0.5,-.08],loc="upper center")
|
94 |
+
plt.tight_layout()
|
95 |
|
96 |
|
97 |
return '', plt.gcf()
|