valhalla commited on
Commit
c879290
·
1 Parent(s): 4b5a542

remove .ipynb

Browse files
demo/.ipynb_checkpoints/tpu-demo-checkpoint.ipynb DELETED
@@ -1,391 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "id": "6eb74941-bb4d-4d7e-97f1-d5a3a07672bf",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "# !pip install flax transformers\n",
11
- "# !git clone https://github.com/patil-suraj/vqgan-jax.git"
12
- ]
13
- },
14
- {
15
- "cell_type": "code",
16
- "execution_count": 305,
17
- "id": "41db7534-f589-4b63-9165-9c9799e1b06e",
18
- "metadata": {},
19
- "outputs": [
20
- {
21
- "name": "stdout",
22
- "output_type": "stream",
23
- "text": [
24
- "/home/surajpatil/vqgan-jax\n"
25
- ]
26
- },
27
- {
28
- "data": {
29
- "text/plain": [
30
- "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n",
31
- " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n",
32
- " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n",
33
- " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n",
34
- " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n",
35
- " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n",
36
- " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n",
37
- " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"
38
- ]
39
- },
40
- "execution_count": 305,
41
- "metadata": {},
42
- "output_type": "execute_result"
43
- }
44
- ],
45
- "source": [
46
- "%cd ~/vqgan-jax\n",
47
- "\n",
48
- "import random\n",
49
- "\n",
50
- "\n",
51
- "import jax\n",
52
- "import flax.linen as nn\n",
53
- "from flax.training.common_utils import shard\n",
54
- "from flax.jax_utils import replicate, unreplicate\n",
55
- "\n",
56
- "from transformers.models.bart.modeling_flax_bart import *\n",
57
- "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
58
- "\n",
59
- "import io\n",
60
- "\n",
61
- "import requests\n",
62
- "from PIL import Image\n",
63
- "import numpy as np\n",
64
- "import matplotlib.pyplot as plt\n",
65
- "\n",
66
- "import torch\n",
67
- "import torchvision.transforms as T\n",
68
- "import torchvision.transforms.functional as TF\n",
69
- "from torchvision.transforms import InterpolationMode\n",
70
- "\n",
71
- "\n",
72
- "from modeling_flax_vqgan import VQModel\n",
73
- "\n",
74
- "jax.devices()"
75
- ]
76
- },
77
- {
78
- "cell_type": "code",
79
- "execution_count": 2,
80
- "id": "b6a3462a-9004-4121-b365-3ae3aaf94dd2",
81
- "metadata": {},
82
- "outputs": [],
83
- "source": [
84
- "# TODO: set those args in a config file\n",
85
- "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
86
- "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
87
- "BOS_TOKEN_ID = 16384\n",
88
- "BASE_MODEL = 'facebook/bart-large'"
89
- ]
90
- },
91
- {
92
- "cell_type": "code",
93
- "execution_count": 3,
94
- "id": "bbef1afb-0b36-44a5-83f7-643d7e2c0e30",
95
- "metadata": {},
96
- "outputs": [],
97
- "source": [
98
- "class CustomFlaxBartModule(FlaxBartModule):\n",
99
- " def setup(self):\n",
100
- " # we keep shared to easily load pre-trained weights\n",
101
- " self.shared = nn.Embed(\n",
102
- " self.config.vocab_size,\n",
103
- " self.config.d_model,\n",
104
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
105
- " dtype=self.dtype,\n",
106
- " )\n",
107
- " # a separate embedding is used for the decoder\n",
108
- " self.decoder_embed = nn.Embed(\n",
109
- " OUTPUT_VOCAB_SIZE,\n",
110
- " self.config.d_model,\n",
111
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
112
- " dtype=self.dtype,\n",
113
- " )\n",
114
- " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
115
- "\n",
116
- " # the decoder has a different config\n",
117
- " decoder_config = BartConfig(self.config.to_dict())\n",
118
- " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
119
- " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
120
- " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
121
- "\n",
122
- "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
123
- " def setup(self):\n",
124
- " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
125
- " self.lm_head = nn.Dense(\n",
126
- " OUTPUT_VOCAB_SIZE,\n",
127
- " use_bias=False,\n",
128
- " dtype=self.dtype,\n",
129
- " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
130
- " )\n",
131
- " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
132
- "\n",
133
- "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
134
- " module_class = CustomFlaxBartForConditionalGenerationModule"
135
- ]
136
- },
137
- {
138
- "cell_type": "code",
139
- "execution_count": null,
140
- "id": "879320b7-eaa0-4dc9-bbf2-c81efc53301d",
141
- "metadata": {},
142
- "outputs": [],
143
- "source": [
144
- "import wandb\n",
145
- "run = wandb.init()\n",
146
- "artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-3h3x3565:v7', type='bart_model')\n",
147
- "artifact_dir = artifact.download()"
148
- ]
149
- },
150
- {
151
- "cell_type": "code",
152
- "execution_count": 164,
153
- "id": "e8bcff33-e95b-4c01-b162-ee857a55c3e6",
154
- "metadata": {},
155
- "outputs": [
156
- {
157
- "name": "stderr",
158
- "output_type": "stream",
159
- "text": [
160
- "/home/surajpatil/transformers/src/transformers/models/bart/configuration_bart.py:177: UserWarning: Please make sure the config includes `forced_bos_token_id=16384` in future versions.The config can simply be saved and uploaded again to be fixed.\n",
161
- " warnings.warn(\n"
162
- ]
163
- },
164
- {
165
- "data": {
166
- "text/plain": [
167
- "(1, 16385)"
168
- ]
169
- },
170
- "execution_count": 164,
171
- "metadata": {},
172
- "output_type": "execute_result"
173
- }
174
- ],
175
- "source": [
176
- "# create our model and initialize it randomly\n",
177
- "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)\n",
178
- "model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)\n",
179
- "model.config.force_bos_token_to_be_generated = False\n",
180
- "model.config.forced_bos_token_id = None\n",
181
- "model.config.forced_eos_token_id = None\n",
182
- "\n",
183
- "# we verify that the shape has not been modified\n",
184
- "model.params['final_logits_bias'].shape"
185
- ]
186
- },
187
- {
188
- "cell_type": "code",
189
- "execution_count": 6,
190
- "id": "8d5e0f14-2502-470e-9553-daee6748601f",
191
- "metadata": {},
192
- "outputs": [
193
- {
194
- "data": {
195
- "application/vnd.jupyter.widget-view+json": {
196
- "model_id": "9b979a72ab9e449387a89bf9b3012af5",
197
- "version_major": 2,
198
- "version_minor": 0
199
- },
200
- "text/plain": [
201
- "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…"
202
- ]
203
- },
204
- "metadata": {},
205
- "output_type": "display_data"
206
- },
207
- {
208
- "name": "stdout",
209
- "output_type": "stream",
210
- "text": [
211
- "\n"
212
- ]
213
- },
214
- {
215
- "data": {
216
- "application/vnd.jupyter.widget-view+json": {
217
- "model_id": "01730e0e9d02428ca9dad680f9fdda42",
218
- "version_major": 2,
219
- "version_minor": 0
220
- },
221
- "text/plain": [
222
- "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=304307206.0, style=ProgressStyle(descri…"
223
- ]
224
- },
225
- "metadata": {},
226
- "output_type": "display_data"
227
- },
228
- {
229
- "name": "stdout",
230
- "output_type": "stream",
231
- "text": [
232
- "\n",
233
- "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
234
- ]
235
- }
236
- ],
237
- "source": [
238
- "vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
239
- ]
240
- },
241
- {
242
- "cell_type": "code",
243
- "execution_count": 295,
244
- "id": "6cca395a-93c2-49bc-a3be-98287e4403d4",
245
- "metadata": {},
246
- "outputs": [],
247
- "source": [
248
- "def custom_to_pil(x):\n",
249
- " x = np.clip(x, 0., 1.)\n",
250
- " x = (255*x).astype(np.uint8)\n",
251
- " x = Image.fromarray(x)\n",
252
- " if not x.mode == \"RGB\":\n",
253
- " x = x.convert(\"RGB\")\n",
254
- " return x\n",
255
- "\n",
256
- "def generate(input, rng, params):\n",
257
- " return model.generate(\n",
258
- " **input,\n",
259
- " max_length=257,\n",
260
- " num_beams=1,\n",
261
- " do_sample=True,\n",
262
- " prng_key=rng,\n",
263
- " eos_token_id=50000,\n",
264
- " pad_token_id=50000,\n",
265
- " params=params\n",
266
- " )\n",
267
- "\n",
268
- "def get_images(indices, params):\n",
269
- " return vqgan.decode_code(indices, params=params)\n",
270
- "\n",
271
- "\n",
272
- "def plot_images(images):\n",
273
- " fig = plt.figure(figsize=(40, 20))\n",
274
- " columns = 4\n",
275
- " rows = 2\n",
276
- " plt.subplots_adjust(hspace=0, wspace=0)\n",
277
- "\n",
278
- " for i in range(1, columns*rows +1):\n",
279
- " fig.add_subplot(rows, columns, i)\n",
280
- " plt.imshow(images[i-1])\n",
281
- " plt.gca().axes.get_yaxis().set_visible(False)\n",
282
- " plt.show()\n",
283
- " \n",
284
- "def stack_reconstructions(images):\n",
285
- " w, h = images[0].size[0], images[0].size[1]\n",
286
- " img = Image.new(\"RGB\", (len(images)*w, h))\n",
287
- " for i, img_ in enumerate(images):\n",
288
- " img.paste(img_, (i*w,0))\n",
289
- " return img"
290
- ]
291
- },
292
- {
293
- "cell_type": "code",
294
- "execution_count": 166,
295
- "id": "b1bec3d2-ef17-4feb-aa0d-b51ed2fdcd3e",
296
- "metadata": {},
297
- "outputs": [],
298
- "source": [
299
- "p_generate = jax.pmap(generate, \"batch\")\n",
300
- "p_get_images = jax.pmap(get_images, \"batch\")"
301
- ]
302
- },
303
- {
304
- "cell_type": "code",
305
- "execution_count": null,
306
- "id": "a539823a-a775-4d92-96a5-dc8b1eef69c5",
307
- "metadata": {},
308
- "outputs": [],
309
- "source": [
310
- "bart_params = replicate(model.params)\n",
311
- "vqgan_params = replicate(vqgan.params)"
312
- ]
313
- },
314
- {
315
- "cell_type": "code",
316
- "execution_count": 328,
317
- "id": "e8b268d8-6992-422a-8373-95651474ae70",
318
- "metadata": {},
319
- "outputs": [],
320
- "source": [
321
- "prompts = [\n",
322
- " \"man in blue jacket walking on pathway in between trees during daytime\",\n",
323
- " 'white snow covered mountain under blue sky during daytime',\n",
324
- " 'white snow covered mountain under blue sky during night',\n",
325
- " \"orange tabby cat on persons hand\",\n",
326
- " \"aerial view of beach during daytime\",\n",
327
- " \"chess pieces on chess board\",\n",
328
- " \"laptop on brown wooden table\",\n",
329
- " \"white bus on road near high rise buildings\",\n",
330
- "]\n",
331
- "\n",
332
- "\n",
333
- "prompt = [prompts[-1]] * 8\n",
334
- "inputs = tokenizer(prompt, return_tensors='jax', padding=\"max_length\", truncation=True, max_length=128).data\n",
335
- "inputs = shard(inputs)"
336
- ]
337
- },
338
- {
339
- "cell_type": "code",
340
- "execution_count": null,
341
- "id": "68638cfa-9a4d-4e6a-8630-91aefb627bbd",
342
- "metadata": {},
343
- "outputs": [],
344
- "source": [
345
- "%%time\n",
346
- "for i in range(8):\n",
347
- " key = random.randint(0, 1e7)\n",
348
- " rng = jax.random.PRNGKey(key)\n",
349
- " rngs = jax.random.split(rng, jax.local_device_count())\n",
350
- " indices = p_generate(inputs, rngs, bart_params).sequences\n",
351
- " indices = indices[:, :, 1:]\n",
352
- "\n",
353
- " images = p_get_images(indices, vqgan_params)\n",
354
- " images = np.squeeze(np.asarray(images), 1)\n",
355
- " imges = [custom_to_pil(image) for image in images]\n",
356
- "\n",
357
- " plt.figure(figsize=(40, 20))\n",
358
- " plt.imshow(stack_reconstructions(imges))"
359
- ]
360
- },
361
- {
362
- "cell_type": "code",
363
- "execution_count": null,
364
- "id": "681af54e-da10-4b8e-80d0-ebcbdf23f376",
365
- "metadata": {},
366
- "outputs": [],
367
- "source": []
368
- }
369
- ],
370
- "metadata": {
371
- "kernelspec": {
372
- "display_name": "Python 3",
373
- "language": "python",
374
- "name": "python3"
375
- },
376
- "language_info": {
377
- "codemirror_mode": {
378
- "name": "ipython",
379
- "version": 3
380
- },
381
- "file_extension": ".py",
382
- "mimetype": "text/x-python",
383
- "name": "python",
384
- "nbconvert_exporter": "python",
385
- "pygments_lexer": "ipython3",
386
- "version": "3.8.10"
387
- }
388
- },
389
- "nbformat": 4,
390
- "nbformat_minor": 5
391
- }