lvwerra HF staff commited on
Commit
b0f3cfd
·
1 Parent(s): eadb728

Update Space (evaluate main: c447fc8e)

Browse files
Files changed (2) hide show
  1. bertscore.py +34 -54
  2. requirements.txt +1 -1
bertscore.py CHANGED
@@ -15,8 +15,6 @@
15
 
16
  import functools
17
  from contextlib import contextmanager
18
- from dataclasses import dataclass
19
- from typing import List, Optional, Union
20
 
21
  import bert_score
22
  import datasets
@@ -99,42 +97,14 @@ Examples:
99
  """
100
 
101
 
102
- @dataclass
103
- class BERTScoreConfig(evaluate.info.Config):
104
-
105
- name: str = "default"
106
-
107
- pos_label: Union[str, int] = 1
108
- average: str = "binary"
109
- lang: Optional[str] = None
110
- sample_weight: Optional[List[float]] = None
111
-
112
- lang: Optional[str] = None
113
- model_type: Optional[str] = None
114
- num_layers: Optional[int] = None
115
- verbose: bool = False
116
- idf = bool = False
117
- device: Optional[str] = None
118
- batch_size: int = 64
119
- nthreads: int = 4
120
- all_layers: bool = False
121
- rescale_with_baseline: bool = False
122
- baseline_path: Optional[str] = None
123
- use_fast_tokenizer: bool = False
124
-
125
-
126
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
127
  class BERTScore(evaluate.Metric):
128
- CONFIG_CLASS = BERTScoreConfig
129
- ALLOWED_CONFIG_NAMES = ["default"]
130
-
131
- def _info(self, config):
132
  return evaluate.MetricInfo(
133
  description=_DESCRIPTION,
134
  citation=_CITATION,
135
  homepage="https://github.com/Tiiiger/bert_score",
136
  inputs_description=_KWARGS_DESCRIPTION,
137
- config=config,
138
  features=[
139
  datasets.Features(
140
  {
@@ -160,12 +130,24 @@ class BERTScore(evaluate.Metric):
160
  self,
161
  predictions,
162
  references,
 
 
 
 
 
 
 
 
 
 
 
 
163
  ):
164
 
165
  if isinstance(references[0], str):
166
  references = [[ref] for ref in references]
167
 
168
- if self.config.idf:
169
  idf_sents = [r for ref in references for r in ref]
170
  else:
171
  idf_sents = None
@@ -174,34 +156,32 @@ class BERTScore(evaluate.Metric):
174
  scorer = bert_score.BERTScorer
175
 
176
  if version.parse(bert_score.__version__) >= version.parse("0.3.10"):
177
- get_hash = functools.partial(get_hash, use_fast_tokenizer=self.config.use_fast_tokenizer)
178
- scorer = functools.partial(scorer, use_fast_tokenizer=self.config.use_fast_tokenizer)
179
- elif self.config.use_fast_tokenizer:
180
  raise ImportWarning(
181
  "To use a fast tokenizer, the module `bert-score>=0.3.10` is required, and the current version of "
182
  "`bert-score` doesn't match this condition.\n"
183
  'You can install it with `pip install "bert-score>=0.3.10"`.'
184
  )
185
 
186
- if self.config.model_type is None:
187
- if self.config.lang is None:
188
  raise ValueError(
189
  "Either 'lang' (e.g. 'en') or 'model_type' (e.g. 'microsoft/deberta-xlarge-mnli')"
190
  " must be specified"
191
  )
192
- model_type = bert_score.utils.lang2model[self.config.lang.lower()]
193
- else:
194
- model_type = self.config.model_type
195
 
196
- if self.config.num_layers is None:
197
  num_layers = bert_score.utils.model2layers[model_type]
198
 
199
  hashcode = get_hash(
200
  model=model_type,
201
  num_layers=num_layers,
202
- idf=self.config.idf,
203
- rescale_with_baseline=self.config.rescale_with_baseline,
204
- use_custom_baseline=self.config.baseline_path is not None,
205
  )
206
 
207
  with filter_logging_context():
@@ -209,22 +189,22 @@ class BERTScore(evaluate.Metric):
209
  self.cached_bertscorer = scorer(
210
  model_type=model_type,
211
  num_layers=num_layers,
212
- batch_size=self.config.batch_size,
213
- nthreads=self.config.nthreads,
214
- all_layers=self.config.all_layers,
215
- idf=self.config.idf,
216
  idf_sents=idf_sents,
217
- device=self.config.device,
218
- lang=self.config.lang,
219
- rescale_with_baseline=self.config.rescale_with_baseline,
220
- baseline_path=self.config.baseline_path,
221
  )
222
 
223
  (P, R, F) = self.cached_bertscorer.score(
224
  cands=predictions,
225
  refs=references,
226
- verbose=self.config.verbose,
227
- batch_size=self.config.batch_size,
228
  )
229
  output_dict = {
230
  "precision": P.tolist(),
 
15
 
16
  import functools
17
  from contextlib import contextmanager
 
 
18
 
19
  import bert_score
20
  import datasets
 
97
  """
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
101
  class BERTScore(evaluate.Metric):
102
+ def _info(self):
 
 
 
103
  return evaluate.MetricInfo(
104
  description=_DESCRIPTION,
105
  citation=_CITATION,
106
  homepage="https://github.com/Tiiiger/bert_score",
107
  inputs_description=_KWARGS_DESCRIPTION,
 
108
  features=[
109
  datasets.Features(
110
  {
 
130
  self,
131
  predictions,
132
  references,
133
+ lang=None,
134
+ model_type=None,
135
+ num_layers=None,
136
+ verbose=False,
137
+ idf=False,
138
+ device=None,
139
+ batch_size=64,
140
+ nthreads=4,
141
+ all_layers=False,
142
+ rescale_with_baseline=False,
143
+ baseline_path=None,
144
+ use_fast_tokenizer=False,
145
  ):
146
 
147
  if isinstance(references[0], str):
148
  references = [[ref] for ref in references]
149
 
150
+ if idf:
151
  idf_sents = [r for ref in references for r in ref]
152
  else:
153
  idf_sents = None
 
156
  scorer = bert_score.BERTScorer
157
 
158
  if version.parse(bert_score.__version__) >= version.parse("0.3.10"):
159
+ get_hash = functools.partial(get_hash, use_fast_tokenizer=use_fast_tokenizer)
160
+ scorer = functools.partial(scorer, use_fast_tokenizer=use_fast_tokenizer)
161
+ elif use_fast_tokenizer:
162
  raise ImportWarning(
163
  "To use a fast tokenizer, the module `bert-score>=0.3.10` is required, and the current version of "
164
  "`bert-score` doesn't match this condition.\n"
165
  'You can install it with `pip install "bert-score>=0.3.10"`.'
166
  )
167
 
168
+ if model_type is None:
169
+ if lang is None:
170
  raise ValueError(
171
  "Either 'lang' (e.g. 'en') or 'model_type' (e.g. 'microsoft/deberta-xlarge-mnli')"
172
  " must be specified"
173
  )
174
+ model_type = bert_score.utils.lang2model[lang.lower()]
 
 
175
 
176
+ if num_layers is None:
177
  num_layers = bert_score.utils.model2layers[model_type]
178
 
179
  hashcode = get_hash(
180
  model=model_type,
181
  num_layers=num_layers,
182
+ idf=idf,
183
+ rescale_with_baseline=rescale_with_baseline,
184
+ use_custom_baseline=baseline_path is not None,
185
  )
186
 
187
  with filter_logging_context():
 
189
  self.cached_bertscorer = scorer(
190
  model_type=model_type,
191
  num_layers=num_layers,
192
+ batch_size=batch_size,
193
+ nthreads=nthreads,
194
+ all_layers=all_layers,
195
+ idf=idf,
196
  idf_sents=idf_sents,
197
+ device=device,
198
+ lang=lang,
199
+ rescale_with_baseline=rescale_with_baseline,
200
+ baseline_path=baseline_path,
201
  )
202
 
203
  (P, R, F) = self.cached_bertscorer.score(
204
  cands=predictions,
205
  refs=references,
206
+ verbose=verbose,
207
+ batch_size=batch_size,
208
  )
209
  output_dict = {
210
  "precision": P.tolist(),
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  bert_score
 
1
+ git+https://github.com/huggingface/evaluate@c447fc8eda9c62af501bfdc6988919571050d950
2
  bert_score