boris commited on
Commit
dccd804
·
2 Parent(s): bb2758c 18f5a29

Merge branch 'main' into abidlabs/main

Browse files
.github/workflows/check_size.yml CHANGED
@@ -14,4 +14,4 @@ jobs:
14
  - name: Check large files
15
  uses: ActionsDesk/[email protected]
16
  with:
17
- filesizelimit: 900000 # so we can sync to HF spaces
 
14
  - name: Check large files
15
  uses: ActionsDesk/[email protected]
16
  with:
17
+ filesizelimit: 9000000 # so we can sync to HF spaces
.gitignore CHANGED
@@ -1 +1,3 @@
1
  __pycache__
 
 
 
1
  __pycache__
2
+ .ipynb_checkpoints
3
+ .streamlit
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2021 The DALL·E mini Authors
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -22,29 +22,13 @@ You can create your own pictures with [the demo](https://huggingface.co/spaces/f
22
 
23
  Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA).
24
 
25
- ## Where does the logo come from?
26
-
27
- The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL·E to illustrate the model's capabilities. Having successful predictions on this prompt represents a big milestone to us.
28
-
29
  ## Development
30
 
31
- This section is for the adventurous people wanting to look into the code.
32
-
33
  ### Dependencies Installation
34
 
35
  The root folder and associated `requirements.txt` is only for the app.
36
 
37
- You will find necessary requirements in each sub-section.
38
-
39
- You should create a new python virtual environment and install the project dependencies inside the virtual env. You need to use the `-f` (`--find-links`) option for `pip` to be able to find the appropriate `libtpu` required for the TPU hardware.
40
-
41
- Adapt the installation to your own hardware and follow library installation instructions.
42
-
43
- ```
44
- $ pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
45
- ```
46
-
47
- If you use `conda`, you can create the virtual env and install everything using: `conda env update -f environments.yaml`
48
 
49
  ### Training of VQGAN
50
 
@@ -58,13 +42,19 @@ Use [patil-suraj/vqgan-jax](https://github.com/patil-suraj/vqgan-jax).
58
 
59
  ### Training of Seq2Seq
60
 
61
- Refer to `dev/seq2seq` folder.
62
 
63
  You can also adjust the [sweep configuration file](https://docs.wandb.ai/guides/sweeps) if you need to perform a hyperparameter search.
64
 
65
- ### Inference
 
 
 
 
66
 
67
- Refer to `dev/notebooks/demo`.
 
 
68
 
69
  ## Authors
70
 
 
22
 
23
  Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA).
24
 
 
 
 
 
25
  ## Development
26
 
 
 
27
  ### Dependencies Installation
28
 
29
  The root folder and associated `requirements.txt` is only for the app.
30
 
31
+ For development, use [`dev/requirements.txt`](dev/requirements.txt) or [`dev/environment.yaml`](dev/environment.yaml).
 
 
 
 
 
 
 
 
 
 
32
 
33
  ### Training of VQGAN
34
 
 
42
 
43
  ### Training of Seq2Seq
44
 
45
+ Refer to [`dev/seq2seq`](dev/seq2seq) folder.
46
 
47
  You can also adjust the [sweep configuration file](https://docs.wandb.ai/guides/sweeps) if you need to perform a hyperparameter search.
48
 
49
+ ### Inference Pipeline
50
+
51
+ To generate sample predictions and understand the inference pipeline step by step, refer to [`dev/inference/inference_pipeline.ipynb`](dev/inference/inference_pipeline.ipynb).
52
+
53
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/dev/inference/inference_pipeline.ipynb)
54
 
55
+ ## Where does the logo come from?
56
+
57
+ The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL·E to illustrate the model's capabilities. Having successful predictions on this prompt represents a big milestone to us.
58
 
59
  ## Authors
60
 
app/app.py CHANGED
@@ -46,7 +46,7 @@ DALL·E mini is an AI model that generates images from any prompt you give!
46
  <p style='text-align: center'>
47
  Created by Boris Dayma et al. 2021
48
  <br/>
49
- <a href="https://github.com/borisdayma/dalle-mini">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA">Project Report</a>
50
  </p>
51
  """, unsafe_allow_html=True)
52
 
 
46
  <p style='text-align: center'>
47
  Created by Boris Dayma et al. 2021
48
  <br/>
49
+ <a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA" target="_blank">Project Report</a>
50
  </p>
51
  """, unsafe_allow_html=True)
52
 
app/gradio/app_gradio.py CHANGED
@@ -19,7 +19,7 @@ import numpy as np
19
  import matplotlib.pyplot as plt
20
 
21
 
22
- from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
23
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
24
 
25
  import gradio as gr
 
19
  import matplotlib.pyplot as plt
20
 
21
 
22
+ from vqgan_jax.modeling_flax_vqgan import VQModel
23
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
24
 
25
  import gradio as gr
dalle_mini/vqgan_jax/README.md DELETED
@@ -1,5 +0,0 @@
1
- ## vqgan-jax
2
-
3
- Files copied from [patil-suraj/vqgan-jax](https://github.com/patil-suraj/vqgan-jax/tree/main/vqgan_jax)
4
-
5
- Required for VQGAN Jax model.
 
 
 
 
 
 
dalle_mini/vqgan_jax/__init__.py DELETED
File without changes
dalle_mini/vqgan_jax/configuration_vqgan.py DELETED
@@ -1,40 +0,0 @@
1
- from typing import Tuple
2
-
3
- from transformers import PretrainedConfig
4
-
5
-
6
- class VQGANConfig(PretrainedConfig):
7
- def __init__(
8
- self,
9
- ch: int = 128,
10
- out_ch: int = 3,
11
- in_channels: int = 3,
12
- num_res_blocks: int = 2,
13
- resolution: int = 256,
14
- z_channels: int = 256,
15
- ch_mult: Tuple = (1, 1, 2, 2, 4),
16
- attn_resolutions: int = (16,),
17
- n_embed: int = 1024,
18
- embed_dim: int = 256,
19
- dropout: float = 0.0,
20
- double_z: bool = False,
21
- resamp_with_conv: bool = True,
22
- give_pre_end: bool = False,
23
- **kwargs,
24
- ):
25
- super().__init__(**kwargs)
26
- self.ch = ch
27
- self.out_ch = out_ch
28
- self.in_channels = in_channels
29
- self.num_res_blocks = num_res_blocks
30
- self.resolution = resolution
31
- self.z_channels = z_channels
32
- self.ch_mult = list(ch_mult)
33
- self.attn_resolutions = list(attn_resolutions)
34
- self.n_embed = n_embed
35
- self.embed_dim = embed_dim
36
- self.dropout = dropout
37
- self.double_z = double_z
38
- self.resamp_with_conv = resamp_with_conv
39
- self.give_pre_end = give_pre_end
40
- self.num_resolutions = len(ch_mult)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dalle_mini/vqgan_jax/modeling_flax_vqgan.py DELETED
@@ -1,609 +0,0 @@
1
- # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
2
-
3
- from functools import partial
4
- from typing import Tuple
5
- import math
6
-
7
- import jax
8
- import jax.numpy as jnp
9
- import numpy as np
10
- import flax.linen as nn
11
- from flax.core.frozen_dict import FrozenDict
12
-
13
- from transformers.modeling_flax_utils import FlaxPreTrainedModel
14
-
15
- from .configuration_vqgan import VQGANConfig
16
-
17
-
18
- class Upsample(nn.Module):
19
- in_channels: int
20
- with_conv: bool
21
- dtype: jnp.dtype = jnp.float32
22
-
23
- def setup(self):
24
- if self.with_conv:
25
- self.conv = nn.Conv(
26
- self.in_channels,
27
- kernel_size=(3, 3),
28
- strides=(1, 1),
29
- padding=((1, 1), (1, 1)),
30
- dtype=self.dtype,
31
- )
32
-
33
- def __call__(self, hidden_states):
34
- batch, height, width, channels = hidden_states.shape
35
- hidden_states = jax.image.resize(
36
- hidden_states,
37
- shape=(batch, height * 2, width * 2, channels),
38
- method="nearest",
39
- )
40
- if self.with_conv:
41
- hidden_states = self.conv(hidden_states)
42
- return hidden_states
43
-
44
-
45
- class Downsample(nn.Module):
46
- in_channels: int
47
- with_conv: bool
48
- dtype: jnp.dtype = jnp.float32
49
-
50
- def setup(self):
51
- if self.with_conv:
52
- self.conv = nn.Conv(
53
- self.in_channels,
54
- kernel_size=(3, 3),
55
- strides=(2, 2),
56
- padding="VALID",
57
- dtype=self.dtype,
58
- )
59
-
60
- def __call__(self, hidden_states):
61
- if self.with_conv:
62
- pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
63
- hidden_states = jnp.pad(hidden_states, pad_width=pad)
64
- hidden_states = self.conv(hidden_states)
65
- else:
66
- hidden_states = nn.avg_pool(hidden_states, window_shape=(2, 2), strides=(2, 2), padding="VALID")
67
- return hidden_states
68
-
69
-
70
- class ResnetBlock(nn.Module):
71
- in_channels: int
72
- out_channels: int = None
73
- use_conv_shortcut: bool = False
74
- temb_channels: int = 512
75
- dropout_prob: float = 0.0
76
- dtype: jnp.dtype = jnp.float32
77
-
78
- def setup(self):
79
- self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
80
-
81
- self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
82
- self.conv1 = nn.Conv(
83
- self.out_channels_,
84
- kernel_size=(3, 3),
85
- strides=(1, 1),
86
- padding=((1, 1), (1, 1)),
87
- dtype=self.dtype,
88
- )
89
-
90
- if self.temb_channels:
91
- self.temb_proj = nn.Dense(self.out_channels_, dtype=self.dtype)
92
-
93
- self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
94
- self.dropout = nn.Dropout(self.dropout_prob)
95
- self.conv2 = nn.Conv(
96
- self.out_channels_,
97
- kernel_size=(3, 3),
98
- strides=(1, 1),
99
- padding=((1, 1), (1, 1)),
100
- dtype=self.dtype,
101
- )
102
-
103
- if self.in_channels != self.out_channels_:
104
- if self.use_conv_shortcut:
105
- self.conv_shortcut = nn.Conv(
106
- self.out_channels_,
107
- kernel_size=(3, 3),
108
- strides=(1, 1),
109
- padding=((1, 1), (1, 1)),
110
- dtype=self.dtype,
111
- )
112
- else:
113
- self.nin_shortcut = nn.Conv(
114
- self.out_channels_,
115
- kernel_size=(1, 1),
116
- strides=(1, 1),
117
- padding="VALID",
118
- dtype=self.dtype,
119
- )
120
-
121
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
122
- residual = hidden_states
123
- hidden_states = self.norm1(hidden_states)
124
- hidden_states = nn.swish(hidden_states)
125
- hidden_states = self.conv1(hidden_states)
126
-
127
- if temb is not None:
128
- hidden_states = hidden_states + self.temb_proj(nn.swish(temb))[:, :, None, None] # TODO: check shapes
129
-
130
- hidden_states = self.norm2(hidden_states)
131
- hidden_states = nn.swish(hidden_states)
132
- hidden_states = self.dropout(hidden_states, deterministic)
133
- hidden_states = self.conv2(hidden_states)
134
-
135
- if self.in_channels != self.out_channels_:
136
- if self.use_conv_shortcut:
137
- residual = self.conv_shortcut(residual)
138
- else:
139
- residual = self.nin_shortcut(residual)
140
-
141
- return hidden_states + residual
142
-
143
-
144
- class AttnBlock(nn.Module):
145
- in_channels: int
146
- dtype: jnp.dtype = jnp.float32
147
-
148
- def setup(self):
149
- conv = partial(
150
- nn.Conv, self.in_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype
151
- )
152
-
153
- self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
154
- self.q, self.k, self.v = conv(), conv(), conv()
155
- self.proj_out = conv()
156
-
157
- def __call__(self, hidden_states):
158
- residual = hidden_states
159
- hidden_states = self.norm(hidden_states)
160
-
161
- query = self.q(hidden_states)
162
- key = self.k(hidden_states)
163
- value = self.v(hidden_states)
164
-
165
- # compute attentions
166
- batch, height, width, channels = query.shape
167
- query = query.reshape((batch, height * width, channels))
168
- key = key.reshape((batch, height * width, channels))
169
- attn_weights = jnp.einsum("...qc,...kc->...qk", query, key)
170
- attn_weights = attn_weights * (int(channels) ** -0.5)
171
- attn_weights = nn.softmax(attn_weights, axis=2)
172
-
173
- ## attend to values
174
- value = value.reshape((batch, height * width, channels))
175
- hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)
176
- hidden_states = hidden_states.reshape((batch, height, width, channels))
177
-
178
- hidden_states = self.proj_out(hidden_states)
179
- hidden_states = hidden_states + residual
180
- return hidden_states
181
-
182
-
183
- class UpsamplingBlock(nn.Module):
184
- config: VQGANConfig
185
- curr_res: int
186
- block_idx: int
187
- dtype: jnp.dtype = jnp.float32
188
-
189
- def setup(self):
190
- if self.block_idx == self.config.num_resolutions - 1:
191
- block_in = self.config.ch * self.config.ch_mult[-1]
192
- else:
193
- block_in = self.config.ch * self.config.ch_mult[self.block_idx + 1]
194
-
195
- block_out = self.config.ch * self.config.ch_mult[self.block_idx]
196
- self.temb_ch = 0
197
-
198
- res_blocks = []
199
- attn_blocks = []
200
- for _ in range(self.config.num_res_blocks + 1):
201
- res_blocks.append(
202
- ResnetBlock(
203
- block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype
204
- )
205
- )
206
- block_in = block_out
207
- if self.curr_res in self.config.attn_resolutions:
208
- attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
209
-
210
- self.block = res_blocks
211
- self.attn = attn_blocks
212
-
213
- self.upsample = None
214
- if self.block_idx != 0:
215
- self.upsample = Upsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)
216
-
217
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
218
- for res_block in self.block:
219
- hidden_states = res_block(hidden_states, temb, deterministic=deterministic)
220
- for attn_block in self.attn:
221
- hidden_states = attn_block(hidden_states)
222
-
223
- if self.upsample is not None:
224
- hidden_states = self.upsample(hidden_states)
225
-
226
- return hidden_states
227
-
228
-
229
- class DownsamplingBlock(nn.Module):
230
- config: VQGANConfig
231
- curr_res: int
232
- block_idx: int
233
- dtype: jnp.dtype = jnp.float32
234
-
235
- def setup(self):
236
- in_ch_mult = (1,) + tuple(self.config.ch_mult)
237
- block_in = self.config.ch * in_ch_mult[self.block_idx]
238
- block_out = self.config.ch * self.config.ch_mult[self.block_idx]
239
- self.temb_ch = 0
240
-
241
- res_blocks = []
242
- attn_blocks = []
243
- for _ in range(self.config.num_res_blocks):
244
- res_blocks.append(
245
- ResnetBlock(
246
- block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype
247
- )
248
- )
249
- block_in = block_out
250
- if self.curr_res in self.config.attn_resolutions:
251
- attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
252
-
253
- self.block = res_blocks
254
- self.attn = attn_blocks
255
-
256
- self.downsample = None
257
- if self.block_idx != self.config.num_resolutions - 1:
258
- self.downsample = Downsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)
259
-
260
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
261
- for res_block in self.block:
262
- hidden_states = res_block(hidden_states, temb, deterministic=deterministic)
263
- for attn_block in self.attn:
264
- hidden_states = attn_block(hidden_states)
265
-
266
- if self.downsample is not None:
267
- hidden_states = self.downsample(hidden_states)
268
-
269
- return hidden_states
270
-
271
-
272
- class MidBlock(nn.Module):
273
- in_channels: int
274
- temb_channels: int
275
- dropout: float
276
- dtype: jnp.dtype = jnp.float32
277
-
278
- def setup(self):
279
- self.block_1 = ResnetBlock(
280
- self.in_channels,
281
- self.in_channels,
282
- temb_channels=self.temb_channels,
283
- dropout_prob=self.dropout,
284
- dtype=self.dtype,
285
- )
286
- self.attn_1 = AttnBlock(self.in_channels, dtype=self.dtype)
287
- self.block_2 = ResnetBlock(
288
- self.in_channels,
289
- self.in_channels,
290
- temb_channels=self.temb_channels,
291
- dropout_prob=self.dropout,
292
- dtype=self.dtype,
293
- )
294
-
295
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
296
- hidden_states = self.block_1(hidden_states, temb, deterministic=deterministic)
297
- hidden_states = self.attn_1(hidden_states)
298
- hidden_states = self.block_2(hidden_states, temb, deterministic=deterministic)
299
- return hidden_states
300
-
301
-
302
- class Encoder(nn.Module):
303
- config: VQGANConfig
304
- dtype: jnp.dtype = jnp.float32
305
-
306
- def setup(self):
307
- self.temb_ch = 0
308
-
309
- # downsampling
310
- self.conv_in = nn.Conv(
311
- self.config.ch,
312
- kernel_size=(3, 3),
313
- strides=(1, 1),
314
- padding=((1, 1), (1, 1)),
315
- dtype=self.dtype,
316
- )
317
-
318
- curr_res = self.config.resolution
319
- downsample_blocks = []
320
- for i_level in range(self.config.num_resolutions):
321
- downsample_blocks.append(DownsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype))
322
-
323
- if i_level != self.config.num_resolutions - 1:
324
- curr_res = curr_res // 2
325
- self.down = downsample_blocks
326
-
327
- # middle
328
- mid_channels = self.config.ch * self.config.ch_mult[-1]
329
- self.mid = MidBlock(mid_channels, self.temb_ch, self.config.dropout, dtype=self.dtype)
330
-
331
- # end
332
- self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
333
- self.conv_out = nn.Conv(
334
- 2 * self.config.z_channels if self.config.double_z else self.config.z_channels,
335
- kernel_size=(3, 3),
336
- strides=(1, 1),
337
- padding=((1, 1), (1, 1)),
338
- dtype=self.dtype,
339
- )
340
-
341
- def __call__(self, pixel_values, deterministic: bool = True):
342
- # timestep embedding
343
- temb = None
344
-
345
- # downsampling
346
- hidden_states = self.conv_in(pixel_values)
347
- for block in self.down:
348
- hidden_states = block(hidden_states, temb, deterministic=deterministic)
349
-
350
- # middle
351
- hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
352
-
353
- # end
354
- hidden_states = self.norm_out(hidden_states)
355
- hidden_states = nn.swish(hidden_states)
356
- hidden_states = self.conv_out(hidden_states)
357
-
358
- return hidden_states
359
-
360
-
361
- class Decoder(nn.Module):
362
- config: VQGANConfig
363
- dtype: jnp.dtype = jnp.float32
364
-
365
- def setup(self):
366
- self.temb_ch = 0
367
-
368
- # compute in_ch_mult, block_in and curr_res at lowest res
369
- block_in = self.config.ch * self.config.ch_mult[self.config.num_resolutions - 1]
370
- curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
371
- self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
372
- print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
373
-
374
- # z to block_in
375
- self.conv_in = nn.Conv(
376
- block_in,
377
- kernel_size=(3, 3),
378
- strides=(1, 1),
379
- padding=((1, 1), (1, 1)),
380
- dtype=self.dtype,
381
- )
382
-
383
- # middle
384
- self.mid = MidBlock(block_in, self.temb_ch, self.config.dropout, dtype=self.dtype)
385
-
386
- # upsampling
387
- upsample_blocks = []
388
- for i_level in reversed(range(self.config.num_resolutions)):
389
- upsample_blocks.append(UpsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype))
390
- if i_level != 0:
391
- curr_res = curr_res * 2
392
- self.up = list(reversed(upsample_blocks)) # reverse to get consistent order
393
-
394
- # end
395
- self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
396
- self.conv_out = nn.Conv(
397
- self.config.out_ch,
398
- kernel_size=(3, 3),
399
- strides=(1, 1),
400
- padding=((1, 1), (1, 1)),
401
- dtype=self.dtype,
402
- )
403
-
404
- def __call__(self, hidden_states, deterministic: bool = True):
405
- # timestep embedding
406
- temb = None
407
-
408
- # z to block_in
409
- hidden_states = self.conv_in(hidden_states)
410
-
411
- # middle
412
- hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
413
-
414
- # upsampling
415
- for block in reversed(self.up):
416
- hidden_states = block(hidden_states, temb, deterministic=deterministic)
417
-
418
- # end
419
- if self.config.give_pre_end:
420
- return hidden_states
421
-
422
- hidden_states = self.norm_out(hidden_states)
423
- hidden_states = nn.swish(hidden_states)
424
- hidden_states = self.conv_out(hidden_states)
425
-
426
- return hidden_states
427
-
428
-
429
- class VectorQuantizer(nn.Module):
430
- """
431
- see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
432
- ____________________________________________
433
- Discretization bottleneck part of the VQ-VAE.
434
- Inputs:
435
- - n_e : number of embeddings
436
- - e_dim : dimension of embedding
437
- - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
438
- _____________________________________________
439
- """
440
-
441
- config: VQGANConfig
442
- dtype: jnp.dtype = jnp.float32
443
-
444
- def setup(self):
445
- self.embedding = nn.Embed(self.config.n_embed, self.config.embed_dim, dtype=self.dtype) # TODO: init
446
-
447
- def __call__(self, hidden_states):
448
- """
449
- Inputs the output of the encoder network z and maps it to a discrete
450
- one-hot vector that is the index of the closest embedding vector e_j
451
- z (continuous) -> z_q (discrete)
452
- z.shape = (batch, channel, height, width)
453
- quantization pipeline:
454
- 1. get encoder input (B,C,H,W)
455
- 2. flatten input to (B*H*W,C)
456
- """
457
- # flatten
458
- hidden_states_flattended = hidden_states.reshape((-1, self.config.embed_dim))
459
-
460
- # dummy op to init the weights, so we can access them below
461
- self.embedding(jnp.ones((1, 1), dtype="i4"))
462
-
463
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
464
- emb_weights = self.variables["params"]["embedding"]["embedding"]
465
- distance = (
466
- jnp.sum(hidden_states_flattended ** 2, axis=1, keepdims=True)
467
- + jnp.sum(emb_weights ** 2, axis=1)
468
- - 2 * jnp.dot(hidden_states_flattended, emb_weights.T)
469
- )
470
-
471
- # get quantized latent vectors
472
- min_encoding_indices = jnp.argmin(distance, axis=1)
473
- z_q = self.embedding(min_encoding_indices).reshape(hidden_states.shape)
474
-
475
- # reshape to (batch, num_tokens)
476
- min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
477
-
478
- # compute the codebook_loss (q_loss) outside the model
479
- # here we return the embeddings and indices
480
- return z_q, min_encoding_indices
481
-
482
- def get_codebook_entry(self, indices, shape=None):
483
- # indices are expected to be of shape (batch, num_tokens)
484
- # get quantized latent vectors
485
- batch, num_tokens = indices.shape
486
- z_q = self.embedding(indices)
487
- z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1)
488
- return z_q
489
-
490
-
491
- class VQModule(nn.Module):
492
- config: VQGANConfig
493
- dtype: jnp.dtype = jnp.float32
494
-
495
- def setup(self):
496
- self.encoder = Encoder(self.config, dtype=self.dtype)
497
- self.decoder = Decoder(self.config, dtype=self.dtype)
498
- self.quantize = VectorQuantizer(self.config, dtype=self.dtype)
499
- self.quant_conv = nn.Conv(
500
- self.config.embed_dim,
501
- kernel_size=(1, 1),
502
- strides=(1, 1),
503
- padding="VALID",
504
- dtype=self.dtype,
505
- )
506
- self.post_quant_conv = nn.Conv(
507
- self.config.z_channels,
508
- kernel_size=(1, 1),
509
- strides=(1, 1),
510
- padding="VALID",
511
- dtype=self.dtype,
512
- )
513
-
514
- def encode(self, pixel_values, deterministic: bool = True):
515
- hidden_states = self.encoder(pixel_values, deterministic=deterministic)
516
- hidden_states = self.quant_conv(hidden_states)
517
- quant_states, indices = self.quantize(hidden_states)
518
- return quant_states, indices
519
-
520
- def decode(self, hidden_states, deterministic: bool = True):
521
- hidden_states = self.post_quant_conv(hidden_states)
522
- hidden_states = self.decoder(hidden_states, deterministic=deterministic)
523
- return hidden_states
524
-
525
- def decode_code(self, code_b):
526
- hidden_states = self.quantize.get_codebook_entry(code_b)
527
- hidden_states = self.decode(hidden_states)
528
- return hidden_states
529
-
530
- def __call__(self, pixel_values, deterministic: bool = True):
531
- quant_states, indices = self.encode(pixel_values, deterministic)
532
- hidden_states = self.decode(quant_states, deterministic)
533
- return hidden_states, indices
534
-
535
-
536
- class VQGANPreTrainedModel(FlaxPreTrainedModel):
537
- """
538
- An abstract class to handle weights initialization and a simple interface
539
- for downloading and loading pretrained models.
540
- """
541
-
542
- config_class = VQGANConfig
543
- base_model_prefix = "model"
544
- module_class: nn.Module = None
545
-
546
- def __init__(
547
- self,
548
- config: VQGANConfig,
549
- input_shape: Tuple = (1, 256, 256, 3),
550
- seed: int = 0,
551
- dtype: jnp.dtype = jnp.float32,
552
- **kwargs,
553
- ):
554
- module = self.module_class(config=config, dtype=dtype, **kwargs)
555
- super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
556
-
557
- def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
558
- # init input tensors
559
- pixel_values = jnp.zeros(input_shape, dtype=jnp.float32)
560
- params_rng, dropout_rng = jax.random.split(rng)
561
- rngs = {"params": params_rng, "dropout": dropout_rng}
562
-
563
- return self.module.init(rngs, pixel_values)["params"]
564
-
565
- def encode(self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False):
566
- # Handle any PRNG if needed
567
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
568
-
569
- return self.module.apply(
570
- {"params": params or self.params}, jnp.array(pixel_values), not train, rngs=rngs, method=self.module.encode
571
- )
572
-
573
- def decode(self, hidden_states, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False):
574
- # Handle any PRNG if needed
575
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
576
-
577
- return self.module.apply(
578
- {"params": params or self.params},
579
- jnp.array(hidden_states),
580
- not train,
581
- rngs=rngs,
582
- method=self.module.decode,
583
- )
584
-
585
- def decode_code(self, indices, params: dict = None):
586
- return self.module.apply(
587
- {"params": params or self.params}, jnp.array(indices, dtype="i4"), method=self.module.decode_code
588
- )
589
-
590
- def __call__(
591
- self,
592
- pixel_values,
593
- params: dict = None,
594
- dropout_rng: jax.random.PRNGKey = None,
595
- train: bool = False,
596
- ):
597
- # Handle any PRNG if needed
598
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
599
-
600
- return self.module.apply(
601
- {"params": params or self.params},
602
- jnp.array(pixel_values),
603
- not train,
604
- rngs=rngs,
605
- )
606
-
607
-
608
- class VQModel(VQGANPreTrainedModel):
609
- module_class = VQModule
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/{notebooks/encoding → encoding}/vqgan-jax-encoding-with-captions.ipynb RENAMED
@@ -50,14 +50,6 @@
50
  "## VQGAN-JAX model"
51
  ]
52
  },
53
- {
54
- "cell_type": "markdown",
55
- "id": "bb408f6c",
56
- "metadata": {},
57
- "source": [
58
- "`dalle_mini` is a local package that contains the VQGAN-JAX model and other utilities."
59
- ]
60
- },
61
  {
62
  "cell_type": "code",
63
  "execution_count": 2,
@@ -65,7 +57,7 @@
65
  "metadata": {},
66
  "outputs": [],
67
  "source": [
68
- "from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel"
69
  ]
70
  },
71
  {
 
50
  "## VQGAN-JAX model"
51
  ]
52
  },
 
 
 
 
 
 
 
 
53
  {
54
  "cell_type": "code",
55
  "execution_count": 2,
 
57
  "metadata": {},
58
  "outputs": [],
59
  "source": [
60
+ "from vqgan_jax.modeling_flax_vqgan import VQModel"
61
  ]
62
  },
63
  {
dev/{notebooks/encoding → encoding}/vqgan-jax-encoding-yfcc100m.ipynb RENAMED
@@ -52,14 +52,6 @@
52
  "## VQGAN-JAX model"
53
  ]
54
  },
55
- {
56
- "cell_type": "markdown",
57
- "id": "bb408f6c",
58
- "metadata": {},
59
- "source": [
60
- "`dalle_mini` is a local package that contains the VQGAN-JAX model and other utilities."
61
- ]
62
- },
63
  {
64
  "cell_type": "code",
65
  "execution_count": 93,
@@ -67,7 +59,7 @@
67
  "metadata": {},
68
  "outputs": [],
69
  "source": [
70
- "from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel"
71
  ]
72
  },
73
  {
@@ -1111,9 +1103,13 @@
1111
  }
1112
  ],
1113
  "metadata": {
 
 
 
1114
  "kernelspec": {
1115
- "name": "python3",
1116
- "display_name": "Python 3.9.0 64-bit ('Python39')"
 
1117
  },
1118
  "language_info": {
1119
  "codemirror_mode": {
@@ -1125,12 +1121,9 @@
1125
  "name": "python",
1126
  "nbconvert_exporter": "python",
1127
  "pygments_lexer": "ipython3",
1128
- "version": "3.9.0"
1129
- },
1130
- "interpreter": {
1131
- "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
1132
  }
1133
  },
1134
  "nbformat": 4,
1135
  "nbformat_minor": 5
1136
- }
 
52
  "## VQGAN-JAX model"
53
  ]
54
  },
 
 
 
 
 
 
 
 
55
  {
56
  "cell_type": "code",
57
  "execution_count": 93,
 
59
  "metadata": {},
60
  "outputs": [],
61
  "source": [
62
+ "from vqgan_jax.modeling_flax_vqgan import VQModel"
63
  ]
64
  },
65
  {
 
1103
  }
1104
  ],
1105
  "metadata": {
1106
+ "interpreter": {
1107
+ "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
1108
+ },
1109
  "kernelspec": {
1110
+ "display_name": "Python 3 (ipykernel)",
1111
+ "language": "python",
1112
+ "name": "python3"
1113
  },
1114
  "language_info": {
1115
  "codemirror_mode": {
 
1121
  "name": "python",
1122
  "nbconvert_exporter": "python",
1123
  "pygments_lexer": "ipython3",
1124
+ "version": "3.8.10"
 
 
 
1125
  }
1126
  },
1127
  "nbformat": 4,
1128
  "nbformat_minor": 5
1129
+ }
dev/{notebooks/encoding → encoding}/vqgan-jax-encoding.ipynb RENAMED
File without changes
dev/{seq2seq/environment.yaml → environment.yaml} RENAMED
File without changes
dev/{predictions → inference}/README.md RENAMED
File without changes
dev/{predictions → inference}/dalle_mini RENAMED
File without changes
dev/inference/inference_pipeline.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
dev/inference/wandb-examples-from-backend.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import wandb
6
+ import os
7
+
8
+ from dalle_mini.backend import ServiceError, get_images_from_backend
9
+ from dalle_mini.helpers import captioned_strip
10
+
11
+ os.environ["WANDB_SILENT"] = "true"
12
+ os.environ["WANDB_CONSOLE"] = "off"
13
+
14
+ def log_to_wandb(prompts):
15
+ try:
16
+ backend_url = os.environ["BACKEND_SERVER"]
17
+ for _ in range(1):
18
+ for prompt in prompts:
19
+ print(f"Getting selections for: {prompt}")
20
+ # make a separate run per prompt
21
+ with wandb.init(
22
+ entity='wandb',
23
+ project='hf-flax-dalle-mini',
24
+ job_type='predictions',# tags=['openai'],
25
+ config={'prompt': prompt}
26
+ ):
27
+ imgs = []
28
+ selected = get_images_from_backend(prompt, backend_url)
29
+ strip = captioned_strip(selected, prompt)
30
+ imgs.append(wandb.Image(strip))
31
+ wandb.log({"images": imgs})
32
+ except ServiceError as error:
33
+ print(f"Service unavailable, status: {error.status_code}")
34
+ except KeyError:
35
+ print("Error: BACKEND_SERVER unset")
36
+
37
+ prompts = [
38
+ # "white snow covered mountain under blue sky during daytime",
39
+ # "aerial view of beach during daytime",
40
+ # "aerial view of beach at night",
41
+ # "a farmhouse surrounded by beautiful flowers",
42
+ # "an armchair in the shape of an avocado",
43
+ # "young woman riding her bike trough a forest",
44
+ # "a unicorn is passing by a rainbow in a field of flowers",
45
+ # "illustration of a baby shark swimming around corals",
46
+ # "painting of an oniric forest glade surrounded by tall trees",
47
+ # "sunset over green mountains",
48
+ # "a forest glade surrounded by tall trees in a sunny Spring morning",
49
+ # "fishing village under the moonlight in a serene sunset",
50
+ # "cartoon of a carrot with big eyes",
51
+ # "still life in the style of Kandinsky",
52
+ # "still life in the style of Picasso",
53
+ # "a graphite sketch of a gothic cathedral",
54
+ # "a graphite sketch of Elon Musk",
55
+ # "a watercolor pond with green leaves and yellow flowers",
56
+ # "a logo of a cute avocado armchair singing karaoke on stage in front of a crowd of strawberry shaped lamps",
57
+ # "happy celebration in a small village in Africa",
58
+ # "a logo of an armchair in the shape of an avocado"
59
+ # "Pele and Maradona in a hypothetical match",
60
+ # "Mohammed Ali and Mike Tyson in a hypothetical match",
61
+ # "a storefront that has the word 'openai' written on it",
62
+ # "a pentagonal green clock",
63
+ # "a collection of glasses is sitting on a table",
64
+ # "a small red block sitting on a large green block",
65
+ # "an extreme close-up view of a capybara sitting in a field",
66
+ # "a cross-section view of a walnut",
67
+ # "a professional high-quality emoji of a lovestruck cup of boba",
68
+ # "a photo of san francisco's golden gate bridge",
69
+ # "an illustration of a baby daikon radish in a tutu walking a dog",
70
+ # "a picture of the Eiffel tower on the Moon",
71
+ # "a colorful stairway to heaven",
72
+ "this is a detailed high-resolution scan of a human brain"
73
+ ]
74
+
75
+ for _ in range(1):
76
+ log_to_wandb(prompts)
dev/{predictions → inference}/wandb-examples.py RENAMED
@@ -4,16 +4,14 @@
4
  import random
5
 
6
  import jax
7
- import flax.linen as nn
8
  from flax.training.common_utils import shard
9
  from flax.jax_utils import replicate, unreplicate
10
 
11
  from transformers.models.bart.modeling_flax_bart import *
12
  from transformers import BartTokenizer, FlaxBartForConditionalGeneration
13
 
14
- import io
15
 
16
- import requests
17
  from PIL import Image
18
  import numpy as np
19
  import matplotlib.pyplot as plt
@@ -23,58 +21,24 @@ import torchvision.transforms as T
23
  import torchvision.transforms.functional as TF
24
  from torchvision.transforms import InterpolationMode
25
 
26
- from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
27
-
28
- # TODO: set those args in a config file
29
- OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
30
- OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
31
- BOS_TOKEN_ID = 16384
32
- BASE_MODEL = 'facebook/bart-large-cnn'
33
-
34
- class CustomFlaxBartModule(FlaxBartModule):
35
- def setup(self):
36
- # we keep shared to easily load pre-trained weights
37
- self.shared = nn.Embed(
38
- self.config.vocab_size,
39
- self.config.d_model,
40
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
41
- dtype=self.dtype,
42
- )
43
- # a separate embedding is used for the decoder
44
- self.decoder_embed = nn.Embed(
45
- OUTPUT_VOCAB_SIZE,
46
- self.config.d_model,
47
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
48
- dtype=self.dtype,
49
- )
50
- self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
51
-
52
- # the decoder has a different config
53
- decoder_config = BartConfig(self.config.to_dict())
54
- decoder_config.max_position_embeddings = OUTPUT_LENGTH
55
- decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
56
- self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
57
-
58
- class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
59
- def setup(self):
60
- self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
61
- self.lm_head = nn.Dense(
62
- OUTPUT_VOCAB_SIZE,
63
- use_bias=False,
64
- dtype=self.dtype,
65
- kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
66
- )
67
- self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
68
-
69
- class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
70
- module_class = CustomFlaxBartForConditionalGenerationModule
71
 
 
 
72
 
73
  import wandb
74
  import os
 
 
 
 
75
  os.environ["WANDB_SILENT"] = "true"
76
  os.environ["WANDB_CONSOLE"] = "off"
77
 
 
 
 
78
  # set id to None so our latest images don't get overwritten
79
  id = None
80
  run = wandb.init(id=id,
@@ -87,8 +51,10 @@ artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-4oh3u7ca:latest', ty
87
  artifact_dir = artifact.download()
88
 
89
  # create our model
90
- tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)
91
  model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
 
 
 
92
  model.config.force_bos_token_to_be_generated = False
93
  model.config.forced_bos_token_id = None
94
  model.config.forced_eos_token_id = None
@@ -143,9 +109,6 @@ p_get_images = jax.pmap(get_images, "batch")
143
  bart_params = replicate(model.params)
144
  vqgan_params = replicate(vqgan.params)
145
 
146
- # ## CLIP Scoring
147
- from transformers import CLIPProcessor, FlaxCLIPModel
148
-
149
  clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
150
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
151
 
@@ -170,16 +133,12 @@ def hallucinate(prompt, num_images=64):
170
 
171
  def clip_top_k(prompt, images, k=8):
172
  inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
 
173
  outputs = clip(**inputs)
174
  logits = outputs.logits_per_text
175
  scores = np.array(logits[0]).argsort()[-k:][::-1]
176
  return [images[score] for score in scores]
177
 
178
-
179
- # ## Log to wandb
180
-
181
- from dalle_mini.helpers import captioned_strip
182
-
183
  def log_to_wandb(prompts):
184
  strips = []
185
  for prompt in prompts:
 
4
  import random
5
 
6
  import jax
 
7
  from flax.training.common_utils import shard
8
  from flax.jax_utils import replicate, unreplicate
9
 
10
  from transformers.models.bart.modeling_flax_bart import *
11
  from transformers import BartTokenizer, FlaxBartForConditionalGeneration
12
 
13
+ import os
14
 
 
15
  from PIL import Image
16
  import numpy as np
17
  import matplotlib.pyplot as plt
 
21
  import torchvision.transforms.functional as TF
22
  from torchvision.transforms import InterpolationMode
23
 
24
+ from dalle_mini.model import CustomFlaxBartForConditionalGeneration
25
+ from vqgan_jax.modeling_flax_vqgan import VQModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # ## CLIP Scoring
28
+ from transformers import CLIPProcessor, FlaxCLIPModel
29
 
30
  import wandb
31
  import os
32
+
33
+ from dalle_mini.helpers import captioned_strip
34
+
35
+
36
  os.environ["WANDB_SILENT"] = "true"
37
  os.environ["WANDB_CONSOLE"] = "off"
38
 
39
+ # TODO: used for legacy support
40
+ BASE_MODEL = 'facebook/bart-large-cnn'
41
+
42
  # set id to None so our latest images don't get overwritten
43
  id = None
44
  run = wandb.init(id=id,
 
51
  artifact_dir = artifact.download()
52
 
53
  # create our model
 
54
  model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
55
+
56
+ # TODO: legacy support (earlier models)
57
+ tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)
58
  model.config.force_bos_token_to_be_generated = False
59
  model.config.forced_bos_token_id = None
60
  model.config.forced_eos_token_id = None
 
109
  bart_params = replicate(model.params)
110
  vqgan_params = replicate(vqgan.params)
111
 
 
 
 
112
  clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
113
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
114
 
 
133
 
134
  def clip_top_k(prompt, images, k=8):
135
  inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
136
+ # FIXME: image should be resized and normalized prior to being processed by CLIP
137
  outputs = clip(**inputs)
138
  logits = outputs.logits_per_text
139
  scores = np.array(logits[0]).argsort()[-k:][::-1]
140
  return [images[score] for score in scores]
141
 
 
 
 
 
 
142
  def log_to_wandb(prompts):
143
  strips = []
144
  for prompt in prompts:
dev/notebooks/README.md DELETED
@@ -1,5 +0,0 @@
1
- # Notebooks
2
-
3
- These notebooks were used during development.
4
-
5
- TODO: This section requires some refactor and clean up.
 
 
 
 
 
 
dev/notebooks/demo/CustomBARTv4b_model-generate.ipynb DELETED
@@ -1,394 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {
6
- "id": "ewer-Q-0w2xA"
7
- },
8
- "source": [
9
- "# Installation"
10
- ]
11
- },
12
- {
13
- "cell_type": "code",
14
- "execution_count": null,
15
- "metadata": {
16
- "colab": {
17
- "base_uri": "https://localhost:8080/"
18
- },
19
- "id": "NpsF9ipLLl2s",
20
- "outputId": "10bf54aa-b89d-4e42-9777-bc97b00a5f32"
21
- },
22
- "outputs": [],
23
- "source": [
24
- "!pip install git+https://github.com/huggingface/transformers/\n",
25
- "!pip install git+https://github.com/google/flax"
26
- ]
27
- },
28
- {
29
- "cell_type": "code",
30
- "execution_count": null,
31
- "metadata": {
32
- "id": "M1wVkrpjU6zO"
33
- },
34
- "outputs": [],
35
- "source": [
36
- "%load_ext autoreload\n",
37
- "%autoreload 2"
38
- ]
39
- },
40
- {
41
- "cell_type": "markdown",
42
- "metadata": {
43
- "id": "t47CH1H_IOT8"
44
- },
45
- "source": [
46
- "# Custom BART Model"
47
- ]
48
- },
49
- {
50
- "cell_type": "code",
51
- "execution_count": null,
52
- "metadata": {
53
- "id": "9jQnM6S2vCpn"
54
- },
55
- "outputs": [],
56
- "source": [
57
- "# TODO: set those args in a config file\n",
58
- "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
59
- "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
60
- "BOS_TOKEN_ID = 16384\n",
61
- "BASE_MODEL = 'facebook/bart-large'"
62
- ]
63
- },
64
- {
65
- "cell_type": "code",
66
- "execution_count": null,
67
- "metadata": {
68
- "id": "_eEaJVxAKpV5"
69
- },
70
- "outputs": [],
71
- "source": [
72
- "import jax\n",
73
- "import flax.linen as nn\n",
74
- "\n",
75
- "from transformers.models.bart.modeling_flax_bart import *\n",
76
- "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
77
- "\n",
78
- "class CustomFlaxBartModule(FlaxBartModule):\n",
79
- " def setup(self):\n",
80
- " # we keep shared to easily load pre-trained weights\n",
81
- " self.shared = nn.Embed(\n",
82
- " self.config.vocab_size,\n",
83
- " self.config.d_model,\n",
84
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
85
- " dtype=self.dtype,\n",
86
- " )\n",
87
- " # a separate embedding is used for the decoder\n",
88
- " self.decoder_embed = nn.Embed(\n",
89
- " OUTPUT_VOCAB_SIZE,\n",
90
- " self.config.d_model,\n",
91
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
92
- " dtype=self.dtype,\n",
93
- " )\n",
94
- " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
95
- "\n",
96
- " # the decoder has a different config\n",
97
- " decoder_config = BartConfig(self.config.to_dict())\n",
98
- " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
99
- " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
100
- " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
101
- "\n",
102
- "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
103
- " def setup(self):\n",
104
- " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
105
- " self.lm_head = nn.Dense(\n",
106
- " OUTPUT_VOCAB_SIZE,\n",
107
- " use_bias=False,\n",
108
- " dtype=self.dtype,\n",
109
- " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
110
- " )\n",
111
- " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
112
- "\n",
113
- "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
114
- " module_class = CustomFlaxBartForConditionalGenerationModule"
115
- ]
116
- },
117
- {
118
- "cell_type": "code",
119
- "execution_count": null,
120
- "metadata": {
121
- "colab": {
122
- "base_uri": "https://localhost:8080/"
123
- },
124
- "id": "S7CP9Td9m2ge",
125
- "outputId": "5638ef68-9c40-46f7-90ba-a4d05b61360d"
126
- },
127
- "outputs": [],
128
- "source": [
129
- "# load pre-trained model for encoder weights\n",
130
- "base_model = FlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)"
131
- ]
132
- },
133
- {
134
- "cell_type": "code",
135
- "execution_count": null,
136
- "metadata": {
137
- "id": "6lmynR-poceH"
138
- },
139
- "outputs": [],
140
- "source": [
141
- "# set up our new model config\n",
142
- "config = BartConfig.from_pretrained(BASE_MODEL)\n",
143
- "config.tie_word_embeddings = False\n",
144
- "config.decoder_start_token_id = BOS_TOKEN_ID\n",
145
- "config.bos_token_id = BOS_TOKEN_ID # should not be used\n",
146
- "config.pos_token_id = BOS_TOKEN_ID # should not be used\n",
147
- "#config.eos_token_id = None # prevents generation from stopping until we reach max_length"
148
- ]
149
- },
150
- {
151
- "cell_type": "code",
152
- "execution_count": null,
153
- "metadata": {
154
- "id": "_6-XKK40oEfP"
155
- },
156
- "outputs": [],
157
- "source": [
158
- "# create our model and initialize it randomly\n",
159
- "model = CustomFlaxBartForConditionalGeneration(config)"
160
- ]
161
- },
162
- {
163
- "cell_type": "code",
164
- "execution_count": null,
165
- "metadata": {
166
- "id": "-r_hZestr-NR"
167
- },
168
- "outputs": [],
169
- "source": [
170
- "# use pretrained weights\n",
171
- "model.params['model']['encoder'] = base_model.params['model']['encoder']\n",
172
- "model.params['model']['shared'] = base_model.params['model']['shared']"
173
- ]
174
- },
175
- {
176
- "cell_type": "code",
177
- "execution_count": null,
178
- "metadata": {
179
- "id": "5NEX8f62sVjx"
180
- },
181
- "outputs": [],
182
- "source": [
183
- "# no need for base_model anymore\n",
184
- "del base_model"
185
- ]
186
- },
187
- {
188
- "cell_type": "code",
189
- "execution_count": null,
190
- "metadata": {
191
- "colab": {
192
- "base_uri": "https://localhost:8080/"
193
- },
194
- "id": "Jz032w73nHEf",
195
- "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49"
196
- },
197
- "outputs": [],
198
- "source": [
199
- "# we verify that the shape has not been modified\n",
200
- "model.params['final_logits_bias'].shape"
201
- ]
202
- },
203
- {
204
- "cell_type": "markdown",
205
- "metadata": {
206
- "id": "zLl24Ez5t7x1"
207
- },
208
- "source": [
209
- "## Inference"
210
- ]
211
- },
212
- {
213
- "cell_type": "code",
214
- "execution_count": null,
215
- "metadata": {
216
- "id": "XLLA2NK3uDQr"
217
- },
218
- "outputs": [],
219
- "source": [
220
- "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)"
221
- ]
222
- },
223
- {
224
- "cell_type": "code",
225
- "execution_count": null,
226
- "metadata": {
227
- "colab": {
228
- "base_uri": "https://localhost:8080/"
229
- },
230
- "id": "Ntow53I_t81D",
231
- "outputId": "59289cdd-1429-4720-cc87-88810c4b99ac"
232
- },
233
- "outputs": [],
234
- "source": [
235
- "text = \"My friends are cool but they eat too many carbs.\"\n",
236
- "inputs = tokenizer(text, max_length=1024, return_tensors='jax')\n",
237
- "encoder_outputs = model.encode(**inputs)"
238
- ]
239
- },
240
- {
241
- "cell_type": "code",
242
- "execution_count": null,
243
- "metadata": {
244
- "colab": {
245
- "base_uri": "https://localhost:8080/"
246
- },
247
- "id": "vcRNJnJ_uJOJ",
248
- "outputId": "025afd54-7908-4a9c-fb59-e40bd3458711"
249
- },
250
- "outputs": [],
251
- "source": [
252
- "decoder_start_token_id = model.config.decoder_start_token_id\n",
253
- "decoder_start_token_id"
254
- ]
255
- },
256
- {
257
- "cell_type": "code",
258
- "execution_count": null,
259
- "metadata": {
260
- "id": "6QWmEwL_uMld"
261
- },
262
- "outputs": [],
263
- "source": [
264
- "decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype=\"i4\") * decoder_start_token_id\n",
265
- "outputs = model.decode(decoder_input_ids, encoder_outputs)"
266
- ]
267
- },
268
- {
269
- "cell_type": "code",
270
- "execution_count": null,
271
- "metadata": {
272
- "colab": {
273
- "base_uri": "https://localhost:8080/"
274
- },
275
- "id": "c_ys3yWBothF",
276
- "outputId": "40d4d584-e0a8-44cb-bbea-0ffa38d50a53"
277
- },
278
- "outputs": [],
279
- "source": [
280
- "outputs"
281
- ]
282
- },
283
- {
284
- "cell_type": "code",
285
- "execution_count": null,
286
- "metadata": {
287
- "colab": {
288
- "base_uri": "https://localhost:8080/"
289
- },
290
- "id": "O6s0wtB_uTC_",
291
- "outputId": "bc0e9e80-e346-4e99-d28e-3f658eda1f66"
292
- },
293
- "outputs": [],
294
- "source": [
295
- "outputs.logits.shape"
296
- ]
297
- },
298
- {
299
- "cell_type": "code",
300
- "execution_count": null,
301
- "metadata": {
302
- "colab": {
303
- "base_uri": "https://localhost:8080/"
304
- },
305
- "id": "ELzemGP3uBzy",
306
- "outputId": "dc12f98a-1ccf-450d-ba2a-9c29d7d14885"
307
- },
308
- "outputs": [],
309
- "source": [
310
- "outputs.logits.argmax(axis=-1)"
311
- ]
312
- },
313
- {
314
- "cell_type": "code",
315
- "execution_count": null,
316
- "metadata": {
317
- "colab": {
318
- "base_uri": "https://localhost:8080/"
319
- },
320
- "id": "fQjikkGEunpx",
321
- "outputId": "3dba0209-ad4e-4069-be38-6c599c677ef1"
322
- },
323
- "outputs": [],
324
- "source": [
325
- "model.config.bos_token_id, model.config.eos_token_id, model.config.pad_token_id"
326
- ]
327
- },
328
- {
329
- "cell_type": "code",
330
- "execution_count": null,
331
- "metadata": {
332
- "id": "P32mJJSbrU1F"
333
- },
334
- "outputs": [],
335
- "source": [
336
- "input_ids_test = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')"
337
- ]
338
- },
339
- {
340
- "cell_type": "code",
341
- "execution_count": null,
342
- "metadata": {
343
- "id": "C7cHbIHruELT"
344
- },
345
- "outputs": [],
346
- "source": [
347
- "greedy_output = model.generate(input_ids_test, max_length=50)"
348
- ]
349
- },
350
- {
351
- "cell_type": "code",
352
- "execution_count": null,
353
- "metadata": {
354
- "colab": {
355
- "base_uri": "https://localhost:8080/"
356
- },
357
- "id": "jYugh9cOuwc9",
358
- "outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537"
359
- },
360
- "outputs": [],
361
- "source": [
362
- "greedy_output[0]"
363
- ]
364
- }
365
- ],
366
- "metadata": {
367
- "accelerator": "TPU",
368
- "colab": {
369
- "collapsed_sections": [],
370
- "machine_shape": "hm",
371
- "name": "CustomBARTv4b-model-generate.ipynb",
372
- "provenance": []
373
- },
374
- "kernelspec": {
375
- "display_name": "Python 3 (ipykernel)",
376
- "language": "python",
377
- "name": "python3"
378
- },
379
- "language_info": {
380
- "codemirror_mode": {
381
- "name": "ipython",
382
- "version": 3
383
- },
384
- "file_extension": ".py",
385
- "mimetype": "text/x-python",
386
- "name": "python",
387
- "nbconvert_exporter": "python",
388
- "pygments_lexer": "ipython3",
389
- "version": "3.8.5"
390
- }
391
- },
392
- "nbformat": 4,
393
- "nbformat_minor": 4
394
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/notebooks/demo/demo_notebook.ipynb DELETED
@@ -1,387 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {
6
- "id": "ewer-Q-0w2xA"
7
- },
8
- "source": [
9
- "# Installation"
10
- ]
11
- },
12
- {
13
- "cell_type": "code",
14
- "execution_count": null,
15
- "metadata": {
16
- "colab": {
17
- "base_uri": "https://localhost:8080/"
18
- },
19
- "id": "NpsF9ipLLl2s",
20
- "outputId": "10bf54aa-b89d-4e42-9777-bc97b00a5f32"
21
- },
22
- "outputs": [],
23
- "source": [
24
- "#!pip install git+https://github.com/huggingface/transformers/\n",
25
- "#!pip install git+https://github.com/google/flax"
26
- ]
27
- },
28
- {
29
- "cell_type": "code",
30
- "execution_count": null,
31
- "metadata": {
32
- "id": "M1wVkrpjU6zO"
33
- },
34
- "outputs": [],
35
- "source": [
36
- "%load_ext autoreload\n",
37
- "%autoreload 2"
38
- ]
39
- },
40
- {
41
- "cell_type": "code",
42
- "execution_count": null,
43
- "metadata": {},
44
- "outputs": [],
45
- "source": [
46
- "%cd ../../vqgan-jax"
47
- ]
48
- },
49
- {
50
- "cell_type": "markdown",
51
- "metadata": {
52
- "id": "t47CH1H_IOT8"
53
- },
54
- "source": [
55
- "# Custom BART Model"
56
- ]
57
- },
58
- {
59
- "cell_type": "code",
60
- "execution_count": null,
61
- "metadata": {
62
- "id": "9jQnM6S2vCpn"
63
- },
64
- "outputs": [],
65
- "source": [
66
- "# TODO: set those args in a config file\n",
67
- "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
68
- "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
69
- "BOS_TOKEN_ID = 16384\n",
70
- "BASE_MODEL = 'facebook/bart-large'"
71
- ]
72
- },
73
- {
74
- "cell_type": "code",
75
- "execution_count": null,
76
- "metadata": {
77
- "id": "_eEaJVxAKpV5"
78
- },
79
- "outputs": [],
80
- "source": [
81
- "import jax\n",
82
- "import flax.linen as nn\n",
83
- "\n",
84
- "from transformers.models.bart.modeling_flax_bart import *\n",
85
- "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
86
- "\n",
87
- "class CustomFlaxBartModule(FlaxBartModule):\n",
88
- " def setup(self):\n",
89
- " # we keep shared to easily load pre-trained weights\n",
90
- " self.shared = nn.Embed(\n",
91
- " self.config.vocab_size,\n",
92
- " self.config.d_model,\n",
93
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
94
- " dtype=self.dtype,\n",
95
- " )\n",
96
- " # a separate embedding is used for the decoder\n",
97
- " self.decoder_embed = nn.Embed(\n",
98
- " OUTPUT_VOCAB_SIZE,\n",
99
- " self.config.d_model,\n",
100
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
101
- " dtype=self.dtype,\n",
102
- " )\n",
103
- " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
104
- "\n",
105
- " # the decoder has a different config\n",
106
- " decoder_config = BartConfig(self.config.to_dict())\n",
107
- " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
108
- " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
109
- " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
110
- "\n",
111
- "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
112
- " def setup(self):\n",
113
- " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
114
- " self.lm_head = nn.Dense(\n",
115
- " OUTPUT_VOCAB_SIZE,\n",
116
- " use_bias=False,\n",
117
- " dtype=self.dtype,\n",
118
- " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
119
- " )\n",
120
- " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
121
- "\n",
122
- "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
123
- " module_class = CustomFlaxBartForConditionalGenerationModule"
124
- ]
125
- },
126
- {
127
- "cell_type": "code",
128
- "execution_count": null,
129
- "metadata": {
130
- "scrolled": true
131
- },
132
- "outputs": [],
133
- "source": [
134
- "import wandb\n",
135
- "run = wandb.init()\n",
136
- "artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-1ef8yxby:latest', type='bart_model')\n",
137
- "artifact_dir = artifact.download()"
138
- ]
139
- },
140
- {
141
- "cell_type": "code",
142
- "execution_count": null,
143
- "metadata": {
144
- "id": "_6-XKK40oEfP",
145
- "scrolled": true
146
- },
147
- "outputs": [],
148
- "source": [
149
- "# create our model and initialize it randomly\n",
150
- "model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)"
151
- ]
152
- },
153
- {
154
- "cell_type": "code",
155
- "execution_count": null,
156
- "metadata": {},
157
- "outputs": [],
158
- "source": [
159
- "model.config.forced_bos_token_id = None"
160
- ]
161
- },
162
- {
163
- "cell_type": "code",
164
- "execution_count": null,
165
- "metadata": {
166
- "colab": {
167
- "base_uri": "https://localhost:8080/"
168
- },
169
- "id": "Jz032w73nHEf",
170
- "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49"
171
- },
172
- "outputs": [],
173
- "source": [
174
- "# we verify that the shape has not been modified\n",
175
- "model.params['final_logits_bias'].shape"
176
- ]
177
- },
178
- {
179
- "cell_type": "markdown",
180
- "metadata": {
181
- "id": "zLl24Ez5t7x1"
182
- },
183
- "source": [
184
- "## Inference"
185
- ]
186
- },
187
- {
188
- "cell_type": "code",
189
- "execution_count": null,
190
- "metadata": {
191
- "id": "XLLA2NK3uDQr"
192
- },
193
- "outputs": [],
194
- "source": [
195
- "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)"
196
- ]
197
- },
198
- {
199
- "cell_type": "code",
200
- "execution_count": null,
201
- "metadata": {},
202
- "outputs": [],
203
- "source": [
204
- "input_text = ['I enjoy walking with my cute dog']*8"
205
- ]
206
- },
207
- {
208
- "cell_type": "code",
209
- "execution_count": null,
210
- "metadata": {
211
- "id": "P32mJJSbrU1F"
212
- },
213
- "outputs": [],
214
- "source": [
215
- "input_ids_test = tokenizer(input_text, return_tensors='jax')"
216
- ]
217
- },
218
- {
219
- "cell_type": "code",
220
- "execution_count": null,
221
- "metadata": {},
222
- "outputs": [],
223
- "source": [
224
- "input_ids_test"
225
- ]
226
- },
227
- {
228
- "cell_type": "code",
229
- "execution_count": null,
230
- "metadata": {
231
- "id": "C7cHbIHruELT"
232
- },
233
- "outputs": [],
234
- "source": [
235
- "greedy_output = model.generate(input_ids_test['input_ids'], max_length=257)"
236
- ]
237
- },
238
- {
239
- "cell_type": "code",
240
- "execution_count": null,
241
- "metadata": {},
242
- "outputs": [],
243
- "source": [
244
- "greedy_output[0].shape"
245
- ]
246
- },
247
- {
248
- "cell_type": "code",
249
- "execution_count": null,
250
- "metadata": {
251
- "colab": {
252
- "base_uri": "https://localhost:8080/"
253
- },
254
- "id": "jYugh9cOuwc9",
255
- "outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537"
256
- },
257
- "outputs": [],
258
- "source": [
259
- "greedy_output[0]"
260
- ]
261
- },
262
- {
263
- "cell_type": "code",
264
- "execution_count": null,
265
- "metadata": {},
266
- "outputs": [],
267
- "source": [
268
- "greedy_output[0][0]"
269
- ]
270
- },
271
- {
272
- "cell_type": "markdown",
273
- "metadata": {},
274
- "source": [
275
- "# VGAN Jax"
276
- ]
277
- },
278
- {
279
- "cell_type": "code",
280
- "execution_count": null,
281
- "metadata": {},
282
- "outputs": [],
283
- "source": [
284
- "import io\n",
285
- "\n",
286
- "import requests\n",
287
- "from PIL import Image\n",
288
- "import numpy as np\n",
289
- "\n",
290
- "import torch\n",
291
- "import torchvision.transforms as T\n",
292
- "import torchvision.transforms.functional as TF\n",
293
- "from torchvision.transforms import InterpolationMode"
294
- ]
295
- },
296
- {
297
- "cell_type": "code",
298
- "execution_count": null,
299
- "metadata": {},
300
- "outputs": [],
301
- "source": [
302
- "from modeling_flax_vqgan import VQModel"
303
- ]
304
- },
305
- {
306
- "cell_type": "code",
307
- "execution_count": null,
308
- "metadata": {},
309
- "outputs": [],
310
- "source": [
311
- "def custom_to_pil(x):\n",
312
- " x = np.clip(x, 0., 1.)\n",
313
- " x = (255*x).astype(np.uint8)\n",
314
- " x = Image.fromarray(x)\n",
315
- " if not x.mode == \"RGB\":\n",
316
- " x = x.convert(\"RGB\")\n",
317
- " return x"
318
- ]
319
- },
320
- {
321
- "cell_type": "code",
322
- "execution_count": null,
323
- "metadata": {
324
- "colab": {
325
- "base_uri": "https://localhost:8080/"
326
- },
327
- "id": "Jz032w73nHEf",
328
- "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49",
329
- "scrolled": true
330
- },
331
- "outputs": [],
332
- "source": [
333
- "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
334
- ]
335
- },
336
- {
337
- "cell_type": "code",
338
- "execution_count": null,
339
- "metadata": {},
340
- "outputs": [],
341
- "source": [
342
- "def get_images(indices, model):\n",
343
- " indices = indices[:, 1:]\n",
344
- " print(indices.shape)\n",
345
- " img = model.decode_code(indices)\n",
346
- " return img"
347
- ]
348
- },
349
- {
350
- "cell_type": "code",
351
- "execution_count": null,
352
- "metadata": {},
353
- "outputs": [],
354
- "source": [
355
- "custom_to_pil(np.asarray(get_images(jnp.expand_dims(greedy_output[0][0],0), model)[0]))"
356
- ]
357
- }
358
- ],
359
- "metadata": {
360
- "accelerator": "TPU",
361
- "colab": {
362
- "collapsed_sections": [],
363
- "machine_shape": "hm",
364
- "name": "CustomBARTv4b-model-generate.ipynb",
365
- "provenance": []
366
- },
367
- "kernelspec": {
368
- "display_name": "Python 3 (ipykernel)",
369
- "language": "python",
370
- "name": "python3"
371
- },
372
- "language_info": {
373
- "codemirror_mode": {
374
- "name": "ipython",
375
- "version": 3
376
- },
377
- "file_extension": ".py",
378
- "mimetype": "text/x-python",
379
- "name": "python",
380
- "nbconvert_exporter": "python",
381
- "pygments_lexer": "ipython3",
382
- "version": "3.8.5"
383
- }
384
- },
385
- "nbformat": 4,
386
- "nbformat_minor": 4
387
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/notebooks/demo/model-sweep.py DELETED
@@ -1,220 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- import random
5
-
6
- import jax
7
- import flax.linen as nn
8
- from flax.training.common_utils import shard
9
- from flax.jax_utils import replicate, unreplicate
10
-
11
- from transformers.models.bart.modeling_flax_bart import *
12
- from transformers import BartTokenizer, FlaxBartForConditionalGeneration
13
-
14
- import io
15
-
16
- import requests
17
- from PIL import Image
18
- import numpy as np
19
- import matplotlib.pyplot as plt
20
-
21
- import torch
22
- import torchvision.transforms as T
23
- import torchvision.transforms.functional as TF
24
- from torchvision.transforms import InterpolationMode
25
-
26
- from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
27
-
28
- # TODO: set those args in a config file
29
- OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
30
- OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
31
- BOS_TOKEN_ID = 16384
32
- BASE_MODEL = 'facebook/bart-large-cnn'
33
- WANDB_MODEL = '3iwhu4w6'
34
-
35
- class CustomFlaxBartModule(FlaxBartModule):
36
- def setup(self):
37
- # we keep shared to easily load pre-trained weights
38
- self.shared = nn.Embed(
39
- self.config.vocab_size,
40
- self.config.d_model,
41
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
42
- dtype=self.dtype,
43
- )
44
- # a separate embedding is used for the decoder
45
- self.decoder_embed = nn.Embed(
46
- OUTPUT_VOCAB_SIZE,
47
- self.config.d_model,
48
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
49
- dtype=self.dtype,
50
- )
51
- self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
52
-
53
- # the decoder has a different config
54
- decoder_config = BartConfig(self.config.to_dict())
55
- decoder_config.max_position_embeddings = OUTPUT_LENGTH
56
- decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
57
- self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
58
-
59
- class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
60
- def setup(self):
61
- self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
62
- self.lm_head = nn.Dense(
63
- OUTPUT_VOCAB_SIZE,
64
- use_bias=False,
65
- dtype=self.dtype,
66
- kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
67
- )
68
- self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
69
-
70
- class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
71
- module_class = CustomFlaxBartForConditionalGenerationModule
72
-
73
- tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)
74
- vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
75
-
76
- def custom_to_pil(x):
77
- x = np.clip(x, 0., 1.)
78
- x = (255*x).astype(np.uint8)
79
- x = Image.fromarray(x)
80
- if not x.mode == "RGB":
81
- x = x.convert("RGB")
82
- return x
83
-
84
- def generate(input, rng, params):
85
- return model.generate(
86
- **input,
87
- max_length=257,
88
- num_beams=1,
89
- do_sample=True,
90
- prng_key=rng,
91
- eos_token_id=50000,
92
- pad_token_id=50000,
93
- params=params,
94
- )
95
-
96
- def get_images(indices, params):
97
- return vqgan.decode_code(indices, params=params)
98
-
99
- def plot_images(images):
100
- fig = plt.figure(figsize=(40, 20))
101
- columns = 4
102
- rows = 2
103
- plt.subplots_adjust(hspace=0, wspace=0)
104
-
105
- for i in range(1, columns*rows +1):
106
- fig.add_subplot(rows, columns, i)
107
- plt.imshow(images[i-1])
108
- plt.gca().axes.get_yaxis().set_visible(False)
109
- plt.show()
110
-
111
- def stack_reconstructions(images):
112
- w, h = images[0].size[0], images[0].size[1]
113
- img = Image.new("RGB", (len(images)*w, h))
114
- for i, img_ in enumerate(images):
115
- img.paste(img_, (i*w,0))
116
- return img
117
-
118
- p_generate = jax.pmap(generate, "batch")
119
- p_get_images = jax.pmap(get_images, "batch")
120
-
121
- # ## CLIP Scoring
122
- from transformers import CLIPProcessor, FlaxCLIPModel
123
-
124
- clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
125
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
126
-
127
- def hallucinate(prompt, num_images=64):
128
- prompt = [prompt] * jax.device_count()
129
- inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
130
- inputs = shard(inputs)
131
-
132
- all_images = []
133
- for i in range(num_images // jax.device_count()):
134
- key = random.randint(0, 1e7)
135
- rng = jax.random.PRNGKey(key)
136
- rngs = jax.random.split(rng, jax.local_device_count())
137
- indices = p_generate(inputs, rngs, bart_params).sequences
138
- indices = indices[:, :, 1:]
139
-
140
- images = p_get_images(indices, vqgan_params)
141
- images = np.squeeze(np.asarray(images), 1)
142
- for image in images:
143
- all_images.append(custom_to_pil(image))
144
- return all_images
145
-
146
- def clip_top_k(prompt, images, k=8):
147
- inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
148
- outputs = clip(**inputs)
149
- logits = outputs.logits_per_text
150
- scores = np.array(logits[0]).argsort()[-k:][::-1]
151
- return [images[score] for score in scores]
152
-
153
- from PIL import ImageDraw, ImageFont
154
-
155
- def captioned_strip(images, caption):
156
- w, h = images[0].size[0], images[0].size[1]
157
- img = Image.new("RGB", (len(images)*w, h + 48))
158
- for i, img_ in enumerate(images):
159
- img.paste(img_, (i*w, 48))
160
- draw = ImageDraw.Draw(img)
161
- font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
162
- draw.text((20, 3), caption, (255,255,255), font=font)
163
- return img
164
-
165
- def log_to_wandb(prompts):
166
- strips = []
167
- for prompt in prompts:
168
- print(f"Generating candidates for: {prompt}")
169
- images = hallucinate(prompt, num_images=32)
170
- selected = clip_top_k(prompt, images, k=8)
171
- strip = captioned_strip(selected, prompt)
172
- strips.append(wandb.Image(strip))
173
- wandb.log({"images": strips})
174
-
175
- ## Artifact loop
176
-
177
- import wandb
178
- import os
179
- os.environ["WANDB_SILENT"] = "true"
180
- os.environ["WANDB_CONSOLE"] = "off"
181
-
182
- id = wandb.util.generate_id()
183
- print(f"Logging images to wandb run id: {id}")
184
-
185
- run = wandb.init(id=id,
186
- entity='wandb',
187
- project="hf-flax-dalle-mini",
188
- job_type="predictions",
189
- resume="allow"
190
- )
191
-
192
- artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-3iwhu4w6:v0', type='bart_model')
193
- producer_run = artifact.logged_by()
194
- logged_artifacts = producer_run.logged_artifacts()
195
-
196
- for artifact in logged_artifacts:
197
- print(f"Generating predictions with version {artifact.version}")
198
- artifact_dir = artifact.download()
199
-
200
- # create our model
201
- model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
202
- model.config.force_bos_token_to_be_generated = False
203
- model.config.forced_bos_token_id = None
204
- model.config.forced_eos_token_id = None
205
-
206
- bart_params = replicate(model.params)
207
- vqgan_params = replicate(vqgan.params)
208
-
209
- prompts = prompts = [
210
- "white snow covered mountain under blue sky during daytime",
211
- "aerial view of beach during daytime",
212
- "aerial view of beach at night",
213
- "an armchair in the shape of an avocado",
214
- "young woman riding her bike trough a forest",
215
- "rice fields by the mediterranean coast",
216
- "white houses on the hill of a greek coastline",
217
- "illustration of a shark with a baby shark",
218
- ]
219
-
220
- log_to_wandb(prompts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/notebooks/demo/tpu-demo.ipynb DELETED
@@ -1,455 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "f6d33374",
6
- "metadata": {},
7
- "source": [
8
- "# Test notebook with CLIP scoring"
9
- ]
10
- },
11
- {
12
- "cell_type": "code",
13
- "execution_count": null,
14
- "id": "6eb74941-bb4d-4d7e-97f1-d5a3a07672bf",
15
- "metadata": {},
16
- "outputs": [],
17
- "source": [
18
- "# !pip install flax transformers\n",
19
- "# !git clone https://github.com/patil-suraj/vqgan-jax.git"
20
- ]
21
- },
22
- {
23
- "cell_type": "code",
24
- "execution_count": null,
25
- "id": "41db7534-f589-4b63-9165-9c9799e1b06e",
26
- "metadata": {},
27
- "outputs": [],
28
- "source": [
29
- "import random\n",
30
- "\n",
31
- "import jax\n",
32
- "import flax.linen as nn\n",
33
- "from flax.training.common_utils import shard\n",
34
- "from flax.jax_utils import replicate, unreplicate\n",
35
- "\n",
36
- "from transformers.models.bart.modeling_flax_bart import *\n",
37
- "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
38
- "\n",
39
- "import io\n",
40
- "\n",
41
- "import requests\n",
42
- "from PIL import Image\n",
43
- "import numpy as np\n",
44
- "import matplotlib.pyplot as plt\n",
45
- "\n",
46
- "import torch\n",
47
- "import torchvision.transforms as T\n",
48
- "import torchvision.transforms.functional as TF\n",
49
- "from torchvision.transforms import InterpolationMode\n",
50
- "\n",
51
- "jax.devices()"
52
- ]
53
- },
54
- {
55
- "cell_type": "markdown",
56
- "id": "d408065c",
57
- "metadata": {},
58
- "source": [
59
- "`dalle_mini` is a local package that contains the VQGAN-JAX model by Suraj, and other utilities. You can also `cd` to the directory that contains your checkout of [`vqgan-jax`](https://github.com/patil-suraj/vqgan-jax.git)"
60
- ]
61
- },
62
- {
63
- "cell_type": "code",
64
- "execution_count": null,
65
- "id": "09295910",
66
- "metadata": {},
67
- "outputs": [],
68
- "source": [
69
- "from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel\n",
70
- "#%cd /content/vqgan-jax"
71
- ]
72
- },
73
- {
74
- "cell_type": "code",
75
- "execution_count": null,
76
- "id": "b6a3462a-9004-4121-b365-3ae3aaf94dd2",
77
- "metadata": {},
78
- "outputs": [],
79
- "source": [
80
- "# TODO: set those args in a config file\n",
81
- "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
82
- "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
83
- "BOS_TOKEN_ID = 16384\n",
84
- "BASE_MODEL = 'facebook/bart-large-cnn'"
85
- ]
86
- },
87
- {
88
- "cell_type": "code",
89
- "execution_count": null,
90
- "id": "bbef1afb-0b36-44a5-83f7-643d7e2c0e30",
91
- "metadata": {},
92
- "outputs": [],
93
- "source": [
94
- "class CustomFlaxBartModule(FlaxBartModule):\n",
95
- " def setup(self):\n",
96
- " # we keep shared to easily load pre-trained weights\n",
97
- " self.shared = nn.Embed(\n",
98
- " self.config.vocab_size,\n",
99
- " self.config.d_model,\n",
100
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
101
- " dtype=self.dtype,\n",
102
- " )\n",
103
- " # a separate embedding is used for the decoder\n",
104
- " self.decoder_embed = nn.Embed(\n",
105
- " OUTPUT_VOCAB_SIZE,\n",
106
- " self.config.d_model,\n",
107
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
108
- " dtype=self.dtype,\n",
109
- " )\n",
110
- " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
111
- "\n",
112
- " # the decoder has a different config\n",
113
- " decoder_config = BartConfig(self.config.to_dict())\n",
114
- " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
115
- " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
116
- " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
117
- "\n",
118
- "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
119
- " def setup(self):\n",
120
- " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
121
- " self.lm_head = nn.Dense(\n",
122
- " OUTPUT_VOCAB_SIZE,\n",
123
- " use_bias=False,\n",
124
- " dtype=self.dtype,\n",
125
- " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
126
- " )\n",
127
- " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
128
- "\n",
129
- "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
130
- " module_class = CustomFlaxBartForConditionalGenerationModule"
131
- ]
132
- },
133
- {
134
- "cell_type": "code",
135
- "execution_count": null,
136
- "id": "879320b7-eaa0-4dc9-bbf2-c81efc53301d",
137
- "metadata": {},
138
- "outputs": [],
139
- "source": [
140
- "import wandb\n",
141
- "run = wandb.init()\n",
142
- "artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-3h3x3565:latest', type='bart_model')\n",
143
- "artifact_dir = artifact.download()"
144
- ]
145
- },
146
- {
147
- "cell_type": "code",
148
- "execution_count": null,
149
- "id": "e8bcff33-e95b-4c01-b162-ee857a55c3e6",
150
- "metadata": {},
151
- "outputs": [],
152
- "source": [
153
- "# create our model\n",
154
- "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)\n",
155
- "model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)\n",
156
- "model.config.force_bos_token_to_be_generated = False\n",
157
- "model.config.forced_bos_token_id = None\n",
158
- "model.config.forced_eos_token_id = None\n",
159
- "\n",
160
- "# we verify that the shape has not been modified\n",
161
- "model.params['final_logits_bias'].shape"
162
- ]
163
- },
164
- {
165
- "cell_type": "code",
166
- "execution_count": null,
167
- "id": "8d5e0f14-2502-470e-9553-daee6748601f",
168
- "metadata": {},
169
- "outputs": [],
170
- "source": [
171
- "vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
172
- ]
173
- },
174
- {
175
- "cell_type": "code",
176
- "execution_count": null,
177
- "id": "6cca395a-93c2-49bc-a3be-98287e4403d4",
178
- "metadata": {},
179
- "outputs": [],
180
- "source": [
181
- "def custom_to_pil(x):\n",
182
- " x = np.clip(x, 0., 1.)\n",
183
- " x = (255*x).astype(np.uint8)\n",
184
- " x = Image.fromarray(x)\n",
185
- " if not x.mode == \"RGB\":\n",
186
- " x = x.convert(\"RGB\")\n",
187
- " return x\n",
188
- "\n",
189
- "def generate(input, rng, params):\n",
190
- " return model.generate(\n",
191
- " **input,\n",
192
- " max_length=257,\n",
193
- " num_beams=1,\n",
194
- " do_sample=True,\n",
195
- " prng_key=rng,\n",
196
- " eos_token_id=50000,\n",
197
- " pad_token_id=50000,\n",
198
- " params=params\n",
199
- " )\n",
200
- "\n",
201
- "def get_images(indices, params):\n",
202
- " return vqgan.decode_code(indices, params=params)\n",
203
- "\n",
204
- "\n",
205
- "def plot_images(images):\n",
206
- " fig = plt.figure(figsize=(40, 20))\n",
207
- " columns = 4\n",
208
- " rows = 2\n",
209
- " plt.subplots_adjust(hspace=0, wspace=0)\n",
210
- "\n",
211
- " for i in range(1, columns*rows +1):\n",
212
- " fig.add_subplot(rows, columns, i)\n",
213
- " plt.imshow(images[i-1])\n",
214
- " plt.gca().axes.get_yaxis().set_visible(False)\n",
215
- " plt.show()\n",
216
- " \n",
217
- "def stack_reconstructions(images):\n",
218
- " w, h = images[0].size[0], images[0].size[1]\n",
219
- " img = Image.new(\"RGB\", (len(images)*w, h))\n",
220
- " for i, img_ in enumerate(images):\n",
221
- " img.paste(img_, (i*w,0))\n",
222
- " return img"
223
- ]
224
- },
225
- {
226
- "cell_type": "code",
227
- "execution_count": null,
228
- "id": "b1bec3d2-ef17-4feb-aa0d-b51ed2fdcd3e",
229
- "metadata": {},
230
- "outputs": [],
231
- "source": [
232
- "p_generate = jax.pmap(generate, \"batch\")\n",
233
- "p_get_images = jax.pmap(get_images, \"batch\")"
234
- ]
235
- },
236
- {
237
- "cell_type": "code",
238
- "execution_count": null,
239
- "id": "a539823a-a775-4d92-96a5-dc8b1eef69c5",
240
- "metadata": {},
241
- "outputs": [],
242
- "source": [
243
- "bart_params = replicate(model.params)\n",
244
- "vqgan_params = replicate(vqgan.params)"
245
- ]
246
- },
247
- {
248
- "cell_type": "code",
249
- "execution_count": null,
250
- "id": "e8b268d8-6992-422a-8373-95651474ae70",
251
- "metadata": {},
252
- "outputs": [],
253
- "source": [
254
- "prompts = [\n",
255
- " \"man in blue jacket walking on pathway in between trees during daytime\",\n",
256
- " 'white snow covered mountain under blue sky during daytime',\n",
257
- " 'white snow covered mountain under blue sky during night',\n",
258
- " \"orange tabby cat on persons hand\",\n",
259
- " \"aerial view of beach during daytime\",\n",
260
- " \"chess pieces on chess board\",\n",
261
- " \"laptop on brown wooden table\",\n",
262
- " \"white bus on road near high rise buildings\",\n",
263
- "]\n",
264
- "\n",
265
- "\n",
266
- "prompt = [prompts[1]] * jax.device_count()\n",
267
- "inputs = tokenizer(prompt, return_tensors='jax', padding=\"max_length\", truncation=True, max_length=128).data\n",
268
- "inputs = shard(inputs)"
269
- ]
270
- },
271
- {
272
- "cell_type": "code",
273
- "execution_count": null,
274
- "id": "68638cfa-9a4d-4e6a-8630-91aefb627bbd",
275
- "metadata": {},
276
- "outputs": [],
277
- "source": [
278
- "%%time\n",
279
- "for i in range(8):\n",
280
- " key = random.randint(0, 1e7)\n",
281
- " rng = jax.random.PRNGKey(key)\n",
282
- " rngs = jax.random.split(rng, jax.local_device_count())\n",
283
- " indices = p_generate(inputs, rngs, bart_params).sequences\n",
284
- " indices = indices[:, :, 1:]\n",
285
- "\n",
286
- " images = p_get_images(indices, vqgan_params)\n",
287
- " images = np.squeeze(np.asarray(images), 1)\n",
288
- " imges = [custom_to_pil(image) for image in images]\n",
289
- "\n",
290
- " plt.figure(figsize=(40, 20))\n",
291
- " plt.imshow(stack_reconstructions(imges))"
292
- ]
293
- },
294
- {
295
- "cell_type": "markdown",
296
- "id": "b6e1060f",
297
- "metadata": {},
298
- "source": [
299
- "## CLIP Scoring"
300
- ]
301
- },
302
- {
303
- "cell_type": "code",
304
- "execution_count": null,
305
- "id": "c68724bc",
306
- "metadata": {},
307
- "outputs": [],
308
- "source": [
309
- "from transformers import CLIPProcessor, FlaxCLIPModel"
310
- ]
311
- },
312
- {
313
- "cell_type": "code",
314
- "execution_count": null,
315
- "id": "17158e5b",
316
- "metadata": {},
317
- "outputs": [],
318
- "source": [
319
- "clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
320
- "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")"
321
- ]
322
- },
323
- {
324
- "cell_type": "code",
325
- "execution_count": null,
326
- "id": "f1b37b6d",
327
- "metadata": {},
328
- "outputs": [],
329
- "source": [
330
- "def hallucinate(prompt, num_images=64):\n",
331
- " prompt = [prompt] * jax.device_count()\n",
332
- " inputs = tokenizer(prompt, return_tensors='jax', padding=\"max_length\", truncation=True, max_length=128).data\n",
333
- " inputs = shard(inputs)\n",
334
- "\n",
335
- " all_images = []\n",
336
- " for i in range(num_images // jax.device_count()):\n",
337
- " key = random.randint(0, 1e7)\n",
338
- " rng = jax.random.PRNGKey(key)\n",
339
- " rngs = jax.random.split(rng, jax.local_device_count())\n",
340
- " indices = p_generate(inputs, rngs, bart_params).sequences\n",
341
- " indices = indices[:, :, 1:]\n",
342
- "\n",
343
- " images = p_get_images(indices, vqgan_params)\n",
344
- " images = np.squeeze(np.asarray(images), 1)\n",
345
- " for image in images:\n",
346
- " all_images.append(custom_to_pil(image))\n",
347
- " return all_images"
348
- ]
349
- },
350
- {
351
- "cell_type": "code",
352
- "execution_count": null,
353
- "id": "831c715f",
354
- "metadata": {},
355
- "outputs": [],
356
- "source": [
357
- "def clip_top_k(prompt, images, k=8):\n",
358
- " inputs = processor(text=prompt, images=images, return_tensors=\"np\", padding=True)\n",
359
- " outputs = clip(**inputs)\n",
360
- " logits = outputs.logits_per_text\n",
361
- " scores = np.array(logits[0]).argsort()[-k:][::-1]\n",
362
- " return [images[score] for score in scores]"
363
- ]
364
- },
365
- {
366
- "cell_type": "code",
367
- "execution_count": null,
368
- "id": "00605e13",
369
- "metadata": {},
370
- "outputs": [],
371
- "source": [
372
- "prompt = \"white snow covered mountain under blue sky during daytime\"\n",
373
- "images = hallucinate(prompt)\n",
374
- "selected = clip_top_k(prompt, images, k=8)\n",
375
- "stack_reconstructions(selected)"
376
- ]
377
- },
378
- {
379
- "cell_type": "code",
380
- "execution_count": null,
381
- "id": "cc745da2",
382
- "metadata": {},
383
- "outputs": [],
384
- "source": [
385
- "prompt = \"aerial view of beach at night\"\n",
386
- "images = hallucinate(prompt)\n",
387
- "selected = clip_top_k(prompt, images, k=8)\n",
388
- "stack_reconstructions(selected)"
389
- ]
390
- },
391
- {
392
- "cell_type": "code",
393
- "execution_count": null,
394
- "id": "c9cc0b1d",
395
- "metadata": {},
396
- "outputs": [],
397
- "source": [
398
- "prompt = \"an armchair in the shape of an avocado\"\n",
399
- "images = hallucinate(prompt)\n",
400
- "selected = clip_top_k(prompt, images, k=8)\n",
401
- "stack_reconstructions(selected)"
402
- ]
403
- },
404
- {
405
- "cell_type": "code",
406
- "execution_count": null,
407
- "id": "574e9433",
408
- "metadata": {},
409
- "outputs": [],
410
- "source": [
411
- "prompt = \"young woman riding her bike into a forest\"\n",
412
- "images = hallucinate(prompt)\n",
413
- "selected = clip_top_k(prompt, images, k=8)\n",
414
- "stack_reconstructions(selected)"
415
- ]
416
- },
417
- {
418
- "cell_type": "markdown",
419
- "id": "4762c91e",
420
- "metadata": {},
421
- "source": [
422
- "`Forest` seems to dominate. Interesting cubist interpretation in the fourth image."
423
- ]
424
- },
425
- {
426
- "cell_type": "code",
427
- "execution_count": null,
428
- "id": "af30608a",
429
- "metadata": {},
430
- "outputs": [],
431
- "source": []
432
- }
433
- ],
434
- "metadata": {
435
- "kernelspec": {
436
- "display_name": "Python 3 (ipykernel)",
437
- "language": "python",
438
- "name": "python3"
439
- },
440
- "language_info": {
441
- "codemirror_mode": {
442
- "name": "ipython",
443
- "version": 3
444
- },
445
- "file_extension": ".py",
446
- "mimetype": "text/x-python",
447
- "name": "python",
448
- "nbconvert_exporter": "python",
449
- "pygments_lexer": "ipython3",
450
- "version": "3.8.5"
451
- }
452
- },
453
- "nbformat": 4,
454
- "nbformat_minor": 5
455
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/notebooks/model/data-pipeline.ipynb DELETED
@@ -1,385 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "bf8fb38a",
6
- "metadata": {},
7
- "source": [
8
- "# Data Pipeline"
9
- ]
10
- },
11
- {
12
- "cell_type": "code",
13
- "execution_count": 1,
14
- "id": "9b83dcb9",
15
- "metadata": {},
16
- "outputs": [],
17
- "source": [
18
- "from dataclasses import dataclass, field\n",
19
- "from pathlib import Path\n",
20
- "\n",
21
- "import datasets\n",
22
- "from datasets import Dataset, load_dataset\n",
23
- "import numpy as np\n",
24
- "\n",
25
- "from transformers import BartTokenizer\n",
26
- "\n",
27
- "from tqdm import tqdm\n",
28
- "\n",
29
- "import jax\n",
30
- "import jax.numpy as jnp\n",
31
- "\n",
32
- "from flax.training.common_utils import shard"
33
- ]
34
- },
35
- {
36
- "cell_type": "markdown",
37
- "id": "a661a89e",
38
- "metadata": {},
39
- "source": [
40
- "File containing image paths, captions and VQGAN-encoded indices."
41
- ]
42
- },
43
- {
44
- "cell_type": "code",
45
- "execution_count": 2,
46
- "id": "0e84e889",
47
- "metadata": {},
48
- "outputs": [],
49
- "source": [
50
- "datafile = '/data/CC12M/images-encoded-10000.tsv' # 9999 encoded images from CC12M"
51
- ]
52
- },
53
- {
54
- "cell_type": "markdown",
55
- "id": "7fdc640b",
56
- "metadata": {},
57
- "source": [
58
- "TODO: generate train/test splits if necessary."
59
- ]
60
- },
61
- {
62
- "cell_type": "code",
63
- "execution_count": 3,
64
- "id": "cc6789b4",
65
- "metadata": {},
66
- "outputs": [
67
- {
68
- "name": "stderr",
69
- "output_type": "stream",
70
- "text": [
71
- "Using custom data configuration default-91833df78e844785\n",
72
- "Reusing dataset csv (/home/pedro/.cache/huggingface/datasets/csv/default-91833df78e844785/0.0.0/e138af468cb14e747fb46a19c787ffcfa5170c821476d20d5304287ce12bbc23)\n"
73
- ]
74
- }
75
- ],
76
- "source": [
77
- "dataset = load_dataset('csv', delimiter='\\t', data_files=[datafile])"
78
- ]
79
- },
80
- {
81
- "cell_type": "code",
82
- "execution_count": 4,
83
- "id": "f3ed4919",
84
- "metadata": {},
85
- "outputs": [
86
- {
87
- "data": {
88
- "text/plain": [
89
- "DatasetDict({\n",
90
- " train: Dataset({\n",
91
- " features: ['image_file', 'caption', 'encoding'],\n",
92
- " num_rows: 9999\n",
93
- " })\n",
94
- "})"
95
- ]
96
- },
97
- "execution_count": 4,
98
- "metadata": {},
99
- "output_type": "execute_result"
100
- }
101
- ],
102
- "source": [
103
- "dataset"
104
- ]
105
- },
106
- {
107
- "cell_type": "code",
108
- "execution_count": 5,
109
- "id": "a70c7354",
110
- "metadata": {},
111
- "outputs": [
112
- {
113
- "data": {
114
- "text/plain": [
115
- "Dataset({\n",
116
- " features: ['image_file', 'caption', 'encoding'],\n",
117
- " num_rows: 9999\n",
118
- "})"
119
- ]
120
- },
121
- "execution_count": 5,
122
- "metadata": {},
123
- "output_type": "execute_result"
124
- }
125
- ],
126
- "source": [
127
- "dataset = dataset[\"train\"]\n",
128
- "dataset"
129
- ]
130
- },
131
- {
132
- "cell_type": "markdown",
133
- "id": "a73454cf",
134
- "metadata": {},
135
- "source": [
136
- "We don't really need the `image_file` field for training. We'll drop it during pre-processing because we won't be able to numericalize it to a `jnp.array`, which would be required in JAX."
137
- ]
138
- },
139
- {
140
- "cell_type": "markdown",
141
- "id": "7c0fa992",
142
- "metadata": {},
143
- "source": [
144
- "## Preprocessing"
145
- ]
146
- },
147
- {
148
- "cell_type": "markdown",
149
- "id": "a0e36582",
150
- "metadata": {},
151
- "source": [
152
- "The `encoding` field contains a string representation of the encoded indices. We'll convert them to numbers. We also need to tokenize the captions."
153
- ]
154
- },
155
- {
156
- "cell_type": "code",
157
- "execution_count": 6,
158
- "id": "d46f6ac5",
159
- "metadata": {},
160
- "outputs": [],
161
- "source": [
162
- "# Setting padding=\"max_length\" as we need fixed length inputs for jitted functions\n",
163
- "max_length = 256 # Read from data_args.max_source_length\n",
164
- "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')\n",
165
- "image_bos = 16384 # Max token is 16383 in our VQGAN configuration"
166
- ]
167
- },
168
- {
169
- "cell_type": "code",
170
- "execution_count": 7,
171
- "id": "4cac6643",
172
- "metadata": {},
173
- "outputs": [],
174
- "source": [
175
- "def preprocess_function(examples):\n",
176
- " inputs = examples[\"caption\"]\n",
177
- "# inputs = [prefix + inp for inp in inputs] # Do we need this?\n",
178
- " model_inputs = tokenizer(\n",
179
- " inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"np\"\n",
180
- " )\n",
181
- "\n",
182
- " model_inputs[\"labels\"] = [[image_bos] + eval(indices) for indices in examples['encoding']]\n",
183
- "\n",
184
- " return model_inputs"
185
- ]
186
- },
187
- {
188
- "cell_type": "code",
189
- "execution_count": 8,
190
- "id": "e6a4cb91",
191
- "metadata": {},
192
- "outputs": [],
193
- "source": [
194
- "num_workers = 48 # We have 96 processors in the TPU\n",
195
- "column_names = dataset.column_names\n",
196
- "input_dataset = dataset.map(preprocess_function,\n",
197
- " remove_columns=column_names,\n",
198
- " batched=True,\n",
199
- " num_proc=48\n",
200
- ")"
201
- ]
202
- },
203
- {
204
- "cell_type": "code",
205
- "execution_count": 9,
206
- "id": "a9b1b467",
207
- "metadata": {},
208
- "outputs": [],
209
- "source": [
210
- "def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):\n",
211
- " \"\"\"\n",
212
- " Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.\n",
213
- " Shuffle batches if `shuffle` is `True`.\n",
214
- " \"\"\"\n",
215
- " steps_per_epoch = len(dataset) // batch_size\n",
216
- "\n",
217
- " if shuffle:\n",
218
- " batch_idx = jax.random.permutation(rng, len(dataset))\n",
219
- " else:\n",
220
- " batch_idx = jnp.arange(len(dataset))\n",
221
- "\n",
222
- " batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.\n",
223
- " batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))\n",
224
- "\n",
225
- " for idx in batch_idx:\n",
226
- " batch = dataset[idx] \n",
227
- " batch = {k: jnp.array(v) for k, v in batch.items()}\n",
228
- " batch = shard(batch)\n",
229
- " yield batch"
230
- ]
231
- },
232
- {
233
- "cell_type": "code",
234
- "execution_count": 10,
235
- "id": "0a628505",
236
- "metadata": {},
237
- "outputs": [
238
- {
239
- "name": "stderr",
240
- "output_type": "stream",
241
- "text": [
242
- "INFO:absl:Starting the local TPU driver.\n",
243
- "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
244
- "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Host TPU Interpreter\n"
245
- ]
246
- }
247
- ],
248
- "source": [
249
- "rng = jax.random.PRNGKey(23) # Use training_args.seed\n",
250
- "batch_size = 64 # Per device\n",
251
- "super_batch_size = batch_size * jax.device_count()"
252
- ]
253
- },
254
- {
255
- "cell_type": "code",
256
- "execution_count": 11,
257
- "id": "b3a5ce7d",
258
- "metadata": {},
259
- "outputs": [],
260
- "source": [
261
- "loader = data_loader(rng, input_dataset, batch_size=super_batch_size)"
262
- ]
263
- },
264
- {
265
- "cell_type": "code",
266
- "execution_count": 12,
267
- "id": "67aa8f9c",
268
- "metadata": {},
269
- "outputs": [],
270
- "source": [
271
- "superbatch = next(iter(loader))"
272
- ]
273
- },
274
- {
275
- "cell_type": "code",
276
- "execution_count": 13,
277
- "id": "7cd99402",
278
- "metadata": {},
279
- "outputs": [
280
- {
281
- "data": {
282
- "text/plain": [
283
- "dict_keys(['attention_mask', 'input_ids', 'labels'])"
284
- ]
285
- },
286
- "execution_count": 13,
287
- "metadata": {},
288
- "output_type": "execute_result"
289
- }
290
- ],
291
- "source": [
292
- "superbatch.keys()"
293
- ]
294
- },
295
- {
296
- "cell_type": "code",
297
- "execution_count": 14,
298
- "id": "652a4a9e",
299
- "metadata": {},
300
- "outputs": [
301
- {
302
- "data": {
303
- "text/plain": [
304
- "8"
305
- ]
306
- },
307
- "execution_count": 14,
308
- "metadata": {},
309
- "output_type": "execute_result"
310
- }
311
- ],
312
- "source": [
313
- "len(superbatch[\"labels\"])"
314
- ]
315
- },
316
- {
317
- "cell_type": "code",
318
- "execution_count": 15,
319
- "id": "de7de4e8",
320
- "metadata": {},
321
- "outputs": [
322
- {
323
- "data": {
324
- "text/plain": [
325
- "(8, 64, 257)"
326
- ]
327
- },
328
- "execution_count": 15,
329
- "metadata": {},
330
- "output_type": "execute_result"
331
- }
332
- ],
333
- "source": [
334
- "superbatch[\"labels\"].shape"
335
- ]
336
- },
337
- {
338
- "cell_type": "markdown",
339
- "id": "6800153b",
340
- "metadata": {},
341
- "source": [
342
- "Any image sequence should begin with `image_bos`:"
343
- ]
344
- },
345
- {
346
- "cell_type": "code",
347
- "execution_count": 16,
348
- "id": "cfe23a71",
349
- "metadata": {},
350
- "outputs": [],
351
- "source": [
352
- "assert superbatch[\"labels\"][1][5][0].item() == image_bos"
353
- ]
354
- },
355
- {
356
- "cell_type": "code",
357
- "execution_count": null,
358
- "id": "0fb899b4",
359
- "metadata": {},
360
- "outputs": [],
361
- "source": []
362
- }
363
- ],
364
- "metadata": {
365
- "kernelspec": {
366
- "display_name": "Python 3 (ipykernel)",
367
- "language": "python",
368
- "name": "python3"
369
- },
370
- "language_info": {
371
- "codemirror_mode": {
372
- "name": "ipython",
373
- "version": 3
374
- },
375
- "file_extension": ".py",
376
- "mimetype": "text/x-python",
377
- "name": "python",
378
- "nbconvert_exporter": "python",
379
- "pygments_lexer": "ipython3",
380
- "version": "3.8.10"
381
- }
382
- },
383
- "nbformat": 4,
384
- "nbformat_minor": 5
385
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/predictions/wandb-examples-from-backend.py DELETED
@@ -1,52 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- from PIL import Image, ImageDraw, ImageFont
5
- import wandb
6
- import os
7
-
8
- from dalle_mini.backend import ServiceError, get_images_from_backend
9
- from dalle_mini.helpers import captioned_strip
10
-
11
- os.environ["WANDB_SILENT"] = "true"
12
- os.environ["WANDB_CONSOLE"] = "off"
13
-
14
- # set id to None so our latest images don't get overwritten
15
- id = None
16
- run = wandb.init(id=id,
17
- entity='wandb',
18
- project="hf-flax-dalle-mini",
19
- job_type="predictions",
20
- resume="allow"
21
- )
22
-
23
- def log_to_wandb(prompts):
24
- try:
25
- backend_url = os.environ["BACKEND_SERVER"]
26
-
27
- strips = []
28
- for prompt in prompts:
29
- print(f"Getting selections for: {prompt}")
30
- selected = get_images_from_backend(prompt, backend_url)
31
- strip = captioned_strip(selected, prompt)
32
- strips.append(wandb.Image(strip))
33
- wandb.log({"images": strips})
34
- except ServiceError as error:
35
- print(f"Service unavailable, status: {error.status_code}")
36
- except KeyError:
37
- print("Error: BACKEND_SERVER unset")
38
-
39
- prompts = [
40
- "white snow covered mountain under blue sky during daytime",
41
- "aerial view of beach during daytime",
42
- "aerial view of beach at night",
43
- "an armchair in the shape of an avocado",
44
- "a logo of an avocado armchair playing music",
45
- "young woman riding her bike trough a forest",
46
- "rice fields by the mediterranean coast",
47
- "white houses on the hill of a greek coastline",
48
- "illustration of a shark with a baby shark",
49
- "painting of an oniric forest glade surrounded by tall trees",
50
- ]
51
-
52
- log_to_wandb(prompts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/{seq2seq/requirements.txt → requirements.txt} RENAMED
@@ -10,6 +10,7 @@ jupyter
10
  wandb
11
  nltk
12
  optax
 
13
 
14
  # Inference
15
  ftfy
 
10
  wandb
11
  nltk
12
  optax
13
+ git+https://github.com/patil-suraj/vqgan-jax.git@610d842dd33c739325a944102ed33acc07692dd5
14
 
15
  # Inference
16
  ftfy
dev/vqgan/JAX_VQGAN_f16_16384_Reconstruction.ipynb ADDED
The diff for this file is too large to render. See raw diff