Spaces:
Runtime error
Runtime error
import json | |
from copy import deepcopy | |
from dataclasses import dataclass, field | |
from pathlib import Path | |
from typing import Any, Literal, Optional, Type, Union | |
import torch | |
from typing_extensions import Self | |
import litgpt.model | |
from litgpt.utils import find_multiple | |
class Config: | |
name: str = "" | |
hf_config: dict = field(default_factory=dict) | |
scale_embeddings: bool = False | |
block_size: int = 4096 | |
vocab_size: int = 50254 | |
padding_multiple: int = 512 | |
padded_vocab_size: Optional[int] = None | |
n_layer: int = 16 | |
n_head: int = 32 | |
head_size: Optional[int] = None | |
n_embd: int = 4096 | |
rotary_percentage: float = 0.25 | |
parallel_residual: bool = True | |
bias: bool = True | |
lm_head_bias: bool = False | |
n_query_groups: Optional[int] = None | |
shared_attention_norm: bool = False | |
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" | |
norm_eps: float = 1e-5 | |
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP" | |
gelu_approximate: str = "none" | |
intermediate_size: Optional[int] = None | |
rope_condense_ratio: int = 1 | |
rope_base: int = 10000 | |
n_expert: int = 0 | |
n_expert_per_token: int = 0 | |
add_qkv_bias: Optional[bool] = None | |
prompt_vocab_size: Optional[int] = None | |
attn_dropout: float = 0.0 | |
pos_type: str = "rope" | |
force_align: bool = False | |
use_pretrain_phoneme_emb: bool = False | |
tie_word_embeddings: bool = False | |
# setting for mini-omni | |
text_vocab_size:int = 152000 | |
cat_audio_vocab_size: int = 29120 | |
audio_vocab_size: int = 4160 | |
whisper_adapter_dim: int = 768 | |
post_adapter: bool = False | |
post_adapter_layers: int = 6 | |
asr_adapter: str = "llamamlp" | |
def __post_init__(self): | |
if not self.name: | |
self.name = self.hf_config.get("name", self.name) | |
if self.head_size is None: | |
assert self.n_embd % self.n_head == 0 | |
self.head_size = self.n_embd // self.n_head | |
if self.padded_vocab_size is None: | |
self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple) | |
else: | |
self.vocab_size = min(self.vocab_size, self.padded_vocab_size) | |
if self.n_query_groups is not None: | |
assert self.n_head % self.n_query_groups == 0 | |
else: | |
self.n_query_groups = self.n_head | |
if self.intermediate_size is None: | |
if self.mlp_class_name == "LLaMAMLP": | |
raise ValueError(f"The config {self.name!r}, needs to set the `intermediate_size`") | |
self.intermediate_size = 4 * self.n_embd | |
self.rope_n_elem = int(self.rotary_percentage * self.head_size) | |
if self.add_qkv_bias is None: | |
self.add_qkv_bias = self.bias | |
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]: | |
if name not in name_to_config: | |
try: | |
conf_dict = next( | |
config | |
for config in configs | |
if name == config["hf_config"]["name"] | |
or config["hf_config"]["org"] + "/" + config["hf_config"]["name"] == name | |
) | |
except StopIteration: | |
raise ValueError(f"{name!r} is not a supported config name") | |
else: | |
conf_dict = name_to_config[name] | |
conf_dict = conf_dict.copy() | |
conf_dict.update(kwargs) | |
return cls(**conf_dict) | |
def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self: | |
with open(path, encoding="utf-8") as fp: | |
file_kwargs = json.load(fp) | |
if file_kwargs is None: | |
raise ValueError(f"{path} is empty which is likely unexpected.") | |
file_kwargs.update(kwargs) | |
return cls(**file_kwargs) | |
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self: | |
if (config_path := path / "config.json").is_file(): | |
return cls.from_file(config_path, **kwargs) | |
if (model_name := path.name) in name_to_config: | |
return cls.from_name(model_name, **kwargs) | |
raise FileNotFoundError(f"For {str(path)!r} neither 'config.json' nor matching config exists.") | |
def mlp_class(self) -> Type: | |
return getattr(litgpt.model, self.mlp_class_name) | |
def norm_class(self) -> Type: | |
if self.norm_class_name == "RMSNorm": | |
from functools import partial | |
from litgpt.model import RMSNorm | |
return partial(RMSNorm, add_unit_offset="Gemma" in self.name) | |
return getattr(torch.nn, self.norm_class_name) | |
configs = [] | |
name_to_config = {config["name"]: config for config in configs} | |