|
from transformers import AutoTokenizer
|
|
import re
|
|
import string
|
|
|
|
|
|
class TF_Tokenizer:
|
|
def __init__(self, model_str):
|
|
tok = AutoTokenizer.from_pretrained(model_str)
|
|
|
|
def __call__(self, txt):
|
|
return self.tok.tokenize(txt)
|
|
|
|
|
|
class WS_Tokenizer:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def __call__(self, txt):
|
|
return re.findall(r"[{}]|\w+".format(string.punctuation), txt)
|
|
|
|
|
|
def convert_spans_to_bio(txt, roles, tokenizer_func):
|
|
roles = sorted(roles, key=lambda x: x["start"])
|
|
roles_left = [r["start"] for r in roles]
|
|
|
|
ttxt = tokenizer_func(txt)
|
|
|
|
c = 0
|
|
cr = -1
|
|
prev = "O"
|
|
troles = []
|
|
for tok in ttxt:
|
|
if c >= len(txt):
|
|
break
|
|
|
|
while txt[c] == " ":
|
|
c += 1
|
|
|
|
else:
|
|
if c in roles_left:
|
|
ind = roles_left.index(c)
|
|
cr = roles[ind]["end"]
|
|
prev = "I-" + roles[ind]["label"]
|
|
troles.append("B-" + roles[ind]["label"])
|
|
else:
|
|
if c < cr:
|
|
troles.append(prev)
|
|
else:
|
|
troles.append("O")
|
|
|
|
c += len(tok)
|
|
|
|
if len(ttxt) != len(troles):
|
|
troles += ["O"] * (len(ttxt) - len(troles))
|
|
|
|
assert len(ttxt) == len(troles)
|
|
return troles
|
|
|
|
|
|
def convert_bio_to_spans(txt, troles, tokenizer_func):
|
|
c = 0
|
|
c2 = 0
|
|
cr = -1
|
|
cs = -1
|
|
prev = "O"
|
|
|
|
roles = []
|
|
ttxt = tokenizer_func(txt)
|
|
|
|
if len(ttxt) != len(troles):
|
|
ttxt = ttxt[: len(troles)]
|
|
|
|
for j, tok in enumerate(ttxt):
|
|
if c >= len(txt):
|
|
break
|
|
|
|
while c < len(txt) and txt[c].isspace():
|
|
c += 1
|
|
|
|
if tok[:2] == "##" or tok == "[UNK]":
|
|
c += len(tok) - 2 if tok[:2] == "##" else 1
|
|
else:
|
|
if troles[j].startswith("B-"):
|
|
if cs >= cr:
|
|
cr = c
|
|
if cs >= 0:
|
|
roles.append({"start": cs, "end": c2, "label": prev})
|
|
cs = c
|
|
prev = troles[j][2:]
|
|
else:
|
|
if troles[j] == "O":
|
|
if cs >= cr:
|
|
cr = c
|
|
if cs >= 0:
|
|
roles.append({"start": cs, "end": c2, "label": prev})
|
|
c += len(tok)
|
|
c2 = c
|
|
|
|
if cs >= cr:
|
|
if cs >= 0:
|
|
roles.append({"start": cs, "end": c2, "label": prev})
|
|
|
|
return roles
|
|
|
|
|
|
def span2bio(txt, labels):
|
|
roles = sorted(labels, key=lambda x: x["label"])
|
|
roles_left = [r["start"] for r in roles]
|
|
|
|
ttxt = re.findall(r"[{}]|\w+".format(string.punctuation), txt)
|
|
|
|
c = 0
|
|
cr = -1
|
|
prev = "O"
|
|
troles = []
|
|
for tok in ttxt:
|
|
if c >= len(txt):
|
|
break
|
|
|
|
while txt[c] == " ":
|
|
c += 1
|
|
|
|
else:
|
|
if c in roles_left:
|
|
ind = roles_left.index(c)
|
|
cr = roles[ind]["end"]
|
|
prev = "I-" + roles[ind]["label"]
|
|
troles.append("B-" + roles[ind]["label"])
|
|
else:
|
|
if c < cr:
|
|
troles.append(prev)
|
|
else:
|
|
troles.append("O")
|
|
|
|
c += len(tok)
|
|
|
|
if len(ttxt) != len(troles):
|
|
troles += ["O"] * (len(ttxt) - len(troles))
|
|
|
|
assert len(ttxt) == len(troles)
|
|
return ttxt, troles
|
|
|