Emanuela Boros commited on
Commit
d868172
·
1 Parent(s): 57b93e7

added pipeline

Browse files
Files changed (1) hide show
  1. generic_nel.py +152 -0
generic_nel.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Pipeline
2
+ import numpy as np
3
+ import torch
4
+ import nltk
5
+
6
+ nltk.download("averaged_perceptron_tagger")
7
+ nltk.download("averaged_perceptron_tagger_eng")
8
+ from nltk.chunk import conlltags2tree
9
+ from nltk import pos_tag
10
+ from nltk.tree import Tree
11
+ import requests
12
+ import torch.nn.functional as F
13
+ import re, string
14
+
15
+
16
+ def get_wikipedia_page_props(input_str: str):
17
+ """
18
+ Retrieves the QID for a given Wikipedia page name from the specified language Wikipedia.
19
+ If the request fails, it falls back to using the OpenRefine Wikidata API.
20
+
21
+ Args:
22
+ input_str (str): The input string in the format "page_name >> language".
23
+
24
+ Returns:
25
+ str: The QID or "NIL" if the QID is not found.
26
+ """
27
+ try:
28
+ # Preprocess the input string
29
+ page_name, language = input_str.split(" >> ")
30
+ page_name = page_name.strip()
31
+ language = language.strip()
32
+ except ValueError:
33
+ return "Invalid input format. Use 'page_name >> language'."
34
+
35
+ wikipedia_url = f"https://{language}.wikipedia.org/w/api.php"
36
+ wikipedia_params = {
37
+ "action": "query",
38
+ "prop": "pageprops",
39
+ "format": "json",
40
+ "titles": page_name,
41
+ }
42
+
43
+ qid = "NIL"
44
+ try:
45
+ # Attempt to fetch from Wikipedia API
46
+ response = requests.get(wikipedia_url, params=wikipedia_params)
47
+ response.raise_for_status()
48
+ data = response.json()
49
+
50
+ if "pages" in data["query"]:
51
+ page_id = list(data["query"]["pages"].keys())[0]
52
+
53
+ if "pageprops" in data["query"]["pages"][page_id]:
54
+ page_props = data["query"]["pages"][page_id]["pageprops"]
55
+
56
+ if "wikibase_item" in page_props:
57
+ return page_props["wikibase_item"]
58
+ else:
59
+ return qid
60
+ else:
61
+ return qid
62
+ except Exception as e:
63
+ return qid
64
+
65
+
66
+ def get_wikipedia_title(qid, language="en"):
67
+ url = f"https://www.wikidata.org/w/api.php"
68
+ params = {
69
+ "action": "wbgetentities",
70
+ "format": "json",
71
+ "ids": qid,
72
+ "props": "sitelinks/urls",
73
+ "sitefilter": f"{language}wiki",
74
+ }
75
+
76
+ response = requests.get(url, params=params)
77
+ data = response.json()
78
+
79
+ try:
80
+ title = data["entities"][qid]["sitelinks"][f"{language}wiki"]["title"]
81
+ url = data["entities"][qid]["sitelinks"][f"{language}wiki"]["url"]
82
+ return title, url
83
+ except KeyError:
84
+ return "NIL", "None"
85
+
86
+
87
+ class NelPipeline(Pipeline):
88
+
89
+ def _sanitize_parameters(self, **kwargs):
90
+ preprocess_kwargs = {}
91
+ if "text" in kwargs:
92
+ preprocess_kwargs["text"] = kwargs["text"]
93
+
94
+ return preprocess_kwargs, {}, {}
95
+
96
+ def preprocess(self, text, **kwargs):
97
+
98
+ outputs = self.model.generate(
99
+ **self.tokenizer([text], return_tensors="pt"),
100
+ num_beams=5,
101
+ num_return_sequences=5,
102
+ max_new_tokens=30,
103
+ )
104
+ wikipedia_predictons = self.tokenizer.batch_decode(
105
+ outputs, skip_special_tokens=True
106
+ )
107
+ print(f"Decoded: {wikipedia_predictons}")
108
+
109
+ return wikipedia_predictons
110
+
111
+ def _forward(self, inputs):
112
+ return inputs
113
+
114
+ def postprocess(self, outputs, **kwargs):
115
+ """
116
+ Postprocess the outputs of the model
117
+ :param outputs:
118
+ :param kwargs:
119
+ :return:
120
+ """
121
+ # outputs
122
+ #
123
+ # predictions = {}
124
+ # confidence_scores = {}
125
+ # for task, logits in tokens_result.logits.items():
126
+ # predictions[task] = torch.argmax(logits, dim=-1).tolist()[0]
127
+ # confidence_scores[task] = F.softmax(logits, dim=-1).tolist()[0]
128
+ #
129
+ # entities = {}
130
+ # for task in predictions.keys():
131
+ # words_list, preds_list, confidence_list = realign(
132
+ # text_sentence,
133
+ # predictions[task],
134
+ # confidence_scores[task],
135
+ # self.tokenizer,
136
+ # self.id2label[task],
137
+ # )
138
+ #
139
+ # entities[task] = get_entities(words_list, preds_list, confidence_list, text)
140
+ #
141
+ # postprocessed_entities = self.postprocess_entities(entities, text_sentence)
142
+ results = []
143
+ for wikipedia_name in outputs:
144
+ # Get QID
145
+ qid = get_wikipedia_page_props(wikipedia_name)
146
+ print(f"{wikipedia_name} -- QID: {qid}")
147
+
148
+ # Get Wikipedia title and URL
149
+ title, url = get_wikipedia_title(qid)
150
+ results.append({"title": title, "qid": qid, "url": url})
151
+
152
+ return results