Spaces:
Runtime error
Runtime error
fix: support smelu
Browse files- README.md +89 -78
- src/dalle_mini/model/modeling.py +25 -1
README.md
CHANGED
@@ -133,127 +133,138 @@ Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization f
|
|
133 |
### Citations
|
134 |
|
135 |
```text
|
136 |
-
@misc{
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
}
|
144 |
```
|
145 |
|
146 |
```text
|
147 |
-
@misc{
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
}
|
155 |
```
|
156 |
|
157 |
```text
|
158 |
-
@misc{
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
}
|
166 |
```
|
167 |
|
168 |
```text
|
169 |
-
@misc{
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
}
|
177 |
```
|
178 |
|
179 |
```text
|
180 |
-
@misc{
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
}
|
188 |
```
|
189 |
|
190 |
```text
|
191 |
-
@misc{
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
}
|
197 |
```
|
198 |
|
199 |
```text
|
200 |
-
@misc{
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
}
|
208 |
```
|
209 |
|
210 |
```text
|
211 |
-
@misc{
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
}
|
219 |
```
|
220 |
|
221 |
```text
|
222 |
-
@inproceedings{
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
}
|
228 |
```
|
229 |
|
230 |
```text
|
231 |
-
@misc{
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
}
|
239 |
```
|
240 |
|
241 |
```text
|
242 |
-
@misc{
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
}
|
250 |
```
|
251 |
|
252 |
```text
|
253 |
-
@misc{
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
}
|
259 |
```
|
|
|
133 |
### Citations
|
134 |
|
135 |
```text
|
136 |
+
@misc{
|
137 |
+
title={Zero-Shot Text-to-Image Generation},
|
138 |
+
author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
|
139 |
+
year={2021},
|
140 |
+
eprint={2102.12092},
|
141 |
+
archivePrefix={arXiv},
|
142 |
+
primaryClass={cs.CV}
|
143 |
}
|
144 |
```
|
145 |
|
146 |
```text
|
147 |
+
@misc{
|
148 |
+
title={Learning Transferable Visual Models From Natural Language Supervision},
|
149 |
+
author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
|
150 |
+
year={2021},
|
151 |
+
eprint={2103.00020},
|
152 |
+
archivePrefix={arXiv},
|
153 |
+
primaryClass={cs.CV}
|
154 |
}
|
155 |
```
|
156 |
|
157 |
```text
|
158 |
+
@misc{
|
159 |
+
title={Taming Transformers for High-Resolution Image Synthesis},
|
160 |
+
author={Patrick Esser and Robin Rombach and Björn Ommer},
|
161 |
+
year={2021},
|
162 |
+
eprint={2012.09841},
|
163 |
+
archivePrefix={arXiv},
|
164 |
+
primaryClass={cs.CV}
|
165 |
}
|
166 |
```
|
167 |
|
168 |
```text
|
169 |
+
@misc{
|
170 |
+
title={BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension},
|
171 |
+
author={Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and Abdelrahman Mohamed and Omer Levy and Ves Stoyanov and Luke Zettlemoyer},
|
172 |
+
year={2019},
|
173 |
+
eprint={1910.13461},
|
174 |
+
archivePrefix={arXiv},
|
175 |
+
primaryClass={cs.CL}
|
176 |
}
|
177 |
```
|
178 |
|
179 |
```text
|
180 |
+
@misc{
|
181 |
+
title={Scalable Second Order Optimization for Deep Learning},
|
182 |
+
author={Rohan Anil and Vineet Gupta and Tomer Koren and Kevin Regan and Yoram Singer},
|
183 |
+
year={2021},
|
184 |
+
eprint={2002.09018},
|
185 |
+
archivePrefix={arXiv},
|
186 |
+
primaryClass={cs.LG}
|
187 |
}
|
188 |
```
|
189 |
|
190 |
```text
|
191 |
+
@misc{
|
192 |
+
title={GLU Variants Improve Transformer},
|
193 |
+
author={Noam Shazeer},
|
194 |
+
year={2020},
|
195 |
+
url={https://arxiv.org/abs/2002.05202}
|
196 |
}
|
197 |
```
|
198 |
|
199 |
```text
|
200 |
+
@misc{
|
201 |
+
title={DeepNet: Scaling transformers to 1,000 layers},
|
202 |
+
author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Zhang, Dongdong and Wei, Furu},
|
203 |
+
year={2022},
|
204 |
+
eprint={2203.00555}
|
205 |
+
archivePrefix={arXiv},
|
206 |
+
primaryClass={cs.LG}
|
207 |
}
|
208 |
```
|
209 |
|
210 |
```text
|
211 |
+
@misc{
|
212 |
+
title={NormFormer: Improved Transformer Pretraining with Extra Normalization},
|
213 |
+
author={Sam Shleifer and Jason Weston and Myle Ott},
|
214 |
+
year={2021},
|
215 |
+
eprint={2110.09456},
|
216 |
+
archivePrefix={arXiv},
|
217 |
+
primaryClass={cs.CL}
|
218 |
}
|
219 |
```
|
220 |
|
221 |
```text
|
222 |
+
@inproceedings{
|
223 |
+
title={Swin Transformer V2: Scaling Up Capacity and Resolution},
|
224 |
+
author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
|
225 |
+
booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)},
|
226 |
+
year={2022}
|
227 |
}
|
228 |
```
|
229 |
|
230 |
```text
|
231 |
+
@misc{
|
232 |
+
title = {CogView: Mastering Text-to-Image Generation via Transformers},
|
233 |
+
author = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
|
234 |
+
year = {2021},
|
235 |
+
eprint = {2105.13290},
|
236 |
+
archivePrefix = {arXiv},
|
237 |
+
primaryClass = {cs.CV}
|
238 |
}
|
239 |
```
|
240 |
|
241 |
```text
|
242 |
+
@misc{
|
243 |
+
title = {Root Mean Square Layer Normalization},
|
244 |
+
author = {Biao Zhang and Rico Sennrich},
|
245 |
+
year = {2019},
|
246 |
+
eprint = {1910.07467},
|
247 |
+
archivePrefix = {arXiv},
|
248 |
+
primaryClass = {cs.LG}
|
249 |
}
|
250 |
```
|
251 |
|
252 |
```text
|
253 |
+
@misc{
|
254 |
+
title = {Sinkformers: Transformers with Doubly Stochastic Attention},
|
255 |
+
url = {https://arxiv.org/abs/2110.11773},
|
256 |
+
author = {Sander, Michael E. and Ablin, Pierre and Blondel, Mathieu and Peyré, Gabriel},
|
257 |
+
publisher = {arXiv},
|
258 |
+
year = {2021},
|
259 |
+
}
|
260 |
+
```
|
261 |
+
|
262 |
+
```text
|
263 |
+
@misc{
|
264 |
+
title = {Smooth activations and reproducibility in deep networks},
|
265 |
+
url = {https://arxiv.org/abs/2010.09931},
|
266 |
+
author = {Shamir, Gil I. and Lin, Dong and Coviello, Lorenzo},
|
267 |
+
publisher = {arXiv},
|
268 |
+
year = {2020},
|
269 |
}
|
270 |
```
|
src/dalle_mini/model/modeling.py
CHANGED
@@ -32,7 +32,7 @@ from flax.linen import partitioning as nn_partitioning
|
|
32 |
from flax.linen.linear import PrecisionLike
|
33 |
from flax.serialization import from_bytes
|
34 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
35 |
-
from jax import lax
|
36 |
from jax.random import PRNGKey
|
37 |
from transformers.configuration_utils import PretrainedConfig
|
38 |
from transformers.file_utils import (
|
@@ -68,6 +68,30 @@ logger = logging.get_logger(__name__)
|
|
68 |
remat = nn_partitioning.remat
|
69 |
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
# deepnet initialization
|
72 |
def deepnet_init(gain=1):
|
73 |
init = jax.nn.initializers.glorot_normal()
|
|
|
32 |
from flax.linen.linear import PrecisionLike
|
33 |
from flax.serialization import from_bytes
|
34 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
35 |
+
from jax import custom_jvp, lax
|
36 |
from jax.random import PRNGKey
|
37 |
from transformers.configuration_utils import PretrainedConfig
|
38 |
from transformers.file_utils import (
|
|
|
68 |
remat = nn_partitioning.remat
|
69 |
|
70 |
|
71 |
+
def smelu(beta: Any = 1.0):
|
72 |
+
"""
|
73 |
+
Implementation of "Real World Large Scale Recommendation Systems Reproducibility and Smooth Activations"
|
74 |
+
https://arxiv.org/abs/2202.06499
|
75 |
+
"""
|
76 |
+
|
77 |
+
@custom_jvp
|
78 |
+
@jax.jit
|
79 |
+
def _smelu(x: Any) -> Any:
|
80 |
+
x = jnp.where(x <= -beta, 0.0, x)
|
81 |
+
return jnp.where(x >= beta, x, jnp.square(x + beta) / (4 * beta))
|
82 |
+
|
83 |
+
_smelu.defjvps(
|
84 |
+
lambda g, ans, x: lax.select(
|
85 |
+
x == -beta,
|
86 |
+
lax.full_like(g, 0),
|
87 |
+
lax.select(x == beta, lax.full_like(g, 1), g),
|
88 |
+
)
|
89 |
+
)
|
90 |
+
return _smelu
|
91 |
+
|
92 |
+
|
93 |
+
ACT2FN.update({"smelu": smelu})
|
94 |
+
|
95 |
# deepnet initialization
|
96 |
def deepnet_init(gain=1):
|
97 |
init = jax.nn.initializers.glorot_normal()
|