NohTow commited on
Commit
a243956
·
1 Parent(s): 2571cc4

New version

Browse files
Files changed (11) hide show
  1. __init__.py +66 -5
  2. activation.py +1 -1
  3. attention.py +5 -5
  4. embeddings.py +3 -3
  5. initialization.py +3 -3
  6. layers.py +6 -6
  7. loss.py +30 -0
  8. mlp.py +4 -4
  9. model.py +1684 -0
  10. normalization.py +1 -1
  11. options.py +6 -6
__init__.py CHANGED
@@ -1,7 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
- import os
3
- import sys
4
 
5
- # Add src folder root to path to allow us to use relative imports regardless of what directory the script is run from
6
- sys.path.append(os.path.dirname(os.path.realpath(__file__)))
7
- from modeling_flexbert import FlexBertModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .attention import (
2
+ BertAlibiUnpadAttention,
3
+ BertAlibiUnpadSelfAttention,
4
+ BertSelfOutput,
5
+ FlexBertPaddedAttention,
6
+ FlexBertUnpadAttention,
7
+ )
8
+ from .embeddings import (
9
+ BertAlibiEmbeddings,
10
+ FlexBertAbsoluteEmbeddings,
11
+ FlexBertSansPositionEmbeddings,
12
+ )
13
+ from .layers import (
14
+ BertAlibiEncoder,
15
+ BertAlibiLayer,
16
+ BertResidualGLU,
17
+ FlexBertPaddedPreNormLayer,
18
+ FlexBertPaddedPostNormLayer,
19
+ FlexBertUnpadPostNormLayer,
20
+ FlexBertUnpadPreNormLayer,
21
+ )
22
+ from .model import (
23
+ BertLMPredictionHead,
24
+ BertModel,
25
+ BertForMaskedLM,
26
+ BertForSequenceClassification,
27
+ BertForMultipleChoice,
28
+ BertOnlyMLMHead,
29
+ BertOnlyNSPHead,
30
+ BertPooler,
31
+ BertPredictionHeadTransform,
32
+ FlexBertModel,
33
+ FlexBertForMaskedLM,
34
+ FlexBertForSequenceClassification,
35
+ FlexBertForMultipleChoice,
36
+ )
37
 
 
 
38
 
39
+ __all__ = [
40
+ "BertAlibiEmbeddings",
41
+ "BertAlibiEncoder",
42
+ "BertForMaskedLM",
43
+ "BertForSequenceClassification",
44
+ "BertForMultipleChoice",
45
+ "BertResidualGLU",
46
+ "BertAlibiLayer",
47
+ "BertLMPredictionHead",
48
+ "BertModel",
49
+ "BertOnlyMLMHead",
50
+ "BertOnlyNSPHead",
51
+ "BertPooler",
52
+ "BertPredictionHeadTransform",
53
+ "BertSelfOutput",
54
+ "BertAlibiUnpadAttention",
55
+ "BertAlibiUnpadSelfAttention",
56
+ "FlexBertPaddedAttention",
57
+ "FlexBertUnpadAttention",
58
+ "FlexBertAbsoluteEmbeddings",
59
+ "FlexBertSansPositionEmbeddings",
60
+ "FlexBertPaddedPreNormLayer",
61
+ "FlexBertPaddedPostNormLayer",
62
+ "FlexBertUnpadPostNormLayer",
63
+ "FlexBertUnpadPreNormLayer",
64
+ "FlexBertModel",
65
+ "FlexBertForMaskedLM",
66
+ "FlexBertForSequenceClassification",
67
+ "FlexBertForMultipleChoice",
68
+ ]
activation.py CHANGED
@@ -7,7 +7,7 @@
7
  from collections import OrderedDict
8
  from typing import Union
9
  import torch.nn as nn
10
- from configuration_bert import FlexBertConfig
11
 
12
 
13
  class ClassInstantier(OrderedDict):
 
7
  from collections import OrderedDict
8
  from typing import Union
9
  import torch.nn as nn
10
+ from .configuration_bert import FlexBertConfig
11
 
12
 
13
  class ClassInstantier(OrderedDict):
attention.py CHANGED
@@ -22,10 +22,10 @@ import logging
22
  import math
23
 
24
  import bert_padding
25
- from configuration_bert import FlexBertConfig, maybe_add_padding
26
- from normalization import get_norm_layer
27
- from initialization import ModuleType, init_weights
28
- import utils # noqa: F401
29
 
30
  IMPL_USE_FLASH3 = False
31
  IMPL_USE_FLASH2 = False
@@ -48,7 +48,7 @@ except ImportError:
48
 
49
  try:
50
  from flash_attn.layers.rotary import RotaryEmbedding # type: ignore
51
- from rotary import UnpaddedRotaryEmbedding # type: ignore
52
 
53
  except ImportError:
54
  RotaryEmbedding = None
 
22
  import math
23
 
24
  import bert_padding
25
+ from .configuration_bert import FlexBertConfig, maybe_add_padding
26
+ from .normalization import get_norm_layer
27
+ from .initialization import ModuleType, init_weights
28
+ import src.utils # noqa: F401
29
 
30
  IMPL_USE_FLASH3 = False
31
  IMPL_USE_FLASH2 = False
 
48
 
49
  try:
50
  from flash_attn.layers.rotary import RotaryEmbedding # type: ignore
51
+ from .rotary import UnpaddedRotaryEmbedding # type: ignore
52
 
53
  except ImportError:
54
  RotaryEmbedding = None
embeddings.py CHANGED
@@ -16,9 +16,9 @@ import torch
16
  import torch.nn as nn
17
  from typing import Optional
18
 
19
- from configuration_bert import FlexBertConfig
20
- from normalization import get_norm_layer
21
- from initialization import ModuleType, init_weights
22
 
23
 
24
  class BertAlibiEmbeddings(nn.Module):
 
16
  import torch.nn as nn
17
  from typing import Optional
18
 
19
+ from .configuration_bert import FlexBertConfig
20
+ from .normalization import get_norm_layer
21
+ from .initialization import ModuleType, init_weights
22
 
23
 
24
  class BertAlibiEmbeddings(nn.Module):
initialization.py CHANGED
@@ -14,10 +14,10 @@ from typing import Optional, Union
14
  import torch
15
  import torch.nn as nn
16
 
17
- from utils import StrEnum
18
 
19
- from configuration_bert import FlexBertConfig
20
- from normalization import RMSNorm
21
 
22
  __all__ = ["init_weights", "ModuleType", "InitFnType"]
23
 
 
14
  import torch
15
  import torch.nn as nn
16
 
17
+ from src.utils import StrEnum
18
 
19
+ from .configuration_bert import FlexBertConfig
20
+ from .normalization import RMSNorm
21
 
22
  __all__ = ["init_weights", "ModuleType", "InitFnType"]
23
 
layers.py CHANGED
@@ -22,12 +22,12 @@ import torch.nn as nn
22
 
23
  import bert_padding
24
 
25
- from activation import get_act_fn
26
- from attention import FlexBertAttentionBase, BertAlibiUnpadAttention, get_attention_layer
27
- from mlp import FlexBertMLPBase, BertResidualGLU, get_mlp_layer
28
- from configuration_bert import FlexBertConfig, maybe_add_padding
29
- from normalization import get_norm_layer
30
- from initialization import ModuleType, init_weights
31
 
32
 
33
  class BertAlibiLayer(nn.Module):
 
22
 
23
  import bert_padding
24
 
25
+ from .activation import get_act_fn
26
+ from .attention import FlexBertAttentionBase, BertAlibiUnpadAttention, get_attention_layer
27
+ from .mlp import FlexBertMLPBase, BertResidualGLU, get_mlp_layer
28
+ from .configuration_bert import FlexBertConfig, maybe_add_padding
29
+ from .normalization import get_norm_layer
30
+ from .initialization import ModuleType, init_weights
31
 
32
 
33
  class BertAlibiLayer(nn.Module):
loss.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ import inspect
5
+ import torch.nn as nn
6
+ from .configuration_bert import FlexBertConfig
7
+
8
+ try:
9
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
10
+ except ImportError:
11
+ CrossEntropyLoss = None
12
+
13
+ LOSS2CLS = {
14
+ "cross_entropy": nn.CrossEntropyLoss,
15
+ "binary_cross_entropy": nn.BCEWithLogitsLoss,
16
+ "mean_squared_error": nn.MSELoss,
17
+ }
18
+
19
+ if CrossEntropyLoss is not None:
20
+ LOSS2CLS["fa_cross_entropy"] = CrossEntropyLoss
21
+
22
+
23
+ def get_loss_fn(config: FlexBertConfig) -> nn.Module:
24
+ try:
25
+ loss_class = LOSS2CLS[config.loss_function]
26
+ signature = inspect.signature(loss_class)
27
+ loss_kwargs = {k: v for k, v in config.loss_kwargs.items() if k in signature.parameters}
28
+ return loss_class(**loss_kwargs)
29
+ except KeyError:
30
+ raise ValueError(f"Invalid loss function type: {config.loss_function}, must be one of {LOSS2CLS.keys()}.")
mlp.py CHANGED
@@ -16,10 +16,10 @@ from typing import Optional
16
  import torch
17
  import torch.nn as nn
18
 
19
- from configuration_bert import FlexBertConfig
20
- from activation import get_act_fn
21
- from normalization import get_norm_layer
22
- from initialization import ModuleType, init_weights
23
 
24
 
25
  class BertResidualGLU(nn.Module):
 
16
  import torch
17
  import torch.nn as nn
18
 
19
+ from .configuration_bert import FlexBertConfig
20
+ from .activation import get_act_fn
21
+ from .normalization import get_norm_layer
22
+ from .initialization import ModuleType, init_weights
23
 
24
 
25
  class BertResidualGLU(nn.Module):
model.py ADDED
@@ -0,0 +1,1684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ # RMSNorm Implementation: Copyright Meta (from their Llama RMSNorm implementation)
5
+ # License: LLAMA 2 COMMUNITY LICENSE AGREEMENT
6
+
7
+ # Copyright 2022 Jonas Geiping
8
+ # License: MIT
9
+
10
+ # Copyright 2022 MosaicML Examples authors
11
+ # SPDX-License-Identifier: Apache-2.0
12
+
13
+ # Copyright 2023 MosaicML Examples authors
14
+ # SPDX-License-Identifier: Apache-2.0
15
+
16
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
17
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
18
+ # Copyright (c) 2023, Tri Dao.
19
+
20
+ """Implements Mosaic BERT, with an eye towards the Hugging Face API.
21
+
22
+ Mosaic BERT improves performance over Hugging Face BERT through the following:
23
+
24
+ 1. ALiBi. This architectural change removes positional embeddings and instead encodes positional
25
+ information through attention biases based on query-key position distance. It improves the effectiveness
26
+ of training with shorter sequence lengths by enabling extrapolation to longer sequences.
27
+
28
+ 2. Gated Linear Units (GLU). This architectural change replaces the FFN component of the BERT layer
29
+ to improve overall expressiveness, providing better convergence properties.
30
+
31
+ 3. Flash Attention. The MosaicBERT's self-attention layer makes use of Flash Attention, which dramatically
32
+ improves the speed of self-attention. Our implementation utilizes a bleeding edge implementation that
33
+ supports attention biases, which allows us to use Flash Attention with ALiBi.
34
+
35
+ 4. Unpadding. Padding is often used to simplify batching across sequences of different lengths. Standard BERT
36
+ implementations waste computation on padded tokens. MosaicBERT internally unpads to reduce unnecessary computation
37
+ and improve speed. It does this without changing how the user interfaces with the model, thereby
38
+ preserving the simple API of standard implementations.
39
+
40
+
41
+ Currently, MosaicBERT is available for masked language modeling :class:`BertForMaskedLM` and sequence
42
+ classification :class:`BertForSequenceClassification`. We aim to expand this catalogue in future releases.
43
+
44
+ See :file:`./mosaic_bert.py` for utilities to simplify working with MosaicBERT in Composer, and for example usage
45
+ of the core Mosaic BERT classes.
46
+ """
47
+
48
+ import logging
49
+ import os
50
+ import sys
51
+ import warnings
52
+ from dataclasses import dataclass
53
+ from typing import List, Optional, Tuple, Union
54
+
55
+ # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from
56
+ sys.path.append(os.path.dirname(os.path.realpath(__file__)))
57
+
58
+ import torch
59
+ import torch.nn as nn
60
+ from einops import rearrange
61
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
62
+ from transformers.modeling_outputs import (
63
+ MaskedLMOutput,
64
+ ModelOutput,
65
+ MultipleChoiceModelOutput,
66
+ SequenceClassifierOutput,
67
+ )
68
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel
69
+
70
+ from bert_padding import index_put_first_axis
71
+
72
+ from src.bert_layers.activation import get_act_fn
73
+ from src.bert_layers.attention import (
74
+ FlexBertPaddedAttention,
75
+ FlexBertPaddedParallelAttention,
76
+ FlexBertPaddedRopeAttention,
77
+ FlexBertPaddedRopeParallelAttention,
78
+ FlexBertUnpadAttention,
79
+ FlexBertUnpadParallelAttention,
80
+ FlexBertUnpadRopeAttention,
81
+ FlexBertUnpadRopeParallelAttention,
82
+ )
83
+ from src.bert_layers.configuration_bert import FlexBertConfig
84
+ from src.bert_layers.embeddings import (
85
+ BertAlibiEmbeddings,
86
+ FlexBertAbsoluteEmbeddings,
87
+ FlexBertCompiledSansPositionEmbeddings,
88
+ FlexBertSansPositionEmbeddings,
89
+ get_embedding_layer,
90
+ )
91
+ from src.bert_layers.initialization import (
92
+ ModuleType,
93
+ TileLinear,
94
+ TileMode,
95
+ init_weights,
96
+ tile_embedding,
97
+ tile_linear,
98
+ tile_norm,
99
+ )
100
+ from src.bert_layers.layers import (
101
+ BertAlibiEncoder,
102
+ BertPooler,
103
+ BertPredictionHeadTransform,
104
+ FlexBertCompileUnpadPreNormLayer,
105
+ FlexBertPaddedEncoder,
106
+ FlexBertPaddedParallelPreNormLayer,
107
+ FlexBertPaddedPostNormLayer,
108
+ FlexBertPaddedPreNormLayer,
109
+ FlexBertUnpadEncoder,
110
+ FlexBertUnpadParallelPreNormLayer,
111
+ FlexBertUnpadPostNormLayer,
112
+ FlexBertUnpadPreNormLayer,
113
+ get_encoder_layer,
114
+ )
115
+ from src.bert_layers.loss import get_loss_fn
116
+ from src.bert_layers.mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
117
+ from src.bert_layers.normalization import get_norm_layer
118
+ from src.bert_layers.padding import pad_input, unpad_input
119
+
120
+ logger = logging.getLogger(__name__)
121
+
122
+
123
+ def _count_parameters(model: nn.Module, trainable: bool = True) -> int:
124
+ if trainable:
125
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
126
+ else:
127
+ return sum(p.numel() for p in model.parameters())
128
+
129
+
130
+ class BertModel(BertPreTrainedModel):
131
+ """Overall BERT model.
132
+
133
+ Args:
134
+ config: a BertConfig class instance with the configuration to build a new model
135
+
136
+ Inputs:
137
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
138
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
139
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
140
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
141
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
142
+ a `sentence B` token (see BERT paper for more details).
143
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
144
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
145
+ input sequence length in the current batch. It's the mask that we typically use for attention when
146
+ a batch has varying length sentences.
147
+ `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
148
+
149
+ Outputs: Tuple of (encoded_layers, pooled_output)
150
+ `encoded_layers`: controlled by `output_all_encoded_layers` argument:
151
+ - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
152
+ of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
153
+ encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
154
+ - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
155
+ to the last attention block of shape [batch_size, sequence_length, hidden_size],
156
+ `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
157
+ classifier pretrained on top of the hidden state associated to the first character of the
158
+ input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
159
+
160
+ Example usage:
161
+ ```python
162
+ # Already been converted into WordPiece token ids
163
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
164
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
165
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
166
+ config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
167
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
168
+ model = BertModel(config=config)
169
+ all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
170
+ ```
171
+ """
172
+
173
+ def __init__(
174
+ self,
175
+ config,
176
+ add_pooling_layer: bool = True,
177
+ ):
178
+ super(BertModel, self).__init__(config)
179
+ self.embeddings = BertAlibiEmbeddings(config)
180
+ self.encoder = BertAlibiEncoder(config)
181
+ self.pooler = BertPooler(config) if add_pooling_layer else None
182
+ self.post_init()
183
+
184
+ def get_input_embeddings(self):
185
+ return self.embeddings.word_embeddings
186
+
187
+ def set_input_embeddings(self, value):
188
+ self.embeddings.word_embeddings = value
189
+
190
+ def forward(
191
+ self,
192
+ input_ids: torch.Tensor,
193
+ token_type_ids: Optional[torch.Tensor] = None,
194
+ attention_mask: Optional[torch.Tensor] = None,
195
+ position_ids: Optional[torch.Tensor] = None,
196
+ output_all_encoded_layers: Optional[bool] = False,
197
+ masked_tokens_mask: Optional[torch.Tensor] = None,
198
+ **kwargs,
199
+ ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
200
+ if attention_mask is None:
201
+ attention_mask = torch.ones_like(input_ids)
202
+ if token_type_ids is None:
203
+ token_type_ids = torch.zeros_like(input_ids)
204
+
205
+ embedding_output = self.embeddings(input_ids, token_type_ids, position_ids)
206
+
207
+ subset_mask = []
208
+ first_col_mask = []
209
+
210
+ if masked_tokens_mask is None:
211
+ subset_mask = None
212
+ else:
213
+ first_col_mask = torch.zeros_like(masked_tokens_mask)
214
+ first_col_mask[:, 0] = True
215
+ subset_mask = masked_tokens_mask | first_col_mask
216
+
217
+ encoder_outputs = self.encoder(
218
+ embedding_output,
219
+ attention_mask,
220
+ output_all_encoded_layers=output_all_encoded_layers,
221
+ subset_mask=subset_mask,
222
+ )
223
+
224
+ if masked_tokens_mask is None:
225
+ sequence_output = encoder_outputs[-1]
226
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
227
+ else:
228
+ # TD [2022-03-01]: the indexing here is very tricky.
229
+ attention_mask_bool = attention_mask.bool()
230
+ subset_idx = subset_mask[attention_mask_bool] # type: ignore
231
+ sequence_output = encoder_outputs[-1][masked_tokens_mask[attention_mask_bool][subset_idx]]
232
+ if self.pooler is not None:
233
+ pool_input = encoder_outputs[-1][first_col_mask[attention_mask_bool][subset_idx]]
234
+ pooled_output = self.pooler(pool_input, pool=False)
235
+ else:
236
+ pooled_output = None
237
+
238
+ if not output_all_encoded_layers:
239
+ encoder_outputs = sequence_output
240
+
241
+ if self.pooler is not None:
242
+ return encoder_outputs, pooled_output
243
+
244
+ return encoder_outputs, None
245
+
246
+
247
+ ###################
248
+ # Bert Heads
249
+ ###################
250
+ class BertLMPredictionHead(nn.Module):
251
+ def __init__(self, config, bert_model_embedding_weights):
252
+ super().__init__()
253
+ self.transform = BertPredictionHeadTransform(config)
254
+ # The output weights are the same as the input embeddings, but there is
255
+ # an output-only bias for each token.
256
+ self.decoder = nn.Linear(bert_model_embedding_weights.size(1), bert_model_embedding_weights.size(0))
257
+ self.decoder.weight = bert_model_embedding_weights
258
+
259
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
260
+ hidden_states = self.transform(hidden_states)
261
+ hidden_states = self.decoder(hidden_states)
262
+ return hidden_states
263
+
264
+
265
+ class BertOnlyMLMHead(nn.Module):
266
+ def __init__(self, config, bert_model_embedding_weights):
267
+ super().__init__()
268
+ self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
269
+
270
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
271
+ prediction_scores = self.predictions(sequence_output)
272
+ return prediction_scores
273
+
274
+
275
+ class BertOnlyNSPHead(nn.Module):
276
+ def __init__(self, config):
277
+ super().__init__()
278
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
279
+
280
+ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
281
+ seq_relationship_score = self.seq_relationship(pooled_output)
282
+ return seq_relationship_score
283
+
284
+
285
+ #####################
286
+ # Various Bert models
287
+ #####################
288
+
289
+
290
+ class BertForPreTraining(BertPreTrainedModel):
291
+ # TBD: Coming in Future Commit
292
+ pass
293
+
294
+
295
+ class BertLMHeadModel(BertPreTrainedModel):
296
+ # TBD: Coming in Future Commit
297
+ pass
298
+
299
+
300
+ class BertForMaskedLM(BertPreTrainedModel):
301
+ def __init__(self, config):
302
+ super().__init__(config)
303
+
304
+ if config.is_decoder:
305
+ warnings.warn(
306
+ "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
307
+ "bi-directional self-attention."
308
+ )
309
+
310
+ self.bert = BertModel(config, add_pooling_layer=False)
311
+ self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
312
+
313
+ # Initialize weights and apply final processing
314
+ self.post_init()
315
+
316
+ @classmethod
317
+ def from_composer(
318
+ cls,
319
+ pretrained_checkpoint,
320
+ state_dict=None,
321
+ cache_dir=None,
322
+ from_tf=False,
323
+ config=None,
324
+ *inputs,
325
+ **kwargs,
326
+ ):
327
+ """Load from pre-trained."""
328
+ model = cls(config, *inputs, **kwargs)
329
+ if from_tf:
330
+ raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
331
+
332
+ state_dict = torch.load(pretrained_checkpoint)
333
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
334
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
335
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
336
+
337
+ if len(missing_keys) > 0:
338
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
339
+ if len(unexpected_keys) > 0:
340
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
341
+
342
+ return model
343
+
344
+ def get_output_embeddings(self):
345
+ return self.cls.predictions.decoder
346
+
347
+ def set_output_embeddings(self, new_embeddings):
348
+ self.cls.predictions.decoder = new_embeddings
349
+
350
+ def forward(
351
+ self,
352
+ input_ids: Optional[torch.Tensor] = None,
353
+ attention_mask: Optional[torch.Tensor] = None,
354
+ token_type_ids: Optional[torch.Tensor] = None,
355
+ position_ids: Optional[torch.Tensor] = None,
356
+ head_mask: Optional[torch.Tensor] = None,
357
+ inputs_embeds: Optional[torch.Tensor] = None,
358
+ encoder_hidden_states: Optional[torch.Tensor] = None,
359
+ encoder_attention_mask: Optional[torch.Tensor] = None,
360
+ labels: Optional[torch.Tensor] = None,
361
+ output_attentions: Optional[bool] = None,
362
+ output_hidden_states: Optional[bool] = None,
363
+ return_dict: Optional[bool] = None,
364
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
365
+ # labels should be a `torch.LongTensor` of shape
366
+ # `(batch_size, sequence_length)`. These are used for computing the
367
+ # masked language modeling loss.
368
+ #
369
+ # Indices should be in `[-100, 0, ..., config.vocab_size]` (see
370
+ # `input_ids` docstring) Tokens with indices set to `-100` are ignored
371
+ # (masked), the loss is only computed for the tokens with labels in `[0,
372
+ # ..., config.vocab_size]`
373
+ #
374
+ # Prediction scores are only computed for masked tokens and the (bs,
375
+ # seqlen) dimensions are flattened
376
+ if (input_ids is not None) == (inputs_embeds is not None):
377
+ raise ValueError("Must specify either input_ids or input_embeds!")
378
+
379
+ if labels is None:
380
+ masked_tokens_mask = None
381
+ else:
382
+ masked_tokens_mask = labels > 0
383
+
384
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
385
+
386
+ outputs = self.bert(
387
+ input_ids,
388
+ attention_mask=attention_mask,
389
+ token_type_ids=token_type_ids,
390
+ position_ids=position_ids,
391
+ head_mask=head_mask,
392
+ inputs_embeds=inputs_embeds,
393
+ encoder_hidden_states=encoder_hidden_states,
394
+ encoder_attention_mask=encoder_attention_mask,
395
+ output_attentions=output_attentions,
396
+ output_hidden_states=output_hidden_states,
397
+ return_dict=return_dict,
398
+ masked_tokens_mask=masked_tokens_mask,
399
+ )
400
+
401
+ sequence_output = outputs[0]
402
+ prediction_scores = self.cls(sequence_output)
403
+
404
+ loss = None
405
+ if labels is not None:
406
+ # Compute loss
407
+ loss_fct = nn.CrossEntropyLoss()
408
+ masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
409
+ loss = loss_fct(prediction_scores, labels.flatten()[masked_token_idx])
410
+
411
+ assert input_ids is not None, "Coding error; please open an issue"
412
+ batch, seqlen = input_ids.shape[:2]
413
+ prediction_scores = rearrange(
414
+ index_put_first_axis(prediction_scores, masked_token_idx, batch * seqlen),
415
+ "(b s) d -> b s d",
416
+ b=batch,
417
+ )
418
+
419
+ if not return_dict:
420
+ output = (prediction_scores,) + outputs[2:]
421
+ return ((loss,) + output) if loss is not None else output
422
+
423
+ return MaskedLMOutput(
424
+ loss=loss,
425
+ logits=prediction_scores,
426
+ hidden_states=None,
427
+ attentions=None,
428
+ )
429
+
430
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **model_kwargs):
431
+ input_shape = input_ids.shape
432
+ effective_batch_size = input_shape[0]
433
+
434
+ # add a dummy token
435
+ if self.config.pad_token_id is None:
436
+ raise ValueError("The PAD token should be defined for generation")
437
+
438
+ attention_mask = torch.cat(
439
+ [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],
440
+ dim=-1,
441
+ )
442
+ dummy_token = torch.full(
443
+ (effective_batch_size, 1),
444
+ self.config.pad_token_id,
445
+ dtype=torch.long,
446
+ device=input_ids.device,
447
+ )
448
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
449
+
450
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
451
+
452
+
453
+ class BertForNextSentencePrediction(BertPreTrainedModel):
454
+ # TBD: Push in future commit
455
+ pass
456
+
457
+
458
+ class BertForSequenceClassification(BertPreTrainedModel):
459
+ """Bert Model transformer with a sequence classification/regression head.
460
+
461
+ This head is just a linear layer on top of the pooled output. Used for,
462
+ e.g., GLUE tasks.
463
+ """
464
+
465
+ def __init__(self, config):
466
+ super().__init__(config)
467
+ self.num_labels = config.num_labels
468
+ self.config = config
469
+
470
+ self.bert = BertModel(config)
471
+ classifier_dropout = (
472
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
473
+ )
474
+ self.dropout = nn.Dropout(classifier_dropout)
475
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
476
+
477
+ # Initialize weights and apply final processing
478
+ self.post_init()
479
+
480
+ @classmethod
481
+ def from_composer(
482
+ cls,
483
+ pretrained_checkpoint,
484
+ state_dict=None,
485
+ cache_dir=None,
486
+ from_tf=False,
487
+ config=None,
488
+ *inputs,
489
+ **kwargs,
490
+ ):
491
+ """Load from pre-trained."""
492
+ model = cls(config, *inputs, **kwargs)
493
+ if from_tf:
494
+ raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
495
+
496
+ state_dict = torch.load(pretrained_checkpoint)
497
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
498
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
499
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
500
+
501
+ if len(missing_keys) > 0:
502
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
503
+ if len(unexpected_keys) > 0:
504
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
505
+
506
+ return model
507
+
508
+ def forward(
509
+ self,
510
+ input_ids: Optional[torch.Tensor] = None,
511
+ attention_mask: Optional[torch.Tensor] = None,
512
+ token_type_ids: Optional[torch.Tensor] = None,
513
+ position_ids: Optional[torch.Tensor] = None,
514
+ head_mask: Optional[torch.Tensor] = None,
515
+ inputs_embeds: Optional[torch.Tensor] = None,
516
+ labels: Optional[torch.Tensor] = None,
517
+ output_attentions: Optional[bool] = None,
518
+ output_hidden_states: Optional[bool] = None,
519
+ return_dict: Optional[bool] = None,
520
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
521
+ # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
522
+ # Labels for computing the sequence classification/regression loss.
523
+ # Indices should be in `[0, ..., config.num_labels - 1]`.
524
+ # If `config.num_labels == 1` a regression loss is computed
525
+ # (mean-square loss). If `config.num_labels > 1` a classification loss
526
+ # is computed (cross-entropy).
527
+
528
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
529
+
530
+ outputs = self.bert(
531
+ input_ids,
532
+ attention_mask=attention_mask,
533
+ token_type_ids=token_type_ids,
534
+ position_ids=position_ids,
535
+ head_mask=head_mask,
536
+ inputs_embeds=inputs_embeds,
537
+ output_attentions=output_attentions,
538
+ output_hidden_states=output_hidden_states,
539
+ return_dict=return_dict,
540
+ )
541
+
542
+ pooled_output = outputs[1]
543
+
544
+ pooled_output = self.dropout(pooled_output)
545
+ logits = self.classifier(pooled_output)
546
+
547
+ loss = None
548
+ if labels is not None:
549
+ # Compute loss
550
+ if self.config.problem_type is None:
551
+ if self.num_labels == 1:
552
+ self.config.problem_type = "regression"
553
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
554
+ self.config.problem_type = "single_label_classification"
555
+ else:
556
+ self.config.problem_type = "multi_label_classification"
557
+
558
+ if self.config.problem_type == "regression":
559
+ loss_fct = nn.MSELoss()
560
+ if self.num_labels == 1:
561
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
562
+ else:
563
+ loss = loss_fct(logits, labels)
564
+ elif self.config.problem_type == "single_label_classification":
565
+ loss_fct = nn.CrossEntropyLoss()
566
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
567
+ elif self.config.problem_type == "multi_label_classification":
568
+ loss_fct = nn.BCEWithLogitsLoss()
569
+ loss = loss_fct(logits, labels)
570
+
571
+ if not return_dict:
572
+ output = (logits,) + outputs[2:]
573
+ return ((loss,) + output) if loss is not None else output
574
+
575
+ return SequenceClassifierOutput(
576
+ loss=loss,
577
+ logits=logits,
578
+ hidden_states=None,
579
+ attentions=None,
580
+ )
581
+
582
+
583
+ class BertForMultipleChoice(BertPreTrainedModel):
584
+ """
585
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
586
+ softmax) e.g. for RocStories/SWAG tasks.
587
+ """
588
+
589
+ def __init__(self, config):
590
+ super().__init__(config)
591
+ self.num_labels = config.num_labels
592
+ self.config = config
593
+
594
+ self.bert = BertModel(config)
595
+ classifier_dropout = (
596
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
597
+ )
598
+ self.dropout = nn.Dropout(classifier_dropout)
599
+
600
+ # In multiple choice tasks, all choices are submitted in a batch, and
601
+ # we compute a logit for each option independently. The logits are then
602
+ # normalized in the forward pass to get a probability distribution over
603
+ # the choices.
604
+ self.classifier = nn.Linear(config.hidden_size, 1)
605
+
606
+ # Initialize weights and apply final processing
607
+ self.post_init()
608
+
609
+ @classmethod
610
+ def from_composer(
611
+ cls,
612
+ pretrained_checkpoint,
613
+ state_dict=None,
614
+ cache_dir=None,
615
+ from_tf=False,
616
+ config=None,
617
+ *inputs,
618
+ **kwargs,
619
+ ):
620
+ """Load from pre-trained."""
621
+ model = cls(config, *inputs, **kwargs)
622
+ if from_tf:
623
+ raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
624
+
625
+ state_dict = torch.load(pretrained_checkpoint)
626
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
627
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
628
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
629
+
630
+ if len(missing_keys) > 0:
631
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
632
+ if len(unexpected_keys) > 0:
633
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
634
+
635
+ return model
636
+
637
+ def forward(
638
+ self,
639
+ input_ids: Optional[torch.Tensor] = None,
640
+ attention_mask: Optional[torch.Tensor] = None,
641
+ token_type_ids: Optional[torch.Tensor] = None,
642
+ position_ids: Optional[torch.Tensor] = None,
643
+ head_mask: Optional[torch.Tensor] = None,
644
+ inputs_embeds: Optional[torch.Tensor] = None,
645
+ labels: Optional[torch.Tensor] = None,
646
+ output_attentions: Optional[bool] = None,
647
+ output_hidden_states: Optional[bool] = None,
648
+ return_dict: Optional[bool] = None,
649
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
650
+ r"""
651
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
652
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
653
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
654
+ `input_ids` above)
655
+ """
656
+
657
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
658
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
659
+
660
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
661
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
662
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
663
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
664
+ inputs_embeds = (
665
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
666
+ if inputs_embeds is not None
667
+ else None
668
+ )
669
+
670
+ outputs = self.bert(
671
+ input_ids,
672
+ attention_mask=attention_mask,
673
+ token_type_ids=token_type_ids,
674
+ position_ids=position_ids,
675
+ head_mask=head_mask,
676
+ inputs_embeds=inputs_embeds,
677
+ output_attentions=output_attentions,
678
+ output_hidden_states=output_hidden_states,
679
+ return_dict=return_dict,
680
+ )
681
+
682
+ pooled_output = outputs[1]
683
+
684
+ pooled_output = self.dropout(pooled_output)
685
+ logits = self.classifier(pooled_output)
686
+ reshaped_logits = logits.view(-1, num_choices)
687
+
688
+ loss = None
689
+ if labels is not None:
690
+ loss_fct = nn.CrossEntropyLoss()
691
+ loss = loss_fct(reshaped_logits, labels)
692
+
693
+ if not return_dict:
694
+ output = (reshaped_logits,) + outputs[2:]
695
+ return ((loss,) + output) if loss is not None else output
696
+
697
+ return MultipleChoiceModelOutput(
698
+ loss=loss,
699
+ logits=reshaped_logits,
700
+ hidden_states=None,
701
+ attentions=None,
702
+ )
703
+
704
+
705
+ class BertForTokenClassification(BertPreTrainedModel):
706
+ # TBD: Push in future commit
707
+ pass
708
+
709
+
710
+ class BertForQuestionAnswering(BertPreTrainedModel):
711
+ """Bert Model with a span classification head.
712
+
713
+ This is used for extractive question-answering tasks like SQuAD (a linear
714
+ layers on top of the hidden states' output to compute `span start logits`
715
+ and `span end logits`).
716
+ """
717
+
718
+ # TBD: Push in future commit
719
+
720
+
721
+ ###################
722
+ # FlexBert Heads
723
+ ###################
724
+
725
+
726
+ class FlexBertPredictionHead(nn.Module):
727
+ def __init__(self, config: FlexBertConfig):
728
+ super().__init__()
729
+ self.config = config
730
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.head_pred_bias)
731
+ self.act = get_act_fn(config.head_pred_act) if config.head_pred_act else nn.Identity()
732
+ self.norm = (
733
+ get_norm_layer(config, compiled_norm=config.compile_model) if config.head_pred_norm else nn.Identity()
734
+ )
735
+
736
+ def _init_weights(self, reset_params: bool = False):
737
+ if reset_params:
738
+ self.norm.reset_parameters()
739
+ init_weights(self.config, self.dense, layer_dim=self.config.hidden_size, type_of_module=ModuleType.in_module)
740
+
741
+ def reset_parameters(self):
742
+ self._init_weights(reset_params=True)
743
+
744
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
745
+ return self.norm(self.act(self.dense(hidden_states)))
746
+
747
+
748
+ class FlexBertPoolingHead(nn.Module):
749
+ def __init__(self, config: FlexBertConfig):
750
+ super().__init__()
751
+ self.config = config
752
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.head_class_bias)
753
+ self.act = get_act_fn(config.head_class_act) if config.head_class_act else nn.Identity()
754
+ self.norm = get_norm_layer(config) if config.head_class_norm else nn.Identity()
755
+ self.drop = torch.nn.Dropout(config.head_class_dropout) if config.head_class_dropout > 0 else nn.Identity()
756
+ self.pooling_type = config.pooling_type
757
+
758
+ def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor:
759
+ if pool:
760
+ if self.pooling_type == "cls":
761
+ output = hidden_states[:, 0]
762
+ elif self.pooling_type == "mean":
763
+ output = hidden_states.mean(dim=1)
764
+ elif self.pooling_type == "max":
765
+ output = hidden_states.max(dim=1)[0]
766
+ else:
767
+ output = hidden_states
768
+
769
+ return self.drop(self.norm(self.act(self.dense(output))))
770
+
771
+ def _init_weights(self, reset_params: bool = False):
772
+ init_weights(self.config, self.dense, self.config.hidden_size, type_of_module=ModuleType.out_module)
773
+ if reset_params and hasattr(self.norm, "reset_parameters"):
774
+ self.norm.reset_parameters()
775
+
776
+ def reset_parameters(self):
777
+ self._init_weights(reset_params=True)
778
+
779
+
780
+ ###################
781
+ # FlexBert Models
782
+ ###################
783
+
784
+
785
+ @dataclass
786
+ class MaskedLMOutput(ModelOutput):
787
+ """
788
+ Base class for masked language models outputs.
789
+
790
+ Args:
791
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
792
+ Masked language modeling (MLM) loss.
793
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
794
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
795
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
796
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
797
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
798
+
799
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
800
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
801
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
802
+ sequence_length)`.
803
+
804
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
805
+ heads.
806
+ """
807
+
808
+ loss: Optional[torch.FloatTensor] = None
809
+ logits: torch.FloatTensor = None
810
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
811
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
812
+ indices: Optional[torch.LongTensor] = None
813
+ cu_seqlens: Optional[torch.LongTensor] = None
814
+ max_seqlen: Optional[int] = None
815
+ batch_size: Optional[int] = None
816
+ seq_len: Optional[int] = None
817
+ labels: Optional[torch.LongTensor] = None
818
+
819
+
820
+ @dataclass
821
+ class MaskedLMOutputZLoss(ModelOutput):
822
+ """
823
+ Base class for masked language models outputs.
824
+
825
+ Args:
826
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
827
+ Masked language modeling (MLM) loss.
828
+ ce_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
829
+ Cross entropy loss.
830
+ z_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
831
+ Z loss.
832
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
833
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
834
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
835
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
836
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
837
+
838
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
839
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
840
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
841
+ sequence_length)`.
842
+
843
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
844
+ heads.
845
+ indices (`torch.LongTensor` of shape `(batch_size,)`):
846
+ Indices of the tokens to be masked.
847
+ """
848
+
849
+ loss: Optional[torch.FloatTensor] = None
850
+ ce_loss: Optional[torch.FloatTensor] = None
851
+ z_loss: Optional[torch.FloatTensor] = None
852
+ logits: torch.FloatTensor = None
853
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
854
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
855
+ indices: Optional[torch.LongTensor] = None
856
+ cu_seqlens: Optional[torch.LongTensor] = None
857
+ max_seqlen: Optional[int] = None
858
+ batch_size: Optional[int] = None
859
+ seq_len: Optional[int] = None
860
+ labels: Optional[torch.LongTensor] = None
861
+
862
+
863
+ class FlexBertPreTrainedModel(BertPreTrainedModel):
864
+ """
865
+ An abstract class to handle custom weights initialization of modules
866
+ """
867
+
868
+ def _init_module_weights(self, module: nn.Module):
869
+ """
870
+ Custom weight init of modules using src.bert_layers.initialization.init_weights
871
+ Currently only supports init of embedding modules
872
+ """
873
+ assert isinstance(module, nn.Module)
874
+ if isinstance(module, nn.Embedding):
875
+ init_weights(self.config, module, type_of_module=ModuleType.emb)
876
+ else:
877
+ raise NotImplementedError("Custom weight init for the given module is not supported")
878
+
879
+
880
+ class FlexBertModel(FlexBertPreTrainedModel):
881
+ """Overall BERT model.
882
+
883
+ Args:
884
+ config: a BertConfig class instance with the configuration to build a new model
885
+
886
+ Inputs:
887
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
888
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
889
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
890
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
891
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
892
+ a `sentence B` token (see BERT paper for more details).
893
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
894
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
895
+ input sequence length in the current batch. It's the mask that we typically use for attention when
896
+ a batch has varying length sentences.
897
+ `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
898
+
899
+ Outputs: Tuple of (encoded_layers, pooled_output)
900
+ `encoded_layers`: controlled by `output_all_encoded_layers` argument:
901
+ - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
902
+ of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
903
+ encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
904
+ - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
905
+ to the last attention block of shape [batch_size, sequence_length, hidden_size],
906
+ `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
907
+ classifier pretrained on top of the hidden state associated to the first character of the
908
+ input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
909
+
910
+ Example usage:
911
+ ```python
912
+ # Already been converted into WordPiece token ids
913
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
914
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
915
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
916
+ config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
917
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
918
+ model = BertModel(config=config)
919
+ all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
920
+ ```
921
+ """
922
+
923
+ def __init__(self, config: FlexBertConfig):
924
+ super().__init__(config)
925
+ self.embeddings = get_embedding_layer(config)
926
+ self.encoder = get_encoder_layer(config)
927
+ if config.final_norm:
928
+ # if we use prenorm attention we need to add a final norm
929
+ self.final_norm = get_norm_layer(config)
930
+ else:
931
+ self.final_norm = None
932
+ self.unpad_embeddings = config.unpad_embeddings
933
+
934
+ def post_init(self):
935
+ self._init_weights(reset_params=False)
936
+ self._backward_compatibility_gradient_checkpointing()
937
+
938
+ def get_input_embeddings(self):
939
+ return self.embeddings.tok_embeddings
940
+
941
+ def set_input_embeddings(self, value):
942
+ self.embeddings.tok_embeddings = value
943
+
944
+ def forward(
945
+ self,
946
+ input_ids: torch.Tensor,
947
+ attention_mask: Optional[torch.Tensor] = None,
948
+ position_ids: Optional[torch.Tensor] = None,
949
+ indices: Optional[torch.Tensor] = None,
950
+ cu_seqlens: Optional[torch.Tensor] = None,
951
+ max_seqlen: Optional[int] = None,
952
+ **kwargs,
953
+ ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
954
+ if attention_mask is None:
955
+ attention_mask = torch.ones_like(input_ids)
956
+
957
+ embedding_output = self.embeddings(input_ids, position_ids)
958
+
959
+ encoder_outputs = self.encoder(
960
+ hidden_states=embedding_output,
961
+ attention_mask=attention_mask,
962
+ indices=indices,
963
+ cu_seqlens=cu_seqlens,
964
+ max_seqlen=max_seqlen,
965
+ )
966
+
967
+ if self.final_norm is not None:
968
+ encoder_outputs = self.final_norm(encoder_outputs)
969
+ return encoder_outputs
970
+
971
+ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
972
+ assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
973
+ if module:
974
+ self._init_module_weights(module)
975
+ else:
976
+ assert isinstance(reset_params, bool)
977
+ self.embeddings._init_weights(reset_params=reset_params)
978
+ self.encoder._init_weights(reset_params=reset_params)
979
+
980
+ if reset_params and self.config.final_norm:
981
+ self.final_norm.reset_parameters()
982
+
983
+ def reset_parameters(self):
984
+ self._init_weights(reset_params=True)
985
+
986
+ def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
987
+ """Returns the number of parameters in the model.
988
+
989
+ Args:
990
+ count_embeddings: count the parameters in the embeddings layer, excluding position embeddings.
991
+ trainable: only count trainable parameters.
992
+ """
993
+ params = sum([_count_parameters(layer, trainable) for layer in self.encoder.layers])
994
+ if count_embeddings:
995
+ params += _count_parameters(self.embeddings, trainable)
996
+ if hasattr(self.embeddings, "position_embeddings"):
997
+ params -= _count_parameters(self.embeddings.position_embeddings, trainable)
998
+ return params
999
+
1000
+
1001
+ class FlexBertForMaskedLM(FlexBertPreTrainedModel):
1002
+ def __init__(self, config: FlexBertConfig):
1003
+ super().__init__(config)
1004
+ self.bert = FlexBertModel(config)
1005
+ self.head = FlexBertPredictionHead(config)
1006
+
1007
+ if config.tie_word_embeddings:
1008
+ decoder_weights = self.bert.embeddings.tok_embeddings.weight
1009
+ else:
1010
+ decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight
1011
+ self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias)
1012
+ self.decoder.weight = decoder_weights
1013
+
1014
+ self.loss_fn = nn.CrossEntropyLoss() if not hasattr(config, "loss_function") else get_loss_fn(config)
1015
+ self.fa_ce = getattr(config, "loss_function", "cross_entropy") == "fa_cross_entropy"
1016
+ self.return_z_loss = config.loss_kwargs.get("return_z_loss", False)
1017
+ self.unpad_embeddings = config.unpad_embeddings
1018
+ self.pad_logits = config.pad_logits
1019
+ self.compile_model = config.compile_model
1020
+ self.masked_prediction = config.masked_prediction
1021
+
1022
+ # Initialize weights and apply final processing
1023
+ self._init_weights(reset_params=False)
1024
+
1025
+ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1026
+ assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1027
+ if module:
1028
+ self._init_module_weights(module)
1029
+ else:
1030
+ assert isinstance(reset_params, bool)
1031
+ self.bert._init_weights(reset_params=reset_params)
1032
+ self.head._init_weights(reset_params=reset_params)
1033
+
1034
+ # Output weights.
1035
+ if not self.config.tie_word_embeddings:
1036
+ init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
1037
+
1038
+ @classmethod
1039
+ def from_composer(
1040
+ cls,
1041
+ pretrained_checkpoint,
1042
+ state_dict=None,
1043
+ cache_dir=None,
1044
+ from_tf=False,
1045
+ config=None,
1046
+ *inputs,
1047
+ **kwargs,
1048
+ ):
1049
+ """Load from pre-trained."""
1050
+ model = cls(config, *inputs, **kwargs)
1051
+ if from_tf:
1052
+ raise ValueError("FlexBERT does not support loading TensorFlow weights.")
1053
+
1054
+ state_dict = torch.load(pretrained_checkpoint)
1055
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
1056
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
1057
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
1058
+
1059
+ if len(missing_keys) > 0:
1060
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
1061
+ if len(unexpected_keys) > 0:
1062
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
1063
+
1064
+ return model
1065
+
1066
+ def get_output_embeddings(self):
1067
+ return self.decoder
1068
+
1069
+ def set_output_embeddings(self, new_embeddings):
1070
+ self.decoder = new_embeddings
1071
+
1072
+ @torch.no_grad()
1073
+ def unpad_inputs(
1074
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor, labels: torch.Tensor
1075
+ ):
1076
+ return unpad_input(input_ids, attention_mask, position_ids, labels)
1077
+
1078
+ @torch.no_grad()
1079
+ def pad_inputs(
1080
+ self,
1081
+ inputs: torch.Tensor,
1082
+ indices: torch.Tensor,
1083
+ batch_size: int,
1084
+ seqlen: int,
1085
+ labels: Optional[torch.Tensor] = None,
1086
+ ignore_index: int = -100,
1087
+ ):
1088
+ return pad_input(
1089
+ inputs=inputs, indices=indices, batch=batch_size, seqlen=seqlen, labels=labels, ignore_index=ignore_index
1090
+ )
1091
+
1092
+ @torch.compile(dynamic=True)
1093
+ def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
1094
+ return self.decoder(self.head(output))
1095
+
1096
+ def forward(
1097
+ self,
1098
+ input_ids: Optional[torch.Tensor],
1099
+ attention_mask: Optional[torch.Tensor] = None,
1100
+ position_ids: Optional[torch.Tensor] = None,
1101
+ labels: Optional[torch.Tensor] = None,
1102
+ return_dict: Optional[bool] = None,
1103
+ indices: Optional[torch.Tensor] = None,
1104
+ cu_seqlens: Optional[torch.Tensor] = None,
1105
+ max_seqlen: Optional[int] = None,
1106
+ batch_size: Optional[int] = None,
1107
+ seq_len: Optional[int] = None,
1108
+ **kwargs,
1109
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1110
+ # labels should be a `torch.LongTensor` of shape
1111
+ # `(batch_size, sequence_length)`. These are used for computing the
1112
+ # masked language modeling loss.
1113
+ #
1114
+ # Indices should be in `[-100, 0, ..., config.vocab_size]` (see
1115
+ # `input_ids` docstring) Tokens with indices set to `-100` are ignored
1116
+ # (masked), the loss is only computed for the tokens with labels in `[0,
1117
+ # ..., config.vocab_size]`
1118
+ #
1119
+ # Prediction scores are only computed for masked tokens and the (bs,
1120
+ # seqlen) dimensions are flattened
1121
+
1122
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1123
+
1124
+ if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
1125
+ batch_size, seq_len = input_ids.shape[:2]
1126
+ input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
1127
+ input_ids, attention_mask, position_ids, labels
1128
+ )
1129
+
1130
+ output = self.bert(
1131
+ input_ids,
1132
+ attention_mask=attention_mask,
1133
+ position_ids=position_ids,
1134
+ indices=indices,
1135
+ cu_seqlens=cu_seqlens,
1136
+ max_seqlen=max_seqlen,
1137
+ )
1138
+
1139
+ if self.masked_prediction and labels is not None:
1140
+ # flatten labels and output first
1141
+ labels = labels.view(-1)
1142
+ output = output.view(labels.shape[0], -1)
1143
+
1144
+ # then filter out the non-masked tokens
1145
+ mask_tokens = labels != self.loss_fn.ignore_index
1146
+ output = output[mask_tokens]
1147
+ labels = labels[mask_tokens]
1148
+
1149
+ if self.compile_model:
1150
+ logits = self.compiled_head(output)
1151
+ else:
1152
+ logits = self.decoder(self.head(output))
1153
+
1154
+ loss = None
1155
+ if labels is not None:
1156
+ if not self.masked_prediction:
1157
+ labels = labels.view(-1)
1158
+ logits = logits.view(labels.shape[0], -1)
1159
+
1160
+ if self.return_z_loss:
1161
+ loss, z_loss = self.loss_fn(logits, labels)
1162
+ if self.pad_logits:
1163
+ return MaskedLMOutputZLoss(
1164
+ loss=loss,
1165
+ ce_loss=loss.detach().clone() - z_loss,
1166
+ z_loss=z_loss,
1167
+ logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0],
1168
+ hidden_states=None,
1169
+ attentions=None,
1170
+ )
1171
+ else:
1172
+ return MaskedLMOutputZLoss(
1173
+ loss=loss,
1174
+ ce_loss=loss.detach().clone() - z_loss,
1175
+ z_loss=z_loss,
1176
+ logits=logits,
1177
+ hidden_states=None,
1178
+ attentions=None,
1179
+ indices=indices,
1180
+ cu_seqlens=cu_seqlens,
1181
+ max_seqlen=max_seqlen,
1182
+ batch_size=batch_size,
1183
+ seq_len=seq_len,
1184
+ labels=labels,
1185
+ )
1186
+ else:
1187
+ loss = self.loss_fn(logits, labels)
1188
+
1189
+ if self.pad_logits:
1190
+ return MaskedLMOutput(
1191
+ loss=loss,
1192
+ logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0],
1193
+ hidden_states=None,
1194
+ attentions=None,
1195
+ )
1196
+ else:
1197
+ return MaskedLMOutput(
1198
+ loss=loss,
1199
+ logits=logits,
1200
+ hidden_states=None,
1201
+ attentions=None,
1202
+ indices=indices,
1203
+ cu_seqlens=cu_seqlens,
1204
+ max_seqlen=max_seqlen,
1205
+ batch_size=batch_size,
1206
+ seq_len=seq_len,
1207
+ labels=labels,
1208
+ )
1209
+
1210
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **model_kwargs):
1211
+ input_shape = input_ids.shape
1212
+ effective_batch_size = input_shape[0]
1213
+
1214
+ # add a dummy token
1215
+ if self.config.pad_token_id is None:
1216
+ raise ValueError("The PAD token should be defined for generation")
1217
+
1218
+ attention_mask = torch.cat(
1219
+ [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],
1220
+ dim=-1,
1221
+ )
1222
+ dummy_token = torch.full(
1223
+ (effective_batch_size, 1),
1224
+ self.config.pad_token_id,
1225
+ dtype=torch.long,
1226
+ device=input_ids.device,
1227
+ )
1228
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1229
+
1230
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1231
+
1232
+ def get_number_parameters(
1233
+ self, count_embeddings: bool = True, count_decoder: bool = False, trainable: bool = True
1234
+ ) -> int:
1235
+ """Returns the number of parameters in the model.
1236
+
1237
+ Args:
1238
+ count_embeddings: count the parameters in the embeddings layer, excluding position embeddings.
1239
+ count_decoder: count the parameters in the decoder layer if weights are not tied.
1240
+ trainable: only count trainable parameters.
1241
+ """
1242
+ params = self.bert.get_number_parameters(count_embeddings, trainable)
1243
+ params += _count_parameters(self.head, trainable)
1244
+ if count_decoder and not self.config.tie_word_embeddings:
1245
+ params += _count_parameters(self.decoder, trainable)
1246
+ return params
1247
+
1248
+
1249
+ class FlexBertForSequenceClassification(FlexBertPreTrainedModel):
1250
+ """Bert Model transformer with a sequence classification/regression head.
1251
+
1252
+ This head is just a linear layer on top of the pooled output. Used for,
1253
+ e.g., GLUE tasks.
1254
+ """
1255
+
1256
+ def __init__(self, config: FlexBertConfig):
1257
+ super().__init__(config)
1258
+ self.num_labels = config.num_labels
1259
+ self.config = config
1260
+
1261
+ self.bert = FlexBertModel(config)
1262
+ self.head = FlexBertPoolingHead(config)
1263
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1264
+
1265
+ # Initialize weights and apply final processing
1266
+ self._init_weights(reset_params=False)
1267
+
1268
+ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1269
+ assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1270
+ if module:
1271
+ self._init_module_weights(module)
1272
+ else:
1273
+ assert isinstance(reset_params, bool)
1274
+ self.bert._init_weights(reset_params=reset_params)
1275
+ self.head._init_weights(reset_params=reset_params)
1276
+ init_weights(self.config, self.classifier, self.config.hidden_size, type_of_module=ModuleType.final_out)
1277
+
1278
+ @classmethod
1279
+ def from_composer(
1280
+ cls,
1281
+ pretrained_checkpoint,
1282
+ state_dict=None,
1283
+ cache_dir=None,
1284
+ from_tf=False,
1285
+ config=None,
1286
+ *inputs,
1287
+ **kwargs,
1288
+ ):
1289
+ """Load from pre-trained."""
1290
+ model = cls(config, *inputs, **kwargs)
1291
+ if from_tf:
1292
+ raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
1293
+
1294
+ state_dict = torch.load(pretrained_checkpoint)
1295
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
1296
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
1297
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
1298
+
1299
+ if len(missing_keys) > 0:
1300
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
1301
+ if len(unexpected_keys) > 0:
1302
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
1303
+
1304
+ return model
1305
+
1306
+ def forward(
1307
+ self,
1308
+ input_ids: Optional[torch.Tensor] = None,
1309
+ attention_mask: Optional[torch.Tensor] = None,
1310
+ position_ids: Optional[torch.Tensor] = None,
1311
+ labels: Optional[torch.Tensor] = None,
1312
+ return_dict: Optional[bool] = None,
1313
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1314
+ # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1315
+ # Labels for computing the sequence classification/regression loss.
1316
+ # Indices should be in `[0, ..., config.num_labels - 1]`.
1317
+ # If `config.num_labels == 1` a regression loss is computed
1318
+ # (mean-square loss). If `config.num_labels > 1` a classification loss
1319
+ # is computed (cross-entropy).
1320
+
1321
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1322
+
1323
+ output = self.bert(
1324
+ input_ids,
1325
+ attention_mask=attention_mask,
1326
+ position_ids=position_ids,
1327
+ )
1328
+
1329
+ pooled_output = self.head(output)
1330
+ logits = self.classifier(pooled_output)
1331
+
1332
+ loss = None
1333
+ if labels is not None:
1334
+ # Compute loss
1335
+ if self.config.problem_type is None:
1336
+ if self.num_labels == 1:
1337
+ self.config.problem_type = "regression"
1338
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1339
+ self.config.problem_type = "single_label_classification"
1340
+ else:
1341
+ self.config.problem_type = "multi_label_classification"
1342
+
1343
+ if self.config.problem_type == "regression":
1344
+ loss_fct = nn.MSELoss()
1345
+ if self.num_labels == 1:
1346
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1347
+ else:
1348
+ loss = loss_fct(logits, labels)
1349
+ elif self.config.problem_type == "single_label_classification":
1350
+ loss_fct = nn.CrossEntropyLoss()
1351
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1352
+ elif self.config.problem_type == "multi_label_classification":
1353
+ loss_fct = nn.BCEWithLogitsLoss()
1354
+ loss = loss_fct(logits, labels)
1355
+
1356
+ if not return_dict:
1357
+ output = (logits,) + output
1358
+ return ((loss,) + output) if loss is not None else output
1359
+
1360
+ return SequenceClassifierOutput(
1361
+ loss=loss,
1362
+ logits=logits,
1363
+ hidden_states=None,
1364
+ attentions=None,
1365
+ )
1366
+
1367
+ def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
1368
+ """Returns the number of parameters in the model.
1369
+
1370
+ Args:
1371
+ count_embeddings: count the parameters in the embeddings layer, excluding position embeddings.
1372
+ trainable: only count trainable parameters.
1373
+ """
1374
+ params = self.bert.get_number_parameters(count_embeddings, trainable)
1375
+ params += _count_parameters(self.head, trainable)
1376
+ params += _count_parameters(self.classifier, trainable)
1377
+ return params
1378
+
1379
+
1380
+ class FlexBertForMultipleChoice(FlexBertPreTrainedModel):
1381
+ """
1382
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1383
+ softmax) e.g. for RocStories/SWAG tasks.
1384
+ """
1385
+
1386
+ def __init__(self, config: FlexBertConfig):
1387
+ super().__init__(config)
1388
+ self.num_labels = config.num_labels
1389
+ self.config = config
1390
+
1391
+ self.bert = FlexBertModel(config)
1392
+ self.head = FlexBertPoolingHead(config)
1393
+
1394
+ # In multiple choice tasks, all choices are submitted in a batch, and
1395
+ # we compute a logit for each option independently. The logits are then
1396
+ # normalized in the forward pass to get a probability distribution over
1397
+ # the choices.
1398
+ self.classifier = nn.Linear(config.hidden_size, 1)
1399
+
1400
+ # Initialize weights and apply final processing
1401
+ self._init_weights(reset_params=False)
1402
+
1403
+ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1404
+ assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1405
+ if module:
1406
+ self._init_module_weights(module)
1407
+ else:
1408
+ assert isinstance(reset_params, bool)
1409
+ self.bert._init_weights(reset_params=reset_params)
1410
+ self.head._init_weights(reset_params=reset_params)
1411
+ init_weights(self.config, self.classifier, self.config.hidden_size, type_of_module=ModuleType.final_out)
1412
+
1413
+ @classmethod
1414
+ def from_composer(
1415
+ cls,
1416
+ pretrained_checkpoint,
1417
+ state_dict=None,
1418
+ cache_dir=None,
1419
+ from_tf=False,
1420
+ config=None,
1421
+ *inputs,
1422
+ **kwargs,
1423
+ ):
1424
+ """Load from pre-trained."""
1425
+ model = cls(config, *inputs, **kwargs)
1426
+ if from_tf:
1427
+ raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
1428
+
1429
+ state_dict = torch.load(pretrained_checkpoint)
1430
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
1431
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
1432
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
1433
+
1434
+ if len(missing_keys) > 0:
1435
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
1436
+ if len(unexpected_keys) > 0:
1437
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
1438
+
1439
+ return model
1440
+
1441
+ def forward(
1442
+ self,
1443
+ input_ids: Optional[torch.Tensor] = None,
1444
+ attention_mask: Optional[torch.Tensor] = None,
1445
+ position_ids: Optional[torch.Tensor] = None,
1446
+ labels: Optional[torch.Tensor] = None,
1447
+ return_dict: Optional[bool] = None,
1448
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1449
+ # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1450
+ # Labels for computing the sequence classification/regression loss.
1451
+ # Indices should be in `[0, ..., config.num_labels - 1]`.
1452
+ # If `config.num_labels == 1` a regression loss is computed
1453
+ # (mean-square loss). If `config.num_labels > 1` a classification loss
1454
+ # is computed (cross-entropy).
1455
+
1456
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1457
+ num_choices = input_ids.shape[1]
1458
+
1459
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1460
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1461
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1462
+
1463
+ output = self.bert(
1464
+ input_ids,
1465
+ attention_mask=attention_mask,
1466
+ position_ids=position_ids,
1467
+ )
1468
+
1469
+ pooled_output = self.head(output)
1470
+ logits = self.classifier(pooled_output)
1471
+ reshaped_logits = logits.view(-1, num_choices)
1472
+
1473
+ loss = None
1474
+ if labels is not None:
1475
+ loss_fct = nn.CrossEntropyLoss()
1476
+ loss = loss_fct(reshaped_logits, labels)
1477
+
1478
+ if not return_dict:
1479
+ output = (reshaped_logits,) + output
1480
+ return ((loss,) + output) if loss is not None else output
1481
+
1482
+ return MultipleChoiceModelOutput(
1483
+ loss=loss,
1484
+ logits=reshaped_logits,
1485
+ hidden_states=None,
1486
+ attentions=None,
1487
+ )
1488
+
1489
+ def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
1490
+ """Returns the number of parameters in the model.
1491
+
1492
+ Args:
1493
+ count_embeddings: count the parameters in the embeddings layer, excluding position embeddings.
1494
+ trainable: only count trainable parameters.
1495
+ """
1496
+ params = self.bert.get_number_parameters(count_embeddings, trainable)
1497
+ params += _count_parameters(self.head, trainable)
1498
+ params += _count_parameters(self.classifier, trainable)
1499
+ return params
1500
+
1501
+
1502
+ def init_model_from_pretrained(
1503
+ pretrained_model: FlexBertModel,
1504
+ new_model: FlexBertModel,
1505
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
1506
+ ):
1507
+ """
1508
+ Initialize the new model from the pretrained model.
1509
+
1510
+ This method uses Gopher layer scaling and Phi-style weight tiling as selected by `mode`.
1511
+ The new model must have the same or more layers and the same or larger dimensions than the pretrained model.
1512
+
1513
+ Args:
1514
+ pretrained_model (FlexBertModel): The smaller, pre-trained model
1515
+ new_model (FlexBertModel): The larger model to be initialized
1516
+ mode (Union[str, TileMode]): The Phi-style weight tiling mode to use
1517
+
1518
+ This function assumes that the new_model has more layers and a larger hidden size
1519
+ than the pretrained_model, but the same vocabulary size.
1520
+ """
1521
+
1522
+ # Tile embeddings
1523
+ assert isinstance(
1524
+ new_model.embeddings, type(pretrained_model.embeddings)
1525
+ ), f"Pretrained and new_model layers must be the same type, got {type(new_model.embeddings)} and {type(pretrained_model.embeddings)}"
1526
+ assert isinstance(
1527
+ new_model.embeddings,
1528
+ (FlexBertAbsoluteEmbeddings, FlexBertSansPositionEmbeddings, FlexBertCompiledSansPositionEmbeddings),
1529
+ ), f"Unsupported embedding layer type: {type(new_model.embeddings)}"
1530
+
1531
+ tile_embedding(pretrained_model.embeddings.tok_embeddings, new_model.embeddings.tok_embeddings, mode=mode)
1532
+ if isinstance(pretrained_model.embeddings, FlexBertAbsoluteEmbeddings):
1533
+ tile_embedding(pretrained_model.embeddings.pos_embeddings, new_model.embeddings.pos_embeddings, mode=mode)
1534
+
1535
+ if hasattr(pretrained_model.embeddings, "norm"):
1536
+ tile_norm(pretrained_model.embeddings.norm, new_model.embeddings.norm, mode=mode)
1537
+
1538
+ # Tile encoder layers
1539
+ assert isinstance(
1540
+ pretrained_model.encoder, (FlexBertUnpadEncoder, FlexBertPaddedEncoder)
1541
+ ), f"Unsupported encoder layer type: {type(pretrained_model.encoder)}"
1542
+ assert isinstance(
1543
+ new_model.encoder, type(pretrained_model.encoder)
1544
+ ), f"Pretrained and new_model encoder layers must be the same type, got {type(new_model.encoder)} and {type(pretrained_model.encoder)}"
1545
+
1546
+ # Calculate the layer mapping
1547
+ pretrained_layers = len(pretrained_model.encoder.layers)
1548
+ new_layers = len(new_model.encoder.layers)
1549
+ layer_mapping = [round(i * pretrained_layers / new_layers) for i in range(new_layers)]
1550
+
1551
+ # Initialize layers
1552
+ for new_model_idx, pretrained_idx in enumerate(layer_mapping):
1553
+ new_model_layer = new_model.encoder.layers[new_model_idx]
1554
+ pretrained_layer = pretrained_model.encoder.layers[pretrained_idx]
1555
+
1556
+ # first tile the PreNorm/PostNorm layers
1557
+ assert isinstance(
1558
+ new_model_layer, type(pretrained_layer)
1559
+ ), f"Pretrained and new_model prenorm/postnorm layers must be the same type, got {type(new_model_layer)} and {type(pretrained_layer)}"
1560
+ assert isinstance(
1561
+ new_model_layer,
1562
+ (
1563
+ FlexBertUnpadPreNormLayer,
1564
+ FlexBertCompileUnpadPreNormLayer,
1565
+ FlexBertUnpadParallelPreNormLayer,
1566
+ FlexBertUnpadPostNormLayer,
1567
+ FlexBertPaddedPreNormLayer,
1568
+ FlexBertPaddedParallelPreNormLayer,
1569
+ FlexBertPaddedPostNormLayer,
1570
+ ),
1571
+ ), f"Unsupported prenorm/postnorm layer type: {type(new_model_layer)}"
1572
+
1573
+ # First tile the normalization layers
1574
+ if hasattr(pretrained_layer, "attn_norm"):
1575
+ tile_norm(pretrained_layer.attn_norm, new_model_layer.attn_norm, mode=mode)
1576
+ if hasattr(pretrained_layer, "norm"):
1577
+ tile_norm(pretrained_layer.norm, new_model_layer.norm, mode=mode)
1578
+ if hasattr(pretrained_layer, "mlp_norm"):
1579
+ tile_norm(pretrained_layer.mlp_norm, new_model_layer.mlp_norm, mode=mode)
1580
+
1581
+ # Then tile the attention & mlp layers
1582
+ assert isinstance(
1583
+ new_model_layer.attn, type(pretrained_layer.attn)
1584
+ ), f"Pretrained and new_model attention layers must be the same type, got {type(new_model_layer.attn)} and {type(pretrained_layer.attn)}"
1585
+
1586
+ # first try the parallel attention layers
1587
+ if isinstance(pretrained_layer, (FlexBertUnpadParallelPreNormLayer, FlexBertPaddedParallelPreNormLayer)):
1588
+ assert isinstance(
1589
+ pretrained_layer.attn,
1590
+ (
1591
+ FlexBertUnpadParallelAttention,
1592
+ FlexBertPaddedParallelAttention,
1593
+ FlexBertUnpadRopeParallelAttention,
1594
+ FlexBertPaddedRopeParallelAttention,
1595
+ ),
1596
+ ), f"Parallel prenorm layer must have parallel attention layer: {type(pretrained_layer.attn)}"
1597
+ if not isinstance(pretrained_layer.mlp, (FlexBertParallelGLU)):
1598
+ raise ValueError(f"Parallel prenorm layer must have parallel MLP layer: {type(pretrained_layer.mlp)}")
1599
+ tile_linear(
1600
+ pretrained_layer.Wqkvff,
1601
+ new_model_layer.Wqkvff,
1602
+ linear_type=TileLinear.wqkvff,
1603
+ mode=mode,
1604
+ pretrained_attn_size=pretrained_layer.attn_size,
1605
+ pretrained_mlp_size=pretrained_layer.mlp_size,
1606
+ new_attn_size=new_model_layer.attn_size,
1607
+ new_mlp_size=new_model_layer.mlp_size,
1608
+ wqkvff_is_glu=True,
1609
+ )
1610
+
1611
+ # then try the fused attention layers
1612
+ elif isinstance(
1613
+ pretrained_layer.attn,
1614
+ (
1615
+ FlexBertUnpadAttention,
1616
+ FlexBertPaddedAttention,
1617
+ FlexBertUnpadRopeAttention,
1618
+ FlexBertPaddedRopeAttention,
1619
+ ),
1620
+ ):
1621
+ tile_linear(pretrained_layer.attn.Wqkv, new_model_layer.attn.Wqkv, linear_type=TileLinear.wqkv, mode=mode)
1622
+ else:
1623
+ raise ValueError(f"Unsupported attention layer type: {type(pretrained_layer.attn)}")
1624
+
1625
+ # finally, tile the attention output layer
1626
+ tile_linear(pretrained_layer.attn.Wo, new_model_layer.attn.Wo, linear_type=TileLinear.default, mode=mode)
1627
+
1628
+ # tile the mlp layer if the model is not using parallel attention layers
1629
+ if not isinstance(pretrained_layer.mlp, (FlexBertMLP, FlexBertGLU, FlexBertParallelGLU)):
1630
+ raise ValueError(f"Unsupported MLP layer type: {type(pretrained_layer.mlp)}")
1631
+ assert isinstance(
1632
+ new_model_layer.mlp, type(pretrained_layer.mlp)
1633
+ ), f"Pretrained and new_model mlp layers must be the same type, got {type(new_model_layer.mlp)} and {type(pretrained_layer.mlp)}"
1634
+
1635
+ # already tiled the parallel glu layer if it exists, so only need to handle mlp & glu Wi
1636
+ if isinstance(pretrained_layer.mlp, FlexBertGLU):
1637
+ tile_linear(pretrained_layer.mlp.Wi, new_model_layer.mlp.Wi, linear_type=TileLinear.glu, mode=mode)
1638
+ elif isinstance(pretrained_layer.mlp, FlexBertMLP):
1639
+ tile_linear(pretrained_layer.mlp.Wi, new_model_layer.mlp.Wi, linear_type=TileLinear.default, mode=mode)
1640
+ # tile the output for both ParallelGLU and MLP/GLU
1641
+ tile_linear(pretrained_layer.mlp.Wo, new_model_layer.mlp.Wo, linear_type=TileLinear.default, mode=mode)
1642
+
1643
+
1644
+ def init_mlm_model_from_pretrained(
1645
+ config: FlexBertConfig,
1646
+ pretrained_model: FlexBertForMaskedLM,
1647
+ new_model: FlexBertForMaskedLM,
1648
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
1649
+ ):
1650
+ """
1651
+ Initialize the new model from the pretrained model.
1652
+
1653
+ This method uses Gopher layer scaling and Phi-style weight tiling as selected by `mode`.
1654
+ The new model must have the same or more layers and the same or larger dimensions than the pretrained model.
1655
+
1656
+ Args:
1657
+ config (FlexBertConfig): The configuration of the new_model
1658
+ pretrained_model (FlexBertForMaskedLM): The smaller, pre-trained model
1659
+ new_model (FlexBertForMaskedLM): The larger model to be initialized from the pretrained model
1660
+ mode (Union[str, TileMode]): The Phi-style weight tiling mode to use
1661
+
1662
+ This function assumes that the new_model has more layers and a larger hidden size
1663
+ than the pretrained_model, but the same vocabulary size.
1664
+ """
1665
+ init_model_from_pretrained(pretrained_model.bert, new_model.bert, mode=mode)
1666
+
1667
+ # TODO: uncomment this when the repo is turned into a pip installable package
1668
+ # if not isinstance(pretrained_model.head, FlexBertPredictionHead):
1669
+ # raise ValueError(f"Pretrained model must have a prediction head: {type(pretrained_model.head)}")
1670
+ # if not isinstance(new_model.head, FlexBertPredictionHead):
1671
+ # raise ValueError(f"New model must have a prediction head: {type(new_model.head)}")
1672
+
1673
+ # tile the prediction head
1674
+ tile_linear(pretrained_model.head.dense, new_model.head.dense, linear_type=TileLinear.default, mode=mode)
1675
+ tile_norm(pretrained_model.head.norm, new_model.head.norm, mode=mode)
1676
+
1677
+ # setup weight tying
1678
+ if config.tie_word_embeddings:
1679
+ new_model.decoder.weight = new_model.bert.embeddings.tok_embeddings.weight
1680
+ tile_linear(
1681
+ pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode, bias_only=True
1682
+ )
1683
+ else:
1684
+ tile_linear(pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode)
normalization.py CHANGED
@@ -10,7 +10,7 @@ import torch
10
  import torch.nn as nn
11
  from torch.nn import init
12
 
13
- from configuration_bert import FlexBertConfig
14
 
15
  try:
16
  from flash_attn.ops.triton.layer_norm import RMSNorm as TritonRMSNorm
 
10
  import torch.nn as nn
11
  from torch.nn import init
12
 
13
+ from .configuration_bert import FlexBertConfig
14
 
15
  try:
16
  from flash_attn.ops.triton.layer_norm import RMSNorm as TritonRMSNorm
options.py CHANGED
@@ -1,9 +1,9 @@
1
- from normalization import NORM2CLS
2
- from embeddings import EBB2CLS
3
- from activation import ACT2CLS
4
- from attention import ATTN2CLS
5
- from mlp import MLP2CLS
6
- from layers import LAYER2CLS
7
 
8
 
9
  def print_layer_options():
 
1
+ from .normalization import NORM2CLS
2
+ from .embeddings import EBB2CLS
3
+ from .activation import ACT2CLS
4
+ from .attention import ATTN2CLS
5
+ from .mlp import MLP2CLS
6
+ from .layers import LAYER2CLS
7
 
8
 
9
  def print_layer_options():