Samuel Mueller commited on
Commit
981d877
·
1 Parent(s): 04eb228

legend added

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -55,9 +55,9 @@ def mean_and_bounds_for_pnf(x,y,test_xs, choice):
55
  bounds = model.criterion.quantile(logits,center_prob=.682).squeeze(1)
56
  return model.criterion.mean(logits).squeeze(1), bounds[:,0], bounds[:,1]
57
 
58
- def plot_w_conf_interval(ax_or_plt, x, m, lb, ub, color):
59
- ax_or_plt.plot(x.squeeze(-1),m, color=color)
60
- ax_or_plt.fill_between(x.squeeze(-1), lb, ub, alpha=.1, color=color)
61
 
62
 
63
 
@@ -74,22 +74,24 @@ def infer(table, choice):
74
  return excuse, None
75
  x = torch.tensor(table[:,0]).unsqueeze(1)
76
  y = torch.tensor(table[:,1])
77
- fig = plt.figure(figsize=(4,2),dpi=1000)
78
 
79
  if len(x) > 4:
80
  return excuse_max_examples, None
81
  if (x<0.).any() or (x>1.).any() or (y<-1).any() or (y>1).any():
82
  return excuse, None
83
 
84
- plt.scatter(x,y)
85
 
86
 
87
 
88
  test_xs = torch.linspace(0,1,100).unsqueeze(1)
89
 
90
- plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_gp(x,y,test_xs), 'green')
91
- plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_pnf(x,y,test_xs, choice), 'blue')
92
-
 
 
93
 
94
 
95
  return '', plt.gcf()
 
55
  bounds = model.criterion.quantile(logits,center_prob=.682).squeeze(1)
56
  return model.criterion.mean(logits).squeeze(1), bounds[:,0], bounds[:,1]
57
 
58
+ def plot_w_conf_interval(ax_or_plt, x, m, lb, ub, color, label_prefix):
59
+ ax_or_plt.plot(x.squeeze(-1),m, color=color, label=label_prefix+' mean')
60
+ ax_or_plt.fill_between(x.squeeze(-1), lb, ub, alpha=.1, color=color, label=label_prefix+' conf. interval')
61
 
62
 
63
 
 
74
  return excuse, None
75
  x = torch.tensor(table[:,0]).unsqueeze(1)
76
  y = torch.tensor(table[:,1])
77
+ fig = plt.figure(figsize=(8,4),dpi=1000)
78
 
79
  if len(x) > 4:
80
  return excuse_max_examples, None
81
  if (x<0.).any() or (x>1.).any() or (y<-1).any() or (y>1).any():
82
  return excuse, None
83
 
84
+ plt.scatter(x,y, color='black', label='Examples in given dataset')
85
 
86
 
87
 
88
  test_xs = torch.linspace(0,1,100).unsqueeze(1)
89
 
90
+ plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_gp(x,y,test_xs), 'green', 'GP')
91
+ plot_w_conf_interval(plt, test_xs, *mean_and_bounds_for_pnf(x,y,test_xs, choice), 'blue', 'PFN')
92
+
93
+ plt.legend(ncol=2,bbox_to_anchor=[0.5,-.08],loc="upper center")
94
+ plt.tight_layout()
95
 
96
 
97
  return '', plt.gcf()