|
from collections import defaultdict |
|
from typing import Tuple |
|
|
|
from rex.metrics import calc_p_r_f1_from_tp_fp_fn, safe_division |
|
from rex.metrics.base import MetricBase |
|
from rex.metrics.tagging import tagging_prf1 |
|
from rex.utils.batch import decompose_batch_into_instances |
|
from rex.utils.iteration import windowed_queue_iter |
|
from rex.utils.random import generate_random_string_with_datetime |
|
from sklearn.metrics import accuracy_score, matthews_corrcoef |
|
|
|
|
|
class MrcNERMetric(MetricBase): |
|
def get_instances_from_batch(self, raw_batch: dict, out_batch: dict) -> Tuple: |
|
gold_instances = [] |
|
pred_instances = [] |
|
|
|
batch_gold = decompose_batch_into_instances(raw_batch) |
|
assert len(batch_gold) == len(out_batch["pred"]) |
|
|
|
for i, gold in enumerate(batch_gold): |
|
gold_instances.append( |
|
{ |
|
"id": gold["id"], |
|
"ents": {(gold["ent_type"], gent) for gent in gold["gold_ents"]}, |
|
} |
|
) |
|
pred_instances.append( |
|
{ |
|
"id": gold["id"], |
|
"ents": {(gold["ent_type"], pent) for pent in out_batch["pred"][i]}, |
|
} |
|
) |
|
|
|
return gold_instances, pred_instances |
|
|
|
def calculate_scores(self, golds: list, preds: list) -> dict: |
|
id2gold = defaultdict(set) |
|
id2pred = defaultdict(set) |
|
|
|
for gold in golds: |
|
id2gold[gold["id"]].update(gold["ents"]) |
|
for pred in preds: |
|
id2pred[pred["id"]].update(pred["ents"]) |
|
assert len(id2gold) == len(id2pred) |
|
|
|
gold_ents = [] |
|
pred_ents = [] |
|
for _id in id2gold: |
|
gold_ents.append(id2gold[_id]) |
|
pred_ents.append(id2pred[_id]) |
|
|
|
return tagging_prf1(gold_ents, pred_ents, type_idx=0) |
|
|
|
|
|
class MrcSpanMetric(MetricBase): |
|
def get_instances_from_batch(self, raw_batch: dict, out_batch: dict) -> Tuple: |
|
gold_instances = [] |
|
pred_instances = [] |
|
|
|
batch_gold = decompose_batch_into_instances(raw_batch) |
|
assert len(batch_gold) == len(out_batch["pred"]) |
|
|
|
for i, gold in enumerate(batch_gold): |
|
gold_instances.append( |
|
{ |
|
"id": gold["id"], |
|
"spans": set(tuple(span) for span in gold["gold_spans"]), |
|
} |
|
) |
|
pred_instances.append( |
|
{ |
|
"id": gold["id"], |
|
"spans": set(out_batch["pred"][i]), |
|
} |
|
) |
|
|
|
return gold_instances, pred_instances |
|
|
|
def calculate_scores(self, golds: list, preds: list) -> dict: |
|
id2gold = defaultdict(set) |
|
id2pred = defaultdict(set) |
|
|
|
for gold in golds: |
|
id2gold[gold["id"]].update(gold["spans"]) |
|
for pred in preds: |
|
id2pred[pred["id"]].update(pred["spans"]) |
|
assert len(id2gold) == len(id2pred) |
|
|
|
gold_spans = [] |
|
pred_spans = [] |
|
for _id in id2gold: |
|
gold_spans.append(id2gold[_id]) |
|
pred_spans.append(id2pred[_id]) |
|
|
|
return tagging_prf1(gold_spans, pred_spans, type_idx=None) |
|
|
|
|
|
def calc_char_event(golds, preds): |
|
""" |
|
Calculate char-level event argument scores |
|
|
|
References: |
|
- https://aistudio.baidu.com/aistudio/competition/detail/46/0/submit-result |
|
|
|
Args: |
|
golds: a list of gold answers (a list of `event_list`), len=#data, |
|
format is a list of `event_list` |
|
preds: a list of pred answers, len=#data |
|
""" |
|
|
|
def _match_arg_char_f1(gold_arg, pred_args): |
|
gtype, grole, gstring = gold_arg |
|
gchars = set(gstring) |
|
garg_len = len(gchars) |
|
cands = [] |
|
for parg in pred_args: |
|
if parg[0] == gtype and parg[1] == grole: |
|
pchars = set(str(parg[-1])) |
|
parg_len = len(pchars) |
|
pmatch = len(pchars & gchars) |
|
p = safe_division(pmatch, parg_len) |
|
r = safe_division(pmatch, garg_len) |
|
f1 = safe_division(2 * p * r, p + r) |
|
cands.append(f1) |
|
if len(cands) > 0: |
|
f1 = sorted(cands)[-1] |
|
return f1 |
|
else: |
|
return 0.0 |
|
|
|
pscore = num_gargs = num_pargs = 0 |
|
for _golds, _preds in zip(golds, preds): |
|
|
|
gold_args = [] |
|
pred_args = [] |
|
for gold in _golds: |
|
for arg in gold.get("arguments", []): |
|
gold_args.append( |
|
(gold.get("event_type"), arg.get("role"), arg.get("argument")) |
|
) |
|
for pred in _preds: |
|
for arg in pred.get("arguments", []): |
|
pred_args.append( |
|
(pred.get("event_type"), arg.get("role"), arg.get("argument")) |
|
) |
|
|
|
num_gargs += len(gold_args) |
|
num_pargs += len(pred_args) |
|
for gold_arg in gold_args: |
|
pscore += _match_arg_char_f1(gold_arg, pred_args) |
|
|
|
p = safe_division(pscore, num_pargs) |
|
r = safe_division(pscore, num_gargs) |
|
f1 = safe_division(2 * p * r, p + r) |
|
return { |
|
"p": p, |
|
"r": r, |
|
"f1": f1, |
|
"pscore": pscore, |
|
"num_pargs": num_pargs, |
|
"num_gargs": num_gargs, |
|
} |
|
|
|
|
|
def calc_trigger_identification_metrics(golds, preds): |
|
tp = fp = fn = 0 |
|
for _golds, _preds in zip(golds, preds): |
|
gold_triggers = {gold["trigger"] for gold in _golds} |
|
pred_triggers = {pred["trigger"] for pred in _preds} |
|
tp += len(gold_triggers & pred_triggers) |
|
fp += len(pred_triggers - gold_triggers) |
|
fn += len(gold_triggers - pred_triggers) |
|
metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn) |
|
return metrics |
|
|
|
|
|
def calc_trigger_classification_metrics(golds, preds): |
|
tp = fp = fn = 0 |
|
for _golds, _preds in zip(golds, preds): |
|
gold_tgg_cls = {(gold["trigger"], gold["event_type"]) for gold in _golds} |
|
pred_tgg_cls = {(pred["trigger"], pred["event_type"]) for pred in _preds} |
|
tp += len(gold_tgg_cls & pred_tgg_cls) |
|
fp += len(pred_tgg_cls - gold_tgg_cls) |
|
fn += len(gold_tgg_cls - pred_tgg_cls) |
|
metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn) |
|
return metrics |
|
|
|
|
|
def calc_arg_identification_metrics(golds, preds): |
|
"""Calculate argument identification metrics |
|
|
|
Notice: |
|
An entity could take different roles in an event, |
|
so the base number must be calculated by |
|
(arg, event type, pos, role) |
|
""" |
|
tp = fp = fn = 0 |
|
for _golds, _preds in zip(golds, preds): |
|
gold_args = set() |
|
pred_args = set() |
|
for gold in _golds: |
|
_args = { |
|
(arg["role"], arg["argument"], gold["event_type"]) |
|
for arg in gold["arguments"] |
|
} |
|
gold_args.update(_args) |
|
for pred in _preds: |
|
_args = { |
|
(arg["role"], arg["argument"], pred["event_type"]) |
|
for arg in pred["arguments"] |
|
} |
|
pred_args.update(_args) |
|
|
|
_tp = 0 |
|
_tp_fp = len(pred_args) |
|
_tp_fn = len(gold_args) |
|
_gold_args_wo_role = {_ga[1:] for _ga in gold_args} |
|
for pred_arg in pred_args: |
|
if pred_arg[1:] in _gold_args_wo_role: |
|
_tp += 1 |
|
tp += _tp |
|
fp += _tp_fp - _tp |
|
fn += _tp_fn - _tp |
|
metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn) |
|
return metrics |
|
|
|
|
|
def calc_arg_classification_metrics(golds, preds): |
|
tp = fp = fn = 0 |
|
for _golds, _preds in zip(golds, preds): |
|
gold_arg_cls = set() |
|
pred_arg_cls = set() |
|
for gold in _golds: |
|
_args = { |
|
(arg["argument"], arg["role"], gold["event_type"]) |
|
for arg in gold["arguments"] |
|
} |
|
gold_arg_cls.update(_args) |
|
for pred in _preds: |
|
_args = { |
|
(arg["argument"], arg["role"], pred["event_type"]) |
|
for arg in pred["arguments"] |
|
} |
|
pred_arg_cls.update(_args) |
|
tp += len(gold_arg_cls & pred_arg_cls) |
|
fp += len(pred_arg_cls - gold_arg_cls) |
|
fn += len(gold_arg_cls - pred_arg_cls) |
|
metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn) |
|
return metrics |
|
|
|
|
|
def calc_ent(golds, preds): |
|
""" |
|
Args: |
|
golds, preds: [(type, index list), ...] |
|
""" |
|
res = tagging_prf1(golds, preds, type_idx=0) |
|
return res |
|
|
|
|
|
def calc_rel(golds, preds): |
|
gold_ents = [] |
|
pred_ents = [] |
|
for gold, pred in zip(golds, preds): |
|
gold_ins_ents = [] |
|
for t in gold: |
|
gold_ins_ents.extend(t[1:]) |
|
gold_ents.append(gold_ins_ents) |
|
pred_ins_ents = [] |
|
for t in pred: |
|
pred_ins_ents.extend(t[1:]) |
|
pred_ents.append(pred_ins_ents) |
|
|
|
metrics = { |
|
"ent": tagging_prf1(gold_ents, pred_ents, type_idx=None), |
|
"rel": tagging_prf1(golds, preds, type_idx=None), |
|
} |
|
return metrics |
|
|
|
|
|
def calc_cls(golds, preds): |
|
metrics = { |
|
"mcc": -1, |
|
"acc": -1, |
|
"mf1": tagging_prf1(golds, preds, type_idx=None), |
|
} |
|
y_true = [] |
|
y_pred = [] |
|
for gold, pred in zip(golds, preds): |
|
y_true.append(" ".join(sorted(gold))) |
|
y_pred.append(" ".join(sorted(pred))) |
|
if y_true and y_pred: |
|
metrics["acc"] = accuracy_score(y_true, y_pred) |
|
else: |
|
metrics["acc"] = 0.0 |
|
metrics["mcc"] = matthews_corrcoef(y_true, y_pred) |
|
return metrics |
|
|
|
|
|
def calc_span(golds, preds, mode="span"): |
|
def _get_tokens(spans: list[tuple[tuple[int]]]) -> list[int]: |
|
tokens = [] |
|
for span in spans: |
|
for part in span: |
|
_toks = [] |
|
if len(part) == 1: |
|
_toks = [part[0]] |
|
elif len(part) > 1: |
|
if mode == "w2": |
|
_toks = [*part] |
|
elif mode == "span": |
|
_toks = [*range(part[0], part[1] + 1)] |
|
else: |
|
raise ValueError |
|
tokens.extend(_toks) |
|
return tokens |
|
|
|
metrics = { |
|
"em": -1, |
|
"f1": None, |
|
} |
|
acc_num = 0 |
|
tp = fp = fn = 0 |
|
for gold, pred in zip(golds, preds): |
|
if gold == pred: |
|
acc_num += 1 |
|
gold_tokens = _get_tokens(gold) |
|
pred_tokens = _get_tokens(pred) |
|
tp += len(set(gold_tokens) & set(pred_tokens)) |
|
fp += len(set(pred_tokens) - set(gold_tokens)) |
|
fn += len(set(gold_tokens) - set(pred_tokens)) |
|
if len(golds) > 0: |
|
metrics["em"] = acc_num / len(golds) |
|
else: |
|
metrics["em"] = 0.0 |
|
metrics["f1"] = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn) |
|
return metrics |
|
|
|
|
|
class MultiPartSpanMetric(MetricBase): |
|
def _encode_span_to_label_dict(self, span_to_label: dict) -> list: |
|
span_to_label_list = [] |
|
for key, val in span_to_label.items(): |
|
span_to_label_list.append({"key": key, "val": val}) |
|
return span_to_label_list |
|
|
|
def _decode_span_to_label(self, span_to_label_list: list) -> dict: |
|
span_to_label = {} |
|
for content in span_to_label_list: |
|
span_to_label[tuple(content["key"])] = content["val"] |
|
return span_to_label |
|
|
|
def get_instances_from_batch(self, raw_batch: dict, out_batch: dict) -> Tuple: |
|
gold_instances = [] |
|
pred_instances = [] |
|
|
|
batch_gold = decompose_batch_into_instances(raw_batch) |
|
assert len(batch_gold) == len(out_batch["pred"]) |
|
|
|
for i, gold in enumerate(batch_gold): |
|
ins_id = gold["raw"].get("id", generate_random_string_with_datetime()) |
|
|
|
|
|
span_to_label_list = self._encode_span_to_label_dict(gold["span_to_label"]) |
|
gold["span_to_label"] = span_to_label_list |
|
gold_instances.append( |
|
{ |
|
"id": ins_id, |
|
"span_to_label_list": span_to_label_list, |
|
"raw_gold_content": gold, |
|
"spans": set( |
|
tuple(multi_part_span) for multi_part_span in gold["spans"] |
|
), |
|
} |
|
) |
|
pred_instances.append( |
|
{ |
|
"id": ins_id, |
|
"spans": set( |
|
tuple(multi_part_span) |
|
for multi_part_span in out_batch["pred"][i] |
|
), |
|
} |
|
) |
|
|
|
return gold_instances, pred_instances |
|
|
|
def calculate_scores(self, golds: list, preds: list) -> dict: |
|
|
|
general_gold_spans, general_pred_spans = [], [] |
|
|
|
gold_cls_list, pred_cls_list = [], [] |
|
|
|
gold_ent_list, pred_ent_list = [], [] |
|
|
|
gold_rel_list, pred_rel_list = [], [] |
|
|
|
gold_event_list, pred_event_list = [], [] |
|
|
|
gold_span_list, pred_span_list = [], [] |
|
|
|
gold_discon_ent_list, pred_discon_ent_list = [], [] |
|
|
|
gold_hyper_rel_list, pred_hyper_rel_list = [], [] |
|
|
|
for gold, pred in zip(golds, preds): |
|
general_gold_spans.append(gold["spans"]) |
|
general_pred_spans.append(pred["spans"]) |
|
span_to_label = self._decode_span_to_label(gold["span_to_label_list"]) |
|
gold_clses, pred_clses = [], [] |
|
gold_ents, pred_ents = [], [] |
|
gold_rels, pred_rels = [], [] |
|
gold_trigger_to_event = defaultdict( |
|
lambda: {"event_type": "", "arguments": []} |
|
) |
|
pred_trigger_to_event = defaultdict( |
|
lambda: {"event_type": "", "arguments": []} |
|
) |
|
gold_events, pred_events = [], [] |
|
gold_spans, pred_spans = [], [] |
|
gold_discon_ents, pred_discon_ents = [], [] |
|
gold_hyper_rels, pred_hyper_rels = [], [] |
|
|
|
raw_schema = gold["raw_gold_content"]["raw"]["schema"] |
|
for span in gold["spans"]: |
|
if span[0] in span_to_label: |
|
label = span_to_label[span[0]] |
|
if label["task"] == "cls" and len(span) == 1: |
|
gold_clses.append(label["string"]) |
|
elif label["task"] == "ent" and len(span) == 2: |
|
gold_ents.append((label["string"], *span[1:])) |
|
elif label["task"] == "rel" and len(span) == 3: |
|
gold_rels.append((label["string"], *span[1:])) |
|
elif label["task"] == "event": |
|
if label["type"] == "lm" and len(span) == 2: |
|
gold_trigger_to_event[span[1]]["event_type"] = label["string"] |
|
elif label["type"] == "lr" and len(span) == 3: |
|
gold_trigger_to_event[span[1]]["arguments"].append( |
|
{"argument": span[2], "role": label["string"]} |
|
) |
|
elif label["task"] == "discontinuous_ent" and len(span) > 1: |
|
gold_discon_ents.append((label["string"], *span[1:])) |
|
elif label["task"] == "hyper_rel" and len(span) == 5 and span[3] in span_to_label: |
|
q_label = span_to_label[span[3]] |
|
gold_hyper_rels.append((label["string"], span[1], span[2], q_label["string"], span[4])) |
|
else: |
|
|
|
gold_spans.append(tuple(span)) |
|
for trigger, item in gold_trigger_to_event.items(): |
|
legal_roles = raw_schema["event"][item["event_type"]] |
|
gold_events.append( |
|
{ |
|
"trigger": trigger, |
|
"event_type": item["event_type"], |
|
"arguments": [ |
|
arg |
|
for arg in filter( |
|
lambda arg: arg["role"] in legal_roles, |
|
item["arguments"], |
|
) |
|
], |
|
} |
|
) |
|
|
|
for span in pred["spans"]: |
|
if span[0] in span_to_label: |
|
label = span_to_label[span[0]] |
|
if label["task"] == "cls" and len(span) == 1: |
|
pred_clses.append(label["string"]) |
|
elif label["task"] == "ent" and len(span) == 2: |
|
pred_ents.append((label["string"], *span[1:])) |
|
elif label["task"] == "rel" and len(span) == 3: |
|
pred_rels.append((label["string"], *span[1:])) |
|
elif label["task"] == "event": |
|
if label["type"] == "lm" and len(span) == 2: |
|
pred_trigger_to_event[span[1]]["event_type"] = label["string"] |
|
elif label["type"] == "lr" and len(span) == 3: |
|
pred_trigger_to_event[span[1]]["arguments"].append( |
|
{"argument": span[2], "role": label["string"]} |
|
) |
|
elif label["task"] == "discontinuous_ent" and len(span) > 1: |
|
pred_discon_ents.append((label["string"], *span[1:])) |
|
elif label["task"] == "hyper_rel" and len(span) == 5 and span[3] in span_to_label: |
|
q_label = span_to_label[span[3]] |
|
pred_hyper_rels.append((label["string"], span[1], span[2], q_label["string"], span[4])) |
|
else: |
|
|
|
pred_spans.append(tuple(span)) |
|
for trigger, item in pred_trigger_to_event.items(): |
|
if item["event_type"] not in raw_schema["event"]: |
|
continue |
|
legal_roles = raw_schema["event"][item["event_type"]] |
|
pred_events.append( |
|
{ |
|
"trigger": trigger, |
|
"event_type": item["event_type"], |
|
"arguments": [ |
|
arg |
|
for arg in filter( |
|
lambda arg: arg["role"] in legal_roles, |
|
item["arguments"], |
|
) |
|
], |
|
} |
|
) |
|
|
|
gold_cls_list.append(gold_clses) |
|
pred_cls_list.append(pred_clses) |
|
gold_ent_list.append(gold_ents) |
|
pred_ent_list.append(pred_ents) |
|
gold_rel_list.append(gold_rels) |
|
pred_rel_list.append(pred_rels) |
|
gold_event_list.append(gold_events) |
|
pred_event_list.append(pred_events) |
|
gold_span_list.append(gold_spans) |
|
pred_span_list.append(pred_spans) |
|
gold_discon_ent_list.append(gold_discon_ents) |
|
pred_discon_ent_list.append(pred_discon_ents) |
|
gold_hyper_rel_list.append(gold_hyper_rels) |
|
pred_hyper_rel_list.append(pred_hyper_rels) |
|
|
|
metrics = { |
|
"general_spans": tagging_prf1( |
|
general_gold_spans, general_pred_spans, type_idx=None |
|
), |
|
"cls": calc_cls(gold_cls_list, pred_cls_list), |
|
"ent": calc_ent(gold_ent_list, pred_ent_list), |
|
"rel": calc_rel(gold_rel_list, pred_rel_list), |
|
"event": { |
|
"trigger_id": calc_trigger_identification_metrics( |
|
gold_event_list, pred_event_list |
|
), |
|
"trigger_cls": calc_trigger_classification_metrics( |
|
gold_event_list, pred_event_list |
|
), |
|
"arg_id": calc_arg_identification_metrics( |
|
gold_event_list, pred_event_list |
|
), |
|
"arg_cls": calc_arg_classification_metrics( |
|
gold_event_list, pred_event_list |
|
), |
|
"char_event": calc_char_event(gold_event_list, pred_event_list), |
|
}, |
|
"discontinuous_ent": tagging_prf1( |
|
gold_discon_ent_list, pred_discon_ent_list, type_idx=None |
|
), |
|
"hyper_rel": tagging_prf1( |
|
gold_hyper_rel_list, pred_hyper_rel_list, type_idx=None |
|
), |
|
|
|
"span": calc_span(gold_span_list, pred_span_list), |
|
} |
|
|
|
return metrics |
|
|