leafspark's picture
feat(model): support using JSON config
641ee6f verified
raw
history blame
4.75 kB
import json
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, Optional, Type, Union
import torch
from typing_extensions import Self
import litgpt.model
from litgpt.utils import find_multiple
@dataclass
class Config:
name: str = ""
hf_config: dict = field(default_factory=dict)
scale_embeddings: bool = False
block_size: int = 4096
vocab_size: int = 50254
padding_multiple: int = 512
padded_vocab_size: Optional[int] = None
n_layer: int = 16
n_head: int = 32
head_size: Optional[int] = None
n_embd: int = 4096
rotary_percentage: float = 0.25
parallel_residual: bool = True
bias: bool = True
lm_head_bias: bool = False
n_query_groups: Optional[int] = None
shared_attention_norm: bool = False
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
norm_eps: float = 1e-5
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP"
gelu_approximate: str = "none"
intermediate_size: Optional[int] = None
rope_condense_ratio: int = 1
rope_base: int = 10000
n_expert: int = 0
n_expert_per_token: int = 0
add_qkv_bias: Optional[bool] = None
prompt_vocab_size: Optional[int] = None
attn_dropout: float = 0.0
pos_type: str = "rope"
force_align: bool = False
use_pretrain_phoneme_emb: bool = False
tie_word_embeddings: bool = False
# setting for mini-omni
text_vocab_size:int = 152000
cat_audio_vocab_size: int = 29120
audio_vocab_size: int = 4160
whisper_adapter_dim: int = 768
post_adapter: bool = False
post_adapter_layers: int = 6
asr_adapter: str = "llamamlp"
def __post_init__(self):
if not self.name:
self.name = self.hf_config.get("name", self.name)
if self.head_size is None:
assert self.n_embd % self.n_head == 0
self.head_size = self.n_embd // self.n_head
if self.padded_vocab_size is None:
self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple)
else:
self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
if self.n_query_groups is not None:
assert self.n_head % self.n_query_groups == 0
else:
self.n_query_groups = self.n_head
if self.intermediate_size is None:
if self.mlp_class_name == "LLaMAMLP":
raise ValueError(f"The config {self.name!r}, needs to set the `intermediate_size`")
self.intermediate_size = 4 * self.n_embd
self.rope_n_elem = int(self.rotary_percentage * self.head_size)
if self.add_qkv_bias is None:
self.add_qkv_bias = self.bias
@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
if name not in name_to_config:
try:
conf_dict = next(
config
for config in configs
if name == config["hf_config"]["name"]
or config["hf_config"]["org"] + "/" + config["hf_config"]["name"] == name
)
except StopIteration:
raise ValueError(f"{name!r} is not a supported config name")
else:
conf_dict = name_to_config[name]
conf_dict = conf_dict.copy()
conf_dict.update(kwargs)
return cls(**conf_dict)
@classmethod
def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
with open(path, encoding="utf-8") as fp:
file_kwargs = json.load(fp)
if file_kwargs is None:
raise ValueError(f"{path} is empty which is likely unexpected.")
file_kwargs.update(kwargs)
return cls(**file_kwargs)
@classmethod
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
if (config_path := path / "config.json").is_file():
return cls.from_file(config_path, **kwargs)
if (model_name := path.name) in name_to_config:
return cls.from_name(model_name, **kwargs)
raise FileNotFoundError(f"For {str(path)!r} neither 'config.json' nor matching config exists.")
@property
def mlp_class(self) -> Type:
return getattr(litgpt.model, self.mlp_class_name)
@property
def norm_class(self) -> Type:
if self.norm_class_name == "RMSNorm":
from functools import partial
from litgpt.model import RMSNorm
return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
return getattr(torch.nn, self.norm_class_name)
configs = []
name_to_config = {config["name"]: config for config in configs}