Spaces:
Runtime error
Runtime error
Pedro Cuenca
commited on
Commit
·
95d2faf
1
Parent(s):
16f038a
* Data preprocessing pipeline proof of concept.
Browse files- model/data-pipeline.ipynb +366 -0
model/data-pipeline.ipynb
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "bf8fb38a",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Data Pipeline"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "code",
|
13 |
+
"execution_count": 1,
|
14 |
+
"id": "9b83dcb9",
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"from dataclasses import dataclass, field\n",
|
19 |
+
"from pathlib import Path\n",
|
20 |
+
"\n",
|
21 |
+
"import datasets\n",
|
22 |
+
"from datasets import Dataset, load_dataset\n",
|
23 |
+
"import numpy as np\n",
|
24 |
+
"\n",
|
25 |
+
"from transformers import BartTokenizer\n",
|
26 |
+
"\n",
|
27 |
+
"from tqdm import tqdm\n",
|
28 |
+
"\n",
|
29 |
+
"import jax\n",
|
30 |
+
"import jax.numpy as jnp\n",
|
31 |
+
"\n",
|
32 |
+
"from flax.training.common_utils import shard"
|
33 |
+
]
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"cell_type": "markdown",
|
37 |
+
"id": "a661a89e",
|
38 |
+
"metadata": {},
|
39 |
+
"source": [
|
40 |
+
"File containing image paths, captions and VQGAN-encoded indices."
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": 2,
|
46 |
+
"id": "0e84e889",
|
47 |
+
"metadata": {},
|
48 |
+
"outputs": [],
|
49 |
+
"source": [
|
50 |
+
"datafile = '/data/CC12M/images-encoded-10000.tsv' # 9999 encoded images from CC12M"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "markdown",
|
55 |
+
"id": "7fdc640b",
|
56 |
+
"metadata": {},
|
57 |
+
"source": [
|
58 |
+
"TODO: generate train/test splits if necessary."
|
59 |
+
]
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"cell_type": "code",
|
63 |
+
"execution_count": 3,
|
64 |
+
"id": "cc6789b4",
|
65 |
+
"metadata": {},
|
66 |
+
"outputs": [
|
67 |
+
{
|
68 |
+
"name": "stderr",
|
69 |
+
"output_type": "stream",
|
70 |
+
"text": [
|
71 |
+
"Using custom data configuration default-91833df78e844785\n",
|
72 |
+
"Reusing dataset csv (/home/pedro/.cache/huggingface/datasets/csv/default-91833df78e844785/0.0.0/e138af468cb14e747fb46a19c787ffcfa5170c821476d20d5304287ce12bbc23)\n"
|
73 |
+
]
|
74 |
+
}
|
75 |
+
],
|
76 |
+
"source": [
|
77 |
+
"dataset = load_dataset('csv', delimiter='\\t', data_files=[datafile])"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "code",
|
82 |
+
"execution_count": 4,
|
83 |
+
"id": "f3ed4919",
|
84 |
+
"metadata": {},
|
85 |
+
"outputs": [
|
86 |
+
{
|
87 |
+
"data": {
|
88 |
+
"text/plain": [
|
89 |
+
"DatasetDict({\n",
|
90 |
+
" train: Dataset({\n",
|
91 |
+
" features: ['image_file', 'caption', 'encoding'],\n",
|
92 |
+
" num_rows: 9999\n",
|
93 |
+
" })\n",
|
94 |
+
"})"
|
95 |
+
]
|
96 |
+
},
|
97 |
+
"execution_count": 4,
|
98 |
+
"metadata": {},
|
99 |
+
"output_type": "execute_result"
|
100 |
+
}
|
101 |
+
],
|
102 |
+
"source": [
|
103 |
+
"dataset"
|
104 |
+
]
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"cell_type": "code",
|
108 |
+
"execution_count": 5,
|
109 |
+
"id": "a70c7354",
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [
|
112 |
+
{
|
113 |
+
"data": {
|
114 |
+
"text/plain": [
|
115 |
+
"Dataset({\n",
|
116 |
+
" features: ['image_file', 'caption', 'encoding'],\n",
|
117 |
+
" num_rows: 9999\n",
|
118 |
+
"})"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
"execution_count": 5,
|
122 |
+
"metadata": {},
|
123 |
+
"output_type": "execute_result"
|
124 |
+
}
|
125 |
+
],
|
126 |
+
"source": [
|
127 |
+
"dataset = dataset[\"train\"]\n",
|
128 |
+
"dataset"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "markdown",
|
133 |
+
"id": "a73454cf",
|
134 |
+
"metadata": {},
|
135 |
+
"source": [
|
136 |
+
"We don't really need the `image_file` field for training. We'll drop it during pre-processing because we won't be able to numericalize it to a `jnp.array`, which would be required in JAX."
|
137 |
+
]
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"cell_type": "markdown",
|
141 |
+
"id": "7c0fa992",
|
142 |
+
"metadata": {},
|
143 |
+
"source": [
|
144 |
+
"## Preprocessing"
|
145 |
+
]
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"cell_type": "markdown",
|
149 |
+
"id": "a0e36582",
|
150 |
+
"metadata": {},
|
151 |
+
"source": [
|
152 |
+
"The `encoding` field contains a string representation of the encoded indices. We'll convert them to numbers. We also need to tokenize the captions."
|
153 |
+
]
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"cell_type": "code",
|
157 |
+
"execution_count": 6,
|
158 |
+
"id": "d46f6ac5",
|
159 |
+
"metadata": {},
|
160 |
+
"outputs": [],
|
161 |
+
"source": [
|
162 |
+
"# Setting padding=\"max_length\" as we need fixed length inputs for jitted functions\n",
|
163 |
+
"max_length = 256 # Read from data_args.max_source_length\n",
|
164 |
+
"tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')"
|
165 |
+
]
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"cell_type": "code",
|
169 |
+
"execution_count": 7,
|
170 |
+
"id": "4cac6643",
|
171 |
+
"metadata": {},
|
172 |
+
"outputs": [],
|
173 |
+
"source": [
|
174 |
+
"def preprocess_function(examples):\n",
|
175 |
+
" inputs = examples[\"caption\"]\n",
|
176 |
+
"# inputs = [prefix + inp for inp in inputs] # Do we need this?\n",
|
177 |
+
" model_inputs = tokenizer(\n",
|
178 |
+
" inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"np\"\n",
|
179 |
+
" )\n",
|
180 |
+
"\n",
|
181 |
+
" model_inputs[\"eval_encoding\"] = [eval(indices) for indices in examples['encoding']]\n",
|
182 |
+
"\n",
|
183 |
+
" return model_inputs"
|
184 |
+
]
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"cell_type": "code",
|
188 |
+
"execution_count": 8,
|
189 |
+
"id": "e6a4cb91",
|
190 |
+
"metadata": {},
|
191 |
+
"outputs": [],
|
192 |
+
"source": [
|
193 |
+
"num_workers = 48 # We have 96 processors in the TPU\n",
|
194 |
+
"column_names = dataset.column_names\n",
|
195 |
+
"dataset = dataset.map(preprocess_function,\n",
|
196 |
+
" remove_columns=column_names,\n",
|
197 |
+
" batched=True,\n",
|
198 |
+
" num_proc=48\n",
|
199 |
+
")"
|
200 |
+
]
|
201 |
+
},
|
202 |
+
{
|
203 |
+
"cell_type": "code",
|
204 |
+
"execution_count": 9,
|
205 |
+
"id": "a9b1b467",
|
206 |
+
"metadata": {},
|
207 |
+
"outputs": [],
|
208 |
+
"source": [
|
209 |
+
"def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):\n",
|
210 |
+
" \"\"\"\n",
|
211 |
+
" Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.\n",
|
212 |
+
" Shuffle batches if `shuffle` is `True`.\n",
|
213 |
+
" \"\"\"\n",
|
214 |
+
" steps_per_epoch = len(dataset) // batch_size\n",
|
215 |
+
"\n",
|
216 |
+
" if shuffle:\n",
|
217 |
+
" batch_idx = jax.random.permutation(rng, len(dataset))\n",
|
218 |
+
" else:\n",
|
219 |
+
" batch_idx = jnp.arange(len(dataset))\n",
|
220 |
+
"\n",
|
221 |
+
" batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.\n",
|
222 |
+
" batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))\n",
|
223 |
+
"\n",
|
224 |
+
" for idx in batch_idx:\n",
|
225 |
+
" batch = dataset[idx] \n",
|
226 |
+
" batch = {k: jnp.array(v) for k, v in batch.items()}\n",
|
227 |
+
" batch = shard(batch)\n",
|
228 |
+
" yield batch"
|
229 |
+
]
|
230 |
+
},
|
231 |
+
{
|
232 |
+
"cell_type": "code",
|
233 |
+
"execution_count": 10,
|
234 |
+
"id": "0a628505",
|
235 |
+
"metadata": {},
|
236 |
+
"outputs": [
|
237 |
+
{
|
238 |
+
"name": "stderr",
|
239 |
+
"output_type": "stream",
|
240 |
+
"text": [
|
241 |
+
"INFO:absl:Starting the local TPU driver.\n",
|
242 |
+
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
|
243 |
+
"INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter TPU Host\n"
|
244 |
+
]
|
245 |
+
}
|
246 |
+
],
|
247 |
+
"source": [
|
248 |
+
"rng = jax.random.PRNGKey(23) # Use training_args.seed\n",
|
249 |
+
"batch_size = 64 # Per device\n",
|
250 |
+
"super_batch_size = batch_size * jax.device_count()"
|
251 |
+
]
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"cell_type": "code",
|
255 |
+
"execution_count": 11,
|
256 |
+
"id": "b3a5ce7d",
|
257 |
+
"metadata": {},
|
258 |
+
"outputs": [],
|
259 |
+
"source": [
|
260 |
+
"loader = data_loader(rng, dataset, batch_size=super_batch_size)"
|
261 |
+
]
|
262 |
+
},
|
263 |
+
{
|
264 |
+
"cell_type": "code",
|
265 |
+
"execution_count": 12,
|
266 |
+
"id": "67aa8f9c",
|
267 |
+
"metadata": {},
|
268 |
+
"outputs": [],
|
269 |
+
"source": [
|
270 |
+
"superbatch = next(iter(loader))"
|
271 |
+
]
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"cell_type": "code",
|
275 |
+
"execution_count": 13,
|
276 |
+
"id": "7cd99402",
|
277 |
+
"metadata": {},
|
278 |
+
"outputs": [
|
279 |
+
{
|
280 |
+
"data": {
|
281 |
+
"text/plain": [
|
282 |
+
"dict_keys(['attention_mask', 'eval_encoding', 'input_ids'])"
|
283 |
+
]
|
284 |
+
},
|
285 |
+
"execution_count": 13,
|
286 |
+
"metadata": {},
|
287 |
+
"output_type": "execute_result"
|
288 |
+
}
|
289 |
+
],
|
290 |
+
"source": [
|
291 |
+
"superbatch.keys()"
|
292 |
+
]
|
293 |
+
},
|
294 |
+
{
|
295 |
+
"cell_type": "code",
|
296 |
+
"execution_count": 14,
|
297 |
+
"id": "652a4a9e",
|
298 |
+
"metadata": {},
|
299 |
+
"outputs": [
|
300 |
+
{
|
301 |
+
"data": {
|
302 |
+
"text/plain": [
|
303 |
+
"8"
|
304 |
+
]
|
305 |
+
},
|
306 |
+
"execution_count": 14,
|
307 |
+
"metadata": {},
|
308 |
+
"output_type": "execute_result"
|
309 |
+
}
|
310 |
+
],
|
311 |
+
"source": [
|
312 |
+
"len(superbatch[\"eval_encoding\"])"
|
313 |
+
]
|
314 |
+
},
|
315 |
+
{
|
316 |
+
"cell_type": "code",
|
317 |
+
"execution_count": 15,
|
318 |
+
"id": "de7de4e8",
|
319 |
+
"metadata": {},
|
320 |
+
"outputs": [
|
321 |
+
{
|
322 |
+
"data": {
|
323 |
+
"text/plain": [
|
324 |
+
"(8, 64, 256)"
|
325 |
+
]
|
326 |
+
},
|
327 |
+
"execution_count": 15,
|
328 |
+
"metadata": {},
|
329 |
+
"output_type": "execute_result"
|
330 |
+
}
|
331 |
+
],
|
332 |
+
"source": [
|
333 |
+
"superbatch[\"eval_encoding\"].shape"
|
334 |
+
]
|
335 |
+
},
|
336 |
+
{
|
337 |
+
"cell_type": "code",
|
338 |
+
"execution_count": null,
|
339 |
+
"id": "cfe23a71",
|
340 |
+
"metadata": {},
|
341 |
+
"outputs": [],
|
342 |
+
"source": []
|
343 |
+
}
|
344 |
+
],
|
345 |
+
"metadata": {
|
346 |
+
"kernelspec": {
|
347 |
+
"display_name": "Python 3 (ipykernel)",
|
348 |
+
"language": "python",
|
349 |
+
"name": "python3"
|
350 |
+
},
|
351 |
+
"language_info": {
|
352 |
+
"codemirror_mode": {
|
353 |
+
"name": "ipython",
|
354 |
+
"version": 3
|
355 |
+
},
|
356 |
+
"file_extension": ".py",
|
357 |
+
"mimetype": "text/x-python",
|
358 |
+
"name": "python",
|
359 |
+
"nbconvert_exporter": "python",
|
360 |
+
"pygments_lexer": "ipython3",
|
361 |
+
"version": "3.8.10"
|
362 |
+
}
|
363 |
+
},
|
364 |
+
"nbformat": 4,
|
365 |
+
"nbformat_minor": 5
|
366 |
+
}
|