Gosse Minnema
Re-enable LOME
2890e34
from typing import *
from allennlp.training.metrics import Metric
from overrides import overrides
import numpy as np
import logging
from .base_f import BaseF
from ..utils import Span, max_match
logger = logging.getLogger('srl_metric')
@Metric.register('srl')
class SRLMetric(Metric):
def __init__(self, check_type: Optional[bool] = None):
self.tri_i = BaseF('tri-i')
self.tri_c = BaseF('tri-c')
self.arg_i = BaseF('arg-i')
self.arg_c = BaseF('arg-c')
if check_type is not None:
logger.warning('Check type argument is deprecated.')
def reset(self) -> None:
for metric in [self.tri_i, self.tri_c, self.arg_i, self.arg_c]:
metric.reset()
def get_metric(self, reset: bool) -> Dict[str, Any]:
ret = dict()
for metric in [self.tri_i, self.tri_c, self.arg_i, self.arg_c]:
ret.update(metric.get_metric(reset))
return ret
@overrides
def __call__(self, prediction: Span, gold: Span):
self.with_label_event(prediction, gold)
self.without_label_event(prediction, gold)
self.tuple_eval(prediction, gold)
# self.with_label_arg(prediction, gold)
# self.without_label_arg(prediction, gold)
def tuple_eval(self, prediction: Span, gold: Span):
def extract_tuples(vr: Span, parent_boundary: bool):
labeled, unlabeled = list(), list()
for event in vr:
for arg in event:
if parent_boundary:
labeled.append((event.boundary, event.label, arg.boundary, arg.label))
unlabeled.append((event.boundary, event.label, arg.boundary))
else:
labeled.append((event.label, arg.boundary, arg.label))
unlabeled.append((event.label, arg.boundary))
return labeled, unlabeled
def equal_matrix(l1, l2): return np.array([[e1 == e2 for e2 in l2] for e1 in l1], dtype=np.int)
pred_label, pred_unlabel = extract_tuples(prediction, False)
gold_label, gold_unlabel = extract_tuples(gold, False)
if len(pred_label) == 0 or len(gold_label) == 0:
arg_c_tp = arg_i_tp = 0
else:
label_bipartite = equal_matrix(pred_label, gold_label)
unlabel_bipartite = equal_matrix(pred_unlabel, gold_unlabel)
arg_c_tp, arg_i_tp = max_match(label_bipartite), max_match(unlabel_bipartite)
arg_c_fp = prediction.n_nodes - len(prediction) - 1 - arg_c_tp
arg_c_fn = gold.n_nodes - len(gold) - 1 - arg_c_tp
arg_i_fp = prediction.n_nodes - len(prediction) - 1 - arg_i_tp
arg_i_fn = gold.n_nodes - len(gold) - 1 - arg_i_tp
assert arg_i_tp >= 0 and arg_i_fn >= 0 and arg_i_fp >= 0
self.arg_i.tp += arg_i_tp
self.arg_i.fp += arg_i_fp
self.arg_i.fn += arg_i_fn
assert arg_c_tp >= 0 and arg_c_fn >= 0 and arg_c_fp >= 0
self.arg_c.tp += arg_c_tp
self.arg_c.fp += arg_c_fp
self.arg_c.fn += arg_c_fn
def with_label_event(self, prediction: Span, gold: Span):
trigger_tp = prediction.match(gold, True, 2) - 1
trigger_fp = len(prediction) - trigger_tp
trigger_fn = len(gold) - trigger_tp
assert trigger_fp >= 0 and trigger_fn >= 0 and trigger_tp >= 0
self.tri_c.tp += trigger_tp
self.tri_c.fp += trigger_fp
self.tri_c.fn += trigger_fn
def with_label_arg(self, prediction: Span, gold: Span):
trigger_tp = prediction.match(gold, True, 2) - 1
role_tp = prediction.match(gold, True, ignore_parent_boundary=True) - 1 - trigger_tp
role_fp = (prediction.n_nodes - 1 - len(prediction)) - role_tp
role_fn = (gold.n_nodes - 1 - len(gold)) - role_tp
assert role_fp >= 0 and role_fn >= 0 and role_tp >= 0
self.arg_c.tp += role_tp
self.arg_c.fp += role_fp
self.arg_c.fn += role_fn
def without_label_event(self, prediction: Span, gold: Span):
tri_i_tp = prediction.match(gold, False, 2) - 1
tri_i_fp = len(prediction) - tri_i_tp
tri_i_fn = len(gold) - tri_i_tp
assert tri_i_tp >= 0 and tri_i_fp >= 0 and tri_i_fn >= 0
self.tri_i.tp += tri_i_tp
self.tri_i.fp += tri_i_fp
self.tri_i.fn += tri_i_fn
def without_label_arg(self, prediction: Span, gold: Span):
arg_i_tp = 0
matched_pairs: List[Tuple[Span, Span]] = list()
n_gold_arg, n_pred_arg = gold.n_nodes - len(gold) - 1, prediction.n_nodes - len(prediction) - 1
prediction, gold = prediction.clone(), gold.clone()
for p in prediction:
for g in gold:
if p.match(g, True, 1) == 1:
arg_i_tp += (p.match(g, False) - 1)
matched_pairs.append((p, g))
break
for p, g in matched_pairs:
prediction.remove_child(p)
gold.remove_child(g)
sub_matches = np.zeros([len(prediction), len(gold)], np.int)
for p_idx, p in enumerate(prediction):
for g_idx, g in enumerate(gold):
if p.label == g.label:
sub_matches[p_idx, g_idx] = p.match(g, False, -1, True)
arg_i_tp += max_match(sub_matches)
arg_i_fp = n_pred_arg - arg_i_tp
arg_i_fn = n_gold_arg - arg_i_tp
assert arg_i_tp >= 0 and arg_i_fn >= 0 and arg_i_fp >= 0
self.arg_i.tp += arg_i_tp
self.arg_i.fp += arg_i_fp
self.arg_i.fn += arg_i_fn