Emanuela Boros
commited on
Commit
·
357be93
1
Parent(s):
a945a9c
added other details
Browse files- generic_nel.py +47 -11
generic_nel.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
from transformers import Pipeline
|
2 |
import nltk
|
|
|
3 |
|
4 |
nltk.download("averaged_perceptron_tagger")
|
5 |
nltk.download("averaged_perceptron_tagger_eng")
|
6 |
-
|
|
|
7 |
|
8 |
|
9 |
def get_wikipedia_page_props(input_str: str):
|
@@ -87,25 +89,36 @@ class NelPipeline(Pipeline):
|
|
87 |
return preprocess_kwargs, {}, {}
|
88 |
|
89 |
def preprocess(self, text, **kwargs):
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
outputs = self.model.generate(
|
92 |
**self.tokenizer([text], return_tensors="pt").to(self.device),
|
93 |
num_beams=5,
|
94 |
num_return_sequences=5,
|
95 |
max_new_tokens=30,
|
96 |
)
|
97 |
-
# print(outputs)
|
98 |
-
# token_ids, scores = outputs.sequences, outputs.sequences_scores
|
99 |
-
# scores_tensor = scores.clone().detach()
|
100 |
-
# probabilities = torch.exp(scores_tensor)
|
101 |
-
# percentages = (probabilities * 100.0).cpu().numpy().tolist()
|
102 |
|
|
|
103 |
wikipedia_predictions = self.tokenizer.batch_decode(
|
104 |
outputs, skip_special_tokens=True
|
105 |
)
|
106 |
-
# print(f"Decoded: {wikipedia_predictons}")np.round(percentages[i], 2)
|
107 |
|
108 |
-
|
|
|
109 |
|
110 |
def _forward(self, inputs):
|
111 |
return inputs
|
@@ -117,7 +130,20 @@ class NelPipeline(Pipeline):
|
|
117 |
:param kwargs:
|
118 |
:return:
|
119 |
"""
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
results = []
|
122 |
for idx, wikipedia_name in enumerate(wikipedia_predictions):
|
123 |
# Get QID
|
@@ -127,7 +153,17 @@ class NelPipeline(Pipeline):
|
|
127 |
# Get Wikipedia title and URL
|
128 |
title, url = get_wikipedia_title(qid)
|
129 |
results.append(
|
130 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
)
|
132 |
|
133 |
return results
|
|
|
1 |
from transformers import Pipeline
|
2 |
import nltk
|
3 |
+
import requests
|
4 |
|
5 |
nltk.download("averaged_perceptron_tagger")
|
6 |
nltk.download("averaged_perceptron_tagger_eng")
|
7 |
+
|
8 |
+
NEL_MODEL = "nel-mgenre-multilingual"
|
9 |
|
10 |
|
11 |
def get_wikipedia_page_props(input_str: str):
|
|
|
89 |
return preprocess_kwargs, {}, {}
|
90 |
|
91 |
def preprocess(self, text, **kwargs):
|
92 |
+
# Extract the entity between [START] and [END]
|
93 |
+
start_token = "[START]"
|
94 |
+
end_token = "[END]"
|
95 |
+
|
96 |
+
if start_token in text and end_token in text:
|
97 |
+
start_idx = text.index(start_token) + len(start_token)
|
98 |
+
end_idx = text.index(end_token)
|
99 |
+
enclosed_entity = text[start_idx:end_idx].strip()
|
100 |
+
lOffset = start_idx # left offset (start of the entity)
|
101 |
+
rOffset = end_idx # right offset (end of the entity)
|
102 |
+
else:
|
103 |
+
enclosed_entity = None
|
104 |
+
lOffset = None
|
105 |
+
rOffset = None
|
106 |
+
|
107 |
+
# Generate predictions using the model
|
108 |
outputs = self.model.generate(
|
109 |
**self.tokenizer([text], return_tensors="pt").to(self.device),
|
110 |
num_beams=5,
|
111 |
num_return_sequences=5,
|
112 |
max_new_tokens=30,
|
113 |
)
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
+
# Decode the predictions into readable text
|
116 |
wikipedia_predictions = self.tokenizer.batch_decode(
|
117 |
outputs, skip_special_tokens=True
|
118 |
)
|
|
|
119 |
|
120 |
+
# Return the predictions along with the extracted entity, lOffset, and rOffset
|
121 |
+
return wikipedia_predictions, enclosed_entity, lOffset, rOffset
|
122 |
|
123 |
def _forward(self, inputs):
|
124 |
return inputs
|
|
|
130 |
:param kwargs:
|
131 |
:return:
|
132 |
"""
|
133 |
+
|
134 |
+
# {
|
135 |
+
# "surface": sentences[i].split("[START]")[1].split("[END]")[0],
|
136 |
+
# "lOffset": lOffset,
|
137 |
+
# "rOffset": rOffset,
|
138 |
+
# "type": "UNK",
|
139 |
+
# "id": f"{lOffset}:{rOffset}:{surface}:{NEL_MODEL}",
|
140 |
+
# "wkd_id": get_wikipedia_page_props(wikipedia_titles[i * 2]),
|
141 |
+
# "wkpedia_pagename": wikipedia_titles[
|
142 |
+
# i * 2
|
143 |
+
# ], # This can be improved with a real API call to get the QID
|
144 |
+
# "confidence_nel": np.round(percentages[i], 2),
|
145 |
+
# }
|
146 |
+
wikipedia_predictions, enclosed_entity, lOffset, rOffset = outputs
|
147 |
results = []
|
148 |
for idx, wikipedia_name in enumerate(wikipedia_predictions):
|
149 |
# Get QID
|
|
|
153 |
# Get Wikipedia title and URL
|
154 |
title, url = get_wikipedia_title(qid)
|
155 |
results.append(
|
156 |
+
{
|
157 |
+
"id": f"{lOffset}:{rOffset}:{enclosed_entity}:{NEL_MODEL}",
|
158 |
+
"surface": enclosed_entity,
|
159 |
+
"title": title,
|
160 |
+
"wkd_id": qid,
|
161 |
+
"url": url,
|
162 |
+
"type": "UNK",
|
163 |
+
"confidence_nel": 0.0,
|
164 |
+
"lOffset": lOffset,
|
165 |
+
"rOffset": rOffset,
|
166 |
+
}
|
167 |
)
|
168 |
|
169 |
return results
|