almost_stochastic_order / almost_stochastic_order.py
Kaleidophon
Fix path to comparison module in Huggingface hub
96b3437
# Copyright 2022 The HuggingFace Evaluate Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Almost Stochastic Order test for model comparison."""
from typing import Optional
import datasets
from deepsig import aso
import evaluate
_DESCRIPTION = """
The Almost Stochastic Order test is a non-parametric test that tests to what extent the distributions of predictions differ measuring the Wasserstein distance from each other through. It can be used to compare the predictions of two models.
"""
_KWARGS_DESCRIPTION = """
Args:
predictions1 (`list` of `float`): Predictions for model 1.
predictions2 (`list` of `float`): Predictions for model 2.
Kwargs:
confidence_level (`float`): Confidence level under which the result is obtained. Default is 0.95.
num_bootstrap_iterations: (`int`): Number of bootstrap iterations to compute upper bound to test statistics. Default is 1000.
dt (`float`): Differential for t during numerical integral calculation. Default is 0.005.
num_jobs (`int` or None): Number of jobs to use for test. If None, this defaults to value specified in the num_process attribute.
show_progress (`bool`): If True, a progress bar is shown when computing the test statistic. Default is False.
seed (`int` or None): Set seed for reproducibility purposes. If None, this defaults to the value specified in the seed attribute.
Returns:
violation_ratio (`float`): (Frequentist upper bound to) Degree of violation of the stochastic order. When it is smaller than 0.5, the model producing predictions1 performs better than the other model at a confidence level specified by confidence_level argument (default is 0.95). Ulmer et al. (2022) recommend to reject the null hypothesis when violation_ratio is under 0.2.
Examples:
>>> aso = evaluate.load("kaleidophon/almost_stochastic_order")
>>> results = aso.compute(predictions1=[-7, 123.45, 43, 4.91, 5], predictions2=[1337.12, -9.74, 1, 2, 3.21])
>>> print(results)
{'violation_ratio': 1.0}
"""
_CITATION = """
@article{ulmer2022deep,
title={deep-significance-Easy and Meaningful Statistical Significance Testing in the Age of Neural Networks},
author={Ulmer, Dennis and Hardmeier, Christian and Frellsen, Jes},
journal={arXiv preprint arXiv:2204.06815},
year={2022}
}
@inproceedings{dror2019deep,
author = {Rotem Dror and
Segev Shlomov and
Roi Reichart},
editor = {Anna Korhonen and
David R. Traum and
Llu{\'{\i}}s M{\`{a}}rquez},
title = {Deep Dominance - How to Properly Compare Deep Neural Models},
booktitle = {Proceedings of the 57th Conference of the Association for Computational
Linguistics, {ACL} 2019, Florence, Italy, July 28-August 2, 2019,
Volume 1: Long Papers},
pages = {2773--2785},
publisher = {Association for Computational Linguistics},
year = {2019}
}
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class AlmostStochasticOrder(evaluate.Comparison):
def _info(self):
return evaluate.ComparisonInfo(
module_type="comparison",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions1": datasets.Value("float"),
"predictions2": datasets.Value("float"),
}
),
)
def _compute(
self, predictions1, predictions2,
confidence_level: float = 0.95,
num_bootstrap_iterations: int = 1000,
dt: float = 0.005,
num_jobs: Optional[int] = None,
show_progress: bool = False,
seed: Optional[int] = None,
**kwargs
):
# Set seed
if seed is None:
seed = self.seed
# Set number of jobs
if num_jobs is None:
num_jobs = self.num_process
else:
num_jobs = num_jobs
# Compute statistic
violation_ratio = aso(
scores_a=predictions1, scores_b=predictions2,
num_bootstrap_iterations=num_bootstrap_iterations,
dt=dt,
num_jobs=num_jobs,
seed=seed,
show_progress=show_progress
)
return {"violation_ratio": violation_ratio}