Samuel Mueller
working locally
f50f696
raw
history blame
775 Bytes
from . import fast_gp, fast_gp_mix
from .utils import get_batch_to_dataloader
def regression_prior_to_binary(get_batch_function):
def binarized_get_batch_function(*args, assert_on=False, **kwargs):
x, y, target_y = get_batch_function(*args, **kwargs)
if assert_on:
assert y is target_y, "y == target_y is assumed by this function"
y = y.sigmoid().bernoulli()
return x, y, y
return binarized_get_batch_function
Binarized_fast_gp_dataloader = get_batch_to_dataloader(regression_prior_to_binary(fast_gp.get_batch))
Binarized_fast_gp_dataloader.num_outputs = 1
Binarized_fast_gp_mix_dataloader = get_batch_to_dataloader(regression_prior_to_binary(fast_gp_mix.get_batch))
Binarized_fast_gp_mix_dataloader.num_outputs = 1