Spaces:
Runtime error
Runtime error
# Copyright 2022 The HuggingFace Evaluate Authors | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Almost Stochastic Order test for model comparison.""" | |
from typing import Optional | |
import datasets | |
from deepsig import aso | |
import evaluate | |
_DESCRIPTION = """ | |
The Almost Stochastic Order test is a non-parametric test that tests to what extent the distributions of predictions differ measuring the Wasserstein distance from each other through. It can be used to compare the predictions of two models. | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Args: | |
predictions1 (`list` of `float`): Predictions for model 1. | |
predictions2 (`list` of `float`): Predictions for model 2. | |
Kwargs: | |
confidence_level (`float`): Confidence level under which the result is obtained. Default is 0.95. | |
num_bootstrap_iterations: (`int`): Number of bootstrap iterations to compute upper bound to test statistics. Default is 1000. | |
dt (`float`): Differential for t during numerical integral calculation. Default is 0.005. | |
num_jobs (`int` or None): Number of jobs to use for test. If None, this defaults to value specified in the num_process attribute. | |
show_progress (`bool`): If True, a progress bar is shown when computing the test statistic. Default is False. | |
seed (`int` or None): Set seed for reproducibility purposes. If None, this defaults to the value specified in the seed attribute. | |
Returns: | |
violation_ratio (`float`): (Frequentist upper bound to) Degree of violation of the stochastic order. When it is smaller than 0.5, the model producing predictions1 performs better than the other model at a confidence level specified by confidence_level argument (default is 0.95). Ulmer et al. (2022) recommend to reject the null hypothesis when violation_ratio is under 0.2. | |
Examples: | |
>>> aso = evaluate.load("kaleidophon/almost_stochastic_order") | |
>>> results = aso.compute(predictions1=[-7, 123.45, 43, 4.91, 5], predictions2=[1337.12, -9.74, 1, 2, 3.21]) | |
>>> print(results) | |
{'violation_ratio': 1.0} | |
""" | |
_CITATION = """ | |
@article{ulmer2022deep, | |
title={deep-significance-Easy and Meaningful Statistical Significance Testing in the Age of Neural Networks}, | |
author={Ulmer, Dennis and Hardmeier, Christian and Frellsen, Jes}, | |
journal={arXiv preprint arXiv:2204.06815}, | |
year={2022} | |
} | |
@inproceedings{dror2019deep, | |
author = {Rotem Dror and | |
Segev Shlomov and | |
Roi Reichart}, | |
editor = {Anna Korhonen and | |
David R. Traum and | |
Llu{\'{\i}}s M{\`{a}}rquez}, | |
title = {Deep Dominance - How to Properly Compare Deep Neural Models}, | |
booktitle = {Proceedings of the 57th Conference of the Association for Computational | |
Linguistics, {ACL} 2019, Florence, Italy, July 28-August 2, 2019, | |
Volume 1: Long Papers}, | |
pages = {2773--2785}, | |
publisher = {Association for Computational Linguistics}, | |
year = {2019} | |
} | |
""" | |
class AlmostStochasticOrder(evaluate.Comparison): | |
def _info(self): | |
return evaluate.ComparisonInfo( | |
module_type="comparison", | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features( | |
{ | |
"predictions1": datasets.Value("float"), | |
"predictions2": datasets.Value("float"), | |
} | |
), | |
) | |
def _compute( | |
self, predictions1, predictions2, | |
confidence_level: float = 0.95, | |
num_bootstrap_iterations: int = 1000, | |
dt: float = 0.005, | |
num_jobs: Optional[int] = None, | |
show_progress: bool = False, | |
seed: Optional[int] = None, | |
**kwargs | |
): | |
# Set seed | |
if seed is None: | |
seed = self.seed | |
# Set number of jobs | |
if num_jobs is None: | |
num_jobs = self.num_process | |
else: | |
num_jobs = num_jobs | |
# Compute statistic | |
violation_ratio = aso( | |
scores_a=predictions1, scores_b=predictions2, | |
num_bootstrap_iterations=num_bootstrap_iterations, | |
dt=dt, | |
num_jobs=num_jobs, | |
seed=seed, | |
show_progress=show_progress | |
) | |
return {"violation_ratio": violation_ratio} | |