Spaces:
Running
Running
Update Space (evaluate main: c447fc8e)
Browse files- bertscore.py +34 -54
- requirements.txt +1 -1
bertscore.py
CHANGED
@@ -15,8 +15,6 @@
|
|
15 |
|
16 |
import functools
|
17 |
from contextlib import contextmanager
|
18 |
-
from dataclasses import dataclass
|
19 |
-
from typing import List, Optional, Union
|
20 |
|
21 |
import bert_score
|
22 |
import datasets
|
@@ -99,42 +97,14 @@ Examples:
|
|
99 |
"""
|
100 |
|
101 |
|
102 |
-
@dataclass
|
103 |
-
class BERTScoreConfig(evaluate.info.Config):
|
104 |
-
|
105 |
-
name: str = "default"
|
106 |
-
|
107 |
-
pos_label: Union[str, int] = 1
|
108 |
-
average: str = "binary"
|
109 |
-
lang: Optional[str] = None
|
110 |
-
sample_weight: Optional[List[float]] = None
|
111 |
-
|
112 |
-
lang: Optional[str] = None
|
113 |
-
model_type: Optional[str] = None
|
114 |
-
num_layers: Optional[int] = None
|
115 |
-
verbose: bool = False
|
116 |
-
idf = bool = False
|
117 |
-
device: Optional[str] = None
|
118 |
-
batch_size: int = 64
|
119 |
-
nthreads: int = 4
|
120 |
-
all_layers: bool = False
|
121 |
-
rescale_with_baseline: bool = False
|
122 |
-
baseline_path: Optional[str] = None
|
123 |
-
use_fast_tokenizer: bool = False
|
124 |
-
|
125 |
-
|
126 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
127 |
class BERTScore(evaluate.Metric):
|
128 |
-
|
129 |
-
ALLOWED_CONFIG_NAMES = ["default"]
|
130 |
-
|
131 |
-
def _info(self, config):
|
132 |
return evaluate.MetricInfo(
|
133 |
description=_DESCRIPTION,
|
134 |
citation=_CITATION,
|
135 |
homepage="https://github.com/Tiiiger/bert_score",
|
136 |
inputs_description=_KWARGS_DESCRIPTION,
|
137 |
-
config=config,
|
138 |
features=[
|
139 |
datasets.Features(
|
140 |
{
|
@@ -160,12 +130,24 @@ class BERTScore(evaluate.Metric):
|
|
160 |
self,
|
161 |
predictions,
|
162 |
references,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
):
|
164 |
|
165 |
if isinstance(references[0], str):
|
166 |
references = [[ref] for ref in references]
|
167 |
|
168 |
-
if
|
169 |
idf_sents = [r for ref in references for r in ref]
|
170 |
else:
|
171 |
idf_sents = None
|
@@ -174,34 +156,32 @@ class BERTScore(evaluate.Metric):
|
|
174 |
scorer = bert_score.BERTScorer
|
175 |
|
176 |
if version.parse(bert_score.__version__) >= version.parse("0.3.10"):
|
177 |
-
get_hash = functools.partial(get_hash, use_fast_tokenizer=
|
178 |
-
scorer = functools.partial(scorer, use_fast_tokenizer=
|
179 |
-
elif
|
180 |
raise ImportWarning(
|
181 |
"To use a fast tokenizer, the module `bert-score>=0.3.10` is required, and the current version of "
|
182 |
"`bert-score` doesn't match this condition.\n"
|
183 |
'You can install it with `pip install "bert-score>=0.3.10"`.'
|
184 |
)
|
185 |
|
186 |
-
if
|
187 |
-
if
|
188 |
raise ValueError(
|
189 |
"Either 'lang' (e.g. 'en') or 'model_type' (e.g. 'microsoft/deberta-xlarge-mnli')"
|
190 |
" must be specified"
|
191 |
)
|
192 |
-
model_type = bert_score.utils.lang2model[
|
193 |
-
else:
|
194 |
-
model_type = self.config.model_type
|
195 |
|
196 |
-
if
|
197 |
num_layers = bert_score.utils.model2layers[model_type]
|
198 |
|
199 |
hashcode = get_hash(
|
200 |
model=model_type,
|
201 |
num_layers=num_layers,
|
202 |
-
idf=
|
203 |
-
rescale_with_baseline=
|
204 |
-
use_custom_baseline=
|
205 |
)
|
206 |
|
207 |
with filter_logging_context():
|
@@ -209,22 +189,22 @@ class BERTScore(evaluate.Metric):
|
|
209 |
self.cached_bertscorer = scorer(
|
210 |
model_type=model_type,
|
211 |
num_layers=num_layers,
|
212 |
-
batch_size=
|
213 |
-
nthreads=
|
214 |
-
all_layers=
|
215 |
-
idf=
|
216 |
idf_sents=idf_sents,
|
217 |
-
device=
|
218 |
-
lang=
|
219 |
-
rescale_with_baseline=
|
220 |
-
baseline_path=
|
221 |
)
|
222 |
|
223 |
(P, R, F) = self.cached_bertscorer.score(
|
224 |
cands=predictions,
|
225 |
refs=references,
|
226 |
-
verbose=
|
227 |
-
batch_size=
|
228 |
)
|
229 |
output_dict = {
|
230 |
"precision": P.tolist(),
|
|
|
15 |
|
16 |
import functools
|
17 |
from contextlib import contextmanager
|
|
|
|
|
18 |
|
19 |
import bert_score
|
20 |
import datasets
|
|
|
97 |
"""
|
98 |
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
101 |
class BERTScore(evaluate.Metric):
|
102 |
+
def _info(self):
|
|
|
|
|
|
|
103 |
return evaluate.MetricInfo(
|
104 |
description=_DESCRIPTION,
|
105 |
citation=_CITATION,
|
106 |
homepage="https://github.com/Tiiiger/bert_score",
|
107 |
inputs_description=_KWARGS_DESCRIPTION,
|
|
|
108 |
features=[
|
109 |
datasets.Features(
|
110 |
{
|
|
|
130 |
self,
|
131 |
predictions,
|
132 |
references,
|
133 |
+
lang=None,
|
134 |
+
model_type=None,
|
135 |
+
num_layers=None,
|
136 |
+
verbose=False,
|
137 |
+
idf=False,
|
138 |
+
device=None,
|
139 |
+
batch_size=64,
|
140 |
+
nthreads=4,
|
141 |
+
all_layers=False,
|
142 |
+
rescale_with_baseline=False,
|
143 |
+
baseline_path=None,
|
144 |
+
use_fast_tokenizer=False,
|
145 |
):
|
146 |
|
147 |
if isinstance(references[0], str):
|
148 |
references = [[ref] for ref in references]
|
149 |
|
150 |
+
if idf:
|
151 |
idf_sents = [r for ref in references for r in ref]
|
152 |
else:
|
153 |
idf_sents = None
|
|
|
156 |
scorer = bert_score.BERTScorer
|
157 |
|
158 |
if version.parse(bert_score.__version__) >= version.parse("0.3.10"):
|
159 |
+
get_hash = functools.partial(get_hash, use_fast_tokenizer=use_fast_tokenizer)
|
160 |
+
scorer = functools.partial(scorer, use_fast_tokenizer=use_fast_tokenizer)
|
161 |
+
elif use_fast_tokenizer:
|
162 |
raise ImportWarning(
|
163 |
"To use a fast tokenizer, the module `bert-score>=0.3.10` is required, and the current version of "
|
164 |
"`bert-score` doesn't match this condition.\n"
|
165 |
'You can install it with `pip install "bert-score>=0.3.10"`.'
|
166 |
)
|
167 |
|
168 |
+
if model_type is None:
|
169 |
+
if lang is None:
|
170 |
raise ValueError(
|
171 |
"Either 'lang' (e.g. 'en') or 'model_type' (e.g. 'microsoft/deberta-xlarge-mnli')"
|
172 |
" must be specified"
|
173 |
)
|
174 |
+
model_type = bert_score.utils.lang2model[lang.lower()]
|
|
|
|
|
175 |
|
176 |
+
if num_layers is None:
|
177 |
num_layers = bert_score.utils.model2layers[model_type]
|
178 |
|
179 |
hashcode = get_hash(
|
180 |
model=model_type,
|
181 |
num_layers=num_layers,
|
182 |
+
idf=idf,
|
183 |
+
rescale_with_baseline=rescale_with_baseline,
|
184 |
+
use_custom_baseline=baseline_path is not None,
|
185 |
)
|
186 |
|
187 |
with filter_logging_context():
|
|
|
189 |
self.cached_bertscorer = scorer(
|
190 |
model_type=model_type,
|
191 |
num_layers=num_layers,
|
192 |
+
batch_size=batch_size,
|
193 |
+
nthreads=nthreads,
|
194 |
+
all_layers=all_layers,
|
195 |
+
idf=idf,
|
196 |
idf_sents=idf_sents,
|
197 |
+
device=device,
|
198 |
+
lang=lang,
|
199 |
+
rescale_with_baseline=rescale_with_baseline,
|
200 |
+
baseline_path=baseline_path,
|
201 |
)
|
202 |
|
203 |
(P, R, F) = self.cached_bertscorer.score(
|
204 |
cands=predictions,
|
205 |
refs=references,
|
206 |
+
verbose=verbose,
|
207 |
+
batch_size=batch_size,
|
208 |
)
|
209 |
output_dict = {
|
210 |
"precision": P.tolist(),
|
requirements.txt
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
-
git+https://github.com/huggingface/evaluate@
|
2 |
bert_score
|
|
|
1 |
+
git+https://github.com/huggingface/evaluate@c447fc8eda9c62af501bfdc6988919571050d950
|
2 |
bert_score
|