jiang commited on
Commit
650c5f6
·
1 Parent(s): 0e87e30

init commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. CODE_OF_CONDUCT.md +4 -0
  3. CONTRIBUTING.md +59 -0
  4. LICENSE +201 -0
  5. NOTICE +1 -0
  6. README.md +185 -13
  7. app.py +47 -0
  8. bert/activations.py +56 -0
  9. bert/configuration_bert.py +143 -0
  10. bert/configuration_utils.py +408 -0
  11. bert/file_utils.py +808 -0
  12. bert/generation_utils.py +993 -0
  13. bert/modeling_bert.py +1569 -0
  14. bert/modeling_utils.py +1269 -0
  15. bert/tokenization_bert.py +545 -0
  16. bert/tokenization_utils.py +723 -0
  17. bert/tokenization_utils_base.py +0 -0
  18. criterions/__init__.py +1 -0
  19. criterions/label_smoothed_cross_entropy.py +394 -0
  20. data/__init__.py +2 -0
  21. data/base_dataset.py +84 -0
  22. data/create_finetuning_data.py +123 -0
  23. data/create_pretraining_data.py +80 -0
  24. data/data_utils.py +606 -0
  25. data/file_dataset.py +112 -0
  26. data/poly_utils.py +294 -0
  27. data/refcoco_dataset.py +294 -0
  28. data/refcoco_pretrain_dataset.py +232 -0
  29. data/val_test_files.p +0 -0
  30. demo.py +410 -0
  31. evaluate.py +185 -0
  32. fairseq/.github/ISSUE_TEMPLATE.md +3 -0
  33. fairseq/.github/ISSUE_TEMPLATE/bug_report.md +43 -0
  34. fairseq/.github/ISSUE_TEMPLATE/documentation.md +15 -0
  35. fairseq/.github/ISSUE_TEMPLATE/feature_request.md +24 -0
  36. fairseq/.github/ISSUE_TEMPLATE/how-to-question.md +33 -0
  37. fairseq/.github/PULL_REQUEST_TEMPLATE.md +16 -0
  38. fairseq/.github/stale.yml +30 -0
  39. fairseq/.github/workflows/build.yml +55 -0
  40. fairseq/.github/workflows/build_wheels.yml +41 -0
  41. fairseq/.gitignore +136 -0
  42. fairseq/.gitmodules +4 -0
  43. fairseq/CODE_OF_CONDUCT.md +77 -0
  44. fairseq/CONTRIBUTING.md +28 -0
  45. fairseq/LICENSE +21 -0
  46. fairseq/README.md +229 -0
  47. fairseq/examples/.gitignore +2 -0
  48. fairseq/examples/__init__.py +9 -0
  49. fairseq/examples/adaptive_span/README.md +90 -0
  50. fairseq/examples/adaptive_span/__init__.py +19 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ## Code of Conduct
2
+ This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3
+ For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4
+ [email protected] with any additional questions or comments.
CONTRIBUTING.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing Guidelines
2
+
3
+ Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4
+ documentation, we greatly value feedback and contributions from our community.
5
+
6
+ Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7
+ information to effectively respond to your bug report or contribution.
8
+
9
+
10
+ ## Reporting Bugs/Feature Requests
11
+
12
+ We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13
+
14
+ When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15
+ reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16
+
17
+ * A reproducible test case or series of steps
18
+ * The version of our code being used
19
+ * Any modifications you've made relevant to the bug
20
+ * Anything unusual about your environment or deployment
21
+
22
+
23
+ ## Contributing via Pull Requests
24
+ Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25
+
26
+ 1. You are working against the latest source on the *main* branch.
27
+ 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28
+ 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29
+
30
+ To send us a pull request, please:
31
+
32
+ 1. Fork the repository.
33
+ 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34
+ 3. Ensure local tests pass.
35
+ 4. Commit to your fork using clear commit messages.
36
+ 5. Send us a pull request, answering any default questions in the pull request interface.
37
+ 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38
+
39
+ GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40
+ [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41
+
42
+
43
+ ## Finding contributions to work on
44
+ Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45
+
46
+
47
+ ## Code of Conduct
48
+ This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49
+ For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50
+ [email protected] with any additional questions or comments.
51
+
52
+
53
+ ## Security issue notifications
54
+ If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55
+
56
+
57
+ ## Licensing
58
+
59
+ See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 1999-2022 Alibaba Group Holding Ltd.
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
NOTICE ADDED
@@ -0,0 +1 @@
 
 
1
+ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
README.md CHANGED
@@ -1,13 +1,185 @@
1
- ---
2
- title: PolyFormer
3
- emoji: 🔥
4
- colorFrom: blue
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 3.29.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PolyFormer: Referring Image Segmentation as Sequential Polygon Generation (CVPR 2023)
2
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/polyformer-referring-image-segmentation-as/referring-expression-segmentation-on-refcocog)](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcocog?p=polyformer-referring-image-segmentation-as)
3
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/polyformer-referring-image-segmentation-as/referring-expression-segmentation-on-refcoco)](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcoco?p=polyformer-referring-image-segmentation-as)
4
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/polyformer-referring-image-segmentation-as/referring-expression-segmentation-on-refcoco-1)](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcoco-1?p=polyformer-referring-image-segmentation-as)
5
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/polyformer-referring-image-segmentation-as/referring-expression-comprehension-on-refcoco)](https://paperswithcode.com/sota/referring-expression-comprehension-on-refcoco?p=polyformer-referring-image-segmentation-as)
6
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/polyformer-referring-image-segmentation-as/referring-expression-comprehension-on-refcoco-1)](https://paperswithcode.com/sota/referring-expression-comprehension-on-refcoco-1?p=polyformer-referring-image-segmentation-as)
7
+
8
+
9
+ \[[Project Page](https://polyformer.github.io/)\] \[[Paper](https://arxiv.org/abs/2302.07387)\]
10
+
11
+ by [Jiang Liu*](https://joellliu.github.io/), [Hui Ding*](http://www.huiding.org/), [Zhaowei Cai](https://zhaoweicai.github.io/), [Yuting Zhang](https://scholar.google.com/citations?user=9UfZJskAAAAJ&hl=en), [Ravi Kumar Satzoda](https://scholar.google.com.sg/citations?user=4ngycwIAAAAJ&hl=en), [Vijay Mahadevan](https://scholar.google.com/citations?user=n9fRgvkAAAAJ&hl=en), [R. Manmatha](https://ciir.cs.umass.edu/~manmatha/).
12
+
13
+
14
+ ## :notes: Introduction
15
+ ![github_figure](pipeline.gif)
16
+ PolyFormer is a unified model for referring image segmentation (polygon vertex sequence) and referring expression comprehension (bounding box corner points). The polygons are converted to segmentation masks in the end.
17
+
18
+ **Contributions:**
19
+
20
+ * State-of-the-art results on referring image segmentation and referring expression comprehension on 6 datasets;
21
+ * A unified framework for referring image segmentation (RIS) and referring expression comprehension (REC) by formulating them as a sequence-to-sequence (seq2seq) prediction problem;
22
+ * A regression-based decoder for accurate coordinate prediction, which outputs continuous 2D coordinates directly without quantization error..
23
+
24
+
25
+
26
+ ## Getting Started
27
+ ### Installation
28
+ ```bash
29
+ conda create -n polyformer python=3.7.4
30
+ conda activate polyformer
31
+ python -m pip install -r requirements.txt
32
+ ```
33
+ Note: if you are getting import errors from `fairseq`, try the following:
34
+ ```bash
35
+ python -m pip install pip==21.2.4
36
+ pip uninstall fairseq
37
+ pip install -r requirements.txt
38
+ ```
39
+
40
+ ## Datasets
41
+ ### Prepare Pretraining Data
42
+ 1. Create the dataset folders
43
+ ```bash
44
+ mkdir datasets
45
+ mkdir datasets/images
46
+ mkdir datasets/annotations
47
+ ```
48
+ 2. Download the *2014 Train images [83K/13GB]* from [COCO](https://cocodataset.org/#download),
49
+ original [Flickr30K images](http://shannon.cs.illinois.edu/DenotationGraph/),
50
+ [ReferItGame images](https://drive.google.com/file/d/1R6Tm7tQTHCil6A_eOhjudK3rgaBxkD2t/view?usp=sharing),
51
+ and [Visual Genome images](http://visualgenome.org/api/v0/api_home.html), and extract them to `datasets/images`.
52
+ 3. Download the annotation file for pretraining datasets [instances.json](https://drive.google.com/drive/folders/1O4hzL8_s3aUsnj_JZnM3CwANd7TejcJO)
53
+ provided by [SeqTR](https://github.com/sean-zhuh/SeqTR) and store it in `datasets/annotations`.
54
+ The workspace directory should be organized like this:
55
+ ```
56
+ PolyFormer/
57
+ ├── datasets/
58
+ │   ├── images
59
+ │   │   ├── flickr30k/*.jpg
60
+ │   │   ├── mscoco/
61
+ │   │ │  └── train2014/*.jpg
62
+ │   │   ├── saiaprtc12/*.jpg
63
+ │   │   └── visual-genome/*.jpg
64
+ │   └── annotations
65
+ │      └── instances.json
66
+ └── ...
67
+ ```
68
+ 4. Generate the tsv files for pretraining
69
+ ```bash
70
+ python data/create_pretraining_data.py
71
+ ```
72
+ ### Prepare Finetuning Data
73
+ 1. Follow the instructions in the `./refer` directory to set up subdirectories
74
+ and download annotations.
75
+ This directory is based on the [refer](https://github.com/lichengunc/refer) API.
76
+
77
+ 2. Generate the tsv files for finetuning
78
+ ```bash
79
+ python data/create_finetuning_data.py
80
+ ```
81
+
82
+
83
+
84
+
85
+ ## Pretraining
86
+ 1. Create the checkpoints folder
87
+ ```bash
88
+ mkdir weights
89
+ ```
90
+ 2. Download pretrain weights of [Swin-base](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth),
91
+ [Swin-large](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth),
92
+ [BERT-base](https://cdn.huggingface.co/bert-base-uncased-pytorch_model.bin)
93
+ and put the weight files in `./pretrained_weights`.
94
+ These weights are needed for training to initialize the model.
95
+
96
+
97
+ 3. Run the pretraining scripts for model pretraining on the referring expression comprehension task:
98
+ ```bash
99
+ cd run_scripts/pretrain
100
+ bash pretrain_polyformer_b.sh # for pretraining PolyFormer-B model
101
+ bash pretrain_polyformer_l.sh # for pretraining PolyFormer-L model
102
+ ```
103
+
104
+ ## Finetuning
105
+ Run the finetuning scripts for model pretraining on the referring image segmentation and referring expression comprehension tasks:
106
+ ```bash
107
+ cd run_scripts/finetune
108
+ bash train_polyformer_b.sh # for finetuning PolyFormer-B model
109
+ bash train_polyformer_l.sh # for finetuning PolyFormer-L model
110
+ ```
111
+ Please make sure to link the pretrain weight paths (Line 20) in the finetuning scripts to the best pretraining checkpoints.
112
+
113
+ ## Evaluation
114
+ Run the evaluation scripts for evaluating on the referring image segmentation and referring expression comprehension tasks:
115
+ ```bash
116
+ cd run_scripts/evaluation
117
+
118
+ # for evaluating PolyFormer-B model
119
+ bash evaluate_polyformer_b_refcoco.sh
120
+ bash evaluate_polyformer_b_refcoco+.sh
121
+ bash evaluate_polyformer_b_refcocog.sh
122
+
123
+ # for evaluating PolyFormer-L model
124
+ bash evaluate_polyformer_l_refcoco.sh
125
+ bash evaluate_polyformer_l_refcoco+.sh
126
+ bash evaluate_polyformer_l_refcocog.sh
127
+ ```
128
+
129
+ ## Model Zoo
130
+ Download the model weights to `./weights` if you want to use our trained models for finetuning and evaluation.
131
+
132
+ | | Refcoco val| | | Refcoco testA| | | Refcoco testB| ||
133
+ |-------------------------------------------------------------------------------------------------------|------|------|---------|------|-------|------|-----|------|------|
134
+ | Model | oIoU | mIoU | [email protected] | oIoU | mIoU |[email protected] | oIoU | mIoU |[email protected] |
135
+ | [PolyFormer-B](https://drive.google.com/file/d/1K0y-WBO6cL7gBzNnJaHAeNu3pgq4DbJ9/view?usp=share_link) | 74.82| 75.96 | 89.73 |76.64| 77.09 | 91.73| 71.06| 73.22 | 86.03 |
136
+ | [PolyFormer-L](https://drive.google.com/file/d/15P6m5RI6HAQE2QXQXMAjw_oBsaPii7b3/view?usp=share_link) | 75.96| 76.94 | 90.38 |78.29| 78.49 | 92.89| 73.25| 74.83 | 87.16|
137
+
138
+
139
+ | [test_demo.py](..%2F..%2FDownloads%2Ftest_demo.py) | Refcoco val| | | Refcoco testA| | | Refcoco testB| ||
140
+ |--------------------------------------------------------------------------------------------------------|------|------|------|------|------|------|------|------|------|
141
+ | Model | oIoU | mIoU |[email protected]| oIoU | mIoU |[email protected] | oIoU | mIoU |[email protected] |
142
+ | [PolyFormer-B ](https://drive.google.com/file/d/12_ylFhsbqGySxDqgeEByn8nKoJtT2n2w/view?usp=share_link) | 67.64| 70.65 | 83.73 | 72.89| 74.51 | 88.60 | 59.33| 64.64 | 76.38 | 67.76| 69.36 |
143
+ | [PolyFormer-L](https://drive.google.com/file/d/1lUCv7dUPctEz4vEpPr7aI8A8ZmfYCB8y/view?usp=share_link) | 69.33| 72.15 | 84.98 | 74.56| 75.71 | 89.77 | 61.87| 66.73 | 77.97 | 69.20| 71.15 |
144
+
145
+
146
+ | | Refcocog val| || | Refcocog test| |
147
+ |-------------------------------------------------------------------------------------------------------|------|------|------|------|------|------|
148
+ | Model | oIoU | mIoU |[email protected] | oIoU | mIoU |[email protected] |
149
+ | [PolyFormer-B](https://drive.google.com/file/d/12_ylFhsbqGySxDqgeEByn8nKoJtT2n2w/view?usp=share_link) | 67.76| 69.36 | 84.46| 69.05| 69.88 | 84.96 |
150
+ | [PolyFormer-L](https://drive.google.com/file/d/1lUCv7dUPctEz4vEpPr7aI8A8ZmfYCB8y/view?usp=share_link) | 69.20| 71.15 | 85.83 | 70.19| 71.17 | 85.91|
151
+
152
+ * Pretrained weights:
153
+ * [PolyFormer-B](https://drive.google.com/file/d/1sAzfChYDdHdaeatB2K14lrJjG4uiXAol/view?usp=share_link)
154
+ * [PolyFormer-L](https://drive.google.com/file/d/1knRxgM1lmEkuZZ-cOm_fmwKP1H0bJGU9/view?usp=share_link)
155
+
156
+ # Acknowlegement
157
+ This codebase is developed based on [OFA](https://github.com/OFA-Sys/OFA).
158
+ Other related codebases include:
159
+ * [Fairseq](https://github.com/pytorch/fairseq)
160
+ * [refer](https://github.com/lichengunc/refer)
161
+ * [LAVT-RIS](https://github.com/yz93/LAVT-RIS/)
162
+ * [SeqTR](https://github.com/sean-zhuh/SeqTR)
163
+
164
+
165
+
166
+ # Citation
167
+ Please cite our paper if you find this codebase helpful :)
168
+
169
+ ```
170
+ @inproceedings{liu2023polyformer,
171
+ title={PolyFormer: Referring Image Segmentation as Sequential Polygon Generation},
172
+ author={Liu, Jiang and Ding, Hui and Cai, Zhaowei and Zhang, Yuting and Satzoda, Ravi Kumar and Mahadevan, Vijay and Manmatha, R},
173
+ booktitle={CVPR},
174
+ year={2023}
175
+ }
176
+ ```
177
+
178
+ ## Security
179
+
180
+ See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
181
+
182
+ ## License
183
+
184
+ This project is licensed under the Apache-2.0 License.
185
+
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from fairseq import utils,tasks
5
+ from utils.checkpoint_utils import load_model_ensemble_and_task
6
+ from models.polyformer import PolyFormerModel
7
+ import cv2
8
+
9
+ import torch
10
+ import numpy as np
11
+ from fairseq import utils, tasks
12
+ from fairseq import checkpoint_utils
13
+ from utils.eval_utils import eval_step
14
+ from tasks.refcoco import RefcocoTask
15
+ from models.polyformer import PolyFormerModel
16
+ from PIL import Image
17
+ from torchvision import transforms
18
+ import cv2
19
+ import gradio as gr
20
+ import math
21
+ from io import BytesIO
22
+ import base64
23
+ import re
24
+ from demo import visual_grounding
25
+
26
+ title = "PolyFormer-Visual_Grounding"
27
+ description = "Gradio Demo for PolyFormer-Visual_Grounding. Upload your own image or click any one of the examples, " \
28
+ "and write a description about a certain object. " \
29
+ "Then click \"Submit\" and wait for the result of grounding. For help or to provide feedback, please contact: Hui Ding (@huidin)"
30
+ article = "<p style='text-align: center'><a href='https://arxiv.org/pdf/2302.07387.pdf' target='_blank'>PolyFormer CVPR2023" \
31
+ "</a></p> "
32
+ # examples = [['A bear astronaut in the space.jpeg', 'a bear astronaut in the space'],
33
+ # ['A unicorn doing computer vision research.jpeg', 'a unicorn doing computer vision research'],
34
+ # ['pig.jpeg', 'a pig robot preparing a delicious meal'],
35
+ # ['otta.png', 'a gentleman otter in a 19th century portrait'],
36
+ # ['pikachu.jpeg', 'a pikachu fine-dining with a view to the Eiffel Tower'],
37
+ # ['A small cabin on top of a snowy mountain in the style of Disney artstation.jpeg', 'a small cabin on top of a snowy mountain in the style of Disney artstation'],
38
+ #
39
+ # ]
40
+ examples = []
41
+ io = gr.Interface(fn=visual_grounding, inputs=[gr.inputs.Image(type='pil'), "textbox"],
42
+ outputs=[gr.outputs.Image(label="output", type='numpy'), gr.outputs.Image(label="predicted mask", type='numpy')],
43
+ title=title, description=description, article=article, examples=examples,
44
+ allow_flagging=False, allow_screenshot=False)
45
+ # io.launch(cache_examples=True)
46
+ io.launch(share=True)
47
+
bert/activations.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def swish(x):
12
+ return x * torch.sigmoid(x)
13
+
14
+
15
+ def _gelu_python(x):
16
+ """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
17
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
18
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
19
+ This is now written in C in torch.nn.functional
20
+ Also see https://arxiv.org/abs/1606.08415
21
+ """
22
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
23
+
24
+
25
+ def gelu_new(x):
26
+ """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
27
+ Also see https://arxiv.org/abs/1606.08415
28
+ """
29
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
30
+
31
+
32
+ if torch.__version__ < "1.4.0":
33
+ gelu = _gelu_python
34
+ else:
35
+ gelu = F.gelu
36
+
37
+
38
+ def gelu_fast(x):
39
+ return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
40
+
41
+
42
+ ACT2FN = {
43
+ "relu": F.relu,
44
+ "swish": swish,
45
+ "gelu": gelu,
46
+ "tanh": torch.tanh,
47
+ "gelu_new": gelu_new,
48
+ "gelu_fast": gelu_fast,
49
+ }
50
+
51
+
52
+ def get_activation(activation_string):
53
+ if activation_string in ACT2FN:
54
+ return ACT2FN[activation_string]
55
+ else:
56
+ raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
bert/configuration_bert.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ BERT model configuration """
17
+
18
+
19
+ import logging
20
+
21
+ from .configuration_utils import PretrainedConfig
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
28
+ "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
29
+ "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
30
+ "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
31
+ "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
32
+ "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
33
+ "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
34
+ "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
35
+ "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
36
+ "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
37
+ "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
38
+ "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
39
+ "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
40
+ "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
41
+ "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
42
+ "cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json",
43
+ "cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json",
44
+ "cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json",
45
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json",
46
+ "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
47
+ "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
48
+ "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json",
49
+ # See all BERT models at https://huggingface.co/models?filter=bert
50
+ }
51
+
52
+
53
+ class BertConfig(PretrainedConfig):
54
+ r"""
55
+ This is the configuration class to store the configuration of a :class:`~transformers.BertModel`.
56
+ It is used to instantiate an BERT model according to the specified arguments, defining the model
57
+ architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
58
+ the BERT `bert-base-uncased <https://huggingface.co/bert-base-uncased>`__ architecture.
59
+
60
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
61
+ to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
62
+ for more information.
63
+
64
+
65
+ Args:
66
+ vocab_size (:obj:`int`, optional, defaults to 30522):
67
+ Vocabulary size of the BERT model. Defines the different tokens that
68
+ can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
69
+ hidden_size (:obj:`int`, optional, defaults to 768):
70
+ Dimensionality of the encoder layers and the pooler layer.
71
+ num_hidden_layers (:obj:`int`, optional, defaults to 12):
72
+ Number of hidden layers in the Transformer encoder.
73
+ num_attention_heads (:obj:`int`, optional, defaults to 12):
74
+ Number of attention heads for each attention layer in the Transformer encoder.
75
+ intermediate_size (:obj:`int`, optional, defaults to 3072):
76
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
77
+ hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
78
+ The non-linear activation function (function or string) in the encoder and pooler.
79
+ If string, "gelu", "relu", "swish" and "gelu_new" are supported.
80
+ hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
81
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
82
+ attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
83
+ The dropout ratio for the attention probabilities.
84
+ max_position_embeddings (:obj:`int`, optional, defaults to 512):
85
+ The maximum sequence length that this model might ever be used with.
86
+ Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
87
+ type_vocab_size (:obj:`int`, optional, defaults to 2):
88
+ The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
89
+ initializer_range (:obj:`float`, optional, defaults to 0.02):
90
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
91
+ layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
92
+ The epsilon used by the layer normalization layers.
93
+ gradient_checkpointing (:obj:`bool`, optional, defaults to False):
94
+ If True, use gradient checkpointing to save memory at the expense of slower backward pass.
95
+
96
+ Example::
97
+
98
+ >>> from transformers import BertModel, BertConfig
99
+
100
+ >>> # Initializing a BERT bert-base-uncased style configuration
101
+ >>> configuration = BertConfig()
102
+
103
+ >>> # Initializing a model from the bert-base-uncased style configuration
104
+ >>> model = BertModel(configuration)
105
+
106
+ >>> # Accessing the model configuration
107
+ >>> configuration = model.config
108
+ """
109
+ model_type = "bert"
110
+
111
+ def __init__(
112
+ self,
113
+ vocab_size=30522,
114
+ hidden_size=768,
115
+ num_hidden_layers=12,
116
+ num_attention_heads=12,
117
+ intermediate_size=3072,
118
+ hidden_act="gelu",
119
+ hidden_dropout_prob=0.1,
120
+ attention_probs_dropout_prob=0.1,
121
+ max_position_embeddings=512,
122
+ type_vocab_size=2,
123
+ initializer_range=0.02,
124
+ layer_norm_eps=1e-12,
125
+ pad_token_id=0,
126
+ gradient_checkpointing=False,
127
+ **kwargs
128
+ ):
129
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
130
+
131
+ self.vocab_size = vocab_size
132
+ self.hidden_size = hidden_size
133
+ self.num_hidden_layers = num_hidden_layers
134
+ self.num_attention_heads = num_attention_heads
135
+ self.hidden_act = hidden_act
136
+ self.intermediate_size = intermediate_size
137
+ self.hidden_dropout_prob = hidden_dropout_prob
138
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
139
+ self.max_position_embeddings = max_position_embeddings
140
+ self.type_vocab_size = type_vocab_size
141
+ self.initializer_range = initializer_range
142
+ self.layer_norm_eps = layer_norm_eps
143
+ self.gradient_checkpointing = gradient_checkpointing
bert/configuration_utils.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Configuration base class and utilities."""
17
+
18
+
19
+ import copy
20
+ import json
21
+ import logging
22
+ import os
23
+ from typing import Dict, Tuple
24
+
25
+ from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class PretrainedConfig(object):
32
+ r""" Base class for all configuration classes.
33
+ Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
34
+
35
+ Note:
36
+ A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights.
37
+ It only affects the model's configuration.
38
+
39
+ Class attributes (overridden by derived classes):
40
+ - ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`.
41
+
42
+ Args:
43
+ finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`):
44
+ Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
45
+ num_labels (:obj:`int`, `optional`, defaults to `2`):
46
+ Number of classes to use when the model is a classification model (sequences/tokens)
47
+ output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`False`):
48
+ Should the model returns all hidden-states.
49
+ output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
50
+ Should the model returns all attentions.
51
+ torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
52
+ Is the model used with Torchscript (for PyTorch models).
53
+ """
54
+ model_type: str = ""
55
+
56
+ def __init__(self, **kwargs):
57
+ # Attributes with defaults
58
+ self.output_hidden_states = kwargs.pop("output_hidden_states", False)
59
+ self.output_attentions = kwargs.pop("output_attentions", False)
60
+ self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
61
+ self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
62
+ self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
63
+ self.pruned_heads = kwargs.pop("pruned_heads", {})
64
+
65
+ # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
66
+ self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
67
+ self.is_decoder = kwargs.pop("is_decoder", False)
68
+
69
+ # Parameters for sequence generation
70
+ self.max_length = kwargs.pop("max_length", 20)
71
+ self.min_length = kwargs.pop("min_length", 0)
72
+ self.do_sample = kwargs.pop("do_sample", False)
73
+ self.early_stopping = kwargs.pop("early_stopping", False)
74
+ self.num_beams = kwargs.pop("num_beams", 1)
75
+ self.temperature = kwargs.pop("temperature", 1.0)
76
+ self.top_k = kwargs.pop("top_k", 50)
77
+ self.top_p = kwargs.pop("top_p", 1.0)
78
+ self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
79
+ self.length_penalty = kwargs.pop("length_penalty", 1.0)
80
+ self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
81
+ self.bad_words_ids = kwargs.pop("bad_words_ids", None)
82
+ self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
83
+
84
+ # Fine-tuning task arguments
85
+ self.architectures = kwargs.pop("architectures", None)
86
+ self.finetuning_task = kwargs.pop("finetuning_task", None)
87
+ self.id2label = kwargs.pop("id2label", None)
88
+ self.label2id = kwargs.pop("label2id", None)
89
+ if self.id2label is not None:
90
+ kwargs.pop("num_labels", None)
91
+ self.id2label = dict((int(key), value) for key, value in self.id2label.items())
92
+ # Keys are always strings in JSON so convert ids to int here.
93
+ else:
94
+ self.num_labels = kwargs.pop("num_labels", 2)
95
+
96
+ # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
97
+ self.prefix = kwargs.pop("prefix", None)
98
+ self.bos_token_id = kwargs.pop("bos_token_id", None)
99
+ self.pad_token_id = kwargs.pop("pad_token_id", None)
100
+ self.eos_token_id = kwargs.pop("eos_token_id", None)
101
+ self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
102
+
103
+ # task specific arguments
104
+ self.task_specific_params = kwargs.pop("task_specific_params", None)
105
+
106
+ # TPU arguments
107
+ self.xla_device = kwargs.pop("xla_device", None)
108
+
109
+ # Additional attributes without default values
110
+ for key, value in kwargs.items():
111
+ try:
112
+ setattr(self, key, value)
113
+ except AttributeError as err:
114
+ logger.error("Can't set {} with value {} for {}".format(key, value, self))
115
+ raise err
116
+
117
+ @property
118
+ def num_labels(self):
119
+ return len(self.id2label)
120
+
121
+ @num_labels.setter
122
+ def num_labels(self, num_labels):
123
+ self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
124
+ self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
125
+
126
+ def save_pretrained(self, save_directory):
127
+ """
128
+ Save a configuration object to the directory `save_directory`, so that it
129
+ can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
130
+
131
+ Args:
132
+ save_directory (:obj:`string`):
133
+ Directory where the configuration JSON file will be saved.
134
+ """
135
+ if os.path.isfile(save_directory):
136
+ raise AssertionError("Provided path ({}) should be a directory, not a file".format(save_directory))
137
+ os.makedirs(save_directory, exist_ok=True)
138
+ # If we save using the predefined names, we can load using `from_pretrained`
139
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
140
+
141
+ self.to_json_file(output_config_file, use_diff=True)
142
+ logger.info("Configuration saved in {}".format(output_config_file))
143
+
144
+ @classmethod
145
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
146
+ r"""
147
+
148
+ Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
149
+
150
+ Args:
151
+ pretrained_model_name_or_path (:obj:`string`):
152
+ either:
153
+ - a string with the `shortcut name` of a pre-trained model configuration to load from cache or
154
+ download, e.g.: ``bert-base-uncased``.
155
+ - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to
156
+ our S3, e.g.: ``dbmdz/bert-base-german-cased``.
157
+ - a path to a `directory` containing a configuration file saved using the
158
+ :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
159
+ - a path or url to a saved configuration JSON `file`, e.g.:
160
+ ``./my_model_directory/configuration.json``.
161
+ cache_dir (:obj:`string`, `optional`):
162
+ Path to a directory in which a downloaded pre-trained model
163
+ configuration should be cached if the standard cache should not be used.
164
+ kwargs (:obj:`Dict[str, any]`, `optional`):
165
+ The values in kwargs of any keys which are configuration attributes will be used to override the loaded
166
+ values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is
167
+ controlled by the `return_unused_kwargs` keyword parameter.
168
+ force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
169
+ Force to (re-)download the model weights and configuration files and override the cached versions if they exist.
170
+ resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
171
+ Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
172
+ proxies (:obj:`Dict`, `optional`):
173
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.:
174
+ :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.`
175
+ The proxies are used on each request.
176
+ return_unused_kwargs: (`optional`) bool:
177
+ If False, then this function returns just the final configuration object.
178
+ If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a
179
+ dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part
180
+ of kwargs which has not been used to update `config` and is otherwise ignored.
181
+
182
+ Returns:
183
+ :class:`PretrainedConfig`: An instance of a configuration object
184
+
185
+ Examples::
186
+
187
+ # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
188
+ # derived class: BertConfig
189
+ config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
190
+ config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
191
+ config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
192
+ config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
193
+ assert config.output_attention == True
194
+ config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
195
+ foo=False, return_unused_kwargs=True)
196
+ assert config.output_attention == True
197
+ assert unused_kwargs == {'foo': False}
198
+
199
+ """
200
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
201
+ return cls.from_dict(config_dict, **kwargs)
202
+
203
+ @classmethod
204
+ def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict, Dict]:
205
+ """
206
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
207
+ for instantiating a Config using `from_dict`.
208
+
209
+ Parameters:
210
+ pretrained_model_name_or_path (:obj:`string`):
211
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
212
+
213
+ Returns:
214
+ :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object.
215
+
216
+ """
217
+ cache_dir = kwargs.pop("cache_dir", None)
218
+ force_download = kwargs.pop("force_download", False)
219
+ resume_download = kwargs.pop("resume_download", False)
220
+ proxies = kwargs.pop("proxies", None)
221
+ local_files_only = kwargs.pop("local_files_only", False)
222
+
223
+ if os.path.isdir(pretrained_model_name_or_path):
224
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
225
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
226
+ config_file = pretrained_model_name_or_path
227
+ else:
228
+ config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
229
+
230
+ try:
231
+ # Load from URL or cache if already cached
232
+ resolved_config_file = cached_path(
233
+ config_file,
234
+ cache_dir=cache_dir,
235
+ force_download=force_download,
236
+ proxies=proxies,
237
+ resume_download=resume_download,
238
+ local_files_only=local_files_only,
239
+ )
240
+ # Load config dict
241
+ if resolved_config_file is None:
242
+ raise EnvironmentError
243
+ config_dict = cls._dict_from_json_file(resolved_config_file)
244
+
245
+ except EnvironmentError:
246
+ msg = (
247
+ f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
248
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
249
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
250
+ )
251
+ raise EnvironmentError(msg)
252
+
253
+ except json.JSONDecodeError:
254
+ msg = (
255
+ "Couldn't reach server at '{}' to download configuration file or "
256
+ "configuration file is not a valid JSON file. "
257
+ "Please check network or file content here: {}.".format(config_file, resolved_config_file)
258
+ )
259
+ raise EnvironmentError(msg)
260
+
261
+ if resolved_config_file == config_file:
262
+ logger.info("loading configuration file {}".format(config_file))
263
+ else:
264
+ logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
265
+
266
+ return config_dict, kwargs
267
+
268
+ @classmethod
269
+ def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig":
270
+ """
271
+ Constructs a `Config` from a Python dictionary of parameters.
272
+
273
+ Args:
274
+ config_dict (:obj:`Dict[str, any]`):
275
+ Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved
276
+ from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict`
277
+ method.
278
+ kwargs (:obj:`Dict[str, any]`):
279
+ Additional parameters from which to initialize the configuration object.
280
+
281
+ Returns:
282
+ :class:`PretrainedConfig`: An instance of a configuration object
283
+ """
284
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
285
+
286
+ config = cls(**config_dict)
287
+
288
+ if hasattr(config, "pruned_heads"):
289
+ config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
290
+
291
+ # Update config with kwargs if needed
292
+ to_remove = []
293
+ for key, value in kwargs.items():
294
+ if hasattr(config, key):
295
+ setattr(config, key, value)
296
+ to_remove.append(key)
297
+ for key in to_remove:
298
+ kwargs.pop(key, None)
299
+
300
+ logger.info("Model config %s", str(config))
301
+ if return_unused_kwargs:
302
+ return config, kwargs
303
+ else:
304
+ return config
305
+
306
+ @classmethod
307
+ def from_json_file(cls, json_file: str) -> "PretrainedConfig":
308
+ """
309
+ Constructs a `Config` from the path to a json file of parameters.
310
+
311
+ Args:
312
+ json_file (:obj:`string`):
313
+ Path to the JSON file containing the parameters.
314
+
315
+ Returns:
316
+ :class:`PretrainedConfig`: An instance of a configuration object
317
+
318
+ """
319
+ config_dict = cls._dict_from_json_file(json_file)
320
+ return cls(**config_dict)
321
+
322
+ @classmethod
323
+ def _dict_from_json_file(cls, json_file: str):
324
+ with open(json_file, "r", encoding="utf-8") as reader:
325
+ text = reader.read()
326
+ return json.loads(text)
327
+
328
+ def __eq__(self, other):
329
+ return self.__dict__ == other.__dict__
330
+
331
+ def __repr__(self):
332
+ return "{} {}".format(self.__class__.__name__, self.to_json_string())
333
+
334
+ def to_diff_dict(self):
335
+ """
336
+ Removes all attributes from config which correspond to the default
337
+ config attributes for better readability and serializes to a Python
338
+ dictionary.
339
+
340
+ Returns:
341
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
342
+ """
343
+ config_dict = self.to_dict()
344
+
345
+ # get the default config dict
346
+ default_config_dict = PretrainedConfig().to_dict()
347
+
348
+ serializable_config_dict = {}
349
+
350
+ # only serialize values that differ from the default config
351
+ for key, value in config_dict.items():
352
+ if key not in default_config_dict or value != default_config_dict[key]:
353
+ serializable_config_dict[key] = value
354
+
355
+ return serializable_config_dict
356
+
357
+ def to_dict(self):
358
+ """
359
+ Serializes this instance to a Python dictionary.
360
+
361
+ Returns:
362
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
363
+ """
364
+ output = copy.deepcopy(self.__dict__)
365
+ if hasattr(self.__class__, "model_type"):
366
+ output["model_type"] = self.__class__.model_type
367
+ return output
368
+
369
+ def to_json_string(self, use_diff=True):
370
+ """
371
+ Serializes this instance to a JSON string.
372
+
373
+ Args:
374
+ use_diff (:obj:`bool`):
375
+ If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string.
376
+
377
+ Returns:
378
+ :obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
379
+ """
380
+ if use_diff is True:
381
+ config_dict = self.to_diff_dict()
382
+ else:
383
+ config_dict = self.to_dict()
384
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
385
+
386
+ def to_json_file(self, json_file_path, use_diff=True):
387
+ """
388
+ Save this instance to a json file.
389
+
390
+ Args:
391
+ json_file_path (:obj:`string`):
392
+ Path to the JSON file in which this configuration instance's parameters will be saved.
393
+ use_diff (:obj:`bool`):
394
+ If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file.
395
+ """
396
+ with open(json_file_path, "w", encoding="utf-8") as writer:
397
+ writer.write(self.to_json_string(use_diff=use_diff))
398
+
399
+ def update(self, config_dict: Dict):
400
+ """
401
+ Updates attributes of this class
402
+ with attributes from `config_dict`.
403
+
404
+ Args:
405
+ :obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class.
406
+ """
407
+ for key, value in config_dict.items():
408
+ setattr(self, key, value)
bert/file_utils.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for working with the local dataset cache.
3
+ This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4
+ Copyright by the AllenNLP authors.
5
+ """
6
+
7
+ import fnmatch
8
+ import json
9
+ import logging
10
+ import os
11
+ import shutil
12
+ import sys
13
+ import tarfile
14
+ import tempfile
15
+ from contextlib import contextmanager
16
+ from functools import partial, wraps
17
+ from hashlib import sha256
18
+ from pathlib import Path
19
+ from typing import Dict, Optional, Union
20
+ from urllib.parse import urlparse
21
+ from zipfile import ZipFile, is_zipfile
22
+
23
+ import requests
24
+ from filelock import FileLock
25
+ from tqdm.auto import tqdm
26
+
27
+ #from . import __version__
28
+ __version__ = "3.0.2"
29
+
30
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
31
+
32
+ try:
33
+ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
34
+ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
35
+ if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
36
+ import torch
37
+
38
+ _torch_available = True # pylint: disable=invalid-name
39
+ logger.info("PyTorch version {} available.".format(torch.__version__))
40
+ else:
41
+ logger.info("Disabling PyTorch because USE_TF is set")
42
+ _torch_available = False
43
+ except ImportError:
44
+ _torch_available = False # pylint: disable=invalid-name
45
+
46
+ try:
47
+ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
48
+ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
49
+
50
+ if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
51
+ import tensorflow as tf
52
+
53
+ assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
54
+ _tf_available = True # pylint: disable=invalid-name
55
+ logger.info("TensorFlow version {} available.".format(tf.__version__))
56
+ else:
57
+ logger.info("Disabling Tensorflow because USE_TORCH is set")
58
+ _tf_available = False
59
+ except (ImportError, AssertionError):
60
+ _tf_available = False # pylint: disable=invalid-name
61
+
62
+
63
+ try:
64
+ from torch.hub import _get_torch_home
65
+
66
+ torch_cache_home = _get_torch_home()
67
+ except ImportError:
68
+ torch_cache_home = os.path.expanduser(
69
+ os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
70
+ )
71
+
72
+
73
+ try:
74
+ import torch_xla.core.xla_model as xm # noqa: F401
75
+
76
+ if _torch_available:
77
+ _torch_tpu_available = True # pylint: disable=
78
+ else:
79
+ _torch_tpu_available = False
80
+ except ImportError:
81
+ _torch_tpu_available = False
82
+
83
+
84
+ try:
85
+ import psutil # noqa: F401
86
+
87
+ _psutil_available = True
88
+
89
+ except ImportError:
90
+ _psutil_available = False
91
+
92
+
93
+ try:
94
+ import py3nvml # noqa: F401
95
+
96
+ _py3nvml_available = True
97
+
98
+ except ImportError:
99
+ _py3nvml_available = False
100
+
101
+
102
+ try:
103
+ from apex import amp # noqa: F401
104
+
105
+ _has_apex = True
106
+ except ImportError:
107
+ _has_apex = False
108
+
109
+ default_cache_path = os.path.join(torch_cache_home, "transformers")
110
+
111
+
112
+ PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
113
+ PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
114
+ TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
115
+
116
+ WEIGHTS_NAME = "pytorch_model.bin"
117
+ TF2_WEIGHTS_NAME = "tf_model.h5"
118
+ TF_WEIGHTS_NAME = "model.ckpt"
119
+ CONFIG_NAME = "config.json"
120
+ MODEL_CARD_NAME = "modelcard.json"
121
+
122
+
123
+ MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]]
124
+ DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
125
+ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
126
+
127
+ S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
128
+ CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
129
+
130
+
131
+ def is_torch_available():
132
+ return _torch_available
133
+
134
+
135
+ def is_tf_available():
136
+ return _tf_available
137
+
138
+
139
+ def is_torch_tpu_available():
140
+ return _torch_tpu_available
141
+
142
+
143
+ def is_psutil_available():
144
+ return _psutil_available
145
+
146
+
147
+ def is_py3nvml_available():
148
+ return _py3nvml_available
149
+
150
+
151
+ def is_apex_available():
152
+ return _has_apex
153
+
154
+
155
+ def add_start_docstrings(*docstr):
156
+ def docstring_decorator(fn):
157
+ fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
158
+ return fn
159
+
160
+ return docstring_decorator
161
+
162
+
163
+ def add_start_docstrings_to_callable(*docstr):
164
+ def docstring_decorator(fn):
165
+ class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0])
166
+ intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name)
167
+ note = r"""
168
+
169
+ .. note::
170
+ Although the recipe for forward pass needs to be defined within
171
+ this function, one should call the :class:`Module` instance afterwards
172
+ instead of this since the former takes care of running the
173
+ pre and post processing steps while the latter silently ignores them.
174
+ """
175
+ fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
176
+ return fn
177
+
178
+ return docstring_decorator
179
+
180
+
181
+ def add_end_docstrings(*docstr):
182
+ def docstring_decorator(fn):
183
+ fn.__doc__ = fn.__doc__ + "".join(docstr)
184
+ return fn
185
+
186
+ return docstring_decorator
187
+
188
+
189
+ PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
190
+ Example::
191
+
192
+ >>> from transformers import {tokenizer_class}, {model_class}
193
+ >>> import torch
194
+
195
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
196
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
197
+
198
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
199
+ >>> labels = torch.tensor([1] * inputs["input_ids"].size(1)).unsqueeze(0) # Batch size 1
200
+
201
+ >>> outputs = model(**inputs, labels=labels)
202
+ >>> loss, scores = outputs[:2]
203
+ """
204
+
205
+ PT_QUESTION_ANSWERING_SAMPLE = r"""
206
+ Example::
207
+
208
+ >>> from transformers import {tokenizer_class}, {model_class}
209
+ >>> import torch
210
+
211
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
212
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
213
+
214
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
215
+ >>> start_positions = torch.tensor([1])
216
+ >>> end_positions = torch.tensor([3])
217
+
218
+ >>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
219
+ >>> loss, start_scores, end_scores = outputs[:3]
220
+ """
221
+
222
+ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
223
+ Example::
224
+
225
+ >>> from transformers import {tokenizer_class}, {model_class}
226
+ >>> import torch
227
+
228
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
229
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
230
+
231
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
232
+ >>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
233
+ >>> outputs = model(**inputs, labels=labels)
234
+ >>> loss, logits = outputs[:2]
235
+ """
236
+
237
+ PT_MASKED_LM_SAMPLE = r"""
238
+ Example::
239
+
240
+ >>> from transformers import {tokenizer_class}, {model_class}
241
+ >>> import torch
242
+
243
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
244
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
245
+
246
+ >>> input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"]
247
+
248
+ >>> outputs = model(input_ids, labels=input_ids)
249
+ >>> loss, prediction_scores = outputs[:2]
250
+ """
251
+
252
+ PT_BASE_MODEL_SAMPLE = r"""
253
+ Example::
254
+
255
+ >>> from transformers import {tokenizer_class}, {model_class}
256
+ >>> import torch
257
+
258
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
259
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
260
+
261
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
262
+ >>> outputs = model(**inputs)
263
+
264
+ >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
265
+ """
266
+
267
+ PT_MULTIPLE_CHOICE_SAMPLE = r"""
268
+ Example::
269
+
270
+ >>> from transformers import {tokenizer_class}, {model_class}
271
+ >>> import torch
272
+
273
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
274
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
275
+
276
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
277
+ >>> choice0 = "It is eaten with a fork and a knife."
278
+ >>> choice1 = "It is eaten while held in the hand."
279
+ >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
280
+
281
+ >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True)
282
+ >>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, labels=labels) # batch size is 1
283
+
284
+ >>> # the linear classifier still needs to be trained
285
+ >>> loss, logits = outputs[:2]
286
+ """
287
+
288
+ PT_CAUSAL_LM_SAMPLE = r"""
289
+ Example::
290
+
291
+ >>> import torch
292
+ >>> from transformers import {tokenizer_class}, {model_class}
293
+
294
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
295
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
296
+
297
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
298
+ >>> outputs = model(**inputs, labels=inputs["input_ids"])
299
+ >>> loss, logits = outputs[:2]
300
+ """
301
+
302
+ TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
303
+ Example::
304
+
305
+ >>> from transformers import {tokenizer_class}, {model_class}
306
+ >>> import tensorflow as tf
307
+
308
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
309
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
310
+
311
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
312
+ >>> input_ids = inputs["input_ids"]
313
+ >>> inputs["labels"] = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
314
+
315
+ >>> outputs = model(inputs)
316
+ >>> loss, scores = outputs[:2]
317
+ """
318
+
319
+ TF_QUESTION_ANSWERING_SAMPLE = r"""
320
+ Example::
321
+
322
+ >>> from transformers import {tokenizer_class}, {model_class}
323
+ >>> import tensorflow as tf
324
+
325
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
326
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
327
+
328
+ >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
329
+ >>> input_dict = tokenizer(question, text, return_tensors='tf')
330
+ >>> start_scores, end_scores = model(input_dict)
331
+
332
+ >>> all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
333
+ >>> answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
334
+ """
335
+
336
+ TF_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
337
+ Example::
338
+
339
+ >>> from transformers import {tokenizer_class}, {model_class}
340
+ >>> import tensorflow as tf
341
+
342
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
343
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
344
+
345
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
346
+ >>> inputs["labels"] = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
347
+
348
+ >>> outputs = model(inputs)
349
+ >>> loss, logits = outputs[:2]
350
+ """
351
+
352
+ TF_MASKED_LM_SAMPLE = r"""
353
+ Example::
354
+ >>> from transformers import {tokenizer_class}, {model_class}
355
+ >>> import tensorflow as tf
356
+
357
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
358
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
359
+
360
+ >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
361
+
362
+ >>> outputs = model(input_ids)
363
+ >>> prediction_scores = outputs[0]
364
+ """
365
+
366
+ TF_BASE_MODEL_SAMPLE = r"""
367
+ Example::
368
+
369
+ >>> from transformers import {tokenizer_class}, {model_class}
370
+ >>> import tensorflow as tf
371
+
372
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
373
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
374
+
375
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
376
+ >>> outputs = model(inputs)
377
+
378
+ >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
379
+ """
380
+
381
+ TF_MULTIPLE_CHOICE_SAMPLE = r"""
382
+ Example::
383
+
384
+ >>> from transformers import {tokenizer_class}, {model_class}
385
+ >>> import tensorflow as tf
386
+
387
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
388
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
389
+
390
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
391
+ >>> choice0 = "It is eaten with a fork and a knife."
392
+ >>> choice1 = "It is eaten while held in the hand."
393
+
394
+ >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='tf', padding=True)
395
+ >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}
396
+ >>> outputs = model(inputs) # batch size is 1
397
+
398
+ >>> # the linear classifier still needs to be trained
399
+ >>> logits = outputs[0]
400
+ """
401
+
402
+ TF_CAUSAL_LM_SAMPLE = r"""
403
+ Example::
404
+
405
+ >>> from transformers import {tokenizer_class}, {model_class}
406
+ >>> import tensorflow as tf
407
+
408
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
409
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
410
+
411
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
412
+ >>> outputs = model(inputs)
413
+ >>> logits = outputs[0]
414
+ """
415
+
416
+
417
+ def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None):
418
+ def docstring_decorator(fn):
419
+ model_class = fn.__qualname__.split(".")[0]
420
+ is_tf_class = model_class[:2] == "TF"
421
+
422
+ if "SequenceClassification" in model_class:
423
+ code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE
424
+ elif "QuestionAnswering" in model_class:
425
+ code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE
426
+ elif "TokenClassification" in model_class:
427
+ code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE
428
+ elif "MultipleChoice" in model_class:
429
+ code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE
430
+ elif "MaskedLM" in model_class:
431
+ code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE
432
+ elif "LMHead" in model_class:
433
+ code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE
434
+ elif "Model" in model_class:
435
+ code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE
436
+ else:
437
+ raise ValueError(f"Docstring can't be built for model {model_class}")
438
+
439
+ built_doc = code_sample.format(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
440
+ fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + built_doc
441
+ return fn
442
+
443
+ return docstring_decorator
444
+
445
+
446
+ def is_remote_url(url_or_filename):
447
+ parsed = urlparse(url_or_filename)
448
+ return parsed.scheme in ("http", "https")
449
+
450
+
451
+ def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
452
+ """
453
+ Resolve a model identifier, and a file name, to a HF-hosted url
454
+ on either S3 or Cloudfront (a Content Delivery Network, or CDN).
455
+
456
+ Cloudfront is replicated over the globe so downloads are way faster
457
+ for the end user (and it also lowers our bandwidth costs). However, it
458
+ is more aggressively cached by default, so may not always reflect the
459
+ latest changes to the underlying file (default TTL is 24 hours).
460
+
461
+ In terms of client-side caching from this library, even though
462
+ Cloudfront relays the ETags from S3, using one or the other
463
+ (or switching from one to the other) will affect caching: cached files
464
+ are not shared between the two because the cached file's name contains
465
+ a hash of the url.
466
+ """
467
+ endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX
468
+ legacy_format = "/" not in model_id
469
+ if legacy_format:
470
+ return f"{endpoint}/{model_id}-{filename}"
471
+ else:
472
+ return f"{endpoint}/{model_id}/{filename}"
473
+
474
+
475
+ def url_to_filename(url, etag=None):
476
+ """
477
+ Convert `url` into a hashed filename in a repeatable way.
478
+ If `etag` is specified, append its hash to the url's, delimited
479
+ by a period.
480
+ If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name
481
+ so that TF 2.0 can identify it as a HDF5 file
482
+ (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
483
+ """
484
+ url_bytes = url.encode("utf-8")
485
+ url_hash = sha256(url_bytes)
486
+ filename = url_hash.hexdigest()
487
+
488
+ if etag:
489
+ etag_bytes = etag.encode("utf-8")
490
+ etag_hash = sha256(etag_bytes)
491
+ filename += "." + etag_hash.hexdigest()
492
+
493
+ if url.endswith(".h5"):
494
+ filename += ".h5"
495
+
496
+ return filename
497
+
498
+
499
+ def filename_to_url(filename, cache_dir=None):
500
+ """
501
+ Return the url and etag (which may be ``None``) stored for `filename`.
502
+ Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
503
+ """
504
+ if cache_dir is None:
505
+ cache_dir = TRANSFORMERS_CACHE
506
+ if isinstance(cache_dir, Path):
507
+ cache_dir = str(cache_dir)
508
+
509
+ cache_path = os.path.join(cache_dir, filename)
510
+ if not os.path.exists(cache_path):
511
+ raise EnvironmentError("file {} not found".format(cache_path))
512
+
513
+ meta_path = cache_path + ".json"
514
+ if not os.path.exists(meta_path):
515
+ raise EnvironmentError("file {} not found".format(meta_path))
516
+
517
+ with open(meta_path, encoding="utf-8") as meta_file:
518
+ metadata = json.load(meta_file)
519
+ url = metadata["url"]
520
+ etag = metadata["etag"]
521
+
522
+ return url, etag
523
+
524
+
525
+ def cached_path(
526
+ url_or_filename,
527
+ cache_dir=None,
528
+ force_download=False,
529
+ proxies=None,
530
+ resume_download=False,
531
+ user_agent: Union[Dict, str, None] = None,
532
+ extract_compressed_file=False,
533
+ force_extract=False,
534
+ local_files_only=False,
535
+ ) -> Optional[str]:
536
+ """
537
+ Given something that might be a URL (or might be a local path),
538
+ determine which. If it's a URL, download the file and cache it, and
539
+ return the path to the cached file. If it's already a local path,
540
+ make sure the file exists and then return the path.
541
+ Args:
542
+ cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
543
+ force_download: if True, re-dowload the file even if it's already cached in the cache dir.
544
+ resume_download: if True, resume the download if incompletly recieved file is found.
545
+ user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
546
+ extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
547
+ file in a folder along the archive.
548
+ force_extract: if True when extract_compressed_file is True and the archive was already extracted,
549
+ re-extract the archive and overide the folder where it was extracted.
550
+
551
+ Return:
552
+ None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
553
+ Local path (string) otherwise
554
+ """
555
+ if cache_dir is None:
556
+ cache_dir = TRANSFORMERS_CACHE
557
+ if isinstance(url_or_filename, Path):
558
+ url_or_filename = str(url_or_filename)
559
+ if isinstance(cache_dir, Path):
560
+ cache_dir = str(cache_dir)
561
+
562
+ if is_remote_url(url_or_filename):
563
+ # URL, so get it from the cache (downloading if necessary)
564
+ output_path = get_from_cache(
565
+ url_or_filename,
566
+ cache_dir=cache_dir,
567
+ force_download=force_download,
568
+ proxies=proxies,
569
+ resume_download=resume_download,
570
+ user_agent=user_agent,
571
+ local_files_only=local_files_only,
572
+ )
573
+ elif os.path.exists(url_or_filename):
574
+ # File, and it exists.
575
+ output_path = url_or_filename
576
+ elif urlparse(url_or_filename).scheme == "":
577
+ # File, but it doesn't exist.
578
+ raise EnvironmentError("file {} not found".format(url_or_filename))
579
+ else:
580
+ # Something unknown
581
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
582
+
583
+ if extract_compressed_file:
584
+ if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
585
+ return output_path
586
+
587
+ # Path where we extract compressed archives
588
+ # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
589
+ output_dir, output_file = os.path.split(output_path)
590
+ output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
591
+ output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
592
+
593
+ if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
594
+ return output_path_extracted
595
+
596
+ # Prevent parallel extractions
597
+ lock_path = output_path + ".lock"
598
+ with FileLock(lock_path):
599
+ shutil.rmtree(output_path_extracted, ignore_errors=True)
600
+ os.makedirs(output_path_extracted)
601
+ if is_zipfile(output_path):
602
+ with ZipFile(output_path, "r") as zip_file:
603
+ zip_file.extractall(output_path_extracted)
604
+ zip_file.close()
605
+ elif tarfile.is_tarfile(output_path):
606
+ tar_file = tarfile.open(output_path)
607
+ tar_file.extractall(output_path_extracted)
608
+ tar_file.close()
609
+ else:
610
+ raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
611
+
612
+ return output_path_extracted
613
+
614
+ return output_path
615
+
616
+
617
+ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
618
+ ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
619
+ if is_torch_available():
620
+ ua += "; torch/{}".format(torch.__version__)
621
+ if is_tf_available():
622
+ ua += "; tensorflow/{}".format(tf.__version__)
623
+ if isinstance(user_agent, dict):
624
+ ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
625
+ elif isinstance(user_agent, str):
626
+ ua += "; " + user_agent
627
+ headers = {"user-agent": ua}
628
+ if resume_size > 0:
629
+ headers["Range"] = "bytes=%d-" % (resume_size,)
630
+ response = requests.get(url, stream=True, proxies=proxies, headers=headers)
631
+ if response.status_code == 416: # Range not satisfiable
632
+ return
633
+ content_length = response.headers.get("Content-Length")
634
+ total = resume_size + int(content_length) if content_length is not None else None
635
+ progress = tqdm(
636
+ unit="B",
637
+ unit_scale=True,
638
+ total=total,
639
+ initial=resume_size,
640
+ desc="Downloading",
641
+ disable=bool(logger.getEffectiveLevel() == logging.NOTSET),
642
+ )
643
+ for chunk in response.iter_content(chunk_size=1024):
644
+ if chunk: # filter out keep-alive new chunks
645
+ progress.update(len(chunk))
646
+ temp_file.write(chunk)
647
+ progress.close()
648
+
649
+
650
+ def get_from_cache(
651
+ url,
652
+ cache_dir=None,
653
+ force_download=False,
654
+ proxies=None,
655
+ etag_timeout=10,
656
+ resume_download=False,
657
+ user_agent: Union[Dict, str, None] = None,
658
+ local_files_only=False,
659
+ ) -> Optional[str]:
660
+ """
661
+ Given a URL, look for the corresponding file in the local cache.
662
+ If it's not there, download it. Then return the path to the cached file.
663
+
664
+ Return:
665
+ None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
666
+ Local path (string) otherwise
667
+ """
668
+ if cache_dir is None:
669
+ cache_dir = TRANSFORMERS_CACHE
670
+ if isinstance(cache_dir, Path):
671
+ cache_dir = str(cache_dir)
672
+
673
+ os.makedirs(cache_dir, exist_ok=True)
674
+
675
+ etag = None
676
+ if not local_files_only:
677
+ try:
678
+ response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
679
+ if response.status_code == 200:
680
+ etag = response.headers.get("ETag")
681
+ except (EnvironmentError, requests.exceptions.Timeout):
682
+ # etag is already None
683
+ pass
684
+
685
+ filename = url_to_filename(url, etag)
686
+
687
+ # get cache path to put the file
688
+ cache_path = os.path.join(cache_dir, filename)
689
+
690
+ # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
691
+ # try to get the last downloaded one
692
+ if etag is None:
693
+ if os.path.exists(cache_path):
694
+ return cache_path
695
+ else:
696
+ matching_files = [
697
+ file
698
+ for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
699
+ if not file.endswith(".json") and not file.endswith(".lock")
700
+ ]
701
+ if len(matching_files) > 0:
702
+ return os.path.join(cache_dir, matching_files[-1])
703
+ else:
704
+ # If files cannot be found and local_files_only=True,
705
+ # the models might've been found if local_files_only=False
706
+ # Notify the user about that
707
+ if local_files_only:
708
+ raise ValueError(
709
+ "Cannot find the requested files in the cached path and outgoing traffic has been"
710
+ " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
711
+ " to False."
712
+ )
713
+ return None
714
+
715
+ # From now on, etag is not None.
716
+ if os.path.exists(cache_path) and not force_download:
717
+ return cache_path
718
+
719
+ # Prevent parallel downloads of the same file with a lock.
720
+ lock_path = cache_path + ".lock"
721
+ with FileLock(lock_path):
722
+
723
+ # If the download just completed while the lock was activated.
724
+ if os.path.exists(cache_path) and not force_download:
725
+ # Even if returning early like here, the lock will be released.
726
+ return cache_path
727
+
728
+ if resume_download:
729
+ incomplete_path = cache_path + ".incomplete"
730
+
731
+ @contextmanager
732
+ def _resumable_file_manager():
733
+ with open(incomplete_path, "a+b") as f:
734
+ yield f
735
+
736
+ temp_file_manager = _resumable_file_manager
737
+ if os.path.exists(incomplete_path):
738
+ resume_size = os.stat(incomplete_path).st_size
739
+ else:
740
+ resume_size = 0
741
+ else:
742
+ temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
743
+ resume_size = 0
744
+
745
+ # Download to temporary file, then copy to cache dir once finished.
746
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
747
+ with temp_file_manager() as temp_file:
748
+ logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
749
+
750
+ http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
751
+
752
+ logger.info("storing %s in cache at %s", url, cache_path)
753
+ os.replace(temp_file.name, cache_path)
754
+
755
+ logger.info("creating metadata file for %s", cache_path)
756
+ meta = {"url": url, "etag": etag}
757
+ meta_path = cache_path + ".json"
758
+ with open(meta_path, "w") as meta_file:
759
+ json.dump(meta, meta_file)
760
+
761
+ return cache_path
762
+
763
+
764
+ class cached_property(property):
765
+ """
766
+ Descriptor that mimics @property but caches output in member variable.
767
+
768
+ From tensorflow_datasets
769
+
770
+ Built-in in functools from Python 3.8.
771
+ """
772
+
773
+ def __get__(self, obj, objtype=None):
774
+ # See docs.python.org/3/howto/descriptor.html#properties
775
+ if obj is None:
776
+ return self
777
+ if self.fget is None:
778
+ raise AttributeError("unreadable attribute")
779
+ attr = "__cached_" + self.fget.__name__
780
+ cached = getattr(obj, attr, None)
781
+ if cached is None:
782
+ cached = self.fget(obj)
783
+ setattr(obj, attr, cached)
784
+ return cached
785
+
786
+
787
+ def torch_required(func):
788
+ # Chose a different decorator name than in tests so it's clear they are not the same.
789
+ @wraps(func)
790
+ def wrapper(*args, **kwargs):
791
+ if is_torch_available():
792
+ return func(*args, **kwargs)
793
+ else:
794
+ raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
795
+
796
+ return wrapper
797
+
798
+
799
+ def tf_required(func):
800
+ # Chose a different decorator name than in tests so it's clear they are not the same.
801
+ @wraps(func)
802
+ def wrapper(*args, **kwargs):
803
+ if is_tf_available():
804
+ return func(*args, **kwargs)
805
+ else:
806
+ raise ImportError(f"Method `{func.__name__}` requires TF.")
807
+
808
+ return wrapper
bert/generation_utils.py ADDED
@@ -0,0 +1,993 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+ from typing import Iterable, Optional, Tuple
19
+
20
+ import torch
21
+ from torch import Tensor
22
+ from torch.nn import functional as F
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class GenerationMixin:
29
+ """
30
+ A class contraining all of the functions supporting generation, to be used as a mixin in PreTrainedModel.
31
+ """
32
+
33
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
34
+ return {"input_ids": input_ids}
35
+
36
+ def adjust_logits_during_generation(self, logits, **kwargs):
37
+ return logits
38
+
39
+ def _use_cache(self, outputs, use_cache):
40
+ """During generation, decide whether to pass the `past` variable to the next forward pass."""
41
+ if len(outputs) <= 1 or use_cache is False:
42
+ return False
43
+ if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
44
+ return False
45
+ return True
46
+
47
+ def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
48
+ """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
49
+ for i in range(batch_size * num_beams):
50
+ for previous_token in set(prev_output_tokens[i].tolist()):
51
+ # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
52
+ if lprobs[i, previous_token] < 0:
53
+ lprobs[i, previous_token] *= repetition_penalty
54
+ else:
55
+ lprobs[i, previous_token] /= repetition_penalty
56
+
57
+ def postprocess_next_token_scores(
58
+ self,
59
+ scores,
60
+ input_ids,
61
+ no_repeat_ngram_size,
62
+ bad_words_ids,
63
+ cur_len,
64
+ min_length,
65
+ max_length,
66
+ eos_token_id,
67
+ repetition_penalty,
68
+ batch_size,
69
+ num_beams,
70
+ ):
71
+ # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
72
+ if repetition_penalty != 1.0:
73
+ self.enforce_repetition_penalty_(
74
+ scores, batch_size, num_beams, input_ids, repetition_penalty,
75
+ )
76
+
77
+ # set eos token prob to zero if min_length is not reached
78
+ if eos_token_id is not None and cur_len < min_length:
79
+ scores[:, eos_token_id] = -float("inf")
80
+
81
+ if no_repeat_ngram_size > 0:
82
+ # calculate a list of banned tokens to prevent repetitively generating the same ngrams
83
+ num_batch_hypotheses = batch_size * num_beams
84
+ # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
85
+ banned_batch_tokens = calc_banned_ngram_tokens(
86
+ input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
87
+ )
88
+ for i, banned_tokens in enumerate(banned_batch_tokens):
89
+ scores[i, banned_tokens] = -float("inf")
90
+
91
+ if bad_words_ids is not None:
92
+ # calculate a list of banned tokens according to bad words
93
+ banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
94
+
95
+ for i, banned_tokens in enumerate(banned_tokens):
96
+ scores[i, banned_tokens] = -float("inf")
97
+
98
+ return scores
99
+
100
+ @torch.no_grad()
101
+ def generate(
102
+ self,
103
+ input_ids: Optional[torch.LongTensor] = None,
104
+ max_length: Optional[int] = None,
105
+ min_length: Optional[int] = None,
106
+ do_sample: Optional[bool] = None,
107
+ early_stopping: Optional[bool] = None,
108
+ num_beams: Optional[int] = None,
109
+ temperature: Optional[float] = None,
110
+ top_k: Optional[int] = None,
111
+ top_p: Optional[float] = None,
112
+ repetition_penalty: Optional[float] = None,
113
+ bad_words_ids: Optional[Iterable[int]] = None,
114
+ bos_token_id: Optional[int] = None,
115
+ pad_token_id: Optional[int] = None,
116
+ eos_token_id: Optional[int] = None,
117
+ length_penalty: Optional[float] = None,
118
+ no_repeat_ngram_size: Optional[int] = None,
119
+ num_return_sequences: Optional[int] = None,
120
+ attention_mask: Optional[torch.LongTensor] = None,
121
+ decoder_start_token_id: Optional[int] = None,
122
+ use_cache: Optional[bool] = None,
123
+ **model_specific_kwargs
124
+ ) -> torch.LongTensor:
125
+ r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
126
+
127
+ Adapted in part from `Facebook's XLM beam search code`_.
128
+
129
+ .. _`Facebook's XLM beam search code`:
130
+ https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
131
+
132
+
133
+ Parameters:
134
+
135
+ input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
136
+ The sequence used as a prompt for the generation. If `None` the method initializes
137
+ it as an empty `torch.LongTensor` of shape `(1,)`.
138
+
139
+ max_length: (`optional`) int
140
+ The max length of the sequence to be generated. Between `min_length` and infinity. Default to 20.
141
+
142
+ min_length: (`optional`) int
143
+ The min length of the sequence to be generated. Between 0 and infinity. Default to 0.
144
+
145
+ do_sample: (`optional`) bool
146
+ If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
147
+
148
+ early_stopping: (`optional`) bool
149
+ if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
150
+
151
+ num_beams: (`optional`) int
152
+ Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
153
+
154
+ temperature: (`optional`) float
155
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
156
+
157
+ top_k: (`optional`) int
158
+ The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
159
+
160
+ top_p: (`optional`) float
161
+ The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
162
+
163
+ repetition_penalty: (`optional`) float
164
+ The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
165
+
166
+ pad_token_id: (`optional`) int
167
+ Padding token. Default to specicic model pad_token_id or None if it does not exist.
168
+
169
+ bos_token_id: (`optional`) int
170
+ BOS token. Defaults to `bos_token_id` as defined in the models config.
171
+
172
+ eos_token_id: (`optional`) int
173
+ EOS token. Defaults to `eos_token_id` as defined in the models config.
174
+
175
+ length_penalty: (`optional`) float
176
+ Exponential penalty to the length. Default to 1.
177
+
178
+ no_repeat_ngram_size: (`optional`) int
179
+ If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
180
+ bad_words_ids: (`optional`) list of lists of int
181
+ `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
182
+
183
+ num_return_sequences: (`optional`) int
184
+ The number of independently computed returned sequences for each element in the batch. Default to 1.
185
+
186
+ attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids`
187
+ Mask to avoid performing attention on padding token indices.
188
+ Mask values selected in ``[0, 1]``:
189
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
190
+ Defaults to `None`.
191
+
192
+ `What are attention masks? <../glossary.html#attention-mask>`__
193
+
194
+ decoder_start_token_id=None: (`optional`) int
195
+ If an encoder-decoder model starts decoding with a different token than BOS.
196
+ Defaults to `None` and is changed to `BOS` later.
197
+
198
+ use_cache: (`optional`) bool
199
+ If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
200
+
201
+ model_specific_kwargs: (`optional`) dict
202
+ Additional model specific kwargs will be forwarded to the `forward` function of the model.
203
+
204
+ Return:
205
+
206
+ output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
207
+ sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`
208
+
209
+ Examples::
210
+
211
+ tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
212
+ model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
213
+ outputs = model.generate(max_length=40) # do greedy decoding
214
+ print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
215
+
216
+ tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
217
+ model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
218
+ input_context = 'The dog'
219
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
220
+ outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
221
+ for i in range(3): # 3 output sequences were generated
222
+ print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
223
+
224
+ tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
225
+ model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
226
+ input_context = 'The dog'
227
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
228
+ outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling
229
+ for i in range(3): # 3 output sequences were generated
230
+ print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
231
+
232
+ tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
233
+ model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
234
+ input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
235
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
236
+ outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
237
+ print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
238
+
239
+ tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
240
+ model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
241
+ input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
242
+ bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
243
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
244
+ outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
245
+ """
246
+
247
+ # We cannot generate if the model does not have a LM head
248
+ if self.get_output_embeddings() is None:
249
+ raise AttributeError(
250
+ "You tried to generate sequences with a model that does not have a LM Head."
251
+ "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
252
+ )
253
+
254
+ max_length = max_length if max_length is not None else self.config.max_length
255
+ min_length = min_length if min_length is not None else self.config.min_length
256
+ do_sample = do_sample if do_sample is not None else self.config.do_sample
257
+ early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
258
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
259
+ num_beams = num_beams if num_beams is not None else self.config.num_beams
260
+ temperature = temperature if temperature is not None else self.config.temperature
261
+ top_k = top_k if top_k is not None else self.config.top_k
262
+ top_p = top_p if top_p is not None else self.config.top_p
263
+ repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
264
+ bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
265
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
266
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
267
+ length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
268
+ no_repeat_ngram_size = (
269
+ no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
270
+ )
271
+ bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
272
+ num_return_sequences = (
273
+ num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
274
+ )
275
+ decoder_start_token_id = (
276
+ decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
277
+ )
278
+
279
+ if input_ids is not None:
280
+ batch_size = input_ids.shape[0] # overriden by the input batch_size
281
+ else:
282
+ batch_size = 1
283
+
284
+ assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
285
+ assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
286
+ assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
287
+ assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
288
+ assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
289
+ assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
290
+ assert temperature > 0, "`temperature` should be strictly positive."
291
+ assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
292
+ assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
293
+ assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
294
+ assert input_ids is not None or (
295
+ isinstance(bos_token_id, int) and bos_token_id >= 0
296
+ ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
297
+ assert pad_token_id is None or (
298
+ isinstance(pad_token_id, int) and (pad_token_id >= 0)
299
+ ), "`pad_token_id` should be a positive integer."
300
+ assert (eos_token_id is None) or (
301
+ isinstance(eos_token_id, int) and (eos_token_id >= 0)
302
+ ), "`eos_token_id` should be a positive integer."
303
+ assert length_penalty > 0, "`length_penalty` should be strictly positive."
304
+ assert (
305
+ isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
306
+ ), "`no_repeat_ngram_size` should be a positive integer."
307
+ assert (
308
+ isinstance(num_return_sequences, int) and num_return_sequences > 0
309
+ ), "`num_return_sequences` should be a strictly positive integer."
310
+ assert (
311
+ bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
312
+ ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
313
+
314
+ if input_ids is None:
315
+ assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
316
+ "you should either supply a context to complete as `input_ids` input "
317
+ "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
318
+ )
319
+ input_ids = torch.full(
320
+ (batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device,
321
+ )
322
+ else:
323
+ assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
324
+
325
+ # not allow to duplicate outputs when greedy decoding
326
+ if do_sample is False:
327
+ if num_beams == 1:
328
+ # no_beam_search greedy generation conditions
329
+ assert (
330
+ num_return_sequences == 1
331
+ ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
332
+
333
+ else:
334
+ # beam_search greedy generation conditions
335
+ assert (
336
+ num_beams >= num_return_sequences
337
+ ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
338
+
339
+ # create attention mask if necessary
340
+ # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
341
+ if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
342
+ attention_mask = input_ids.ne(pad_token_id).long()
343
+ elif attention_mask is None:
344
+ attention_mask = input_ids.new_ones(input_ids.shape)
345
+
346
+ # set pad_token_id to eos_token_id if not set. Important that this is done after
347
+ # attention_mask is created
348
+ if pad_token_id is None and eos_token_id is not None:
349
+ logger.warning(
350
+ "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
351
+ )
352
+ pad_token_id = eos_token_id
353
+
354
+ # current position and vocab size
355
+ if hasattr(self.config, "vocab_size"):
356
+ vocab_size = self.config.vocab_size
357
+ elif (
358
+ self.config.is_encoder_decoder
359
+ and hasattr(self.config, "decoder")
360
+ and hasattr(self.config.decoder, "vocab_size")
361
+ ):
362
+ vocab_size = self.config.decoder.vocab_size
363
+
364
+ # set effective batch size and effective batch multiplier according to do_sample
365
+ if do_sample:
366
+ effective_batch_size = batch_size * num_return_sequences
367
+ effective_batch_mult = num_return_sequences
368
+ else:
369
+ effective_batch_size = batch_size
370
+ effective_batch_mult = 1
371
+
372
+ if self.config.is_encoder_decoder:
373
+ if decoder_start_token_id is None:
374
+ decoder_start_token_id = bos_token_id
375
+
376
+ assert (
377
+ decoder_start_token_id is not None
378
+ ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
379
+ assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
380
+ assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
381
+
382
+ # get encoder and store encoder outputs
383
+ encoder = self.get_encoder()
384
+
385
+ encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
386
+
387
+ # Expand input ids if num_beams > 1 or num_return_sequences > 1
388
+ if num_return_sequences > 1 or num_beams > 1:
389
+ input_ids_len = input_ids.shape[-1]
390
+ input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
391
+ attention_mask = attention_mask.unsqueeze(1).expand(
392
+ batch_size, effective_batch_mult * num_beams, input_ids_len
393
+ )
394
+
395
+ input_ids = input_ids.contiguous().view(
396
+ effective_batch_size * num_beams, input_ids_len
397
+ ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
398
+ attention_mask = attention_mask.contiguous().view(
399
+ effective_batch_size * num_beams, input_ids_len
400
+ ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
401
+
402
+ if self.config.is_encoder_decoder:
403
+ # create empty decoder_input_ids
404
+ input_ids = torch.full(
405
+ (effective_batch_size * num_beams, 1),
406
+ decoder_start_token_id,
407
+ dtype=torch.long,
408
+ device=next(self.parameters()).device,
409
+ )
410
+ cur_len = 1
411
+
412
+ assert (
413
+ batch_size == encoder_outputs[0].shape[0]
414
+ ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
415
+
416
+ # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
417
+ expanded_batch_idxs = (
418
+ torch.arange(batch_size)
419
+ .view(-1, 1)
420
+ .repeat(1, num_beams * effective_batch_mult)
421
+ .view(-1)
422
+ .to(input_ids.device)
423
+ )
424
+ # expand encoder_outputs
425
+ encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
426
+
427
+ else:
428
+ encoder_outputs = None
429
+ cur_len = input_ids.shape[-1]
430
+
431
+ assert (
432
+ cur_len < max_length
433
+ ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
434
+
435
+ if num_beams > 1:
436
+ output = self._generate_beam_search(
437
+ input_ids,
438
+ cur_len=cur_len,
439
+ max_length=max_length,
440
+ min_length=min_length,
441
+ do_sample=do_sample,
442
+ early_stopping=early_stopping,
443
+ temperature=temperature,
444
+ top_k=top_k,
445
+ top_p=top_p,
446
+ repetition_penalty=repetition_penalty,
447
+ no_repeat_ngram_size=no_repeat_ngram_size,
448
+ bad_words_ids=bad_words_ids,
449
+ pad_token_id=pad_token_id,
450
+ eos_token_id=eos_token_id,
451
+ batch_size=effective_batch_size,
452
+ num_return_sequences=num_return_sequences,
453
+ length_penalty=length_penalty,
454
+ num_beams=num_beams,
455
+ vocab_size=vocab_size,
456
+ encoder_outputs=encoder_outputs,
457
+ attention_mask=attention_mask,
458
+ use_cache=use_cache,
459
+ model_specific_kwargs=model_specific_kwargs,
460
+ )
461
+ else:
462
+ output = self._generate_no_beam_search(
463
+ input_ids,
464
+ cur_len=cur_len,
465
+ max_length=max_length,
466
+ min_length=min_length,
467
+ do_sample=do_sample,
468
+ temperature=temperature,
469
+ top_k=top_k,
470
+ top_p=top_p,
471
+ repetition_penalty=repetition_penalty,
472
+ no_repeat_ngram_size=no_repeat_ngram_size,
473
+ bad_words_ids=bad_words_ids,
474
+ pad_token_id=pad_token_id,
475
+ eos_token_id=eos_token_id,
476
+ batch_size=effective_batch_size,
477
+ encoder_outputs=encoder_outputs,
478
+ attention_mask=attention_mask,
479
+ use_cache=use_cache,
480
+ model_specific_kwargs=model_specific_kwargs,
481
+ )
482
+
483
+ return output
484
+
485
+ def _generate_no_beam_search(
486
+ self,
487
+ input_ids,
488
+ cur_len,
489
+ max_length,
490
+ min_length,
491
+ do_sample,
492
+ temperature,
493
+ top_k,
494
+ top_p,
495
+ repetition_penalty,
496
+ no_repeat_ngram_size,
497
+ bad_words_ids,
498
+ pad_token_id,
499
+ eos_token_id,
500
+ batch_size,
501
+ encoder_outputs,
502
+ attention_mask,
503
+ use_cache,
504
+ model_specific_kwargs,
505
+ ):
506
+ """ Generate sequences for each example without beam search (num_beams == 1).
507
+ All returned sequence are generated independantly.
508
+ """
509
+ # length of generated sentences / unfinished sentences
510
+ unfinished_sents = input_ids.new(batch_size).fill_(1)
511
+ sent_lengths = input_ids.new(batch_size).fill_(max_length)
512
+
513
+ past = (encoder_outputs, None) if encoder_outputs is not None else None
514
+
515
+ while cur_len < max_length:
516
+ model_inputs = self.prepare_inputs_for_generation(
517
+ input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
518
+ )
519
+
520
+ outputs = self(**model_inputs)
521
+ next_token_logits = outputs[0][:, -1, :]
522
+
523
+ scores = self.postprocess_next_token_scores(
524
+ scores=next_token_logits,
525
+ input_ids=input_ids,
526
+ no_repeat_ngram_size=no_repeat_ngram_size,
527
+ bad_words_ids=bad_words_ids,
528
+ cur_len=cur_len,
529
+ min_length=min_length,
530
+ max_length=max_length,
531
+ eos_token_id=eos_token_id,
532
+ repetition_penalty=repetition_penalty,
533
+ batch_size=batch_size,
534
+ num_beams=1,
535
+ )
536
+
537
+ # if model has past, then set the past variable to speed up decoding
538
+ if self._use_cache(outputs, use_cache):
539
+ past = outputs[1]
540
+
541
+ if do_sample:
542
+ # Temperature (higher temperature => more likely to sample low probability tokens)
543
+ if temperature != 1.0:
544
+ scores = scores / temperature
545
+ # Top-p/top-k filtering
546
+ next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
547
+ # Sample
548
+ probs = F.softmax(next_token_logscores, dim=-1)
549
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
550
+ else:
551
+ # Greedy decoding
552
+ next_token = torch.argmax(next_token_logits, dim=-1)
553
+
554
+ # update generations and finished sentences
555
+ if eos_token_id is not None:
556
+ # pad finished sentences if eos_token_id exist
557
+ tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
558
+ else:
559
+ tokens_to_add = next_token
560
+
561
+ # add token and increase length by one
562
+ input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
563
+ cur_len = cur_len + 1
564
+
565
+ if eos_token_id is not None:
566
+ eos_in_sents = tokens_to_add == eos_token_id
567
+ # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
568
+ is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
569
+ sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
570
+ # unfinished_sents is set to zero if eos in sentence
571
+ unfinished_sents.mul_((~eos_in_sents).long())
572
+
573
+ # stop when there is a </s> in each sentence, or if we exceed the maximul length
574
+ if unfinished_sents.max() == 0:
575
+ break
576
+
577
+ # extend attention_mask for new generated input if only decoder
578
+ if self.config.is_encoder_decoder is False:
579
+ attention_mask = torch.cat(
580
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
581
+ )
582
+
583
+ return input_ids
584
+
585
+ def _generate_beam_search(
586
+ self,
587
+ input_ids,
588
+ cur_len,
589
+ max_length,
590
+ min_length,
591
+ do_sample,
592
+ early_stopping,
593
+ temperature,
594
+ top_k,
595
+ top_p,
596
+ repetition_penalty,
597
+ no_repeat_ngram_size,
598
+ bad_words_ids,
599
+ pad_token_id,
600
+ eos_token_id,
601
+ batch_size,
602
+ num_return_sequences,
603
+ length_penalty,
604
+ num_beams,
605
+ vocab_size,
606
+ encoder_outputs,
607
+ attention_mask,
608
+ use_cache,
609
+ model_specific_kwargs,
610
+ ):
611
+ """ Generate sequences for each example with beam search.
612
+ """
613
+
614
+ # generated hypotheses
615
+ generated_hyps = [
616
+ BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
617
+ for _ in range(batch_size)
618
+ ]
619
+
620
+ # scores for each sentence in the beam
621
+ beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
622
+
623
+ # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
624
+ if do_sample is False:
625
+ beam_scores[:, 1:] = -1e9
626
+ beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
627
+
628
+ # cache compute states
629
+ past = (encoder_outputs, None) if encoder_outputs is not None else None
630
+
631
+ # done sentences
632
+ done = [False for _ in range(batch_size)]
633
+
634
+ while cur_len < max_length:
635
+ model_inputs = self.prepare_inputs_for_generation(
636
+ input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
637
+ )
638
+ outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
639
+ next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
640
+
641
+ # if model has past, then set the past variable to speed up decoding
642
+ if self._use_cache(outputs, use_cache):
643
+ past = outputs[1]
644
+ if self.config.is_encoder_decoder and do_sample is False:
645
+ # TODO (PVP) still a bit hacky here - there might be a better solution
646
+ next_token_logits = self.adjust_logits_during_generation(
647
+ next_token_logits, cur_len=cur_len, max_length=max_length
648
+ )
649
+
650
+ scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
651
+
652
+ scores = self.postprocess_next_token_scores(
653
+ scores=scores,
654
+ input_ids=input_ids,
655
+ no_repeat_ngram_size=no_repeat_ngram_size,
656
+ bad_words_ids=bad_words_ids,
657
+ cur_len=cur_len,
658
+ min_length=min_length,
659
+ max_length=max_length,
660
+ eos_token_id=eos_token_id,
661
+ repetition_penalty=repetition_penalty,
662
+ batch_size=batch_size,
663
+ num_beams=num_beams,
664
+ )
665
+
666
+ assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
667
+ scores.shape, (batch_size * num_beams, vocab_size)
668
+ )
669
+
670
+ if do_sample:
671
+ _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
672
+ # Temperature
673
+ if temperature != 1.0:
674
+ _scores = _scores / temperature
675
+ # Top-p/top-k filtering
676
+ _scores = top_k_top_p_filtering(
677
+ _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
678
+ ) # (batch_size * num_beams, vocab_size)
679
+ # re-organize to group the beam together to sample from all beam_idxs
680
+ _scores = _scores.contiguous().view(
681
+ batch_size, num_beams * vocab_size
682
+ ) # (batch_size, num_beams * vocab_size)
683
+
684
+ # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
685
+ probs = F.softmax(_scores, dim=-1)
686
+ next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2)
687
+ # Compute next scores
688
+ next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
689
+ # sort the sampled vector to make sure that the first num_beams samples are the best
690
+ next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
691
+ next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
692
+
693
+ else:
694
+ next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
695
+
696
+ # re-organize to group the beam together (we are keeping top hypothesis accross beams)
697
+ next_scores = next_scores.view(
698
+ batch_size, num_beams * vocab_size
699
+ ) # (batch_size, num_beams * vocab_size)
700
+
701
+ next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
702
+
703
+ assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
704
+
705
+ # next batch beam content
706
+ next_batch_beam = []
707
+
708
+ # for each sentence
709
+ for batch_idx in range(batch_size):
710
+
711
+ # if we are done with this sentence, add a pad token
712
+ if done[batch_idx]:
713
+ assert (
714
+ len(generated_hyps[batch_idx]) >= num_beams
715
+ ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
716
+ assert (
717
+ eos_token_id is not None and pad_token_id is not None
718
+ ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
719
+ next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
720
+ continue
721
+
722
+ # next sentence beam content, this will get added to next_batch_beam
723
+ next_sent_beam = []
724
+
725
+ # next tokens for this sentence
726
+ for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
727
+ zip(next_tokens[batch_idx], next_scores[batch_idx])
728
+ ):
729
+ # get beam and token IDs
730
+ beam_id = beam_token_id // vocab_size
731
+ token_id = beam_token_id % vocab_size
732
+
733
+ effective_beam_id = batch_idx * num_beams + beam_id
734
+ # add to generated hypotheses if end of sentence
735
+ if (eos_token_id is not None) and (token_id.item() == eos_token_id):
736
+ # if beam_token does not belong to top num_beams tokens, it should not be added
737
+ is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
738
+ if is_beam_token_worse_than_top_num_beams:
739
+ continue
740
+ generated_hyps[batch_idx].add(
741
+ input_ids[effective_beam_id].clone(), beam_token_score.item(),
742
+ )
743
+ else:
744
+ # add next predicted token since it is not eos_token
745
+ next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
746
+
747
+ # once the beam for next step is full, don't add more tokens to it.
748
+ if len(next_sent_beam) == num_beams:
749
+ break
750
+
751
+ # Check if we are done so that we can save a pad step if all(done)
752
+ done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
753
+ next_scores[batch_idx].max().item(), cur_len
754
+ )
755
+
756
+ # update next beam content
757
+ assert len(next_sent_beam) == num_beams, "Beam should always be full"
758
+ next_batch_beam.extend(next_sent_beam)
759
+ assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
760
+
761
+ # stop when we are done with each sentence
762
+ if all(done):
763
+ break
764
+
765
+ # sanity check / prepare next batch
766
+ assert len(next_batch_beam) == batch_size * num_beams
767
+ beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
768
+ beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
769
+ beam_idx = input_ids.new([x[2] for x in next_batch_beam])
770
+
771
+ # re-order batch and update current length
772
+ input_ids = input_ids[beam_idx, :]
773
+ input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
774
+ cur_len = cur_len + 1
775
+
776
+ # re-order internal states
777
+ if past is not None:
778
+ past = self._reorder_cache(past, beam_idx)
779
+
780
+ # extend attention_mask for new generated input if only decoder
781
+ if self.config.is_encoder_decoder is False:
782
+ attention_mask = torch.cat(
783
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
784
+ )
785
+
786
+ # finalize all open beam hypotheses and add to generated hypotheses
787
+ for batch_idx in range(batch_size):
788
+ if done[batch_idx]:
789
+ continue
790
+
791
+ # test that beam scores match previously calculated scores if not eos and batch_idx not done
792
+ if eos_token_id is not None and all(
793
+ (token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
794
+ ):
795
+ assert torch.all(
796
+ next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
797
+ ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
798
+ next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
799
+ )
800
+
801
+ # need to add best num_beams hypotheses to generated hyps
802
+ for beam_id in range(num_beams):
803
+ effective_beam_id = batch_idx * num_beams + beam_id
804
+ final_score = beam_scores[effective_beam_id].item()
805
+ final_tokens = input_ids[effective_beam_id]
806
+ generated_hyps[batch_idx].add(final_tokens, final_score)
807
+
808
+ # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
809
+ output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
810
+ output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
811
+
812
+ # select the best hypotheses
813
+ sent_lengths = input_ids.new(output_batch_size)
814
+ best = []
815
+
816
+ # retrieve best hypotheses
817
+ for i, hypotheses in enumerate(generated_hyps):
818
+ sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
819
+ for j in range(output_num_return_sequences_per_batch):
820
+ effective_batch_idx = output_num_return_sequences_per_batch * i + j
821
+ best_hyp = sorted_hyps.pop()[1]
822
+ sent_lengths[effective_batch_idx] = len(best_hyp)
823
+ best.append(best_hyp)
824
+
825
+ # shorter batches are padded
826
+ if sent_lengths.min().item() != sent_lengths.max().item():
827
+ assert pad_token_id is not None, "`Pad_token_id` has to be defined"
828
+ sent_max_len = min(sent_lengths.max().item() + 1, max_length)
829
+ decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
830
+
831
+ # fill with hypothesis and eos_token_id if necessary
832
+ for i, hypo in enumerate(best):
833
+ decoded[i, : sent_lengths[i]] = hypo
834
+ if sent_lengths[i] < max_length:
835
+ decoded[i, sent_lengths[i]] = eos_token_id
836
+ else:
837
+ # none of the hypotheses have an eos_token
838
+ assert (len(hypo) == max_length for hypo in best)
839
+ decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
840
+
841
+ return decoded
842
+
843
+ @staticmethod
844
+ def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
845
+ return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
846
+
847
+
848
+ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
849
+ """Copied from fairseq for no_repeat_ngram in beam_search"""
850
+ if cur_len + 1 < no_repeat_ngram_size:
851
+ # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
852
+ return [[] for _ in range(num_hypos)]
853
+ generated_ngrams = [{} for _ in range(num_hypos)]
854
+ for idx in range(num_hypos):
855
+ gen_tokens = prev_input_ids[idx].tolist()
856
+ generated_ngram = generated_ngrams[idx]
857
+ for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
858
+ prev_ngram_tuple = tuple(ngram[:-1])
859
+ generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
860
+
861
+ def _get_generated_ngrams(hypo_idx):
862
+ # Before decoding the next token, prevent decoding of ngrams that have already appeared
863
+ start_idx = cur_len + 1 - no_repeat_ngram_size
864
+ ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
865
+ return generated_ngrams[hypo_idx].get(ngram_idx, [])
866
+
867
+ banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
868
+ return banned_tokens
869
+
870
+
871
+ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
872
+ banned_tokens = []
873
+
874
+ def _tokens_match(prev_tokens, tokens):
875
+ if len(tokens) == 0:
876
+ # if bad word tokens is just one token always ban it
877
+ return True
878
+ if len(tokens) > len(prev_input_ids):
879
+ # if bad word tokens are longer then prev input_ids they can't be equal
880
+ return False
881
+
882
+ if prev_tokens[-len(tokens) :] == tokens:
883
+ # if tokens match
884
+ return True
885
+ else:
886
+ return False
887
+
888
+ for prev_input_ids_slice in prev_input_ids:
889
+ banned_tokens_slice = []
890
+
891
+ for banned_token_seq in bad_words_ids:
892
+ assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
893
+ bad_words_ids
894
+ )
895
+
896
+ if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
897
+ # if tokens do not match continue
898
+ continue
899
+
900
+ banned_tokens_slice.append(banned_token_seq[-1])
901
+
902
+ banned_tokens.append(banned_tokens_slice)
903
+
904
+ return banned_tokens
905
+
906
+
907
+ def top_k_top_p_filtering(
908
+ logits: Tensor,
909
+ top_k: int = 0,
910
+ top_p: float = 1.0,
911
+ filter_value: float = -float("Inf"),
912
+ min_tokens_to_keep: int = 1,
913
+ ) -> Tensor:
914
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
915
+ Args:
916
+ logits: logits distribution shape (batch size, vocabulary size)
917
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
918
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
919
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
920
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
921
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
922
+ """
923
+ if top_k > 0:
924
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
925
+ # Remove all tokens with a probability less than the last token of the top-k
926
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
927
+ logits[indices_to_remove] = filter_value
928
+
929
+ if top_p < 1.0:
930
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
931
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
932
+
933
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
934
+ sorted_indices_to_remove = cumulative_probs > top_p
935
+ if min_tokens_to_keep > 1:
936
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
937
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
938
+ # Shift the indices to the right to keep also the first token above the threshold
939
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
940
+ sorted_indices_to_remove[..., 0] = 0
941
+
942
+ # scatter sorted tensors to original indexing
943
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
944
+ logits[indices_to_remove] = filter_value
945
+ return logits
946
+
947
+
948
+ class BeamHypotheses(object):
949
+ def __init__(self, num_beams, max_length, length_penalty, early_stopping):
950
+ """
951
+ Initialize n-best list of hypotheses.
952
+ """
953
+ self.max_length = max_length - 1 # ignoring bos_token
954
+ self.length_penalty = length_penalty
955
+ self.early_stopping = early_stopping
956
+ self.num_beams = num_beams
957
+ self.beams = []
958
+ self.worst_score = 1e9
959
+
960
+ def __len__(self):
961
+ """
962
+ Number of hypotheses in the list.
963
+ """
964
+ return len(self.beams)
965
+
966
+ def add(self, hyp, sum_logprobs):
967
+ """
968
+ Add a new hypothesis to the list.
969
+ """
970
+ score = sum_logprobs / len(hyp) ** self.length_penalty
971
+ if len(self) < self.num_beams or score > self.worst_score:
972
+ self.beams.append((score, hyp))
973
+ if len(self) > self.num_beams:
974
+ sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
975
+ del self.beams[sorted_scores[0][1]]
976
+ self.worst_score = sorted_scores[1][0]
977
+ else:
978
+ self.worst_score = min(score, self.worst_score)
979
+
980
+ def is_done(self, best_sum_logprobs, cur_len):
981
+ """
982
+ If there are enough hypotheses and that none of the hypotheses being generated
983
+ can become better than the worst one in the heap, then we are done with this sentence.
984
+ """
985
+
986
+ if len(self) < self.num_beams:
987
+ return False
988
+ elif self.early_stopping:
989
+ return True
990
+ else:
991
+ cur_score = best_sum_logprobs / cur_len ** self.length_penalty
992
+ ret = self.worst_score >= cur_score
993
+ return ret
bert/modeling_bert.py ADDED
@@ -0,0 +1,1569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model. """
17
+
18
+
19
+ import logging
20
+ import math
21
+ import os
22
+ import warnings
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import CrossEntropyLoss, MSELoss
28
+
29
+ from .activations import gelu, gelu_new, swish
30
+ from .configuration_bert import BertConfig
31
+ from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
32
+ from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
33
+
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ _TOKENIZER_FOR_DOC = "BertTokenizer"
38
+
39
+ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
40
+ "bert-base-uncased",
41
+ "bert-large-uncased",
42
+ "bert-base-cased",
43
+ "bert-large-cased",
44
+ "bert-base-multilingual-uncased",
45
+ "bert-base-multilingual-cased",
46
+ "bert-base-chinese",
47
+ "bert-base-german-cased",
48
+ "bert-large-uncased-whole-word-masking",
49
+ "bert-large-cased-whole-word-masking",
50
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
51
+ "bert-large-cased-whole-word-masking-finetuned-squad",
52
+ "bert-base-cased-finetuned-mrpc",
53
+ "bert-base-german-dbmdz-cased",
54
+ "bert-base-german-dbmdz-uncased",
55
+ "cl-tohoku/bert-base-japanese",
56
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
57
+ "cl-tohoku/bert-base-japanese-char",
58
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking",
59
+ "TurkuNLP/bert-base-finnish-cased-v1",
60
+ "TurkuNLP/bert-base-finnish-uncased-v1",
61
+ "wietsedv/bert-base-dutch-cased",
62
+ # See all BERT models at https://huggingface.co/models?filter=bert
63
+ ]
64
+
65
+
66
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
67
+ """ Load tf checkpoints in a pytorch model.
68
+ """
69
+ try:
70
+ import re
71
+ import numpy as np
72
+ import tensorflow as tf
73
+ except ImportError:
74
+ logger.error(
75
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
76
+ "https://www.tensorflow.org/install/ for installation instructions."
77
+ )
78
+ raise
79
+ tf_path = os.path.abspath(tf_checkpoint_path)
80
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
81
+ # Load weights from TF model
82
+ init_vars = tf.train.list_variables(tf_path)
83
+ names = []
84
+ arrays = []
85
+ for name, shape in init_vars:
86
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
87
+ array = tf.train.load_variable(tf_path, name)
88
+ names.append(name)
89
+ arrays.append(array)
90
+
91
+ for name, array in zip(names, arrays):
92
+ name = name.split("/")
93
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
94
+ # which are not required for using pretrained model
95
+ if any(
96
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
97
+ for n in name
98
+ ):
99
+ logger.info("Skipping {}".format("/".join(name)))
100
+ continue
101
+ pointer = model
102
+ for m_name in name:
103
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
104
+ scope_names = re.split(r"_(\d+)", m_name)
105
+ else:
106
+ scope_names = [m_name]
107
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
108
+ pointer = getattr(pointer, "weight")
109
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
110
+ pointer = getattr(pointer, "bias")
111
+ elif scope_names[0] == "output_weights":
112
+ pointer = getattr(pointer, "weight")
113
+ elif scope_names[0] == "squad":
114
+ pointer = getattr(pointer, "classifier")
115
+ else:
116
+ try:
117
+ pointer = getattr(pointer, scope_names[0])
118
+ except AttributeError:
119
+ logger.info("Skipping {}".format("/".join(name)))
120
+ continue
121
+ if len(scope_names) >= 2:
122
+ num = int(scope_names[1])
123
+ pointer = pointer[num]
124
+ if m_name[-11:] == "_embeddings":
125
+ pointer = getattr(pointer, "weight")
126
+ elif m_name == "kernel":
127
+ array = np.transpose(array)
128
+ try:
129
+ assert pointer.shape == array.shape
130
+ except AssertionError as e:
131
+ e.args += (pointer.shape, array.shape)
132
+ raise
133
+ logger.info("Initialize PyTorch weight {}".format(name))
134
+ pointer.data = torch.from_numpy(array)
135
+ return model
136
+
137
+
138
+ def mish(x):
139
+ return x * torch.tanh(nn.functional.softplus(x))
140
+
141
+
142
+ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
143
+
144
+
145
+ BertLayerNorm = torch.nn.LayerNorm
146
+
147
+
148
+ class BertEmbeddings(nn.Module):
149
+ """Construct the embeddings from word, position and token_type embeddings.
150
+ """
151
+
152
+ def __init__(self, config):
153
+ super().__init__()
154
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
155
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
156
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
157
+
158
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
159
+ # any TensorFlow checkpoint file
160
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
161
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
162
+
163
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
164
+ if input_ids is not None:
165
+ input_shape = input_ids.size()
166
+ else:
167
+ input_shape = inputs_embeds.size()[:-1]
168
+
169
+ seq_length = input_shape[1]
170
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
171
+ if position_ids is None:
172
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
173
+ position_ids = position_ids.unsqueeze(0).expand(input_shape)
174
+ if token_type_ids is None:
175
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
176
+
177
+ if inputs_embeds is None:
178
+ inputs_embeds = self.word_embeddings(input_ids)
179
+ position_embeddings = self.position_embeddings(position_ids)
180
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
181
+
182
+ embeddings = inputs_embeds + position_embeddings + token_type_embeddings
183
+ embeddings = self.LayerNorm(embeddings)
184
+ embeddings = self.dropout(embeddings)
185
+ return embeddings
186
+
187
+
188
+ class BertSelfAttention(nn.Module):
189
+ def __init__(self, config):
190
+ super().__init__()
191
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
192
+ raise ValueError(
193
+ "The hidden size (%d) is not a multiple of the number of attention "
194
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
195
+ )
196
+
197
+ self.num_attention_heads = config.num_attention_heads
198
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
199
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
200
+
201
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
202
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
203
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
204
+
205
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
206
+
207
+ def transpose_for_scores(self, x):
208
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
209
+ x = x.view(*new_x_shape)
210
+ return x.permute(0, 2, 1, 3)
211
+
212
+ def forward(
213
+ self,
214
+ hidden_states,
215
+ attention_mask=None,
216
+ head_mask=None,
217
+ encoder_hidden_states=None,
218
+ encoder_attention_mask=None,
219
+ output_attentions=False,
220
+ ):
221
+ mixed_query_layer = self.query(hidden_states)
222
+
223
+ # If this is instantiated as a cross-attention module, the keys
224
+ # and values come from an encoder; the attention mask needs to be
225
+ # such that the encoder's padding tokens are not attended to.
226
+ if encoder_hidden_states is not None:
227
+ mixed_key_layer = self.key(encoder_hidden_states)
228
+ mixed_value_layer = self.value(encoder_hidden_states)
229
+ attention_mask = encoder_attention_mask
230
+ else:
231
+ mixed_key_layer = self.key(hidden_states)
232
+ mixed_value_layer = self.value(hidden_states)
233
+
234
+ query_layer = self.transpose_for_scores(mixed_query_layer)
235
+ key_layer = self.transpose_for_scores(mixed_key_layer)
236
+ value_layer = self.transpose_for_scores(mixed_value_layer)
237
+
238
+ # Take the dot product between "query" and "key" to get the raw attention scores.
239
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
240
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
241
+ if attention_mask is not None:
242
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
243
+ attention_scores = attention_scores + attention_mask
244
+
245
+ # Normalize the attention scores to probabilities.
246
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
247
+
248
+ # This is actually dropping out entire tokens to attend to, which might
249
+ # seem a bit unusual, but is taken from the original Transformer paper.
250
+ attention_probs = self.dropout(attention_probs)
251
+
252
+ # Mask heads if we want to
253
+ if head_mask is not None:
254
+ attention_probs = attention_probs * head_mask
255
+
256
+ context_layer = torch.matmul(attention_probs, value_layer)
257
+
258
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
259
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
260
+ context_layer = context_layer.view(*new_context_layer_shape)
261
+
262
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
263
+ return outputs
264
+
265
+
266
+ class BertSelfOutput(nn.Module):
267
+ def __init__(self, config):
268
+ super().__init__()
269
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
270
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
271
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
272
+
273
+ def forward(self, hidden_states, input_tensor):
274
+ hidden_states = self.dense(hidden_states)
275
+ hidden_states = self.dropout(hidden_states)
276
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
277
+ return hidden_states
278
+
279
+
280
+ class BertAttention(nn.Module):
281
+ def __init__(self, config):
282
+ super().__init__()
283
+ self.self = BertSelfAttention(config)
284
+ self.output = BertSelfOutput(config)
285
+ self.pruned_heads = set()
286
+
287
+ def prune_heads(self, heads):
288
+ if len(heads) == 0:
289
+ return
290
+ heads, index = find_pruneable_heads_and_indices(
291
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
292
+ )
293
+
294
+ # Prune linear layers
295
+ self.self.query = prune_linear_layer(self.self.query, index)
296
+ self.self.key = prune_linear_layer(self.self.key, index)
297
+ self.self.value = prune_linear_layer(self.self.value, index)
298
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
299
+
300
+ # Update hyper params and store pruned heads
301
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
302
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
303
+ self.pruned_heads = self.pruned_heads.union(heads)
304
+
305
+ def forward(
306
+ self,
307
+ hidden_states,
308
+ attention_mask=None,
309
+ head_mask=None,
310
+ encoder_hidden_states=None,
311
+ encoder_attention_mask=None,
312
+ output_attentions=False,
313
+ ):
314
+ self_outputs = self.self(
315
+ hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions,
316
+ )
317
+ attention_output = self.output(self_outputs[0], hidden_states)
318
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
319
+ return outputs
320
+
321
+
322
+ class BertIntermediate(nn.Module):
323
+ def __init__(self, config):
324
+ super().__init__()
325
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
326
+ if isinstance(config.hidden_act, str):
327
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
328
+ else:
329
+ self.intermediate_act_fn = config.hidden_act
330
+
331
+ def forward(self, hidden_states):
332
+ hidden_states = self.dense(hidden_states)
333
+ hidden_states = self.intermediate_act_fn(hidden_states)
334
+ return hidden_states
335
+
336
+
337
+ class BertOutput(nn.Module):
338
+ def __init__(self, config):
339
+ super().__init__()
340
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
341
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
342
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
343
+
344
+ def forward(self, hidden_states, input_tensor):
345
+ hidden_states = self.dense(hidden_states)
346
+ hidden_states = self.dropout(hidden_states)
347
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
348
+ return hidden_states
349
+
350
+
351
+ class BertLayer(nn.Module):
352
+ def __init__(self, config):
353
+ super().__init__()
354
+ self.attention = BertAttention(config)
355
+ self.is_decoder = config.is_decoder
356
+ if self.is_decoder:
357
+ self.crossattention = BertAttention(config)
358
+ self.intermediate = BertIntermediate(config)
359
+ self.output = BertOutput(config)
360
+
361
+ def forward(
362
+ self,
363
+ hidden_states,
364
+ attention_mask=None,
365
+ head_mask=None,
366
+ encoder_hidden_states=None,
367
+ encoder_attention_mask=None,
368
+ output_attentions=False,
369
+ ):
370
+ self_attention_outputs = self.attention(
371
+ hidden_states, attention_mask, head_mask, output_attentions=output_attentions,
372
+ )
373
+ attention_output = self_attention_outputs[0]
374
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
375
+
376
+ if self.is_decoder and encoder_hidden_states is not None:
377
+ cross_attention_outputs = self.crossattention(
378
+ attention_output,
379
+ attention_mask,
380
+ head_mask,
381
+ encoder_hidden_states,
382
+ encoder_attention_mask,
383
+ output_attentions,
384
+ )
385
+ attention_output = cross_attention_outputs[0]
386
+ outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
387
+
388
+ intermediate_output = self.intermediate(attention_output)
389
+ layer_output = self.output(intermediate_output, attention_output)
390
+ outputs = (layer_output,) + outputs
391
+ return outputs
392
+
393
+
394
+ class BertEncoder(nn.Module):
395
+ def __init__(self, config):
396
+ super().__init__()
397
+ self.config = config
398
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
399
+
400
+ def forward(
401
+ self,
402
+ hidden_states,
403
+ attention_mask=None,
404
+ head_mask=None,
405
+ encoder_hidden_states=None,
406
+ encoder_attention_mask=None,
407
+ output_attentions=False,
408
+ output_hidden_states=False,
409
+ ):
410
+ all_hidden_states = ()
411
+ all_attentions = ()
412
+ for i, layer_module in enumerate(self.layer):
413
+ if output_hidden_states:
414
+ all_hidden_states = all_hidden_states + (hidden_states,)
415
+
416
+ if getattr(self.config, "gradient_checkpointing", False):
417
+
418
+ def create_custom_forward(module):
419
+ def custom_forward(*inputs):
420
+ return module(*inputs, output_attentions)
421
+
422
+ return custom_forward
423
+
424
+ layer_outputs = torch.utils.checkpoint.checkpoint(
425
+ create_custom_forward(layer_module),
426
+ hidden_states,
427
+ attention_mask,
428
+ head_mask[i],
429
+ encoder_hidden_states,
430
+ encoder_attention_mask,
431
+ )
432
+ else:
433
+ layer_outputs = layer_module(
434
+ hidden_states,
435
+ attention_mask,
436
+ head_mask[i],
437
+ encoder_hidden_states,
438
+ encoder_attention_mask,
439
+ output_attentions,
440
+ )
441
+ hidden_states = layer_outputs[0]
442
+
443
+ if output_attentions:
444
+ all_attentions = all_attentions + (layer_outputs[1],)
445
+
446
+ # Add last layer
447
+ if output_hidden_states:
448
+ all_hidden_states = all_hidden_states + (hidden_states,)
449
+
450
+ outputs = (hidden_states,)
451
+ if output_hidden_states:
452
+ outputs = outputs + (all_hidden_states,)
453
+ if output_attentions:
454
+ outputs = outputs + (all_attentions,)
455
+ return outputs # last-layer hidden state, (all hidden states), (all attentions)
456
+
457
+
458
+ class BertPooler(nn.Module):
459
+ def __init__(self, config):
460
+ super().__init__()
461
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
462
+ self.activation = nn.Tanh()
463
+
464
+ def forward(self, hidden_states):
465
+ # We "pool" the model by simply taking the hidden state corresponding
466
+ # to the first token.
467
+ first_token_tensor = hidden_states[:, 0]
468
+ pooled_output = self.dense(first_token_tensor)
469
+ pooled_output = self.activation(pooled_output)
470
+ return pooled_output
471
+
472
+
473
+ class BertPredictionHeadTransform(nn.Module):
474
+ def __init__(self, config):
475
+ super().__init__()
476
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
477
+ if isinstance(config.hidden_act, str):
478
+ self.transform_act_fn = ACT2FN[config.hidden_act]
479
+ else:
480
+ self.transform_act_fn = config.hidden_act
481
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
482
+
483
+ def forward(self, hidden_states):
484
+ hidden_states = self.dense(hidden_states)
485
+ hidden_states = self.transform_act_fn(hidden_states)
486
+ hidden_states = self.LayerNorm(hidden_states)
487
+ return hidden_states
488
+
489
+
490
+ class BertLMPredictionHead(nn.Module):
491
+ def __init__(self, config):
492
+ super().__init__()
493
+ self.transform = BertPredictionHeadTransform(config)
494
+
495
+ # The output weights are the same as the input embeddings, but there is
496
+ # an output-only bias for each token.
497
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
498
+
499
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
500
+
501
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
502
+ self.decoder.bias = self.bias
503
+
504
+ def forward(self, hidden_states):
505
+ hidden_states = self.transform(hidden_states)
506
+ hidden_states = self.decoder(hidden_states)
507
+ return hidden_states
508
+
509
+
510
+ class BertOnlyMLMHead(nn.Module):
511
+ def __init__(self, config):
512
+ super().__init__()
513
+ self.predictions = BertLMPredictionHead(config)
514
+
515
+ def forward(self, sequence_output):
516
+ prediction_scores = self.predictions(sequence_output)
517
+ return prediction_scores
518
+
519
+
520
+ class BertOnlyNSPHead(nn.Module):
521
+ def __init__(self, config):
522
+ super().__init__()
523
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
524
+
525
+ def forward(self, pooled_output):
526
+ seq_relationship_score = self.seq_relationship(pooled_output)
527
+ return seq_relationship_score
528
+
529
+
530
+ class BertPreTrainingHeads(nn.Module):
531
+ def __init__(self, config):
532
+ super().__init__()
533
+ self.predictions = BertLMPredictionHead(config)
534
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
535
+
536
+ def forward(self, sequence_output, pooled_output):
537
+ prediction_scores = self.predictions(sequence_output)
538
+ seq_relationship_score = self.seq_relationship(pooled_output)
539
+ return prediction_scores, seq_relationship_score
540
+
541
+
542
+ class BertPreTrainedModel(PreTrainedModel):
543
+ """ An abstract class to handle weights initialization and
544
+ a simple interface for downloading and loading pretrained models.
545
+ """
546
+
547
+ config_class = BertConfig
548
+ load_tf_weights = load_tf_weights_in_bert
549
+ base_model_prefix = "bert"
550
+
551
+ def _init_weights(self, module):
552
+ """ Initialize the weights """
553
+ if isinstance(module, (nn.Linear, nn.Embedding)):
554
+ # Slightly different from the TF version which uses truncated_normal for initialization
555
+ # cf https://github.com/pytorch/pytorch/pull/5617
556
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
557
+ elif isinstance(module, BertLayerNorm):
558
+ module.bias.data.zero_()
559
+ module.weight.data.fill_(1.0)
560
+ if isinstance(module, nn.Linear) and module.bias is not None:
561
+ module.bias.data.zero_()
562
+
563
+
564
+ BERT_START_DOCSTRING = r"""
565
+ This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
566
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
567
+ usage and behavior.
568
+
569
+ Parameters:
570
+ config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
571
+ Initializing with a config file does not load the weights associated with the model, only the configuration.
572
+ Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
573
+ """
574
+
575
+ BERT_INPUTS_DOCSTRING = r"""
576
+ Args:
577
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
578
+ Indices of input sequence tokens in the vocabulary.
579
+
580
+ Indices can be obtained using :class:`transformers.BertTokenizer`.
581
+ See :func:`transformers.PreTrainedTokenizer.encode` and
582
+ :func:`transformers.PreTrainedTokenizer.__call__` for details.
583
+
584
+ `What are input IDs? <../glossary.html#input-ids>`__
585
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
586
+ Mask to avoid performing attention on padding token indices.
587
+ Mask values selected in ``[0, 1]``:
588
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
589
+
590
+ `What are attention masks? <../glossary.html#attention-mask>`__
591
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
592
+ Segment token indices to indicate first and second portions of the inputs.
593
+ Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
594
+ corresponds to a `sentence B` token
595
+
596
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
597
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
598
+ Indices of positions of each input sequence tokens in the position embeddings.
599
+ Selected in the range ``[0, config.max_position_embeddings - 1]``.
600
+
601
+ `What are position IDs? <../glossary.html#position-ids>`_
602
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
603
+ Mask to nullify selected heads of the self-attention modules.
604
+ Mask values selected in ``[0, 1]``:
605
+ :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
606
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
607
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
608
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
609
+ than the model's internal embedding lookup matrix.
610
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
611
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
612
+ if the model is configured as a decoder.
613
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
614
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask
615
+ is used in the cross-attention if the model is configured as a decoder.
616
+ Mask values selected in ``[0, 1]``:
617
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
618
+ output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
619
+ If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
620
+ """
621
+
622
+
623
+ @add_start_docstrings(
624
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
625
+ BERT_START_DOCSTRING,
626
+ )
627
+ class BertModel(BertPreTrainedModel):
628
+ """
629
+
630
+ The model can behave as an encoder (with only self-attention) as well
631
+ as a decoder, in which case a layer of cross-attention is added between
632
+ the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
633
+ Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
634
+
635
+ To behave as an decoder the model needs to be initialized with the
636
+ :obj:`is_decoder` argument of the configuration set to :obj:`True`; an
637
+ :obj:`encoder_hidden_states` is expected as an input to the forward pass.
638
+
639
+ .. _`Attention is all you need`:
640
+ https://arxiv.org/abs/1706.03762
641
+
642
+ """
643
+
644
+ def __init__(self, config):
645
+ super().__init__(config)
646
+ self.config = config
647
+
648
+ self.embeddings = BertEmbeddings(config)
649
+ self.encoder = BertEncoder(config)
650
+ self.pooler = BertPooler(config)
651
+
652
+ self.init_weights()
653
+
654
+ def get_input_embeddings(self):
655
+ return self.embeddings.word_embeddings
656
+
657
+ def set_input_embeddings(self, value):
658
+ self.embeddings.word_embeddings = value
659
+
660
+ def _prune_heads(self, heads_to_prune):
661
+ """ Prunes heads of the model.
662
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
663
+ See base class PreTrainedModel
664
+ """
665
+ for layer, heads in heads_to_prune.items():
666
+ self.encoder.layer[layer].attention.prune_heads(heads)
667
+
668
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
669
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
670
+ def forward(
671
+ self,
672
+ input_ids=None,
673
+ attention_mask=None,
674
+ token_type_ids=None,
675
+ position_ids=None,
676
+ head_mask=None,
677
+ inputs_embeds=None,
678
+ encoder_hidden_states=None,
679
+ encoder_attention_mask=None,
680
+ output_attentions=None,
681
+ output_hidden_states=None,
682
+ ):
683
+ r"""
684
+ Return:
685
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
686
+ last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
687
+ Sequence of hidden-states at the output of the last layer of the model.
688
+ pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
689
+ Last layer hidden-state of the first token of the sequence (classification token)
690
+ further processed by a Linear layer and a Tanh activation function. The Linear
691
+ layer weights are trained from the next sentence prediction (classification)
692
+ objective during pre-training.
693
+
694
+ This output is usually *not* a good summary
695
+ of the semantic content of the input, you're often better with averaging or pooling
696
+ the sequence of hidden-states for the whole input sequence.
697
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
698
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
699
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
700
+
701
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
702
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
703
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
704
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
705
+
706
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
707
+ heads.
708
+ """
709
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
710
+ output_hidden_states = (
711
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
712
+ )
713
+
714
+ if input_ids is not None and inputs_embeds is not None:
715
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
716
+ elif input_ids is not None:
717
+ input_shape = input_ids.size()
718
+ elif inputs_embeds is not None:
719
+ input_shape = inputs_embeds.size()[:-1]
720
+ else:
721
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
722
+
723
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
724
+
725
+ if attention_mask is None:
726
+ attention_mask = torch.ones(input_shape, device=device)
727
+ if token_type_ids is None:
728
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
729
+
730
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
731
+ # ourselves in which case we just need to make it broadcastable to all heads.
732
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
733
+
734
+ # If a 2D ou 3D attention mask is provided for the cross-attention
735
+ # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
736
+ if self.config.is_decoder and encoder_hidden_states is not None:
737
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
738
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
739
+ if encoder_attention_mask is None:
740
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
741
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
742
+ else:
743
+ encoder_extended_attention_mask = None
744
+
745
+ # Prepare head mask if needed
746
+ # 1.0 in head_mask indicate we keep the head
747
+ # attention_probs has shape bsz x n_heads x N x N
748
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
749
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
750
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
751
+
752
+ embedding_output = self.embeddings(
753
+ input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
754
+ )
755
+ encoder_outputs = self.encoder(
756
+ embedding_output,
757
+ attention_mask=extended_attention_mask,
758
+ head_mask=head_mask,
759
+ encoder_hidden_states=encoder_hidden_states,
760
+ encoder_attention_mask=encoder_extended_attention_mask,
761
+ output_attentions=output_attentions,
762
+ output_hidden_states=output_hidden_states,
763
+ )
764
+ sequence_output = encoder_outputs[0]
765
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
766
+
767
+ outputs = (sequence_output, pooled_output,) + encoder_outputs[
768
+ 1:
769
+ ] # add hidden_states and attentions if they are here
770
+ return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
771
+
772
+
773
+ @add_start_docstrings(
774
+ """Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and
775
+ a `next sentence prediction (classification)` head. """,
776
+ BERT_START_DOCSTRING,
777
+ )
778
+ class BertForPreTraining(BertPreTrainedModel):
779
+ def __init__(self, config):
780
+ super().__init__(config)
781
+
782
+ self.bert = BertModel(config)
783
+ self.cls = BertPreTrainingHeads(config)
784
+
785
+ self.init_weights()
786
+
787
+ def get_output_embeddings(self):
788
+ return self.cls.predictions.decoder
789
+
790
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
791
+ def forward(
792
+ self,
793
+ input_ids=None,
794
+ attention_mask=None,
795
+ token_type_ids=None,
796
+ position_ids=None,
797
+ head_mask=None,
798
+ inputs_embeds=None,
799
+ labels=None,
800
+ next_sentence_label=None,
801
+ output_attentions=None,
802
+ output_hidden_states=None,
803
+ **kwargs
804
+ ):
805
+ r"""
806
+ labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
807
+ Labels for computing the masked language modeling loss.
808
+ Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
809
+ Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
810
+ in ``[0, ..., config.vocab_size]``
811
+ next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
812
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
813
+ Indices should be in ``[0, 1]``.
814
+ ``0`` indicates sequence B is a continuation of sequence A,
815
+ ``1`` indicates sequence B is a random sequence.
816
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
817
+ Used to hide legacy arguments that have been deprecated.
818
+
819
+ Returns:
820
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
821
+ loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
822
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
823
+ prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
824
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
825
+ seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
826
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False
827
+ continuation before SoftMax).
828
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
829
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
830
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
831
+
832
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
833
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
834
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
835
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
836
+
837
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
838
+ heads.
839
+
840
+
841
+ Examples::
842
+
843
+ >>> from transformers import BertTokenizer, BertForPreTraining
844
+ >>> import torch
845
+
846
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
847
+ >>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
848
+
849
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
850
+ >>> outputs = model(**inputs)
851
+
852
+ >>> prediction_scores, seq_relationship_scores = outputs[:2]
853
+
854
+ """
855
+ if "masked_lm_labels" in kwargs:
856
+ warnings.warn(
857
+ "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
858
+ DeprecationWarning,
859
+ )
860
+ labels = kwargs.pop("masked_lm_labels")
861
+ assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
862
+
863
+ outputs = self.bert(
864
+ input_ids,
865
+ attention_mask=attention_mask,
866
+ token_type_ids=token_type_ids,
867
+ position_ids=position_ids,
868
+ head_mask=head_mask,
869
+ inputs_embeds=inputs_embeds,
870
+ output_attentions=output_attentions,
871
+ output_hidden_states=output_hidden_states,
872
+ )
873
+
874
+ sequence_output, pooled_output = outputs[:2]
875
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
876
+
877
+ outputs = (prediction_scores, seq_relationship_score,) + outputs[
878
+ 2:
879
+ ] # add hidden states and attention if they are here
880
+
881
+ if labels is not None and next_sentence_label is not None:
882
+ loss_fct = CrossEntropyLoss()
883
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
884
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
885
+ total_loss = masked_lm_loss + next_sentence_loss
886
+ outputs = (total_loss,) + outputs
887
+
888
+ return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
889
+
890
+
891
+ @add_start_docstrings(
892
+ """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
893
+ )
894
+ class BertLMHeadModel(BertPreTrainedModel):
895
+ def __init__(self, config):
896
+ super().__init__(config)
897
+ assert config.is_decoder, "If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True`."
898
+
899
+ self.bert = BertModel(config)
900
+ self.cls = BertOnlyMLMHead(config)
901
+
902
+ self.init_weights()
903
+
904
+ def get_output_embeddings(self):
905
+ return self.cls.predictions.decoder
906
+
907
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
908
+ def forward(
909
+ self,
910
+ input_ids=None,
911
+ attention_mask=None,
912
+ token_type_ids=None,
913
+ position_ids=None,
914
+ head_mask=None,
915
+ inputs_embeds=None,
916
+ labels=None,
917
+ encoder_hidden_states=None,
918
+ encoder_attention_mask=None,
919
+ output_attentions=None,
920
+ output_hidden_states=None,
921
+ **kwargs
922
+ ):
923
+ r"""
924
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
925
+ Labels for computing the left-to-right language modeling loss (next word prediction).
926
+ Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
927
+ Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
928
+ in ``[0, ..., config.vocab_size]``
929
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
930
+ Used to hide legacy arguments that have been deprecated.
931
+
932
+ Returns:
933
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
934
+ ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
935
+ Next token prediction loss.
936
+ prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
937
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
938
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
939
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
940
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
941
+
942
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
943
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
944
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
945
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
946
+
947
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
948
+ heads.
949
+
950
+ Example::
951
+
952
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
953
+ >>> import torch
954
+
955
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
956
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
957
+ >>> config.is_decoder = True
958
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
959
+
960
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
961
+ >>> outputs = model(**inputs)
962
+
963
+ >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
964
+ """
965
+
966
+ outputs = self.bert(
967
+ input_ids,
968
+ attention_mask=attention_mask,
969
+ token_type_ids=token_type_ids,
970
+ position_ids=position_ids,
971
+ head_mask=head_mask,
972
+ inputs_embeds=inputs_embeds,
973
+ encoder_hidden_states=encoder_hidden_states,
974
+ encoder_attention_mask=encoder_attention_mask,
975
+ output_attentions=output_attentions,
976
+ output_hidden_states=output_hidden_states,
977
+ )
978
+
979
+ sequence_output = outputs[0]
980
+ prediction_scores = self.cls(sequence_output)
981
+
982
+ outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
983
+
984
+ if labels is not None:
985
+ # we are doing next-token prediction; shift prediction scores and input ids by one
986
+ prediction_scores = prediction_scores[:, :-1, :].contiguous()
987
+ labels = labels[:, 1:].contiguous()
988
+ loss_fct = CrossEntropyLoss()
989
+ ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
990
+ outputs = (ltr_lm_loss,) + outputs
991
+
992
+ return outputs # (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
993
+
994
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
995
+ input_shape = input_ids.shape
996
+
997
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
998
+ if attention_mask is None:
999
+ attention_mask = input_ids.new_ones(input_shape)
1000
+
1001
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1002
+
1003
+
1004
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
1005
+ class BertForMaskedLM(BertPreTrainedModel):
1006
+ def __init__(self, config):
1007
+ super().__init__(config)
1008
+ assert (
1009
+ not config.is_decoder
1010
+ ), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
1011
+
1012
+ self.bert = BertModel(config)
1013
+ self.cls = BertOnlyMLMHead(config)
1014
+
1015
+ self.init_weights()
1016
+
1017
+ def get_output_embeddings(self):
1018
+ return self.cls.predictions.decoder
1019
+
1020
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1021
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1022
+ def forward(
1023
+ self,
1024
+ input_ids=None,
1025
+ attention_mask=None,
1026
+ token_type_ids=None,
1027
+ position_ids=None,
1028
+ head_mask=None,
1029
+ inputs_embeds=None,
1030
+ labels=None,
1031
+ encoder_hidden_states=None,
1032
+ encoder_attention_mask=None,
1033
+ output_attentions=None,
1034
+ output_hidden_states=None,
1035
+ **kwargs
1036
+ ):
1037
+ r"""
1038
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1039
+ Labels for computing the masked language modeling loss.
1040
+ Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
1041
+ Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
1042
+ in ``[0, ..., config.vocab_size]``
1043
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
1044
+ Used to hide legacy arguments that have been deprecated.
1045
+
1046
+ Returns:
1047
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1048
+ masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
1049
+ Masked language modeling loss.
1050
+ prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
1051
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1052
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1053
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1054
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1055
+
1056
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1057
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1058
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1059
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1060
+
1061
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1062
+ heads.
1063
+ """
1064
+ if "masked_lm_labels" in kwargs:
1065
+ warnings.warn(
1066
+ "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
1067
+ DeprecationWarning,
1068
+ )
1069
+ labels = kwargs.pop("masked_lm_labels")
1070
+ assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
1071
+ assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
1072
+
1073
+ outputs = self.bert(
1074
+ input_ids,
1075
+ attention_mask=attention_mask,
1076
+ token_type_ids=token_type_ids,
1077
+ position_ids=position_ids,
1078
+ head_mask=head_mask,
1079
+ inputs_embeds=inputs_embeds,
1080
+ encoder_hidden_states=encoder_hidden_states,
1081
+ encoder_attention_mask=encoder_attention_mask,
1082
+ output_attentions=output_attentions,
1083
+ output_hidden_states=output_hidden_states,
1084
+ )
1085
+
1086
+ sequence_output = outputs[0]
1087
+ prediction_scores = self.cls(sequence_output)
1088
+
1089
+ outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
1090
+
1091
+ if labels is not None:
1092
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1093
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1094
+ outputs = (masked_lm_loss,) + outputs
1095
+
1096
+ return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
1097
+
1098
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1099
+ input_shape = input_ids.shape
1100
+ effective_batch_size = input_shape[0]
1101
+
1102
+ # add a dummy token
1103
+ assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
1104
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1105
+ dummy_token = torch.full(
1106
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1107
+ )
1108
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1109
+
1110
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1111
+
1112
+
1113
+ @add_start_docstrings(
1114
+ """Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
1115
+ )
1116
+ class BertForNextSentencePrediction(BertPreTrainedModel):
1117
+ def __init__(self, config):
1118
+ super().__init__(config)
1119
+
1120
+ self.bert = BertModel(config)
1121
+ self.cls = BertOnlyNSPHead(config)
1122
+
1123
+ self.init_weights()
1124
+
1125
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1126
+ def forward(
1127
+ self,
1128
+ input_ids=None,
1129
+ attention_mask=None,
1130
+ token_type_ids=None,
1131
+ position_ids=None,
1132
+ head_mask=None,
1133
+ inputs_embeds=None,
1134
+ next_sentence_label=None,
1135
+ output_attentions=None,
1136
+ output_hidden_states=None,
1137
+ ):
1138
+ r"""
1139
+ next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1140
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
1141
+ Indices should be in ``[0, 1]``.
1142
+ ``0`` indicates sequence B is a continuation of sequence A,
1143
+ ``1`` indicates sequence B is a random sequence.
1144
+
1145
+ Returns:
1146
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1147
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
1148
+ Next sequence prediction (classification) loss.
1149
+ seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
1150
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
1151
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1152
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1153
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1154
+
1155
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1156
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1157
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1158
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1159
+
1160
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1161
+ heads.
1162
+
1163
+ Examples::
1164
+
1165
+ >>> from transformers import BertTokenizer, BertForNextSentencePrediction
1166
+ >>> import torch
1167
+
1168
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1169
+ >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
1170
+
1171
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1172
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1173
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
1174
+
1175
+ >>> loss, logits = model(**encoding, next_sentence_label=torch.LongTensor([1]))
1176
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1177
+ """
1178
+
1179
+ outputs = self.bert(
1180
+ input_ids,
1181
+ attention_mask=attention_mask,
1182
+ token_type_ids=token_type_ids,
1183
+ position_ids=position_ids,
1184
+ head_mask=head_mask,
1185
+ inputs_embeds=inputs_embeds,
1186
+ output_attentions=output_attentions,
1187
+ output_hidden_states=output_hidden_states,
1188
+ )
1189
+
1190
+ pooled_output = outputs[1]
1191
+
1192
+ seq_relationship_score = self.cls(pooled_output)
1193
+
1194
+ outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
1195
+ if next_sentence_label is not None:
1196
+ loss_fct = CrossEntropyLoss()
1197
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1198
+ outputs = (next_sentence_loss,) + outputs
1199
+
1200
+ return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
1201
+
1202
+
1203
+ @add_start_docstrings(
1204
+ """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
1205
+ the pooled output) e.g. for GLUE tasks. """,
1206
+ BERT_START_DOCSTRING,
1207
+ )
1208
+ class BertForSequenceClassification(BertPreTrainedModel):
1209
+ def __init__(self, config):
1210
+ super().__init__(config)
1211
+ self.num_labels = config.num_labels
1212
+
1213
+ self.bert = BertModel(config)
1214
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1215
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1216
+
1217
+ self.init_weights()
1218
+
1219
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1220
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1221
+ def forward(
1222
+ self,
1223
+ input_ids=None,
1224
+ attention_mask=None,
1225
+ token_type_ids=None,
1226
+ position_ids=None,
1227
+ head_mask=None,
1228
+ inputs_embeds=None,
1229
+ labels=None,
1230
+ output_attentions=None,
1231
+ output_hidden_states=None,
1232
+ ):
1233
+ r"""
1234
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1235
+ Labels for computing the sequence classification/regression loss.
1236
+ Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
1237
+ If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1238
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1239
+
1240
+ Returns:
1241
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1242
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
1243
+ Classification (or regression if config.num_labels==1) loss.
1244
+ logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
1245
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
1246
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1247
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1248
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1249
+
1250
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1251
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1252
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1253
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1254
+
1255
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1256
+ heads.
1257
+ """
1258
+
1259
+ outputs = self.bert(
1260
+ input_ids,
1261
+ attention_mask=attention_mask,
1262
+ token_type_ids=token_type_ids,
1263
+ position_ids=position_ids,
1264
+ head_mask=head_mask,
1265
+ inputs_embeds=inputs_embeds,
1266
+ output_attentions=output_attentions,
1267
+ output_hidden_states=output_hidden_states,
1268
+ )
1269
+
1270
+ pooled_output = outputs[1]
1271
+
1272
+ pooled_output = self.dropout(pooled_output)
1273
+ logits = self.classifier(pooled_output)
1274
+
1275
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
1276
+
1277
+ if labels is not None:
1278
+ if self.num_labels == 1:
1279
+ # We are doing regression
1280
+ loss_fct = MSELoss()
1281
+ loss = loss_fct(logits.view(-1), labels.view(-1))
1282
+ else:
1283
+ loss_fct = CrossEntropyLoss()
1284
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1285
+ outputs = (loss,) + outputs
1286
+
1287
+ return outputs # (loss), logits, (hidden_states), (attentions)
1288
+
1289
+
1290
+ @add_start_docstrings(
1291
+ """Bert Model with a multiple choice classification head on top (a linear layer on top of
1292
+ the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
1293
+ BERT_START_DOCSTRING,
1294
+ )
1295
+ class BertForMultipleChoice(BertPreTrainedModel):
1296
+ def __init__(self, config):
1297
+ super().__init__(config)
1298
+
1299
+ self.bert = BertModel(config)
1300
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1301
+ self.classifier = nn.Linear(config.hidden_size, 1)
1302
+
1303
+ self.init_weights()
1304
+
1305
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
1306
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1307
+ def forward(
1308
+ self,
1309
+ input_ids=None,
1310
+ attention_mask=None,
1311
+ token_type_ids=None,
1312
+ position_ids=None,
1313
+ head_mask=None,
1314
+ inputs_embeds=None,
1315
+ labels=None,
1316
+ output_attentions=None,
1317
+ output_hidden_states=None,
1318
+ ):
1319
+ r"""
1320
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1321
+ Labels for computing the multiple choice classification loss.
1322
+ Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
1323
+ of the input tensors. (see `input_ids` above)
1324
+
1325
+ Returns:
1326
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1327
+ loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
1328
+ Classification loss.
1329
+ classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
1330
+ `num_choices` is the second dimension of the input tensors. (see `input_ids` above).
1331
+
1332
+ Classification scores (before SoftMax).
1333
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1334
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1335
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1336
+
1337
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1338
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1339
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1340
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1341
+
1342
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1343
+ heads.
1344
+ """
1345
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1346
+
1347
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1348
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1349
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1350
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1351
+ inputs_embeds = (
1352
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1353
+ if inputs_embeds is not None
1354
+ else None
1355
+ )
1356
+
1357
+ outputs = self.bert(
1358
+ input_ids,
1359
+ attention_mask=attention_mask,
1360
+ token_type_ids=token_type_ids,
1361
+ position_ids=position_ids,
1362
+ head_mask=head_mask,
1363
+ inputs_embeds=inputs_embeds,
1364
+ output_attentions=output_attentions,
1365
+ output_hidden_states=output_hidden_states,
1366
+ )
1367
+
1368
+ pooled_output = outputs[1]
1369
+
1370
+ pooled_output = self.dropout(pooled_output)
1371
+ logits = self.classifier(pooled_output)
1372
+ reshaped_logits = logits.view(-1, num_choices)
1373
+
1374
+ outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
1375
+
1376
+ if labels is not None:
1377
+ loss_fct = CrossEntropyLoss()
1378
+ loss = loss_fct(reshaped_logits, labels)
1379
+ outputs = (loss,) + outputs
1380
+
1381
+ return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
1382
+
1383
+
1384
+ @add_start_docstrings(
1385
+ """Bert Model with a token classification head on top (a linear layer on top of
1386
+ the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
1387
+ BERT_START_DOCSTRING,
1388
+ )
1389
+ class BertForTokenClassification(BertPreTrainedModel):
1390
+ def __init__(self, config):
1391
+ super().__init__(config)
1392
+ self.num_labels = config.num_labels
1393
+
1394
+ self.bert = BertModel(config)
1395
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1396
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1397
+
1398
+ self.init_weights()
1399
+
1400
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1401
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1402
+ def forward(
1403
+ self,
1404
+ input_ids=None,
1405
+ attention_mask=None,
1406
+ token_type_ids=None,
1407
+ position_ids=None,
1408
+ head_mask=None,
1409
+ inputs_embeds=None,
1410
+ labels=None,
1411
+ output_attentions=None,
1412
+ output_hidden_states=None,
1413
+ ):
1414
+ r"""
1415
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1416
+ Labels for computing the token classification loss.
1417
+ Indices should be in ``[0, ..., config.num_labels - 1]``.
1418
+
1419
+ Returns:
1420
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1421
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
1422
+ Classification loss.
1423
+ scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
1424
+ Classification scores (before SoftMax).
1425
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1426
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1427
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1428
+
1429
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1430
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1431
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1432
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1433
+
1434
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1435
+ heads.
1436
+ """
1437
+
1438
+ outputs = self.bert(
1439
+ input_ids,
1440
+ attention_mask=attention_mask,
1441
+ token_type_ids=token_type_ids,
1442
+ position_ids=position_ids,
1443
+ head_mask=head_mask,
1444
+ inputs_embeds=inputs_embeds,
1445
+ output_attentions=output_attentions,
1446
+ output_hidden_states=output_hidden_states,
1447
+ )
1448
+
1449
+ sequence_output = outputs[0]
1450
+
1451
+ sequence_output = self.dropout(sequence_output)
1452
+ logits = self.classifier(sequence_output)
1453
+
1454
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
1455
+ if labels is not None:
1456
+ loss_fct = CrossEntropyLoss()
1457
+ # Only keep active parts of the loss
1458
+ if attention_mask is not None:
1459
+ active_loss = attention_mask.view(-1) == 1
1460
+ active_logits = logits.view(-1, self.num_labels)
1461
+ active_labels = torch.where(
1462
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
1463
+ )
1464
+ loss = loss_fct(active_logits, active_labels)
1465
+ else:
1466
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1467
+ outputs = (loss,) + outputs
1468
+
1469
+ return outputs # (loss), scores, (hidden_states), (attentions)
1470
+
1471
+
1472
+ @add_start_docstrings(
1473
+ """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1474
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
1475
+ BERT_START_DOCSTRING,
1476
+ )
1477
+ class BertForQuestionAnswering(BertPreTrainedModel):
1478
+ def __init__(self, config):
1479
+ super().__init__(config)
1480
+ self.num_labels = config.num_labels
1481
+
1482
+ self.bert = BertModel(config)
1483
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1484
+
1485
+ self.init_weights()
1486
+
1487
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1488
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1489
+ def forward(
1490
+ self,
1491
+ input_ids=None,
1492
+ attention_mask=None,
1493
+ token_type_ids=None,
1494
+ position_ids=None,
1495
+ head_mask=None,
1496
+ inputs_embeds=None,
1497
+ start_positions=None,
1498
+ end_positions=None,
1499
+ output_attentions=None,
1500
+ output_hidden_states=None,
1501
+ ):
1502
+ r"""
1503
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1504
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1505
+ Positions are clamped to the length of the sequence (`sequence_length`).
1506
+ Position outside of the sequence are not taken into account for computing the loss.
1507
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1508
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1509
+ Positions are clamped to the length of the sequence (`sequence_length`).
1510
+ Position outside of the sequence are not taken into account for computing the loss.
1511
+
1512
+ Returns:
1513
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1514
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
1515
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
1516
+ start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
1517
+ Span-start scores (before SoftMax).
1518
+ end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
1519
+ Span-end scores (before SoftMax).
1520
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1521
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1522
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1523
+
1524
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1525
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1526
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1527
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1528
+
1529
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1530
+ heads.
1531
+ """
1532
+
1533
+ outputs = self.bert(
1534
+ input_ids,
1535
+ attention_mask=attention_mask,
1536
+ token_type_ids=token_type_ids,
1537
+ position_ids=position_ids,
1538
+ head_mask=head_mask,
1539
+ inputs_embeds=inputs_embeds,
1540
+ output_attentions=output_attentions,
1541
+ output_hidden_states=output_hidden_states,
1542
+ )
1543
+
1544
+ sequence_output = outputs[0]
1545
+
1546
+ logits = self.qa_outputs(sequence_output)
1547
+ start_logits, end_logits = logits.split(1, dim=-1)
1548
+ start_logits = start_logits.squeeze(-1)
1549
+ end_logits = end_logits.squeeze(-1)
1550
+
1551
+ outputs = (start_logits, end_logits,) + outputs[2:]
1552
+ if start_positions is not None and end_positions is not None:
1553
+ # If we are on multi-GPU, split add a dimension
1554
+ if len(start_positions.size()) > 1:
1555
+ start_positions = start_positions.squeeze(-1)
1556
+ if len(end_positions.size()) > 1:
1557
+ end_positions = end_positions.squeeze(-1)
1558
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1559
+ ignored_index = start_logits.size(1)
1560
+ start_positions.clamp_(0, ignored_index)
1561
+ end_positions.clamp_(0, ignored_index)
1562
+
1563
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1564
+ start_loss = loss_fct(start_logits, start_positions)
1565
+ end_loss = loss_fct(end_logits, end_positions)
1566
+ total_loss = (start_loss + end_loss) / 2
1567
+ outputs = (total_loss,) + outputs
1568
+
1569
+ return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
bert/modeling_utils.py ADDED
@@ -0,0 +1,1269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ import logging
19
+ import os
20
+ from typing import Callable, Dict, List, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import Tensor, device, dtype, nn
24
+ from torch.nn import CrossEntropyLoss
25
+ from torch.nn import functional as F
26
+
27
+ from .activations import get_activation
28
+ from .configuration_utils import PretrainedConfig
29
+ from .file_utils import (
30
+ DUMMY_INPUTS,
31
+ TF2_WEIGHTS_NAME,
32
+ TF_WEIGHTS_NAME,
33
+ WEIGHTS_NAME,
34
+ cached_path,
35
+ hf_bucket_url,
36
+ is_remote_url,
37
+ )
38
+ from .generation_utils import GenerationMixin
39
+
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ try:
45
+ from torch.nn import Identity
46
+ except ImportError:
47
+ # Older PyTorch compatibility
48
+ class Identity(nn.Module):
49
+ r"""A placeholder identity operator that is argument-insensitive.
50
+ """
51
+
52
+ def __init__(self, *args, **kwargs):
53
+ super().__init__()
54
+
55
+ def forward(self, input):
56
+ return input
57
+
58
+
59
+ def find_pruneable_heads_and_indices(
60
+ heads: List, n_heads: int, head_size: int, already_pruned_heads: set
61
+ ) -> Tuple[set, "torch.LongTensor"]:
62
+ mask = torch.ones(n_heads, head_size)
63
+ heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
64
+ for head in heads:
65
+ # Compute how many pruned heads are before the head and move the index accordingly
66
+ head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
67
+ mask[head] = 0
68
+ mask = mask.view(-1).contiguous().eq(1)
69
+ index: torch.LongTensor = torch.arange(len(mask))[mask].long()
70
+ return heads, index
71
+
72
+
73
+ class ModuleUtilsMixin:
74
+ """
75
+ A few utilities for torch.nn.Modules, to be used as a mixin.
76
+ """
77
+
78
+ def num_parameters(self, only_trainable: bool = False) -> int:
79
+ """
80
+ Get number of (optionally, trainable) parameters in the module.
81
+ """
82
+ params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
83
+ return sum(p.numel() for p in params)
84
+
85
+ @staticmethod
86
+ def _hook_rss_memory_pre_forward(module, *args, **kwargs):
87
+ try:
88
+ import psutil
89
+ except (ImportError):
90
+ raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
91
+
92
+ process = psutil.Process(os.getpid())
93
+ mem = process.memory_info()
94
+ module.mem_rss_pre_forward = mem.rss
95
+ return None
96
+
97
+ @staticmethod
98
+ def _hook_rss_memory_post_forward(module, *args, **kwargs):
99
+ try:
100
+ import psutil
101
+ except (ImportError):
102
+ raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
103
+
104
+ process = psutil.Process(os.getpid())
105
+ mem = process.memory_info()
106
+ module.mem_rss_post_forward = mem.rss
107
+ mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
108
+ module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
109
+ return None
110
+
111
+ def add_memory_hooks(self):
112
+ """ Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
113
+ Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()`
114
+ """
115
+ for module in self.modules():
116
+ module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
117
+ module.register_forward_hook(self._hook_rss_memory_post_forward)
118
+ self.reset_memory_hooks_state()
119
+
120
+ def reset_memory_hooks_state(self):
121
+ for module in self.modules():
122
+ module.mem_rss_diff = 0
123
+ module.mem_rss_post_forward = 0
124
+ module.mem_rss_pre_forward = 0
125
+
126
+ @property
127
+ def device(self) -> device:
128
+ """
129
+ Get torch.device from module, assuming that the whole module has one device.
130
+ """
131
+ try:
132
+ return next(self.parameters()).device
133
+ except StopIteration:
134
+ # For nn.DataParallel compatibility in PyTorch 1.5
135
+
136
+ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
137
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
138
+ return tuples
139
+
140
+ gen = self._named_members(get_members_fn=find_tensor_attributes)
141
+ first_tuple = next(gen)
142
+ return first_tuple[1].device
143
+
144
+ @property
145
+ def dtype(self) -> dtype:
146
+ """
147
+ Get torch.dtype from module, assuming that the whole module has one dtype.
148
+ """
149
+ try:
150
+ return next(self.parameters()).dtype
151
+ except StopIteration:
152
+ # For nn.DataParallel compatibility in PyTorch 1.5
153
+
154
+ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
155
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
156
+ return tuples
157
+
158
+ gen = self._named_members(get_members_fn=find_tensor_attributes)
159
+ first_tuple = next(gen)
160
+ return first_tuple[1].dtype
161
+
162
+ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
163
+ """type: torch.Tensor -> torch.Tensor"""
164
+ if encoder_attention_mask.dim() == 3:
165
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
166
+ if encoder_attention_mask.dim() == 2:
167
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
168
+ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
169
+ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
170
+ # /transformer/transformer_layers.py#L270
171
+ # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
172
+ # encoder_extended_attention_mask.transpose(-1, -2))
173
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
174
+
175
+ if self.dtype == torch.float16:
176
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
177
+ elif self.dtype == torch.float32:
178
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
179
+ else:
180
+ raise ValueError(
181
+ "{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
182
+ self.dtype
183
+ )
184
+ )
185
+
186
+ return encoder_extended_attention_mask
187
+
188
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple, device: device) -> Tensor:
189
+ """Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.
190
+
191
+ Arguments:
192
+ attention_mask: torch.Tensor with 1 indicating tokens to ATTEND to
193
+ input_shape: tuple, shape of input_ids
194
+ device: torch.Device, usually self.device
195
+
196
+ Returns:
197
+ torch.Tensor with dtype of attention_mask.dtype
198
+ """
199
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
200
+ # ourselves in which case we just need to make it broadcastable to all heads.
201
+ if attention_mask.dim() == 3:
202
+ extended_attention_mask = attention_mask[:, None, :, :]
203
+ elif attention_mask.dim() == 2:
204
+ # Provided a padding mask of dimensions [batch_size, seq_length]
205
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
206
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
207
+ if self.config.is_decoder:
208
+ batch_size, seq_length = input_shape
209
+ seq_ids = torch.arange(seq_length, device=device)
210
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
211
+ # causal and attention masks must have same type with pytorch version < 1.3
212
+ causal_mask = causal_mask.to(attention_mask.dtype)
213
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
214
+ else:
215
+ extended_attention_mask = attention_mask[:, None, None, :]
216
+ else:
217
+ raise ValueError(
218
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
219
+ input_shape, attention_mask.shape
220
+ )
221
+ )
222
+
223
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
224
+ # masked positions, this operation will create a tensor which is 0.0 for
225
+ # positions we want to attend and -10000.0 for masked positions.
226
+ # Since we are adding it to the raw scores before the softmax, this is
227
+ # effectively the same as removing these entirely.
228
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
229
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
230
+ return extended_attention_mask
231
+
232
+ def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor:
233
+ """
234
+ # Prepare head mask if needed
235
+ # 1.0 in head_mask indicate we keep the head
236
+ attention_probs has shape bsz x n_heads x N x N
237
+ Arguments:
238
+ head_mask: torch.Tensor or None: has shape [num_heads] or [num_hidden_layers x num_heads]
239
+ num_hidden_layers: int
240
+ Returns:
241
+ Tensor of shape shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
242
+ or list with [None] for each layer
243
+ """
244
+ if head_mask is not None:
245
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
246
+ if is_attention_chunked is True:
247
+ head_mask = head_mask.unsqueeze(-1)
248
+ else:
249
+ head_mask = [None] * num_hidden_layers
250
+
251
+ return head_mask
252
+
253
+ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
254
+ """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
255
+ if head_mask.dim() == 1:
256
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
257
+ head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
258
+ elif head_mask.dim() == 2:
259
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
260
+ assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
261
+ head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
262
+ return head_mask
263
+
264
+
265
+ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
266
+ r""" Base class for all models.
267
+
268
+ :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
269
+ as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
270
+
271
+ Class attributes (overridden by derived classes):
272
+ - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
273
+ - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
274
+
275
+ - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
276
+ - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
277
+ - ``path``: a path (string) to the TensorFlow checkpoint.
278
+
279
+ - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
280
+ """
281
+ config_class = None
282
+ base_model_prefix = ""
283
+
284
+ @property
285
+ def dummy_inputs(self):
286
+ """ Dummy inputs to do a forward pass in the network.
287
+
288
+ Returns:
289
+ torch.Tensor with dummy inputs
290
+ """
291
+ return {"input_ids": torch.tensor(DUMMY_INPUTS)}
292
+
293
+ def __init__(self, config, *inputs, **kwargs):
294
+ super().__init__()
295
+ if not isinstance(config, PretrainedConfig):
296
+ raise ValueError(
297
+ "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
298
+ "To create a model from a pretrained model use "
299
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
300
+ self.__class__.__name__, self.__class__.__name__
301
+ )
302
+ )
303
+ # Save config in model
304
+ self.config = config
305
+
306
+ @property
307
+ def base_model(self):
308
+ return getattr(self, self.base_model_prefix, self)
309
+
310
+ def get_input_embeddings(self):
311
+ """
312
+ Returns the model's input embeddings.
313
+
314
+ Returns:
315
+ :obj:`nn.Module`:
316
+ A torch module mapping vocabulary to hidden states.
317
+ """
318
+ base_model = getattr(self, self.base_model_prefix, self)
319
+ if base_model is not self:
320
+ return base_model.get_input_embeddings()
321
+ else:
322
+ raise NotImplementedError
323
+
324
+ def set_input_embeddings(self, value: nn.Module):
325
+ """
326
+ Set model's input embeddings
327
+
328
+ Args:
329
+ value (:obj:`nn.Module`):
330
+ A module mapping vocabulary to hidden states.
331
+ """
332
+ base_model = getattr(self, self.base_model_prefix, self)
333
+ if base_model is not self:
334
+ base_model.set_input_embeddings(value)
335
+ else:
336
+ raise NotImplementedError
337
+
338
+ def get_output_embeddings(self):
339
+ """
340
+ Returns the model's output embeddings.
341
+
342
+ Returns:
343
+ :obj:`nn.Module`:
344
+ A torch module mapping hidden states to vocabulary.
345
+ """
346
+ return None # Overwrite for models with output embeddings
347
+
348
+ def tie_weights(self):
349
+ """
350
+ Tie the weights between the input embeddings and the output embeddings.
351
+ If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
352
+ the weights instead.
353
+ """
354
+ output_embeddings = self.get_output_embeddings()
355
+ if output_embeddings is not None:
356
+ self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
357
+
358
+ def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
359
+ """ Tie or clone module weights depending of whether we are using TorchScript or not
360
+ """
361
+ if self.config.torchscript:
362
+ output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
363
+ else:
364
+ output_embeddings.weight = input_embeddings.weight
365
+
366
+ if getattr(output_embeddings, "bias", None) is not None:
367
+ output_embeddings.bias.data = torch.nn.functional.pad(
368
+ output_embeddings.bias.data,
369
+ (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],),
370
+ "constant",
371
+ 0,
372
+ )
373
+ if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
374
+ output_embeddings.out_features = input_embeddings.num_embeddings
375
+
376
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None):
377
+ """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
378
+ Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
379
+
380
+ Arguments:
381
+
382
+ new_num_tokens: (`optional`) int:
383
+ New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
384
+ If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
385
+
386
+ Return: ``torch.nn.Embeddings``
387
+ Pointer to the input tokens Embeddings Module of the model
388
+ """
389
+ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
390
+ model_embeds = base_model._resize_token_embeddings(new_num_tokens)
391
+ if new_num_tokens is None:
392
+ return model_embeds
393
+
394
+ # Update base model and current model config
395
+ self.config.vocab_size = new_num_tokens
396
+ base_model.vocab_size = new_num_tokens
397
+
398
+ # Tie weights again if needed
399
+ self.tie_weights()
400
+
401
+ return model_embeds
402
+
403
+ def _resize_token_embeddings(self, new_num_tokens):
404
+ old_embeddings = self.get_input_embeddings()
405
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
406
+ self.set_input_embeddings(new_embeddings)
407
+ return self.get_input_embeddings()
408
+
409
+ def _get_resized_embeddings(
410
+ self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None
411
+ ) -> torch.nn.Embedding:
412
+ """ Build a resized Embedding Module from a provided token Embedding Module.
413
+ Increasing the size will add newly initialized vectors at the end
414
+ Reducing the size will remove vectors from the end
415
+
416
+ Args:
417
+ old_embeddings: ``torch.nn.Embedding``
418
+ Old embeddings to be resized.
419
+ new_num_tokens: (`optional`) int
420
+ New number of tokens in the embedding matrix.
421
+ Increasing the size will add newly initialized vectors at the end
422
+ Reducing the size will remove vectors from the end
423
+ If not provided or None: return the provided token Embedding Module.
424
+ Return: ``torch.nn.Embedding``
425
+ Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
426
+ """
427
+ if new_num_tokens is None:
428
+ return old_embeddings
429
+
430
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
431
+ if old_num_tokens == new_num_tokens:
432
+ return old_embeddings
433
+
434
+ # Build new embeddings
435
+ new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
436
+ new_embeddings.to(old_embeddings.weight.device)
437
+
438
+ # initialize all new embeddings (in particular added tokens)
439
+ self._init_weights(new_embeddings)
440
+
441
+ # Copy token embeddings from the previous weights
442
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
443
+ new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
444
+
445
+ return new_embeddings
446
+
447
+ def init_weights(self):
448
+ """ Initialize and prunes weights if needed. """
449
+ # Initialize weights
450
+ self.apply(self._init_weights)
451
+
452
+ # Prune heads if needed
453
+ if self.config.pruned_heads:
454
+ self.prune_heads(self.config.pruned_heads)
455
+
456
+ # Tie weights if needed
457
+ self.tie_weights()
458
+
459
+ def prune_heads(self, heads_to_prune: Dict):
460
+ """ Prunes heads of the base model.
461
+
462
+ Arguments:
463
+
464
+ heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
465
+ E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
466
+ """
467
+ # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
468
+ for layer, heads in heads_to_prune.items():
469
+ union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
470
+ self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
471
+
472
+ self.base_model._prune_heads(heads_to_prune)
473
+
474
+ def save_pretrained(self, save_directory):
475
+ """ Save a model and its configuration file to a directory, so that it
476
+ can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
477
+
478
+ Arguments:
479
+ save_directory: directory to which to save.
480
+ """
481
+ if os.path.isfile(save_directory):
482
+ logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
483
+ return
484
+ os.makedirs(save_directory, exist_ok=True)
485
+
486
+ # Only save the model itself if we are using distributed training
487
+ model_to_save = self.module if hasattr(self, "module") else self
488
+
489
+ # Attach architecture to the config
490
+ model_to_save.config.architectures = [model_to_save.__class__.__name__]
491
+
492
+ # If we save using the predefined names, we can load using `from_pretrained`
493
+ output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
494
+
495
+ if getattr(self.config, "xla_device", False):
496
+ import torch_xla.core.xla_model as xm
497
+
498
+ if xm.is_master_ordinal():
499
+ # Save configuration file
500
+ model_to_save.config.save_pretrained(save_directory)
501
+ # xm.save takes care of saving only from master
502
+ xm.save(model_to_save.state_dict(), output_model_file)
503
+ else:
504
+ model_to_save.config.save_pretrained(save_directory)
505
+ torch.save(model_to_save.state_dict(), output_model_file)
506
+
507
+ logger.info("Model weights saved in {}".format(output_model_file))
508
+
509
+ @classmethod
510
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
511
+ r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
512
+
513
+ The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
514
+ To train the model, you should first set it back in training mode with ``model.train()``
515
+
516
+ The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
517
+ It is up to you to train those weights with a downstream fine-tuning task.
518
+
519
+ The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
520
+
521
+ Parameters:
522
+ pretrained_model_name_or_path: either:
523
+ - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
524
+ - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
525
+ - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
526
+ - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
527
+ - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
528
+
529
+ model_args: (`optional`) Sequence of positional arguments:
530
+ All remaning positional arguments will be passed to the underlying model's ``__init__`` method
531
+
532
+ config: (`optional`) one of:
533
+ - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
534
+ - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
535
+
536
+ Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
537
+ - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
538
+ - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
539
+ - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
540
+
541
+ state_dict: (`optional`) dict:
542
+ an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
543
+ This option can be used if you want to create a model from a pretrained configuration but load your own weights.
544
+ In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
545
+
546
+ cache_dir: (`optional`) string:
547
+ Path to a directory in which a downloaded pre-trained model
548
+ configuration should be cached if the standard cache should not be used.
549
+
550
+ force_download: (`optional`) boolean, default False:
551
+ Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
552
+
553
+ resume_download: (`optional`) boolean, default False:
554
+ Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
555
+
556
+ proxies: (`optional`) dict, default None:
557
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
558
+ The proxies are used on each request.
559
+
560
+ output_loading_info: (`optional`) boolean:
561
+ Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
562
+
563
+ kwargs: (`optional`) Remaining dictionary of keyword arguments:
564
+ Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
565
+
566
+ - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
567
+ - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
568
+
569
+ Examples::
570
+
571
+ # For example purposes. Not runnable.
572
+ model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
573
+ model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
574
+ model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
575
+ assert model.config.output_attention == True
576
+ # Loading from a TF checkpoint file instead of a PyTorch model (slower)
577
+ config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
578
+ model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
579
+
580
+ """
581
+ config = kwargs.pop("config", None)
582
+ state_dict = kwargs.pop("state_dict", None)
583
+ cache_dir = kwargs.pop("cache_dir", None)
584
+ from_tf = kwargs.pop("from_tf", False)
585
+ force_download = kwargs.pop("force_download", False)
586
+ resume_download = kwargs.pop("resume_download", False)
587
+ proxies = kwargs.pop("proxies", None)
588
+ output_loading_info = kwargs.pop("output_loading_info", False)
589
+ local_files_only = kwargs.pop("local_files_only", False)
590
+ use_cdn = kwargs.pop("use_cdn", True)
591
+
592
+ # Load config if we don't provide a configuration
593
+ if not isinstance(config, PretrainedConfig):
594
+ config_path = config if config is not None else pretrained_model_name_or_path
595
+ config, model_kwargs = cls.config_class.from_pretrained(
596
+ config_path,
597
+ *model_args,
598
+ cache_dir=cache_dir,
599
+ return_unused_kwargs=True,
600
+ force_download=force_download,
601
+ resume_download=resume_download,
602
+ proxies=proxies,
603
+ local_files_only=local_files_only,
604
+ **kwargs,
605
+ )
606
+ else:
607
+ model_kwargs = kwargs
608
+
609
+ # Load model
610
+ if pretrained_model_name_or_path is not None:
611
+ if os.path.isdir(pretrained_model_name_or_path):
612
+ if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
613
+ # Load from a TF 1.0 checkpoint
614
+ archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
615
+ elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
616
+ # Load from a TF 2.0 checkpoint
617
+ archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
618
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
619
+ # Load from a PyTorch checkpoint
620
+ archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
621
+ else:
622
+ raise EnvironmentError(
623
+ "Error no file named {} found in directory {} or `from_tf` set to False".format(
624
+ [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
625
+ pretrained_model_name_or_path,
626
+ )
627
+ )
628
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
629
+ archive_file = pretrained_model_name_or_path
630
+ elif os.path.isfile(pretrained_model_name_or_path + ".index"):
631
+ assert (
632
+ from_tf
633
+ ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
634
+ pretrained_model_name_or_path + ".index"
635
+ )
636
+ archive_file = pretrained_model_name_or_path + ".index"
637
+ else:
638
+ archive_file = hf_bucket_url(
639
+ pretrained_model_name_or_path,
640
+ filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
641
+ use_cdn=use_cdn,
642
+ )
643
+ # pytorch_model.bin
644
+ # https://cdn.huggingface.co/bert-base-uncased-pytorch_model.bin
645
+ try:
646
+ # Load from URL or cache if already cached
647
+ resolved_archive_file = cached_path(
648
+ archive_file,
649
+ cache_dir=cache_dir,
650
+ force_download=force_download,
651
+ proxies=proxies,
652
+ resume_download=resume_download,
653
+ local_files_only=local_files_only,
654
+ )
655
+ if resolved_archive_file is None:
656
+ raise EnvironmentError
657
+ except EnvironmentError:
658
+ msg = (
659
+ f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
660
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
661
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
662
+ )
663
+ raise EnvironmentError(msg)
664
+
665
+ if resolved_archive_file == archive_file:
666
+ logger.info("loading weights file {}".format(archive_file))
667
+ else:
668
+ logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
669
+ else:
670
+ resolved_archive_file = None
671
+
672
+ # Instantiate model.
673
+ model = cls(config, *model_args, **model_kwargs)
674
+
675
+ if state_dict is None and not from_tf:
676
+ try:
677
+ state_dict = torch.load(resolved_archive_file, map_location="cpu")
678
+ except Exception:
679
+ raise OSError(
680
+ "Unable to load weights from pytorch checkpoint file. "
681
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
682
+ )
683
+
684
+ missing_keys = []
685
+ unexpected_keys = []
686
+ error_msgs = []
687
+
688
+ if from_tf:
689
+ if resolved_archive_file.endswith(".index"):
690
+ # Load from a TensorFlow 1.X checkpoint - provided by original authors
691
+ model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
692
+ else:
693
+ # Load from our TensorFlow 2.0 checkpoints
694
+ try:
695
+ from transformers import load_tf2_checkpoint_in_pytorch_model
696
+
697
+ model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
698
+ except ImportError:
699
+ logger.error(
700
+ "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
701
+ "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
702
+ )
703
+ raise
704
+ else:
705
+ # Convert old format to new format if needed from a PyTorch state_dict
706
+ old_keys = []
707
+ new_keys = []
708
+ for key in state_dict.keys():
709
+ new_key = None
710
+ if "gamma" in key:
711
+ new_key = key.replace("gamma", "weight")
712
+ if "beta" in key:
713
+ new_key = key.replace("beta", "bias")
714
+ if new_key:
715
+ old_keys.append(key)
716
+ new_keys.append(new_key)
717
+ for old_key, new_key in zip(old_keys, new_keys):
718
+ state_dict[new_key] = state_dict.pop(old_key)
719
+
720
+ # copy state_dict so _load_from_state_dict can modify it
721
+ metadata = getattr(state_dict, "_metadata", None)
722
+ state_dict = state_dict.copy()
723
+ if metadata is not None:
724
+ state_dict._metadata = metadata
725
+
726
+ ##############################################################################################
727
+ # Print out state_dict's contents: keys
728
+ '''
729
+ for key, _ in state_dict.items():
730
+ print(key)
731
+ '''
732
+ ##############################################################################################
733
+
734
+
735
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
736
+ # so we need to apply the function recursively.
737
+ def load(module: nn.Module, prefix=""):
738
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
739
+ module._load_from_state_dict(
740
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
741
+ )
742
+ for name, child in module._modules.items():
743
+ if child is not None:
744
+ load(child, prefix + name + ".")
745
+
746
+ # Make sure we are able to load base models as well as derived models (with heads)
747
+ start_prefix = ""
748
+ model_to_load = model
749
+ has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
750
+ if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
751
+ start_prefix = cls.base_model_prefix + "."
752
+ if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
753
+ model_to_load = getattr(model, cls.base_model_prefix)
754
+
755
+ load(model_to_load, prefix=start_prefix)
756
+
757
+ if model.__class__.__name__ != model_to_load.__class__.__name__:
758
+ base_model_state_dict = model_to_load.state_dict().keys()
759
+ head_model_state_dict_without_base_prefix = [
760
+ key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
761
+ ]
762
+
763
+ missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
764
+
765
+ if len(unexpected_keys) > 0:
766
+ logger.warning(
767
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
768
+ f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
769
+ f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
770
+ f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
771
+ f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
772
+ f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
773
+ )
774
+ else:
775
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
776
+ if len(missing_keys) > 0:
777
+ logger.warning(
778
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
779
+ f"and are newly initialized: {missing_keys}\n"
780
+ f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
781
+ )
782
+ else:
783
+ logger.info(
784
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
785
+ f"If your task is similar to the task the model of the ckeckpoint was trained on, "
786
+ f"you can already use {model.__class__.__name__} for predictions without further training."
787
+ )
788
+ if len(error_msgs) > 0:
789
+ raise RuntimeError(
790
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
791
+ model.__class__.__name__, "\n\t".join(error_msgs)
792
+ )
793
+ )
794
+ model.tie_weights() # make sure token embedding weights are still tied if needed
795
+
796
+ # Set model in evaluation mode to deactivate DropOut modules by default
797
+ model.eval()
798
+
799
+ if output_loading_info:
800
+ loading_info = {
801
+ "missing_keys": missing_keys,
802
+ "unexpected_keys": unexpected_keys,
803
+ "error_msgs": error_msgs,
804
+ }
805
+ return model, loading_info
806
+
807
+ if hasattr(config, "xla_device") and config.xla_device:
808
+ import torch_xla.core.xla_model as xm
809
+
810
+ model = xm.send_cpu_data_to_device(model, xm.xla_device())
811
+ model.to(xm.xla_device())
812
+
813
+ return model
814
+
815
+
816
+ class Conv1D(nn.Module):
817
+ def __init__(self, nf, nx):
818
+ """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
819
+ Basically works like a Linear layer but the weights are transposed
820
+ """
821
+ super().__init__()
822
+ self.nf = nf
823
+ w = torch.empty(nx, nf)
824
+ nn.init.normal_(w, std=0.02)
825
+ self.weight = nn.Parameter(w)
826
+ self.bias = nn.Parameter(torch.zeros(nf))
827
+
828
+ def forward(self, x):
829
+ size_out = x.size()[:-1] + (self.nf,)
830
+ x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
831
+ x = x.view(*size_out)
832
+ return x
833
+
834
+
835
+ class PoolerStartLogits(nn.Module):
836
+ """ Compute SQuAD start_logits from sequence hidden states. """
837
+
838
+ def __init__(self, config):
839
+ super().__init__()
840
+ self.dense = nn.Linear(config.hidden_size, 1)
841
+
842
+ def forward(self, hidden_states, p_mask=None):
843
+ """ Args:
844
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
845
+ invalid position mask such as query and special symbols (PAD, SEP, CLS)
846
+ 1.0 means token should be masked.
847
+ """
848
+ x = self.dense(hidden_states).squeeze(-1)
849
+
850
+ if p_mask is not None:
851
+ if next(self.parameters()).dtype == torch.float16:
852
+ x = x * (1 - p_mask) - 65500 * p_mask
853
+ else:
854
+ x = x * (1 - p_mask) - 1e30 * p_mask
855
+
856
+ return x
857
+
858
+
859
+ class PoolerEndLogits(nn.Module):
860
+ """ Compute SQuAD end_logits from sequence hidden states and start token hidden state.
861
+ """
862
+
863
+ def __init__(self, config):
864
+ super().__init__()
865
+ self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
866
+ self.activation = nn.Tanh()
867
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
868
+ self.dense_1 = nn.Linear(config.hidden_size, 1)
869
+
870
+ def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
871
+ """ Args:
872
+ One of ``start_states``, ``start_positions`` should be not None.
873
+ If both are set, ``start_positions`` overrides ``start_states``.
874
+
875
+ **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
876
+ hidden states of the first tokens for the labeled span.
877
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
878
+ position of the first token for the labeled span:
879
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
880
+ Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
881
+ 1.0 means token should be masked.
882
+ """
883
+ assert (
884
+ start_states is not None or start_positions is not None
885
+ ), "One of start_states, start_positions should be not None"
886
+ if start_positions is not None:
887
+ slen, hsz = hidden_states.shape[-2:]
888
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
889
+ start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
890
+ start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
891
+
892
+ x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
893
+ x = self.activation(x)
894
+ x = self.LayerNorm(x)
895
+ x = self.dense_1(x).squeeze(-1)
896
+
897
+ if p_mask is not None:
898
+ if next(self.parameters()).dtype == torch.float16:
899
+ x = x * (1 - p_mask) - 65500 * p_mask
900
+ else:
901
+ x = x * (1 - p_mask) - 1e30 * p_mask
902
+
903
+ return x
904
+
905
+
906
+ class PoolerAnswerClass(nn.Module):
907
+ """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
908
+
909
+ def __init__(self, config):
910
+ super().__init__()
911
+ self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
912
+ self.activation = nn.Tanh()
913
+ self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
914
+
915
+ def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
916
+ """
917
+ Args:
918
+ One of ``start_states``, ``start_positions`` should be not None.
919
+ If both are set, ``start_positions`` overrides ``start_states``.
920
+
921
+ **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
922
+ hidden states of the first tokens for the labeled span.
923
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
924
+ position of the first token for the labeled span.
925
+ **cls_index**: torch.LongTensor of shape ``(batch_size,)``
926
+ position of the CLS token. If None, take the last token.
927
+
928
+ note(Original repo):
929
+ no dependency on end_feature so that we can obtain one single `cls_logits`
930
+ for each sample
931
+ """
932
+ hsz = hidden_states.shape[-1]
933
+ assert (
934
+ start_states is not None or start_positions is not None
935
+ ), "One of start_states, start_positions should be not None"
936
+ if start_positions is not None:
937
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
938
+ start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
939
+
940
+ if cls_index is not None:
941
+ cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
942
+ cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
943
+ else:
944
+ cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
945
+
946
+ x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
947
+ x = self.activation(x)
948
+ x = self.dense_1(x).squeeze(-1)
949
+
950
+ return x
951
+
952
+
953
+ class SQuADHead(nn.Module):
954
+ r""" A SQuAD head inspired by XLNet.
955
+
956
+ Parameters:
957
+ config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
958
+
959
+ Inputs:
960
+ **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
961
+ hidden states of sequence tokens
962
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
963
+ position of the first token for the labeled span.
964
+ **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
965
+ position of the last token for the labeled span.
966
+ **cls_index**: torch.LongTensor of shape ``(batch_size,)``
967
+ position of the CLS token. If None, take the last token.
968
+ **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
969
+ Whether the question has a possible answer in the paragraph or not.
970
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
971
+ Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
972
+ 1.0 means token should be masked.
973
+
974
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
975
+ **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
976
+ Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
977
+ **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
978
+ ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
979
+ Log probabilities for the top config.start_n_top start token possibilities (beam-search).
980
+ **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
981
+ ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
982
+ Indices for the top config.start_n_top start token possibilities (beam-search).
983
+ **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
984
+ ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
985
+ Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
986
+ **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
987
+ ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
988
+ Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
989
+ **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
990
+ ``torch.FloatTensor`` of shape ``(batch_size,)``
991
+ Log probabilities for the ``is_impossible`` label of the answers.
992
+ """
993
+
994
+ def __init__(self, config):
995
+ super().__init__()
996
+ self.start_n_top = config.start_n_top
997
+ self.end_n_top = config.end_n_top
998
+
999
+ self.start_logits = PoolerStartLogits(config)
1000
+ self.end_logits = PoolerEndLogits(config)
1001
+ self.answer_class = PoolerAnswerClass(config)
1002
+
1003
+ def forward(
1004
+ self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
1005
+ ):
1006
+ outputs = ()
1007
+
1008
+ start_logits = self.start_logits(hidden_states, p_mask=p_mask)
1009
+
1010
+ if start_positions is not None and end_positions is not None:
1011
+ # If we are on multi-GPU, let's remove the dimension added by batch splitting
1012
+ for x in (start_positions, end_positions, cls_index, is_impossible):
1013
+ if x is not None and x.dim() > 1:
1014
+ x.squeeze_(-1)
1015
+
1016
+ # during training, compute the end logits based on the ground truth of the start position
1017
+ end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
1018
+
1019
+ loss_fct = CrossEntropyLoss()
1020
+ start_loss = loss_fct(start_logits, start_positions)
1021
+ end_loss = loss_fct(end_logits, end_positions)
1022
+ total_loss = (start_loss + end_loss) / 2
1023
+
1024
+ if cls_index is not None and is_impossible is not None:
1025
+ # Predict answerability from the representation of CLS and START
1026
+ cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
1027
+ loss_fct_cls = nn.BCEWithLogitsLoss()
1028
+ cls_loss = loss_fct_cls(cls_logits, is_impossible)
1029
+
1030
+ # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
1031
+ total_loss += cls_loss * 0.5
1032
+
1033
+ outputs = (total_loss,) + outputs
1034
+
1035
+ else:
1036
+ # during inference, compute the end logits based on beam search
1037
+ bsz, slen, hsz = hidden_states.size()
1038
+ start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
1039
+
1040
+ start_top_log_probs, start_top_index = torch.topk(
1041
+ start_log_probs, self.start_n_top, dim=-1
1042
+ ) # shape (bsz, start_n_top)
1043
+ start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
1044
+ start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
1045
+ start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
1046
+
1047
+ hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
1048
+ start_states
1049
+ ) # shape (bsz, slen, start_n_top, hsz)
1050
+ p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
1051
+ end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
1052
+ end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
1053
+
1054
+ end_top_log_probs, end_top_index = torch.topk(
1055
+ end_log_probs, self.end_n_top, dim=1
1056
+ ) # shape (bsz, end_n_top, start_n_top)
1057
+ end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
1058
+ end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
1059
+
1060
+ start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
1061
+ cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
1062
+
1063
+ outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits,) + outputs
1064
+
1065
+ # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
1066
+ # or (if labels are provided) (total_loss,)
1067
+ return outputs
1068
+
1069
+
1070
+ class SequenceSummary(nn.Module):
1071
+ r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
1072
+ Args of the config class:
1073
+ summary_type:
1074
+ - 'last' => [default] take the last token hidden state (like XLNet)
1075
+ - 'first' => take the first token hidden state (like Bert)
1076
+ - 'mean' => take the mean of all tokens hidden states
1077
+ - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
1078
+ - 'attn' => Not implemented now, use multi-head attention
1079
+ summary_use_proj: Add a projection after the vector extraction
1080
+ summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
1081
+ summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default
1082
+ summary_first_dropout: Add a dropout before the projection and activation
1083
+ summary_last_dropout: Add a dropout after the projection and activation
1084
+ """
1085
+
1086
+ def __init__(self, config: PretrainedConfig):
1087
+ super().__init__()
1088
+
1089
+ self.summary_type = getattr(config, "summary_type", "last")
1090
+ if self.summary_type == "attn":
1091
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
1092
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
1093
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
1094
+ raise NotImplementedError
1095
+
1096
+ self.summary = Identity()
1097
+ if hasattr(config, "summary_use_proj") and config.summary_use_proj:
1098
+ if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
1099
+ num_classes = config.num_labels
1100
+ else:
1101
+ num_classes = config.hidden_size
1102
+ self.summary = nn.Linear(config.hidden_size, num_classes)
1103
+
1104
+ activation_string = getattr(config, "summary_activation", None)
1105
+ self.activation: Callable = (get_activation(activation_string) if activation_string else Identity())
1106
+
1107
+ self.first_dropout = Identity()
1108
+ if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
1109
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
1110
+
1111
+ self.last_dropout = Identity()
1112
+ if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
1113
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
1114
+
1115
+ def forward(self, hidden_states, cls_index=None):
1116
+ """ hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer.
1117
+ cls_index: [optional] position of the classification token if summary_type == 'cls_index',
1118
+ shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
1119
+ if summary_type == 'cls_index' and cls_index is None:
1120
+ we take the last token of the sequence as classification token
1121
+ """
1122
+ if self.summary_type == "last":
1123
+ output = hidden_states[:, -1]
1124
+ elif self.summary_type == "first":
1125
+ output = hidden_states[:, 0]
1126
+ elif self.summary_type == "mean":
1127
+ output = hidden_states.mean(dim=1)
1128
+ elif self.summary_type == "cls_index":
1129
+ if cls_index is None:
1130
+ cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long,)
1131
+ else:
1132
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
1133
+ cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
1134
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
1135
+ output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
1136
+ elif self.summary_type == "attn":
1137
+ raise NotImplementedError
1138
+
1139
+ output = self.first_dropout(output)
1140
+ output = self.summary(output)
1141
+ output = self.activation(output)
1142
+ output = self.last_dropout(output)
1143
+
1144
+ return output
1145
+
1146
+
1147
+ def prune_linear_layer(layer, index, dim=0):
1148
+ """ Prune a linear layer (a model parameters) to keep only entries in index.
1149
+ Return the pruned layer as a new layer with requires_grad=True.
1150
+ Used to remove heads.
1151
+ """
1152
+ index = index.to(layer.weight.device)
1153
+ W = layer.weight.index_select(dim, index).clone().detach()
1154
+ if layer.bias is not None:
1155
+ if dim == 1:
1156
+ b = layer.bias.clone().detach()
1157
+ else:
1158
+ b = layer.bias[index].clone().detach()
1159
+ new_size = list(layer.weight.size())
1160
+ new_size[dim] = len(index)
1161
+ new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
1162
+ new_layer.weight.requires_grad = False
1163
+ new_layer.weight.copy_(W.contiguous())
1164
+ new_layer.weight.requires_grad = True
1165
+ if layer.bias is not None:
1166
+ new_layer.bias.requires_grad = False
1167
+ new_layer.bias.copy_(b.contiguous())
1168
+ new_layer.bias.requires_grad = True
1169
+ return new_layer
1170
+
1171
+
1172
+ def prune_conv1d_layer(layer, index, dim=1):
1173
+ """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
1174
+ A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
1175
+ Return the pruned layer as a new layer with requires_grad=True.
1176
+ Used to remove heads.
1177
+ """
1178
+ index = index.to(layer.weight.device)
1179
+ W = layer.weight.index_select(dim, index).clone().detach()
1180
+ if dim == 0:
1181
+ b = layer.bias.clone().detach()
1182
+ else:
1183
+ b = layer.bias[index].clone().detach()
1184
+ new_size = list(layer.weight.size())
1185
+ new_size[dim] = len(index)
1186
+ new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
1187
+ new_layer.weight.requires_grad = False
1188
+ new_layer.weight.copy_(W.contiguous())
1189
+ new_layer.weight.requires_grad = True
1190
+ new_layer.bias.requires_grad = False
1191
+ new_layer.bias.copy_(b.contiguous())
1192
+ new_layer.bias.requires_grad = True
1193
+ return new_layer
1194
+
1195
+
1196
+ def prune_layer(layer, index, dim=None):
1197
+ """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
1198
+ Return the pruned layer as a new layer with requires_grad=True.
1199
+ Used to remove heads.
1200
+ """
1201
+ if isinstance(layer, nn.Linear):
1202
+ return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
1203
+ elif isinstance(layer, Conv1D):
1204
+ return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
1205
+ else:
1206
+ raise ValueError("Can't prune layer of class {}".format(layer.__class__))
1207
+
1208
+
1209
+ def apply_chunking_to_forward(
1210
+ chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors
1211
+ ) -> torch.Tensor:
1212
+ """
1213
+ This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`.
1214
+ It then applies a layer `forward_fn` to each chunk independently to save memory.
1215
+ If the `forward_fn` is independent across the `chunk_dim` this function will yield the
1216
+ same result as not applying it.
1217
+
1218
+ Args:
1219
+ chunk_size: int - the chunk size of a chunked tensor. `num_chunks` = `len(input_tensors[0]) / chunk_size`
1220
+ chunk_dim: int - the dimension over which the input_tensors should be chunked
1221
+ forward_fn: fn - the forward fn of the model
1222
+ input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked
1223
+ Returns:
1224
+ a Tensor with the same shape the foward_fn would have given if applied
1225
+
1226
+
1227
+ Examples::
1228
+
1229
+ # rename the usual forward() fn to forward_chunk()
1230
+ def forward_chunk(self, hidden_states):
1231
+ hidden_states = self.decoder(hidden_states)
1232
+ return hidden_states
1233
+
1234
+ # implement a chunked forward function
1235
+ def forward(self, hidden_states):
1236
+ return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
1237
+ """
1238
+
1239
+ assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
1240
+ tensor_shape = input_tensors[0].shape
1241
+ assert all(
1242
+ input_tensor.shape == tensor_shape for input_tensor in input_tensors
1243
+ ), "All input tenors have to be of the same shape"
1244
+
1245
+ # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability
1246
+ num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
1247
+ assert num_args_in_forward_chunk_fn == len(
1248
+ input_tensors
1249
+ ), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format(
1250
+ num_args_in_forward_chunk_fn, len(input_tensors)
1251
+ )
1252
+
1253
+ if chunk_size > 0:
1254
+ assert (
1255
+ input_tensors[0].shape[chunk_dim] % chunk_size == 0
1256
+ ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
1257
+ input_tensors[0].shape[chunk_dim], chunk_size
1258
+ )
1259
+
1260
+ num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
1261
+
1262
+ # chunk input tensor into tuples
1263
+ input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
1264
+ # apply forward fn to every tuple
1265
+ output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
1266
+ # concatenate output at same dimension
1267
+ return torch.cat(output_chunks, dim=chunk_dim)
1268
+
1269
+ return forward_fn(*input_tensors)
bert/tokenization_bert.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes."""
16
+
17
+
18
+ import collections
19
+ import logging
20
+ import os
21
+ import unicodedata
22
+ from typing import List, Optional
23
+
24
+ from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
30
+
31
+ PRETRAINED_VOCAB_FILES_MAP = {
32
+ "vocab_file": {
33
+ "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
34
+ "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
35
+ "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
36
+ "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
37
+ "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
38
+ "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
39
+ "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
40
+ "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
41
+ "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
42
+ "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
43
+ "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
44
+ "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
45
+ "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
46
+ "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
47
+ "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
48
+ "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
49
+ "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
50
+ "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt",
51
+ }
52
+ }
53
+
54
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
55
+ "bert-base-uncased": 512,
56
+ "bert-large-uncased": 512,
57
+ "bert-base-cased": 512,
58
+ "bert-large-cased": 512,
59
+ "bert-base-multilingual-uncased": 512,
60
+ "bert-base-multilingual-cased": 512,
61
+ "bert-base-chinese": 512,
62
+ "bert-base-german-cased": 512,
63
+ "bert-large-uncased-whole-word-masking": 512,
64
+ "bert-large-cased-whole-word-masking": 512,
65
+ "bert-large-uncased-whole-word-masking-finetuned-squad": 512,
66
+ "bert-large-cased-whole-word-masking-finetuned-squad": 512,
67
+ "bert-base-cased-finetuned-mrpc": 512,
68
+ "bert-base-german-dbmdz-cased": 512,
69
+ "bert-base-german-dbmdz-uncased": 512,
70
+ "TurkuNLP/bert-base-finnish-cased-v1": 512,
71
+ "TurkuNLP/bert-base-finnish-uncased-v1": 512,
72
+ "wietsedv/bert-base-dutch-cased": 512,
73
+ }
74
+
75
+ PRETRAINED_INIT_CONFIGURATION = {
76
+ "bert-base-uncased": {"do_lower_case": True},
77
+ "bert-large-uncased": {"do_lower_case": True},
78
+ "bert-base-cased": {"do_lower_case": False},
79
+ "bert-large-cased": {"do_lower_case": False},
80
+ "bert-base-multilingual-uncased": {"do_lower_case": True},
81
+ "bert-base-multilingual-cased": {"do_lower_case": False},
82
+ "bert-base-chinese": {"do_lower_case": False},
83
+ "bert-base-german-cased": {"do_lower_case": False},
84
+ "bert-large-uncased-whole-word-masking": {"do_lower_case": True},
85
+ "bert-large-cased-whole-word-masking": {"do_lower_case": False},
86
+ "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
87
+ "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
88
+ "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
89
+ "bert-base-german-dbmdz-cased": {"do_lower_case": False},
90
+ "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
91
+ "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
92
+ "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
93
+ "wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
94
+ }
95
+
96
+
97
+ def load_vocab(vocab_file):
98
+ """Loads a vocabulary file into a dictionary."""
99
+ vocab = collections.OrderedDict()
100
+ with open(vocab_file, "r", encoding="utf-8") as reader:
101
+ tokens = reader.readlines()
102
+ for index, token in enumerate(tokens):
103
+ token = token.rstrip("\n")
104
+ vocab[token] = index
105
+ return vocab
106
+
107
+
108
+ def whitespace_tokenize(text):
109
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
110
+ text = text.strip()
111
+ if not text:
112
+ return []
113
+ tokens = text.split()
114
+ return tokens
115
+
116
+
117
+ class BertTokenizer(PreTrainedTokenizer):
118
+ r"""
119
+ Constructs a BERT tokenizer. Based on WordPiece.
120
+
121
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
122
+ should refer to the superclass for more information regarding methods.
123
+
124
+ Args:
125
+ vocab_file (:obj:`string`):
126
+ File containing the vocabulary.
127
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
128
+ Whether to lowercase the input when tokenizing.
129
+ do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
130
+ Whether to do basic tokenization before WordPiece.
131
+ never_split (:obj:`Iterable`, `optional`, defaults to :obj:`None`):
132
+ Collection of tokens which will never be split during tokenization. Only has an effect when
133
+ :obj:`do_basic_tokenize=True`
134
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
135
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
136
+ token instead.
137
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
138
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
139
+ for sequence classification or for a text and a question for question answering.
140
+ It is also used as the last token of a sequence built with special tokens.
141
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
142
+ The token used for padding, for example when batching sequences of different lengths.
143
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
144
+ The classifier token which is used when doing sequence classification (classification of the whole
145
+ sequence instead of per-token classification). It is the first token of the sequence when built with
146
+ special tokens.
147
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
148
+ The token used for masking values. This is the token used when training this model with masked language
149
+ modeling. This is the token which the model will try to predict.
150
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
151
+ Whether to tokenize Chinese characters.
152
+ This should likely be deactivated for Japanese:
153
+ see: https://github.com/huggingface/transformers/issues/328
154
+ """
155
+
156
+ vocab_files_names = VOCAB_FILES_NAMES
157
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
158
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
159
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
160
+
161
+ def __init__(
162
+ self,
163
+ vocab_file,
164
+ do_lower_case=True,
165
+ do_basic_tokenize=True,
166
+ never_split=None,
167
+ unk_token="[UNK]",
168
+ sep_token="[SEP]",
169
+ pad_token="[PAD]",
170
+ cls_token="[CLS]",
171
+ mask_token="[MASK]",
172
+ tokenize_chinese_chars=True,
173
+ **kwargs
174
+ ):
175
+ super().__init__(
176
+ unk_token=unk_token,
177
+ sep_token=sep_token,
178
+ pad_token=pad_token,
179
+ cls_token=cls_token,
180
+ mask_token=mask_token,
181
+ **kwargs,
182
+ )
183
+
184
+ if not os.path.isfile(vocab_file):
185
+ raise ValueError(
186
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
187
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
188
+ )
189
+ self.vocab = load_vocab(vocab_file)
190
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
191
+ self.do_basic_tokenize = do_basic_tokenize
192
+ if do_basic_tokenize:
193
+ self.basic_tokenizer = BasicTokenizer(
194
+ do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars
195
+ )
196
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
197
+
198
+ @property
199
+ def vocab_size(self):
200
+ return len(self.vocab)
201
+
202
+ def get_vocab(self):
203
+ return dict(self.vocab, **self.added_tokens_encoder)
204
+
205
+ def _tokenize(self, text):
206
+ split_tokens = []
207
+ if self.do_basic_tokenize:
208
+ for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
209
+
210
+ # If the token is part of the never_split set
211
+ if token in self.basic_tokenizer.never_split:
212
+ split_tokens.append(token)
213
+ else:
214
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
215
+ else:
216
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
217
+ return split_tokens
218
+
219
+ def _convert_token_to_id(self, token):
220
+ """ Converts a token (str) in an id using the vocab. """
221
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
222
+
223
+ def _convert_id_to_token(self, index):
224
+ """Converts an index (integer) in a token (str) using the vocab."""
225
+ return self.ids_to_tokens.get(index, self.unk_token)
226
+
227
+ def convert_tokens_to_string(self, tokens):
228
+ """ Converts a sequence of tokens (string) in a single string. """
229
+ out_string = " ".join(tokens).replace(" ##", "").strip()
230
+ return out_string
231
+
232
+ def build_inputs_with_special_tokens(
233
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
234
+ ) -> List[int]:
235
+ """
236
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
237
+ by concatenating and adding special tokens.
238
+ A BERT sequence has the following format:
239
+
240
+ - single sequence: ``[CLS] X [SEP]``
241
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
242
+
243
+ Args:
244
+ token_ids_0 (:obj:`List[int]`):
245
+ List of IDs to which the special tokens will be added
246
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
247
+ Optional second list of IDs for sequence pairs.
248
+
249
+ Returns:
250
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
251
+ """
252
+ if token_ids_1 is None:
253
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
254
+ cls = [self.cls_token_id]
255
+ sep = [self.sep_token_id]
256
+ return cls + token_ids_0 + sep + token_ids_1 + sep
257
+
258
+ def get_special_tokens_mask(
259
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
260
+ ) -> List[int]:
261
+ """
262
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
263
+ special tokens using the tokenizer ``prepare_for_model`` method.
264
+
265
+ Args:
266
+ token_ids_0 (:obj:`List[int]`):
267
+ List of ids.
268
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
269
+ Optional second list of IDs for sequence pairs.
270
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
271
+ Set to True if the token list is already formatted with special tokens for the model
272
+
273
+ Returns:
274
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
275
+ """
276
+
277
+ if already_has_special_tokens:
278
+ if token_ids_1 is not None:
279
+ raise ValueError(
280
+ "You should not supply a second sequence if the provided sequence of "
281
+ "ids is already formated with special tokens for the model."
282
+ )
283
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
284
+
285
+ if token_ids_1 is not None:
286
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
287
+ return [1] + ([0] * len(token_ids_0)) + [1]
288
+
289
+ def create_token_type_ids_from_sequences(
290
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
291
+ ) -> List[int]:
292
+ """
293
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
294
+ A BERT sequence pair mask has the following format:
295
+
296
+ ::
297
+
298
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
299
+ | first sequence | second sequence |
300
+
301
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
302
+
303
+ Args:
304
+ token_ids_0 (:obj:`List[int]`):
305
+ List of ids.
306
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
307
+ Optional second list of IDs for sequence pairs.
308
+
309
+ Returns:
310
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
311
+ sequence(s).
312
+ """
313
+ sep = [self.sep_token_id]
314
+ cls = [self.cls_token_id]
315
+ if token_ids_1 is None:
316
+ return len(cls + token_ids_0 + sep) * [0]
317
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
318
+
319
+ def save_vocabulary(self, vocab_path):
320
+ """
321
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
322
+
323
+ Args:
324
+ vocab_path (:obj:`str`):
325
+ The directory in which to save the vocabulary.
326
+
327
+ Returns:
328
+ :obj:`Tuple(str)`: Paths to the files saved.
329
+ """
330
+ index = 0
331
+ if os.path.isdir(vocab_path):
332
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
333
+ else:
334
+ vocab_file = vocab_path
335
+ with open(vocab_file, "w", encoding="utf-8") as writer:
336
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
337
+ if index != token_index:
338
+ logger.warning(
339
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
340
+ " Please check that the vocabulary is not corrupted!".format(vocab_file)
341
+ )
342
+ index = token_index
343
+ writer.write(token + "\n")
344
+ index += 1
345
+ return (vocab_file,)
346
+
347
+
348
+ class BasicTokenizer(object):
349
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
350
+
351
+ def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
352
+ """ Constructs a BasicTokenizer.
353
+
354
+ Args:
355
+ **do_lower_case**: Whether to lower case the input.
356
+ **never_split**: (`optional`) list of str
357
+ Kept for backward compatibility purposes.
358
+ Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
359
+ List of token not to split.
360
+ **tokenize_chinese_chars**: (`optional`) boolean (default True)
361
+ Whether to tokenize Chinese characters.
362
+ This should likely be deactivated for Japanese:
363
+ see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
364
+ """
365
+ if never_split is None:
366
+ never_split = []
367
+ self.do_lower_case = do_lower_case
368
+ self.never_split = set(never_split)
369
+ self.tokenize_chinese_chars = tokenize_chinese_chars
370
+
371
+ def tokenize(self, text, never_split=None):
372
+ """ Basic Tokenization of a piece of text.
373
+ Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
374
+
375
+ Args:
376
+ **never_split**: (`optional`) list of str
377
+ Kept for backward compatibility purposes.
378
+ Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
379
+ List of token not to split.
380
+ """
381
+ # union() returns a new set by concatenating the two sets.
382
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
383
+
384
+ # This was added on November 1st, 2018 for the multilingual and Chinese
385
+ # models. This is also applied to the English models now, but it doesn't
386
+ # matter since the English models were not trained on any Chinese data
387
+ # and generally don't have any Chinese data in them (there are Chinese
388
+ # characters in the vocabulary because Wikipedia does have some Chinese
389
+ # words in the English Wikipedia.).
390
+ if self.tokenize_chinese_chars:
391
+ text = self._tokenize_chinese_chars(text)
392
+ orig_tokens = whitespace_tokenize(text)
393
+ split_tokens = []
394
+ for token in orig_tokens:
395
+ if self.do_lower_case and token not in never_split:
396
+ token = token.lower()
397
+ token = self._run_strip_accents(token)
398
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
399
+
400
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
401
+ return output_tokens
402
+
403
+ def _run_strip_accents(self, text):
404
+ """Strips accents from a piece of text."""
405
+ text = unicodedata.normalize("NFD", text)
406
+ output = []
407
+ for char in text:
408
+ cat = unicodedata.category(char)
409
+ if cat == "Mn":
410
+ continue
411
+ output.append(char)
412
+ return "".join(output)
413
+
414
+ def _run_split_on_punc(self, text, never_split=None):
415
+ """Splits punctuation on a piece of text."""
416
+ if never_split is not None and text in never_split:
417
+ return [text]
418
+ chars = list(text)
419
+ i = 0
420
+ start_new_word = True
421
+ output = []
422
+ while i < len(chars):
423
+ char = chars[i]
424
+ if _is_punctuation(char):
425
+ output.append([char])
426
+ start_new_word = True
427
+ else:
428
+ if start_new_word:
429
+ output.append([])
430
+ start_new_word = False
431
+ output[-1].append(char)
432
+ i += 1
433
+
434
+ return ["".join(x) for x in output]
435
+
436
+ def _tokenize_chinese_chars(self, text):
437
+ """Adds whitespace around any CJK character."""
438
+ output = []
439
+ for char in text:
440
+ cp = ord(char)
441
+ if self._is_chinese_char(cp):
442
+ output.append(" ")
443
+ output.append(char)
444
+ output.append(" ")
445
+ else:
446
+ output.append(char)
447
+ return "".join(output)
448
+
449
+ def _is_chinese_char(self, cp):
450
+ """Checks whether CP is the codepoint of a CJK character."""
451
+ # This defines a "chinese character" as anything in the CJK Unicode block:
452
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
453
+ #
454
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
455
+ # despite its name. The modern Korean Hangul alphabet is a different block,
456
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
457
+ # space-separated words, so they are not treated specially and handled
458
+ # like the all of the other languages.
459
+ if (
460
+ (cp >= 0x4E00 and cp <= 0x9FFF)
461
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
462
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
463
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
464
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
465
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
466
+ or (cp >= 0xF900 and cp <= 0xFAFF)
467
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
468
+ ): #
469
+ return True
470
+
471
+ return False
472
+
473
+ def _clean_text(self, text):
474
+ """Performs invalid character removal and whitespace cleanup on text."""
475
+ output = []
476
+ for char in text:
477
+ cp = ord(char)
478
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
479
+ continue
480
+ if _is_whitespace(char):
481
+ output.append(" ")
482
+ else:
483
+ output.append(char)
484
+ return "".join(output)
485
+
486
+
487
+ class WordpieceTokenizer(object):
488
+ """Runs WordPiece tokenization."""
489
+
490
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
491
+ self.vocab = vocab
492
+ self.unk_token = unk_token
493
+ self.max_input_chars_per_word = max_input_chars_per_word
494
+
495
+ def tokenize(self, text):
496
+ """Tokenizes a piece of text into its word pieces.
497
+
498
+ This uses a greedy longest-match-first algorithm to perform tokenization
499
+ using the given vocabulary.
500
+
501
+ For example:
502
+ input = "unaffable"
503
+ output = ["un", "##aff", "##able"]
504
+
505
+ Args:
506
+ text: A single token or whitespace separated tokens. This should have
507
+ already been passed through `BasicTokenizer`.
508
+
509
+ Returns:
510
+ A list of wordpiece tokens.
511
+ """
512
+
513
+ output_tokens = []
514
+ for token in whitespace_tokenize(text):
515
+ chars = list(token)
516
+ if len(chars) > self.max_input_chars_per_word:
517
+ output_tokens.append(self.unk_token)
518
+ continue
519
+
520
+ is_bad = False
521
+ start = 0
522
+ sub_tokens = []
523
+ while start < len(chars):
524
+ end = len(chars)
525
+ cur_substr = None
526
+ while start < end:
527
+ substr = "".join(chars[start:end])
528
+ if start > 0:
529
+ substr = "##" + substr
530
+ if substr in self.vocab:
531
+ cur_substr = substr
532
+ break
533
+ end -= 1
534
+ if cur_substr is None:
535
+ is_bad = True
536
+ break
537
+ sub_tokens.append(cur_substr)
538
+ start = end
539
+
540
+ if is_bad:
541
+ output_tokens.append(self.unk_token)
542
+ else:
543
+ output_tokens.extend(sub_tokens)
544
+ return output_tokens
545
+
bert/tokenization_utils.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Tokenization classes for python tokenizers.
16
+ For fast tokenizers (provided by HuggingFace's tokenizers library) see tokenization_utils_fast.py
17
+ """
18
+
19
+ import itertools
20
+ import logging
21
+ import re
22
+ import unicodedata
23
+ from typing import Dict, List, Optional, Tuple, Union
24
+
25
+ from .file_utils import add_end_docstrings
26
+ from .tokenization_utils_base import (
27
+ ENCODE_KWARGS_DOCSTRING,
28
+ ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
29
+ AddedToken,
30
+ BatchEncoding,
31
+ EncodedInput,
32
+ EncodedInputPair,
33
+ PaddingStrategy,
34
+ PreTokenizedInput,
35
+ PreTokenizedInputPair,
36
+ PreTrainedTokenizerBase,
37
+ TensorType,
38
+ TextInput,
39
+ TextInputPair,
40
+ TruncationStrategy,
41
+ )
42
+
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ def _is_whitespace(char):
48
+ """Checks whether `chars` is a whitespace character."""
49
+ # \t, \n, and \r are technically contorl characters but we treat them
50
+ # as whitespace since they are generally considered as such.
51
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
52
+ return True
53
+ cat = unicodedata.category(char)
54
+ if cat == "Zs":
55
+ return True
56
+ return False
57
+
58
+
59
+ def _is_control(char):
60
+ """Checks whether `chars` is a control character."""
61
+ # These are technically control characters but we count them as whitespace
62
+ # characters.
63
+ if char == "\t" or char == "\n" or char == "\r":
64
+ return False
65
+ cat = unicodedata.category(char)
66
+ if cat.startswith("C"):
67
+ return True
68
+ return False
69
+
70
+
71
+ def _is_punctuation(char):
72
+ """Checks whether `chars` is a punctuation character."""
73
+ cp = ord(char)
74
+ # We treat all non-letter/number ASCII as punctuation.
75
+ # Characters such as "^", "$", and "`" are not in the Unicode
76
+ # Punctuation class but we treat them as punctuation anyways, for
77
+ # consistency.
78
+ if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
79
+ return True
80
+ cat = unicodedata.category(char)
81
+ if cat.startswith("P"):
82
+ return True
83
+ return False
84
+
85
+
86
+ def _is_end_of_word(text):
87
+ """Checks whether the last character in text is one of a punctuation, control or whitespace character."""
88
+ last_char = text[-1]
89
+ return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char))
90
+
91
+
92
+ def _is_start_of_word(text):
93
+ """Checks whether the first character in text is one of a punctuation, control or whitespace character."""
94
+ first_char = text[0]
95
+ return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))
96
+
97
+
98
+ class PreTrainedTokenizer(PreTrainedTokenizerBase):
99
+ """ Base class for all slow tokenizers.
100
+
101
+ Handle all the shared methods for tokenization and special tokens as well as methods
102
+ downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
103
+
104
+ This class also contain the added tokens in a unified way on top of all tokenizers so we don't
105
+ have to handle the specific vocabulary augmentation methods of the various underlying
106
+ dictionary structures (BPE, sentencepiece...).
107
+
108
+ Class attributes (overridden by derived classes):
109
+
110
+ - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file
111
+ required by the model, and as associated values, the filename for saving the associated file (string).
112
+ - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys
113
+ being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the
114
+ `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the
115
+ associated pretrained vocabulary file.
116
+ - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained
117
+ models, and as associated values, the maximum length of the sequence inputs of this model, or None if the
118
+ model has no maximum input size.
119
+ - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the
120
+ pretrained models, and as associated values, a dictionnary of specific arguments to pass to the
121
+ ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the
122
+ ``from_pretrained()`` method.
123
+
124
+ Args:
125
+ - ``model_max_length``: (`Optional`) int: the maximum length in number of tokens for the inputs to the transformer model.
126
+ When the tokenizer is loaded with `from_pretrained`, this will be set to the value stored for the associated
127
+ model in ``max_model_input_sizes`` (see above). If no value is provided, will default to VERY_LARGE_INTEGER (`int(1e30)`).
128
+ no associated max_length can be found in ``max_model_input_sizes``.
129
+ - ``padding_side``: (`Optional`) string: the side on which the model should have padding applied.
130
+ Should be selected between ['right', 'left']
131
+ - ``model_input_names``: (`Optional`) List[string]: the list of the forward pass inputs accepted by the
132
+ model ("token_type_ids", "attention_mask"...).
133
+ - ``bos_token``: (`Optional`) string: a beginning of sentence token.
134
+ Will be associated to ``self.bos_token`` and ``self.bos_token_id``
135
+ - ``eos_token``: (`Optional`) string: an end of sentence token.
136
+ Will be associated to ``self.eos_token`` and ``self.eos_token_id``
137
+ - ``unk_token``: (`Optional`) string: an unknown token.
138
+ Will be associated to ``self.unk_token`` and ``self.unk_token_id``
139
+ - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence).
140
+ Will be associated to ``self.sep_token`` and ``self.sep_token_id``
141
+ - ``pad_token``: (`Optional`) string: a padding token.
142
+ Will be associated to ``self.pad_token`` and ``self.pad_token_id``
143
+ - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence
144
+ leveraging self-attention along the full depth of the model).
145
+ Will be associated to ``self.cls_token`` and ``self.cls_token_id``
146
+ - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language
147
+ modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id``
148
+ - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens.
149
+ Adding all special tokens here ensure they won't be split by the tokenization process.
150
+ Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
151
+
152
+
153
+ .. automethod:: __call__
154
+ """
155
+
156
+ def __init__(self, **kwargs):
157
+ super().__init__(**kwargs)
158
+
159
+ # Added tokens - We store this for both slow and fast tokenizers
160
+ # until the serialization of Fast tokenizers is updated
161
+ self.added_tokens_encoder: Dict[str, int] = {}
162
+ self.added_tokens_decoder: Dict[int, str] = {}
163
+ self.unique_no_split_tokens: List[str] = []
164
+
165
+ @property
166
+ def is_fast(self) -> bool:
167
+ return False
168
+
169
+ @property
170
+ def vocab_size(self) -> int:
171
+ """ Size of the base vocabulary (without the added tokens) """
172
+ raise NotImplementedError
173
+
174
+ def get_vocab(self):
175
+ """ Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """
176
+ raise NotImplementedError()
177
+
178
+ def get_added_vocab(self) -> Dict[str, int]:
179
+ return self.added_tokens_encoder
180
+
181
+ def __len__(self):
182
+ """ Size of the full vocabulary with the added tokens """
183
+ return self.vocab_size + len(self.added_tokens_encoder)
184
+
185
+ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens=False) -> int:
186
+ """
187
+ Add a list of new tokens to the tokenizer class. If the new tokens are not in the
188
+ vocabulary, they are added to it with indices starting from length of the current vocabulary.
189
+
190
+ Args:
191
+ new_tokens: string or list of string. Each string is a token to add. Tokens are only added if they are not
192
+ already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
193
+
194
+ Returns:
195
+ Number of tokens added to the vocabulary.
196
+
197
+ Examples::
198
+
199
+ # Let's see how to increase the vocabulary of Bert model and tokenizer
200
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
201
+ model = BertModel.from_pretrained('bert-base-uncased')
202
+
203
+ num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
204
+ print('We have added', num_added_toks, 'tokens')
205
+ model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
206
+ """
207
+ new_tokens = [str(tok) for tok in new_tokens]
208
+
209
+ tokens_to_add = []
210
+ for token in new_tokens:
211
+ assert isinstance(token, str)
212
+ if not special_tokens and self.init_kwargs.get("do_lower_case", False):
213
+ token = token.lower()
214
+ if (
215
+ token != self.unk_token
216
+ and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
217
+ and token not in tokens_to_add
218
+ ):
219
+ tokens_to_add.append(token)
220
+ if self.verbose:
221
+ logger.info("Adding %s to the vocabulary", token)
222
+
223
+ added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
224
+ added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
225
+ self.added_tokens_encoder.update(added_tok_encoder)
226
+ self.added_tokens_decoder.update(added_tok_decoder)
227
+
228
+ # Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert)
229
+ if special_tokens:
230
+ self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(new_tokens)))
231
+ else:
232
+ # Or on the newly added tokens
233
+ self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
234
+
235
+ return len(tokens_to_add)
236
+
237
+ def num_special_tokens_to_add(self, pair=False):
238
+ """
239
+ Returns the number of added tokens when encoding a sequence with special tokens.
240
+
241
+ Note:
242
+ This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this
243
+ inside your training loop.
244
+
245
+ Args:
246
+ pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the
247
+ number of added tokens in the case of a single sequence if set to False.
248
+
249
+ Returns:
250
+ Number of tokens added to sequences
251
+ """
252
+ token_ids_0 = []
253
+ token_ids_1 = []
254
+ return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
255
+
256
+ def tokenize(self, text: TextInput, **kwargs):
257
+ """ Converts a string in a sequence of tokens (string), using the tokenizer.
258
+ Split in words for word-based vocabulary or sub-words for sub-word-based
259
+ vocabularies (BPE/SentencePieces/WordPieces).
260
+
261
+ Take care of added tokens.
262
+
263
+ Args:
264
+ text (:obj:`string`): The sequence to be encoded.
265
+ **kwargs (:obj: `dict`): Arguments passed to the model-specific `prepare_for_tokenization` preprocessing method.
266
+ """
267
+ # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
268
+ all_special_tokens_extended = dict(
269
+ (str(t), t) for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
270
+ )
271
+
272
+ text, kwargs = self.prepare_for_tokenization(text, **kwargs)
273
+
274
+ if kwargs:
275
+ logger.warning(f"Keyword arguments {kwargs} not recognized.")
276
+
277
+ # TODO: should this be in the base class?
278
+ if self.init_kwargs.get("do_lower_case", False):
279
+ # convert non-special tokens to lowercase
280
+ escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
281
+ pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
282
+ text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
283
+
284
+ def split_on_token(tok, text):
285
+ result = []
286
+ tok_extended = all_special_tokens_extended.get(tok, None)
287
+ split_text = text.split(tok)
288
+ full_word = ""
289
+ for i, sub_text in enumerate(split_text):
290
+ # AddedToken can control whitespace stripping around them.
291
+ # We use them for GPT2 and Roberta to have different behavior depending on the special token
292
+ # Cf. https://github.com/huggingface/transformers/pull/2778
293
+ # and https://github.com/huggingface/transformers/issues/3788
294
+ if isinstance(tok_extended, AddedToken):
295
+ if tok_extended.single_word:
296
+ # Try to avoid splitting on token
297
+ if (
298
+ i < len(split_text) - 1
299
+ and not _is_end_of_word(sub_text)
300
+ and not _is_start_of_word(split_text[i + 1])
301
+ ):
302
+ # Don't extract the special token
303
+ full_word += sub_text + tok
304
+ elif full_word:
305
+ full_word += sub_text
306
+ result += [full_word]
307
+ full_word = ""
308
+ continue
309
+ # Strip white spaces on the right
310
+ if tok_extended.rstrip and i > 0:
311
+ # A bit counter-intuitive but we strip the left of the string
312
+ # since tok_extended.rstrip means the special token is eating all white spaces on its right
313
+ sub_text = sub_text.lstrip()
314
+ # Strip white spaces on the left
315
+ if tok_extended.lstrip and i < len(split_text) - 1:
316
+ sub_text = sub_text.rstrip() # Opposite here
317
+ else:
318
+ # We strip left and right by default
319
+ if i < len(split_text) - 1:
320
+ sub_text = sub_text.rstrip()
321
+ if i > 0:
322
+ sub_text = sub_text.lstrip()
323
+
324
+ if i == 0 and not sub_text:
325
+ result += [tok]
326
+ elif i == len(split_text) - 1:
327
+ if sub_text:
328
+ result += [sub_text]
329
+ else:
330
+ pass
331
+ else:
332
+ if sub_text:
333
+ result += [sub_text]
334
+ result += [tok]
335
+ return result
336
+
337
+ def split_on_tokens(tok_list, text):
338
+ if not text.strip():
339
+ return []
340
+ if not tok_list:
341
+ return self._tokenize(text)
342
+
343
+ tokenized_text = []
344
+ text_list = [text]
345
+ for tok in tok_list:
346
+ tokenized_text = []
347
+ for sub_text in text_list:
348
+ if sub_text not in self.unique_no_split_tokens:
349
+ tokenized_text += split_on_token(tok, sub_text)
350
+ else:
351
+ tokenized_text += [sub_text]
352
+ text_list = tokenized_text
353
+
354
+ return list(
355
+ itertools.chain.from_iterable(
356
+ (
357
+ self._tokenize(token) if token not in self.unique_no_split_tokens else [token]
358
+ for token in tokenized_text
359
+ )
360
+ )
361
+ )
362
+
363
+ no_split_token = self.unique_no_split_tokens
364
+ tokenized_text = split_on_tokens(no_split_token, text)
365
+ return tokenized_text
366
+
367
+ def _tokenize(self, text, **kwargs):
368
+ """ Converts a string in a sequence of tokens (string), using the tokenizer.
369
+ Split in words for word-based vocabulary or sub-words for sub-word-based
370
+ vocabularies (BPE/SentencePieces/WordPieces).
371
+
372
+ Do NOT take care of added tokens.
373
+ """
374
+ raise NotImplementedError
375
+
376
+ def convert_tokens_to_ids(self, tokens):
377
+ """ Converts a token string (or a sequence of tokens) in a single integer id
378
+ (or a sequence of ids), using the vocabulary.
379
+ """
380
+ if tokens is None:
381
+ return None
382
+
383
+ if isinstance(tokens, str):
384
+ return self._convert_token_to_id_with_added_voc(tokens)
385
+
386
+ ids = []
387
+ for token in tokens:
388
+ ids.append(self._convert_token_to_id_with_added_voc(token))
389
+ return ids
390
+
391
+ def _convert_token_to_id_with_added_voc(self, token):
392
+ if token is None:
393
+ return None
394
+
395
+ if token in self.added_tokens_encoder:
396
+ return self.added_tokens_encoder[token]
397
+ return self._convert_token_to_id(token)
398
+
399
+ def _convert_token_to_id(self, token):
400
+ raise NotImplementedError
401
+
402
+ def _encode_plus(
403
+ self,
404
+ text: Union[TextInput, PreTokenizedInput, EncodedInput],
405
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
406
+ add_special_tokens: bool = True,
407
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
408
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
409
+ max_length: Optional[int] = None,
410
+ stride: int = 0,
411
+ is_pretokenized: bool = False,
412
+ pad_to_multiple_of: Optional[int] = None,
413
+ return_tensors: Optional[Union[str, TensorType]] = None,
414
+ return_token_type_ids: Optional[bool] = None,
415
+ return_attention_mask: Optional[bool] = None,
416
+ return_overflowing_tokens: bool = False,
417
+ return_special_tokens_mask: bool = False,
418
+ return_offsets_mapping: bool = False,
419
+ return_length: bool = False,
420
+ verbose: bool = True,
421
+ **kwargs
422
+ ) -> BatchEncoding:
423
+ def get_input_ids(text):
424
+ if isinstance(text, str):
425
+ tokens = self.tokenize(text, **kwargs)
426
+ return self.convert_tokens_to_ids(tokens)
427
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
428
+ if is_pretokenized:
429
+ tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
430
+ return self.convert_tokens_to_ids(tokens)
431
+ else:
432
+ return self.convert_tokens_to_ids(text)
433
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
434
+ return text
435
+ else:
436
+ if is_pretokenized:
437
+ raise ValueError(
438
+ f"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_pretokenized=True`."
439
+ )
440
+ else:
441
+ raise ValueError(
442
+ f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
443
+ )
444
+
445
+ if return_offsets_mapping:
446
+ raise NotImplementedError(
447
+ "return_offset_mapping is not available when using Python tokenizers."
448
+ "To use this feature, change your tokenizer to one deriving from "
449
+ "transformers.PreTrainedTokenizerFast."
450
+ "More information on available tokenizers at "
451
+ "https://github.com/huggingface/transformers/pull/2674"
452
+ )
453
+
454
+ first_ids = get_input_ids(text)
455
+ second_ids = get_input_ids(text_pair) if text_pair is not None else None
456
+
457
+ return self.prepare_for_model(
458
+ first_ids,
459
+ pair_ids=second_ids,
460
+ add_special_tokens=add_special_tokens,
461
+ padding=padding_strategy.value,
462
+ truncation=truncation_strategy.value,
463
+ max_length=max_length,
464
+ stride=stride,
465
+ pad_to_multiple_of=pad_to_multiple_of,
466
+ return_tensors=return_tensors,
467
+ prepend_batch_axis=True,
468
+ return_attention_mask=return_attention_mask,
469
+ return_token_type_ids=return_token_type_ids,
470
+ return_overflowing_tokens=return_overflowing_tokens,
471
+ return_special_tokens_mask=return_special_tokens_mask,
472
+ return_length=return_length,
473
+ verbose=verbose,
474
+ )
475
+
476
+ def _batch_encode_plus(
477
+ self,
478
+ batch_text_or_text_pairs: Union[
479
+ List[TextInput],
480
+ List[TextInputPair],
481
+ List[PreTokenizedInput],
482
+ List[PreTokenizedInputPair],
483
+ List[EncodedInput],
484
+ List[EncodedInputPair],
485
+ ],
486
+ add_special_tokens: bool = True,
487
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
488
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
489
+ max_length: Optional[int] = None,
490
+ stride: int = 0,
491
+ is_pretokenized: bool = False,
492
+ pad_to_multiple_of: Optional[int] = None,
493
+ return_tensors: Optional[Union[str, TensorType]] = None,
494
+ return_token_type_ids: Optional[bool] = None,
495
+ return_attention_mask: Optional[bool] = None,
496
+ return_overflowing_tokens: bool = False,
497
+ return_special_tokens_mask: bool = False,
498
+ return_offsets_mapping: bool = False,
499
+ return_length: bool = False,
500
+ verbose: bool = True,
501
+ **kwargs
502
+ ) -> BatchEncoding:
503
+ def get_input_ids(text):
504
+ if isinstance(text, str):
505
+ tokens = self.tokenize(text, **kwargs)
506
+ return self.convert_tokens_to_ids(tokens)
507
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
508
+ if is_pretokenized:
509
+ tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
510
+ return self.convert_tokens_to_ids(tokens)
511
+ else:
512
+ return self.convert_tokens_to_ids(text)
513
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
514
+ return text
515
+ else:
516
+ raise ValueError(
517
+ "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
518
+ )
519
+
520
+ if return_offsets_mapping:
521
+ raise NotImplementedError(
522
+ "return_offset_mapping is not available when using Python tokenizers."
523
+ "To use this feature, change your tokenizer to one deriving from "
524
+ "transformers.PreTrainedTokenizerFast."
525
+ )
526
+
527
+ input_ids = []
528
+ for ids_or_pair_ids in batch_text_or_text_pairs:
529
+ if not isinstance(ids_or_pair_ids, (list, tuple)):
530
+ ids, pair_ids = ids_or_pair_ids, None
531
+ elif is_pretokenized and not isinstance(ids_or_pair_ids[0], (list, tuple)):
532
+ ids, pair_ids = ids_or_pair_ids, None
533
+ else:
534
+ ids, pair_ids = ids_or_pair_ids
535
+
536
+ first_ids = get_input_ids(ids)
537
+ second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
538
+ input_ids.append((first_ids, second_ids))
539
+
540
+ batch_outputs = self._batch_prepare_for_model(
541
+ input_ids,
542
+ add_special_tokens=add_special_tokens,
543
+ padding_strategy=padding_strategy,
544
+ truncation_strategy=truncation_strategy,
545
+ max_length=max_length,
546
+ stride=stride,
547
+ pad_to_multiple_of=pad_to_multiple_of,
548
+ return_attention_mask=return_attention_mask,
549
+ return_token_type_ids=return_token_type_ids,
550
+ return_overflowing_tokens=return_overflowing_tokens,
551
+ return_special_tokens_mask=return_special_tokens_mask,
552
+ return_length=return_length,
553
+ return_tensors=return_tensors,
554
+ verbose=verbose,
555
+ )
556
+
557
+ return BatchEncoding(batch_outputs)
558
+
559
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
560
+ def _batch_prepare_for_model(
561
+ self,
562
+ batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]],
563
+ add_special_tokens: bool = True,
564
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
565
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
566
+ max_length: Optional[int] = None,
567
+ stride: int = 0,
568
+ pad_to_multiple_of: Optional[int] = None,
569
+ return_tensors: Optional[str] = None,
570
+ return_token_type_ids: Optional[bool] = None,
571
+ return_attention_mask: Optional[bool] = None,
572
+ return_overflowing_tokens: bool = False,
573
+ return_special_tokens_mask: bool = False,
574
+ return_length: bool = False,
575
+ verbose: bool = True,
576
+ ) -> BatchEncoding:
577
+ """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
578
+ It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
579
+ manages a moving window (with user defined stride) for overflowing tokens
580
+
581
+ Args:
582
+ batch_ids_pairs: list of tokenized input ids or input ids pairs
583
+ """
584
+
585
+ batch_outputs = {}
586
+ for first_ids, second_ids in batch_ids_pairs:
587
+ outputs = self.prepare_for_model(
588
+ first_ids,
589
+ second_ids,
590
+ add_special_tokens=add_special_tokens,
591
+ padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward
592
+ truncation=truncation_strategy.value,
593
+ max_length=max_length,
594
+ stride=stride,
595
+ pad_to_multiple_of=None, # we pad in batch afterward
596
+ return_attention_mask=False, # we pad in batch afterward
597
+ return_token_type_ids=return_token_type_ids,
598
+ return_overflowing_tokens=return_overflowing_tokens,
599
+ return_special_tokens_mask=return_special_tokens_mask,
600
+ return_length=return_length,
601
+ return_tensors=None, # We convert the whole batch to tensors at the end
602
+ prepend_batch_axis=False,
603
+ verbose=verbose,
604
+ )
605
+
606
+ for key, value in outputs.items():
607
+ if key not in batch_outputs:
608
+ batch_outputs[key] = []
609
+ batch_outputs[key].append(value)
610
+
611
+ batch_outputs = self.pad(
612
+ batch_outputs,
613
+ padding=padding_strategy.value,
614
+ max_length=max_length,
615
+ pad_to_multiple_of=pad_to_multiple_of,
616
+ return_attention_mask=return_attention_mask,
617
+ )
618
+
619
+ batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
620
+
621
+ return batch_outputs
622
+
623
+ def prepare_for_tokenization(self, text: str, is_pretokenized=False, **kwargs) -> (str, dict):
624
+ """ Performs any necessary transformations before tokenization.
625
+
626
+ This method should pop the arguments from kwargs and return kwargs as well.
627
+ We test kwargs at the end of the encoding process to be sure all the arguments have been used.
628
+ """
629
+ return (text, kwargs)
630
+
631
+ def get_special_tokens_mask(
632
+ self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
633
+ ) -> List[int]:
634
+ """
635
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
636
+ special tokens using the tokenizer ``prepare_for_model`` method.
637
+
638
+ Args:
639
+ token_ids_0: list of ids (must not contain special tokens)
640
+ token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
641
+ for sequence pairs
642
+ already_has_special_tokens: (default False) Set to True if the token list is already formated with
643
+ special tokens for the model
644
+
645
+ Returns:
646
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
647
+ """
648
+ return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
649
+
650
+ def convert_ids_to_tokens(
651
+ self, ids: Union[int, List[int]], skip_special_tokens: bool = False
652
+ ) -> Union[str, List[str]]:
653
+ """ Converts a single index or a sequence of indices (integers) in a token "
654
+ (resp.) a sequence of tokens (str), using the vocabulary and added tokens.
655
+
656
+ Args:
657
+ skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
658
+ """
659
+ if isinstance(ids, int):
660
+ if ids in self.added_tokens_decoder:
661
+ return self.added_tokens_decoder[ids]
662
+ else:
663
+ return self._convert_id_to_token(ids)
664
+ tokens = []
665
+ for index in ids:
666
+ index = int(index)
667
+ if skip_special_tokens and index in self.all_special_ids:
668
+ continue
669
+ if index in self.added_tokens_decoder:
670
+ tokens.append(self.added_tokens_decoder[index])
671
+ else:
672
+ tokens.append(self._convert_id_to_token(index))
673
+ return tokens
674
+
675
+ def _convert_id_to_token(self, index: int) -> str:
676
+ raise NotImplementedError
677
+
678
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
679
+ """ Converts a sequence of tokens (string) in a single string.
680
+ The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
681
+ but we often want to remove sub-word tokenization artifacts at the same time.
682
+ """
683
+ return " ".join(self.convert_ids_to_tokens(tokens))
684
+
685
+ def decode(
686
+ self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
687
+ ) -> str:
688
+ filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
689
+
690
+ # To avoid mixing byte-level and unicode for byte-level BPT
691
+ # we need to build string separatly for added tokens and byte-level tokens
692
+ # cf. https://github.com/huggingface/transformers/issues/1133
693
+ sub_texts = []
694
+ current_sub_text = []
695
+ for token in filtered_tokens:
696
+ if skip_special_tokens and token in self.all_special_ids:
697
+ continue
698
+ if token in self.added_tokens_encoder:
699
+ if current_sub_text:
700
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
701
+ current_sub_text = []
702
+ sub_texts.append(token)
703
+ else:
704
+ current_sub_text.append(token)
705
+ if current_sub_text:
706
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
707
+ text = " ".join(sub_texts)
708
+
709
+ if clean_up_tokenization_spaces:
710
+ clean_text = self.clean_up_tokenization(text)
711
+ return clean_text
712
+ else:
713
+ return text
714
+
715
+ def save_vocabulary(self, save_directory) -> Tuple[str]:
716
+ """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
717
+ and special token mappings.
718
+
719
+ Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full
720
+ Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained`
721
+ class method.
722
+ """
723
+ raise NotImplementedError
bert/tokenization_utils_base.py ADDED
The diff for this file is too large to render. See raw diff
 
criterions/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .label_smoothed_cross_entropy import AdjustLabelSmoothedCrossEntropyCriterion
criterions/label_smoothed_cross_entropy.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Modified from OFA (https://github.com/OFA-Sys/OFA)
3
+ # Copyright 2022 The OFA-Sys Team.
4
+ # All rights reserved.
5
+ # This source code is licensed under the Apache 2.0 license
6
+ # found in the LICENSE file in the root directory.
7
+ # ------------------------------------------------------------------------
8
+ # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
9
+ # SPDX-License-Identifier: Apache-2.0
10
+
11
+ import math
12
+ from dataclasses import dataclass, field
13
+ from typing import Optional
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import numpy as np
18
+ from fairseq import metrics, utils
19
+ from fairseq.criterions import FairseqCriterion, register_criterion
20
+ from fairseq.dataclass import FairseqDataclass
21
+ from omegaconf import II
22
+
23
+
24
+ @dataclass
25
+ class AdjustLabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
26
+ label_smoothing: float = field(
27
+ default=0.0,
28
+ metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
29
+ )
30
+ report_accuracy: bool = field(
31
+ default=False,
32
+ metadata={"help": "report accuracy metric"},
33
+ )
34
+ det_weight: float = field(
35
+ default=1.0,
36
+ metadata={"help": "weight of detection loss"},
37
+ )
38
+ cls_weight: float = field(
39
+ default=1.0,
40
+ metadata={"help": "weight of classification loss"},
41
+ )
42
+
43
+ ignore_prefix_size: int = field(
44
+ default=0,
45
+ metadata={"help": "Ignore first N tokens"},
46
+ )
47
+ ignore_eos: bool = field(
48
+ default=False,
49
+ metadata={"help": "Ignore eos token"},
50
+ )
51
+ sentence_avg: bool = II("optimization.sentence_avg")
52
+ drop_worst_ratio: float = field(
53
+ default=0.0,
54
+ metadata={"help": "ratio for discarding bad samples"},
55
+ )
56
+ drop_worst_after: int = field(
57
+ default=0,
58
+ metadata={"help": "steps for discarding bad samples"},
59
+ )
60
+ use_rdrop: bool = field(
61
+ default=False, metadata={"help": "use R-Drop"}
62
+ )
63
+ reg_alpha: float = field(
64
+ default=1.0, metadata={"help": "weight for R-Drop"}
65
+ )
66
+ sample_patch_num: int = field(
67
+ default=196, metadata={"help": "sample patches for v1"}
68
+ )
69
+ constraint_range: Optional[str] = field(
70
+ default=None,
71
+ metadata={"help": "constraint range"}
72
+ )
73
+
74
+
75
+ def construct_rdrop_sample(x):
76
+ if isinstance(x, dict):
77
+ for key in x:
78
+ x[key] = construct_rdrop_sample(x[key])
79
+ return x
80
+ elif isinstance(x, torch.Tensor):
81
+ return x.repeat(2, *([1] * (x.dim() - 1)))
82
+ elif isinstance(x, int):
83
+ return x * 2
84
+ elif isinstance(x, np.ndarray):
85
+ return x.repeat(2)
86
+ else:
87
+ raise NotImplementedError
88
+
89
+
90
+ def kl_loss(p, q):
91
+ p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
92
+ q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
93
+ loss = (p_loss + q_loss) / 2
94
+ return loss
95
+
96
+
97
+ def label_smoothed_nll_loss(
98
+ lprobs, target, epsilon, update_num, reduce=True,
99
+ drop_worst_ratio=0.0, drop_worst_after=0, use_rdrop=False, reg_alpha=1.0,
100
+ constraint_masks=None, constraint_start=None, constraint_end=None
101
+ ):
102
+ if target.dim() == lprobs.dim() - 1:
103
+ target = target.unsqueeze(-1)
104
+ nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
105
+ if constraint_masks is not None:
106
+ smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(dim=-1, keepdim=True).squeeze(-1)
107
+ eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
108
+ elif constraint_start is not None and constraint_end is not None:
109
+ constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
110
+ smooth_loss = -lprobs[:, constraint_range].sum(dim=-1, keepdim=True).squeeze(-1)
111
+ eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
112
+ else:
113
+ smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
114
+ eps_i = epsilon / (lprobs.size(-1) - 1)
115
+ loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
116
+ if drop_worst_ratio > 0 and update_num > drop_worst_after:
117
+ if use_rdrop:
118
+ true_batch_size = loss.size(0) // 2
119
+ _, indices = torch.topk(loss[:true_batch_size], k=int(true_batch_size * (1 - drop_worst_ratio)), largest=False)
120
+ loss = torch.cat([loss[indices], loss[indices+true_batch_size]])
121
+ nll_loss = torch.cat([nll_loss[indices], nll_loss[indices+true_batch_size]])
122
+ lprobs = torch.cat([lprobs[indices], lprobs[indices+true_batch_size]])
123
+ else:
124
+ loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_worst_ratio)), largest=False)
125
+ nll_loss = nll_loss[indices]
126
+ lprobs = lprobs[indices]
127
+
128
+
129
+ ntokens = loss.numel()
130
+ nll_loss = nll_loss.sum()
131
+
132
+ loss = loss.sum()
133
+ if use_rdrop:
134
+ true_batch_size = lprobs.size(0) // 2
135
+ p = lprobs[:true_batch_size]
136
+ q = lprobs[true_batch_size:]
137
+ if constraint_start is not None and constraint_end is not None:
138
+ constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
139
+ p = p[:, constraint_range]
140
+ q = q[:, constraint_range]
141
+ loss += kl_loss(p, q) * reg_alpha
142
+
143
+ return loss, nll_loss, ntokens
144
+
145
+ @register_criterion(
146
+ "adjust_label_smoothed_cross_entropy", dataclass=AdjustLabelSmoothedCrossEntropyCriterionConfig
147
+ )
148
+ class AdjustLabelSmoothedCrossEntropyCriterion(FairseqCriterion):
149
+ def __init__(
150
+ self,
151
+ task,
152
+ sentence_avg,
153
+ label_smoothing,
154
+ ignore_prefix_size=0,
155
+ ignore_eos=False,
156
+ report_accuracy=False,
157
+ drop_worst_ratio=0,
158
+ drop_worst_after=0,
159
+ use_rdrop=False,
160
+ reg_alpha=1.0,
161
+ sample_patch_num=196,
162
+ constraint_range=None,
163
+ det_weight=1.0,
164
+ cls_weight=1.0
165
+ ):
166
+ super().__init__(task)
167
+ self.sentence_avg = sentence_avg
168
+ self.eps = label_smoothing
169
+ self.ignore_prefix_size = ignore_prefix_size
170
+ self.ignore_eos = ignore_eos
171
+ self.report_accuracy = report_accuracy
172
+ self.drop_worst_ratio = drop_worst_ratio
173
+ self.drop_worst_after = drop_worst_after
174
+ self.use_rdrop = use_rdrop
175
+ self.reg_alpha = reg_alpha
176
+ self.sample_patch_num = sample_patch_num
177
+
178
+ self.det_weight = det_weight
179
+ self.cls_weight = cls_weight
180
+
181
+ self.constraint_start = None
182
+ self.constraint_end = None
183
+ if constraint_range is not None:
184
+ constraint_start, constraint_end = constraint_range.split(',')
185
+ self.constraint_start = int(constraint_start)
186
+ self.constraint_end = int(constraint_end)
187
+
188
+ def forward(self, model, sample, update_num=0, reduce=True):
189
+ """Compute the loss for the given sample.
190
+
191
+ Returns a tuple with three elements:
192
+ 1) the loss
193
+ 2) the sample size, which is used as the denominator for the gradient
194
+ 3) logging outputs to display while training
195
+ """
196
+ if isinstance(sample, list):
197
+ if self.sample_patch_num > 0:
198
+ sample[0]['net_input']['sample_patch_num'] = self.sample_patch_num
199
+ loss_v1, sample_size_v1, logging_output_v1 = self.forward(model, sample[0], update_num, reduce)
200
+ loss_v2, sample_size_v2, logging_output_v2 = self.forward(model, sample[1], update_num, reduce)
201
+ loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
202
+ sample_size = 1
203
+ logging_output = {
204
+ "loss": loss.data,
205
+ "loss_v1": loss_v1.data,
206
+ "loss_v2": loss_v2.data,
207
+ "nll_loss": logging_output_v1["nll_loss"].data / sample_size_v1 + logging_output_v2[
208
+ "nll_loss"].data / sample_size_v2,
209
+ "ntokens": logging_output_v1["ntokens"] + logging_output_v2["ntokens"],
210
+ "nsentences": logging_output_v1["nsentences"] + logging_output_v2["nsentences"],
211
+ "sample_size": 1,
212
+ "sample_size_v1": sample_size_v1,
213
+ "sample_size_v2": sample_size_v2,
214
+ }
215
+ return loss, sample_size, logging_output
216
+
217
+ if self.use_rdrop:
218
+ construct_rdrop_sample(sample)
219
+
220
+ net_output = model(**sample["net_input"])
221
+ loss, nll_loss, ntokens = self.compute_loss(model, net_output, sample, update_num, det_weight=self.det_weight,
222
+ cls_weight=self.cls_weight, reduce=reduce)
223
+ sample_size = (
224
+ sample["target"].size(0)
225
+ )
226
+ logging_output = {
227
+ "loss": loss.data,
228
+ "nll_loss": nll_loss.data,
229
+ "ntokens": sample["ntokens"],
230
+ "nsentences": sample["nsentences"],
231
+ "sample_size": sample_size,
232
+ }
233
+ if self.report_accuracy:
234
+ n_correct, total = self.compute_accuracy(model, net_output, sample)
235
+ logging_output["n_correct"] = utils.item(n_correct.data)
236
+ logging_output["total"] = utils.item(total.data)
237
+ return loss, sample_size, logging_output
238
+
239
+ def get_lprobs_and_target(self, model, net_output, sample):
240
+ conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
241
+ constraint_masks = None
242
+ if "constraint_masks" in sample and sample["constraint_masks"] is not None:
243
+ constraint_masks = sample["constraint_masks"]
244
+ net_output[0].masked_fill_(~constraint_masks, -math.inf)
245
+ if self.constraint_start is not None and self.constraint_end is not None:
246
+ net_output[0][:, :, 4:self.constraint_start] = -math.inf
247
+ net_output[0][:, :, self.constraint_end:] = -math.inf
248
+ lprobs = model.get_normalized_probs(net_output, log_probs=True) * conf
249
+ target = sample["token_type"]
250
+ if self.ignore_prefix_size > 0:
251
+ lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous()
252
+ target = target[:, self.ignore_prefix_size:].contiguous()
253
+ if constraint_masks is not None:
254
+ constraint_masks = constraint_masks[:, self.ignore_prefix_size:, :].contiguous()
255
+ if self.ignore_eos:
256
+ bsz, seq_len, embed_dim = lprobs.size()
257
+ eos_indices = target.eq(self.task.tgt_dict.eos())
258
+ lprobs = lprobs[~eos_indices].reshape(bsz, seq_len - 1, embed_dim)
259
+ target = target[~eos_indices].reshape(bsz, seq_len - 1)
260
+ if constraint_masks is not None:
261
+ constraint_masks = constraint_masks[~eos_indices].reshape(bsz, seq_len - 1, embed_dim)
262
+ if constraint_masks is not None:
263
+ constraint_masks = constraint_masks.view(-1, constraint_masks.size(-1))
264
+
265
+ # index = torch.zeros(lprobs.shape[:2]).to(lprobs.device)
266
+ # index[:, :4] = 1 # 1 indicates the location of detection results
267
+
268
+ return lprobs.view(-1, lprobs.size(-1)), target.view(-1), constraint_masks, None # index.view(-1)
269
+
270
+ def compute_loss(self, model, net_output, sample, update_num, det_weight=1.0, cls_weight=1.0, reduce=True):
271
+ b = sample['target'].shape[0]
272
+ lprobs, target, constraint_masks, index = self.get_lprobs_and_target(model, net_output, sample)
273
+ if constraint_masks is not None:
274
+ constraint_masks = constraint_masks[target != -1]
275
+ # index = index[target != self.padding_idx]
276
+ lprobs = lprobs[target != -1]
277
+ target = target[target != -1]
278
+
279
+ loss_cls, nll_loss, ntokens = label_smoothed_nll_loss(
280
+ lprobs,
281
+ target,
282
+ self.eps,
283
+ update_num,
284
+ reduce=reduce,
285
+ drop_worst_ratio=self.drop_worst_ratio,
286
+ drop_worst_after=self.drop_worst_after,
287
+ use_rdrop=self.use_rdrop,
288
+ reg_alpha=self.reg_alpha,
289
+ constraint_masks=constraint_masks,
290
+ constraint_start=self.constraint_start,
291
+ constraint_end=self.constraint_end
292
+ )
293
+ loss_cls = cls_weight * loss_cls/b
294
+
295
+ # compute regression loss
296
+ token_type = sample["token_type"]
297
+ token_type = torch.stack([token_type, token_type], -1)
298
+ target = sample["target"]
299
+ index = torch.zeros_like(target).to(target.device)
300
+ index[:, :2, :] = 1 # the first two tokens are bbox points; 1 indicates the location of detection results
301
+
302
+ target = target[token_type == 0]
303
+ index = index[token_type == 0]
304
+ regression_output = net_output[1].squeeze(-1)
305
+ regression_output = regression_output[token_type == 0]
306
+
307
+ loss_reg = F.l1_loss(target[index == 1], regression_output[index == 1]) * det_weight
308
+ if (index == 0).any():
309
+ loss_reg += F.l1_loss(target[index == 0], regression_output[index == 0])
310
+
311
+ loss = loss_reg + loss_cls
312
+ if update_num % 5000 == 1:
313
+ print(f"loss_reg: {loss_reg.item()} loss_cls: {loss_cls.item()}")
314
+
315
+ return loss, nll_loss, ntokens
316
+
317
+ def compute_accuracy(self, model, net_output, sample):
318
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
319
+ mask = target.ne(self.padding_idx)
320
+ n_correct = torch.sum(
321
+ lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
322
+ )
323
+ total = torch.sum(mask)
324
+ return n_correct, total
325
+
326
+ @classmethod
327
+ def reduce_metrics(cls, logging_outputs) -> None:
328
+ """Aggregate logging outputs from data parallel training."""
329
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
330
+ loss_sum_v1 = sum(log.get("loss_v1", 0) for log in logging_outputs)
331
+ loss_sum_v2 = sum(log.get("loss_v2", 0) for log in logging_outputs)
332
+ nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
333
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
334
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
335
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
336
+ sample_size_v1 = sum(log.get("sample_size_v1", 0) for log in logging_outputs)
337
+ sample_size_v2 = sum(log.get("sample_size_v2", 0) for log in logging_outputs)
338
+
339
+ metrics.log_scalar(
340
+ "loss", loss_sum / sample_size, sample_size, round=3
341
+ )
342
+ metrics.log_scalar(
343
+ "loss_v1", loss_sum_v1 / max(sample_size_v1, 1), max(sample_size_v1, 1), round=3
344
+ )
345
+ metrics.log_scalar(
346
+ "loss_v2", loss_sum_v2 / max(sample_size_v2, 1), max(sample_size_v2, 1), round=3
347
+ )
348
+ metrics.log_scalar(
349
+ "nll_loss", nll_loss_sum / sample_size, ntokens, round=3
350
+ )
351
+ metrics.log_derived(
352
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
353
+ )
354
+
355
+ metrics.log_scalar(
356
+ "ntokens", ntokens, 1, round=3
357
+ )
358
+ metrics.log_scalar(
359
+ "nsentences", nsentences, 1, round=3
360
+ )
361
+ metrics.log_scalar(
362
+ "sample_size", sample_size, 1, round=3
363
+ )
364
+ metrics.log_scalar(
365
+ "sample_size_v1", sample_size_v1, 1, round=3
366
+ )
367
+ metrics.log_scalar(
368
+ "sample_size_v2", sample_size_v2, 1, round=3
369
+ )
370
+
371
+ total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
372
+ if total > 0:
373
+ metrics.log_scalar("total", total)
374
+ n_correct = utils.item(
375
+ sum(log.get("n_correct", 0) for log in logging_outputs)
376
+ )
377
+ metrics.log_scalar("n_correct", n_correct)
378
+ metrics.log_derived(
379
+ "accuracy",
380
+ lambda meters: round(
381
+ meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
382
+ )
383
+ if meters["total"].sum > 0
384
+ else float("nan"),
385
+ )
386
+
387
+ @staticmethod
388
+ def logging_outputs_can_be_summed() -> bool:
389
+ """
390
+ Whether the logging outputs returned by `forward` can be summed
391
+ across workers prior to calling `reduce_metrics`. Setting this
392
+ to True will improves distributed training speed.
393
+ """
394
+ return True
data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .refcoco_dataset import RefcocoDataset
2
+ from .refcoco_pretrain_dataset import RefcocoPretrainDataset
data/base_dataset.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Modified from OFA (https://github.com/OFA-Sys/OFA)
3
+ # Copyright 2022 The OFA-Sys Team.
4
+ # All rights reserved.
5
+ # This source code is licensed under the Apache 2.0 license
6
+ # found in the LICENSE file in the root directory.
7
+ # ------------------------------------------------------------------------
8
+ # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
9
+ # SPDX-License-Identifier: Apache-2.0
10
+
11
+ import logging
12
+ import re
13
+ import torch.utils.data
14
+ from fairseq.data import FairseqDataset
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class BaseDataset(FairseqDataset):
20
+ def __init__(self, split, dataset, bpe, src_dict, tgt_dict):
21
+ self.split = split
22
+ self.dataset = dataset
23
+ self.bpe = bpe
24
+ self.src_dict = src_dict
25
+ self.tgt_dict = tgt_dict
26
+
27
+ self.bos = src_dict.bos()
28
+ self.eos = src_dict.eos()
29
+ self.pad = src_dict.pad()
30
+ self.bos_item = torch.LongTensor([self.bos])
31
+ self.eos_item = torch.LongTensor([self.eos])
32
+
33
+ def __len__(self):
34
+ return len(self.dataset)
35
+
36
+ def encode_text(self, text, length=None, append_bos=False, append_eos=False, use_bpe=True):
37
+ s = self.tgt_dict.encode_line(
38
+ line=self.bpe.encode(text) if use_bpe else text,
39
+ add_if_not_exist=False,
40
+ append_eos=False
41
+ ).long()
42
+ if length is not None:
43
+ s = s[:length]
44
+ if append_bos:
45
+ s = torch.cat([self.bos_item, s])
46
+ if append_eos:
47
+ s = torch.cat([s, self.eos_item])
48
+ return s
49
+
50
+ def pre_question(self, question, max_ques_words):
51
+ question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ')
52
+
53
+ question = re.sub(
54
+ r"\s{2,}",
55
+ ' ',
56
+ question,
57
+ )
58
+ question = question.rstrip('\n')
59
+ question = question.strip(' ')
60
+
61
+ # truncate question
62
+ question_words = question.split(' ')
63
+ if len(question_words) > max_ques_words:
64
+ question = ' '.join(question_words[:max_ques_words])
65
+
66
+ return question
67
+
68
+ def pre_caption(self, caption, max_words):
69
+ caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
70
+
71
+ caption = re.sub(
72
+ r"\s{2,}",
73
+ ' ',
74
+ caption,
75
+ )
76
+ caption = caption.rstrip('\n')
77
+ caption = caption.strip(' ')
78
+
79
+ # truncate caption
80
+ caption_words = caption.split(' ')
81
+ if len(caption_words) > max_words:
82
+ caption = ' '.join(caption_words[:max_words])
83
+
84
+ return caption
data/create_finetuning_data.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from refer.refer import REFER
2
+ import numpy as np
3
+ from PIL import Image
4
+ import random
5
+ import os
6
+ from tqdm import tqdm
7
+
8
+ import pickle
9
+ from poly_utils import is_clockwise, revert_direction, check_length, reorder_points, \
10
+ approximate_polygons, interpolate_polygons, image_to_base64, polygons_to_string
11
+
12
+
13
+ max_length = 400
14
+
15
+ data_root = './refer/data'
16
+ datasets = ['refcoco', 'refcoco+', 'refcocog']
17
+
18
+ image_dir = './datasets/images/mscoco/train2014'
19
+ val_test_files = pickle.load(open("data/val_test_files.p", "rb"))
20
+
21
+ combined_train_data = []
22
+
23
+ for dataset in datasets:
24
+ if dataset == 'refcoco':
25
+ splits = ['train', 'val', 'testA', 'testB']
26
+ splitBy = 'unc'
27
+ elif dataset == 'refcoco+':
28
+ splits = ['train', 'val', 'testA', 'testB']
29
+ splitBy = 'unc'
30
+ elif dataset == 'refcocog':
31
+ splits = ['train', 'val']
32
+ splitBy = 'umd'
33
+
34
+ save_dir = f'datasets/finetune/{dataset}'
35
+ os.makedirs(save_dir, exist_ok=True)
36
+ for split in splits:
37
+ num_pts = []
38
+ max_num_pts = 0
39
+ file_name = os.path.join(save_dir, f"{dataset}_{split}.tsv")
40
+ print("creating ", file_name)
41
+
42
+ uniq_ids = []
43
+ image_ids = []
44
+ sents = []
45
+ coeffs_strings = []
46
+ img_strings = []
47
+
48
+ writer = open(file_name, 'w')
49
+ refer = REFER(data_root, dataset, splitBy)
50
+
51
+ ref_ids = refer.getRefIds(split=split)
52
+
53
+ for this_ref_id in tqdm(ref_ids):
54
+ this_img_id = refer.getImgIds(this_ref_id)
55
+ this_img = refer.Imgs[this_img_id[0]]
56
+ fn = this_img['file_name']
57
+ img_id = fn.split(".")[0].split("_")[-1]
58
+
59
+ # load image
60
+ img = Image.open(os.path.join(image_dir, this_img['file_name'])).convert("RGB")
61
+
62
+ # convert image to string
63
+ img_base64 = image_to_base64(img, format='jpeg')
64
+
65
+ # load mask
66
+ ref = refer.loadRefs(this_ref_id)
67
+ ref_mask = np.array(refer.getMask(ref[0])['mask'])
68
+ annot = np.zeros(ref_mask.shape)
69
+ annot[ref_mask == 1] = 1 # 255
70
+ annot_img = Image.fromarray(annot.astype(np.uint8), mode="P")
71
+ annot_base64 = image_to_base64(annot_img, format='png')
72
+
73
+ polygons = refer.getPolygon(ref[0])['polygon']
74
+
75
+ polygons_processed = []
76
+ for polygon in polygons:
77
+ # make the polygon clockwise
78
+ if not is_clockwise(polygon):
79
+ polygon = revert_direction(polygon)
80
+
81
+ # reorder the polygon so that the first vertex is the one closest to image origin
82
+ polygon = reorder_points(polygon)
83
+ polygons_processed.append(polygon)
84
+
85
+ polygons = sorted(polygons_processed, key=lambda x: (x[0] ** 2 + x[1] ** 2, x[0], x[1]))
86
+ polygons_interpolated = interpolate_polygons(polygons)
87
+
88
+ polygons = approximate_polygons(polygons, 5, max_length)
89
+
90
+ pts_string = polygons_to_string(polygons)
91
+ pts_string_interpolated = polygons_to_string(polygons_interpolated)
92
+
93
+ # load box
94
+ box = refer.getRefBox(this_ref_id) # x,y,w,h
95
+ x, y, w, h = box
96
+ box_string = f'{x},{y},{x + w},{y + h}'
97
+
98
+ max_num_pts = max(max_num_pts, check_length(polygons))
99
+
100
+ num_pts.append(check_length(polygons))
101
+ # load text
102
+ ref_sent = refer.Refs[this_ref_id]
103
+ for i, (sent, sent_id) in enumerate(zip(ref_sent['sentences'], ref_sent['sent_ids'])):
104
+ uniq_id = f"{this_ref_id}_{i}"
105
+ instance = '\t'.join(
106
+ [uniq_id, str(this_img_id[0]), sent['sent'], box_string, pts_string, img_base64, annot_base64,
107
+ pts_string_interpolated]) + '\n'
108
+ writer.write(instance)
109
+
110
+ if img_id not in val_test_files and split == 'train': # filtered out val/test files
111
+ combined_train_data.append(instance)
112
+ writer.close()
113
+
114
+ random.shuffle(combined_train_data)
115
+ file_name = os.path.join("datasets/finetune/refcoco+g_train_shuffled.tsv")
116
+ print("creating ", file_name)
117
+ writer = open(file_name, 'w')
118
+ writer.writelines(combined_train_data)
119
+ writer.close()
120
+
121
+
122
+
123
+
data/create_pretraining_data.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from tqdm import tqdm
4
+ import random
5
+ import pickle
6
+
7
+ # set up image paths
8
+ imgsfile = dict(
9
+ coco='mscoco/train2014',
10
+ vg='visual-genome',
11
+ saiaprtc12='saiaprtc12',
12
+ flickr='flickr30k'
13
+ )
14
+
15
+ # load annotation files
16
+ f = open("datasets/annotations/instances.json")
17
+ print("Loading annotation file")
18
+ data = json.load(f)
19
+ f.close()
20
+
21
+ # load the validation and test image list of refcoco, refcoco+, and refcocog
22
+ val_test_files = pickle.load(open("data/val_test_files.p", "rb"))
23
+
24
+ # create result folder
25
+ os.makedirs("datasets/pretrain", exist_ok=True)
26
+
27
+ # generate training tsv file
28
+ train_instances = data['train']
29
+ tsv_filename = "datasets/pretrain/train_shuffled.tsv"
30
+ writer = open(tsv_filename, 'w')
31
+ print("generating ", tsv_filename)
32
+
33
+ lines = []
34
+ for i, data_i in enumerate(tqdm(train_instances)):
35
+ data_source = data_i['data_source']
36
+ image_id = data_i['image_id']
37
+ bbox = data_i['bbox']
38
+ expressions = data_i['expressions']
39
+ height, width = data_i['height'], data_i['width']
40
+ x, y, w, h = bbox
41
+ box_string = f'{x},{y},{x + w},{y + h}'
42
+ img_name = "COCO_train2014_%012d.jpg" if "coco" in data_source else "%d.jpg"
43
+ img_name = img_name % image_id
44
+ filepath = os.path.join(imgsfile[data_source], img_name)
45
+ line = '\t'.join([str(i), expressions[0].replace('\n', ''), box_string, filepath]) + '\n'
46
+ lines.append(line)
47
+
48
+ # shuffle the training set
49
+ random.shuffle(lines)
50
+
51
+ # write training tsv file
52
+ writer.writelines(lines)
53
+ writer.close()
54
+
55
+ # generate validation tsv files
56
+ val_sets = ['val_refcoco_unc', 'val_refcocoplus_unc', 'val_refcocog_umd', 'val_flickr30k', 'val_referitgame_berkeley']
57
+ for val_set in val_sets:
58
+ val_instances = data[val_set]
59
+ tsv_filename = f"datasets/pretrain/{val_set}.tsv"
60
+ writer = open(tsv_filename, 'w')
61
+ print("generating ", tsv_filename)
62
+
63
+ lines = []
64
+ for i, data_i in enumerate(tqdm(val_instances)):
65
+ data_source = data_i['data_source']
66
+ image_id = data_i['image_id']
67
+ bbox = data_i['bbox']
68
+ expressions = data_i['expressions']
69
+ height, width = data_i['height'], data_i['width']
70
+ x, y, w, h = bbox
71
+ box_string = f'{x},{y},{x + w},{y + h}'
72
+ img_name = "COCO_train2014_%012d.jpg" if "coco" in data_source else "%d.jpg"
73
+ img_name = img_name % image_id
74
+ filepath = os.path.join(imgsfile[data_source], img_name)
75
+ line = '\t'.join([str(i), expressions[0].replace('\n', ''), box_string, filepath]) + '\n'
76
+ lines.append(line)
77
+
78
+ # write tsv file
79
+ writer.writelines(lines)
80
+ writer.close()
data/data_utils.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Modified from OFA (https://github.com/OFA-Sys/OFA)
3
+ # Copyright 2022 The OFA-Sys Team.
4
+ # All rights reserved.
5
+ # This source code is licensed under the Apache 2.0 license
6
+ # found in the LICENSE file in the root directory.
7
+ # ------------------------------------------------------------------------
8
+ # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
9
+ # SPDX-License-Identifier: Apache-2.0
10
+
11
+ try:
12
+ from collections.abc import Iterable
13
+ except ImportError:
14
+ from collections import Iterable
15
+ import contextlib
16
+ import itertools
17
+ import logging
18
+ import re
19
+ import warnings
20
+ from typing import Optional, Tuple
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from fairseq.file_io import PathManager
26
+ from fairseq import utils
27
+ import os
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def infer_language_pair(path):
33
+ """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
34
+ src, dst = None, None
35
+ for filename in PathManager.ls(path):
36
+ parts = filename.split(".")
37
+ if len(parts) >= 3 and len(parts[1].split("-")) == 2:
38
+ return parts[1].split("-")
39
+ return src, dst
40
+
41
+
42
+ def collate_tokens(
43
+ values,
44
+ pad_idx,
45
+ eos_idx=None,
46
+ left_pad=False,
47
+ move_eos_to_beginning=False,
48
+ pad_to_length=None,
49
+ pad_to_multiple=1,
50
+ pad_to_bsz=None,
51
+ ):
52
+ """Convert a list of 1d tensors into a padded 2d tensor."""
53
+ size = max(v.size(0) for v in values)
54
+ size = size if pad_to_length is None else max(size, pad_to_length)
55
+ if pad_to_multiple != 1 and size % pad_to_multiple != 0:
56
+ size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
57
+
58
+ def copy_tensor(src, dst):
59
+ assert dst.numel() == src.numel()
60
+ if move_eos_to_beginning:
61
+ if eos_idx is None:
62
+ # if no eos_idx is specified, then use the last token in src
63
+ dst[0] = src[-1]
64
+ else:
65
+ dst[0] = eos_idx
66
+ dst[1:] = src[:-1]
67
+ else:
68
+ dst.copy_(src)
69
+
70
+ if values[0].dim() == 1:
71
+ res = values[0].new(len(values), size).fill_(pad_idx)
72
+ elif values[0].dim() == 2:
73
+ assert move_eos_to_beginning is False
74
+ res = values[0].new(len(values), size, values[0].size(1)).fill_(pad_idx)
75
+ else:
76
+ raise NotImplementedError
77
+
78
+ for i, v in enumerate(values):
79
+ copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
80
+ return res
81
+
82
+
83
+ def load_indexed_dataset(
84
+ path, dictionary=None, dataset_impl=None, combine=False, default="cached"
85
+ ):
86
+ """A helper function for loading indexed datasets.
87
+
88
+ Args:
89
+ path (str): path to indexed dataset (e.g., 'data-bin/train')
90
+ dictionary (~fairseq.data.Dictionary): data dictionary
91
+ dataset_impl (str, optional): which dataset implementation to use. If
92
+ not provided, it will be inferred automatically. For legacy indexed
93
+ data we use the 'cached' implementation by default.
94
+ combine (bool, optional): automatically load and combine multiple
95
+ datasets. For example, if *path* is 'data-bin/train', then we will
96
+ combine 'data-bin/train', 'data-bin/train1', ... and return a
97
+ single ConcatDataset instance.
98
+ """
99
+ import fairseq.data.indexed_dataset as indexed_dataset
100
+ from fairseq.data.concat_dataset import ConcatDataset
101
+
102
+ datasets = []
103
+ for k in itertools.count():
104
+ path_k = path + (str(k) if k > 0 else "")
105
+ try:
106
+ path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
107
+ except Exception as e:
108
+ if "StorageException: [404] Path not found" in str(e):
109
+ logger.warning(f"path_k: {e} not found")
110
+ else:
111
+ raise e
112
+
113
+ dataset_impl_k = dataset_impl
114
+ if dataset_impl_k is None:
115
+ dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
116
+ dataset = indexed_dataset.make_dataset(
117
+ path_k,
118
+ impl=dataset_impl_k or default,
119
+ fix_lua_indexing=True,
120
+ dictionary=dictionary,
121
+ )
122
+ if dataset is None:
123
+ break
124
+ logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k))
125
+ datasets.append(dataset)
126
+ if not combine:
127
+ break
128
+ if len(datasets) == 0:
129
+ return None
130
+ elif len(datasets) == 1:
131
+ return datasets[0]
132
+ else:
133
+ return ConcatDataset(datasets)
134
+
135
+
136
+ @contextlib.contextmanager
137
+ def numpy_seed(seed, *addl_seeds):
138
+ """Context manager which seeds the NumPy PRNG with the specified seed and
139
+ restores the state afterward"""
140
+ if seed is None:
141
+ yield
142
+ return
143
+ if len(addl_seeds) > 0:
144
+ seed = int(hash((seed, *addl_seeds)) % 1e6)
145
+ state = np.random.get_state()
146
+ np.random.seed(seed)
147
+ try:
148
+ yield
149
+ finally:
150
+ np.random.set_state(state)
151
+
152
+
153
+ def collect_filtered(function, iterable, filtered):
154
+ """
155
+ Similar to :func:`filter` but collects filtered elements in ``filtered``.
156
+
157
+ Args:
158
+ function (callable): function that returns ``False`` for elements that
159
+ should be filtered
160
+ iterable (iterable): iterable to filter
161
+ filtered (list): list to store filtered elements
162
+ """
163
+ for el in iterable:
164
+ if function(el):
165
+ yield el
166
+ else:
167
+ filtered.append(el)
168
+
169
+
170
+ def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
171
+ def compare_leq(a, b):
172
+ return a <= b if not isinstance(a, tuple) else max(a) <= b
173
+
174
+ def check_size(idx):
175
+ if isinstance(max_positions, float) or isinstance(max_positions, int):
176
+ return size_fn(idx) <= max_positions
177
+ elif isinstance(max_positions, dict):
178
+ idx_size = size_fn(idx)
179
+ assert isinstance(idx_size, dict)
180
+ intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
181
+ return all(
182
+ all(
183
+ a is None or b is None or a <= b
184
+ for a, b in zip(idx_size[key], max_positions[key])
185
+ )
186
+ for key in intersect_keys
187
+ )
188
+ else:
189
+ # For MultiCorpusSampledDataset, will generalize it later
190
+ if not isinstance(size_fn(idx), Iterable):
191
+ return all(size_fn(idx) <= b for b in max_positions)
192
+ return all(
193
+ a is None or b is None or a <= b
194
+ for a, b in zip(size_fn(idx), max_positions)
195
+ )
196
+
197
+ ignored = []
198
+ itr = collect_filtered(check_size, indices, ignored)
199
+ indices = np.fromiter(itr, dtype=np.int64, count=-1)
200
+ return indices, ignored
201
+
202
+
203
+ def filter_by_size(indices, dataset, max_positions, raise_exception=False):
204
+ """
205
+ [deprecated] Filter indices based on their size.
206
+ Use `FairseqDataset::filter_indices_by_size` instead.
207
+
208
+ Args:
209
+ indices (List[int]): ordered list of dataset indices
210
+ dataset (FairseqDataset): fairseq dataset instance
211
+ max_positions (tuple): filter elements larger than this size.
212
+ Comparisons are done component-wise.
213
+ raise_exception (bool, optional): if ``True``, raise an exception if
214
+ any elements are filtered (default: False).
215
+ """
216
+ warnings.warn(
217
+ "data_utils.filter_by_size is deprecated. "
218
+ "Use `FairseqDataset::filter_indices_by_size` instead.",
219
+ stacklevel=2,
220
+ )
221
+ if isinstance(max_positions, float) or isinstance(max_positions, int):
222
+ if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
223
+ ignored = indices[dataset.sizes[indices] > max_positions].tolist()
224
+ indices = indices[dataset.sizes[indices] <= max_positions]
225
+ elif (
226
+ hasattr(dataset, "sizes")
227
+ and isinstance(dataset.sizes, list)
228
+ and len(dataset.sizes) == 1
229
+ ):
230
+ ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
231
+ indices = indices[dataset.sizes[0][indices] <= max_positions]
232
+ else:
233
+ indices, ignored = _filter_by_size_dynamic(
234
+ indices, dataset.size, max_positions
235
+ )
236
+ else:
237
+ indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
238
+
239
+ if len(ignored) > 0 and raise_exception:
240
+ raise Exception(
241
+ (
242
+ "Size of sample #{} is invalid (={}) since max_positions={}, "
243
+ "skip this example with --skip-invalid-size-inputs-valid-test"
244
+ ).format(ignored[0], dataset.size(ignored[0]), max_positions)
245
+ )
246
+ if len(ignored) > 0:
247
+ logger.warning(
248
+ (
249
+ "{} samples have invalid sizes and will be skipped, "
250
+ "max_positions={}, first few sample ids={}"
251
+ ).format(len(ignored), max_positions, ignored[:10])
252
+ )
253
+ return indices
254
+
255
+
256
+ def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
257
+ """Filter a list of sample indices. Remove those that are longer
258
+ than specified in max_sizes.
259
+
260
+ Args:
261
+ indices (np.array): original array of sample indices
262
+ max_sizes (int or list[int] or tuple[int]): max sample size,
263
+ can be defined separately for src and tgt (then list or tuple)
264
+
265
+ Returns:
266
+ np.array: filtered sample array
267
+ list: list of removed indices
268
+ """
269
+ if max_sizes is None:
270
+ return indices, []
271
+ if type(max_sizes) in (int, float):
272
+ max_src_size, max_tgt_size = max_sizes, max_sizes
273
+ else:
274
+ max_src_size, max_tgt_size = max_sizes
275
+ if tgt_sizes is None:
276
+ ignored = indices[src_sizes[indices] > max_src_size]
277
+ else:
278
+ ignored = indices[
279
+ (src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
280
+ ]
281
+ if len(ignored) > 0:
282
+ if tgt_sizes is None:
283
+ indices = indices[src_sizes[indices] <= max_src_size]
284
+ else:
285
+ indices = indices[
286
+ (src_sizes[indices] <= max_src_size)
287
+ & (tgt_sizes[indices] <= max_tgt_size)
288
+ ]
289
+ return indices, ignored.tolist()
290
+
291
+
292
+ def batch_by_size(
293
+ indices,
294
+ num_tokens_fn,
295
+ num_tokens_vec=None,
296
+ max_tokens=None,
297
+ max_sentences=None,
298
+ required_batch_size_multiple=1,
299
+ fixed_shapes=None,
300
+ ):
301
+ """
302
+ Yield mini-batches of indices bucketed by size. Batches may contain
303
+ sequences of different lengths.
304
+
305
+ Args:
306
+ indices (List[int]): ordered list of dataset indices
307
+ num_tokens_fn (callable): function that returns the number of tokens at
308
+ a given index
309
+ num_tokens_vec (List[int], optional): precomputed vector of the number
310
+ of tokens for each index in indices (to enable faster batch generation)
311
+ max_tokens (int, optional): max number of tokens in each batch
312
+ (default: None).
313
+ max_sentences (int, optional): max number of sentences in each
314
+ batch (default: None).
315
+ required_batch_size_multiple (int, optional): require batch size to
316
+ be less than N or a multiple of N (default: 1).
317
+ fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
318
+ only be created with the given shapes. *max_sentences* and
319
+ *required_batch_size_multiple* will be ignored (default: None).
320
+ """
321
+ try:
322
+ from fairseq.data.data_utils_fast import (
323
+ batch_by_size_fn,
324
+ batch_by_size_vec,
325
+ batch_fixed_shapes_fast,
326
+ )
327
+ except ImportError:
328
+ raise ImportError(
329
+ "Please build Cython components with: "
330
+ "`python setup.py build_ext --inplace`"
331
+ )
332
+ except ValueError:
333
+ raise ValueError(
334
+ "Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
335
+ )
336
+
337
+ # added int() to avoid TypeError: an integer is required
338
+ max_tokens = (
339
+ int(max_tokens) if max_tokens is not None else -1
340
+ )
341
+ max_sentences = max_sentences if max_sentences is not None else -1
342
+ bsz_mult = required_batch_size_multiple
343
+
344
+ if not isinstance(indices, np.ndarray):
345
+ indices = np.fromiter(indices, dtype=np.int64, count=-1)
346
+
347
+ if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray):
348
+ num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1)
349
+
350
+ if fixed_shapes is None:
351
+ if num_tokens_vec is None:
352
+ return batch_by_size_fn(
353
+ indices,
354
+ num_tokens_fn,
355
+ max_tokens,
356
+ max_sentences,
357
+ bsz_mult,
358
+ )
359
+ else:
360
+ return batch_by_size_vec(
361
+ indices,
362
+ num_tokens_vec,
363
+ max_tokens,
364
+ max_sentences,
365
+ bsz_mult,
366
+ )
367
+
368
+ else:
369
+ fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
370
+ sort_order = np.lexsort(
371
+ [
372
+ fixed_shapes[:, 1].argsort(), # length
373
+ fixed_shapes[:, 0].argsort(), # bsz
374
+ ]
375
+ )
376
+ fixed_shapes_sorted = fixed_shapes[sort_order]
377
+ return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
378
+
379
+
380
+ def post_process(sentence: str, symbol: str):
381
+ if symbol == "sentencepiece":
382
+ sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
383
+ elif symbol == "wordpiece":
384
+ sentence = sentence.replace(" ", "").replace("_", " ").strip()
385
+ elif symbol == "letter":
386
+ sentence = sentence.replace(" ", "").replace("|", " ").strip()
387
+ elif symbol == "silence":
388
+ import re
389
+ sentence = sentence.replace("<SIL>", "")
390
+ sentence = re.sub(' +', ' ', sentence).strip()
391
+ elif symbol == "_EOW":
392
+ sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
393
+ elif symbol in {"subword_nmt", "@@ ", "@@"}:
394
+ if symbol == "subword_nmt":
395
+ symbol = "@@ "
396
+ sentence = (sentence + " ").replace(symbol, "").rstrip()
397
+ elif symbol == "none":
398
+ pass
399
+ elif symbol is not None:
400
+ raise NotImplementedError(f"Unknown post_process option: {symbol}")
401
+ return sentence
402
+
403
+
404
+ def compute_mask_indices(
405
+ shape: Tuple[int, int],
406
+ padding_mask: Optional[torch.Tensor],
407
+ mask_prob: float,
408
+ mask_length: int,
409
+ mask_type: str = "static",
410
+ mask_other: float = 0.0,
411
+ min_masks: int = 0,
412
+ no_overlap: bool = False,
413
+ min_space: int = 0,
414
+ ) -> np.ndarray:
415
+ """
416
+ Computes random mask spans for a given shape
417
+
418
+ Args:
419
+ shape: the the shape for which to compute masks.
420
+ should be of size 2 where first element is batch size and 2nd is timesteps
421
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
422
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
423
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
424
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
425
+ mask_type: how to compute mask lengths
426
+ static = fixed size
427
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
428
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
429
+ poisson = sample from possion distribution with lambda = mask length
430
+ min_masks: minimum number of masked spans
431
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
432
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
433
+ """
434
+
435
+ bsz, all_sz = shape
436
+ mask = np.full((bsz, all_sz), False)
437
+
438
+ all_num_mask = int(
439
+ # add a random number for probabilistic rounding
440
+ mask_prob * all_sz / float(mask_length)
441
+ + np.random.rand()
442
+ )
443
+
444
+ all_num_mask = max(min_masks, all_num_mask)
445
+
446
+ mask_idcs = []
447
+ for i in range(bsz):
448
+ if padding_mask is not None:
449
+ sz = all_sz - padding_mask[i].long().sum().item()
450
+ num_mask = int(
451
+ # add a random number for probabilistic rounding
452
+ mask_prob * sz / float(mask_length)
453
+ + np.random.rand()
454
+ )
455
+ num_mask = max(min_masks, num_mask)
456
+ else:
457
+ sz = all_sz
458
+ num_mask = all_num_mask
459
+
460
+ if mask_type == "static":
461
+ lengths = np.full(num_mask, mask_length)
462
+ elif mask_type == "uniform":
463
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
464
+ elif mask_type == "normal":
465
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
466
+ lengths = [max(1, int(round(x))) for x in lengths]
467
+ elif mask_type == "poisson":
468
+ lengths = np.random.poisson(mask_length, size=num_mask)
469
+ lengths = [int(round(x)) for x in lengths]
470
+ else:
471
+ raise Exception("unknown mask selection " + mask_type)
472
+
473
+ if sum(lengths) == 0:
474
+ lengths[0] = min(mask_length, sz - 1)
475
+
476
+ if no_overlap:
477
+ mask_idc = []
478
+
479
+ def arrange(s, e, length, keep_length):
480
+ span_start = np.random.randint(s, e - length)
481
+ mask_idc.extend(span_start + i for i in range(length))
482
+
483
+ new_parts = []
484
+ if span_start - s - min_space >= keep_length:
485
+ new_parts.append((s, span_start - min_space + 1))
486
+ if e - span_start - keep_length - min_space > keep_length:
487
+ new_parts.append((span_start + length + min_space, e))
488
+ return new_parts
489
+
490
+ parts = [(0, sz)]
491
+ min_length = min(lengths)
492
+ for length in sorted(lengths, reverse=True):
493
+ lens = np.fromiter(
494
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
495
+ np.int,
496
+ )
497
+ l_sum = np.sum(lens)
498
+ if l_sum == 0:
499
+ break
500
+ probs = lens / np.sum(lens)
501
+ c = np.random.choice(len(parts), p=probs)
502
+ s, e = parts.pop(c)
503
+ parts.extend(arrange(s, e, length, min_length))
504
+ mask_idc = np.asarray(mask_idc)
505
+ else:
506
+ min_len = min(lengths)
507
+ if sz - min_len <= num_mask:
508
+ min_len = sz - num_mask - 1
509
+
510
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
511
+
512
+ mask_idc = np.asarray(
513
+ [
514
+ mask_idc[j] + offset
515
+ for j in range(len(mask_idc))
516
+ for offset in range(lengths[j])
517
+ ]
518
+ )
519
+
520
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
521
+
522
+ min_len = min([len(m) for m in mask_idcs])
523
+ for i, mask_idc in enumerate(mask_idcs):
524
+ if len(mask_idc) > min_len:
525
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
526
+ mask[i, mask_idc] = True
527
+
528
+ return mask
529
+
530
+
531
+ def get_mem_usage():
532
+ try:
533
+ import psutil
534
+
535
+ mb = 1024 * 1024
536
+ return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
537
+ except ImportError:
538
+ return "N/A"
539
+
540
+
541
+ # lens: torch.LongTensor
542
+ # returns: torch.BoolTensor
543
+ def lengths_to_padding_mask(lens):
544
+ bsz, max_lens = lens.size(0), torch.max(lens).item()
545
+ mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
546
+ mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
547
+ return mask
548
+
549
+
550
+ # lens: torch.LongTensor
551
+ # returns: torch.BoolTensor
552
+ def lengths_to_mask(lens):
553
+ return ~lengths_to_padding_mask(lens)
554
+
555
+
556
+ def get_buckets(sizes, num_buckets):
557
+ buckets = np.unique(
558
+ np.percentile(
559
+ sizes,
560
+ np.linspace(0, 100, num_buckets + 1),
561
+ interpolation='lower',
562
+ )[1:]
563
+ )
564
+ return buckets
565
+
566
+
567
+ def get_bucketed_sizes(orig_sizes, buckets):
568
+ sizes = np.copy(orig_sizes)
569
+ assert np.min(sizes) >= 0
570
+ start_val = -1
571
+ for end_val in buckets:
572
+ mask = (sizes > start_val) & (sizes <= end_val)
573
+ sizes[mask] = end_val
574
+ start_val = end_val
575
+ return sizes
576
+
577
+
578
+
579
+ def _find_extra_valid_paths(dataset_path: str) -> set:
580
+ paths = utils.split_paths(dataset_path)
581
+ all_valid_paths = set()
582
+ for sub_dir in paths:
583
+ contents = PathManager.ls(sub_dir)
584
+ valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
585
+ all_valid_paths |= {os.path.basename(p) for p in valid_paths}
586
+ # Remove .bin, .idx etc
587
+ roots = {os.path.splitext(p)[0] for p in all_valid_paths}
588
+ return roots
589
+
590
+
591
+ def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
592
+ """Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
593
+ if (
594
+ train_cfg.dataset.ignore_unused_valid_subsets
595
+ or train_cfg.dataset.combine_valid_subsets
596
+ or train_cfg.dataset.disable_validation
597
+ or not hasattr(train_cfg.task, "data")
598
+ ):
599
+ return
600
+ other_paths = _find_extra_valid_paths(train_cfg.task.data)
601
+ specified_subsets = train_cfg.dataset.valid_subset.split(",")
602
+ ignored_paths = [p for p in other_paths if p not in specified_subsets]
603
+ if ignored_paths:
604
+ advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
605
+ msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
606
+ raise ValueError(msg)
data/file_dataset.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Modified from OFA (https://github.com/OFA-Sys/OFA)
3
+ # Copyright 2022 The OFA-Sys Team.
4
+ # All rights reserved.
5
+ # This source code is licensed under the Apache 2.0 license
6
+ # found in the LICENSE file in the root directory.
7
+ # ------------------------------------------------------------------------
8
+ # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
9
+ # SPDX-License-Identifier: Apache-2.0
10
+
11
+ import os
12
+ import torch
13
+ import pickle
14
+
15
+
16
+ class FileDataset:
17
+ def __init__(self, file_path, selected_col_ids=None, dtypes=None, separator="\t", cached_index=False):
18
+ self.file_path = file_path
19
+ assert os.path.exists(self.file_path), "Error: The local datafile {} not exists!".format(self.file_path)
20
+
21
+ self.separator = separator
22
+ if selected_col_ids is None:
23
+ # default to all fields
24
+ self.selected_col_ids = list(
25
+ range(len(open(self.file_path).readline().rstrip("\n").split(self.separator))))
26
+ else:
27
+ self.selected_col_ids = [int(col_id) for col_id in selected_col_ids.split(",")]
28
+ if dtypes is None:
29
+ # default to str
30
+ self.dtypes = [str for col_id in self.selected_col_ids]
31
+ else:
32
+ self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(",")]
33
+ assert len(self.dtypes) == len(self.selected_col_ids)
34
+
35
+ self.data_cnt = 0
36
+ try:
37
+ self.slice_id = torch.distributed.get_rank()
38
+ self.slice_count = torch.distributed.get_world_size()
39
+ except Exception:
40
+ self.slice_id = 0
41
+ self.slice_count = 1
42
+ self.cached_index = cached_index
43
+ self._init_seek_index()
44
+ self._reader = self._get_reader()
45
+ print("file {} slice_id {} row count {} total row count {}".format(
46
+ self.file_path, self.slice_id, self.row_count, self.total_row_count)
47
+ )
48
+
49
+ def _init_seek_index(self):
50
+ if self.cached_index:
51
+ cache_path = "{}.index".format(self.file_path)
52
+ assert os.path.exists(cache_path), "cache file {} not exists!".format(cache_path)
53
+ self.total_row_count, self.lineid_to_offset = pickle.load(open(cache_path, "rb"))
54
+ print("local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping".format(
55
+ self.file_path, self.slice_id))
56
+ else:
57
+ # make an iteration over the file to get row_count and line_idx-to-offset mapping
58
+ fp = open(self.file_path, "r")
59
+ print("local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping".format(
60
+ self.file_path, self.slice_id))
61
+ self.total_row_count = 0
62
+ offset = 0
63
+ self.lineid_to_offset = []
64
+ for line in fp:
65
+ self.lineid_to_offset.append(offset)
66
+ self.total_row_count += 1
67
+ offset += len(line.encode('utf-8'))
68
+ self._compute_start_pos_and_row_count()
69
+ print("local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping".format(
70
+ self.file_path, self.slice_id))
71
+
72
+ def _compute_start_pos_and_row_count(self):
73
+ self.row_count = self.total_row_count // self.slice_count
74
+ if self.slice_id < self.total_row_count - self.row_count * self.slice_count:
75
+ self.row_count += 1
76
+ self.start_pos = self.row_count * self.slice_id
77
+ else:
78
+ self.start_pos = self.row_count * self.slice_id + (self.total_row_count - self.row_count * self.slice_count)
79
+
80
+ def _get_reader(self):
81
+ fp = open(self.file_path, "r")
82
+ fp.seek(self.lineid_to_offset[self.start_pos])
83
+ return fp
84
+
85
+ def _seek(self, offset=0):
86
+ try:
87
+ print("slice_id {} seek offset {}".format(self.slice_id, self.start_pos + offset))
88
+ self._reader.seek(self.lineid_to_offset[self.start_pos + offset])
89
+ self.data_cnt = offset
90
+ except Exception:
91
+ print("slice_id {} seek offset {}".format(self.slice_id, offset))
92
+ self._reader.seek(self.lineid_to_offset[offset])
93
+ self.data_cnt = offset
94
+
95
+ def __del__(self):
96
+ self._reader.close()
97
+
98
+ def __len__(self):
99
+ return self.row_count
100
+
101
+ def get_total_row_count(self):
102
+ return self.total_row_count
103
+
104
+ def __getitem__(self, index):
105
+ if self.data_cnt == self.row_count:
106
+ print("reach the end of datafile, start a new reader")
107
+ self.data_cnt = 0
108
+ self._reader = self._get_reader()
109
+ column_l = self._reader.readline().rstrip("\n").split(self.separator)
110
+ self.data_cnt += 1
111
+ column_l = [dtype(column_l[col_id]) for col_id, dtype in zip(self.selected_col_ids, self.dtypes)]
112
+ return column_l
data/poly_utils.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ from itertools import groupby
4
+ from PIL import Image
5
+ import math
6
+ from math import ceil, floor
7
+ from skimage import draw
8
+ from random import sample
9
+ import base64
10
+ from io import BytesIO
11
+
12
+ convert = lambda text: int(text) if text.isdigit() else text.lower()
13
+ natrual_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
14
+
15
+
16
+ def points_to_token_string(box, polygons):
17
+ polygon_strings = []
18
+ for polygon in polygons:
19
+ polygon_string = " ".join([f"<bin_{int(p[0])}_{int(p[1])}>" for p in polygon])
20
+ polygon_strings.append(polygon_string)
21
+ polygon_string = " <separator> ".join(polygon_strings)
22
+ box_string = " ".join([f"<bin_{int(p[0])}_{int(p[1])}>" for p in box])
23
+ token_string = " ".join([box_string, polygon_string])
24
+
25
+ token_type = []
26
+ for token in token_string.split(" "):
27
+ if "bin" in token:
28
+ token_type.append(0) # 0 for coordinate tokens
29
+ else:
30
+ token_type.append(1) # 1 for separator tokens
31
+ return token_string, token_type
32
+
33
+
34
+ def resize_binary_mask(array, new_size):
35
+ image = Image.fromarray(array.astype(np.uint8) * 255)
36
+ image = image.resize(new_size)
37
+ return np.asarray(image).astype(np.bool_)
38
+
39
+
40
+ def close_contour(contour):
41
+ if not np.array_equal(contour[0], contour[-1]):
42
+ contour = np.vstack((contour, contour[0]))
43
+ return contour
44
+
45
+
46
+ def binary_mask_to_rle(binary_mask):
47
+ rle = {'counts': [], 'size': list(binary_mask.shape)}
48
+ counts = rle.get('counts')
49
+ for i, (value, elements) in enumerate(groupby(binary_mask.ravel(order='F'))):
50
+ if i == 0 and value == 1:
51
+ counts.append(0)
52
+ counts.append(len(list(elements)))
53
+
54
+ return rle
55
+
56
+
57
+ def revert_direction(poly):
58
+ poly = np.array(poly).reshape(int(len(poly) / 2), 2)
59
+ poly = poly[::-1, :]
60
+ return list(poly.flatten())
61
+
62
+
63
+ def reorder_points(poly):
64
+ poly = np.array(poly)
65
+ xs = poly[::2]
66
+ ys = poly[1::2]
67
+ points = np.array(poly).reshape(int(len(poly) / 2), 2)
68
+ start = np.argmin(xs ** 2 + ys ** 2) # smallest distance to the origin
69
+ poly_reordered = np.concatenate([points[start:], points[:start]], 0)
70
+ return list(poly_reordered.flatten())
71
+
72
+
73
+ def convert_pts(coeffs):
74
+ pts = []
75
+ for i in range(len(coeffs) // 2):
76
+ pts.append([coeffs[2 * i + 1], coeffs[2 * i]]) # y, x
77
+ return np.array(pts, np.int32)
78
+
79
+
80
+ def get_mask_from_codes(codes, img_size):
81
+ masks = [np.zeros(img_size)]
82
+ for code in codes:
83
+ if len(code) > 0:
84
+ mask = draw.polygon2mask(img_size, convert_pts(code))
85
+ mask = np.array(mask, np.uint8)
86
+ masks.append(mask)
87
+ mask = sum(masks)
88
+ mask = mask > 0
89
+ return mask.astype(np.uint8)
90
+
91
+
92
+ def is_clockwise(poly):
93
+ n = len(poly) // 2
94
+ xs = poly[::2]
95
+ xs.append(xs[0])
96
+ ys = poly[1::2]
97
+ ys.append(ys[0])
98
+ area = 0
99
+ for i in range(n):
100
+ x1, y1 = xs[i], ys[i]
101
+ x2, y2 = xs[i + 1], ys[i + 1]
102
+ area += (x2 - x1) * (y2 + y1)
103
+ return area < 0
104
+
105
+
106
+ def close_polygon_contour(poly):
107
+ poly = np.array(poly).reshape(int(len(poly) / 2), 2)
108
+ x1, y1 = poly[0]
109
+ x2, y2 = poly[-1]
110
+ if x1 != x2:
111
+ poly = np.concatenate([poly, [poly[0]]], 0)
112
+ return list(poly.flatten())
113
+
114
+
115
+ def close_polygons_contour(polygons):
116
+ polygons_closed = []
117
+ for polygon in polygons:
118
+ polygon_closed = close_polygon_contour(polygon)
119
+ polygons_closed.append(polygon_closed)
120
+ return polygons_closed
121
+
122
+
123
+ def image_to_base64(img, format):
124
+ output_buffer = BytesIO()
125
+ img.save(output_buffer, format=format)
126
+ byte_data = output_buffer.getvalue()
127
+ base64_str = base64.b64encode(byte_data)
128
+ base64_str = str(base64_str, encoding='utf-8')
129
+ return base64_str
130
+
131
+
132
+ def process_polygons(polygons, redirection=True, reorder=True, close=False):
133
+ polygons_processed = []
134
+ for polygon in polygons:
135
+ if redirection and not is_clockwise(polygon):
136
+ polygon = revert_direction(polygon)
137
+ if reorder:
138
+ polygon = reorder_points(polygon)
139
+ if close:
140
+ polygon = close_polygon_contour(polygon)
141
+ polygons_processed.append(polygon)
142
+ polygons = sorted(polygons_processed, key=lambda x: (x[0] ** 2 + x[1] ** 2, x[0], x[1]))
143
+ return polygons
144
+
145
+
146
+ def string_to_polygons(pts_strings):
147
+ pts_strings = pts_strings.split(" ")[:-1]
148
+ polygons = []
149
+ for pts_string in pts_strings:
150
+ polygon = pts_string.split(",")
151
+ polygon = [float(p) for p in polygon]
152
+ polygons.append(polygon)
153
+ return polygons
154
+
155
+
156
+ def downsample_polygon(polygon, ds_rate=25):
157
+ points = np.array(polygon).reshape(int(len(polygon) / 2), 2)
158
+ points = points[::ds_rate]
159
+ return list(points.flatten())
160
+
161
+
162
+ def downsample_polygons(polygons, ds_rate=25):
163
+ polygons_ds = []
164
+ for polygon in polygons:
165
+ polygons_ds.append(downsample_polygon(polygon, ds_rate))
166
+ return polygons_ds
167
+
168
+
169
+ def check_length(polygons):
170
+ length = 0
171
+ for polygon in polygons:
172
+ length += len(polygon)
173
+ return length
174
+
175
+
176
+ def approximate_polygon(poly, tolerance=2):
177
+ poly = np.array(poly).reshape(int(len(poly) / 2), 2)
178
+ new_poly = [poly[0]]
179
+ for i in range(1, len(poly)):
180
+ x1, y1 = new_poly[-1]
181
+ x2, y2 = poly[i]
182
+ dist = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
183
+ if dist > tolerance:
184
+ new_poly.append(poly[i])
185
+ new_poly = np.array(new_poly)
186
+ return list(new_poly.flatten())
187
+
188
+
189
+ def approximate_polygons(polys, tolerance=1.0, max_length=400):
190
+ tol = tolerance
191
+ while check_length(polys) > max_length:
192
+ polys_new = []
193
+ for poly in polys:
194
+ polys_new.append(approximate_polygon(poly, tolerance=tol))
195
+ polys = polys_new
196
+ tol += 2.0
197
+ return polys
198
+
199
+
200
+ def random_int(low, high):
201
+ if low < high:
202
+ return np.random.randint(low, high)
203
+ else:
204
+ return max(low, high)
205
+
206
+
207
+ def interpolate_points(ps, pe):
208
+ xs, ys = ps
209
+ xe, ye = pe
210
+ points = []
211
+ dx = xe - xs
212
+ dy = ye - ys
213
+ if dx != 0:
214
+ scale = dy / dx
215
+ if xe > xs:
216
+ x_interpolated = list(range(ceil(xs), floor(xe) + 1))
217
+ else:
218
+ x_interpolated = list(range(floor(xs), ceil(xe) - 1, -1))
219
+ for x in x_interpolated:
220
+ y = ys + (x - xs) * scale
221
+ points.append([x, y])
222
+ if dy != 0:
223
+ scale = dx / dy
224
+ if ye > ys:
225
+ y_interpolated = list(range(ceil(ys), floor(ye) + 1))
226
+ else:
227
+ y_interpolated = list(range(floor(ys), ceil(ye) - 1, -1))
228
+ for y in y_interpolated:
229
+ x = xs + (y - ys) * scale
230
+ points.append([x, y])
231
+ if xe > xs:
232
+ points = sorted(points, key=lambda x: x[0])
233
+ else:
234
+ points = sorted(points, key=lambda x: -x[0])
235
+ return points
236
+
237
+
238
+ def interpolate_polygon(polygon):
239
+ points = np.array(polygon).reshape(int(len(polygon) / 2), 2)
240
+ points_interpolated = []
241
+ points_interpolated.append(points[0])
242
+ for i in range(0, len(points) - 1):
243
+ points_i = interpolate_points(points[i], points[i + 1])
244
+ points_interpolated += points_i
245
+ points_interpolated.append(points[i + 1])
246
+ points_interpolated = prune_points(points_interpolated)
247
+ polygon_interpolated = np.array(points_interpolated)
248
+ return list(polygon_interpolated.flatten())
249
+
250
+
251
+ def prune_points(points, th=0.1):
252
+ points_pruned = [points[0]]
253
+ for i in range(1, len(points)):
254
+ x1, y1 = points_pruned[-1]
255
+ x2, y2 = points[i]
256
+ dist = (x2 - x1) ** 2 + (y2 - y1) ** 2
257
+ if dist > th:
258
+ points_pruned.append(points[i])
259
+ return points_pruned
260
+
261
+
262
+ def interpolate_polygons(polygons):
263
+ polygons_i = []
264
+ for polygon in polygons:
265
+ polygons_i.append(interpolate_polygon(polygon))
266
+ return polygons_i
267
+
268
+
269
+ def sample_polygon(polygon, sample_rate=0.5):
270
+ points = np.array(polygon).reshape(int(len(polygon) / 2), 2)
271
+ k = int(len(points) * sample_rate)
272
+ index = sorted(sample(list(range(len(points))), k))
273
+ points_sampled = points[index]
274
+ return list(np.array(points_sampled).flatten())
275
+
276
+
277
+ def sample_polygons(polygons, max_length=400.0):
278
+ n = check_length(polygons)
279
+ k = max_length / n
280
+ polygons_s = []
281
+ for polygon in polygons:
282
+ polygons_s.append(sample_polygon(polygon, k))
283
+ return polygons_s
284
+
285
+
286
+ def polygons_to_string(polygons):
287
+ pts_strings = []
288
+ for polygon in polygons:
289
+ pts_string = ','.join([str(num) for num in polygon])
290
+ pts_string += " " # separator
291
+ pts_strings.append(pts_string)
292
+ pts_strings = "".join(pts_strings)
293
+ return pts_strings
294
+
data/refcoco_dataset.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Modified from OFA (https://github.com/OFA-Sys/OFA)
3
+ # Copyright 2022 The OFA-Sys Team.
4
+ # All rights reserved.
5
+ # This source code is licensed under the Apache 2.0 license
6
+ # found in the LICENSE file in the root directory.
7
+ # ------------------------------------------------------------------------
8
+ # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
9
+ # SPDX-License-Identifier: Apache-2.0
10
+
11
+ from io import BytesIO
12
+
13
+ import logging
14
+ import warnings
15
+
16
+ import numpy as np
17
+ import torch
18
+ import base64
19
+ import utils.transforms as T
20
+ import math
21
+ from PIL import Image, ImageFile
22
+
23
+ from data import data_utils
24
+ from data.base_dataset import BaseDataset
25
+ from bert.tokenization_bert import BertTokenizer
26
+ from data.poly_utils import string_to_polygons, downsample_polygons, polygons_to_string, points_to_token_string
27
+ import cv2
28
+
29
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
30
+ ImageFile.MAX_IMAGE_PIXELS = None
31
+ Image.MAX_IMAGE_PIXELS = None
32
+
33
+ logger = logging.getLogger(__name__)
34
+ warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
35
+
36
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
37
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
38
+
39
+
40
+ class RefcocoDataset(BaseDataset):
41
+ def __init__(
42
+ self,
43
+ split,
44
+ dataset,
45
+ bpe,
46
+ src_dict,
47
+ tgt_dict=None,
48
+ max_src_length=80,
49
+ max_tgt_length=30,
50
+ patch_image_size=512,
51
+ imagenet_default_mean_and_std=False,
52
+ num_bins=1000,
53
+ max_image_size=512
54
+ ):
55
+ super().__init__(split, dataset, bpe, src_dict, tgt_dict)
56
+ self.max_src_length = max_src_length
57
+ self.max_tgt_length = max_tgt_length
58
+ self.patch_image_size = patch_image_size
59
+ self.num_bins = num_bins
60
+
61
+ if imagenet_default_mean_and_std:
62
+ mean = IMAGENET_DEFAULT_MEAN
63
+ std = IMAGENET_DEFAULT_STD
64
+ else:
65
+ mean = [0.5, 0.5, 0.5]
66
+ std = [0.5, 0.5, 0.5]
67
+
68
+ # for positioning
69
+ self.positioning_transform = T.Compose([
70
+ T.RandomResize([patch_image_size], max_size=patch_image_size),
71
+ T.ToTensor(),
72
+ T.Normalize(mean=mean, std=std, max_image_size=max_image_size)
73
+ ])
74
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
75
+
76
+ def __getitem__(self, index):
77
+ data = self.dataset[index]
78
+ if len(data) == 7:
79
+ uniq_id, base64_str, seg64_str, text, poly_original, region_coord, poly_interpolated = data
80
+ train = True
81
+ else:
82
+ uniq_id, base64_str, seg64_str, text, poly, region_coord = data
83
+ train = False
84
+
85
+ # load image and segmentation labels
86
+ image = Image.open(BytesIO(base64.urlsafe_b64decode(base64_str))).convert("RGB")
87
+ label = Image.open(BytesIO(base64.urlsafe_b64decode(seg64_str)))
88
+ label = np.asarray(label)
89
+ label = cv2.resize(label, [self.patch_image_size, self.patch_image_size], interpolation=cv2.INTER_NEAREST)
90
+
91
+ w, h = image.size
92
+ patch_image = self.positioning_transform(image, target=None)
93
+ resize_h = self.patch_image_size
94
+ resize_w = self.patch_image_size
95
+ patch_mask = torch.tensor([True])
96
+
97
+ if train:
98
+ prob = np.random.uniform()
99
+ if prob < 0.5:
100
+ polygons_interpolated = string_to_polygons(poly_interpolated)
101
+ ds_rate = np.random.randint(25, 41)
102
+ polygons_augmented = downsample_polygons(polygons_interpolated, ds_rate)
103
+ poly = polygons_to_string(polygons_augmented)
104
+ else:
105
+ poly = poly_original
106
+
107
+ polygons = string_to_polygons(poly)
108
+ polygons_scaled = []
109
+ for polygon in polygons:
110
+ n_point = len(polygon) // 2
111
+ scale = np.concatenate([np.array([w, h]) for _ in range(n_point)], 0)
112
+ polygon = polygon / scale
113
+ polygon = polygon.reshape(n_point, 2)
114
+ polygons_scaled.append(polygon)
115
+
116
+ x0, y0, x1, y1 = region_coord.strip().split(',')
117
+ region_points = [float(x0), float(y0), float(x1), float(y1)]
118
+ region = np.array(region_points)
119
+
120
+ region_points = region_points / np.array([w, h, w, h]) # scaled to [0,1]
121
+ region_points = torch.tensor(region_points.reshape(2, 2))
122
+
123
+ quant_box = region_points * (self.num_bins - 1)
124
+ quant_box11 = [[math.floor(p[0]), math.floor(p[1])] for p in quant_box]
125
+ quant_box21 = [[math.ceil(p[0]), math.floor(p[1])] for p in quant_box]
126
+ quant_box12 = [[math.floor(p[0]), math.ceil(p[1])] for p in quant_box]
127
+ quant_box22 = [[math.ceil(p[0]), math.ceil(p[1])] for p in quant_box]
128
+
129
+ quant_poly = [poly * (self.num_bins - 1) for poly in polygons_scaled]
130
+ quant_poly11 = [[[math.floor(p[0]), math.floor(p[1])] for p in poly] for poly in quant_poly]
131
+ quant_poly21 = [[[math.ceil(p[0]), math.floor(p[1])] for p in poly] for poly in quant_poly]
132
+ quant_poly12 = [[[math.floor(p[0]), math.ceil(p[1])] for p in poly] for poly in quant_poly]
133
+ quant_poly22 = [[[math.ceil(p[0]), math.ceil(p[1])] for p in poly] for poly in quant_poly]
134
+
135
+ region_coord11, _ = points_to_token_string(quant_box11, quant_poly11)
136
+ region_coord21, _ = points_to_token_string(quant_box21, quant_poly21)
137
+ region_coord12, _ = points_to_token_string(quant_box12, quant_poly12)
138
+ region_coord22, token_type = points_to_token_string(quant_box22, quant_poly22)
139
+
140
+ # compute bilinear interpolation coefficient
141
+ delta_x1 = [0] + [p[0] - math.floor(p[0]) for p in quant_box] # [0] for bos token
142
+ for polygon in quant_poly:
143
+ delta = [poly_point[0] - math.floor(poly_point[0]) for poly_point in polygon]
144
+ delta_x1.extend(delta)
145
+ delta_x1.extend([0]) # for separator token
146
+ delta_x1 = delta_x1[:-1] # there is no separator token in the end
147
+ delta_x1 = torch.tensor(delta_x1)
148
+ delta_x2 = 1 - delta_x1
149
+
150
+ delta_y1 = [0] + [p[1] - math.floor(p[1]) for p in quant_box] # [0] for bos token
151
+ for polygon in quant_poly:
152
+ delta = [poly_point[1] - math.floor(poly_point[1]) for poly_point in polygon]
153
+ delta_y1.extend(delta)
154
+ delta_y1.extend([0]) # for separator token
155
+ delta_y1 = delta_y1[:-1] # there is no separator token in the end
156
+ delta_y1 = torch.tensor(delta_y1)
157
+ delta_y2 = 1 - delta_y1
158
+
159
+ token_type.append(2) # 2 for eos token
160
+
161
+ src_caption = self.pre_caption(text, self.max_src_length)
162
+
163
+ prompt = ' which region does the text " {} " describe?'.format(src_caption)
164
+
165
+ # tgt for input
166
+ tgt_item11 = self.encode_text(region_coord11, use_bpe=False)
167
+ tgt_item12 = self.encode_text(region_coord12, use_bpe=False)
168
+ tgt_item21 = self.encode_text(region_coord21, use_bpe=False)
169
+ tgt_item22 = self.encode_text(region_coord22, use_bpe=False)
170
+
171
+ # tgt for output
172
+ target_item = region_points
173
+ for poly in polygons_scaled:
174
+ target_item = torch.cat([target_item, torch.tensor(poly), torch.tensor([[0, 0]])], dim=0) # [0, 0] is padding token for separator and eos
175
+
176
+ #target_item = torch.cat([tgt_item, self.eos_item])
177
+ prev_output_item11 = torch.cat([self.bos_item, tgt_item11])
178
+ prev_output_item12 = torch.cat([self.bos_item, tgt_item12])
179
+ prev_output_item21 = torch.cat([self.bos_item, tgt_item21])
180
+ prev_output_item22 = torch.cat([self.bos_item, tgt_item22])
181
+ example = {
182
+ "id": uniq_id,
183
+ "source": prompt,
184
+ "patch_image": patch_image,
185
+ "patch_mask": patch_mask,
186
+ "target": target_item,
187
+ "prev_output_tokens_11": prev_output_item11,
188
+ "prev_output_tokens_12": prev_output_item12,
189
+ "prev_output_tokens_21": prev_output_item21,
190
+ "prev_output_tokens_22": prev_output_item22,
191
+ "delta_x1": delta_x1,
192
+ "delta_y1": delta_y1,
193
+ "delta_x2": delta_x2,
194
+ "delta_y2": delta_y2,
195
+ "w_resize_ratio": torch.tensor(resize_w / w),
196
+ "h_resize_ratio": torch.tensor(resize_h / h),
197
+ "region_coord": torch.tensor(region),
198
+ "token_type": torch.tensor(token_type),
199
+ "w": torch.tensor(w),
200
+ "h": torch.tensor(h),
201
+ "label": label,
202
+ "n_poly": len(polygons),
203
+ "text": src_caption
204
+ }
205
+ return example
206
+
207
+ def collate(self, samples, pad_idx, eos_idx):
208
+ if len(samples) == 0:
209
+ return {}
210
+
211
+ def merge(key, padding_item):
212
+ return data_utils.collate_tokens(
213
+ [s[key] for s in samples],
214
+ padding_item,
215
+ eos_idx=eos_idx,
216
+ )
217
+
218
+ id = np.array([s["id"] for s in samples])
219
+ captions = [s["source"] for s in samples]
220
+ tokenized = self.tokenizer.batch_encode_plus(captions, padding="longest", return_tensors="pt")
221
+ src_tokens = tokenized["input_ids"]
222
+ att_masks = tokenized["attention_mask"]
223
+ src_lengths = torch.LongTensor(att_masks.ne(0).long().sum())
224
+
225
+ patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
226
+ patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
227
+
228
+ w_resize_ratios = torch.stack([s["w_resize_ratio"] for s in samples], dim=0)
229
+ h_resize_ratios = torch.stack([s["h_resize_ratio"] for s in samples], dim=0)
230
+
231
+ delta_x1 = merge("delta_x1", 0)
232
+ delta_y1 = merge("delta_y1", 0)
233
+ delta_x2 = merge("delta_x2", 1)
234
+ delta_y2 = merge("delta_y2", 1)
235
+
236
+ region_coords = torch.stack([s['region_coord'] for s in samples], dim=0)
237
+
238
+ target = merge("target", pad_idx)
239
+ tgt_lengths = torch.LongTensor([s["target"].shape[0] for s in samples])
240
+ ntokens = tgt_lengths.sum().item()
241
+
242
+ prev_output_tokens_11 = merge("prev_output_tokens_11", pad_idx)
243
+ prev_output_tokens_12 = merge("prev_output_tokens_12", pad_idx)
244
+ prev_output_tokens_21 = merge("prev_output_tokens_21", pad_idx)
245
+ prev_output_tokens_22 = merge("prev_output_tokens_22", pad_idx)
246
+
247
+ token_type = merge("token_type", -1)
248
+ w = torch.stack([s["w"] for s in samples], dim=0)
249
+ h = torch.stack([s["h"] for s in samples], dim=0)
250
+ n_poly = [s['n_poly'] for s in samples]
251
+
252
+ labels = np.stack([sample['label'] for sample in samples], 0)
253
+ text = [s["text"] for s in samples]
254
+ batch = {
255
+ "id": id,
256
+ "nsentences": len(samples),
257
+ "ntokens": ntokens,
258
+ "net_input": {
259
+ "src_tokens": src_tokens,
260
+ "src_lengths": src_lengths,
261
+ "att_masks": att_masks,
262
+ "patch_images": patch_images,
263
+ "patch_masks": patch_masks,
264
+ "prev_output_tokens_11": prev_output_tokens_11,
265
+ "prev_output_tokens_12": prev_output_tokens_12,
266
+ "prev_output_tokens_21": prev_output_tokens_21,
267
+ "prev_output_tokens_22": prev_output_tokens_22,
268
+ "delta_x1": delta_x1,
269
+ "delta_y1": delta_y1,
270
+ "delta_x2": delta_x2,
271
+ "delta_y2": delta_y2
272
+ },
273
+ "target": target,
274
+ "w_resize_ratios": w_resize_ratios,
275
+ "h_resize_ratios": h_resize_ratios,
276
+ "region_coords": region_coords,
277
+ "label": labels,
278
+ "token_type": token_type,
279
+ "w": w,
280
+ "h": h,
281
+ "n_poly": n_poly,
282
+ "text": text
283
+ }
284
+
285
+ return batch
286
+
287
+ def collater(self, samples, pad_to_length=None):
288
+ """Merge a list of samples to form a mini-batch.
289
+ Args:
290
+ samples (List[dict]): samples to collate
291
+ Returns:
292
+ dict: a mini-batch containing the data of the task
293
+ """
294
+ return self.collate(samples, pad_idx=self.pad, eos_idx=self.eos)
data/refcoco_pretrain_dataset.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Modified from OFA (https://github.com/OFA-Sys/OFA)
3
+ # Copyright 2022 The OFA-Sys Team.
4
+ # All rights reserved.
5
+ # This source code is licensed under the Apache 2.0 license
6
+ # found in the LICENSE file in the root directory.
7
+ # ------------------------------------------------------------------------
8
+ # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
9
+ # SPDX-License-Identifier: Apache-2.0
10
+
11
+ from io import BytesIO
12
+
13
+ import logging
14
+ import warnings
15
+
16
+ import numpy as np
17
+ import torch
18
+ import base64
19
+ import utils.transforms as T
20
+ import math
21
+ import os
22
+ from PIL import Image, ImageFile
23
+
24
+ from data import data_utils
25
+ from data.base_dataset import BaseDataset
26
+ from bert.tokenization_bert import BertTokenizer
27
+
28
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
29
+ ImageFile.MAX_IMAGE_PIXELS = None
30
+ Image.MAX_IMAGE_PIXELS = None
31
+
32
+ logger = logging.getLogger(__name__)
33
+ warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
34
+
35
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
36
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
37
+
38
+
39
+ class RefcocoPretrainDataset(BaseDataset):
40
+ def __init__(
41
+ self,
42
+ split,
43
+ dataset,
44
+ bpe,
45
+ src_dict,
46
+ tgt_dict=None,
47
+ max_src_length=80,
48
+ max_tgt_length=30,
49
+ patch_image_size=512,
50
+ imagenet_default_mean_and_std=False,
51
+ num_bins=1000,
52
+ max_image_size=512,
53
+ image_path="../../datasets/images"
54
+ ):
55
+ super().__init__(split, dataset, bpe, src_dict, tgt_dict)
56
+ self.max_src_length = max_src_length
57
+ self.max_tgt_length = max_tgt_length
58
+ self.patch_image_size = patch_image_size
59
+ self.num_bins = num_bins
60
+ self.image_path = image_path
61
+
62
+ if imagenet_default_mean_and_std:
63
+ mean = IMAGENET_DEFAULT_MEAN
64
+ std = IMAGENET_DEFAULT_STD
65
+ else:
66
+ mean = [0.5, 0.5, 0.5]
67
+ std = [0.5, 0.5, 0.5]
68
+
69
+ # for positioning
70
+ self.positioning_transform = T.Compose([
71
+ T.RandomResize([patch_image_size], max_size=patch_image_size),
72
+ T.ToTensor(),
73
+ T.Normalize(mean=mean, std=std, max_image_size=max_image_size)
74
+ ])
75
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
76
+
77
+ def __getitem__(self, index):
78
+ uniq_id, img_file, text, region_coord = self.dataset[index]
79
+
80
+ img_path = os.path.join(self.image_path, img_file)
81
+ image = Image.open(img_path).convert("RGB")
82
+
83
+ w, h = image.size
84
+ boxes_target = {"boxes": [], "labels": [], "area": [], "size": torch.tensor([h, w])}
85
+ x0, y0, x1, y1 = region_coord.strip().split(',')
86
+ region = torch.tensor([float(x0), float(y0), float(x1), float(y1)])
87
+ boxes_target["boxes"] = torch.tensor([[float(x0), float(y0), float(x1), float(y1)]])
88
+ boxes_target["labels"] = np.array([0])
89
+ boxes_target["area"] = torch.tensor([(float(x1) - float(x0)) * (float(y1) - float(y0))])
90
+
91
+ patch_image, patch_boxes = self.positioning_transform(image, boxes_target)
92
+ resize_h, resize_w = patch_boxes["size"][0], patch_boxes["size"][1]
93
+ patch_mask = torch.tensor([True])
94
+
95
+ quant_box = [patch_boxes["boxes"][0][i] * (self.num_bins - 1) for i in range(4)]
96
+ quant_box = np.array(quant_box).reshape(2, 2)
97
+
98
+ quant_box11 = [[math.floor(p[0]), math.floor(p[1])] for p in quant_box]
99
+ quant_box21 = [[math.ceil(p[0]), math.floor(p[1])] for p in quant_box]
100
+ quant_box12 = [[math.floor(p[0]), math.ceil(p[1])] for p in quant_box]
101
+ quant_box22 = [[math.ceil(p[0]), math.ceil(p[1])] for p in quant_box]
102
+
103
+
104
+ # compute linear interpolation coefficient (0 for bos token)
105
+ delta_x1 = torch.tensor([0] + [p[0] - math.floor(p[0]) for p in quant_box])
106
+ delta_y1 = torch.tensor([0] + [p[1] - math.floor(p[1]) for p in quant_box])
107
+ delta_x2 = 1 - delta_x1
108
+ delta_y2 = 1 - delta_y1
109
+
110
+ region_coord11 = " ".join([f"<bin_{int(p[0])}_{int(p[1])}>" for p in quant_box11])
111
+ region_coord21 = " ".join([f"<bin_{int(p[0])}_{int(p[1])}>" for p in quant_box21])
112
+ region_coord12 = " ".join([f"<bin_{int(p[0])}_{int(p[1])}>" for p in quant_box12])
113
+ region_coord22 = " ".join([f"<bin_{int(p[0])}_{int(p[1])}>" for p in quant_box22])
114
+
115
+ src_caption = self.pre_caption(text, self.max_src_length)
116
+
117
+ prompt = ' which region does the text " {} " describe?'.format(src_caption)
118
+
119
+ # tgt for input
120
+ tgt_item11 = self.encode_text(region_coord11, use_bpe=False)
121
+ tgt_item12 = self.encode_text(region_coord12, use_bpe=False)
122
+ tgt_item21 = self.encode_text(region_coord21, use_bpe=False)
123
+ tgt_item22 = self.encode_text(region_coord22, use_bpe=False)
124
+
125
+ # tgt for output
126
+ tgt_box = torch.reshape(patch_boxes["boxes"][0], (2, 2))
127
+ target_item = torch.cat([tgt_box, torch.tensor([[1, 1]])], dim=0) # [1, 1] is padding token for eos
128
+
129
+ #target_item = torch.cat([tgt_item, self.eos_item])
130
+ prev_output_item11 = torch.cat([self.bos_item, tgt_item11])
131
+ prev_output_item12 = torch.cat([self.bos_item, tgt_item12])
132
+ prev_output_item21 = torch.cat([self.bos_item, tgt_item21])
133
+ prev_output_item22 = torch.cat([self.bos_item, tgt_item22])
134
+ example = {
135
+ "id": uniq_id,
136
+ "source": prompt,
137
+ "patch_image": patch_image,
138
+ "patch_mask": patch_mask,
139
+ "target": target_item,
140
+ "prev_output_tokens_11": prev_output_item11,
141
+ "prev_output_tokens_12": prev_output_item12,
142
+ "prev_output_tokens_21": prev_output_item21,
143
+ "prev_output_tokens_22": prev_output_item22,
144
+ "delta_x1": delta_x1,
145
+ "delta_y1": delta_y1,
146
+ "delta_x2": delta_x2,
147
+ "delta_y2": delta_y2,
148
+ "w_resize_ratio": resize_w / w,
149
+ "h_resize_ratio": resize_h / h,
150
+ "region_coord": region,
151
+ "token_type": torch.tensor([0, 0, 2])
152
+ }
153
+ return example
154
+
155
+ def collate(self, samples, pad_idx, eos_idx):
156
+ if len(samples) == 0:
157
+ return {}
158
+
159
+ def merge(key):
160
+ return data_utils.collate_tokens(
161
+ [s[key] for s in samples],
162
+ pad_idx,
163
+ eos_idx=eos_idx,
164
+ )
165
+
166
+ id = np.array([s["id"] for s in samples])
167
+ captions = [s["source"] for s in samples]
168
+ tokenized = self.tokenizer.batch_encode_plus(captions, padding="longest", return_tensors="pt")
169
+ src_tokens = tokenized["input_ids"]
170
+ att_masks = tokenized["attention_mask"]
171
+ src_lengths = torch.LongTensor(att_masks.ne(0).long().sum())
172
+
173
+ patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
174
+ patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
175
+
176
+ w_resize_ratios = torch.stack([s["w_resize_ratio"] for s in samples], dim=0)
177
+ h_resize_ratios = torch.stack([s["h_resize_ratio"] for s in samples], dim=0)
178
+
179
+ delta_x1 = torch.stack([s["delta_x1"] for s in samples], dim=0)
180
+ delta_y1 = torch.stack([s["delta_y1"] for s in samples], dim=0)
181
+ delta_x2 = torch.stack([s["delta_x2"] for s in samples], dim=0)
182
+ delta_y2 = torch.stack([s["delta_y2"] for s in samples], dim=0)
183
+
184
+ region_coords = torch.stack([s['region_coord'] for s in samples], dim=0)
185
+
186
+ target = merge("target")
187
+ tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
188
+ ntokens = tgt_lengths.sum().item()
189
+
190
+ prev_output_tokens_11 = merge("prev_output_tokens_11")
191
+ prev_output_tokens_12 = merge("prev_output_tokens_12")
192
+ prev_output_tokens_21 = merge("prev_output_tokens_21")
193
+ prev_output_tokens_22 = merge("prev_output_tokens_22")
194
+
195
+ token_type = merge("token_type")
196
+
197
+ batch = {
198
+ "id": id,
199
+ "nsentences": len(samples),
200
+ "ntokens": ntokens,
201
+ "net_input": {
202
+ "src_tokens": src_tokens,
203
+ "src_lengths": src_lengths,
204
+ "att_masks": att_masks,
205
+ "patch_images": patch_images,
206
+ "patch_masks": patch_masks,
207
+ "prev_output_tokens_11": prev_output_tokens_11,
208
+ "prev_output_tokens_12": prev_output_tokens_12,
209
+ "prev_output_tokens_21": prev_output_tokens_21,
210
+ "prev_output_tokens_22": prev_output_tokens_22,
211
+ "delta_x1": delta_x1,
212
+ "delta_y1": delta_y1,
213
+ "delta_x2": delta_x2,
214
+ "delta_y2": delta_y2
215
+ },
216
+ "target": target,
217
+ "token_type": token_type,
218
+ "w_resize_ratios": w_resize_ratios,
219
+ "h_resize_ratios": h_resize_ratios,
220
+ "region_coords": region_coords
221
+ }
222
+
223
+ return batch
224
+
225
+ def collater(self, samples, pad_to_length=None):
226
+ """Merge a list of samples to form a mini-batch.
227
+ Args:
228
+ samples (List[dict]): samples to collate
229
+ Returns:
230
+ dict: a mini-batch containing the data of the task
231
+ """
232
+ return self.collate(samples, pad_idx=self.pad, eos_idx=self.eos)
data/val_test_files.p ADDED
Binary file (152 kB). View file
 
demo.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from fairseq import utils,tasks
4
+ from utils.checkpoint_utils import load_model_ensemble_and_task
5
+ from utils.eval_utils import eval_step
6
+ from tasks.refcoco import RefcocoTask
7
+ from models.polyformer import PolyFormerModel
8
+ from PIL import Image
9
+ import cv2
10
+ import math
11
+ from skimage import draw
12
+
13
+
14
+ tasks.register_task('refcoco', RefcocoTask)
15
+
16
+ # turn on cuda if GPU is available
17
+ use_cuda = torch.cuda.is_available()
18
+ # use fp16 only when GPU is available
19
+ use_fp16 = True
20
+
21
+ # Load pretrained ckpt & config
22
+ overrides={"bpe_dir":"utils/BPE"}
23
+ models, cfg, task = load_model_ensemble_and_task(
24
+ utils.split_paths('weights/polyformer_l_refcocog.pt'),
25
+ arg_overrides=overrides
26
+ )
27
+ # print(cfg)
28
+ cfg.common.seed = 7
29
+ cfg.generation.beam = 5
30
+ cfg.generation.min_len = 12
31
+ cfg.generation.max_len_a = 0
32
+ cfg.generation.max_len_b = 420
33
+ cfg.generation.no_repeat_ngram_size = 3
34
+ # cfg.max_tgt_length = 256
35
+ #cfg.num_bins = 1000
36
+ cfg.task.patch_image_size = 512
37
+
38
+ from bert.tokenization_bert import BertTokenizer
39
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
40
+
41
+ # Fix seed for stochastic decoding
42
+ if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
43
+ np.random.seed(cfg.common.seed)
44
+ utils.set_torch_seed(cfg.common.seed)
45
+
46
+ # model = ''
47
+ # Move models to GPU
48
+ for model in models:
49
+ model.eval()
50
+ if use_fp16:
51
+ model.half()
52
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
53
+ model.cuda()
54
+ model.prepare_for_inference_(cfg)
55
+
56
+ # Initialize generator
57
+ generator = task.build_generator(models, cfg.generation)
58
+
59
+
60
+ # Image transform
61
+ from torchvision import transforms
62
+ mean = [0.5, 0.5, 0.5]
63
+ std = [0.5, 0.5, 0.5]
64
+
65
+ patch_resize_transform = transforms.Compose([
66
+ lambda image: image.convert("RGB"),
67
+ transforms.Resize((cfg.task.patch_image_size, cfg.task.patch_image_size), interpolation=Image.BICUBIC),
68
+ transforms.ToTensor(),
69
+ transforms.Normalize(mean=mean, std=std),
70
+ ])
71
+
72
+ # Text preprocess
73
+ bos_item = torch.LongTensor([task.src_dict.bos()])
74
+ eos_item = torch.LongTensor([task.src_dict.eos()])
75
+ pad_idx = task.src_dict.pad()
76
+
77
+
78
+ # Construct input for refcoco task
79
+ patch_image_size = cfg.task.patch_image_size
80
+ def construct_sample(image: Image, text: str):
81
+ w, h = image.size
82
+ w_resize_ratio = torch.tensor(patch_image_size / w).unsqueeze(0)
83
+ h_resize_ratio = torch.tensor(patch_image_size / h).unsqueeze(0)
84
+ patch_image = patch_resize_transform(image).unsqueeze(0)
85
+ patch_mask = torch.tensor([True])
86
+
87
+ prompt = ' which region does the text " {} " describe?'.format(text)
88
+ tokenized = tokenizer.batch_encode_plus([prompt], padding="longest", return_tensors="pt")
89
+ src_tokens = tokenized["input_ids"]
90
+ att_masks = tokenized["attention_mask"]
91
+ src_lengths = torch.LongTensor(att_masks.ne(0).long().sum())
92
+
93
+ sample = {
94
+ "id":np.array(['42']),
95
+ "net_input": {
96
+ "src_tokens": src_tokens,
97
+ "src_lengths": src_lengths,
98
+ "att_masks": att_masks,
99
+ "patch_images": patch_image,
100
+ "patch_masks": patch_mask,
101
+ },
102
+ "w_resize_ratios": w_resize_ratio,
103
+ "h_resize_ratios": h_resize_ratio,
104
+ "region_coords": torch.randn(1, 4),
105
+ "label": np.zeros((512,512)),
106
+ "poly": 'None',
107
+ "text": text
108
+ }
109
+ return sample
110
+
111
+ # Function to turn FP32 to FP16
112
+ def apply_half(t):
113
+ if t.dtype is torch.float32:
114
+ return t.to(dtype=torch.half)
115
+ return t
116
+
117
+
118
+ from io import BytesIO
119
+ import base64
120
+ import re
121
+
122
+ def pre_caption(caption):
123
+ caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
124
+
125
+ caption = re.sub(
126
+ r"\s{2,}",
127
+ ' ',
128
+ caption,
129
+ )
130
+ caption = caption.rstrip('\n')
131
+ caption = caption.strip(' ')
132
+ return caption
133
+
134
+
135
+ def convert_pts(coeffs):
136
+ pts = []
137
+ for i in range(len(coeffs) // 2):
138
+ pts.append([coeffs[2 * i + 1], coeffs[2 * i]]) # y, x
139
+ return np.array(pts, np.int32)
140
+
141
+ def get_mask_from_codes(codes, img_size):
142
+ masks = [np.zeros(img_size)]
143
+ for code in codes:
144
+ mask = draw.polygon2mask(img_size, convert_pts(code))
145
+ mask = np.array(mask, np.uint8)
146
+ masks.append(mask)
147
+ mask = sum(masks)
148
+ mask = mask > 0
149
+ return mask.astype(np.uint8)
150
+
151
+
152
+ def overlay_predictions(img, mask=None, polygons=None, bbox=None, color_box=(0, 255, 0), color_mask=[255, 102, 102], color_poly=[255, 0, 0], thickness=3, radius=6):
153
+ overlayed = img.copy()
154
+ if bbox is not None:
155
+ overlayed = draw_bbox(overlayed, bbox, color=color_box, thickness=thickness)
156
+ if mask is not None:
157
+ overlayed = overlay_davis(overlayed, mask, colors=[[0, 0, 0], color_mask])
158
+ if polygons is not None:
159
+ overlayed = plot_polygons(overlayed, polygons, color=color_poly, radius=radius)
160
+ return overlayed
161
+
162
+
163
+ def overlay_davis(image, mask, colors=[[0, 0, 0], [255, 102, 102]], cscale=1, alpha=0.4): # [255, 178, 102] orange [102, 178, 255] red
164
+ from scipy.ndimage.morphology import binary_dilation
165
+
166
+ colors = np.reshape(colors, (-1, 3))
167
+ colors = np.atleast_2d(colors) * cscale
168
+
169
+ im_overlay = image.copy()
170
+ object_ids = np.unique(mask)
171
+
172
+ h_i, w_i = image.shape[0:2]
173
+ h_m, w_m = mask.shape[0:2]
174
+ if h_i != h_m:
175
+ mask = cv2.resize(mask, [h_i, w_i], interpolation=cv2.INTER_NEAREST)
176
+ for object_id in object_ids[1:]:
177
+ # Overlay color on binary mask
178
+ foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id])
179
+ binary_mask = mask == object_id
180
+
181
+ # Compose image
182
+ im_overlay[binary_mask] = foreground[binary_mask]
183
+
184
+ return im_overlay.astype(image.dtype)
185
+
186
+
187
+ def draw_bbox(img, box, color=(0, 255, 0), thickness=2):
188
+ x1, y1, x2, y2 = box
189
+ return cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness=thickness)
190
+
191
+ def plot_polygons(img, polygons, color=(255, 0, 0), radius=7):
192
+ for polygon in polygons:
193
+ if len(polygon) > 0:
194
+ polygon = np.reshape(polygon[:len(polygon)-len(polygon)%2], (len(polygon)//2, 2)).astype(np.int16)
195
+ for i, point in enumerate(polygon):
196
+ img = cv2.circle(img, point, radius, color, thickness=-1)
197
+ img = cv2.circle(img, polygon[0], radius, color, thickness=-1)
198
+ return img
199
+
200
+ def plot_arrow(img, polygons, color=(128, 128, 128), thickness=3, tip_length=0.3):
201
+ for polygon in polygons:
202
+ if len(polygon) > 0:
203
+ polygon = np.reshape(polygon[:len(polygon)-len(polygon)%2], (len(polygon)//2, 2)).astype(np.int16)
204
+ for i, point in enumerate(polygon):
205
+ if i > 0:
206
+ img = cv2.arrowedLine(img, polygon[i-1], point, color, thickness=thickness, tipLength=tip_length)
207
+ return img
208
+
209
+ def downsample_polygon(polygon, ds_rate=25):
210
+ points = np.array(polygon).reshape(int(len(polygon) / 2), 2)
211
+ points = points[::ds_rate]
212
+ return list(points.flatten())
213
+
214
+
215
+ def downsample_polygons(polygons, ds_rate=25):
216
+ polygons_ds = []
217
+ for polygon in polygons:
218
+ polygons_ds.append(downsample_polygon(polygon, ds_rate))
219
+ return polygons_ds
220
+
221
+
222
+
223
+ def visual_grounding(image, text):
224
+
225
+ # Construct input sample & preprocess for GPU if cuda available
226
+ sample = construct_sample(image, text.lower())
227
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
228
+ sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample
229
+
230
+ with torch.no_grad():
231
+ if isinstance(models, list):
232
+ model = models[0]
233
+ min_len = 6
234
+ max_len = 210
235
+ model.eval()
236
+ img = sample["net_input"]["patch_images"]
237
+ b = img.shape[0]
238
+ prev_output_token_11 = [[0] for _ in range(b)]
239
+ prev_output_token_12 = [[0] for _ in range(b)]
240
+ prev_output_token_21 = [[0] for _ in range(b)]
241
+ prev_output_token_22 = [[0] for _ in range(b)]
242
+ delta_x1 = [[0] for _ in range(b)]
243
+ delta_y1 = [[0] for _ in range(b)]
244
+ delta_x2 = [[1] for _ in range(b)]
245
+ delta_y2 = [[1] for _ in range(b)]
246
+
247
+ gen_out = [[] for _ in range(b)]
248
+
249
+ n_bins = 64
250
+
251
+ unfinish_flag = np.ones(b)
252
+ i = 0
253
+
254
+ encoder_out = model.encoder(
255
+ sample['net_input']['src_tokens'],
256
+ src_lengths=sample['net_input']['src_lengths'],
257
+ att_masks=sample['net_input']['att_masks'],
258
+ patch_images=sample['net_input']['patch_images'],
259
+ patch_masks=sample['net_input']['patch_masks'],
260
+ token_embeddings=None,
261
+ return_all_hiddens=False,
262
+ sample_patch_num=None
263
+ )
264
+ attn_masks = []
265
+ while i < max_len and unfinish_flag.any():
266
+ # print(i)
267
+ prev_output_tokens_11_tensor = torch.tensor(np.array(prev_output_token_11)).to(img.device).long()
268
+ prev_output_tokens_12_tensor = torch.tensor(np.array(prev_output_token_12)).to(img.device).long()
269
+ prev_output_tokens_21_tensor = torch.tensor(np.array(prev_output_token_21)).to(img.device).long()
270
+ prev_output_tokens_22_tensor = torch.tensor(np.array(prev_output_token_22)).to(img.device).long()
271
+ delta_x1_tensor = torch.tensor(np.array(delta_x1)).to(img.device)
272
+ delta_x2_tensor = torch.tensor(np.array(delta_x2)).to(img.device)
273
+ delta_y1_tensor = torch.tensor(np.array(delta_y1)).to(img.device)
274
+ delta_y2_tensor = torch.tensor(np.array(delta_y2)).to(img.device)
275
+
276
+ net_output = model.decoder(
277
+ prev_output_tokens_11_tensor,
278
+ prev_output_tokens_12_tensor,
279
+ prev_output_tokens_21_tensor,
280
+ prev_output_tokens_22_tensor,
281
+ delta_x1_tensor,
282
+ delta_y1_tensor,
283
+ delta_x2_tensor,
284
+ delta_y2_tensor,
285
+ code_masks=None,
286
+ encoder_out=encoder_out,
287
+ features_only=False,
288
+ alignment_layer=None,
289
+ alignment_heads=None,
290
+ src_lengths=sample['net_input']['src_lengths'],
291
+ return_all_hiddens=False
292
+ )
293
+
294
+ cls_output = net_output[0]
295
+ cls_type = torch.argmax(cls_output, 2)
296
+ reg_output = net_output[1].squeeze(-1)
297
+ attn = net_output[2]['attn']
298
+ attn_arrays = [att.detach().cpu().numpy() for att in attn]
299
+ attn_arrays = np.concatenate(attn_arrays, 0)
300
+ attn_arrays = np.mean(attn_arrays, 0)
301
+ attn_arrays = attn_arrays[i, :256].reshape(16, 16)
302
+ h, w = image.size
303
+ attn_mask = cv2.resize(attn_arrays.astype(np.float32), (h, w))
304
+ attn_masks.append(attn_mask)
305
+
306
+ for j in range(b):
307
+ # print(j)
308
+ if unfinish_flag[j] == 1: # prediction is not finished
309
+ cls_j = cls_type[j, i].item()
310
+ if cls_j == 0 or (cls_j == 2 and i < min_len): # 0 for coordinate tokens; 2 for eos
311
+ output_j_x, output_j_y = reg_output[j, i].cpu().numpy()
312
+ output_j_x = min(output_j_x, 1)
313
+ output_j_y = min(output_j_y, 1)
314
+
315
+ gen_out[j].extend([output_j_x, output_j_y])
316
+
317
+ output_j_x = output_j_x * (n_bins - 1)
318
+ output_j_y = output_j_y * (n_bins - 1)
319
+
320
+ output_j_x_floor = math.floor(output_j_x)
321
+ output_j_y_floor = math.floor(output_j_y)
322
+ output_j_x_ceil = math.ceil(output_j_x)
323
+ output_j_y_ceil = math.ceil(output_j_y)
324
+
325
+ # convert to token
326
+ prev_output_token_11[j].append(output_j_x_floor * n_bins + output_j_y_floor + 4)
327
+ prev_output_token_12[j].append(output_j_x_floor * n_bins + output_j_y_ceil + 4)
328
+ prev_output_token_21[j].append(output_j_x_ceil * n_bins + output_j_y_floor + 4)
329
+ prev_output_token_22[j].append(output_j_x_ceil * n_bins + output_j_y_ceil + 4)
330
+
331
+ delta_x = output_j_x - output_j_x_floor
332
+ delta_y = output_j_y - output_j_y_floor
333
+ elif cls_j == 1: # 1 for separator tokens
334
+ gen_out[j].append(2) # insert 2 indicating separator tokens
335
+ prev_output_token_11[j].append(3)
336
+ prev_output_token_12[j].append(3)
337
+ prev_output_token_21[j].append(3)
338
+ prev_output_token_22[j].append(3)
339
+ delta_x = 0
340
+ delta_y = 0
341
+ else: # eos is predicted and i >= min_len
342
+ unfinish_flag[j] = 0
343
+ gen_out[j].append(-1)
344
+ prev_output_token_11[j].append(2) # 2 is eos token
345
+ prev_output_token_12[j].append(2) # 2 is eos token
346
+ prev_output_token_21[j].append(2) # 2 is eos token
347
+ prev_output_token_22[j].append(2) # 2 is eos token
348
+ delta_x = 0
349
+ delta_y = 0
350
+ else: # prediction is finished
351
+ gen_out[j].append(-1)
352
+ prev_output_token_11[j].append(1) # 1 is padding token
353
+ prev_output_token_12[j].append(1)
354
+ prev_output_token_21[j].append(1)
355
+ prev_output_token_22[j].append(1)
356
+ delta_x = 0
357
+ delta_y = 0
358
+ delta_x1[j].append(delta_x)
359
+ delta_y1[j].append(delta_y)
360
+ delta_x2[j].append(1 - delta_x)
361
+ delta_y2[j].append(1 - delta_y)
362
+ i += 1
363
+ print("inference step: ", i)
364
+
365
+ hyps = []
366
+ hyps_det = []
367
+ n_poly_pred = []
368
+ b = len(gen_out)
369
+ for i in range(b):
370
+ gen_out_i = np.array(gen_out[i])
371
+ gen_out_i = gen_out_i[gen_out_i != -1] # excluding eos and padding indices
372
+
373
+
374
+ gen_out_i_det = gen_out_i[:4]
375
+ w, h = image.size
376
+ gen_out_i_det[::2] *= w
377
+ gen_out_i_det[1::2] *= h
378
+
379
+ polygons_pred = gen_out_i[4:]
380
+ polygons_pred = np.append(polygons_pred, [2])
381
+ size = len(polygons_pred)
382
+ idx_list = [idx for idx, val in
383
+ enumerate(polygons_pred) if val == 2] # 2 indicates separator token
384
+
385
+ polygons_pred[::2] *= w
386
+ polygons_pred[1::2] *= h
387
+ if len(idx_list) > 0: # multiple polygons
388
+ polygons = []
389
+ pred_idx = 0
390
+ for idx in idx_list:
391
+ cur_idx = idx
392
+ if pred_idx == cur_idx or pred_idx == size:
393
+ pass
394
+ else:
395
+ polygons.append(polygons_pred[pred_idx: cur_idx])
396
+ pred_idx = cur_idx + 1
397
+ else:
398
+ polygons = [polygons_pred]
399
+
400
+ n_poly_pred.append(len(polygons))
401
+ hyps.append(polygons)
402
+ hyps_det.append(gen_out_i_det)
403
+
404
+
405
+ pred_mask = get_mask_from_codes(hyps[0], (h, w))
406
+ pred_overlayed = overlay_predictions(np.asarray(image), pred_mask, hyps[0], hyps_det[0])
407
+
408
+ return pred_overlayed, np.array(pred_mask*255, dtype=np.uint8)
409
+
410
+
evaluate.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright 2022 The OFA-Sys Team.
3
+ # All rights reserved.
4
+ # This source code is licensed under the Apache 2.0 license
5
+ # found in the LICENSE file in the root directory.
6
+
7
+ import logging
8
+ import os
9
+ import sys
10
+
11
+ import numpy as np
12
+ import torch
13
+ from fairseq import distributed_utils, options, tasks, utils
14
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
15
+ from fairseq.logging import progress_bar
16
+ from fairseq.utils import reset_logging
17
+ from omegaconf import DictConfig
18
+
19
+ from utils import checkpoint_utils
20
+ from utils.eval_utils import eval_step, merge_results
21
+
22
+ logging.basicConfig(
23
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
24
+ datefmt="%Y-%m-%d %H:%M:%S",
25
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
26
+ stream=sys.stdout,
27
+ )
28
+ logger = logging.getLogger("ofa.evaluate")
29
+
30
+
31
+ def apply_half(t):
32
+ if t.dtype is torch.float32:
33
+ return t.to(dtype=torch.half)
34
+ return t
35
+
36
+
37
+ def main(cfg: DictConfig, **kwargs):
38
+ utils.import_user_module(cfg.common)
39
+
40
+ reset_logging()
41
+ logger.info(cfg)
42
+
43
+ assert (
44
+ cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
45
+ ), "Must specify batch size either with --max-tokens or --batch-size"
46
+
47
+ # Fix seed for stochastic decoding
48
+ if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
49
+ np.random.seed(cfg.common.seed)
50
+ utils.set_torch_seed(cfg.common.seed)
51
+
52
+ use_fp16 = cfg.common.fp16
53
+ use_cuda = torch.cuda.is_available() and not cfg.common.cpu
54
+
55
+ if use_cuda:
56
+ torch.cuda.set_device(cfg.distributed_training.device_id)
57
+
58
+ # Load ensemble
59
+ overrides = eval(cfg.common_eval.model_overrides)
60
+ # Deal with beam-search / all-candidate VQA eval
61
+ if cfg.task._name == "vqa_gen":
62
+ overrides['val_inference_type'] = "beamsearch" if kwargs['beam_search_vqa_eval'] else "allcand"
63
+
64
+ logger.info("loading model(s) from {}".format(cfg.common_eval.path))
65
+ if kwargs["zero_shot"]:
66
+ task = tasks.setup_task(cfg.task)
67
+ models, saved_cfg = checkpoint_utils.load_model_ensemble(
68
+ utils.split_paths(cfg.common_eval.path),
69
+ arg_overrides=overrides,
70
+ task=task,
71
+ suffix=cfg.checkpoint.checkpoint_suffix,
72
+ strict=(cfg.checkpoint.checkpoint_shard_count == 1),
73
+ num_shards=cfg.checkpoint.checkpoint_shard_count,
74
+ )
75
+ else:
76
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
77
+ utils.split_paths(cfg.common_eval.path),
78
+ arg_overrides=overrides,
79
+ suffix=cfg.checkpoint.checkpoint_suffix,
80
+ strict=(cfg.checkpoint.checkpoint_shard_count == 1),
81
+ num_shards=cfg.checkpoint.checkpoint_shard_count,
82
+ )
83
+
84
+ # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
85
+ task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
86
+
87
+ # Move models to GPU
88
+ for model, ckpt_path in zip(models, utils.split_paths(cfg.common_eval.path)):
89
+ if kwargs['ema_eval']:
90
+ logger.info("loading EMA weights from {}".format(ckpt_path))
91
+ model.load_state_dict(checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model'])
92
+ model.eval()
93
+ if use_fp16:
94
+ model.half()
95
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
96
+ model.cuda()
97
+ model.prepare_for_inference_(cfg)
98
+
99
+ # Load dataset (possibly sharded)
100
+ itr = task.get_batch_iterator(
101
+ dataset=task.dataset(cfg.dataset.gen_subset),
102
+ max_tokens=cfg.dataset.max_tokens,
103
+ max_sentences=cfg.dataset.batch_size,
104
+ max_positions=utils.resolve_max_positions(
105
+ task.max_positions(), *[m.max_positions() for m in models]
106
+ ),
107
+ ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
108
+ required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
109
+ seed=cfg.common.seed,
110
+ num_shards=cfg.distributed_training.distributed_world_size,
111
+ shard_id=cfg.distributed_training.distributed_rank,
112
+ num_workers=cfg.dataset.num_workers,
113
+ data_buffer_size=cfg.dataset.data_buffer_size,
114
+ ).next_epoch_itr(shuffle=False)
115
+ progress = progress_bar.progress_bar(
116
+ itr,
117
+ log_format=cfg.common.log_format,
118
+ log_interval=cfg.common.log_interval,
119
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
120
+ )
121
+
122
+ # Initialize generator
123
+ generator = task.build_generator(models, cfg.generation)
124
+
125
+ # for sample in progress:
126
+ # if "net_input" not in sample:
127
+ # continue
128
+ # sample = utils.move_to_cuda(sample) if use_cuda else sample
129
+ # sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample
130
+ # with torch.no_grad():
131
+ # eval_step(task, generator, models, sample, **kwargs)
132
+ # progress.log({"sentences": sample["nsentences"]})
133
+ #
134
+ # merge_results(task, cfg, logger, kwargs['result_dir'])
135
+
136
+ results = []
137
+ prec_list = [.5, .6, .7, .8, .9]
138
+ prec_score_sum = [torch.FloatTensor([0]).cuda() for _ in prec_list]
139
+ f_score_sum = torch.FloatTensor([0]).cuda()
140
+ ap_det_score_sum = torch.FloatTensor([0]).cuda()
141
+ score_sum = torch.FloatTensor([0]).cuda()
142
+ score_cnt = torch.FloatTensor([0]).cuda()
143
+ cum_I_sum = torch.FloatTensor([0]).cuda()
144
+ cum_U_sum = torch.FloatTensor([0]).cuda()
145
+ for sample in progress:
146
+ if "net_input" not in sample:
147
+ continue
148
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
149
+ sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample
150
+ with torch.no_grad():
151
+ result, scores, f_scores, ap_scores, cum_I, cum_U = eval_step(task, generator, models, sample, **kwargs)
152
+ results += result
153
+ for prec_score, prec in zip(prec_score_sum, prec_list):
154
+ prec_score += sum(scores >= prec) if scores is not None else 0
155
+ cum_I_sum += sum(cum_I) if scores is not None else 0
156
+ cum_U_sum += sum(cum_U) if scores is not None else 0
157
+ score_sum += sum(scores) if scores is not None else 0
158
+ f_score_sum += sum(f_scores) if scores is not None else 0
159
+ ap_det_score_sum += sum(ap_scores) if scores is not None else 0
160
+ score_cnt += len(scores) if scores is not None else 0
161
+ progress.log({"sentences": sample["nsentences"]})
162
+
163
+ merge_results(task, cfg, logger, score_cnt, score_sum, f_score_sum, ap_det_score_sum,prec_score_sum, cum_I_sum, cum_U_sum, results)
164
+
165
+
166
+ def cli_main():
167
+ parser = options.get_generation_parser()
168
+ parser.add_argument("--ema-eval", action='store_true', help="Use EMA weights to make evaluation.")
169
+ parser.add_argument("--beam-search-vqa-eval", action='store_true', help="Use beam search for vqa evaluation (faster inference speed but sub-optimal result), if not specified, we compute scores for each answer in the candidate set, which is slower but can obtain best result.")
170
+ parser.add_argument("--zero-shot", action='store_true')
171
+ parser.add_argument("--vis_dir", type=str, default=None)
172
+ parser.add_argument("--result_dir", type=str, default=None)
173
+ parser.add_argument("--vis", action='store_true', default=False)
174
+ args = options.parse_args_and_arch(parser)
175
+ cfg = convert_namespace_to_omegaconf(args)
176
+ if args.result_dir is None:
177
+ args.result_dir = args.vis_dir
178
+ distributed_utils.call_main(
179
+ cfg, main, ema_eval=args.ema_eval, beam_search_vqa_eval=args.beam_search_vqa_eval, zero_shot=args.zero_shot,
180
+ vis_dir=args.vis_dir, vis=args.vis, result_dir=args.result_dir
181
+ )
182
+
183
+
184
+ if __name__ == "__main__":
185
+ cli_main()
fairseq/.github/ISSUE_TEMPLATE.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ## 👉 [Please follow one of these issue templates](https://github.com/pytorch/fairseq/issues/new/choose) 👈
2
+
3
+ Note: to keep the backlog clean and actionable, issues may be immediately closed if they do not follow one of the above issue templates.
fairseq/.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 🐛 Bug Report
3
+ about: Submit a bug report to help us improve
4
+ labels: 'bug, needs triage'
5
+ ---
6
+
7
+ ## 🐛 Bug
8
+
9
+ <!-- A clear and concise description of what the bug is. -->
10
+
11
+ ### To Reproduce
12
+
13
+ Steps to reproduce the behavior (**always include the command you ran**):
14
+
15
+ 1. Run cmd '....'
16
+ 2. See error
17
+
18
+ <!-- If you have a code sample, error messages, stack traces, please provide it here as well -->
19
+
20
+
21
+ #### Code sample
22
+ <!-- Ideally attach a minimal code sample to reproduce the decried issue.
23
+ Minimal means having the shortest code but still preserving the bug. -->
24
+
25
+ ### Expected behavior
26
+
27
+ <!-- A clear and concise description of what you expected to happen. -->
28
+
29
+ ### Environment
30
+
31
+ - fairseq Version (e.g., 1.0 or main):
32
+ - PyTorch Version (e.g., 1.0)
33
+ - OS (e.g., Linux):
34
+ - How you installed fairseq (`pip`, source):
35
+ - Build command you used (if compiling from source):
36
+ - Python version:
37
+ - CUDA/cuDNN version:
38
+ - GPU models and configuration:
39
+ - Any other relevant information:
40
+
41
+ ### Additional context
42
+
43
+ <!-- Add any other context about the problem here. -->
fairseq/.github/ISSUE_TEMPLATE/documentation.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 📚 Documentation/Typos
3
+ about: Report an issue related to documentation or a typo
4
+ labels: 'documentation, needs triage'
5
+ ---
6
+
7
+ ## 📚 Documentation
8
+
9
+ For typos and doc fixes, please go ahead and:
10
+
11
+ 1. Create an issue.
12
+ 2. Fix the typo.
13
+ 3. Submit a PR.
14
+
15
+ Thanks!
fairseq/.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 🚀 Feature Request
3
+ about: Submit a proposal/request for a new feature
4
+ labels: 'enhancement, help wanted, needs triage'
5
+ ---
6
+
7
+ ## 🚀 Feature Request
8
+ <!-- A clear and concise description of the feature proposal -->
9
+
10
+ ### Motivation
11
+
12
+ <!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
13
+
14
+ ### Pitch
15
+
16
+ <!-- A clear and concise description of what you want to happen. -->
17
+
18
+ ### Alternatives
19
+
20
+ <!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->
21
+
22
+ ### Additional context
23
+
24
+ <!-- Add any other context or screenshots about the feature request here. -->
fairseq/.github/ISSUE_TEMPLATE/how-to-question.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: ❓ Questions/Help
3
+ about: If you have questions, please first search existing issues and docs
4
+ labels: 'question, needs triage'
5
+ ---
6
+
7
+ ## ❓ Questions and Help
8
+
9
+ ### Before asking:
10
+ 1. search the issues.
11
+ 2. search the docs.
12
+
13
+ <!-- If you still can't find what you need: -->
14
+
15
+ #### What is your question?
16
+
17
+ #### Code
18
+
19
+ <!-- Please paste a code snippet if your question requires it! -->
20
+
21
+ #### What have you tried?
22
+
23
+ #### What's your environment?
24
+
25
+ - fairseq Version (e.g., 1.0 or main):
26
+ - PyTorch Version (e.g., 1.0)
27
+ - OS (e.g., Linux):
28
+ - How you installed fairseq (`pip`, source):
29
+ - Build command you used (if compiling from source):
30
+ - Python version:
31
+ - CUDA/cuDNN version:
32
+ - GPU models and configuration:
33
+ - Any other relevant information:
fairseq/.github/PULL_REQUEST_TEMPLATE.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Before submitting
2
+
3
+ - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
4
+ - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
5
+ - [ ] Did you make sure to update the docs?
6
+ - [ ] Did you write any new necessary tests?
7
+
8
+ ## What does this PR do?
9
+ Fixes # (issue).
10
+
11
+ ## PR review
12
+ Anyone in the community is free to review the PR once the tests have passed.
13
+ If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
14
+
15
+ ## Did you have fun?
16
+ Make sure you had fun coding 🙃
fairseq/.github/stale.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for probot-stale - https://github.com/probot/stale
2
+ # Mostly copied from github.com/facebook/react/blob/master/.github/stale.yml
3
+ # Number of days of inactivity before an issue becomes stale
4
+ daysUntilStale: 90
5
+ # Number of days of inactivity before a stale issue is closed
6
+ daysUntilClose: 7
7
+ # Issues with these labels will never be considered stale
8
+ exemptLabels:
9
+ - bug
10
+ # Label to use when marking an issue as stale
11
+ staleLabel: stale
12
+ issues:
13
+ # Comment to post when marking an issue as stale.
14
+ markComment: >
15
+ This issue has been automatically marked as stale.
16
+ **If this issue is still affecting you, please leave any comment** (for example, "bump"), and we'll keep it open.
17
+ We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!
18
+ # Comment to post when closing a stale issue.
19
+ closeComment: >
20
+ Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!
21
+ pulls:
22
+ # Comment to post when marking a pull request as stale.
23
+ markComment: >
24
+ This pull request has been automatically marked as stale.
25
+ **If this pull request is still relevant, please leave any comment** (for example, "bump"), and we'll keep it open.
26
+ We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated.
27
+ # Comment to post when closing a stale pull request.
28
+ closeComment: >
29
+ Closing this pull request after a prolonged period of inactivity. If this issue is still present in the latest release, please ask for this pull request to be reopened. Thank you!
30
+
fairseq/.github/workflows/build.yml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: build
2
+
3
+ on:
4
+ # Trigger the workflow on push to main or any pull request
5
+ push:
6
+ branches:
7
+ - main
8
+ pull_request:
9
+
10
+ jobs:
11
+ build:
12
+
13
+ strategy:
14
+ max-parallel: 4
15
+ matrix:
16
+ platform: [ubuntu-latest, macos-latest]
17
+ python-version: [3.6, 3.7]
18
+
19
+ runs-on: ${{ matrix.platform }}
20
+
21
+ steps:
22
+ - uses: actions/checkout@v2
23
+
24
+ - name: Set up Python ${{ matrix.python-version }}
25
+ uses: actions/setup-python@v2
26
+ with:
27
+ python-version: ${{ matrix.python-version }}
28
+
29
+ - name: Conditionally install pytorch
30
+ if: matrix.platform == 'windows-latest'
31
+ run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html
32
+
33
+ - name: Install locally
34
+ run: |
35
+ python -m pip install --upgrade pip
36
+ git submodule update --init --recursive
37
+ python setup.py build_ext --inplace
38
+ python -m pip install --editable .
39
+
40
+ - name: Install optional test requirements
41
+ run: |
42
+ python -m pip install iopath transformers pyarrow
43
+ python -m pip install git+https://github.com/facebookresearch/fairscale.git@main
44
+
45
+ - name: Lint with flake8
46
+ run: |
47
+ pip install flake8
48
+ # stop the build if there are Python syntax errors or undefined names
49
+ flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --extend-exclude fairseq/model_parallel/megatron
50
+ # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
51
+ flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --extend-exclude fairseq/model_parallel/megatron
52
+
53
+ - name: Run tests
54
+ run: |
55
+ python setup.py test
fairseq/.github/workflows/build_wheels.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: build_wheels
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - v[0-9]+.[0-9]+.[x0-9]+
7
+ tags:
8
+ - v*
9
+
10
+ jobs:
11
+ build_wheels:
12
+ name: Build wheels on ${{ matrix.os }}
13
+ runs-on: ${{ matrix.os }}
14
+ strategy:
15
+ matrix:
16
+ os: [ubuntu-latest, macos-latest]
17
+
18
+ steps:
19
+ - uses: actions/checkout@v2
20
+
21
+ - name: Install Python
22
+ uses: actions/setup-python@v2
23
+ with:
24
+ python-version: '3.7'
25
+
26
+ - name: Install cibuildwheel
27
+ run: |
28
+ python -m pip install cibuildwheel
29
+
30
+ - name: Build wheels for CPython
31
+ run: |
32
+ python -m cibuildwheel --output-dir dist
33
+ env:
34
+ CIBW_BUILD: "cp36-*64 cp37-*64 cp38-*64"
35
+ CIBW_MANYLINUX_X86_64_IMAGE: manylinux1
36
+ CIBW_BEFORE_BUILD: git submodule update --init --recursive && pip install .
37
+
38
+ - uses: actions/upload-artifact@v2
39
+ with:
40
+ name: wheels
41
+ path: ./dist/*.whl
fairseq/.gitignore ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # JetBrains PyCharm IDE
2
+ .idea/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # macOS dir files
13
+ .DS_Store
14
+
15
+ # Distribution / packaging
16
+ .Python
17
+ env/
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+
34
+ # Checkpoints
35
+ checkpoints
36
+
37
+ # PyInstaller
38
+ # Usually these files are written by a python script from a template
39
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
40
+ *.manifest
41
+ *.spec
42
+
43
+ # Installer logs
44
+ pip-log.txt
45
+ pip-delete-this-directory.txt
46
+
47
+ # Unit test / coverage reports
48
+ htmlcov/
49
+ .tox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ .hypothesis/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # pyenv
83
+ .python-version
84
+
85
+ # celery beat schedule file
86
+ celerybeat-schedule
87
+
88
+ # SageMath parsed files
89
+ *.sage.py
90
+
91
+ # dotenv
92
+ .env
93
+
94
+ # virtualenv
95
+ .venv
96
+ venv/
97
+ ENV/
98
+
99
+ # Spyder project settings
100
+ .spyderproject
101
+ .spyproject
102
+
103
+ # Rope project settings
104
+ .ropeproject
105
+
106
+ # mkdocs documentation
107
+ /site
108
+
109
+ # mypy
110
+ .mypy_cache/
111
+
112
+ # Generated files
113
+ /fairseq/temporal_convolution_tbc
114
+ /fairseq/modules/*_layer/*_forward.cu
115
+ /fairseq/modules/*_layer/*_backward.cu
116
+ /fairseq/version.py
117
+
118
+ # data
119
+ data-bin/
120
+
121
+ # reranking
122
+ /examples/reranking/rerank_data
123
+
124
+ # Cython-generated C++ source files
125
+ /fairseq/data/data_utils_fast.cpp
126
+ /fairseq/data/token_block_utils_fast.cpp
127
+
128
+ # VSCODE
129
+ .vscode/ftp-sync.json
130
+ .vscode/settings.json
131
+
132
+ # Experimental Folder
133
+ experimental/*
134
+
135
+ # Weights and Biases logs
136
+ wandb/
fairseq/.gitmodules ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [submodule "fairseq/model_parallel/megatron"]
2
+ path = fairseq/model_parallel/megatron
3
+ url = https://github.com/ngoyal2707/Megatron-LM
4
+ branch = fairseq
fairseq/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ ## Enforcement
56
+
57
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
+ reported by contacting the project team at <[email protected]>. All
59
+ complaints will be reviewed and investigated and will result in a response that
60
+ is deemed necessary and appropriate to the circumstances. The project team is
61
+ obligated to maintain confidentiality with regard to the reporter of an incident.
62
+ Further details of specific enforcement policies may be posted separately.
63
+
64
+ Project maintainers who do not follow or enforce the Code of Conduct in good
65
+ faith may face temporary or permanent repercussions as determined by other
66
+ members of the project's leadership.
67
+
68
+ ## Attribution
69
+
70
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72
+
73
+ [homepage]: https://www.contributor-covenant.org
74
+
75
+ For answers to common questions about this code of conduct, see
76
+ https://www.contributor-covenant.org/faq
77
+
fairseq/CONTRIBUTING.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq)
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ ## License
26
+ By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
27
+ you agree that your contributions will be licensed under the LICENSE file in
28
+ the root directory of this source tree.
fairseq/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Facebook, Inc. and its affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
fairseq/README.md ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="docs/fairseq_logo.png" width="150">
3
+ <br />
4
+ <br />
5
+ <a href="https://github.com/pytorch/fairseq/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
6
+ <a href="https://github.com/pytorch/fairseq/releases"><img alt="Latest Release" src="https://img.shields.io/github/release/pytorch/fairseq.svg" /></a>
7
+ <a href="https://github.com/pytorch/fairseq/actions?query=workflow:build"><img alt="Build Status" src="https://github.com/pytorch/fairseq/workflows/build/badge.svg" /></a>
8
+ <a href="https://fairseq.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation Status" src="https://readthedocs.org/projects/fairseq/badge/?version=latest" /></a>
9
+ </p>
10
+
11
+ --------------------------------------------------------------------------------
12
+
13
+ Fairseq(-py) is a sequence modeling toolkit that allows researchers and
14
+ developers to train custom models for translation, summarization, language
15
+ modeling and other text generation tasks.
16
+
17
+ We provide reference implementations of various sequence modeling papers:
18
+
19
+ <details><summary>List of implemented papers</summary><p>
20
+
21
+ * **Convolutional Neural Networks (CNN)**
22
+ + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
23
+ + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
24
+ + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
25
+ + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
26
+ + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
27
+ * **LightConv and DynamicConv models**
28
+ + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
29
+ * **Long Short-Term Memory (LSTM) networks**
30
+ + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
31
+ * **Transformer (self-attention) networks**
32
+ + Attention Is All You Need (Vaswani et al., 2017)
33
+ + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
34
+ + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
35
+ + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
36
+ + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
37
+ + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md)
38
+ + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md)
39
+ + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
40
+ + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
41
+ + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
42
+ + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
43
+ + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
44
+ + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
45
+ + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
46
+ + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
47
+ + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
48
+ + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
49
+ + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
50
+ + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
51
+ + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979)
52
+ + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027)
53
+ + [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084)
54
+ * **Non-autoregressive Transformers**
55
+ + Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
56
+ + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
57
+ + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
58
+ + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
59
+ + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
60
+ * **Finetuning**
61
+ + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
62
+
63
+ </p></details>
64
+
65
+ ### What's New:
66
+
67
+ * September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
68
+ * July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md)
69
+ * July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md)
70
+ * June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md)
71
+ * May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md)
72
+ * March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md)
73
+ * February 2021 [Added LASER training code](examples/laser/README.md)
74
+ * December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md)
75
+ * December 2020: [GottBERT model and code released](examples/gottbert/README.md)
76
+ * November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework
77
+ * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md)
78
+ * November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0)
79
+ * October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
80
+ * October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
81
+ * October 2020: [Added CRISS models and code](examples/criss/README.md)
82
+
83
+ <details><summary>Previous updates</summary><p>
84
+
85
+ * September 2020: [Added Linformer code](examples/linformer/README.md)
86
+ * September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
87
+ * August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
88
+ * August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
89
+ * July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
90
+ * May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
91
+ * April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
92
+ * April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
93
+ * April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
94
+ * March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
95
+ * February 2020: [mBART model and code released](examples/mbart/README.md)
96
+ * February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german)
97
+ * December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
98
+ * November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
99
+ * November 2019: [CamemBERT model and code released](examples/camembert/README.md)
100
+ * November 2019: [BART model and code released](examples/bart/README.md)
101
+ * November 2019: [XLM-R models and code released](examples/xlmr/README.md)
102
+ * September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
103
+ * August 2019: [WMT'19 models released](examples/wmt19/README.md)
104
+ * July 2019: fairseq relicensed under MIT license
105
+ * July 2019: [RoBERTa models and code released](examples/roberta/README.md)
106
+ * June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
107
+
108
+ </p></details>
109
+
110
+ ### Features:
111
+
112
+ * multi-GPU training on one machine or across multiple machines (data and model parallel)
113
+ * fast generation on both CPU and GPU with multiple search algorithms implemented:
114
+ + beam search
115
+ + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
116
+ + sampling (unconstrained, top-k and top-p/nucleus)
117
+ + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018)
118
+ * [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU
119
+ * [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
120
+ * [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers
121
+ * [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration
122
+ * [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md)
123
+ * [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md)
124
+
125
+ We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
126
+ with a convenient `torch.hub` interface:
127
+
128
+ ``` python
129
+ en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
130
+ en2de.translate('Hello world', beam=5)
131
+ # 'Hallo Welt'
132
+ ```
133
+
134
+ See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
135
+ and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
136
+
137
+ # Requirements and Installation
138
+
139
+ * [PyTorch](http://pytorch.org/) version >= 1.5.0
140
+ * Python version >= 3.6
141
+ * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
142
+ * **To install fairseq** and develop locally:
143
+
144
+ ``` bash
145
+ git clone https://github.com/pytorch/fairseq
146
+ cd fairseq
147
+ pip install --editable ./
148
+
149
+ # on MacOS:
150
+ # CFLAGS="-stdlib=libc++" pip install --editable ./
151
+
152
+ # to install the latest stable release (0.10.x)
153
+ # pip install fairseq
154
+ ```
155
+
156
+ * **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
157
+
158
+ ``` bash
159
+ git clone https://github.com/NVIDIA/apex
160
+ cd apex
161
+ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
162
+ --global-option="--deprecated_fused_adam" --global-option="--xentropy" \
163
+ --global-option="--fast_multihead_attn" ./
164
+ ```
165
+
166
+ * **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
167
+ * If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size`
168
+ as command line options to `nvidia-docker run` .
169
+
170
+ # Getting Started
171
+
172
+ The [full documentation](https://fairseq.readthedocs.io/) contains instructions
173
+ for getting started, training new models and extending fairseq with new model
174
+ types and tasks.
175
+
176
+ # Pre-trained models and examples
177
+
178
+ We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below,
179
+ as well as example training and evaluation commands.
180
+
181
+ * [Translation](examples/translation/README.md): convolutional and transformer models are available
182
+ * [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
183
+
184
+ We also have more detailed READMEs to reproduce results from specific papers:
185
+
186
+ * [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
187
+ * [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
188
+ * [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
189
+ * [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
190
+ * [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
191
+ * [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
192
+ * [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
193
+ * [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
194
+ * [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
195
+ * [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
196
+ * [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
197
+ * [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
198
+ * [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
199
+ * [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
200
+ * [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
201
+ * [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
202
+ * [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
203
+ * [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
204
+ * [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
205
+ * [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md)
206
+
207
+ # Join the fairseq community
208
+
209
+ * Twitter: https://twitter.com/fairseq
210
+ * Facebook page: https://www.facebook.com/groups/fairseq.users
211
+ * Google group: https://groups.google.com/forum/#!forum/fairseq-users
212
+
213
+ # License
214
+
215
+ fairseq(-py) is MIT-licensed.
216
+ The license applies to the pre-trained models as well.
217
+
218
+ # Citation
219
+
220
+ Please cite as:
221
+
222
+ ``` bibtex
223
+ @inproceedings{ott2019fairseq,
224
+ title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
225
+ author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
226
+ booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
227
+ year = {2019},
228
+ }
229
+ ```
fairseq/examples/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ !*/*.sh
2
+ !*/*.md
fairseq/examples/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ try:
7
+ from fairseq.version import __version__ # noqa
8
+ except ImportError:
9
+ pass
fairseq/examples/adaptive_span/README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adaptive Span
2
+
3
+ Adaptive Span is a novel self-attention mechanism that can learn its optimal
4
+ attention span. This allows us to extend significantly the maximum context size
5
+ used in Transformer, while maintaining control over their memory footprint
6
+ and computational time. It uses the Truncated BPTT technique for training,
7
+ as in [transformerXL](https://github.com/pytorch/fairseq/blob/main/examples/truncated_bptt/README.md).
8
+
9
+ Adaptive Span was introduced by paper:
10
+ [Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799),
11
+ which achieved state-of-the-art language modeling results at the time of publication.
12
+
13
+ We manage to reproduce their result in fairseq and keep most of the
14
+ [original implementation](https://github.com/facebookresearch/adaptive-span) untouched.
15
+ You can refer to the their sweep file as well if any combination of hyperparameter is not clear.
16
+
17
+ ##### 0. Setup
18
+
19
+ First you need to process the Enwik8 dataset, we use the pre-tokenized dataset
20
+ from [adaptive span paper](https://github.com/facebookresearch/adaptive-span/blob/master/get_data.sh).
21
+ You can download the dataset, and then run:
22
+ ```bash
23
+ fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \
24
+ --validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \
25
+ --destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20
26
+ ```
27
+
28
+ ##### 1. Train a Adaptive Span model on Enwik8
29
+
30
+ We will train a 12-layer Adaptive Span model following the [hyperparameters
31
+ used in the original
32
+ paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
33
+
34
+ The following command assumes 4 GPUs, so that the total batch size is 64
35
+ sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs:
36
+ ```bash
37
+ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
38
+ --user-dir examples/adaptive_span \
39
+ --data ~/data/enwik8/data-bin/ \
40
+ --fp16 --fp16-no-flatten-grads --max-update 600000 \
41
+ --task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \
42
+ --n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \
43
+ --attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \
44
+ --validate-interval-updates 1000 \
45
+ --lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \
46
+ --lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \
47
+ --seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07
48
+ ```
49
+ This should land around 1.05 on validation, 1.03 on test. You can lower the
50
+ --aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc
51
+ improvement to the transformerXL baseline here.
52
+ If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients
53
+ and simulate training on 4 GPUs.
54
+ You can also reproduce the transformerXL result on enwik8 using this code base.
55
+ It should land around 1.06 on test,matching the [original paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_enwik8_base.sh).
56
+ You can try by
57
+ ```bash
58
+ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
59
+ --user-dir examples/truncated_bptt \
60
+ ~/data/enwik8/data-bin/ \
61
+ --task truncated_bptt_lm --fp16 --max-update 400000 \
62
+ --tokens-per-sample 512 --arch transformer_xl --n-layer 12 \
63
+ --d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \
64
+ --dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \
65
+ --lr-scheduler cosine --warmup-updates 0 \
66
+ --lr 0.0 --lr 0.00025 --batch-size 15 \
67
+ --update-freq 1 --seed 2 --log-format json --log-interval 25 \
68
+ --fp16
69
+ ```
70
+
71
+ ##### 2. Evaluate
72
+ For Adaptive Span:
73
+ ```bash
74
+ fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
75
+ --user-dir examples/adaptive_span \
76
+ --task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test
77
+ ```
78
+ For Transformer-XL evaluation:
79
+ ```bash
80
+ fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
81
+ --user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \
82
+ --tokens-per-sample 80 \
83
+ --model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \
84
+ --gen-subset valid
85
+ ```
86
+
87
+ *Note:* During training the model saw 512 tokens of context
88
+ (``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation
89
+ settings from [the original
90
+ paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
fairseq/examples/adaptive_span/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import os
8
+
9
+ # automatically import any Python files in the current directory
10
+ cur_dir = os.path.dirname(__file__)
11
+ for file in os.listdir(cur_dir):
12
+ path = os.path.join(cur_dir, file)
13
+ if (
14
+ not file.startswith("_")
15
+ and not file.startswith(".")
16
+ and (file.endswith(".py") or os.path.isdir(path))
17
+ ):
18
+ mod_name = file[: file.find(".py")] if file.endswith(".py") else file
19
+ module = importlib.import_module(__name__ + "." + mod_name)