|
|
|
|
|
|
|
|
|
|
|
import os |
|
from collections import Counter |
|
|
|
import torch |
|
from fairseq.file_io import PathManager |
|
from fairseq.tokenizer import tokenize_line |
|
from typing import List, Dict |
|
|
|
|
|
def safe_readline(f): |
|
pos = f.tell() |
|
while True: |
|
try: |
|
return f.readline() |
|
except UnicodeDecodeError: |
|
pos -= 1 |
|
f.seek(pos) |
|
|
|
|
|
class Binarizer: |
|
@staticmethod |
|
def binarize( |
|
filename, |
|
dict, |
|
consumer, |
|
tokenize=tokenize_line, |
|
append_eos=True, |
|
reverse_order=False, |
|
offset=0, |
|
end=-1, |
|
already_numberized=False, |
|
) -> Dict[str, int]: |
|
nseq, ntok = 0, 0 |
|
replaced = Counter() |
|
|
|
def replaced_consumer(word, idx): |
|
if idx == dict.unk_index and word != dict.unk_word: |
|
replaced.update([word]) |
|
|
|
with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: |
|
f.seek(offset) |
|
|
|
line = safe_readline(f) |
|
while line: |
|
|
|
|
|
|
|
|
|
|
|
|
|
if end > 0 and f.tell() > end and f.tell() < end + 2 ** 32: |
|
break |
|
if already_numberized: |
|
id_strings = line.strip().split() |
|
id_list = [int(id_string) for id_string in id_strings] |
|
if reverse_order: |
|
id_list.reverse() |
|
if append_eos: |
|
id_list.append(dict.eos()) |
|
ids = torch.IntTensor(id_list) |
|
else: |
|
ids = dict.encode_line( |
|
line=line, |
|
line_tokenizer=tokenize, |
|
add_if_not_exist=False, |
|
consumer=replaced_consumer, |
|
append_eos=append_eos, |
|
reverse_order=reverse_order, |
|
) |
|
nseq += 1 |
|
ntok += len(ids) |
|
consumer(ids) |
|
line = f.readline() |
|
return { |
|
"nseq": nseq, |
|
"nunk": sum(replaced.values()), |
|
"ntok": ntok, |
|
"replaced": replaced, |
|
} |
|
|
|
@staticmethod |
|
def binarize_alignments( |
|
filename, alignment_parser, consumer, offset=0, end=-1 |
|
) -> Dict[str, int]: |
|
nseq = 0 |
|
|
|
with open(PathManager.get_local_path(filename), "r") as f: |
|
f.seek(offset) |
|
line = safe_readline(f) |
|
while line: |
|
if end > 0 and f.tell() > end: |
|
break |
|
ids = alignment_parser(line) |
|
nseq += 1 |
|
consumer(ids) |
|
line = f.readline() |
|
return {"nseq": nseq} |
|
|
|
@staticmethod |
|
def find_offsets(filename, num_chunks) -> List[int]: |
|
with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: |
|
size = os.fstat(f.fileno()).st_size |
|
chunk_size = size // num_chunks |
|
offsets = [0 for _ in range(num_chunks + 1)] |
|
for i in range(1, num_chunks): |
|
f.seek(chunk_size * i) |
|
safe_readline(f) |
|
offsets[i] = f.tell() |
|
return offsets |
|
|