stefan-it commited on
Commit
caecb8c
·
1 Parent(s): 49473bf

xlstm: add configuration and modeling (own one)

Browse files
Files changed (2) hide show
  1. configuration_xlstm.py +97 -0
  2. modeling_xlstm.py +214 -0
configuration_xlstm.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Any, Dict, Optional
3
+
4
+ from dacite import Config as DaciteConfig
5
+ from dacite import from_dict
6
+ from omegaconf import OmegaConf
7
+ from transformers.configuration_utils import PretrainedConfig
8
+ from xlstm import xLSTMLMModelConfig
9
+
10
+ # from .config_presets import xlstm_cfg_map
11
+
12
+
13
+ class xLSTMConfig(PretrainedConfig):
14
+ """XLSTM configuration class.
15
+ We seperate the specific xLSTM model configuration
16
+ from the rest due to the heavy nesting of the configuration.
17
+ """
18
+
19
+ model_type = "xlstm"
20
+
21
+ def __init__(
22
+ self, vocab_size: int = 32000, config: Optional[Dict[str, Any]] = None, **kwargs
23
+ ):
24
+ super().__init__(**kwargs)
25
+
26
+ cfg = OmegaConf.create(config)
27
+ cfg["vocab_size"] = vocab_size
28
+ for key, value in kwargs.items():
29
+ cfg[key] = value
30
+
31
+ self._xlstm_config = cfg
32
+ self.vocab_size = vocab_size
33
+ self.embedding_dim = cfg.get("embedding_dim")
34
+ self.context_length = cfg.get("context_length")
35
+
36
+ def to_xlstm_config(self):
37
+ return from_dict(
38
+ data_class=xLSTMLMModelConfig,
39
+ data=OmegaConf.to_container(self._xlstm_config),
40
+ config=DaciteConfig(strict=True),
41
+ )
42
+
43
+ def to_dict(self) -> Dict[str, Any]:
44
+ """
45
+ Converts the configuration to a dictionary for serialization.
46
+ """
47
+ output = super().to_dict()
48
+ output["_xlstm_config"] = OmegaConf.to_container(
49
+ self._xlstm_config, resolve=True
50
+ )
51
+ relevant_keys = [
52
+ "vocab_size",
53
+ "embedding_dim",
54
+ "context_length",
55
+ "torch_dtype",
56
+ "_xlstm_config",
57
+ "transformers_version",
58
+ "architectures",
59
+ "model_type",
60
+ ]
61
+ output_ = output.copy()
62
+ for key in output.keys():
63
+ if key not in relevant_keys:
64
+ output_.pop(key)
65
+ return output_
66
+
67
+ @classmethod
68
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
69
+ """
70
+ Creates a configuration instance from a dictionary.
71
+ """
72
+ xlstm_config = config_dict.pop("_xlstm_config")
73
+ vocab_size = config_dict.pop("vocab_size")
74
+ config = cls(vocab_size=vocab_size, config=xlstm_config)
75
+ if "auto_map" in config_dict and config_dict["auto_map"]:
76
+ setattr(config, "auto_map", config_dict.pop("auto_map"))
77
+
78
+ # breakpoint()
79
+ # config.xlstm_config = xlstm_config
80
+ if "return_unused_kwargs" in kwargs and kwargs["return_unused_kwargs"]:
81
+ return config, {}
82
+
83
+ return config
84
+
85
+ def to_json_string(self, *args, **kwargs) -> str:
86
+ """
87
+ Serializes the instance to a JSON string.
88
+ """
89
+ return json.dumps(self.to_dict(), indent=2)
90
+
91
+ @classmethod
92
+ def from_json_string(cls, json_string: str):
93
+ """
94
+ Deserializes the instance from a JSON string.
95
+ """
96
+ config_dict = json.loads(json_string)
97
+ return cls.from_dict(config_dict)
modeling_xlstm.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence, Tuple, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers import PreTrainedModel
6
+ from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
7
+ from xlstm.components.init import small_init_init_
8
+ from xlstm.utils import WeightDecayOptimGroupMixin
9
+ from xlstm.xlstm_block_stack import xLSTMBlockStack as _xLSTMBlockStack
10
+
11
+ from .configuration_xlstm import xLSTMConfig
12
+
13
+
14
+ class xLSTMPreTrainedModel(PreTrainedModel):
15
+ """Base class for all models."""
16
+
17
+ config_class = xLSTMConfig
18
+
19
+
20
+ class xLSTMBlockStack(_xLSTMBlockStack):
21
+ """Small wrapper to expose hidden states"""
22
+
23
+ def forward(
24
+ self, x: torch.Tensor, **kwargs
25
+ ) -> Tuple[torch.Tensor, Sequence[torch.Tensor]]:
26
+ hidden_states = ()
27
+ for block in self.blocks:
28
+ x = block(x, **kwargs)
29
+ hidden_states += (x,)
30
+
31
+ x = self.post_blocks_norm(x)
32
+
33
+ return x, hidden_states
34
+
35
+
36
+ class xLSTMModel(xLSTMPreTrainedModel):
37
+ def __init__(self, config: xLSTMConfig):
38
+ super().__init__(config)
39
+ self.config = config
40
+
41
+ self.token_embedding = nn.Embedding(
42
+ num_embeddings=config.vocab_size, embedding_dim=config.embedding_dim
43
+ )
44
+ _config = config.to_xlstm_config()
45
+
46
+ self.emb_dropout = (
47
+ nn.Dropout(_config.dropout)
48
+ if _config.add_embedding_dropout
49
+ else nn.Identity()
50
+ )
51
+
52
+ self.xlstm_block_stack = xLSTMBlockStack(config=_config)
53
+
54
+
55
+ def forward(
56
+ self,
57
+ input_ids: torch.LongTensor,
58
+ output_hidden_states: Optional[bool] = None,
59
+ return_dict=Optional[bool],
60
+ ) -> Union[Tuple, BaseModelOutput]:
61
+ token_embedding = self.token_embedding(input_ids)
62
+ x = self.emb_dropout(token_embedding)
63
+ x, hidden_states = self.xlstm_block_stack(x)
64
+
65
+ if output_hidden_states:
66
+ hidden_states = (token_embedding,) + hidden_states
67
+
68
+ if not return_dict:
69
+ return x, hidden_states
70
+
71
+ return BaseModelOutput(
72
+ last_hidden_state=x,
73
+ hidden_states=hidden_states if output_hidden_states else None,
74
+ )
75
+
76
+
77
+ class xLSTMForCausalLM(xLSTMPreTrainedModel, WeightDecayOptimGroupMixin):
78
+ _tied_weights_keys = ["lm_head.weight"]
79
+
80
+ def __init__(self, config: xLSTMConfig, **kwargs):
81
+ super().__init__(config)
82
+ self.config = config
83
+ self.vocab_size = config.vocab_size
84
+
85
+ self.model = xLSTMModel(config)
86
+
87
+ self.lm_head = nn.Linear(
88
+ in_features=config.embedding_dim,
89
+ out_features=config.vocab_size,
90
+ bias=False,
91
+ )
92
+
93
+ self.post_init()
94
+ # TODO: Add option for up-projection
95
+
96
+ def get_input_embeddings(self):
97
+ return self.model.token_embedding
98
+
99
+ def set_input_embeddings(self, value: nn.Module):
100
+ self.model.token_embedding = value
101
+
102
+ def get_output_embeddings(self):
103
+ return self.lm_head
104
+
105
+ def set_output_embeddings(self, value):
106
+ self.lm_head = value
107
+
108
+ def reset_parameters(self):
109
+ self.model.xlstm_block_stack.reset_parameters()
110
+
111
+ small_init_init_(
112
+ self.get_input_embeddings().weight, dim=self.config.embedding_dim
113
+ )
114
+
115
+ if not self.config.tie_word_embeddings:
116
+ small_init_init_(
117
+ self.get_output_embeddings().weight, dim=self.config.embedding_dim
118
+ )
119
+
120
+ def forward(
121
+ self,
122
+ input_ids: torch.Tensor,
123
+ labels: Optional[torch.LongTensor] = None,
124
+ output_hidden_states: Optional[bool] = None,
125
+ return_dict: Optional[bool] = None,
126
+ ):
127
+ output = self.model(
128
+ input_ids,
129
+ output_hidden_states=output_hidden_states,
130
+ )
131
+
132
+ hidden_state = output[0]
133
+
134
+ logits = self.lm_head(hidden_state)
135
+ logits = logits.float()
136
+
137
+ loss = None
138
+
139
+ if labels is not None:
140
+ shift_logits = logits[..., :-1, :].contiguous()
141
+ shift_labels = labels[..., 1:].contiguous()
142
+
143
+ loss_fct = nn.CrossEntropyLoss()
144
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
145
+ shift_labels = shift_labels.view(-1)
146
+
147
+ shift_labels = shift_labels.to(shift_logits.device)
148
+ loss = loss_fct(shift_logits, shift_labels)
149
+
150
+ if not return_dict:
151
+ output = (logits,) + output[1:]
152
+ return ((loss,) + output) if loss is not None else output
153
+
154
+ return CausalLMOutputWithPast(
155
+ loss=loss,
156
+ logits=logits,
157
+ hidden_states=output.hidden_states,
158
+ )
159
+
160
+ def step(
161
+ self,
162
+ idx: torch.Tensor,
163
+ state: dict[str, dict[str, tuple[torch.Tensor, ...]]] = None,
164
+ **kwargs,
165
+ ) -> tuple[torch.Tensor, dict[str, dict[str, tuple[torch.Tensor, ...]]]]:
166
+ x = self.token_embedding(idx)
167
+ x = self.emb_dropout(x)
168
+ x, state = self.xlstm_block_stack.step(x, state=state, **kwargs)
169
+ logits = self.lm_head(x)
170
+ return logits, state
171
+
172
+ def _create_weight_decay_optim_groups(
173
+ self, **kwargs
174
+ ) -> tuple[Sequence[nn.Parameter], Sequence[nn.Parameter]]:
175
+ weight_decay, no_weight_decay = super()._create_weight_decay_optim_groups(
176
+ **kwargs
177
+ )
178
+ # remove token embedding and add it to the correct group, accrording to the config
179
+ weight_decay = list(weight_decay)
180
+ removed = 0
181
+ for idx in range(len(weight_decay)):
182
+ if weight_decay[idx - removed] is self.get_input_embeddings().weight:
183
+ weight_decay.pop(idx - removed)
184
+ removed += 1
185
+ weight_decay = tuple(weight_decay)
186
+
187
+ # TODO: Fix this
188
+ # if self.config.weight_decay_on_embedding:
189
+ if True:
190
+ weight_decay += (self.get_input_embeddings().weight,)
191
+ else:
192
+ no_weight_decay += (self.get_input_embeddings().weight,)
193
+
194
+ return weight_decay, no_weight_decay
195
+
196
+ def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
197
+ new_embeddings = nn.Embedding(
198
+ new_num_tokens, self.token_embedding.embedding_dim
199
+ )
200
+ self.token_embedding = new_embeddings.to(self.device)
201
+ return new_embeddings
202
+
203
+ def tie_weights(self):
204
+ self.get_output_embeddings().weight = self.get_input_embeddings().weight
205
+
206
+ def prepare_inputs_for_generation(
207
+ self,
208
+ input_ids,
209
+ **kwargs,
210
+ ):
211
+ model_inputs = {
212
+ "input_ids": input_ids.to(self.device),
213
+ }
214
+ return model_inputs