Pedro Cuenca commited on
Commit
82fad8c
·
1 Parent(s): b4dfea0

Notebook to encode splitted YFCC100M files.

Browse files

File paths need to be updated.

Splits can be created using a command like:

```
mkdir metadata_splitted
cd metadata_splitted
split -l 500000 --numeric-suffixes ../metadata_YFCC100M.jsonl metadata_split_
```

Encoded files will be saved to the directory specified by
`yfcc100m_output`, and their names will be the same as the source
splits.

encoding/vqgan-jax-encoding-yfcc100m-splitted.ipynb ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "d0b72877",
6
+ "metadata": {},
7
+ "source": [
8
+ "# vqgan-jax-encoding-yfcc100m"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "747733a4",
14
+ "metadata": {},
15
+ "source": [
16
+ "Same as `vqgan-jax-encoding-with-captions`, but for YFCC100M.\n",
17
+ "\n",
18
+ "This dataset was prepared by @borisdayma in Json lines format."
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 1,
24
+ "id": "3b59489e",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "import io\n",
29
+ "\n",
30
+ "import requests\n",
31
+ "from PIL import Image\n",
32
+ "import numpy as np\n",
33
+ "from tqdm import tqdm\n",
34
+ "\n",
35
+ "import torch\n",
36
+ "import torchvision.transforms as T\n",
37
+ "import torchvision.transforms.functional as TF\n",
38
+ "from torchvision.transforms import InterpolationMode\n",
39
+ "from torch.utils.data import Dataset, DataLoader\n",
40
+ "from torchvision.datasets.folder import default_loader\n",
41
+ "\n",
42
+ "import jax\n",
43
+ "from jax import pmap"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "id": "511c3b9e",
49
+ "metadata": {},
50
+ "source": [
51
+ "## VQGAN-JAX model"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "markdown",
56
+ "id": "bb408f6c",
57
+ "metadata": {},
58
+ "source": [
59
+ "`dalle_mini` is a local package that contains the VQGAN-JAX model and other utilities."
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": 2,
65
+ "id": "2ca50dc7",
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "markdown",
74
+ "id": "7b60da9a",
75
+ "metadata": {},
76
+ "source": [
77
+ "We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model."
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 4,
83
+ "id": "29ce8b15",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "markdown",
92
+ "id": "c7c4c1e6",
93
+ "metadata": {},
94
+ "source": [
95
+ "## Dataset"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "markdown",
100
+ "id": "fd4c608e",
101
+ "metadata": {},
102
+ "source": [
103
+ "I splitted the files to do the process iteratively. Pandas struggles with memory and `datasets` has problems when filtering files, as described [in this issue](https://github.com/huggingface/datasets/issues/2644)."
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": 5,
109
+ "id": "6c058636",
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "import pandas as pd\n",
114
+ "from pathlib import Path"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 6,
120
+ "id": "81b19eca",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "yfcc100m = Path('/sddata/dalle-mini/YFCC100M_OpenAI_subset')\n",
125
+ "# Images are 'sharded' from the following directory\n",
126
+ "yfcc100m_images = yfcc100m/'data'/'images'\n",
127
+ "yfcc100m_metadata_splits = yfcc100m/'metadata_splitted'\n",
128
+ "yfcc100m_output = yfcc100m/'metadata_encoded'"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": 7,
134
+ "id": "40873de9",
135
+ "metadata": {},
136
+ "outputs": [
137
+ {
138
+ "data": {
139
+ "text/plain": [
140
+ "[PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_04'),\n",
141
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_25'),\n",
142
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_17'),\n",
143
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_10'),\n",
144
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_22'),\n",
145
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_28'),\n",
146
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_09'),\n",
147
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_03'),\n",
148
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_07'),\n",
149
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_26'),\n",
150
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_14'),\n",
151
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_19'),\n",
152
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_13'),\n",
153
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_21'),\n",
154
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_00'),\n",
155
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_02'),\n",
156
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_08'),\n",
157
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_11'),\n",
158
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_29'),\n",
159
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_23'),\n",
160
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_24'),\n",
161
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_16'),\n",
162
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_05'),\n",
163
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_01'),\n",
164
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_12'),\n",
165
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_18'),\n",
166
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_20'),\n",
167
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_27'),\n",
168
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_15'),\n",
169
+ " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_06')]"
170
+ ]
171
+ },
172
+ "execution_count": 7,
173
+ "metadata": {},
174
+ "output_type": "execute_result"
175
+ }
176
+ ],
177
+ "source": [
178
+ "all_splits = [x for x in yfcc100m_metadata_splits.iterdir() if x.is_file()]\n",
179
+ "all_splits"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "markdown",
184
+ "id": "f604e3c9",
185
+ "metadata": {},
186
+ "source": [
187
+ "### Cleanup"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": 8,
193
+ "id": "dea06b92",
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "def image_exists(root: str, name: str, ext: str):\n",
198
+ " image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix(ext)\n",
199
+ " return image_path.exists()"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": 9,
205
+ "id": "1d34d7aa",
206
+ "metadata": {},
207
+ "outputs": [],
208
+ "source": [
209
+ "class YFC100Dataset(Dataset):\n",
210
+ " def __init__(self, image_list: pd.DataFrame, images_root: str, image_size: int, max_items=None):\n",
211
+ " \"\"\"\n",
212
+ " :param image_list: DataFrame with clean entries - all images must exist.\n",
213
+ " :param images_root: Root directory containing the images\n",
214
+ " :param image_size: Image size. Source images will be resized and center-cropped.\n",
215
+ " :max_items: Limit dataset size for debugging\n",
216
+ " \"\"\"\n",
217
+ " self.image_list = image_list\n",
218
+ " self.images_root = Path(images_root)\n",
219
+ " if max_items is not None: self.image_list = self.image_list[:max_items]\n",
220
+ " self.image_size = image_size\n",
221
+ " \n",
222
+ " def __len__(self):\n",
223
+ " return len(self.image_list)\n",
224
+ " \n",
225
+ " def _get_raw_image(self, i):\n",
226
+ " image_name = self.image_list.iloc[0].key\n",
227
+ " image_path = (self.images_root/image_name[0:3]/image_name[3:6]/image_name).with_suffix('.jpg')\n",
228
+ " return default_loader(image_path)\n",
229
+ " \n",
230
+ " def resize_image(self, image):\n",
231
+ " s = min(image.size)\n",
232
+ " r = self.image_size / s\n",
233
+ " s = (round(r * image.size[1]), round(r * image.size[0]))\n",
234
+ " image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n",
235
+ " image = TF.center_crop(image, output_size = 2 * [self.image_size])\n",
236
+ " # FIXME: np.array is necessary in my installation, but it should be automatic\n",
237
+ " image = torch.unsqueeze(T.ToTensor()(np.array(image)), 0)\n",
238
+ " image = image.permute(0, 2, 3, 1).numpy()\n",
239
+ " return image\n",
240
+ " \n",
241
+ " def __getitem__(self, i):\n",
242
+ " image = self._get_raw_image(i)\n",
243
+ " image = self.resize_image(image)\n",
244
+ " # Just return the image, not the caption\n",
245
+ " return image"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "markdown",
250
+ "id": "62ad01c3",
251
+ "metadata": {},
252
+ "source": [
253
+ "## Encoding"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": 10,
259
+ "id": "88f36d0b",
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": [
263
+ "def encode(model, batch):\n",
264
+ " print(\"jitting encode function\")\n",
265
+ " _, indices = model.encode(batch)\n",
266
+ "\n",
267
+ "# # FIXME: The model does not run in my computer (no cudNN currently installed) - faking it\n",
268
+ "# indices = np.random.randint(0, 16384, (batch.shape[0], 256))\n",
269
+ " return indices"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": null,
275
+ "id": "d1f45dd8",
276
+ "metadata": {},
277
+ "outputs": [],
278
+ "source": [
279
+ "#FIXME\n",
280
+ "# import random\n",
281
+ "# model = {}"
282
+ ]
283
+ },
284
+ {
285
+ "cell_type": "code",
286
+ "execution_count": 11,
287
+ "id": "1f35f0cb",
288
+ "metadata": {},
289
+ "outputs": [],
290
+ "source": [
291
+ "from flax.training.common_utils import shard\n",
292
+ "\n",
293
+ "def superbatch_generator(dataloader):\n",
294
+ " iter_loader = iter(dataloader)\n",
295
+ " for batch in iter_loader:\n",
296
+ " batch = batch.squeeze(1)\n",
297
+ " # Skip incomplete last batch\n",
298
+ " if batch.shape[0] == dataloader.batch_size:\n",
299
+ " yield shard(batch)"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "code",
304
+ "execution_count": 13,
305
+ "id": "2210705b",
306
+ "metadata": {},
307
+ "outputs": [],
308
+ "source": [
309
+ "import os\n",
310
+ "import jax\n",
311
+ "\n",
312
+ "def encode_captioned_dataset(dataset, output_jsonl, batch_size=32, num_workers=16):\n",
313
+ " if os.path.isfile(output_jsonl):\n",
314
+ " print(f\"Destination file {output_jsonl} already exists, please move away.\")\n",
315
+ " return\n",
316
+ " \n",
317
+ " num_tpus = jax.device_count()\n",
318
+ " dataloader = DataLoader(dataset, batch_size=num_tpus*batch_size, num_workers=num_workers)\n",
319
+ " superbatches = superbatch_generator(dataloader)\n",
320
+ " \n",
321
+ " p_encoder = pmap(lambda batch: encode(model, batch))\n",
322
+ "\n",
323
+ " # We save each superbatch to avoid reallocation of buffers as we process them.\n",
324
+ " # We keep the file open to prevent excessive file seeks.\n",
325
+ " with open(output_jsonl, \"w\") as file:\n",
326
+ " iterations = len(dataset) // (batch_size * num_tpus)\n",
327
+ " for n in tqdm(range(iterations)):\n",
328
+ " superbatch = next(superbatches)\n",
329
+ " encoded = p_encoder(superbatch.numpy())\n",
330
+ " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
331
+ "\n",
332
+ " # Extract fields from the dataset internal `image_list` property, and save to disk\n",
333
+ " # We need to read from the df because the Dataset only returns images\n",
334
+ " start_index = n * batch_size * num_tpus\n",
335
+ " end_index = (n+1) * batch_size * num_tpus\n",
336
+ " keys = dataset.image_list[\"key\"][start_index:end_index].values\n",
337
+ " captions = dataset.image_list[\"caption\"][start_index:end_index].values\n",
338
+ "# encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
339
+ " batch_df = pd.DataFrame.from_dict({\"key\": keys, \"caption\": captions, \"encoding\": encoded})\n",
340
+ " batch_df.to_json(file, orient='records', lines=True)"
341
+ ]
342
+ },
343
+ {
344
+ "cell_type": "code",
345
+ "execution_count": 14,
346
+ "id": "7704863d",
347
+ "metadata": {},
348
+ "outputs": [
349
+ {
350
+ "name": "stdout",
351
+ "output_type": "stream",
352
+ "text": [
353
+ "Processing /sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_04\n",
354
+ "54024 selected from 500000 total entries\n"
355
+ ]
356
+ },
357
+ {
358
+ "name": "stderr",
359
+ "output_type": "stream",
360
+ "text": [
361
+ "INFO:absl:Starting the local TPU driver.\n",
362
+ "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
363
+ "INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.\n",
364
+ " 0%| | 0/31 [00:00<?, ?it/s]"
365
+ ]
366
+ },
367
+ {
368
+ "name": "stdout",
369
+ "output_type": "stream",
370
+ "text": [
371
+ "jitting encode function\n"
372
+ ]
373
+ },
374
+ {
375
+ "name": "stderr",
376
+ "output_type": "stream",
377
+ "text": [
378
+ "100%|███████████████████████████████████████████████████████████████████████████████| 31/31 [00:02<00:00, 10.61it/s]\n"
379
+ ]
380
+ },
381
+ {
382
+ "name": "stdout",
383
+ "output_type": "stream",
384
+ "text": [
385
+ "Processing /sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_25\n",
386
+ "99530 selected from 500000 total entries\n"
387
+ ]
388
+ },
389
+ {
390
+ "name": "stderr",
391
+ "output_type": "stream",
392
+ "text": [
393
+ " 3%|██▌ | 1/31 [00:01<00:53, 1.79s/it]"
394
+ ]
395
+ },
396
+ {
397
+ "name": "stdout",
398
+ "output_type": "stream",
399
+ "text": [
400
+ "jitting encode function\n"
401
+ ]
402
+ },
403
+ {
404
+ "name": "stderr",
405
+ "output_type": "stream",
406
+ "text": [
407
+ "100%|███████████████████████████████████████████████████████████████████████████████| 31/31 [00:03<00:00, 9.92it/s]\n"
408
+ ]
409
+ }
410
+ ],
411
+ "source": [
412
+ "for split in all_splits:\n",
413
+ " print(f\"Processing {split}\")\n",
414
+ " df = pd.read_json(split, orient=\"records\", lines=True)\n",
415
+ " df['image_exists'] = df.apply(lambda row: image_exists(yfcc100m_images, row['key'], '.' + row['ext']), axis=1)\n",
416
+ " print(f\"{len(df[df.image_exists])} selected from {len(df)} total entries\")\n",
417
+ " \n",
418
+ " df = df[df.image_exists]\n",
419
+ " captions = df.apply(lambda row: ' '.join([row[\"title_clean\"], row[\"description_clean\"]]), axis=1)\n",
420
+ " df[\"caption\"] = captions.values\n",
421
+ " \n",
422
+ " dataset = YFC100Dataset(\n",
423
+ " image_list = df,\n",
424
+ " images_root = yfcc100m_images,\n",
425
+ " image_size = 256,\n",
426
+ "# max_items = 2000,\n",
427
+ " )\n",
428
+ " \n",
429
+ " encode_captioned_dataset(dataset, yfcc100m_output/split.name, batch_size=64, num_workers=16)"
430
+ ]
431
+ },
432
+ {
433
+ "cell_type": "markdown",
434
+ "id": "8953dd84",
435
+ "metadata": {},
436
+ "source": [
437
+ "----"
438
+ ]
439
+ }
440
+ ],
441
+ "metadata": {
442
+ "kernelspec": {
443
+ "display_name": "Python 3 (ipykernel)",
444
+ "language": "python",
445
+ "name": "python3"
446
+ },
447
+ "language_info": {
448
+ "codemirror_mode": {
449
+ "name": "ipython",
450
+ "version": 3
451
+ },
452
+ "file_extension": ".py",
453
+ "mimetype": "text/x-python",
454
+ "name": "python",
455
+ "nbconvert_exporter": "python",
456
+ "pygments_lexer": "ipython3",
457
+ "version": "3.8.10"
458
+ }
459
+ },
460
+ "nbformat": 4,
461
+ "nbformat_minor": 5
462
+ }