File size: 3,903 Bytes
a0806ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# -*- coding: utf-8 -*-

from __future__ import annotations

import argparse
import logging
from itertools import chain
from typing import Any, Dict, List, Optional

from datasets import load_dataset
from transformers import AutoTokenizer

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


def tokenize(
    examples: Dict[str, List[Any]],
    tokenizer: AutoTokenizer,
    context_length: int
) -> Dict[str, List[List[int]]]:
    """
    Tokenize the input text and split into chunks of specified context length.

    Args:
        examples:
            Dictionary containing the input text.
        tokenizer:
            Initialized tokenizer.
        context_length:
            Length of each context chunk.

    Returns:
        Dictionary containing tokenized and chunked input ids
    """
    text = examples['text']
    input_ids = tokenizer(text)['input_ids']
    input_ids = list(chain(*input_ids))
    total_length = len(input_ids)
    total_length = (total_length // context_length) * context_length
    # The last chunk smaller than context_length will be discarded
    return {'input_ids': [input_ids[i:i+context_length] for i in range(0, total_length, context_length)]}


def preprocess(
    dataset: str,
    name: Optional[str] = None,
    split: str = 'train',
    output: str = 'data',
    model: str = 'mistralai/Mistral-7B-v0.1',
    num_proc: int = 64,
    context_length: int = 8192
) -> None:
    """
    Load, tokenize, and save the processed dataset.

    Args:
        dataset:
            Path or name of the dataset.
        name:
            Name of the dataset configuration.
        split:
            Dataset split to process.
        output:
            Output directory.
        model:
            Model name for tokenizer.
        num_proc:
            Number of processes for parallel processing.
        context_length:
            Context length for tokenization.
    """
    tokenized_path = f'{output}/{dataset}/{name}/{split}' if name is not None else f'{output}/{dataset}/{split}'

    logging.info(f'Initializing tokenizer of {model}')
    tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
    logging.info(f'Tokenizer initialized: {tokenizer}')

    logging.info(f'Loading dataset: {dataset}')
    dataset = load_dataset(dataset, name=name, split=split)

    remove_columns = list(next(iter(dataset)).keys())
    logging.info('Tokenizing and processing dataset')
    dataset = dataset.map(
        lambda examples: tokenize(examples, tokenizer, context_length),
        batched=True,
        remove_columns=remove_columns,
        num_proc=num_proc,
        desc="Running tokenizer on dataset"
    )

    logging.info(f'Saving processed dataset to {tokenized_path}')
    dataset.save_to_disk(tokenized_path, num_proc=num_proc)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Preprocess and tokenize dataset")
    parser.add_argument("--dataset", default="HuggingFaceFW/fineweb-edu", help="Path or name of the dataset")
    parser.add_argument("--name", default=None, help="Name of the dataset configuration")
    parser.add_argument("--split", default="train", help="Dataset split to process")
    parser.add_argument("--output", default="data", help="Output directory")
    parser.add_argument("--model", default="mistralai/Mistral-7B-v0.1", help="Model name for tokenizer")
    parser.add_argument("--num_proc", type=int, default=64, help="Number of processes for parallel processing")
    parser.add_argument("--context_length", type=int, default=8192, help="Context length for tokenization")
    args = parser.parse_args()

    preprocess(
        dataset=args.dataset,
        name=args.name,
        split=args.split,
        output=args.output,
        model=args.model,
        num_proc=args.num_proc,
        context_length=args.context_length
    )