Spaces:
Runtime error
Runtime error
feat(model): support using JSON config
Browse files- litgpt/config.py +8 -49
litgpt/config.py
CHANGED
@@ -1,12 +1,10 @@
|
|
1 |
-
|
2 |
-
|
3 |
from copy import deepcopy
|
4 |
from dataclasses import dataclass, field
|
5 |
from pathlib import Path
|
6 |
from typing import Any, Literal, Optional, Type, Union
|
7 |
|
8 |
import torch
|
9 |
-
import yaml
|
10 |
from typing_extensions import Self
|
11 |
|
12 |
import litgpt.model
|
@@ -30,33 +28,11 @@ class Config:
|
|
30 |
parallel_residual: bool = True
|
31 |
bias: bool = True
|
32 |
lm_head_bias: bool = False
|
33 |
-
# to use multi-head attention (MHA), set this to `n_head` (default)
|
34 |
-
# to use multi-query attention (MQA), set this to 1
|
35 |
-
# to use grouped-query attention (GQA), set this to a value in between
|
36 |
-
# Example with `n_head=4`
|
37 |
-
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
|
38 |
-
# │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
|
39 |
-
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
|
40 |
-
# │ │ │ │ │ │ │
|
41 |
-
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
|
42 |
-
# │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
|
43 |
-
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
|
44 |
-
# │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
|
45 |
-
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
|
46 |
-
# │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
|
47 |
-
# └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
|
48 |
-
# ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
|
49 |
-
# MHA GQA MQA
|
50 |
-
# n_query_groups=4 n_query_groups=2 n_query_groups=1
|
51 |
-
#
|
52 |
-
# credit https://arxiv.org/pdf/2305.13245.pdf
|
53 |
n_query_groups: Optional[int] = None
|
54 |
shared_attention_norm: bool = False
|
55 |
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
|
56 |
norm_eps: float = 1e-5
|
57 |
-
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] =
|
58 |
-
"GptNeoxMLP"
|
59 |
-
)
|
60 |
gelu_approximate: str = "none"
|
61 |
intermediate_size: Optional[int] = None
|
62 |
rope_condense_ratio: int = 1
|
@@ -90,27 +66,19 @@ class Config:
|
|
90 |
assert self.n_embd % self.n_head == 0
|
91 |
self.head_size = self.n_embd // self.n_head
|
92 |
|
93 |
-
# vocab size should be a power of 2 to be optimal on hardware. compute the closest value
|
94 |
if self.padded_vocab_size is None:
|
95 |
-
self.padded_vocab_size = find_multiple(
|
96 |
-
self.vocab_size, self.padding_multiple
|
97 |
-
)
|
98 |
else:
|
99 |
-
# vocab size shouldn't be larger than padded vocab size
|
100 |
self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
|
101 |
|
102 |
-
# compute the number of query groups
|
103 |
if self.n_query_groups is not None:
|
104 |
assert self.n_head % self.n_query_groups == 0
|
105 |
else:
|
106 |
self.n_query_groups = self.n_head
|
107 |
|
108 |
-
# compute the intermediate size for MLP if not set
|
109 |
if self.intermediate_size is None:
|
110 |
if self.mlp_class_name == "LLaMAMLP":
|
111 |
-
raise ValueError(
|
112 |
-
f"The config {self.name!r}, needs to set the `intermediate_size`"
|
113 |
-
)
|
114 |
self.intermediate_size = 4 * self.n_embd
|
115 |
|
116 |
self.rope_n_elem = int(self.rotary_percentage * self.head_size)
|
@@ -121,14 +89,12 @@ class Config:
|
|
121 |
@classmethod
|
122 |
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
|
123 |
if name not in name_to_config:
|
124 |
-
# search through all `config['hf_config']['name']`
|
125 |
try:
|
126 |
conf_dict = next(
|
127 |
config
|
128 |
for config in configs
|
129 |
if name == config["hf_config"]["name"]
|
130 |
-
or config["hf_config"]["org"] + "/" + config["hf_config"]["name"]
|
131 |
-
== name
|
132 |
)
|
133 |
except StopIteration:
|
134 |
raise ValueError(f"{name!r} is not a supported config name")
|
@@ -142,7 +108,7 @@ class Config:
|
|
142 |
@classmethod
|
143 |
def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
|
144 |
with open(path, encoding="utf-8") as fp:
|
145 |
-
file_kwargs =
|
146 |
if file_kwargs is None:
|
147 |
raise ValueError(f"{path} is empty which is likely unexpected.")
|
148 |
file_kwargs.update(kwargs)
|
@@ -150,28 +116,21 @@ class Config:
|
|
150 |
|
151 |
@classmethod
|
152 |
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
|
153 |
-
|
154 |
-
if (config_path := path / "model_config.yaml").is_file():
|
155 |
return cls.from_file(config_path, **kwargs)
|
156 |
if (model_name := path.name) in name_to_config:
|
157 |
return cls.from_name(model_name, **kwargs)
|
158 |
-
raise FileNotFoundError(
|
159 |
-
f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists."
|
160 |
-
)
|
161 |
|
162 |
@property
|
163 |
def mlp_class(self) -> Type:
|
164 |
-
# `self.mlp_class_name` cannot be the type to keep the config serializable
|
165 |
return getattr(litgpt.model, self.mlp_class_name)
|
166 |
|
167 |
@property
|
168 |
def norm_class(self) -> Type:
|
169 |
-
# `self.norm_class_name` cannot be the type to keep the config serializable
|
170 |
if self.norm_class_name == "RMSNorm":
|
171 |
from functools import partial
|
172 |
-
|
173 |
from litgpt.model import RMSNorm
|
174 |
-
|
175 |
return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
|
176 |
return getattr(torch.nn, self.norm_class_name)
|
177 |
|
|
|
1 |
+
import json
|
|
|
2 |
from copy import deepcopy
|
3 |
from dataclasses import dataclass, field
|
4 |
from pathlib import Path
|
5 |
from typing import Any, Literal, Optional, Type, Union
|
6 |
|
7 |
import torch
|
|
|
8 |
from typing_extensions import Self
|
9 |
|
10 |
import litgpt.model
|
|
|
28 |
parallel_residual: bool = True
|
29 |
bias: bool = True
|
30 |
lm_head_bias: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
n_query_groups: Optional[int] = None
|
32 |
shared_attention_norm: bool = False
|
33 |
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
|
34 |
norm_eps: float = 1e-5
|
35 |
+
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP"
|
|
|
|
|
36 |
gelu_approximate: str = "none"
|
37 |
intermediate_size: Optional[int] = None
|
38 |
rope_condense_ratio: int = 1
|
|
|
66 |
assert self.n_embd % self.n_head == 0
|
67 |
self.head_size = self.n_embd // self.n_head
|
68 |
|
|
|
69 |
if self.padded_vocab_size is None:
|
70 |
+
self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple)
|
|
|
|
|
71 |
else:
|
|
|
72 |
self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
|
73 |
|
|
|
74 |
if self.n_query_groups is not None:
|
75 |
assert self.n_head % self.n_query_groups == 0
|
76 |
else:
|
77 |
self.n_query_groups = self.n_head
|
78 |
|
|
|
79 |
if self.intermediate_size is None:
|
80 |
if self.mlp_class_name == "LLaMAMLP":
|
81 |
+
raise ValueError(f"The config {self.name!r}, needs to set the `intermediate_size`")
|
|
|
|
|
82 |
self.intermediate_size = 4 * self.n_embd
|
83 |
|
84 |
self.rope_n_elem = int(self.rotary_percentage * self.head_size)
|
|
|
89 |
@classmethod
|
90 |
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
|
91 |
if name not in name_to_config:
|
|
|
92 |
try:
|
93 |
conf_dict = next(
|
94 |
config
|
95 |
for config in configs
|
96 |
if name == config["hf_config"]["name"]
|
97 |
+
or config["hf_config"]["org"] + "/" + config["hf_config"]["name"] == name
|
|
|
98 |
)
|
99 |
except StopIteration:
|
100 |
raise ValueError(f"{name!r} is not a supported config name")
|
|
|
108 |
@classmethod
|
109 |
def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
|
110 |
with open(path, encoding="utf-8") as fp:
|
111 |
+
file_kwargs = json.load(fp)
|
112 |
if file_kwargs is None:
|
113 |
raise ValueError(f"{path} is empty which is likely unexpected.")
|
114 |
file_kwargs.update(kwargs)
|
|
|
116 |
|
117 |
@classmethod
|
118 |
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
|
119 |
+
if (config_path := path / "config.json").is_file():
|
|
|
120 |
return cls.from_file(config_path, **kwargs)
|
121 |
if (model_name := path.name) in name_to_config:
|
122 |
return cls.from_name(model_name, **kwargs)
|
123 |
+
raise FileNotFoundError(f"For {str(path)!r} neither 'config.json' nor matching config exists.")
|
|
|
|
|
124 |
|
125 |
@property
|
126 |
def mlp_class(self) -> Type:
|
|
|
127 |
return getattr(litgpt.model, self.mlp_class_name)
|
128 |
|
129 |
@property
|
130 |
def norm_class(self) -> Type:
|
|
|
131 |
if self.norm_class_name == "RMSNorm":
|
132 |
from functools import partial
|
|
|
133 |
from litgpt.model import RMSNorm
|
|
|
134 |
return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
|
135 |
return getattr(torch.nn, self.norm_class_name)
|
136 |
|