leafspark commited on
Commit
641ee6f
·
verified ·
1 Parent(s): 6eacc63

feat(model): support using JSON config

Browse files
Files changed (1) hide show
  1. litgpt/config.py +8 -49
litgpt/config.py CHANGED
@@ -1,12 +1,10 @@
1
- # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
-
3
  from copy import deepcopy
4
  from dataclasses import dataclass, field
5
  from pathlib import Path
6
  from typing import Any, Literal, Optional, Type, Union
7
 
8
  import torch
9
- import yaml
10
  from typing_extensions import Self
11
 
12
  import litgpt.model
@@ -30,33 +28,11 @@ class Config:
30
  parallel_residual: bool = True
31
  bias: bool = True
32
  lm_head_bias: bool = False
33
- # to use multi-head attention (MHA), set this to `n_head` (default)
34
- # to use multi-query attention (MQA), set this to 1
35
- # to use grouped-query attention (GQA), set this to a value in between
36
- # Example with `n_head=4`
37
- # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
38
- # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
39
- # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
40
- # │ │ │ │ │ │ │
41
- # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
42
- # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
43
- # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
44
- # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
45
- # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
46
- # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
47
- # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
48
- # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
49
- # MHA GQA MQA
50
- # n_query_groups=4 n_query_groups=2 n_query_groups=1
51
- #
52
- # credit https://arxiv.org/pdf/2305.13245.pdf
53
  n_query_groups: Optional[int] = None
54
  shared_attention_norm: bool = False
55
  norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
56
  norm_eps: float = 1e-5
57
- mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = (
58
- "GptNeoxMLP"
59
- )
60
  gelu_approximate: str = "none"
61
  intermediate_size: Optional[int] = None
62
  rope_condense_ratio: int = 1
@@ -90,27 +66,19 @@ class Config:
90
  assert self.n_embd % self.n_head == 0
91
  self.head_size = self.n_embd // self.n_head
92
 
93
- # vocab size should be a power of 2 to be optimal on hardware. compute the closest value
94
  if self.padded_vocab_size is None:
95
- self.padded_vocab_size = find_multiple(
96
- self.vocab_size, self.padding_multiple
97
- )
98
  else:
99
- # vocab size shouldn't be larger than padded vocab size
100
  self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
101
 
102
- # compute the number of query groups
103
  if self.n_query_groups is not None:
104
  assert self.n_head % self.n_query_groups == 0
105
  else:
106
  self.n_query_groups = self.n_head
107
 
108
- # compute the intermediate size for MLP if not set
109
  if self.intermediate_size is None:
110
  if self.mlp_class_name == "LLaMAMLP":
111
- raise ValueError(
112
- f"The config {self.name!r}, needs to set the `intermediate_size`"
113
- )
114
  self.intermediate_size = 4 * self.n_embd
115
 
116
  self.rope_n_elem = int(self.rotary_percentage * self.head_size)
@@ -121,14 +89,12 @@ class Config:
121
  @classmethod
122
  def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
123
  if name not in name_to_config:
124
- # search through all `config['hf_config']['name']`
125
  try:
126
  conf_dict = next(
127
  config
128
  for config in configs
129
  if name == config["hf_config"]["name"]
130
- or config["hf_config"]["org"] + "/" + config["hf_config"]["name"]
131
- == name
132
  )
133
  except StopIteration:
134
  raise ValueError(f"{name!r} is not a supported config name")
@@ -142,7 +108,7 @@ class Config:
142
  @classmethod
143
  def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
144
  with open(path, encoding="utf-8") as fp:
145
- file_kwargs = yaml.safe_load(fp)
146
  if file_kwargs is None:
147
  raise ValueError(f"{path} is empty which is likely unexpected.")
148
  file_kwargs.update(kwargs)
@@ -150,28 +116,21 @@ class Config:
150
 
151
  @classmethod
152
  def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
153
- """Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`."""
154
- if (config_path := path / "model_config.yaml").is_file():
155
  return cls.from_file(config_path, **kwargs)
156
  if (model_name := path.name) in name_to_config:
157
  return cls.from_name(model_name, **kwargs)
158
- raise FileNotFoundError(
159
- f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists."
160
- )
161
 
162
  @property
163
  def mlp_class(self) -> Type:
164
- # `self.mlp_class_name` cannot be the type to keep the config serializable
165
  return getattr(litgpt.model, self.mlp_class_name)
166
 
167
  @property
168
  def norm_class(self) -> Type:
169
- # `self.norm_class_name` cannot be the type to keep the config serializable
170
  if self.norm_class_name == "RMSNorm":
171
  from functools import partial
172
-
173
  from litgpt.model import RMSNorm
174
-
175
  return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
176
  return getattr(torch.nn, self.norm_class_name)
177
 
 
1
+ import json
 
2
  from copy import deepcopy
3
  from dataclasses import dataclass, field
4
  from pathlib import Path
5
  from typing import Any, Literal, Optional, Type, Union
6
 
7
  import torch
 
8
  from typing_extensions import Self
9
 
10
  import litgpt.model
 
28
  parallel_residual: bool = True
29
  bias: bool = True
30
  lm_head_bias: bool = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  n_query_groups: Optional[int] = None
32
  shared_attention_norm: bool = False
33
  norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
34
  norm_eps: float = 1e-5
35
+ mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP"
 
 
36
  gelu_approximate: str = "none"
37
  intermediate_size: Optional[int] = None
38
  rope_condense_ratio: int = 1
 
66
  assert self.n_embd % self.n_head == 0
67
  self.head_size = self.n_embd // self.n_head
68
 
 
69
  if self.padded_vocab_size is None:
70
+ self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple)
 
 
71
  else:
 
72
  self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
73
 
 
74
  if self.n_query_groups is not None:
75
  assert self.n_head % self.n_query_groups == 0
76
  else:
77
  self.n_query_groups = self.n_head
78
 
 
79
  if self.intermediate_size is None:
80
  if self.mlp_class_name == "LLaMAMLP":
81
+ raise ValueError(f"The config {self.name!r}, needs to set the `intermediate_size`")
 
 
82
  self.intermediate_size = 4 * self.n_embd
83
 
84
  self.rope_n_elem = int(self.rotary_percentage * self.head_size)
 
89
  @classmethod
90
  def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
91
  if name not in name_to_config:
 
92
  try:
93
  conf_dict = next(
94
  config
95
  for config in configs
96
  if name == config["hf_config"]["name"]
97
+ or config["hf_config"]["org"] + "/" + config["hf_config"]["name"] == name
 
98
  )
99
  except StopIteration:
100
  raise ValueError(f"{name!r} is not a supported config name")
 
108
  @classmethod
109
  def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
110
  with open(path, encoding="utf-8") as fp:
111
+ file_kwargs = json.load(fp)
112
  if file_kwargs is None:
113
  raise ValueError(f"{path} is empty which is likely unexpected.")
114
  file_kwargs.update(kwargs)
 
116
 
117
  @classmethod
118
  def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
119
+ if (config_path := path / "config.json").is_file():
 
120
  return cls.from_file(config_path, **kwargs)
121
  if (model_name := path.name) in name_to_config:
122
  return cls.from_name(model_name, **kwargs)
123
+ raise FileNotFoundError(f"For {str(path)!r} neither 'config.json' nor matching config exists.")
 
 
124
 
125
  @property
126
  def mlp_class(self) -> Type:
 
127
  return getattr(litgpt.model, self.mlp_class_name)
128
 
129
  @property
130
  def norm_class(self) -> Type:
 
131
  if self.norm_class_name == "RMSNorm":
132
  from functools import partial
 
133
  from litgpt.model import RMSNorm
 
134
  return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
135
  return getattr(torch.nn, self.norm_class_name)
136