|
from __future__ import annotations |
|
from .mecab_tokenizer import MeCabTokenizer |
|
import os |
|
|
|
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} |
|
|
|
|
|
def save_stoi(stoi: dict[str, int], vocab_file: str): |
|
with open(vocab_file, "w", encoding="utf-8") as writer: |
|
index = 0 |
|
for token, token_index in sorted(stoi.items(), key=lambda kv: kv[1]): |
|
if index != token_index: |
|
raise ValueError( |
|
"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." |
|
" Please check that the vocabulary is not corrupted!") |
|
writer.write(token + "\n") |
|
index += 1 |
|
|
|
|
|
def load_stoi(vocab_file: str) -> dict[str, int]: |
|
stoi: dict[str, int] = {} |
|
with open(vocab_file, "r", encoding="utf-8") as reader: |
|
tokens = reader.readlines() |
|
for index, token in enumerate(tokens): |
|
token = token.rstrip("\n") |
|
stoi[token] = index |
|
return stoi |
|
|
|
|
|
class FastTextJpTokenizer(MeCabTokenizer): |
|
vocab_files_names = VOCAB_FILES_NAMES |
|
|
|
def __init__(self, |
|
vocab_file: str, |
|
hinshi: list[str] | None = None, |
|
mecab_dicdir: str | None = None, |
|
**kwargs): |
|
"""初期化処理 |
|
|
|
Args: |
|
vocab_file (str): vocab_fileのpath |
|
hinshi (list[str] | None, optional): 抽出する品詞 |
|
mecab_dicdir (str | None, optional): dicrcのあるディレクトリ |
|
""" |
|
super().__init__(hinshi, mecab_dicdir, **kwargs) |
|
|
|
if not os.path.isfile(vocab_file): |
|
raise ValueError( |
|
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" |
|
" model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" |
|
) |
|
self.stoi = load_stoi(vocab_file) |
|
self.itos = dict([(ids, tok) for tok, ids in self.stoi.items()]) |
|
self.v_size = len(self.stoi) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
def vocab_size(self) -> int: |
|
""" |
|
`int`: Size of the base vocabulary (without the added tokens). |
|
""" |
|
return self.v_size |
|
|
|
def _convert_token_to_id(self, token: str) -> int: |
|
return self.stoi[token] |
|
|
|
def _convert_id_to_token(self, index: int) -> str: |
|
return self.itos[index] |
|
|
|
def save_vocabulary(self, |
|
save_directory: str, |
|
filename_prefix: str | None = None) -> tuple[str]: |
|
index = 0 |
|
if os.path.isdir(save_directory): |
|
vocab_file = os.path.join( |
|
save_directory, |
|
(filename_prefix + "-" if filename_prefix else "") + |
|
"vocab.txt") |
|
else: |
|
vocab_file = (filename_prefix + |
|
"-" if filename_prefix else "") + save_directory |
|
save_stoi(self.stoi, vocab_file) |
|
return (vocab_file, ) |
|
|
|
|
|
FastTextJpTokenizer.register_for_auto_class("AutoTokenizer") |
|
|