Spaces:
Build error
Build error
juancopi81
commited on
Commit
·
b100e1c
1
Parent(s):
8f8dcb6
Add t5x and mt3 models
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +202 -0
- README.md +0 -2
- mt3/__init__.py +33 -0
- mt3/datasets.py +325 -0
- mt3/event_codec.py +112 -0
- mt3/event_codec_test.py +55 -0
- mt3/gin/eval.gin +72 -0
- mt3/gin/infer.gin +92 -0
- mt3/gin/ismir2021.gin +9 -0
- mt3/gin/ismir2022/base.gin +10 -0
- mt3/gin/ismir2022/finetune.gin +25 -0
- mt3/gin/ismir2022/pretrain.gin +13 -0
- mt3/gin/ismir2022/small.gin +2 -0
- mt3/gin/local_tiny.gin +63 -0
- mt3/gin/model.gin +60 -0
- mt3/gin/mt3.gin +9 -0
- mt3/gin/train.gin +148 -0
- mt3/inference.py +138 -0
- mt3/layers.py +830 -0
- mt3/layers_test.py +545 -0
- mt3/metrics.py +392 -0
- mt3/metrics_utils.py +196 -0
- mt3/metrics_utils_test.py +259 -0
- mt3/mixing.py +91 -0
- mt3/models.py +152 -0
- mt3/network.py +409 -0
- mt3/note_sequences.py +446 -0
- mt3/note_sequences_test.py +505 -0
- mt3/preprocessors.py +669 -0
- mt3/pytest.ini +3 -0
- mt3/run_length_encoding.py +423 -0
- mt3/run_length_encoding_test.py +107 -0
- mt3/scripts/dump_task.py +80 -0
- mt3/scripts/extract_monophonic_examples.py +251 -0
- mt3/spectrograms.py +82 -0
- mt3/summaries.py +471 -0
- mt3/tasks.py +402 -0
- mt3/version.py +16 -0
- mt3/vocabularies.py +282 -0
- mt3/vocabularies_test.py +114 -0
- pytest.ini +3 -0
- setup.cfg +2 -0
- setup.py +67 -0
- t5x/__init__.py +34 -0
- t5x/adafactor.py +608 -0
- t5x/adafactor_test.py +527 -0
- t5x/checkpoint_importer.py +485 -0
- t5x/checkpoint_importer_test.py +81 -0
- t5x/checkpoint_utils.py +91 -0
- t5x/checkpoint_utils_test.py +149 -0
LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright [yyyy] [name of copyright owner]
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
README.md
CHANGED
@@ -10,5 +10,3 @@ app_file: app.py
|
|
10 |
pinned: false
|
11 |
license: apache-2.0
|
12 |
---
|
13 |
-
|
14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
10 |
pinned: false
|
11 |
license: apache-2.0
|
12 |
---
|
|
|
|
mt3/__init__.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Base module for MT3."""
|
16 |
+
|
17 |
+
from mt3 import datasets
|
18 |
+
from mt3 import event_codec
|
19 |
+
from mt3 import inference
|
20 |
+
from mt3 import layers
|
21 |
+
from mt3 import metrics
|
22 |
+
from mt3 import metrics_utils
|
23 |
+
from mt3 import models
|
24 |
+
from mt3 import network
|
25 |
+
from mt3 import note_sequences
|
26 |
+
from mt3 import preprocessors
|
27 |
+
from mt3 import run_length_encoding
|
28 |
+
from mt3 import spectrograms
|
29 |
+
from mt3 import summaries
|
30 |
+
from mt3 import tasks
|
31 |
+
from mt3 import vocabularies
|
32 |
+
|
33 |
+
from mt3.version import __version__
|
mt3/datasets.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Dataset configurations."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
from typing import Mapping, Sequence, Union
|
19 |
+
|
20 |
+
from mt3 import note_sequences
|
21 |
+
import tensorflow as tf
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
@dataclasses.dataclass
|
26 |
+
class InferEvalSplit:
|
27 |
+
# key in dictionary containing all dataset splits
|
28 |
+
name: str
|
29 |
+
# task name suffix (each eval split is a separate task)
|
30 |
+
suffix: str
|
31 |
+
# whether or not to include in the mixture of all eval tasks
|
32 |
+
include_in_mixture: bool = True
|
33 |
+
|
34 |
+
|
35 |
+
@dataclasses.dataclass
|
36 |
+
class DatasetConfig:
|
37 |
+
"""Configuration for a transcription dataset."""
|
38 |
+
# dataset name
|
39 |
+
name: str
|
40 |
+
# mapping from split name to path
|
41 |
+
paths: Mapping[str, str]
|
42 |
+
# mapping from feature name to feature
|
43 |
+
features: Mapping[str, Union[tf.io.FixedLenFeature,
|
44 |
+
tf.io.FixedLenSequenceFeature]]
|
45 |
+
# training split name
|
46 |
+
train_split: str
|
47 |
+
# training eval split name
|
48 |
+
train_eval_split: str
|
49 |
+
# list of infer eval split specs
|
50 |
+
infer_eval_splits: Sequence[InferEvalSplit]
|
51 |
+
# list of track specs to be used for metrics
|
52 |
+
track_specs: Sequence[note_sequences.TrackSpec] = dataclasses.field(
|
53 |
+
default_factory=list)
|
54 |
+
|
55 |
+
MAESTROV1_CONFIG = DatasetConfig(
|
56 |
+
name='maestrov1',
|
57 |
+
paths={
|
58 |
+
'train':
|
59 |
+
'gs://magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0_ns_wav_train.tfrecord-?????-of-00010',
|
60 |
+
'train_subset':
|
61 |
+
'gs://magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0_ns_wav_train.tfrecord-00002-of-00010',
|
62 |
+
'validation':
|
63 |
+
'gs://magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0_ns_wav_validation.tfrecord-?????-of-00010',
|
64 |
+
'validation_subset':
|
65 |
+
'gs://magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0_ns_wav_validation.tfrecord-0000[06]-of-00010',
|
66 |
+
'test':
|
67 |
+
'gs://magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0_ns_wav_test.tfrecord-?????-of-00010'
|
68 |
+
},
|
69 |
+
features={
|
70 |
+
'audio': tf.io.FixedLenFeature([], dtype=tf.string),
|
71 |
+
'sequence': tf.io.FixedLenFeature([], dtype=tf.string),
|
72 |
+
'id': tf.io.FixedLenFeature([], dtype=tf.string)
|
73 |
+
},
|
74 |
+
train_split='train',
|
75 |
+
train_eval_split='validation_subset',
|
76 |
+
infer_eval_splits=[
|
77 |
+
InferEvalSplit(name='train', suffix='eval_train_full',
|
78 |
+
include_in_mixture=False),
|
79 |
+
InferEvalSplit(name='train_subset', suffix='eval_train'),
|
80 |
+
InferEvalSplit(name='validation', suffix='validation_full',
|
81 |
+
include_in_mixture=False),
|
82 |
+
InferEvalSplit(name='validation_subset', suffix='validation'),
|
83 |
+
InferEvalSplit(name='test', suffix='test', include_in_mixture=False)
|
84 |
+
])
|
85 |
+
|
86 |
+
|
87 |
+
MAESTROV3_CONFIG = DatasetConfig(
|
88 |
+
name='maestrov3',
|
89 |
+
paths={
|
90 |
+
'train':
|
91 |
+
'gs://magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0_ns_wav_train.tfrecord-?????-of-00025',
|
92 |
+
'train_subset':
|
93 |
+
'gs://magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0_ns_wav_train.tfrecord-00004-of-00025',
|
94 |
+
'validation':
|
95 |
+
'gs://magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0_ns_wav_validation.tfrecord-?????-of-00025',
|
96 |
+
'validation_subset':
|
97 |
+
'gs://magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0_ns_wav_validation.tfrecord-0002?-of-00025',
|
98 |
+
'test':
|
99 |
+
'gs://magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0_ns_wav_test.tfrecord-?????-of-00025'
|
100 |
+
},
|
101 |
+
features={
|
102 |
+
'audio': tf.io.FixedLenFeature([], dtype=tf.string),
|
103 |
+
'sequence': tf.io.FixedLenFeature([], dtype=tf.string),
|
104 |
+
'id': tf.io.FixedLenFeature([], dtype=tf.string)
|
105 |
+
},
|
106 |
+
train_split='train',
|
107 |
+
train_eval_split='validation_subset',
|
108 |
+
infer_eval_splits=[
|
109 |
+
InferEvalSplit(name='train', suffix='eval_train_full',
|
110 |
+
include_in_mixture=False),
|
111 |
+
InferEvalSplit(name='train_subset', suffix='eval_train'),
|
112 |
+
InferEvalSplit(name='validation', suffix='validation_full',
|
113 |
+
include_in_mixture=False),
|
114 |
+
InferEvalSplit(name='validation_subset', suffix='validation'),
|
115 |
+
InferEvalSplit(name='test', suffix='test', include_in_mixture=False)
|
116 |
+
])
|
117 |
+
|
118 |
+
|
119 |
+
GUITARSET_CONFIG = DatasetConfig(
|
120 |
+
name='guitarset',
|
121 |
+
paths={
|
122 |
+
'train':
|
123 |
+
'gs://mt3/data/datasets/guitarset/train.tfrecord-?????-of-00019',
|
124 |
+
'validation':
|
125 |
+
'gs://mt3/data/datasets/guitarset/validation.tfrecord-?????-of-00006',
|
126 |
+
},
|
127 |
+
features={
|
128 |
+
'sequence': tf.io.FixedLenFeature([], dtype=tf.string),
|
129 |
+
'audio': tf.io.FixedLenFeature([], dtype=tf.string),
|
130 |
+
'velocity_range': tf.io.FixedLenFeature([], dtype=tf.string),
|
131 |
+
'id': tf.io.FixedLenFeature([], dtype=tf.string),
|
132 |
+
},
|
133 |
+
train_split='train',
|
134 |
+
train_eval_split='validation',
|
135 |
+
infer_eval_splits=[
|
136 |
+
InferEvalSplit(name='train', suffix='eval_train'),
|
137 |
+
InferEvalSplit(name='validation', suffix='validation'),
|
138 |
+
])
|
139 |
+
|
140 |
+
|
141 |
+
URMP_CONFIG = DatasetConfig(
|
142 |
+
name='urmp',
|
143 |
+
paths={
|
144 |
+
'train': 'gs://mt3/data/datasets/urmp/train.tfrecord',
|
145 |
+
'validation': 'gs://mt3/data/datasets/urmp/validation.tfrecord',
|
146 |
+
},
|
147 |
+
features={
|
148 |
+
'id': tf.io.FixedLenFeature([], dtype=tf.string),
|
149 |
+
'tracks': tf.io.FixedLenSequenceFeature(
|
150 |
+
[], dtype=tf.int64, allow_missing=True),
|
151 |
+
'inst_names': tf.io.FixedLenSequenceFeature(
|
152 |
+
[], dtype=tf.string, allow_missing=True),
|
153 |
+
'audio': tf.io.FixedLenFeature([], dtype=tf.string),
|
154 |
+
'sequence': tf.io.FixedLenFeature([], dtype=tf.string),
|
155 |
+
'instrument_sequences': tf.io.FixedLenSequenceFeature(
|
156 |
+
[], dtype=tf.string, allow_missing=True),
|
157 |
+
},
|
158 |
+
train_split='train',
|
159 |
+
train_eval_split='validation',
|
160 |
+
infer_eval_splits=[
|
161 |
+
InferEvalSplit(name='train', suffix='eval_train'),
|
162 |
+
InferEvalSplit(name='validation', suffix='validation')
|
163 |
+
])
|
164 |
+
|
165 |
+
|
166 |
+
MUSICNET_CONFIG = DatasetConfig(
|
167 |
+
name='musicnet',
|
168 |
+
paths={
|
169 |
+
'train':
|
170 |
+
'gs://mt3/data/datasets/musicnet/musicnet-train.tfrecord-?????-of-00036',
|
171 |
+
'validation':
|
172 |
+
'gs://mt3/data/datasets/musicnet/musicnet-validation.tfrecord-?????-of-00005',
|
173 |
+
'test':
|
174 |
+
'gs://mt3/data/datasets/musicnet/musicnet-test.tfrecord-?????-of-00003'
|
175 |
+
},
|
176 |
+
features={
|
177 |
+
'id': tf.io.FixedLenFeature([], dtype=tf.string),
|
178 |
+
'sample_rate': tf.io.FixedLenFeature([], dtype=tf.float32),
|
179 |
+
'audio': tf.io.FixedLenSequenceFeature(
|
180 |
+
[], dtype=tf.float32, allow_missing=True),
|
181 |
+
'sequence': tf.io.FixedLenFeature([], dtype=tf.string)
|
182 |
+
},
|
183 |
+
train_split='train',
|
184 |
+
train_eval_split='validation',
|
185 |
+
infer_eval_splits=[
|
186 |
+
InferEvalSplit(name='train', suffix='eval_train'),
|
187 |
+
InferEvalSplit(name='validation', suffix='validation'),
|
188 |
+
InferEvalSplit(name='test', suffix='test', include_in_mixture=False)
|
189 |
+
])
|
190 |
+
|
191 |
+
|
192 |
+
MUSICNET_EM_CONFIG = DatasetConfig(
|
193 |
+
name='musicnet_em',
|
194 |
+
paths={
|
195 |
+
'train':
|
196 |
+
'gs://mt3/data/datasets/musicnet_em/train.tfrecord-?????-of-00103',
|
197 |
+
'validation':
|
198 |
+
'gs://mt3/data/datasets/musicnet_em/validation.tfrecord-?????-of-00005',
|
199 |
+
'test':
|
200 |
+
'gs://mt3/data/datasets/musicnet_em/test.tfrecord-?????-of-00006'
|
201 |
+
},
|
202 |
+
features={
|
203 |
+
'id': tf.io.FixedLenFeature([], dtype=tf.string),
|
204 |
+
'sample_rate': tf.io.FixedLenFeature([], dtype=tf.float32),
|
205 |
+
'audio': tf.io.FixedLenSequenceFeature(
|
206 |
+
[], dtype=tf.float32, allow_missing=True),
|
207 |
+
'sequence': tf.io.FixedLenFeature([], dtype=tf.string)
|
208 |
+
},
|
209 |
+
train_split='train',
|
210 |
+
train_eval_split='validation',
|
211 |
+
infer_eval_splits=[
|
212 |
+
InferEvalSplit(name='train', suffix='eval_train'),
|
213 |
+
InferEvalSplit(name='validation', suffix='validation'),
|
214 |
+
InferEvalSplit(name='test', suffix='test', include_in_mixture=False)
|
215 |
+
])
|
216 |
+
|
217 |
+
|
218 |
+
CERBERUS4_CONFIG = DatasetConfig(
|
219 |
+
name='cerberus4',
|
220 |
+
paths={
|
221 |
+
'train':
|
222 |
+
'gs://mt3/data/datasets/cerberus4/slakh_multi_cerberus_train_bass:drums:guitar:piano.tfrecord-?????-of-00286',
|
223 |
+
'train_subset':
|
224 |
+
'gs://mt3/data/datasets/cerberus4/slakh_multi_cerberus_train_bass:drums:guitar:piano.tfrecord-00000-of-00286',
|
225 |
+
'validation':
|
226 |
+
'gs://mt3/data/datasets/cerberus4/slakh_multi_cerberus_validation_bass:drums:guitar:piano.tfrecord-?????-of-00212',
|
227 |
+
'validation_subset':
|
228 |
+
'gs://mt3/data/datasets/cerberus4/slakh_multi_cerberus_validation_bass:drums:guitar:piano.tfrecord-0000?-of-00212',
|
229 |
+
'test':
|
230 |
+
'gs://mt3/data/datasets/cerberus4/slakh_multi_cerberus_test_bass:drums:guitar:piano.tfrecord-?????-of-00106'
|
231 |
+
},
|
232 |
+
features={
|
233 |
+
'audio_sample_rate': tf.io.FixedLenFeature([], dtype=tf.int64),
|
234 |
+
'inst_names': tf.io.FixedLenSequenceFeature(
|
235 |
+
[], dtype=tf.string, allow_missing=True),
|
236 |
+
'midi_class': tf.io.FixedLenSequenceFeature(
|
237 |
+
[], dtype=tf.int64, allow_missing=True),
|
238 |
+
'mix': tf.io.FixedLenSequenceFeature(
|
239 |
+
[], dtype=tf.float32, allow_missing=True),
|
240 |
+
'note_sequences': tf.io.FixedLenSequenceFeature(
|
241 |
+
[], dtype=tf.string, allow_missing=True),
|
242 |
+
'plugin_name': tf.io.FixedLenSequenceFeature(
|
243 |
+
[], dtype=tf.int64, allow_missing=True),
|
244 |
+
'program_num': tf.io.FixedLenSequenceFeature(
|
245 |
+
[], dtype=tf.int64, allow_missing=True),
|
246 |
+
'slakh_class': tf.io.FixedLenSequenceFeature(
|
247 |
+
[], dtype=tf.int64, allow_missing=True),
|
248 |
+
'src_ids': tf.io.FixedLenSequenceFeature(
|
249 |
+
[], dtype=tf.string, allow_missing=True),
|
250 |
+
'stems': tf.io.FixedLenSequenceFeature(
|
251 |
+
[], dtype=tf.float32, allow_missing=True),
|
252 |
+
'stems_shape': tf.io.FixedLenFeature([2], dtype=tf.int64),
|
253 |
+
'target_type': tf.io.FixedLenFeature([], dtype=tf.string),
|
254 |
+
'track_id': tf.io.FixedLenFeature([], dtype=tf.string),
|
255 |
+
},
|
256 |
+
train_split='train',
|
257 |
+
train_eval_split='validation_subset',
|
258 |
+
infer_eval_splits=[
|
259 |
+
InferEvalSplit(name='train', suffix='eval_train_full',
|
260 |
+
include_in_mixture=False),
|
261 |
+
InferEvalSplit(name='train_subset', suffix='eval_train'),
|
262 |
+
InferEvalSplit(name='validation', suffix='validation_full',
|
263 |
+
include_in_mixture=False),
|
264 |
+
InferEvalSplit(name='validation_subset', suffix='validation'),
|
265 |
+
InferEvalSplit(name='test', suffix='test', include_in_mixture=False)
|
266 |
+
],
|
267 |
+
track_specs=[
|
268 |
+
note_sequences.TrackSpec('bass', program=32),
|
269 |
+
note_sequences.TrackSpec('drums', is_drum=True),
|
270 |
+
note_sequences.TrackSpec('guitar', program=24),
|
271 |
+
note_sequences.TrackSpec('piano', program=0)
|
272 |
+
])
|
273 |
+
|
274 |
+
|
275 |
+
SLAKH_CONFIG = DatasetConfig(
|
276 |
+
name='slakh',
|
277 |
+
paths={
|
278 |
+
'train':
|
279 |
+
'gs://mt3/data/datasets/slakh/slakh_multi_full_subsets_10_train_all_inst.tfrecord-?????-of-02307',
|
280 |
+
'train_subset':
|
281 |
+
'gs://mt3/data/datasets/slakh/slakh_multi_full_subsets_10_train_all_inst.tfrecord-00000-of-02307',
|
282 |
+
'validation':
|
283 |
+
'gs://mt3/data/datasets/slakh/slakh_multi_full_validation_all_inst.tfrecord-?????-of-00168',
|
284 |
+
'validation_subset':
|
285 |
+
'gs://mt3/data/datasets/slakh/slakh_multi_full_validation_all_inst.tfrecord-0000?-of-00168',
|
286 |
+
'test':
|
287 |
+
'gs://mt3/data/datasets/slakh/slakh_multi_full_test_all_inst.tfrecord-?????-of-00109'
|
288 |
+
},
|
289 |
+
features={
|
290 |
+
'audio_sample_rate': tf.io.FixedLenFeature([], dtype=tf.int64),
|
291 |
+
'inst_names': tf.io.FixedLenSequenceFeature([], dtype=tf.string,
|
292 |
+
allow_missing=True),
|
293 |
+
'midi_class': tf.io.FixedLenSequenceFeature([], dtype=tf.int64,
|
294 |
+
allow_missing=True),
|
295 |
+
'mix': tf.io.FixedLenSequenceFeature([], dtype=tf.float32,
|
296 |
+
allow_missing=True),
|
297 |
+
'note_sequences': tf.io.FixedLenSequenceFeature([], dtype=tf.string,
|
298 |
+
allow_missing=True),
|
299 |
+
'plugin_name': tf.io.FixedLenSequenceFeature([], dtype=tf.int64,
|
300 |
+
allow_missing=True),
|
301 |
+
'program_num': tf.io.FixedLenSequenceFeature([], dtype=tf.int64,
|
302 |
+
allow_missing=True),
|
303 |
+
'slakh_class': tf.io.FixedLenSequenceFeature([], dtype=tf.int64,
|
304 |
+
allow_missing=True),
|
305 |
+
'src_ids': tf.io.FixedLenSequenceFeature([], dtype=tf.string,
|
306 |
+
allow_missing=True),
|
307 |
+
'stems': tf.io.FixedLenSequenceFeature([], dtype=tf.float32,
|
308 |
+
allow_missing=True),
|
309 |
+
'stems_shape': tf.io.FixedLenFeature([2], dtype=tf.int64),
|
310 |
+
'target_type': tf.io.FixedLenFeature([], dtype=tf.string),
|
311 |
+
'track_id': tf.io.FixedLenFeature([], dtype=tf.string),
|
312 |
+
},
|
313 |
+
train_split='train',
|
314 |
+
train_eval_split='validation_subset',
|
315 |
+
infer_eval_splits=[
|
316 |
+
InferEvalSplit(name='train', suffix='eval_train_full',
|
317 |
+
include_in_mixture=False),
|
318 |
+
InferEvalSplit(name='train_subset', suffix='eval_train'),
|
319 |
+
InferEvalSplit(name='validation', suffix='validation_full',
|
320 |
+
include_in_mixture=False),
|
321 |
+
InferEvalSplit(name='validation_subset', suffix='validation'),
|
322 |
+
InferEvalSplit(name='test', suffix='test', include_in_mixture=False)
|
323 |
+
])
|
324 |
+
|
325 |
+
|
mt3/event_codec.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Encode and decode events."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
from typing import List, Tuple
|
19 |
+
|
20 |
+
|
21 |
+
@dataclasses.dataclass
|
22 |
+
class EventRange:
|
23 |
+
type: str
|
24 |
+
min_value: int
|
25 |
+
max_value: int
|
26 |
+
|
27 |
+
|
28 |
+
@dataclasses.dataclass
|
29 |
+
class Event:
|
30 |
+
type: str
|
31 |
+
value: int
|
32 |
+
|
33 |
+
|
34 |
+
class Codec:
|
35 |
+
"""Encode and decode events.
|
36 |
+
|
37 |
+
Useful for declaring what certain ranges of a vocabulary should be used for.
|
38 |
+
This is intended to be used from Python before encoding or after decoding with
|
39 |
+
GenericTokenVocabulary. This class is more lightweight and does not include
|
40 |
+
things like EOS or UNK token handling.
|
41 |
+
|
42 |
+
To ensure that 'shift' events are always the first block of the vocab and
|
43 |
+
start at 0, that event type is required and specified separately.
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, max_shift_steps: int, steps_per_second: float,
|
47 |
+
event_ranges: List[EventRange]):
|
48 |
+
"""Define Codec.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
max_shift_steps: Maximum number of shift steps that can be encoded.
|
52 |
+
steps_per_second: Shift steps will be interpreted as having a duration of
|
53 |
+
1 / steps_per_second.
|
54 |
+
event_ranges: Other supported event types and their ranges.
|
55 |
+
"""
|
56 |
+
self.steps_per_second = steps_per_second
|
57 |
+
self._shift_range = EventRange(
|
58 |
+
type='shift', min_value=0, max_value=max_shift_steps)
|
59 |
+
self._event_ranges = [self._shift_range] + event_ranges
|
60 |
+
# Ensure all event types have unique names.
|
61 |
+
assert len(self._event_ranges) == len(
|
62 |
+
set([er.type for er in self._event_ranges]))
|
63 |
+
|
64 |
+
@property
|
65 |
+
def num_classes(self) -> int:
|
66 |
+
return sum(er.max_value - er.min_value + 1 for er in self._event_ranges)
|
67 |
+
|
68 |
+
# The next couple methods are simplified special case methods just for shift
|
69 |
+
# events that are intended to be used from within autograph functions.
|
70 |
+
|
71 |
+
def is_shift_event_index(self, index: int) -> bool:
|
72 |
+
return (self._shift_range.min_value <= index) and (
|
73 |
+
index <= self._shift_range.max_value)
|
74 |
+
|
75 |
+
@property
|
76 |
+
def max_shift_steps(self) -> int:
|
77 |
+
return self._shift_range.max_value
|
78 |
+
|
79 |
+
def encode_event(self, event: Event) -> int:
|
80 |
+
"""Encode an event to an index."""
|
81 |
+
offset = 0
|
82 |
+
for er in self._event_ranges:
|
83 |
+
if event.type == er.type:
|
84 |
+
if not er.min_value <= event.value <= er.max_value:
|
85 |
+
raise ValueError(
|
86 |
+
f'Event value {event.value} is not within valid range '
|
87 |
+
f'[{er.min_value}, {er.max_value}] for type {event.type}')
|
88 |
+
return offset + event.value - er.min_value
|
89 |
+
offset += er.max_value - er.min_value + 1
|
90 |
+
|
91 |
+
raise ValueError(f'Unknown event type: {event.type}')
|
92 |
+
|
93 |
+
def event_type_range(self, event_type: str) -> Tuple[int, int]:
|
94 |
+
"""Return [min_id, max_id] for an event type."""
|
95 |
+
offset = 0
|
96 |
+
for er in self._event_ranges:
|
97 |
+
if event_type == er.type:
|
98 |
+
return offset, offset + (er.max_value - er.min_value)
|
99 |
+
offset += er.max_value - er.min_value + 1
|
100 |
+
|
101 |
+
raise ValueError(f'Unknown event type: {event_type}')
|
102 |
+
|
103 |
+
def decode_event_index(self, index: int) -> Event:
|
104 |
+
"""Decode an event index to an Event."""
|
105 |
+
offset = 0
|
106 |
+
for er in self._event_ranges:
|
107 |
+
if offset <= index <= offset + er.max_value - er.min_value:
|
108 |
+
return Event(
|
109 |
+
type=er.type, value=er.min_value + index - offset)
|
110 |
+
offset += er.max_value - er.min_value + 1
|
111 |
+
|
112 |
+
raise ValueError(f'Unknown event index: {index}')
|
mt3/event_codec_test.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for event_codec."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from mt3 import event_codec
|
19 |
+
|
20 |
+
Event = event_codec.Event
|
21 |
+
EventRange = event_codec.EventRange
|
22 |
+
|
23 |
+
|
24 |
+
class EventCodecTest(absltest.TestCase):
|
25 |
+
|
26 |
+
def test_encode_decode(self):
|
27 |
+
ec = event_codec.Codec(
|
28 |
+
max_shift_steps=100,
|
29 |
+
steps_per_second=100,
|
30 |
+
event_ranges=[EventRange('pitch', min_value=0, max_value=127)])
|
31 |
+
events = [
|
32 |
+
Event(type='pitch', value=60),
|
33 |
+
Event(type='shift', value=5),
|
34 |
+
Event(type='pitch', value=62),
|
35 |
+
]
|
36 |
+
encoded = [ec.encode_event(e) for e in events]
|
37 |
+
self.assertSequenceEqual([161, 5, 163], encoded)
|
38 |
+
|
39 |
+
decoded = [ec.decode_event_index(idx) for idx in encoded]
|
40 |
+
self.assertSequenceEqual(events, decoded)
|
41 |
+
|
42 |
+
def test_shift_steps(self):
|
43 |
+
ec = event_codec.Codec(
|
44 |
+
max_shift_steps=100,
|
45 |
+
steps_per_second=100,
|
46 |
+
event_ranges=[EventRange('pitch', min_value=0, max_value=127)])
|
47 |
+
|
48 |
+
self.assertEqual(100, ec.max_shift_steps)
|
49 |
+
self.assertFalse(ec.is_shift_event_index(-1))
|
50 |
+
self.assertTrue(ec.is_shift_event_index(0))
|
51 |
+
self.assertTrue(ec.is_shift_event_index(100))
|
52 |
+
self.assertFalse(ec.is_shift_event_index(101))
|
53 |
+
|
54 |
+
if __name__ == '__main__':
|
55 |
+
absltest.main()
|
mt3/gin/eval.gin
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Defaults for eval.py.
|
2 |
+
#
|
3 |
+
# You must also include a binding for MODEL.
|
4 |
+
#
|
5 |
+
# Required to be set:
|
6 |
+
#
|
7 |
+
# - TASK_PREFIX
|
8 |
+
# - TASK_FEATURE_LENGTHS
|
9 |
+
# - CHECKPOINT_PATH
|
10 |
+
# - EVAL_OUTPUT_DIR
|
11 |
+
#
|
12 |
+
# Commonly overridden options:
|
13 |
+
#
|
14 |
+
# - DatasetConfig.split
|
15 |
+
# - DatasetConfig.batch_size
|
16 |
+
# - DatasetConfig.use_cached
|
17 |
+
# - RestoreCheckpointConfig.mode
|
18 |
+
# - PjitPartitioner.num_partitions
|
19 |
+
|
20 |
+
from __gin__ import dynamic_registration
|
21 |
+
|
22 |
+
import __main__ as eval_script
|
23 |
+
from mt3 import preprocessors
|
24 |
+
from mt3 import tasks
|
25 |
+
from mt3 import vocabularies
|
26 |
+
from t5x import partitioning
|
27 |
+
from t5x import utils
|
28 |
+
|
29 |
+
# Must be overridden
|
30 |
+
TASK_PREFIX = %gin.REQUIRED
|
31 |
+
TASK_FEATURE_LENGTHS = %gin.REQUIRED
|
32 |
+
CHECKPOINT_PATH = %gin.REQUIRED
|
33 |
+
EVAL_OUTPUT_DIR = %gin.REQUIRED
|
34 |
+
|
35 |
+
# Number of velocity bins: set to 1 (no velocity) or 127
|
36 |
+
NUM_VELOCITY_BINS = %gin.REQUIRED
|
37 |
+
VOCAB_CONFIG = @vocabularies.VocabularyConfig()
|
38 |
+
vocabularies.VocabularyConfig.num_velocity_bins = %NUM_VELOCITY_BINS
|
39 |
+
|
40 |
+
# Program granularity: set to 'flat', 'midi_class', or 'full'
|
41 |
+
PROGRAM_GRANULARITY = %gin.REQUIRED
|
42 |
+
preprocessors.map_midi_programs.granularity_type = %PROGRAM_GRANULARITY
|
43 |
+
|
44 |
+
TASK_SUFFIX = 'test'
|
45 |
+
tasks.construct_task_name:
|
46 |
+
task_prefix = %TASK_PREFIX
|
47 |
+
vocab_config = %VOCAB_CONFIG
|
48 |
+
task_suffix = %TASK_SUFFIX
|
49 |
+
|
50 |
+
eval_script.evaluate:
|
51 |
+
model = %MODEL # imported from separate gin file
|
52 |
+
dataset_cfg = @utils.DatasetConfig()
|
53 |
+
partitioner = @partitioning.PjitPartitioner()
|
54 |
+
restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
|
55 |
+
output_dir = %EVAL_OUTPUT_DIR
|
56 |
+
|
57 |
+
utils.DatasetConfig:
|
58 |
+
mixture_or_task_name = @tasks.construct_task_name()
|
59 |
+
task_feature_lengths = %TASK_FEATURE_LENGTHS
|
60 |
+
split = 'eval'
|
61 |
+
batch_size = 32
|
62 |
+
shuffle = False
|
63 |
+
seed = 42
|
64 |
+
use_cached = True
|
65 |
+
pack = False
|
66 |
+
use_custom_packing_ops = False
|
67 |
+
|
68 |
+
partitioning.PjitPartitioner.num_partitions = 1
|
69 |
+
|
70 |
+
utils.RestoreCheckpointConfig:
|
71 |
+
path = %CHECKPOINT_PATH
|
72 |
+
mode = 'specific'
|
mt3/gin/infer.gin
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Defaults for infer.py.
|
2 |
+
#
|
3 |
+
# You must also include a binding for MODEL.
|
4 |
+
#
|
5 |
+
# Required to be set:
|
6 |
+
#
|
7 |
+
# - TASK_PREFIX
|
8 |
+
# - TASK_FEATURE_LENGTHS
|
9 |
+
# - CHECKPOINT_PATH
|
10 |
+
# - INFER_OUTPUT_DIR
|
11 |
+
#
|
12 |
+
# Commonly overridden options:
|
13 |
+
#
|
14 |
+
# - infer.mode
|
15 |
+
# - infer.checkpoint_period
|
16 |
+
# - infer.shard_id
|
17 |
+
# - infer.num_shards
|
18 |
+
# - DatasetConfig.split
|
19 |
+
# - DatasetConfig.batch_size
|
20 |
+
# - DatasetConfig.use_cached
|
21 |
+
# - RestoreCheckpointConfig.is_tensorflow
|
22 |
+
# - RestoreCheckpointConfig.mode
|
23 |
+
# - PjitPartitioner.num_partitions
|
24 |
+
|
25 |
+
from __gin__ import dynamic_registration
|
26 |
+
|
27 |
+
import __main__ as infer_script
|
28 |
+
from mt3 import inference
|
29 |
+
from mt3 import preprocessors
|
30 |
+
from mt3 import tasks
|
31 |
+
from mt3 import vocabularies
|
32 |
+
from t5x import partitioning
|
33 |
+
from t5x import utils
|
34 |
+
|
35 |
+
# Must be overridden
|
36 |
+
TASK_PREFIX = %gin.REQUIRED
|
37 |
+
TASK_FEATURE_LENGTHS = %gin.REQUIRED
|
38 |
+
CHECKPOINT_PATH = %gin.REQUIRED
|
39 |
+
INFER_OUTPUT_DIR = %gin.REQUIRED
|
40 |
+
|
41 |
+
# Number of velocity bins: set to 1 (no velocity) or 127
|
42 |
+
NUM_VELOCITY_BINS = %gin.REQUIRED
|
43 |
+
VOCAB_CONFIG = @vocabularies.VocabularyConfig()
|
44 |
+
vocabularies.VocabularyConfig.num_velocity_bins = %NUM_VELOCITY_BINS
|
45 |
+
|
46 |
+
# Program granularity: set to 'flat', 'midi_class', or 'full'
|
47 |
+
PROGRAM_GRANULARITY = %gin.REQUIRED
|
48 |
+
preprocessors.map_midi_programs.granularity_type = %PROGRAM_GRANULARITY
|
49 |
+
|
50 |
+
TASK_SUFFIX = 'test'
|
51 |
+
tasks.construct_task_name:
|
52 |
+
task_prefix = %TASK_PREFIX
|
53 |
+
vocab_config = %VOCAB_CONFIG
|
54 |
+
task_suffix = %TASK_SUFFIX
|
55 |
+
|
56 |
+
ONSETS_ONLY = %gin.REQUIRED
|
57 |
+
USE_TIES = %gin.REQUIRED
|
58 |
+
inference.write_inferences_to_file:
|
59 |
+
vocab_config = %VOCAB_CONFIG
|
60 |
+
onsets_only = %ONSETS_ONLY
|
61 |
+
use_ties = %USE_TIES
|
62 |
+
|
63 |
+
infer_script.infer:
|
64 |
+
mode = 'predict'
|
65 |
+
model = %MODEL # imported from separate gin file
|
66 |
+
output_dir = %INFER_OUTPUT_DIR
|
67 |
+
dataset_cfg = @utils.DatasetConfig()
|
68 |
+
partitioner = @partitioning.PjitPartitioner()
|
69 |
+
restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
|
70 |
+
# This is a hack, but pass an extremely large value here to make sure the
|
71 |
+
# entire dataset fits in a single epoch. Otherwise, segments from a single
|
72 |
+
# example may end up in different epochs after splitting.
|
73 |
+
checkpoint_period = 1000000
|
74 |
+
shard_id = 0
|
75 |
+
num_shards = 1
|
76 |
+
write_fn = @inference.write_inferences_to_file
|
77 |
+
|
78 |
+
utils.DatasetConfig:
|
79 |
+
mixture_or_task_name = @tasks.construct_task_name()
|
80 |
+
task_feature_lengths = %TASK_FEATURE_LENGTHS
|
81 |
+
use_cached = True
|
82 |
+
split = 'eval'
|
83 |
+
batch_size = 32
|
84 |
+
shuffle = False
|
85 |
+
seed = 0
|
86 |
+
pack = False
|
87 |
+
|
88 |
+
partitioning.PjitPartitioner.num_partitions = 1
|
89 |
+
|
90 |
+
utils.RestoreCheckpointConfig:
|
91 |
+
path = %CHECKPOINT_PATH
|
92 |
+
mode = 'specific'
|
mt3/gin/ismir2021.gin
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for ISMIR 2021 piano-only model.
|
2 |
+
|
3 |
+
TASK_PREFIX = 'maestrov3_notes'
|
4 |
+
TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 1024}
|
5 |
+
TRAIN_STEPS = 400000
|
6 |
+
NUM_VELOCITY_BINS = 127
|
7 |
+
PROGRAM_GRANULARITY = 'flat'
|
8 |
+
ONSETS_ONLY = False
|
9 |
+
USE_TIES = False
|
mt3/gin/ismir2022/base.gin
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# T5.1.1 Base model.
|
2 |
+
include 'model.gin'
|
3 |
+
|
4 |
+
network.T5Config:
|
5 |
+
emb_dim = 768
|
6 |
+
num_heads = 12
|
7 |
+
num_encoder_layers = 12
|
8 |
+
num_decoder_layers = 12
|
9 |
+
head_dim = 64
|
10 |
+
mlp_dim = 2048
|
mt3/gin/ismir2022/finetune.gin
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __gin__ import dynamic_registration
|
2 |
+
|
3 |
+
from mt3 import network
|
4 |
+
from t5x import utils
|
5 |
+
|
6 |
+
include 'train.gin'
|
7 |
+
|
8 |
+
TASK_PREFIX = 'mega_notes_ties'
|
9 |
+
TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 1024}
|
10 |
+
TRAIN_STEPS = 150000
|
11 |
+
BATCH_SIZE = 256
|
12 |
+
LABEL_SMOOTHING = 0.0
|
13 |
+
NUM_VELOCITY_BINS = 1
|
14 |
+
PROGRAM_GRANULARITY = 'full'
|
15 |
+
ONSETS_ONLY = False
|
16 |
+
USE_TIES = True
|
17 |
+
MAX_EXAMPLES_PER_MIX = None
|
18 |
+
|
19 |
+
network.T5Config.dropout_rate = 0.1
|
20 |
+
|
21 |
+
CHECKPOINT_PATH = %gin.REQUIRED
|
22 |
+
utils.CheckpointConfig.restore = @utils.RestoreCheckpointConfig()
|
23 |
+
utils.RestoreCheckpointConfig:
|
24 |
+
path = %CHECKPOINT_PATH
|
25 |
+
mode = 'specific'
|
mt3/gin/ismir2022/pretrain.gin
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include 'train.gin'
|
2 |
+
|
3 |
+
TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 1024}
|
4 |
+
TRAIN_STEPS = 500000
|
5 |
+
BATCH_SIZE = 1024
|
6 |
+
LABEL_SMOOTHING = 0.1
|
7 |
+
NUM_VELOCITY_BINS = 1
|
8 |
+
PROGRAM_GRANULARITY = 'full'
|
9 |
+
ONSETS_ONLY = False
|
10 |
+
USE_TIES = True
|
11 |
+
MAX_EXAMPLES_PER_MIX = 8
|
12 |
+
|
13 |
+
network.T5Config.dropout_rate = 0.0
|
mt3/gin/ismir2022/small.gin
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# T5.1.1 Small model.
|
2 |
+
include 'model.gin'
|
mt3/gin/local_tiny.gin
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A gin file to make the Transformer models tiny for faster local testing.
|
2 |
+
#
|
3 |
+
# When testing locally with CPU, there are a few things that we need.
|
4 |
+
# - tiny model size
|
5 |
+
# - small enough batch size
|
6 |
+
# - small sequence length
|
7 |
+
# - determinstic dataset pipeline
|
8 |
+
#
|
9 |
+
# This gin file adds such configs. To use this gin file, add it on top of the
|
10 |
+
# existing full-scale gin files. The ordering of the gin file matters. So this
|
11 |
+
# should be added after all the other files are added to override the same
|
12 |
+
# configurables.
|
13 |
+
|
14 |
+
from __gin__ import dynamic_registration
|
15 |
+
|
16 |
+
from t5x import partitioning
|
17 |
+
from t5x import trainer
|
18 |
+
from t5x import utils
|
19 |
+
from t5x.examples.t5 import network
|
20 |
+
|
21 |
+
import __main__ as train_script
|
22 |
+
|
23 |
+
train_script.train.random_seed = 42 # dropout seed
|
24 |
+
train/utils.DatasetConfig.seed = 42 # dataset seed
|
25 |
+
|
26 |
+
TASK_FEATURE_LENGTHS = {"inputs": 8, "targets": 16}
|
27 |
+
LABEL_SMOOTHING = 0.0
|
28 |
+
|
29 |
+
# Network specification overrides
|
30 |
+
network.Transformer.config = @network.T5Config()
|
31 |
+
network.T5Config:
|
32 |
+
dtype = 'float32'
|
33 |
+
emb_dim = 8
|
34 |
+
num_heads = 4
|
35 |
+
num_encoder_layers = 2
|
36 |
+
num_decoder_layers = 2
|
37 |
+
head_dim = 3
|
38 |
+
mlp_dim = 16
|
39 |
+
mlp_activations = ('gelu', 'linear')
|
40 |
+
dropout_rate = 0.0
|
41 |
+
logits_via_embedding = False
|
42 |
+
|
43 |
+
TRAIN_STEPS = 3
|
44 |
+
|
45 |
+
train/utils.DatasetConfig:
|
46 |
+
batch_size = 8
|
47 |
+
shuffle = False
|
48 |
+
|
49 |
+
train_eval/utils.DatasetConfig.batch_size = 8
|
50 |
+
|
51 |
+
train_script.train:
|
52 |
+
eval_period = 3
|
53 |
+
eval_steps = 3
|
54 |
+
|
55 |
+
trainer.Trainer.num_microbatches = 0
|
56 |
+
partitioning.PjitPartitioner:
|
57 |
+
num_partitions = 1
|
58 |
+
model_parallel_submesh = None
|
59 |
+
|
60 |
+
utils.CheckpointConfig:
|
61 |
+
restore = None
|
62 |
+
|
63 |
+
infer_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
|
mt3/gin/model.gin
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# T5.1.1 Small model.
|
2 |
+
from __gin__ import dynamic_registration
|
3 |
+
|
4 |
+
from mt3 import models
|
5 |
+
from mt3 import network
|
6 |
+
from mt3 import spectrograms
|
7 |
+
from mt3 import vocabularies
|
8 |
+
import seqio
|
9 |
+
from t5x import adafactor
|
10 |
+
|
11 |
+
# ------------------- Loss HParam ----------------------------------------------
|
12 |
+
Z_LOSS = 0.0001
|
13 |
+
LABEL_SMOOTHING = 0.0
|
14 |
+
LOSS_NORMALIZING_FACTOR = None
|
15 |
+
models.ContinuousInputsEncoderDecoderModel:
|
16 |
+
z_loss = %Z_LOSS
|
17 |
+
label_smoothing = %LABEL_SMOOTHING
|
18 |
+
loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
|
19 |
+
|
20 |
+
# Output vocabulary
|
21 |
+
VOCAB_CONFIG = %gin.REQUIRED
|
22 |
+
OUTPUT_VOCABULARY = @vocabularies.vocabulary_from_codec()
|
23 |
+
vocabularies.vocabulary_from_codec.codec = @vocabularies.build_codec()
|
24 |
+
vocabularies.build_codec.vocab_config = %VOCAB_CONFIG
|
25 |
+
|
26 |
+
# ------------------- Optimizer ------------------------------------------------
|
27 |
+
# `learning_rate` is set by `Trainer.learning_rate_fn`.
|
28 |
+
OPTIMIZER = @adafactor.Adafactor()
|
29 |
+
adafactor.Adafactor:
|
30 |
+
decay_rate = 0.8
|
31 |
+
step_offset = 0
|
32 |
+
logical_factor_rules = @adafactor.standard_logical_factor_rules()
|
33 |
+
|
34 |
+
# ------------------- Model ----------------------------------------------------
|
35 |
+
SPECTROGRAM_CONFIG = @spectrograms.SpectrogramConfig()
|
36 |
+
MODEL = @models.ContinuousInputsEncoderDecoderModel()
|
37 |
+
models.ContinuousInputsEncoderDecoderModel:
|
38 |
+
module = @network.Transformer()
|
39 |
+
input_vocabulary = @seqio.vocabularies.PassThroughVocabulary()
|
40 |
+
output_vocabulary = %OUTPUT_VOCABULARY
|
41 |
+
optimizer_def = %OPTIMIZER
|
42 |
+
input_depth = @spectrograms.input_depth()
|
43 |
+
seqio.vocabularies.PassThroughVocabulary.size = 0
|
44 |
+
spectrograms.input_depth.spectrogram_config = %SPECTROGRAM_CONFIG
|
45 |
+
|
46 |
+
# ------------------- Network specification ------------------------------------
|
47 |
+
network.Transformer.config = @network.T5Config()
|
48 |
+
network.T5Config:
|
49 |
+
vocab_size = @vocabularies.num_embeddings()
|
50 |
+
dtype = 'float32'
|
51 |
+
emb_dim = 512
|
52 |
+
num_heads = 6
|
53 |
+
num_encoder_layers = 8
|
54 |
+
num_decoder_layers = 8
|
55 |
+
head_dim = 64
|
56 |
+
mlp_dim = 1024
|
57 |
+
mlp_activations = ('gelu', 'linear')
|
58 |
+
dropout_rate = 0.1
|
59 |
+
logits_via_embedding = False
|
60 |
+
vocabularies.num_embeddings.vocabulary = %OUTPUT_VOCABULARY
|
mt3/gin/mt3.gin
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for MT3 multi-task multitrack model.
|
2 |
+
|
3 |
+
TASK_PREFIX = 'mega_notes_ties'
|
4 |
+
TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 1024}
|
5 |
+
TRAIN_STEPS = 1000000
|
6 |
+
NUM_VELOCITY_BINS = 1
|
7 |
+
PROGRAM_GRANULARITY = 'full'
|
8 |
+
ONSETS_ONLY = False
|
9 |
+
USE_TIES = True
|
mt3/gin/train.gin
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Defaults for training with train.py.
|
2 |
+
#
|
3 |
+
# You must also include a binding for MODEL.
|
4 |
+
#
|
5 |
+
# Required to be set:
|
6 |
+
#
|
7 |
+
# - TASK_PREFIX
|
8 |
+
# - TASK_FEATURE_LENGTHS
|
9 |
+
# - TRAIN_STEPS
|
10 |
+
# - MODEL_DIR
|
11 |
+
#
|
12 |
+
# Commonly overridden options:
|
13 |
+
# - BATCH_SIZE
|
14 |
+
# - PjitPartitioner.num_partitions
|
15 |
+
# - Trainer.num_microbatches
|
16 |
+
# - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess
|
17 |
+
# on the fly.
|
18 |
+
|
19 |
+
from __gin__ import dynamic_registration
|
20 |
+
|
21 |
+
import __main__ as train_script
|
22 |
+
import seqio
|
23 |
+
from mt3 import mixing
|
24 |
+
from mt3 import preprocessors
|
25 |
+
from mt3 import tasks
|
26 |
+
from mt3 import vocabularies
|
27 |
+
from t5x import gin_utils
|
28 |
+
from t5x import partitioning
|
29 |
+
from t5x import utils
|
30 |
+
from t5x import trainer
|
31 |
+
|
32 |
+
# Must be overridden
|
33 |
+
TASK_PREFIX = %gin.REQUIRED
|
34 |
+
TASK_FEATURE_LENGTHS = %gin.REQUIRED
|
35 |
+
TRAIN_STEPS = %gin.REQUIRED
|
36 |
+
MODEL_DIR = %gin.REQUIRED
|
37 |
+
|
38 |
+
# Commonly overridden
|
39 |
+
TRAIN_TASK_SUFFIX = 'train'
|
40 |
+
EVAL_TASK_SUFFIX = 'eval'
|
41 |
+
USE_CACHED_TASKS = True
|
42 |
+
BATCH_SIZE = 256
|
43 |
+
|
44 |
+
# Sometimes overridden
|
45 |
+
EVAL_STEPS = 20
|
46 |
+
|
47 |
+
# Convenience overrides.
|
48 |
+
EVALUATOR_USE_MEMORY_CACHE = True
|
49 |
+
EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset.
|
50 |
+
JSON_WRITE_N_RESULTS = 0 # Don't write any inferences.
|
51 |
+
|
52 |
+
# Number of velocity bins: set to 1 (no velocity) or 127
|
53 |
+
NUM_VELOCITY_BINS = %gin.REQUIRED
|
54 |
+
VOCAB_CONFIG = @vocabularies.VocabularyConfig()
|
55 |
+
vocabularies.VocabularyConfig.num_velocity_bins = %NUM_VELOCITY_BINS
|
56 |
+
|
57 |
+
# Program granularity: set to 'flat', 'midi_class', or 'full'
|
58 |
+
PROGRAM_GRANULARITY = %gin.REQUIRED
|
59 |
+
preprocessors.map_midi_programs.granularity_type = %PROGRAM_GRANULARITY
|
60 |
+
|
61 |
+
# Maximum number of examples per mix, or None for no mixing
|
62 |
+
MAX_EXAMPLES_PER_MIX = None
|
63 |
+
mixing.mix_transcription_examples.max_examples_per_mix = %MAX_EXAMPLES_PER_MIX
|
64 |
+
|
65 |
+
train/tasks.construct_task_name:
|
66 |
+
task_prefix = %TASK_PREFIX
|
67 |
+
vocab_config = %VOCAB_CONFIG
|
68 |
+
task_suffix = %TRAIN_TASK_SUFFIX
|
69 |
+
|
70 |
+
eval/tasks.construct_task_name:
|
71 |
+
task_prefix = %TASK_PREFIX
|
72 |
+
vocab_config = %VOCAB_CONFIG
|
73 |
+
task_suffix = %EVAL_TASK_SUFFIX
|
74 |
+
|
75 |
+
train_script.train:
|
76 |
+
model = %MODEL # imported from separate gin file
|
77 |
+
model_dir = %MODEL_DIR
|
78 |
+
train_dataset_cfg = @train/utils.DatasetConfig()
|
79 |
+
train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
|
80 |
+
infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig()
|
81 |
+
checkpoint_cfg = @utils.CheckpointConfig()
|
82 |
+
partitioner = @partitioning.PjitPartitioner()
|
83 |
+
trainer_cls = @trainer.Trainer
|
84 |
+
total_steps = %TRAIN_STEPS
|
85 |
+
eval_steps = %EVAL_STEPS
|
86 |
+
eval_period = 5000
|
87 |
+
random_seed = None # use faster, hardware RNG
|
88 |
+
summarize_config_fn = @gin_utils.summarize_gin_config
|
89 |
+
inference_evaluator_cls = @seqio.Evaluator
|
90 |
+
|
91 |
+
seqio.Evaluator:
|
92 |
+
logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
|
93 |
+
num_examples = %EVALUATOR_NUM_EXAMPLES
|
94 |
+
use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE
|
95 |
+
|
96 |
+
seqio.JSONLogger:
|
97 |
+
write_n_results = %JSON_WRITE_N_RESULTS
|
98 |
+
|
99 |
+
train/utils.DatasetConfig:
|
100 |
+
mixture_or_task_name = @train/tasks.construct_task_name()
|
101 |
+
task_feature_lengths = %TASK_FEATURE_LENGTHS
|
102 |
+
split = 'train'
|
103 |
+
batch_size = %BATCH_SIZE
|
104 |
+
shuffle = True
|
105 |
+
seed = None # use a new seed each run/restart
|
106 |
+
use_cached = %USE_CACHED_TASKS
|
107 |
+
pack = False
|
108 |
+
|
109 |
+
train_eval/utils.DatasetConfig:
|
110 |
+
mixture_or_task_name = @train/tasks.construct_task_name()
|
111 |
+
task_feature_lengths = %TASK_FEATURE_LENGTHS
|
112 |
+
split = 'eval'
|
113 |
+
batch_size = %BATCH_SIZE
|
114 |
+
shuffle = False
|
115 |
+
seed = 42
|
116 |
+
use_cached = %USE_CACHED_TASKS
|
117 |
+
pack = False
|
118 |
+
|
119 |
+
infer_eval/utils.DatasetConfig:
|
120 |
+
mixture_or_task_name = @eval/tasks.construct_task_name()
|
121 |
+
task_feature_lengths = %TASK_FEATURE_LENGTHS
|
122 |
+
split = 'eval'
|
123 |
+
batch_size = %BATCH_SIZE
|
124 |
+
shuffle = False
|
125 |
+
seed = 42
|
126 |
+
use_cached = %USE_CACHED_TASKS
|
127 |
+
pack = False
|
128 |
+
|
129 |
+
utils.CheckpointConfig:
|
130 |
+
restore = None
|
131 |
+
save = @utils.SaveCheckpointConfig()
|
132 |
+
utils.SaveCheckpointConfig:
|
133 |
+
period = 5000
|
134 |
+
dtype = 'float32'
|
135 |
+
keep = None # keep all checkpoints
|
136 |
+
save_dataset = False # don't checkpoint dataset state
|
137 |
+
|
138 |
+
partitioning.PjitPartitioner:
|
139 |
+
num_partitions = 1
|
140 |
+
model_parallel_submesh = None
|
141 |
+
|
142 |
+
trainer.Trainer:
|
143 |
+
num_microbatches = None
|
144 |
+
learning_rate_fn = @utils.create_learning_rate_scheduler()
|
145 |
+
utils.create_learning_rate_scheduler:
|
146 |
+
factors = 'constant'
|
147 |
+
base_learning_rate = 0.001
|
148 |
+
warmup_steps = 1000
|
mt3/inference.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Functions for MT3 inference."""
|
16 |
+
|
17 |
+
import functools
|
18 |
+
import json
|
19 |
+
|
20 |
+
from typing import Any, Optional, Sequence
|
21 |
+
|
22 |
+
import gin
|
23 |
+
|
24 |
+
from mt3 import metrics_utils
|
25 |
+
from mt3 import note_sequences
|
26 |
+
from mt3 import tasks
|
27 |
+
from mt3 import vocabularies
|
28 |
+
|
29 |
+
import note_seq
|
30 |
+
import seqio
|
31 |
+
import tensorflow as tf
|
32 |
+
|
33 |
+
|
34 |
+
def write_inferences_to_file(
|
35 |
+
path: str,
|
36 |
+
inferences: Sequence[Any],
|
37 |
+
task_ds: tf.data.Dataset,
|
38 |
+
mode: str,
|
39 |
+
vocabulary: Optional[seqio.Vocabulary] = None,
|
40 |
+
vocab_config=gin.REQUIRED,
|
41 |
+
onsets_only=gin.REQUIRED,
|
42 |
+
use_ties=gin.REQUIRED) -> None:
|
43 |
+
"""Writes model predictions, ground truth transcriptions, and input audio.
|
44 |
+
|
45 |
+
For now this only works for transcription tasks with ties.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
path: File path to write to.
|
49 |
+
inferences: Model inferences, output of predict_batch.
|
50 |
+
task_ds: Original task dataset.
|
51 |
+
mode: Prediction mode; must be 'predict' as 'score' is not supported.
|
52 |
+
vocabulary: Task output vocabulary.
|
53 |
+
vocab_config: Vocabulary config object.
|
54 |
+
onsets_only: If True, only predict onsets.
|
55 |
+
use_ties: If True, use "tie" representation.
|
56 |
+
"""
|
57 |
+
if mode == 'score':
|
58 |
+
raise ValueError('`score` mode currently not supported in MT3')
|
59 |
+
if not vocabulary:
|
60 |
+
raise ValueError('`vocabulary` parameter required in `predict` mode')
|
61 |
+
|
62 |
+
if onsets_only and use_ties:
|
63 |
+
raise ValueError('ties not compatible with onset-only transcription')
|
64 |
+
if onsets_only:
|
65 |
+
encoding_spec = note_sequences.NoteOnsetEncodingSpec
|
66 |
+
elif not use_ties:
|
67 |
+
encoding_spec = note_sequences.NoteEncodingSpec
|
68 |
+
else:
|
69 |
+
encoding_spec = note_sequences.NoteEncodingWithTiesSpec
|
70 |
+
|
71 |
+
codec = vocabularies.build_codec(vocab_config)
|
72 |
+
|
73 |
+
targets = []
|
74 |
+
predictions = []
|
75 |
+
|
76 |
+
for inp, output in zip(task_ds.as_numpy_iterator(), inferences):
|
77 |
+
tokens = tasks.trim_eos(vocabulary.decode_tf(output).numpy())
|
78 |
+
|
79 |
+
start_time = inp['input_times'][0]
|
80 |
+
# Round down to nearest symbolic token step.
|
81 |
+
start_time -= start_time % (1 / codec.steps_per_second)
|
82 |
+
|
83 |
+
targets.append({
|
84 |
+
'unique_id': inp['unique_id'][0],
|
85 |
+
'ref_ns': inp['sequence'][0] if inp['sequence'][0] else None,
|
86 |
+
})
|
87 |
+
|
88 |
+
predictions.append({
|
89 |
+
'unique_id': inp['unique_id'][0],
|
90 |
+
'est_tokens': tokens,
|
91 |
+
'start_time': start_time,
|
92 |
+
# Input audio is not part of the "prediction" but the below call to
|
93 |
+
# metrics_utils.event_predictions_to_ns handles the concatenation.
|
94 |
+
'raw_inputs': inp['raw_inputs']
|
95 |
+
})
|
96 |
+
|
97 |
+
# The first target for each full example contains the NoteSequence; just
|
98 |
+
# organize by ID.
|
99 |
+
full_targets = {}
|
100 |
+
for target in targets:
|
101 |
+
if target['ref_ns']:
|
102 |
+
full_targets[target['unique_id']] = {
|
103 |
+
'ref_ns': note_seq.NoteSequence.FromString(target['ref_ns'])
|
104 |
+
}
|
105 |
+
|
106 |
+
full_predictions = metrics_utils.combine_predictions_by_id(
|
107 |
+
predictions=predictions,
|
108 |
+
combine_predictions_fn=functools.partial(
|
109 |
+
metrics_utils.event_predictions_to_ns,
|
110 |
+
codec=codec,
|
111 |
+
encoding_spec=encoding_spec))
|
112 |
+
|
113 |
+
assert sorted(full_targets.keys()) == sorted(full_predictions.keys())
|
114 |
+
|
115 |
+
full_target_prediction_pairs = [
|
116 |
+
(full_targets[id], full_predictions[id])
|
117 |
+
for id in sorted(full_targets.keys())
|
118 |
+
]
|
119 |
+
|
120 |
+
def note_to_dict(note):
|
121 |
+
return {
|
122 |
+
'start_time': note.start_time,
|
123 |
+
'end_time': note.end_time,
|
124 |
+
'pitch': note.pitch,
|
125 |
+
'velocity': note.velocity,
|
126 |
+
'program': note.program,
|
127 |
+
'is_drum': note.is_drum
|
128 |
+
}
|
129 |
+
|
130 |
+
with tf.io.gfile.GFile(path, 'w') as f:
|
131 |
+
for target, prediction in full_target_prediction_pairs:
|
132 |
+
json_dict = {
|
133 |
+
'id': target['ref_ns'].id,
|
134 |
+
'est_notes':
|
135 |
+
[note_to_dict(note) for note in prediction['est_ns'].notes]
|
136 |
+
}
|
137 |
+
json_str = json.dumps(json_dict, cls=seqio.TensorAndNumpyEncoder)
|
138 |
+
f.write(json_str + '\n')
|
mt3/layers.py
ADDED
@@ -0,0 +1,830 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Dense attention classes and mask/weighting functions."""
|
16 |
+
|
17 |
+
# pylint: disable=attribute-defined-outside-init,g-bare-generic
|
18 |
+
|
19 |
+
import dataclasses
|
20 |
+
import functools
|
21 |
+
import operator
|
22 |
+
from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union
|
23 |
+
|
24 |
+
from flax import linen as nn
|
25 |
+
from flax.linen import partitioning as nn_partitioning
|
26 |
+
import jax
|
27 |
+
from jax import lax
|
28 |
+
from jax import random
|
29 |
+
import jax.numpy as jnp
|
30 |
+
import numpy as np
|
31 |
+
|
32 |
+
|
33 |
+
# from flax.linen.partitioning import param_with_axes, with_sharding_constraint
|
34 |
+
param_with_axes = nn_partitioning.param_with_axes
|
35 |
+
with_sharding_constraint = nn_partitioning.with_sharding_constraint
|
36 |
+
|
37 |
+
|
38 |
+
# Type annotations
|
39 |
+
Array = jnp.ndarray
|
40 |
+
DType = jnp.dtype
|
41 |
+
PRNGKey = jnp.ndarray
|
42 |
+
Shape = Iterable[int]
|
43 |
+
Activation = Callable[..., Array]
|
44 |
+
# Parameter initializers.
|
45 |
+
Initializer = Callable[[PRNGKey, Shape, DType], Array]
|
46 |
+
|
47 |
+
default_embed_init = nn.initializers.variance_scaling(
|
48 |
+
1.0, 'fan_in', 'normal', out_axis=0)
|
49 |
+
|
50 |
+
|
51 |
+
def sinusoidal(min_scale: float = 1.0,
|
52 |
+
max_scale: float = 10000.0,
|
53 |
+
dtype: DType = jnp.float32) -> Initializer:
|
54 |
+
"""Creates 1D Sinusoidal Position Embedding Initializer.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
min_scale: Minimum frequency-scale in sine grating.
|
58 |
+
max_scale: Maximum frequency-scale in sine grating.
|
59 |
+
dtype: The DType of the returned values.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
The sinusoidal initialization function.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def init(key: PRNGKey, shape: Shape, dtype: DType = dtype) -> Array:
|
66 |
+
"""Sinusoidal init."""
|
67 |
+
del key
|
68 |
+
if dtype != np.float32:
|
69 |
+
raise ValueError('The sinusoidal initializer only supports float32.')
|
70 |
+
if len(list(shape)) != 2:
|
71 |
+
raise ValueError(
|
72 |
+
f'Expected a 2D shape (max_len, features), but got {shape}.')
|
73 |
+
max_len, features = shape
|
74 |
+
pe = np.zeros((max_len, features), dtype=dtype)
|
75 |
+
position = np.arange(0, max_len)[:, np.newaxis]
|
76 |
+
scale_factor = -np.log(max_scale / min_scale) / (features // 2 - 1)
|
77 |
+
div_term = min_scale * np.exp(np.arange(0, features // 2) * scale_factor)
|
78 |
+
pe[:, :features // 2] = np.sin(position * div_term)
|
79 |
+
pe[:, features // 2:2 * (features // 2)] = np.cos(position * div_term)
|
80 |
+
return jnp.array(pe)
|
81 |
+
|
82 |
+
return init
|
83 |
+
|
84 |
+
|
85 |
+
def dot_product_attention(query: Array,
|
86 |
+
key: Array,
|
87 |
+
value: Array,
|
88 |
+
bias: Optional[Array] = None,
|
89 |
+
dropout_rng: Optional[PRNGKey] = None,
|
90 |
+
dropout_rate: float = 0.,
|
91 |
+
deterministic: bool = False,
|
92 |
+
dtype: DType = jnp.float32,
|
93 |
+
float32_logits: bool = False):
|
94 |
+
"""Computes dot-product attention given query, key, and value.
|
95 |
+
|
96 |
+
This is the core function for applying attention based on
|
97 |
+
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
|
98 |
+
query and key and combines the values using the attention weights.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
query: queries for calculating attention with shape of `[batch, q_length,
|
102 |
+
num_heads, qk_depth_per_head]`.
|
103 |
+
key: keys for calculating attention with shape of `[batch, kv_length,
|
104 |
+
num_heads, qk_depth_per_head]`.
|
105 |
+
value: values to be used in attention with shape of `[batch, kv_length,
|
106 |
+
num_heads, v_depth_per_head]`.
|
107 |
+
bias: bias for the attention weights. This should be broadcastable to the
|
108 |
+
shape `[batch, num_heads, q_length, kv_length]` This can be used for
|
109 |
+
incorporating causal masks, padding masks, proximity bias, etc.
|
110 |
+
dropout_rng: JAX PRNGKey: to be used for dropout
|
111 |
+
dropout_rate: dropout rate
|
112 |
+
deterministic: bool, deterministic or not (to apply dropout)
|
113 |
+
dtype: the dtype of the computation (default: float32)
|
114 |
+
float32_logits: bool, if True then compute logits in float32 to avoid
|
115 |
+
numerical issues with bfloat16.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
Output of shape `[batch, length, num_heads, v_depth_per_head]`.
|
119 |
+
"""
|
120 |
+
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
|
121 |
+
assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], (
|
122 |
+
'q, k, v batch dims must match.')
|
123 |
+
assert query.shape[-2] == key.shape[-2] == value.shape[-2], (
|
124 |
+
'q, k, v num_heads must match.')
|
125 |
+
assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
|
126 |
+
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
|
127 |
+
|
128 |
+
# Casting logits and softmax computation for float32 for model stability.
|
129 |
+
if float32_logits:
|
130 |
+
query = query.astype(jnp.float32)
|
131 |
+
key = key.astype(jnp.float32)
|
132 |
+
|
133 |
+
# `attn_weights`: [batch, num_heads, q_length, kv_length]
|
134 |
+
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
|
135 |
+
|
136 |
+
# Apply attention bias: masking, dropout, proximity bias, etc.
|
137 |
+
if bias is not None:
|
138 |
+
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
|
139 |
+
|
140 |
+
# Normalize the attention weights across `kv_length` dimension.
|
141 |
+
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
142 |
+
|
143 |
+
# Apply attention dropout.
|
144 |
+
if not deterministic and dropout_rate > 0.:
|
145 |
+
keep_prob = 1.0 - dropout_rate
|
146 |
+
# T5 broadcasts along the "length" dim, but unclear which one that
|
147 |
+
# corresponds to in positional dimensions here, assuming query dim.
|
148 |
+
dropout_shape = list(attn_weights.shape)
|
149 |
+
dropout_shape[-2] = 1
|
150 |
+
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
|
151 |
+
keep = jnp.broadcast_to(keep, attn_weights.shape)
|
152 |
+
multiplier = (
|
153 |
+
keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
|
154 |
+
attn_weights = attn_weights * multiplier
|
155 |
+
|
156 |
+
# Take the linear combination of `value`.
|
157 |
+
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
|
158 |
+
|
159 |
+
|
160 |
+
dynamic_vector_slice_in_dim = jax.vmap(
|
161 |
+
lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
|
162 |
+
|
163 |
+
|
164 |
+
class MultiHeadDotProductAttention(nn.Module):
|
165 |
+
"""Multi-head dot-product attention.
|
166 |
+
|
167 |
+
Attributes:
|
168 |
+
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
|
169 |
+
should be divisible by the number of heads.
|
170 |
+
head_dim: dimension of each head.
|
171 |
+
dtype: the dtype of the computation.
|
172 |
+
dropout_rate: dropout rate
|
173 |
+
kernel_init: initializer for the kernel of the Dense layers.
|
174 |
+
float32_logits: bool, if True then compute logits in float32 to avoid
|
175 |
+
numerical issues with bfloat16.
|
176 |
+
"""
|
177 |
+
|
178 |
+
num_heads: int
|
179 |
+
head_dim: int
|
180 |
+
dtype: DType = jnp.float32
|
181 |
+
dropout_rate: float = 0.
|
182 |
+
kernel_init: Initializer = nn.initializers.variance_scaling(
|
183 |
+
1.0, 'fan_in', 'normal')
|
184 |
+
float32_logits: bool = False # computes logits in float32 for stability.
|
185 |
+
|
186 |
+
@nn.compact
|
187 |
+
def __call__(self,
|
188 |
+
inputs_q: Array,
|
189 |
+
inputs_kv: Array,
|
190 |
+
mask: Optional[Array] = None,
|
191 |
+
bias: Optional[Array] = None,
|
192 |
+
*,
|
193 |
+
decode: bool = False,
|
194 |
+
deterministic: bool = False) -> Array:
|
195 |
+
"""Applies multi-head dot product attention on the input data.
|
196 |
+
|
197 |
+
Projects the inputs into multi-headed query, key, and value vectors,
|
198 |
+
applies dot-product attention and project the results to an output vector.
|
199 |
+
|
200 |
+
There are two modes: decoding and non-decoding (e.g., training). The mode is
|
201 |
+
determined by `decode` argument. For decoding, this method is called twice,
|
202 |
+
first to initialize the cache and then for an actual decoding process. The
|
203 |
+
two calls are differentiated by the presence of 'cached_key' in the variable
|
204 |
+
dict. In the cache initialization stage, the cache variables are initialized
|
205 |
+
as zeros and will be filled in the subsequent decoding process.
|
206 |
+
|
207 |
+
In the cache initialization call, `inputs_q` has a shape [batch, length,
|
208 |
+
q_features] and `inputs_kv`: [batch, length, kv_features]. During the
|
209 |
+
incremental decoding stage, query, key and value all have the shape [batch,
|
210 |
+
1, qkv_features] corresponding to a single step.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
inputs_q: input queries of shape `[batch, q_length, q_features]`.
|
214 |
+
inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
|
215 |
+
mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
|
216 |
+
bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
|
217 |
+
decode: Whether to prepare and use an autoregressive cache.
|
218 |
+
deterministic: Disables dropout if set to True.
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
output of shape `[batch, length, q_features]`.
|
222 |
+
"""
|
223 |
+
projection = functools.partial(
|
224 |
+
DenseGeneral,
|
225 |
+
axis=-1,
|
226 |
+
features=(self.num_heads, self.head_dim),
|
227 |
+
kernel_axes=('embed', 'joined_kv'),
|
228 |
+
dtype=self.dtype)
|
229 |
+
|
230 |
+
# NOTE: T5 does not explicitly rescale the attention logits by
|
231 |
+
# 1/sqrt(depth_kq)! This is folded into the initializers of the
|
232 |
+
# linear transformations, which is equivalent under Adafactor.
|
233 |
+
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
|
234 |
+
query_init = lambda *args: self.kernel_init(*args) / depth_scaling
|
235 |
+
|
236 |
+
# Project inputs_q to multi-headed q/k/v
|
237 |
+
# dimensions are then [batch, length, num_heads, head_dim]
|
238 |
+
query = projection(kernel_init=query_init, name='query')(inputs_q)
|
239 |
+
key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
|
240 |
+
value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
|
241 |
+
|
242 |
+
query = with_sharding_constraint(query, ('batch', 'length', 'heads', 'kv'))
|
243 |
+
key = with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv'))
|
244 |
+
value = with_sharding_constraint(value, ('batch', 'length', 'heads', 'kv'))
|
245 |
+
|
246 |
+
if decode:
|
247 |
+
# Detect if we're initializing by absence of existing cache data.
|
248 |
+
is_initialized = self.has_variable('cache', 'cached_key')
|
249 |
+
# The key and value have dimension [batch, length, num_heads, head_dim],
|
250 |
+
# but we cache them as [batch, num_heads, head_dim, length] as a TPU
|
251 |
+
# fusion optimization. This also enables the "scatter via one-hot
|
252 |
+
# broadcast" trick, which means we do a one-hot broadcast instead of a
|
253 |
+
# scatter/gather operations, resulting in a 3-4x speedup in practice.
|
254 |
+
swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3])
|
255 |
+
cached_key = self.variable('cache', 'cached_key', jnp.zeros,
|
256 |
+
swap_dims(key.shape), key.dtype)
|
257 |
+
cached_value = self.variable('cache', 'cached_value', jnp.zeros,
|
258 |
+
swap_dims(value.shape), value.dtype)
|
259 |
+
cache_index = self.variable('cache', 'cache_index',
|
260 |
+
lambda: jnp.array(0, dtype=jnp.int32))
|
261 |
+
if is_initialized:
|
262 |
+
batch, num_heads, head_dim, length = (cached_key.value.shape)
|
263 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
264 |
+
# and cache the keys and values step by step.
|
265 |
+
# Sanity shape check of cached key against input query.
|
266 |
+
expected_shape = (batch, 1, num_heads, head_dim)
|
267 |
+
if expected_shape != query.shape:
|
268 |
+
raise ValueError('Autoregressive cache shape error, '
|
269 |
+
'expected query shape %s instead got %s.' %
|
270 |
+
(expected_shape, query.shape))
|
271 |
+
|
272 |
+
# Create a OHE of the current index. NOTE: the index is increased below.
|
273 |
+
cur_index = cache_index.value
|
274 |
+
one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype)
|
275 |
+
# In order to update the key, value caches with the current key and
|
276 |
+
# value, we move the length axis to the back, similar to what we did for
|
277 |
+
# the cached ones above.
|
278 |
+
# Note these are currently the key and value of a single position, since
|
279 |
+
# we feed one position at a time.
|
280 |
+
one_token_key = jnp.moveaxis(key, -3, -1)
|
281 |
+
one_token_value = jnp.moveaxis(value, -3, -1)
|
282 |
+
# Update key, value caches with our new 1d spatial slices.
|
283 |
+
# We implement an efficient scatter into the cache via one-hot
|
284 |
+
# broadcast and addition.
|
285 |
+
key = cached_key.value + one_token_key * one_hot_indices
|
286 |
+
value = cached_value.value + one_token_value * one_hot_indices
|
287 |
+
cached_key.value = key
|
288 |
+
cached_value.value = value
|
289 |
+
cache_index.value = cache_index.value + 1
|
290 |
+
# Move the keys and values back to their original shapes.
|
291 |
+
key = jnp.moveaxis(key, -1, -3)
|
292 |
+
value = jnp.moveaxis(value, -1, -3)
|
293 |
+
|
294 |
+
# Causal mask for cached decoder self-attention: our single query
|
295 |
+
# position should only attend to those key positions that have already
|
296 |
+
# been generated and cached, not the remaining zero elements.
|
297 |
+
mask = combine_masks(
|
298 |
+
mask,
|
299 |
+
jnp.broadcast_to(
|
300 |
+
jnp.arange(length) <= cur_index,
|
301 |
+
# (1, 1, length) represent (head dim, query length, key length)
|
302 |
+
# query length is 1 because during decoding we deal with one
|
303 |
+
# index.
|
304 |
+
# The same mask is applied to all batch elements and heads.
|
305 |
+
(batch, 1, 1, length)))
|
306 |
+
|
307 |
+
# Grab the correct relative attention bias during decoding. This is
|
308 |
+
# only required during single step decoding.
|
309 |
+
if bias is not None:
|
310 |
+
# The bias is a full attention matrix, but during decoding we only
|
311 |
+
# have to take a slice of it.
|
312 |
+
# This is equivalent to bias[..., cur_index:cur_index+1, :].
|
313 |
+
bias = dynamic_vector_slice_in_dim(
|
314 |
+
jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2)
|
315 |
+
|
316 |
+
# Convert the boolean attention mask to an attention bias.
|
317 |
+
if mask is not None:
|
318 |
+
# attention mask in the form of attention bias
|
319 |
+
attention_bias = lax.select(
|
320 |
+
mask > 0,
|
321 |
+
jnp.full(mask.shape, 0.).astype(self.dtype),
|
322 |
+
jnp.full(mask.shape, -1e10).astype(self.dtype))
|
323 |
+
else:
|
324 |
+
attention_bias = None
|
325 |
+
|
326 |
+
# Add provided bias term (e.g. relative position embedding).
|
327 |
+
if bias is not None:
|
328 |
+
attention_bias = combine_biases(attention_bias, bias)
|
329 |
+
|
330 |
+
dropout_rng = None
|
331 |
+
if not deterministic and self.dropout_rate > 0.:
|
332 |
+
dropout_rng = self.make_rng('dropout')
|
333 |
+
|
334 |
+
# Apply attention.
|
335 |
+
x = dot_product_attention(
|
336 |
+
query,
|
337 |
+
key,
|
338 |
+
value,
|
339 |
+
bias=attention_bias,
|
340 |
+
dropout_rng=dropout_rng,
|
341 |
+
dropout_rate=self.dropout_rate,
|
342 |
+
deterministic=deterministic,
|
343 |
+
dtype=self.dtype,
|
344 |
+
float32_logits=self.float32_logits)
|
345 |
+
|
346 |
+
# Back to the original inputs dimensions.
|
347 |
+
out = DenseGeneral(
|
348 |
+
features=inputs_q.shape[-1], # output dim is set to the input dim.
|
349 |
+
axis=(-2, -1),
|
350 |
+
kernel_init=self.kernel_init,
|
351 |
+
kernel_axes=('joined_kv', 'embed'),
|
352 |
+
dtype=self.dtype,
|
353 |
+
name='out')(
|
354 |
+
x)
|
355 |
+
return out
|
356 |
+
|
357 |
+
|
358 |
+
def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
|
359 |
+
# A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
|
360 |
+
return tuple([ax if ax >= 0 else ndim + ax for ax in axes])
|
361 |
+
|
362 |
+
|
363 |
+
def _canonicalize_tuple(x):
|
364 |
+
if isinstance(x, Iterable):
|
365 |
+
return tuple(x)
|
366 |
+
else:
|
367 |
+
return (x,)
|
368 |
+
|
369 |
+
|
370 |
+
#------------------------------------------------------------------------------
|
371 |
+
# DenseGeneral for attention layers.
|
372 |
+
#------------------------------------------------------------------------------
|
373 |
+
class DenseGeneral(nn.Module):
|
374 |
+
"""A linear transformation (without bias) with flexible axes.
|
375 |
+
|
376 |
+
Attributes:
|
377 |
+
features: tuple with numbers of output features.
|
378 |
+
axis: tuple with axes to apply the transformation on.
|
379 |
+
dtype: the dtype of the computation (default: float32).
|
380 |
+
kernel_init: initializer function for the weight matrix.
|
381 |
+
"""
|
382 |
+
features: Union[Iterable[int], int]
|
383 |
+
axis: Union[Iterable[int], int] = -1
|
384 |
+
dtype: DType = jnp.float32
|
385 |
+
kernel_init: Initializer = nn.initializers.variance_scaling(
|
386 |
+
1.0, 'fan_in', 'truncated_normal')
|
387 |
+
kernel_axes: Tuple[str, ...] = ()
|
388 |
+
|
389 |
+
@nn.compact
|
390 |
+
def __call__(self, inputs: Array) -> Array:
|
391 |
+
"""Applies a linear transformation to the inputs along multiple dimensions.
|
392 |
+
|
393 |
+
Args:
|
394 |
+
inputs: The nd-array to be transformed.
|
395 |
+
|
396 |
+
Returns:
|
397 |
+
The transformed input.
|
398 |
+
"""
|
399 |
+
features = _canonicalize_tuple(self.features)
|
400 |
+
axis = _canonicalize_tuple(self.axis)
|
401 |
+
|
402 |
+
inputs = jnp.asarray(inputs, self.dtype)
|
403 |
+
axis = _normalize_axes(axis, inputs.ndim)
|
404 |
+
|
405 |
+
kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features
|
406 |
+
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),
|
407 |
+
np.prod(features))
|
408 |
+
kernel = param_with_axes(
|
409 |
+
'kernel',
|
410 |
+
self.kernel_init,
|
411 |
+
kernel_param_shape,
|
412 |
+
jnp.float32,
|
413 |
+
axes=self.kernel_axes)
|
414 |
+
kernel = jnp.asarray(kernel, self.dtype)
|
415 |
+
kernel = jnp.reshape(kernel, kernel_shape)
|
416 |
+
|
417 |
+
contract_ind = tuple(range(0, len(axis)))
|
418 |
+
return lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
|
419 |
+
|
420 |
+
|
421 |
+
def _convert_to_activation_function(
|
422 |
+
fn_or_string: Union[str, Callable]) -> Callable:
|
423 |
+
"""Convert a string to an activation function."""
|
424 |
+
if fn_or_string == 'linear':
|
425 |
+
return lambda x: x
|
426 |
+
elif isinstance(fn_or_string, str):
|
427 |
+
return getattr(nn, fn_or_string)
|
428 |
+
elif callable(fn_or_string):
|
429 |
+
return fn_or_string
|
430 |
+
else:
|
431 |
+
raise ValueError("don't know how to convert %s to an activation function" %
|
432 |
+
(fn_or_string,))
|
433 |
+
|
434 |
+
|
435 |
+
class MlpBlock(nn.Module):
|
436 |
+
"""Transformer MLP / feed-forward block.
|
437 |
+
|
438 |
+
Attributes:
|
439 |
+
intermediate_dim: Shared dimension of hidden layers.
|
440 |
+
activations: Type of activations for each layer. Each element is either
|
441 |
+
'linear', a string function name in flax.linen, or a function.
|
442 |
+
kernel_init: Kernel function, passed to the dense layers.
|
443 |
+
deterministic: Whether the dropout layers should be deterministic.
|
444 |
+
intermediate_dropout_rate: Dropout rate used after the intermediate layers.
|
445 |
+
dtype: Type for the dense layer.
|
446 |
+
"""
|
447 |
+
intermediate_dim: int = 2048
|
448 |
+
activations: Sequence[Union[str, Callable]] = ('relu',)
|
449 |
+
kernel_init: Initializer = nn.initializers.variance_scaling(
|
450 |
+
1.0, 'fan_in', 'truncated_normal')
|
451 |
+
intermediate_dropout_rate: float = 0.1
|
452 |
+
dtype: Any = jnp.float32
|
453 |
+
|
454 |
+
@nn.compact
|
455 |
+
def __call__(self, inputs, decode: bool = False, deterministic: bool = False):
|
456 |
+
"""Applies Transformer MlpBlock module."""
|
457 |
+
# Iterate over specified MLP input activation functions.
|
458 |
+
# e.g. ('relu',) or ('gelu', 'linear') for gated-gelu.
|
459 |
+
activations = []
|
460 |
+
for idx, act_fn in enumerate(self.activations):
|
461 |
+
dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}'
|
462 |
+
x = DenseGeneral(
|
463 |
+
self.intermediate_dim,
|
464 |
+
dtype=self.dtype,
|
465 |
+
kernel_init=self.kernel_init,
|
466 |
+
kernel_axes=('embed', 'mlp'),
|
467 |
+
name=dense_name)(
|
468 |
+
inputs)
|
469 |
+
x = _convert_to_activation_function(act_fn)(x)
|
470 |
+
activations.append(x)
|
471 |
+
|
472 |
+
# Take elementwise product of above intermediate activations.
|
473 |
+
x = functools.reduce(operator.mul, activations)
|
474 |
+
# Apply dropout and final dense output projection.
|
475 |
+
x = nn.Dropout(
|
476 |
+
rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))(
|
477 |
+
x, deterministic=deterministic) # Broadcast along length.
|
478 |
+
x = with_sharding_constraint(x, ('batch', 'length', 'mlp'))
|
479 |
+
output = DenseGeneral(
|
480 |
+
inputs.shape[-1],
|
481 |
+
dtype=self.dtype,
|
482 |
+
kernel_init=self.kernel_init,
|
483 |
+
kernel_axes=('mlp', 'embed'),
|
484 |
+
name='wo')(
|
485 |
+
x)
|
486 |
+
return output
|
487 |
+
|
488 |
+
|
489 |
+
class Embed(nn.Module):
|
490 |
+
"""A parameterized function from integers [0, n) to d-dimensional vectors.
|
491 |
+
|
492 |
+
Attributes:
|
493 |
+
num_embeddings: number of embeddings.
|
494 |
+
features: number of feature dimensions for each embedding.
|
495 |
+
dtype: the dtype of the embedding vectors (default: float32).
|
496 |
+
embedding_init: embedding initializer.
|
497 |
+
one_hot: performs the gather with a one-hot contraction rather than a true
|
498 |
+
gather. This is currently needed for SPMD partitioning.
|
499 |
+
"""
|
500 |
+
num_embeddings: int
|
501 |
+
features: int
|
502 |
+
cast_input_dtype: Optional[DType] = None
|
503 |
+
dtype: DType = jnp.float32
|
504 |
+
attend_dtype: Optional[DType] = None
|
505 |
+
embedding_init: Initializer = default_embed_init
|
506 |
+
one_hot: bool = False
|
507 |
+
embedding: Array = dataclasses.field(init=False)
|
508 |
+
|
509 |
+
def setup(self):
|
510 |
+
self.embedding = param_with_axes(
|
511 |
+
'embedding',
|
512 |
+
self.embedding_init, (self.num_embeddings, self.features),
|
513 |
+
jnp.float32,
|
514 |
+
axes=('vocab', 'embed'))
|
515 |
+
|
516 |
+
def __call__(self, inputs: Array) -> Array:
|
517 |
+
"""Embeds the inputs along the last dimension.
|
518 |
+
|
519 |
+
Args:
|
520 |
+
inputs: input data, all dimensions are considered batch dimensions.
|
521 |
+
|
522 |
+
Returns:
|
523 |
+
Output which is embedded input data. The output shape follows the input,
|
524 |
+
with an additional `features` dimension appended.
|
525 |
+
"""
|
526 |
+
if self.cast_input_dtype:
|
527 |
+
inputs = inputs.astype(self.cast_input_dtype)
|
528 |
+
if not jnp.issubdtype(inputs.dtype, jnp.integer):
|
529 |
+
raise ValueError('Input type must be an integer or unsigned integer.')
|
530 |
+
if self.one_hot:
|
531 |
+
iota = lax.iota(jnp.int32, self.num_embeddings)
|
532 |
+
one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype)
|
533 |
+
output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype))
|
534 |
+
else:
|
535 |
+
output = jnp.asarray(self.embedding, self.dtype)[inputs]
|
536 |
+
output = with_sharding_constraint(output, ('batch', 'length', 'embed'))
|
537 |
+
return output
|
538 |
+
|
539 |
+
def attend(self, query: Array) -> Array:
|
540 |
+
"""Attend over the embedding using a query array.
|
541 |
+
|
542 |
+
Args:
|
543 |
+
query: array with last dimension equal the feature depth `features` of the
|
544 |
+
embedding.
|
545 |
+
|
546 |
+
Returns:
|
547 |
+
An array with final dim `num_embeddings` corresponding to the batched
|
548 |
+
inner-product of the array of query vectors against each embedding.
|
549 |
+
Commonly used for weight-sharing between embeddings and logit transform
|
550 |
+
in NLP models.
|
551 |
+
"""
|
552 |
+
dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype
|
553 |
+
return jnp.dot(query, jnp.asarray(self.embedding, dtype).T)
|
554 |
+
|
555 |
+
|
556 |
+
class FixedEmbed(nn.Module):
|
557 |
+
"""Fixed (not learnable) embeddings specified by the initializer function.
|
558 |
+
|
559 |
+
Attributes:
|
560 |
+
init_fn: The initializer function that defines the embeddings.
|
561 |
+
max_length: The maximum supported length.
|
562 |
+
dtype: The DType to use for the embeddings.
|
563 |
+
"""
|
564 |
+
features: int
|
565 |
+
max_length: int = 2048
|
566 |
+
embedding_init: Initializer = sinusoidal()
|
567 |
+
dtype: jnp.dtype = jnp.float32
|
568 |
+
|
569 |
+
def setup(self):
|
570 |
+
# The key is set to None because sinusoid init is deterministic.
|
571 |
+
shape = (self.max_length, self.features)
|
572 |
+
self.embedding = self.embedding_init(None, shape, self.dtype) # pylint: disable=too-many-function-args
|
573 |
+
|
574 |
+
@nn.compact
|
575 |
+
def __call__(self,
|
576 |
+
inputs,
|
577 |
+
*,
|
578 |
+
decode: bool = False):
|
579 |
+
"""Returns the fixed position embeddings specified by the initializer.
|
580 |
+
|
581 |
+
Args:
|
582 |
+
inputs: <int>[batch_size, seq_len] input position indices.
|
583 |
+
decode: True if running in single-position autoregressive decode mode.
|
584 |
+
|
585 |
+
Returns:
|
586 |
+
The fixed position embeddings <float32>[batch_size, seq_len, features].
|
587 |
+
"""
|
588 |
+
# We use a cache position index for tracking decoding position.
|
589 |
+
if decode:
|
590 |
+
position_embedder_index = self.variable(
|
591 |
+
'cache', 'position_embedder_index',
|
592 |
+
lambda: jnp.array(-1, dtype=jnp.uint32))
|
593 |
+
i = position_embedder_index.value
|
594 |
+
position_embedder_index.value = i + 1
|
595 |
+
return jax.lax.dynamic_slice(self.embedding, jnp.array((i, 0)),
|
596 |
+
np.array((1, self.features)))
|
597 |
+
|
598 |
+
return jnp.take(self.embedding, inputs, axis=0)
|
599 |
+
|
600 |
+
|
601 |
+
#------------------------------------------------------------------------------
|
602 |
+
# T5 Layernorm - no subtraction of mean or bias.
|
603 |
+
#------------------------------------------------------------------------------
|
604 |
+
class LayerNorm(nn.Module):
|
605 |
+
"""T5 Layer normalization operating on the last axis of the input data."""
|
606 |
+
epsilon: float = 1e-6
|
607 |
+
dtype: Any = jnp.float32
|
608 |
+
scale_init: Initializer = nn.initializers.ones
|
609 |
+
|
610 |
+
@nn.compact
|
611 |
+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
612 |
+
"""Applies layer normalization on the input."""
|
613 |
+
x = jnp.asarray(x, jnp.float32)
|
614 |
+
features = x.shape[-1]
|
615 |
+
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
|
616 |
+
y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype)
|
617 |
+
scale = param_with_axes(
|
618 |
+
'scale', self.scale_init, (features,), jnp.float32, axes=('embed',))
|
619 |
+
|
620 |
+
scale = jnp.asarray(scale, self.dtype)
|
621 |
+
return y * scale
|
622 |
+
|
623 |
+
|
624 |
+
#------------------------------------------------------------------------------
|
625 |
+
# Mask-making utility functions.
|
626 |
+
#------------------------------------------------------------------------------
|
627 |
+
def make_attention_mask(query_input: Array,
|
628 |
+
key_input: Array,
|
629 |
+
pairwise_fn: Callable = jnp.multiply,
|
630 |
+
extra_batch_dims: int = 0,
|
631 |
+
dtype: DType = jnp.float32) -> Array:
|
632 |
+
"""Mask-making helper for attention weights.
|
633 |
+
|
634 |
+
In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the
|
635 |
+
attention weights will be `[batch, heads, len_q, len_kv]` and this
|
636 |
+
function will produce `[batch, 1, len_q, len_kv]`.
|
637 |
+
|
638 |
+
Args:
|
639 |
+
query_input: a batched, flat input of query_length size
|
640 |
+
key_input: a batched, flat input of key_length size
|
641 |
+
pairwise_fn: broadcasting elementwise comparison function
|
642 |
+
extra_batch_dims: number of extra batch dims to add singleton axes for, none
|
643 |
+
by default
|
644 |
+
dtype: mask return dtype
|
645 |
+
|
646 |
+
Returns:
|
647 |
+
A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention.
|
648 |
+
"""
|
649 |
+
# [batch, len_q, len_kv]
|
650 |
+
mask = pairwise_fn(
|
651 |
+
# [batch, len_q] -> [batch, len_q, 1]
|
652 |
+
jnp.expand_dims(query_input, axis=-1),
|
653 |
+
# [batch, len_q] -> [batch, 1, len_kv]
|
654 |
+
jnp.expand_dims(key_input, axis=-2))
|
655 |
+
|
656 |
+
# [batch, 1, len_q, len_kv]. This creates the head dim.
|
657 |
+
mask = jnp.expand_dims(mask, axis=-3)
|
658 |
+
mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims)))
|
659 |
+
return mask.astype(dtype)
|
660 |
+
|
661 |
+
|
662 |
+
def make_causal_mask(x: Array,
|
663 |
+
extra_batch_dims: int = 0,
|
664 |
+
dtype: DType = jnp.float32) -> Array:
|
665 |
+
"""Make a causal mask for self-attention.
|
666 |
+
|
667 |
+
In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights
|
668 |
+
will be `[batch, heads, len, len]` and this function will produce a
|
669 |
+
causal mask of shape `[batch, 1, len, len]`.
|
670 |
+
|
671 |
+
Note that a causal mask does not depend on the values of x; it only depends on
|
672 |
+
the shape. If x has padding elements, they will not be treated in a special
|
673 |
+
manner.
|
674 |
+
|
675 |
+
Args:
|
676 |
+
x: input array of shape `[batch, len]`
|
677 |
+
extra_batch_dims: number of batch dims to add singleton axes for, none by
|
678 |
+
default
|
679 |
+
dtype: mask return dtype
|
680 |
+
|
681 |
+
Returns:
|
682 |
+
A `[batch, 1, len, len]` shaped causal mask for 1d attention.
|
683 |
+
"""
|
684 |
+
idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
|
685 |
+
return make_attention_mask(
|
686 |
+
idxs,
|
687 |
+
idxs,
|
688 |
+
jnp.greater_equal,
|
689 |
+
extra_batch_dims=extra_batch_dims,
|
690 |
+
dtype=dtype)
|
691 |
+
|
692 |
+
|
693 |
+
def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
|
694 |
+
"""Combine attention masks.
|
695 |
+
|
696 |
+
Args:
|
697 |
+
*masks: set of attention mask arguments to combine, some can be None.
|
698 |
+
dtype: final mask dtype
|
699 |
+
|
700 |
+
Returns:
|
701 |
+
Combined mask, reduced by logical and, returns None if no masks given.
|
702 |
+
"""
|
703 |
+
masks = [m for m in masks if m is not None]
|
704 |
+
if not masks:
|
705 |
+
return None
|
706 |
+
assert all(map(lambda x: x.ndim == masks[0].ndim, masks)), (
|
707 |
+
f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
|
708 |
+
mask, *other_masks = masks
|
709 |
+
for other_mask in other_masks:
|
710 |
+
mask = jnp.logical_and(mask, other_mask)
|
711 |
+
return mask.astype(dtype)
|
712 |
+
|
713 |
+
|
714 |
+
def combine_biases(*masks: Optional[Array]):
|
715 |
+
"""Combine attention biases.
|
716 |
+
|
717 |
+
Args:
|
718 |
+
*masks: set of attention bias arguments to combine, some can be None.
|
719 |
+
|
720 |
+
Returns:
|
721 |
+
Combined mask, reduced by summation, returns None if no masks given.
|
722 |
+
"""
|
723 |
+
masks = [m for m in masks if m is not None]
|
724 |
+
if not masks:
|
725 |
+
return None
|
726 |
+
assert all(map(lambda x: x.ndim == masks[0].ndim, masks)), (
|
727 |
+
f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
|
728 |
+
mask, *other_masks = masks
|
729 |
+
for other_mask in other_masks:
|
730 |
+
mask = mask + other_mask
|
731 |
+
return mask
|
732 |
+
|
733 |
+
|
734 |
+
def make_decoder_mask(decoder_target_tokens: Array,
|
735 |
+
dtype: DType,
|
736 |
+
decoder_causal_attention: Optional[Array] = None,
|
737 |
+
decoder_segment_ids: Optional[Array] = None) -> Array:
|
738 |
+
"""Compute the self-attention mask for a decoder.
|
739 |
+
|
740 |
+
Decoder mask is formed by combining a causal mask, a padding mask and an
|
741 |
+
optional packing mask. If decoder_causal_attention is passed, it makes the
|
742 |
+
masking non-causal for positions that have value of 1.
|
743 |
+
|
744 |
+
A prefix LM is applied to a dataset which has a notion of "inputs" and
|
745 |
+
"targets", e.g., a machine translation task. The inputs and targets are
|
746 |
+
concatenated to form a new target. `decoder_target_tokens` is the concatenated
|
747 |
+
decoder output tokens.
|
748 |
+
|
749 |
+
The "inputs" portion of the concatenated sequence can attend to other "inputs"
|
750 |
+
tokens even for those at a later time steps. In order to control this
|
751 |
+
behavior, `decoder_causal_attention` is necessary. This is a binary mask with
|
752 |
+
a value of 1 indicating that the position belonged to "inputs" portion of the
|
753 |
+
original dataset.
|
754 |
+
|
755 |
+
Example:
|
756 |
+
|
757 |
+
Suppose we have a dataset with two examples.
|
758 |
+
|
759 |
+
ds = [{"inputs": [6, 7], "targets": [8]},
|
760 |
+
{"inputs": [3, 4], "targets": [5]}]
|
761 |
+
|
762 |
+
After the data preprocessing with packing, the two examples are packed into
|
763 |
+
one example with the following three fields (some fields are skipped for
|
764 |
+
simplicity).
|
765 |
+
|
766 |
+
decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]]
|
767 |
+
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]
|
768 |
+
decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]
|
769 |
+
|
770 |
+
where each array has [batch, length] shape with batch size being 1. Then,
|
771 |
+
this function computes the following mask.
|
772 |
+
|
773 |
+
mask = [[[[1, 1, 0, 0, 0, 0, 0],
|
774 |
+
[1, 1, 0, 0, 0, 0, 0],
|
775 |
+
[1, 1, 1, 0, 0, 0, 0],
|
776 |
+
[0, 0, 0, 1, 1, 0, 0],
|
777 |
+
[0, 0, 0, 1, 1, 0, 0],
|
778 |
+
[0, 0, 0, 1, 1, 1, 0],
|
779 |
+
[0, 0, 0, 0, 0, 0, 0]]]]
|
780 |
+
|
781 |
+
mask[b, 1, :, :] represents the mask for the example `b` in the batch.
|
782 |
+
Because mask is for a self-attention layer, the mask's shape is a square of
|
783 |
+
shape [query length, key length].
|
784 |
+
|
785 |
+
mask[b, 1, i, j] = 1 means that the query token at position i can attend to
|
786 |
+
the key token at position j.
|
787 |
+
|
788 |
+
Args:
|
789 |
+
decoder_target_tokens: decoder output tokens. [batch, length]
|
790 |
+
dtype: dtype of the output mask.
|
791 |
+
decoder_causal_attention: a binary mask indicating which position should
|
792 |
+
only attend to earlier positions in the sequence. Others will attend
|
793 |
+
bidirectionally. [batch, length]
|
794 |
+
decoder_segment_ids: decoder segmentation info for packed examples. [batch,
|
795 |
+
length]
|
796 |
+
|
797 |
+
Returns:
|
798 |
+
the combined decoder mask.
|
799 |
+
"""
|
800 |
+
masks = []
|
801 |
+
# The same mask is applied to all attention heads. So the head dimension is 1,
|
802 |
+
# i.e., the mask will be broadcast along the heads dim.
|
803 |
+
# [batch, 1, length, length]
|
804 |
+
causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype)
|
805 |
+
|
806 |
+
# Positions with value 1 in `decoder_causal_attneition` can attend
|
807 |
+
# bidirectionally.
|
808 |
+
if decoder_causal_attention is not None:
|
809 |
+
# [batch, 1, length, length]
|
810 |
+
inputs_mask = make_attention_mask(
|
811 |
+
decoder_causal_attention,
|
812 |
+
decoder_causal_attention,
|
813 |
+
jnp.logical_and,
|
814 |
+
dtype=dtype)
|
815 |
+
masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype))
|
816 |
+
else:
|
817 |
+
masks.append(causal_mask)
|
818 |
+
|
819 |
+
# Padding mask.
|
820 |
+
masks.append(
|
821 |
+
make_attention_mask(
|
822 |
+
decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype))
|
823 |
+
|
824 |
+
# Packing mask
|
825 |
+
if decoder_segment_ids is not None:
|
826 |
+
masks.append(
|
827 |
+
make_attention_mask(
|
828 |
+
decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype))
|
829 |
+
|
830 |
+
return combine_masks(*masks, dtype=dtype)
|
mt3/layers_test.py
ADDED
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for attention classes."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
from typing import Optional
|
19 |
+
from unittest import mock
|
20 |
+
|
21 |
+
from absl.testing import absltest
|
22 |
+
from absl.testing import parameterized
|
23 |
+
from flax import linen as nn
|
24 |
+
from flax.core import freeze
|
25 |
+
from flax.linen import partitioning as nn_partitioning
|
26 |
+
import jax
|
27 |
+
from jax import random
|
28 |
+
from jax.nn import initializers
|
29 |
+
import jax.numpy as jnp
|
30 |
+
from mt3 import layers
|
31 |
+
import numpy as np
|
32 |
+
|
33 |
+
# Parse absl flags test_srcdir and test_tmpdir.
|
34 |
+
jax.config.parse_flags_with_absl()
|
35 |
+
|
36 |
+
Array = jnp.ndarray
|
37 |
+
AxisMetadata = nn_partitioning.AxisMetadata # pylint: disable=invalid-name
|
38 |
+
|
39 |
+
|
40 |
+
class SelfAttention(layers.MultiHeadDotProductAttention):
|
41 |
+
"""Self-attention special case of multi-head dot-product attention."""
|
42 |
+
|
43 |
+
@nn.compact
|
44 |
+
def __call__(self,
|
45 |
+
inputs_q: Array,
|
46 |
+
mask: Optional[Array] = None,
|
47 |
+
bias: Optional[Array] = None,
|
48 |
+
deterministic: bool = False):
|
49 |
+
return super().__call__(
|
50 |
+
inputs_q, inputs_q, mask, bias, deterministic=deterministic)
|
51 |
+
|
52 |
+
|
53 |
+
@dataclasses.dataclass(frozen=True)
|
54 |
+
class SelfAttentionArgs:
|
55 |
+
num_heads: int = 1
|
56 |
+
batch_size: int = 2
|
57 |
+
# qkv_features: int = 3
|
58 |
+
head_dim: int = 3
|
59 |
+
# out_features: int = 4
|
60 |
+
q_len: int = 5
|
61 |
+
features: int = 6
|
62 |
+
dropout_rate: float = 0.1
|
63 |
+
deterministic: bool = False
|
64 |
+
decode: bool = False
|
65 |
+
float32_logits: bool = False
|
66 |
+
|
67 |
+
def __post_init__(self):
|
68 |
+
# If we are doing decoding, the query length should be 1, because are doing
|
69 |
+
# autoregressive decoding where we feed one position at a time.
|
70 |
+
assert not self.decode or self.q_len == 1
|
71 |
+
|
72 |
+
def init_args(self):
|
73 |
+
return dict(
|
74 |
+
num_heads=self.num_heads,
|
75 |
+
head_dim=self.head_dim,
|
76 |
+
dropout_rate=self.dropout_rate,
|
77 |
+
float32_logits=self.float32_logits)
|
78 |
+
|
79 |
+
def apply_args(self):
|
80 |
+
inputs_q = jnp.ones((self.batch_size, self.q_len, self.features))
|
81 |
+
mask = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len))
|
82 |
+
bias = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len))
|
83 |
+
return {
|
84 |
+
'inputs_q': inputs_q,
|
85 |
+
'mask': mask,
|
86 |
+
'bias': bias,
|
87 |
+
'deterministic': self.deterministic
|
88 |
+
}
|
89 |
+
|
90 |
+
|
91 |
+
class AttentionTest(parameterized.TestCase):
|
92 |
+
|
93 |
+
def test_dot_product_attention_shape(self):
|
94 |
+
# This test only checks for shape but tries to make sure all code paths are
|
95 |
+
# reached.
|
96 |
+
dropout_rng = random.PRNGKey(0)
|
97 |
+
batch_size, num_heads, q_len, kv_len, qk_depth, v_depth = 1, 2, 3, 4, 5, 6
|
98 |
+
|
99 |
+
query = jnp.ones((batch_size, q_len, num_heads, qk_depth))
|
100 |
+
key = jnp.ones((batch_size, kv_len, num_heads, qk_depth))
|
101 |
+
value = jnp.ones((batch_size, kv_len, num_heads, v_depth))
|
102 |
+
bias = jnp.ones((batch_size, num_heads, q_len, kv_len))
|
103 |
+
|
104 |
+
args = dict(
|
105 |
+
query=query,
|
106 |
+
key=key,
|
107 |
+
value=value,
|
108 |
+
bias=bias,
|
109 |
+
dropout_rng=dropout_rng,
|
110 |
+
dropout_rate=0.5,
|
111 |
+
deterministic=False,
|
112 |
+
)
|
113 |
+
|
114 |
+
output = layers.dot_product_attention(**args)
|
115 |
+
self.assertEqual(output.shape, (batch_size, q_len, num_heads, v_depth))
|
116 |
+
|
117 |
+
def test_make_attention_mask_multiply_pairwise_fn(self):
|
118 |
+
decoder_target_tokens = jnp.array([[7, 0, 0], [8, 5, 0]])
|
119 |
+
attention_mask = layers.make_attention_mask(
|
120 |
+
decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=jnp.int32)
|
121 |
+
expected0 = jnp.array([[1, 0, 0], [0, 0, 0], [0, 0, 0]])
|
122 |
+
expected1 = jnp.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]])
|
123 |
+
self.assertEqual(attention_mask.shape, (2, 1, 3, 3))
|
124 |
+
np.testing.assert_array_equal(attention_mask[0, 0], expected0)
|
125 |
+
np.testing.assert_array_equal(attention_mask[1, 0], expected1)
|
126 |
+
|
127 |
+
def test_make_attention_mask_equal_pairwise_fn(self):
|
128 |
+
segment_ids = jnp.array([[1, 1, 2, 2, 2, 0], [1, 1, 1, 2, 0, 0]])
|
129 |
+
attention_mask = layers.make_attention_mask(
|
130 |
+
segment_ids, segment_ids, pairwise_fn=jnp.equal, dtype=jnp.int32)
|
131 |
+
# Padding is not treated in a special way. So they need to be zeroed out
|
132 |
+
# separately.
|
133 |
+
expected0 = jnp.array([[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0],
|
134 |
+
[0, 0, 1, 1, 1, 0], [0, 0, 1, 1, 1, 0],
|
135 |
+
[0, 0, 1, 1, 1, 0], [0, 0, 0, 0, 0, 1]])
|
136 |
+
expected1 = jnp.array([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0],
|
137 |
+
[1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0],
|
138 |
+
[0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1]])
|
139 |
+
self.assertEqual(attention_mask.shape, (2, 1, 6, 6))
|
140 |
+
np.testing.assert_array_equal(attention_mask[0, 0], expected0)
|
141 |
+
np.testing.assert_array_equal(attention_mask[1, 0], expected1)
|
142 |
+
|
143 |
+
def test_make_causal_mask_with_padding(self):
|
144 |
+
x = jnp.array([[7, 0, 0], [8, 5, 0]])
|
145 |
+
y = layers.make_causal_mask(x)
|
146 |
+
self.assertEqual(y.shape, (2, 1, 3, 3))
|
147 |
+
# Padding is not treated in a special way. So they need to be zeroed out
|
148 |
+
# separately.
|
149 |
+
expected_y = jnp.array([[[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]],
|
150 |
+
jnp.float32)
|
151 |
+
np.testing.assert_allclose(y[0], expected_y)
|
152 |
+
np.testing.assert_allclose(y[1], expected_y)
|
153 |
+
|
154 |
+
def test_make_causal_mask_extra_batch_dims(self):
|
155 |
+
x = jnp.ones((3, 3, 5))
|
156 |
+
y = layers.make_causal_mask(x, extra_batch_dims=2)
|
157 |
+
self.assertEqual(y.shape, (1, 1, 3, 3, 1, 5, 5))
|
158 |
+
|
159 |
+
def test_make_causal_mask(self):
|
160 |
+
x = jnp.ones((1, 3))
|
161 |
+
y = layers.make_causal_mask(x)
|
162 |
+
self.assertEqual(y.shape, (1, 1, 3, 3))
|
163 |
+
expected_y = jnp.array([[[[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]]],
|
164 |
+
jnp.float32)
|
165 |
+
np.testing.assert_allclose(y, expected_y)
|
166 |
+
|
167 |
+
def test_combine_masks(self):
|
168 |
+
masks = [
|
169 |
+
jnp.array([0, 1, 0, 1], jnp.float32), None,
|
170 |
+
jnp.array([1, 1, 1, 1], jnp.float32),
|
171 |
+
jnp.array([1, 1, 1, 0], jnp.float32)
|
172 |
+
]
|
173 |
+
y = layers.combine_masks(*masks)
|
174 |
+
np.testing.assert_allclose(y, jnp.array([0, 1, 0, 0], jnp.float32))
|
175 |
+
|
176 |
+
def test_combine_biases(self):
|
177 |
+
masks = [
|
178 |
+
jnp.array([0, 1, 0, 1], jnp.float32), None,
|
179 |
+
jnp.array([0, 1, 1, 1], jnp.float32),
|
180 |
+
jnp.array([0, 1, 1, 0], jnp.float32)
|
181 |
+
]
|
182 |
+
y = layers.combine_biases(*masks)
|
183 |
+
np.testing.assert_allclose(y, jnp.array([0, 3, 2, 2], jnp.float32))
|
184 |
+
|
185 |
+
def test_make_decoder_mask_lm_unpacked(self):
|
186 |
+
decoder_target_tokens = jnp.array([6, 7, 3, 0])
|
187 |
+
mask = layers.make_decoder_mask(
|
188 |
+
decoder_target_tokens=decoder_target_tokens, dtype=jnp.float32)
|
189 |
+
expected_mask = jnp.array([[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0],
|
190 |
+
[0, 0, 0, 0]]])
|
191 |
+
np.testing.assert_array_equal(mask, expected_mask)
|
192 |
+
|
193 |
+
def test_make_decoder_mask_lm_packed(self):
|
194 |
+
decoder_target_tokens = jnp.array([[6, 7, 3, 4, 5, 0]])
|
195 |
+
decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 0]])
|
196 |
+
mask = layers.make_decoder_mask(
|
197 |
+
decoder_target_tokens=decoder_target_tokens,
|
198 |
+
dtype=jnp.float32,
|
199 |
+
decoder_segment_ids=decoder_segment_ids)
|
200 |
+
expected_mask = jnp.array([[[[1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0],
|
201 |
+
[1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0],
|
202 |
+
[0, 0, 0, 1, 1, 0], [0, 0, 0, 0, 0, 0]]]])
|
203 |
+
np.testing.assert_array_equal(mask, expected_mask)
|
204 |
+
|
205 |
+
def test_make_decoder_mask_prefix_lm_unpacked(self):
|
206 |
+
decoder_target_tokens = jnp.array([[5, 6, 7, 3, 4, 0]])
|
207 |
+
decoder_causal_attention = jnp.array([[1, 1, 1, 0, 0, 0]])
|
208 |
+
mask = layers.make_decoder_mask(
|
209 |
+
decoder_target_tokens=decoder_target_tokens,
|
210 |
+
dtype=jnp.float32,
|
211 |
+
decoder_causal_attention=decoder_causal_attention)
|
212 |
+
expected_mask = jnp.array(
|
213 |
+
[[[[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0],
|
214 |
+
[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 0], [0, 0, 0, 0, 0, 0]]]],
|
215 |
+
dtype=jnp.float32)
|
216 |
+
np.testing.assert_array_equal(mask, expected_mask)
|
217 |
+
|
218 |
+
def test_make_decoder_mask_prefix_lm_packed(self):
|
219 |
+
decoder_target_tokens = jnp.array([[5, 6, 7, 8, 3, 4, 0]])
|
220 |
+
decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 2, 0]])
|
221 |
+
decoder_causal_attention = jnp.array([[1, 1, 0, 1, 1, 0, 0]])
|
222 |
+
mask = layers.make_decoder_mask(
|
223 |
+
decoder_target_tokens=decoder_target_tokens,
|
224 |
+
dtype=jnp.float32,
|
225 |
+
decoder_causal_attention=decoder_causal_attention,
|
226 |
+
decoder_segment_ids=decoder_segment_ids)
|
227 |
+
expected_mask = jnp.array([[[[1, 1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0],
|
228 |
+
[1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0],
|
229 |
+
[0, 0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 1, 1, 0],
|
230 |
+
[0, 0, 0, 0, 0, 0, 0]]]])
|
231 |
+
np.testing.assert_array_equal(mask, expected_mask)
|
232 |
+
|
233 |
+
def test_make_decoder_mask_prefix_lm_unpacked_multiple_elements(self):
|
234 |
+
decoder_target_tokens = jnp.array([[6, 7, 3, 0], [4, 5, 0, 0]])
|
235 |
+
decoder_causal_attention = jnp.array([[1, 1, 0, 0], [1, 0, 0, 0]])
|
236 |
+
mask = layers.make_decoder_mask(
|
237 |
+
decoder_target_tokens=decoder_target_tokens,
|
238 |
+
dtype=jnp.float32,
|
239 |
+
decoder_causal_attention=decoder_causal_attention)
|
240 |
+
expected_mask0 = jnp.array([[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0],
|
241 |
+
[0, 0, 0, 0]])
|
242 |
+
expected_mask1 = jnp.array([[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0],
|
243 |
+
[0, 0, 0, 0]])
|
244 |
+
self.assertEqual(mask.shape, (2, 1, 4, 4))
|
245 |
+
np.testing.assert_array_equal(mask[0, 0], expected_mask0)
|
246 |
+
np.testing.assert_array_equal(mask[1, 0], expected_mask1)
|
247 |
+
|
248 |
+
def test_make_decoder_mask_composite_causal_attention(self):
|
249 |
+
decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 0]])
|
250 |
+
decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 0]])
|
251 |
+
mask = layers.make_decoder_mask(
|
252 |
+
decoder_target_tokens=decoder_target_tokens,
|
253 |
+
dtype=jnp.float32,
|
254 |
+
decoder_causal_attention=decoder_causal_attention)
|
255 |
+
expected_mask0 = jnp.array([[1, 1, 0, 0, 1, 1, 0], [1, 1, 0, 0, 1, 1, 0],
|
256 |
+
[1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0],
|
257 |
+
[1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 0],
|
258 |
+
[0, 0, 0, 0, 0, 0, 0]])
|
259 |
+
|
260 |
+
self.assertEqual(mask.shape, (1, 1, 7, 7))
|
261 |
+
np.testing.assert_array_equal(mask[0, 0], expected_mask0)
|
262 |
+
|
263 |
+
def test_make_decoder_mask_composite_causal_attention_packed(self):
|
264 |
+
decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 2, 3, 4]])
|
265 |
+
decoder_segment_ids = jnp.array([[1, 1, 1, 1, 1, 1, 2, 2, 2]])
|
266 |
+
decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 1, 1, 0]])
|
267 |
+
mask = layers.make_decoder_mask(
|
268 |
+
decoder_target_tokens=decoder_target_tokens,
|
269 |
+
dtype=jnp.float32,
|
270 |
+
decoder_causal_attention=decoder_causal_attention,
|
271 |
+
decoder_segment_ids=decoder_segment_ids)
|
272 |
+
expected_mask0 = jnp.array([[1, 1, 0, 0, 1, 1, 0, 0, 0],
|
273 |
+
[1, 1, 0, 0, 1, 1, 0, 0, 0],
|
274 |
+
[1, 1, 1, 0, 0, 0, 0, 0, 0],
|
275 |
+
[1, 1, 1, 1, 0, 0, 0, 0, 0],
|
276 |
+
[1, 1, 1, 1, 1, 1, 0, 0, 0],
|
277 |
+
[1, 1, 1, 1, 1, 1, 0, 0, 0],
|
278 |
+
[0, 0, 0, 0, 0, 0, 1, 1, 0],
|
279 |
+
[0, 0, 0, 0, 0, 0, 1, 1, 0],
|
280 |
+
[0, 0, 0, 0, 0, 0, 1, 1, 1]])
|
281 |
+
|
282 |
+
self.assertEqual(mask.shape, (1, 1, 9, 9))
|
283 |
+
np.testing.assert_array_equal(mask[0, 0], expected_mask0)
|
284 |
+
|
285 |
+
@parameterized.parameters({'f': 20}, {'f': 22})
|
286 |
+
def test_multihead_dot_product_attention(self, f):
|
287 |
+
# b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim
|
288 |
+
b, q, h, d, k = 2, 3, 4, 5, 6
|
289 |
+
|
290 |
+
base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0)
|
291 |
+
args = base_args.init_args()
|
292 |
+
|
293 |
+
np.random.seed(0)
|
294 |
+
inputs_q = np.random.randn(b, q, f)
|
295 |
+
inputs_kv = np.random.randn(b, k, f)
|
296 |
+
|
297 |
+
# Projection: [b, q, f] -> [b, q, h, d]
|
298 |
+
# So the kernels have to be [f, h, d]
|
299 |
+
query_kernel = np.random.randn(f, h, d)
|
300 |
+
key_kernel = np.random.randn(f, h, d)
|
301 |
+
value_kernel = np.random.randn(f, h, d)
|
302 |
+
# `out` calculation: [b, q, h, d] -> [b, q, f]
|
303 |
+
# So kernel has to be [h, d, f]
|
304 |
+
out_kernel = np.random.randn(h, d, f)
|
305 |
+
|
306 |
+
params = {
|
307 |
+
'query': {
|
308 |
+
'kernel': query_kernel.reshape(f, -1)
|
309 |
+
},
|
310 |
+
'key': {
|
311 |
+
'kernel': key_kernel.reshape(f, -1)
|
312 |
+
},
|
313 |
+
'value': {
|
314 |
+
'kernel': value_kernel.reshape(f, -1)
|
315 |
+
},
|
316 |
+
'out': {
|
317 |
+
'kernel': out_kernel.reshape(-1, f)
|
318 |
+
}
|
319 |
+
}
|
320 |
+
y = layers.MultiHeadDotProductAttention(**args).apply(
|
321 |
+
{'params': freeze(params)}, inputs_q, inputs_kv)
|
322 |
+
|
323 |
+
query = np.einsum('bqf,fhd->bqhd', inputs_q, query_kernel)
|
324 |
+
key = np.einsum('bkf,fhd->bkhd', inputs_kv, key_kernel)
|
325 |
+
value = np.einsum('bkf,fhd->bkhd', inputs_kv, value_kernel)
|
326 |
+
logits = np.einsum('bqhd,bkhd->bhqk', query, key)
|
327 |
+
weights = nn.softmax(logits, axis=-1)
|
328 |
+
combined_value = np.einsum('bhqk,bkhd->bqhd', weights, value)
|
329 |
+
y_expected = np.einsum('bqhd,hdf->bqf', combined_value, out_kernel)
|
330 |
+
np.testing.assert_allclose(y, y_expected, rtol=1e-5, atol=1e-5)
|
331 |
+
|
332 |
+
def test_multihead_dot_product_attention_caching(self):
|
333 |
+
# b: batch, f: qkv_features, k: kv_len, h: num_head, d: head_dim
|
334 |
+
b, h, d, k = 2, 3, 4, 5
|
335 |
+
f = h * d
|
336 |
+
|
337 |
+
base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0)
|
338 |
+
args = base_args.init_args()
|
339 |
+
|
340 |
+
cache = {
|
341 |
+
'cached_key': np.zeros((b, h, d, k)),
|
342 |
+
'cached_value': np.zeros((b, h, d, k)),
|
343 |
+
'cache_index': np.array(0)
|
344 |
+
}
|
345 |
+
inputs_q = np.random.randn(b, 1, f)
|
346 |
+
inputs_kv = np.random.randn(b, 1, f)
|
347 |
+
|
348 |
+
# Mock dense general such that q, k, v projections are replaced by simple
|
349 |
+
# reshaping.
|
350 |
+
def mock_dense_general(self, x, **kwargs): # pylint: disable=unused-argument
|
351 |
+
return x.reshape(b, -1, h, d)
|
352 |
+
|
353 |
+
with mock.patch.object(
|
354 |
+
layers.DenseGeneral, '__call__', new=mock_dense_general):
|
355 |
+
_, mutated = layers.MultiHeadDotProductAttention(**args).apply(
|
356 |
+
{'cache': freeze(cache)},
|
357 |
+
inputs_q,
|
358 |
+
inputs_kv,
|
359 |
+
decode=True,
|
360 |
+
mutable=['cache'])
|
361 |
+
updated_cache = mutated['cache']
|
362 |
+
|
363 |
+
# Perform the same mocked projection to generate the expected cache.
|
364 |
+
# (key|value): [b, 1, h, d]
|
365 |
+
key = mock_dense_general(None, inputs_kv)
|
366 |
+
value = mock_dense_general(None, inputs_kv)
|
367 |
+
|
368 |
+
# cached_(key|value): [b, h, d, k]
|
369 |
+
cache['cached_key'][:, :, :, 0] = key[:, 0, :, :]
|
370 |
+
cache['cached_value'][:, :, :, 0] = value[:, 0, :, :]
|
371 |
+
cache['cache_index'] = np.array(1)
|
372 |
+
for name, array in cache.items():
|
373 |
+
np.testing.assert_allclose(array, updated_cache[name])
|
374 |
+
|
375 |
+
def test_dot_product_attention(self):
|
376 |
+
# b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim
|
377 |
+
b, q, h, d, k = 2, 3, 4, 5, 6
|
378 |
+
np.random.seed(0)
|
379 |
+
query = np.random.randn(b, q, h, d)
|
380 |
+
key = np.random.randn(b, k, h, d)
|
381 |
+
value = np.random.randn(b, k, h, d)
|
382 |
+
bias = np.random.randn(b, h, q, k)
|
383 |
+
attn_out = layers.dot_product_attention(query, key, value, bias=bias)
|
384 |
+
logits = np.einsum('bqhd,bkhd->bhqk', query, key)
|
385 |
+
weights = jax.nn.softmax(logits + bias, axis=-1)
|
386 |
+
expected = np.einsum('bhqk,bkhd->bqhd', weights, value)
|
387 |
+
np.testing.assert_allclose(attn_out, expected, atol=1e-6)
|
388 |
+
|
389 |
+
|
390 |
+
class EmbeddingTest(parameterized.TestCase):
|
391 |
+
|
392 |
+
def test_embedder_raises_exception_for_incorrect_input_type(self):
|
393 |
+
"""Tests that inputs are integers and that an exception is raised if not."""
|
394 |
+
embed = layers.Embed(num_embeddings=10, features=5)
|
395 |
+
inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1)
|
396 |
+
variables = embed.init(jax.random.PRNGKey(0), inputs)
|
397 |
+
bad_inputs = inputs.astype(np.float32)
|
398 |
+
with self.assertRaisesRegex(
|
399 |
+
ValueError, 'Input type must be an integer or unsigned integer.'):
|
400 |
+
_ = embed.apply(variables, bad_inputs)
|
401 |
+
|
402 |
+
@parameterized.named_parameters(
|
403 |
+
{
|
404 |
+
'testcase_name': 'with_ones',
|
405 |
+
'init_fn': jax.nn.initializers.ones,
|
406 |
+
'num_embeddings': 10,
|
407 |
+
'features': 5,
|
408 |
+
'matrix_sum': 5 * 10,
|
409 |
+
}, {
|
410 |
+
'testcase_name': 'with_zeros',
|
411 |
+
'init_fn': jax.nn.initializers.zeros,
|
412 |
+
'num_embeddings': 10,
|
413 |
+
'features': 5,
|
414 |
+
'matrix_sum': 0,
|
415 |
+
})
|
416 |
+
def test_embedding_initializes_correctly(self, init_fn, num_embeddings,
|
417 |
+
features, matrix_sum):
|
418 |
+
"""Tests if the Embed class initializes with the requested initializer."""
|
419 |
+
embed = layers.Embed(
|
420 |
+
num_embeddings=num_embeddings,
|
421 |
+
features=features,
|
422 |
+
embedding_init=init_fn)
|
423 |
+
inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1)
|
424 |
+
variables = embed.init(jax.random.PRNGKey(0), inputs)
|
425 |
+
embedding_matrix = variables['params']['embedding']
|
426 |
+
self.assertEqual(int(np.sum(embedding_matrix)), matrix_sum)
|
427 |
+
|
428 |
+
def test_embedding_matrix_shape(self):
|
429 |
+
"""Tests that the embedding matrix has the right shape."""
|
430 |
+
num_embeddings = 10
|
431 |
+
features = 5
|
432 |
+
embed = layers.Embed(num_embeddings=num_embeddings, features=features)
|
433 |
+
inputs = np.expand_dims(np.arange(features, dtype=np.int64), 1)
|
434 |
+
variables = embed.init(jax.random.PRNGKey(0), inputs)
|
435 |
+
embedding_matrix = variables['params']['embedding']
|
436 |
+
self.assertEqual((num_embeddings, features), embedding_matrix.shape)
|
437 |
+
|
438 |
+
def test_embedding_attend(self):
|
439 |
+
"""Tests that attending with ones returns sum of embedding vectors."""
|
440 |
+
features = 5
|
441 |
+
embed = layers.Embed(num_embeddings=10, features=features)
|
442 |
+
inputs = np.array([[1]], dtype=np.int64)
|
443 |
+
variables = embed.init(jax.random.PRNGKey(0), inputs)
|
444 |
+
query = np.ones(features, dtype=np.float32)
|
445 |
+
result = embed.apply(variables, query, method=embed.attend)
|
446 |
+
expected = np.sum(variables['params']['embedding'], -1)
|
447 |
+
np.testing.assert_array_almost_equal(result, expected)
|
448 |
+
|
449 |
+
|
450 |
+
class DenseTest(parameterized.TestCase):
|
451 |
+
|
452 |
+
def test_dense_general_no_bias(self):
|
453 |
+
rng = random.PRNGKey(0)
|
454 |
+
x = jnp.ones((1, 3))
|
455 |
+
model = layers.DenseGeneral(
|
456 |
+
features=4,
|
457 |
+
kernel_init=initializers.ones,
|
458 |
+
)
|
459 |
+
y, _ = model.init_with_output(rng, x)
|
460 |
+
self.assertEqual(y.shape, (1, 4))
|
461 |
+
np.testing.assert_allclose(y, np.full((1, 4), 3.))
|
462 |
+
|
463 |
+
def test_dense_general_two_features(self):
|
464 |
+
rng = random.PRNGKey(0)
|
465 |
+
x = jnp.ones((1, 3))
|
466 |
+
model = layers.DenseGeneral(
|
467 |
+
features=(2, 2),
|
468 |
+
kernel_init=initializers.ones,
|
469 |
+
)
|
470 |
+
y, _ = model.init_with_output(rng, x)
|
471 |
+
# We transform the last input dimension to two output dimensions (2, 2).
|
472 |
+
np.testing.assert_allclose(y, np.full((1, 2, 2), 3.))
|
473 |
+
|
474 |
+
def test_dense_general_two_axes(self):
|
475 |
+
rng = random.PRNGKey(0)
|
476 |
+
x = jnp.ones((1, 2, 2))
|
477 |
+
model = layers.DenseGeneral(
|
478 |
+
features=3,
|
479 |
+
axis=(-2, 2), # Note: this is the same as (1, 2).
|
480 |
+
kernel_init=initializers.ones,
|
481 |
+
)
|
482 |
+
y, _ = model.init_with_output(rng, x)
|
483 |
+
# We transform the last two input dimensions (2, 2) to one output dimension.
|
484 |
+
np.testing.assert_allclose(y, np.full((1, 3), 4.))
|
485 |
+
|
486 |
+
def test_mlp_same_out_dim(self):
|
487 |
+
module = layers.MlpBlock(
|
488 |
+
intermediate_dim=4,
|
489 |
+
activations=('relu',),
|
490 |
+
kernel_init=nn.initializers.xavier_uniform(),
|
491 |
+
dtype=jnp.float32,
|
492 |
+
)
|
493 |
+
inputs = np.array(
|
494 |
+
[
|
495 |
+
# Batch 1.
|
496 |
+
[[1, 1], [1, 1], [1, 2]],
|
497 |
+
# Batch 2.
|
498 |
+
[[2, 2], [3, 1], [2, 2]],
|
499 |
+
],
|
500 |
+
dtype=np.float32)
|
501 |
+
params = module.init(random.PRNGKey(0), inputs, deterministic=True)
|
502 |
+
self.assertEqual(
|
503 |
+
jax.tree_map(lambda a: a.tolist(), params), {
|
504 |
+
'params': {
|
505 |
+
'wi': {
|
506 |
+
'kernel': [[
|
507 |
+
-0.8675811290740967, 0.08417510986328125,
|
508 |
+
0.022586345672607422, -0.9124102592468262
|
509 |
+
],
|
510 |
+
[
|
511 |
+
-0.19464373588562012, 0.49809837341308594,
|
512 |
+
0.7808468341827393, 0.9267289638519287
|
513 |
+
]],
|
514 |
+
},
|
515 |
+
'wo': {
|
516 |
+
'kernel': [[0.01154780387878418, 0.1397249698638916],
|
517 |
+
[0.974980354309082, 0.5903260707855225],
|
518 |
+
[-0.05997943878173828, 0.616570234298706],
|
519 |
+
[0.2934272289276123, 0.8181164264678955]],
|
520 |
+
},
|
521 |
+
},
|
522 |
+
'params_axes': {
|
523 |
+
'wi': {
|
524 |
+
'kernel_axes': AxisMetadata(names=('embed', 'mlp')),
|
525 |
+
},
|
526 |
+
'wo': {
|
527 |
+
'kernel_axes': AxisMetadata(names=('mlp', 'embed')),
|
528 |
+
},
|
529 |
+
},
|
530 |
+
})
|
531 |
+
result = module.apply(params, inputs, deterministic=True)
|
532 |
+
np.testing.assert_allclose(
|
533 |
+
result.tolist(),
|
534 |
+
[[[0.5237172245979309, 0.8508185744285583],
|
535 |
+
[0.5237172245979309, 0.8508185744285583],
|
536 |
+
[1.2344461679458618, 2.3844780921936035]],
|
537 |
+
[[1.0474344491958618, 1.7016371488571167],
|
538 |
+
[0.6809444427490234, 0.9663378596305847],
|
539 |
+
[1.0474344491958618, 1.7016371488571167]]],
|
540 |
+
rtol=1e-6,
|
541 |
+
)
|
542 |
+
|
543 |
+
|
544 |
+
if __name__ == '__main__':
|
545 |
+
absltest.main()
|
mt3/metrics.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Transcription metrics."""
|
16 |
+
|
17 |
+
import collections
|
18 |
+
import copy
|
19 |
+
import functools
|
20 |
+
from typing import Any, Iterable, Mapping, Optional, Sequence
|
21 |
+
|
22 |
+
import mir_eval
|
23 |
+
|
24 |
+
from mt3 import event_codec
|
25 |
+
from mt3 import metrics_utils
|
26 |
+
from mt3 import note_sequences
|
27 |
+
from mt3 import spectrograms
|
28 |
+
from mt3 import summaries
|
29 |
+
from mt3 import vocabularies
|
30 |
+
|
31 |
+
import note_seq
|
32 |
+
import numpy as np
|
33 |
+
import seqio
|
34 |
+
|
35 |
+
|
36 |
+
def _program_aware_note_scores(
|
37 |
+
ref_ns: note_seq.NoteSequence,
|
38 |
+
est_ns: note_seq.NoteSequence,
|
39 |
+
granularity_type: str
|
40 |
+
) -> Mapping[str, float]:
|
41 |
+
"""Compute precision/recall/F1 for notes taking program into account.
|
42 |
+
|
43 |
+
For non-drum tracks, uses onsets and offsets. For drum tracks, uses onsets
|
44 |
+
only. Applies MIDI program map of specified granularity type.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
ref_ns: Reference NoteSequence with ground truth labels.
|
48 |
+
est_ns: Estimated NoteSequence.
|
49 |
+
granularity_type: String key in vocabularies.PROGRAM_GRANULARITIES dict.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
A dictionary containing precision, recall, and F1 score.
|
53 |
+
"""
|
54 |
+
program_map_fn = vocabularies.PROGRAM_GRANULARITIES[
|
55 |
+
granularity_type].program_map_fn
|
56 |
+
|
57 |
+
ref_ns = copy.deepcopy(ref_ns)
|
58 |
+
for note in ref_ns.notes:
|
59 |
+
if not note.is_drum:
|
60 |
+
note.program = program_map_fn(note.program)
|
61 |
+
|
62 |
+
est_ns = copy.deepcopy(est_ns)
|
63 |
+
for note in est_ns.notes:
|
64 |
+
if not note.is_drum:
|
65 |
+
note.program = program_map_fn(note.program)
|
66 |
+
|
67 |
+
program_and_is_drum_tuples = (
|
68 |
+
set((note.program, note.is_drum) for note in ref_ns.notes) |
|
69 |
+
set((note.program, note.is_drum) for note in est_ns.notes)
|
70 |
+
)
|
71 |
+
|
72 |
+
drum_precision_sum = 0.0
|
73 |
+
drum_precision_count = 0
|
74 |
+
drum_recall_sum = 0.0
|
75 |
+
drum_recall_count = 0
|
76 |
+
|
77 |
+
nondrum_precision_sum = 0.0
|
78 |
+
nondrum_precision_count = 0
|
79 |
+
nondrum_recall_sum = 0.0
|
80 |
+
nondrum_recall_count = 0
|
81 |
+
|
82 |
+
for program, is_drum in program_and_is_drum_tuples:
|
83 |
+
est_track = note_sequences.extract_track(est_ns, program, is_drum)
|
84 |
+
ref_track = note_sequences.extract_track(ref_ns, program, is_drum)
|
85 |
+
|
86 |
+
est_intervals, est_pitches, unused_est_velocities = (
|
87 |
+
note_seq.sequences_lib.sequence_to_valued_intervals(est_track))
|
88 |
+
ref_intervals, ref_pitches, unused_ref_velocities = (
|
89 |
+
note_seq.sequences_lib.sequence_to_valued_intervals(ref_track))
|
90 |
+
|
91 |
+
args = {
|
92 |
+
'ref_intervals': ref_intervals, 'ref_pitches': ref_pitches,
|
93 |
+
'est_intervals': est_intervals, 'est_pitches': est_pitches
|
94 |
+
}
|
95 |
+
if is_drum:
|
96 |
+
args['offset_ratio'] = None
|
97 |
+
|
98 |
+
precision, recall, unused_f_measure, unused_avg_overlap_ratio = (
|
99 |
+
mir_eval.transcription.precision_recall_f1_overlap(**args))
|
100 |
+
|
101 |
+
if is_drum:
|
102 |
+
drum_precision_sum += precision * len(est_intervals)
|
103 |
+
drum_precision_count += len(est_intervals)
|
104 |
+
drum_recall_sum += recall * len(ref_intervals)
|
105 |
+
drum_recall_count += len(ref_intervals)
|
106 |
+
else:
|
107 |
+
nondrum_precision_sum += precision * len(est_intervals)
|
108 |
+
nondrum_precision_count += len(est_intervals)
|
109 |
+
nondrum_recall_sum += recall * len(ref_intervals)
|
110 |
+
nondrum_recall_count += len(ref_intervals)
|
111 |
+
|
112 |
+
precision_sum = drum_precision_sum + nondrum_precision_sum
|
113 |
+
precision_count = drum_precision_count + nondrum_precision_count
|
114 |
+
recall_sum = drum_recall_sum + nondrum_recall_sum
|
115 |
+
recall_count = drum_recall_count + nondrum_recall_count
|
116 |
+
|
117 |
+
precision = (precision_sum / precision_count) if precision_count else 0
|
118 |
+
recall = (recall_sum / recall_count) if recall_count else 0
|
119 |
+
f_measure = mir_eval.util.f_measure(precision, recall)
|
120 |
+
|
121 |
+
drum_precision = ((drum_precision_sum / drum_precision_count)
|
122 |
+
if drum_precision_count else 0)
|
123 |
+
drum_recall = ((drum_recall_sum / drum_recall_count)
|
124 |
+
if drum_recall_count else 0)
|
125 |
+
drum_f_measure = mir_eval.util.f_measure(drum_precision, drum_recall)
|
126 |
+
|
127 |
+
nondrum_precision = ((nondrum_precision_sum / nondrum_precision_count)
|
128 |
+
if nondrum_precision_count else 0)
|
129 |
+
nondrum_recall = ((nondrum_recall_sum / nondrum_recall_count)
|
130 |
+
if nondrum_recall_count else 0)
|
131 |
+
nondrum_f_measure = mir_eval.util.f_measure(nondrum_precision, nondrum_recall)
|
132 |
+
|
133 |
+
return {
|
134 |
+
f'Onset + offset + program precision ({granularity_type})': precision,
|
135 |
+
f'Onset + offset + program recall ({granularity_type})': recall,
|
136 |
+
f'Onset + offset + program F1 ({granularity_type})': f_measure,
|
137 |
+
f'Drum onset precision ({granularity_type})': drum_precision,
|
138 |
+
f'Drum onset recall ({granularity_type})': drum_recall,
|
139 |
+
f'Drum onset F1 ({granularity_type})': drum_f_measure,
|
140 |
+
f'Nondrum onset + offset + program precision ({granularity_type})':
|
141 |
+
nondrum_precision,
|
142 |
+
f'Nondrum onset + offset + program recall ({granularity_type})':
|
143 |
+
nondrum_recall,
|
144 |
+
f'Nondrum onset + offset + program F1 ({granularity_type})':
|
145 |
+
nondrum_f_measure
|
146 |
+
}
|
147 |
+
|
148 |
+
|
149 |
+
def _note_onset_tolerance_sweep(
|
150 |
+
ref_ns: note_seq.NoteSequence, est_ns: note_seq.NoteSequence,
|
151 |
+
tolerances: Iterable[float] = (0.01, 0.02, 0.05, 0.1, 0.2, 0.5)
|
152 |
+
) -> Mapping[str, float]:
|
153 |
+
"""Compute note precision/recall/F1 across a range of tolerances."""
|
154 |
+
est_intervals, est_pitches, unused_est_velocities = (
|
155 |
+
note_seq.sequences_lib.sequence_to_valued_intervals(est_ns))
|
156 |
+
ref_intervals, ref_pitches, unused_ref_velocities = (
|
157 |
+
note_seq.sequences_lib.sequence_to_valued_intervals(ref_ns))
|
158 |
+
|
159 |
+
scores = {}
|
160 |
+
|
161 |
+
for tol in tolerances:
|
162 |
+
precision, recall, f_measure, _ = (
|
163 |
+
mir_eval.transcription.precision_recall_f1_overlap(
|
164 |
+
ref_intervals=ref_intervals, ref_pitches=ref_pitches,
|
165 |
+
est_intervals=est_intervals, est_pitches=est_pitches,
|
166 |
+
onset_tolerance=tol, offset_min_tolerance=tol))
|
167 |
+
|
168 |
+
scores[f'Onset + offset precision ({tol})'] = precision
|
169 |
+
scores[f'Onset + offset recall ({tol})'] = recall
|
170 |
+
scores[f'Onset + offset F1 ({tol})'] = f_measure
|
171 |
+
|
172 |
+
return scores
|
173 |
+
|
174 |
+
|
175 |
+
def transcription_metrics(
|
176 |
+
targets: Sequence[Mapping[str, Any]],
|
177 |
+
predictions: Sequence[Mapping[str, Any]],
|
178 |
+
codec: event_codec.Codec,
|
179 |
+
spectrogram_config: spectrograms.SpectrogramConfig,
|
180 |
+
onsets_only: bool,
|
181 |
+
use_ties: bool,
|
182 |
+
track_specs: Optional[Sequence[note_sequences.TrackSpec]] = None,
|
183 |
+
num_summary_examples: int = 5,
|
184 |
+
frame_fps: float = 62.5,
|
185 |
+
frame_velocity_threshold: int = 30,
|
186 |
+
) -> Mapping[str, seqio.metrics.MetricValue]:
|
187 |
+
"""Compute mir_eval transcription metrics."""
|
188 |
+
if onsets_only and use_ties:
|
189 |
+
raise ValueError('Ties not compatible with onset-only transcription.')
|
190 |
+
if onsets_only:
|
191 |
+
encoding_spec = note_sequences.NoteOnsetEncodingSpec
|
192 |
+
elif not use_ties:
|
193 |
+
encoding_spec = note_sequences.NoteEncodingSpec
|
194 |
+
else:
|
195 |
+
encoding_spec = note_sequences.NoteEncodingWithTiesSpec
|
196 |
+
|
197 |
+
# The first target for each full example contains the NoteSequence; just
|
198 |
+
# organize by ID.
|
199 |
+
full_targets = {}
|
200 |
+
for target in targets:
|
201 |
+
if target['ref_ns']:
|
202 |
+
full_targets[target['unique_id']] = {'ref_ns': target['ref_ns']}
|
203 |
+
|
204 |
+
# Gather all predictions for the same ID and concatenate them in time order,
|
205 |
+
# to construct full-length predictions.
|
206 |
+
full_predictions = metrics_utils.combine_predictions_by_id(
|
207 |
+
predictions=predictions,
|
208 |
+
combine_predictions_fn=functools.partial(
|
209 |
+
metrics_utils.event_predictions_to_ns,
|
210 |
+
codec=codec,
|
211 |
+
encoding_spec=encoding_spec))
|
212 |
+
|
213 |
+
assert sorted(full_targets.keys()) == sorted(full_predictions.keys())
|
214 |
+
|
215 |
+
full_target_prediction_pairs = [
|
216 |
+
(full_targets[id], full_predictions[id])
|
217 |
+
for id in sorted(full_targets.keys())
|
218 |
+
]
|
219 |
+
|
220 |
+
scores = collections.defaultdict(list)
|
221 |
+
all_track_pianorolls = collections.defaultdict(list)
|
222 |
+
for target, prediction in full_target_prediction_pairs:
|
223 |
+
scores['Invalid events'].append(prediction['est_invalid_events'])
|
224 |
+
scores['Dropped events'].append(prediction['est_dropped_events'])
|
225 |
+
|
226 |
+
def remove_drums(ns):
|
227 |
+
ns_drumless = note_seq.NoteSequence()
|
228 |
+
ns_drumless.CopyFrom(ns)
|
229 |
+
del ns_drumless.notes[:]
|
230 |
+
ns_drumless.notes.extend([note for note in ns.notes if not note.is_drum])
|
231 |
+
return ns_drumless
|
232 |
+
|
233 |
+
est_ns_drumless = remove_drums(prediction['est_ns'])
|
234 |
+
ref_ns_drumless = remove_drums(target['ref_ns'])
|
235 |
+
|
236 |
+
# Whether or not there are separate tracks, compute metrics for the full
|
237 |
+
# NoteSequence minus drums.
|
238 |
+
est_tracks = [est_ns_drumless]
|
239 |
+
ref_tracks = [ref_ns_drumless]
|
240 |
+
use_track_offsets = [not onsets_only]
|
241 |
+
use_track_velocities = [not onsets_only]
|
242 |
+
track_instrument_names = ['']
|
243 |
+
|
244 |
+
if track_specs is not None:
|
245 |
+
# Compute transcription metrics separately for each track.
|
246 |
+
for spec in track_specs:
|
247 |
+
est_tracks.append(note_sequences.extract_track(
|
248 |
+
prediction['est_ns'], spec.program, spec.is_drum))
|
249 |
+
ref_tracks.append(note_sequences.extract_track(
|
250 |
+
target['ref_ns'], spec.program, spec.is_drum))
|
251 |
+
use_track_offsets.append(not onsets_only and not spec.is_drum)
|
252 |
+
use_track_velocities.append(not onsets_only)
|
253 |
+
track_instrument_names.append(spec.name)
|
254 |
+
|
255 |
+
for est_ns, ref_ns, use_offsets, use_velocities, instrument_name in zip(
|
256 |
+
est_tracks, ref_tracks, use_track_offsets, use_track_velocities,
|
257 |
+
track_instrument_names):
|
258 |
+
track_scores = {}
|
259 |
+
|
260 |
+
est_intervals, est_pitches, est_velocities = (
|
261 |
+
note_seq.sequences_lib.sequence_to_valued_intervals(est_ns))
|
262 |
+
|
263 |
+
ref_intervals, ref_pitches, ref_velocities = (
|
264 |
+
note_seq.sequences_lib.sequence_to_valued_intervals(ref_ns))
|
265 |
+
|
266 |
+
# Precision / recall / F1 using onsets (and pitches) only.
|
267 |
+
precision, recall, f_measure, avg_overlap_ratio = (
|
268 |
+
mir_eval.transcription.precision_recall_f1_overlap(
|
269 |
+
ref_intervals=ref_intervals,
|
270 |
+
ref_pitches=ref_pitches,
|
271 |
+
est_intervals=est_intervals,
|
272 |
+
est_pitches=est_pitches,
|
273 |
+
offset_ratio=None))
|
274 |
+
del avg_overlap_ratio
|
275 |
+
track_scores['Onset precision'] = precision
|
276 |
+
track_scores['Onset recall'] = recall
|
277 |
+
track_scores['Onset F1'] = f_measure
|
278 |
+
|
279 |
+
if use_offsets:
|
280 |
+
# Precision / recall / F1 using onsets and offsets.
|
281 |
+
precision, recall, f_measure, avg_overlap_ratio = (
|
282 |
+
mir_eval.transcription.precision_recall_f1_overlap(
|
283 |
+
ref_intervals=ref_intervals,
|
284 |
+
ref_pitches=ref_pitches,
|
285 |
+
est_intervals=est_intervals,
|
286 |
+
est_pitches=est_pitches))
|
287 |
+
del avg_overlap_ratio
|
288 |
+
track_scores['Onset + offset precision'] = precision
|
289 |
+
track_scores['Onset + offset recall'] = recall
|
290 |
+
track_scores['Onset + offset F1'] = f_measure
|
291 |
+
|
292 |
+
if use_velocities:
|
293 |
+
# Precision / recall / F1 using onsets and velocities (no offsets).
|
294 |
+
precision, recall, f_measure, avg_overlap_ratio = (
|
295 |
+
mir_eval.transcription_velocity.precision_recall_f1_overlap(
|
296 |
+
ref_intervals=ref_intervals,
|
297 |
+
ref_pitches=ref_pitches,
|
298 |
+
ref_velocities=ref_velocities,
|
299 |
+
est_intervals=est_intervals,
|
300 |
+
est_pitches=est_pitches,
|
301 |
+
est_velocities=est_velocities,
|
302 |
+
offset_ratio=None))
|
303 |
+
track_scores['Onset + velocity precision'] = precision
|
304 |
+
track_scores['Onset + velocity recall'] = recall
|
305 |
+
track_scores['Onset + velocity F1'] = f_measure
|
306 |
+
|
307 |
+
if use_offsets and use_velocities:
|
308 |
+
# Precision / recall / F1 using onsets, offsets, and velocities.
|
309 |
+
precision, recall, f_measure, avg_overlap_ratio = (
|
310 |
+
mir_eval.transcription_velocity.precision_recall_f1_overlap(
|
311 |
+
ref_intervals=ref_intervals,
|
312 |
+
ref_pitches=ref_pitches,
|
313 |
+
ref_velocities=ref_velocities,
|
314 |
+
est_intervals=est_intervals,
|
315 |
+
est_pitches=est_pitches,
|
316 |
+
est_velocities=est_velocities))
|
317 |
+
track_scores['Onset + offset + velocity precision'] = precision
|
318 |
+
track_scores['Onset + offset + velocity recall'] = recall
|
319 |
+
track_scores['Onset + offset + velocity F1'] = f_measure
|
320 |
+
|
321 |
+
# Calculate framewise metrics.
|
322 |
+
is_drum = all([n.is_drum for n in ref_ns.notes])
|
323 |
+
ref_pr = metrics_utils.get_prettymidi_pianoroll(
|
324 |
+
ref_ns, frame_fps, is_drum=is_drum)
|
325 |
+
est_pr = metrics_utils.get_prettymidi_pianoroll(
|
326 |
+
est_ns, frame_fps, is_drum=is_drum)
|
327 |
+
all_track_pianorolls[instrument_name].append((est_pr, ref_pr))
|
328 |
+
frame_precision, frame_recall, frame_f1 = metrics_utils.frame_metrics(
|
329 |
+
ref_pr, est_pr, velocity_threshold=frame_velocity_threshold)
|
330 |
+
track_scores['Frame Precision'] = frame_precision
|
331 |
+
track_scores['Frame Recall'] = frame_recall
|
332 |
+
track_scores['Frame F1'] = frame_f1
|
333 |
+
|
334 |
+
for metric_name, metric_value in track_scores.items():
|
335 |
+
if instrument_name:
|
336 |
+
scores[f'{instrument_name}/{metric_name}'].append(metric_value)
|
337 |
+
else:
|
338 |
+
scores[metric_name].append(metric_value)
|
339 |
+
|
340 |
+
# Add program-aware note metrics for all program granularities.
|
341 |
+
# Note that this interacts with the training program granularity; in
|
342 |
+
# particular granularities *higher* than the training granularity are likely
|
343 |
+
# to have poor metrics.
|
344 |
+
for granularity_type in vocabularies.PROGRAM_GRANULARITIES:
|
345 |
+
for name, score in _program_aware_note_scores(
|
346 |
+
target['ref_ns'], prediction['est_ns'],
|
347 |
+
granularity_type=granularity_type).items():
|
348 |
+
scores[name].append(score)
|
349 |
+
|
350 |
+
# Add (non-program-aware) note metrics across a range of onset/offset
|
351 |
+
# tolerances.
|
352 |
+
for name, score in _note_onset_tolerance_sweep(
|
353 |
+
ref_ns=ref_ns_drumless, est_ns=est_ns_drumless).items():
|
354 |
+
scores[name].append(score)
|
355 |
+
|
356 |
+
mean_scores = {k: np.mean(v) for k, v in scores.items()}
|
357 |
+
|
358 |
+
score_histograms = {'%s (hist)' % k: seqio.metrics.Histogram(np.array(v))
|
359 |
+
for k, v in scores.items()}
|
360 |
+
|
361 |
+
# Pick several examples to summarize.
|
362 |
+
targets_to_summarize, predictions_to_summarize = zip(
|
363 |
+
*full_target_prediction_pairs[:num_summary_examples])
|
364 |
+
|
365 |
+
# Compute audio summaries.
|
366 |
+
audio_summaries = summaries.audio_summaries(
|
367 |
+
targets=targets_to_summarize,
|
368 |
+
predictions=predictions_to_summarize,
|
369 |
+
spectrogram_config=spectrogram_config)
|
370 |
+
|
371 |
+
# Compute transcription summaries.
|
372 |
+
transcription_summaries = summaries.transcription_summaries(
|
373 |
+
targets=targets_to_summarize,
|
374 |
+
predictions=predictions_to_summarize,
|
375 |
+
spectrogram_config=spectrogram_config,
|
376 |
+
ns_feature_suffix='ns',
|
377 |
+
track_specs=track_specs)
|
378 |
+
|
379 |
+
pianorolls_to_summarize = {
|
380 |
+
k: v[:num_summary_examples] for k, v in all_track_pianorolls.items()
|
381 |
+
}
|
382 |
+
|
383 |
+
prettymidi_pianoroll_summaries = summaries.prettymidi_pianoroll(
|
384 |
+
pianorolls_to_summarize, fps=frame_fps)
|
385 |
+
|
386 |
+
return {
|
387 |
+
**mean_scores,
|
388 |
+
**score_histograms,
|
389 |
+
**audio_summaries,
|
390 |
+
**transcription_summaries,
|
391 |
+
**prettymidi_pianoroll_summaries,
|
392 |
+
}
|
mt3/metrics_utils.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Utilities for transcription metrics."""
|
16 |
+
|
17 |
+
import collections
|
18 |
+
import functools
|
19 |
+
|
20 |
+
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, TypeVar
|
21 |
+
|
22 |
+
from mt3 import event_codec
|
23 |
+
from mt3 import note_sequences
|
24 |
+
from mt3 import run_length_encoding
|
25 |
+
|
26 |
+
import note_seq
|
27 |
+
import numpy as np
|
28 |
+
import pretty_midi
|
29 |
+
import sklearn
|
30 |
+
|
31 |
+
S = TypeVar('S')
|
32 |
+
T = TypeVar('T')
|
33 |
+
|
34 |
+
CombineExamplesFunctionType = Callable[[Sequence[Mapping[str, Any]]],
|
35 |
+
Mapping[str, Any]]
|
36 |
+
|
37 |
+
|
38 |
+
def _group_predictions_by_id(
|
39 |
+
predictions: Sequence[Mapping[str, T]]
|
40 |
+
) -> Mapping[str, Sequence[T]]:
|
41 |
+
predictions_by_id = collections.defaultdict(list)
|
42 |
+
for pred in predictions:
|
43 |
+
predictions_by_id[pred['unique_id']].append(pred)
|
44 |
+
return predictions_by_id
|
45 |
+
|
46 |
+
|
47 |
+
def combine_predictions_by_id(
|
48 |
+
predictions: Sequence[Mapping[str, Any]],
|
49 |
+
combine_predictions_fn: CombineExamplesFunctionType
|
50 |
+
) -> Mapping[str, Mapping[str, Any]]:
|
51 |
+
"""Concatenate predicted examples, grouping by ID and sorting by time."""
|
52 |
+
predictions_by_id = _group_predictions_by_id(predictions)
|
53 |
+
return {
|
54 |
+
id: combine_predictions_fn(preds)
|
55 |
+
for id, preds in predictions_by_id.items()
|
56 |
+
}
|
57 |
+
|
58 |
+
|
59 |
+
def decode_and_combine_predictions(
|
60 |
+
predictions: Sequence[Mapping[str, Any]],
|
61 |
+
init_state_fn: Callable[[], S],
|
62 |
+
begin_segment_fn: Callable[[S], None],
|
63 |
+
decode_tokens_fn: Callable[[S, Sequence[int], int, Optional[int]],
|
64 |
+
Tuple[int, int]],
|
65 |
+
flush_state_fn: Callable[[S], T]
|
66 |
+
) -> Tuple[T, int, int]:
|
67 |
+
"""Decode and combine a sequence of predictions to a full result.
|
68 |
+
|
69 |
+
For time-based events, this usually means concatenation.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
predictions: List of predictions, each of which is a dictionary containing
|
73 |
+
estimated tokens ('est_tokens') and start time ('start_time') fields.
|
74 |
+
init_state_fn: Function that takes no arguments and returns an initial
|
75 |
+
decoding state.
|
76 |
+
begin_segment_fn: Function that updates the decoding state at the beginning
|
77 |
+
of a segment.
|
78 |
+
decode_tokens_fn: Function that takes a decoding state, estimated tokens
|
79 |
+
(for a single segment), start time, and max time, and processes the
|
80 |
+
tokens, updating the decoding state in place. Also returns the number of
|
81 |
+
invalid and dropped events for the segment.
|
82 |
+
flush_state_fn: Function that flushes the final decoding state into the
|
83 |
+
result.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
result: The full combined decoding.
|
87 |
+
total_invalid_events: Total number of invalid event tokens across all
|
88 |
+
predictions.
|
89 |
+
total_dropped_events: Total number of dropped event tokens across all
|
90 |
+
predictions.
|
91 |
+
"""
|
92 |
+
sorted_predictions = sorted(predictions, key=lambda pred: pred['start_time'])
|
93 |
+
|
94 |
+
state = init_state_fn()
|
95 |
+
total_invalid_events = 0
|
96 |
+
total_dropped_events = 0
|
97 |
+
|
98 |
+
for pred_idx, pred in enumerate(sorted_predictions):
|
99 |
+
begin_segment_fn(state)
|
100 |
+
|
101 |
+
# Depending on the audio token hop length, each symbolic token could be
|
102 |
+
# associated with multiple audio frames. Since we split up the audio frames
|
103 |
+
# into segments for prediction, this could lead to overlap. To prevent
|
104 |
+
# overlap issues, ensure that the current segment does not make any
|
105 |
+
# predictions for the time period covered by the subsequent segment.
|
106 |
+
max_decode_time = None
|
107 |
+
if pred_idx < len(sorted_predictions) - 1:
|
108 |
+
max_decode_time = sorted_predictions[pred_idx + 1]['start_time']
|
109 |
+
|
110 |
+
invalid_events, dropped_events = decode_tokens_fn(
|
111 |
+
state, pred['est_tokens'], pred['start_time'], max_decode_time)
|
112 |
+
|
113 |
+
total_invalid_events += invalid_events
|
114 |
+
total_dropped_events += dropped_events
|
115 |
+
|
116 |
+
return flush_state_fn(state), total_invalid_events, total_dropped_events
|
117 |
+
|
118 |
+
|
119 |
+
def event_predictions_to_ns(
|
120 |
+
predictions: Sequence[Mapping[str, Any]], codec: event_codec.Codec,
|
121 |
+
encoding_spec: note_sequences.NoteEncodingSpecType
|
122 |
+
) -> Mapping[str, Any]:
|
123 |
+
"""Convert a sequence of predictions to a combined NoteSequence."""
|
124 |
+
ns, total_invalid_events, total_dropped_events = decode_and_combine_predictions(
|
125 |
+
predictions=predictions,
|
126 |
+
init_state_fn=encoding_spec.init_decoding_state_fn,
|
127 |
+
begin_segment_fn=encoding_spec.begin_decoding_segment_fn,
|
128 |
+
decode_tokens_fn=functools.partial(
|
129 |
+
run_length_encoding.decode_events,
|
130 |
+
codec=codec,
|
131 |
+
decode_event_fn=encoding_spec.decode_event_fn),
|
132 |
+
flush_state_fn=encoding_spec.flush_decoding_state_fn)
|
133 |
+
|
134 |
+
# Also concatenate raw inputs from all predictions.
|
135 |
+
sorted_predictions = sorted(predictions, key=lambda pred: pred['start_time'])
|
136 |
+
raw_inputs = np.concatenate(
|
137 |
+
[pred['raw_inputs'] for pred in sorted_predictions], axis=0)
|
138 |
+
start_times = [pred['start_time'] for pred in sorted_predictions]
|
139 |
+
|
140 |
+
return {
|
141 |
+
'raw_inputs': raw_inputs,
|
142 |
+
'start_times': start_times,
|
143 |
+
'est_ns': ns,
|
144 |
+
'est_invalid_events': total_invalid_events,
|
145 |
+
'est_dropped_events': total_dropped_events,
|
146 |
+
}
|
147 |
+
|
148 |
+
|
149 |
+
def get_prettymidi_pianoroll(ns: note_seq.NoteSequence, fps: float,
|
150 |
+
is_drum: bool):
|
151 |
+
"""Convert NoteSequence to pianoroll through pretty_midi."""
|
152 |
+
for note in ns.notes:
|
153 |
+
if is_drum or note.end_time - note.start_time < 0.05:
|
154 |
+
# Give all drum notes a fixed length, and all others a min length
|
155 |
+
note.end_time = note.start_time + 0.05
|
156 |
+
|
157 |
+
pm = note_seq.note_sequence_to_pretty_midi(ns)
|
158 |
+
end_time = pm.get_end_time()
|
159 |
+
cc = [
|
160 |
+
# all sound off
|
161 |
+
pretty_midi.ControlChange(number=120, value=0, time=end_time),
|
162 |
+
# all notes off
|
163 |
+
pretty_midi.ControlChange(number=123, value=0, time=end_time)
|
164 |
+
]
|
165 |
+
pm.instruments[0].control_changes = cc
|
166 |
+
if is_drum:
|
167 |
+
# If inst.is_drum is set, pretty_midi will return an all zero pianoroll.
|
168 |
+
for inst in pm.instruments:
|
169 |
+
inst.is_drum = False
|
170 |
+
pianoroll = pm.get_piano_roll(fs=fps)
|
171 |
+
return pianoroll
|
172 |
+
|
173 |
+
|
174 |
+
def frame_metrics(ref_pianoroll: np.ndarray,
|
175 |
+
est_pianoroll: np.ndarray,
|
176 |
+
velocity_threshold: int) -> Tuple[float, float, float]:
|
177 |
+
"""Frame Precision, Recall, and F1."""
|
178 |
+
# Pad to same length
|
179 |
+
if ref_pianoroll.shape[1] > est_pianoroll.shape[1]:
|
180 |
+
diff = ref_pianoroll.shape[1] - est_pianoroll.shape[1]
|
181 |
+
est_pianoroll = np.pad(est_pianoroll, [(0, 0), (0, diff)], mode='constant')
|
182 |
+
elif est_pianoroll.shape[1] > ref_pianoroll.shape[1]:
|
183 |
+
diff = est_pianoroll.shape[1] - ref_pianoroll.shape[1]
|
184 |
+
ref_pianoroll = np.pad(ref_pianoroll, [(0, 0), (0, diff)], mode='constant')
|
185 |
+
|
186 |
+
# For ref, remove any notes that are too quiet (consistent with Cerberus.)
|
187 |
+
ref_frames_bool = ref_pianoroll > velocity_threshold
|
188 |
+
# For est, keep all predicted notes.
|
189 |
+
est_frames_bool = est_pianoroll > 0
|
190 |
+
|
191 |
+
precision, recall, f1, _ = sklearn.metrics.precision_recall_fscore_support(
|
192 |
+
ref_frames_bool.flatten(),
|
193 |
+
est_frames_bool.flatten(),
|
194 |
+
labels=[True, False])
|
195 |
+
|
196 |
+
return precision[0], recall[0], f1[0]
|
mt3/metrics_utils_test.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for metrics_utils."""
|
16 |
+
|
17 |
+
from mt3 import event_codec
|
18 |
+
from mt3 import metrics_utils
|
19 |
+
from mt3 import note_sequences
|
20 |
+
|
21 |
+
import note_seq
|
22 |
+
import numpy as np
|
23 |
+
import tensorflow as tf
|
24 |
+
|
25 |
+
|
26 |
+
class MetricsUtilsTest(tf.test.TestCase):
|
27 |
+
|
28 |
+
def test_event_predictions_to_ns(self):
|
29 |
+
predictions = [
|
30 |
+
{
|
31 |
+
'raw_inputs': [0, 0],
|
32 |
+
'start_time': 0.0,
|
33 |
+
'est_tokens': [20, 160],
|
34 |
+
},
|
35 |
+
{
|
36 |
+
'raw_inputs': [1, 1],
|
37 |
+
'start_time': 0.4,
|
38 |
+
# These last 2 events should be dropped.
|
39 |
+
'est_tokens': [20, 161, 50, 162],
|
40 |
+
},
|
41 |
+
{
|
42 |
+
'raw_inputs': [2, 2],
|
43 |
+
'start_time': 0.8,
|
44 |
+
'est_tokens': [163, 20, 164]
|
45 |
+
},
|
46 |
+
]
|
47 |
+
expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
|
48 |
+
expected_ns.notes.add(
|
49 |
+
pitch=59,
|
50 |
+
velocity=100,
|
51 |
+
start_time=0.20,
|
52 |
+
end_time=0.21)
|
53 |
+
expected_ns.notes.add(
|
54 |
+
pitch=60,
|
55 |
+
velocity=100,
|
56 |
+
start_time=0.60,
|
57 |
+
end_time=0.61)
|
58 |
+
expected_ns.notes.add(
|
59 |
+
pitch=62,
|
60 |
+
velocity=100,
|
61 |
+
start_time=0.80,
|
62 |
+
end_time=0.81)
|
63 |
+
expected_ns.notes.add(
|
64 |
+
pitch=63,
|
65 |
+
velocity=100,
|
66 |
+
start_time=1.00,
|
67 |
+
end_time=1.01)
|
68 |
+
expected_ns.total_time = 1.01
|
69 |
+
|
70 |
+
codec = event_codec.Codec(
|
71 |
+
max_shift_steps=100,
|
72 |
+
steps_per_second=100,
|
73 |
+
event_ranges=[
|
74 |
+
event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH,
|
75 |
+
note_seq.MAX_MIDI_PITCH)])
|
76 |
+
res = metrics_utils.event_predictions_to_ns(
|
77 |
+
predictions, codec=codec,
|
78 |
+
encoding_spec=note_sequences.NoteOnsetEncodingSpec)
|
79 |
+
self.assertProtoEquals(expected_ns, res['est_ns'])
|
80 |
+
self.assertEqual(0, res['est_invalid_events'])
|
81 |
+
self.assertEqual(2, res['est_dropped_events'])
|
82 |
+
np.testing.assert_array_equal([0, 0, 1, 1, 2, 2], res['raw_inputs'])
|
83 |
+
|
84 |
+
def test_event_predictions_to_ns_with_offsets(self):
|
85 |
+
predictions = [
|
86 |
+
{
|
87 |
+
'raw_inputs': [0, 0],
|
88 |
+
'start_time': 0.0,
|
89 |
+
'est_tokens': [20, 356, 160],
|
90 |
+
},
|
91 |
+
{
|
92 |
+
'raw_inputs': [1, 1],
|
93 |
+
'start_time': 0.4,
|
94 |
+
'est_tokens': [20, 292, 161],
|
95 |
+
},
|
96 |
+
{
|
97 |
+
'raw_inputs': [2, 2],
|
98 |
+
'start_time': 0.8,
|
99 |
+
'est_tokens': [20, 229, 160, 161]
|
100 |
+
},
|
101 |
+
]
|
102 |
+
expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
|
103 |
+
expected_ns.notes.add(
|
104 |
+
pitch=59,
|
105 |
+
velocity=127,
|
106 |
+
start_time=0.20,
|
107 |
+
end_time=1.00)
|
108 |
+
expected_ns.notes.add(
|
109 |
+
pitch=60,
|
110 |
+
velocity=63,
|
111 |
+
start_time=0.60,
|
112 |
+
end_time=1.00)
|
113 |
+
expected_ns.total_time = 1.00
|
114 |
+
|
115 |
+
codec = event_codec.Codec(
|
116 |
+
max_shift_steps=100,
|
117 |
+
steps_per_second=100,
|
118 |
+
event_ranges=[
|
119 |
+
event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH,
|
120 |
+
note_seq.MAX_MIDI_PITCH),
|
121 |
+
event_codec.EventRange('velocity', 0, 127)
|
122 |
+
])
|
123 |
+
res = metrics_utils.event_predictions_to_ns(
|
124 |
+
predictions, codec=codec, encoding_spec=note_sequences.NoteEncodingSpec)
|
125 |
+
self.assertProtoEquals(expected_ns, res['est_ns'])
|
126 |
+
self.assertEqual(0, res['est_invalid_events'])
|
127 |
+
self.assertEqual(0, res['est_dropped_events'])
|
128 |
+
np.testing.assert_array_equal([0, 0, 1, 1, 2, 2], res['raw_inputs'])
|
129 |
+
|
130 |
+
def test_event_predictions_to_ns_multitrack(self):
|
131 |
+
predictions = [
|
132 |
+
{
|
133 |
+
'raw_inputs': [0, 0],
|
134 |
+
'start_time': 0.0,
|
135 |
+
'est_tokens': [20, 517, 356, 160],
|
136 |
+
},
|
137 |
+
{
|
138 |
+
'raw_inputs': [1, 1],
|
139 |
+
'start_time': 0.4,
|
140 |
+
'est_tokens': [20, 356, 399],
|
141 |
+
},
|
142 |
+
{
|
143 |
+
'raw_inputs': [2, 2],
|
144 |
+
'start_time': 0.8,
|
145 |
+
'est_tokens': [20, 517, 229, 160]
|
146 |
+
},
|
147 |
+
]
|
148 |
+
expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
|
149 |
+
expected_ns.notes.add(
|
150 |
+
pitch=42,
|
151 |
+
velocity=127,
|
152 |
+
start_time=0.60,
|
153 |
+
end_time=0.61,
|
154 |
+
is_drum=True,
|
155 |
+
instrument=9)
|
156 |
+
expected_ns.notes.add(
|
157 |
+
pitch=59,
|
158 |
+
velocity=127,
|
159 |
+
start_time=0.20,
|
160 |
+
end_time=1.00,
|
161 |
+
program=32)
|
162 |
+
expected_ns.total_time = 1.00
|
163 |
+
|
164 |
+
codec = event_codec.Codec(
|
165 |
+
max_shift_steps=100,
|
166 |
+
steps_per_second=100,
|
167 |
+
event_ranges=[
|
168 |
+
event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH,
|
169 |
+
note_seq.MAX_MIDI_PITCH),
|
170 |
+
event_codec.EventRange('velocity', 0, 127),
|
171 |
+
event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH,
|
172 |
+
note_seq.MAX_MIDI_PITCH),
|
173 |
+
event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM,
|
174 |
+
note_seq.MAX_MIDI_PROGRAM)
|
175 |
+
])
|
176 |
+
res = metrics_utils.event_predictions_to_ns(
|
177 |
+
predictions, codec=codec, encoding_spec=note_sequences.NoteEncodingSpec)
|
178 |
+
self.assertProtoEquals(expected_ns, res['est_ns'])
|
179 |
+
self.assertEqual(0, res['est_invalid_events'])
|
180 |
+
self.assertEqual(0, res['est_dropped_events'])
|
181 |
+
np.testing.assert_array_equal([0, 0, 1, 1, 2, 2], res['raw_inputs'])
|
182 |
+
|
183 |
+
def test_event_predictions_to_ns_multitrack_ties(self):
|
184 |
+
predictions = [
|
185 |
+
{
|
186 |
+
'raw_inputs': [0, 0],
|
187 |
+
'start_time': 0.0,
|
188 |
+
'est_tokens': [613, # no tied notes
|
189 |
+
20, 517, 356, 160],
|
190 |
+
},
|
191 |
+
{
|
192 |
+
'raw_inputs': [1, 1],
|
193 |
+
'start_time': 0.4,
|
194 |
+
'est_tokens': [517, 160, 613, # tied note
|
195 |
+
20, 356, 399],
|
196 |
+
},
|
197 |
+
{
|
198 |
+
'raw_inputs': [2, 2],
|
199 |
+
'start_time': 0.8,
|
200 |
+
'est_tokens': [613] # no tied notes, causing active note to end
|
201 |
+
},
|
202 |
+
]
|
203 |
+
expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
|
204 |
+
expected_ns.notes.add(
|
205 |
+
pitch=42,
|
206 |
+
velocity=127,
|
207 |
+
start_time=0.60,
|
208 |
+
end_time=0.61,
|
209 |
+
is_drum=True,
|
210 |
+
instrument=9)
|
211 |
+
expected_ns.notes.add(
|
212 |
+
pitch=59,
|
213 |
+
velocity=127,
|
214 |
+
start_time=0.20,
|
215 |
+
end_time=0.80,
|
216 |
+
program=32)
|
217 |
+
expected_ns.total_time = 0.80
|
218 |
+
|
219 |
+
codec = event_codec.Codec(
|
220 |
+
max_shift_steps=100,
|
221 |
+
steps_per_second=100,
|
222 |
+
event_ranges=[
|
223 |
+
event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH,
|
224 |
+
note_seq.MAX_MIDI_PITCH),
|
225 |
+
event_codec.EventRange('velocity', 0, 127),
|
226 |
+
event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH,
|
227 |
+
note_seq.MAX_MIDI_PITCH),
|
228 |
+
event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM,
|
229 |
+
note_seq.MAX_MIDI_PROGRAM),
|
230 |
+
event_codec.EventRange('tie', 0, 0)
|
231 |
+
])
|
232 |
+
res = metrics_utils.event_predictions_to_ns(
|
233 |
+
predictions, codec=codec,
|
234 |
+
encoding_spec=note_sequences.NoteEncodingWithTiesSpec)
|
235 |
+
self.assertProtoEquals(expected_ns, res['est_ns'])
|
236 |
+
self.assertEqual(0, res['est_invalid_events'])
|
237 |
+
self.assertEqual(0, res['est_dropped_events'])
|
238 |
+
np.testing.assert_array_equal([0, 0, 1, 1, 2, 2], res['raw_inputs'])
|
239 |
+
|
240 |
+
def test_frame_metrics(self):
|
241 |
+
ref = np.zeros(shape=(128, 5))
|
242 |
+
est = np.zeros(shape=(128, 5))
|
243 |
+
|
244 |
+
# one overlapping note, two false positives, two false negatives
|
245 |
+
ref[10, 0] = 127
|
246 |
+
ref[10, 1] = 127
|
247 |
+
ref[10, 2] = 127
|
248 |
+
|
249 |
+
est[10, 2] = 127
|
250 |
+
est[10, 3] = 127
|
251 |
+
est[10, 4] = 127
|
252 |
+
|
253 |
+
prec, rec, _ = metrics_utils.frame_metrics(ref, est, velocity_threshold=1)
|
254 |
+
np.testing.assert_approx_equal(prec, 1/3)
|
255 |
+
np.testing.assert_approx_equal(rec, 1/3)
|
256 |
+
|
257 |
+
|
258 |
+
if __name__ == '__main__':
|
259 |
+
tf.test.main()
|
mt3/mixing.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Functions for mixing (in the audio sense) multiple transcription examples."""
|
16 |
+
|
17 |
+
from typing import Callable, Optional, Sequence
|
18 |
+
|
19 |
+
import gin
|
20 |
+
|
21 |
+
from mt3 import event_codec
|
22 |
+
from mt3 import run_length_encoding
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
import seqio
|
26 |
+
import tensorflow as tf
|
27 |
+
|
28 |
+
|
29 |
+
@gin.configurable
|
30 |
+
def mix_transcription_examples(
|
31 |
+
ds: tf.data.Dataset,
|
32 |
+
sequence_length: seqio.preprocessors.SequenceLengthType,
|
33 |
+
output_features: seqio.preprocessors.OutputFeaturesType,
|
34 |
+
codec: event_codec.Codec,
|
35 |
+
inputs_feature_key: str = 'inputs',
|
36 |
+
targets_feature_keys: Sequence[str] = ('targets',),
|
37 |
+
max_examples_per_mix: Optional[int] = None,
|
38 |
+
shuffle_buffer_size: int = seqio.SHUFFLE_BUFFER_SIZE
|
39 |
+
) -> Callable[..., tf.data.Dataset]:
|
40 |
+
"""Preprocessor that mixes together "batches" of transcription examples.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
ds: Dataset of individual transcription examples, each of which should
|
44 |
+
have an 'inputs' field containing 1D audio samples (currently only
|
45 |
+
audio encoders that use raw samples as an intermediate representation
|
46 |
+
are supported), and a 'targets' field containing run-length encoded
|
47 |
+
note events.
|
48 |
+
sequence_length: Dictionary mapping feature key to length.
|
49 |
+
output_features: Dictionary mapping feature key to spec.
|
50 |
+
codec: An event_codec.Codec used to interpret the target events.
|
51 |
+
inputs_feature_key: Feature key for inputs which will be mixed as audio.
|
52 |
+
targets_feature_keys: List of feature keys for targets, each of which will
|
53 |
+
be merged (separately) as run-length encoded note events.
|
54 |
+
max_examples_per_mix: Maximum number of individual examples to mix together.
|
55 |
+
shuffle_buffer_size: Size of shuffle buffer to use for shuffle prior to
|
56 |
+
mixing.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
Dataset containing mixed examples.
|
60 |
+
"""
|
61 |
+
if max_examples_per_mix is None:
|
62 |
+
return ds
|
63 |
+
|
64 |
+
# TODO(iansimon): is there a way to use seqio's seed?
|
65 |
+
ds = tf.data.Dataset.sample_from_datasets([
|
66 |
+
ds.shuffle(
|
67 |
+
buffer_size=shuffle_buffer_size // max_examples_per_mix
|
68 |
+
).padded_batch(batch_size=i) for i in range(1, max_examples_per_mix + 1)
|
69 |
+
])
|
70 |
+
|
71 |
+
def mix_inputs(ex):
|
72 |
+
samples = tf.reduce_sum(ex[inputs_feature_key], axis=0)
|
73 |
+
norm = tf.linalg.norm(samples, ord=np.inf)
|
74 |
+
ex[inputs_feature_key] = tf.math.divide_no_nan(samples, norm)
|
75 |
+
return ex
|
76 |
+
ds = ds.map(mix_inputs, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
77 |
+
|
78 |
+
max_tokens = sequence_length['targets']
|
79 |
+
if output_features['targets'].add_eos:
|
80 |
+
# Leave room to insert an EOS token.
|
81 |
+
max_tokens -= 1
|
82 |
+
|
83 |
+
def mix_targets(ex):
|
84 |
+
for k in targets_feature_keys:
|
85 |
+
ex[k] = run_length_encoding.merge_run_length_encoded_targets(
|
86 |
+
targets=ex[k],
|
87 |
+
codec=codec)
|
88 |
+
return ex
|
89 |
+
ds = ds.map(mix_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
90 |
+
|
91 |
+
return ds
|
mt3/models.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Feature converter and model for continuous inputs."""
|
16 |
+
|
17 |
+
from typing import Mapping
|
18 |
+
import seqio
|
19 |
+
from t5x import decoding
|
20 |
+
from t5x import models
|
21 |
+
import tensorflow as tf
|
22 |
+
|
23 |
+
|
24 |
+
class ContinuousInputsEncDecFeatureConverter(seqio.FeatureConverter):
|
25 |
+
"""Feature converter for an encoder-decoder with continuous inputs."""
|
26 |
+
|
27 |
+
TASK_FEATURES = {
|
28 |
+
"inputs": seqio.FeatureConverter.FeatureSpec(dtype=tf.float32, rank=2),
|
29 |
+
"targets": seqio.FeatureConverter.FeatureSpec(dtype=tf.int32),
|
30 |
+
}
|
31 |
+
MODEL_FEATURES = {
|
32 |
+
"encoder_input_tokens":
|
33 |
+
seqio.FeatureConverter.FeatureSpec(dtype=tf.float32, rank=2),
|
34 |
+
"decoder_target_tokens":
|
35 |
+
seqio.FeatureConverter.FeatureSpec(dtype=tf.int32),
|
36 |
+
"decoder_input_tokens":
|
37 |
+
seqio.FeatureConverter.FeatureSpec(dtype=tf.int32),
|
38 |
+
"decoder_loss_weights":
|
39 |
+
seqio.FeatureConverter.FeatureSpec(dtype=tf.int32),
|
40 |
+
}
|
41 |
+
PACKING_FEATURE_DTYPES = {
|
42 |
+
"encoder_segment_ids": tf.int32,
|
43 |
+
"decoder_segment_ids": tf.int32,
|
44 |
+
"encoder_positions": tf.int32,
|
45 |
+
"decoder_positions": tf.int32
|
46 |
+
}
|
47 |
+
|
48 |
+
def _convert_features(
|
49 |
+
self, ds: tf.data.Dataset,
|
50 |
+
task_feature_lengths: Mapping[str, int]) -> tf.data.Dataset:
|
51 |
+
"""Convert the dataset to be fed to the encoder-decoder model.
|
52 |
+
|
53 |
+
The conversion process involves three steps
|
54 |
+
|
55 |
+
1. Each feature in the `task_feature_lengths` is trimmed/padded and
|
56 |
+
optionally packed depending on the value of self.pack.
|
57 |
+
2. "inputs" fields are mapped to the encoder input and "targets" are mapped
|
58 |
+
to decoder input (after being shifted) and target.
|
59 |
+
|
60 |
+
All the keys in the `task_feature_lengths` should be present in the input
|
61 |
+
dataset, which may contain some extra features that are not in the
|
62 |
+
`task_feature_lengths`. They will not be included in the output dataset.
|
63 |
+
One common scenario is the "inputs_pretokenized" and "targets_pretokenized"
|
64 |
+
fields.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
ds: an input tf.data.Dataset to be converted.
|
68 |
+
task_feature_lengths: a mapping from feature to its length.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
ds: the converted dataset.
|
72 |
+
"""
|
73 |
+
|
74 |
+
def convert_example(
|
75 |
+
features: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
|
76 |
+
# targets_segment_id is present only for a packed dataset.
|
77 |
+
decoder_input_tokens = seqio.autoregressive_inputs(
|
78 |
+
features["targets"],
|
79 |
+
sequence_id=features.get("targets_segment_ids", None))
|
80 |
+
|
81 |
+
d = {"encoder_input_tokens": features["inputs"],
|
82 |
+
"decoder_target_tokens": features["targets"],
|
83 |
+
"decoder_input_tokens": decoder_input_tokens,
|
84 |
+
# Loss is computed for all but the padding positions.
|
85 |
+
"decoder_loss_weights":
|
86 |
+
seqio.non_padding_position(features["targets"])}
|
87 |
+
|
88 |
+
if self.pack:
|
89 |
+
d["encoder_segment_ids"] = features["inputs_segment_ids"]
|
90 |
+
d["decoder_segment_ids"] = features["targets_segment_ids"]
|
91 |
+
d["encoder_positions"] = features["inputs_positions"]
|
92 |
+
d["decoder_positions"] = features["targets_positions"]
|
93 |
+
|
94 |
+
return d
|
95 |
+
|
96 |
+
ds = self._pack_or_pad(ds, task_feature_lengths)
|
97 |
+
return ds.map(
|
98 |
+
convert_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
99 |
+
|
100 |
+
def get_model_feature_lengths(
|
101 |
+
self, task_feature_lengths: Mapping[str, int]) -> Mapping[str, int]:
|
102 |
+
"""Define the length relationship between input and output features."""
|
103 |
+
encoder_length = task_feature_lengths["inputs"]
|
104 |
+
decoder_length = task_feature_lengths["targets"]
|
105 |
+
|
106 |
+
model_feature_lengths = {
|
107 |
+
"encoder_input_tokens": encoder_length,
|
108 |
+
"decoder_target_tokens": decoder_length,
|
109 |
+
"decoder_input_tokens": decoder_length,
|
110 |
+
"decoder_loss_weights": decoder_length
|
111 |
+
}
|
112 |
+
if self.pack:
|
113 |
+
model_feature_lengths["encoder_segment_ids"] = encoder_length
|
114 |
+
model_feature_lengths["decoder_segment_ids"] = decoder_length
|
115 |
+
model_feature_lengths["encoder_positions"] = encoder_length
|
116 |
+
model_feature_lengths["decoder_positions"] = decoder_length
|
117 |
+
|
118 |
+
return model_feature_lengths
|
119 |
+
|
120 |
+
|
121 |
+
class ContinuousInputsEncoderDecoderModel(models.EncoderDecoderModel):
|
122 |
+
"""Encoder-decoder model with continuous inputs."""
|
123 |
+
|
124 |
+
FEATURE_CONVERTER_CLS = ContinuousInputsEncDecFeatureConverter
|
125 |
+
|
126 |
+
def __init__(self, module, input_vocabulary, output_vocabulary, optimizer_def,
|
127 |
+
input_depth, decode_fn=decoding.beam_search, label_smoothing=0.0,
|
128 |
+
z_loss=0.0, loss_normalizing_factor=None):
|
129 |
+
super().__init__(
|
130 |
+
module=module,
|
131 |
+
input_vocabulary=input_vocabulary,
|
132 |
+
output_vocabulary=output_vocabulary,
|
133 |
+
optimizer_def=optimizer_def,
|
134 |
+
decode_fn=decode_fn,
|
135 |
+
label_smoothing=label_smoothing,
|
136 |
+
z_loss=z_loss,
|
137 |
+
loss_normalizing_factor=loss_normalizing_factor)
|
138 |
+
self._input_depth = input_depth
|
139 |
+
|
140 |
+
def get_initial_variables(self, rng, input_shapes, input_types=None):
|
141 |
+
"""Hacky override to bypass eval/infer inability to handle rank-3 inputs."""
|
142 |
+
encoder_shape = input_shapes["encoder_input_tokens"]
|
143 |
+
if len(encoder_shape) == 2:
|
144 |
+
input_shapes = {
|
145 |
+
"encoder_input_tokens": (*encoder_shape, self._input_depth),
|
146 |
+
**{k: v for k, v in input_shapes.items()
|
147 |
+
if k != "encoder_input_tokens"}
|
148 |
+
}
|
149 |
+
else:
|
150 |
+
assert encoder_shape[-1] == self._input_depth
|
151 |
+
return super().get_initial_variables(
|
152 |
+
rng=rng, input_shapes=input_shapes, input_types=input_types)
|
mt3/network.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""T5.1.1 Transformer model."""
|
16 |
+
|
17 |
+
from typing import Any, Sequence
|
18 |
+
|
19 |
+
from flax import linen as nn
|
20 |
+
from flax import struct
|
21 |
+
import jax.numpy as jnp
|
22 |
+
from mt3 import layers
|
23 |
+
|
24 |
+
|
25 |
+
@struct.dataclass
|
26 |
+
class T5Config:
|
27 |
+
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
|
28 |
+
vocab_size: int
|
29 |
+
# Activation dtypes.
|
30 |
+
dtype: Any = jnp.float32
|
31 |
+
emb_dim: int = 512
|
32 |
+
num_heads: int = 8
|
33 |
+
num_encoder_layers: int = 6
|
34 |
+
num_decoder_layers: int = 6
|
35 |
+
head_dim: int = 64
|
36 |
+
mlp_dim: int = 2048
|
37 |
+
# Activation functions are retrieved from Flax.
|
38 |
+
mlp_activations: Sequence[str] = ('relu',)
|
39 |
+
dropout_rate: float = 0.1
|
40 |
+
# If `True`, the embedding weights are used in the decoder output layer.
|
41 |
+
logits_via_embedding: bool = False
|
42 |
+
|
43 |
+
|
44 |
+
class EncoderLayer(nn.Module):
|
45 |
+
"""Transformer encoder layer."""
|
46 |
+
config: T5Config
|
47 |
+
|
48 |
+
@nn.compact
|
49 |
+
def __call__(self, inputs, encoder_mask=None, deterministic=False):
|
50 |
+
cfg = self.config
|
51 |
+
|
52 |
+
# Attention block.
|
53 |
+
assert inputs.ndim == 3
|
54 |
+
x = layers.LayerNorm(
|
55 |
+
dtype=cfg.dtype, name='pre_attention_layer_norm')(
|
56 |
+
inputs)
|
57 |
+
# [batch, length, emb_dim] -> [batch, length, emb_dim]
|
58 |
+
x = layers.MultiHeadDotProductAttention(
|
59 |
+
num_heads=cfg.num_heads,
|
60 |
+
dtype=cfg.dtype,
|
61 |
+
head_dim=cfg.head_dim,
|
62 |
+
dropout_rate=cfg.dropout_rate,
|
63 |
+
name='attention')(
|
64 |
+
x, x, encoder_mask, deterministic=deterministic)
|
65 |
+
x = nn.Dropout(
|
66 |
+
rate=cfg.dropout_rate, broadcast_dims=(-2,))(
|
67 |
+
x, deterministic=deterministic)
|
68 |
+
x = x + inputs
|
69 |
+
|
70 |
+
# MLP block.
|
71 |
+
y = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(x)
|
72 |
+
# [batch, length, emb_dim] -> [batch, length, emb_dim]
|
73 |
+
y = layers.MlpBlock(
|
74 |
+
intermediate_dim=cfg.mlp_dim,
|
75 |
+
activations=cfg.mlp_activations,
|
76 |
+
intermediate_dropout_rate=cfg.dropout_rate,
|
77 |
+
dtype=cfg.dtype,
|
78 |
+
name='mlp',
|
79 |
+
)(y, deterministic=deterministic)
|
80 |
+
y = nn.Dropout(
|
81 |
+
rate=cfg.dropout_rate, broadcast_dims=(-2,))(
|
82 |
+
y, deterministic=deterministic)
|
83 |
+
y = y + x
|
84 |
+
|
85 |
+
return y
|
86 |
+
|
87 |
+
|
88 |
+
class DecoderLayer(nn.Module):
|
89 |
+
"""Transformer decoder layer that attends to the encoder."""
|
90 |
+
config: T5Config
|
91 |
+
|
92 |
+
@nn.compact
|
93 |
+
def __call__(self,
|
94 |
+
inputs,
|
95 |
+
encoded,
|
96 |
+
decoder_mask=None,
|
97 |
+
encoder_decoder_mask=None,
|
98 |
+
deterministic=False,
|
99 |
+
decode=False,
|
100 |
+
max_decode_length=None):
|
101 |
+
cfg = self.config
|
102 |
+
|
103 |
+
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
|
104 |
+
x = layers.LayerNorm(
|
105 |
+
dtype=cfg.dtype, name='pre_self_attention_layer_norm')(
|
106 |
+
inputs)
|
107 |
+
|
108 |
+
# Self-attention block
|
109 |
+
x = layers.MultiHeadDotProductAttention(
|
110 |
+
num_heads=cfg.num_heads,
|
111 |
+
dtype=cfg.dtype,
|
112 |
+
head_dim=cfg.head_dim,
|
113 |
+
dropout_rate=cfg.dropout_rate,
|
114 |
+
name='self_attention')(
|
115 |
+
x,
|
116 |
+
x,
|
117 |
+
decoder_mask,
|
118 |
+
deterministic=deterministic,
|
119 |
+
decode=decode)
|
120 |
+
x = nn.Dropout(
|
121 |
+
rate=cfg.dropout_rate, broadcast_dims=(-2,))(
|
122 |
+
x, deterministic=deterministic)
|
123 |
+
x = x + inputs
|
124 |
+
|
125 |
+
# Encoder-Decoder block.
|
126 |
+
y = layers.LayerNorm(
|
127 |
+
dtype=cfg.dtype, name='pre_cross_attention_layer_norm')(
|
128 |
+
x)
|
129 |
+
y = layers.MultiHeadDotProductAttention(
|
130 |
+
num_heads=cfg.num_heads,
|
131 |
+
dtype=cfg.dtype,
|
132 |
+
head_dim=cfg.head_dim,
|
133 |
+
dropout_rate=cfg.dropout_rate,
|
134 |
+
name='encoder_decoder_attention')(
|
135 |
+
y, encoded, encoder_decoder_mask, deterministic=deterministic)
|
136 |
+
y = nn.Dropout(
|
137 |
+
rate=cfg.dropout_rate, broadcast_dims=(-2,))(
|
138 |
+
y, deterministic=deterministic)
|
139 |
+
y = y + x
|
140 |
+
|
141 |
+
# MLP block.
|
142 |
+
z = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(y)
|
143 |
+
z = layers.MlpBlock(
|
144 |
+
intermediate_dim=cfg.mlp_dim,
|
145 |
+
activations=cfg.mlp_activations,
|
146 |
+
intermediate_dropout_rate=cfg.dropout_rate,
|
147 |
+
dtype=cfg.dtype,
|
148 |
+
name='mlp',
|
149 |
+
)(z, deterministic=deterministic)
|
150 |
+
z = nn.Dropout(
|
151 |
+
rate=cfg.dropout_rate, broadcast_dims=(-2,))(
|
152 |
+
z, deterministic=deterministic)
|
153 |
+
z = z + y
|
154 |
+
|
155 |
+
return z
|
156 |
+
|
157 |
+
|
158 |
+
class Encoder(nn.Module):
|
159 |
+
"""A stack of encoder layers."""
|
160 |
+
config: T5Config
|
161 |
+
|
162 |
+
@nn.compact
|
163 |
+
def __call__(self,
|
164 |
+
encoder_input_tokens,
|
165 |
+
encoder_mask=None,
|
166 |
+
deterministic=False):
|
167 |
+
cfg = self.config
|
168 |
+
assert encoder_input_tokens.ndim == 3 # [batch, length, depth]
|
169 |
+
|
170 |
+
seq_length = encoder_input_tokens.shape[-2]
|
171 |
+
inputs_positions = jnp.arange(seq_length)[None, :]
|
172 |
+
|
173 |
+
# [batch, length, depth] -> [batch, length, emb_dim]
|
174 |
+
x = layers.DenseGeneral(
|
175 |
+
cfg.emb_dim,
|
176 |
+
dtype=cfg.dtype,
|
177 |
+
kernel_init=nn.linear.default_kernel_init,
|
178 |
+
kernel_axes=('vocab', 'embed'),
|
179 |
+
name='continuous_inputs_projection')(encoder_input_tokens)
|
180 |
+
x = x + layers.FixedEmbed(features=cfg.emb_dim)(inputs_positions)
|
181 |
+
x = nn.Dropout(
|
182 |
+
rate=cfg.dropout_rate, broadcast_dims=(-2,))(
|
183 |
+
x, deterministic=deterministic)
|
184 |
+
x = x.astype(cfg.dtype)
|
185 |
+
|
186 |
+
for lyr in range(cfg.num_encoder_layers):
|
187 |
+
# [batch, length, emb_dim] -> [batch, length, emb_dim]
|
188 |
+
x = EncoderLayer(
|
189 |
+
config=cfg,
|
190 |
+
name=f'layers_{lyr}')(x, encoder_mask, deterministic)
|
191 |
+
|
192 |
+
x = layers.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x)
|
193 |
+
return nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic)
|
194 |
+
|
195 |
+
|
196 |
+
class Decoder(nn.Module):
|
197 |
+
"""A stack of decoder layers as a part of an encoder-decoder architecture."""
|
198 |
+
config: T5Config
|
199 |
+
|
200 |
+
@nn.compact
|
201 |
+
def __call__(self,
|
202 |
+
encoded,
|
203 |
+
decoder_input_tokens,
|
204 |
+
decoder_positions=None,
|
205 |
+
decoder_mask=None,
|
206 |
+
encoder_decoder_mask=None,
|
207 |
+
deterministic=False,
|
208 |
+
decode=False,
|
209 |
+
max_decode_length=None):
|
210 |
+
cfg = self.config
|
211 |
+
assert decoder_input_tokens.ndim == 2 # [batch, len]
|
212 |
+
|
213 |
+
seq_length = decoder_input_tokens.shape[-1]
|
214 |
+
decoder_positions = jnp.arange(seq_length)[None, :]
|
215 |
+
|
216 |
+
# [batch, length] -> [batch, length, emb_dim]
|
217 |
+
y = layers.Embed(
|
218 |
+
num_embeddings=cfg.vocab_size,
|
219 |
+
features=cfg.emb_dim,
|
220 |
+
dtype=cfg.dtype,
|
221 |
+
attend_dtype=jnp.float32, # for logit training stability
|
222 |
+
embedding_init=nn.initializers.normal(stddev=1.0),
|
223 |
+
one_hot=True,
|
224 |
+
name='token_embedder')(decoder_input_tokens.astype('int32'))
|
225 |
+
y = y + layers.FixedEmbed(features=cfg.emb_dim)(
|
226 |
+
decoder_positions, decode=decode)
|
227 |
+
y = nn.Dropout(
|
228 |
+
rate=cfg.dropout_rate, broadcast_dims=(-2,))(
|
229 |
+
y, deterministic=deterministic)
|
230 |
+
y = y.astype(cfg.dtype)
|
231 |
+
|
232 |
+
for lyr in range(cfg.num_decoder_layers):
|
233 |
+
# [batch, length, emb_dim] -> [batch, length, emb_dim]
|
234 |
+
y = DecoderLayer(
|
235 |
+
config=cfg, name=f'layers_{lyr}')(
|
236 |
+
y,
|
237 |
+
encoded,
|
238 |
+
decoder_mask=decoder_mask,
|
239 |
+
encoder_decoder_mask=encoder_decoder_mask,
|
240 |
+
deterministic=deterministic,
|
241 |
+
decode=decode,
|
242 |
+
max_decode_length=max_decode_length)
|
243 |
+
|
244 |
+
y = layers.LayerNorm(dtype=cfg.dtype, name='decoder_norm')(y)
|
245 |
+
y = nn.Dropout(
|
246 |
+
rate=cfg.dropout_rate, broadcast_dims=(-2,))(
|
247 |
+
y, deterministic=deterministic)
|
248 |
+
|
249 |
+
# [batch, length, emb_dim] -> [batch, length, vocab_size]
|
250 |
+
if cfg.logits_via_embedding:
|
251 |
+
# Use the transpose of embedding matrix for logit transform.
|
252 |
+
logits = self.shared_embedding.attend(y)
|
253 |
+
# Correctly normalize pre-softmax logits for this shared case.
|
254 |
+
logits = logits / jnp.sqrt(y.shape[-1])
|
255 |
+
else:
|
256 |
+
logits = layers.DenseGeneral(
|
257 |
+
cfg.vocab_size,
|
258 |
+
dtype=jnp.float32, # Use float32 for stabiliity.
|
259 |
+
kernel_axes=('embed', 'vocab'),
|
260 |
+
name='logits_dense')(
|
261 |
+
y)
|
262 |
+
return logits
|
263 |
+
|
264 |
+
|
265 |
+
class Transformer(nn.Module):
|
266 |
+
"""An encoder-decoder Transformer model."""
|
267 |
+
config: T5Config
|
268 |
+
|
269 |
+
def setup(self):
|
270 |
+
cfg = self.config
|
271 |
+
|
272 |
+
self.encoder = Encoder(config=cfg)
|
273 |
+
self.decoder = Decoder(config=cfg)
|
274 |
+
|
275 |
+
def encode(self,
|
276 |
+
encoder_input_tokens,
|
277 |
+
encoder_segment_ids=None,
|
278 |
+
enable_dropout=True):
|
279 |
+
"""Applies Transformer encoder-branch on the inputs."""
|
280 |
+
cfg = self.config
|
281 |
+
assert encoder_input_tokens.ndim == 3 # (batch, length, depth)
|
282 |
+
|
283 |
+
# Make padding attention mask; we don't actually mask out any input
|
284 |
+
# positions, letting the model potentially attend to the zero vector used as
|
285 |
+
# padding.
|
286 |
+
encoder_mask = layers.make_attention_mask(
|
287 |
+
jnp.ones(encoder_input_tokens.shape[:-1]),
|
288 |
+
jnp.ones(encoder_input_tokens.shape[:-1]),
|
289 |
+
dtype=cfg.dtype)
|
290 |
+
# Add segmentation block-diagonal attention mask if using segmented data.
|
291 |
+
if encoder_segment_ids is not None:
|
292 |
+
encoder_mask = layers.combine_masks(
|
293 |
+
encoder_mask,
|
294 |
+
layers.make_attention_mask(
|
295 |
+
encoder_segment_ids,
|
296 |
+
encoder_segment_ids,
|
297 |
+
jnp.equal,
|
298 |
+
dtype=cfg.dtype))
|
299 |
+
|
300 |
+
return self.encoder(
|
301 |
+
encoder_input_tokens, encoder_mask, deterministic=not enable_dropout)
|
302 |
+
|
303 |
+
def decode(
|
304 |
+
self,
|
305 |
+
encoded,
|
306 |
+
encoder_input_tokens, # only needed for masks
|
307 |
+
decoder_input_tokens,
|
308 |
+
decoder_target_tokens,
|
309 |
+
encoder_segment_ids=None,
|
310 |
+
decoder_segment_ids=None,
|
311 |
+
decoder_positions=None,
|
312 |
+
enable_dropout=True,
|
313 |
+
decode=False,
|
314 |
+
max_decode_length=None):
|
315 |
+
"""Applies Transformer decoder-branch on encoded-input and target."""
|
316 |
+
cfg = self.config
|
317 |
+
|
318 |
+
# Make padding attention masks.
|
319 |
+
if decode:
|
320 |
+
# Do not mask decoder attention based on targets padding at
|
321 |
+
# decoding/inference time.
|
322 |
+
decoder_mask = None
|
323 |
+
encoder_decoder_mask = layers.make_attention_mask(
|
324 |
+
jnp.ones_like(decoder_target_tokens),
|
325 |
+
jnp.ones(encoder_input_tokens.shape[:-1]),
|
326 |
+
dtype=cfg.dtype)
|
327 |
+
else:
|
328 |
+
decoder_mask = layers.make_decoder_mask(
|
329 |
+
decoder_target_tokens=decoder_target_tokens,
|
330 |
+
dtype=cfg.dtype,
|
331 |
+
decoder_segment_ids=decoder_segment_ids)
|
332 |
+
encoder_decoder_mask = layers.make_attention_mask(
|
333 |
+
decoder_target_tokens > 0,
|
334 |
+
jnp.ones(encoder_input_tokens.shape[:-1]),
|
335 |
+
dtype=cfg.dtype)
|
336 |
+
|
337 |
+
# Add segmentation block-diagonal attention masks if using segmented data.
|
338 |
+
if encoder_segment_ids is not None:
|
339 |
+
if decode:
|
340 |
+
raise ValueError(
|
341 |
+
'During decoding, packing should not be used but '
|
342 |
+
'`encoder_segment_ids` was passed to `Transformer.decode`.')
|
343 |
+
|
344 |
+
encoder_decoder_mask = layers.combine_masks(
|
345 |
+
encoder_decoder_mask,
|
346 |
+
layers.make_attention_mask(
|
347 |
+
decoder_segment_ids,
|
348 |
+
encoder_segment_ids,
|
349 |
+
jnp.equal,
|
350 |
+
dtype=cfg.dtype))
|
351 |
+
|
352 |
+
logits = self.decoder(
|
353 |
+
encoded,
|
354 |
+
decoder_input_tokens=decoder_input_tokens,
|
355 |
+
decoder_positions=decoder_positions,
|
356 |
+
decoder_mask=decoder_mask,
|
357 |
+
encoder_decoder_mask=encoder_decoder_mask,
|
358 |
+
deterministic=not enable_dropout,
|
359 |
+
decode=decode,
|
360 |
+
max_decode_length=max_decode_length)
|
361 |
+
return logits.astype(self.config.dtype)
|
362 |
+
|
363 |
+
def __call__(self,
|
364 |
+
encoder_input_tokens,
|
365 |
+
decoder_input_tokens,
|
366 |
+
decoder_target_tokens,
|
367 |
+
encoder_segment_ids=None,
|
368 |
+
decoder_segment_ids=None,
|
369 |
+
encoder_positions=None,
|
370 |
+
decoder_positions=None,
|
371 |
+
*,
|
372 |
+
enable_dropout: bool = True,
|
373 |
+
decode: bool = False):
|
374 |
+
"""Applies Transformer model on the inputs.
|
375 |
+
|
376 |
+
This method requires both decoder_target_tokens and decoder_input_tokens,
|
377 |
+
which is a shifted version of the former. For a packed dataset, it usually
|
378 |
+
has additional processing applied. For example, the first element of each
|
379 |
+
sequence has id 0 instead of the shifted EOS id from the previous sequence.
|
380 |
+
|
381 |
+
Args:
|
382 |
+
encoder_input_tokens: input data to the encoder.
|
383 |
+
decoder_input_tokens: input token to the decoder.
|
384 |
+
decoder_target_tokens: target token to the decoder.
|
385 |
+
encoder_segment_ids: encoder segmentation info for packed examples.
|
386 |
+
decoder_segment_ids: decoder segmentation info for packed examples.
|
387 |
+
encoder_positions: encoder subsequence positions for packed examples.
|
388 |
+
decoder_positions: decoder subsequence positions for packed examples.
|
389 |
+
enable_dropout: Ensables dropout if set to True.
|
390 |
+
decode: Whether to prepare and use an autoregressive cache.
|
391 |
+
|
392 |
+
Returns:
|
393 |
+
logits array from full transformer.
|
394 |
+
"""
|
395 |
+
encoded = self.encode(
|
396 |
+
encoder_input_tokens,
|
397 |
+
encoder_segment_ids=encoder_segment_ids,
|
398 |
+
enable_dropout=enable_dropout)
|
399 |
+
|
400 |
+
return self.decode(
|
401 |
+
encoded,
|
402 |
+
encoder_input_tokens, # only used for masks
|
403 |
+
decoder_input_tokens,
|
404 |
+
decoder_target_tokens,
|
405 |
+
encoder_segment_ids=encoder_segment_ids,
|
406 |
+
decoder_segment_ids=decoder_segment_ids,
|
407 |
+
decoder_positions=decoder_positions,
|
408 |
+
enable_dropout=enable_dropout,
|
409 |
+
decode=decode)
|
mt3/note_sequences.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Helper functions that operate on NoteSequence protos."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
import itertools
|
19 |
+
|
20 |
+
from typing import MutableMapping, MutableSet, Optional, Sequence, Tuple
|
21 |
+
|
22 |
+
from mt3 import event_codec
|
23 |
+
from mt3 import run_length_encoding
|
24 |
+
from mt3 import vocabularies
|
25 |
+
|
26 |
+
import note_seq
|
27 |
+
|
28 |
+
DEFAULT_VELOCITY = 100
|
29 |
+
DEFAULT_NOTE_DURATION = 0.01
|
30 |
+
|
31 |
+
# Quantization can result in zero-length notes; enforce a minimum duration.
|
32 |
+
MIN_NOTE_DURATION = 0.01
|
33 |
+
|
34 |
+
|
35 |
+
@dataclasses.dataclass
|
36 |
+
class TrackSpec:
|
37 |
+
name: str
|
38 |
+
program: int = 0
|
39 |
+
is_drum: bool = False
|
40 |
+
|
41 |
+
|
42 |
+
def extract_track(ns, program, is_drum):
|
43 |
+
track = note_seq.NoteSequence(ticks_per_quarter=220)
|
44 |
+
track_notes = [note for note in ns.notes
|
45 |
+
if note.program == program and note.is_drum == is_drum]
|
46 |
+
track.notes.extend(track_notes)
|
47 |
+
track.total_time = (max(note.end_time for note in track.notes)
|
48 |
+
if track.notes else 0.0)
|
49 |
+
return track
|
50 |
+
|
51 |
+
|
52 |
+
def trim_overlapping_notes(ns: note_seq.NoteSequence) -> note_seq.NoteSequence:
|
53 |
+
"""Trim overlapping notes from a NoteSequence, dropping zero-length notes."""
|
54 |
+
ns_trimmed = note_seq.NoteSequence()
|
55 |
+
ns_trimmed.CopyFrom(ns)
|
56 |
+
channels = set((note.pitch, note.program, note.is_drum)
|
57 |
+
for note in ns_trimmed.notes)
|
58 |
+
for pitch, program, is_drum in channels:
|
59 |
+
notes = [note for note in ns_trimmed.notes if note.pitch == pitch
|
60 |
+
and note.program == program and note.is_drum == is_drum]
|
61 |
+
sorted_notes = sorted(notes, key=lambda note: note.start_time)
|
62 |
+
for i in range(1, len(sorted_notes)):
|
63 |
+
if sorted_notes[i - 1].end_time > sorted_notes[i].start_time:
|
64 |
+
sorted_notes[i - 1].end_time = sorted_notes[i].start_time
|
65 |
+
valid_notes = [note for note in ns_trimmed.notes
|
66 |
+
if note.start_time < note.end_time]
|
67 |
+
del ns_trimmed.notes[:]
|
68 |
+
ns_trimmed.notes.extend(valid_notes)
|
69 |
+
return ns_trimmed
|
70 |
+
|
71 |
+
|
72 |
+
def assign_instruments(ns: note_seq.NoteSequence) -> None:
|
73 |
+
"""Assign instrument numbers to notes; modifies NoteSequence in place."""
|
74 |
+
program_instruments = {}
|
75 |
+
for note in ns.notes:
|
76 |
+
if note.program not in program_instruments and not note.is_drum:
|
77 |
+
num_instruments = len(program_instruments)
|
78 |
+
note.instrument = (num_instruments if num_instruments < 9
|
79 |
+
else num_instruments + 1)
|
80 |
+
program_instruments[note.program] = note.instrument
|
81 |
+
elif note.is_drum:
|
82 |
+
note.instrument = 9
|
83 |
+
else:
|
84 |
+
note.instrument = program_instruments[note.program]
|
85 |
+
|
86 |
+
|
87 |
+
def validate_note_sequence(ns: note_seq.NoteSequence) -> None:
|
88 |
+
"""Raise ValueError if NoteSequence contains invalid notes."""
|
89 |
+
for note in ns.notes:
|
90 |
+
if note.start_time >= note.end_time:
|
91 |
+
raise ValueError('note has start time >= end time: %f >= %f' %
|
92 |
+
(note.start_time, note.end_time))
|
93 |
+
if note.velocity == 0:
|
94 |
+
raise ValueError('note has zero velocity')
|
95 |
+
|
96 |
+
|
97 |
+
def note_arrays_to_note_sequence(
|
98 |
+
onset_times: Sequence[float],
|
99 |
+
pitches: Sequence[int],
|
100 |
+
offset_times: Optional[Sequence[float]] = None,
|
101 |
+
velocities: Optional[Sequence[int]] = None,
|
102 |
+
programs: Optional[Sequence[int]] = None,
|
103 |
+
is_drums: Optional[Sequence[bool]] = None
|
104 |
+
) -> note_seq.NoteSequence:
|
105 |
+
"""Convert note onset / offset / pitch / velocity arrays to NoteSequence."""
|
106 |
+
ns = note_seq.NoteSequence(ticks_per_quarter=220)
|
107 |
+
for onset_time, offset_time, pitch, velocity, program, is_drum in itertools.zip_longest(
|
108 |
+
onset_times, [] if offset_times is None else offset_times,
|
109 |
+
pitches, [] if velocities is None else velocities,
|
110 |
+
[] if programs is None else programs,
|
111 |
+
[] if is_drums is None else is_drums):
|
112 |
+
if offset_time is None:
|
113 |
+
offset_time = onset_time + DEFAULT_NOTE_DURATION
|
114 |
+
if velocity is None:
|
115 |
+
velocity = DEFAULT_VELOCITY
|
116 |
+
if program is None:
|
117 |
+
program = 0
|
118 |
+
if is_drum is None:
|
119 |
+
is_drum = False
|
120 |
+
ns.notes.add(
|
121 |
+
start_time=onset_time,
|
122 |
+
end_time=offset_time,
|
123 |
+
pitch=pitch,
|
124 |
+
velocity=velocity,
|
125 |
+
program=program,
|
126 |
+
is_drum=is_drum)
|
127 |
+
ns.total_time = max(ns.total_time, offset_time)
|
128 |
+
assign_instruments(ns)
|
129 |
+
return ns
|
130 |
+
|
131 |
+
|
132 |
+
@dataclasses.dataclass
|
133 |
+
class NoteEventData:
|
134 |
+
pitch: int
|
135 |
+
velocity: Optional[int] = None
|
136 |
+
program: Optional[int] = None
|
137 |
+
is_drum: Optional[bool] = None
|
138 |
+
instrument: Optional[int] = None
|
139 |
+
|
140 |
+
|
141 |
+
def note_sequence_to_onsets(
|
142 |
+
ns: note_seq.NoteSequence
|
143 |
+
) -> Tuple[Sequence[float], Sequence[NoteEventData]]:
|
144 |
+
"""Extract note onsets and pitches from NoteSequence proto."""
|
145 |
+
# Sort by pitch to use as a tiebreaker for subsequent stable sort.
|
146 |
+
notes = sorted(ns.notes, key=lambda note: note.pitch)
|
147 |
+
return ([note.start_time for note in notes],
|
148 |
+
[NoteEventData(pitch=note.pitch) for note in notes])
|
149 |
+
|
150 |
+
|
151 |
+
def note_sequence_to_onsets_and_offsets(
|
152 |
+
ns: note_seq.NoteSequence,
|
153 |
+
) -> Tuple[Sequence[float], Sequence[NoteEventData]]:
|
154 |
+
"""Extract onset & offset times and pitches from a NoteSequence proto.
|
155 |
+
|
156 |
+
The onset & offset times will not necessarily be in sorted order.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
ns: NoteSequence from which to extract onsets and offsets.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
times: A list of note onset and offset times.
|
163 |
+
values: A list of NoteEventData objects where velocity is zero for note
|
164 |
+
offsets.
|
165 |
+
"""
|
166 |
+
# Sort by pitch and put offsets before onsets as a tiebreaker for subsequent
|
167 |
+
# stable sort.
|
168 |
+
notes = sorted(ns.notes, key=lambda note: note.pitch)
|
169 |
+
times = ([note.end_time for note in notes] +
|
170 |
+
[note.start_time for note in notes])
|
171 |
+
values = ([NoteEventData(pitch=note.pitch, velocity=0) for note in notes] +
|
172 |
+
[NoteEventData(pitch=note.pitch, velocity=note.velocity)
|
173 |
+
for note in notes])
|
174 |
+
return times, values
|
175 |
+
|
176 |
+
|
177 |
+
def note_sequence_to_onsets_and_offsets_and_programs(
|
178 |
+
ns: note_seq.NoteSequence,
|
179 |
+
) -> Tuple[Sequence[float], Sequence[NoteEventData]]:
|
180 |
+
"""Extract onset & offset times and pitches & programs from a NoteSequence.
|
181 |
+
|
182 |
+
The onset & offset times will not necessarily be in sorted order.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
ns: NoteSequence from which to extract onsets and offsets.
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
times: A list of note onset and offset times.
|
189 |
+
values: A list of NoteEventData objects where velocity is zero for note
|
190 |
+
offsets.
|
191 |
+
"""
|
192 |
+
# Sort by program and pitch and put offsets before onsets as a tiebreaker for
|
193 |
+
# subsequent stable sort.
|
194 |
+
notes = sorted(ns.notes,
|
195 |
+
key=lambda note: (note.is_drum, note.program, note.pitch))
|
196 |
+
times = ([note.end_time for note in notes if not note.is_drum] +
|
197 |
+
[note.start_time for note in notes])
|
198 |
+
values = ([NoteEventData(pitch=note.pitch, velocity=0,
|
199 |
+
program=note.program, is_drum=False)
|
200 |
+
for note in notes if not note.is_drum] +
|
201 |
+
[NoteEventData(pitch=note.pitch, velocity=note.velocity,
|
202 |
+
program=note.program, is_drum=note.is_drum)
|
203 |
+
for note in notes])
|
204 |
+
return times, values
|
205 |
+
|
206 |
+
|
207 |
+
@dataclasses.dataclass
|
208 |
+
class NoteEncodingState:
|
209 |
+
"""Encoding state for note transcription, keeping track of active pitches."""
|
210 |
+
# velocity bin for active pitches and programs
|
211 |
+
active_pitches: MutableMapping[Tuple[int, int], int] = dataclasses.field(
|
212 |
+
default_factory=dict)
|
213 |
+
|
214 |
+
|
215 |
+
def note_event_data_to_events(
|
216 |
+
state: Optional[NoteEncodingState],
|
217 |
+
value: NoteEventData,
|
218 |
+
codec: event_codec.Codec,
|
219 |
+
) -> Sequence[event_codec.Event]:
|
220 |
+
"""Convert note event data to a sequence of events."""
|
221 |
+
if value.velocity is None:
|
222 |
+
# onsets only, no program or velocity
|
223 |
+
return [event_codec.Event('pitch', value.pitch)]
|
224 |
+
else:
|
225 |
+
num_velocity_bins = vocabularies.num_velocity_bins_from_codec(codec)
|
226 |
+
velocity_bin = vocabularies.velocity_to_bin(
|
227 |
+
value.velocity, num_velocity_bins)
|
228 |
+
if value.program is None:
|
229 |
+
# onsets + offsets + velocities only, no programs
|
230 |
+
if state is not None:
|
231 |
+
state.active_pitches[(value.pitch, 0)] = velocity_bin
|
232 |
+
return [event_codec.Event('velocity', velocity_bin),
|
233 |
+
event_codec.Event('pitch', value.pitch)]
|
234 |
+
else:
|
235 |
+
if value.is_drum:
|
236 |
+
# drum events use a separate vocabulary
|
237 |
+
return [event_codec.Event('velocity', velocity_bin),
|
238 |
+
event_codec.Event('drum', value.pitch)]
|
239 |
+
else:
|
240 |
+
# program + velocity + pitch
|
241 |
+
if state is not None:
|
242 |
+
state.active_pitches[(value.pitch, value.program)] = velocity_bin
|
243 |
+
return [event_codec.Event('program', value.program),
|
244 |
+
event_codec.Event('velocity', velocity_bin),
|
245 |
+
event_codec.Event('pitch', value.pitch)]
|
246 |
+
|
247 |
+
|
248 |
+
def note_encoding_state_to_events(
|
249 |
+
state: NoteEncodingState
|
250 |
+
) -> Sequence[event_codec.Event]:
|
251 |
+
"""Output program and pitch events for active notes plus a final tie event."""
|
252 |
+
events = []
|
253 |
+
for pitch, program in sorted(
|
254 |
+
state.active_pitches.keys(), key=lambda k: k[::-1]):
|
255 |
+
if state.active_pitches[(pitch, program)]:
|
256 |
+
events += [event_codec.Event('program', program),
|
257 |
+
event_codec.Event('pitch', pitch)]
|
258 |
+
events.append(event_codec.Event('tie', 0))
|
259 |
+
return events
|
260 |
+
|
261 |
+
|
262 |
+
@dataclasses.dataclass
|
263 |
+
class NoteDecodingState:
|
264 |
+
"""Decoding state for note transcription."""
|
265 |
+
current_time: float = 0.0
|
266 |
+
# velocity to apply to subsequent pitch events (zero for note-off)
|
267 |
+
current_velocity: int = DEFAULT_VELOCITY
|
268 |
+
# program to apply to subsequent pitch events
|
269 |
+
current_program: int = 0
|
270 |
+
# onset time and velocity for active pitches and programs
|
271 |
+
active_pitches: MutableMapping[Tuple[int, int],
|
272 |
+
Tuple[float, int]] = dataclasses.field(
|
273 |
+
default_factory=dict)
|
274 |
+
# pitches (with programs) to continue from previous segment
|
275 |
+
tied_pitches: MutableSet[Tuple[int, int]] = dataclasses.field(
|
276 |
+
default_factory=set)
|
277 |
+
# whether or not we are in the tie section at the beginning of a segment
|
278 |
+
is_tie_section: bool = False
|
279 |
+
# partially-decoded NoteSequence
|
280 |
+
note_sequence: note_seq.NoteSequence = dataclasses.field(
|
281 |
+
default_factory=lambda: note_seq.NoteSequence(ticks_per_quarter=220))
|
282 |
+
|
283 |
+
|
284 |
+
def decode_note_onset_event(
|
285 |
+
state: NoteDecodingState,
|
286 |
+
time: float,
|
287 |
+
event: event_codec.Event,
|
288 |
+
codec: event_codec.Codec,
|
289 |
+
) -> None:
|
290 |
+
"""Process note onset event and update decoding state."""
|
291 |
+
if event.type == 'pitch':
|
292 |
+
state.note_sequence.notes.add(
|
293 |
+
start_time=time, end_time=time + DEFAULT_NOTE_DURATION,
|
294 |
+
pitch=event.value, velocity=DEFAULT_VELOCITY)
|
295 |
+
state.note_sequence.total_time = max(state.note_sequence.total_time,
|
296 |
+
time + DEFAULT_NOTE_DURATION)
|
297 |
+
else:
|
298 |
+
raise ValueError('unexpected event type: %s' % event.type)
|
299 |
+
|
300 |
+
|
301 |
+
def _add_note_to_sequence(
|
302 |
+
ns: note_seq.NoteSequence,
|
303 |
+
start_time: float, end_time: float, pitch: int, velocity: int,
|
304 |
+
program: int = 0, is_drum: bool = False
|
305 |
+
) -> None:
|
306 |
+
end_time = max(end_time, start_time + MIN_NOTE_DURATION)
|
307 |
+
ns.notes.add(
|
308 |
+
start_time=start_time, end_time=end_time,
|
309 |
+
pitch=pitch, velocity=velocity, program=program, is_drum=is_drum)
|
310 |
+
ns.total_time = max(ns.total_time, end_time)
|
311 |
+
|
312 |
+
|
313 |
+
def decode_note_event(
|
314 |
+
state: NoteDecodingState,
|
315 |
+
time: float,
|
316 |
+
event: event_codec.Event,
|
317 |
+
codec: event_codec.Codec
|
318 |
+
) -> None:
|
319 |
+
"""Process note event and update decoding state."""
|
320 |
+
if time < state.current_time:
|
321 |
+
raise ValueError('event time < current time, %f < %f' % (
|
322 |
+
time, state.current_time))
|
323 |
+
state.current_time = time
|
324 |
+
if event.type == 'pitch':
|
325 |
+
pitch = event.value
|
326 |
+
if state.is_tie_section:
|
327 |
+
# "tied" pitch
|
328 |
+
if (pitch, state.current_program) not in state.active_pitches:
|
329 |
+
raise ValueError('inactive pitch/program in tie section: %d/%d' %
|
330 |
+
(pitch, state.current_program))
|
331 |
+
if (pitch, state.current_program) in state.tied_pitches:
|
332 |
+
raise ValueError('pitch/program is already tied: %d/%d' %
|
333 |
+
(pitch, state.current_program))
|
334 |
+
state.tied_pitches.add((pitch, state.current_program))
|
335 |
+
elif state.current_velocity == 0:
|
336 |
+
# note offset
|
337 |
+
if (pitch, state.current_program) not in state.active_pitches:
|
338 |
+
raise ValueError('note-off for inactive pitch/program: %d/%d' %
|
339 |
+
(pitch, state.current_program))
|
340 |
+
onset_time, onset_velocity = state.active_pitches.pop(
|
341 |
+
(pitch, state.current_program))
|
342 |
+
_add_note_to_sequence(
|
343 |
+
state.note_sequence, start_time=onset_time, end_time=time,
|
344 |
+
pitch=pitch, velocity=onset_velocity, program=state.current_program)
|
345 |
+
else:
|
346 |
+
# note onset
|
347 |
+
if (pitch, state.current_program) in state.active_pitches:
|
348 |
+
# The pitch is already active; this shouldn't really happen but we'll
|
349 |
+
# try to handle it gracefully by ending the previous note and starting a
|
350 |
+
# new one.
|
351 |
+
onset_time, onset_velocity = state.active_pitches.pop(
|
352 |
+
(pitch, state.current_program))
|
353 |
+
_add_note_to_sequence(
|
354 |
+
state.note_sequence, start_time=onset_time, end_time=time,
|
355 |
+
pitch=pitch, velocity=onset_velocity, program=state.current_program)
|
356 |
+
state.active_pitches[(pitch, state.current_program)] = (
|
357 |
+
time, state.current_velocity)
|
358 |
+
elif event.type == 'drum':
|
359 |
+
# drum onset (drums have no offset)
|
360 |
+
if state.current_velocity == 0:
|
361 |
+
raise ValueError('velocity cannot be zero for drum event')
|
362 |
+
offset_time = time + DEFAULT_NOTE_DURATION
|
363 |
+
_add_note_to_sequence(
|
364 |
+
state.note_sequence, start_time=time, end_time=offset_time,
|
365 |
+
pitch=event.value, velocity=state.current_velocity, is_drum=True)
|
366 |
+
elif event.type == 'velocity':
|
367 |
+
# velocity change
|
368 |
+
num_velocity_bins = vocabularies.num_velocity_bins_from_codec(codec)
|
369 |
+
velocity = vocabularies.bin_to_velocity(event.value, num_velocity_bins)
|
370 |
+
state.current_velocity = velocity
|
371 |
+
elif event.type == 'program':
|
372 |
+
# program change
|
373 |
+
state.current_program = event.value
|
374 |
+
elif event.type == 'tie':
|
375 |
+
# end of tie section; end active notes that weren't declared tied
|
376 |
+
if not state.is_tie_section:
|
377 |
+
raise ValueError('tie section end event when not in tie section')
|
378 |
+
for (pitch, program) in list(state.active_pitches.keys()):
|
379 |
+
if (pitch, program) not in state.tied_pitches:
|
380 |
+
onset_time, onset_velocity = state.active_pitches.pop((pitch, program))
|
381 |
+
_add_note_to_sequence(
|
382 |
+
state.note_sequence,
|
383 |
+
start_time=onset_time, end_time=state.current_time,
|
384 |
+
pitch=pitch, velocity=onset_velocity, program=program)
|
385 |
+
state.is_tie_section = False
|
386 |
+
else:
|
387 |
+
raise ValueError('unexpected event type: %s' % event.type)
|
388 |
+
|
389 |
+
|
390 |
+
def begin_tied_pitches_section(state: NoteDecodingState) -> None:
|
391 |
+
"""Begin the tied pitches section at the start of a segment."""
|
392 |
+
state.tied_pitches = set()
|
393 |
+
state.is_tie_section = True
|
394 |
+
|
395 |
+
|
396 |
+
def flush_note_decoding_state(
|
397 |
+
state: NoteDecodingState
|
398 |
+
) -> note_seq.NoteSequence:
|
399 |
+
"""End all active notes and return resulting NoteSequence."""
|
400 |
+
for onset_time, _ in state.active_pitches.values():
|
401 |
+
state.current_time = max(state.current_time, onset_time + MIN_NOTE_DURATION)
|
402 |
+
for (pitch, program) in list(state.active_pitches.keys()):
|
403 |
+
onset_time, onset_velocity = state.active_pitches.pop((pitch, program))
|
404 |
+
_add_note_to_sequence(
|
405 |
+
state.note_sequence, start_time=onset_time, end_time=state.current_time,
|
406 |
+
pitch=pitch, velocity=onset_velocity, program=program)
|
407 |
+
assign_instruments(state.note_sequence)
|
408 |
+
return state.note_sequence
|
409 |
+
|
410 |
+
|
411 |
+
class NoteEncodingSpecType(run_length_encoding.EventEncodingSpec):
|
412 |
+
pass
|
413 |
+
|
414 |
+
|
415 |
+
# encoding spec for modeling note onsets only
|
416 |
+
NoteOnsetEncodingSpec = NoteEncodingSpecType(
|
417 |
+
init_encoding_state_fn=lambda: None,
|
418 |
+
encode_event_fn=note_event_data_to_events,
|
419 |
+
encoding_state_to_events_fn=None,
|
420 |
+
init_decoding_state_fn=NoteDecodingState,
|
421 |
+
begin_decoding_segment_fn=lambda state: None,
|
422 |
+
decode_event_fn=decode_note_onset_event,
|
423 |
+
flush_decoding_state_fn=lambda state: state.note_sequence)
|
424 |
+
|
425 |
+
|
426 |
+
# encoding spec for modeling onsets and offsets
|
427 |
+
NoteEncodingSpec = NoteEncodingSpecType(
|
428 |
+
init_encoding_state_fn=lambda: None,
|
429 |
+
encode_event_fn=note_event_data_to_events,
|
430 |
+
encoding_state_to_events_fn=None,
|
431 |
+
init_decoding_state_fn=NoteDecodingState,
|
432 |
+
begin_decoding_segment_fn=lambda state: None,
|
433 |
+
decode_event_fn=decode_note_event,
|
434 |
+
flush_decoding_state_fn=flush_note_decoding_state)
|
435 |
+
|
436 |
+
|
437 |
+
# encoding spec for modeling onsets and offsets, with a "tie" section at the
|
438 |
+
# beginning of each segment listing already-active notes
|
439 |
+
NoteEncodingWithTiesSpec = NoteEncodingSpecType(
|
440 |
+
init_encoding_state_fn=NoteEncodingState,
|
441 |
+
encode_event_fn=note_event_data_to_events,
|
442 |
+
encoding_state_to_events_fn=note_encoding_state_to_events,
|
443 |
+
init_decoding_state_fn=NoteDecodingState,
|
444 |
+
begin_decoding_segment_fn=begin_tied_pitches_section,
|
445 |
+
decode_event_fn=decode_note_event,
|
446 |
+
flush_decoding_state_fn=flush_note_decoding_state)
|
mt3/note_sequences_test.py
ADDED
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for note_sequences."""
|
16 |
+
|
17 |
+
from mt3 import event_codec
|
18 |
+
from mt3 import note_sequences
|
19 |
+
from mt3 import run_length_encoding
|
20 |
+
|
21 |
+
import note_seq
|
22 |
+
import numpy as np
|
23 |
+
import tensorflow as tf
|
24 |
+
|
25 |
+
codec = event_codec.Codec(
|
26 |
+
max_shift_steps=100,
|
27 |
+
steps_per_second=100,
|
28 |
+
event_ranges=[
|
29 |
+
event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH,
|
30 |
+
note_seq.MAX_MIDI_PITCH),
|
31 |
+
event_codec.EventRange('velocity', 0, 127),
|
32 |
+
event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH,
|
33 |
+
note_seq.MAX_MIDI_PITCH),
|
34 |
+
event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM,
|
35 |
+
note_seq.MAX_MIDI_PROGRAM),
|
36 |
+
event_codec.EventRange('tie', 0, 0)
|
37 |
+
])
|
38 |
+
|
39 |
+
|
40 |
+
class RunLengthEncodingTest(tf.test.TestCase):
|
41 |
+
|
42 |
+
def test_encode_and_index_note_sequence(self):
|
43 |
+
ns = note_seq.NoteSequence()
|
44 |
+
ns.notes.add(start_time=1.0,
|
45 |
+
end_time=1.1,
|
46 |
+
pitch=61,
|
47 |
+
velocity=100)
|
48 |
+
ns.notes.add(start_time=2.0,
|
49 |
+
end_time=2.1,
|
50 |
+
pitch=62,
|
51 |
+
velocity=100)
|
52 |
+
ns.notes.add(start_time=3.0,
|
53 |
+
end_time=3.1,
|
54 |
+
pitch=63,
|
55 |
+
velocity=100)
|
56 |
+
ns.total_time = ns.notes[-1].end_time
|
57 |
+
|
58 |
+
frame_times = np.arange(0, 4, step=.001)
|
59 |
+
|
60 |
+
event_times, event_values = note_sequences.note_sequence_to_onsets(ns)
|
61 |
+
events, event_start_indices, event_end_indices, _, _ = run_length_encoding.encode_and_index_events(
|
62 |
+
state=None, event_times=event_times, event_values=event_values,
|
63 |
+
encode_event_fn=note_sequences.note_event_data_to_events,
|
64 |
+
codec=codec, frame_times=frame_times)
|
65 |
+
|
66 |
+
self.assertEqual(len(frame_times), len(event_start_indices))
|
67 |
+
self.assertEqual(len(frame_times), len(event_end_indices))
|
68 |
+
self.assertLen(events, 403)
|
69 |
+
expected_events = ([1] * 100 +
|
70 |
+
[162] +
|
71 |
+
[1] * 100 +
|
72 |
+
[163] +
|
73 |
+
[1] * 100 +
|
74 |
+
[164] +
|
75 |
+
[1] * 100)
|
76 |
+
np.testing.assert_array_equal(expected_events, events)
|
77 |
+
|
78 |
+
self.assertEqual(event_start_indices[0], 0)
|
79 |
+
self.assertEqual(event_end_indices[0], 0)
|
80 |
+
|
81 |
+
self.assertEqual(162, events[100])
|
82 |
+
self.assertEqual(1.0, frame_times[1000])
|
83 |
+
self.assertEqual(event_start_indices[1000], 100)
|
84 |
+
self.assertEqual(event_end_indices[1000], 100)
|
85 |
+
|
86 |
+
self.assertEqual(163, events[201])
|
87 |
+
self.assertEqual(2.0, frame_times[2000])
|
88 |
+
self.assertEqual(event_start_indices[2000], 201)
|
89 |
+
self.assertEqual(event_end_indices[2000], 201)
|
90 |
+
|
91 |
+
self.assertEqual(164, events[302])
|
92 |
+
self.assertEqual(3.0, frame_times[3000])
|
93 |
+
self.assertEqual(event_start_indices[3000], 302)
|
94 |
+
self.assertEqual(event_end_indices[3000], 302)
|
95 |
+
|
96 |
+
self.assertEqual(1, events[-1])
|
97 |
+
self.assertEqual(3.999, frame_times[-1])
|
98 |
+
self.assertEqual(event_start_indices[-1], 402)
|
99 |
+
self.assertEqual(event_end_indices[-1], len(expected_events))
|
100 |
+
|
101 |
+
def test_encode_and_index_note_sequence_velocity(self):
|
102 |
+
ns = note_seq.NoteSequence()
|
103 |
+
ns.notes.add(start_time=1.0,
|
104 |
+
end_time=3.0,
|
105 |
+
pitch=61,
|
106 |
+
velocity=1)
|
107 |
+
ns.notes.add(start_time=2.0,
|
108 |
+
end_time=4.0,
|
109 |
+
pitch=62,
|
110 |
+
velocity=127)
|
111 |
+
ns.total_time = ns.notes[-1].end_time
|
112 |
+
|
113 |
+
frame_times = np.arange(0, 4, step=.001)
|
114 |
+
|
115 |
+
event_times, event_values = (
|
116 |
+
note_sequences.note_sequence_to_onsets_and_offsets(ns))
|
117 |
+
events, event_start_indices, event_end_indices, _, _ = run_length_encoding.encode_and_index_events(
|
118 |
+
state=None, event_times=event_times, event_values=event_values,
|
119 |
+
encode_event_fn=note_sequences.note_event_data_to_events,
|
120 |
+
codec=codec, frame_times=frame_times)
|
121 |
+
|
122 |
+
self.assertEqual(len(frame_times), len(event_start_indices))
|
123 |
+
self.assertEqual(len(frame_times), len(event_end_indices))
|
124 |
+
self.assertLen(events, 408)
|
125 |
+
expected_events = ([1] * 100 +
|
126 |
+
[230, 162] +
|
127 |
+
[1] * 100 +
|
128 |
+
[356, 163] +
|
129 |
+
[1] * 100 +
|
130 |
+
[229, 162] +
|
131 |
+
[1] * 100 +
|
132 |
+
[229, 163])
|
133 |
+
np.testing.assert_array_equal(expected_events, events)
|
134 |
+
|
135 |
+
self.assertEqual(event_start_indices[0], 0)
|
136 |
+
self.assertEqual(event_end_indices[0], 0)
|
137 |
+
|
138 |
+
self.assertEqual(230, events[100])
|
139 |
+
self.assertEqual(162, events[101])
|
140 |
+
self.assertEqual(1.0, frame_times[1000])
|
141 |
+
self.assertEqual(event_start_indices[1000], 100)
|
142 |
+
self.assertEqual(event_end_indices[1000], 100)
|
143 |
+
|
144 |
+
self.assertEqual(356, events[202])
|
145 |
+
self.assertEqual(163, events[203])
|
146 |
+
self.assertEqual(2.0, frame_times[2000])
|
147 |
+
self.assertEqual(event_start_indices[2000], 202)
|
148 |
+
self.assertEqual(event_end_indices[2000], 202)
|
149 |
+
|
150 |
+
self.assertEqual(229, events[304])
|
151 |
+
self.assertEqual(162, events[305])
|
152 |
+
self.assertEqual(3.0, frame_times[3000])
|
153 |
+
self.assertEqual(event_start_indices[3000], 304)
|
154 |
+
self.assertEqual(event_end_indices[3000], 304)
|
155 |
+
|
156 |
+
self.assertEqual(229, events[406])
|
157 |
+
self.assertEqual(163, events[407])
|
158 |
+
self.assertEqual(3.999, frame_times[-1])
|
159 |
+
self.assertEqual(event_start_indices[-1], 405)
|
160 |
+
self.assertEqual(event_end_indices[-1], len(expected_events))
|
161 |
+
|
162 |
+
def test_encode_and_index_note_sequence_multitrack(self):
|
163 |
+
ns = note_seq.NoteSequence()
|
164 |
+
ns.notes.add(start_time=0.0,
|
165 |
+
end_time=1.0,
|
166 |
+
pitch=37,
|
167 |
+
velocity=127,
|
168 |
+
is_drum=True)
|
169 |
+
ns.notes.add(start_time=1.0,
|
170 |
+
end_time=3.0,
|
171 |
+
pitch=61,
|
172 |
+
velocity=127,
|
173 |
+
program=0)
|
174 |
+
ns.notes.add(start_time=2.0,
|
175 |
+
end_time=4.0,
|
176 |
+
pitch=62,
|
177 |
+
velocity=127,
|
178 |
+
program=40)
|
179 |
+
ns.total_time = ns.notes[-1].end_time
|
180 |
+
|
181 |
+
frame_times = np.arange(0, 4, step=.001)
|
182 |
+
|
183 |
+
event_times, event_values = (
|
184 |
+
note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns))
|
185 |
+
(tokens, event_start_indices, event_end_indices, state_tokens,
|
186 |
+
state_event_indices) = run_length_encoding.encode_and_index_events(
|
187 |
+
state=note_sequences.NoteEncodingState(),
|
188 |
+
event_times=event_times, event_values=event_values,
|
189 |
+
encode_event_fn=note_sequences.note_event_data_to_events,
|
190 |
+
codec=codec, frame_times=frame_times,
|
191 |
+
encoding_state_to_events_fn=(
|
192 |
+
note_sequences.note_encoding_state_to_events))
|
193 |
+
|
194 |
+
self.assertEqual(len(frame_times), len(event_start_indices))
|
195 |
+
self.assertEqual(len(frame_times), len(event_end_indices))
|
196 |
+
self.assertEqual(len(frame_times), len(state_event_indices))
|
197 |
+
self.assertLen(tokens, 414)
|
198 |
+
|
199 |
+
expected_events = (
|
200 |
+
[event_codec.Event('velocity', 127), event_codec.Event('drum', 37)] +
|
201 |
+
[event_codec.Event('shift', 1)] * 100 +
|
202 |
+
[event_codec.Event('program', 0),
|
203 |
+
event_codec.Event('velocity', 127), event_codec.Event('pitch', 61)] +
|
204 |
+
[event_codec.Event('shift', 1)] * 100 +
|
205 |
+
[event_codec.Event('program', 40),
|
206 |
+
event_codec.Event('velocity', 127), event_codec.Event('pitch', 62)] +
|
207 |
+
[event_codec.Event('shift', 1)] * 100 +
|
208 |
+
[event_codec.Event('program', 0),
|
209 |
+
event_codec.Event('velocity', 0), event_codec.Event('pitch', 61)] +
|
210 |
+
[event_codec.Event('shift', 1)] * 100 +
|
211 |
+
[event_codec.Event('program', 40),
|
212 |
+
event_codec.Event('velocity', 0), event_codec.Event('pitch', 62)])
|
213 |
+
expected_tokens = [codec.encode_event(e) for e in expected_events]
|
214 |
+
np.testing.assert_array_equal(expected_tokens, tokens)
|
215 |
+
|
216 |
+
expected_state_events = [
|
217 |
+
event_codec.Event('tie', 0), # state prior to first drum
|
218 |
+
event_codec.Event('tie', 0), # state prior to first onset
|
219 |
+
event_codec.Event('program', 0), # state prior to second onset
|
220 |
+
event_codec.Event('pitch', 61), # |
|
221 |
+
event_codec.Event('tie', 0), # |
|
222 |
+
event_codec.Event('program', 0), # state prior to first offset
|
223 |
+
event_codec.Event('pitch', 61), # |
|
224 |
+
event_codec.Event('program', 40), # |
|
225 |
+
event_codec.Event('pitch', 62), # |
|
226 |
+
event_codec.Event('tie', 0), # |
|
227 |
+
event_codec.Event('program', 40), # state prior to second offset
|
228 |
+
event_codec.Event('pitch', 62), # |
|
229 |
+
event_codec.Event('tie', 0) # |
|
230 |
+
]
|
231 |
+
expected_state_tokens = [codec.encode_event(e)
|
232 |
+
for e in expected_state_events]
|
233 |
+
np.testing.assert_array_equal(expected_state_tokens, state_tokens)
|
234 |
+
|
235 |
+
self.assertEqual(event_start_indices[0], 0)
|
236 |
+
self.assertEqual(event_end_indices[0], 0)
|
237 |
+
self.assertEqual(state_event_indices[0], 0)
|
238 |
+
|
239 |
+
self.assertEqual(1.0, frame_times[1000])
|
240 |
+
self.assertEqual(event_start_indices[1000], 102)
|
241 |
+
self.assertEqual(event_end_indices[1000], 102)
|
242 |
+
self.assertEqual(state_event_indices[1000], 1)
|
243 |
+
|
244 |
+
self.assertEqual(2.0, frame_times[2000])
|
245 |
+
self.assertEqual(event_start_indices[2000], 205)
|
246 |
+
self.assertEqual(event_end_indices[2000], 205)
|
247 |
+
self.assertEqual(state_event_indices[2000], 2)
|
248 |
+
|
249 |
+
self.assertEqual(3.0, frame_times[3000])
|
250 |
+
self.assertEqual(event_start_indices[3000], 308)
|
251 |
+
self.assertEqual(event_end_indices[3000], 308)
|
252 |
+
self.assertEqual(state_event_indices[3000], 5)
|
253 |
+
|
254 |
+
self.assertEqual(3.999, frame_times[-1])
|
255 |
+
self.assertEqual(event_start_indices[-1], 410)
|
256 |
+
self.assertEqual(event_end_indices[-1], len(expected_events))
|
257 |
+
self.assertEqual(state_event_indices[-1], 10)
|
258 |
+
|
259 |
+
def test_encode_and_index_note_sequence_last_token_alignment(self):
|
260 |
+
ns = note_seq.NoteSequence()
|
261 |
+
ns.notes.add(start_time=0.0,
|
262 |
+
end_time=0.1,
|
263 |
+
pitch=60,
|
264 |
+
velocity=100)
|
265 |
+
ns.total_time = ns.notes[-1].end_time
|
266 |
+
|
267 |
+
frame_times = np.arange(0, 1.008, step=.008)
|
268 |
+
|
269 |
+
event_times, event_values = note_sequences.note_sequence_to_onsets(ns)
|
270 |
+
events, event_start_indices, event_end_indices, _, _ = run_length_encoding.encode_and_index_events(
|
271 |
+
state=None,
|
272 |
+
event_times=event_times,
|
273 |
+
event_values=event_values,
|
274 |
+
encode_event_fn=note_sequences.note_event_data_to_events,
|
275 |
+
codec=codec,
|
276 |
+
frame_times=frame_times)
|
277 |
+
|
278 |
+
self.assertEqual(len(frame_times), len(event_start_indices))
|
279 |
+
self.assertEqual(len(frame_times), len(event_end_indices))
|
280 |
+
self.assertLen(events, 102)
|
281 |
+
expected_events = [161] + [1] * 101
|
282 |
+
|
283 |
+
np.testing.assert_array_equal(expected_events, events)
|
284 |
+
|
285 |
+
self.assertEqual(event_start_indices[0], 0)
|
286 |
+
self.assertEqual(event_end_indices[0], 0)
|
287 |
+
self.assertEqual(event_start_indices[125], 101)
|
288 |
+
self.assertEqual(event_end_indices[125], 102)
|
289 |
+
|
290 |
+
def test_decode_note_sequence_events(self):
|
291 |
+
events = [25, 161, 50, 162]
|
292 |
+
|
293 |
+
decoding_state = note_sequences.NoteDecodingState()
|
294 |
+
invalid_ids, dropped_events = run_length_encoding.decode_events(
|
295 |
+
state=decoding_state, tokens=events, start_time=0, max_time=None,
|
296 |
+
codec=codec, decode_event_fn=note_sequences.decode_note_onset_event)
|
297 |
+
ns = note_sequences.flush_note_decoding_state(decoding_state)
|
298 |
+
|
299 |
+
self.assertEqual(0, invalid_ids)
|
300 |
+
self.assertEqual(0, dropped_events)
|
301 |
+
expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
|
302 |
+
expected_ns.notes.add(
|
303 |
+
pitch=60,
|
304 |
+
velocity=100,
|
305 |
+
start_time=0.25,
|
306 |
+
end_time=0.26)
|
307 |
+
expected_ns.notes.add(
|
308 |
+
pitch=61,
|
309 |
+
velocity=100,
|
310 |
+
start_time=0.50,
|
311 |
+
end_time=0.51)
|
312 |
+
expected_ns.total_time = 0.51
|
313 |
+
self.assertProtoEquals(expected_ns, ns)
|
314 |
+
|
315 |
+
def test_decode_note_sequence_events_onsets_only(self):
|
316 |
+
events = [5, 161, 25, 162]
|
317 |
+
|
318 |
+
decoding_state = note_sequences.NoteDecodingState()
|
319 |
+
invalid_ids, dropped_events = run_length_encoding.decode_events(
|
320 |
+
state=decoding_state, tokens=events, start_time=0, max_time=None,
|
321 |
+
codec=codec, decode_event_fn=note_sequences.decode_note_onset_event)
|
322 |
+
ns = note_sequences.flush_note_decoding_state(decoding_state)
|
323 |
+
|
324 |
+
self.assertEqual(0, invalid_ids)
|
325 |
+
self.assertEqual(0, dropped_events)
|
326 |
+
expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
|
327 |
+
expected_ns.notes.add(
|
328 |
+
pitch=60,
|
329 |
+
velocity=100,
|
330 |
+
start_time=0.05,
|
331 |
+
end_time=0.06)
|
332 |
+
expected_ns.notes.add(
|
333 |
+
pitch=61,
|
334 |
+
velocity=100,
|
335 |
+
start_time=0.25,
|
336 |
+
end_time=0.26)
|
337 |
+
expected_ns.total_time = 0.26
|
338 |
+
self.assertProtoEquals(expected_ns, ns)
|
339 |
+
|
340 |
+
def test_decode_note_sequence_events_velocity(self):
|
341 |
+
events = [5, 356, 161, 25, 229, 161]
|
342 |
+
|
343 |
+
decoding_state = note_sequences.NoteDecodingState()
|
344 |
+
invalid_ids, dropped_events = run_length_encoding.decode_events(
|
345 |
+
state=decoding_state, tokens=events, start_time=0, max_time=None,
|
346 |
+
codec=codec, decode_event_fn=note_sequences.decode_note_event)
|
347 |
+
ns = note_sequences.flush_note_decoding_state(decoding_state)
|
348 |
+
|
349 |
+
self.assertEqual(0, invalid_ids)
|
350 |
+
self.assertEqual(0, dropped_events)
|
351 |
+
expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
|
352 |
+
expected_ns.notes.add(
|
353 |
+
pitch=60,
|
354 |
+
velocity=127,
|
355 |
+
start_time=0.05,
|
356 |
+
end_time=0.25)
|
357 |
+
expected_ns.total_time = 0.25
|
358 |
+
self.assertProtoEquals(expected_ns, ns)
|
359 |
+
|
360 |
+
def test_decode_note_sequence_events_missing_offset(self):
|
361 |
+
events = [5, 356, 161, 10, 161, 25, 229, 161]
|
362 |
+
|
363 |
+
decoding_state = note_sequences.NoteDecodingState()
|
364 |
+
invalid_ids, dropped_events = run_length_encoding.decode_events(
|
365 |
+
state=decoding_state, tokens=events, start_time=0, max_time=None,
|
366 |
+
codec=codec, decode_event_fn=note_sequences.decode_note_event)
|
367 |
+
ns = note_sequences.flush_note_decoding_state(decoding_state)
|
368 |
+
|
369 |
+
self.assertEqual(0, invalid_ids)
|
370 |
+
self.assertEqual(0, dropped_events)
|
371 |
+
expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
|
372 |
+
expected_ns.notes.add(
|
373 |
+
pitch=60,
|
374 |
+
velocity=127,
|
375 |
+
start_time=0.05,
|
376 |
+
end_time=0.10)
|
377 |
+
expected_ns.notes.add(
|
378 |
+
pitch=60,
|
379 |
+
velocity=127,
|
380 |
+
start_time=0.10,
|
381 |
+
end_time=0.25)
|
382 |
+
expected_ns.total_time = 0.25
|
383 |
+
self.assertProtoEquals(expected_ns, ns)
|
384 |
+
|
385 |
+
def test_decode_note_sequence_events_multitrack(self):
|
386 |
+
events = [5, 525, 356, 161, 15, 356, 394, 25, 525, 229, 161]
|
387 |
+
|
388 |
+
decoding_state = note_sequences.NoteDecodingState()
|
389 |
+
invalid_ids, dropped_events = run_length_encoding.decode_events(
|
390 |
+
state=decoding_state, tokens=events, start_time=0, max_time=None,
|
391 |
+
codec=codec, decode_event_fn=note_sequences.decode_note_event)
|
392 |
+
ns = note_sequences.flush_note_decoding_state(decoding_state)
|
393 |
+
|
394 |
+
self.assertEqual(0, invalid_ids)
|
395 |
+
self.assertEqual(0, dropped_events)
|
396 |
+
expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
|
397 |
+
expected_ns.notes.add(
|
398 |
+
pitch=37,
|
399 |
+
velocity=127,
|
400 |
+
start_time=0.15,
|
401 |
+
end_time=0.16,
|
402 |
+
instrument=9,
|
403 |
+
is_drum=True)
|
404 |
+
expected_ns.notes.add(
|
405 |
+
pitch=60,
|
406 |
+
velocity=127,
|
407 |
+
start_time=0.05,
|
408 |
+
end_time=0.25,
|
409 |
+
program=40)
|
410 |
+
expected_ns.total_time = 0.25
|
411 |
+
self.assertProtoEquals(expected_ns, ns)
|
412 |
+
|
413 |
+
def test_decode_note_sequence_events_invalid_tokens(self):
|
414 |
+
events = [5, -1, 161, -2, 25, 162, 9999]
|
415 |
+
|
416 |
+
decoding_state = note_sequences.NoteDecodingState()
|
417 |
+
invalid_events, dropped_events = run_length_encoding.decode_events(
|
418 |
+
state=decoding_state, tokens=events, start_time=0, max_time=None,
|
419 |
+
codec=codec, decode_event_fn=note_sequences.decode_note_onset_event)
|
420 |
+
ns = note_sequences.flush_note_decoding_state(decoding_state)
|
421 |
+
|
422 |
+
self.assertEqual(3, invalid_events)
|
423 |
+
self.assertEqual(0, dropped_events)
|
424 |
+
expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
|
425 |
+
expected_ns.notes.add(
|
426 |
+
pitch=60,
|
427 |
+
velocity=100,
|
428 |
+
start_time=0.05,
|
429 |
+
end_time=0.06)
|
430 |
+
expected_ns.notes.add(
|
431 |
+
pitch=61,
|
432 |
+
velocity=100,
|
433 |
+
start_time=0.25,
|
434 |
+
end_time=0.26)
|
435 |
+
expected_ns.total_time = 0.26
|
436 |
+
self.assertProtoEquals(expected_ns, ns)
|
437 |
+
|
438 |
+
def test_decode_note_sequence_events_allow_event_at_exactly_max_time(self):
|
439 |
+
events = [161, 25, 162]
|
440 |
+
|
441 |
+
decoding_state = note_sequences.NoteDecodingState()
|
442 |
+
invalid_ids, dropped_events = run_length_encoding.decode_events(
|
443 |
+
state=decoding_state, tokens=events, start_time=1.0, max_time=1.25,
|
444 |
+
codec=codec, decode_event_fn=note_sequences.decode_note_onset_event)
|
445 |
+
ns = note_sequences.flush_note_decoding_state(decoding_state)
|
446 |
+
|
447 |
+
self.assertEqual(0, invalid_ids)
|
448 |
+
self.assertEqual(0, dropped_events)
|
449 |
+
expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
|
450 |
+
expected_ns.notes.add(
|
451 |
+
pitch=60,
|
452 |
+
velocity=100,
|
453 |
+
start_time=1.00,
|
454 |
+
end_time=1.01)
|
455 |
+
expected_ns.notes.add(
|
456 |
+
pitch=61,
|
457 |
+
velocity=100,
|
458 |
+
start_time=1.25,
|
459 |
+
end_time=1.26)
|
460 |
+
expected_ns.total_time = 1.26
|
461 |
+
self.assertProtoEquals(expected_ns, ns)
|
462 |
+
|
463 |
+
def test_decode_note_sequence_events_dropped_events(self):
|
464 |
+
events = [5, 161, 30, 162]
|
465 |
+
|
466 |
+
decoding_state = note_sequences.NoteDecodingState()
|
467 |
+
invalid_ids, dropped_events = run_length_encoding.decode_events(
|
468 |
+
state=decoding_state, tokens=events, start_time=1.0, max_time=1.25,
|
469 |
+
codec=codec, decode_event_fn=note_sequences.decode_note_onset_event)
|
470 |
+
ns = note_sequences.flush_note_decoding_state(decoding_state)
|
471 |
+
|
472 |
+
self.assertEqual(0, invalid_ids)
|
473 |
+
self.assertEqual(2, dropped_events)
|
474 |
+
expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
|
475 |
+
expected_ns.notes.add(
|
476 |
+
pitch=60,
|
477 |
+
velocity=100,
|
478 |
+
start_time=1.05,
|
479 |
+
end_time=1.06)
|
480 |
+
expected_ns.total_time = 1.06
|
481 |
+
self.assertProtoEquals(expected_ns, ns)
|
482 |
+
|
483 |
+
def test_decode_note_sequence_events_invalid_events(self):
|
484 |
+
events = [25, 230, 50, 161]
|
485 |
+
|
486 |
+
decoding_state = note_sequences.NoteDecodingState()
|
487 |
+
invalid_ids, dropped_events = run_length_encoding.decode_events(
|
488 |
+
state=decoding_state, tokens=events, start_time=0, max_time=None,
|
489 |
+
codec=codec, decode_event_fn=note_sequences.decode_note_onset_event)
|
490 |
+
ns = note_sequences.flush_note_decoding_state(decoding_state)
|
491 |
+
|
492 |
+
self.assertEqual(1, invalid_ids)
|
493 |
+
self.assertEqual(0, dropped_events)
|
494 |
+
expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
|
495 |
+
expected_ns.notes.add(
|
496 |
+
pitch=60,
|
497 |
+
velocity=100,
|
498 |
+
start_time=0.50,
|
499 |
+
end_time=0.51)
|
500 |
+
expected_ns.total_time = 0.51
|
501 |
+
self.assertProtoEquals(expected_ns, ns)
|
502 |
+
|
503 |
+
|
504 |
+
if __name__ == '__main__':
|
505 |
+
tf.test.main()
|
mt3/preprocessors.py
ADDED
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Transcription preprocessors."""
|
16 |
+
|
17 |
+
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple
|
18 |
+
|
19 |
+
from absl import logging
|
20 |
+
import gin
|
21 |
+
from immutabledict import immutabledict
|
22 |
+
import librosa
|
23 |
+
|
24 |
+
from mt3 import event_codec
|
25 |
+
from mt3 import note_sequences
|
26 |
+
from mt3 import run_length_encoding
|
27 |
+
from mt3 import spectrograms
|
28 |
+
from mt3 import vocabularies
|
29 |
+
|
30 |
+
import note_seq
|
31 |
+
import numpy as np
|
32 |
+
import seqio
|
33 |
+
import tensorflow as tf
|
34 |
+
|
35 |
+
|
36 |
+
def add_unique_id(ds: tf.data.Dataset) -> tf.data.Dataset:
|
37 |
+
"""Add unique integer ID to each example in a dataset."""
|
38 |
+
def add_id_field(i, ex):
|
39 |
+
ex['unique_id'] = [i]
|
40 |
+
return ex
|
41 |
+
return ds.enumerate().map(
|
42 |
+
add_id_field, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
43 |
+
|
44 |
+
|
45 |
+
@seqio.map_over_dataset
|
46 |
+
def pad_notesequence_array(ex):
|
47 |
+
"""Pad the NoteSequence array so that it can later be "split"."""
|
48 |
+
ex['sequence'] = tf.pad(tf.expand_dims(ex['sequence'], 0),
|
49 |
+
[[0, len(ex['input_times']) - 1]])
|
50 |
+
return ex
|
51 |
+
|
52 |
+
|
53 |
+
@seqio.map_over_dataset
|
54 |
+
def add_dummy_targets(ex):
|
55 |
+
"""Add dummy targets; used in eval when targets are not actually used."""
|
56 |
+
ex['targets'] = np.array([], dtype=np.int32)
|
57 |
+
return ex
|
58 |
+
|
59 |
+
|
60 |
+
def _audio_to_frames(
|
61 |
+
samples: Sequence[float],
|
62 |
+
spectrogram_config: spectrograms.SpectrogramConfig,
|
63 |
+
) -> Tuple[Sequence[Sequence[int]], np.ndarray]:
|
64 |
+
"""Convert audio samples to non-overlapping frames and frame times."""
|
65 |
+
frame_size = spectrogram_config.hop_width
|
66 |
+
logging.info('Padding %d samples to multiple of %d', len(samples), frame_size)
|
67 |
+
samples = np.pad(samples,
|
68 |
+
[0, frame_size - len(samples) % frame_size],
|
69 |
+
mode='constant')
|
70 |
+
|
71 |
+
frames = spectrograms.split_audio(samples, spectrogram_config)
|
72 |
+
|
73 |
+
num_frames = len(samples) // frame_size
|
74 |
+
logging.info('Encoded %d samples to %d frames (%d samples each)',
|
75 |
+
len(samples), num_frames, frame_size)
|
76 |
+
|
77 |
+
times = np.arange(num_frames) / spectrogram_config.frames_per_second
|
78 |
+
return frames, times
|
79 |
+
|
80 |
+
|
81 |
+
def _include_inputs(ds, input_record, fields_to_omit=('audio',)):
|
82 |
+
"""Include fields from input record (other than audio) in dataset records."""
|
83 |
+
def include_inputs_fn(output_record):
|
84 |
+
for key in set(input_record.keys()) - set(output_record.keys()):
|
85 |
+
output_record[key] = input_record[key]
|
86 |
+
for key in fields_to_omit:
|
87 |
+
del output_record[key]
|
88 |
+
return output_record
|
89 |
+
return ds.map(include_inputs_fn,
|
90 |
+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
91 |
+
|
92 |
+
|
93 |
+
def tokenize_transcription_example(
|
94 |
+
ds: tf.data.Dataset, spectrogram_config: spectrograms.SpectrogramConfig,
|
95 |
+
codec: event_codec.Codec, is_training_data: bool,
|
96 |
+
onsets_only: bool, include_ties: bool, audio_is_samples: bool,
|
97 |
+
id_feature_key: Optional[str] = None
|
98 |
+
) -> tf.data.Dataset:
|
99 |
+
"""Tokenize a note transcription example for run-length encoding.
|
100 |
+
|
101 |
+
Outputs include:
|
102 |
+
inputs: audio sample frames, num_frames-by-frame_size
|
103 |
+
input_time: timestamp for each frame
|
104 |
+
targets: symbolic sequence of note-related events
|
105 |
+
input_event_start_indices: start target index for every input index
|
106 |
+
input_event_end_indices: end target index for every input index
|
107 |
+
|
108 |
+
Args:
|
109 |
+
ds: Input dataset.
|
110 |
+
spectrogram_config: Spectrogram configuration.
|
111 |
+
codec: Event vocabulary codec.
|
112 |
+
is_training_data: Unused.
|
113 |
+
onsets_only: If True, include only onset events (not offset, velocity, or
|
114 |
+
program).
|
115 |
+
include_ties: If True, also write state events containing active notes to
|
116 |
+
support a "tie" section after run-length encoding.
|
117 |
+
audio_is_samples: If True, audio is floating-point samples instead of
|
118 |
+
serialized WAV.
|
119 |
+
id_feature_key: If not None, replace sequence ID with specified key field
|
120 |
+
from the dataset.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
Dataset with the outputs described above.
|
124 |
+
"""
|
125 |
+
del is_training_data
|
126 |
+
|
127 |
+
if onsets_only and include_ties:
|
128 |
+
raise ValueError('Ties not supported when only modeling onsets.')
|
129 |
+
|
130 |
+
def tokenize(sequence, audio, sample_rate, example_id=None):
|
131 |
+
ns = note_seq.NoteSequence.FromString(sequence)
|
132 |
+
note_sequences.validate_note_sequence(ns)
|
133 |
+
|
134 |
+
if example_id is not None:
|
135 |
+
ns.id = example_id
|
136 |
+
|
137 |
+
if audio_is_samples:
|
138 |
+
samples = audio
|
139 |
+
if sample_rate != spectrogram_config.sample_rate:
|
140 |
+
samples = librosa.resample(
|
141 |
+
samples, sample_rate, spectrogram_config.sample_rate)
|
142 |
+
else:
|
143 |
+
samples = note_seq.audio_io.wav_data_to_samples_librosa(
|
144 |
+
audio, sample_rate=spectrogram_config.sample_rate)
|
145 |
+
|
146 |
+
logging.info('Got samples for %s::%s with length %d',
|
147 |
+
ns.id, ns.filename, len(samples))
|
148 |
+
|
149 |
+
frames, frame_times = _audio_to_frames(samples, spectrogram_config)
|
150 |
+
|
151 |
+
if onsets_only:
|
152 |
+
times, values = note_sequences.note_sequence_to_onsets(ns)
|
153 |
+
else:
|
154 |
+
ns = note_seq.apply_sustain_control_changes(ns)
|
155 |
+
times, values = (
|
156 |
+
note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns))
|
157 |
+
|
158 |
+
# The original NoteSequence can have a lot of control changes we don't need;
|
159 |
+
# delete them.
|
160 |
+
del ns.control_changes[:]
|
161 |
+
|
162 |
+
(events, event_start_indices, event_end_indices,
|
163 |
+
state_events, state_event_indices) = (
|
164 |
+
run_length_encoding.encode_and_index_events(
|
165 |
+
state=note_sequences.NoteEncodingState() if include_ties else None,
|
166 |
+
event_times=times,
|
167 |
+
event_values=values,
|
168 |
+
encode_event_fn=note_sequences.note_event_data_to_events,
|
169 |
+
codec=codec,
|
170 |
+
frame_times=frame_times,
|
171 |
+
encoding_state_to_events_fn=(
|
172 |
+
note_sequences.note_encoding_state_to_events
|
173 |
+
if include_ties else None)))
|
174 |
+
|
175 |
+
yield {
|
176 |
+
'inputs': frames,
|
177 |
+
'input_times': frame_times,
|
178 |
+
'targets': events,
|
179 |
+
'input_event_start_indices': event_start_indices,
|
180 |
+
'input_event_end_indices': event_end_indices,
|
181 |
+
'state_events': state_events,
|
182 |
+
'input_state_event_indices': state_event_indices,
|
183 |
+
'sequence': ns.SerializeToString()
|
184 |
+
}
|
185 |
+
|
186 |
+
def process_record(input_record):
|
187 |
+
if audio_is_samples and 'sample_rate' not in input_record:
|
188 |
+
raise ValueError('Must provide sample rate when audio is samples.')
|
189 |
+
|
190 |
+
args = [
|
191 |
+
input_record['sequence'],
|
192 |
+
input_record['audio'],
|
193 |
+
input_record['sample_rate'] if 'sample_rate' in input_record else 0
|
194 |
+
]
|
195 |
+
if id_feature_key is not None:
|
196 |
+
args.append(input_record[id_feature_key])
|
197 |
+
|
198 |
+
ds = tf.data.Dataset.from_generator(
|
199 |
+
tokenize,
|
200 |
+
output_signature={
|
201 |
+
'inputs':
|
202 |
+
tf.TensorSpec(
|
203 |
+
shape=(None, spectrogram_config.hop_width),
|
204 |
+
dtype=tf.float32),
|
205 |
+
'input_times':
|
206 |
+
tf.TensorSpec(shape=(None,), dtype=tf.float32),
|
207 |
+
'targets':
|
208 |
+
tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
209 |
+
'input_event_start_indices':
|
210 |
+
tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
211 |
+
'input_event_end_indices':
|
212 |
+
tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
213 |
+
'state_events':
|
214 |
+
tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
215 |
+
'input_state_event_indices':
|
216 |
+
tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
217 |
+
'sequence':
|
218 |
+
tf.TensorSpec(shape=(), dtype=tf.string)
|
219 |
+
},
|
220 |
+
args=args)
|
221 |
+
|
222 |
+
ds = _include_inputs(ds, input_record)
|
223 |
+
return ds
|
224 |
+
|
225 |
+
tokenized_records = ds.flat_map(process_record)
|
226 |
+
return tokenized_records
|
227 |
+
|
228 |
+
|
229 |
+
def tokenize_guitarset_example(
|
230 |
+
ds: tf.data.Dataset, spectrogram_config: spectrograms.SpectrogramConfig,
|
231 |
+
codec: event_codec.Codec, is_training_data: bool,
|
232 |
+
onsets_only: bool, include_ties: bool
|
233 |
+
) -> tf.data.Dataset:
|
234 |
+
"""Tokenize a GuitarSet transcription example."""
|
235 |
+
def _preprocess_example(ex, name):
|
236 |
+
assert 'inst_names' not in ex, 'Key `inst_names` is already populated.'
|
237 |
+
ex['inst_names'] = [name]
|
238 |
+
ex['instrument_sequences'] = [ex.pop('sequence')]
|
239 |
+
return ex
|
240 |
+
|
241 |
+
ds = ds.map(
|
242 |
+
lambda x: _preprocess_example(x, 'Clean Guitar'),
|
243 |
+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
244 |
+
ds = tokenize_example_with_program_lookup(
|
245 |
+
ds,
|
246 |
+
spectrogram_config=spectrogram_config,
|
247 |
+
codec=codec,
|
248 |
+
is_training_data=is_training_data,
|
249 |
+
inst_name_to_program_fn=guitarset_instrument_to_program,
|
250 |
+
onsets_only=onsets_only,
|
251 |
+
include_ties=include_ties,
|
252 |
+
id_feature_key='id')
|
253 |
+
return ds
|
254 |
+
|
255 |
+
|
256 |
+
def guitarset_instrument_to_program(instrument: str) -> int:
|
257 |
+
"""GuitarSet is all guitar, return the first MIDI guitar program."""
|
258 |
+
if instrument == 'Clean Guitar':
|
259 |
+
return 24
|
260 |
+
else:
|
261 |
+
raise ValueError('Unknown GuitarSet instrument: %s' % instrument)
|
262 |
+
|
263 |
+
|
264 |
+
def tokenize_example_with_program_lookup(
|
265 |
+
ds: tf.data.Dataset,
|
266 |
+
spectrogram_config: spectrograms.SpectrogramConfig,
|
267 |
+
codec: event_codec.Codec,
|
268 |
+
is_training_data: bool,
|
269 |
+
onsets_only: bool,
|
270 |
+
include_ties: bool,
|
271 |
+
inst_name_to_program_fn: Callable[[str], int],
|
272 |
+
id_feature_key: Optional[str] = None
|
273 |
+
) -> tf.data.Dataset:
|
274 |
+
"""Tokenize an example, optionally looking up and assigning program numbers.
|
275 |
+
|
276 |
+
This can be used by any dataset where a mapping function can be used to
|
277 |
+
map from the inst_names feature to a set of program numbers.
|
278 |
+
|
279 |
+
Args:
|
280 |
+
ds: Input dataset.
|
281 |
+
spectrogram_config: Spectrogram configuration.
|
282 |
+
codec: Event vocabulary codec.
|
283 |
+
is_training_data: Unused.
|
284 |
+
onsets_only: If True, include only onset events (not offset & velocity).
|
285 |
+
include_ties: If True, include tie events.
|
286 |
+
inst_name_to_program_fn: A function used to map the instrument names
|
287 |
+
in the `inst_names` feature of each example to a MIDI program number.
|
288 |
+
id_feature_key: If not None, replace sequence ID with specified key field
|
289 |
+
from the dataset.
|
290 |
+
|
291 |
+
Returns:
|
292 |
+
Dataset with the outputs described above.
|
293 |
+
"""
|
294 |
+
del is_training_data
|
295 |
+
|
296 |
+
def tokenize(sequences, inst_names, audio, example_id=None):
|
297 |
+
# Add all the notes from the tracks to a single NoteSequence.
|
298 |
+
ns = note_seq.NoteSequence(ticks_per_quarter=220)
|
299 |
+
tracks = [note_seq.NoteSequence.FromString(seq) for seq in sequences]
|
300 |
+
assert len(tracks) == len(inst_names)
|
301 |
+
for track, inst_name in zip(tracks, inst_names):
|
302 |
+
program = inst_name_to_program_fn(
|
303 |
+
inst_name.decode())
|
304 |
+
|
305 |
+
# Note that there are no pitch bends in URMP data; the below block will
|
306 |
+
# raise PitchBendError if one is encountered.
|
307 |
+
add_track_to_notesequence(ns, track, program=program, is_drum=False,
|
308 |
+
ignore_pitch_bends=False)
|
309 |
+
|
310 |
+
note_sequences.assign_instruments(ns)
|
311 |
+
note_sequences.validate_note_sequence(ns)
|
312 |
+
|
313 |
+
if example_id is not None:
|
314 |
+
ns.id = example_id
|
315 |
+
|
316 |
+
samples = note_seq.audio_io.wav_data_to_samples_librosa(
|
317 |
+
audio, sample_rate=spectrogram_config.sample_rate)
|
318 |
+
|
319 |
+
logging.info('Got samples for %s::%s with length %d',
|
320 |
+
ns.id, ns.filename, len(samples))
|
321 |
+
|
322 |
+
frames, frame_times = _audio_to_frames(samples, spectrogram_config)
|
323 |
+
|
324 |
+
if onsets_only:
|
325 |
+
times, values = note_sequences.note_sequence_to_onsets(ns)
|
326 |
+
else:
|
327 |
+
times, values = (
|
328 |
+
note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns))
|
329 |
+
|
330 |
+
# The original NoteSequence can have a lot of control changes we don't need;
|
331 |
+
# delete them.
|
332 |
+
del ns.control_changes[:]
|
333 |
+
|
334 |
+
(events, event_start_indices, event_end_indices,
|
335 |
+
state_events, state_event_indices) = (
|
336 |
+
run_length_encoding.encode_and_index_events(
|
337 |
+
state=note_sequences.NoteEncodingState() if include_ties else None,
|
338 |
+
event_times=times,
|
339 |
+
event_values=values,
|
340 |
+
encode_event_fn=note_sequences.note_event_data_to_events,
|
341 |
+
codec=codec,
|
342 |
+
frame_times=frame_times,
|
343 |
+
encoding_state_to_events_fn=(
|
344 |
+
note_sequences.note_encoding_state_to_events
|
345 |
+
if include_ties else None)))
|
346 |
+
|
347 |
+
yield {
|
348 |
+
'inputs': frames,
|
349 |
+
'input_times': frame_times,
|
350 |
+
'targets': events,
|
351 |
+
'input_event_start_indices': event_start_indices,
|
352 |
+
'input_event_end_indices': event_end_indices,
|
353 |
+
'state_events': state_events,
|
354 |
+
'input_state_event_indices': state_event_indices,
|
355 |
+
'sequence': ns.SerializeToString()
|
356 |
+
}
|
357 |
+
|
358 |
+
def process_record(input_record):
|
359 |
+
args = [
|
360 |
+
input_record['instrument_sequences'],
|
361 |
+
input_record['inst_names'],
|
362 |
+
input_record['audio'],
|
363 |
+
]
|
364 |
+
if id_feature_key is not None:
|
365 |
+
args.append(input_record[id_feature_key])
|
366 |
+
|
367 |
+
ds = tf.data.Dataset.from_generator(
|
368 |
+
tokenize,
|
369 |
+
output_signature={
|
370 |
+
'inputs':
|
371 |
+
tf.TensorSpec(
|
372 |
+
shape=(None, spectrogram_config.hop_width),
|
373 |
+
dtype=tf.float32),
|
374 |
+
'input_times':
|
375 |
+
tf.TensorSpec(shape=(None,), dtype=tf.float32),
|
376 |
+
'targets':
|
377 |
+
tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
378 |
+
'input_event_start_indices':
|
379 |
+
tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
380 |
+
'input_event_end_indices':
|
381 |
+
tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
382 |
+
'state_events':
|
383 |
+
tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
384 |
+
'input_state_event_indices':
|
385 |
+
tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
386 |
+
'sequence':
|
387 |
+
tf.TensorSpec(shape=(), dtype=tf.string)
|
388 |
+
},
|
389 |
+
args=args)
|
390 |
+
|
391 |
+
ds = _include_inputs(ds, input_record)
|
392 |
+
return ds
|
393 |
+
|
394 |
+
tokenized_records = ds.flat_map(process_record)
|
395 |
+
return tokenized_records
|
396 |
+
|
397 |
+
|
398 |
+
_URMP_INSTRUMENT_PROGRAMS = immutabledict({
|
399 |
+
'vn': 40, # violin
|
400 |
+
'va': 41, # viola
|
401 |
+
'vc': 42, # cello
|
402 |
+
'db': 43, # double bass
|
403 |
+
'tpt': 56, # trumpet
|
404 |
+
'tbn': 57, # trombone
|
405 |
+
'tba': 58, # tuba
|
406 |
+
'hn': 60, # French horn
|
407 |
+
'sax': 64, # saxophone
|
408 |
+
'ob': 68, # oboe
|
409 |
+
'bn': 70, # bassoon
|
410 |
+
'cl': 71, # clarinet
|
411 |
+
'fl': 73 # flute
|
412 |
+
})
|
413 |
+
|
414 |
+
|
415 |
+
def urmp_instrument_to_program(urmp_instrument: str) -> int:
|
416 |
+
"""Fetch the program number associated with a given URMP instrument code."""
|
417 |
+
if urmp_instrument not in _URMP_INSTRUMENT_PROGRAMS:
|
418 |
+
raise ValueError('unknown URMP instrument: %s' % urmp_instrument)
|
419 |
+
return _URMP_INSTRUMENT_PROGRAMS[urmp_instrument]
|
420 |
+
|
421 |
+
|
422 |
+
_SLAKH_CLASS_PROGRAMS = immutabledict({
|
423 |
+
'Acoustic Piano': 0,
|
424 |
+
'Electric Piano': 4,
|
425 |
+
'Chromatic Percussion': 8,
|
426 |
+
'Organ': 16,
|
427 |
+
'Acoustic Guitar': 24,
|
428 |
+
'Clean Electric Guitar': 26,
|
429 |
+
'Distorted Electric Guitar': 29,
|
430 |
+
'Acoustic Bass': 32,
|
431 |
+
'Electric Bass': 33,
|
432 |
+
'Violin': 40,
|
433 |
+
'Viola': 41,
|
434 |
+
'Cello': 42,
|
435 |
+
'Contrabass': 43,
|
436 |
+
'Orchestral Harp': 46,
|
437 |
+
'Timpani': 47,
|
438 |
+
'String Ensemble': 48,
|
439 |
+
'Synth Strings': 50,
|
440 |
+
'Choir and Voice': 52,
|
441 |
+
'Orchestral Hit': 55,
|
442 |
+
'Trumpet': 56,
|
443 |
+
'Trombone': 57,
|
444 |
+
'Tuba': 58,
|
445 |
+
'French Horn': 60,
|
446 |
+
'Brass Section': 61,
|
447 |
+
'Soprano/Alto Sax': 64,
|
448 |
+
'Tenor Sax': 66,
|
449 |
+
'Baritone Sax': 67,
|
450 |
+
'Oboe': 68,
|
451 |
+
'English Horn': 69,
|
452 |
+
'Bassoon': 70,
|
453 |
+
'Clarinet': 71,
|
454 |
+
'Pipe': 73,
|
455 |
+
'Synth Lead': 80,
|
456 |
+
'Synth Pad': 88
|
457 |
+
})
|
458 |
+
|
459 |
+
|
460 |
+
def slakh_class_to_program_and_is_drum(slakh_class: str) -> Tuple[int, bool]:
|
461 |
+
"""Map Slakh class string to program number and boolean indicating drums."""
|
462 |
+
if slakh_class == 'Drums':
|
463 |
+
return 0, True
|
464 |
+
elif slakh_class not in _SLAKH_CLASS_PROGRAMS:
|
465 |
+
raise ValueError('unknown Slakh class: %s' % slakh_class)
|
466 |
+
else:
|
467 |
+
return _SLAKH_CLASS_PROGRAMS[slakh_class], False
|
468 |
+
|
469 |
+
|
470 |
+
class PitchBendError(Exception):
|
471 |
+
pass
|
472 |
+
|
473 |
+
|
474 |
+
def add_track_to_notesequence(ns: note_seq.NoteSequence,
|
475 |
+
track: note_seq.NoteSequence,
|
476 |
+
program: int, is_drum: bool,
|
477 |
+
ignore_pitch_bends: bool):
|
478 |
+
"""Add a track to a NoteSequence."""
|
479 |
+
if track.pitch_bends and not ignore_pitch_bends:
|
480 |
+
raise PitchBendError
|
481 |
+
track_sus = note_seq.apply_sustain_control_changes(track)
|
482 |
+
for note in track_sus.notes:
|
483 |
+
note.program = program
|
484 |
+
note.is_drum = is_drum
|
485 |
+
ns.notes.extend([note])
|
486 |
+
ns.total_time = max(ns.total_time, note.end_time)
|
487 |
+
|
488 |
+
|
489 |
+
def tokenize_slakh_example(
|
490 |
+
ds: tf.data.Dataset,
|
491 |
+
spectrogram_config: spectrograms.SpectrogramConfig,
|
492 |
+
codec: event_codec.Codec,
|
493 |
+
is_training_data: bool,
|
494 |
+
onsets_only: bool,
|
495 |
+
include_ties: bool,
|
496 |
+
track_specs: Optional[Sequence[note_sequences.TrackSpec]],
|
497 |
+
ignore_pitch_bends: bool
|
498 |
+
) -> tf.data.Dataset:
|
499 |
+
"""Tokenize a Slakh multitrack note transcription example."""
|
500 |
+
def tokenize(sequences, samples, sample_rate, inst_names, example_id):
|
501 |
+
if sample_rate != spectrogram_config.sample_rate:
|
502 |
+
samples = librosa.resample(
|
503 |
+
samples, sample_rate, spectrogram_config.sample_rate)
|
504 |
+
|
505 |
+
frames, frame_times = _audio_to_frames(samples, spectrogram_config)
|
506 |
+
|
507 |
+
# Add all the notes from the tracks to a single NoteSequence.
|
508 |
+
ns = note_seq.NoteSequence(ticks_per_quarter=220)
|
509 |
+
tracks = [note_seq.NoteSequence.FromString(seq) for seq in sequences]
|
510 |
+
assert len(tracks) == len(inst_names)
|
511 |
+
if track_specs:
|
512 |
+
# Specific tracks expected.
|
513 |
+
assert len(tracks) == len(track_specs)
|
514 |
+
for track, spec, inst_name in zip(tracks, track_specs, inst_names):
|
515 |
+
# Make sure the instrument name matches what we expect.
|
516 |
+
assert inst_name.decode() == spec.name
|
517 |
+
try:
|
518 |
+
add_track_to_notesequence(ns, track,
|
519 |
+
program=spec.program, is_drum=spec.is_drum,
|
520 |
+
ignore_pitch_bends=ignore_pitch_bends)
|
521 |
+
except PitchBendError:
|
522 |
+
# TODO(iansimon): is there a way to count these?
|
523 |
+
return
|
524 |
+
else:
|
525 |
+
for track, inst_name in zip(tracks, inst_names):
|
526 |
+
# Instrument name should be Slakh class.
|
527 |
+
program, is_drum = slakh_class_to_program_and_is_drum(
|
528 |
+
inst_name.decode())
|
529 |
+
try:
|
530 |
+
add_track_to_notesequence(ns, track, program=program, is_drum=is_drum,
|
531 |
+
ignore_pitch_bends=ignore_pitch_bends)
|
532 |
+
except PitchBendError:
|
533 |
+
# TODO(iansimon): is there a way to count these?
|
534 |
+
return
|
535 |
+
|
536 |
+
note_sequences.assign_instruments(ns)
|
537 |
+
note_sequences.validate_note_sequence(ns)
|
538 |
+
if is_training_data:
|
539 |
+
# Trim overlapping notes in training (as our event vocabulary cannot
|
540 |
+
# represent them), but preserve original NoteSequence for eval.
|
541 |
+
ns = note_sequences.trim_overlapping_notes(ns)
|
542 |
+
|
543 |
+
ns.id = example_id
|
544 |
+
|
545 |
+
if onsets_only:
|
546 |
+
times, values = note_sequences.note_sequence_to_onsets(ns)
|
547 |
+
else:
|
548 |
+
times, values = (
|
549 |
+
note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns))
|
550 |
+
|
551 |
+
(events, event_start_indices, event_end_indices,
|
552 |
+
state_events, state_event_indices) = (
|
553 |
+
run_length_encoding.encode_and_index_events(
|
554 |
+
state=note_sequences.NoteEncodingState() if include_ties else None,
|
555 |
+
event_times=times,
|
556 |
+
event_values=values,
|
557 |
+
encode_event_fn=note_sequences.note_event_data_to_events,
|
558 |
+
codec=codec,
|
559 |
+
frame_times=frame_times,
|
560 |
+
encoding_state_to_events_fn=(
|
561 |
+
note_sequences.note_encoding_state_to_events
|
562 |
+
if include_ties else None)))
|
563 |
+
|
564 |
+
yield {
|
565 |
+
'inputs': frames,
|
566 |
+
'input_times': frame_times,
|
567 |
+
'targets': events,
|
568 |
+
'input_event_start_indices': event_start_indices,
|
569 |
+
'input_event_end_indices': event_end_indices,
|
570 |
+
'state_events': state_events,
|
571 |
+
'input_state_event_indices': state_event_indices,
|
572 |
+
'sequence': ns.SerializeToString()
|
573 |
+
}
|
574 |
+
|
575 |
+
def process_record(input_record):
|
576 |
+
ds = tf.data.Dataset.from_generator(
|
577 |
+
tokenize,
|
578 |
+
output_signature={
|
579 |
+
'inputs':
|
580 |
+
tf.TensorSpec(
|
581 |
+
shape=(None, spectrogram_config.hop_width),
|
582 |
+
dtype=tf.float32),
|
583 |
+
'input_times':
|
584 |
+
tf.TensorSpec(shape=(None,), dtype=tf.float32),
|
585 |
+
'targets':
|
586 |
+
tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
587 |
+
'input_event_start_indices':
|
588 |
+
tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
589 |
+
'input_event_end_indices':
|
590 |
+
tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
591 |
+
'state_events':
|
592 |
+
tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
593 |
+
'input_state_event_indices':
|
594 |
+
tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
595 |
+
'sequence':
|
596 |
+
tf.TensorSpec(shape=(), dtype=tf.string)
|
597 |
+
},
|
598 |
+
args=[
|
599 |
+
input_record['note_sequences'], input_record['mix'],
|
600 |
+
input_record['audio_sample_rate'], input_record['inst_names'],
|
601 |
+
input_record['track_id']
|
602 |
+
])
|
603 |
+
|
604 |
+
ds = _include_inputs(ds, input_record, fields_to_omit=['mix', 'stems'])
|
605 |
+
return ds
|
606 |
+
|
607 |
+
tokenized_records = ds.flat_map(process_record)
|
608 |
+
return tokenized_records
|
609 |
+
|
610 |
+
|
611 |
+
|
612 |
+
|
613 |
+
@seqio.map_over_dataset
|
614 |
+
def compute_spectrograms(ex, spectrogram_config):
|
615 |
+
samples = spectrograms.flatten_frames(ex['inputs'])
|
616 |
+
ex['inputs'] = spectrograms.compute_spectrogram(samples, spectrogram_config)
|
617 |
+
ex['raw_inputs'] = samples
|
618 |
+
return ex
|
619 |
+
|
620 |
+
|
621 |
+
def handle_too_long(dataset: tf.data.Dataset,
|
622 |
+
output_features: seqio.preprocessors.OutputFeaturesType,
|
623 |
+
sequence_length: seqio.preprocessors.SequenceLengthType,
|
624 |
+
skip: bool = False) -> tf.data.Dataset:
|
625 |
+
"""Handle sequences that are too long, by either failing or skipping them."""
|
626 |
+
def max_length_for_key(key):
|
627 |
+
max_length = sequence_length[key]
|
628 |
+
if output_features[key].add_eos:
|
629 |
+
max_length -= 1
|
630 |
+
return max_length
|
631 |
+
|
632 |
+
if skip:
|
633 |
+
# Drop examples where one of the features is longer than its maximum
|
634 |
+
# sequence length.
|
635 |
+
def is_not_too_long(ex):
|
636 |
+
return not tf.reduce_any(
|
637 |
+
[k in output_features and len(v) > max_length_for_key(k)
|
638 |
+
for k, v in ex.items()])
|
639 |
+
dataset = dataset.filter(is_not_too_long)
|
640 |
+
|
641 |
+
def assert_not_too_long(key: str, value: tf.Tensor) -> tf.Tensor:
|
642 |
+
if key in output_features:
|
643 |
+
max_length = max_length_for_key(key)
|
644 |
+
tf.debugging.assert_less_equal(
|
645 |
+
tf.shape(value)[0], max_length,
|
646 |
+
f'Value for "{key}" field exceeds maximum length')
|
647 |
+
return value
|
648 |
+
|
649 |
+
# Assert that no examples have features longer than their maximum sequence
|
650 |
+
# length.
|
651 |
+
return dataset.map(
|
652 |
+
lambda ex: {k: assert_not_too_long(k, v) for k, v in ex.items()},
|
653 |
+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
654 |
+
|
655 |
+
|
656 |
+
@gin.configurable
|
657 |
+
def map_midi_programs(
|
658 |
+
ds: tf.data.Dataset,
|
659 |
+
codec: event_codec.Codec,
|
660 |
+
granularity_type: str = 'full',
|
661 |
+
feature_key: str = 'targets'
|
662 |
+
) -> Mapping[str, Any]:
|
663 |
+
"""Apply MIDI program map to token sequences."""
|
664 |
+
granularity = vocabularies.PROGRAM_GRANULARITIES[granularity_type]
|
665 |
+
def _map_program_tokens(ex):
|
666 |
+
ex[feature_key] = granularity.tokens_map_fn(ex[feature_key], codec)
|
667 |
+
return ex
|
668 |
+
return ds.map(_map_program_tokens,
|
669 |
+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
mt3/pytest.ini
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[pytest]
|
2 |
+
python_files = *_test.py
|
3 |
+
log_level = INFO
|
mt3/run_length_encoding.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tools for run length encoding."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
from typing import Any, Callable, Mapping, MutableMapping, Tuple, Optional, Sequence, TypeVar
|
19 |
+
|
20 |
+
from absl import logging
|
21 |
+
from mt3 import event_codec
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import seqio
|
25 |
+
import tensorflow as tf
|
26 |
+
|
27 |
+
Event = event_codec.Event
|
28 |
+
|
29 |
+
# These should be type variables, but unfortunately those are incompatible with
|
30 |
+
# dataclasses.
|
31 |
+
EventData = Any
|
32 |
+
EncodingState = Any
|
33 |
+
DecodingState = Any
|
34 |
+
DecodeResult = Any
|
35 |
+
|
36 |
+
T = TypeVar('T', bound=EventData)
|
37 |
+
ES = TypeVar('ES', bound=EncodingState)
|
38 |
+
DS = TypeVar('DS', bound=DecodingState)
|
39 |
+
|
40 |
+
|
41 |
+
@dataclasses.dataclass
|
42 |
+
class EventEncodingSpec:
|
43 |
+
"""Spec for encoding events."""
|
44 |
+
# initialize encoding state
|
45 |
+
init_encoding_state_fn: Callable[[], EncodingState]
|
46 |
+
# convert EventData into zero or more events, updating encoding state
|
47 |
+
encode_event_fn: Callable[[EncodingState, EventData, event_codec.Codec],
|
48 |
+
Sequence[event_codec.Event]]
|
49 |
+
# convert encoding state (at beginning of segment) into events
|
50 |
+
encoding_state_to_events_fn: Optional[Callable[[EncodingState],
|
51 |
+
Sequence[event_codec.Event]]]
|
52 |
+
# create empty decoding state
|
53 |
+
init_decoding_state_fn: Callable[[], DecodingState]
|
54 |
+
# update decoding state when entering new segment
|
55 |
+
begin_decoding_segment_fn: Callable[[DecodingState], None]
|
56 |
+
# consume time and Event and update decoding state
|
57 |
+
decode_event_fn: Callable[
|
58 |
+
[DecodingState, float, event_codec.Event, event_codec.Codec], None]
|
59 |
+
# flush decoding state into result
|
60 |
+
flush_decoding_state_fn: Callable[[DecodingState], DecodeResult]
|
61 |
+
|
62 |
+
|
63 |
+
def encode_and_index_events(
|
64 |
+
state: ES,
|
65 |
+
event_times: Sequence[float],
|
66 |
+
event_values: Sequence[T],
|
67 |
+
encode_event_fn: Callable[[ES, T, event_codec.Codec],
|
68 |
+
Sequence[event_codec.Event]],
|
69 |
+
codec: event_codec.Codec,
|
70 |
+
frame_times: Sequence[float],
|
71 |
+
encoding_state_to_events_fn: Optional[
|
72 |
+
Callable[[ES], Sequence[event_codec.Event]]] = None,
|
73 |
+
) -> Tuple[Sequence[int], Sequence[int], Sequence[int],
|
74 |
+
Sequence[int], Sequence[int]]:
|
75 |
+
"""Encode a sequence of timed events and index to audio frame times.
|
76 |
+
|
77 |
+
Encodes time shifts as repeated single step shifts for later run length
|
78 |
+
encoding.
|
79 |
+
|
80 |
+
Optionally, also encodes a sequence of "state events", keeping track of the
|
81 |
+
current encoding state at each audio frame. This can be used e.g. to prepend
|
82 |
+
events representing the current state to a targets segment.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
state: Initial event encoding state.
|
86 |
+
event_times: Sequence of event times.
|
87 |
+
event_values: Sequence of event values.
|
88 |
+
encode_event_fn: Function that transforms event value into a sequence of one
|
89 |
+
or more event_codec.Event objects.
|
90 |
+
codec: An event_codec.Codec object that maps Event objects to indices.
|
91 |
+
frame_times: Time for every audio frame.
|
92 |
+
encoding_state_to_events_fn: Function that transforms encoding state into a
|
93 |
+
sequence of one or more event_codec.Event objects.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
events: Encoded events and shifts.
|
97 |
+
event_start_indices: Corresponding start event index for every audio frame.
|
98 |
+
Note: one event can correspond to multiple audio indices due to sampling
|
99 |
+
rate differences. This makes splitting sequences tricky because the same
|
100 |
+
event can appear at the end of one sequence and the beginning of
|
101 |
+
another.
|
102 |
+
event_end_indices: Corresponding end event index for every audio frame. Used
|
103 |
+
to ensure when slicing that one chunk ends where the next begins. Should
|
104 |
+
always be true that event_end_indices[i] = event_start_indices[i + 1].
|
105 |
+
state_events: Encoded "state" events representing the encoding state before
|
106 |
+
each event.
|
107 |
+
state_event_indices: Corresponding state event index for every audio frame.
|
108 |
+
"""
|
109 |
+
indices = np.argsort(event_times, kind='stable')
|
110 |
+
event_steps = [round(event_times[i] * codec.steps_per_second)
|
111 |
+
for i in indices]
|
112 |
+
event_values = [event_values[i] for i in indices]
|
113 |
+
|
114 |
+
events = []
|
115 |
+
state_events = []
|
116 |
+
event_start_indices = []
|
117 |
+
state_event_indices = []
|
118 |
+
|
119 |
+
cur_step = 0
|
120 |
+
cur_event_idx = 0
|
121 |
+
cur_state_event_idx = 0
|
122 |
+
|
123 |
+
def fill_event_start_indices_to_cur_step():
|
124 |
+
while(len(event_start_indices) < len(frame_times) and
|
125 |
+
frame_times[len(event_start_indices)] <
|
126 |
+
cur_step / codec.steps_per_second):
|
127 |
+
event_start_indices.append(cur_event_idx)
|
128 |
+
state_event_indices.append(cur_state_event_idx)
|
129 |
+
|
130 |
+
for event_step, event_value in zip(event_steps, event_values):
|
131 |
+
while event_step > cur_step:
|
132 |
+
events.append(codec.encode_event(Event(type='shift', value=1)))
|
133 |
+
cur_step += 1
|
134 |
+
fill_event_start_indices_to_cur_step()
|
135 |
+
cur_event_idx = len(events)
|
136 |
+
cur_state_event_idx = len(state_events)
|
137 |
+
if encoding_state_to_events_fn:
|
138 |
+
# Dump state to state events *before* processing the next event, because
|
139 |
+
# we want to capture the state prior to the occurrence of the event.
|
140 |
+
for e in encoding_state_to_events_fn(state):
|
141 |
+
state_events.append(codec.encode_event(e))
|
142 |
+
for e in encode_event_fn(state, event_value, codec):
|
143 |
+
events.append(codec.encode_event(e))
|
144 |
+
|
145 |
+
# After the last event, continue filling out the event_start_indices array.
|
146 |
+
# The inequality is not strict because if our current step lines up exactly
|
147 |
+
# with (the start of) an audio frame, we need to add an additional shift event
|
148 |
+
# to "cover" that frame.
|
149 |
+
while cur_step / codec.steps_per_second <= frame_times[-1]:
|
150 |
+
events.append(codec.encode_event(Event(type='shift', value=1)))
|
151 |
+
cur_step += 1
|
152 |
+
fill_event_start_indices_to_cur_step()
|
153 |
+
cur_event_idx = len(events)
|
154 |
+
|
155 |
+
# Now fill in event_end_indices. We need this extra array to make sure that
|
156 |
+
# when we slice events, each slice ends exactly where the subsequent slice
|
157 |
+
# begins.
|
158 |
+
event_end_indices = event_start_indices[1:] + [len(events)]
|
159 |
+
|
160 |
+
events = np.array(events)
|
161 |
+
state_events = np.array(state_events)
|
162 |
+
event_start_indices = np.array(event_start_indices)
|
163 |
+
event_end_indices = np.array(event_end_indices)
|
164 |
+
state_event_indices = np.array(state_event_indices)
|
165 |
+
|
166 |
+
return (events, event_start_indices, event_end_indices,
|
167 |
+
state_events, state_event_indices)
|
168 |
+
|
169 |
+
|
170 |
+
@seqio.map_over_dataset
|
171 |
+
def extract_target_sequence_with_indices(features, state_events_end_token=None):
|
172 |
+
"""Extract target sequence corresponding to audio token segment."""
|
173 |
+
target_start_idx = features['input_event_start_indices'][0]
|
174 |
+
target_end_idx = features['input_event_end_indices'][-1]
|
175 |
+
|
176 |
+
features['targets'] = features['targets'][target_start_idx:target_end_idx]
|
177 |
+
|
178 |
+
if state_events_end_token is not None:
|
179 |
+
# Extract the state events corresponding to the audio start token, and
|
180 |
+
# prepend them to the targets array.
|
181 |
+
state_event_start_idx = features['input_state_event_indices'][0]
|
182 |
+
state_event_end_idx = state_event_start_idx + 1
|
183 |
+
while features['state_events'][
|
184 |
+
state_event_end_idx - 1] != state_events_end_token:
|
185 |
+
state_event_end_idx += 1
|
186 |
+
features['targets'] = tf.concat([
|
187 |
+
features['state_events'][state_event_start_idx:state_event_end_idx],
|
188 |
+
features['targets']
|
189 |
+
], axis=0)
|
190 |
+
|
191 |
+
return features
|
192 |
+
|
193 |
+
|
194 |
+
def remove_redundant_state_changes_fn(
|
195 |
+
codec: event_codec.Codec,
|
196 |
+
feature_key: str = 'targets',
|
197 |
+
state_change_event_types: Sequence[str] = ()
|
198 |
+
) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]:
|
199 |
+
"""Return preprocessing function that removes redundant state change events.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
codec: The event_codec.Codec used to interpret the events.
|
203 |
+
feature_key: The feature key for which to remove redundant state changes.
|
204 |
+
state_change_event_types: A list of event types that represent state
|
205 |
+
changes; tokens corresponding to these event types will be interpreted
|
206 |
+
as state changes and redundant ones will be removed.
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
A preprocessing function that removes redundant state change events.
|
210 |
+
"""
|
211 |
+
state_change_event_ranges = [codec.event_type_range(event_type)
|
212 |
+
for event_type in state_change_event_types]
|
213 |
+
|
214 |
+
def remove_redundant_state_changes(
|
215 |
+
features: MutableMapping[str, Any],
|
216 |
+
) -> Mapping[str, Any]:
|
217 |
+
"""Remove redundant tokens e.g. duplicate velocity changes from sequence."""
|
218 |
+
current_state = tf.zeros(len(state_change_event_ranges), dtype=tf.int32)
|
219 |
+
output = tf.constant([], dtype=tf.int32)
|
220 |
+
|
221 |
+
for event in features[feature_key]:
|
222 |
+
# Let autograph know that the shape of 'output' will change during the
|
223 |
+
# loop.
|
224 |
+
tf.autograph.experimental.set_loop_options(
|
225 |
+
shape_invariants=[(output, tf.TensorShape([None]))])
|
226 |
+
is_redundant = False
|
227 |
+
for i, (min_index, max_index) in enumerate(state_change_event_ranges):
|
228 |
+
if (min_index <= event) and (event <= max_index):
|
229 |
+
if current_state[i] == event:
|
230 |
+
is_redundant = True
|
231 |
+
current_state = tf.tensor_scatter_nd_update(
|
232 |
+
current_state, indices=[[i]], updates=[event])
|
233 |
+
if not is_redundant:
|
234 |
+
output = tf.concat([output, [event]], axis=0)
|
235 |
+
|
236 |
+
features[feature_key] = output
|
237 |
+
return features
|
238 |
+
|
239 |
+
return seqio.map_over_dataset(remove_redundant_state_changes)
|
240 |
+
|
241 |
+
|
242 |
+
def run_length_encode_shifts_fn(
|
243 |
+
codec: event_codec.Codec,
|
244 |
+
feature_key: str = 'targets'
|
245 |
+
) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]:
|
246 |
+
"""Return a function that run-length encodes shifts for a given codec.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
codec: The Codec to use for shift events.
|
250 |
+
feature_key: The feature key for which to run-length encode shifts.
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
A preprocessing function that run-length encodes single-step shifts.
|
254 |
+
"""
|
255 |
+
def run_length_encode_shifts(
|
256 |
+
features: MutableMapping[str, Any]
|
257 |
+
) -> Mapping[str, Any]:
|
258 |
+
"""Combine leading/interior shifts, trim trailing shifts.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
features: Dict of features to process.
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
A dict of features.
|
265 |
+
"""
|
266 |
+
events = features[feature_key]
|
267 |
+
|
268 |
+
shift_steps = 0
|
269 |
+
total_shift_steps = 0
|
270 |
+
output = tf.constant([], dtype=tf.int32)
|
271 |
+
|
272 |
+
for event in events:
|
273 |
+
# Let autograph know that the shape of 'output' will change during the
|
274 |
+
# loop.
|
275 |
+
tf.autograph.experimental.set_loop_options(
|
276 |
+
shape_invariants=[(output, tf.TensorShape([None]))])
|
277 |
+
if codec.is_shift_event_index(event):
|
278 |
+
shift_steps += 1
|
279 |
+
total_shift_steps += 1
|
280 |
+
|
281 |
+
else:
|
282 |
+
# Once we've reached a non-shift event, RLE all previous shift events
|
283 |
+
# before outputting the non-shift event.
|
284 |
+
if shift_steps > 0:
|
285 |
+
shift_steps = total_shift_steps
|
286 |
+
while shift_steps > 0:
|
287 |
+
output_steps = tf.minimum(codec.max_shift_steps, shift_steps)
|
288 |
+
output = tf.concat([output, [output_steps]], axis=0)
|
289 |
+
shift_steps -= output_steps
|
290 |
+
output = tf.concat([output, [event]], axis=0)
|
291 |
+
|
292 |
+
features[feature_key] = output
|
293 |
+
return features
|
294 |
+
|
295 |
+
return seqio.map_over_dataset(run_length_encode_shifts)
|
296 |
+
|
297 |
+
|
298 |
+
def merge_run_length_encoded_targets(
|
299 |
+
targets: np.ndarray,
|
300 |
+
codec: event_codec.Codec
|
301 |
+
) -> Sequence[int]:
|
302 |
+
"""Merge multiple tracks of target events into a single stream.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
targets: A 2D array (# tracks by # events) of integer event values.
|
306 |
+
codec: The event_codec.Codec used to interpret the events.
|
307 |
+
|
308 |
+
Returns:
|
309 |
+
A 1D array of merged events.
|
310 |
+
"""
|
311 |
+
num_tracks = tf.shape(targets)[0]
|
312 |
+
targets_length = tf.shape(targets)[1]
|
313 |
+
|
314 |
+
current_step = 0
|
315 |
+
current_offsets = tf.zeros(num_tracks, dtype=tf.int32)
|
316 |
+
|
317 |
+
output = tf.constant([], dtype=tf.int32)
|
318 |
+
done = tf.constant(False)
|
319 |
+
|
320 |
+
while not done:
|
321 |
+
# Let autograph know that the shape of 'output' will change during the loop.
|
322 |
+
tf.autograph.experimental.set_loop_options(
|
323 |
+
shape_invariants=[(output, tf.TensorShape([None]))])
|
324 |
+
|
325 |
+
# Determine which targets track has the earliest next step.
|
326 |
+
next_step = codec.max_shift_steps + 1
|
327 |
+
next_track = -1
|
328 |
+
for i in range(num_tracks):
|
329 |
+
if (current_offsets[i] == targets_length or
|
330 |
+
targets[i][current_offsets[i]] == 0):
|
331 |
+
# Already reached the end of this targets track.
|
332 |
+
# (Zero is technically a valid shift event but we never actually use it;
|
333 |
+
# it is always padding.)
|
334 |
+
continue
|
335 |
+
if not codec.is_shift_event_index(targets[i][current_offsets[i]]):
|
336 |
+
# The only way we would be at a non-shift event is if we have not yet
|
337 |
+
# reached the first shift event, which means we're at step zero.
|
338 |
+
next_step = 0
|
339 |
+
next_track = i
|
340 |
+
elif targets[i][current_offsets[i]] < next_step:
|
341 |
+
next_step = targets[i][current_offsets[i]]
|
342 |
+
next_track = i
|
343 |
+
|
344 |
+
if next_track == -1:
|
345 |
+
# We've already merged all of the target tracks in their entirety.
|
346 |
+
done = tf.constant(True)
|
347 |
+
break
|
348 |
+
|
349 |
+
if next_step == current_step and next_step > 0:
|
350 |
+
# We don't need to include the shift event itself as it's the same step as
|
351 |
+
# the previous shift.
|
352 |
+
start_offset = current_offsets[next_track] + 1
|
353 |
+
else:
|
354 |
+
start_offset = current_offsets[next_track]
|
355 |
+
|
356 |
+
# Merge in events up to but not including the next shift.
|
357 |
+
end_offset = start_offset + 1
|
358 |
+
while end_offset < targets_length and not codec.is_shift_event_index(
|
359 |
+
targets[next_track][end_offset]):
|
360 |
+
end_offset += 1
|
361 |
+
output = tf.concat(
|
362 |
+
[output, targets[next_track][start_offset:end_offset]], axis=0)
|
363 |
+
|
364 |
+
current_step = next_step
|
365 |
+
current_offsets = tf.tensor_scatter_nd_update(
|
366 |
+
current_offsets, indices=[[next_track]], updates=[end_offset])
|
367 |
+
|
368 |
+
return output
|
369 |
+
|
370 |
+
|
371 |
+
def decode_events(
|
372 |
+
state: DS,
|
373 |
+
tokens: np.ndarray,
|
374 |
+
start_time: int,
|
375 |
+
max_time: Optional[int],
|
376 |
+
codec: event_codec.Codec,
|
377 |
+
decode_event_fn: Callable[[DS, float, event_codec.Event, event_codec.Codec],
|
378 |
+
None],
|
379 |
+
) -> Tuple[int, int]:
|
380 |
+
"""Decode a series of tokens, maintaining a decoding state object.
|
381 |
+
|
382 |
+
Args:
|
383 |
+
state: Decoding state object; will be modified in-place.
|
384 |
+
tokens: event tokens to convert.
|
385 |
+
start_time: offset start time if decoding in the middle of a sequence.
|
386 |
+
max_time: Events at or beyond this time will be dropped.
|
387 |
+
codec: An event_codec.Codec object that maps indices to Event objects.
|
388 |
+
decode_event_fn: Function that consumes an Event (and the current time) and
|
389 |
+
updates the decoding state.
|
390 |
+
|
391 |
+
Returns:
|
392 |
+
invalid_events: number of events that could not be decoded.
|
393 |
+
dropped_events: number of events dropped due to max_time restriction.
|
394 |
+
"""
|
395 |
+
invalid_events = 0
|
396 |
+
dropped_events = 0
|
397 |
+
cur_steps = 0
|
398 |
+
cur_time = start_time
|
399 |
+
token_idx = 0
|
400 |
+
for token_idx, token in enumerate(tokens):
|
401 |
+
try:
|
402 |
+
event = codec.decode_event_index(token)
|
403 |
+
except ValueError:
|
404 |
+
invalid_events += 1
|
405 |
+
continue
|
406 |
+
if event.type == 'shift':
|
407 |
+
cur_steps += event.value
|
408 |
+
cur_time = start_time + cur_steps / codec.steps_per_second
|
409 |
+
if max_time and cur_time > max_time:
|
410 |
+
dropped_events = len(tokens) - token_idx
|
411 |
+
break
|
412 |
+
else:
|
413 |
+
cur_steps = 0
|
414 |
+
try:
|
415 |
+
decode_event_fn(state, cur_time, event, codec)
|
416 |
+
except ValueError:
|
417 |
+
invalid_events += 1
|
418 |
+
logging.info(
|
419 |
+
'Got invalid event when decoding event %s at time %f. '
|
420 |
+
'Invalid event counter now at %d.',
|
421 |
+
event, cur_time, invalid_events, exc_info=True)
|
422 |
+
continue
|
423 |
+
return invalid_events, dropped_events
|
mt3/run_length_encoding_test.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for run_length_encoding."""
|
16 |
+
|
17 |
+
from mt3 import event_codec
|
18 |
+
from mt3 import run_length_encoding
|
19 |
+
|
20 |
+
import note_seq
|
21 |
+
import numpy as np
|
22 |
+
import seqio
|
23 |
+
import tensorflow as tf
|
24 |
+
|
25 |
+
assert_dataset = seqio.test_utils.assert_dataset
|
26 |
+
codec = event_codec.Codec(
|
27 |
+
max_shift_steps=100,
|
28 |
+
steps_per_second=100,
|
29 |
+
event_ranges=[
|
30 |
+
event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH,
|
31 |
+
note_seq.MAX_MIDI_PITCH),
|
32 |
+
event_codec.EventRange('velocity', 0, 127),
|
33 |
+
event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH,
|
34 |
+
note_seq.MAX_MIDI_PITCH),
|
35 |
+
event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM,
|
36 |
+
note_seq.MAX_MIDI_PROGRAM),
|
37 |
+
event_codec.EventRange('tie', 0, 0)
|
38 |
+
])
|
39 |
+
run_length_encode_shifts = run_length_encoding.run_length_encode_shifts_fn(
|
40 |
+
codec=codec)
|
41 |
+
|
42 |
+
|
43 |
+
class RunLengthEncodingTest(tf.test.TestCase):
|
44 |
+
|
45 |
+
def test_remove_redundant_state_changes(self):
|
46 |
+
og_dataset = tf.data.Dataset.from_tensors({
|
47 |
+
'targets': [3, 525, 356, 161, 2, 525, 356, 161, 355, 394]
|
48 |
+
})
|
49 |
+
|
50 |
+
assert_dataset(
|
51 |
+
run_length_encoding.remove_redundant_state_changes_fn(
|
52 |
+
codec=codec,
|
53 |
+
state_change_event_types=['velocity', 'program'])(og_dataset),
|
54 |
+
{
|
55 |
+
'targets': [3, 525, 356, 161, 2, 161, 355, 394],
|
56 |
+
})
|
57 |
+
|
58 |
+
def test_run_length_encode_shifts(self):
|
59 |
+
og_dataset = tf.data.Dataset.from_tensors({
|
60 |
+
'targets': [1, 1, 1, 161, 1, 1, 1, 162, 1, 1, 1]
|
61 |
+
})
|
62 |
+
|
63 |
+
assert_dataset(
|
64 |
+
run_length_encode_shifts(og_dataset),
|
65 |
+
{
|
66 |
+
'targets': [3, 161, 6, 162],
|
67 |
+
})
|
68 |
+
|
69 |
+
def test_run_length_encode_shifts_beyond_max_length(self):
|
70 |
+
og_dataset = tf.data.Dataset.from_tensors({
|
71 |
+
'targets': [1] * 202 + [161, 1, 1, 1]
|
72 |
+
})
|
73 |
+
|
74 |
+
assert_dataset(
|
75 |
+
run_length_encode_shifts(og_dataset),
|
76 |
+
{
|
77 |
+
'targets': [100, 100, 2, 161],
|
78 |
+
})
|
79 |
+
|
80 |
+
def test_run_length_encode_shifts_simultaneous(self):
|
81 |
+
og_dataset = tf.data.Dataset.from_tensors({
|
82 |
+
'targets': [1, 1, 1, 161, 162, 1, 1, 1]
|
83 |
+
})
|
84 |
+
|
85 |
+
assert_dataset(
|
86 |
+
run_length_encode_shifts(og_dataset),
|
87 |
+
{
|
88 |
+
'targets': [3, 161, 162],
|
89 |
+
})
|
90 |
+
|
91 |
+
def test_merge_run_length_encoded_targets(self):
|
92 |
+
# pylint: disable=bad-whitespace
|
93 |
+
targets = np.array([
|
94 |
+
[ 3, 161, 162, 5, 163],
|
95 |
+
[160, 164, 3, 165, 0]
|
96 |
+
])
|
97 |
+
# pylint: enable=bad-whitespace
|
98 |
+
merged_targets = run_length_encoding.merge_run_length_encoded_targets(
|
99 |
+
targets=targets, codec=codec)
|
100 |
+
expected_merged_targets = [
|
101 |
+
160, 164, 3, 161, 162, 165, 5, 163
|
102 |
+
]
|
103 |
+
np.testing.assert_array_equal(expected_merged_targets, merged_targets)
|
104 |
+
|
105 |
+
|
106 |
+
if __name__ == '__main__':
|
107 |
+
tf.test.main()
|
mt3/scripts/dump_task.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Simple debugging utility for printing out task contents."""
|
16 |
+
|
17 |
+
import re
|
18 |
+
|
19 |
+
from absl import app
|
20 |
+
from absl import flags
|
21 |
+
|
22 |
+
import mt3.tasks # pylint: disable=unused-import
|
23 |
+
|
24 |
+
import seqio
|
25 |
+
import tensorflow as tf
|
26 |
+
|
27 |
+
|
28 |
+
FLAGS = flags.FLAGS
|
29 |
+
|
30 |
+
flags.DEFINE_string("task", None, "A registered Task.")
|
31 |
+
flags.DEFINE_string("task_cache_dir", None, "Directory to use for task cache.")
|
32 |
+
flags.DEFINE_integer("max_examples", 10,
|
33 |
+
"Maximum number of examples (-1 for no limit).")
|
34 |
+
flags.DEFINE_string("format_string", "targets = {targets}",
|
35 |
+
"Format for printing examples.")
|
36 |
+
flags.DEFINE_string("split", "train",
|
37 |
+
"Which split of the dataset, e.g. train or validation.")
|
38 |
+
flags.DEFINE_integer("sequence_length_inputs", 256,
|
39 |
+
"Sequence length for inputs.")
|
40 |
+
flags.DEFINE_integer("sequence_length_targets", 1024,
|
41 |
+
"Sequence length for targets.")
|
42 |
+
|
43 |
+
|
44 |
+
def main(_):
|
45 |
+
if FLAGS.task_cache_dir:
|
46 |
+
seqio.add_global_cache_dirs([FLAGS.task_cache_dir])
|
47 |
+
|
48 |
+
task = seqio.get_mixture_or_task(FLAGS.task)
|
49 |
+
|
50 |
+
ds = task.get_dataset(
|
51 |
+
sequence_length={
|
52 |
+
"inputs": FLAGS.sequence_length_inputs,
|
53 |
+
"targets": FLAGS.sequence_length_targets,
|
54 |
+
},
|
55 |
+
split=FLAGS.split,
|
56 |
+
use_cached=bool(FLAGS.task_cache_dir),
|
57 |
+
shuffle=False)
|
58 |
+
|
59 |
+
keys = re.findall(r"{([\w+]+)}", FLAGS.format_string)
|
60 |
+
def _example_to_string(ex):
|
61 |
+
key_to_string = {}
|
62 |
+
for k in keys:
|
63 |
+
if k in ex:
|
64 |
+
v = ex[k].numpy().tolist()
|
65 |
+
key_to_string[k] = task.output_features[k].vocabulary.decode(v)
|
66 |
+
else:
|
67 |
+
key_to_string[k] = ""
|
68 |
+
return FLAGS.format_string.format(**key_to_string)
|
69 |
+
|
70 |
+
for ex in ds.take(FLAGS.max_examples):
|
71 |
+
for k, v in ex.items():
|
72 |
+
print(f"{k}: {tf.shape(v)}")
|
73 |
+
print(_example_to_string(ex))
|
74 |
+
print()
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
flags.mark_flags_as_required(["task"])
|
79 |
+
|
80 |
+
app.run(main)
|
mt3/scripts/extract_monophonic_examples.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Detect monophonic tracks and extract notes."""
|
16 |
+
|
17 |
+
import collections
|
18 |
+
import os
|
19 |
+
|
20 |
+
from absl import app
|
21 |
+
from absl import flags
|
22 |
+
from absl import logging
|
23 |
+
|
24 |
+
import ddsp
|
25 |
+
import librosa
|
26 |
+
import note_seq
|
27 |
+
import numpy as np
|
28 |
+
import scipy
|
29 |
+
import tensorflow as tf
|
30 |
+
|
31 |
+
|
32 |
+
_INPUT_DIR = flags.DEFINE_string(
|
33 |
+
'input_dir', None,
|
34 |
+
'Input directory containing WAV files.')
|
35 |
+
_OUTPUT_TFRECORD_PATH = flags.DEFINE_string(
|
36 |
+
'output_tfrecord_path', None,
|
37 |
+
'Path to the output TFRecord containing tf.train.Example protos with '
|
38 |
+
'monophonic tracks and inferred NoteSequence protos.')
|
39 |
+
|
40 |
+
|
41 |
+
CREPE_SAMPLE_RATE = 16000
|
42 |
+
CREPE_FRAME_RATE = 100
|
43 |
+
|
44 |
+
MONOPHONIC_CONFIDENCE_THRESHOLD = 0.95 # confidence must be greater than this
|
45 |
+
MONOPHONIC_CONFIDENCE_FRAC = 0.2 # for this fraction of frames
|
46 |
+
|
47 |
+
# split input audio into clips
|
48 |
+
CLIP_LENGTH_SECONDS = 5
|
49 |
+
|
50 |
+
|
51 |
+
def is_monophonic_heuristic(f0_confidence):
|
52 |
+
"""Heuristic to check for monophonicity using f0 confidence."""
|
53 |
+
return (np.sum(f0_confidence >= MONOPHONIC_CONFIDENCE_THRESHOLD) /
|
54 |
+
len(f0_confidence) >= MONOPHONIC_CONFIDENCE_FRAC)
|
55 |
+
|
56 |
+
|
57 |
+
# HMM parameters for modeling notes and F0 tracks.
|
58 |
+
F0_MIDI_SIGMA = 0.2
|
59 |
+
OCTAVE_ERROR_PROB = 0.05
|
60 |
+
NOTES_PER_SECOND = 2
|
61 |
+
NOTE_CHANGE_PROB = NOTES_PER_SECOND / CREPE_FRAME_RATE
|
62 |
+
F0_CONFIDENCE_EXP = 7.5
|
63 |
+
|
64 |
+
|
65 |
+
def f0_hmm_matrices(f0_hz, f0_confidence):
|
66 |
+
"""Observation and transition matrices for hidden Markov model of F0."""
|
67 |
+
f0_midi = librosa.hz_to_midi(f0_hz)
|
68 |
+
f0_midi_diff = f0_midi[:, np.newaxis] - np.arange(128)[np.newaxis, :]
|
69 |
+
|
70 |
+
# Compute the probability of each pitch at each frame, taking octave errors
|
71 |
+
# into account.
|
72 |
+
f0_midi_prob_octave_correct = scipy.stats.norm.pdf(
|
73 |
+
f0_midi_diff, scale=F0_MIDI_SIGMA)
|
74 |
+
f0_midi_prob_octave_low = scipy.stats.norm.pdf(
|
75 |
+
f0_midi_diff + 12, scale=F0_MIDI_SIGMA)
|
76 |
+
f0_midi_prob_octave_high = scipy.stats.norm.pdf(
|
77 |
+
f0_midi_diff - 12, scale=F0_MIDI_SIGMA)
|
78 |
+
|
79 |
+
# distribution of pitch values given note
|
80 |
+
f0_midi_loglik = ((1 - OCTAVE_ERROR_PROB) * f0_midi_prob_octave_correct +
|
81 |
+
0.5 * OCTAVE_ERROR_PROB * f0_midi_prob_octave_low +
|
82 |
+
0.5 * OCTAVE_ERROR_PROB * f0_midi_prob_octave_high)
|
83 |
+
# (uniform) distribution of pitch values given rest
|
84 |
+
f0_midi_rest_loglik = -np.log(128)
|
85 |
+
|
86 |
+
# Here we interpret confidence, after adjusting by exponent, as P(not rest).
|
87 |
+
f0_confidence_prob = np.power(f0_confidence, F0_CONFIDENCE_EXP)[:, np.newaxis]
|
88 |
+
|
89 |
+
obs_loglik = np.concatenate([
|
90 |
+
# probability of note (normalized by number of possible notes)
|
91 |
+
f0_midi_loglik + np.log(f0_confidence_prob) - np.log(128),
|
92 |
+
# probability of rest
|
93 |
+
f0_midi_rest_loglik + np.log(1.0 - f0_confidence_prob)
|
94 |
+
], axis=1)
|
95 |
+
|
96 |
+
# Normalize to adjust P(confidence | note) by uniform P(note).
|
97 |
+
# TODO(iansimon): Not sure how correct this is but it doesn't affect the path.
|
98 |
+
obs_loglik += np.log(129)
|
99 |
+
|
100 |
+
trans_prob = ((NOTE_CHANGE_PROB / 128) * np.ones(129) +
|
101 |
+
(1 - NOTE_CHANGE_PROB - NOTE_CHANGE_PROB / 128) * np.eye(129))
|
102 |
+
trans_loglik = np.log(trans_prob)
|
103 |
+
|
104 |
+
return obs_loglik, trans_loglik
|
105 |
+
|
106 |
+
|
107 |
+
def hmm_forward(obs_loglik, trans_loglik):
|
108 |
+
"""Forward algorithm for a hidden Markov model."""
|
109 |
+
n, k = obs_loglik.shape
|
110 |
+
trans = np.exp(trans_loglik)
|
111 |
+
|
112 |
+
loglik = 0.0
|
113 |
+
|
114 |
+
l = obs_loglik[0] - np.log(k)
|
115 |
+
c = scipy.special.logsumexp(l)
|
116 |
+
loglik += c
|
117 |
+
|
118 |
+
for i in range(1, n):
|
119 |
+
p = np.exp(l - c)
|
120 |
+
l = np.log(np.dot(p, trans)) + obs_loglik[i]
|
121 |
+
c = scipy.special.logsumexp(l)
|
122 |
+
loglik += c
|
123 |
+
|
124 |
+
return loglik
|
125 |
+
|
126 |
+
|
127 |
+
def hmm_viterbi(obs_loglik, trans_loglik):
|
128 |
+
"""Viterbi algorithm for a hidden Markov model."""
|
129 |
+
n, k = obs_loglik.shape
|
130 |
+
|
131 |
+
loglik_matrix = np.zeros_like(obs_loglik)
|
132 |
+
path_matrix = np.zeros_like(obs_loglik, dtype=np.int32)
|
133 |
+
|
134 |
+
loglik_matrix[0, :] = obs_loglik[0, :] - np.log(k)
|
135 |
+
|
136 |
+
for i in range(1, n):
|
137 |
+
mat = np.tile(loglik_matrix[i - 1][:, np.newaxis], [1, 129]) + trans_loglik
|
138 |
+
path_matrix[i, :] = mat.argmax(axis=0)
|
139 |
+
loglik_matrix[i, :] = mat[path_matrix[i, :], range(129)] + obs_loglik[i]
|
140 |
+
|
141 |
+
path = [np.argmax(loglik_matrix[-1])]
|
142 |
+
for i in range(n, 1, -1):
|
143 |
+
path.append(path_matrix[i - 1, path[-1]])
|
144 |
+
|
145 |
+
return [(pitch if pitch < 128 else None) for pitch in path[::-1]]
|
146 |
+
|
147 |
+
|
148 |
+
def pitches_to_notesequence(pitches):
|
149 |
+
"""Convert sequence of pitches output by Viterbi to NoteSequence proto."""
|
150 |
+
ns = note_seq.NoteSequence(ticks_per_quarter=220)
|
151 |
+
current_pitch = None
|
152 |
+
start_time = None
|
153 |
+
for frame, pitch in enumerate(pitches):
|
154 |
+
time = frame / CREPE_FRAME_RATE
|
155 |
+
if pitch != current_pitch:
|
156 |
+
if current_pitch is not None:
|
157 |
+
ns.notes.add(
|
158 |
+
pitch=current_pitch, velocity=100,
|
159 |
+
start_time=start_time, end_time=time)
|
160 |
+
current_pitch = pitch
|
161 |
+
start_time = time
|
162 |
+
if current_pitch is not None:
|
163 |
+
ns.notes.add(
|
164 |
+
pitch=current_pitch, velocity=100,
|
165 |
+
start_time=start_time, end_time=len(pitches) / CREPE_FRAME_RATE)
|
166 |
+
if ns.notes:
|
167 |
+
ns.total_time = ns.notes[-1].end_time
|
168 |
+
return ns
|
169 |
+
|
170 |
+
|
171 |
+
# Per-frame log likelihood threshold below which an F0 track will be discarded.
|
172 |
+
# Note that this is dependent on the HMM parameters specified above, so if those
|
173 |
+
# change then this threshold should also change.
|
174 |
+
PER_FRAME_LOGLIK_THRESHOLD = 0.3
|
175 |
+
|
176 |
+
|
177 |
+
def extract_note_sequence(crepe, samples, counters):
|
178 |
+
"""Use CREPE to attempt to extract a monophonic NoteSequence from audio."""
|
179 |
+
f0_hz, f0_confidence = crepe.predict_f0_and_confidence(
|
180 |
+
samples[np.newaxis, :], viterbi=False)
|
181 |
+
|
182 |
+
f0_hz = f0_hz[0].numpy()
|
183 |
+
f0_confidence = f0_confidence[0].numpy()
|
184 |
+
|
185 |
+
if not is_monophonic_heuristic(f0_confidence):
|
186 |
+
counters['not_monophonic'] += 1
|
187 |
+
return None
|
188 |
+
|
189 |
+
obs_loglik, trans_loglik = f0_hmm_matrices(f0_hz, f0_confidence)
|
190 |
+
|
191 |
+
loglik = hmm_forward(obs_loglik, trans_loglik)
|
192 |
+
if loglik / len(obs_loglik) < PER_FRAME_LOGLIK_THRESHOLD:
|
193 |
+
counters['low_likelihood'] += 1
|
194 |
+
return None
|
195 |
+
|
196 |
+
pitches = hmm_viterbi(obs_loglik, trans_loglik)
|
197 |
+
ns = pitches_to_notesequence(pitches)
|
198 |
+
|
199 |
+
counters['extracted_monophonic_sequence'] += 1
|
200 |
+
return ns
|
201 |
+
|
202 |
+
|
203 |
+
def process_wav_file(wav_filename, crepe, counters):
|
204 |
+
"""Extract monophonic transcription examples from a WAV file."""
|
205 |
+
wav_data = tf.io.gfile.GFile(wav_filename, 'rb').read()
|
206 |
+
samples = note_seq.audio_io.wav_data_to_samples_librosa(
|
207 |
+
wav_data, sample_rate=CREPE_SAMPLE_RATE)
|
208 |
+
clip_length_samples = int(CREPE_SAMPLE_RATE * CLIP_LENGTH_SECONDS)
|
209 |
+
for start_sample in range(0, len(samples), clip_length_samples):
|
210 |
+
clip_samples = samples[start_sample:start_sample + clip_length_samples]
|
211 |
+
if len(clip_samples) < clip_length_samples:
|
212 |
+
clip_samples = np.pad(
|
213 |
+
clip_samples, [(0, clip_length_samples - len(clip_samples))])
|
214 |
+
ns = extract_note_sequence(crepe, clip_samples, counters)
|
215 |
+
if ns:
|
216 |
+
feature = {
|
217 |
+
'audio': tf.train.Feature(
|
218 |
+
float_list=tf.train.FloatList(value=clip_samples.tolist())),
|
219 |
+
'filename': tf.train.Feature(
|
220 |
+
bytes_list=tf.train.BytesList(value=[wav_filename.encode()])),
|
221 |
+
'offset': tf.train.Feature(
|
222 |
+
int64_list=tf.train.Int64List(value=[start_sample])),
|
223 |
+
'sampling_rate': tf.train.Feature(
|
224 |
+
float_list=tf.train.FloatList(value=[CREPE_SAMPLE_RATE])),
|
225 |
+
'sequence': tf.train.Feature(
|
226 |
+
bytes_list=tf.train.BytesList(value=[ns.SerializeToString()]))
|
227 |
+
}
|
228 |
+
yield tf.train.Example(features=tf.train.Features(feature=feature))
|
229 |
+
|
230 |
+
|
231 |
+
def main(unused_argv):
|
232 |
+
flags.mark_flags_as_required(['input_dir', 'output_tfrecord_path'])
|
233 |
+
crepe = ddsp.spectral_ops.PretrainedCREPE('full')
|
234 |
+
counters = collections.defaultdict(int)
|
235 |
+
with tf.io.TFRecordWriter(_OUTPUT_TFRECORD_PATH.value) as writer:
|
236 |
+
for filename in tf.io.gfile.listdir(_INPUT_DIR.value):
|
237 |
+
if not filename.endswith('.wav'):
|
238 |
+
logging.info('skipping %s...', filename)
|
239 |
+
counters['non_wav_files_skipped'] += 1
|
240 |
+
continue
|
241 |
+
logging.info('processing %s...', filename)
|
242 |
+
for ex in process_wav_file(
|
243 |
+
os.path.join(_INPUT_DIR.value, filename), crepe, counters):
|
244 |
+
writer.write(ex.SerializeToString())
|
245 |
+
counters['wav_files_processed'] += 1
|
246 |
+
for k, v in counters.items():
|
247 |
+
logging.info('COUNTER: %s = %d', k, v)
|
248 |
+
|
249 |
+
|
250 |
+
if __name__ == '__main__':
|
251 |
+
app.run(main)
|
mt3/spectrograms.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Audio spectrogram functions."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
|
19 |
+
from ddsp import spectral_ops
|
20 |
+
import tensorflow as tf
|
21 |
+
|
22 |
+
# defaults for spectrogram config
|
23 |
+
DEFAULT_SAMPLE_RATE = 16000
|
24 |
+
DEFAULT_HOP_WIDTH = 128
|
25 |
+
DEFAULT_NUM_MEL_BINS = 512
|
26 |
+
|
27 |
+
# fixed constants; add these to SpectrogramConfig before changing
|
28 |
+
FFT_SIZE = 2048
|
29 |
+
MEL_LO_HZ = 20.0
|
30 |
+
|
31 |
+
|
32 |
+
@dataclasses.dataclass
|
33 |
+
class SpectrogramConfig:
|
34 |
+
"""Spectrogram configuration parameters."""
|
35 |
+
sample_rate: int = DEFAULT_SAMPLE_RATE
|
36 |
+
hop_width: int = DEFAULT_HOP_WIDTH
|
37 |
+
num_mel_bins: int = DEFAULT_NUM_MEL_BINS
|
38 |
+
|
39 |
+
@property
|
40 |
+
def abbrev_str(self):
|
41 |
+
s = ''
|
42 |
+
if self.sample_rate != DEFAULT_SAMPLE_RATE:
|
43 |
+
s += 'sr%d' % self.sample_rate
|
44 |
+
if self.hop_width != DEFAULT_HOP_WIDTH:
|
45 |
+
s += 'hw%d' % self.hop_width
|
46 |
+
if self.num_mel_bins != DEFAULT_NUM_MEL_BINS:
|
47 |
+
s += 'mb%d' % self.num_mel_bins
|
48 |
+
return s
|
49 |
+
|
50 |
+
@property
|
51 |
+
def frames_per_second(self):
|
52 |
+
return self.sample_rate / self.hop_width
|
53 |
+
|
54 |
+
|
55 |
+
def split_audio(samples, spectrogram_config):
|
56 |
+
"""Split audio into frames."""
|
57 |
+
return tf.signal.frame(
|
58 |
+
samples,
|
59 |
+
frame_length=spectrogram_config.hop_width,
|
60 |
+
frame_step=spectrogram_config.hop_width,
|
61 |
+
pad_end=True)
|
62 |
+
|
63 |
+
|
64 |
+
def compute_spectrogram(samples, spectrogram_config):
|
65 |
+
"""Compute a mel spectrogram."""
|
66 |
+
overlap = 1 - (spectrogram_config.hop_width / FFT_SIZE)
|
67 |
+
return spectral_ops.compute_logmel(
|
68 |
+
samples,
|
69 |
+
bins=spectrogram_config.num_mel_bins,
|
70 |
+
lo_hz=MEL_LO_HZ,
|
71 |
+
overlap=overlap,
|
72 |
+
fft_size=FFT_SIZE,
|
73 |
+
sample_rate=spectrogram_config.sample_rate)
|
74 |
+
|
75 |
+
|
76 |
+
def flatten_frames(frames):
|
77 |
+
"""Convert frames back into a flat array of samples."""
|
78 |
+
return tf.reshape(frames, [-1])
|
79 |
+
|
80 |
+
|
81 |
+
def input_depth(spectrogram_config):
|
82 |
+
return spectrogram_config.num_mel_bins
|
mt3/summaries.py
ADDED
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""TensorBoard summaries and utilities."""
|
16 |
+
|
17 |
+
from typing import Any, Mapping, Optional, Sequence, Tuple
|
18 |
+
|
19 |
+
import librosa
|
20 |
+
|
21 |
+
from mt3 import note_sequences
|
22 |
+
from mt3 import spectrograms
|
23 |
+
|
24 |
+
import note_seq
|
25 |
+
from note_seq import midi_synth
|
26 |
+
from note_seq import sequences_lib
|
27 |
+
from note_seq.protobuf import music_pb2
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
import seqio
|
31 |
+
|
32 |
+
|
33 |
+
_DEFAULT_AUDIO_SECONDS = 30.0
|
34 |
+
_DEFAULT_PIANOROLL_FRAMES_PER_SECOND = 15
|
35 |
+
|
36 |
+
# TODO(iansimon): pick a SoundFont; for some reason the default is all organ
|
37 |
+
|
38 |
+
|
39 |
+
def _extract_example_audio(
|
40 |
+
examples: Sequence[Mapping[str, Any]],
|
41 |
+
sample_rate: float,
|
42 |
+
num_seconds: float,
|
43 |
+
audio_key: str = 'raw_inputs'
|
44 |
+
) -> np.ndarray:
|
45 |
+
"""Extract audio from examples.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
examples: List of examples containing raw audio.
|
49 |
+
sample_rate: Number of samples per second.
|
50 |
+
num_seconds: Number of seconds of audio to include.
|
51 |
+
audio_key: Dictionary key for the raw audio.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
An n-by-num_samples numpy array of samples.
|
55 |
+
"""
|
56 |
+
n = len(examples)
|
57 |
+
num_samples = round(num_seconds * sample_rate)
|
58 |
+
all_samples = np.zeros([n, num_samples])
|
59 |
+
for i, ex in enumerate(examples):
|
60 |
+
samples = ex[audio_key][:num_samples]
|
61 |
+
all_samples[i, :len(samples)] = samples
|
62 |
+
return all_samples
|
63 |
+
|
64 |
+
|
65 |
+
def _example_to_note_sequence(
|
66 |
+
example: Mapping[str, Sequence[float]],
|
67 |
+
ns_feature_name: str,
|
68 |
+
note_onset_feature_name: str,
|
69 |
+
note_offset_feature_name: str,
|
70 |
+
note_frequency_feature_name: str,
|
71 |
+
note_confidence_feature_name: str,
|
72 |
+
num_seconds: float
|
73 |
+
) -> music_pb2.NoteSequence:
|
74 |
+
"""Extract NoteSequence from example."""
|
75 |
+
if ns_feature_name:
|
76 |
+
ns = example[ns_feature_name]
|
77 |
+
|
78 |
+
else:
|
79 |
+
onset_times = np.array(example[note_onset_feature_name])
|
80 |
+
pitches = librosa.hz_to_midi(
|
81 |
+
example[note_frequency_feature_name]).round().astype(int)
|
82 |
+
assert len(onset_times) == len(pitches)
|
83 |
+
|
84 |
+
if note_offset_feature_name or note_confidence_feature_name:
|
85 |
+
offset_times = (
|
86 |
+
example[note_offset_feature_name]
|
87 |
+
if note_offset_feature_name
|
88 |
+
else onset_times + note_sequences.DEFAULT_NOTE_DURATION
|
89 |
+
)
|
90 |
+
assert len(onset_times) == len(offset_times)
|
91 |
+
|
92 |
+
confidences = (np.array(example[note_confidence_feature_name])
|
93 |
+
if note_confidence_feature_name else None)
|
94 |
+
velocities = np.ceil(
|
95 |
+
note_seq.MAX_MIDI_VELOCITY * confidences if confidences is not None
|
96 |
+
else note_sequences.DEFAULT_VELOCITY * np.ones_like(onset_times)
|
97 |
+
).astype(int)
|
98 |
+
assert len(onset_times) == len(velocities)
|
99 |
+
|
100 |
+
ns = note_sequences.note_arrays_to_note_sequence(
|
101 |
+
onset_times=onset_times, offset_times=offset_times,
|
102 |
+
pitches=pitches, velocities=velocities)
|
103 |
+
|
104 |
+
else:
|
105 |
+
ns = note_sequences.note_arrays_to_note_sequence(
|
106 |
+
onset_times=onset_times, pitches=pitches)
|
107 |
+
|
108 |
+
return sequences_lib.trim_note_sequence(ns, 0, num_seconds)
|
109 |
+
|
110 |
+
|
111 |
+
def _synthesize_example_notes(
|
112 |
+
examples: Sequence[Mapping[str, Sequence[float]]],
|
113 |
+
ns_feature_name: str,
|
114 |
+
note_onset_feature_name: str,
|
115 |
+
note_offset_feature_name: str,
|
116 |
+
note_frequency_feature_name: str,
|
117 |
+
note_confidence_feature_name: str,
|
118 |
+
sample_rate: float,
|
119 |
+
num_seconds: float,
|
120 |
+
) -> np.ndarray:
|
121 |
+
"""Synthesize example notes to audio.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
examples: List of example dictionaries, containing either serialized
|
125 |
+
NoteSequence protos or note onset times and pitches.
|
126 |
+
ns_feature_name: Name of serialized NoteSequence feature.
|
127 |
+
note_onset_feature_name: Name of note onset times feature.
|
128 |
+
note_offset_feature_name: Name of note offset times feature.
|
129 |
+
note_frequency_feature_name: Name of note frequencies feature.
|
130 |
+
note_confidence_feature_name: Name of note confidences (velocities) feature.
|
131 |
+
sample_rate: Sample rate at which to synthesize.
|
132 |
+
num_seconds: Number of seconds to synthesize for each example.
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
An n-by-num_samples numpy array of samples.
|
136 |
+
"""
|
137 |
+
if (ns_feature_name is not None) == (note_onset_feature_name is not None):
|
138 |
+
raise ValueError(
|
139 |
+
'must specify exactly one of NoteSequence feature and onset feature')
|
140 |
+
|
141 |
+
n = len(examples)
|
142 |
+
num_samples = round(num_seconds * sample_rate)
|
143 |
+
|
144 |
+
all_samples = np.zeros([n, num_samples])
|
145 |
+
|
146 |
+
for i, ex in enumerate(examples):
|
147 |
+
ns = _example_to_note_sequence(
|
148 |
+
ex,
|
149 |
+
ns_feature_name=ns_feature_name,
|
150 |
+
note_onset_feature_name=note_onset_feature_name,
|
151 |
+
note_offset_feature_name=note_offset_feature_name,
|
152 |
+
note_frequency_feature_name=note_frequency_feature_name,
|
153 |
+
note_confidence_feature_name=note_confidence_feature_name,
|
154 |
+
num_seconds=num_seconds)
|
155 |
+
fluidsynth = midi_synth.fluidsynth
|
156 |
+
samples = fluidsynth(ns, sample_rate=sample_rate)
|
157 |
+
if len(samples) > num_samples:
|
158 |
+
samples = samples[:num_samples]
|
159 |
+
all_samples[i, :len(samples)] = samples
|
160 |
+
|
161 |
+
return all_samples
|
162 |
+
|
163 |
+
|
164 |
+
def _examples_to_pianorolls(
|
165 |
+
targets: Sequence[Mapping[str, Sequence[float]]],
|
166 |
+
predictions: Sequence[Mapping[str, Sequence[float]]],
|
167 |
+
ns_feature_suffix: str,
|
168 |
+
note_onset_feature_suffix: str,
|
169 |
+
note_offset_feature_suffix: str,
|
170 |
+
note_frequency_feature_suffix: str,
|
171 |
+
note_confidence_feature_suffix: str,
|
172 |
+
track_specs: Optional[Sequence[note_sequences.TrackSpec]],
|
173 |
+
num_seconds: float,
|
174 |
+
frames_per_second: float
|
175 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
176 |
+
"""Generate pianoroll images from example notes.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
targets: List of target dictionaries, containing either serialized
|
180 |
+
NoteSequence protos or note onset times and pitches.
|
181 |
+
predictions: List of prediction dictionaries, containing either serialized
|
182 |
+
NoteSequence protos or note onset times and pitches.
|
183 |
+
ns_feature_suffix: Suffix of serialized NoteSequence feature.
|
184 |
+
note_onset_feature_suffix: Suffix of note onset times feature.
|
185 |
+
note_offset_feature_suffix: Suffix of note offset times feature.
|
186 |
+
note_frequency_feature_suffix: Suffix of note frequencies feature.
|
187 |
+
note_confidence_feature_suffix: Suffix of note confidences (velocities)
|
188 |
+
feature.
|
189 |
+
track_specs: Optional list of TrackSpec objects to indicate a set of tracks
|
190 |
+
into which each NoteSequence should be split. Tracks will be stacked
|
191 |
+
vertically in the pianorolls
|
192 |
+
num_seconds: Number of seconds to show for each example.
|
193 |
+
frames_per_second: Number of pianoroll frames per second.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
onset_pianorolls: An n-by-num_pitches-by-num_frames-by-4 numpy array of
|
197 |
+
pianoroll images showing only onsets.
|
198 |
+
full_pianorolls: An n-by-num_pitches-by-num_frames-by-4 numpy array of
|
199 |
+
pianoroll images.
|
200 |
+
"""
|
201 |
+
if (ns_feature_suffix is not None) == (note_onset_feature_suffix is not None):
|
202 |
+
raise ValueError(
|
203 |
+
'must specify exactly one of NoteSequence feature and onset feature')
|
204 |
+
|
205 |
+
def ex_to_ns(example, prefix):
|
206 |
+
return _example_to_note_sequence(
|
207 |
+
example=example,
|
208 |
+
ns_feature_name=(prefix + ns_feature_suffix
|
209 |
+
if ns_feature_suffix else None),
|
210 |
+
note_onset_feature_name=(prefix + note_onset_feature_suffix
|
211 |
+
if note_onset_feature_suffix else None),
|
212 |
+
note_offset_feature_name=(prefix + note_offset_feature_suffix
|
213 |
+
if note_offset_feature_suffix else None),
|
214 |
+
note_frequency_feature_name=(
|
215 |
+
prefix + note_frequency_feature_suffix
|
216 |
+
if note_frequency_feature_suffix else None),
|
217 |
+
note_confidence_feature_name=(
|
218 |
+
prefix + note_confidence_feature_suffix
|
219 |
+
if note_confidence_feature_suffix else None),
|
220 |
+
num_seconds=num_seconds)
|
221 |
+
|
222 |
+
n = len(targets)
|
223 |
+
num_pitches = note_seq.MAX_MIDI_PITCH - note_seq.MIN_MIDI_PITCH + 1
|
224 |
+
num_frames = round(num_seconds * frames_per_second)
|
225 |
+
num_tracks = len(track_specs) if track_specs else 1
|
226 |
+
pianoroll_height = num_tracks * num_pitches + (num_tracks - 1)
|
227 |
+
|
228 |
+
onset_images = np.zeros([n, pianoroll_height, num_frames, 3])
|
229 |
+
full_images = np.zeros([n, pianoroll_height, num_frames, 3])
|
230 |
+
|
231 |
+
for i, (target, pred) in enumerate(zip(targets, predictions)):
|
232 |
+
target_ns, pred_ns = [
|
233 |
+
ex_to_ns(ex, prefix)
|
234 |
+
for (ex, prefix) in [(target, 'ref_'), (pred, 'est_')]
|
235 |
+
]
|
236 |
+
|
237 |
+
# Show lines at frame boundaries. To ensure that these lines are drawn with
|
238 |
+
# the same downsampling and frame selection logic as the real NoteSequences,
|
239 |
+
# use this hack to draw the lines with a NoteSequence that contains notes
|
240 |
+
# across all pitches at all frame start times.
|
241 |
+
start_times_ns = note_seq.NoteSequence()
|
242 |
+
start_times_ns.CopyFrom(target_ns)
|
243 |
+
del start_times_ns.notes[:]
|
244 |
+
for start_time in pred['start_times']:
|
245 |
+
if start_time < target_ns.total_time:
|
246 |
+
for pitch in range(
|
247 |
+
note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH + 1):
|
248 |
+
start_times_ns.notes.add(
|
249 |
+
pitch=pitch,
|
250 |
+
velocity=100,
|
251 |
+
start_time=start_time,
|
252 |
+
end_time=start_time + (1 / frames_per_second))
|
253 |
+
|
254 |
+
start_time_roll = sequences_lib.sequence_to_pianoroll(
|
255 |
+
start_times_ns,
|
256 |
+
frames_per_second=frames_per_second,
|
257 |
+
min_pitch=note_seq.MIN_MIDI_PITCH,
|
258 |
+
max_pitch=note_seq.MAX_MIDI_PITCH,
|
259 |
+
onset_mode='length_ms')
|
260 |
+
num_start_time_frames = min(len(start_time_roll.onsets), num_frames)
|
261 |
+
|
262 |
+
if track_specs is not None:
|
263 |
+
target_tracks = [note_sequences.extract_track(target_ns,
|
264 |
+
spec.program, spec.is_drum)
|
265 |
+
for spec in track_specs]
|
266 |
+
pred_tracks = [note_sequences.extract_track(pred_ns,
|
267 |
+
spec.program, spec.is_drum)
|
268 |
+
for spec in track_specs]
|
269 |
+
else:
|
270 |
+
target_tracks = [target_ns]
|
271 |
+
pred_tracks = [pred_ns]
|
272 |
+
|
273 |
+
for j, (target_track, pred_track) in enumerate(zip(target_tracks[::-1],
|
274 |
+
pred_tracks[::-1])):
|
275 |
+
target_roll = sequences_lib.sequence_to_pianoroll(
|
276 |
+
target_track,
|
277 |
+
frames_per_second=frames_per_second,
|
278 |
+
min_pitch=note_seq.MIN_MIDI_PITCH,
|
279 |
+
max_pitch=note_seq.MAX_MIDI_PITCH,
|
280 |
+
onset_mode='length_ms')
|
281 |
+
pred_roll = sequences_lib.sequence_to_pianoroll(
|
282 |
+
pred_track,
|
283 |
+
frames_per_second=frames_per_second,
|
284 |
+
min_pitch=note_seq.MIN_MIDI_PITCH,
|
285 |
+
max_pitch=note_seq.MAX_MIDI_PITCH,
|
286 |
+
onset_mode='length_ms')
|
287 |
+
|
288 |
+
num_target_frames = min(len(target_roll.onsets), num_frames)
|
289 |
+
num_pred_frames = min(len(pred_roll.onsets), num_frames)
|
290 |
+
|
291 |
+
start_offset = j * (num_pitches + 1)
|
292 |
+
end_offset = (j + 1) * (num_pitches + 1) - 1
|
293 |
+
|
294 |
+
# Onsets
|
295 |
+
onset_images[
|
296 |
+
i, start_offset:end_offset, :num_start_time_frames, 0
|
297 |
+
] = start_time_roll.onsets[:num_start_time_frames, :].T
|
298 |
+
onset_images[
|
299 |
+
i, start_offset:end_offset, :num_target_frames, 1
|
300 |
+
] = target_roll.onsets[:num_target_frames, :].T
|
301 |
+
onset_images[
|
302 |
+
i, start_offset:end_offset, :num_pred_frames, 2
|
303 |
+
] = pred_roll.onsets[:num_pred_frames, :].T
|
304 |
+
|
305 |
+
# Full notes
|
306 |
+
full_images[
|
307 |
+
i, start_offset:end_offset, :num_start_time_frames, 0
|
308 |
+
] = start_time_roll.onsets[:num_start_time_frames, :].T
|
309 |
+
full_images[
|
310 |
+
i, start_offset:end_offset, :num_target_frames, 1
|
311 |
+
] = target_roll.active[:num_target_frames, :].T
|
312 |
+
full_images[
|
313 |
+
i, start_offset:end_offset, :num_pred_frames, 2
|
314 |
+
] = pred_roll.active[:num_pred_frames, :].T
|
315 |
+
|
316 |
+
# Add separator between tracks.
|
317 |
+
if j < num_tracks - 1:
|
318 |
+
onset_images[i, end_offset, :, 0] = 1
|
319 |
+
full_images[i, end_offset, :, 0] = 1
|
320 |
+
|
321 |
+
return onset_images[:, ::-1, :, :], full_images[:, ::-1, :, :]
|
322 |
+
|
323 |
+
|
324 |
+
def prettymidi_pianoroll(
|
325 |
+
track_pianorolls: Mapping[str, Sequence[Tuple[np.ndarray, np.ndarray]]],
|
326 |
+
fps: float,
|
327 |
+
num_seconds=_DEFAULT_AUDIO_SECONDS
|
328 |
+
) -> Mapping[str, seqio.metrics.MetricValue]:
|
329 |
+
"""Create summary from given pianorolls."""
|
330 |
+
max_len = int(num_seconds * fps)
|
331 |
+
summaries = {}
|
332 |
+
for inst_name, all_prs in track_pianorolls.items():
|
333 |
+
|
334 |
+
est_prs, ref_prs = zip(*all_prs)
|
335 |
+
|
336 |
+
bs = len(ref_prs)
|
337 |
+
pianoroll_image_batch = np.zeros(shape=(bs, 128, max_len, 3))
|
338 |
+
for i in range(bs):
|
339 |
+
ref_pr = ref_prs[i][:, :max_len]
|
340 |
+
est_pr = est_prs[i][:, :max_len]
|
341 |
+
|
342 |
+
pianoroll_image_batch[i, :, :est_pr.shape[1], 2] = est_pr
|
343 |
+
pianoroll_image_batch[i, :, :ref_pr.shape[1], 1] = ref_pr
|
344 |
+
if not inst_name:
|
345 |
+
inst_name = 'all instruments'
|
346 |
+
|
347 |
+
summaries[f'{inst_name} pretty_midi pianoroll'] = seqio.metrics.Image(
|
348 |
+
image=pianoroll_image_batch, max_outputs=bs)
|
349 |
+
|
350 |
+
return summaries
|
351 |
+
|
352 |
+
|
353 |
+
def audio_summaries(
|
354 |
+
targets: Sequence[Mapping[str, Sequence[float]]],
|
355 |
+
predictions: Sequence[Mapping[str, Sequence[float]]],
|
356 |
+
spectrogram_config: spectrograms.SpectrogramConfig,
|
357 |
+
num_seconds: float = _DEFAULT_AUDIO_SECONDS
|
358 |
+
) -> Mapping[str, seqio.metrics.MetricValue]:
|
359 |
+
"""Compute audio summaries for a list of examples.
|
360 |
+
|
361 |
+
Args:
|
362 |
+
targets: List of targets, unused as we pass the input audio tokens via
|
363 |
+
predictions.
|
364 |
+
predictions: List of predictions, including input audio tokens.
|
365 |
+
spectrogram_config: Spectrogram configuration.
|
366 |
+
num_seconds: Number of seconds of audio to include in the summaries.
|
367 |
+
Longer audio will be cropped (from the beginning), shorter audio will be
|
368 |
+
padded with silence (at the end).
|
369 |
+
|
370 |
+
Returns:
|
371 |
+
A dictionary mapping "audio" to the audio summaries.
|
372 |
+
"""
|
373 |
+
del targets
|
374 |
+
samples = _extract_example_audio(
|
375 |
+
examples=predictions,
|
376 |
+
sample_rate=spectrogram_config.sample_rate,
|
377 |
+
num_seconds=num_seconds)
|
378 |
+
return {
|
379 |
+
'audio': seqio.metrics.Audio(
|
380 |
+
audiodata=samples[:, :, np.newaxis],
|
381 |
+
sample_rate=spectrogram_config.sample_rate,
|
382 |
+
max_outputs=samples.shape[0])
|
383 |
+
}
|
384 |
+
|
385 |
+
|
386 |
+
def transcription_summaries(
|
387 |
+
targets: Sequence[Mapping[str, Sequence[float]]],
|
388 |
+
predictions: Sequence[Mapping[str, Sequence[float]]],
|
389 |
+
spectrogram_config: spectrograms.SpectrogramConfig,
|
390 |
+
ns_feature_suffix: Optional[str] = None,
|
391 |
+
note_onset_feature_suffix: Optional[str] = None,
|
392 |
+
note_offset_feature_suffix: Optional[str] = None,
|
393 |
+
note_frequency_feature_suffix: Optional[str] = None,
|
394 |
+
note_confidence_feature_suffix: Optional[str] = None,
|
395 |
+
track_specs: Optional[Sequence[note_sequences.TrackSpec]] = None,
|
396 |
+
num_seconds: float = _DEFAULT_AUDIO_SECONDS,
|
397 |
+
pianoroll_frames_per_second: float = _DEFAULT_PIANOROLL_FRAMES_PER_SECOND,
|
398 |
+
) -> Mapping[str, seqio.metrics.MetricValue]:
|
399 |
+
"""Compute note transcription summaries for multiple examples.
|
400 |
+
|
401 |
+
Args:
|
402 |
+
targets: List of targets containing ground truth.
|
403 |
+
predictions: List of predictions, including raw input audio.
|
404 |
+
spectrogram_config: The spectrogram configuration.
|
405 |
+
ns_feature_suffix: Suffix of serialized NoteSequence feature.
|
406 |
+
note_onset_feature_suffix: Suffix of note onset times feature.
|
407 |
+
note_offset_feature_suffix: Suffix of note offset times feature.
|
408 |
+
note_frequency_feature_suffix: Suffix of note frequencies feature.
|
409 |
+
note_confidence_feature_suffix: Suffix of note confidences (velocities)
|
410 |
+
feature.
|
411 |
+
track_specs: Optional list of TrackSpec objects to indicate a set of tracks
|
412 |
+
into which each NoteSequence should be split.
|
413 |
+
num_seconds: Number of seconds of audio to include in the summaries.
|
414 |
+
Longer audio will be cropped (from the beginning), shorter audio will be
|
415 |
+
padded with silence (at the end).
|
416 |
+
pianoroll_frames_per_second: Temporal resolution of pianoroll images.
|
417 |
+
|
418 |
+
Returns:
|
419 |
+
A dictionary of input, ground truth, and transcription summaries.
|
420 |
+
"""
|
421 |
+
audio_samples = _extract_example_audio(
|
422 |
+
examples=predictions,
|
423 |
+
sample_rate=spectrogram_config.sample_rate,
|
424 |
+
num_seconds=num_seconds)
|
425 |
+
|
426 |
+
def synthesize(examples, prefix):
|
427 |
+
return _synthesize_example_notes(
|
428 |
+
examples=examples,
|
429 |
+
ns_feature_name=(prefix + ns_feature_suffix
|
430 |
+
if ns_feature_suffix else None),
|
431 |
+
note_onset_feature_name=(prefix + note_onset_feature_suffix
|
432 |
+
if note_onset_feature_suffix else None),
|
433 |
+
note_offset_feature_name=(prefix + note_offset_feature_suffix
|
434 |
+
if note_offset_feature_suffix else None),
|
435 |
+
note_frequency_feature_name=(
|
436 |
+
prefix + note_frequency_feature_suffix
|
437 |
+
if note_frequency_feature_suffix else None),
|
438 |
+
note_confidence_feature_name=(
|
439 |
+
prefix + note_confidence_feature_suffix
|
440 |
+
if note_confidence_feature_suffix else None),
|
441 |
+
sample_rate=spectrogram_config.sample_rate,
|
442 |
+
num_seconds=num_seconds)
|
443 |
+
|
444 |
+
synthesized_predictions = synthesize(predictions, 'est_')
|
445 |
+
|
446 |
+
onset_pianoroll_images, full_pianoroll_images = _examples_to_pianorolls(
|
447 |
+
targets=targets,
|
448 |
+
predictions=predictions,
|
449 |
+
ns_feature_suffix=ns_feature_suffix,
|
450 |
+
note_onset_feature_suffix=note_onset_feature_suffix,
|
451 |
+
note_offset_feature_suffix=note_offset_feature_suffix,
|
452 |
+
note_frequency_feature_suffix=note_frequency_feature_suffix,
|
453 |
+
note_confidence_feature_suffix=note_confidence_feature_suffix,
|
454 |
+
track_specs=track_specs,
|
455 |
+
num_seconds=num_seconds,
|
456 |
+
frames_per_second=pianoroll_frames_per_second)
|
457 |
+
|
458 |
+
return {
|
459 |
+
'input_with_transcription': seqio.metrics.Audio(
|
460 |
+
audiodata=np.stack([audio_samples, synthesized_predictions], axis=2),
|
461 |
+
sample_rate=spectrogram_config.sample_rate,
|
462 |
+
max_outputs=audio_samples.shape[0]),
|
463 |
+
|
464 |
+
'pianoroll': seqio.metrics.Image(
|
465 |
+
image=full_pianoroll_images,
|
466 |
+
max_outputs=full_pianoroll_images.shape[0]),
|
467 |
+
|
468 |
+
'onset_pianoroll': seqio.metrics.Image(
|
469 |
+
image=onset_pianoroll_images,
|
470 |
+
max_outputs=onset_pianoroll_images.shape[0]),
|
471 |
+
}
|
mt3/tasks.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Transcription task definitions."""
|
16 |
+
|
17 |
+
import functools
|
18 |
+
from typing import Optional, Sequence
|
19 |
+
|
20 |
+
from mt3 import datasets
|
21 |
+
from mt3 import event_codec
|
22 |
+
from mt3 import metrics
|
23 |
+
from mt3 import mixing
|
24 |
+
from mt3 import preprocessors
|
25 |
+
from mt3 import run_length_encoding
|
26 |
+
from mt3 import spectrograms
|
27 |
+
from mt3 import vocabularies
|
28 |
+
|
29 |
+
import note_seq
|
30 |
+
import numpy as np
|
31 |
+
import seqio
|
32 |
+
import t5
|
33 |
+
import tensorflow as tf
|
34 |
+
|
35 |
+
# Split audio frame sequences into this length before the cache placeholder.
|
36 |
+
MAX_NUM_CACHED_FRAMES = 2000
|
37 |
+
|
38 |
+
seqio.add_global_cache_dirs(['gs://mt3/data/cache_tasks/'])
|
39 |
+
|
40 |
+
|
41 |
+
def construct_task_name(
|
42 |
+
task_prefix: str,
|
43 |
+
spectrogram_config=spectrograms.SpectrogramConfig(),
|
44 |
+
vocab_config=vocabularies.VocabularyConfig(),
|
45 |
+
task_suffix: Optional[str] = None
|
46 |
+
) -> str:
|
47 |
+
"""Construct task name from prefix, config, and optional suffix."""
|
48 |
+
fields = [task_prefix]
|
49 |
+
if spectrogram_config.abbrev_str:
|
50 |
+
fields.append(spectrogram_config.abbrev_str)
|
51 |
+
if vocab_config.abbrev_str:
|
52 |
+
fields.append(vocab_config.abbrev_str)
|
53 |
+
if task_suffix:
|
54 |
+
fields.append(task_suffix)
|
55 |
+
return '_'.join(fields)
|
56 |
+
|
57 |
+
|
58 |
+
def trim_eos(tokens: Sequence[int]) -> np.ndarray:
|
59 |
+
"""If EOS is present, remove it and everything after."""
|
60 |
+
tokens = np.array(tokens, np.int32)
|
61 |
+
if vocabularies.DECODED_EOS_ID in tokens:
|
62 |
+
tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
|
63 |
+
return tokens
|
64 |
+
|
65 |
+
|
66 |
+
def postprocess(tokens, example, is_target, codec):
|
67 |
+
"""Transcription postprocessing function."""
|
68 |
+
tokens = trim_eos(tokens)
|
69 |
+
|
70 |
+
if is_target:
|
71 |
+
return {
|
72 |
+
'unique_id': example['unique_id'][0],
|
73 |
+
'ref_ns': (note_seq.NoteSequence.FromString(example['sequence'][0])
|
74 |
+
if example['sequence'][0] else None),
|
75 |
+
'ref_tokens': tokens,
|
76 |
+
}
|
77 |
+
|
78 |
+
start_time = example['input_times'][0]
|
79 |
+
# Round down to nearest symbolic token step.
|
80 |
+
start_time -= start_time % (1 / codec.steps_per_second)
|
81 |
+
|
82 |
+
return {
|
83 |
+
'unique_id': example['unique_id'][0],
|
84 |
+
'raw_inputs': example['raw_inputs'],
|
85 |
+
'est_tokens': tokens,
|
86 |
+
'start_time': start_time
|
87 |
+
}
|
88 |
+
|
89 |
+
|
90 |
+
def add_transcription_task_to_registry(
|
91 |
+
dataset_config: datasets.DatasetConfig,
|
92 |
+
spectrogram_config: spectrograms.SpectrogramConfig,
|
93 |
+
vocab_config: vocabularies.VocabularyConfig,
|
94 |
+
tokenize_fn, # TODO(iansimon): add type signature
|
95 |
+
onsets_only: bool,
|
96 |
+
include_ties: bool,
|
97 |
+
skip_too_long: bool = False
|
98 |
+
) -> None:
|
99 |
+
"""Add note transcription task to seqio.TaskRegistry."""
|
100 |
+
codec = vocabularies.build_codec(vocab_config)
|
101 |
+
vocabulary = vocabularies.vocabulary_from_codec(codec)
|
102 |
+
|
103 |
+
output_features = {
|
104 |
+
'targets': seqio.Feature(vocabulary=vocabulary),
|
105 |
+
'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2)
|
106 |
+
}
|
107 |
+
|
108 |
+
task_name = 'onsets' if onsets_only else 'notes'
|
109 |
+
if include_ties:
|
110 |
+
task_name += '_ties'
|
111 |
+
task_prefix = f'{dataset_config.name}_{task_name}'
|
112 |
+
|
113 |
+
train_task_name = construct_task_name(
|
114 |
+
task_prefix=task_prefix,
|
115 |
+
spectrogram_config=spectrogram_config,
|
116 |
+
vocab_config=vocab_config,
|
117 |
+
task_suffix='train')
|
118 |
+
|
119 |
+
mixture_task_names = []
|
120 |
+
|
121 |
+
tie_token = codec.encode_event(event_codec.Event('tie', 0))
|
122 |
+
track_specs = (dataset_config.track_specs
|
123 |
+
if dataset_config.track_specs else None)
|
124 |
+
|
125 |
+
# Add transcription training task.
|
126 |
+
seqio.TaskRegistry.add(
|
127 |
+
train_task_name,
|
128 |
+
source=seqio.TFExampleDataSource(
|
129 |
+
split_to_filepattern={
|
130 |
+
'train': dataset_config.paths[dataset_config.train_split],
|
131 |
+
'eval': dataset_config.paths[dataset_config.train_eval_split]
|
132 |
+
},
|
133 |
+
feature_description=dataset_config.features),
|
134 |
+
output_features=output_features,
|
135 |
+
preprocessors=[
|
136 |
+
functools.partial(
|
137 |
+
tokenize_fn,
|
138 |
+
spectrogram_config=spectrogram_config, codec=codec,
|
139 |
+
is_training_data=True, onsets_only=onsets_only,
|
140 |
+
include_ties=include_ties),
|
141 |
+
functools.partial(
|
142 |
+
t5.data.preprocessors.split_tokens,
|
143 |
+
max_tokens_per_segment=MAX_NUM_CACHED_FRAMES,
|
144 |
+
feature_key='inputs',
|
145 |
+
additional_feature_keys=[
|
146 |
+
'input_event_start_indices', 'input_event_end_indices',
|
147 |
+
'input_state_event_indices'
|
148 |
+
],
|
149 |
+
passthrough_feature_keys=['targets', 'state_events']),
|
150 |
+
seqio.CacheDatasetPlaceholder(),
|
151 |
+
functools.partial(
|
152 |
+
t5.data.preprocessors.select_random_chunk,
|
153 |
+
feature_key='inputs',
|
154 |
+
additional_feature_keys=[
|
155 |
+
'input_event_start_indices', 'input_event_end_indices',
|
156 |
+
'input_state_event_indices'
|
157 |
+
],
|
158 |
+
passthrough_feature_keys=['targets', 'state_events'],
|
159 |
+
uniform_random_start=True),
|
160 |
+
functools.partial(
|
161 |
+
run_length_encoding.extract_target_sequence_with_indices,
|
162 |
+
state_events_end_token=tie_token if include_ties else None),
|
163 |
+
functools.partial(preprocessors.map_midi_programs, codec=codec),
|
164 |
+
run_length_encoding.run_length_encode_shifts_fn(
|
165 |
+
codec,
|
166 |
+
feature_key='targets'),
|
167 |
+
functools.partial(
|
168 |
+
mixing.mix_transcription_examples,
|
169 |
+
codec=codec,
|
170 |
+
targets_feature_keys=['targets']),
|
171 |
+
run_length_encoding.remove_redundant_state_changes_fn(
|
172 |
+
feature_key='targets', codec=codec,
|
173 |
+
state_change_event_types=['velocity', 'program']),
|
174 |
+
functools.partial(
|
175 |
+
preprocessors.compute_spectrograms,
|
176 |
+
spectrogram_config=spectrogram_config),
|
177 |
+
functools.partial(preprocessors.handle_too_long, skip=skip_too_long),
|
178 |
+
functools.partial(
|
179 |
+
seqio.preprocessors.tokenize_and_append_eos,
|
180 |
+
copy_pretokenized=False)
|
181 |
+
],
|
182 |
+
postprocess_fn=None,
|
183 |
+
metric_fns=[],
|
184 |
+
)
|
185 |
+
|
186 |
+
# Add transcription eval tasks.
|
187 |
+
for split in dataset_config.infer_eval_splits:
|
188 |
+
eval_task_name = construct_task_name(
|
189 |
+
task_prefix=task_prefix,
|
190 |
+
spectrogram_config=spectrogram_config,
|
191 |
+
vocab_config=vocab_config,
|
192 |
+
task_suffix=split.suffix)
|
193 |
+
|
194 |
+
if split.include_in_mixture:
|
195 |
+
mixture_task_names.append(eval_task_name)
|
196 |
+
|
197 |
+
seqio.TaskRegistry.add(
|
198 |
+
eval_task_name,
|
199 |
+
source=seqio.TFExampleDataSource(
|
200 |
+
split_to_filepattern={'eval': dataset_config.paths[split.name]},
|
201 |
+
feature_description=dataset_config.features),
|
202 |
+
output_features=output_features,
|
203 |
+
preprocessors=[
|
204 |
+
functools.partial(
|
205 |
+
tokenize_fn,
|
206 |
+
spectrogram_config=spectrogram_config, codec=codec,
|
207 |
+
is_training_data='train' in split.name, onsets_only=onsets_only,
|
208 |
+
include_ties=include_ties),
|
209 |
+
seqio.CacheDatasetPlaceholder(),
|
210 |
+
preprocessors.add_unique_id,
|
211 |
+
preprocessors.pad_notesequence_array,
|
212 |
+
functools.partial(
|
213 |
+
t5.data.preprocessors.split_tokens_to_inputs_length,
|
214 |
+
feature_key='inputs',
|
215 |
+
additional_feature_keys=['input_times', 'sequence'],
|
216 |
+
passthrough_feature_keys=['unique_id']),
|
217 |
+
# Add dummy targets as they are dropped during the above split to
|
218 |
+
# avoid memory blowups, but expected to be present by seqio; the
|
219 |
+
# evaluation metrics currently only use the target NoteSequence.
|
220 |
+
preprocessors.add_dummy_targets,
|
221 |
+
functools.partial(
|
222 |
+
preprocessors.compute_spectrograms,
|
223 |
+
spectrogram_config=spectrogram_config),
|
224 |
+
functools.partial(preprocessors.handle_too_long, skip=False),
|
225 |
+
functools.partial(
|
226 |
+
seqio.preprocessors.tokenize_and_append_eos,
|
227 |
+
copy_pretokenized=False)
|
228 |
+
],
|
229 |
+
postprocess_fn=functools.partial(postprocess, codec=codec),
|
230 |
+
metric_fns=[
|
231 |
+
functools.partial(
|
232 |
+
metrics.transcription_metrics,
|
233 |
+
codec=codec,
|
234 |
+
spectrogram_config=spectrogram_config,
|
235 |
+
onsets_only=onsets_only,
|
236 |
+
use_ties=include_ties,
|
237 |
+
track_specs=track_specs)
|
238 |
+
],
|
239 |
+
)
|
240 |
+
|
241 |
+
seqio.MixtureRegistry.add(
|
242 |
+
construct_task_name(
|
243 |
+
task_prefix=task_prefix, spectrogram_config=spectrogram_config,
|
244 |
+
vocab_config=vocab_config, task_suffix='eval'),
|
245 |
+
mixture_task_names,
|
246 |
+
default_rate=1)
|
247 |
+
|
248 |
+
|
249 |
+
# Just use default spectrogram config.
|
250 |
+
SPECTROGRAM_CONFIG = spectrograms.SpectrogramConfig()
|
251 |
+
|
252 |
+
# Create two vocabulary configs, one default and one with only on-off velocity.
|
253 |
+
VOCAB_CONFIG_FULL = vocabularies.VocabularyConfig()
|
254 |
+
VOCAB_CONFIG_NOVELOCITY = vocabularies.VocabularyConfig(num_velocity_bins=1)
|
255 |
+
|
256 |
+
# Transcribe MAESTRO v1.
|
257 |
+
add_transcription_task_to_registry(
|
258 |
+
dataset_config=datasets.MAESTROV1_CONFIG,
|
259 |
+
spectrogram_config=SPECTROGRAM_CONFIG,
|
260 |
+
vocab_config=VOCAB_CONFIG_FULL,
|
261 |
+
tokenize_fn=functools.partial(
|
262 |
+
preprocessors.tokenize_transcription_example,
|
263 |
+
audio_is_samples=False,
|
264 |
+
id_feature_key='id'),
|
265 |
+
onsets_only=False,
|
266 |
+
include_ties=False)
|
267 |
+
|
268 |
+
# Transcribe MAESTRO v3.
|
269 |
+
add_transcription_task_to_registry(
|
270 |
+
dataset_config=datasets.MAESTROV3_CONFIG,
|
271 |
+
spectrogram_config=SPECTROGRAM_CONFIG,
|
272 |
+
vocab_config=VOCAB_CONFIG_FULL,
|
273 |
+
tokenize_fn=functools.partial(
|
274 |
+
preprocessors.tokenize_transcription_example,
|
275 |
+
audio_is_samples=False,
|
276 |
+
id_feature_key='id'),
|
277 |
+
onsets_only=False,
|
278 |
+
include_ties=False)
|
279 |
+
|
280 |
+
# Transcribe MAESTRO v3 without velocities, with ties.
|
281 |
+
add_transcription_task_to_registry(
|
282 |
+
dataset_config=datasets.MAESTROV3_CONFIG,
|
283 |
+
spectrogram_config=SPECTROGRAM_CONFIG,
|
284 |
+
vocab_config=VOCAB_CONFIG_NOVELOCITY,
|
285 |
+
tokenize_fn=functools.partial(
|
286 |
+
preprocessors.tokenize_transcription_example,
|
287 |
+
audio_is_samples=False,
|
288 |
+
id_feature_key='id'),
|
289 |
+
onsets_only=False,
|
290 |
+
include_ties=True)
|
291 |
+
|
292 |
+
# Transcribe GuitarSet, with ties.
|
293 |
+
add_transcription_task_to_registry(
|
294 |
+
dataset_config=datasets.GUITARSET_CONFIG,
|
295 |
+
spectrogram_config=SPECTROGRAM_CONFIG,
|
296 |
+
vocab_config=VOCAB_CONFIG_NOVELOCITY,
|
297 |
+
tokenize_fn=preprocessors.tokenize_guitarset_example,
|
298 |
+
onsets_only=False,
|
299 |
+
include_ties=True)
|
300 |
+
|
301 |
+
# Transcribe URMP mixes, with ties.
|
302 |
+
add_transcription_task_to_registry(
|
303 |
+
dataset_config=datasets.URMP_CONFIG,
|
304 |
+
spectrogram_config=SPECTROGRAM_CONFIG,
|
305 |
+
vocab_config=VOCAB_CONFIG_NOVELOCITY,
|
306 |
+
tokenize_fn=functools.partial(
|
307 |
+
preprocessors.tokenize_example_with_program_lookup,
|
308 |
+
inst_name_to_program_fn=preprocessors.urmp_instrument_to_program,
|
309 |
+
id_feature_key='id'),
|
310 |
+
onsets_only=False,
|
311 |
+
include_ties=True)
|
312 |
+
|
313 |
+
# Transcribe MusicNet, with ties.
|
314 |
+
add_transcription_task_to_registry(
|
315 |
+
dataset_config=datasets.MUSICNET_CONFIG,
|
316 |
+
spectrogram_config=SPECTROGRAM_CONFIG,
|
317 |
+
vocab_config=VOCAB_CONFIG_NOVELOCITY,
|
318 |
+
tokenize_fn=functools.partial(
|
319 |
+
preprocessors.tokenize_transcription_example,
|
320 |
+
audio_is_samples=True,
|
321 |
+
id_feature_key='id'),
|
322 |
+
onsets_only=False,
|
323 |
+
include_ties=True)
|
324 |
+
|
325 |
+
# Transcribe MusicNetEM, with ties.
|
326 |
+
add_transcription_task_to_registry(
|
327 |
+
dataset_config=datasets.MUSICNET_EM_CONFIG,
|
328 |
+
spectrogram_config=SPECTROGRAM_CONFIG,
|
329 |
+
vocab_config=VOCAB_CONFIG_NOVELOCITY,
|
330 |
+
tokenize_fn=functools.partial(
|
331 |
+
preprocessors.tokenize_transcription_example,
|
332 |
+
audio_is_samples=True,
|
333 |
+
id_feature_key='id'),
|
334 |
+
onsets_only=False,
|
335 |
+
include_ties=True)
|
336 |
+
|
337 |
+
# Transcribe Cerberus4 (piano-guitar-bass-drums quartets), with ties.
|
338 |
+
add_transcription_task_to_registry(
|
339 |
+
dataset_config=datasets.CERBERUS4_CONFIG,
|
340 |
+
spectrogram_config=SPECTROGRAM_CONFIG,
|
341 |
+
vocab_config=VOCAB_CONFIG_NOVELOCITY,
|
342 |
+
tokenize_fn=functools.partial(
|
343 |
+
preprocessors.tokenize_slakh_example,
|
344 |
+
track_specs=datasets.CERBERUS4_CONFIG.track_specs,
|
345 |
+
ignore_pitch_bends=True),
|
346 |
+
onsets_only=False,
|
347 |
+
include_ties=True)
|
348 |
+
|
349 |
+
# Transcribe 10 random sub-mixes of each song from Slakh, with ties.
|
350 |
+
add_transcription_task_to_registry(
|
351 |
+
dataset_config=datasets.SLAKH_CONFIG,
|
352 |
+
spectrogram_config=SPECTROGRAM_CONFIG,
|
353 |
+
vocab_config=VOCAB_CONFIG_NOVELOCITY,
|
354 |
+
tokenize_fn=functools.partial(
|
355 |
+
preprocessors.tokenize_slakh_example,
|
356 |
+
track_specs=None,
|
357 |
+
ignore_pitch_bends=True),
|
358 |
+
onsets_only=False,
|
359 |
+
include_ties=True)
|
360 |
+
|
361 |
+
|
362 |
+
# Construct task names to include in transcription mixture.
|
363 |
+
MIXTURE_DATASET_NAMES = [
|
364 |
+
'maestrov3', 'guitarset', 'urmp', 'musicnet_em', 'cerberus4', 'slakh'
|
365 |
+
]
|
366 |
+
MIXTURE_TRAIN_TASK_NAMES = []
|
367 |
+
MIXTURE_EVAL_TASK_NAMES = []
|
368 |
+
MIXTURE_TEST_TASK_NAMES = []
|
369 |
+
for dataset_name in MIXTURE_DATASET_NAMES:
|
370 |
+
MIXTURE_TRAIN_TASK_NAMES.append(
|
371 |
+
construct_task_name(task_prefix=f'{dataset_name}_notes_ties',
|
372 |
+
spectrogram_config=SPECTROGRAM_CONFIG,
|
373 |
+
vocab_config=VOCAB_CONFIG_NOVELOCITY,
|
374 |
+
task_suffix='train'))
|
375 |
+
MIXTURE_EVAL_TASK_NAMES.append(
|
376 |
+
construct_task_name(task_prefix=f'{dataset_name}_notes_ties',
|
377 |
+
spectrogram_config=SPECTROGRAM_CONFIG,
|
378 |
+
vocab_config=VOCAB_CONFIG_NOVELOCITY,
|
379 |
+
task_suffix='validation'))
|
380 |
+
MIXING_TEMPERATURE = 10 / 3
|
381 |
+
|
382 |
+
# Add the mixture of all transcription tasks, with ties.
|
383 |
+
seqio.MixtureRegistry.add(
|
384 |
+
construct_task_name(
|
385 |
+
task_prefix='mega_notes_ties',
|
386 |
+
spectrogram_config=SPECTROGRAM_CONFIG,
|
387 |
+
vocab_config=VOCAB_CONFIG_NOVELOCITY,
|
388 |
+
task_suffix='train'),
|
389 |
+
MIXTURE_TRAIN_TASK_NAMES,
|
390 |
+
default_rate=functools.partial(
|
391 |
+
seqio.mixing_rate_num_examples,
|
392 |
+
temperature=MIXING_TEMPERATURE))
|
393 |
+
seqio.MixtureRegistry.add(
|
394 |
+
construct_task_name(
|
395 |
+
task_prefix='mega_notes_ties',
|
396 |
+
spectrogram_config=SPECTROGRAM_CONFIG,
|
397 |
+
vocab_config=VOCAB_CONFIG_NOVELOCITY,
|
398 |
+
task_suffix='eval'),
|
399 |
+
MIXTURE_EVAL_TASK_NAMES,
|
400 |
+
default_rate=functools.partial(
|
401 |
+
seqio.mixing_rate_num_examples,
|
402 |
+
temperature=MIXING_TEMPERATURE))
|
mt3/version.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""MT3 version."""
|
16 |
+
__version__ = '0.0.1'
|
mt3/vocabularies.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Model vocabulary."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
import math
|
19 |
+
|
20 |
+
from typing import Callable, Optional, Sequence
|
21 |
+
from mt3 import event_codec
|
22 |
+
|
23 |
+
import note_seq
|
24 |
+
import seqio
|
25 |
+
import t5.data
|
26 |
+
import tensorflow as tf
|
27 |
+
|
28 |
+
|
29 |
+
DECODED_EOS_ID = -1
|
30 |
+
DECODED_INVALID_ID = -2
|
31 |
+
|
32 |
+
# defaults for vocabulary config
|
33 |
+
DEFAULT_STEPS_PER_SECOND = 100
|
34 |
+
DEFAULT_MAX_SHIFT_SECONDS = 10
|
35 |
+
DEFAULT_NUM_VELOCITY_BINS = 127
|
36 |
+
|
37 |
+
|
38 |
+
@dataclasses.dataclass
|
39 |
+
class VocabularyConfig:
|
40 |
+
"""Vocabulary configuration parameters."""
|
41 |
+
steps_per_second: int = DEFAULT_STEPS_PER_SECOND
|
42 |
+
max_shift_seconds: int = DEFAULT_MAX_SHIFT_SECONDS
|
43 |
+
num_velocity_bins: int = DEFAULT_NUM_VELOCITY_BINS
|
44 |
+
|
45 |
+
@property
|
46 |
+
def abbrev_str(self):
|
47 |
+
s = ''
|
48 |
+
if self.steps_per_second != DEFAULT_STEPS_PER_SECOND:
|
49 |
+
s += 'ss%d' % self.steps_per_second
|
50 |
+
if self.max_shift_seconds != DEFAULT_MAX_SHIFT_SECONDS:
|
51 |
+
s += 'ms%d' % self.max_shift_seconds
|
52 |
+
if self.num_velocity_bins != DEFAULT_NUM_VELOCITY_BINS:
|
53 |
+
s += 'vb%d' % self.num_velocity_bins
|
54 |
+
return s
|
55 |
+
|
56 |
+
|
57 |
+
def num_velocity_bins_from_codec(codec: event_codec.Codec):
|
58 |
+
"""Get number of velocity bins from event codec."""
|
59 |
+
lo, hi = codec.event_type_range('velocity')
|
60 |
+
return hi - lo
|
61 |
+
|
62 |
+
|
63 |
+
def velocity_to_bin(velocity, num_velocity_bins):
|
64 |
+
if velocity == 0:
|
65 |
+
return 0
|
66 |
+
else:
|
67 |
+
return math.ceil(num_velocity_bins * velocity / note_seq.MAX_MIDI_VELOCITY)
|
68 |
+
|
69 |
+
|
70 |
+
def bin_to_velocity(velocity_bin, num_velocity_bins):
|
71 |
+
if velocity_bin == 0:
|
72 |
+
return 0
|
73 |
+
else:
|
74 |
+
return int(note_seq.MAX_MIDI_VELOCITY * velocity_bin / num_velocity_bins)
|
75 |
+
|
76 |
+
|
77 |
+
def drop_programs(tokens, codec: event_codec.Codec):
|
78 |
+
"""Drops program change events from a token sequence."""
|
79 |
+
min_program_id, max_program_id = codec.event_type_range('program')
|
80 |
+
return tokens[(tokens < min_program_id) | (tokens > max_program_id)]
|
81 |
+
|
82 |
+
|
83 |
+
def programs_to_midi_classes(tokens, codec):
|
84 |
+
"""Modifies program events to be the first program in the MIDI class."""
|
85 |
+
min_program_id, max_program_id = codec.event_type_range('program')
|
86 |
+
is_program = (tokens >= min_program_id) & (tokens <= max_program_id)
|
87 |
+
return tf.where(
|
88 |
+
is_program,
|
89 |
+
min_program_id + 8 * ((tokens - min_program_id) // 8),
|
90 |
+
tokens)
|
91 |
+
|
92 |
+
|
93 |
+
@dataclasses.dataclass
|
94 |
+
class ProgramGranularity:
|
95 |
+
# both tokens_map_fn and program_map_fn should be idempotent
|
96 |
+
tokens_map_fn: Callable[[Sequence[int], event_codec.Codec], Sequence[int]]
|
97 |
+
program_map_fn: Callable[[int], int]
|
98 |
+
|
99 |
+
|
100 |
+
PROGRAM_GRANULARITIES = {
|
101 |
+
# "flat" granularity; drop program change tokens and set NoteSequence
|
102 |
+
# programs to zero
|
103 |
+
'flat': ProgramGranularity(
|
104 |
+
tokens_map_fn=drop_programs,
|
105 |
+
program_map_fn=lambda program: 0),
|
106 |
+
|
107 |
+
# map each program to the first program in its MIDI class
|
108 |
+
'midi_class': ProgramGranularity(
|
109 |
+
tokens_map_fn=programs_to_midi_classes,
|
110 |
+
program_map_fn=lambda program: 8 * (program // 8)),
|
111 |
+
|
112 |
+
# leave programs as is
|
113 |
+
'full': ProgramGranularity(
|
114 |
+
tokens_map_fn=lambda tokens, codec: tokens,
|
115 |
+
program_map_fn=lambda program: program)
|
116 |
+
}
|
117 |
+
|
118 |
+
|
119 |
+
def build_codec(vocab_config: VocabularyConfig):
|
120 |
+
"""Build event codec."""
|
121 |
+
event_ranges = [
|
122 |
+
event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH,
|
123 |
+
note_seq.MAX_MIDI_PITCH),
|
124 |
+
# velocity bin 0 is used for note-off
|
125 |
+
event_codec.EventRange('velocity', 0, vocab_config.num_velocity_bins),
|
126 |
+
# used to indicate that a pitch is present at the beginning of a segment
|
127 |
+
# (only has an "off" event as when using ties all pitch events until the
|
128 |
+
# "tie" event belong to the tie section)
|
129 |
+
event_codec.EventRange('tie', 0, 0),
|
130 |
+
event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM,
|
131 |
+
note_seq.MAX_MIDI_PROGRAM),
|
132 |
+
event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH,
|
133 |
+
note_seq.MAX_MIDI_PITCH),
|
134 |
+
]
|
135 |
+
|
136 |
+
return event_codec.Codec(
|
137 |
+
max_shift_steps=(vocab_config.steps_per_second *
|
138 |
+
vocab_config.max_shift_seconds),
|
139 |
+
steps_per_second=vocab_config.steps_per_second,
|
140 |
+
event_ranges=event_ranges)
|
141 |
+
|
142 |
+
|
143 |
+
def vocabulary_from_codec(codec: event_codec.Codec) -> seqio.Vocabulary:
|
144 |
+
return GenericTokenVocabulary(
|
145 |
+
codec.num_classes, extra_ids=t5.data.DEFAULT_EXTRA_IDS)
|
146 |
+
|
147 |
+
|
148 |
+
class GenericTokenVocabulary(seqio.Vocabulary):
|
149 |
+
"""Vocabulary with pass-through encoding of tokens."""
|
150 |
+
|
151 |
+
def __init__(self, regular_ids: int, extra_ids: int = 0):
|
152 |
+
# The special tokens: 0=PAD, 1=EOS, and 2=UNK
|
153 |
+
self._num_special_tokens = 3
|
154 |
+
self._num_regular_tokens = regular_ids
|
155 |
+
super().__init__(extra_ids=extra_ids)
|
156 |
+
|
157 |
+
@property
|
158 |
+
def eos_id(self) -> Optional[int]:
|
159 |
+
return 1
|
160 |
+
|
161 |
+
@property
|
162 |
+
def unk_id(self) -> Optional[int]:
|
163 |
+
return 2
|
164 |
+
|
165 |
+
@property
|
166 |
+
def _base_vocab_size(self) -> int:
|
167 |
+
"""Number of ids.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
an integer, the vocabulary size
|
171 |
+
"""
|
172 |
+
return self._num_special_tokens + self._num_regular_tokens
|
173 |
+
|
174 |
+
def _encode(self, token_ids: Sequence[int]) -> Sequence[int]:
|
175 |
+
"""Encode a list of tokens ids as a list of integers.
|
176 |
+
|
177 |
+
To keep the first few ids for special tokens, increase ids by the number
|
178 |
+
of special tokens.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
token_ids: array of token ids.
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
a list of integers (not terminated by EOS)
|
185 |
+
"""
|
186 |
+
encoded = []
|
187 |
+
for token_id in token_ids:
|
188 |
+
if not 0 <= token_id < self._num_regular_tokens:
|
189 |
+
raise ValueError(
|
190 |
+
f'token_id {token_id} does not fall within valid range of '
|
191 |
+
f'[0, {self._num_regular_tokens})')
|
192 |
+
encoded.append(token_id + self._num_special_tokens)
|
193 |
+
|
194 |
+
return encoded
|
195 |
+
|
196 |
+
def _decode(self, ids: Sequence[int]) -> Sequence[int]:
|
197 |
+
"""Decode a list of integers to a list of token ids.
|
198 |
+
|
199 |
+
The special tokens of PAD and UNK as well as extra_ids will be
|
200 |
+
replaced with DECODED_INVALID_ID in the output. If EOS is present, it will
|
201 |
+
be the final token in the decoded output and will be represented by
|
202 |
+
DECODED_EOS_ID.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
ids: a list of integers
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
a list of token ids.
|
209 |
+
"""
|
210 |
+
# convert all the extra ids to INVALID_ID
|
211 |
+
def _decode_id(encoded_id):
|
212 |
+
if encoded_id == self.eos_id:
|
213 |
+
return DECODED_EOS_ID
|
214 |
+
elif encoded_id < self._num_special_tokens:
|
215 |
+
return DECODED_INVALID_ID
|
216 |
+
elif encoded_id >= self._base_vocab_size:
|
217 |
+
return DECODED_INVALID_ID
|
218 |
+
else:
|
219 |
+
return encoded_id - self._num_special_tokens
|
220 |
+
ids = [_decode_id(int(i)) for i in ids]
|
221 |
+
return ids
|
222 |
+
|
223 |
+
def _encode_tf(self, token_ids: tf.Tensor) -> tf.Tensor:
|
224 |
+
"""Encode a list of tokens to a tf.Tensor.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
token_ids: array of audio token ids.
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
a 1d tf.Tensor with dtype tf.int32
|
231 |
+
"""
|
232 |
+
with tf.control_dependencies(
|
233 |
+
[tf.debugging.assert_less(
|
234 |
+
token_ids, tf.cast(self._num_regular_tokens, token_ids.dtype)),
|
235 |
+
tf.debugging.assert_greater_equal(
|
236 |
+
token_ids, tf.cast(0, token_ids.dtype))
|
237 |
+
]):
|
238 |
+
tf_ids = token_ids + self._num_special_tokens
|
239 |
+
return tf_ids
|
240 |
+
|
241 |
+
def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor:
|
242 |
+
"""Decode in TensorFlow.
|
243 |
+
|
244 |
+
The special tokens of PAD and UNK as well as extra_ids will be
|
245 |
+
replaced with DECODED_INVALID_ID in the output. If EOS is present, it and
|
246 |
+
all following tokens in the decoded output and will be represented by
|
247 |
+
DECODED_EOS_ID.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
ids: a 1d tf.Tensor with dtype tf.int32
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
a 1d tf.Tensor with dtype tf.int32
|
254 |
+
"""
|
255 |
+
# Create a mask that is true from the first EOS position onward.
|
256 |
+
# First, create an array that is True whenever there is an EOS, then cumsum
|
257 |
+
# that array so that every position after and including the first True is
|
258 |
+
# >1, then cast back to bool for the final mask.
|
259 |
+
eos_and_after = tf.cumsum(
|
260 |
+
tf.cast(tf.equal(ids, self.eos_id), tf.int32), exclusive=False, axis=-1)
|
261 |
+
eos_and_after = tf.cast(eos_and_after, tf.bool)
|
262 |
+
|
263 |
+
return tf.where(
|
264 |
+
eos_and_after,
|
265 |
+
DECODED_EOS_ID,
|
266 |
+
tf.where(
|
267 |
+
tf.logical_and(
|
268 |
+
tf.greater_equal(ids, self._num_special_tokens),
|
269 |
+
tf.less(ids, self._base_vocab_size)),
|
270 |
+
ids - self._num_special_tokens,
|
271 |
+
DECODED_INVALID_ID))
|
272 |
+
|
273 |
+
def __eq__(self, other):
|
274 |
+
their_extra_ids = other.extra_ids
|
275 |
+
their_num_regular_tokens = other._num_regular_tokens
|
276 |
+
return (self.extra_ids == their_extra_ids and
|
277 |
+
self._num_regular_tokens == their_num_regular_tokens)
|
278 |
+
|
279 |
+
|
280 |
+
def num_embeddings(vocabulary: GenericTokenVocabulary) -> int:
|
281 |
+
"""Vocabulary size as a multiple of 128 for TPU efficiency."""
|
282 |
+
return 128 * math.ceil(vocabulary.vocab_size / 128)
|
mt3/vocabularies_test.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for vocabularies."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from mt3 import vocabularies
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import tensorflow.compat.v2 as tf
|
22 |
+
|
23 |
+
tf.compat.v1.enable_eager_execution()
|
24 |
+
|
25 |
+
|
26 |
+
class VocabulariesTest(absltest.TestCase):
|
27 |
+
|
28 |
+
def test_velocity_quantization(self):
|
29 |
+
self.assertEqual(0, vocabularies.velocity_to_bin(0, num_velocity_bins=1))
|
30 |
+
self.assertEqual(0, vocabularies.velocity_to_bin(0, num_velocity_bins=127))
|
31 |
+
self.assertEqual(0, vocabularies.bin_to_velocity(0, num_velocity_bins=1))
|
32 |
+
self.assertEqual(0, vocabularies.bin_to_velocity(0, num_velocity_bins=127))
|
33 |
+
|
34 |
+
self.assertEqual(
|
35 |
+
1,
|
36 |
+
vocabularies.velocity_to_bin(
|
37 |
+
vocabularies.bin_to_velocity(1, num_velocity_bins=1),
|
38 |
+
num_velocity_bins=1))
|
39 |
+
|
40 |
+
for velocity_bin in range(1, 128):
|
41 |
+
self.assertEqual(
|
42 |
+
velocity_bin,
|
43 |
+
vocabularies.velocity_to_bin(
|
44 |
+
vocabularies.bin_to_velocity(velocity_bin, num_velocity_bins=127),
|
45 |
+
num_velocity_bins=127))
|
46 |
+
|
47 |
+
def test_encode_decode(self):
|
48 |
+
vocab = vocabularies.GenericTokenVocabulary(32)
|
49 |
+
input_tokens = [1, 2, 3]
|
50 |
+
expected_encoded = [4, 5, 6]
|
51 |
+
|
52 |
+
# Encode
|
53 |
+
self.assertSequenceEqual(vocab.encode(input_tokens), expected_encoded)
|
54 |
+
np.testing.assert_array_equal(
|
55 |
+
vocab.encode_tf(tf.convert_to_tensor(input_tokens)).numpy(),
|
56 |
+
expected_encoded)
|
57 |
+
|
58 |
+
# Decode
|
59 |
+
self.assertSequenceEqual(vocab.decode(expected_encoded), input_tokens)
|
60 |
+
np.testing.assert_array_equal(
|
61 |
+
vocab.decode_tf(tf.convert_to_tensor(expected_encoded)).numpy(),
|
62 |
+
input_tokens)
|
63 |
+
|
64 |
+
def test_decode_invalid_ids(self):
|
65 |
+
vocab = vocabularies.GenericTokenVocabulary(32, extra_ids=4)
|
66 |
+
encoded = [0, 2, 3, 4, 34, 35]
|
67 |
+
expected_decoded = [-2, -2, 0, 1, 31, -2]
|
68 |
+
self.assertSequenceEqual(vocab.decode(encoded), expected_decoded)
|
69 |
+
np.testing.assert_array_equal(
|
70 |
+
vocab.decode_tf(tf.convert_to_tensor(encoded)).numpy(),
|
71 |
+
expected_decoded)
|
72 |
+
|
73 |
+
def test_decode_eos(self):
|
74 |
+
vocab = vocabularies.GenericTokenVocabulary(32)
|
75 |
+
encoded = [0, 2, 3, 4, 1, 0, 1, 0]
|
76 |
+
# Python decode function truncates everything after first EOS.
|
77 |
+
expected_decoded = [-2, -2, 0, 1, -1]
|
78 |
+
self.assertSequenceEqual(vocab.decode(encoded), expected_decoded)
|
79 |
+
# TF decode function preserves array length.
|
80 |
+
expected_decoded_tf = [-2, -2, 0, 1, -1, -1, -1, -1]
|
81 |
+
np.testing.assert_array_equal(
|
82 |
+
vocab.decode_tf(tf.convert_to_tensor(encoded)).numpy(),
|
83 |
+
expected_decoded_tf)
|
84 |
+
|
85 |
+
def test_encode_invalid_id(self):
|
86 |
+
vocab = vocabularies.GenericTokenVocabulary(32)
|
87 |
+
inputs = [0, 15, 31]
|
88 |
+
# No exception expected.
|
89 |
+
vocab.encode(inputs)
|
90 |
+
vocab.encode_tf(tf.convert_to_tensor(inputs))
|
91 |
+
|
92 |
+
inputs_too_low = [-1, 15, 31]
|
93 |
+
with self.assertRaises(ValueError):
|
94 |
+
vocab.encode(inputs_too_low)
|
95 |
+
with self.assertRaises(tf.errors.InvalidArgumentError):
|
96 |
+
vocab.encode_tf(tf.convert_to_tensor(inputs_too_low))
|
97 |
+
|
98 |
+
inputs_too_high = [0, 15, 32]
|
99 |
+
with self.assertRaises(ValueError):
|
100 |
+
vocab.encode(inputs_too_high)
|
101 |
+
with self.assertRaises(tf.errors.InvalidArgumentError):
|
102 |
+
vocab.encode_tf(tf.convert_to_tensor(inputs_too_high))
|
103 |
+
|
104 |
+
def test_encode_dtypes(self):
|
105 |
+
vocab = vocabularies.GenericTokenVocabulary(32)
|
106 |
+
inputs = [0, 15, 31]
|
107 |
+
encoded32 = vocab.encode_tf(tf.convert_to_tensor(inputs, tf.int32))
|
108 |
+
self.assertEqual(tf.int32, encoded32.dtype)
|
109 |
+
encoded64 = vocab.encode_tf(tf.convert_to_tensor(inputs, tf.int64))
|
110 |
+
self.assertEqual(tf.int64, encoded64.dtype)
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == '__main__':
|
114 |
+
absltest.main()
|
pytest.ini
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[pytest]
|
2 |
+
python_files = *_test.py
|
3 |
+
log_level = INFO
|
setup.cfg
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[aliases]
|
2 |
+
test=pytest
|
setup.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The MT3 Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Install mt3."""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import sys
|
19 |
+
import setuptools
|
20 |
+
|
21 |
+
# To enable importing version.py directly, we add its path to sys.path.
|
22 |
+
version_path = os.path.join(os.path.dirname(__file__), 'mt3')
|
23 |
+
sys.path.append(version_path)
|
24 |
+
from version import __version__ # pylint: disable=g-import-not-at-top
|
25 |
+
|
26 |
+
setuptools.setup(
|
27 |
+
name='mt3',
|
28 |
+
version=__version__,
|
29 |
+
description='Multi-Task Multitrack Music Transcription',
|
30 |
+
author='Google Inc.',
|
31 |
+
author_email='[email protected]',
|
32 |
+
url='http://github.com/magenta/mt3',
|
33 |
+
license='Apache 2.0',
|
34 |
+
packages=setuptools.find_packages(),
|
35 |
+
package_data={
|
36 |
+
'': ['*.gin'],
|
37 |
+
},
|
38 |
+
scripts=[],
|
39 |
+
install_requires=[
|
40 |
+
'absl-py == 1.1.0',
|
41 |
+
'ddsp == 3.4.4',
|
42 |
+
'flax == 0.5.2',
|
43 |
+
'gin-config == 0.5.0',
|
44 |
+
'immutabledict == 2.2.1',
|
45 |
+
'librosa == 0.9.2',
|
46 |
+
'mir_eval == 0.7',
|
47 |
+
'note_seq == 0.0.3',
|
48 |
+
'numpy == 1.21.6',
|
49 |
+
'pretty_midi == 0.2.9',
|
50 |
+
'scikit-learn == 1.0.2',
|
51 |
+
'scipy == 1.7.3',
|
52 |
+
'seqio == 0.0.8',
|
53 |
+
't5 == 0.9.3',
|
54 |
+
'tensorflow',
|
55 |
+
'tensorflow-datasets == 4.6.0',
|
56 |
+
],
|
57 |
+
classifiers=[
|
58 |
+
'Development Status :: 4 - Beta',
|
59 |
+
'Intended Audience :: Developers',
|
60 |
+
'Intended Audience :: Science/Research',
|
61 |
+
'License :: OSI Approved :: Apache Software License',
|
62 |
+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
63 |
+
],
|
64 |
+
tests_require=['pytest'],
|
65 |
+
setup_requires=['pytest-runner'],
|
66 |
+
keywords='music transcription machinelearning audio',
|
67 |
+
)
|
t5x/__init__.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The T5X Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Import API modules."""
|
16 |
+
|
17 |
+
import t5x.adafactor
|
18 |
+
import t5x.checkpoints
|
19 |
+
import t5x.decoding
|
20 |
+
import t5x.gin_utils
|
21 |
+
import t5x.losses
|
22 |
+
import t5x.models
|
23 |
+
import t5x.partitioning
|
24 |
+
import t5x.state_utils
|
25 |
+
import t5x.train_state
|
26 |
+
import t5x.trainer
|
27 |
+
import t5x.utils
|
28 |
+
|
29 |
+
# Version number.
|
30 |
+
from t5x.version import __version__
|
31 |
+
|
32 |
+
# TODO(adarob): Move clients to t5x.checkpointing and rename
|
33 |
+
# checkpoints.py to checkpointing.py
|
34 |
+
checkpointing = t5x.checkpoints
|
t5x/adafactor.py
ADDED
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The T5X Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Adafactor Optimizer.
|
16 |
+
|
17 |
+
Specialized Adafactor implementation for T5X with:
|
18 |
+
- custom factorization specification rules.
|
19 |
+
- support for stacked parameters from scanned layers and parameter fusions.
|
20 |
+
|
21 |
+
Why do we need custom factorization? In the Adafactor paper, scalar, vector and
|
22 |
+
matrix parameters are considered. This is sufficiently general because higher
|
23 |
+
dimensional parameters can be reshaped. In practice, there are situations where
|
24 |
+
higher dimensional parameters are desirable. For example, consider the
|
25 |
+
multi-headed attention. It has projection kernels. This is naturally
|
26 |
+
represented as 3-dimensional array [d_model, num_head, head_dim]. Keeping the
|
27 |
+
3-dimensional structure can be beneficial for performance optimization, e.g., by
|
28 |
+
giving compilers additional degree of freedom to do layout optimization.
|
29 |
+
|
30 |
+
The default heuristic behavior for the second-moment estimator can lead to an
|
31 |
+
unexpected result because it assumes that the parameters are matrices (vectors
|
32 |
+
and scalars are not factored). The dimensions are sorted and the smaller
|
33 |
+
dimension is assigned to the row dim and the larger dim to the col dim (unless
|
34 |
+
the two largest dims have an equal size and then the original ordering of the
|
35 |
+
dimensions is used). Then `v_row` (i.e., the optimizer state for the row) is
|
36 |
+
obtained by removing the col dim. In other words, `rank(v_row) = rank(v) - 1`.
|
37 |
+
If the parameter is higher dimensional, v_row and v_col are higher dimensional.
|
38 |
+
Therefore, the outer product of v_row and v_col do not necessarily corresponds
|
39 |
+
to the row rank approximation that minimizes the generalized Kullback-Leibler
|
40 |
+
divergence (the original Adafactor formulation).
|
41 |
+
|
42 |
+
This Adafactor implementation generalized the default behavior such that we
|
43 |
+
obtain the correct second moment estimator even for higher dimensional
|
44 |
+
parameters.
|
45 |
+
|
46 |
+
"""
|
47 |
+
import enum
|
48 |
+
import re
|
49 |
+
from typing import Any, Mapping, Optional, Sequence, Tuple, Union
|
50 |
+
|
51 |
+
from absl import logging
|
52 |
+
from flax import struct
|
53 |
+
from flax.core import freeze
|
54 |
+
from flax.core import FrozenDict
|
55 |
+
from flax.core import unfreeze
|
56 |
+
from flax.serialization import from_state_dict
|
57 |
+
from flax.serialization import to_state_dict
|
58 |
+
from flax.traverse_util import flatten_dict
|
59 |
+
from flax.traverse_util import unflatten_dict
|
60 |
+
import jax
|
61 |
+
import jax.numpy as jnp
|
62 |
+
import numpy as np
|
63 |
+
from t5x import utils
|
64 |
+
from t5x.optimizers import OptimizerDef
|
65 |
+
from t5x.optimizers import OptimizerState
|
66 |
+
|
67 |
+
Dtype = Any
|
68 |
+
|
69 |
+
|
70 |
+
class FactorDim(enum.Enum):
|
71 |
+
# Don't factorize this dimension.
|
72 |
+
NONE = None
|
73 |
+
# A batch-like dimension that we should not average over.
|
74 |
+
BATCH = 1
|
75 |
+
ROW = 2
|
76 |
+
COLUMN = 3
|
77 |
+
|
78 |
+
|
79 |
+
# Sentinel value signifying the legacy heuristic factorization rule.
|
80 |
+
class HeuristicRule(enum.Enum):
|
81 |
+
token = 1
|
82 |
+
|
83 |
+
|
84 |
+
HEURISTIC_RULE = HeuristicRule.token
|
85 |
+
FactorRule = Union[HeuristicRule, Tuple[FactorDim]]
|
86 |
+
|
87 |
+
|
88 |
+
def _restore(target, flat):
|
89 |
+
state_dict = unflatten_dict({tuple(k.split('/')): v for k, v in flat.items()})
|
90 |
+
if isinstance(target, FrozenDict):
|
91 |
+
return freeze(state_dict)
|
92 |
+
else:
|
93 |
+
return state_dict
|
94 |
+
|
95 |
+
|
96 |
+
def _insert(tpl, idx, x):
|
97 |
+
tmp = list(tpl)
|
98 |
+
tmp.insert(idx, x)
|
99 |
+
return tuple(tmp)
|
100 |
+
|
101 |
+
|
102 |
+
def standard_logical_factor_rules():
|
103 |
+
return freeze({
|
104 |
+
'vocab': FactorDim.COLUMN,
|
105 |
+
'embed': FactorDim.ROW,
|
106 |
+
'mlp': FactorDim.COLUMN,
|
107 |
+
'heads': FactorDim.COLUMN,
|
108 |
+
'kv': FactorDim.COLUMN,
|
109 |
+
'joined_kv': FactorDim.COLUMN,
|
110 |
+
'relpos_buckets': FactorDim.NONE,
|
111 |
+
'layers': FactorDim.BATCH, # used in scanned layers
|
112 |
+
'stack': FactorDim.BATCH, # used in stacked params
|
113 |
+
# 'batch', 'length' should not occur in parameters
|
114 |
+
'q_wi_fused': FactorDim.COLUMN,
|
115 |
+
'o_wo_fused': FactorDim.COLUMN,
|
116 |
+
'multiquery_heads': FactorDim.COLUMN,
|
117 |
+
'kv_fused': FactorDim.COLUMN,
|
118 |
+
'layer_norm_scale': FactorDim.NONE,
|
119 |
+
'mlp_activations': FactorDim.COLUMN,
|
120 |
+
})
|
121 |
+
|
122 |
+
|
123 |
+
def factor_name_to_factordim(name):
|
124 |
+
if not isinstance(name, str):
|
125 |
+
return name
|
126 |
+
name = name.lower()
|
127 |
+
return {
|
128 |
+
'row': FactorDim.ROW,
|
129 |
+
'col': FactorDim.COLUMN,
|
130 |
+
'column': FactorDim.COLUMN,
|
131 |
+
'batch': FactorDim.BATCH,
|
132 |
+
'none': FactorDim.NONE,
|
133 |
+
'unfactorized': FactorDim.NONE
|
134 |
+
}[name]
|
135 |
+
|
136 |
+
|
137 |
+
class HParamMap:
|
138 |
+
"""Maps parameter path names to hparams.
|
139 |
+
|
140 |
+
Names of parameters nested in a PyTree (e.g., an Optimizer) are formed by
|
141 |
+
joining the names along the path to the parameter leaf with '/'.
|
142 |
+
"""
|
143 |
+
|
144 |
+
def __init__(self, rules):
|
145 |
+
self._rules = [(re.compile(r), p) for r, p in rules]
|
146 |
+
|
147 |
+
def __getitem__(self, key: str) -> Any:
|
148 |
+
for r, p in self._rules:
|
149 |
+
if r.search(key):
|
150 |
+
return p
|
151 |
+
raise KeyError(f'No factor rule found for parameter: {key}')
|
152 |
+
|
153 |
+
def __call__(self, params):
|
154 |
+
"""Returns a copy of the params with mapped hparams in leaves."""
|
155 |
+
flat_state_dict = flatten_dict(to_state_dict(params))
|
156 |
+
flat_rules_dict = {k: self['/'.join(k)] for k in flat_state_dict.keys()}
|
157 |
+
return from_state_dict(params, unflatten_dict(flat_rules_dict))
|
158 |
+
|
159 |
+
|
160 |
+
@struct.dataclass
|
161 |
+
class _AdafactorHyperParams:
|
162 |
+
"""Hparams for Adafactor optimizer."""
|
163 |
+
learning_rate: Optional[float]
|
164 |
+
factored: bool
|
165 |
+
multiply_by_parameter_scale: Union[bool, HParamMap]
|
166 |
+
beta1: Optional[float]
|
167 |
+
decay_rate: float
|
168 |
+
step_offset: int
|
169 |
+
clipping_threshold: Optional[float]
|
170 |
+
weight_decay_rate: Optional[float]
|
171 |
+
min_dim_size_to_factor: int
|
172 |
+
epsilon1: float
|
173 |
+
epsilon2: float
|
174 |
+
factor_map: Optional[HParamMap] = None
|
175 |
+
logical_factor_rules: Any = None
|
176 |
+
weight_decay_rate_lr_exponent: Optional[float] = None
|
177 |
+
global_norm_clip_threshold: Optional[float] = None
|
178 |
+
max_parameter_scale: Optional[float] = None
|
179 |
+
skip_nan_updates: Optional[bool] = False
|
180 |
+
|
181 |
+
|
182 |
+
@struct.dataclass
|
183 |
+
class _AdafactorParamState:
|
184 |
+
v_row: np.ndarray # used in normal factored version
|
185 |
+
v_col: np.ndarray
|
186 |
+
v: np.ndarray # only used without factoring
|
187 |
+
m: np.ndarray # only used with momentum
|
188 |
+
|
189 |
+
|
190 |
+
class Adafactor(OptimizerDef):
|
191 |
+
"""Adafactor optimizer.
|
192 |
+
|
193 |
+
Adafactor is described in https://arxiv.org/abs/1804.04235.
|
194 |
+
"""
|
195 |
+
|
196 |
+
def __init__(self,
|
197 |
+
learning_rate: Optional[float] = None,
|
198 |
+
factored: bool = True,
|
199 |
+
multiply_by_parameter_scale: Union[bool, HParamMap] = True,
|
200 |
+
beta1: Optional[float] = None,
|
201 |
+
decay_rate: float = 0.8,
|
202 |
+
step_offset: int = 0,
|
203 |
+
clipping_threshold: Optional[float] = 1.0,
|
204 |
+
weight_decay_rate: Optional[float] = None,
|
205 |
+
min_dim_size_to_factor: int = 128,
|
206 |
+
epsilon1: float = 1e-30,
|
207 |
+
epsilon2: float = 1e-3,
|
208 |
+
dtype_momentum: Dtype = jnp.float32,
|
209 |
+
factor_map: Optional[HParamMap] = None,
|
210 |
+
logical_factor_rules: Optional[Mapping[str, FactorDim]] = None,
|
211 |
+
weight_decay_rate_lr_exponent: Optional[float] = None,
|
212 |
+
global_norm_clip_threshold: Optional[float] = None,
|
213 |
+
max_parameter_scale: Optional[float] = None,
|
214 |
+
skip_nan_updates: Optional[bool] = False):
|
215 |
+
"""Constructor for the Adafactor optimizer.
|
216 |
+
|
217 |
+
|
218 |
+
Args:
|
219 |
+
learning_rate: float: learning rate. NB: the natural scale for adafactor
|
220 |
+
LR is markedly different from Adam, one doesn't use the 1/sqrt(hidden)
|
221 |
+
correction for this optimizer with attention-based models.
|
222 |
+
factored: boolean: whether to use factored second-moment estimator for 2d
|
223 |
+
variables.
|
224 |
+
multiply_by_parameter_scale: boolean: if True, then scale provided
|
225 |
+
learning_rate by parameter norm. if False, provided learning_rate is
|
226 |
+
absolute step size.
|
227 |
+
beta1: an optional float value between 0 and 1, enables momentum and uses
|
228 |
+
extra memory if non-None! None by default.
|
229 |
+
decay_rate: float: controls second-moment exponential decay schedule.
|
230 |
+
step_offset: for finetuning, one may optionally set this to the starting
|
231 |
+
step-number of the finetuning phase to reset the second moment
|
232 |
+
accumulators after pretraining. Does not affect the momentum even if it
|
233 |
+
was used during pretraining.
|
234 |
+
clipping_threshold: an optional float >= 1, if None no update clipping.
|
235 |
+
weight_decay_rate: optional rate at which to decay weights.
|
236 |
+
min_dim_size_to_factor: only factor accumulator if two array dimensions
|
237 |
+
are at least this size.
|
238 |
+
epsilon1: Regularization constant for squared gradient.
|
239 |
+
epsilon2: Regularization constant for parameter scale.
|
240 |
+
dtype_momentum: dtype of momentum buffers.
|
241 |
+
factor_map: hparam-map from key path to manual factorization rules.
|
242 |
+
logical_factor_rules: factorization rules provided as a set of mappings
|
243 |
+
from logical axis name to ROW, COLUMN, BATCH, or NONE. Supercedes
|
244 |
+
factor_map if `set_param_axes` is called.
|
245 |
+
weight_decay_rate_lr_exponent: If present, weight decay rate is computed
|
246 |
+
as (learning_rate ** weight_decay_rate_lr_exponent). If
|
247 |
+
weight_decay_rate is also present, then multiply by it.
|
248 |
+
global_norm_clip_threshold: If set, will clip gradients by global norm
|
249 |
+
before Adafactor stats are applied.
|
250 |
+
max_parameter_scale: If set, clips the parameter scale to a maximum value,
|
251 |
+
which helps prevent parameters from growing without bound.
|
252 |
+
skip_nan_updates: If set, any parameter that would have been updated by a
|
253 |
+
NaN value after a applying gradients will be kept with the earlier
|
254 |
+
value it had.
|
255 |
+
"""
|
256 |
+
if not factored and factor_map is not None:
|
257 |
+
raise ValueError('Adafactor factored is False but factorization rules '
|
258 |
+
'have been provided.')
|
259 |
+
if not isinstance(multiply_by_parameter_scale, (bool, HParamMap)):
|
260 |
+
raise TypeError(
|
261 |
+
'`multiply_by_parameter_scale` must be either bool or `HParamMap` '
|
262 |
+
f'type. Got {type(multiply_by_parameter_scale)}')
|
263 |
+
|
264 |
+
if not isinstance(factor_map, (type(None), HParamMap)):
|
265 |
+
raise TypeError(
|
266 |
+
'`factor_map` must be either None or `HParamMap` type. Got '
|
267 |
+
f'{type(factor_map)}')
|
268 |
+
|
269 |
+
hyper_params = _AdafactorHyperParams(
|
270 |
+
learning_rate, factored, multiply_by_parameter_scale, beta1, decay_rate,
|
271 |
+
step_offset, clipping_threshold, weight_decay_rate,
|
272 |
+
min_dim_size_to_factor, epsilon1, epsilon2, factor_map,
|
273 |
+
logical_factor_rules, weight_decay_rate_lr_exponent,
|
274 |
+
global_norm_clip_threshold, max_parameter_scale, skip_nan_updates)
|
275 |
+
self.dtype_momentum = jax.dtypes.canonicalize_dtype(dtype_momentum)
|
276 |
+
super().__init__(hyper_params)
|
277 |
+
|
278 |
+
@staticmethod
|
279 |
+
def _decay_rate_pow(i: int, exponent: float = 0.8) -> float:
|
280 |
+
"""Default Adafactor second-moment decay schedule."""
|
281 |
+
t = jnp.array(i, jnp.float32) + 1.0
|
282 |
+
return 1.0 - t**(-exponent)
|
283 |
+
|
284 |
+
@staticmethod
|
285 |
+
def _parse_rule(
|
286 |
+
rule: Optional[FactorRule],
|
287 |
+
shape: Sequence[int],
|
288 |
+
path: str,
|
289 |
+
fallback_to_heuristics=True
|
290 |
+
) -> Tuple[Tuple[int, ...], Optional[Union[HeuristicRule, Tuple[Tuple[
|
291 |
+
int, ...], Tuple[int, ...]]]]]:
|
292 |
+
"""Parses specification and return factored dims and dims for averaging.
|
293 |
+
|
294 |
+
Adafactor needs to know the two largest dimensions to factorize along.
|
295 |
+
Traditionally it used a heuristic, but we want finer control over these
|
296 |
+
factorization dimensions. Additionally, there are situations where
|
297 |
+
parameters are batched together for e.g. scanned layers and QKV fusion,
|
298 |
+
and we want to ensure that the scale updates and clipping thresholds are
|
299 |
+
calculated _within_ each array and not across the entire batched array.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
rule: the rule is either None (default to heuristic behavior) or a tuple
|
303 |
+
of the same rank as the `param` array containing a FactorDim.ROW or
|
304 |
+
FactorDim.COLUMN to mark dimensions to factorize in two row and column
|
305 |
+
sets, and optionally dimensions marked FactorDim.BATCH to denote batched
|
306 |
+
dimensions that should not be averaged over. e.g. (BATCH, ROW, COLUMN,
|
307 |
+
COLUMN)
|
308 |
+
shape: shape of the variable
|
309 |
+
path: '/' joined parameter path.
|
310 |
+
fallback_to_heuristics: whether to fallback to heuristic factorization
|
311 |
+
rule. For most cases this should be set to `True`.
|
312 |
+
|
313 |
+
Returns:
|
314 |
+
tuple of: tuple of dimensions to average over, 2-tuple of dimensions to
|
315 |
+
factorize over.
|
316 |
+
"""
|
317 |
+
param_ndim = len(shape)
|
318 |
+
|
319 |
+
if rule is None:
|
320 |
+
# No factorization.
|
321 |
+
return tuple(np.arange(param_ndim)), None
|
322 |
+
|
323 |
+
if rule is HEURISTIC_RULE:
|
324 |
+
if param_ndim > 2:
|
325 |
+
raise ValueError(
|
326 |
+
f'A parameter with rank strictly higher than 2 must have an '
|
327 |
+
f'explicit factorization rule: {path}, {shape}')
|
328 |
+
# Even if no explicit rule is provided for the param, we still want to
|
329 |
+
# average over all the dimensions for computing the RMS scale.
|
330 |
+
return tuple(np.arange(param_ndim)), HEURISTIC_RULE
|
331 |
+
|
332 |
+
if len(rule) != param_ndim:
|
333 |
+
raise ValueError(f'Factorization rule {rule} has incorrect rank '
|
334 |
+
f'for param of rank {param_ndim}: {path}, {shape}')
|
335 |
+
|
336 |
+
row_dims = tuple(idx for idx, d in enumerate(rule) if d == FactorDim.ROW)
|
337 |
+
col_dims = tuple(idx for idx, d in enumerate(rule) if d == FactorDim.COLUMN)
|
338 |
+
batched_dims = tuple(
|
339 |
+
idx for idx, d in enumerate(rule) if d == FactorDim.BATCH)
|
340 |
+
averaging_dims = tuple(np.delete(np.arange(param_ndim), batched_dims))
|
341 |
+
factor_dims = (row_dims, col_dims)
|
342 |
+
if factor_dims == ((), ()):
|
343 |
+
factor_dims = None
|
344 |
+
|
345 |
+
if fallback_to_heuristics and param_ndim <= 2 and not batched_dims:
|
346 |
+
logging.warning(
|
347 |
+
'Since rank of parameter %s %d is less than or equal to 2, the '
|
348 |
+
'factorization method falls back to heuristics and the provided '
|
349 |
+
'factor rule %s is ignored.', path, param_ndim, rule)
|
350 |
+
return tuple(np.arange(param_ndim)), HEURISTIC_RULE
|
351 |
+
|
352 |
+
return averaging_dims, factor_dims
|
353 |
+
|
354 |
+
def _factored_dims(
|
355 |
+
self, shape: Sequence[int]) -> Optional[Tuple[Tuple[int], Tuple[int]]]:
|
356 |
+
"""Whether to use a factored second moment estimator.
|
357 |
+
|
358 |
+
If there are not two dimensions of size >= min_dim_size_to_factor, then we
|
359 |
+
do not factor. If we do factor the accumulator, then this function returns a
|
360 |
+
tuple of the two largest axes to reduce over.
|
361 |
+
|
362 |
+
Args:
|
363 |
+
shape: a Shape
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
None or a tuple of ints
|
367 |
+
"""
|
368 |
+
if not self.hyper_params.factored or len(shape) < 2:
|
369 |
+
return None
|
370 |
+
sorted_dims = np.argsort(shape)
|
371 |
+
if shape[sorted_dims[-2]] < self.hyper_params.min_dim_size_to_factor:
|
372 |
+
return None
|
373 |
+
return (int(sorted_dims[-2]),), (int(sorted_dims[-1]),)
|
374 |
+
|
375 |
+
def init_param_state(self, param, path):
|
376 |
+
shape = param.shape
|
377 |
+
state = {k: jnp.zeros((1,)) for k in ['v_row', 'v_col', 'v', 'm']}
|
378 |
+
if self.hyper_params.factored:
|
379 |
+
factor_rule = (
|
380 |
+
self.hyper_params.factor_map[path]
|
381 |
+
if self.hyper_params.factor_map else HEURISTIC_RULE)
|
382 |
+
else:
|
383 |
+
factor_rule = None
|
384 |
+
_, factored_dims = self._parse_rule(factor_rule, param.shape, path)
|
385 |
+
if factored_dims is HEURISTIC_RULE:
|
386 |
+
factored_dims = self._factored_dims(shape)
|
387 |
+
if factored_dims is not None:
|
388 |
+
d1, d0 = factored_dims
|
389 |
+
vr_shape = np.delete(shape, d0)
|
390 |
+
vc_shape = np.delete(shape, d1)
|
391 |
+
state['v_row'] = jnp.zeros(vr_shape, dtype=jnp.float32)
|
392 |
+
state['v_col'] = jnp.zeros(vc_shape, dtype=jnp.float32)
|
393 |
+
else:
|
394 |
+
state['v'] = jnp.zeros(param.shape, dtype=jnp.float32)
|
395 |
+
if self.hyper_params.beta1 is not None:
|
396 |
+
state['m'] = jnp.zeros(param.shape, dtype=self.dtype_momentum)
|
397 |
+
return _AdafactorParamState(**state)
|
398 |
+
|
399 |
+
def init_state(self, params):
|
400 |
+
params_flat = utils.flatten_dict_string_keys(params)
|
401 |
+
param_states_flat = [
|
402 |
+
self.init_param_state(param, path)
|
403 |
+
for path, param in params_flat.items()
|
404 |
+
]
|
405 |
+
param_states_flat = {
|
406 |
+
k: v for k, v in zip(params_flat.keys(), param_states_flat)
|
407 |
+
}
|
408 |
+
param_states = _restore(params, param_states_flat)
|
409 |
+
state = OptimizerState(jnp.asarray(0, dtype=jnp.int32), param_states)
|
410 |
+
return state
|
411 |
+
|
412 |
+
def apply_param_gradient(self, step, hyper_params, param, state, grad, path):
|
413 |
+
assert hyper_params.learning_rate is not None, 'no learning rate provided.'
|
414 |
+
learning_rate = hyper_params.learning_rate
|
415 |
+
beta1 = hyper_params.beta1
|
416 |
+
decay_rate = hyper_params.decay_rate
|
417 |
+
step_offset = hyper_params.step_offset
|
418 |
+
multiply_by_parameter_scale = hyper_params.multiply_by_parameter_scale
|
419 |
+
max_parameter_scale = hyper_params.max_parameter_scale
|
420 |
+
clipping_threshold = hyper_params.clipping_threshold
|
421 |
+
weight_decay_rate = hyper_params.weight_decay_rate
|
422 |
+
epsilon1 = hyper_params.epsilon1
|
423 |
+
epsilon2 = hyper_params.epsilon2
|
424 |
+
if hyper_params.weight_decay_rate_lr_exponent:
|
425 |
+
weight_decay_rate = (
|
426 |
+
(weight_decay_rate or 1.0) *
|
427 |
+
learning_rate**hyper_params.weight_decay_rate_lr_exponent)
|
428 |
+
|
429 |
+
if self.hyper_params.factored:
|
430 |
+
factor_rule = (
|
431 |
+
self.hyper_params.factor_map[path]
|
432 |
+
if self.hyper_params.factor_map else HEURISTIC_RULE)
|
433 |
+
else:
|
434 |
+
factor_rule = None
|
435 |
+
averaging_dims, factored_dims = self._parse_rule(factor_rule, param.shape,
|
436 |
+
path)
|
437 |
+
|
438 |
+
grad = grad.astype(jnp.float32)
|
439 |
+
|
440 |
+
updates = {k: jnp.zeros((1,)) for k in ['v_row', 'v_col', 'v', 'm']}
|
441 |
+
decay_rate = self._decay_rate_pow(step - step_offset, exponent=decay_rate)
|
442 |
+
update_scale = learning_rate
|
443 |
+
|
444 |
+
if isinstance(multiply_by_parameter_scale, HParamMap):
|
445 |
+
multiply_by_parameter_scale = multiply_by_parameter_scale[path]
|
446 |
+
if multiply_by_parameter_scale:
|
447 |
+
param_scale = jnp.sqrt(
|
448 |
+
jnp.mean(param * param, axis=averaging_dims, keepdims=True))
|
449 |
+
# Clip param_scale to a minimum value of epsilon2.
|
450 |
+
param_scale = jnp.maximum(param_scale, epsilon2)
|
451 |
+
# Clip param_scale to a maximum value, if specified.
|
452 |
+
if max_parameter_scale is not None:
|
453 |
+
param_scale = jnp.minimum(param_scale, max_parameter_scale)
|
454 |
+
update_scale *= param_scale
|
455 |
+
mixing_rate = 1.0 - decay_rate
|
456 |
+
|
457 |
+
grad_sqr = grad * grad + epsilon1
|
458 |
+
if factored_dims is HEURISTIC_RULE:
|
459 |
+
factored_dims = self._factored_dims(param.shape)
|
460 |
+
if factored_dims is not None:
|
461 |
+
d1, d0 = factored_dims
|
462 |
+
new_v_row = (
|
463 |
+
decay_rate * state.v_row + mixing_rate * jnp.mean(grad_sqr, axis=d0))
|
464 |
+
new_v_col = (
|
465 |
+
decay_rate * state.v_col + mixing_rate * jnp.mean(grad_sqr, axis=d1))
|
466 |
+
updates['v_row'] = new_v_row
|
467 |
+
updates['v_col'] = new_v_col
|
468 |
+
reduced_d1 = tuple(d - len([e for e in d0 if e < d]) for d in d1)
|
469 |
+
|
470 |
+
row_col_mean = jnp.mean(new_v_row, axis=reduced_d1, keepdims=True)
|
471 |
+
row_factor = (new_v_row / row_col_mean)**-0.5
|
472 |
+
col_factor = (new_v_col)**-0.5
|
473 |
+
y = (
|
474 |
+
grad * jnp.expand_dims(row_factor, axis=d0) *
|
475 |
+
jnp.expand_dims(col_factor, axis=d1))
|
476 |
+
else:
|
477 |
+
new_v = decay_rate * state.v + mixing_rate * grad_sqr
|
478 |
+
updates['v'] = new_v
|
479 |
+
y = grad * (new_v)**-0.5
|
480 |
+
|
481 |
+
if clipping_threshold is not None:
|
482 |
+
clipping_denom = (
|
483 |
+
jnp.maximum(
|
484 |
+
1.0,
|
485 |
+
jnp.sqrt(jnp.mean(y * y, axis=averaging_dims, keepdims=True)) /
|
486 |
+
clipping_threshold))
|
487 |
+
y /= clipping_denom
|
488 |
+
|
489 |
+
subtrahend = update_scale * y
|
490 |
+
if beta1 is not None:
|
491 |
+
new_m = beta1 * state.m + (1.0 - beta1) * subtrahend
|
492 |
+
subtrahend = new_m
|
493 |
+
updates['m'] = new_m.astype(self.dtype_momentum)
|
494 |
+
|
495 |
+
if weight_decay_rate is not None:
|
496 |
+
new_param = (1.0 - weight_decay_rate) * param - subtrahend
|
497 |
+
else:
|
498 |
+
new_param = param - subtrahend
|
499 |
+
|
500 |
+
if hyper_params.skip_nan_updates:
|
501 |
+
updates['v_row'] = jnp.where(
|
502 |
+
jnp.isnan(updates['v_row']), state.v_row, updates['v_row'])
|
503 |
+
updates['v_col'] = jnp.where(
|
504 |
+
jnp.isnan(updates['v_col']), state.v_col, updates['v_col'])
|
505 |
+
updates['v'] = jnp.where(jnp.isnan(updates['v']), state.v, updates['v'])
|
506 |
+
updates['m'] = jnp.where(jnp.isnan(updates['m']), state.m, updates['m'])
|
507 |
+
new_param = jnp.where(jnp.isnan(new_param), param, new_param)
|
508 |
+
new_state = _AdafactorParamState(**updates)
|
509 |
+
|
510 |
+
return new_param.astype(param.dtype), new_state
|
511 |
+
|
512 |
+
def apply_gradient(self, hyper_params, params, state, grads):
|
513 |
+
"""Applies a gradient for a set of parameters.
|
514 |
+
|
515 |
+
Args:
|
516 |
+
hyper_params: a named tuple of hyper parameters.
|
517 |
+
params: the parameters that should be updated.
|
518 |
+
state: a named tuple containing the state of the optimizer
|
519 |
+
grads: the gradient tensors for the parameters.
|
520 |
+
|
521 |
+
Returns:
|
522 |
+
A tuple containing the new parameters and the new optimizer state.
|
523 |
+
"""
|
524 |
+
step = state.step
|
525 |
+
# We assume that params, param_states, and grads are all dict-like here.
|
526 |
+
params_flat_dict = utils.flatten_dict_string_keys(params)
|
527 |
+
params_paths = params_flat_dict.keys()
|
528 |
+
params_flat = params_flat_dict.values()
|
529 |
+
# extra paranoia to guarantee identical value ordering
|
530 |
+
states_flat = utils.flatten_dict_string_keys(state.param_states)
|
531 |
+
states_flat = [states_flat[k] for k in params_paths]
|
532 |
+
grads_flat = utils.flatten_dict_string_keys(grads)
|
533 |
+
grads_flat = [grads_flat[k] for k in params_paths]
|
534 |
+
|
535 |
+
if hyper_params.global_norm_clip_threshold:
|
536 |
+
# Paper: http://proceedings.mlr.press/v28/pascanu13.pdf
|
537 |
+
# TF: https://www.tensorflow.org/api_docs/python/tf/clip_by_global_norm
|
538 |
+
squared_l2_norms = [jnp.sum(jnp.square(g)) for g in grads_flat]
|
539 |
+
global_norm = jnp.sqrt(jnp.sum(jnp.array(squared_l2_norms)))
|
540 |
+
scale = hyper_params.global_norm_clip_threshold * jnp.minimum(
|
541 |
+
1.0 / hyper_params.global_norm_clip_threshold, 1.0 / global_norm)
|
542 |
+
grads_flat = [g * scale for g in grads_flat]
|
543 |
+
|
544 |
+
out = [
|
545 |
+
self.apply_param_gradient(step, hyper_params, param, state, grad, path)
|
546 |
+
for param, state, grad, path in zip(params_flat, states_flat,
|
547 |
+
grads_flat, params_paths)
|
548 |
+
]
|
549 |
+
|
550 |
+
new_params_flat, new_states_flat = list(zip(*out)) if out else ((), ())
|
551 |
+
new_params_flat = {k: v for k, v in zip(params_paths, new_params_flat)}
|
552 |
+
new_states_flat = {k: v for k, v in zip(params_paths, new_states_flat)}
|
553 |
+
new_params = _restore(params, new_params_flat)
|
554 |
+
new_param_states = _restore(params, new_states_flat)
|
555 |
+
new_state = OptimizerState(step + 1, new_param_states)
|
556 |
+
|
557 |
+
return new_params, new_state
|
558 |
+
|
559 |
+
def set_param_axes(self, param_logical_axes):
|
560 |
+
"""Sets Adafactor factorization map from logical axis names tree."""
|
561 |
+
logical_factor_rules = self.hyper_params.logical_factor_rules
|
562 |
+
if logical_factor_rules is None:
|
563 |
+
return
|
564 |
+
|
565 |
+
# pylint:disable=invalid-name
|
566 |
+
NONE = FactorDim.NONE
|
567 |
+
COLUMN = FactorDim.COLUMN
|
568 |
+
ROW = FactorDim.ROW
|
569 |
+
|
570 |
+
# pylint:enable=invalid-name
|
571 |
+
|
572 |
+
def apply_rules(axes):
|
573 |
+
# Partially factorized params are marked as unfactorized, preserving
|
574 |
+
# only BATCH axis annotations. We also check for incompletely factorized
|
575 |
+
# params that have ROW, COLUMN but also accidental NONE dimensions and
|
576 |
+
# raise an error in that case.
|
577 |
+
axis_rules = tuple(logical_factor_rules[x] for x in axes)
|
578 |
+
axis_rules = tuple(factor_name_to_factordim(x) for x in axis_rules)
|
579 |
+
if ROW in axis_rules and COLUMN in axis_rules and NONE in axis_rules:
|
580 |
+
raise ValueError(f'Incomplete adafactor spec {axis_rules} for {axes}!')
|
581 |
+
if ROW not in axis_rules or COLUMN not in axis_rules:
|
582 |
+
axis_rules = tuple(
|
583 |
+
NONE if x in (ROW, COLUMN) else x for x in axis_rules)
|
584 |
+
return axis_rules
|
585 |
+
|
586 |
+
factor_map = jax.tree_map(apply_rules, param_logical_axes)
|
587 |
+
factor_map = utils.flatten_dict_string_keys(factor_map)
|
588 |
+
|
589 |
+
self.hyper_params = self.hyper_params.replace(factor_map=factor_map)
|
590 |
+
|
591 |
+
def derive_logical_axes(self, optimizer_state, param_logical_axes):
|
592 |
+
"""Derives optimizer logical partitioning from model logical partitions."""
|
593 |
+
optimizer_logical_axes = jax.tree_map(lambda x: None,
|
594 |
+
optimizer_state.state_dict())
|
595 |
+
optimizer_logical_axes['target'] = param_logical_axes
|
596 |
+
|
597 |
+
def factor_rule(logical_axes, adafactor_leaf):
|
598 |
+
return dict(
|
599 |
+
v_row=None,
|
600 |
+
v_col=None,
|
601 |
+
v=logical_axes if adafactor_leaf['v'].shape != (1,) else None,
|
602 |
+
m=logical_axes if self.hyper_params.beta1 else None)
|
603 |
+
|
604 |
+
optimizer_logical_axes['state']['param_states'] = jax.tree_map(
|
605 |
+
factor_rule, unfreeze(param_logical_axes),
|
606 |
+
optimizer_state.state_dict()['state']['param_states'])
|
607 |
+
|
608 |
+
return optimizer_state.restore_state(unfreeze(optimizer_logical_axes))
|
t5x/adafactor_test.py
ADDED
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The T5X Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for t5x.adafactor."""
|
16 |
+
|
17 |
+
import functools
|
18 |
+
import operator
|
19 |
+
from typing import Sequence
|
20 |
+
|
21 |
+
from absl.testing import absltest
|
22 |
+
from absl.testing import parameterized
|
23 |
+
|
24 |
+
import flax
|
25 |
+
from flax import optim # used for equivalence testing only
|
26 |
+
from flax import traverse_util
|
27 |
+
import jax
|
28 |
+
from jax import numpy as jnp
|
29 |
+
from jax import random
|
30 |
+
import numpy as np
|
31 |
+
|
32 |
+
from t5x import adafactor
|
33 |
+
from t5x import optimizers
|
34 |
+
|
35 |
+
OptimizerState = optimizers.OptimizerState
|
36 |
+
|
37 |
+
_AdafactorHyperParams = adafactor._AdafactorHyperParams
|
38 |
+
_AdafactorParamState = adafactor._AdafactorParamState
|
39 |
+
|
40 |
+
_BATCH = adafactor.FactorDim.BATCH
|
41 |
+
_ROW = adafactor.FactorDim.ROW
|
42 |
+
_COL = adafactor.FactorDim.COLUMN
|
43 |
+
|
44 |
+
# Testing helpers
|
45 |
+
|
46 |
+
|
47 |
+
def _assert_numpy_allclose(a, b, atol=None, rtol=None):
|
48 |
+
a, b = jnp.array(a), jnp.array(b)
|
49 |
+
a = a.astype(np.float32) if a.dtype == jnp.bfloat16 else a
|
50 |
+
b = b.astype(np.float32) if b.dtype == jnp.bfloat16 else b
|
51 |
+
kw = {}
|
52 |
+
if atol:
|
53 |
+
kw['atol'] = atol
|
54 |
+
if rtol:
|
55 |
+
kw['rtol'] = rtol
|
56 |
+
np.testing.assert_allclose(a, b, **kw)
|
57 |
+
|
58 |
+
|
59 |
+
def check_eq(xs, ys, atol=None, rtol=None):
|
60 |
+
xs_leaves, xs_tree = jax.tree_flatten(xs)
|
61 |
+
ys_leaves, ys_tree = jax.tree_flatten(ys)
|
62 |
+
assert xs_tree == ys_tree, f"Tree shapes don't match. \n{xs_tree}\n{ys_tree}"
|
63 |
+
assert jax.tree_util.tree_all(
|
64 |
+
jax.tree_multimap(lambda x, y: np.array(x).shape == np.array(y).shape,
|
65 |
+
xs_leaves, ys_leaves)), "Leaves' shapes don't match."
|
66 |
+
assert jax.tree_multimap(
|
67 |
+
functools.partial(_assert_numpy_allclose, atol=atol, rtol=rtol),
|
68 |
+
xs_leaves, ys_leaves)
|
69 |
+
|
70 |
+
|
71 |
+
def flattened_state_dict(x):
|
72 |
+
s = flax.serialization.to_state_dict(x)
|
73 |
+
return flax.traverse_util.flatten_dict(s, sep='/')
|
74 |
+
|
75 |
+
|
76 |
+
def tree_shape(x):
|
77 |
+
return jax.tree_map(jnp.shape, x)
|
78 |
+
|
79 |
+
|
80 |
+
def tree_equals(x, y):
|
81 |
+
return jax.tree_util.tree_all(jax.tree_multimap(operator.eq, x, y))
|
82 |
+
|
83 |
+
|
84 |
+
def _get_multi_adafactor(
|
85 |
+
learning_rate: float, step_offset: int,
|
86 |
+
adafactor_exclude_from_parameter_scale: Sequence[str]
|
87 |
+
) -> optim.MultiOptimizer:
|
88 |
+
"""Get adafactor with support for excluding some parameters from scaling."""
|
89 |
+
|
90 |
+
def _should_not_scale(path):
|
91 |
+
return any([s in path for s in adafactor_exclude_from_parameter_scale])
|
92 |
+
|
93 |
+
scaled_vars = traverse_util.ModelParamTraversal(
|
94 |
+
lambda path, _: not _should_not_scale(path))
|
95 |
+
unscaled_vars = traverse_util.ModelParamTraversal(
|
96 |
+
lambda path, _: _should_not_scale(path))
|
97 |
+
scaled_opt = optim.Adafactor(
|
98 |
+
learning_rate, decay_rate=0.8, step_offset=step_offset)
|
99 |
+
unscaled_opt = optim.Adafactor(
|
100 |
+
learning_rate,
|
101 |
+
decay_rate=0.8,
|
102 |
+
step_offset=step_offset,
|
103 |
+
multiply_by_parameter_scale=False)
|
104 |
+
return optim.MultiOptimizer((scaled_vars, scaled_opt),
|
105 |
+
(unscaled_vars, unscaled_opt))
|
106 |
+
|
107 |
+
|
108 |
+
# Inline test data
|
109 |
+
|
110 |
+
MODEL_SHAPE = {
|
111 |
+
'decoder': {
|
112 |
+
'decoder_norm': {'scale': [128]},
|
113 |
+
'layers_0': {
|
114 |
+
'encoder_decoder_attention': {
|
115 |
+
'key': {'kernel': [128, 256]},
|
116 |
+
'out': {'kernel': [256, 128]},
|
117 |
+
'query': {'kernel': [128, 256]},
|
118 |
+
'value': {'kernel': [128, 256]}},
|
119 |
+
'mlp': {
|
120 |
+
'wi': {'kernel': [128, 512]},
|
121 |
+
'wo': {'kernel': [512, 128]}},
|
122 |
+
'pre_cross_attention_layer_norm': {'scale': [128]},
|
123 |
+
'pre_mlp_layer_norm': {'scale': [128]},
|
124 |
+
'pre_self_attention_layer_norm': {'scale': [128]},
|
125 |
+
'self_attention': {
|
126 |
+
'key': {'kernel': [128, 256]},
|
127 |
+
'out': {'kernel': [256, 128]},
|
128 |
+
'query': {'kernel': [128, 256]},
|
129 |
+
'value': {'kernel': [128, 256]}}},
|
130 |
+
'layers_1': {
|
131 |
+
'encoder_decoder_attention': {
|
132 |
+
'key': {'kernel': [128, 128]},
|
133 |
+
'out': {'kernel': [128, 128]},
|
134 |
+
'query': {'kernel': [128, 128]},
|
135 |
+
'value': {'kernel': [128, 128]}},
|
136 |
+
'mlp': {
|
137 |
+
'wi': {'kernel': [128, 512]},
|
138 |
+
'wo': {'kernel': [512, 128]}},
|
139 |
+
'pre_cross_attention_layer_norm': {'scale': [128]},
|
140 |
+
'pre_mlp_layer_norm': {'scale': [128]},
|
141 |
+
'pre_self_attention_layer_norm': {'scale': [128]},
|
142 |
+
'self_attention': {
|
143 |
+
'key': {'kernel': [128, 256]},
|
144 |
+
'out': {'kernel': [256, 128]},
|
145 |
+
'query': {'kernel': [128, 256]},
|
146 |
+
'value': {'kernel': [128, 256]}}},
|
147 |
+
'relpos_bias': {'rel_embedding': [2, 32]}},
|
148 |
+
'encoder': {
|
149 |
+
'encoder_norm': {'scale': [128]},
|
150 |
+
'layers_0': {
|
151 |
+
'attention': {
|
152 |
+
'key': {'kernel': [128, 256]},
|
153 |
+
'out': {'kernel': [256, 128]},
|
154 |
+
'query': {'kernel': [128, 256]},
|
155 |
+
'value': {'kernel': [128, 256]}},
|
156 |
+
'mlp': {
|
157 |
+
'wi': {'kernel': [128, 512]},
|
158 |
+
'wo': {'kernel': [512, 128]}},
|
159 |
+
'pre_attention_layer_norm': {'scale': [128]},
|
160 |
+
'pre_mlp_layer_norm': {'scale': [128]}},
|
161 |
+
'layers_1': {
|
162 |
+
'attention': {
|
163 |
+
'key': {'kernel': [128, 256]},
|
164 |
+
'out': {'kernel': [256, 128]},
|
165 |
+
'query': {'kernel': [128, 256]},
|
166 |
+
'value': {'kernel': [128, 256]}},
|
167 |
+
'mlp': {
|
168 |
+
'wi': {'kernel': [128, 512]},
|
169 |
+
'wo': {'kernel': [512, 128]}},
|
170 |
+
'pre_attention_layer_norm': {'scale': [128]},
|
171 |
+
'pre_mlp_layer_norm': {'scale': [128]}},
|
172 |
+
'relpos_bias': {'rel_embedding': [2, 32]}},
|
173 |
+
'token_embedder': {'embedding': [32128, 128]}} # pyformat: disable
|
174 |
+
|
175 |
+
|
176 |
+
class AdafactorTest(parameterized.TestCase):
|
177 |
+
|
178 |
+
# Classic Adafactor Behavior Tests
|
179 |
+
|
180 |
+
def test_2D_simple(self):
|
181 |
+
x = {'a': jnp.ones((24, 16))}
|
182 |
+
opt_def = adafactor.Adafactor(min_dim_size_to_factor=8)
|
183 |
+
optimizer = opt_def.create(x)
|
184 |
+
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
|
185 |
+
ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (16,)}
|
186 |
+
self.assertTrue(tree_equals(shapes, ref))
|
187 |
+
|
188 |
+
def test_2D_simple_nofactor(self):
|
189 |
+
x = {'a': jnp.ones((24, 16))}
|
190 |
+
opt_def = adafactor.Adafactor(min_dim_size_to_factor=32)
|
191 |
+
optimizer = opt_def.create(x)
|
192 |
+
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
|
193 |
+
ref = {'a/m': (1,), 'a/v': (24, 16), 'a/v_col': (1,), 'a/v_row': (1,)}
|
194 |
+
self.assertTrue(tree_equals(shapes, ref))
|
195 |
+
|
196 |
+
def test_2D_simple_nofactor_momentum(self):
|
197 |
+
x = {'a': jnp.ones((24, 16))}
|
198 |
+
opt_def = adafactor.Adafactor(min_dim_size_to_factor=32, beta1=0.1)
|
199 |
+
optimizer = opt_def.create(x)
|
200 |
+
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
|
201 |
+
ref = {'a/m': (24, 16), 'a/v': (24, 16), 'a/v_col': (1,), 'a/v_row': (1,)}
|
202 |
+
self.assertTrue(tree_equals(shapes, ref))
|
203 |
+
|
204 |
+
def test_3D_simple(self):
|
205 |
+
x = {'a': jnp.ones((24, 4, 16))}
|
206 |
+
factor_map = adafactor.HParamMap((('a', (_COL, _BATCH, _ROW)),))
|
207 |
+
opt_def = adafactor.Adafactor(
|
208 |
+
min_dim_size_to_factor=8, factor_map=factor_map)
|
209 |
+
optimizer = opt_def.create(x)
|
210 |
+
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
|
211 |
+
ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (4, 16)}
|
212 |
+
self.assertTrue(tree_equals(shapes, ref))
|
213 |
+
|
214 |
+
def test_init_state(self):
|
215 |
+
params = {'x': np.zeros((3, 2))}
|
216 |
+
optimizer_def = adafactor.Adafactor(
|
217 |
+
learning_rate=0.1, decay_rate=0.8, beta1=None, min_dim_size_to_factor=0)
|
218 |
+
state = optimizer_def.init_state(params)
|
219 |
+
|
220 |
+
expected_hyper_params = _AdafactorHyperParams(0.1, True, True, None, 0.8, 0,
|
221 |
+
1.0, None, 0, 1e-30, 1e-3)
|
222 |
+
self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
|
223 |
+
expected_state = OptimizerState(
|
224 |
+
0, {
|
225 |
+
'x':
|
226 |
+
_AdafactorParamState(
|
227 |
+
np.zeros((2,)), np.zeros((3,)), np.zeros(
|
228 |
+
(1,)), np.zeros((1,)))
|
229 |
+
})
|
230 |
+
check_eq(state, expected_state)
|
231 |
+
|
232 |
+
# unfactorized
|
233 |
+
optimizer_def = adafactor.Adafactor(
|
234 |
+
learning_rate=0.1, decay_rate=0.8, beta1=0.0, min_dim_size_to_factor=32)
|
235 |
+
state = optimizer_def.init_state(params)
|
236 |
+
|
237 |
+
expected_hyper_params = _AdafactorHyperParams(0.1, True, True, 0.0, 0.8, 0,
|
238 |
+
1.0, None, 32, 1e-30, 1e-3)
|
239 |
+
self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
|
240 |
+
expected_state = OptimizerState(
|
241 |
+
0, {
|
242 |
+
'x':
|
243 |
+
_AdafactorParamState(
|
244 |
+
np.zeros((1,)), np.zeros((1,)), np.zeros(
|
245 |
+
(3, 2)), np.zeros((3, 2)))
|
246 |
+
})
|
247 |
+
check_eq(state, expected_state)
|
248 |
+
|
249 |
+
def test_apply_gradient(self):
|
250 |
+
optimizer_def = adafactor.Adafactor(
|
251 |
+
learning_rate=0.1, decay_rate=0.8, min_dim_size_to_factor=0)
|
252 |
+
params = {'x': np.ones((3, 2), np.float32)}
|
253 |
+
state = OptimizerState(
|
254 |
+
1, {
|
255 |
+
'x':
|
256 |
+
_AdafactorParamState(
|
257 |
+
np.array([0.9, 0.9]), np.array([0.1, 0.1, 0.1]),
|
258 |
+
np.zeros((1,)), np.zeros((1,)))
|
259 |
+
})
|
260 |
+
grads = {'x': np.ones((3, 2), np.float32)}
|
261 |
+
new_params, new_state = optimizer_def.apply_gradient(
|
262 |
+
optimizer_def.hyper_params, params, state, grads)
|
263 |
+
expected_new_state = OptimizerState(
|
264 |
+
2, {
|
265 |
+
'x':
|
266 |
+
_AdafactorParamState(
|
267 |
+
np.array([0.9574349, 0.9574349]),
|
268 |
+
np.array([0.6169143, 0.6169143, 0.6169143]), np.zeros(
|
269 |
+
(1,)), np.zeros((1,)))
|
270 |
+
})
|
271 |
+
expected_new_params = {'x': 0.9 * np.ones((3, 2))}
|
272 |
+
check_eq(new_params, expected_new_params)
|
273 |
+
check_eq(new_state, expected_new_state, rtol=1e-6)
|
274 |
+
|
275 |
+
# unfactored w momentum
|
276 |
+
optimizer_def = adafactor.Adafactor(
|
277 |
+
learning_rate=0.1, beta1=0.0, decay_rate=0.8, min_dim_size_to_factor=32)
|
278 |
+
params = {'x': np.ones((3, 2), np.float32)}
|
279 |
+
state = OptimizerState(
|
280 |
+
1, {
|
281 |
+
'x':
|
282 |
+
_AdafactorParamState(
|
283 |
+
np.zeros(1,), np.zeros(1,), 0.5 * np.ones(
|
284 |
+
(3, 2)), np.zeros((3, 2)))
|
285 |
+
})
|
286 |
+
grads = {'x': np.ones((3, 2), np.float32)}
|
287 |
+
new_params, new_state = optimizer_def.apply_gradient(
|
288 |
+
optimizer_def.hyper_params, params, state, grads)
|
289 |
+
expected_new_params = {'x': 0.9 * np.ones((3, 2))}
|
290 |
+
check_eq(new_params, expected_new_params)
|
291 |
+
expected_new_state = OptimizerState(
|
292 |
+
2, {
|
293 |
+
'x':
|
294 |
+
_AdafactorParamState(
|
295 |
+
np.array([0.0]), np.array([0.0]), 0.787174 * np.ones(
|
296 |
+
(3, 2)), 0.1 * np.ones((3, 2)))
|
297 |
+
})
|
298 |
+
check_eq(new_state, expected_new_state, rtol=1e-6)
|
299 |
+
|
300 |
+
def test_apply_gradient_with_global_norm_clipping(self):
|
301 |
+
optimizer_def = adafactor.Adafactor(
|
302 |
+
learning_rate=0.1,
|
303 |
+
decay_rate=0.8,
|
304 |
+
min_dim_size_to_factor=0,
|
305 |
+
global_norm_clip_threshold=1.0)
|
306 |
+
params = {'x': np.ones((3, 2), np.float32)}
|
307 |
+
state = OptimizerState(
|
308 |
+
1, {
|
309 |
+
'x':
|
310 |
+
_AdafactorParamState(
|
311 |
+
np.array([0.9, 0.9]), np.array([0.1, 0.1, 0.1]),
|
312 |
+
np.zeros((1,)), np.zeros((1,)))
|
313 |
+
})
|
314 |
+
grads = {'x': np.ones((3, 2), np.float32)}
|
315 |
+
new_params, new_state = optimizer_def.apply_gradient(
|
316 |
+
optimizer_def.hyper_params, params, state, grads)
|
317 |
+
expected_new_state = OptimizerState(
|
318 |
+
2, {
|
319 |
+
'x':
|
320 |
+
_AdafactorParamState(
|
321 |
+
np.array([0.478811, 0.478811]),
|
322 |
+
np.array([0.13829, 0.13829, 0.13829]), np.zeros(
|
323 |
+
(1,)), np.zeros((1,)))
|
324 |
+
})
|
325 |
+
expected_new_params = {'x': 0.9 * np.ones((3, 2))}
|
326 |
+
check_eq(new_params, expected_new_params)
|
327 |
+
check_eq(new_state, expected_new_state, rtol=1e-6)
|
328 |
+
|
329 |
+
def test_factorizes(self):
|
330 |
+
params = {'x': np.zeros((64, 64))}
|
331 |
+
optimizer_def = adafactor.Adafactor(
|
332 |
+
learning_rate=0.1,
|
333 |
+
decay_rate=0.8,
|
334 |
+
beta1=None,
|
335 |
+
min_dim_size_to_factor=32)
|
336 |
+
state = optimizer_def.init_state(params)
|
337 |
+
self.assertEqual(state.param_states['x'].v.shape, (1,))
|
338 |
+
self.assertEqual(state.param_states['x'].m.shape, (1,))
|
339 |
+
self.assertEqual(state.param_states['x'].v_row.shape, (64,))
|
340 |
+
self.assertEqual(state.param_states['x'].v_col.shape, (64,))
|
341 |
+
|
342 |
+
params = {'x': np.zeros((31, 64))}
|
343 |
+
optimizer_def = adafactor.Adafactor(
|
344 |
+
learning_rate=0.1,
|
345 |
+
decay_rate=0.8,
|
346 |
+
beta1=None,
|
347 |
+
min_dim_size_to_factor=32)
|
348 |
+
state = optimizer_def.init_state(params)
|
349 |
+
self.assertEqual(state.param_states['x'].v.shape, (31, 64))
|
350 |
+
self.assertEqual(state.param_states['x'].m.shape, (1,))
|
351 |
+
self.assertEqual(state.param_states['x'].v_row.shape, (1,))
|
352 |
+
self.assertEqual(state.param_states['x'].v_col.shape, (1,))
|
353 |
+
|
354 |
+
# Manually specified factorization rules tests.
|
355 |
+
|
356 |
+
@parameterized.parameters(
|
357 |
+
{'rule': (_ROW, _COL)},
|
358 |
+
{'rule': (_COL, _ROW)},
|
359 |
+
)
|
360 |
+
def test_2D_ignore_specified_factor_rule(self, rule):
|
361 |
+
x = {'a': jnp.ones((24, 16))}
|
362 |
+
factor_map = adafactor.HParamMap((('a', rule),))
|
363 |
+
opt_def = adafactor.Adafactor(
|
364 |
+
min_dim_size_to_factor=8, factor_map=factor_map)
|
365 |
+
optimizer = opt_def.create(x)
|
366 |
+
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
|
367 |
+
# Since param is 2D, the explicit factor rule should be ignored and falls
|
368 |
+
# back to heuristics where v_row corresponds to the smaller dim.
|
369 |
+
ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (16,)}
|
370 |
+
self.assertTrue(tree_equals(shapes, ref))
|
371 |
+
|
372 |
+
def test_3D_simple_manual_rules(self):
|
373 |
+
x = {'a': jnp.ones((24, 4, 16))}
|
374 |
+
|
375 |
+
factor_map = adafactor.HParamMap((('a', (_COL, _BATCH, _ROW)),))
|
376 |
+
opt_def = adafactor.Adafactor(
|
377 |
+
min_dim_size_to_factor=8, factor_map=factor_map)
|
378 |
+
optimizer = opt_def.create(x)
|
379 |
+
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
|
380 |
+
ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (4, 16)}
|
381 |
+
self.assertTrue(tree_equals(shapes, ref))
|
382 |
+
|
383 |
+
factor_map = adafactor.HParamMap((('a', (_ROW, _BATCH, _COL)),))
|
384 |
+
opt_def = adafactor.Adafactor(
|
385 |
+
min_dim_size_to_factor=8, factor_map=factor_map)
|
386 |
+
optimizer = opt_def.create(x)
|
387 |
+
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
|
388 |
+
ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (4, 16), 'a/v_row': (24, 4)}
|
389 |
+
self.assertTrue(tree_equals(shapes, ref))
|
390 |
+
|
391 |
+
factor_map = adafactor.HParamMap((('a', (_COL, _ROW, _ROW)),))
|
392 |
+
opt_def = adafactor.Adafactor(
|
393 |
+
min_dim_size_to_factor=8, factor_map=factor_map)
|
394 |
+
optimizer = opt_def.create(x)
|
395 |
+
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
|
396 |
+
ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (4, 16)}
|
397 |
+
self.assertTrue(tree_equals(shapes, ref))
|
398 |
+
|
399 |
+
factor_map = adafactor.HParamMap((('a', (_COL, _COL, _ROW)),))
|
400 |
+
opt_def = adafactor.Adafactor(
|
401 |
+
min_dim_size_to_factor=8, factor_map=factor_map)
|
402 |
+
optimizer = opt_def.create(x)
|
403 |
+
shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
|
404 |
+
ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (16,)}
|
405 |
+
self.assertTrue(tree_equals(shapes, ref))
|
406 |
+
|
407 |
+
def test_standard_factor_rules(self):
|
408 |
+
# one-off test to double-check that we're following the previous
|
409 |
+
# heuristic convention for rows/columns.
|
410 |
+
def test_standard_factor_rules():
|
411 |
+
token_embedding = (_COL, _ROW)
|
412 |
+
attn_qkv = (_ROW, _COL)
|
413 |
+
attn_out = (_COL, _ROW)
|
414 |
+
mlp_in = (_ROW, _COL)
|
415 |
+
mlp_out = (_COL, _ROW)
|
416 |
+
return ((r'_layer_norm/(bias|scale)',
|
417 |
+
None), (r'(encoder|decoder)_norm/(bias|scale)', None),
|
418 |
+
(r'(encoder_decoder_|self_|\b)attention/(query|key|value)/kernel',
|
419 |
+
attn_qkv), (r'(encoder_decoder_|self_|\b)attention/out/kernel',
|
420 |
+
attn_out), (r'mlp/DenseGeneral_\d+/bias', None),
|
421 |
+
(r'mlp/wi(_\d+)?/kernel', mlp_in), (r'mlp/wo/kernel', mlp_out),
|
422 |
+
(r'\brelpos_bias', None), (r'token_embedder', token_embedding),
|
423 |
+
(r'.*', adafactor.HEURISTIC_RULE))
|
424 |
+
|
425 |
+
# create fake model parameters
|
426 |
+
k = jax.random.PRNGKey(0)
|
427 |
+
params = jax.tree_map(
|
428 |
+
lambda shape: jax.random.uniform(k, shape),
|
429 |
+
MODEL_SHAPE,
|
430 |
+
is_leaf=lambda x: isinstance(x, list))
|
431 |
+
# make traditional adafactor state with heuristic
|
432 |
+
factor_map1 = adafactor.HParamMap(((r'.*', adafactor.HEURISTIC_RULE),))
|
433 |
+
optimizer_def1 = adafactor.Adafactor(
|
434 |
+
0.1,
|
435 |
+
decay_rate=0.8,
|
436 |
+
step_offset=0,
|
437 |
+
multiply_by_parameter_scale=True,
|
438 |
+
factor_map=factor_map1)
|
439 |
+
optimizer1 = optimizer_def1.create(params)
|
440 |
+
# make traditional adafactor state with explicit rules
|
441 |
+
factor_map2 = adafactor.HParamMap(test_standard_factor_rules())
|
442 |
+
optimizer_def2 = adafactor.Adafactor(
|
443 |
+
0.1,
|
444 |
+
decay_rate=0.8,
|
445 |
+
step_offset=0,
|
446 |
+
multiply_by_parameter_scale=True,
|
447 |
+
factor_map=factor_map2)
|
448 |
+
optimizer2 = optimizer_def2.create(params)
|
449 |
+
# are they the same?
|
450 |
+
check_eq(optimizer1.state.param_states, optimizer2.state.param_states)
|
451 |
+
|
452 |
+
@parameterized.parameters(
|
453 |
+
{'shape': (64, 64)},
|
454 |
+
{'shape': (64, 132)},
|
455 |
+
{'shape': (132, 64)},
|
456 |
+
{'shape': (132, 132)},
|
457 |
+
{'shape': (132, 140)},
|
458 |
+
{'shape': (140, 132)},
|
459 |
+
)
|
460 |
+
def test_no_factor_map_equivalence(self, shape):
|
461 |
+
k = random.PRNGKey(0)
|
462 |
+
k1, k2 = random.split(k)
|
463 |
+
p = {'a': random.uniform(k1, shape)}
|
464 |
+
g = {'a': random.uniform(k2, shape)}
|
465 |
+
|
466 |
+
orig_opt = optim.Adafactor(0.1).create(p)
|
467 |
+
new_opt = adafactor.Adafactor(0.1, factor_map=None).create(p)
|
468 |
+
check_eq(orig_opt.state_dict(), new_opt.state_dict())
|
469 |
+
|
470 |
+
orig_opt1 = orig_opt.apply_gradient(g)
|
471 |
+
new_opt1 = new_opt.apply_gradient(g)
|
472 |
+
check_eq(orig_opt1.state_dict(), new_opt1.state_dict())
|
473 |
+
|
474 |
+
@parameterized.parameters({
|
475 |
+
'shape': (128, 128),
|
476 |
+
'rule': (_ROW, _COL)
|
477 |
+
}, {
|
478 |
+
'shape': (132, 128),
|
479 |
+
'rule': (_COL, _ROW)
|
480 |
+
}, {
|
481 |
+
'shape': (128, 132),
|
482 |
+
'rule': (_ROW, _COL)
|
483 |
+
})
|
484 |
+
def test_simple_equivalence(self, shape, rule):
|
485 |
+
k = random.PRNGKey(0)
|
486 |
+
k1, k2 = random.split(k)
|
487 |
+
k3, k4 = random.split(k1)
|
488 |
+
k5, k6 = random.split(k2)
|
489 |
+
|
490 |
+
p = {'a': random.uniform(k3, shape), 'b': random.uniform(k4, shape)}
|
491 |
+
g = {'a': random.uniform(k5, shape), 'b': random.uniform(k6, shape)}
|
492 |
+
|
493 |
+
orig_opt = optim.Adafactor(0.1).create(p)
|
494 |
+
factor_map = adafactor.HParamMap(
|
495 |
+
rules=((('a'), rule), ('.*', adafactor.HEURISTIC_RULE)))
|
496 |
+
new_opt = adafactor.Adafactor(0.1, factor_map=factor_map).create(p)
|
497 |
+
check_eq(orig_opt.state_dict(), new_opt.state_dict())
|
498 |
+
|
499 |
+
orig_opt1 = orig_opt.apply_gradient(g)
|
500 |
+
new_opt1 = new_opt.apply_gradient(g)
|
501 |
+
check_eq(orig_opt1.state_dict(), new_opt1.state_dict())
|
502 |
+
|
503 |
+
@parameterized.parameters({'shape': (64, 64)}, {'shape': (132, 132)})
|
504 |
+
def test_multiply_by_parameter_scale_equivalence(self, shape):
|
505 |
+
# Use large parameter values to magnify the parameter scaling effect.
|
506 |
+
p = {'a': np.random.randn(*shape) * 100, 'b': np.random.randn(*shape) * 100}
|
507 |
+
g = {'a': np.random.randn(*shape), 'b': np.random.randn(*shape)}
|
508 |
+
orig_opt = _get_multi_adafactor(
|
509 |
+
3.0, 0, adafactor_exclude_from_parameter_scale=('a',)).create(p)
|
510 |
+
scaling_map = adafactor.HParamMap([('a', False), ('.*', True)])
|
511 |
+
new_opt = adafactor.Adafactor(
|
512 |
+
3.0, multiply_by_parameter_scale=scaling_map).create(p)
|
513 |
+
check_eq(orig_opt.state_dict(), new_opt.state_dict())
|
514 |
+
|
515 |
+
orig_opt1 = orig_opt.apply_gradient(g)
|
516 |
+
new_opt1 = new_opt.apply_gradient(g)
|
517 |
+
check_eq(orig_opt1.state_dict(), new_opt1.state_dict())
|
518 |
+
|
519 |
+
def test_3d_without_factor_map(self):
|
520 |
+
x = {'a': jnp.ones((24, 4, 16))}
|
521 |
+
opt_def = adafactor.Adafactor(factor_map=None)
|
522 |
+
with self.assertRaises(ValueError):
|
523 |
+
_ = opt_def.create(x)
|
524 |
+
|
525 |
+
|
526 |
+
if __name__ == '__main__':
|
527 |
+
absltest.main()
|
t5x/checkpoint_importer.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The T5X Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""T5 Checkpoint Importer."""
|
16 |
+
|
17 |
+
import asyncio
|
18 |
+
from concurrent.futures import thread
|
19 |
+
import re
|
20 |
+
from typing import Any, Callable, Mapping, MutableMapping, Optional, Union
|
21 |
+
|
22 |
+
from flax import traverse_util
|
23 |
+
import jax
|
24 |
+
from jax import numpy as jnp
|
25 |
+
import numpy as np
|
26 |
+
import orbax.checkpoint
|
27 |
+
import tensorflow as tf
|
28 |
+
import tensorstore as ts
|
29 |
+
|
30 |
+
# TODO(b/233659813): Cleanup clients depending on t5x.checkpoint_importer for
|
31 |
+
# LazyArray. Reconcile divergence in subclass implementation when possible.
|
32 |
+
LazyArray = orbax.checkpoint.lazy_array.LazyArray
|
33 |
+
|
34 |
+
|
35 |
+
# TODO(brianlester): The choice between using a `LazyTreadPoolArray` or a
|
36 |
+
# `LazyAwaitableArray` is dependent on if the user provided `get_fn` is blocking
|
37 |
+
# or async respectively, if we can detect which it is, we can automatically
|
38 |
+
# proxy to the correct subclass. We cannot detect of `get_fn` is a lambda that
|
39 |
+
# wraps an async call so this isn't possible yet. Add this dispatch once we are
|
40 |
+
# able to detect that, python3.8+ can detect async for partial'ed functions but
|
41 |
+
# not lambdas.
|
42 |
+
class LazyThreadPoolArray(LazyArray):
|
43 |
+
"""Lazily and asynchronously loads an array when the `get_fn` blocks."""
|
44 |
+
|
45 |
+
# Uses a global threadpool to enable asynchronous loading.
|
46 |
+
executor = thread.ThreadPoolExecutor()
|
47 |
+
|
48 |
+
def get_async(self) -> asyncio.Future:
|
49 |
+
return asyncio.wrap_future(self.executor.submit(self.get))
|
50 |
+
|
51 |
+
def get(self) -> np.ndarray:
|
52 |
+
arr = self._get_fn()
|
53 |
+
if arr.dtype != self.dtype:
|
54 |
+
arr = arr.astype(self.dtype)
|
55 |
+
return arr
|
56 |
+
|
57 |
+
|
58 |
+
class LazyAwaitableArray(LazyArray):
|
59 |
+
"""Lazily and asynchronously loads an array when the `get_fn` is async.
|
60 |
+
|
61 |
+
Note:
|
62 |
+
The synchronous load method `.get` requires the asyncio event loop and
|
63 |
+
calling `.run_until_complete`. This is not supported when the event loop is
|
64 |
+
already running (for example, from inside another async function).
|
65 |
+
|
66 |
+
Note:
|
67 |
+
Currently, this class has a few helper methods for creating a
|
68 |
+
LazyAwaitableArray when the input could be either an array, or a TensorStore
|
69 |
+
spec. Most people use async code when dealing with TensorStore so the
|
70 |
+
classmethods have been placed here. When someone eventually uses a blocking
|
71 |
+
function to read from TensorStore they can be moved to the LazyArray base
|
72 |
+
class.
|
73 |
+
"""
|
74 |
+
|
75 |
+
def get_async(self) -> asyncio.Future:
|
76 |
+
|
77 |
+
async def _get_and_cast():
|
78 |
+
# Pytype has a false positive here, where it treats our _get_fn (_read_ts
|
79 |
+
# in this case) as having a return type of `np.ndarray` instead of
|
80 |
+
# wrapping it in an Awaitable. Related to this bug
|
81 |
+
# https://github.com/google/pytype/issues/527
|
82 |
+
arr = await self._get_fn() # pytype: disable=bad-return-type
|
83 |
+
if arr.dtype != self.dtype:
|
84 |
+
arr = arr.astype(self.dtype)
|
85 |
+
return arr
|
86 |
+
|
87 |
+
return asyncio.ensure_future(_get_and_cast())
|
88 |
+
|
89 |
+
def get(self) -> np.ndarray:
|
90 |
+
loop = asyncio.get_event_loop()
|
91 |
+
return loop.run_until_complete(self.get_async())
|
92 |
+
|
93 |
+
@classmethod
|
94 |
+
def from_tensor_store_spec(
|
95 |
+
cls,
|
96 |
+
ts_spec: ts.Spec,
|
97 |
+
get_fn: Callable[[], np.ndarray],
|
98 |
+
dtype: Optional[jnp.dtype] = None) -> 'LazyAwaitableArray':
|
99 |
+
"""Create a LazyAwaitableArray based on a tensorstore.Spec."""
|
100 |
+
ts_spec = ts_spec.to_json()
|
101 |
+
shape = ts_spec['metadata']['shape']
|
102 |
+
if dtype is None:
|
103 |
+
dtype = jnp.dtype(ts_spec['dtype'])
|
104 |
+
else:
|
105 |
+
dtype = jnp.dtype(dtype)
|
106 |
+
# v2 T5X checkpoints use uint16 as the TensorStore datatype and then store
|
107 |
+
# the bfloat16 bytes as in in the 16 bytes uint16 has (no actual cast). When
|
108 |
+
# When reading the dtype from the TensorStore, if we keep the dtype of these
|
109 |
+
# v2 checkpoints as np.uint16 then the _get_fn (which has a possible cast to
|
110 |
+
# support the `restore_dtype` parameter for the checkpointer) will actually
|
111 |
+
# cast the bfloat16 values to uint16, generally resulting in an array of all
|
112 |
+
# zeros. This check avoid the actual cast to uint16 by replacing the dtype.
|
113 |
+
if dtype == np.uint16:
|
114 |
+
dtype = jnp.bfloat16
|
115 |
+
return cls(shape, dtype, get_fn)
|
116 |
+
|
117 |
+
@classmethod
|
118 |
+
def from_array(cls,
|
119 |
+
array: np.ndarray,
|
120 |
+
get_fn: Callable[[], np.ndarray],
|
121 |
+
dtype: Optional[jnp.dtype] = None) -> 'LazyAwaitableArray':
|
122 |
+
"""Create a LazyAwaitableArray based on an array or python number."""
|
123 |
+
if dtype is None:
|
124 |
+
dtype = array.dtype
|
125 |
+
else:
|
126 |
+
dtype = jnp.dtype(dtype)
|
127 |
+
return cls(array.shape, dtype, get_fn)
|
128 |
+
|
129 |
+
@classmethod
|
130 |
+
def from_tensor_store_spec_or_array(
|
131 |
+
cls,
|
132 |
+
maybe_ts_spec: Union[ts.Spec, np.ndarray],
|
133 |
+
get_fn: Callable[[], np.ndarray],
|
134 |
+
dtype: Optional[jnp.dtype] = None) -> 'LazyAwaitableArray':
|
135 |
+
"""Create a LazyAwaitableArray based on an array or a tensorstore.Spec."""
|
136 |
+
if isinstance(maybe_ts_spec, ts.Spec):
|
137 |
+
return cls.from_tensor_store_spec(maybe_ts_spec, get_fn, dtype=dtype)
|
138 |
+
return cls.from_array(maybe_ts_spec, get_fn, dtype=dtype)
|
139 |
+
|
140 |
+
|
141 |
+
class CheckpointTranslator:
|
142 |
+
"""Utility class for defining mapping rules from one flatdict to another.
|
143 |
+
|
144 |
+
We assume a checkpoint is loaded as a dictionary with flattened keys of the
|
145 |
+
form: 'name0/name1/name2/.../nameN'
|
146 |
+
|
147 |
+
A rule is added with the 'add' decorator, which takes a regex matching rule
|
148 |
+
and wraps a conversion function, feeding it (opts, key, val, **regex_groups)
|
149 |
+
where opts is a dict containing apply-time keyword options for use by the
|
150 |
+
conversion functions.
|
151 |
+
"""
|
152 |
+
|
153 |
+
def __init__(self):
|
154 |
+
self.rules = []
|
155 |
+
|
156 |
+
def add(self, pattern):
|
157 |
+
"""Adds a new keyval conversion rule.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
pattern: regex with capture groups for matching given sets of model
|
161 |
+
variables. We terminate all regexes with '$' to force complete matches.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
Translation function decorator for associating with the provided
|
165 |
+
pattern.
|
166 |
+
"""
|
167 |
+
|
168 |
+
def register_translation_fn_decorator(fn):
|
169 |
+
# We force a complete match by adding end-of-string match.
|
170 |
+
self.rules.append((re.compile(pattern + '$'), fn))
|
171 |
+
return fn
|
172 |
+
|
173 |
+
return register_translation_fn_decorator
|
174 |
+
|
175 |
+
def apply(self, flatdict, **opts):
|
176 |
+
"""Applies rules to a flattened dictionary.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
flatdict: flat-key dictionary of variables.
|
180 |
+
**opts: additional config options for translation rules supplied at
|
181 |
+
application time.
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
Checkpoint data with translated key/values in flat-key dict format.
|
185 |
+
"""
|
186 |
+
new_dict = {}
|
187 |
+
unmatched = {}
|
188 |
+
for k, v in flatdict.items():
|
189 |
+
matched = False
|
190 |
+
for rule_pat, rule_fn in self.rules:
|
191 |
+
if rule_pat.match(k):
|
192 |
+
groups = rule_pat.match(k).groups()
|
193 |
+
new_k, new_v = rule_fn(opts, k, v, *groups)
|
194 |
+
if new_k is not None:
|
195 |
+
new_dict[new_k] = new_v
|
196 |
+
matched = True
|
197 |
+
break
|
198 |
+
if not matched:
|
199 |
+
unmatched[k] = v
|
200 |
+
|
201 |
+
# We force every key-value pair in checkpoint to have a rule associated with
|
202 |
+
# it.
|
203 |
+
if unmatched:
|
204 |
+
raise ValueError('Unmapped tensor keys exist: %s' % unmatched)
|
205 |
+
|
206 |
+
return new_dict
|
207 |
+
|
208 |
+
|
209 |
+
# Create a translation rule set for importing T5 & T5.1.1 model checkpoints.
|
210 |
+
# -----------------------------------------------------------------------------
|
211 |
+
t5_importer = CheckpointTranslator()
|
212 |
+
|
213 |
+
# Name mappings.
|
214 |
+
SLOT_MAP = {'_slot_vc': 'v_col', '_slot_vr': 'v_row', '_slot_v': 'v'}
|
215 |
+
TOWER_MAP = {'transformer': 'decoder'}
|
216 |
+
|
217 |
+
|
218 |
+
@t5_importer.add(r'global_step')
|
219 |
+
def global_step(opts, key, val):
|
220 |
+
del opts, key
|
221 |
+
return 'state/step', val.astype(np.int32).get() if isinstance(
|
222 |
+
val, LazyArray) else val
|
223 |
+
|
224 |
+
|
225 |
+
@t5_importer.add(r'shared/embedding(\w*)')
|
226 |
+
def shared_embeddings(opts, key, val, slot):
|
227 |
+
del opts, key
|
228 |
+
prefix = 'state/param_states' if slot else 'target'
|
229 |
+
suffix = '/' + SLOT_MAP[slot] if slot else ''
|
230 |
+
newkey = f'{prefix}/token_embedder/embedding{suffix}'
|
231 |
+
return newkey, val
|
232 |
+
|
233 |
+
|
234 |
+
@t5_importer.add(r'(encoder|decoder|transformer)/embedding(\w*)')
|
235 |
+
def separate_embeddings(opts, key, val, encdec, slot):
|
236 |
+
del opts, key
|
237 |
+
prefix = 'state/param_states' if slot else 'target'
|
238 |
+
suffix = '/' + SLOT_MAP[slot] if slot else ''
|
239 |
+
encdec = TOWER_MAP.get(encdec, encdec)
|
240 |
+
newkey = f'{prefix}/{encdec}/token_embedder/embedding{suffix}'
|
241 |
+
return newkey, val
|
242 |
+
|
243 |
+
|
244 |
+
# In the Mesh TensorFlow T5 code, relative_attention_bias always occurs in layer
|
245 |
+
# 0 because SelfAttention precedes other sublayers within the same block.
|
246 |
+
@t5_importer.add(
|
247 |
+
r'(encoder|decoder|transformer)/block_(\d+)/layer_000/SelfAttention/relative_attention_bias(\w*)'
|
248 |
+
)
|
249 |
+
def rel_embeddings(opts, key, val, encdec, blocknum, slot):
|
250 |
+
"""Process relpos bias assuming that they are not shared across layers."""
|
251 |
+
del opts, key
|
252 |
+
prefix = 'state/param_states' if slot else 'target'
|
253 |
+
suffix = '/' + SLOT_MAP[slot] if slot else ''
|
254 |
+
blocknum = int(blocknum)
|
255 |
+
encdec = TOWER_MAP.get(encdec, encdec)
|
256 |
+
# At this point, we can't determine whether the relpos bias was shared across
|
257 |
+
# layers or not. We first assume that it was not shared. During post
|
258 |
+
# processing, we remove the layers_0 scope if it was shared.
|
259 |
+
newkey = f'{prefix}/{encdec}/layers_{blocknum}/relpos_bias/rel_embedding{suffix}'
|
260 |
+
return newkey, val
|
261 |
+
|
262 |
+
|
263 |
+
@t5_importer.add(
|
264 |
+
r'(encoder|decoder|transformer)/block_(\d+)/layer_\d+/(SelfAttention|EncDecAttention)/(q|k|v|o)(\w*)'
|
265 |
+
)
|
266 |
+
def attention_layers(opts, key, val, encdec, blocknum, attntype, qkvo, slot):
|
267 |
+
"""Process attention layers."""
|
268 |
+
del opts, key
|
269 |
+
prefix = 'state/param_states' if slot else 'target'
|
270 |
+
suffix = '/' + SLOT_MAP[slot] if slot else ''
|
271 |
+
blocknum = int(blocknum)
|
272 |
+
encdec = TOWER_MAP.get(encdec, encdec)
|
273 |
+
matrix = {'q': 'query', 'k': 'key', 'v': 'value', 'o': 'out'}[qkvo]
|
274 |
+
|
275 |
+
if encdec == 'encoder':
|
276 |
+
attntype = 'attention'
|
277 |
+
else:
|
278 |
+
attntype = {
|
279 |
+
'SelfAttention': 'self_attention',
|
280 |
+
'EncDecAttention': 'encoder_decoder_attention'
|
281 |
+
}[attntype]
|
282 |
+
newkey = f'{prefix}/{encdec}/layers_{blocknum}/{attntype}/{matrix}/kernel{suffix}'
|
283 |
+
return newkey, val
|
284 |
+
|
285 |
+
|
286 |
+
@t5_importer.add(
|
287 |
+
r'(encoder|decoder|transformer)/block_(\d+)/layer_\d+/DenseReluDense/(wi|wo)(?:_(\d+))?/kernel(\w*)'
|
288 |
+
)
|
289 |
+
def mlpblock(opts, key, val, encdec, blocknum, io_name, io_num, slot):
|
290 |
+
"""Process MLP blocks."""
|
291 |
+
del opts, key
|
292 |
+
prefix = 'state/param_states' if slot else 'target'
|
293 |
+
suffix = '/' + SLOT_MAP[slot] if slot else ''
|
294 |
+
blocknum = int(blocknum)
|
295 |
+
encdec = TOWER_MAP.get(encdec, encdec)
|
296 |
+
io_num = f'_{io_num}' if io_num else ''
|
297 |
+
newkey = f'{prefix}/{encdec}/layers_{blocknum}/mlp/{io_name}{io_num}/kernel{suffix}'
|
298 |
+
return newkey, val
|
299 |
+
|
300 |
+
|
301 |
+
@t5_importer.add(
|
302 |
+
r'(encoder|decoder|transformer)/block_(\d+)/layer_(\d+)/(?:layer|rms)_norm/scale(\w*)'
|
303 |
+
)
|
304 |
+
def layernorms(opts, key, val, encdec, blocknum, lyrnum, slot):
|
305 |
+
"""Process layer norms assuming that they are pre-layernorms."""
|
306 |
+
del opts, key
|
307 |
+
prefix = 'state/param_states' if slot else 'target'
|
308 |
+
suffix = '/' + SLOT_MAP[slot] if slot else ''
|
309 |
+
lyrnum = int(lyrnum)
|
310 |
+
|
311 |
+
if encdec == 'transformer':
|
312 |
+
layernorm_type = ['pre_self_attention_layer_norm',
|
313 |
+
'pre_mlp_layer_norm'][lyrnum]
|
314 |
+
|
315 |
+
elif encdec == 'encoder':
|
316 |
+
layernorm_type = ['pre_attention_layer_norm', 'pre_mlp_layer_norm'][lyrnum]
|
317 |
+
else: # decoder
|
318 |
+
layernorm_type = [
|
319 |
+
'pre_self_attention_layer_norm', 'pre_cross_attention_layer_norm',
|
320 |
+
'pre_mlp_layer_norm'
|
321 |
+
][lyrnum]
|
322 |
+
|
323 |
+
encdec = TOWER_MAP.get(encdec, encdec)
|
324 |
+
newkey = f'{prefix}/{encdec}/layers_{int(blocknum)}/{layernorm_type}/scale{suffix}'
|
325 |
+
return newkey, val
|
326 |
+
|
327 |
+
|
328 |
+
@t5_importer.add(
|
329 |
+
r'(encoder|decoder|transformer)/(?:final_layer|rms)_norm/scale(\w*)')
|
330 |
+
def final_layernorms(opts, key, val, encdec, slot):
|
331 |
+
"""Process final layer norms."""
|
332 |
+
del opts, key
|
333 |
+
prefix = 'state/param_states' if slot else 'target'
|
334 |
+
suffix = '/' + SLOT_MAP[slot] if slot else ''
|
335 |
+
norm = {
|
336 |
+
'encoder': 'encoder_norm',
|
337 |
+
'decoder': 'decoder_norm',
|
338 |
+
'transformer': 'decoder_norm'
|
339 |
+
}[encdec]
|
340 |
+
encdec = TOWER_MAP.get(encdec, encdec)
|
341 |
+
newkey = f'{prefix}/{encdec}/{norm}/scale{suffix}'
|
342 |
+
return newkey, val
|
343 |
+
|
344 |
+
|
345 |
+
@t5_importer.add(r'(?:decoder|transformer)/logits/kernel(\w*)')
|
346 |
+
def final_logits(opts, key, val, slot):
|
347 |
+
del opts, key
|
348 |
+
prefix = 'state/param_states' if slot else 'target'
|
349 |
+
suffix = '/' + SLOT_MAP[slot] if slot else ''
|
350 |
+
newkey = f'{prefix}/decoder/logits_dense/kernel{suffix}'
|
351 |
+
return newkey, val
|
352 |
+
|
353 |
+
|
354 |
+
def _add_missing_param_states(t5_data):
|
355 |
+
"""Add dummy slots that Flax Adafactor requires but TF does not."""
|
356 |
+
updates = {}
|
357 |
+
for k in t5_data:
|
358 |
+
if k.startswith('target'):
|
359 |
+
state_leaf = 'state/param_states' + k[len('target'):]
|
360 |
+
updates[state_leaf + '/m'] = np.zeros((1,), np.float32)
|
361 |
+
if state_leaf + '/v' in t5_data:
|
362 |
+
updates[state_leaf + '/v_row'] = np.zeros((1,), np.float32)
|
363 |
+
updates[state_leaf + '/v_col'] = np.zeros((1,), np.float32)
|
364 |
+
elif state_leaf + '/v_row' in t5_data:
|
365 |
+
updates[state_leaf + '/v'] = np.zeros((1,), np.float32)
|
366 |
+
t5_data.update(**updates)
|
367 |
+
return t5_data
|
368 |
+
|
369 |
+
|
370 |
+
def _maybe_correct_relpos_bias(t5_data):
|
371 |
+
"""Correct the relpos_bias format if it is shared across layers."""
|
372 |
+
max_layer_ind = 0
|
373 |
+
for k, v in t5_data.items():
|
374 |
+
match = re.search(r'layers_(\d+)/relpos_bias', k)
|
375 |
+
if match:
|
376 |
+
layer_ind = int(match.groups()[0])
|
377 |
+
max_layer_ind = max(max_layer_ind, layer_ind)
|
378 |
+
|
379 |
+
modified_dict = {}
|
380 |
+
if max_layer_ind == 0:
|
381 |
+
# Relative position biases are shared across layers
|
382 |
+
for k, v in t5_data.items():
|
383 |
+
new_k = re.sub(r'layers_\d+/relpos_bias', 'relpos_bias', k)
|
384 |
+
modified_dict[new_k] = v
|
385 |
+
else:
|
386 |
+
# Relative position biases are unique in each layer. No more processing is
|
387 |
+
# necessary.
|
388 |
+
modified_dict = t5_data
|
389 |
+
|
390 |
+
return modified_dict
|
391 |
+
|
392 |
+
|
393 |
+
# Load checkpoint, translate, and update flax optimizer and model.
|
394 |
+
# -----------------------------------------------------------------------------
|
395 |
+
def load_tf_ckpt(path):
|
396 |
+
"""Load a TF checkpoint as a flat dictionary of numpy arrays."""
|
397 |
+
ckpt_reader = tf.train.load_checkpoint(path)
|
398 |
+
ckpt_shape_map = ckpt_reader.get_variable_to_shape_map()
|
399 |
+
ckpt_dtype_map = ckpt_reader.get_variable_to_dtype_map()
|
400 |
+
datamap = { # pylint: disable=g-complex-comprehension
|
401 |
+
k: LazyThreadPoolArray(
|
402 |
+
s,
|
403 |
+
jnp.dtype(ckpt_dtype_map[k].as_numpy_dtype),
|
404 |
+
lambda x=k: ckpt_reader.get_tensor(x))
|
405 |
+
for k, s in ckpt_shape_map.items()
|
406 |
+
}
|
407 |
+
return datamap
|
408 |
+
|
409 |
+
|
410 |
+
def _update_state_dict(state_dict: Mapping[str, Any],
|
411 |
+
t5_data: MutableMapping[str, LazyArray],
|
412 |
+
strict: bool = True) -> Mapping[str, Any]:
|
413 |
+
"""Update flax optimizer for T5 model.
|
414 |
+
|
415 |
+
Args:
|
416 |
+
state_dict: Optimizer to update with T5 parameters.
|
417 |
+
t5_data: T5 model parameters, typically loaded from a checkpoint.
|
418 |
+
strict: If True requires that optimizer and t5_data mappings contain the
|
419 |
+
same set of names (variables). If False, updating will succeed even if
|
420 |
+
t5_data contains variables not in the optimizer. If the optimizer has
|
421 |
+
variables not in t5_data, this function will still fail.
|
422 |
+
|
423 |
+
Returns:
|
424 |
+
Updated optimizer.
|
425 |
+
"""
|
426 |
+
flat_state_dict = traverse_util.flatten_dict(state_dict, sep='/')
|
427 |
+
|
428 |
+
# Remove parameters from the checkpoint not found in the optimizer (this
|
429 |
+
# allows us to load checkpoints that contain more parameters than our current
|
430 |
+
# model).
|
431 |
+
if not strict:
|
432 |
+
for k in list(t5_data):
|
433 |
+
if k not in flat_state_dict:
|
434 |
+
t5_data.pop(k)
|
435 |
+
|
436 |
+
# Shape check.
|
437 |
+
for k, v in t5_data.items():
|
438 |
+
if flat_state_dict[k].shape != v.shape:
|
439 |
+
raise ValueError(
|
440 |
+
f'Variable {k} has shape {v.shape} != {flat_state_dict[k].shape}')
|
441 |
+
flat_state_dict = t5_data
|
442 |
+
state_dict = traverse_util.unflatten_dict(
|
443 |
+
{tuple(k.split('/')): v for k, v in flat_state_dict.items()})
|
444 |
+
return state_dict
|
445 |
+
|
446 |
+
|
447 |
+
def restore_from_t5_checkpoint(
|
448 |
+
state_dict: Mapping[str, Any],
|
449 |
+
path: str,
|
450 |
+
lazy_parameters: bool = False,
|
451 |
+
strict: bool = True,
|
452 |
+
translator: Optional[CheckpointTranslator] = None) -> Mapping[str, Any]:
|
453 |
+
"""Load T5 checkpoint and update Adafactor optimizer and T5 model from it.
|
454 |
+
|
455 |
+
We require that the final translated checkpoint structure exactly matches
|
456 |
+
that of the Flax Adafactor + Transformer data, up to shape agreement of
|
457 |
+
the leaves.
|
458 |
+
|
459 |
+
Args:
|
460 |
+
state_dict: Flax Adafactor Optimizer for T5 transformer encoder-decoder.
|
461 |
+
path: a path to checkpoint file or directory.
|
462 |
+
lazy_parameters: whether to leave the parameters as LazyArrays to preserve
|
463 |
+
memory.
|
464 |
+
strict: If True requires that optimizer and t5_data mappings contain the
|
465 |
+
same set of names (variables). If False, updating will succeed even if
|
466 |
+
t5_data contains variables not in the optimizer. If the optimizer has
|
467 |
+
variables not in t5_data, this function will still fail.
|
468 |
+
translator: The mapping rules for conversion. If None, then default T5
|
469 |
+
conversion rules will be used.
|
470 |
+
|
471 |
+
Returns:
|
472 |
+
Adafactor optimizer updated with parameters and optimizer state from
|
473 |
+
T5 checkpoint.
|
474 |
+
"""
|
475 |
+
if translator is None:
|
476 |
+
translator = t5_importer
|
477 |
+
ckpt_data = load_tf_ckpt(path)
|
478 |
+
t5_data = translator.apply(ckpt_data)
|
479 |
+
t5_data = _add_missing_param_states(t5_data)
|
480 |
+
t5_data = _maybe_correct_relpos_bias(t5_data)
|
481 |
+
state_dict = _update_state_dict(state_dict, t5_data, strict=strict)
|
482 |
+
if not lazy_parameters:
|
483 |
+
state_dict = jax.tree_map(
|
484 |
+
lambda x: x.get() if isinstance(x, LazyArray) else x, state_dict)
|
485 |
+
return state_dict
|
t5x/checkpoint_importer_test.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The T5X Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for t5x.checkpoint_importer."""
|
16 |
+
|
17 |
+
import json
|
18 |
+
import os
|
19 |
+
|
20 |
+
from absl import flags
|
21 |
+
from absl.testing import absltest
|
22 |
+
import jax
|
23 |
+
import numpy as np
|
24 |
+
from t5x import checkpoint_importer
|
25 |
+
import tensorflow as tf
|
26 |
+
|
27 |
+
|
28 |
+
class CheckpointImporterTest(absltest.TestCase):
|
29 |
+
|
30 |
+
def test_rel_embeddings_shared_layers(self):
|
31 |
+
# This represents a ckpt where the Mesh TensorFlow's
|
32 |
+
# transformer_layers.SelfAttention.relative_attention_type = "bias_shared",
|
33 |
+
# i.e., the same relative attention parameters are shared by all layers
|
34 |
+
# within the (en|de)coder.
|
35 |
+
ckpt_data = {
|
36 |
+
'encoder/block_000/layer_000/SelfAttention/relative_attention_bias':
|
37 |
+
1,
|
38 |
+
'decoder/block_000/layer_000/SelfAttention/relative_attention_bias':
|
39 |
+
2,
|
40 |
+
'decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v':
|
41 |
+
3,
|
42 |
+
}
|
43 |
+
t5_data = checkpoint_importer.t5_importer.apply(ckpt_data)
|
44 |
+
t5_data = checkpoint_importer._maybe_correct_relpos_bias(t5_data)
|
45 |
+
expected = {
|
46 |
+
'target/encoder/relpos_bias/rel_embedding': 1,
|
47 |
+
'target/decoder/relpos_bias/rel_embedding': 2,
|
48 |
+
'state/param_states/decoder/relpos_bias/rel_embedding/v': 3,
|
49 |
+
}
|
50 |
+
self.assertEqual(t5_data, expected)
|
51 |
+
|
52 |
+
def test_rel_embeddings_per_layer(self):
|
53 |
+
# This represents a ckpt where the Mesh TensorFlow's
|
54 |
+
# transformer_layers.SelfAttention.relative_attention_type = "bias", i.e.,
|
55 |
+
# each layer has its own relative attention parameters.
|
56 |
+
ckpt_data = {
|
57 |
+
'encoder/block_000/layer_000/SelfAttention/relative_attention_bias':
|
58 |
+
1,
|
59 |
+
'encoder/block_001/layer_000/SelfAttention/relative_attention_bias':
|
60 |
+
2,
|
61 |
+
'decoder/block_000/layer_000/SelfAttention/relative_attention_bias':
|
62 |
+
3,
|
63 |
+
'decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v':
|
64 |
+
4,
|
65 |
+
'decoder/block_011/layer_000/SelfAttention/relative_attention_bias':
|
66 |
+
5
|
67 |
+
}
|
68 |
+
t5_data = checkpoint_importer.t5_importer.apply(ckpt_data)
|
69 |
+
t5_data = checkpoint_importer._maybe_correct_relpos_bias(t5_data)
|
70 |
+
expected = {
|
71 |
+
'target/encoder/layers_0/relpos_bias/rel_embedding': 1,
|
72 |
+
'target/encoder/layers_1/relpos_bias/rel_embedding': 2,
|
73 |
+
'target/decoder/layers_0/relpos_bias/rel_embedding': 3,
|
74 |
+
'state/param_states/decoder/layers_0/relpos_bias/rel_embedding/v': 4,
|
75 |
+
'target/decoder/layers_11/relpos_bias/rel_embedding': 5,
|
76 |
+
}
|
77 |
+
self.assertEqual(t5_data, expected)
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == '__main__':
|
81 |
+
absltest.main()
|
t5x/checkpoint_utils.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The T5X Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Checkpoint helper functions for managing checkpoints.
|
16 |
+
|
17 |
+
Supports marking checkpoints as pinned to exclude them from the checkpointer
|
18 |
+
removal process.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import os
|
22 |
+
|
23 |
+
from absl import logging
|
24 |
+
|
25 |
+
from tensorflow.io import gfile
|
26 |
+
|
27 |
+
# PINNED file in the checkpoint directory indicates that the checkpoint should
|
28 |
+
# not be removed during the automatic pruning of old checkpoints.
|
29 |
+
_PINNED_CHECKPOINT_FILENAME = 'PINNED'
|
30 |
+
|
31 |
+
|
32 |
+
def pinned_checkpoint_filepath(ckpt_dir: str) -> str:
|
33 |
+
"""Full path of the pinned checkpoint file."""
|
34 |
+
return os.path.join(ckpt_dir, _PINNED_CHECKPOINT_FILENAME)
|
35 |
+
|
36 |
+
|
37 |
+
def is_pinned_checkpoint(ckpt_dir: str) -> bool:
|
38 |
+
"""Returns whether the checkpoint is pinned, and should NOT be removed."""
|
39 |
+
pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir)
|
40 |
+
if gfile.exists(pinned_ckpt_file):
|
41 |
+
return True
|
42 |
+
return False
|
43 |
+
|
44 |
+
|
45 |
+
def pin_checkpoint(ckpt_dir: str, txt: str = '1') -> None:
|
46 |
+
"""Pin a checkpoint so it does not get deleted by the normal pruning process.
|
47 |
+
|
48 |
+
Creates a PINNED file in the checkpoint directory to indicate the checkpoint
|
49 |
+
should be excluded from the deletion of old checkpoints.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
ckpt_dir: The checkpoint step dir that is to be always kept.
|
53 |
+
txt: Text to be written into the checkpoints ALWAYS_KEEP me file.
|
54 |
+
"""
|
55 |
+
pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir)
|
56 |
+
with gfile.GFile(pinned_ckpt_file, 'w') as f:
|
57 |
+
logging.debug('Write %s file : %s.', pinned_ckpt_file, txt)
|
58 |
+
f.write(txt)
|
59 |
+
|
60 |
+
|
61 |
+
def unpin_checkpoint(ckpt_dir: str) -> None:
|
62 |
+
"""Removes the pinned status of the checkpoint so it is open for deletion."""
|
63 |
+
if not is_pinned_checkpoint(ckpt_dir):
|
64 |
+
logging.debug('%s is not PINNED. Nothing to do here.', ckpt_dir)
|
65 |
+
return
|
66 |
+
try:
|
67 |
+
pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir)
|
68 |
+
logging.debug('Remove %s file.', pinned_ckpt_file)
|
69 |
+
gfile.rmtree(pinned_ckpt_file)
|
70 |
+
except IOError:
|
71 |
+
logging.exception('Failed to unpin %s', ckpt_dir)
|
72 |
+
|
73 |
+
|
74 |
+
def remove_checkpoint_dir(ckpt_dir: str) -> None:
|
75 |
+
"""Removes the checkpoint dir if it is not pinned."""
|
76 |
+
if not is_pinned_checkpoint(ckpt_dir):
|
77 |
+
logging.info('Deleting checkpoint: %s', ckpt_dir)
|
78 |
+
gfile.rmtree(ckpt_dir)
|
79 |
+
else:
|
80 |
+
logging.info('Keeping pinned checkpoint: %s', ckpt_dir)
|
81 |
+
|
82 |
+
|
83 |
+
def remove_dataset_checkpoint(ckpt_dir: str, train_ds_prefix: str) -> None:
|
84 |
+
"""Removes dataset checkpoints if the checkpoint is not pinned."""
|
85 |
+
if not is_pinned_checkpoint(ckpt_dir):
|
86 |
+
train_ds_pattern = os.path.join(ckpt_dir, train_ds_prefix + '*')
|
87 |
+
logging.info('Deleting dataset checkpoint: %s', train_ds_pattern)
|
88 |
+
for file in gfile.glob(train_ds_pattern):
|
89 |
+
gfile.remove(file)
|
90 |
+
else:
|
91 |
+
logging.info('Keeping pinned checkpoint: %s', ckpt_dir)
|
t5x/checkpoint_utils_test.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The T5X Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for t5x.checkpoint_utils."""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import traceback
|
19 |
+
|
20 |
+
from absl.testing import absltest
|
21 |
+
from t5x import checkpoint_utils
|
22 |
+
from tensorflow.io import gfile
|
23 |
+
|
24 |
+
TESTDATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
|
25 |
+
|
26 |
+
|
27 |
+
class CheckpointsUtilsTest(absltest.TestCase):
|
28 |
+
|
29 |
+
def setUp(self):
|
30 |
+
super().setUp()
|
31 |
+
self.checkpoints_dir = self.create_tempdir()
|
32 |
+
self.ckpt_dir_path = self.checkpoints_dir.full_path
|
33 |
+
self.pinned_ckpt_file = os.path.join(self.ckpt_dir_path, "PINNED")
|
34 |
+
self.checkpoints_dir.create_file("checkpoint")
|
35 |
+
# Create a `train_ds` file representing the dataset checkpoint.
|
36 |
+
train_ds_basename = "train_ds-00000-of-00001"
|
37 |
+
self.train_ds_file = os.path.join(self.ckpt_dir_path, train_ds_basename)
|
38 |
+
self.checkpoints_dir.create_file(train_ds_basename)
|
39 |
+
|
40 |
+
def test_always_keep_checkpoint_file(self):
|
41 |
+
self.assertEqual(
|
42 |
+
"/path/to/ckpt/dir/PINNED",
|
43 |
+
checkpoint_utils.pinned_checkpoint_filepath("/path/to/ckpt/dir"))
|
44 |
+
|
45 |
+
def test_is_pinned_checkpoint_false_by_default(self):
|
46 |
+
# Ensure regular checkpoint without PINNED file.
|
47 |
+
self.assertFalse(gfile.exists(os.path.join(self.ckpt_dir_path, "PINNED")))
|
48 |
+
|
49 |
+
# Validate checkpoints are not pinned by default.
|
50 |
+
self.assertFalse(checkpoint_utils.is_pinned_checkpoint(self.ckpt_dir_path))
|
51 |
+
|
52 |
+
def test_is_pinned_checkpoint(self):
|
53 |
+
# Ensure the checkpoint directory as pinned.
|
54 |
+
pinned_ckpt_testdata = os.path.join(TESTDATA, "pinned_ckpt_dir")
|
55 |
+
pinned_file = os.path.join(pinned_ckpt_testdata, "PINNED")
|
56 |
+
self.assertTrue(gfile.exists(pinned_file))
|
57 |
+
|
58 |
+
# Test and validate.
|
59 |
+
self.assertTrue(checkpoint_utils.is_pinned_checkpoint(pinned_ckpt_testdata))
|
60 |
+
|
61 |
+
def test_is_pinned_missing_ckpt(self):
|
62 |
+
self.assertFalse(
|
63 |
+
checkpoint_utils.is_pinned_checkpoint(
|
64 |
+
os.path.join(self.ckpt_dir_path, "ckpt_does_not_exist")))
|
65 |
+
|
66 |
+
def test_pin_checkpoint(self):
|
67 |
+
# Ensure directory isn't already pinned.
|
68 |
+
self.assertFalse(gfile.exists(self.pinned_ckpt_file))
|
69 |
+
|
70 |
+
# Test.
|
71 |
+
checkpoint_utils.pin_checkpoint(self.ckpt_dir_path)
|
72 |
+
|
73 |
+
# Validate.
|
74 |
+
self.assertTrue(gfile.exists(self.pinned_ckpt_file))
|
75 |
+
with open(self.pinned_ckpt_file) as f:
|
76 |
+
self.assertEqual("1", f.read())
|
77 |
+
|
78 |
+
def test_pin_checkpoint_txt(self):
|
79 |
+
checkpoint_utils.pin_checkpoint(self.ckpt_dir_path, "TEXT_IN_PINNED")
|
80 |
+
self.assertTrue(os.path.exists(os.path.join(self.ckpt_dir_path, "PINNED")))
|
81 |
+
with open(self.pinned_ckpt_file) as f:
|
82 |
+
self.assertEqual("TEXT_IN_PINNED", f.read())
|
83 |
+
|
84 |
+
def test_unpin_checkpoint(self):
|
85 |
+
# Mark the checkpoint directory as pinned.
|
86 |
+
self.checkpoints_dir.create_file("PINNED")
|
87 |
+
self.assertTrue(checkpoint_utils.is_pinned_checkpoint(self.ckpt_dir_path))
|
88 |
+
|
89 |
+
# Test.
|
90 |
+
checkpoint_utils.unpin_checkpoint(self.ckpt_dir_path)
|
91 |
+
|
92 |
+
# Validate the "PINNED" checkpoint file got removed.
|
93 |
+
self.assertFalse(gfile.exists(os.path.join(self.ckpt_dir_path, "PINNED")))
|
94 |
+
|
95 |
+
def test_unpin_checkpoint_does_not_exist(self):
|
96 |
+
missing_ckpt_path = os.path.join(self.ckpt_dir_path, "ckpt_does_not_exist")
|
97 |
+
self.assertFalse(gfile.exists(missing_ckpt_path))
|
98 |
+
|
99 |
+
# Test. Assert does not raise error.
|
100 |
+
try:
|
101 |
+
checkpoint_utils.unpin_checkpoint(missing_ckpt_path)
|
102 |
+
except IOError:
|
103 |
+
# TODO(b/172262005): Remove traceback.format_exc() from the error message.
|
104 |
+
self.fail("Unpin checkpoint failed with: %s" % traceback.format_exc())
|
105 |
+
|
106 |
+
def test_remove_checkpoint_dir(self):
|
107 |
+
# Ensure the checkpoint directory is setup.
|
108 |
+
assert gfile.exists(self.ckpt_dir_path)
|
109 |
+
|
110 |
+
# Test.
|
111 |
+
checkpoint_utils.remove_checkpoint_dir(self.ckpt_dir_path)
|
112 |
+
|
113 |
+
# Validate the checkpoint directory got removed.
|
114 |
+
self.assertFalse(gfile.exists(self.ckpt_dir_path))
|
115 |
+
|
116 |
+
def test_remove_checkpoint_dir_pinned(self):
|
117 |
+
# Mark the checkpoint directory as pinned so it does not get removed.
|
118 |
+
self.checkpoints_dir.create_file("PINNED")
|
119 |
+
|
120 |
+
# Test.
|
121 |
+
checkpoint_utils.remove_checkpoint_dir(self.ckpt_dir_path)
|
122 |
+
|
123 |
+
# Validate the checkpoint directory still exists.
|
124 |
+
self.assertTrue(gfile.exists(self.ckpt_dir_path))
|
125 |
+
|
126 |
+
def test_remove_dataset_checkpoint(self):
|
127 |
+
# Ensure the checkpoint directory is setup.
|
128 |
+
assert gfile.exists(self.ckpt_dir_path)
|
129 |
+
|
130 |
+
# Test.
|
131 |
+
checkpoint_utils.remove_dataset_checkpoint(self.ckpt_dir_path, "train_ds")
|
132 |
+
|
133 |
+
# Validate the checkpoint directory got removed.
|
134 |
+
self.assertFalse(gfile.exists(self.train_ds_file))
|
135 |
+
self.assertTrue(gfile.exists(self.ckpt_dir_path))
|
136 |
+
|
137 |
+
def test_remove_dataset_checkpoint_pinned(self):
|
138 |
+
# Mark the checkpoint directory as pinned so it does not get removed.
|
139 |
+
self.checkpoints_dir.create_file("PINNED")
|
140 |
+
|
141 |
+
# Test.
|
142 |
+
checkpoint_utils.remove_dataset_checkpoint(self.ckpt_dir_path, "train_ds")
|
143 |
+
|
144 |
+
# Validate the checkpoint directory still exists.
|
145 |
+
self.assertTrue(gfile.exists(self.train_ds_file))
|
146 |
+
self.assertTrue(gfile.exists(self.ckpt_dir_path))
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
absltest.main()
|