import json from copy import deepcopy from dataclasses import dataclass, field from pathlib import Path from typing import Any, Literal, Optional, Type, Union import torch from typing_extensions import Self import litgpt.model from litgpt.utils import find_multiple @dataclass class Config: name: str = "" hf_config: dict = field(default_factory=dict) scale_embeddings: bool = False block_size: int = 4096 vocab_size: int = 50254 padding_multiple: int = 512 padded_vocab_size: Optional[int] = None n_layer: int = 16 n_head: int = 32 head_size: Optional[int] = None n_embd: int = 4096 rotary_percentage: float = 0.25 parallel_residual: bool = True bias: bool = True lm_head_bias: bool = False n_query_groups: Optional[int] = None shared_attention_norm: bool = False norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" norm_eps: float = 1e-5 mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP" gelu_approximate: str = "none" intermediate_size: Optional[int] = None rope_condense_ratio: int = 1 rope_base: int = 10000 n_expert: int = 0 n_expert_per_token: int = 0 add_qkv_bias: Optional[bool] = None prompt_vocab_size: Optional[int] = None attn_dropout: float = 0.0 pos_type: str = "rope" force_align: bool = False use_pretrain_phoneme_emb: bool = False tie_word_embeddings: bool = False # setting for mini-omni text_vocab_size:int = 152000 cat_audio_vocab_size: int = 29120 audio_vocab_size: int = 4160 whisper_adapter_dim: int = 768 post_adapter: bool = False post_adapter_layers: int = 6 asr_adapter: str = "llamamlp" def __post_init__(self): if not self.name: self.name = self.hf_config.get("name", self.name) if self.head_size is None: assert self.n_embd % self.n_head == 0 self.head_size = self.n_embd // self.n_head if self.padded_vocab_size is None: self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple) else: self.vocab_size = min(self.vocab_size, self.padded_vocab_size) if self.n_query_groups is not None: assert self.n_head % self.n_query_groups == 0 else: self.n_query_groups = self.n_head if self.intermediate_size is None: if self.mlp_class_name == "LLaMAMLP": raise ValueError(f"The config {self.name!r}, needs to set the `intermediate_size`") self.intermediate_size = 4 * self.n_embd self.rope_n_elem = int(self.rotary_percentage * self.head_size) if self.add_qkv_bias is None: self.add_qkv_bias = self.bias @classmethod def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]: if name not in name_to_config: try: conf_dict = next( config for config in configs if name == config["hf_config"]["name"] or config["hf_config"]["org"] + "/" + config["hf_config"]["name"] == name ) except StopIteration: raise ValueError(f"{name!r} is not a supported config name") else: conf_dict = name_to_config[name] conf_dict = conf_dict.copy() conf_dict.update(kwargs) return cls(**conf_dict) @classmethod def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self: with open(path, encoding="utf-8") as fp: file_kwargs = json.load(fp) if file_kwargs is None: raise ValueError(f"{path} is empty which is likely unexpected.") file_kwargs.update(kwargs) return cls(**file_kwargs) @classmethod def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self: if (config_path := path / "config.json").is_file(): return cls.from_file(config_path, **kwargs) if (model_name := path.name) in name_to_config: return cls.from_name(model_name, **kwargs) raise FileNotFoundError(f"For {str(path)!r} neither 'config.json' nor matching config exists.") @property def mlp_class(self) -> Type: return getattr(litgpt.model, self.mlp_class_name) @property def norm_class(self) -> Type: if self.norm_class_name == "RMSNorm": from functools import partial from litgpt.model import RMSNorm return partial(RMSNorm, add_unit_offset="Gemma" in self.name) return getattr(torch.nn, self.norm_class_name) configs = [] name_to_config = {config["name"]: config for config in configs}