File size: 775 Bytes
f50f696 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
from . import fast_gp, fast_gp_mix
from .utils import get_batch_to_dataloader
def regression_prior_to_binary(get_batch_function):
def binarized_get_batch_function(*args, assert_on=False, **kwargs):
x, y, target_y = get_batch_function(*args, **kwargs)
if assert_on:
assert y is target_y, "y == target_y is assumed by this function"
y = y.sigmoid().bernoulli()
return x, y, y
return binarized_get_batch_function
Binarized_fast_gp_dataloader = get_batch_to_dataloader(regression_prior_to_binary(fast_gp.get_batch))
Binarized_fast_gp_dataloader.num_outputs = 1
Binarized_fast_gp_mix_dataloader = get_batch_to_dataloader(regression_prior_to_binary(fast_gp_mix.get_batch))
Binarized_fast_gp_mix_dataloader.num_outputs = 1
|