Spaces:
Running
Running
Update Space (evaluate main: 8e481b15)
Browse files- bertscore.py +24 -22
bertscore.py
CHANGED
@@ -105,12 +105,20 @@ class BERTScore(evaluate.Metric):
|
|
105 |
citation=_CITATION,
|
106 |
homepage="https://github.com/Tiiiger/bert_score",
|
107 |
inputs_description=_KWARGS_DESCRIPTION,
|
108 |
-
features=
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
codebase_urls=["https://github.com/Tiiiger/bert_score"],
|
115 |
reference_urls=[
|
116 |
"https://github.com/Tiiiger/bert_score",
|
@@ -135,6 +143,15 @@ class BERTScore(evaluate.Metric):
|
|
135 |
baseline_path=None,
|
136 |
use_fast_tokenizer=False,
|
137 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
get_hash = bert_score.utils.get_hash
|
139 |
scorer = bert_score.BERTScorer
|
140 |
|
@@ -171,6 +188,7 @@ class BERTScore(evaluate.Metric):
|
|
171 |
nthreads=nthreads,
|
172 |
all_layers=all_layers,
|
173 |
idf=idf,
|
|
|
174 |
device=device,
|
175 |
lang=lang,
|
176 |
rescale_with_baseline=rescale_with_baseline,
|
@@ -190,19 +208,3 @@ class BERTScore(evaluate.Metric):
|
|
190 |
"hashcode": hashcode,
|
191 |
}
|
192 |
return output_dict
|
193 |
-
|
194 |
-
def add_batch(self, predictions=None, references=None, **kwargs):
|
195 |
-
"""Add a batch of predictions and references for the metric's stack."""
|
196 |
-
# References can be strings or lists of strings
|
197 |
-
# Let's change strings to lists of strings with one element
|
198 |
-
if references is not None:
|
199 |
-
references = [[ref] if isinstance(ref, str) else ref for ref in references]
|
200 |
-
super().add_batch(predictions=predictions, references=references, **kwargs)
|
201 |
-
|
202 |
-
def add(self, prediction=None, reference=None, **kwargs):
|
203 |
-
"""Add one prediction and reference for the metric's stack."""
|
204 |
-
# References can be strings or lists of strings
|
205 |
-
# Let's change strings to lists of strings with one element
|
206 |
-
if isinstance(reference, str):
|
207 |
-
reference = [reference]
|
208 |
-
super().add(prediction=prediction, reference=reference, **kwargs)
|
|
|
105 |
citation=_CITATION,
|
106 |
homepage="https://github.com/Tiiiger/bert_score",
|
107 |
inputs_description=_KWARGS_DESCRIPTION,
|
108 |
+
features=[
|
109 |
+
datasets.Features(
|
110 |
+
{
|
111 |
+
"predictions": datasets.Value("string", id="sequence"),
|
112 |
+
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
|
113 |
+
}
|
114 |
+
),
|
115 |
+
datasets.Features(
|
116 |
+
{
|
117 |
+
"predictions": datasets.Value("string", id="sequence"),
|
118 |
+
"references": datasets.Value("string", id="sequence"),
|
119 |
+
}
|
120 |
+
),
|
121 |
+
],
|
122 |
codebase_urls=["https://github.com/Tiiiger/bert_score"],
|
123 |
reference_urls=[
|
124 |
"https://github.com/Tiiiger/bert_score",
|
|
|
143 |
baseline_path=None,
|
144 |
use_fast_tokenizer=False,
|
145 |
):
|
146 |
+
|
147 |
+
if isinstance(references[0], str):
|
148 |
+
references = [[ref] for ref in references]
|
149 |
+
|
150 |
+
if idf:
|
151 |
+
idf_sents = [r for ref in references for r in ref]
|
152 |
+
else:
|
153 |
+
idf_sents = None
|
154 |
+
|
155 |
get_hash = bert_score.utils.get_hash
|
156 |
scorer = bert_score.BERTScorer
|
157 |
|
|
|
188 |
nthreads=nthreads,
|
189 |
all_layers=all_layers,
|
190 |
idf=idf,
|
191 |
+
idf_sents=idf_sents,
|
192 |
device=device,
|
193 |
lang=lang,
|
194 |
rescale_with_baseline=rescale_with_baseline,
|
|
|
208 |
"hashcode": hashcode,
|
209 |
}
|
210 |
return output_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|