Martin Dočekal
commited on
Commit
·
6aed907
1
Parent(s):
8f4e42e
bootstrapping, return dict keys change
Browse files- rouge_raw.py +185 -34
rouge_raw.py
CHANGED
@@ -35,12 +35,127 @@ Module for raw ROUGE score calculation from:
|
|
35 |
|
36 |
:author: Martin Dočekal
|
37 |
"""
|
38 |
-
|
39 |
import re
|
40 |
from typing import Sequence, Optional
|
41 |
|
42 |
import datasets
|
43 |
import evaluate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
|
46 |
class RougeRawOriginal:
|
@@ -51,6 +166,7 @@ class RougeRawOriginal:
|
|
51 |
|
52 |
class FScore:
|
53 |
"""F1 score representation."""
|
|
|
54 |
def __init__(self, correct, gold, system):
|
55 |
self.p = correct / system if system else 0.
|
56 |
self.r = correct / gold if gold else 0.
|
@@ -58,6 +174,7 @@ class RougeRawOriginal:
|
|
58 |
|
59 |
def _rouge_n(self, n, gold_words, system_words):
|
60 |
"""Compute Rouge-n for given words."""
|
|
|
61 |
def n_grams(n, words):
|
62 |
ngrams = {}
|
63 |
total = 0
|
@@ -108,27 +225,56 @@ class RougeRawOriginal:
|
|
108 |
"L": self._rouge_l(lc_gold_words, lc_system_words),
|
109 |
}
|
110 |
|
111 |
-
def corpus(self, gold, system):
|
112 |
"""Compute RougeRAW-1, RougeRAW-2, RougeRAW-L for given corpora.
|
113 |
Each corpus should be a collection of documents, each document a string.
|
|
|
|
|
114 |
"""
|
115 |
|
116 |
assert isinstance(gold, list) and isinstance(system, list), "Expected list arguments"
|
117 |
assert len(gold) == len(system), "Given corpora should be of the same length"
|
118 |
|
119 |
-
|
|
|
|
|
|
|
|
|
120 |
|
121 |
if len(gold):
|
122 |
for gold_document, system_document in zip(gold, system):
|
123 |
for key, value in self.document(gold_document, system_document).items():
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
return rouge
|
134 |
|
@@ -178,15 +324,18 @@ Args:
|
|
178 |
select: (Optional) string. The name of the metric to return. One of: 'rougeraw1_precision', 'rougeraw1_recall', 'rougeraw1_fmeasure', 'rougeraw2_precision', 'rougeraw2_recall', 'rougeraw2_fmeasure', 'rougerawl_precision', 'rougerawl_recall', 'rougerawl_fmeasure'.
|
179 |
If None, all metrics are returned as a dictionary.
|
180 |
Returns:
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
190 |
Examples:
|
191 |
>>> rougeraw = evaluate.load('CZLC/rouge_raw')
|
192 |
>>> predictions = ["the cat is on the mat", "hello there"]
|
@@ -217,21 +366,23 @@ class RougeRaw(evaluate.Metric):
|
|
217 |
],
|
218 |
)
|
219 |
|
220 |
-
def _compute(self, predictions: Sequence[str], references: Sequence[str], select: Optional[str] = None
|
221 |
-
|
222 |
-
res =
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
|
|
|
|
233 |
|
234 |
if select is not None:
|
235 |
return res[select]
|
236 |
return res
|
237 |
-
|
|
|
35 |
|
36 |
:author: Martin Dočekal
|
37 |
"""
|
38 |
+
import collections
|
39 |
import re
|
40 |
from typing import Sequence, Optional
|
41 |
|
42 |
import datasets
|
43 |
import evaluate
|
44 |
+
import numpy as np
|
45 |
+
|
46 |
+
|
47 |
+
class AggregateScore(collections.namedtuple("AggregateScore", ["low", "mid", "high"])):
|
48 |
+
"""
|
49 |
+
Tuple containing confidence intervals for scores.
|
50 |
+
Taken from: https://github.com/google-research/google-research/blob/master/rouge/scoring.py
|
51 |
+
"""
|
52 |
+
|
53 |
+
|
54 |
+
class Score(
|
55 |
+
collections.namedtuple("Score", ["precision", "recall", "fmeasure"])):
|
56 |
+
"""Tuple containing precision, recall, and f-measure values."""
|
57 |
+
|
58 |
+
|
59 |
+
class BootstrapAggregator(object):
|
60 |
+
"""Aggregates scores to provide confidence intervals.
|
61 |
+
Taken from: https://github.com/google-research/google-research/blob/master/rouge/scoring.py
|
62 |
+
|
63 |
+
Sample usage:
|
64 |
+
scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'])
|
65 |
+
aggregator = Aggregator()
|
66 |
+
aggregator.add_scores(scorer.score("one two three", "one two"))
|
67 |
+
aggregator.add_scores(scorer.score("one two five six", "seven eight"))
|
68 |
+
result = aggregator.aggregate()
|
69 |
+
print result
|
70 |
+
{'rougeL': AggregateScore(
|
71 |
+
low=Score(precision=0.0, recall=0.0, fmeasure=0.0),
|
72 |
+
mid=Score(precision=0.5, recall=0.33, fmeasure=0.40),
|
73 |
+
high=Score(precision=1.0, recall=0.66, fmeasure=0.80)),
|
74 |
+
'rouge1': AggregateScore(
|
75 |
+
low=Score(precision=0.0, recall=0.0, fmeasure=0.0),
|
76 |
+
mid=Score(precision=0.5, recall=0.33, fmeasure=0.40),
|
77 |
+
high=Score(precision=1.0, recall=0.66, fmeasure=0.80))}
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(self, confidence_interval=0.95, n_samples=1000):
|
81 |
+
"""Initializes a BootstrapAggregator object.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
confidence_interval: Confidence interval to compute on the mean as a
|
85 |
+
decimal.
|
86 |
+
n_samples: Number of samples to use for bootstrap resampling.
|
87 |
+
|
88 |
+
Raises:
|
89 |
+
ValueError: If invalid argument is given.
|
90 |
+
"""
|
91 |
+
|
92 |
+
if confidence_interval < 0 or confidence_interval > 1:
|
93 |
+
raise ValueError("confidence_interval must be in range [0, 1]")
|
94 |
+
if n_samples <= 0:
|
95 |
+
raise ValueError("n_samples must be positive")
|
96 |
+
|
97 |
+
self._n_samples = n_samples
|
98 |
+
self._confidence_interval = confidence_interval
|
99 |
+
self._scores = collections.defaultdict(list)
|
100 |
+
|
101 |
+
def add_scores(self, scores):
|
102 |
+
"""Adds a sample for future aggregation.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
scores: Dict mapping score_type strings to a namedtuple object/class
|
106 |
+
representing a score.
|
107 |
+
"""
|
108 |
+
|
109 |
+
for score_type, score in scores.items():
|
110 |
+
self._scores[score_type].append(score)
|
111 |
+
|
112 |
+
def aggregate(self):
|
113 |
+
"""Aggregates scores previously added using add_scores.
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
A dict mapping score_type to AggregateScore objects.
|
117 |
+
"""
|
118 |
+
|
119 |
+
result = {}
|
120 |
+
for score_type, scores in self._scores.items():
|
121 |
+
# Stack scores into a 2-d matrix of (sample, measure).
|
122 |
+
score_matrix = np.vstack(tuple(scores))
|
123 |
+
# Percentiles are returned as (interval, measure).
|
124 |
+
percentiles = self._bootstrap_resample(score_matrix)
|
125 |
+
# Extract the three intervals (low, mid, high).
|
126 |
+
intervals = tuple(
|
127 |
+
(scores[0].__class__(*percentiles[j, :]) for j in range(3)))
|
128 |
+
result[score_type] = AggregateScore(
|
129 |
+
low=intervals[0], mid=intervals[1], high=intervals[2])
|
130 |
+
return result
|
131 |
+
|
132 |
+
def _bootstrap_resample(self, matrix):
|
133 |
+
"""Performs bootstrap resampling on a matrix of scores.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
matrix: A 2-d matrix of (sample, measure).
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
A 2-d matrix of (bounds, measure). There are three bounds: low (row 0),
|
140 |
+
mid (row 1) and high (row 2). Mid is always the mean, while low and high
|
141 |
+
bounds are specified by self._confidence_interval (which defaults to 0.95
|
142 |
+
meaning it will return the 2.5th and 97.5th percentiles for a 95%
|
143 |
+
confidence interval on the mean).
|
144 |
+
"""
|
145 |
+
|
146 |
+
# Matrix of (bootstrap sample, measure).
|
147 |
+
sample_mean = np.zeros((self._n_samples, matrix.shape[1]))
|
148 |
+
for i in range(self._n_samples):
|
149 |
+
sample_idx = np.random.choice(
|
150 |
+
np.arange(matrix.shape[0]), size=matrix.shape[0])
|
151 |
+
sample = matrix[sample_idx, :]
|
152 |
+
sample_mean[i, :] = np.mean(sample, axis=0)
|
153 |
+
|
154 |
+
# Take percentiles on the estimate of the mean using bootstrap samples.
|
155 |
+
# Final result is a (bounds, measure) matrix.
|
156 |
+
percentile_delta = (1 - self._confidence_interval) / 2
|
157 |
+
q = 100 * np.array([percentile_delta, 0.5, 1 - percentile_delta])
|
158 |
+
return np.percentile(sample_mean, q, axis=0)
|
159 |
|
160 |
|
161 |
class RougeRawOriginal:
|
|
|
166 |
|
167 |
class FScore:
|
168 |
"""F1 score representation."""
|
169 |
+
|
170 |
def __init__(self, correct, gold, system):
|
171 |
self.p = correct / system if system else 0.
|
172 |
self.r = correct / gold if gold else 0.
|
|
|
174 |
|
175 |
def _rouge_n(self, n, gold_words, system_words):
|
176 |
"""Compute Rouge-n for given words."""
|
177 |
+
|
178 |
def n_grams(n, words):
|
179 |
ngrams = {}
|
180 |
total = 0
|
|
|
225 |
"L": self._rouge_l(lc_gold_words, lc_system_words),
|
226 |
}
|
227 |
|
228 |
+
def corpus(self, gold, system, aggregate=True):
|
229 |
"""Compute RougeRAW-1, RougeRAW-2, RougeRAW-L for given corpora.
|
230 |
Each corpus should be a collection of documents, each document a string.
|
231 |
+
|
232 |
+
If aggregate is True, the lower, mid, and upper bounds of the confidence interval are returned.
|
233 |
"""
|
234 |
|
235 |
assert isinstance(gold, list) and isinstance(system, list), "Expected list arguments"
|
236 |
assert len(gold) == len(system), "Given corpora should be of the same length"
|
237 |
|
238 |
+
|
239 |
+
if aggregate:
|
240 |
+
aggregator = BootstrapAggregator()
|
241 |
+
else:
|
242 |
+
rouge = {key: self.FScore(0, 0, 0) for key in ["1", "2", "L"]}
|
243 |
|
244 |
if len(gold):
|
245 |
for gold_document, system_document in zip(gold, system):
|
246 |
for key, value in self.document(gold_document, system_document).items():
|
247 |
+
if aggregate:
|
248 |
+
aggregator.add_scores({
|
249 |
+
key: Score(precision=value.p, recall=value.r, fmeasure=value.f)
|
250 |
+
})
|
251 |
+
else:
|
252 |
+
rouge[key].p += value.p
|
253 |
+
rouge[key].r += value.r
|
254 |
+
rouge[key].f += value.f
|
255 |
+
|
256 |
+
if not aggregate:
|
257 |
+
for key in rouge:
|
258 |
+
rouge[key].p /= len(gold)
|
259 |
+
rouge[key].r /= len(gold)
|
260 |
+
rouge[key].f /= len(gold)
|
261 |
+
|
262 |
+
if aggregate:
|
263 |
+
rouge = {}
|
264 |
+
# convert the named tuple to a dict
|
265 |
+
|
266 |
+
for k, ag_score in aggregator.aggregate().items():
|
267 |
+
rouge[k + "_low_precision"] = float(ag_score.low.precision)
|
268 |
+
rouge[k + "_low_recall"] = float(ag_score.low.recall)
|
269 |
+
rouge[k + "_low_fmeasure"] = float(ag_score.low.fmeasure)
|
270 |
+
|
271 |
+
rouge[k + "_mid_precision"] = float(ag_score.mid.precision)
|
272 |
+
rouge[k + "_mid_recall"] = float(ag_score.mid.recall)
|
273 |
+
rouge[k + "_mid_fmeasure"] = float(ag_score.mid.fmeasure)
|
274 |
+
|
275 |
+
rouge[k + "_high_precision"] = float(ag_score.high.precision)
|
276 |
+
rouge[k + "_high_recall"] = float(ag_score.high.recall)
|
277 |
+
rouge[k + "_high_fmeasure"] = float(ag_score.high.fmeasure)
|
278 |
|
279 |
return rouge
|
280 |
|
|
|
324 |
select: (Optional) string. The name of the metric to return. One of: 'rougeraw1_precision', 'rougeraw1_recall', 'rougeraw1_fmeasure', 'rougeraw2_precision', 'rougeraw2_recall', 'rougeraw2_fmeasure', 'rougerawl_precision', 'rougerawl_recall', 'rougerawl_fmeasure'.
|
325 |
If None, all metrics are returned as a dictionary.
|
326 |
Returns:
|
327 |
+
1_precision
|
328 |
+
1_recall
|
329 |
+
1_fmeasure
|
330 |
+
2_precision
|
331 |
+
2_recall
|
332 |
+
2_fmeasure
|
333 |
+
l_precision
|
334 |
+
l_recall
|
335 |
+
l_fmeasure
|
336 |
+
|
337 |
+
if aggregate is True there are also low, mid and high values for each metric. Thus, e.g.:
|
338 |
+
1_low_precision
|
339 |
Examples:
|
340 |
>>> rougeraw = evaluate.load('CZLC/rouge_raw')
|
341 |
>>> predictions = ["the cat is on the mat", "hello there"]
|
|
|
366 |
],
|
367 |
)
|
368 |
|
369 |
+
def _compute(self, predictions: Sequence[str], references: Sequence[str], select: Optional[str] = None,
|
370 |
+
aggregate: bool = True):
|
371 |
+
res = RougeRawOriginal().corpus(references, predictions, aggregate=aggregate)
|
372 |
+
|
373 |
+
if not aggregate:
|
374 |
+
res = {
|
375 |
+
"1_precision": res["1"].p,
|
376 |
+
"1_recall": res["1"].r,
|
377 |
+
"1_fmeasure": res["1"].f,
|
378 |
+
"2_precision": res["2"].p,
|
379 |
+
"2_recall": res["2"].r,
|
380 |
+
"2_fmeasure": res["2"].f,
|
381 |
+
"l_precision": res["L"].p,
|
382 |
+
"l_recall": res["L"].r,
|
383 |
+
"l_fmeasure": res["L"].f,
|
384 |
+
}
|
385 |
|
386 |
if select is not None:
|
387 |
return res[select]
|
388 |
return res
|
|