File size: 6,206 Bytes
d868172
 
357be93
7d0539f
d868172
 
 
357be93
 
d868172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4697040
d868172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4697040
d868172
516cd0a
d868172
516cd0a
d868172
516cd0a
d868172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357be93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d868172
a945a9c
bd79886
 
d868172
69c8f9c
 
d868172
b1746af
 
69c8f9c
b1746af
3df5bff
 
f437981
 
 
1ea9d7f
f437981
1ea9d7f
bffe103
 
987f96d
357be93
34d59ab
d868172
 
 
 
 
 
 
 
 
 
 
357be93
 
 
 
 
 
 
 
 
 
 
 
 
7d0539f
d868172
ddc72bb
d868172
516cd0a
8d26403
d868172
 
516cd0a
5b87284
357be93
7d0539f
357be93
 
71e1eca
 
357be93
9bbf34b
357be93
 
 
5b87284
d868172
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from transformers import Pipeline
import nltk
import requests
import torch

nltk.download("averaged_perceptron_tagger")
nltk.download("averaged_perceptron_tagger_eng")

NEL_MODEL = "nel-mgenre-multilingual"


def get_wikipedia_page_props(input_str: str):
    """
    Retrieves the QID for a given Wikipedia page name from the specified language Wikipedia.
    If the request fails, it falls back to using the OpenRefine Wikidata API.

    Args:
        input_str (str): The input string in the format "page_name >> language".

    Returns:
        str: The QID or "NIL" if the QID is not found.
    """
    try:
        # Preprocess the input string
        page_name, language = input_str.split(" >> ")
        page_name = page_name.strip()
        language = language.strip()
    except ValueError:
        return "Invalid input format. Use 'page_name >> language'.", "None"

    wikipedia_url = f"https://{language}.wikipedia.org/w/api.php"
    wikipedia_params = {
        "action": "query",
        "prop": "pageprops",
        "format": "json",
        "titles": page_name,
    }

    qid = "NIL"
    try:
        # Attempt to fetch from Wikipedia API
        response = requests.get(wikipedia_url, params=wikipedia_params)
        response.raise_for_status()
        data = response.json()

        if "pages" in data["query"]:
            page_id = list(data["query"]["pages"].keys())[0]

            if "pageprops" in data["query"]["pages"][page_id]:
                page_props = data["query"]["pages"][page_id]["pageprops"]

                if "wikibase_item" in page_props:
                    return page_props["wikibase_item"], language
                else:
                    return qid, language
            else:
                return qid, language
    except Exception as e:
        return qid, language


def get_wikipedia_title(qid, language="en"):
    url = f"https://www.wikidata.org/w/api.php"
    params = {
        "action": "wbgetentities",
        "format": "json",
        "ids": qid,
        "props": "sitelinks/urls",
        "sitefilter": f"{language}wiki",
    }

    response = requests.get(url, params=params)
    data = response.json()

    try:
        title = data["entities"][qid]["sitelinks"][f"{language}wiki"]["title"]
        url = data["entities"][qid]["sitelinks"][f"{language}wiki"]["url"]
        return title, url
    except KeyError:
        return "NIL", "None"


class NelPipeline(Pipeline):

    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        if "text" in kwargs:
            preprocess_kwargs["text"] = kwargs["text"]

        return preprocess_kwargs, {}, {}

    def preprocess(self, text, **kwargs):
        # Extract the entity between [START] and [END]
        start_token = "[START]"
        end_token = "[END]"

        if start_token in text and end_token in text:
            start_idx = text.index(start_token) + len(start_token)
            end_idx = text.index(end_token)
            enclosed_entity = text[start_idx:end_idx].strip()
            lOffset = start_idx  # left offset (start of the entity)
            rOffset = end_idx  # right offset (end of the entity)
        else:
            enclosed_entity = None
            lOffset = None
            rOffset = None

        # Generate predictions using the model
        outputs = self.model.generate(
            **self.tokenizer([text], return_tensors="pt").to(self.device),
            num_beams=1,
            num_return_sequences=1,
            max_new_tokens=30,
            return_dict_in_generate=True,
            output_scores=True,
        )
        # Decode the predictions into readable text
        wikipedia_predictions = self.tokenizer.batch_decode(
            outputs.sequences, skip_special_tokens=True
        )
        # Process the scores for each token

        transition_scores = self.model.compute_transition_scores(
            outputs.sequences, outputs.scores, normalize_logits=True
        )
        log_prob_sum = sum(transition_scores[0])

        # Calculate the probability for the entire sequence by exponentiating the sum of log probabilities
        sequence_confidence = torch.exp(log_prob_sum)
        percentages = sequence_confidence.cpu().numpy() * 100.0

        # Return the predictions along with the extracted entity, lOffset, and rOffset
        return wikipedia_predictions, enclosed_entity, lOffset, rOffset, [percentages]

    def _forward(self, inputs):
        return inputs

    def postprocess(self, outputs, **kwargs):
        """
        Postprocess the outputs of the model
        :param outputs:
        :param kwargs:
        :return:
        """

        # {
        #     "surface": sentences[i].split("[START]")[1].split("[END]")[0],
        #     "lOffset": lOffset,
        #     "rOffset": rOffset,
        #     "type": "UNK",
        #     "id": f"{lOffset}:{rOffset}:{surface}:{NEL_MODEL}",
        #     "wkd_id": get_wikipedia_page_props(wikipedia_titles[i * 2]),
        #     "wkpedia_pagename": wikipedia_titles[
        #         i * 2
        #         ],  # This can be improved with a real API call to get the QID
        #     "confidence_nel": np.round(percentages[i], 2),
        # }
        wikipedia_predictions, enclosed_entity, lOffset, rOffset, percentages = outputs
        results = []
        for idx, wikipedia_name in enumerate(wikipedia_predictions):
            # Get QID
            qid, language = get_wikipedia_page_props(wikipedia_name)
            # print(f"{wikipedia_name} -- QID: {qid}")

            # Get Wikipedia title and URL
            wkpedia_pagename, url = get_wikipedia_title(qid, language)
            results.append(
                {
                    # "id": f"{lOffset}:{rOffset}:{enclosed_entity}:{NEL_MODEL}",
                    "surface": enclosed_entity,
                    "wkd_id": qid,
                    "wkpedia_pagename": wkpedia_pagename,
                    "wkpedia_url": url,
                    "type": "UNK",
                    "confidence_nel": round(percentages[idx], 2),
                    "lOffset": lOffset,
                    "rOffset": rOffset,
                }
            )

        return results