File size: 3,767 Bytes
6a12fa7 4d4008f fca4282 4d4008f fca4282 4d4008f a04d51c fca4282 6a12fa7 fca4282 40e1b0d fca4282 4d4008f 6a12fa7 a77d761 6a12fa7 0b61b47 6a12fa7 0b61b47 fca4282 0b61b47 28d14f1 0b61b47 97591d0 0b61b47 97591d0 6a12fa7 fca4282 6a12fa7 4d4008f 0b61b47 4d4008f 0b61b47 6a12fa7 0b61b47 6a12fa7 0b61b47 6a12fa7 0b61b47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import numpy
from transformers import TokenClassificationPipeline
class UniversalDependenciesPipeline(TokenClassificationPipeline):
def _forward(self,model_inputs):
import torch
v=model_inputs["input_ids"][0].tolist()
with torch.no_grad():
e=self.model(input_ids=torch.tensor([v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)],device=self.device))
return {"logits":e.logits[:,1:-2,:],**model_inputs}
def check_model_type(self,supported_models):
pass
def postprocess(self,model_outputs,**kwargs):
if "logits" not in model_outputs:
return "".join(self.postprocess(x,**kwargs) for x in model_outputs)
e=model_outputs["logits"].numpy()
r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,-numpy.inf)
g=self.model.config.label2id["X|_|goeswith"]
r=numpy.tri(e.shape[0])
for i in range(e.shape[0]):
for j in range(i+2,e.shape[1]):
r[i,j]=r[i,j-1] if numpy.argmax(e[i,j-1])==g else 1
e[:,:,g]+=numpy.where(r==0,0,-numpy.inf)
m,p=numpy.max(e,axis=2),numpy.argmax(e,axis=2)
h=self.chu_liu_edmonds(m)
z=[i for i,j in enumerate(h) if i==j]
if len(z)>1:
k,h=z[numpy.argmax(m[z,z])],numpy.min(m)-numpy.max(m)
m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
h=self.chu_liu_edmonds(m)
v=[(s,e) for s,e in model_outputs["offset_mapping"][0].tolist() if s<e]
q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
for i,j in reversed(list(enumerate(q[1:],1))):
if j[-1]=="goeswith" and set([t[-1] for t in q[h[i]+1:i+1]])=={"goeswith"}:
h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
v[i-1]=(v[i-1][0],v.pop(i)[1])
q.pop(i)
elif v[i-1][1]>v[i][0]:
h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
v[i-1]=(v[i-1][0],v.pop(i)[1])
q.pop(i)
t=model_outputs["sentence"].replace("\n"," ")
for i,(s,e) in reversed(list(enumerate(v))):
w=t[s:e]
if w.startswith(" "):
j=len(w)-len(w.lstrip())
w=w.lstrip()
v[i]=(v[i][0]+j,v[i][1])
if w.endswith(" "):
j=len(w)-len(w.rstrip())
w=w.rstrip()
v[i]=(v[i][0],v[i][1]-j)
if w.strip()=="":
h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
v.pop(i)
q.pop(i)
u="# text = "+t+"\n"
for i,(s,e) in enumerate(v):
u+="\t".join([str(i+1),t[s:e],"_",q[i][0],"_","|".join(q[i][1:-1]),str(0 if h[i]==i else h[i]+1),q[i][-1],"_","_" if i+1<len(v) and e<v[i+1][0] else "SpaceAfter=No"])+"\n"
return u+"\n"
def chu_liu_edmonds(self,matrix):
h=numpy.argmax(matrix,axis=0)
x=[-1 if i==j else j for i,j in enumerate(h)]
for b in [lambda x,i,j:-1 if i not in x else x[i],lambda x,i,j:-1 if j<0 else x[j]]:
y=[]
while x!=y:
y=list(x)
for i,j in enumerate(x):
x[i]=b(x,i,j)
if max(x)<0:
return h
y,x=[i for i,j in enumerate(x) if j==max(x)],[i for i,j in enumerate(x) if j<max(x)]
z=matrix-numpy.max(matrix,axis=0)
m=numpy.block([[z[x,:][:,x],numpy.max(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.max(z[y,:][:,x],axis=0),numpy.max(z[y,y])]])
k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.argmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
i=y[numpy.argmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
h[i]=x[k[-1]] if k[-1]<len(x) else i
return h
|