File size: 6,206 Bytes
d868172 357be93 7d0539f d868172 357be93 d868172 4697040 d868172 4697040 d868172 516cd0a d868172 516cd0a d868172 516cd0a d868172 357be93 d868172 a945a9c bd79886 d868172 69c8f9c d868172 b1746af 69c8f9c b1746af 3df5bff f437981 1ea9d7f f437981 1ea9d7f bffe103 987f96d 357be93 34d59ab d868172 357be93 7d0539f d868172 ddc72bb d868172 516cd0a 8d26403 d868172 516cd0a 5b87284 357be93 7d0539f 357be93 71e1eca 357be93 9bbf34b 357be93 5b87284 d868172 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
from transformers import Pipeline
import nltk
import requests
import torch
nltk.download("averaged_perceptron_tagger")
nltk.download("averaged_perceptron_tagger_eng")
NEL_MODEL = "nel-mgenre-multilingual"
def get_wikipedia_page_props(input_str: str):
"""
Retrieves the QID for a given Wikipedia page name from the specified language Wikipedia.
If the request fails, it falls back to using the OpenRefine Wikidata API.
Args:
input_str (str): The input string in the format "page_name >> language".
Returns:
str: The QID or "NIL" if the QID is not found.
"""
try:
# Preprocess the input string
page_name, language = input_str.split(" >> ")
page_name = page_name.strip()
language = language.strip()
except ValueError:
return "Invalid input format. Use 'page_name >> language'.", "None"
wikipedia_url = f"https://{language}.wikipedia.org/w/api.php"
wikipedia_params = {
"action": "query",
"prop": "pageprops",
"format": "json",
"titles": page_name,
}
qid = "NIL"
try:
# Attempt to fetch from Wikipedia API
response = requests.get(wikipedia_url, params=wikipedia_params)
response.raise_for_status()
data = response.json()
if "pages" in data["query"]:
page_id = list(data["query"]["pages"].keys())[0]
if "pageprops" in data["query"]["pages"][page_id]:
page_props = data["query"]["pages"][page_id]["pageprops"]
if "wikibase_item" in page_props:
return page_props["wikibase_item"], language
else:
return qid, language
else:
return qid, language
except Exception as e:
return qid, language
def get_wikipedia_title(qid, language="en"):
url = f"https://www.wikidata.org/w/api.php"
params = {
"action": "wbgetentities",
"format": "json",
"ids": qid,
"props": "sitelinks/urls",
"sitefilter": f"{language}wiki",
}
response = requests.get(url, params=params)
data = response.json()
try:
title = data["entities"][qid]["sitelinks"][f"{language}wiki"]["title"]
url = data["entities"][qid]["sitelinks"][f"{language}wiki"]["url"]
return title, url
except KeyError:
return "NIL", "None"
class NelPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "text" in kwargs:
preprocess_kwargs["text"] = kwargs["text"]
return preprocess_kwargs, {}, {}
def preprocess(self, text, **kwargs):
# Extract the entity between [START] and [END]
start_token = "[START]"
end_token = "[END]"
if start_token in text and end_token in text:
start_idx = text.index(start_token) + len(start_token)
end_idx = text.index(end_token)
enclosed_entity = text[start_idx:end_idx].strip()
lOffset = start_idx # left offset (start of the entity)
rOffset = end_idx # right offset (end of the entity)
else:
enclosed_entity = None
lOffset = None
rOffset = None
# Generate predictions using the model
outputs = self.model.generate(
**self.tokenizer([text], return_tensors="pt").to(self.device),
num_beams=1,
num_return_sequences=1,
max_new_tokens=30,
return_dict_in_generate=True,
output_scores=True,
)
# Decode the predictions into readable text
wikipedia_predictions = self.tokenizer.batch_decode(
outputs.sequences, skip_special_tokens=True
)
# Process the scores for each token
transition_scores = self.model.compute_transition_scores(
outputs.sequences, outputs.scores, normalize_logits=True
)
log_prob_sum = sum(transition_scores[0])
# Calculate the probability for the entire sequence by exponentiating the sum of log probabilities
sequence_confidence = torch.exp(log_prob_sum)
percentages = sequence_confidence.cpu().numpy() * 100.0
# Return the predictions along with the extracted entity, lOffset, and rOffset
return wikipedia_predictions, enclosed_entity, lOffset, rOffset, [percentages]
def _forward(self, inputs):
return inputs
def postprocess(self, outputs, **kwargs):
"""
Postprocess the outputs of the model
:param outputs:
:param kwargs:
:return:
"""
# {
# "surface": sentences[i].split("[START]")[1].split("[END]")[0],
# "lOffset": lOffset,
# "rOffset": rOffset,
# "type": "UNK",
# "id": f"{lOffset}:{rOffset}:{surface}:{NEL_MODEL}",
# "wkd_id": get_wikipedia_page_props(wikipedia_titles[i * 2]),
# "wkpedia_pagename": wikipedia_titles[
# i * 2
# ], # This can be improved with a real API call to get the QID
# "confidence_nel": np.round(percentages[i], 2),
# }
wikipedia_predictions, enclosed_entity, lOffset, rOffset, percentages = outputs
results = []
for idx, wikipedia_name in enumerate(wikipedia_predictions):
# Get QID
qid, language = get_wikipedia_page_props(wikipedia_name)
# print(f"{wikipedia_name} -- QID: {qid}")
# Get Wikipedia title and URL
wkpedia_pagename, url = get_wikipedia_title(qid, language)
results.append(
{
# "id": f"{lOffset}:{rOffset}:{enclosed_entity}:{NEL_MODEL}",
"surface": enclosed_entity,
"wkd_id": qid,
"wkpedia_pagename": wkpedia_pagename,
"wkpedia_url": url,
"type": "UNK",
"confidence_nel": round(percentages[idx], 2),
"lOffset": lOffset,
"rOffset": rOffset,
}
)
return results
|