|
from collections import defaultdict |
|
|
|
from rex.metrics.tagging import tagging_prf1 |
|
from rex.utils.io import load_jsonlines |
|
from rex.utils.position import find_all_positions |
|
|
|
|
|
def main(): |
|
middle_filepath = "outputs/InstructBert_TagSpan_DebertaV3Base_ACE05EN_labelmap_Rel_updateTag_bs32/middle/test.final.jsonl" |
|
data = load_jsonlines(middle_filepath) |
|
for ins in data: |
|
gold = ins["gold"] |
|
pred = ins["pred"] |
|
if gold["spans"] != pred["spans"]: |
|
breakpoint() |
|
|
|
|
|
def check_ent_string_matching_upper_bound(filepath: str, strategy: str = "first"): |
|
def _check_overlap(x, y): |
|
if x[0] > y[1] or y[0] > x[1]: |
|
return False |
|
else: |
|
return True |
|
|
|
data = load_jsonlines(filepath) |
|
golds = [] |
|
preds = [] |
|
for ins in data: |
|
text = ins["text"] |
|
gold_ents = ins["ans"]["ent"] |
|
gold_ents = list( |
|
set([(ent["text"], ent["type"], tuple(ent["span"])) for ent in gold_ents]) |
|
) |
|
gold_ents.sort(key=lambda x: len(x[0]), reverse=True) |
|
pred_ents = [] |
|
matched = set() |
|
for gold_ent in gold_ents: |
|
ent_string = gold_ent[0] |
|
ent_type = gold_ent[1] |
|
positions = find_all_positions(text, ent_string) |
|
if strategy == "first": |
|
for position in positions: |
|
if (ent_type, position) not in matched: |
|
matched.add((ent_type, position)) |
|
pred_ents.append((ent_string, ent_type, tuple(position))) |
|
else: |
|
flag = False |
|
for position in positions: |
|
for _, g in matched: |
|
if _check_overlap(g, position): |
|
flag = True |
|
if flag: |
|
continue |
|
|
|
if (ent_type, position) not in matched: |
|
matched.add((ent_type, position)) |
|
pred_ents.append((ent_string, ent_type, tuple(position))) |
|
break |
|
|
|
golds.append(gold_ents) |
|
preds.append(pred_ents) |
|
|
|
results = tagging_prf1(golds, preds) |
|
|
|
print(f"filepath: {filepath}, Strategy: {strategy}") |
|
print(f"Results: {results['micro']}") |
|
|
|
|
|
def check_rel_tanl_upper_bound(filepath): |
|
data = load_jsonlines(filepath) |
|
golds = [] |
|
preds = [] |
|
for ins in data: |
|
text = ins["text"] |
|
gold_rels = ins["ans"]["rel"] |
|
ent_text_to_spans = defaultdict(set) |
|
for ent in ins["ans"]["ent"]: |
|
ent_text_to_spans[ent["text"]].add(tuple(ent["span"])) |
|
gold_rels = list( |
|
set( |
|
[ |
|
( |
|
tuple(rel["head"]["span"]), |
|
rel["relation"], |
|
tuple(rel["tail"]["span"]), |
|
) |
|
for rel in gold_rels |
|
] |
|
) |
|
) |
|
pred_rels = [] |
|
for pred_rel in ins["ans"]["rel"]: |
|
|
|
tail_text = pred_rel["tail"]["text"] |
|
if ( |
|
tail_text in ent_text_to_spans |
|
and len(ent_text_to_spans[tail_text]) == 1 |
|
): |
|
tail_span = list(ent_text_to_spans[tail_text])[0] |
|
pred_rels.append( |
|
(tuple(pred_rel["head"]["span"]), pred_rel["relation"], tail_span) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
golds.append(gold_rels) |
|
preds.append(pred_rels) |
|
|
|
results = tagging_prf1(golds, preds) |
|
|
|
print(f"filepath: {filepath}") |
|
print(f"Results: {results['micro']}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for filepath in [ |
|
"/data/tzhu/Mirror/resources/Mirror/uie/rel/ace05-rel/test.jsonl", |
|
"/data/tzhu/Mirror/resources/Mirror/uie/rel/conll04/test.jsonl", |
|
"/data/tzhu/Mirror/resources/Mirror/uie/rel/nyt/test.jsonl", |
|
"/data/tzhu/Mirror/resources/Mirror/uie/rel/scierc/test.jsonl", |
|
]: |
|
check_rel_tanl_upper_bound(filepath) |
|
|