|
|
|
|
|
from __future__ import annotations |
|
|
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
|
|
import transformers |
|
from transformers import HfArgumentParser, TrainingArguments |
|
|
|
from flame.logging import get_logger |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
@dataclass |
|
class TrainingArguments(TrainingArguments): |
|
|
|
model_name_or_path: str = field( |
|
default=None, |
|
metadata={ |
|
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." |
|
}, |
|
) |
|
tokenizer: str = field( |
|
default="mistralai/Mistral-7B-v0.1", |
|
metadata={"help": "Name of the tokenizer to use."} |
|
) |
|
use_fast_tokenizer: bool = field( |
|
default=False, |
|
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, |
|
) |
|
from_config: bool = field( |
|
default=True, |
|
metadata={"help": "Whether to initialize models from scratch."}, |
|
) |
|
dataset: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "The dataset(s) to use. Use commas to separate multiple datasets."}, |
|
) |
|
dataset_name: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "The name of provided dataset(s) to use."}, |
|
) |
|
cache_dir: str = field( |
|
default=None, |
|
metadata={"help": "Path to the cached tokenized dataset."}, |
|
) |
|
split: str = field( |
|
default="train", |
|
metadata={"help": "Which dataset split to use for training and evaluation."}, |
|
) |
|
streaming: bool = field( |
|
default=False, |
|
metadata={"help": "Enable dataset streaming."}, |
|
) |
|
hf_hub_token: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "Auth token to log in with Hugging Face Hub."}, |
|
) |
|
preprocessing_num_workers: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "The number of processes to use for the pre-processing."}, |
|
) |
|
buffer_size: int = field( |
|
default=2048, |
|
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}, |
|
) |
|
context_length: int = field( |
|
default=2048, |
|
metadata={"help": "The context length of the tokenized inputs in the dataset."}, |
|
) |
|
|
|
|
|
def get_train_args(): |
|
parser = HfArgumentParser(TrainingArguments) |
|
args, unknown_args = parser.parse_args_into_dataclasses(return_remaining_strings=True) |
|
|
|
if unknown_args: |
|
print(parser.format_help()) |
|
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args)) |
|
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args)) |
|
|
|
if args.should_log: |
|
transformers.utils.logging.set_verbosity(args.get_process_log_level()) |
|
transformers.utils.logging.enable_default_handler() |
|
transformers.utils.logging.enable_explicit_format() |
|
|
|
transformers.set_seed(args.seed) |
|
return args |
|
|