sanchit-gandhi HF staff commited on
Commit
5f537ba
1 Parent(s): b8910de

Add training scripts and weights

Browse files
.gitattributes CHANGED
@@ -21,7 +21,6 @@
21
  *.pt filter=lfs diff=lfs merge=lfs -text
22
  *.pth filter=lfs diff=lfs merge=lfs -text
23
  *.rar filter=lfs diff=lfs merge=lfs -text
24
- *.safetensors filter=lfs diff=lfs merge=lfs -text
25
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
  *.tar.* filter=lfs diff=lfs merge=lfs -text
27
  *.tflite filter=lfs diff=lfs merge=lfs -text
 
21
  *.pt filter=lfs diff=lfs merge=lfs -text
22
  *.pth filter=lfs diff=lfs merge=lfs -text
23
  *.rar filter=lfs diff=lfs merge=lfs -text
 
24
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.tar.* filter=lfs diff=lfs merge=lfs -text
26
  *.tflite filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ tags:
5
+ - esb
6
+ datasets:
7
+ - esb/datasets
8
+ - facebook/voxpopuli
9
+ ---
10
+
11
+ To reproduce this run, first call `get_ctc_tokenizer.py` to train the CTC tokenizer and then execute the following command to train the CTC system:
12
+ ```python
13
+ #!/usr/bin/env bash
14
+ python run_flax_speech_recognition_ctc.py \
15
+ --model_name_or_path="esb/wav2vec2-ctc-pretrained" \
16
+ --tokenizer_name="wav2vec2-ctc-voxpopuli-tokenizer" \
17
+ --dataset_name="esb/datasets" \
18
+ --dataset_config_name="voxpopuli" \
19
+ --output_dir="./" \
20
+ --wandb_project="wav2vec2-ctc" \
21
+ --wandb_name="wav2vec2-ctc-voxpopuli" \
22
+ --max_steps="50000" \
23
+ --save_steps="10000" \
24
+ --eval_steps="10000" \
25
+ --learning_rate="3e-4" \
26
+ --logging_steps="25" \
27
+ --warmup_steps="5000" \
28
+ --preprocessing_num_workers="1" \
29
+ --per_device_eval_batch_size="1" \
30
+ --do_train \
31
+ --do_eval \
32
+ --do_predict \
33
+ --overwrite_output_dir \
34
+ --gradient_checkpointing \
35
+ --freeze_feature_encoder \
36
+ --push_to_hub \
37
+ --use_auth_token
38
+ ```
config.json ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.1,
3
+ "adapter_kernel_size": 3,
4
+ "adapter_stride": 2,
5
+ "add_adapter": false,
6
+ "apply_spec_augment": true,
7
+ "architectures": [
8
+ "Wav2Vec2ForCTC"
9
+ ],
10
+ "attention_dropout": 0.1,
11
+ "bos_token_id": 1,
12
+ "classifier_proj_size": 256,
13
+ "codevector_dim": 768,
14
+ "contrastive_logits_temperature": 0.1,
15
+ "conv_bias": true,
16
+ "conv_dim": [
17
+ 512,
18
+ 512,
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512
24
+ ],
25
+ "conv_kernel": [
26
+ 10,
27
+ 3,
28
+ 3,
29
+ 3,
30
+ 3,
31
+ 2,
32
+ 2
33
+ ],
34
+ "conv_stride": [
35
+ 5,
36
+ 2,
37
+ 2,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2
42
+ ],
43
+ "ctc_loss_reduction": "sum",
44
+ "ctc_zero_infinity": false,
45
+ "diversity_loss_weight": 0.1,
46
+ "do_stable_layer_norm": true,
47
+ "eos_token_id": 2,
48
+ "feat_extract_activation": "gelu",
49
+ "feat_extract_dropout": 0.0,
50
+ "feat_extract_norm": "layer",
51
+ "feat_proj_dropout": 0.0,
52
+ "feat_quantizer_dropout": 0.0,
53
+ "final_dropout": 0.0,
54
+ "fuse_matmuls": false,
55
+ "gradient_checkpointing": true,
56
+ "hidden_act": "gelu",
57
+ "hidden_dropout": 0.1,
58
+ "hidden_dropout_prob": 0.1,
59
+ "hidden_size": 1024,
60
+ "initializer_range": 0.02,
61
+ "intermediate_size": 4096,
62
+ "layer_norm_eps": 1e-05,
63
+ "layerdrop": 0.0,
64
+ "mask_feature_length": 10,
65
+ "mask_feature_min_masks": 0,
66
+ "mask_feature_prob": 0.0,
67
+ "mask_time_length": 10,
68
+ "mask_time_min_masks": 2,
69
+ "mask_time_prob": 0.1,
70
+ "model_type": "wav2vec2",
71
+ "num_adapter_layers": 3,
72
+ "num_attention_heads": 16,
73
+ "num_codevector_groups": 2,
74
+ "num_codevectors_per_group": 320,
75
+ "num_conv_pos_embedding_groups": 16,
76
+ "num_conv_pos_embeddings": 128,
77
+ "num_feat_extract_layers": 7,
78
+ "num_hidden_layers": 24,
79
+ "num_negatives": 100,
80
+ "output_hidden_size": 1024,
81
+ "pad_token_id": 0,
82
+ "proj_codevector_dim": 768,
83
+ "tdnn_dilation": [
84
+ 1,
85
+ 2,
86
+ 3,
87
+ 1,
88
+ 1
89
+ ],
90
+ "tdnn_dim": [
91
+ 512,
92
+ 512,
93
+ 512,
94
+ 512,
95
+ 1500
96
+ ],
97
+ "tdnn_kernel": [
98
+ 5,
99
+ 3,
100
+ 3,
101
+ 1,
102
+ 1
103
+ ],
104
+ "transformers_version": "4.18.0.dev0",
105
+ "use_scan": true,
106
+ "use_weighted_layer_sum": false,
107
+ "vocab_size": 35,
108
+ "xvector_output_dim": 512
109
+ }
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ffca1e1c2f8cbf19b33e85549fa8a64fba702365e266900d2e03e38e72ab305
3
+ size 1261900450
get_ctc_tokenizer.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from datasets import load_dataset
3
+ from collections import Counter
4
+ import json
5
+ import os
6
+ import tempfile
7
+ from transformers import Wav2Vec2CTCTokenizer
8
+
9
+ # which dataset
10
+ dataset_name = "voxpopuli"
11
+ # which split -> we should only use train to train our tokenizer
12
+ split = "train"
13
+ # in case the dataset requires access
14
+ use_auth_token = True
15
+ # name of tok to upload to the Hub
16
+ tokenizer_name = f"wav2vec2-ctc-{dataset_name}-tokenizer"
17
+
18
+ # FIX the cutoff freq for all datasets -> an entirely dataset-agnostic approach
19
+ cutoff_freq = 0.01
20
+
21
+ dataset = load_dataset(
22
+ "esb/datasets",
23
+ dataset_name,
24
+ split=split,
25
+ use_auth_token=use_auth_token,
26
+ )
27
+
28
+ # remove all data that is unnecessary to save RAM
29
+ dataset = dataset.remove_columns(list(set(dataset.column_names) - {"text"}))
30
+
31
+ # define function to see stats about letters and to create vocab
32
+ def create_vocabulary_from_data(dataset, word_delimiter_token="|", cutoff_freq=0.0):
33
+ def extract_all_chars(batch):
34
+ all_text = " ".join(batch["text"])
35
+
36
+ count_chars_dict = Counter(list(all_text))
37
+ # sort by freq
38
+ count_chars_dict = sorted(count_chars_dict.items(), key=lambda item: (-item[1], item[0]))
39
+ # retrieve dict, freq
40
+ vocab, freqs = zip(*count_chars_dict)
41
+
42
+ return {"vocab": list(vocab), "freqs": list(freqs)}
43
+
44
+ dataset = dataset.map(
45
+ extract_all_chars,
46
+ batched=True,
47
+ batch_size=-1,
48
+ remove_columns=dataset.column_names,
49
+ )
50
+
51
+ vocab, freqs = dataset["vocab"], dataset["freqs"]
52
+ total_num_chars = sum(freqs)
53
+ chars_to_remove = []
54
+
55
+ print("Character Occurences")
56
+ print(f"Total characters in dataset: {total_num_chars}")
57
+ print(50 * "-")
58
+ print(f"{'Char'.rjust(5)} | {'Total occ'.rjust(10)} | {'% of total occ'.rjust(20)} |")
59
+ print(50 * "-")
60
+ for char, freq in zip(vocab, freqs):
61
+ freq_in_percent = freq / total_num_chars * 100
62
+ print(f"{char.rjust(5)} | {str(freq).rjust(10)} | {str(round(freq_in_percent, 3)).rjust(20)} |")
63
+ if freq_in_percent < cutoff_freq:
64
+ chars_to_remove.append(char)
65
+ print(50 * "-")
66
+
67
+ vocab = list(set(vocab) - set(chars_to_remove))
68
+
69
+ # Wav2Vec2CTC Tokenizers always have those as the first tokens (important for CTC)
70
+ vocab = ["<pad>", "<s>", "</s>", "<unk>"] + vocab
71
+
72
+ # create json dict
73
+ vocab_dict = {v: k for k, v in enumerate(list(vocab))}
74
+
75
+ # replace white space with delimiter token
76
+ if word_delimiter_token is not None:
77
+ vocab_dict[word_delimiter_token] = vocab_dict[" "]
78
+ del vocab_dict[" "]
79
+
80
+ return vocab_dict
81
+
82
+ # Note that the functions accepts the following important args
83
+ # --cutoff_freq
84
+ # => This is very important! Lots of datasets will contain "wrong" characters in the training set, e.g.
85
+ # characters that just occur a couple of times.
86
+ # By default, the CTC vocab creation would just add them to the vocab even if their occurance is neglectible # compared to the "super frequent" letters. We can see such characters as "errors" or irrelevant in the
87
+ # dataset, so that we should delete them from the vocab. During training, they would then just be classified
88
+ # unknown <unk> tokens which the model can handle.
89
+ # In this script, we deploy a mechanism to remove all chars whose freq in % is below a certain threshold.
90
+ # We FIX this threshold for all datasets (i.e. dataset-agnostic)
91
+
92
+ vocab_dict = create_vocabulary_from_data(dataset, cutoff_freq=cutoff_freq)
93
+
94
+ # save vocab dict to be loaded into tokenizer
95
+ with tempfile.TemporaryDirectory() as tmp:
96
+ with open(os.path.join(tmp, "vocab.json"), "w") as file:
97
+ json.dump(vocab_dict, file)
98
+
99
+ tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tmp)
100
+
101
+ # push tokenizer to the Hub
102
+ tokenizer.push_to_hub(tokenizer_name)
models/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from models.configuration_bart import BartConfig
2
+ from models.configuration_wav2vec2 import Wav2Vec2Config
3
+ from models.configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
4
+ from models.modeling_flax_wav2vec2 import FlaxWav2Vec2Model, FlaxWav2Vec2Module, FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForCTCModule
5
+ from models.modeling_flax_bart import FlaxBartForCausalLM, FlaxBartForCausalLMModule
6
+ from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
models/configuration_bart.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ BART model configuration"""
16
+ import warnings
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25
+ "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/config.json",
26
+ # See all BART models at https://huggingface.co/models?filter=bart
27
+ }
28
+
29
+
30
+ class BartConfig(PretrainedConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of a [`BartModel`]. It is used to instantiate a BART
33
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
34
+ defaults will yield a similar configuration to that of the BART
35
+ [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
36
+
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 50265):
43
+ Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the
44
+ `inputs_ids` passed when calling [`BartModel`] or [`TFBartModel`].
45
+ d_model (`int`, *optional*, defaults to 1024):
46
+ Dimensionality of the layers and the pooler layer.
47
+ encoder_layers (`int`, *optional*, defaults to 12):
48
+ Number of encoder layers.
49
+ decoder_layers (`int`, *optional*, defaults to 12):
50
+ Number of decoder layers.
51
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
52
+ Number of attention heads for each attention layer in the Transformer encoder.
53
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
54
+ Number of attention heads for each attention layer in the Transformer decoder.
55
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
56
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
57
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
58
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
59
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
60
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
61
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
62
+ dropout (`float`, *optional*, defaults to 0.1):
63
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
64
+ attention_dropout (`float`, *optional*, defaults to 0.0):
65
+ The dropout ratio for the attention probabilities.
66
+ activation_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for activations inside the fully connected layer.
68
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
69
+ The dropout ratio for classifier.
70
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
71
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
72
+ just in case (e.g., 512 or 1024 or 2048).
73
+ init_std (`float`, *optional*, defaults to 0.02):
74
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
75
+ encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
76
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
77
+ for more details.
78
+ decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
79
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
80
+ for more details.
81
+ scale_embedding (`bool`, *optional*, defaults to `False`):
82
+ Scale embeddings by diving by sqrt(d_model).
83
+ use_cache (`bool`, *optional*, defaults to `True`):
84
+ Whether or not the model should return the last key/values attentions (not used by all models).
85
+ num_labels: (`int`, *optional*, defaults to 3):
86
+ The number of labels to use in [`BartForSequenceClassification`].
87
+ forced_eos_token_id (`int`, *optional*, defaults to 2):
88
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
89
+ `eos_token_id`.
90
+ use_scan (`bool`, *optional*, defaults to `False`):
91
+ Whether or not to use nn.scan in the Flax Bart attention layers.
92
+
93
+ Example:
94
+
95
+ ```python
96
+ >>> from transformers import BartModel, BartConfig
97
+
98
+ >>> # Initializing a BART facebook/bart-large style configuration
99
+ >>> configuration = BartConfig()
100
+
101
+ >>> # Initializing a model from the facebook/bart-large style configuration
102
+ >>> model = BartModel(configuration)
103
+
104
+ >>> # Accessing the model configuration
105
+ >>> configuration = model.config
106
+ ```"""
107
+ model_type = "bart"
108
+ keys_to_ignore_at_inference = ["past_key_values"]
109
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
110
+
111
+ def __init__(
112
+ self,
113
+ vocab_size=50265,
114
+ max_position_embeddings=1024,
115
+ encoder_layers=12,
116
+ encoder_ffn_dim=4096,
117
+ encoder_attention_heads=16,
118
+ decoder_layers=12,
119
+ decoder_ffn_dim=4096,
120
+ decoder_attention_heads=16,
121
+ encoder_layerdrop=0.0,
122
+ decoder_layerdrop=0.0,
123
+ activation_function="gelu",
124
+ d_model=1024,
125
+ dropout=0.1,
126
+ attention_dropout=0.0,
127
+ activation_dropout=0.0,
128
+ init_std=0.02,
129
+ classifier_dropout=0.0,
130
+ scale_embedding=False,
131
+ use_cache=True,
132
+ use_scan=False,
133
+ fuse_matmuls=False,
134
+ num_labels=3,
135
+ pad_token_id=1,
136
+ bos_token_id=0,
137
+ eos_token_id=2,
138
+ is_encoder_decoder=True,
139
+ decoder_start_token_id=2,
140
+ forced_eos_token_id=2,
141
+ **kwargs
142
+ ):
143
+ self.vocab_size = vocab_size
144
+ self.max_position_embeddings = max_position_embeddings
145
+ self.d_model = d_model
146
+ self.encoder_ffn_dim = encoder_ffn_dim
147
+ self.encoder_layers = encoder_layers
148
+ self.encoder_attention_heads = encoder_attention_heads
149
+ self.decoder_ffn_dim = decoder_ffn_dim
150
+ self.decoder_layers = decoder_layers
151
+ self.decoder_attention_heads = decoder_attention_heads
152
+ self.dropout = dropout
153
+ self.attention_dropout = attention_dropout
154
+ self.activation_dropout = activation_dropout
155
+ self.activation_function = activation_function
156
+ self.init_std = init_std
157
+ self.encoder_layerdrop = encoder_layerdrop
158
+ self.decoder_layerdrop = decoder_layerdrop
159
+ self.classifier_dropout = classifier_dropout
160
+ self.use_cache = use_cache
161
+ self.use_scan = use_scan
162
+ self.fuse_matmuls = fuse_matmuls
163
+ self.num_hidden_layers = encoder_layers
164
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
165
+
166
+ super().__init__(
167
+ num_labels=num_labels,
168
+ pad_token_id=pad_token_id,
169
+ bos_token_id=bos_token_id,
170
+ eos_token_id=eos_token_id,
171
+ is_encoder_decoder=is_encoder_decoder,
172
+ decoder_start_token_id=decoder_start_token_id,
173
+ forced_eos_token_id=forced_eos_token_id,
174
+ **kwargs,
175
+ )
176
+
177
+ # ensure backward compatibility for BART CNN models
178
+ if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
179
+ self.forced_bos_token_id = self.bos_token_id
180
+ warnings.warn(
181
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
182
+ "The config can simply be saved and uploaded again to be fixed."
183
+ )
models/configuration_speech_encoder_decoder.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import copy
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+ from models.configuration_wav2vec2 import Wav2Vec2Config
22
+ from models.configuration_bart import BartConfig
23
+ from transformers import AutoConfig
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class SpeechEncoderDecoderConfig(PretrainedConfig):
30
+ r"""
31
+ [`SpeechEncoderDecoderConfig`] is the configuration class to store the configuration of a
32
+ [`SpeechEncoderDecoderModel`]. It is used to instantiate an Encoder Decoder model according to the specified
33
+ arguments, defining the encoder and decoder configs.
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+ Args:
39
+ kwargs (*optional*):
40
+ Dictionary of keyword arguments. Notably:
41
+
42
+ - **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
43
+ the encoder config.
44
+ - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
45
+ the decoder config.
46
+
47
+ Examples:
48
+
49
+ ```python
50
+ >>> from transformers import BertConfig, Wav2Vec2Config, SpeechEncoderDecoderConfig, SpeechEncoderDecoderModel
51
+
52
+ >>> # Initializing a Wav2Vec2 & BERT style configuration
53
+ >>> config_encoder = Wav2Vec2Config()
54
+ >>> config_decoder = BertConfig()
55
+
56
+ >>> config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
57
+
58
+ >>> # Initializing a Wav2Vec2Bert model from a Wav2Vec2 & bert-base-uncased style configurations
59
+ >>> model = SpeechEncoderDecoderModel(config=config)
60
+
61
+ >>> # Accessing the model configuration
62
+ >>> config_encoder = model.config.encoder
63
+ >>> config_decoder = model.config.decoder
64
+ >>> # set decoder config to causal lm
65
+ >>> config_decoder.is_decoder = True
66
+ >>> config_decoder.add_cross_attention = True
67
+
68
+ >>> # Saving the model, including its configuration
69
+ >>> model.save_pretrained("my-model")
70
+
71
+ >>> # loading model and config from pretrained folder
72
+ >>> encoder_decoder_config = SpeechEncoderDecoderConfig.from_pretrained("my-model")
73
+ >>> model = SpeechEncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)
74
+ ```"""
75
+ model_type = "speech-encoder-decoder"
76
+ is_composition = True
77
+
78
+ def __init__(self, **kwargs):
79
+ super().__init__(**kwargs)
80
+ if "encoder" not in kwargs or "decoder" not in kwargs:
81
+ raise ValueError(
82
+ f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
83
+ )
84
+
85
+ encoder_config = kwargs.pop("encoder")
86
+ decoder_config = kwargs.pop("decoder")
87
+
88
+ # TODO: Load configs from AutoConfig (as done in Transformers 🤗)
89
+ self.encoder = Wav2Vec2Config(**encoder_config)
90
+ self.decoder = BartConfig(**decoder_config)
91
+ self.is_encoder_decoder = True
92
+
93
+ @classmethod
94
+ def from_encoder_decoder_configs(
95
+ cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
96
+ ) -> PretrainedConfig:
97
+ r"""
98
+ Instantiate a [`SpeechEncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model
99
+ configuration and decoder model configuration.
100
+
101
+ Returns:
102
+ [`SpeechEncoderDecoderConfig`]: An instance of a configuration object
103
+ """
104
+ logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
105
+ decoder_config.is_decoder = True
106
+ decoder_config.add_cross_attention = True
107
+
108
+ return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
109
+
110
+ def to_dict(self):
111
+ """
112
+ Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
113
+
114
+ Returns:
115
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
116
+ """
117
+ output = copy.deepcopy(self.__dict__)
118
+ output["encoder"] = self.encoder.to_dict()
119
+ output["decoder"] = self.decoder.to_dict()
120
+ output["model_type"] = self.__class__.model_type
121
+ return output
models/configuration_wav2vec2.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Wav2Vec2 model configuration"""
16
+
17
+ import functools
18
+ import operator
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json",
28
+ # See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
29
+ }
30
+
31
+
32
+ class Wav2Vec2Config(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`Wav2Vec2Model`]. It is used to instantiate an
35
+ Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
36
+ with the defaults will yield a similar configuration to that of the Wav2Vec2
37
+ [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) architecture.
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+
43
+ Args:
44
+ vocab_size (`int`, *optional*, defaults to 32):
45
+ Vocabulary size of the Wav2Vec2 model. Defines the number of different tokens that can be represented by
46
+ the `inputs_ids` passed when calling [`Wav2Vec2Model`] or [`TFWav2Vec2Model`]. Vocabulary size of the
47
+ model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward
48
+ method of [`Wav2Vec2Model`].
49
+ hidden_size (`int`, *optional*, defaults to 768):
50
+ Dimensionality of the encoder layers and the pooler layer.
51
+ num_hidden_layers (`int`, *optional*, defaults to 12):
52
+ Number of hidden layers in the Transformer encoder.
53
+ num_attention_heads (`int`, *optional*, defaults to 12):
54
+ Number of attention heads for each attention layer in the Transformer encoder.
55
+ intermediate_size (`int`, *optional*, defaults to 3072):
56
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
57
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
60
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_dropout (`float`, *optional*, defaults to 0.1):
63
+ The dropout ratio for the attention probabilities.
64
+ final_dropout (`float`, *optional*, defaults to 0.1):
65
+ The dropout probability for the final projection layer of [`Wav2Vec2ForCTC`].
66
+ initializer_range (`float`, *optional*, defaults to 0.02):
67
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
69
+ The epsilon used by the layer normalization layers.
70
+ feat_extract_norm (`str`, *optional*, defaults to `"group"`):
71
+ The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
72
+ normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
73
+ convolutional layers.
74
+ feat_proj_dropout (`float`, *optional*, defaults to 0.0):
75
+ The dropout probability for output of the feature encoder.
76
+ feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
77
+ The non-linear activation function (function or string) in the 1D convolutional layers of the feature
78
+ extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
79
+ feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
80
+ The dropout probabilitiy for quantized feature encoder states.
81
+ conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
82
+ A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
83
+ feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
84
+ conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
85
+ A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
86
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
87
+ conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
88
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
89
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
90
+ *conv_dim*.
91
+ conv_bias (`bool`, *optional*, defaults to `False`):
92
+ Whether the 1D convolutional layers have a bias.
93
+ num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
94
+ Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
95
+ embeddings layer.
96
+ num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
97
+ Number of groups of 1D convolutional positional embeddings layer.
98
+ do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
99
+ Whether to apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
100
+ True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
101
+ False` corresponds to applying layer norm after the attention layer.
102
+ apply_spec_augment (`bool`, *optional*, defaults to `True`):
103
+ Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
104
+ [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
105
+ Recognition](https://arxiv.org/abs/1904.08779).
106
+ mask_time_prob (`float`, *optional*, defaults to 0.05):
107
+ Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
108
+ procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
109
+ reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
110
+ masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
111
+ actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
112
+ mask_time_length (`int`, *optional*, defaults to 10):
113
+ Length of vector span along the time axis.
114
+ mask_time_min_masks (`int`, *optional*, defaults to 2),:
115
+ The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
116
+ irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
117
+ mask_time_min_masks''
118
+ mask_feature_prob (`float`, *optional*, defaults to 0.0):
119
+ Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
120
+ masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
121
+ the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
122
+ span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
123
+ may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
124
+ True`.
125
+ mask_feature_length (`int`, *optional*, defaults to 10):
126
+ Length of vector span along the feature axis.
127
+ mask_feature_min_masks (`int`, *optional*, defaults to 0),:
128
+ The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
129
+ step, irrespectively of `mask_feature_prob`. Only relevant if
130
+ ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
131
+ num_codevectors_per_group (`int`, *optional*, defaults to 320):
132
+ Number of entries in each quantization codebook (group).
133
+ num_codevector_groups (`int`, *optional*, defaults to 2):
134
+ Number of codevector groups for product codevector quantization.
135
+ contrastive_logits_temperature (`float`, *optional*, defaults to 0.1):
136
+ The temperature *kappa* in the contrastive loss.
137
+ feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
138
+ The dropout probabilitiy for the output of the feature encoder that's used by the quantizer.
139
+ num_negatives (`int`, *optional*, defaults to 100):
140
+ Number of negative samples for the contrastive loss.
141
+ codevector_dim (`int`, *optional*, defaults to 256):
142
+ Dimensionality of the quantized feature vectors.
143
+ proj_codevector_dim (`int`, *optional*, defaults to 256):
144
+ Dimensionality of the final projection of both the quantized and the transformer features.
145
+ diversity_loss_weight (`int`, *optional*, defaults to 0.1):
146
+ The weight of the codebook diversity loss component.
147
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
148
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
149
+ instance of [`Wav2Vec2ForCTC`].
150
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
151
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
152
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
153
+ of [`Wav2Vec2ForCTC`].
154
+ use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
155
+ Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
156
+ instance of [`Wav2Vec2ForSequenceClassification`].
157
+ classifier_proj_size (`int`, *optional*, defaults to 256):
158
+ Dimensionality of the projection before token mean-pooling for classification.
159
+ tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
160
+ A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
161
+ module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
162
+ tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
163
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
164
+ *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
165
+ tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
166
+ A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
167
+ *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
168
+ xvector_output_dim (`int`, *optional*, defaults to 512):
169
+ Dimensionality of the *XVector* embedding vectors.
170
+ add_adapter (`bool`, *optional*, defaults to `False`):
171
+ Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
172
+ warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
173
+ adapter_kernel_size (`int`, *optional*, defaults to 3):
174
+ Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
175
+ adapter_stride (`int`, *optional*, defaults to 2):
176
+ Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
177
+ num_adapter_layers (`int`, *optional*, defaults to 3):
178
+ Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
179
+ True`.
180
+ output_hidden_size (`int`, *optional*):
181
+ Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
182
+ if `add_adapter is True`.
183
+ use_scan (`bool`, *optional*, defaults to `False`):
184
+ Whether or not to use nn.scan in the Flax Wav2Vec2 transformer layers.
185
+
186
+ Example:
187
+
188
+ ```python
189
+ >>> from transformers import Wav2Vec2Model, Wav2Vec2Config
190
+
191
+ >>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration
192
+ >>> configuration = Wav2Vec2Config()
193
+
194
+ >>> # Initializing a model from the facebook/wav2vec2-base-960h style configuration
195
+ >>> model = Wav2Vec2Model(configuration)
196
+
197
+ >>> # Accessing the model configuration
198
+ >>> configuration = model.config
199
+ ```"""
200
+ model_type = "wav2vec2"
201
+
202
+ def __init__(
203
+ self,
204
+ vocab_size=32,
205
+ hidden_size=768,
206
+ num_hidden_layers=12,
207
+ num_attention_heads=12,
208
+ intermediate_size=3072,
209
+ hidden_act="gelu",
210
+ hidden_dropout=0.1,
211
+ activation_dropout=0.1,
212
+ attention_dropout=0.1,
213
+ feat_proj_dropout=0.0,
214
+ feat_quantizer_dropout=0.0,
215
+ final_dropout=0.1,
216
+ layerdrop=0.1,
217
+ initializer_range=0.02,
218
+ layer_norm_eps=1e-5,
219
+ feat_extract_norm="group",
220
+ feat_extract_activation="gelu",
221
+ conv_dim=(512, 512, 512, 512, 512, 512, 512),
222
+ conv_stride=(5, 2, 2, 2, 2, 2, 2),
223
+ conv_kernel=(10, 3, 3, 3, 3, 2, 2),
224
+ conv_bias=False,
225
+ num_conv_pos_embeddings=128,
226
+ num_conv_pos_embedding_groups=16,
227
+ do_stable_layer_norm=False,
228
+ apply_spec_augment=True,
229
+ mask_time_prob=0.05,
230
+ mask_time_length=10,
231
+ mask_time_min_masks=2,
232
+ mask_feature_prob=0.0,
233
+ mask_feature_length=10,
234
+ mask_feature_min_masks=0,
235
+ num_codevectors_per_group=320,
236
+ num_codevector_groups=2,
237
+ contrastive_logits_temperature=0.1,
238
+ num_negatives=100,
239
+ codevector_dim=256,
240
+ proj_codevector_dim=256,
241
+ diversity_loss_weight=0.1,
242
+ ctc_loss_reduction="sum",
243
+ ctc_zero_infinity=False,
244
+ use_weighted_layer_sum=False,
245
+ classifier_proj_size=256,
246
+ tdnn_dim=(512, 512, 512, 512, 1500),
247
+ tdnn_kernel=(5, 3, 3, 1, 1),
248
+ tdnn_dilation=(1, 2, 3, 1, 1),
249
+ xvector_output_dim=512,
250
+ pad_token_id=0,
251
+ bos_token_id=1,
252
+ eos_token_id=2,
253
+ add_adapter=False,
254
+ adapter_kernel_size=3,
255
+ adapter_stride=2,
256
+ num_adapter_layers=3,
257
+ output_hidden_size=None,
258
+ use_scan=False,
259
+ fuse_matmuls=False,
260
+ **kwargs
261
+ ):
262
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
263
+ self.hidden_size = hidden_size
264
+ self.feat_extract_norm = feat_extract_norm
265
+ self.feat_extract_activation = feat_extract_activation
266
+ self.conv_dim = list(conv_dim)
267
+ self.conv_stride = list(conv_stride)
268
+ self.conv_kernel = list(conv_kernel)
269
+ self.conv_bias = conv_bias
270
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
271
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
272
+ self.num_feat_extract_layers = len(self.conv_dim)
273
+ self.num_hidden_layers = num_hidden_layers
274
+ self.intermediate_size = intermediate_size
275
+ self.hidden_act = hidden_act
276
+ self.num_attention_heads = num_attention_heads
277
+ self.hidden_dropout = hidden_dropout
278
+ self.attention_dropout = attention_dropout
279
+ self.activation_dropout = activation_dropout
280
+ self.feat_proj_dropout = feat_proj_dropout
281
+ self.final_dropout = final_dropout
282
+ self.layerdrop = layerdrop
283
+ self.layer_norm_eps = layer_norm_eps
284
+ self.initializer_range = initializer_range
285
+ self.vocab_size = vocab_size
286
+ self.do_stable_layer_norm = do_stable_layer_norm
287
+ self.use_weighted_layer_sum = use_weighted_layer_sum
288
+ self.use_scan = use_scan
289
+ self.fuse_matmuls = fuse_matmuls
290
+
291
+ if (
292
+ (len(self.conv_stride) != self.num_feat_extract_layers)
293
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
294
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
295
+ ):
296
+ raise ValueError(
297
+ "Configuration for convolutional layers is incorrect. "
298
+ "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
299
+ f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
300
+ f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
301
+ )
302
+
303
+ # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
304
+ self.apply_spec_augment = apply_spec_augment
305
+ self.mask_time_prob = mask_time_prob
306
+ self.mask_time_length = mask_time_length
307
+ self.mask_time_min_masks = mask_time_min_masks
308
+ self.mask_feature_prob = mask_feature_prob
309
+ self.mask_feature_length = mask_feature_length
310
+ self.mask_feature_min_masks = mask_feature_min_masks
311
+
312
+ # parameters for pretraining with codevector quantized representations
313
+ self.num_codevectors_per_group = num_codevectors_per_group
314
+ self.num_codevector_groups = num_codevector_groups
315
+ self.contrastive_logits_temperature = contrastive_logits_temperature
316
+ self.feat_quantizer_dropout = feat_quantizer_dropout
317
+ self.num_negatives = num_negatives
318
+ self.codevector_dim = codevector_dim
319
+ self.proj_codevector_dim = proj_codevector_dim
320
+ self.diversity_loss_weight = diversity_loss_weight
321
+
322
+ # ctc loss
323
+ self.ctc_loss_reduction = ctc_loss_reduction
324
+ self.ctc_zero_infinity = ctc_zero_infinity
325
+
326
+ # adapter
327
+ self.add_adapter = add_adapter
328
+ self.adapter_kernel_size = adapter_kernel_size
329
+ self.adapter_stride = adapter_stride
330
+ self.num_adapter_layers = num_adapter_layers
331
+ self.output_hidden_size = output_hidden_size or hidden_size
332
+
333
+ # SequenceClassification-specific parameter. Feel free to ignore for other classes.
334
+ self.classifier_proj_size = classifier_proj_size
335
+
336
+ # XVector-specific parameters. Feel free to ignore for other classes.
337
+ self.tdnn_dim = list(tdnn_dim)
338
+ self.tdnn_kernel = list(tdnn_kernel)
339
+ self.tdnn_dilation = list(tdnn_dilation)
340
+ self.xvector_output_dim = xvector_output_dim
341
+
342
+ @property
343
+ def inputs_to_logits_ratio(self):
344
+ return functools.reduce(operator.mul, self.conv_stride, 1)
models/modeling_flax_bart.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Flax Bart model."""
16
+
17
+ import math
18
+ import random
19
+ from functools import partial
20
+ from typing import Optional, Tuple
21
+
22
+ import numpy as np
23
+
24
+ import flax.linen as nn
25
+ import jax
26
+ import jax.numpy as jnp
27
+ from flax.core.frozen_dict import FrozenDict, unfreeze
28
+ from flax.linen import combine_masks, make_causal_mask
29
+ from flax.linen import partitioning as nn_partitioning
30
+ from flax.linen.attention import dot_product_attention_weights
31
+ from jax import lax
32
+ from jax.random import PRNGKey
33
+
34
+ from transformers.modeling_flax_outputs import (
35
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
36
+ FlaxCausalLMOutputWithCrossAttentions,
37
+ )
38
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
39
+
40
+ from models import BartConfig
41
+
42
+
43
+ scan_with_axes = nn_partitioning.scan_with_axes
44
+ remat = nn_partitioning.remat
45
+
46
+
47
+ def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
48
+ """
49
+ Shift input ids one token to the right.
50
+ """
51
+ shifted_input_ids = np.zeros_like(input_ids)
52
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
53
+ shifted_input_ids[:, 0] = decoder_start_token_id
54
+
55
+ shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
56
+ return shifted_input_ids
57
+
58
+
59
+ class FlaxBartAttention(nn.Module):
60
+ config: BartConfig
61
+ embed_dim: int
62
+ num_heads: int
63
+ dropout: float = 0.0
64
+ causal: bool = False
65
+ bias: bool = True
66
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
67
+
68
+ def setup(self) -> None:
69
+ self.head_dim = self.embed_dim // self.num_heads
70
+ if self.head_dim * self.num_heads != self.embed_dim:
71
+ raise ValueError(
72
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
73
+ f" and `num_heads`: {self.num_heads})."
74
+ )
75
+
76
+ dense = partial(
77
+ nn.Dense,
78
+ self.embed_dim,
79
+ use_bias=self.bias,
80
+ dtype=self.dtype,
81
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
82
+ )
83
+
84
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
85
+
86
+ self.fused_proj = nn.Dense(
87
+ self.embed_dim * 3,
88
+ use_bias=self.bias,
89
+ dtype=self.dtype,
90
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
91
+ )
92
+
93
+ self.fused_key_value = nn.Dense(
94
+ self.embed_dim * 2,
95
+ use_bias=self.bias,
96
+ dtype=self.dtype,
97
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
98
+ )
99
+
100
+ self.out_proj = dense()
101
+
102
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
103
+
104
+ if self.causal:
105
+ self.causal_mask = make_causal_mask(
106
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
107
+ )
108
+
109
+ def _split_heads(self, hidden_states):
110
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
111
+
112
+ def _merge_heads(self, hidden_states):
113
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
114
+
115
+ @nn.compact
116
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
117
+ """
118
+ This function takes projected key, value states from a single input token and concatenates the states to cached
119
+ states from previous steps. This function is slighly adapted from the official Flax repository:
120
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
121
+ """
122
+ # detect if we're initializing by absence of existing cache data.
123
+ is_initialized = self.has_variable("cache", "cached_key")
124
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
125
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
126
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
127
+
128
+ if is_initialized:
129
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
130
+ # update key, value caches with our new 1d spatial slices
131
+ cur_index = cache_index.value
132
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
133
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
134
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
135
+ cached_key.value = key
136
+ cached_value.value = value
137
+ num_updated_cache_vectors = query.shape[1]
138
+ cache_index.value = cache_index.value + num_updated_cache_vectors
139
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
140
+ pad_mask = jnp.broadcast_to(
141
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
142
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
143
+ )
144
+ attention_mask = combine_masks(pad_mask, attention_mask)
145
+ return key, value, attention_mask
146
+
147
+ def __call__(
148
+ self,
149
+ hidden_states: jnp.ndarray,
150
+ key_value_states: Optional[jnp.ndarray] = None,
151
+ attention_mask: Optional[jnp.ndarray] = None,
152
+ init_cache: bool = False,
153
+ deterministic: bool = True,
154
+ ) -> Tuple[jnp.ndarray]:
155
+ """Input shape: Batch x Time x Channel"""
156
+
157
+ # if key_value_states are provided this layer is used as a cross-attention layer
158
+ # for the decoder
159
+ is_cross_attention = key_value_states is not None
160
+ batch_size = hidden_states.shape[0]
161
+
162
+ if self.config.fuse_matmuls:
163
+ # get key, value proj
164
+ if is_cross_attention:
165
+ # get query proj
166
+ query_states = self.q_proj(hidden_states)
167
+ # cross_attentions
168
+ attention_states = self.fused_key_value(key_value_states)
169
+ key_states, value_states = jnp.split(attention_states, 2, axis=-1)
170
+ else:
171
+ attention_states = self.fused_proj(hidden_states)
172
+ query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
173
+
174
+ else:
175
+ # get query proj
176
+ query_states = self.q_proj(hidden_states)
177
+ # get key, value proj
178
+ if is_cross_attention:
179
+ # cross_attentions
180
+ key_states = self.k_proj(key_value_states)
181
+ value_states = self.v_proj(key_value_states)
182
+ else:
183
+ # self_attention
184
+ key_states = self.k_proj(hidden_states)
185
+ value_states = self.v_proj(hidden_states)
186
+
187
+ query_states = self._split_heads(query_states)
188
+ key_states = self._split_heads(key_states)
189
+ value_states = self._split_heads(value_states)
190
+
191
+ # handle cache prepare causal attention mask
192
+ if self.causal:
193
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
194
+ if self.has_variable("cache", "cached_key"):
195
+ mask_shift = self.variables["cache"]["cache_index"]
196
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
197
+ causal_mask = lax.dynamic_slice(
198
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
199
+ )
200
+ else:
201
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
202
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
203
+
204
+ # combine masks if needed
205
+ if attention_mask is not None and self.causal:
206
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
207
+ attention_mask = combine_masks(attention_mask, causal_mask)
208
+ elif self.causal:
209
+ attention_mask = causal_mask
210
+ elif attention_mask is not None:
211
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
212
+
213
+ # During fast autoregressive decoding, we feed one position at a time,
214
+ # and cache the keys and values step by step.
215
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
216
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
217
+ key_states, value_states, query_states, attention_mask
218
+ )
219
+
220
+ # Convert the boolean attention mask to an attention bias.
221
+ if attention_mask is not None:
222
+ # attention mask in the form of attention bias
223
+ attention_bias = lax.select(
224
+ attention_mask > 0,
225
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
226
+ jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
227
+ )
228
+ else:
229
+ attention_bias = None
230
+
231
+ dropout_rng = None
232
+ if not deterministic and self.dropout > 0.0:
233
+ dropout_rng = self.make_rng("dropout")
234
+
235
+ attn_weights = dot_product_attention_weights(
236
+ query_states,
237
+ key_states,
238
+ bias=attention_bias,
239
+ dropout_rng=dropout_rng,
240
+ dropout_rate=self.dropout,
241
+ broadcast_dropout=True,
242
+ deterministic=deterministic,
243
+ dtype=self.dtype,
244
+ precision=None,
245
+ )
246
+
247
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
248
+ attn_output = self._merge_heads(attn_output)
249
+ attn_output = self.out_proj(attn_output)
250
+
251
+ return attn_output, attn_weights
252
+
253
+
254
+ class FlaxBartDecoderLayer(nn.Module):
255
+ config: BartConfig
256
+ dtype: jnp.dtype = jnp.float32
257
+
258
+ def setup(self) -> None:
259
+ self.embed_dim = self.config.d_model
260
+ self.self_attn = FlaxBartAttention(
261
+ config=self.config,
262
+ embed_dim=self.embed_dim,
263
+ num_heads=self.config.decoder_attention_heads,
264
+ dropout=self.config.attention_dropout,
265
+ causal=True,
266
+ dtype=self.dtype,
267
+ )
268
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
269
+ self.activation_fn = ACT2FN[self.config.activation_function]
270
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
271
+
272
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
273
+ self.encoder_attn = FlaxBartAttention(
274
+ config=self.config,
275
+ embed_dim=self.embed_dim,
276
+ num_heads=self.config.decoder_attention_heads,
277
+ dropout=self.config.attention_dropout,
278
+ dtype=self.dtype,
279
+ )
280
+ self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
281
+ self.fc1 = nn.Dense(
282
+ self.config.encoder_ffn_dim,
283
+ dtype=self.dtype,
284
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
285
+ )
286
+ self.fc2 = nn.Dense(
287
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
288
+ )
289
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
290
+
291
+ def __call__(
292
+ self,
293
+ hidden_states: jnp.ndarray,
294
+ attention_mask: jnp.ndarray,
295
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
296
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
297
+ init_cache: bool = False,
298
+ output_attentions: bool = True,
299
+ deterministic: bool = True,
300
+ ) -> Tuple[jnp.ndarray]:
301
+
302
+ if self.config.use_scan:
303
+ hidden_states = hidden_states[0]
304
+
305
+ residual = hidden_states
306
+
307
+ # Self Attention
308
+ hidden_states, self_attn_weights = self.self_attn(
309
+ hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
310
+ )
311
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
312
+ hidden_states = residual + hidden_states
313
+ hidden_states = self.self_attn_layer_norm(hidden_states)
314
+
315
+ # Cross-Attention Block
316
+ cross_attn_weights = None
317
+ if encoder_hidden_states is not None:
318
+ residual = hidden_states
319
+
320
+ hidden_states, cross_attn_weights = self.encoder_attn(
321
+ hidden_states=hidden_states,
322
+ key_value_states=encoder_hidden_states,
323
+ attention_mask=encoder_attention_mask,
324
+ )
325
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
326
+ hidden_states = residual + hidden_states
327
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
328
+
329
+ # Fully Connected
330
+ residual = hidden_states
331
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
332
+ hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
333
+ hidden_states = self.fc2(hidden_states)
334
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
335
+ hidden_states = residual + hidden_states
336
+ hidden_states = self.final_layer_norm(hidden_states)
337
+
338
+ outputs = (hidden_states,)
339
+
340
+ if output_attentions:
341
+ outputs += (self_attn_weights, cross_attn_weights)
342
+
343
+ if self.config.use_scan:
344
+ outputs = (outputs, None)
345
+
346
+ return outputs
347
+
348
+
349
+ class FlaxBartDecoderLayerCollection(nn.Module):
350
+ config: BartConfig
351
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
352
+
353
+ @nn.compact
354
+ def __call__(
355
+ self,
356
+ hidden_states,
357
+ attention_mask,
358
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
359
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
360
+ deterministic: bool = True,
361
+ init_cache: bool = False,
362
+ output_attentions: bool = False,
363
+ output_hidden_states: bool = False,
364
+ return_dict: bool = True,
365
+ ):
366
+ # decoder layers
367
+ all_hidden_states = () if output_hidden_states else None
368
+ all_self_attns = () if output_attentions else None
369
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
370
+
371
+ num_decoder_layers = self.config.decoder_layers
372
+ BlockDecoderLayer = (
373
+ remat(
374
+ FlaxBartDecoderLayer,
375
+ static_argnums=(4, 5, 6),
376
+ prevent_cse=not self.config.use_scan,
377
+ )
378
+ if self.config.gradient_checkpointing
379
+ else FlaxBartDecoderLayer
380
+ )
381
+
382
+ if self.config.use_scan:
383
+ # since all decoder layers are the same, we use nn.scan directly
384
+ assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
385
+ assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
386
+ hidden_states = (hidden_states,)
387
+
388
+ # TODO: add layerdrop in checkpointed scan (note: default value for layerdrop in config is zero)
389
+ hidden_states, _ = scan_with_axes(
390
+ BlockDecoderLayer,
391
+ variable_axes={"params": 0, "cache": 0},
392
+ split_rngs={"params": True, "dropout": True},
393
+ in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast),
394
+ length=num_decoder_layers,
395
+ )(self.config, dtype=self.dtype, name="FlaxBartDecoderLayers")(
396
+ hidden_states,
397
+ attention_mask,
398
+ encoder_hidden_states,
399
+ encoder_attention_mask,
400
+ init_cache,
401
+ output_attentions,
402
+ deterministic,
403
+ )
404
+ hidden_states = hidden_states[0]
405
+
406
+ else:
407
+ for layer in range(num_decoder_layers):
408
+ if output_hidden_states:
409
+ all_hidden_states += (hidden_states,)
410
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
411
+ dropout_probability = random.uniform(0, 1)
412
+ if not deterministic and (dropout_probability < self.config.decoder_layerdrop):
413
+ layer_outputs = (None, None, None)
414
+ else:
415
+ layer_outputs = BlockDecoderLayer(self.config, dtype=self.dtype, name=str(layer),)(
416
+ hidden_states,
417
+ attention_mask,
418
+ encoder_hidden_states,
419
+ encoder_attention_mask,
420
+ init_cache,
421
+ output_attentions,
422
+ deterministic,
423
+ )
424
+
425
+ hidden_states = layer_outputs[0]
426
+ if output_attentions:
427
+ all_self_attns += (layer_outputs[1],)
428
+
429
+ if encoder_hidden_states is not None:
430
+ all_cross_attentions += (layer_outputs[2],)
431
+
432
+ # add hidden states from the last decoder layer
433
+ if output_hidden_states:
434
+ all_hidden_states += (hidden_states,)
435
+
436
+ outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
437
+
438
+ if not return_dict:
439
+ return tuple(v for v in outputs if v is not None)
440
+
441
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
442
+ last_hidden_state=hidden_states,
443
+ hidden_states=all_hidden_states,
444
+ attentions=all_self_attns,
445
+ cross_attentions=all_cross_attentions,
446
+ )
447
+
448
+
449
+ class FlaxBartDecoder(nn.Module):
450
+ config: BartConfig
451
+ embed_tokens: nn.Embed
452
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
453
+
454
+ def setup(self):
455
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
456
+
457
+ embed_dim = self.config.d_model
458
+ self.padding_idx = self.config.pad_token_id
459
+ self.max_target_positions = self.config.max_position_embeddings
460
+ self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
461
+
462
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
463
+ # and adjust num_embeddings appropriately. Other models don't have this hack
464
+ self.offset = 2
465
+ self.embed_positions = nn.Embed(
466
+ self.config.max_position_embeddings + self.offset,
467
+ embed_dim,
468
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
469
+ )
470
+
471
+ self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
472
+ self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
473
+
474
+ def __call__(
475
+ self,
476
+ input_ids,
477
+ attention_mask,
478
+ position_ids,
479
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
480
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
481
+ init_cache: bool = False,
482
+ output_attentions: bool = False,
483
+ output_hidden_states: bool = False,
484
+ return_dict: bool = True,
485
+ deterministic: bool = True,
486
+ ):
487
+ input_shape = input_ids.shape
488
+ input_ids = input_ids.reshape(-1, input_shape[-1])
489
+
490
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
491
+
492
+ # embed positions
493
+ positions = self.embed_positions(position_ids + self.offset)
494
+
495
+ hidden_states = inputs_embeds + positions
496
+ hidden_states = self.layernorm_embedding(hidden_states)
497
+
498
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
499
+
500
+ outputs = self.layers(
501
+ hidden_states,
502
+ attention_mask,
503
+ encoder_hidden_states,
504
+ encoder_attention_mask,
505
+ deterministic=deterministic,
506
+ init_cache=init_cache,
507
+ output_attentions=output_attentions,
508
+ output_hidden_states=output_hidden_states,
509
+ return_dict=return_dict,
510
+ )
511
+
512
+ if not return_dict:
513
+ return outputs
514
+
515
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
516
+ last_hidden_state=outputs.last_hidden_state,
517
+ hidden_states=outputs.hidden_states,
518
+ attentions=outputs.attentions,
519
+ cross_attentions=outputs.cross_attentions,
520
+ )
521
+
522
+
523
+ class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel):
524
+ config_class = BartConfig
525
+ base_model_prefix: str = "model"
526
+ module_class: nn.Module = None
527
+
528
+ def __init__(
529
+ self,
530
+ config: BartConfig,
531
+ input_shape: Tuple[int] = (1, 1),
532
+ seed: int = 0,
533
+ dtype: jnp.dtype = jnp.float32,
534
+ _do_init: bool = True,
535
+ **kwargs
536
+ ):
537
+ config.is_decoder = True
538
+ config.is_encoder_decoder = False
539
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
540
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
541
+
542
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
543
+ # init input tensors
544
+ input_ids = jnp.zeros(input_shape, dtype="i4")
545
+ attention_mask = jnp.ones_like(input_ids)
546
+
547
+ batch_size, sequence_length = input_ids.shape
548
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
549
+
550
+ params_rng, dropout_rng = jax.random.split(rng)
551
+ rngs = {"params": params_rng, "dropout": dropout_rng}
552
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.d_model,))
553
+ encoder_attention_mask = attention_mask
554
+ module_init_outputs = self.module.init(
555
+ rngs,
556
+ input_ids,
557
+ attention_mask,
558
+ position_ids,
559
+ encoder_hidden_states,
560
+ encoder_attention_mask,
561
+ return_dict=False,
562
+ )
563
+ return module_init_outputs["params"]
564
+
565
+ def init_cache(self, batch_size, max_length):
566
+ r"""
567
+ Args:
568
+ batch_size (`int`):
569
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
570
+ max_length (`int`):
571
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
572
+ cache.
573
+ """
574
+ # init input variables to retrieve cache
575
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
576
+ attention_mask = jnp.ones_like(input_ids, dtype="i4")
577
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
578
+
579
+ init_variables = self.module.init(
580
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
581
+ )
582
+ return unfreeze(init_variables["cache"])
583
+
584
+ def __call__(
585
+ self,
586
+ input_ids: jnp.ndarray,
587
+ attention_mask: Optional[jnp.ndarray] = None,
588
+ position_ids: Optional[jnp.ndarray] = None,
589
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
590
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
591
+ output_attentions: Optional[bool] = None,
592
+ output_hidden_states: Optional[bool] = None,
593
+ return_dict: Optional[bool] = None,
594
+ train: bool = False,
595
+ params: dict = None,
596
+ past_key_values: dict = None,
597
+ dropout_rng: PRNGKey = None,
598
+ ):
599
+ """
600
+ Args:
601
+ input_ids (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`):
602
+ Indices of decoder input sequence tokens in the vocabulary.
603
+
604
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
605
+ [`PreTrainedTokenizer.__call__`] for details.
606
+
607
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
608
+
609
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
610
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
611
+ for denoising pre-training following the paper.
612
+ attention_mask (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`, *optional*):
613
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
614
+ be used by default.
615
+
616
+ If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
617
+ paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
618
+ position_ids (`numpy.ndarray` of shape `(target_batch_size, sequence_length)`, *optional*):
619
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
620
+ range `[0, config.max_position_embeddings - 1]`.
621
+ encoder_hidden_states (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
622
+ A sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
623
+ encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
624
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
625
+
626
+ - 1 for tokens that are **not masked**,
627
+ - 0 for tokens that are **masked**.
628
+
629
+ [What are attention masks?](../glossary#attention-mask)
630
+ output_attentions (`bool`, *optional*):
631
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
632
+ tensors for more detail.
633
+ output_hidden_states (`bool`, *optional*):
634
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
635
+ more detail.
636
+ return_dict (`bool`, *optional*):
637
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
638
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
639
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
640
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
641
+ """
642
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
643
+ output_hidden_states = (
644
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
645
+ )
646
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
647
+
648
+ if encoder_hidden_states is not None and encoder_attention_mask is None:
649
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
650
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
651
+
652
+ # prepare decoder inputs
653
+ if attention_mask is None:
654
+ attention_mask = jnp.ones_like(input_ids)
655
+ if position_ids is None:
656
+ batch_size, sequence_length = input_ids.shape
657
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
658
+
659
+ # Handle any PRNG if needed
660
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
661
+
662
+ inputs = {"params": params or self.params}
663
+
664
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
665
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
666
+ # changed by FlaxBartAttention module
667
+ if past_key_values:
668
+ inputs["cache"] = past_key_values
669
+ mutable = ["cache"]
670
+ else:
671
+ mutable = False
672
+
673
+ outputs = self.module.apply(
674
+ inputs,
675
+ input_ids=jnp.array(input_ids, dtype="i4"),
676
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
677
+ position_ids=jnp.array(position_ids, dtype="i4"),
678
+ encoder_hidden_states=encoder_hidden_states,
679
+ encoder_attention_mask=encoder_attention_mask,
680
+ output_attentions=output_attentions,
681
+ output_hidden_states=output_hidden_states,
682
+ return_dict=return_dict,
683
+ deterministic=not train,
684
+ rngs=rngs,
685
+ mutable=mutable,
686
+ )
687
+
688
+ # add updated cache to model output
689
+ if past_key_values is not None and return_dict:
690
+ outputs, past_key_values = outputs
691
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
692
+ return outputs
693
+ elif past_key_values is not None and not return_dict:
694
+ outputs, past_key_values = outputs
695
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
696
+
697
+ return outputs
698
+
699
+
700
+ class FlaxBartDecoderWrapper(nn.Module):
701
+ """
702
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
703
+ used in combination with the [`EncoderDecoderModel`] framework.
704
+ """
705
+
706
+ config: BartConfig
707
+ dtype: jnp.dtype = jnp.float32
708
+
709
+ def setup(self):
710
+ embed_dim = self.config.d_model
711
+ embed_tokens = nn.Embed(
712
+ self.config.vocab_size,
713
+ embed_dim,
714
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
715
+ )
716
+ self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype)
717
+
718
+ def __call__(self, *args, **kwargs):
719
+ return self.decoder(*args, **kwargs)
720
+
721
+
722
+ class FlaxBartForCausalLMModule(nn.Module):
723
+ """Bart Decoder Module with a language modeling head on top (linear layer with weights tied to the input embeddings)
724
+ e.g. for autoregressive tasks.
725
+ """
726
+
727
+ config: BartConfig
728
+ dtype: jnp.dtype = jnp.float32
729
+
730
+ def setup(self):
731
+ self.model = FlaxBartDecoderWrapper(config=self.config, dtype=self.dtype)
732
+ self.lm_head = nn.Dense(
733
+ self.config.vocab_size,
734
+ use_bias=False,
735
+ dtype=self.dtype,
736
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
737
+ )
738
+
739
+ def __call__(
740
+ self,
741
+ input_ids,
742
+ attention_mask,
743
+ position_ids,
744
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
745
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
746
+ init_cache: bool = False,
747
+ output_attentions: bool = False,
748
+ output_hidden_states: bool = False,
749
+ return_dict: bool = True,
750
+ deterministic: bool = True,
751
+ ):
752
+
753
+ outputs = self.model(
754
+ input_ids,
755
+ attention_mask,
756
+ position_ids,
757
+ encoder_hidden_states,
758
+ encoder_attention_mask,
759
+ deterministic=deterministic,
760
+ init_cache=init_cache,
761
+ output_attentions=output_attentions,
762
+ output_hidden_states=output_hidden_states,
763
+ return_dict=return_dict,
764
+ )
765
+
766
+ hidden_states = outputs[0]
767
+
768
+ if self.config.tie_word_embeddings:
769
+ shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
770
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
771
+ else:
772
+ lm_logits = self.lm_head(hidden_states)
773
+
774
+ if not return_dict:
775
+ return (lm_logits,) + outputs[1:]
776
+
777
+ return FlaxCausalLMOutputWithCrossAttentions(
778
+ logits=lm_logits,
779
+ hidden_states=outputs.hidden_states,
780
+ attentions=outputs.attentions,
781
+ cross_attentions=outputs.cross_attentions,
782
+ )
783
+
784
+
785
+ class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
786
+ """Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings)
787
+ e.g. for autoregressive tasks.
788
+ """
789
+
790
+ module_class = FlaxBartForCausalLMModule
791
+
792
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
793
+ # initializing the cache
794
+ batch_size, seq_length = input_ids.shape
795
+
796
+ past_key_values = self.init_cache(batch_size, max_length)
797
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
798
+ # But since the decoder uses a causal mask, those positions are masked anyway.
799
+ # Thus, we can create a single static attention_mask here, which is more efficient for compilation
800
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
801
+ if attention_mask is not None:
802
+ position_ids = attention_mask.cumsum(axis=-1) - 1
803
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
804
+ else:
805
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
806
+
807
+ return {
808
+ "past_key_values": past_key_values,
809
+ "attention_mask": extended_attention_mask,
810
+ "position_ids": position_ids,
811
+ }
812
+
813
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
814
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
815
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
816
+ return model_kwargs
models/modeling_flax_speech_encoder_decoder.py ADDED
@@ -0,0 +1,1245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Classes to support Flax Speech-Encoder-Decoder architectures"""
16
+
17
+ import os
18
+ from functools import partial
19
+ from typing import Optional, Tuple, Union, Dict
20
+
21
+ import flax
22
+ import flax.linen as nn
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from flax.core.frozen_dict import FrozenDict, unfreeze
26
+ from jax import lax
27
+ from jax.random import PRNGKey
28
+ import numpy as np
29
+
30
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
31
+ from transformers.modeling_flax_utils import FlaxPreTrainedModel
32
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ModelOutput
33
+ from transformers.generation_flax_utils import FlaxLogitsProcessorList
34
+ from models import (
35
+ FlaxWav2Vec2Model,
36
+ FlaxWav2Vec2Module,
37
+ FlaxBartForCausalLM,
38
+ FlaxBartForCausalLMModule,
39
+ BartConfig,
40
+ Wav2Vec2Config,
41
+ SpeechEncoderDecoderConfig,
42
+ )
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+ _CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig"
47
+
48
+ SPEECH_ENCODER_DECODER_START_DOCSTRING = r"""
49
+ This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech
50
+ autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is
51
+ loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via
52
+ [`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder
53
+ and should be fine-tuned on a downstream generative task, like summarization.
54
+
55
+ The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
56
+ tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
57
+ Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
58
+ Zhou, Wei Li, Peter J. Liu.
59
+
60
+ Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech
61
+ Translation](https://arxiv.org/abs/2104.06678) it is shown how leveraging large pretrained speech models for speech
62
+ translation yields a significant performance improvement.
63
+
64
+ After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other
65
+ models (see the examples for more information).
66
+
67
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
68
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
69
+ etc.)
70
+
71
+ This model is also a Flax Linen
72
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
73
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
74
+
75
+ Parameters:
76
+ config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
77
+ Initializing with a config file does not load the weights associated with the model, only the
78
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
79
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
80
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
81
+ `jax.numpy.bfloat16` (on TPUs).
82
+
83
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
84
+ specified all the computation will be performed with the given `dtype`.
85
+
86
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
87
+ parameters.**
88
+
89
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
90
+ [`~FlaxPreTrainedModel.to_bf16`].
91
+ """
92
+
93
+ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
94
+ Args:
95
+ inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
96
+ Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
97
+ or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
98
+ library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
99
+ [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
100
+ *torch.FloatTensor*.
101
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
102
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
103
+
104
+ - 1 for tokens that are **not masked**,
105
+ - 0 for tokens that are **masked**.
106
+
107
+ [What are attention masks?](../glossary#attention-mask)
108
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
109
+ Indices of decoder input sequence tokens in the vocabulary.
110
+
111
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
112
+ [`PreTrainedTokenizer.__call__`] for details.
113
+
114
+ [What are input IDs?](../glossary#input-ids)
115
+
116
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
117
+ `past_key_values`).
118
+
119
+ For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
120
+ created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
121
+ and prepending them with the `decoder_start_token_id`.
122
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
123
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
124
+ be used by default.
125
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
126
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
127
+ range `[0, config.decoder.max_position_embeddings - 1]`.
128
+ output_hidden_states (`bool`, *optional*):
129
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
130
+ more detail.
131
+ return_dict (`bool`, *optional*):
132
+ If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
133
+ """
134
+
135
+ SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
136
+ Args:
137
+ inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
138
+ Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
139
+ or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
140
+ library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
141
+ [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
142
+ *torch.FloatTensor*.
143
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
144
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
145
+
146
+ - 1 for tokens that are **not masked**,
147
+ - 0 for tokens that are **masked**.
148
+
149
+ [What are attention masks?](../glossary#attention-mask)
150
+ output_attentions (`bool`, *optional*):
151
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
152
+ tensors for more detail.
153
+ output_hidden_states (`bool`, *optional*):
154
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
155
+ more detail.
156
+ return_dict (`bool`, *optional*):
157
+ If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple.
158
+ """
159
+
160
+ SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
161
+ Args:
162
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
163
+ Indices of decoder input sequence tokens in the vocabulary.
164
+
165
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
166
+ [`PreTrainedTokenizer.__call__`] for details.
167
+
168
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
169
+
170
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
171
+ `past_key_values`).
172
+
173
+ For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
174
+ created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
175
+ and prepending them with the `decoder_start_token_id`.
176
+ encoder_outputs (`tuple(tuple(jnp.ndarray)`):
177
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
178
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
179
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
180
+ encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
181
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
182
+
183
+ - 1 for tokens that are **not masked**,
184
+ - 0 for tokens that are **masked**.
185
+
186
+ [What are attention masks?](../glossary#attention-mask)
187
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
188
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
189
+ be used by default.
190
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
191
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
192
+ range `[0, config.decoder.max_position_embeddings - 1]`.
193
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
194
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
195
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
196
+ output_attentions (`bool`, *optional*):
197
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
198
+ tensors for more detail.
199
+ output_hidden_states (`bool`, *optional*):
200
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
201
+ more detail.
202
+ return_dict (`bool`, *optional*):
203
+ If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a
204
+ plain tuple.
205
+ """
206
+
207
+ @flax.struct.dataclass
208
+ class FlaxBeamSearchOutput(ModelOutput):
209
+ """
210
+ Flax Base class for outputs of decoder-only generation models using greedy search.
211
+
212
+
213
+ Args:
214
+ sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
215
+ The generated sequences.
216
+ scores (`jnp.ndarray` of shape `(batch_size,)`):
217
+ The scores (log probabilites) of the generated sequences.
218
+ """
219
+
220
+ sequences: jnp.ndarray = None
221
+ scores: jnp.ndarray = None
222
+
223
+
224
+ @flax.struct.dataclass
225
+ class BeamSearchState:
226
+ cur_len: jnp.ndarray
227
+ running_sequences: jnp.ndarray
228
+ running_scores: jnp.ndarray
229
+ sequences: jnp.ndarray
230
+ scores: jnp.ndarray
231
+ is_sent_finished: jnp.ndarray
232
+ model_kwargs: Dict[str, jnp.ndarray]
233
+
234
+
235
+
236
+
237
+ class FlaxSpeechEncoderDecoderModule(nn.Module):
238
+ config: SpeechEncoderDecoderConfig
239
+ dtype: jnp.dtype = jnp.float32
240
+
241
+ def setup(self):
242
+ encoder_config = self.config.encoder
243
+ decoder_config = self.config.decoder
244
+
245
+ # TODO: configure FlaxAutoModel mappings (required when trialling different encoder-decoder combinations)
246
+ encoder_module = FlaxWav2Vec2Module
247
+ decoder_module = FlaxBartForCausalLMModule
248
+
249
+ self.encoder = encoder_module(encoder_config, dtype=self.dtype)
250
+ self.decoder = decoder_module(decoder_config, dtype=self.dtype)
251
+
252
+ # encoder outputs might need to be projected to different dimension for decoder
253
+ if (
254
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
255
+ and self.decoder.config.cross_attention_hidden_size is None
256
+ ):
257
+ self.enc_to_dec_proj = nn.Dense(
258
+ self.decoder.config.hidden_size,
259
+ kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
260
+ dtype=self.dtype,
261
+ )
262
+ else:
263
+ self.enc_to_dec_proj = None
264
+
265
+ def _get_feat_extract_output_lengths(
266
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
267
+ ):
268
+ """
269
+ Computes the output length of the convolutional layers
270
+ """
271
+
272
+ add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter
273
+
274
+ def _conv_out_length(input_length, kernel_size, stride):
275
+ # 1D convolutional layer output length formula taken
276
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
277
+ return (input_length - kernel_size) // stride + 1
278
+
279
+ for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride):
280
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
281
+
282
+ if add_adapter:
283
+ for _ in range(self.config.encoder.num_adapter_layers):
284
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride)
285
+
286
+ return input_lengths
287
+
288
+ def _get_encoder_module(self):
289
+ return self.encoder
290
+
291
+ def _get_projection_module(self):
292
+ return self.enc_to_dec_proj
293
+
294
+ def _get_decoder_module(self):
295
+ return self.decoder
296
+
297
+ def __call__(
298
+ self,
299
+ inputs,
300
+ attention_mask,
301
+ decoder_input_ids,
302
+ decoder_attention_mask,
303
+ decoder_position_ids,
304
+ encoder_outputs=None,
305
+ extract_features=None,
306
+ output_attentions: bool = False,
307
+ output_hidden_states: bool = False,
308
+ output_features: bool = False,
309
+ return_dict: bool = True,
310
+ deterministic: bool = True,
311
+ freeze_feature_encoder: bool = False,
312
+ ):
313
+ if encoder_outputs is None:
314
+ encoder_outputs = self.encoder(
315
+ inputs,
316
+ attention_mask=attention_mask,
317
+ extract_features=extract_features,
318
+ output_attentions=output_attentions,
319
+ output_hidden_states=output_hidden_states,
320
+ output_features=output_features,
321
+ return_dict=return_dict,
322
+ deterministic=deterministic,
323
+ freeze_feature_encoder=freeze_feature_encoder,
324
+ )
325
+
326
+ if output_features:
327
+ return encoder_outputs
328
+
329
+ encoder_hidden_states = encoder_outputs[0]
330
+
331
+ # optionally project encoder_hidden_states
332
+ if self.enc_to_dec_proj is not None:
333
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
334
+
335
+ # compute correct encoder attention mask
336
+ if attention_mask is not None:
337
+ encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
338
+ encoder_hidden_states.shape[1], attention_mask
339
+ )
340
+ else:
341
+ encoder_attention_mask = None
342
+
343
+ # flax script modeling_flax_wav2vec2.py
344
+ decoder_outputs = self.decoder(
345
+ input_ids=decoder_input_ids,
346
+ attention_mask=decoder_attention_mask,
347
+ position_ids=decoder_position_ids,
348
+ encoder_hidden_states=encoder_hidden_states,
349
+ encoder_attention_mask=encoder_attention_mask,
350
+ output_attentions=output_attentions,
351
+ output_hidden_states=output_hidden_states,
352
+ return_dict=return_dict,
353
+ deterministic=deterministic,
354
+ )
355
+
356
+ if not return_dict:
357
+ return decoder_outputs + encoder_outputs
358
+
359
+ return FlaxSeq2SeqLMOutput(
360
+ logits=decoder_outputs.logits,
361
+ decoder_hidden_states=decoder_outputs.hidden_states,
362
+ decoder_attentions=decoder_outputs.attentions,
363
+ cross_attentions=decoder_outputs.cross_attentions,
364
+ encoder_last_hidden_state=encoder_hidden_states,
365
+ encoder_hidden_states=encoder_outputs.hidden_states,
366
+ encoder_attentions=encoder_outputs.attentions,
367
+ )
368
+
369
+
370
+ @add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING)
371
+ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
372
+ r"""
373
+ [`FlaxSpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
374
+ with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one
375
+ as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the
376
+ encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder.
377
+ """
378
+
379
+ config_class = SpeechEncoderDecoderConfig
380
+ base_model_prefix: str = "speech_encoder_decoder"
381
+ module_class = FlaxSpeechEncoderDecoderModule
382
+
383
+ def __init__(
384
+ self,
385
+ config: SpeechEncoderDecoderConfig,
386
+ input_shape: Optional[Tuple] = None,
387
+ seed: int = 0,
388
+ dtype: jnp.dtype = jnp.float32,
389
+ _do_init: bool = True,
390
+ **kwargs
391
+ ):
392
+
393
+ if not _do_init:
394
+ raise ValueError(
395
+ "`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
396
+ )
397
+
398
+ if config.decoder.cross_attention_hidden_size is not None:
399
+ # Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
400
+ if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
401
+ raise ValueError(
402
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
403
+ "it has to be equal to the encoder's `hidden_size`. "
404
+ f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
405
+ f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
406
+ )
407
+
408
+ # make sure input & output embeddings are not tied
409
+ config.tie_word_embeddings = False
410
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
411
+
412
+ if input_shape is None:
413
+ # speech encoders almost always downsample the sequence length dimension
414
+ encoder_input_length = 1024
415
+ decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length)
416
+ input_shape = ((1, encoder_input_length), (1, decoder_input_length))
417
+
418
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
419
+
420
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
421
+ encoder_input_shape, decoder_input_shape = input_shape
422
+
423
+ # init input DeviceArrays
424
+ inputs = jnp.zeros(encoder_input_shape, dtype="f4")
425
+ attention_mask = jnp.ones_like(inputs, dtype="i4")
426
+ decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
427
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
428
+
429
+ batch_size, sequence_length = inputs.shape
430
+
431
+ decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
432
+ if not decoder_batch_size == batch_size:
433
+ raise ValueError(
434
+ f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
435
+ )
436
+ decoder_position_ids = jnp.broadcast_to(
437
+ jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
438
+ )
439
+
440
+ params_rng, dropout_rng = jax.random.split(rng)
441
+ rngs = {"params": params_rng, "dropout": dropout_rng}
442
+
443
+ return self.module.init(
444
+ rngs,
445
+ inputs,
446
+ attention_mask,
447
+ decoder_input_ids,
448
+ decoder_attention_mask,
449
+ decoder_position_ids,
450
+ )["params"]
451
+
452
+ def init_cache(self, batch_size, max_length, encoder_outputs):
453
+ r"""
454
+ Args:
455
+ batch_size (`int`):
456
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
457
+ max_length (`int`):
458
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
459
+ cache.
460
+ encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
461
+ `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
462
+ `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
463
+ is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
464
+ cross-attention of the decoder.
465
+ """
466
+ # init input variables to retrieve cache
467
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
468
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
469
+ decoder_position_ids = jnp.broadcast_to(
470
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
471
+ )
472
+
473
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
474
+ decoder_module = module._get_decoder_module()
475
+ return decoder_module(
476
+ input_ids=decoder_input_ids,
477
+ attention_mask=decoder_attention_mask,
478
+ position_ids=decoder_position_ids,
479
+ **kwargs,
480
+ )
481
+
482
+ init_variables = self.module.init(
483
+ jax.random.PRNGKey(0),
484
+ decoder_input_ids=decoder_input_ids,
485
+ decoder_attention_mask=decoder_attention_mask,
486
+ decoder_position_ids=decoder_position_ids,
487
+ encoder_hidden_states=encoder_outputs[0],
488
+ init_cache=True,
489
+ method=_decoder_forward, # we only need to call the decoder to init the cache
490
+ )
491
+ return unfreeze(init_variables["cache"])
492
+
493
+ def _get_feat_extract_output_lengths(
494
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
495
+ ):
496
+ return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
497
+
498
+ @add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
499
+ @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
500
+ def encode(
501
+ self,
502
+ inputs: jnp.ndarray,
503
+ attention_mask: Optional[jnp.ndarray] = None,
504
+ extract_features: Optional[jnp.ndarray] = None,
505
+ output_attentions: Optional[bool] = None,
506
+ output_hidden_states: Optional[bool] = None,
507
+ output_features: Optional[bool] = None,
508
+ return_dict: Optional[bool] = None,
509
+ train: bool = False,
510
+ freeze_feature_encoder: bool = False,
511
+ params: dict = None,
512
+ dropout_rng: PRNGKey = None,
513
+ ):
514
+ r"""
515
+ Returns:
516
+
517
+ Example:
518
+
519
+ ```python
520
+ >>> from transformers import FlaxSpeechEncoderDecoderModel
521
+
522
+ >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
523
+ >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
524
+ ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
525
+ ... )
526
+
527
+ >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
528
+ >>> encoder_outputs = model.encode(inputs)
529
+ ```"""
530
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
531
+ output_hidden_states = (
532
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
533
+ )
534
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
535
+
536
+ if attention_mask is None:
537
+ attention_mask = jnp.ones_like(inputs, dtype="i4")
538
+
539
+ if extract_features is not None:
540
+ extract_features = jnp.array(extract_features, dtype="f4")
541
+
542
+ # Handle any PRNG if needed
543
+ rngs = {}
544
+ if dropout_rng is not None:
545
+ rngs["dropout"] = dropout_rng
546
+
547
+ def _encoder_forward(module, inputs, attention_mask, **kwargs):
548
+ encode_module = module._get_encoder_module()
549
+ return encode_module(inputs, attention_mask, **kwargs)
550
+
551
+ outputs = self.module.apply(
552
+ {"params": params or self.params},
553
+ inputs=jnp.array(inputs, dtype="f4"),
554
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
555
+ extract_features=extract_features,
556
+ output_attentions=output_attentions,
557
+ output_hidden_states=output_hidden_states,
558
+ output_features=output_features,
559
+ return_dict=return_dict,
560
+ deterministic=not train,
561
+ freeze_feature_encoder=freeze_feature_encoder,
562
+ rngs=rngs,
563
+ method=_encoder_forward,
564
+ )
565
+
566
+ if return_dict and not output_features:
567
+ outputs = FlaxBaseModelOutput(
568
+ last_hidden_state=outputs.last_hidden_state,
569
+ hidden_states=outputs.hidden_states,
570
+ attentions=outputs.attentions,
571
+ )
572
+
573
+ return outputs
574
+
575
+ @add_start_docstrings(SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
576
+ @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
577
+ def decode(
578
+ self,
579
+ decoder_input_ids,
580
+ encoder_outputs,
581
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
582
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
583
+ decoder_position_ids: Optional[jnp.ndarray] = None,
584
+ past_key_values: dict = None,
585
+ output_attentions: Optional[bool] = None,
586
+ output_hidden_states: Optional[bool] = None,
587
+ return_dict: Optional[bool] = None,
588
+ train: bool = False,
589
+ params: dict = None,
590
+ dropout_rng: PRNGKey = None,
591
+ ):
592
+ r"""
593
+ Returns:
594
+
595
+ Example:
596
+
597
+ ```python
598
+ >>> from transformers import FlaxSpeechEncoderDecoderModel
599
+ >>> import jax.numpy as jnp
600
+
601
+ >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
602
+ >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
603
+ ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
604
+ ... )
605
+
606
+ >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
607
+ >>> encoder_outputs = model.encode(inputs)
608
+
609
+ >>> decoder_start_token_id = model.config.decoder.bos_token_id
610
+ >>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id
611
+
612
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
613
+ >>> logits = outputs.logits
614
+ ```"""
615
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
616
+ output_hidden_states = (
617
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
618
+ )
619
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
620
+
621
+ encoder_hidden_states = encoder_outputs[0]
622
+ if encoder_attention_mask is None:
623
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
624
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
625
+
626
+ batch_size, sequence_length = decoder_input_ids.shape
627
+ if decoder_attention_mask is None:
628
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
629
+
630
+ if decoder_position_ids is None:
631
+ if past_key_values is not None:
632
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
633
+
634
+ decoder_position_ids = jnp.broadcast_to(
635
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
636
+ )
637
+
638
+ # Handle any PRNG if needed
639
+ rngs = {}
640
+ if dropout_rng is not None:
641
+ rngs["dropout"] = dropout_rng
642
+
643
+ params = {"params": params or self.params}
644
+
645
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
646
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
647
+ # it can be changed by FlaxBartAttention module
648
+ if past_key_values:
649
+ params["cache"] = past_key_values
650
+ mutable = ["cache"]
651
+ else:
652
+ mutable = False
653
+
654
+ def _decoder_forward(
655
+ module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
656
+ ):
657
+
658
+ projection_module = module._get_projection_module()
659
+ decoder_module = module._get_decoder_module()
660
+
661
+ # optionally project encoder_hidden_states
662
+ if projection_module is not None:
663
+ encoder_hidden_states = projection_module(encoder_hidden_states)
664
+
665
+ return decoder_module(
666
+ decoder_input_ids,
667
+ decoder_attention_mask,
668
+ decoder_position_ids,
669
+ encoder_hidden_states,
670
+ **kwargs,
671
+ )
672
+
673
+ outputs = self.module.apply(
674
+ params,
675
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
676
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
677
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
678
+ encoder_hidden_states=encoder_hidden_states,
679
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
680
+ output_attentions=output_attentions,
681
+ output_hidden_states=output_hidden_states,
682
+ return_dict=return_dict,
683
+ deterministic=not train,
684
+ rngs=rngs,
685
+ mutable=mutable,
686
+ method=_decoder_forward,
687
+ )
688
+
689
+ # add updated cache to model output
690
+ if past_key_values is not None and return_dict:
691
+ outputs, past = outputs
692
+ outputs["past_key_values"] = unfreeze(past["cache"])
693
+ return outputs
694
+ elif past_key_values is not None and not return_dict:
695
+ outputs, past = outputs
696
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
697
+
698
+ return outputs
699
+
700
+ @add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING)
701
+ @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
702
+ def __call__(
703
+ self,
704
+ inputs: jnp.ndarray,
705
+ attention_mask: Optional[jnp.ndarray] = None,
706
+ extract_features: Optional[jnp.ndarray] = None,
707
+ decoder_input_ids: Optional[jnp.ndarray] = None,
708
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
709
+ decoder_position_ids: Optional[jnp.ndarray] = None,
710
+ output_attentions: Optional[bool] = None,
711
+ output_hidden_states: Optional[bool] = None,
712
+ output_features: Optional[bool] = None,
713
+ return_dict: Optional[bool] = None,
714
+ train: bool = False,
715
+ freeze_feature_encoder: bool = False,
716
+ params: dict = None,
717
+ dropout_rng: PRNGKey = None,
718
+ ):
719
+ r"""
720
+ Returns:
721
+
722
+ Examples:
723
+
724
+ ```python
725
+ >>> from transformers import FlaxSpeechEncoderDecoderModel, BartTokenizer
726
+
727
+ >>> # load a fine-tuned wav2vec2-2-bart model
728
+ >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large")
729
+ >>> # load output tokenizer
730
+ >>> tokenizer_output = BartTokenizer.from_pretrained("facebook/bart-large")
731
+
732
+ >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
733
+
734
+ >>> # use bart's special bos, pad and eos tokens
735
+ >>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id
736
+ >>> model.config.pad_token_id = model.decoder.config.pad_token_id
737
+ >>> model.config.eos_token_id = model.decoder.config.eos_token_id
738
+
739
+ >>> outputs = model.generate(inputs)
740
+ # Assert something? More interesting input? dtype correct?
741
+ ```
742
+ """
743
+
744
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
745
+ output_hidden_states = (
746
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
747
+ )
748
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
749
+
750
+ # prepare encoder inputs
751
+ if attention_mask is None:
752
+ attention_mask = jnp.ones_like(inputs, dtype="i4")
753
+
754
+ if extract_features is not None:
755
+ inputs = None # we can omit passing the inputs to the model to save memory
756
+ extract_features = jnp.array(extract_features, dtype="f4")
757
+ else:
758
+ inputs = jnp.array(inputs, dtype="f4")
759
+
760
+ # prepare decoder inputs
761
+ if decoder_input_ids is None:
762
+ raise ValueError(
763
+ "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
764
+ )
765
+ if decoder_attention_mask is None:
766
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
767
+ if decoder_position_ids is None:
768
+ batch_size, sequence_length = decoder_input_ids.shape
769
+ decoder_position_ids = jnp.broadcast_to(
770
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
771
+ )
772
+
773
+ # Handle any PRNG if needed
774
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
775
+
776
+ return self.module.apply(
777
+ {"params": params or self.params},
778
+ inputs=inputs,
779
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
780
+ extract_features=extract_features,
781
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
782
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
783
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
784
+ output_attentions=output_attentions,
785
+ output_hidden_states=output_hidden_states,
786
+ output_features=output_features,
787
+ return_dict=return_dict,
788
+ deterministic=not train,
789
+ freeze_feature_encoder=freeze_feature_encoder,
790
+ rngs=rngs,
791
+ )
792
+
793
+ def prepare_inputs_for_generation(
794
+ self,
795
+ decoder_input_ids,
796
+ max_length,
797
+ attention_mask: Optional[jnp.DeviceArray] = None,
798
+ decoder_attention_mask: Optional[jnp.DeviceArray] = None,
799
+ encoder_outputs=None,
800
+ **kwargs
801
+ ):
802
+ # initializing the cache
803
+ batch_size, seq_length = decoder_input_ids.shape
804
+
805
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
806
+ # Note that usually one would have to put 0's in the attention_mask for x > input.shape[-1] and x < cache_length.
807
+ # But since the decoder uses a causal mask, those positions are masked anyways.
808
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
809
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
810
+ if decoder_attention_mask is not None:
811
+ decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
812
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
813
+ else:
814
+ decoder_position_ids = jnp.broadcast_to(
815
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
816
+ )
817
+
818
+ return {
819
+ "past_key_values": past_key_values,
820
+ "encoder_outputs": encoder_outputs,
821
+ "encoder_attention_mask": attention_mask,
822
+ "decoder_attention_mask": extended_attention_mask,
823
+ "decoder_position_ids": decoder_position_ids,
824
+ }
825
+
826
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
827
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
828
+ model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
829
+ return model_kwargs
830
+
831
+ @classmethod
832
+ def from_encoder_decoder_pretrained(
833
+ cls,
834
+ encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
835
+ decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
836
+ *model_args,
837
+ **kwargs
838
+ ) -> FlaxPreTrainedModel:
839
+ r"""
840
+ Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
841
+ checkpoints.
842
+
843
+ Params:
844
+ encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*):
845
+ Information necessary to initiate the encoder. Can be either:
846
+
847
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
848
+ Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
849
+ user or organization name, like `dbmdz/bert-base-german-cased`.
850
+ - A path to a *directory* containing model weights saved using
851
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
852
+
853
+ decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`):
854
+ Information necessary to initiate the decoder. Can be either:
855
+
856
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
857
+ Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
858
+ user or organization name, like `dbmdz/bert-base-german-cased`.
859
+ - A path to a *directory* containing model weights saved using
860
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
861
+
862
+ model_args (remaining positional arguments, *optional*):
863
+ All remaning positional arguments will be passed to the underlying model's `__init__` method.
864
+
865
+ kwargs (remaining dictionary of keyword arguments, *optional*):
866
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
867
+ `output_attentions=True`).
868
+
869
+ - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
870
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
871
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
872
+
873
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
874
+
875
+ Example:
876
+
877
+ ```python
878
+ >>> from transformers import FlaxSpeechEncoderDecoderModel
879
+
880
+ >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
881
+ >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
882
+ ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
883
+ ... )
884
+ >>> # saving model after fine-tuning
885
+ >>> model.save_pretrained("./wav2vec2-2-bart-large")
886
+ >>> # load fine-tuned model
887
+ >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("./wav2vec2-2-bart-large")
888
+ ```"""
889
+
890
+ kwargs_encoder = {
891
+ argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
892
+ }
893
+
894
+ kwargs_decoder = {
895
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
896
+ }
897
+
898
+ # remove encoder, decoder kwargs from kwargs
899
+ for key in kwargs_encoder.keys():
900
+ del kwargs["encoder_" + key]
901
+ for key in kwargs_decoder.keys():
902
+ del kwargs["decoder_" + key]
903
+
904
+ # Load and initialize the encoder and decoder
905
+ # The distinction between encoder and decoder at the model level is made
906
+ # by the value of the flag `is_decoder` that we need to set correctly.
907
+ encoder = kwargs_encoder.pop("model", None)
908
+ if encoder is None:
909
+ if encoder_pretrained_model_name_or_path is None:
910
+ raise ValueError(
911
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
912
+ "to be defined."
913
+ )
914
+
915
+ if "config" not in kwargs_encoder:
916
+ # TODO: AutoConfig .from_pretrained
917
+ encoder_config, kwargs_encoder = Wav2Vec2Config.from_pretrained(
918
+ encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
919
+ )
920
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
921
+ logger.info(
922
+ f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
923
+ "from a decoder model. Cross-attention and casual mask are disabled."
924
+ )
925
+ encoder_config.is_decoder = False
926
+ encoder_config.add_cross_attention = False
927
+
928
+ kwargs_encoder["config"] = encoder_config
929
+
930
+ # TODO: FlaxAutoModel .from_pretrained
931
+ encoder = FlaxWav2Vec2Model.from_pretrained(
932
+ encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
933
+ )
934
+
935
+ decoder = kwargs_decoder.pop("model", None)
936
+ if decoder is None:
937
+ if decoder_pretrained_model_name_or_path is None:
938
+ raise ValueError(
939
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
940
+ "to be defined."
941
+ )
942
+
943
+ if "config" not in kwargs_decoder:
944
+ # TODO: AutoConfig .from_pretrained
945
+ decoder_config, kwargs_decoder = BartConfig.from_pretrained(
946
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
947
+ )
948
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
949
+ logger.info(
950
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
951
+ f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
952
+ f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
953
+ "cross attention layers."
954
+ )
955
+ decoder_config.is_decoder = True
956
+ decoder_config.add_cross_attention = True
957
+
958
+ kwargs_decoder["config"] = decoder_config
959
+
960
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
961
+ logger.warning(
962
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
963
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
964
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
965
+ "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
966
+ "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
967
+ )
968
+
969
+ # TODO: FlaxAutoModelForCausalLM .from_pretrained
970
+ decoder = FlaxBartForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
971
+
972
+ # instantiate config with corresponding kwargs
973
+ dtype = kwargs.pop("dtype", jnp.float32)
974
+ config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
975
+
976
+ # make sure input & output word embeddings are not tied
977
+ config.tie_word_embeddings = False
978
+
979
+ # init model
980
+ model = cls(config, dtype=dtype)
981
+ model.params["encoder"] = encoder.params
982
+ model.params["decoder"] = decoder.params
983
+
984
+ return model
985
+
986
+ def _beam_search(
987
+ self,
988
+ input_ids: None,
989
+ max_length: Optional[int] = None,
990
+ pad_token_id: Optional[int] = None,
991
+ eos_token_id: Optional[int] = None,
992
+ length_penalty: Optional[float] = None,
993
+ early_stopping: Optional[bool] = None,
994
+ logits_processor: Optional[FlaxLogitsProcessorList] = None,
995
+ trace: bool = True,
996
+ params: Optional[Dict[str, jnp.ndarray]] = None,
997
+ model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
998
+ ):
999
+ """
1000
+ This beam search function is heavily inspired by Flax's official example:
1001
+ https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
1002
+ """
1003
+
1004
+ def flatten_beam_dim(tensor):
1005
+ """Flattens the first two dimensions of a non-scalar array."""
1006
+ # ignore scalars (e.g. cache index)
1007
+ if tensor.ndim == 0 or tensor.ndim == 1:
1008
+ return tensor
1009
+ elif tensor.ndim == 6:
1010
+ return tensor.reshape(tensor.shape[:1] + (tensor.shape[1] * tensor.shape[2],) + tensor.shape[3:])
1011
+ return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
1012
+
1013
+ def unflatten_beam_dim(tensor, batch_size, num_beams):
1014
+ """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
1015
+ # ignore scalars (e.g. cache index)
1016
+ if tensor.ndim == 0 or tensor.ndim == 1:
1017
+ return tensor
1018
+ if tensor.ndim == 5:
1019
+ return tensor.reshape(tensor.shape[:1] + (batch_size, num_beams) + tensor.shape[2:])
1020
+ return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
1021
+
1022
+ def gather_beams(nested, beam_indices, batch_size, new_num_beams):
1023
+ """
1024
+ Gathers the beam slices indexed by beam_indices into new beam array.
1025
+ """
1026
+ batch_indices = jnp.reshape(
1027
+ jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
1028
+ )
1029
+
1030
+ def gather_fn(tensor):
1031
+ # ignore scalars (e.g. cache index)
1032
+ if tensor.ndim == 0 or tensor.ndim == 1:
1033
+ return tensor
1034
+ if tensor.ndim == 6:
1035
+ return tensor[:, batch_indices, beam_indices]
1036
+ return tensor[batch_indices, beam_indices]
1037
+
1038
+ return jax.tree_map(gather_fn, nested)
1039
+
1040
+ # init values
1041
+ max_length = max_length if max_length is not None else self.config.max_length
1042
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
1043
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1044
+ length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
1045
+ early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
1046
+
1047
+ batch_size, num_beams, cur_len = input_ids.shape
1048
+
1049
+ eos_token_id = jnp.array(eos_token_id)
1050
+ pad_token_id = jnp.array(pad_token_id)
1051
+ cur_len = jnp.array(cur_len)
1052
+
1053
+ # per batch,beam-item holding current token in loop.
1054
+ sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
1055
+ running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
1056
+ running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
1057
+
1058
+ # per batch,beam-item state bit indicating if sentence has finished.
1059
+ is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
1060
+
1061
+ # per batch,beam-item score, logprobs
1062
+ running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
1063
+ scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
1064
+
1065
+ # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
1066
+ # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
1067
+ model = self.decode if self.config.is_encoder_decoder else self
1068
+
1069
+ # flatten beam dim
1070
+ if "encoder_outputs" in model_kwargs:
1071
+ model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
1072
+ model_kwargs["encoder_outputs"]["last_hidden_state"]
1073
+ )
1074
+ if "attention_mask" in model_kwargs:
1075
+ model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])
1076
+
1077
+ # initialize model specific kwargs
1078
+ model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
1079
+
1080
+ # initialize state
1081
+ state = BeamSearchState(
1082
+ cur_len=cur_len,
1083
+ running_sequences=running_sequences,
1084
+ running_scores=running_scores,
1085
+ sequences=sequences,
1086
+ scores=scores,
1087
+ is_sent_finished=is_sent_finished,
1088
+ model_kwargs=model_kwargs,
1089
+ )
1090
+
1091
+ def beam_search_cond_fn(state):
1092
+ """beam search state termination condition fn."""
1093
+
1094
+ # 1. is less than max length?
1095
+ not_max_length_yet = state.cur_len < max_length
1096
+
1097
+ # 2. can the new beams still improve?
1098
+ best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
1099
+ worst_finished_score = jnp.where(
1100
+ state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
1101
+ )
1102
+ improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
1103
+
1104
+ # 3. is there still a beam that has not finished?
1105
+ still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
1106
+
1107
+ return not_max_length_yet & still_open_beam & improvement_still_possible
1108
+
1109
+ def beam_search_body_fn(state, input_ids_length=1):
1110
+ """beam search state update fn."""
1111
+ # 1. Forward current tokens
1112
+ # Collect the current position slice along length to feed the fast
1113
+ # autoregressive decoder model. Flatten the beam dimension into batch
1114
+ # dimension for feeding into the model.
1115
+ # unflatten beam dimension
1116
+ # Unflatten beam dimension in attention cache arrays
1117
+ input_token = flatten_beam_dim(
1118
+ lax.dynamic_slice(
1119
+ state.running_sequences,
1120
+ (0, 0, state.cur_len - input_ids_length),
1121
+ (batch_size, num_beams, input_ids_length),
1122
+ )
1123
+ )
1124
+ model_outputs = model(input_token, params=params, **state.model_kwargs)
1125
+
1126
+ logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
1127
+ cache = jax.tree_map(
1128
+ lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
1129
+ )
1130
+
1131
+ # adapt logits for FlaxMarianMTModel
1132
+ logits = self._adapt_logits_for_beam_search(logits)
1133
+
1134
+ # 2. Compute log probs
1135
+ # get log probabilities from logits,
1136
+ # process logits with processors (*e.g.* min_length, ...), and
1137
+ # add new logprobs to existing running logprobs scores.
1138
+ log_probs = jax.nn.log_softmax(logits)
1139
+ log_probs = logits_processor(
1140
+ flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
1141
+ )
1142
+ log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
1143
+ log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
1144
+ vocab_size = log_probs.shape[2]
1145
+ log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
1146
+
1147
+ # 3. Retrieve top-K
1148
+ # Each item in batch has num_beams * vocab_size candidate sequences.
1149
+ # For each item, get the top 2*k candidates with the highest log-
1150
+ # probabilities. We gather the top 2*K beams here so that even if the best
1151
+ # K sequences reach EOS simultaneously, we have another K sequences
1152
+ # remaining to continue the live beam search.
1153
+ # Gather the top 2*K scores from _all_ beams.
1154
+ # Gather 2*k top beams.
1155
+ # Recover the beam index by floor division.
1156
+ # Recover token id by modulo division and expand Id array for broadcasting.
1157
+ # Update sequences for the 2*K top-k new sequences.
1158
+ beams_to_keep = 2 * num_beams
1159
+ topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
1160
+ topk_beam_indices = topk_indices // vocab_size
1161
+ topk_running_sequences = gather_beams(
1162
+ state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
1163
+ )
1164
+ topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
1165
+ topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))
1166
+
1167
+ # 4. Check which sequences have ended
1168
+ # Update current sequences:
1169
+ # Did any of these sequences reach an end marker?
1170
+ # To prevent these just finished sequences from being added to the current sequences
1171
+ # set of active beam search sequences, set their log probs to a very large
1172
+ # negative value.
1173
+ did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
1174
+ running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
1175
+ # 5. Get running sequences scores for next
1176
+ # Determine the top k beam indices (from top 2*k beams) from log probs
1177
+ # and gather top k beams (from top 2*k beams).
1178
+ next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1)
1179
+ next_running_sequences, next_running_scores = gather_beams(
1180
+ [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
1181
+ )
1182
+
1183
+ # 6. Process topk logits
1184
+ # Further process log probs:
1185
+ # - add length penalty
1186
+ # - make sure no scores can be added anymore if beam is full
1187
+ # - make sure still running sequences cannot be chosen as finalized beam
1188
+ topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
1189
+ beams_in_batch_are_full = (
1190
+ jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
1191
+ & early_stopping
1192
+ )
1193
+ add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
1194
+ topk_log_probs += add_penalty * np.array(-1.0e7)
1195
+
1196
+ # 7. Get scores, sequences, is sentence finished for next.
1197
+ # Combine sequences, scores, and flags along the beam dimension and compare
1198
+ # new finished sequence scores to existing finished scores and select the
1199
+ # best from the new set of beams
1200
+ merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
1201
+ merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
1202
+ merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
1203
+ topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
1204
+ next_sequences, next_scores, next_is_sent_finished = gather_beams(
1205
+ [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
1206
+ )
1207
+
1208
+ # 8. Update model kwargs.
1209
+ # Determine the top k beam indices from the original set of all beams.
1210
+ # With these, gather the top k beam-associated caches.
1211
+ next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
1212
+ next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
1213
+ model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
1214
+ next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
1215
+
1216
+ return BeamSearchState(
1217
+ cur_len=state.cur_len + 1,
1218
+ running_scores=next_running_scores,
1219
+ running_sequences=next_running_sequences,
1220
+ scores=next_scores,
1221
+ sequences=next_sequences,
1222
+ is_sent_finished=next_is_sent_finished,
1223
+ model_kwargs=next_model_kwargs,
1224
+ )
1225
+
1226
+ # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
1227
+ if input_ids.shape[-1] > 1:
1228
+ state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
1229
+
1230
+ if not trace:
1231
+ state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
1232
+ else:
1233
+ state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
1234
+
1235
+ # Account for the edge-case where there are no finished sequences for a
1236
+ # particular batch item. If so, return running sequences for that batch item.
1237
+ none_finished = jnp.any(state.is_sent_finished, axis=1)
1238
+ sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
1239
+ scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
1240
+
1241
+ # return all beams for each batch and the best score
1242
+ sequences = sequences[:, :]
1243
+ scores = scores[:, -1]
1244
+
1245
+ return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
models/modeling_flax_wav2vec2.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Flax Wav2Vec2 model."""
16
+
17
+ from functools import partial
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import flax
21
+ import flax.linen as nn
22
+ import jax
23
+ import jax.numpy as jnp
24
+ from flax.core.frozen_dict import FrozenDict
25
+ from flax.linen import partitioning as nn_partitioning
26
+ from flax.linen.attention import dot_product_attention_weights
27
+ from jax import lax
28
+
29
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
30
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
31
+ from transformers.utils import ModelOutput
32
+
33
+ from models import Wav2Vec2Config
34
+
35
+ scan_with_axes = nn_partitioning.scan_with_axes
36
+ remat = nn_partitioning.remat
37
+
38
+
39
+ @flax.struct.dataclass
40
+ class FlaxWav2Vec2BaseModelOutput(ModelOutput):
41
+ """
42
+ Output type of [`FlaxWav2Vec2BaseModelOutput`], with potential hidden states and attentions.
43
+
44
+ Args:
45
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
46
+ Sequence of hidden-states at the output of the last layer of the model.
47
+ extract_features (`jnp.ndarray` of shape `(batch_size, sequence_length, last_conv_dim)`):
48
+ Sequence of extracted feature vectors of the last convolutional layer of the model with `last_conv_dim`
49
+ being the dimension of the last convolutional layer.
50
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
51
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
52
+ `(batch_size, sequence_length, hidden_size)`.
53
+
54
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
55
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
56
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
57
+ sequence_length)`.
58
+
59
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
60
+ heads.
61
+ """
62
+
63
+ last_hidden_state: jnp.ndarray = None
64
+ extract_features: jnp.ndarray = None
65
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
66
+ attentions: Optional[Tuple[jnp.ndarray]] = None
67
+
68
+
69
+ WAV_2_VEC_2_START_DOCSTRING = r"""
70
+ Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
71
+ Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
72
+ Auli.
73
+
74
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
75
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
76
+ etc.)
77
+
78
+ This model is also a Flax Linen
79
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
80
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
81
+
82
+ Finally, this model supports inherent JAX features such as:
83
+
84
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
85
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
86
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
87
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
88
+
89
+ Parameters:
90
+ config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model.
91
+ Initializing with a config file does not load the weights associated with the model, only the
92
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
93
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
94
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
95
+ `jax.numpy.bfloat16` (on TPUs).
96
+
97
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
98
+ specified all the computation will be performed with the given `dtype`.
99
+
100
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
101
+ parameters.**
102
+
103
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
104
+ [`~FlaxPreTrainedModel.to_bf16`].
105
+ """
106
+
107
+
108
+ WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
109
+ Args:
110
+ input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
111
+ Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
112
+ into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
113
+ soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding
114
+ and conversion into a tensor of type *jnp.ndarray*. See [`Wav2Vec2Processor.__call__`] for details.
115
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
116
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
117
+ 1]`:
118
+
119
+ - 1 for tokens that are **not masked**,
120
+ - 0 for tokens that are **masked**.
121
+
122
+ [What are attention masks?](../glossary#attention-mask) .. warning:: `attention_mask` should only be passed
123
+ if the corresponding processor has `config.return_attention_mask == True`. For all models whose processor
124
+ has `config.return_attention_mask == False`, such as
125
+ [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), `attention_mask` should **not** be
126
+ passed to avoid degraded performance when doing batched inference. For such models `input_values` should
127
+ simply be padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly
128
+ different results depending on whether `input_values` is padded or not.
129
+ mask_time_indices (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
130
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
131
+ masked extracted features in *config.proj_codevector_dim* space.
132
+ output_attentions (`bool`, *optional*):
133
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
134
+ tensors for more detail.
135
+ output_hidden_states (`bool`, *optional*):
136
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
137
+ more detail.
138
+ return_dict (`bool`, *optional*):
139
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
140
+ """
141
+
142
+
143
+ class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
144
+ config: Wav2Vec2Config
145
+ layer_id: int = 0
146
+ dtype: jnp.dtype = jnp.float32
147
+
148
+ def setup(self):
149
+ self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1
150
+ self.out_conv_dim = self.config.conv_dim[self.layer_id]
151
+
152
+ self.conv = nn.Conv(
153
+ features=self.config.conv_dim[self.layer_id],
154
+ kernel_size=(self.config.conv_kernel[self.layer_id],),
155
+ strides=(self.config.conv_stride[self.layer_id],),
156
+ use_bias=self.config.conv_bias,
157
+ kernel_init=jax.nn.initializers.he_normal(),
158
+ padding="VALID",
159
+ dtype=self.dtype,
160
+ )
161
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
162
+ self.activation = ACT2FN[self.config.feat_extract_activation]
163
+
164
+ def __call__(self, hidden_states):
165
+ hidden_states = self.conv(hidden_states)
166
+ hidden_states = self.layer_norm(hidden_states)
167
+ hidden_states = self.activation(hidden_states)
168
+ return hidden_states
169
+
170
+
171
+ class FlaxConvWithWeightNorm(nn.Module):
172
+ config: Wav2Vec2Config
173
+ dtype: jnp.dtype = jnp.float32
174
+
175
+ def setup(self):
176
+ self.conv = nn.Conv(
177
+ features=self.config.hidden_size,
178
+ kernel_size=(self.config.num_conv_pos_embeddings,),
179
+ kernel_init=jax.nn.initializers.he_normal(),
180
+ padding="VALID",
181
+ feature_group_count=self.config.num_conv_pos_embedding_groups,
182
+ dtype=self.dtype,
183
+ )
184
+ weight_shape = (
185
+ self.conv.features,
186
+ self.conv.features // self.conv.feature_group_count,
187
+ self.conv.kernel_size[0],
188
+ )
189
+ self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape)
190
+ self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
191
+ self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
192
+ self.prev_padding = self.conv.kernel_size[0] // 2
193
+
194
+ def _get_normed_weights(self):
195
+ weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]
196
+ normed_weight_v = jnp.divide(self.weight_v, weight_v_norm)
197
+ normed_kernel = jnp.multiply(normed_weight_v, self.weight_g)
198
+ return normed_kernel
199
+
200
+ def __call__(self, hidden_states):
201
+ kernel = self._get_normed_weights()
202
+ hidden_states = jnp.pad(hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0)))
203
+ hidden_states = self.conv.apply({"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states)
204
+ return hidden_states
205
+
206
+
207
+ class FlaxWav2Vec2PositionalConvEmbedding(nn.Module):
208
+ config: Wav2Vec2Config
209
+ dtype: jnp.dtype = jnp.float32
210
+
211
+ def setup(self):
212
+ self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype)
213
+ self.activation = ACT2FN[self.config.feat_extract_activation]
214
+ self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0
215
+
216
+ def __call__(self, hidden_states):
217
+ hidden_states = hidden_states.transpose((0, 1, 2))
218
+
219
+ hidden_states = self.conv(hidden_states)
220
+
221
+ if self.num_pad_remove > 0:
222
+ hidden_states = hidden_states[:, : -self.num_pad_remove, :]
223
+ hidden_states = self.activation(hidden_states)
224
+
225
+ hidden_states = hidden_states.transpose((0, 1, 2))
226
+ return hidden_states
227
+
228
+
229
+ class FlaxConvLayersCollection(nn.Module):
230
+ config: Wav2Vec2Config
231
+ dtype: jnp.dtype = jnp.float32
232
+
233
+ def setup(self):
234
+ if self.config.feat_extract_norm == "layer":
235
+ # note that we can't use scan on the conv layers as they differ on a layer-by-layer basis
236
+ BlockLayer = remat(FlaxWav2Vec2LayerNormConvLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2LayerNormConvLayer
237
+ self.layers = [
238
+ BlockLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
239
+ for i in range(self.config.num_feat_extract_layers)
240
+ ]
241
+ elif self.config.feat_extract_norm == "group":
242
+ raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported")
243
+ else:
244
+ raise ValueError(
245
+ f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group', 'layer']"
246
+ )
247
+
248
+ def __call__(self, hidden_states):
249
+ for i, conv_layer in enumerate(self.layers):
250
+ hidden_states = conv_layer(hidden_states)
251
+ return hidden_states
252
+
253
+
254
+ class FlaxWav2Vec2FeatureEncoder(nn.Module):
255
+ """Construct the features from raw audio waveform"""
256
+
257
+ config: Wav2Vec2Config
258
+ dtype: jnp.dtype = jnp.float32
259
+
260
+ def setup(self):
261
+ self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype)
262
+
263
+ def __call__(self, input_values, freeze_feature_encoder=False):
264
+ hidden_states = input_values[:, :, None]
265
+ hidden_states = self.conv_layers(hidden_states)
266
+ if freeze_feature_encoder:
267
+ hidden_states = jax.lax.stop_gradient(hidden_states)
268
+ return hidden_states
269
+
270
+
271
+ class FlaxWav2Vec2FeatureProjection(nn.Module):
272
+ config: Wav2Vec2Config
273
+ dtype: jnp.dtype = jnp.float32
274
+
275
+ def setup(self):
276
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
277
+ self.projection = nn.Dense(
278
+ self.config.hidden_size,
279
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
280
+ dtype=self.dtype,
281
+ )
282
+ self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout)
283
+
284
+ def __call__(self, hidden_states, deterministic=True):
285
+ norm_hidden_states = self.layer_norm(hidden_states)
286
+ hidden_states = self.projection(norm_hidden_states)
287
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
288
+ return hidden_states, norm_hidden_states
289
+
290
+
291
+ class FlaxWav2Vec2Attention(nn.Module):
292
+ config: Wav2Vec2Config
293
+ embed_dim: int
294
+ num_heads: int
295
+ dropout: float = 0.0
296
+ bias: bool = True
297
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
298
+
299
+ def setup(self) -> None:
300
+ self.head_dim = self.embed_dim // self.num_heads
301
+ if self.head_dim * self.num_heads != self.embed_dim:
302
+ raise ValueError(
303
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
304
+ )
305
+
306
+ dense = partial(
307
+ nn.Dense,
308
+ self.embed_dim,
309
+ use_bias=self.bias,
310
+ dtype=self.dtype,
311
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
312
+ )
313
+
314
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
315
+
316
+ self.fused_proj = nn.Dense(
317
+ self.embed_dim * 3,
318
+ use_bias=self.bias,
319
+ dtype=self.dtype,
320
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
321
+ )
322
+
323
+ self.out_proj = dense()
324
+
325
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
326
+
327
+ def _split_heads(self, hidden_states):
328
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
329
+
330
+ def _merge_heads(self, hidden_states):
331
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
332
+
333
+ def __call__(
334
+ self,
335
+ hidden_states: jnp.ndarray,
336
+ key_value_states: Optional[jnp.ndarray] = None,
337
+ attention_mask: Optional[jnp.ndarray] = None,
338
+ deterministic: bool = True,
339
+ ) -> Tuple[jnp.ndarray]:
340
+ """Input shape: Batch x Time x Channel"""
341
+
342
+ if self.config.fuse_matmuls:
343
+ attention_states = self.fused_proj(hidden_states)
344
+ query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
345
+
346
+ else:
347
+ # get query proj
348
+ query_states = self.q_proj(hidden_states)
349
+
350
+ key_states = self.k_proj(hidden_states)
351
+ value_states = self.v_proj(hidden_states)
352
+
353
+ query_states = self._split_heads(query_states)
354
+ key_states = self._split_heads(key_states)
355
+ value_states = self._split_heads(value_states)
356
+
357
+ if attention_mask is not None:
358
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
359
+
360
+ # Convert the boolean attention mask to an attention bias.
361
+ if attention_mask is not None:
362
+ # attention mask in the form of attention bias
363
+ attention_bias = lax.select(
364
+ attention_mask > 0,
365
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
366
+ jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
367
+ )
368
+ else:
369
+ attention_bias = None
370
+
371
+ dropout_rng = None
372
+ if not deterministic and self.dropout > 0.0:
373
+ dropout_rng = self.make_rng("dropout")
374
+
375
+ attn_weights = dot_product_attention_weights(
376
+ query_states,
377
+ key_states,
378
+ bias=attention_bias,
379
+ dropout_rng=dropout_rng,
380
+ dropout_rate=self.dropout,
381
+ broadcast_dropout=True,
382
+ deterministic=deterministic,
383
+ dtype=self.dtype,
384
+ precision=None,
385
+ )
386
+
387
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
388
+ attn_output = self._merge_heads(attn_output)
389
+ attn_output = self.out_proj(attn_output)
390
+
391
+ return attn_output, attn_weights
392
+
393
+
394
+ class FlaxWav2Vec2FeedForward(nn.Module):
395
+ config: Wav2Vec2Config
396
+ dtype: jnp.dtype = jnp.float32
397
+
398
+ def setup(self):
399
+ self.intermediate_dropout = nn.Dropout(rate=self.config.activation_dropout)
400
+
401
+ self.intermediate_dense = nn.Dense(
402
+ self.config.intermediate_size,
403
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
404
+ dtype=self.dtype,
405
+ )
406
+ if isinstance(self.config.hidden_act, str):
407
+ self.intermediate_act_fn = ACT2FN[self.config.hidden_act]
408
+ else:
409
+ self.intermediate_act_fn = self.config.hidden_act
410
+
411
+ self.output_dense = nn.Dense(
412
+ self.config.hidden_size,
413
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
414
+ dtype=self.dtype,
415
+ )
416
+ self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout)
417
+
418
+ def __call__(self, hidden_states, deterministic=True):
419
+ hidden_states = self.intermediate_dense(hidden_states)
420
+ hidden_states = self.intermediate_act_fn(hidden_states)
421
+ hidden_states = self.intermediate_dropout(hidden_states, deterministic=deterministic)
422
+
423
+ hidden_states = self.output_dense(hidden_states)
424
+ hidden_states = self.output_dropout(hidden_states, deterministic=deterministic)
425
+ return hidden_states
426
+
427
+
428
+ class FlaxWav2Vec2EncoderLayerStableLayerNorm(nn.Module):
429
+ config: Wav2Vec2Config
430
+ dtype: jnp.dtype = jnp.float32
431
+
432
+ def setup(self):
433
+ self.attention = FlaxWav2Vec2Attention(
434
+ config=self.config,
435
+ embed_dim=self.config.hidden_size,
436
+ num_heads=self.config.num_attention_heads,
437
+ dropout=self.config.attention_dropout,
438
+ dtype=self.dtype,
439
+ )
440
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
441
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
442
+ self.feed_forward = FlaxWav2Vec2FeedForward(self.config, dtype=self.dtype)
443
+ self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
444
+
445
+ def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions=False):
446
+ if self.config.use_scan:
447
+ hidden_states = hidden_states[0]
448
+ attn_residual = hidden_states
449
+ hidden_states = self.layer_norm(hidden_states)
450
+ hidden_states, attn_weights = self.attention(
451
+ hidden_states, attention_mask=attention_mask, deterministic=deterministic
452
+ )
453
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
454
+ hidden_states = attn_residual + hidden_states
455
+ hidden_states = hidden_states + self.feed_forward(
456
+ self.final_layer_norm(hidden_states), deterministic=deterministic
457
+ )
458
+
459
+ outputs = (hidden_states,)
460
+
461
+ if output_attentions:
462
+ outputs += (attn_weights,)
463
+
464
+ if self.config.use_scan:
465
+ outputs = (outputs, None)
466
+
467
+ return outputs
468
+
469
+
470
+ class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module):
471
+ config: Wav2Vec2Config
472
+ dtype: jnp.dtype = jnp.float32
473
+
474
+ @nn.compact
475
+ def __call__(
476
+ self,
477
+ hidden_states,
478
+ attention_mask=None,
479
+ deterministic: bool = True,
480
+ output_attentions: bool = False,
481
+ output_hidden_states: bool = False,
482
+ return_dict: bool = True,
483
+ ):
484
+ all_attentions = () if output_attentions else None
485
+ all_hidden_states = () if output_hidden_states else None
486
+
487
+ num_layers = self.config.num_hidden_layers
488
+ BlockEncoderLayer = (
489
+ remat(
490
+ FlaxWav2Vec2EncoderLayerStableLayerNorm,
491
+ static_argnums=(2, 3),
492
+ prevent_cse=not self.config.use_scan,
493
+ )
494
+ if self.config.gradient_checkpointing
495
+ else FlaxWav2Vec2EncoderLayerStableLayerNorm
496
+ )
497
+
498
+ if self.config.use_scan:
499
+ # since all decoder layers are the same, we use nn.scan directly
500
+ assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
501
+ assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
502
+ hidden_states = (hidden_states,)
503
+
504
+ hidden_states, _ = scan_with_axes(
505
+ BlockEncoderLayer,
506
+ variable_axes={"params": 0, "cache": 0},
507
+ split_rngs={"params": True, "dropout": True},
508
+ in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
509
+ length=num_layers,
510
+ )(self.config, dtype=self.dtype, name="FlaxWav2Vec2EncoderLayers",)(
511
+ hidden_states, attention_mask, deterministic, output_attentions
512
+ )
513
+ hidden_states = hidden_states[0]
514
+
515
+ else:
516
+ for layer in range(num_layers):
517
+ if output_hidden_states:
518
+ all_hidden_states += (hidden_states,)
519
+
520
+ layer_outputs = BlockEncoderLayer(
521
+ self.config,
522
+ dtype=self.dtype,
523
+ name=str(layer),
524
+ )(hidden_states, attention_mask, deterministic, output_attentions)
525
+
526
+ hidden_states = layer_outputs[0]
527
+
528
+ if output_attentions:
529
+ all_attentions += (layer_outputs[1],)
530
+
531
+ if output_hidden_states:
532
+ all_hidden_states += (hidden_states,)
533
+
534
+ outputs = (hidden_states, all_hidden_states, all_attentions)
535
+
536
+ if not return_dict:
537
+ return tuple(v for v in outputs if v is not None)
538
+
539
+ return FlaxBaseModelOutput(
540
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
541
+ )
542
+
543
+
544
+ class FlaxWav2Vec2StableLayerNormEncoder(nn.Module):
545
+ config: Wav2Vec2Config
546
+ dtype: jnp.dtype = jnp.float32
547
+
548
+ def setup(self):
549
+ self.pos_conv_embed = FlaxWav2Vec2PositionalConvEmbedding(self.config, dtype=self.dtype)
550
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
551
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
552
+ self.layers = FlaxWav2Vec2EncoderLayerStableLayerNormCollection(self.config, dtype=self.dtype)
553
+
554
+ def __call__(
555
+ self,
556
+ hidden_states,
557
+ attention_mask=None,
558
+ deterministic=True,
559
+ output_attentions=False,
560
+ output_hidden_states=False,
561
+ return_dict=True,
562
+ ):
563
+
564
+ if attention_mask is not None:
565
+ # make sure padded tokens are not attended to
566
+ hidden_states = jnp.where(
567
+ jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), hidden_states, 0
568
+ )
569
+
570
+ position_embeddings = self.pos_conv_embed(hidden_states)
571
+
572
+ hidden_states = hidden_states + position_embeddings
573
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
574
+
575
+ outputs = self.layers(
576
+ hidden_states,
577
+ attention_mask,
578
+ output_attentions=output_attentions,
579
+ output_hidden_states=output_hidden_states,
580
+ return_dict=return_dict,
581
+ )
582
+
583
+ last_hidden_state = self.layer_norm(outputs[0])
584
+
585
+ # update the last element in `hidden_states` after applying `layernorm` above
586
+ hidden_states = None
587
+ if output_hidden_states:
588
+ hidden_states = outputs[1]
589
+ hidden_states = hidden_states[:-1] + (last_hidden_state,)
590
+
591
+ if not return_dict:
592
+ outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
593
+ return tuple(v for v in outputs if v is not None)
594
+
595
+ return FlaxBaseModelOutput(
596
+ last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=outputs.attentions
597
+ )
598
+
599
+
600
+ class FlaxWav2Vec2Adapter(nn.Module):
601
+ config: Wav2Vec2Config
602
+ dtype: jnp.dtype = jnp.float32
603
+
604
+ def setup(self):
605
+ # hidden_states require down-projection if feature dims don't match
606
+ if self.config.output_hidden_size != self.config.hidden_size:
607
+ self.proj = nn.Dense(
608
+ self.config.output_hidden_size,
609
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
610
+ dtype=self.dtype,
611
+ )
612
+ self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
613
+ else:
614
+ self.proj = self.proj_layer_norm = None
615
+
616
+ self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype)
617
+
618
+ def __call__(self, hidden_states, deterministic=True):
619
+ # down-project hidden_states if required
620
+ if self.proj is not None and self.proj_layer_norm is not None:
621
+ hidden_states = self.proj(hidden_states)
622
+ hidden_states = self.proj_layer_norm(hidden_states)
623
+
624
+ hidden_states = self.layers(hidden_states)
625
+
626
+ return hidden_states
627
+
628
+
629
+ class FlaxWav2Vec2AdapterLayer(nn.Module):
630
+ config: Wav2Vec2Config
631
+ dtype: jnp.dtype = jnp.float32
632
+
633
+ def setup(self):
634
+ self.conv = nn.Conv(
635
+ features=2 * self.config.output_hidden_size,
636
+ kernel_size=(self.config.adapter_kernel_size,),
637
+ strides=(self.config.adapter_stride,),
638
+ padding=((1, 1),),
639
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
640
+ dtype=self.dtype,
641
+ )
642
+
643
+ def __call__(self, hidden_states):
644
+ hidden_states = self.conv(hidden_states)
645
+ hidden_states = nn.glu(hidden_states, axis=2)
646
+
647
+ return hidden_states
648
+
649
+
650
+ class FlaxWav2Vec2AdapterLayersCollection(nn.Module):
651
+ config: Wav2Vec2Config
652
+ dtype: jnp.dtype = jnp.float32
653
+
654
+ def setup(self):
655
+ BlockAdapterLayer = remat(FlaxWav2Vec2AdapterLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2AdapterLayer
656
+ self.layers = [
657
+ BlockAdapterLayer(self.config, name=str(i), dtype=self.dtype)
658
+ for i in range(self.config.num_adapter_layers)
659
+ ]
660
+
661
+ def __call__(self, hidden_states):
662
+ for conv_layer in self.layers:
663
+ hidden_states = conv_layer(hidden_states)
664
+
665
+ return hidden_states
666
+
667
+
668
+ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
669
+ """
670
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
671
+ models.
672
+ """
673
+
674
+ config_class = Wav2Vec2Config
675
+ base_model_prefix: str = "wav2vec2"
676
+ main_input_name = "input_values"
677
+ module_class: nn.Module = None
678
+
679
+ def __init__(
680
+ self,
681
+ config: Wav2Vec2Config,
682
+ input_shape: Tuple = (1, 1024),
683
+ seed: int = 0,
684
+ dtype: jnp.dtype = jnp.float32,
685
+ _do_init: bool = True,
686
+ **kwargs,
687
+ ):
688
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
689
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
690
+
691
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
692
+ # init input tensors
693
+ input_values = jnp.zeros(input_shape, dtype="i4")
694
+ attention_mask = jnp.ones_like(input_values)
695
+ params_rng, dropout_rng = jax.random.split(rng, 2)
696
+ rngs = {"params": params_rng, "dropout": dropout_rng}
697
+
698
+ return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
699
+
700
+ def __call__(
701
+ self,
702
+ input_values,
703
+ attention_mask=None,
704
+ mask_time_indices=None,
705
+ extract_features=None,
706
+ params: dict = None,
707
+ dropout_rng: jax.random.PRNGKey = None,
708
+ train: bool = False,
709
+ output_attentions: Optional[bool] = None,
710
+ output_hidden_states: Optional[bool] = None,
711
+ output_features: Optional[bool] = None,
712
+ freeze_feature_encoder: bool = False,
713
+ return_dict: Optional[bool] = None,
714
+ ):
715
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
716
+ output_hidden_states = (
717
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
718
+ )
719
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
720
+
721
+ if attention_mask is None:
722
+ batch_size, sequence_length = input_values.shape
723
+ attention_mask = jnp.ones((batch_size, sequence_length))
724
+
725
+ if extract_features is not None:
726
+ extract_features = jnp.array(extract_features, dtype="f4")
727
+
728
+ # Handle any PRNG if needed
729
+ rngs = {}
730
+ if dropout_rng is not None:
731
+ rngs["dropout"] = dropout_rng
732
+
733
+ inputs = {"params": params or self.params}
734
+
735
+ return self.module.apply(
736
+ inputs,
737
+ jnp.array(input_values, dtype="f4"),
738
+ jnp.array(attention_mask, dtype="i4"),
739
+ mask_time_indices,
740
+ extract_features,
741
+ not train,
742
+ output_attentions,
743
+ output_hidden_states,
744
+ output_features,
745
+ freeze_feature_encoder,
746
+ return_dict,
747
+ rngs=rngs,
748
+ )
749
+
750
+ def _get_feat_extract_output_lengths(
751
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
752
+ ):
753
+ return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
754
+
755
+ def _get_feature_vector_attention_mask(
756
+ self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
757
+ ):
758
+ return self.module._get_feature_vector_attention_mask(feature_vector_length, attention_mask, add_adapter=add_adapter)
759
+
760
+
761
+ class FlaxWav2Vec2Module(nn.Module):
762
+ config: Wav2Vec2Config
763
+ dtype: jnp.dtype = jnp.float32
764
+
765
+ def setup(self):
766
+ self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype)
767
+ self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype)
768
+ self.masked_spec_embed = self.param(
769
+ "masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,)
770
+ )
771
+
772
+ if self.config.do_stable_layer_norm:
773
+ self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype)
774
+ else:
775
+ raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")
776
+
777
+ self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None
778
+
779
+ def __call__(
780
+ self,
781
+ input_values,
782
+ attention_mask=None,
783
+ mask_time_indices=None,
784
+ extract_features=None,
785
+ deterministic=True,
786
+ output_attentions=None,
787
+ output_hidden_states=None,
788
+ output_features=False,
789
+ freeze_feature_encoder=False,
790
+ return_dict=None,
791
+ ):
792
+
793
+ # forward pass through the feature extractor if features not specified
794
+ if extract_features is None:
795
+ extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)
796
+
797
+ if output_features:
798
+ return extract_features
799
+
800
+ # make sure that no loss is computed on padded inputs
801
+ if attention_mask is not None:
802
+ # compute reduced attention_mask corresponding to feature vectors
803
+ attention_mask = self._get_feature_vector_attention_mask(
804
+ extract_features.shape[1], attention_mask, add_adapter=False
805
+ )
806
+
807
+ hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
808
+ if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
809
+ hidden_states = jnp.where(
810
+ jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
811
+ jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
812
+ hidden_states,
813
+ )
814
+
815
+ encoder_outputs = self.encoder(
816
+ hidden_states,
817
+ attention_mask=attention_mask,
818
+ deterministic=deterministic,
819
+ output_attentions=output_attentions,
820
+ output_hidden_states=output_hidden_states,
821
+ return_dict=return_dict,
822
+ )
823
+
824
+ hidden_states = encoder_outputs[0]
825
+
826
+ if self.adapter is not None:
827
+ hidden_states = self.adapter(hidden_states)
828
+
829
+ if not return_dict:
830
+ return (hidden_states, extract_features) + encoder_outputs[1:]
831
+
832
+ return FlaxWav2Vec2BaseModelOutput(
833
+ last_hidden_state=hidden_states,
834
+ extract_features=extract_features,
835
+ hidden_states=encoder_outputs.hidden_states,
836
+ attentions=encoder_outputs.attentions,
837
+ )
838
+
839
+ def _get_feat_extract_output_lengths(
840
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
841
+ ):
842
+ """
843
+ Computes the output length of the convolutional layers
844
+ """
845
+
846
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
847
+
848
+ def _conv_out_length(input_length, kernel_size, stride):
849
+ # 1D convolutional layer output length formula taken
850
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
851
+ return (input_length - kernel_size) // stride + 1
852
+
853
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
854
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
855
+
856
+ if add_adapter:
857
+ for _ in range(self.config.num_adapter_layers):
858
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
859
+
860
+ return input_lengths
861
+
862
+ def _get_feature_vector_attention_mask(
863
+ self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
864
+ ):
865
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
866
+ # on inference mode.
867
+ non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
868
+
869
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
870
+
871
+ batch_size = attention_mask.shape[0]
872
+
873
+ attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
874
+ # these two operations makes sure that all values
875
+ # before the output lengths indices are attended to
876
+ attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
877
+ attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
878
+ return attention_mask
879
+
880
+
881
+ class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
882
+ module_class = FlaxWav2Vec2Module
883
+
884
+
885
+ class FlaxWav2Vec2ForCTCModule(nn.Module):
886
+ config: Wav2Vec2Config
887
+ dtype: jnp.dtype = jnp.float32
888
+
889
+ def setup(self):
890
+ self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
891
+ self.dropout = nn.Dropout(rate=self.config.final_dropout)
892
+ self.lm_head = nn.Dense(
893
+ self.config.vocab_size,
894
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
895
+ dtype=self.dtype,
896
+ )
897
+
898
+ def __call__(
899
+ self,
900
+ input_values,
901
+ attention_mask=None,
902
+ mask_time_indices=None,
903
+ extract_features=None,
904
+ deterministic=True,
905
+ output_attentions=None,
906
+ output_hidden_states=None,
907
+ output_features=False,
908
+ freeze_feature_encoder=False,
909
+ return_dict=None,
910
+ ):
911
+ outputs = self.wav2vec2(
912
+ input_values,
913
+ attention_mask=attention_mask,
914
+ mask_time_indices=mask_time_indices,
915
+ deterministic=deterministic,
916
+ output_attentions=output_attentions,
917
+ output_hidden_states=output_hidden_states,
918
+ freeze_feature_encoder=freeze_feature_encoder,
919
+ return_dict=return_dict,
920
+ )
921
+
922
+ hidden_states = outputs[0]
923
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
924
+
925
+ logits = self.lm_head(hidden_states)
926
+
927
+ if not return_dict:
928
+ return (logits,) + outputs[2:]
929
+
930
+ return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
931
+
932
+ def _get_feat_extract_output_lengths(
933
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
934
+ ):
935
+ """
936
+ Computes the output length of the convolutional layers
937
+ """
938
+
939
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
940
+
941
+ def _conv_out_length(input_length, kernel_size, stride):
942
+ # 1D convolutional layer output length formula taken
943
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
944
+ return (input_length - kernel_size) // stride + 1
945
+
946
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
947
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
948
+
949
+ if add_adapter:
950
+ for _ in range(self.config.num_adapter_layers):
951
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
952
+
953
+ return input_lengths
954
+
955
+ def _get_feature_vector_attention_mask(
956
+ self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
957
+ ):
958
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
959
+ # on inference mode.
960
+ non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
961
+
962
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
963
+
964
+ batch_size = attention_mask.shape[0]
965
+
966
+ attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
967
+ # these two operations makes sure that all values
968
+ # before the output lengths indices are attended to
969
+ attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
970
+ attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
971
+ return attention_mask
972
+
973
+
974
+ class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
975
+ module_class = FlaxWav2Vec2ForCTCModule
preprocessor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "processor_class": "Wav2Vec2Processor",
8
+ "return_attention_mask": true,
9
+ "sampling_rate": 16000
10
+ }
run_flax_speech_recognition_ctc.py ADDED
@@ -0,0 +1,1398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition.
17
+ """
18
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
19
+
20
+ import logging
21
+ import math
22
+ import os
23
+ import sys
24
+ import time
25
+ from dataclasses import dataclass, field
26
+ from pathlib import Path
27
+ from typing import Any, Callable, Dict, List, Optional, Union
28
+
29
+ import datasets
30
+ import numpy as np
31
+ from datasets import DatasetDict, load_dataset, load_metric
32
+ from tqdm import tqdm
33
+
34
+ import flax
35
+ import jax
36
+ import jax.numpy as jnp
37
+ import optax
38
+ import transformers
39
+ import wandb as wandb
40
+ from flax import core, jax_utils, struct, traverse_util
41
+ from flax.jax_utils import unreplicate, pad_shard_unpad
42
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
43
+ from huggingface_hub import Repository
44
+ from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC
45
+ from optax._src import linear_algebra
46
+ from transformers import (
47
+ AutoFeatureExtractor,
48
+ AutoProcessor,
49
+ AutoTokenizer,
50
+ HfArgumentParser,
51
+ TrainingArguments,
52
+ is_tensorboard_available,
53
+ )
54
+ from transformers.file_utils import get_full_repo_name
55
+ from transformers.utils import check_min_version
56
+ from transformers.utils.versions import require_version
57
+
58
+
59
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
60
+ check_min_version("4.17.0.dev0")
61
+
62
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
63
+
64
+ logger = logging.getLogger(__name__)
65
+
66
+
67
+ @flax.struct.dataclass
68
+ class ModelArguments:
69
+ """
70
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
71
+ """
72
+
73
+ model_name_or_path: str = field(
74
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
75
+ )
76
+ config_name: Optional[str] = field(
77
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
78
+ )
79
+ tokenizer_name: Optional[str] = field(
80
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
81
+ )
82
+ feature_extractor_name: Optional[str] = field(
83
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
84
+ )
85
+ cache_dir: Optional[str] = field(
86
+ default=None,
87
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
88
+ )
89
+ use_fast_tokenizer: bool = field(
90
+ default=True,
91
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
92
+ )
93
+ model_revision: str = field(
94
+ default="main",
95
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
96
+ )
97
+ use_auth_token: bool = field(
98
+ default=False,
99
+ metadata={
100
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
101
+ "with private models)."
102
+ },
103
+ )
104
+ freeze_feature_encoder: bool = field(
105
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
106
+ )
107
+ activation_dropout: float = field(
108
+ default=0.1,
109
+ metadata={
110
+ "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
111
+ },
112
+ )
113
+ hidden_dropout: float = field(
114
+ default=0.1,
115
+ metadata={
116
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
117
+ },
118
+ )
119
+ feat_proj_dropout: float = field(
120
+ default=0.0,
121
+ metadata={
122
+ "help": "The feat proj dropout probability for feature encoder representations."
123
+ },
124
+ )
125
+ mask_time_prob: float = field(
126
+ default=0.1,
127
+ metadata={
128
+ "help": "The spec aug dropout probability for feature encoder representations."
129
+ },
130
+ )
131
+
132
+
133
+ @flax.struct.dataclass
134
+ class DataTrainingArguments:
135
+ """
136
+ Arguments pertaining to what data we are going to input our model for training and eval.
137
+ """
138
+
139
+ dataset_name: str = field(
140
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
141
+ )
142
+ dataset_config_name: Optional[str] = field(
143
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
144
+ )
145
+ text_column: Optional[str] = field(
146
+ default=None,
147
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
148
+ )
149
+ dataset_cache_dir: Optional[str] = field(
150
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
151
+ )
152
+ overwrite_cache: bool = field(
153
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
154
+ )
155
+ preprocessing_num_workers: Optional[int] = field(
156
+ default=None,
157
+ metadata={"help": "The number of processes to use for the preprocessing."},
158
+ )
159
+ max_train_samples: Optional[int] = field(
160
+ default=None,
161
+ metadata={
162
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
163
+ "value if set."
164
+ },
165
+ )
166
+ max_eval_samples: Optional[int] = field(
167
+ default=None,
168
+ metadata={
169
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
170
+ "value if set."
171
+ },
172
+ )
173
+ max_test_samples: Optional[int] = field(
174
+ default=None,
175
+ metadata={
176
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
177
+ "value if set."
178
+ },
179
+ )
180
+ audio_column_name: str = field(
181
+ default="audio",
182
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
183
+ )
184
+ text_column_name: str = field(
185
+ default="text",
186
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
187
+ )
188
+ max_duration_in_seconds: float = field(
189
+ default=20.0,
190
+ metadata={
191
+ "help": "Filter audio files in the training set that are longer than `max_duration_in_seconds` seconds"
192
+ },
193
+ )
194
+ min_duration_in_seconds: float = field(
195
+ default=0.0, metadata={"help": "Filter audio files in the training set that are shorter than `min_duration_in_seconds` seconds"}
196
+ )
197
+ max_label_length: Optional[int] = field(
198
+ default=512,
199
+ metadata={
200
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
201
+ "than this will be filtered."
202
+ },
203
+ )
204
+ min_label_length: Optional[int] = field(
205
+ default=0,
206
+ metadata={
207
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
208
+ "than this will be filtered."
209
+ },
210
+ )
211
+ max_eval_duration_in_seconds: float = field(
212
+ default=None,
213
+ metadata={
214
+ "help": "Filter audio files in the eval/test set that are longer than `max_duration_in_seconds` seconds"
215
+ },
216
+ )
217
+ pad_input_to_multiple_of: Optional[int] = field(
218
+ default=32000,
219
+ metadata={
220
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
221
+ "This is important to avoid triggering recompilations on TPU."
222
+ },
223
+ )
224
+ pad_target_to_multiple_of: Optional[int] = field(
225
+ default=None,
226
+ metadata={
227
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
228
+ "This is important to avoid triggering recompilations on TPU."
229
+ },
230
+ )
231
+ preprocessing_only: bool = field(
232
+ default=False,
233
+ metadata={
234
+ "help": "Whether to only do data preprocessing and skip training. "
235
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
236
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
237
+ "so that the cached datasets can consequently be loaded in distributed training"
238
+ },
239
+ )
240
+ train_split_name: str = field(
241
+ default="train",
242
+ metadata={
243
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
244
+ },
245
+ )
246
+ eval_split_name: str = field(
247
+ default="validation",
248
+ metadata={
249
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
250
+ },
251
+ )
252
+ wandb_project: str = field(
253
+ default="flax-speech-recognition-ctc",
254
+ metadata={"help": "The name of the wandb project."},
255
+ )
256
+ wandb_name: str = field(
257
+ default=None,
258
+ metadata={"help": "The name of the wandb run."},
259
+ )
260
+ wandb_job_type: str = field(
261
+ default="CTC",
262
+ metadata={"help": "The name of the wandb job type."},
263
+ )
264
+ test_split_name: str = field(
265
+ default="test",
266
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
267
+ )
268
+
269
+
270
+ # @flax.struct.dataclass
271
+ @dataclass
272
+ class FlaxTrainingArguments(TrainingArguments):
273
+ precision: str = field(
274
+ default="full",
275
+ metadata={
276
+ "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
277
+ "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
278
+ },
279
+ )
280
+ matmul_precision: str = field(
281
+ default="default",
282
+ metadata={
283
+ "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
284
+ "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
285
+ "This configuration option does not change the behaviours of such calls with explicit precision arguments; "
286
+ "it only changes the behaviors of calls with no such argument provided. "
287
+ "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
288
+ },
289
+ )
290
+ multisteps: bool = field(
291
+ default=False,
292
+ metadata={
293
+ "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, "
294
+ "a custom gradient accumulation implementation will be employed."
295
+ },
296
+ )
297
+
298
+
299
+ def to_fp32(t):
300
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
301
+
302
+
303
+ def to_bf16(t):
304
+ return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
305
+
306
+
307
+ class MixedPrecisionTrainState(struct.PyTreeNode):
308
+ """Train state for use with a single Optax optimizer.
309
+ Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
310
+
311
+ Synopsis::
312
+
313
+ state = TrainState.create(
314
+ apply_fn=model.apply,
315
+ params=variables['params'],
316
+ tx=tx)
317
+ grad_fn = jax.grad(make_loss_fn(state.apply_fn))
318
+ for batch in data:
319
+ grads = grad_fn(state.params, batch)
320
+ state = state.apply_gradients(grads=grads)
321
+
322
+ Args:
323
+ step: Counter starts at 0 and is incremented by every call to
324
+ `.apply_gradients()`.
325
+ apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
326
+ convenience to have a shorter params list for the `train_step()` function
327
+ in your training loop.
328
+ params: The parameters to be updated by `tx` and used by `apply_fn`.
329
+ tx: An Optax gradient transformation.
330
+ opt_state: The state for `tx`.
331
+ dropout_rng: PRNG key for stochastic operations.
332
+ bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
333
+ """
334
+
335
+ step: int
336
+ apply_fn: Callable = struct.field(pytree_node=False)
337
+ get_attention_mask_fn: Callable = struct.field(pytree_node=False)
338
+ params: core.FrozenDict[str, Any]
339
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
340
+ opt_state: optax.OptState
341
+ dropout_rng: jnp.ndarray
342
+ max_grad_norm: Optional[float] = 1.0
343
+
344
+ def apply_gradients(self, *, grads, to_dtype, **kwargs):
345
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
346
+
347
+ Note that internally this function calls `.tx.update()` followed by a call
348
+ to `optax.apply_updates()` to update `params` and `opt_state`.
349
+
350
+ Args:
351
+ grads: Gradients that have the same pytree structure as `.params`.
352
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
353
+
354
+ Returns:
355
+ An updated instance of `self` with `step` incremented by one, `params`
356
+ and `opt_state` updated by applying `grads`, and additional attributes
357
+ replaced as specified by `kwargs`.
358
+ """
359
+
360
+ # clip gradients by global l2 norm
361
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
362
+ g_norm = linear_algebra.global_norm(grads)
363
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
364
+ grads = jax.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
365
+
366
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
367
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
368
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
369
+
370
+ new_params = optax.apply_updates(self.params, updates)
371
+ return self.replace(
372
+ step=self.step + 1,
373
+ params=new_params,
374
+ opt_state=to_dtype(new_opt_state),
375
+ **kwargs,
376
+ )
377
+
378
+ @classmethod
379
+ def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
380
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
381
+ # downcast optimizer state to bf16 if mixed-precision training
382
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
383
+ return cls(
384
+ step=0,
385
+ apply_fn=apply_fn,
386
+ params=params,
387
+ tx=tx,
388
+ opt_state=opt_state,
389
+ **kwargs,
390
+ )
391
+
392
+ def replicate(self):
393
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
394
+
395
+
396
+ @flax.struct.dataclass
397
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
398
+ """
399
+ Data collator that will dynamically pad the inputs received.
400
+ Args:
401
+ processor ([`Wav2Vec2Processor`])
402
+ The processor used for proccessing the data.
403
+ decoder_start_token_id (:obj: `int`)
404
+ The begin-of-sentence of the decoder.
405
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
406
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
407
+ among:
408
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
409
+ sequence if provided).
410
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
411
+ maximum acceptable input length for the model if that argument is not provided.
412
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
413
+ different lengths).
414
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
415
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
416
+ See above for details.
417
+ max_input_length (:obj:`float`, `optional`):
418
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
419
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
420
+ If set will pad the input sequence to a multiple of the provided value.
421
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
422
+ 7.5 (Volta).
423
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
424
+ If set will pad the target sequence to a multiple of the provided value.
425
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
426
+ 7.5 (Volta).
427
+ """
428
+
429
+ processor: Any
430
+ input_padding: Union[bool, str] = "longest"
431
+ label_padding: Union[bool, str] = "max_length"
432
+ pad_input_to_multiple_of: Optional[int] = None
433
+ pad_to_multiple_of_label: Optional[int] = None
434
+ max_input_length: Optional[float] = None
435
+ max_label_length: Optional[float] = None
436
+
437
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
438
+ # split inputs and labels since they have to be of different lengths and need
439
+ # different padding methods
440
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
441
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
442
+
443
+ # reformat list to dict and set to pytorch format
444
+ batch = self.processor.feature_extractor.pad(
445
+ input_features,
446
+ max_length=self.max_input_length,
447
+ padding=self.input_padding,
448
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
449
+ return_tensors="np",
450
+ )
451
+
452
+ labels_batch = self.processor.tokenizer.pad(
453
+ label_features,
454
+ max_length=self.max_label_length,
455
+ padding=self.label_padding,
456
+ pad_to_multiple_of=self.pad_to_multiple_of_label,
457
+ return_tensors="np",
458
+ )
459
+
460
+ labels = labels_batch["input_ids"]
461
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
462
+ labels = labels.filled(fill_value=-100)
463
+
464
+ batch["labels"] = labels
465
+
466
+ return batch
467
+
468
+
469
+ def get_grouped_indices(
470
+ dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
471
+ ) -> np.array:
472
+ """
473
+ Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
474
+ Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
475
+ lengths. To do this, the indices are:
476
+
477
+ - randomly permuted (if a JAX rng is specified)
478
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
479
+ - sorted by length in each mega-batch
480
+
481
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
482
+ maximum length placed first, so that an OOM happens sooner rather than later.
483
+ """
484
+ lengths = dataset["input_length"]
485
+
486
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
487
+ if mega_batch_mult is None:
488
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
489
+ # Just in case, for tiny datasets
490
+ if mega_batch_mult == 0:
491
+ mega_batch_mult = 1
492
+
493
+ # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
494
+ num_samples = len(lengths)
495
+ indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
496
+
497
+ megabatch_size = mega_batch_mult * batch_size
498
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
499
+ megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
500
+
501
+ # The rest is to get the biggest batch first.
502
+ # Since each megabatch is sorted by descending length, the longest element is the first
503
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
504
+ max_idx = np.argmax(megabatch_maximums).item()
505
+ # Switch to put the longest batch in first position
506
+ # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
507
+ megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
508
+
509
+ megabatches = np.array([i for megabatch in megabatches for i in megabatch])
510
+
511
+ return megabatches
512
+
513
+
514
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
515
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
516
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
517
+ num_samples = len(samples_idx)
518
+ if drop_last:
519
+ samples_to_remove = num_samples % batch_size
520
+ if samples_to_remove != 0:
521
+ samples_idx = samples_idx[:-samples_to_remove]
522
+ sections_split = num_samples // batch_size
523
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
524
+ else:
525
+ sections_split = math.ceil(num_samples / batch_size)
526
+ samples_idx = np.array_split(samples_idx, sections_split)
527
+ return samples_idx
528
+
529
+
530
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
531
+ summary_writer.scalar("train_time", train_time, step)
532
+
533
+ train_metrics = get_metrics(train_metrics)
534
+ for key, vals in train_metrics.items():
535
+ tag = f"train_{key}"
536
+ for i, val in enumerate(vals):
537
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
538
+
539
+
540
+ def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
541
+ for metric_name, value in eval_metrics.items():
542
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
543
+
544
+ if pred_str is not None:
545
+ # write output actual predictions for debugging
546
+ summary_writer.text("eval_predictions", "\n".join(pred_str), step)
547
+
548
+
549
+ def write_wandb_log(metrics, step, prefix=None):
550
+ if jax.process_index() == 0:
551
+ log_metrics = {}
552
+ for k, v in metrics.items():
553
+ if "layer" in k:
554
+ log_metrics[f"{k}/"] = v
555
+ elif prefix is not None:
556
+ log_metrics[f"{prefix}/{k}"] = v
557
+ else:
558
+ log_metrics[k] = v
559
+ wandb.log(log_metrics, step)
560
+
561
+
562
+ def write_wandb_pred(pred_str, label_str, step, final_step=False, prefix="eval"):
563
+ if jax.process_index() == 0:
564
+ # convert str data to a wandb compatible format
565
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
566
+ if not final_step:
567
+ # we'll log the first 50 predictions for each intermediate epoch
568
+ wandb.log(
569
+ {
570
+ f"{prefix}/step_{int(step / 1000)}k": wandb.Table(
571
+ columns=["label_str", "pred_str"], data=str_data[:50]
572
+ )
573
+ },
574
+ step,
575
+ )
576
+ else:
577
+ # we'll log all predictions for the last epoch
578
+ wandb.log(
579
+ {
580
+ f"{prefix}/step_{int(step / 1000)}k_all": wandb.Table(
581
+ columns=["label_str", "pred_str"], data=str_data
582
+ )
583
+ },
584
+ step,
585
+ )
586
+
587
+
588
+ def create_learning_rate_fn(
589
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
590
+ ) -> Callable[[int], jnp.array]:
591
+ """Returns a linear warmup, linear_decay learning rate function."""
592
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
593
+ decay_fn = optax.linear_schedule(
594
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
595
+ )
596
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
597
+ return schedule_fn
598
+
599
+
600
+ def ctc_loss(
601
+ logits,
602
+ logits_attention_mask,
603
+ labels,
604
+ blank_id,
605
+ loss_reduction="mean",
606
+ output_emission_dict=False,
607
+ log_epsilon=-100000.0,
608
+ ):
609
+ """Computes CTC loss.
610
+ This function performs forward computation over an FSA with `N * 2` states
611
+ where `N` is the max number of labels. The states are split into two groups:
612
+ Phi states and emission states. a phi-state accepts repetition of
613
+ phi (blank)-symbols and transits to emission state when the correct label is
614
+ observed. An emission state accepts repetition of the label and transits to
615
+ the next phi states at any time (so called epsilon-transition).
616
+ Below, `B` denotes the batch size, `T` denotes the time steps in `logits`,
617
+ and `N` denotes the time steps in `labels`.
618
+ Args:
619
+ logits: (B, T, K)-array containing log-probabilities of each class.
620
+ logitpaddings: (B, T)-array. Padding indicators for `logits`.
621
+ labels: (B, N)-array containing reference integer labels.
622
+ labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently,
623
+ `labels` must be right-padded, i.e. each row of `labelpaddings` must be
624
+ repetition of zeroes, followed by repetition of ones.
625
+ blank_id: Id for blank token.
626
+ loss_reduction: one of "mean", "sum", "default"
627
+ - "none": no reduction is applied.
628
+ - "mean": output loss will be divided by target lengths and then the
629
+ mean over the batch is taken.
630
+ - "sum": output loss are summed over batch
631
+ output_emission_dict: whether to output additional information about the emission probs
632
+ Returns:
633
+ A pair of `(per_seq_loss, aux)`.
634
+ per_seq_loss:
635
+ (B,)-array containing loss values for each sequence in the batch.
636
+ aux: Dictionary containing interim variables used for computing losses.
637
+ aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each
638
+ phi-state corresponding to the n-th label.
639
+ aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each
640
+ emission-state corresponding to the n-th label.
641
+ aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol
642
+ corresponding to each time frame.
643
+ aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label
644
+ corresponding to each time frame.
645
+ """
646
+ # label paddings are indicated by -100
647
+ labelpaddings = labels < 0
648
+ # logit paddings are the inverse of attention_mask
649
+ logitpaddings = ~logits_attention_mask
650
+
651
+ # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py
652
+ batchsize, unused_maxinputlen, num_classes = logits.shape
653
+ batchsize_, maxlabellen = labels.shape
654
+
655
+ logprobs = jax.nn.log_softmax(logits)
656
+ labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32)
657
+
658
+ # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].
659
+ repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32)
660
+ repeat = jnp.pad(repeat, ((0, 0), (0, 1)))
661
+
662
+ logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1]
663
+ logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1]
664
+
665
+ one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K]
666
+ logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot)
667
+ logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N]
668
+
669
+ logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N]
670
+ logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0)
671
+ logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N]
672
+
673
+ def loop_body(prev, x):
674
+ prev_phi, prev_emit = prev
675
+ # emit-to-phi epsilon transition, except if the next label is repetition
676
+ prev_phi_orig = prev_phi
677
+ prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat))
678
+
679
+ logprob_emit, logprob_phi, pad = x
680
+
681
+ # phi-to-emit transition
682
+ next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit)
683
+ # self-loop transition
684
+ next_phi = prev_phi + logprob_phi
685
+ # emit-to-phi blank transition only when the next label is repetition
686
+ next_phi = next_phi.at[:, 1:].set(
687
+ jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat))
688
+ )
689
+
690
+ pad = pad.reshape((batchsize, 1))
691
+ next_emit = pad * prev_emit + (1.0 - pad) * next_emit
692
+ next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi
693
+
694
+ return (next_phi, next_emit), (next_phi, next_emit)
695
+
696
+ xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0)))
697
+ _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs)
698
+
699
+ # last row needs to be updated with the last epsilon transition
700
+ logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1]))
701
+ logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last)
702
+
703
+ # extract per_seq_loss
704
+ one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1]
705
+ per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot)
706
+
707
+ if loss_reduction == "mean":
708
+ target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1)
709
+ loss = (per_seq_loss / target_lengths).mean()
710
+ elif loss_reduction == "sum":
711
+ loss = per_seq_loss.sum()
712
+ else:
713
+ loss = per_seq_loss
714
+
715
+ if not output_emission_dict:
716
+ return loss
717
+
718
+ return loss, {
719
+ "logalpha_phi": logalpha_phi,
720
+ "logalpha_emit": logalpha_emit,
721
+ "logprobs_phi": logprobs_phi,
722
+ "logprobs_emit": logprobs_emit,
723
+ }
724
+
725
+
726
+ def main():
727
+ # 1. Parse input arguments
728
+ # See all possible arguments in src/transformers/training_args.py
729
+ # or by passing the --help flag to this script.
730
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
731
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments))
732
+
733
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
734
+ # If we pass only one argument to the script and it's the path to a json file,
735
+ # let's parse it to get our arguments.
736
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
737
+ else:
738
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
739
+
740
+ # 2. Setup logging
741
+ # Make one log on every process with the configuration for debugging.
742
+ logging.basicConfig(
743
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
744
+ datefmt="%m/%d/%Y %H:%M:%S",
745
+ handlers=[logging.StreamHandler(sys.stdout)],
746
+ )
747
+ # Set the verbosity to info of the Transformers logger.
748
+ # We only want one process per machine to log things on the screen.
749
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
750
+ if jax.process_index() == 0:
751
+ datasets.utils.logging.set_verbosity_warning()
752
+ transformers.utils.logging.set_verbosity_info()
753
+ else:
754
+ datasets.utils.logging.set_verbosity_error()
755
+ transformers.utils.logging.set_verbosity_error()
756
+
757
+ # Set up wandb run
758
+ if jax.process_index() == 0:
759
+ wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
760
+
761
+ logger.info("Training/evaluation parameters %s", training_args)
762
+
763
+ # Set the default TPU matmul precision and display the number of devices
764
+ jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
765
+ logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
766
+
767
+ # 4. Load dataset
768
+ raw_datasets = DatasetDict()
769
+
770
+ if training_args.do_train:
771
+ raw_datasets["train"] = load_dataset(
772
+ data_args.dataset_name,
773
+ data_args.dataset_config_name,
774
+ split=data_args.train_split_name,
775
+ cache_dir=data_args.dataset_cache_dir,
776
+ use_auth_token=True if model_args.use_auth_token else None,
777
+ )
778
+
779
+ if training_args.do_eval:
780
+ raw_datasets["eval"] = load_dataset(
781
+ data_args.dataset_name,
782
+ data_args.dataset_config_name,
783
+ split=data_args.eval_split_name,
784
+ cache_dir=data_args.dataset_cache_dir,
785
+ use_auth_token=True if model_args.use_auth_token else None,
786
+ )
787
+
788
+ if training_args.do_predict:
789
+ test_split = data_args.test_split_name.split("+")
790
+ for split in test_split:
791
+ raw_datasets[split] = load_dataset(
792
+ data_args.dataset_name,
793
+ data_args.dataset_config_name,
794
+ split=split,
795
+ cache_dir=data_args.dataset_cache_dir,
796
+ use_auth_token=True if model_args.use_auth_token else None,
797
+ )
798
+
799
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
800
+ raise ValueError(
801
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
802
+ "training, evaluation or prediction has to be done."
803
+ )
804
+
805
+ # if not training, there is no need to run multiple epochs
806
+ if not training_args.do_train:
807
+ training_args.num_train_epochs = 1
808
+
809
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
810
+ raise ValueError(
811
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
812
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
813
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
814
+ )
815
+
816
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
817
+ raise ValueError(
818
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
819
+ "Make sure to set `--text_column_name` to the correct text column - one of "
820
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
821
+ )
822
+
823
+ # 5. Load pretrained model, tokenizer, and feature extractor
824
+ #
825
+ # Distributed training:
826
+ # The .from_pretrained methods guarantee that only one local process can concurrently
827
+ config = Wav2Vec2Config.from_pretrained(
828
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
829
+ cache_dir=model_args.cache_dir,
830
+ revision=model_args.model_revision,
831
+ use_auth_token=True if model_args.use_auth_token else None,
832
+ )
833
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
834
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
835
+ cache_dir=model_args.cache_dir,
836
+ revision=model_args.model_revision,
837
+ use_auth_token=True if model_args.use_auth_token else None,
838
+ )
839
+ tokenizer = AutoTokenizer.from_pretrained(
840
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
841
+ cache_dir=model_args.cache_dir,
842
+ revision=model_args.model_revision,
843
+ use_auth_token=True if model_args.use_auth_token else None,
844
+ )
845
+ # update config according to training args, model args, and tokenizer attributes
846
+ config.update(
847
+ {
848
+ "gradient_checkpointing": training_args.gradient_checkpointing,
849
+ "activation_dropout": model_args.activation_dropout,
850
+ "hidden_dropout": model_args.hidden_dropout,
851
+ "feat_proj_dropout": model_args.feat_proj_dropout,
852
+ "mask_time_prob": model_args.mask_time_prob,
853
+ "vocab_size": tokenizer.vocab_size,
854
+ }
855
+ )
856
+
857
+ if training_args.precision == "full_mixed":
858
+ dtype = jnp.bfloat16
859
+ training_args.mixed_precision = True
860
+ elif training_args.precision == "half_mixed":
861
+ dtype = jnp.bfloat16
862
+ training_args.mixed_precision = False
863
+ else:
864
+ dtype = jnp.float32
865
+ training_args.mixed_precision = False
866
+
867
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
868
+ model_args.model_name_or_path,
869
+ config=config,
870
+ dtype=dtype,
871
+ cache_dir=model_args.cache_dir,
872
+ revision=model_args.model_revision,
873
+ use_auth_token=True if model_args.use_auth_token else None,
874
+ )
875
+
876
+ # 6. Resample speech dataset ALWAYS
877
+ raw_datasets = raw_datasets.cast_column(
878
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
879
+ )
880
+
881
+ # 7. Preprocessing the datasets.
882
+ # We need to read the audio files as arrays and tokenize the targets.
883
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
884
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
885
+ max_eval_input_length = int(data_args.max_eval_duration_in_seconds * feature_extractor.sampling_rate) if data_args.max_eval_duration_in_seconds else None
886
+ max_target_length = data_args.max_label_length
887
+ min_target_length = data_args.min_label_length
888
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
889
+ audio_column_name = data_args.audio_column_name
890
+ num_workers = data_args.preprocessing_num_workers
891
+ text_column_name = data_args.text_column_name
892
+ model_input_name = feature_extractor.model_input_names[0]
893
+
894
+ if training_args.do_train and data_args.max_train_samples is not None:
895
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
896
+
897
+ if training_args.do_eval and data_args.max_eval_samples is not None:
898
+ raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
899
+
900
+ if training_args.do_predict and data_args.max_test_samples is not None:
901
+ for split in test_split:
902
+ raw_datasets[split] = raw_datasets[split].select(range(data_args.max_eval_samples))
903
+
904
+ def prepare_dataset(batch):
905
+ # Pre-process audio
906
+ sample = batch[audio_column_name]
907
+ # normalise audio (mean, std) to (0, 1)
908
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
909
+ # process audio length
910
+ batch[model_input_name] = inputs.input_values[0]
911
+ batch["input_length"] = len(batch["input_values"])
912
+
913
+ input_str = batch[text_column_name]
914
+ batch["labels"] = tokenizer(input_str).input_ids
915
+ batch["labels_length"] = len(batch["labels"])
916
+ return batch
917
+
918
+ vectorized_datasets = raw_datasets.map(
919
+ prepare_dataset,
920
+ remove_columns=next(iter(raw_datasets.values())).column_names,
921
+ num_proc=num_workers,
922
+ desc="preprocess dataset",
923
+ )
924
+
925
+ # filter training data with inputs longer than max_input_length
926
+ def is_audio_in_length_range(length):
927
+ return min_input_length < length < max_input_length
928
+
929
+ if training_args.do_train:
930
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
931
+ is_audio_in_length_range,
932
+ num_proc=num_workers,
933
+ input_columns=["input_length"],
934
+ )
935
+
936
+ # filter data with targets shorter than min_target_length or longer than max_target_length
937
+ def is_labels_in_length_range(length):
938
+ return min_target_length < length < max_target_length
939
+
940
+ if training_args.do_train:
941
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
942
+ is_labels_in_length_range,
943
+ num_proc=num_workers,
944
+ input_columns=["labels_length"],
945
+ )
946
+
947
+
948
+ if max_eval_input_length is not None:
949
+ # filter training data with inputs longer than max_input_length
950
+ def is_eval_audio_in_length_range(length):
951
+ return min_input_length < length < max_eval_input_length
952
+
953
+ if training_args.do_eval:
954
+ vectorized_datasets["eval"] = vectorized_datasets["eval"].filter(
955
+ is_eval_audio_in_length_range,
956
+ num_proc=num_workers,
957
+ input_columns=["input_length"],
958
+ )
959
+
960
+ if training_args.do_predict:
961
+ for split in test_split:
962
+ vectorized_datasets[split] = vectorized_datasets[split].filter(
963
+ is_eval_audio_in_length_range,
964
+ num_proc=num_workers,
965
+ input_columns=["input_length"],
966
+ )
967
+
968
+ # for large datasets it is advised to run the preprocessing on a
969
+ # single machine first with `args.preprocessing_only` since there will mostly likely
970
+ # be a timeout when running the script in distributed mode.
971
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
972
+ # cached dataset
973
+ if data_args.preprocessing_only:
974
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
975
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
976
+ return
977
+
978
+ # 8. Load Metrics
979
+ wer_metric = load_metric("wer")
980
+ cer_metric = load_metric("cer")
981
+
982
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
983
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
984
+
985
+ pred_str = tokenizer.batch_decode(pred_ids)
986
+ # we do not want to group tokens when computing the metrics
987
+ label_str = tokenizer.batch_decode(padded_ids, group_tokens=False)
988
+
989
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
990
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
991
+
992
+ return {"wer": wer, "cer": cer}, pred_str, label_str
993
+
994
+ # 9. save feature extractor, tokenizer and config
995
+ feature_extractor.save_pretrained(training_args.output_dir)
996
+ tokenizer.save_pretrained(training_args.output_dir)
997
+ config.save_pretrained(training_args.output_dir)
998
+
999
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
1000
+
1001
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
1002
+ processor=processor,
1003
+ input_padding="longest",
1004
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
1005
+ max_label_length=data_args.max_label_length,
1006
+ )
1007
+
1008
+ # Enable tensorboard only on the master node
1009
+ has_tensorboard = is_tensorboard_available()
1010
+ if has_tensorboard and jax.process_index() == 0:
1011
+ try:
1012
+ from flax.metrics.tensorboard import SummaryWriter
1013
+
1014
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1015
+ except ImportError as ie:
1016
+ has_tensorboard = False
1017
+ logger.warning(
1018
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1019
+ )
1020
+ else:
1021
+ logger.warning(
1022
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1023
+ "Please run `pip install tensorboard` to enable."
1024
+ )
1025
+
1026
+ # 10. Handle the repository creation
1027
+ if training_args.push_to_hub:
1028
+ with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
1029
+ git_lfs_extensions = f.read()
1030
+ if "*.wandb" not in git_lfs_extensions:
1031
+ f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
1032
+ if training_args.hub_model_id is None:
1033
+ repo_name = get_full_repo_name(
1034
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
1035
+ )
1036
+ else:
1037
+ repo_name = training_args.hub_model_id
1038
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
1039
+
1040
+ # 11. Initialize our training
1041
+ rng = jax.random.PRNGKey(training_args.seed)
1042
+ rng, dropout_rng = jax.random.split(rng)
1043
+
1044
+ # Store some constants
1045
+ max_steps = int(training_args.max_steps)
1046
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1047
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1048
+ batch_size_per_update = train_batch_size * gradient_accumulation_steps
1049
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1050
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
1051
+ to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
1052
+
1053
+ if training_args.do_train:
1054
+ num_train_samples = len(vectorized_datasets["train"])
1055
+ steps_per_epoch = num_train_samples // batch_size_per_update
1056
+ if max_steps > 0:
1057
+ num_epochs = -(training_args.max_steps // -steps_per_epoch)
1058
+ total_train_steps = max_steps
1059
+ else:
1060
+ num_epochs = int(training_args.num_train_epochs)
1061
+ total_train_steps = steps_per_epoch * num_epochs
1062
+
1063
+ # Create learning rate schedule
1064
+ # Create learning rate schedule
1065
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1066
+ total_train_steps,
1067
+ training_args.warmup_steps,
1068
+ training_args.learning_rate,
1069
+ )
1070
+
1071
+ # We use Optax's "masking" functionality to not apply weight decay
1072
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1073
+ # mask boolean with the same structure as the parameters.
1074
+ # The mask is True for parameters that should be decayed.
1075
+ # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
1076
+ # For FlaxT5, one should correct the layer norm parameter naming
1077
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1078
+ def decay_mask_fn(params):
1079
+ flat_params = traverse_util.flatten_dict(params)
1080
+ layer_norm_params = [
1081
+ (name, "scale")
1082
+ for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1083
+ ]
1084
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1085
+ return traverse_util.unflatten_dict(flat_mask)
1086
+
1087
+ if training_args.adafactor:
1088
+ # Create Adafactor optimizer
1089
+ optim = optax.adafactor(
1090
+ learning_rate=linear_decay_lr_schedule_fn,
1091
+ dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
1092
+ weight_decay_rate=training_args.weight_decay,
1093
+ weight_decay_mask=decay_mask_fn,
1094
+ )
1095
+ else:
1096
+ # Create AdamW optimizer
1097
+ optim = optax.adamw(
1098
+ learning_rate=linear_decay_lr_schedule_fn,
1099
+ b1=training_args.adam_beta1,
1100
+ b2=training_args.adam_beta2,
1101
+ eps=training_args.adam_epsilon,
1102
+ weight_decay=training_args.weight_decay,
1103
+ mask=decay_mask_fn,
1104
+ )
1105
+
1106
+ # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1)
1107
+ if training_args.multisteps and gradient_accumulation_steps > 1:
1108
+ optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False)
1109
+ else:
1110
+ num_epochs = 0
1111
+ total_train_steps = 0
1112
+ num_train_samples = 0
1113
+ optim = None
1114
+
1115
+ # Setup train state
1116
+ state = MixedPrecisionTrainState.create(
1117
+ apply_fn=model.__call__,
1118
+ get_attention_mask_fn=model._get_feature_vector_attention_mask,
1119
+ params=model.params,
1120
+ tx=optim,
1121
+ to_dtype=to_dtype,
1122
+ dropout_rng=dropout_rng,
1123
+ max_grad_norm=training_args.max_grad_norm,
1124
+ )
1125
+
1126
+ # Replicate the train state on each device
1127
+ state = state.replicate()
1128
+ blank_id = model.config.pad_token_id
1129
+
1130
+ # Define gradient update step fn
1131
+ def train_step(state, batch):
1132
+ # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
1133
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1134
+
1135
+ def compute_loss(params, minibatch):
1136
+ labels = minibatch.pop("labels")
1137
+ logits = state.apply_fn(
1138
+ **minibatch,
1139
+ params=params,
1140
+ dropout_rng=dropout_rng,
1141
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
1142
+ train=True,
1143
+ )[0]
1144
+ logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"])
1145
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1146
+
1147
+ return loss
1148
+
1149
+ grad_fn = jax.value_and_grad(compute_loss)
1150
+
1151
+ if gradient_accumulation_steps == 1 or training_args.multisteps:
1152
+ loss, grad = grad_fn(to_dtype(state.params), batch)
1153
+
1154
+ # Custom gradient accumulation
1155
+ else:
1156
+ # add a first dimension over gradient_accumulation_steps for minibatch slices
1157
+ batch = jax.tree_map(
1158
+ lambda x: x.reshape(
1159
+ gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
1160
+ ),
1161
+ batch,
1162
+ )
1163
+
1164
+ def accum_minibatch_step(accum_grad, minibatch):
1165
+ # compute loss, num labels and grad over minibatch and accumulate
1166
+ loss, grad = grad_fn(to_dtype(state.params), minibatch)
1167
+ return jax.tree_map(jnp.add, accum_grad, grad), loss
1168
+
1169
+ # create an initial state for accumulating losses, num labels and gradients
1170
+ init_grad = jax.tree_map(jnp.zeros_like, to_dtype(state.params))
1171
+ # loop accum minibatch step over the number of gradient accumulation steps
1172
+ grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch)
1173
+
1174
+ # update state
1175
+ new_state = state.apply_gradients(
1176
+ grads=grad,
1177
+ dropout_rng=new_dropout_rng,
1178
+ to_dtype=to_dtype,
1179
+ )
1180
+
1181
+ # compute gradient norms over all layers and globally for detailed monitoring
1182
+ layer_grad_norm = jax.tree_map(jnp.linalg.norm, grad)
1183
+ logs = {
1184
+ "layer_grad_norm": layer_grad_norm,
1185
+ "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)),
1186
+ }
1187
+
1188
+ # compute parameter norms over all layers and globally for detailed monitoring
1189
+ layer_param_norm = jax.tree_map(jnp.linalg.norm, new_state.params)
1190
+ logs["layer_param_norm"] = layer_param_norm
1191
+ logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm))
1192
+
1193
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1194
+ metrics.update(logs)
1195
+
1196
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1197
+ # metrics = to_fp32(metrics)
1198
+
1199
+ return new_state, metrics
1200
+
1201
+ # Define eval fn
1202
+ def eval_step(params, batch):
1203
+ labels = batch.pop("labels")
1204
+ logits = model(**batch, params=params, train=False)[0]
1205
+
1206
+ logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"])
1207
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1208
+
1209
+ pred_ids = jnp.argmax(logits, axis=-1)
1210
+
1211
+ # summarize metrics
1212
+ metrics = {"loss": loss}
1213
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1214
+ # metrics = to_fp32(metrics)
1215
+ return metrics, pred_ids
1216
+
1217
+ # Create parallel version of the train and eval step
1218
+ if training_args.do_train:
1219
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
1220
+
1221
+ if training_args.do_eval or training_args.do_predict:
1222
+ p_eval_step = jax.pmap(eval_step, "batch")
1223
+
1224
+ def run_evaluation(step, final_step=False):
1225
+ if training_args.do_eval:
1226
+ # ======================== Evaluating ==============================
1227
+ eval_metrics = []
1228
+ eval_preds = []
1229
+ eval_labels = []
1230
+
1231
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1232
+ eval_samples_idx = get_grouped_indices(vectorized_datasets["eval"], eval_batch_size)
1233
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1234
+
1235
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
1236
+ samples = [vectorized_datasets["eval"][int(idx)] for idx in batch_idx]
1237
+ batch = data_collator(samples)
1238
+ labels = batch["labels"]
1239
+
1240
+ try:
1241
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1242
+ except TypeError:
1243
+ continue
1244
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1245
+ eval_metrics.append(metrics)
1246
+
1247
+ eval_labels.extend(labels)
1248
+
1249
+ # normalize eval metrics
1250
+ eval_metrics = get_metrics(eval_metrics)
1251
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
1252
+ eval_metrics = to_fp32(eval_metrics)
1253
+
1254
+ # always run compute metrics
1255
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1256
+ eval_metrics.update(error_rate_metric)
1257
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1258
+
1259
+ # Print metrics and update progress bar
1260
+ desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1261
+ epochs.write(desc)
1262
+ epochs.desc = desc
1263
+
1264
+ # Save metrics
1265
+ write_wandb_log(eval_metrics, step, prefix="eval")
1266
+ write_wandb_pred(pred_str, label_str, step, final_step=final_step)
1267
+ # if has_tensorboard and jax.process_index() == 0:
1268
+ # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
1269
+
1270
+ def save_checkpoint(step):
1271
+ # save and push checkpoint to the hub
1272
+ if jax.process_index() == 0:
1273
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
1274
+ model.save_pretrained(training_args.output_dir, params=params)
1275
+ tokenizer.save_pretrained(training_args.output_dir)
1276
+ if training_args.push_to_hub:
1277
+ repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
1278
+
1279
+ logger.info("***** Running training *****")
1280
+ logger.info(f" Num examples = {num_train_samples}")
1281
+ logger.info(f" Num Epochs = {num_epochs}")
1282
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
1283
+ logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
1284
+ logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
1285
+ logger.info(f" Total optimization steps = {total_train_steps}")
1286
+ logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}")
1287
+ logger.info(f" Use scan: {config.use_scan}")
1288
+ logger.info(f" Fuse matmuls: {config.fuse_matmuls}")
1289
+
1290
+ train_time = cur_step = 0
1291
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1292
+ for epoch in epochs:
1293
+ if training_args.do_train:
1294
+ # ======================== Training ================================
1295
+ train_start = time.time()
1296
+
1297
+ # Create sampling rng
1298
+ rng, input_rng = jax.random.split(rng)
1299
+
1300
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1301
+ train_samples_idx = get_grouped_indices(vectorized_datasets["train"], batch_size_per_update, input_rng)
1302
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
1303
+
1304
+ # Gather the indices for creating the batch and do a training step
1305
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
1306
+ samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
1307
+ batch = data_collator(samples)
1308
+ batch = shard(batch.data)
1309
+ try:
1310
+ state, train_metric = p_train_step(state, batch)
1311
+ except TypeError as e:
1312
+ logger.warning("Encountered following error: \n", e)
1313
+
1314
+ cur_step = epoch * (num_train_samples // batch_size_per_update) + step
1315
+
1316
+ if cur_step % training_args.logging_steps == 0:
1317
+ # Save metrics
1318
+ train_metric = unreplicate(train_metric)
1319
+ train_time += time.time() - train_start
1320
+ # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
1321
+ write_wandb_log(to_fp32(train_metric), cur_step, prefix="train")
1322
+ # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
1323
+ # if has_tensorboard and jax.process_index() == 0:
1324
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
1325
+
1326
+ epochs.write(
1327
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
1328
+ )
1329
+
1330
+ if cur_step % total_train_steps == 0:
1331
+ break
1332
+
1333
+ if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
1334
+ run_evaluation(cur_step, final_step=False)
1335
+
1336
+ if cur_step % training_args.save_steps == 0:
1337
+ save_checkpoint(cur_step)
1338
+
1339
+ if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
1340
+ # run evaluation at the end of the epoch if eval steps are not specified
1341
+ run_evaluation(cur_step, final_step=False)
1342
+ save_checkpoint(cur_step)
1343
+
1344
+ if training_args.do_train:
1345
+ save_checkpoint(cur_step)
1346
+
1347
+ cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
1348
+
1349
+ if training_args.do_eval:
1350
+ run_evaluation(cur_step, final_step=True)
1351
+
1352
+ # TODO: collapse 'do_predict' into the run_evaluation function
1353
+ if training_args.do_predict:
1354
+ for split in test_split:
1355
+ # ======================== Evaluating ==============================
1356
+ eval_metrics = []
1357
+ eval_preds = []
1358
+ eval_labels = []
1359
+
1360
+ # Generate eval set by sequentially sampling indices from the test dataset and grouping by length
1361
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
1362
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1363
+
1364
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)):
1365
+ samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
1366
+ batch = data_collator(samples)
1367
+ labels = batch["labels"]
1368
+
1369
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1370
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1371
+ eval_metrics.append(metrics)
1372
+
1373
+ eval_labels.extend(labels)
1374
+
1375
+ # normalize eval metrics
1376
+ eval_metrics = get_metrics(eval_metrics)
1377
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
1378
+ eval_metrics = to_fp32(eval_metrics)
1379
+
1380
+ # always run compute metrics
1381
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1382
+ eval_metrics.update(error_rate_metric)
1383
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1384
+
1385
+ # Print metrics and update progress bar
1386
+ desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1387
+ epochs.write(desc)
1388
+ epochs.desc = desc
1389
+
1390
+ # Save metrics
1391
+ write_wandb_log(eval_metrics, cur_step, prefix=split)
1392
+ write_wandb_pred(pred_str, label_str, cur_step, final_step=True, prefix=split)
1393
+ # if has_tensorboard and jax.process_index() == 0:
1394
+ # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str)
1395
+
1396
+
1397
+ if __name__ == "__main__":
1398
+ main()
run_voxpopuli.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ python run_flax_speech_recognition_ctc.py \
3
+ --model_name_or_path="esb/wav2vec2-ctc-pretrained" \
4
+ --tokenizer_name="wav2vec2-ctc-voxpopuli-tokenizer" \
5
+ --dataset_name="esb/datasets" \
6
+ --dataset_config_name="voxpopuli" \
7
+ --output_dir="./" \
8
+ --wandb_project="wav2vec2-ctc" \
9
+ --wandb_name="wav2vec2-ctc-voxpopuli" \
10
+ --max_steps="50000" \
11
+ --save_steps="10000" \
12
+ --eval_steps="10000" \
13
+ --learning_rate="3e-4" \
14
+ --logging_steps="25" \
15
+ --warmup_steps="5000" \
16
+ --preprocessing_num_workers="1" \
17
+ --per_device_eval_batch_size="1" \
18
+ --do_train \
19
+ --do_eval \
20
+ --do_predict \
21
+ --overwrite_output_dir \
22
+ --gradient_checkpointing \
23
+ --freeze_feature_encoder \
24
+ --push_to_hub \
25
+ --use_auth_token
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>"}
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>", "do_lower_case": false, "word_delimiter_token": "|", "replace_word_delimiter_char": " ", "special_tokens_map_file": null, "name_or_path": "sanchit-gandhi/wav2vec2_ctc_voxpopuli_tokenizer", "tokenizer_class": "Wav2Vec2CTCTokenizer"}
vocab.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"<pad>": 0, "<s>": 1, "</s>": 2, "<unk>": 3, "r": 4, "u": 5, "c": 6, "y": 7, "f": 8, "j": 9, ";": 10, "m": 11, "a": 12, ".": 13, "q": 14, "e": 15, "i": 16, "s": 17, "w": 18, "p": 19, "n": 20, "z": 21, "o": 22, "l": 23, "'": 24, "h": 25, "v": 26, "t": 27, "g": 28, "b": 30, "x": 31, "d": 32, "k": 33, "?": 34, "|": 29}