Spaces:
Runtime error
Runtime error
from typing import List, Optional | |
MODEL_SELECTION_ID: str = "model_selection" | |
MODEL_VERSION_SELECTION_ID: str = "model_version_selection" | |
LOAD_IN_4_BIT_ID: str = "load_in_4bit" | |
BNB_4BIT_QUANT_TYPE: str = "bnb_4bit_quant_type" | |
BNB_4BIT_COMPUTE_DTYPE: str = "bnb_4bit_compute_dtype" | |
BNB_4BIT_USE_DOUBLE_QUANT: str = "bnb_4bit_use_double_quant" | |
DATASET_SELECTION_ID = "dataset_selection" | |
DATASET_SHUFFLING_SEED = "dataset_seed" | |
FLASH_ATTENTION_ID = "flash_attention" | |
PAD_SIDE_ID = "pad_side" | |
PAD_VALUE_ID = "pad_value" | |
LORA_R_ID = "lora_r" | |
LORA_ALPHA_ID = "lora_alpha" | |
LORA_DROPOUT_ID = "lora_dropout" | |
LORA_BIAS_ID = 'lora_bias' | |
NUM_TRAIN_EPOCHS_ID = "num_train_epochs" | |
MAX_STEPS_ID = "max_steps_id" | |
LOGGING_STEPS_ID = "logging_steps" | |
PER_DEVICE_TRAIN_BATCH_SIZE = "per_device_train_batch_size" | |
SAVE_STRATEGY_ID = "save_strategy" | |
GRADIENT_ACCUMULATION_STEPS_ID = "gradient_accumulation_steps" | |
GRADIENT_CHECKPOINTING_ID = "gradient_checkpointing" | |
LEARNING_RATE_ID = "learning_rate" | |
MAX_GRAD_NORM_ID = "max_grad_norm" | |
WARMUP_RATIO_ID = "warmup_ratio" | |
LR_SCHEDULER_TYPE_ID = "lr_scheduler_type" | |
OUTPUT_DIR_ID = "output_dir" | |
PUSH_TO_HUB_ID = "push_to_hub" | |
REPOSITORY_NAME_ID = "repo_id" | |
REPORT_TO_ID = "report_to" | |
README_ID = "readme" | |
MAX_SEQ_LENGTH_ID = "max_seq_length" | |
PACKING_ID = "packing" | |
OPTIMIZER_ID = "optim" | |
BETA1_ID = "adam_beta1" | |
BETA2_ID = "adam_beta2" | |
EPSILON_ID = "adam_epsilon" | |
WEIGHT_DECAY_ID = "weight_decay" | |
class FTDataSet: | |
def __init__(self, path: str, dataset_split: Optional[str] = None): | |
self.path = path | |
self.dataset_split = dataset_split | |
def __str__(self): | |
return self.path | |
deita_dataset = FTDataSet(path="HuggingFaceH4/deita-10k-v0-sft", dataset_split="train_sft") | |
dolly = FTDataSet(path="philschmid/dolly-15k-oai-style", dataset_split="train") | |
ultrachat_200k = FTDataSet(path="HuggingFaceH4/ultrachat_200k", dataset_split="train_sft") | |
ft_datasets = [deita_dataset, dolly, ultrachat_200k] | |
class Model: | |
def __init__(self, name: str, versions: List[str]): | |
self.name = name | |
self.versions = versions | |
def __str__(self): | |
return self.name | |
models: List[Model] = [] | |
gemma = Model(name="google/gemma", versions=["7b", "2b"]) | |
models.append(gemma) | |
falcon = Model(name="tiiuae/falcon", versions=["7b"]) # "7b-instruct" | |
models.append(falcon) | |
phi = Model(name="microsoft/phi", versions=["1_5", "1", "2"]) | |
models.append(phi) | |
llama = Model(name="meta-llama/Llama-2", versions=["7b", "7b-hf"]) # "7b-chat", "7b-chat-hf" | |
models.append(llama) | |
mistral = Model(name="mistralai/Mistral", versions=["7B-v0.1"]) # "7B-Instruct-v0.1" | |
models.append(mistral) | |
tinyLlama = Model(name="TinyLlama/TinyLlama-1.1B", | |
versions=['intermediate-step-1431k-3T', 'step-50K-105b', 'intermediate-step-240k-503b', | |
'intermediate-step-715k-1.5T', 'intermediate-step-1195k-token-2.5T']) | |
models.append(tinyLlama) | |