parser / udpipe2 /udpipe2_dataset.py
anasampa2's picture
Upload 110 files
0e5da39 verified
raw
history blame
20.5 kB
# This file is part of UDPipe 2 <http://github.com/ufal/udpipe>.
#
# Copyright 2020 Institute of Formal and Applied Linguistics, Faculty of
# Mathematics and Physics, Charles University in Prague, Czech Republic.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
import io
import pickle
import re
import sys
import numpy as np
__version__ = "2.1.1-dev"
class UDPipe2Dataset:
FORMS = 0
LEMMAS = 1
UPOS = 2
XPOS = 3
FEATS = 4
HEAD = 5
DEPREL = 6
DEPS = 7
MISC = 8
FACTORS = 9
VARIANT = 9
EMBEDDINGS = 10
FACTORS_MAP = {"FORMS": FORMS, "LEMMAS": LEMMAS, "UPOS": UPOS, "XPOS": XPOS, "FEATS": FEATS,
"HEAD": HEAD, "DEPREL": DEPREL, "DEPS": DEPS, "MISC": MISC}
re_extras = re.compile(r"^#|^\d+-|^\d+\.")
re_variant = re.compile(r"^#\s*variant\s*=\s*(\S+)")
class _Factor:
ROOT = 2
def __init__(self, with_root, characters, train=None):
self.words_map = train.words_map if train else {'<pad>': 0, '<unk>': 1, '<root>': 2}
self.words = train.words if train else ['<pad>', '<unk>', '<root>']
self.word_ids = []
self.strings = []
self.with_root = with_root
self.characters = characters
if characters:
self.alphabet_map = train.alphabet_map if train else {'<pad>': 0, '<unk>': 1, '<root>': 2}
self.alphabet = train.alphabet if train else ['<pad>', '<unk>', '<root>']
self.charseqs_map = {'<pad>': 0, '<unk>': 1, '<root>': 2}
self.charseqs = [[0], [1], [2]]
self.charseq_ids = []
def __init__(self, path=None, text=None, embeddings=[], train=None, shuffle_batches=True,
override_variant=None, max_sentence_len=None, max_sentences=None):
# Create factors and other variables
self._factors = []
for f in range(self.FACTORS):
self._factors.append(self._Factor(f == self.FORMS, f == self.FORMS, train._factors[f] if train else None))
self._extras = []
form_dict = {}
self._lr_allow_copy = train._lr_allow_copy if train else None
lemma_dict_with_copy, lemma_dict_no_copy = {}, {}
self._variant_map = train._variant_map if train else {}
self._variants = []
# Load contextualized embeddings
if isinstance(embeddings, list) and all(isinstance(embedding, np.ndarray) for embedding in embeddings):
self._embeddings = embeddings
else:
self._embeddings = []
for embeddings_path in embeddings:
with np.load(embeddings_path, allow_pickle=True) as embeddings_file:
for i, (_, value) in enumerate(embeddings_file.items()):
if max_sentence_len: value = value[:max_sentence_len]
if i >= len(self._embeddings): self._embeddings.append(value)
else: self._embeddings[i] = np.concatenate([self._embeddings[i], value], axis=1)
assert i + 1 == len(self._embeddings)
self._embeddings_size = self._embeddings[0].shape[1] if self._embeddings else 0
# Load the sentences
with open(path, "r", encoding="utf-8") if path is not None else io.StringIO(text) as file:
in_sentence = False
variant = ""
for line in file:
line = line.rstrip("\r\n")
if line:
if self.re_extras.match(line):
variant_match = self.re_variant.match(line)
if variant_match:
variant = variant_match.group(1)
if in_sentence:
while len(self._extras) < len(self._factors[0].word_ids): self._extras.append([])
while len(self._extras[-1]) <= len(self._factors[0].word_ids[-1]) - self._factors[0].with_root:
self._extras[-1].append("")
else:
while len(self._extras) <= len(self._factors[0].word_ids): self._extras.append([])
if not len(self._extras[-1]): self._extras[-1].append("")
self._extras[-1][-1] += ("\n" if self._extras[-1][-1] else "") + line
continue
if max_sentence_len and in_sentence and len(self._factors[0].strings[-1]) - self._factors[0].with_root >= max_sentence_len:
continue
columns = line.split("\t")[1:]
for f in range(self.FACTORS):
factor = self._factors[f]
if not in_sentence:
if len(factor.word_ids): factor.word_ids[-1] = np.array(factor.word_ids[-1], np.int32)
factor.word_ids.append([])
factor.strings.append([])
if factor.characters: factor.charseq_ids.append([])
if factor.with_root:
factor.word_ids[-1].append(factor.ROOT)
factor.strings[-1].append(factor.words[factor.ROOT])
if factor.characters: factor.charseq_ids[-1].append(factor.ROOT)
word = columns[f]
factor.strings[-1].append(word)
# Preprocess word
if f == self.LEMMAS and self._lr_allow_copy is not None:
word = self._gen_lemma_rule(columns[self.FORMS], columns[self.LEMMAS], self._lr_allow_copy)
# Character-level information
if factor.characters:
if word not in factor.charseqs_map:
factor.charseqs_map[word] = len(factor.charseqs)
factor.charseqs.append([])
for c in word:
if c not in factor.alphabet_map:
if train:
c = '<unk>'
else:
factor.alphabet_map[c] = len(factor.alphabet)
factor.alphabet.append(c)
factor.charseqs[-1].append(factor.alphabet_map[c])
factor.charseq_ids[-1].append(factor.charseqs_map[word])
# Word-level information
if f == self.HEAD:
factor.word_ids[-1].append(int(word) if word != "_" else -1)
elif f == self.FORMS and not train:
factor.word_ids[-1].append(0)
form_dict[word] = form_dict.get(word, 0) + 1
elif f == self.LEMMAS and self._lr_allow_copy is None:
factor.word_ids[-1].append(0)
lemma_dict_with_copy[self._gen_lemma_rule(columns[self.FORMS], word, True)] = 1
lemma_dict_no_copy[self._gen_lemma_rule(columns[self.FORMS], word, False)] = 1
else:
if word not in factor.words_map:
if train:
word = '<unk>'
else:
factor.words_map[word] = len(factor.words)
factor.words.append(word)
factor.word_ids[-1].append(factor.words_map[word])
if not in_sentence:
if override_variant is not None: variant = override_variant
if (variant not in self._variant_map) and (not train):
self._variant_map[variant] = len(self._variant_map)
self._variants.append(self._variant_map.get(variant, 0))
in_sentence = True
else:
in_sentence = False
if max_sentences is not None and len(self._factors[self.FORMS].word_ids) >= max_sentences:
break
# Finalize forms if needed
if not train:
forms = self._factors[self.FORMS]
for i in range(len(forms.word_ids)):
for j in range(forms.with_root, len(forms.word_ids[i])):
word = "<unk>" if form_dict[forms.strings[i][j]] < 2 else forms.strings[i][j]
if word not in forms.words_map:
forms.words_map[word] = len(forms.words)
forms.words.append(word)
forms.word_ids[i][j] = forms.words_map[word]
# Finalize lemmas if needed
if self._lr_allow_copy is None:
self._lr_allow_copy = True if len(lemma_dict_with_copy) < len(lemma_dict_no_copy) else False
lemmas = self._factors[self.LEMMAS]
for i in range(len(lemmas.word_ids)):
for j in range(lemmas.with_root, len(lemmas.word_ids[i])):
word = self._gen_lemma_rule(self._factors[self.FORMS].strings[i][j - lemmas.with_root + self._factors[self.FORMS].with_root],
lemmas.strings[i][j], self._lr_allow_copy)
if word not in lemmas.words_map:
lemmas.words_map[word] = len(lemmas.words)
lemmas.words.append(word)
lemmas.word_ids[i][j] = lemmas.words_map[word]
# Compute sentence lengths
sentences = len(self._factors[self.FORMS].word_ids)
self._sentence_lens = np.zeros([sentences], np.int32)
for i in range(len(self._factors[self.FORMS].word_ids)):
self._sentence_lens[i] = len(self._factors[self.FORMS].word_ids[i]) - self._factors[self.FORMS].with_root
self._shuffle_batches = shuffle_batches
self._permutation = np.random.permutation(len(self._sentence_lens)) if self._shuffle_batches else np.arange(len(self._sentence_lens))
if self._embeddings:
assert sentences == len(self._embeddings)
for i in range(sentences):
assert self._sentence_lens[i] == len(self._embeddings[i]), "{} {} {}".format(i, self._sentence_lens[i], len(self._embeddings[i]))
@property
def sentence_lens(self):
return self._sentence_lens
@property
def factors(self):
return self._factors
@property
def variants(self):
return len(self._variant_map)
@property
def embeddings_size(self):
return self._embeddings_size
def save_mappings(self, path):
mappings = UDPipe2Dataset.__new__(UDPipe2Dataset)
for field in ["_lr_allow_copy", "_variant_map", "_embeddings_size"]:
setattr(mappings, field, getattr(self, field))
mappings._factors = []
for factor in self._factors:
mappings._factors.append(mappings._Factor(factor.with_root, factor.characters, factor))
with open(path, "wb") as mappings_file:
pickle.dump(mappings, mappings_file, protocol=3)
@staticmethod
def load_mappings(path):
with open(path, "rb") as mappings_file:
return pickle.load(mappings_file)
def epoch_finished(self):
if len(self._permutation) == 0:
self._permutation = np.random.permutation(len(self._sentence_lens)) if self._shuffle_batches else np.arange(len(self._sentence_lens))
return True
return False
def next_batch(self, batch_size, max_form_length=64):
batch_size = min(batch_size, len(self._permutation))
batch_perm = self._permutation[:batch_size]
self._permutation = self._permutation[batch_size:]
# General data
batch_sentence_lens = self._sentence_lens[batch_perm]
max_sentence_len = np.max(batch_sentence_lens)
# Word-level data
batch_word_ids = []
for factor in self._factors:
batch_word_ids.append(np.zeros([batch_size, max_sentence_len + factor.with_root], np.int32))
for i in range(batch_size):
batch_word_ids[-1][i, 0:batch_sentence_lens[i] + factor.with_root] = factor.word_ids[batch_perm[i]]
# Variants
batch_word_ids.append(np.zeros([batch_size], np.int32))
for i in range(batch_size):
batch_word_ids[-1][i] = self._variants[batch_perm[i]]
# Contextualized embeddings
if self._embeddings:
forms = self._factors[self.FORMS]
batch_word_ids.append(np.zeros([batch_size, max_sentence_len + forms.with_root, self.embeddings_size], np.float16))
for i in range(batch_size):
batch_word_ids[-1][i, forms.with_root:forms.with_root + len(self._embeddings[batch_perm[i]])] = \
self._embeddings[batch_perm[i]]
# Character-level data
batch_charseq_ids, batch_charseqs, batch_charseq_lens = [], [], []
for factor in self._factors:
if not factor.characters:
batch_charseq_ids.append([])
batch_charseqs.append([])
batch_charseq_lens.append([])
continue
batch_charseq_ids.append(np.zeros([batch_size, max_sentence_len + factor.with_root], np.int32))
charseqs_map = {}
charseqs = []
charseq_lens = []
for i in range(batch_size):
for j, charseq_id in enumerate(factor.charseq_ids[batch_perm[i]]):
if charseq_id not in charseqs_map:
charseqs_map[charseq_id] = len(charseqs)
charseqs.append(factor.charseqs[charseq_id][:max_form_length])
batch_charseq_ids[-1][i, j] = charseqs_map[charseq_id]
batch_charseq_lens.append(np.array([len(charseq) for charseq in charseqs], np.int32))
batch_charseqs.append(np.zeros([len(charseqs), np.max(batch_charseq_lens[-1])], np.int32))
for i in range(len(charseqs)):
batch_charseqs[-1][i, 0:len(charseqs[i])] = charseqs[i]
return self._sentence_lens[batch_perm], batch_word_ids, batch_charseq_ids, batch_charseqs, batch_charseq_lens
def write_sentence(self, output, index, overrides):
for i in range(self._sentence_lens[index] + 1):
# Start by writing extras
if index < len(self._extras) and i < len(self._extras[index]) and self._extras[index][i]:
print(self._extras[index][i], file=output)
if i == self._sentence_lens[index]: break
fields = []
fields.append(str(i + 1))
for f in range(self.FACTORS):
factor = self._factors[f]
offset = i + factor.with_root
field = factor.strings[index][offset]
# Overrides
if overrides is not None and f < len(overrides) and overrides[f] is not None:
override = overrides[f][offset]
if f == self.HEAD:
field = str(override) if override >= 0 else "_"
elif (f == self.LEMMAS or f == self.XPOS) and isinstance(override, str):
field = override
else:
field = factor.words[override]
if f == self.LEMMAS:
try:
field = self._apply_lemma_rule(fields[-1], field)
except:
print("Applying lemma rule failed for form '{}' and rule '{}', using the form as lemma".format(
fields[-1], field), file=sys.stderr)
field = fields[-1]
# Do not generate empty lemmas
field = field or fields[-1]
fields.append(field)
print("\t".join(fields), file=output)
print(file=output)
@staticmethod
def _min_edit_script(source, target, allow_copy):
a = [[(len(source) + len(target) + 1, None)] * (len(target) + 1) for _ in range(len(source) + 1)]
for i in range(0, len(source) + 1):
for j in range(0, len(target) + 1):
if i == 0 and j == 0:
a[i][j] = (0, "")
else:
if allow_copy and i and j and source[i - 1] == target[j - 1] and a[i-1][j-1][0] < a[i][j][0]:
a[i][j] = (a[i-1][j-1][0], a[i-1][j-1][1] + "→")
if i and a[i-1][j][0] < a[i][j][0]:
a[i][j] = (a[i-1][j][0] + 1, a[i-1][j][1] + "-")
if j and a[i][j-1][0] < a[i][j][0]:
a[i][j] = (a[i][j-1][0] + 1, a[i][j-1][1] + "+" + target[j - 1])
return a[-1][-1][1]
@staticmethod
def _gen_lemma_rule(form, lemma, allow_copy):
form = form.lower()
previous_case = -1
lemma_casing = ""
for i, c in enumerate(lemma):
case = "↑" if c.lower() != c else "↓"
if case != previous_case:
lemma_casing += "{}{}{}".format("¦" if lemma_casing else "", case, i if i <= len(lemma) // 2 else i - len(lemma))
previous_case = case
lemma = lemma.lower()
best, best_form, best_lemma = 0, 0, 0
for l in range(len(lemma)):
for f in range(len(form)):
cpl = 0
while f + cpl < len(form) and l + cpl < len(lemma) and form[f + cpl] == lemma[l + cpl]: cpl += 1
if cpl > best:
best = cpl
best_form = f
best_lemma = l
rule = lemma_casing + ";"
if not best:
rule += "a" + lemma
else:
rule += "d{}¦{}".format(
UDPipe2Dataset._min_edit_script(form[:best_form], lemma[:best_lemma], allow_copy),
UDPipe2Dataset._min_edit_script(form[best_form + best:], lemma[best_lemma + best:], allow_copy),
)
return rule
@staticmethod
def _apply_lemma_rule(form, lemma_rule):
casing, rule = lemma_rule.split(";", 1)
if rule.startswith("a"):
lemma = rule[1:]
else:
form = form.lower()
rules, rule_sources = rule[1:].split("¦"), []
assert len(rules) == 2
for rule in rules:
source, i = 0, 0
while i < len(rule):
if rule[i] == "→" or rule[i] == "-":
source += 1
else:
assert rule[i] == "+"
i += 1
i += 1
rule_sources.append(source)
try:
lemma, form_offset = "", 0
for i in range(2):
j, offset = 0, (0 if i == 0 else len(form) - rule_sources[1])
while j < len(rules[i]):
if rules[i][j] == "→":
lemma += form[offset]
offset += 1
elif rules[i][j] == "-":
offset += 1
else:
assert(rules[i][j] == "+")
lemma += rules[i][j + 1]
j += 1
j += 1
if i == 0:
lemma += form[rule_sources[0] : len(form) - rule_sources[1]]
except:
lemma = form
for rule in casing.split("¦"):
if rule == "↓0": continue # The lemma is lowercased initially
if not rule: continue # Empty lemma might generate empty casing rule
case, offset = rule[0], int(rule[1:])
lemma = lemma[:offset] + (lemma[offset:].upper() if case == "↑" else lemma[offset:].lower())
return lemma