Emanuela Boros
commited on
Commit
·
d868172
1
Parent(s):
57b93e7
added pipeline
Browse files- generic_nel.py +152 -0
generic_nel.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Pipeline
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import nltk
|
5 |
+
|
6 |
+
nltk.download("averaged_perceptron_tagger")
|
7 |
+
nltk.download("averaged_perceptron_tagger_eng")
|
8 |
+
from nltk.chunk import conlltags2tree
|
9 |
+
from nltk import pos_tag
|
10 |
+
from nltk.tree import Tree
|
11 |
+
import requests
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import re, string
|
14 |
+
|
15 |
+
|
16 |
+
def get_wikipedia_page_props(input_str: str):
|
17 |
+
"""
|
18 |
+
Retrieves the QID for a given Wikipedia page name from the specified language Wikipedia.
|
19 |
+
If the request fails, it falls back to using the OpenRefine Wikidata API.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
input_str (str): The input string in the format "page_name >> language".
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
str: The QID or "NIL" if the QID is not found.
|
26 |
+
"""
|
27 |
+
try:
|
28 |
+
# Preprocess the input string
|
29 |
+
page_name, language = input_str.split(" >> ")
|
30 |
+
page_name = page_name.strip()
|
31 |
+
language = language.strip()
|
32 |
+
except ValueError:
|
33 |
+
return "Invalid input format. Use 'page_name >> language'."
|
34 |
+
|
35 |
+
wikipedia_url = f"https://{language}.wikipedia.org/w/api.php"
|
36 |
+
wikipedia_params = {
|
37 |
+
"action": "query",
|
38 |
+
"prop": "pageprops",
|
39 |
+
"format": "json",
|
40 |
+
"titles": page_name,
|
41 |
+
}
|
42 |
+
|
43 |
+
qid = "NIL"
|
44 |
+
try:
|
45 |
+
# Attempt to fetch from Wikipedia API
|
46 |
+
response = requests.get(wikipedia_url, params=wikipedia_params)
|
47 |
+
response.raise_for_status()
|
48 |
+
data = response.json()
|
49 |
+
|
50 |
+
if "pages" in data["query"]:
|
51 |
+
page_id = list(data["query"]["pages"].keys())[0]
|
52 |
+
|
53 |
+
if "pageprops" in data["query"]["pages"][page_id]:
|
54 |
+
page_props = data["query"]["pages"][page_id]["pageprops"]
|
55 |
+
|
56 |
+
if "wikibase_item" in page_props:
|
57 |
+
return page_props["wikibase_item"]
|
58 |
+
else:
|
59 |
+
return qid
|
60 |
+
else:
|
61 |
+
return qid
|
62 |
+
except Exception as e:
|
63 |
+
return qid
|
64 |
+
|
65 |
+
|
66 |
+
def get_wikipedia_title(qid, language="en"):
|
67 |
+
url = f"https://www.wikidata.org/w/api.php"
|
68 |
+
params = {
|
69 |
+
"action": "wbgetentities",
|
70 |
+
"format": "json",
|
71 |
+
"ids": qid,
|
72 |
+
"props": "sitelinks/urls",
|
73 |
+
"sitefilter": f"{language}wiki",
|
74 |
+
}
|
75 |
+
|
76 |
+
response = requests.get(url, params=params)
|
77 |
+
data = response.json()
|
78 |
+
|
79 |
+
try:
|
80 |
+
title = data["entities"][qid]["sitelinks"][f"{language}wiki"]["title"]
|
81 |
+
url = data["entities"][qid]["sitelinks"][f"{language}wiki"]["url"]
|
82 |
+
return title, url
|
83 |
+
except KeyError:
|
84 |
+
return "NIL", "None"
|
85 |
+
|
86 |
+
|
87 |
+
class NelPipeline(Pipeline):
|
88 |
+
|
89 |
+
def _sanitize_parameters(self, **kwargs):
|
90 |
+
preprocess_kwargs = {}
|
91 |
+
if "text" in kwargs:
|
92 |
+
preprocess_kwargs["text"] = kwargs["text"]
|
93 |
+
|
94 |
+
return preprocess_kwargs, {}, {}
|
95 |
+
|
96 |
+
def preprocess(self, text, **kwargs):
|
97 |
+
|
98 |
+
outputs = self.model.generate(
|
99 |
+
**self.tokenizer([text], return_tensors="pt"),
|
100 |
+
num_beams=5,
|
101 |
+
num_return_sequences=5,
|
102 |
+
max_new_tokens=30,
|
103 |
+
)
|
104 |
+
wikipedia_predictons = self.tokenizer.batch_decode(
|
105 |
+
outputs, skip_special_tokens=True
|
106 |
+
)
|
107 |
+
print(f"Decoded: {wikipedia_predictons}")
|
108 |
+
|
109 |
+
return wikipedia_predictons
|
110 |
+
|
111 |
+
def _forward(self, inputs):
|
112 |
+
return inputs
|
113 |
+
|
114 |
+
def postprocess(self, outputs, **kwargs):
|
115 |
+
"""
|
116 |
+
Postprocess the outputs of the model
|
117 |
+
:param outputs:
|
118 |
+
:param kwargs:
|
119 |
+
:return:
|
120 |
+
"""
|
121 |
+
# outputs
|
122 |
+
#
|
123 |
+
# predictions = {}
|
124 |
+
# confidence_scores = {}
|
125 |
+
# for task, logits in tokens_result.logits.items():
|
126 |
+
# predictions[task] = torch.argmax(logits, dim=-1).tolist()[0]
|
127 |
+
# confidence_scores[task] = F.softmax(logits, dim=-1).tolist()[0]
|
128 |
+
#
|
129 |
+
# entities = {}
|
130 |
+
# for task in predictions.keys():
|
131 |
+
# words_list, preds_list, confidence_list = realign(
|
132 |
+
# text_sentence,
|
133 |
+
# predictions[task],
|
134 |
+
# confidence_scores[task],
|
135 |
+
# self.tokenizer,
|
136 |
+
# self.id2label[task],
|
137 |
+
# )
|
138 |
+
#
|
139 |
+
# entities[task] = get_entities(words_list, preds_list, confidence_list, text)
|
140 |
+
#
|
141 |
+
# postprocessed_entities = self.postprocess_entities(entities, text_sentence)
|
142 |
+
results = []
|
143 |
+
for wikipedia_name in outputs:
|
144 |
+
# Get QID
|
145 |
+
qid = get_wikipedia_page_props(wikipedia_name)
|
146 |
+
print(f"{wikipedia_name} -- QID: {qid}")
|
147 |
+
|
148 |
+
# Get Wikipedia title and URL
|
149 |
+
title, url = get_wikipedia_title(qid)
|
150 |
+
results.append({"title": title, "qid": qid, "url": url})
|
151 |
+
|
152 |
+
return results
|