|
from . import fast_gp, fast_gp_mix |
|
from .utils import get_batch_to_dataloader |
|
|
|
def regression_prior_to_binary(get_batch_function): |
|
|
|
def binarized_get_batch_function(*args, assert_on=False, **kwargs): |
|
x, y, target_y = get_batch_function(*args, **kwargs) |
|
if assert_on: |
|
assert y is target_y, "y == target_y is assumed by this function" |
|
y = y.sigmoid().bernoulli() |
|
return x, y, y |
|
|
|
return binarized_get_batch_function |
|
|
|
|
|
Binarized_fast_gp_dataloader = get_batch_to_dataloader(regression_prior_to_binary(fast_gp.get_batch)) |
|
Binarized_fast_gp_dataloader.num_outputs = 1 |
|
|
|
|
|
Binarized_fast_gp_mix_dataloader = get_batch_to_dataloader(regression_prior_to_binary(fast_gp_mix.get_batch)) |
|
Binarized_fast_gp_mix_dataloader.num_outputs = 1 |
|
|