File size: 4,873 Bytes
76b4794
 
 
 
 
 
 
 
6ee9897
 
 
 
 
 
 
 
 
 
76b4794
 
 
 
 
 
 
 
 
 
 
 
6ee9897
 
 
 
 
 
 
 
 
76b4794
6ee9897
76b4794
 
6ee9897
 
76b4794
 
 
 
 
 
 
6ee9897
 
 
3ba50ba
 
6ee9897
76b4794
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ee9897
 
 
 
 
76b4794
6ee9897
76b4794
 
6ee9897
 
 
 
 
 
 
 
 
8acc99c
 
 
 
76b4794
 
6ee9897
 
 
 
 
 
 
 
 
76b4794
 
 
 
 
6ee9897
 
 
 
 
 
 
 
 
76b4794
 
 
 
6ee9897
76b4794
 
 
 
 
 
 
6ee9897
 
76b4794
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from __future__ import annotations
from .mecab_tokenizer import MeCabTokenizer
import os

VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}


def save_stoi(stoi: dict[str, int], vocab_file: str):
    """単語IDの辞書を配列にしてvocab_fileに保存します。

    Args:
        stoi (dict[str, int]): 単語IDのマッピング
        vocab_file (str): 保存するパス

    Raises:
        ValueError: IDが途切れているとエラーを起こします。
    """

    with open(vocab_file, "w", encoding="utf-8") as writer:
        index = 0
        for token, token_index in sorted(stoi.items(), key=lambda kv: kv[1]):
            if index != token_index:
                raise ValueError(
                    "Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
                    " Please check that the vocabulary is not corrupted!")
            writer.write(token + "\n")
            index += 1


def load_stoi(vocab_file: str) -> dict[str, int]:
    """ファイルから単語IDの辞書をロードします。

    Args:
        vocab_file (str): ファイルのパス

    Returns:
        dict[str, int]: 単語IDのマッピング
    """

    stoi: dict[str, int] = {}
    # ファイルから読み出し
    with open(vocab_file, "r", encoding="utf-8") as reader:
        tokens = reader.readlines()

    # 単語IDのマッピングを生成します。
    for index, token in enumerate(tokens):
        token = token.rstrip("\n")
        stoi[token] = index
    return stoi


class FastTextJpTokenizer(MeCabTokenizer):

    # Configが認識するのに必要です。
    # https://huggingface.co/docs/transformers/custom_models#writing-a-custom-configuration
    model_type = "fasttext_jp"

    # vocab.txtを認識するのにおそらく必要。
    vocab_files_names = VOCAB_FILES_NAMES

    def __init__(self,
                 vocab_file: str,
                 hinshi: list[str] | None = None,
                 mecab_dicdir: str | None = None,
                 **kwargs):
        """初期化処理

        Args:
            vocab_file (str): vocab_fileのpath
            hinshi (list[str] | None, optional): 抽出する品詞
            mecab_dicdir (str | None, optional): dicrcのあるディレクトリ
        """
        super().__init__(hinshi, mecab_dicdir, **kwargs)

        if not os.path.isfile(vocab_file):
            raise ValueError(
                f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
                " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
            )
        self.stoi = load_stoi(vocab_file)
        self.itos = dict([(ids, tok) for tok, ids in self.stoi.items()])

    @property
    def vocab_size(self) -> int:
        """ボキャブラリのサイズ
        ※PreTrainedTokenizerで実装すべき必須の関数。

        Returns:
            int: ボキャブラリのサイズ
        """
        return len(self.stoi)

    def _convert_token_to_id(self, token: str) -> int:
        """単語からID
        ※PreTrainedTokenizerで実装すべき必須の関数。
        
        Args:
            token (str): 単語

        Returns:
            int: ID
        """
        id = self.stoi.get(token)
        if id is not None:
            return id
        return self.stoi[self.unk_token]

    def _convert_id_to_token(self, index: int) -> str:
        """IDから単語
        ※PreTrainedTokenizerで実装すべき必須の関数。
        
        Args:
            index (int): ID

        Returns:
            str: 単語
        """
        return self.itos[index]

    def save_vocabulary(self,
                        save_directory: str,
                        filename_prefix: str | None = None) -> tuple[str]:
        """ボキャブラリの保存

        Args:
            save_directory (str): 保存するディレクトリ。ファイル名はvocab.txtに固定
            filename_prefix (str | None, optional): ファイルのprefix

        Returns:
            tuple[str]: ファイル名を返す。
        """
        if os.path.isdir(save_directory):
            vocab_file = os.path.join(
                save_directory,
                (filename_prefix + "-" if filename_prefix else "") +
                VOCAB_FILES_NAMES["vocab_file"])
        else:
            vocab_file = (filename_prefix +
                          "-" if filename_prefix else "") + save_directory
        save_stoi(self.stoi, vocab_file)
        return (vocab_file, )


# AutoTokenizerに登録が必要だが、いろいろやり方が変わっているようで定まっていない。(2022/11/6)
# https://huggingface.co/docs/transformers/custom_models#sending-the-code-to-the-hub
FastTextJpTokenizer.register_for_auto_class("AutoTokenizer")