boris commited on
Commit
0aab987
·
1 Parent(s): 47f6891

chore: reduce size of notebooks

Browse files

Former-commit-id: 4b1870193012ec35af398b3864eb37a43adf1e97

dev/notebooks/demo/CustomBARTv4b_model-generate.ipynb CHANGED
@@ -1,566 +1,394 @@
1
  {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
 
 
 
 
 
 
 
 
 
 
 
5
  "colab": {
6
- "name": "CustomBARTv4b-model-generate.ipynb",
7
- "provenance": [],
8
- "collapsed_sections": [],
9
- "machine_shape": "hm"
10
- },
11
- "kernelspec": {
12
- "name": "python3",
13
- "display_name": "Python 3"
14
  },
15
- "language_info": {
16
- "name": "python"
17
- },
18
- "accelerator": "TPU"
 
 
 
 
19
  },
20
- "cells": [
21
- {
22
- "cell_type": "markdown",
23
- "metadata": {
24
- "id": "ewer-Q-0w2xA"
25
- },
26
- "source": [
27
- "# Installation"
28
- ]
29
- },
30
- {
31
- "cell_type": "code",
32
- "metadata": {
33
- "colab": {
34
- "base_uri": "https://localhost:8080/"
35
- },
36
- "id": "NpsF9ipLLl2s",
37
- "outputId": "10bf54aa-b89d-4e42-9777-bc97b00a5f32"
38
- },
39
- "source": [
40
- "!pip install git+https://github.com/huggingface/transformers/\n",
41
- "!pip install git+https://github.com/google/flax"
42
- ],
43
- "execution_count": 1,
44
- "outputs": [
45
- {
46
- "output_type": "stream",
47
- "text": [
48
- "Collecting git+https://github.com/huggingface/transformers/\n",
49
- " Cloning https://github.com/huggingface/transformers/ to /tmp/pip-req-build-oxejx1op\n",
50
- " Running command git clone -q https://github.com/huggingface/transformers/ /tmp/pip-req-build-oxejx1op\n",
51
- " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
52
- " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
53
- " Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
54
- "Requirement already satisfied (use --upgrade to upgrade): transformers==4.9.0.dev0 from git+https://github.com/huggingface/transformers/ in /usr/local/lib/python3.7/dist-packages\n",
55
- "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (1.19.5)\n",
56
- "Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (20.9)\n",
57
- "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (5.4.1)\n",
58
- "Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (0.0.45)\n",
59
- "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (4.6.0)\n",
60
- "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (4.41.1)\n",
61
- "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (3.0.12)\n",
62
- "Requirement already satisfied: huggingface-hub==0.0.12 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (0.0.12)\n",
63
- "Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (0.10.3)\n",
64
- "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (2019.12.20)\n",
65
- "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (2.23.0)\n",
66
- "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->transformers==4.9.0.dev0) (2.4.7)\n",
67
- "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.9.0.dev0) (1.15.0)\n",
68
- "Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.9.0.dev0) (1.0.1)\n",
69
- "Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.9.0.dev0) (7.1.2)\n",
70
- "Requirement already satisfied: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->transformers==4.9.0.dev0) (3.7.4.3)\n",
71
- "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->transformers==4.9.0.dev0) (3.4.1)\n",
72
- "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (2021.5.30)\n",
73
- "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (3.0.4)\n",
74
- "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (1.24.3)\n",
75
- "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (2.10)\n",
76
- "Building wheels for collected packages: transformers\n",
77
- " Building wheel for transformers (PEP 517) ... \u001b[?25l\u001b[?25hdone\n",
78
- " Created wheel for transformers: filename=transformers-4.9.0.dev0-cp37-none-any.whl size=2582229 sha256=249c593273ccca3027c6427d2c6fd749a89f21d722d628d97eb438a2cf3185a8\n",
79
- " Stored in directory: /tmp/pip-ephem-wheel-cache-l2rqt1b7/wheels/61/69/33/974fccec4d0ab5feee9fe83bd93e680d269a805be9ede5ec60\n",
80
- "Successfully built transformers\n",
81
- "Collecting git+https://github.com/google/flax\n",
82
- " Cloning https://github.com/google/flax to /tmp/pip-req-build-rt9g1_wx\n",
83
- " Running command git clone -q https://github.com/google/flax /tmp/pip-req-build-rt9g1_wx\n",
84
- "Requirement already satisfied (use --upgrade to upgrade): flax==0.3.4 from git+https://github.com/google/flax in /usr/local/lib/python3.7/dist-packages\n",
85
- "Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (1.19.5)\n",
86
- "Requirement already satisfied: jax>=0.2.13 in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (0.2.13)\n",
87
- "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (3.2.2)\n",
88
- "Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (1.0.2)\n",
89
- "Requirement already satisfied: optax in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (0.0.9)\n",
90
- "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->flax==0.3.4) (3.3.0)\n",
91
- "Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->flax==0.3.4) (0.12.0)\n",
92
- "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (2.8.1)\n",
93
- "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (0.10.0)\n",
94
- "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (2.4.7)\n",
95
- "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (1.3.1)\n",
96
- "Requirement already satisfied: chex>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from optax->flax==0.3.4) (0.0.8)\n",
97
- "Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax->flax==0.3.4) (0.1.66+cuda110)\n",
98
- "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax>=0.2.13->flax==0.3.4) (1.15.0)\n",
99
- "Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax==0.3.4) (0.1.6)\n",
100
- "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax==0.3.4) (0.11.1)\n",
101
- "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax->flax==0.3.4) (1.12)\n",
102
- "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax->flax==0.3.4) (1.4.1)\n",
103
- "Building wheels for collected packages: flax\n",
104
- " Building wheel for flax (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
105
- " Created wheel for flax: filename=flax-0.3.4-cp37-none-any.whl size=184692 sha256=503b27995f372afe33631e71572d5edc1fffd4d2e0a4cd206d291ad6b0e4c299\n",
106
- " Stored in directory: /tmp/pip-ephem-wheel-cache-g1pzxnv6/wheels/3d/26/f4/0ea6051d7352289d9e4f8178348452b35a9a97bde6035405a5\n",
107
- "Successfully built flax\n"
108
- ],
109
- "name": "stdout"
110
- }
111
- ]
112
- },
113
- {
114
- "cell_type": "code",
115
- "metadata": {
116
- "id": "M1wVkrpjU6zO"
117
- },
118
- "source": [
119
- "%load_ext autoreload\n",
120
- "%autoreload 2"
121
- ],
122
- "execution_count": 2,
123
- "outputs": []
124
- },
125
- {
126
- "cell_type": "markdown",
127
- "metadata": {
128
- "id": "t47CH1H_IOT8"
129
- },
130
- "source": [
131
- "# Custom BART Model"
132
- ]
133
- },
134
- {
135
- "cell_type": "code",
136
- "metadata": {
137
- "id": "9jQnM6S2vCpn"
138
- },
139
- "source": [
140
- "# TODO: set those args in a config file\n",
141
- "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
142
- "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
143
- "BOS_TOKEN_ID = 16384\n",
144
- "BASE_MODEL = 'facebook/bart-large'"
145
- ],
146
- "execution_count": 3,
147
- "outputs": []
148
- },
149
- {
150
- "cell_type": "code",
151
- "metadata": {
152
- "id": "_eEaJVxAKpV5"
153
- },
154
- "source": [
155
- "import jax\n",
156
- "import flax.linen as nn\n",
157
- "\n",
158
- "from transformers.models.bart.modeling_flax_bart import *\n",
159
- "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
160
- "\n",
161
- "class CustomFlaxBartModule(FlaxBartModule):\n",
162
- " def setup(self):\n",
163
- " # we keep shared to easily load pre-trained weights\n",
164
- " self.shared = nn.Embed(\n",
165
- " self.config.vocab_size,\n",
166
- " self.config.d_model,\n",
167
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
168
- " dtype=self.dtype,\n",
169
- " )\n",
170
- " # a separate embedding is used for the decoder\n",
171
- " self.decoder_embed = nn.Embed(\n",
172
- " OUTPUT_VOCAB_SIZE,\n",
173
- " self.config.d_model,\n",
174
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
175
- " dtype=self.dtype,\n",
176
- " )\n",
177
- " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
178
- "\n",
179
- " # the decoder has a different config\n",
180
- " decoder_config = BartConfig(self.config.to_dict())\n",
181
- " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
182
- " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
183
- " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
184
- "\n",
185
- "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
186
- " def setup(self):\n",
187
- " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
188
- " self.lm_head = nn.Dense(\n",
189
- " OUTPUT_VOCAB_SIZE,\n",
190
- " use_bias=False,\n",
191
- " dtype=self.dtype,\n",
192
- " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
193
- " )\n",
194
- " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
195
- "\n",
196
- "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
197
- " module_class = CustomFlaxBartForConditionalGenerationModule"
198
- ],
199
- "execution_count": 4,
200
- "outputs": []
201
- },
202
- {
203
- "cell_type": "code",
204
- "metadata": {
205
- "id": "S7CP9Td9m2ge",
206
- "colab": {
207
- "base_uri": "https://localhost:8080/"
208
- },
209
- "outputId": "5638ef68-9c40-46f7-90ba-a4d05b61360d"
210
- },
211
- "source": [
212
- "# load pre-trained model for encoder weights\n",
213
- "base_model = FlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)"
214
- ],
215
- "execution_count": 5,
216
- "outputs": [
217
- {
218
- "output_type": "stream",
219
- "text": [
220
- "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
221
- ],
222
- "name": "stderr"
223
- }
224
- ]
225
- },
226
- {
227
- "cell_type": "code",
228
- "metadata": {
229
- "id": "6lmynR-poceH"
230
- },
231
- "source": [
232
- "# set up our new model config\n",
233
- "config = BartConfig.from_pretrained(BASE_MODEL)\n",
234
- "config.tie_word_embeddings = False\n",
235
- "config.decoder_start_token_id = BOS_TOKEN_ID\n",
236
- "config.bos_token_id = BOS_TOKEN_ID # should not be used\n",
237
- "config.pos_token_id = BOS_TOKEN_ID # should not be used\n",
238
- "#config.eos_token_id = None # prevents generation from stopping until we reach max_length"
239
- ],
240
- "execution_count": 6,
241
- "outputs": []
242
- },
243
- {
244
- "cell_type": "code",
245
- "metadata": {
246
- "id": "_6-XKK40oEfP"
247
- },
248
- "source": [
249
- "# create our model and initialize it randomly\n",
250
- "model = CustomFlaxBartForConditionalGeneration(config)"
251
- ],
252
- "execution_count": 7,
253
- "outputs": []
254
- },
255
- {
256
- "cell_type": "code",
257
- "metadata": {
258
- "id": "-r_hZestr-NR"
259
- },
260
- "source": [
261
- "# use pretrained weights\n",
262
- "model.params['model']['encoder'] = base_model.params['model']['encoder']\n",
263
- "model.params['model']['shared'] = base_model.params['model']['shared']"
264
- ],
265
- "execution_count": 8,
266
- "outputs": []
267
- },
268
- {
269
- "cell_type": "code",
270
- "metadata": {
271
- "id": "5NEX8f62sVjx"
272
- },
273
- "source": [
274
- "# no need for base_model anymore\n",
275
- "del base_model"
276
- ],
277
- "execution_count": 9,
278
- "outputs": []
279
- },
280
- {
281
- "cell_type": "code",
282
- "metadata": {
283
- "colab": {
284
- "base_uri": "https://localhost:8080/"
285
- },
286
- "id": "Jz032w73nHEf",
287
- "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49"
288
- },
289
- "source": [
290
- "# we verify that the shape has not been modified\n",
291
- "model.params['final_logits_bias'].shape"
292
- ],
293
- "execution_count": 10,
294
- "outputs": [
295
- {
296
- "output_type": "execute_result",
297
- "data": {
298
- "text/plain": [
299
- "(1, 16385)"
300
- ]
301
- },
302
- "metadata": {
303
- "tags": []
304
- },
305
- "execution_count": 10
306
- }
307
- ]
308
- },
309
- {
310
- "cell_type": "markdown",
311
- "metadata": {
312
- "id": "zLl24Ez5t7x1"
313
- },
314
- "source": [
315
- "## Inference"
316
- ]
317
- },
318
- {
319
- "cell_type": "code",
320
- "metadata": {
321
- "id": "XLLA2NK3uDQr"
322
- },
323
- "source": [
324
- "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)"
325
- ],
326
- "execution_count": 11,
327
- "outputs": []
328
- },
329
- {
330
- "cell_type": "code",
331
- "metadata": {
332
- "colab": {
333
- "base_uri": "https://localhost:8080/"
334
- },
335
- "id": "Ntow53I_t81D",
336
- "outputId": "59289cdd-1429-4720-cc87-88810c4b99ac"
337
- },
338
- "source": [
339
- "text = \"My friends are cool but they eat too many carbs.\"\n",
340
- "inputs = tokenizer(text, max_length=1024, return_tensors='jax')\n",
341
- "encoder_outputs = model.encode(**inputs)"
342
- ],
343
- "execution_count": 12,
344
- "outputs": [
345
- {
346
- "output_type": "stream",
347
- "text": [
348
- "Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n"
349
- ],
350
- "name": "stderr"
351
- }
352
- ]
353
  },
354
- {
355
- "cell_type": "code",
356
- "metadata": {
357
- "colab": {
358
- "base_uri": "https://localhost:8080/"
359
- },
360
- "id": "vcRNJnJ_uJOJ",
361
- "outputId": "025afd54-7908-4a9c-fb59-e40bd3458711"
362
- },
363
- "source": [
364
- "decoder_start_token_id = model.config.decoder_start_token_id\n",
365
- "decoder_start_token_id"
366
- ],
367
- "execution_count": 13,
368
- "outputs": [
369
- {
370
- "output_type": "execute_result",
371
- "data": {
372
- "text/plain": [
373
- "16384"
374
- ]
375
- },
376
- "metadata": {
377
- "tags": []
378
- },
379
- "execution_count": 13
380
- }
381
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  },
383
- {
384
- "cell_type": "code",
385
- "metadata": {
386
- "id": "6QWmEwL_uMld"
387
- },
388
- "source": [
389
- "decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype=\"i4\") * decoder_start_token_id\n",
390
- "outputs = model.decode(decoder_input_ids, encoder_outputs)"
391
- ],
392
- "execution_count": 14,
393
- "outputs": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  },
395
- {
396
- "cell_type": "code",
397
- "metadata": {
398
- "colab": {
399
- "base_uri": "https://localhost:8080/"
400
- },
401
- "id": "c_ys3yWBothF",
402
- "outputId": "40d4d584-e0a8-44cb-bbea-0ffa38d50a53"
403
- },
404
- "source": [
405
- "outputs"
406
- ],
407
- "execution_count": 15,
408
- "outputs": [
409
- {
410
- "output_type": "execute_result",
411
- "data": {
412
- "text/plain": [
413
- "FlaxCausalLMOutputWithCrossAttentions([('logits',\n",
414
- " DeviceArray([[[ 0.5263986 , -2.0947676 , -0.18830685, ..., 0.7599884 ,\n",
415
- " 0.6746795 , -1.0411576 ]]], dtype=float32))])"
416
- ]
417
- },
418
- "metadata": {
419
- "tags": []
420
- },
421
- "execution_count": 15
422
- }
423
- ]
424
  },
425
- {
426
- "cell_type": "code",
427
- "metadata": {
428
- "colab": {
429
- "base_uri": "https://localhost:8080/"
430
- },
431
- "id": "O6s0wtB_uTC_",
432
- "outputId": "bc0e9e80-e346-4e99-d28e-3f658eda1f66"
433
- },
434
- "source": [
435
- "outputs.logits.shape"
436
- ],
437
- "execution_count": 16,
438
- "outputs": [
439
- {
440
- "output_type": "execute_result",
441
- "data": {
442
- "text/plain": [
443
- "(1, 1, 16385)"
444
- ]
445
- },
446
- "metadata": {
447
- "tags": []
448
- },
449
- "execution_count": 16
450
- }
451
- ]
452
  },
453
- {
454
- "cell_type": "code",
455
- "metadata": {
456
- "colab": {
457
- "base_uri": "https://localhost:8080/"
458
- },
459
- "id": "ELzemGP3uBzy",
460
- "outputId": "dc12f98a-1ccf-450d-ba2a-9c29d7d14885"
461
- },
462
- "source": [
463
- "outputs.logits.argmax(axis=-1)"
464
- ],
465
- "execution_count": 17,
466
- "outputs": [
467
- {
468
- "output_type": "execute_result",
469
- "data": {
470
- "text/plain": [
471
- "DeviceArray([[12459]], dtype=int32)"
472
- ]
473
- },
474
- "metadata": {
475
- "tags": []
476
- },
477
- "execution_count": 17
478
- }
479
- ]
480
  },
481
- {
482
- "cell_type": "code",
483
- "metadata": {
484
- "colab": {
485
- "base_uri": "https://localhost:8080/"
486
- },
487
- "id": "fQjikkGEunpx",
488
- "outputId": "3dba0209-ad4e-4069-be38-6c599c677ef1"
489
- },
490
- "source": [
491
- "model.config.bos_token_id, model.config.eos_token_id, model.config.pad_token_id"
492
- ],
493
- "execution_count": 18,
494
- "outputs": [
495
- {
496
- "output_type": "execute_result",
497
- "data": {
498
- "text/plain": [
499
- "(16384, 2, 1)"
500
- ]
501
- },
502
- "metadata": {
503
- "tags": []
504
- },
505
- "execution_count": 18
506
- }
507
- ]
508
  },
509
- {
510
- "cell_type": "code",
511
- "metadata": {
512
- "id": "P32mJJSbrU1F"
513
- },
514
- "source": [
515
- "input_ids_test = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')"
516
- ],
517
- "execution_count": 19,
518
- "outputs": []
 
 
 
 
519
  },
520
- {
521
- "cell_type": "code",
522
- "metadata": {
523
- "id": "C7cHbIHruELT"
524
- },
525
- "source": [
526
- "greedy_output = model.generate(input_ids_test, max_length=50)"
527
- ],
528
- "execution_count": 20,
529
- "outputs": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  },
531
- {
532
- "cell_type": "code",
533
- "metadata": {
534
- "colab": {
535
- "base_uri": "https://localhost:8080/"
536
- },
537
- "id": "jYugh9cOuwc9",
538
- "outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537"
539
- },
540
- "source": [
541
- "greedy_output[0]"
542
- ],
543
- "execution_count": 21,
544
- "outputs": [
545
- {
546
- "output_type": "execute_result",
547
- "data": {
548
- "text/plain": [
549
- "DeviceArray([[16384, 0, 3570, 13405, 10186, 2392, 16362, 1869,\n",
550
- " 15772, 13546, 15772, 13546, 9348, 14791, 15772, 15772,\n",
551
- " 15772, 11272, 15772, 13546, 15772, 15772, 13546, 15772,\n",
552
- " 13546, 15772, 6642, 15772, 10776, 6431, 15772, 14567,\n",
553
- " 13406, 15772, 14567, 6235, 15772, 4909, 16160, 568,\n",
554
- " 4664, 6650, 8952, 9089, 15772, 5952, 7375, 10843,\n",
555
- " 8952, 2]], dtype=int32)"
556
- ]
557
- },
558
- "metadata": {
559
- "tags": []
560
- },
561
- "execution_count": 21
562
- }
563
- ]
564
- }
565
- ]
 
 
566
  }
 
1
  {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "ewer-Q-0w2xA"
7
+ },
8
+ "source": [
9
+ "# Installation"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
  "colab": {
17
+ "base_uri": "https://localhost:8080/"
 
 
 
 
 
 
 
18
  },
19
+ "id": "NpsF9ipLLl2s",
20
+ "outputId": "10bf54aa-b89d-4e42-9777-bc97b00a5f32"
21
+ },
22
+ "outputs": [],
23
+ "source": [
24
+ "!pip install git+https://github.com/huggingface/transformers/\n",
25
+ "!pip install git+https://github.com/google/flax"
26
+ ]
27
  },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {
32
+ "id": "M1wVkrpjU6zO"
33
+ },
34
+ "outputs": [],
35
+ "source": [
36
+ "%load_ext autoreload\n",
37
+ "%autoreload 2"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "markdown",
42
+ "metadata": {
43
+ "id": "t47CH1H_IOT8"
44
+ },
45
+ "source": [
46
+ "# Custom BART Model"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": null,
52
+ "metadata": {
53
+ "id": "9jQnM6S2vCpn"
54
+ },
55
+ "outputs": [],
56
+ "source": [
57
+ "# TODO: set those args in a config file\n",
58
+ "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
59
+ "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
60
+ "BOS_TOKEN_ID = 16384\n",
61
+ "BASE_MODEL = 'facebook/bart-large'"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "metadata": {
68
+ "id": "_eEaJVxAKpV5"
69
+ },
70
+ "outputs": [],
71
+ "source": [
72
+ "import jax\n",
73
+ "import flax.linen as nn\n",
74
+ "\n",
75
+ "from transformers.models.bart.modeling_flax_bart import *\n",
76
+ "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
77
+ "\n",
78
+ "class CustomFlaxBartModule(FlaxBartModule):\n",
79
+ " def setup(self):\n",
80
+ " # we keep shared to easily load pre-trained weights\n",
81
+ " self.shared = nn.Embed(\n",
82
+ " self.config.vocab_size,\n",
83
+ " self.config.d_model,\n",
84
+ " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
85
+ " dtype=self.dtype,\n",
86
+ " )\n",
87
+ " # a separate embedding is used for the decoder\n",
88
+ " self.decoder_embed = nn.Embed(\n",
89
+ " OUTPUT_VOCAB_SIZE,\n",
90
+ " self.config.d_model,\n",
91
+ " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
92
+ " dtype=self.dtype,\n",
93
+ " )\n",
94
+ " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
95
+ "\n",
96
+ " # the decoder has a different config\n",
97
+ " decoder_config = BartConfig(self.config.to_dict())\n",
98
+ " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
99
+ " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
100
+ " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
101
+ "\n",
102
+ "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
103
+ " def setup(self):\n",
104
+ " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
105
+ " self.lm_head = nn.Dense(\n",
106
+ " OUTPUT_VOCAB_SIZE,\n",
107
+ " use_bias=False,\n",
108
+ " dtype=self.dtype,\n",
109
+ " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
110
+ " )\n",
111
+ " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
112
+ "\n",
113
+ "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
114
+ " module_class = CustomFlaxBartForConditionalGenerationModule"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "metadata": {
121
+ "colab": {
122
+ "base_uri": "https://localhost:8080/"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  },
124
+ "id": "S7CP9Td9m2ge",
125
+ "outputId": "5638ef68-9c40-46f7-90ba-a4d05b61360d"
126
+ },
127
+ "outputs": [],
128
+ "source": [
129
+ "# load pre-trained model for encoder weights\n",
130
+ "base_model = FlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "metadata": {
137
+ "id": "6lmynR-poceH"
138
+ },
139
+ "outputs": [],
140
+ "source": [
141
+ "# set up our new model config\n",
142
+ "config = BartConfig.from_pretrained(BASE_MODEL)\n",
143
+ "config.tie_word_embeddings = False\n",
144
+ "config.decoder_start_token_id = BOS_TOKEN_ID\n",
145
+ "config.bos_token_id = BOS_TOKEN_ID # should not be used\n",
146
+ "config.pos_token_id = BOS_TOKEN_ID # should not be used\n",
147
+ "#config.eos_token_id = None # prevents generation from stopping until we reach max_length"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": null,
153
+ "metadata": {
154
+ "id": "_6-XKK40oEfP"
155
+ },
156
+ "outputs": [],
157
+ "source": [
158
+ "# create our model and initialize it randomly\n",
159
+ "model = CustomFlaxBartForConditionalGeneration(config)"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "metadata": {
166
+ "id": "-r_hZestr-NR"
167
+ },
168
+ "outputs": [],
169
+ "source": [
170
+ "# use pretrained weights\n",
171
+ "model.params['model']['encoder'] = base_model.params['model']['encoder']\n",
172
+ "model.params['model']['shared'] = base_model.params['model']['shared']"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": null,
178
+ "metadata": {
179
+ "id": "5NEX8f62sVjx"
180
+ },
181
+ "outputs": [],
182
+ "source": [
183
+ "# no need for base_model anymore\n",
184
+ "del base_model"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": null,
190
+ "metadata": {
191
+ "colab": {
192
+ "base_uri": "https://localhost:8080/"
193
  },
194
+ "id": "Jz032w73nHEf",
195
+ "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49"
196
+ },
197
+ "outputs": [],
198
+ "source": [
199
+ "# we verify that the shape has not been modified\n",
200
+ "model.params['final_logits_bias'].shape"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "markdown",
205
+ "metadata": {
206
+ "id": "zLl24Ez5t7x1"
207
+ },
208
+ "source": [
209
+ "## Inference"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "metadata": {
216
+ "id": "XLLA2NK3uDQr"
217
+ },
218
+ "outputs": [],
219
+ "source": [
220
+ "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": null,
226
+ "metadata": {
227
+ "colab": {
228
+ "base_uri": "https://localhost:8080/"
229
  },
230
+ "id": "Ntow53I_t81D",
231
+ "outputId": "59289cdd-1429-4720-cc87-88810c4b99ac"
232
+ },
233
+ "outputs": [],
234
+ "source": [
235
+ "text = \"My friends are cool but they eat too many carbs.\"\n",
236
+ "inputs = tokenizer(text, max_length=1024, return_tensors='jax')\n",
237
+ "encoder_outputs = model.encode(**inputs)"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "execution_count": null,
243
+ "metadata": {
244
+ "colab": {
245
+ "base_uri": "https://localhost:8080/"
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  },
247
+ "id": "vcRNJnJ_uJOJ",
248
+ "outputId": "025afd54-7908-4a9c-fb59-e40bd3458711"
249
+ },
250
+ "outputs": [],
251
+ "source": [
252
+ "decoder_start_token_id = model.config.decoder_start_token_id\n",
253
+ "decoder_start_token_id"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": null,
259
+ "metadata": {
260
+ "id": "6QWmEwL_uMld"
261
+ },
262
+ "outputs": [],
263
+ "source": [
264
+ "decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype=\"i4\") * decoder_start_token_id\n",
265
+ "outputs = model.decode(decoder_input_ids, encoder_outputs)"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": null,
271
+ "metadata": {
272
+ "colab": {
273
+ "base_uri": "https://localhost:8080/"
274
  },
275
+ "id": "c_ys3yWBothF",
276
+ "outputId": "40d4d584-e0a8-44cb-bbea-0ffa38d50a53"
277
+ },
278
+ "outputs": [],
279
+ "source": [
280
+ "outputs"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "code",
285
+ "execution_count": null,
286
+ "metadata": {
287
+ "colab": {
288
+ "base_uri": "https://localhost:8080/"
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  },
290
+ "id": "O6s0wtB_uTC_",
291
+ "outputId": "bc0e9e80-e346-4e99-d28e-3f658eda1f66"
292
+ },
293
+ "outputs": [],
294
+ "source": [
295
+ "outputs.logits.shape"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": null,
301
+ "metadata": {
302
+ "colab": {
303
+ "base_uri": "https://localhost:8080/"
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  },
305
+ "id": "ELzemGP3uBzy",
306
+ "outputId": "dc12f98a-1ccf-450d-ba2a-9c29d7d14885"
307
+ },
308
+ "outputs": [],
309
+ "source": [
310
+ "outputs.logits.argmax(axis=-1)"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": null,
316
+ "metadata": {
317
+ "colab": {
318
+ "base_uri": "https://localhost:8080/"
319
  },
320
+ "id": "fQjikkGEunpx",
321
+ "outputId": "3dba0209-ad4e-4069-be38-6c599c677ef1"
322
+ },
323
+ "outputs": [],
324
+ "source": [
325
+ "model.config.bos_token_id, model.config.eos_token_id, model.config.pad_token_id"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "metadata": {
332
+ "id": "P32mJJSbrU1F"
333
+ },
334
+ "outputs": [],
335
+ "source": [
336
+ "input_ids_test = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "code",
341
+ "execution_count": null,
342
+ "metadata": {
343
+ "id": "C7cHbIHruELT"
344
+ },
345
+ "outputs": [],
346
+ "source": [
347
+ "greedy_output = model.generate(input_ids_test, max_length=50)"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "code",
352
+ "execution_count": null,
353
+ "metadata": {
354
+ "colab": {
355
+ "base_uri": "https://localhost:8080/"
356
  },
357
+ "id": "jYugh9cOuwc9",
358
+ "outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537"
359
+ },
360
+ "outputs": [],
361
+ "source": [
362
+ "greedy_output[0]"
363
+ ]
364
+ }
365
+ ],
366
+ "metadata": {
367
+ "accelerator": "TPU",
368
+ "colab": {
369
+ "collapsed_sections": [],
370
+ "machine_shape": "hm",
371
+ "name": "CustomBARTv4b-model-generate.ipynb",
372
+ "provenance": []
373
+ },
374
+ "kernelspec": {
375
+ "display_name": "Python 3 (ipykernel)",
376
+ "language": "python",
377
+ "name": "python3"
378
+ },
379
+ "language_info": {
380
+ "codemirror_mode": {
381
+ "name": "ipython",
382
+ "version": 3
383
+ },
384
+ "file_extension": ".py",
385
+ "mimetype": "text/x-python",
386
+ "name": "python",
387
+ "nbconvert_exporter": "python",
388
+ "pygments_lexer": "ipython3",
389
+ "version": "3.8.5"
390
+ }
391
+ },
392
+ "nbformat": 4,
393
+ "nbformat_minor": 4
394
  }
dev/notebooks/demo/demo_notebook.ipynb CHANGED
@@ -27,7 +27,7 @@
27
  },
28
  {
29
  "cell_type": "code",
30
- "execution_count": 1,
31
  "metadata": {
32
  "id": "M1wVkrpjU6zO"
33
  },
@@ -39,17 +39,9 @@
39
  },
40
  {
41
  "cell_type": "code",
42
- "execution_count": 2,
43
  "metadata": {},
44
- "outputs": [
45
- {
46
- "name": "stdout",
47
- "output_type": "stream",
48
- "text": [
49
- "/home/tmabraham/vqgan-jax\n"
50
- ]
51
- }
52
- ],
53
  "source": [
54
  "%cd ../../vqgan-jax"
55
  ]
@@ -65,7 +57,7 @@
65
  },
66
  {
67
  "cell_type": "code",
68
- "execution_count": 3,
69
  "metadata": {
70
  "id": "9jQnM6S2vCpn"
71
  },
@@ -80,7 +72,7 @@
80
  },
81
  {
82
  "cell_type": "code",
83
- "execution_count": 4,
84
  "metadata": {
85
  "id": "_eEaJVxAKpV5"
86
  },
@@ -133,44 +125,11 @@
133
  },
134
  {
135
  "cell_type": "code",
136
- "execution_count": 5,
137
  "metadata": {
138
  "scrolled": true
139
  },
140
- "outputs": [
141
- {
142
- "name": "stderr",
143
- "output_type": "stream",
144
- "text": [
145
- "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mtmabraham\u001b[0m (use `wandb login --relogin` to force relogin)\n"
146
- ]
147
- },
148
- {
149
- "data": {
150
- "text/html": [
151
- "\n",
152
- " Tracking run with wandb version 0.10.33<br/>\n",
153
- " Syncing run <strong style=\"color:#cdcd00\">rare-night-7</strong> to <a href=\"https://wandb.ai\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
154
- " Project page: <a href=\"https://wandb.ai/tmabraham/vqgan-jax\" target=\"_blank\">https://wandb.ai/tmabraham/vqgan-jax</a><br/>\n",
155
- " Run page: <a href=\"https://wandb.ai/tmabraham/vqgan-jax/runs/qzxavce8\" target=\"_blank\">https://wandb.ai/tmabraham/vqgan-jax/runs/qzxavce8</a><br/>\n",
156
- " Run data is saved locally in <code>/home/tmabraham/vqgan-jax/wandb/run-20210715_075019-qzxavce8</code><br/><br/>\n",
157
- " "
158
- ],
159
- "text/plain": [
160
- "<IPython.core.display.HTML object>"
161
- ]
162
- },
163
- "metadata": {},
164
- "output_type": "display_data"
165
- },
166
- {
167
- "name": "stderr",
168
- "output_type": "stream",
169
- "text": [
170
- "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-1ef8yxby:latest, 1674.97MB. 2 files... Done. 0:0:0\n"
171
- ]
172
- }
173
- ],
174
  "source": [
175
  "import wandb\n",
176
  "run = wandb.init()\n",
@@ -180,24 +139,12 @@
180
  },
181
  {
182
  "cell_type": "code",
183
- "execution_count": 6,
184
  "metadata": {
185
  "id": "_6-XKK40oEfP",
186
  "scrolled": true
187
  },
188
- "outputs": [
189
- {
190
- "name": "stderr",
191
- "output_type": "stream",
192
- "text": [
193
- "/home/tmabraham/dalle-mini/src/transformers/src/transformers/models/bart/configuration_bart.py:180: UserWarning: Please make sure the config includes `forced_bos_token_id=16384` in future versions.The config can simply be saved and uploaded again to be fixed.\n",
194
- " warnings.warn(\n",
195
- "INFO:absl:Starting the local TPU driver.\n",
196
- "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
197
- "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: TPU Interpreter Host\n"
198
- ]
199
- }
200
- ],
201
  "source": [
202
  "# create our model and initialize it randomly\n",
203
  "model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)"
@@ -205,7 +152,7 @@
205
  },
206
  {
207
  "cell_type": "code",
208
- "execution_count": 7,
209
  "metadata": {},
210
  "outputs": [],
211
  "source": [
@@ -214,7 +161,7 @@
214
  },
215
  {
216
  "cell_type": "code",
217
- "execution_count": 8,
218
  "metadata": {
219
  "colab": {
220
  "base_uri": "https://localhost:8080/"
@@ -222,18 +169,7 @@
222
  "id": "Jz032w73nHEf",
223
  "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49"
224
  },
225
- "outputs": [
226
- {
227
- "data": {
228
- "text/plain": [
229
- "(1, 16385)"
230
- ]
231
- },
232
- "execution_count": 8,
233
- "metadata": {},
234
- "output_type": "execute_result"
235
- }
236
- ],
237
  "source": [
238
  "# we verify that the shape has not been modified\n",
239
  "model.params['final_logits_bias'].shape"
@@ -250,7 +186,7 @@
250
  },
251
  {
252
  "cell_type": "code",
253
- "execution_count": 9,
254
  "metadata": {
255
  "id": "XLLA2NK3uDQr"
256
  },
@@ -261,7 +197,7 @@
261
  },
262
  {
263
  "cell_type": "code",
264
- "execution_count": 10,
265
  "metadata": {},
266
  "outputs": [],
267
  "source": [
@@ -270,7 +206,7 @@
270
  },
271
  {
272
  "cell_type": "code",
273
- "execution_count": 11,
274
  "metadata": {
275
  "id": "P32mJJSbrU1F"
276
  },
@@ -281,49 +217,16 @@
281
  },
282
  {
283
  "cell_type": "code",
284
- "execution_count": 12,
285
  "metadata": {},
286
- "outputs": [
287
- {
288
- "data": {
289
- "text/plain": [
290
- "{'input_ids': DeviceArray([[ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
291
- " 2],\n",
292
- " [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
293
- " 2],\n",
294
- " [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
295
- " 2],\n",
296
- " [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
297
- " 2],\n",
298
- " [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
299
- " 2],\n",
300
- " [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
301
- " 2],\n",
302
- " [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
303
- " 2],\n",
304
- " [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
305
- " 2]], dtype=int32), 'attention_mask': DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
306
- " [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
307
- " [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
308
- " [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
309
- " [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
310
- " [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
311
- " [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
312
- " [1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)}"
313
- ]
314
- },
315
- "execution_count": 12,
316
- "metadata": {},
317
- "output_type": "execute_result"
318
- }
319
- ],
320
  "source": [
321
  "input_ids_test"
322
  ]
323
  },
324
  {
325
  "cell_type": "code",
326
- "execution_count": 13,
327
  "metadata": {
328
  "id": "C7cHbIHruELT"
329
  },
@@ -334,27 +237,16 @@
334
  },
335
  {
336
  "cell_type": "code",
337
- "execution_count": 14,
338
  "metadata": {},
339
- "outputs": [
340
- {
341
- "data": {
342
- "text/plain": [
343
- "(8, 257)"
344
- ]
345
- },
346
- "execution_count": 14,
347
- "metadata": {},
348
- "output_type": "execute_result"
349
- }
350
- ],
351
  "source": [
352
  "greedy_output[0].shape"
353
  ]
354
  },
355
  {
356
  "cell_type": "code",
357
- "execution_count": 15,
358
  "metadata": {
359
  "colab": {
360
  "base_uri": "https://localhost:8080/"
@@ -362,76 +254,16 @@
362
  "id": "jYugh9cOuwc9",
363
  "outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537"
364
  },
365
- "outputs": [
366
- {
367
- "data": {
368
- "text/plain": [
369
- "DeviceArray([[16384, 10042, 10042, ..., 10042, 10042, 9570],\n",
370
- " [16384, 10042, 10042, ..., 10042, 10042, 9570],\n",
371
- " [16384, 10042, 10042, ..., 10042, 10042, 9570],\n",
372
- " ...,\n",
373
- " [16384, 10042, 10042, ..., 10042, 10042, 9570],\n",
374
- " [16384, 10042, 10042, ..., 10042, 10042, 9570],\n",
375
- " [16384, 10042, 10042, ..., 10042, 10042, 9570]], dtype=int32)"
376
- ]
377
- },
378
- "execution_count": 15,
379
- "metadata": {},
380
- "output_type": "execute_result"
381
- }
382
- ],
383
  "source": [
384
  "greedy_output[0]"
385
  ]
386
  },
387
  {
388
  "cell_type": "code",
389
- "execution_count": 16,
390
  "metadata": {},
391
- "outputs": [
392
- {
393
- "data": {
394
- "text/plain": [
395
- "DeviceArray([16384, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
396
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
397
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
398
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
399
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
400
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
401
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
402
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
403
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
404
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
405
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
406
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
407
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
408
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
409
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
410
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
411
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
412
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
413
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
414
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
415
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
416
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
417
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
418
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
419
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
420
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
421
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
422
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
423
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
424
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
425
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
426
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
427
- " 9570], dtype=int32)"
428
- ]
429
- },
430
- "execution_count": 16,
431
- "metadata": {},
432
- "output_type": "execute_result"
433
- }
434
- ],
435
  "source": [
436
  "greedy_output[0][0]"
437
  ]
@@ -445,7 +277,7 @@
445
  },
446
  {
447
  "cell_type": "code",
448
- "execution_count": 17,
449
  "metadata": {},
450
  "outputs": [],
451
  "source": [
@@ -463,7 +295,7 @@
463
  },
464
  {
465
  "cell_type": "code",
466
- "execution_count": 18,
467
  "metadata": {},
468
  "outputs": [],
469
  "source": [
@@ -472,7 +304,7 @@
472
  },
473
  {
474
  "cell_type": "code",
475
- "execution_count": 19,
476
  "metadata": {},
477
  "outputs": [],
478
  "source": [
@@ -487,7 +319,7 @@
487
  },
488
  {
489
  "cell_type": "code",
490
- "execution_count": 20,
491
  "metadata": {
492
  "colab": {
493
  "base_uri": "https://localhost:8080/"
@@ -496,22 +328,14 @@
496
  "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49",
497
  "scrolled": true
498
  },
499
- "outputs": [
500
- {
501
- "name": "stdout",
502
- "output_type": "stream",
503
- "text": [
504
- "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
505
- ]
506
- }
507
- ],
508
  "source": [
509
  "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
510
  ]
511
  },
512
  {
513
  "cell_type": "code",
514
- "execution_count": 21,
515
  "metadata": {},
516
  "outputs": [],
517
  "source": [
@@ -524,29 +348,9 @@
524
  },
525
  {
526
  "cell_type": "code",
527
- "execution_count": 22,
528
  "metadata": {},
529
- "outputs": [
530
- {
531
- "name": "stdout",
532
- "output_type": "stream",
533
- "text": [
534
- "(1, 256)\n",
535
- "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
536
- ]
537
- },
538
- {
539
- "data": {
540
- "image/png": "\n",
541
- "text/plain": [
542
- "<PIL.Image.Image image mode=RGB size=256x256 at 0x7FA20677A400>"
543
- ]
544
- },
545
- "execution_count": 22,
546
- "metadata": {},
547
- "output_type": "execute_result"
548
- }
549
- ],
550
  "source": [
551
  "custom_to_pil(np.asarray(get_images(jnp.expand_dims(greedy_output[0][0],0), model)[0]))"
552
  ]
@@ -561,7 +365,7 @@
561
  "provenance": []
562
  },
563
  "kernelspec": {
564
- "display_name": "Python 3",
565
  "language": "python",
566
  "name": "python3"
567
  },
@@ -575,9 +379,9 @@
575
  "name": "python",
576
  "nbconvert_exporter": "python",
577
  "pygments_lexer": "ipython3",
578
- "version": "3.8.8"
579
  }
580
  },
581
  "nbformat": 4,
582
- "nbformat_minor": 1
583
  }
 
27
  },
28
  {
29
  "cell_type": "code",
30
+ "execution_count": null,
31
  "metadata": {
32
  "id": "M1wVkrpjU6zO"
33
  },
 
39
  },
40
  {
41
  "cell_type": "code",
42
+ "execution_count": null,
43
  "metadata": {},
44
+ "outputs": [],
 
 
 
 
 
 
 
 
45
  "source": [
46
  "%cd ../../vqgan-jax"
47
  ]
 
57
  },
58
  {
59
  "cell_type": "code",
60
+ "execution_count": null,
61
  "metadata": {
62
  "id": "9jQnM6S2vCpn"
63
  },
 
72
  },
73
  {
74
  "cell_type": "code",
75
+ "execution_count": null,
76
  "metadata": {
77
  "id": "_eEaJVxAKpV5"
78
  },
 
125
  },
126
  {
127
  "cell_type": "code",
128
+ "execution_count": null,
129
  "metadata": {
130
  "scrolled": true
131
  },
132
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  "source": [
134
  "import wandb\n",
135
  "run = wandb.init()\n",
 
139
  },
140
  {
141
  "cell_type": "code",
142
+ "execution_count": null,
143
  "metadata": {
144
  "id": "_6-XKK40oEfP",
145
  "scrolled": true
146
  },
147
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
148
  "source": [
149
  "# create our model and initialize it randomly\n",
150
  "model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)"
 
152
  },
153
  {
154
  "cell_type": "code",
155
+ "execution_count": null,
156
  "metadata": {},
157
  "outputs": [],
158
  "source": [
 
161
  },
162
  {
163
  "cell_type": "code",
164
+ "execution_count": null,
165
  "metadata": {
166
  "colab": {
167
  "base_uri": "https://localhost:8080/"
 
169
  "id": "Jz032w73nHEf",
170
  "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49"
171
  },
172
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
173
  "source": [
174
  "# we verify that the shape has not been modified\n",
175
  "model.params['final_logits_bias'].shape"
 
186
  },
187
  {
188
  "cell_type": "code",
189
+ "execution_count": null,
190
  "metadata": {
191
  "id": "XLLA2NK3uDQr"
192
  },
 
197
  },
198
  {
199
  "cell_type": "code",
200
+ "execution_count": null,
201
  "metadata": {},
202
  "outputs": [],
203
  "source": [
 
206
  },
207
  {
208
  "cell_type": "code",
209
+ "execution_count": null,
210
  "metadata": {
211
  "id": "P32mJJSbrU1F"
212
  },
 
217
  },
218
  {
219
  "cell_type": "code",
220
+ "execution_count": null,
221
  "metadata": {},
222
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  "source": [
224
  "input_ids_test"
225
  ]
226
  },
227
  {
228
  "cell_type": "code",
229
+ "execution_count": null,
230
  "metadata": {
231
  "id": "C7cHbIHruELT"
232
  },
 
237
  },
238
  {
239
  "cell_type": "code",
240
+ "execution_count": null,
241
  "metadata": {},
242
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
243
  "source": [
244
  "greedy_output[0].shape"
245
  ]
246
  },
247
  {
248
  "cell_type": "code",
249
+ "execution_count": null,
250
  "metadata": {
251
  "colab": {
252
  "base_uri": "https://localhost:8080/"
 
254
  "id": "jYugh9cOuwc9",
255
  "outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537"
256
  },
257
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  "source": [
259
  "greedy_output[0]"
260
  ]
261
  },
262
  {
263
  "cell_type": "code",
264
+ "execution_count": null,
265
  "metadata": {},
266
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  "source": [
268
  "greedy_output[0][0]"
269
  ]
 
277
  },
278
  {
279
  "cell_type": "code",
280
+ "execution_count": null,
281
  "metadata": {},
282
  "outputs": [],
283
  "source": [
 
295
  },
296
  {
297
  "cell_type": "code",
298
+ "execution_count": null,
299
  "metadata": {},
300
  "outputs": [],
301
  "source": [
 
304
  },
305
  {
306
  "cell_type": "code",
307
+ "execution_count": null,
308
  "metadata": {},
309
  "outputs": [],
310
  "source": [
 
319
  },
320
  {
321
  "cell_type": "code",
322
+ "execution_count": null,
323
  "metadata": {
324
  "colab": {
325
  "base_uri": "https://localhost:8080/"
 
328
  "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49",
329
  "scrolled": true
330
  },
331
+ "outputs": [],
 
 
 
 
 
 
 
 
332
  "source": [
333
  "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
334
  ]
335
  },
336
  {
337
  "cell_type": "code",
338
+ "execution_count": null,
339
  "metadata": {},
340
  "outputs": [],
341
  "source": [
 
348
  },
349
  {
350
  "cell_type": "code",
351
+ "execution_count": null,
352
  "metadata": {},
353
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  "source": [
355
  "custom_to_pil(np.asarray(get_images(jnp.expand_dims(greedy_output[0][0],0), model)[0]))"
356
  ]
 
365
  "provenance": []
366
  },
367
  "kernelspec": {
368
+ "display_name": "Python 3 (ipykernel)",
369
  "language": "python",
370
  "name": "python3"
371
  },
 
379
  "name": "python",
380
  "nbconvert_exporter": "python",
381
  "pygments_lexer": "ipython3",
382
+ "version": "3.8.5"
383
  }
384
  },
385
  "nbformat": 4,
386
+ "nbformat_minor": 4
387
  }
dev/notebooks/demo/tpu-demo.ipynb CHANGED
The diff for this file is too large to render. See raw diff