KoichiYasuoka
commited on
Commit
·
5b3c45d
1
Parent(s):
b441a93
support transformers>=4.28
Browse files
ud.py
CHANGED
@@ -32,6 +32,8 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
|
|
32 |
return {"logits":e.logits[:,1:-2,:],**model_inputs}
|
33 |
def postprocess(self,model_outputs,**kwargs):
|
34 |
import numpy
|
|
|
|
|
35 |
e=model_outputs["logits"].numpy()
|
36 |
r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
|
37 |
e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)
|
|
|
32 |
return {"logits":e.logits[:,1:-2,:],**model_inputs}
|
33 |
def postprocess(self,model_outputs,**kwargs):
|
34 |
import numpy
|
35 |
+
if "logits" not in model_outputs:
|
36 |
+
return "".join(self.postprocess(x,**kwargs) for x in model_outputs)
|
37 |
e=model_outputs["logits"].numpy()
|
38 |
r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
|
39 |
e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)
|