Spaces:
Runtime error
Runtime error
fix: pmap clip32
Browse files
dev/inference/wandb-backend.ipynb
CHANGED
@@ -36,7 +36,8 @@
|
|
36 |
"VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
|
37 |
"normalize_text = True\n",
|
38 |
"latest_only = False # log only latest or all versions\n",
|
39 |
-
"suffix = '_1' # mainly for duplicate inference runs with a deleted version"
|
|
|
40 |
]
|
41 |
},
|
42 |
{
|
@@ -51,7 +52,8 @@
|
|
51 |
"VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
|
52 |
"normalize_text = False\n",
|
53 |
"latest_only = True # log only latest or all versions\n",
|
54 |
-
"suffix = '_2' # mainly for duplicate inference runs with a deleted version"
|
|
|
55 |
]
|
56 |
},
|
57 |
{
|
@@ -82,7 +84,12 @@
|
|
82 |
"clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
83 |
"processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
84 |
"clip_params = replicate(clip.params)\n",
|
85 |
-
"vqgan_params = replicate(vqgan.params)"
|
|
|
|
|
|
|
|
|
|
|
86 |
]
|
87 |
},
|
88 |
{
|
@@ -98,8 +105,14 @@
|
|
98 |
"\n",
|
99 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
100 |
"def p_clip(inputs):\n",
|
101 |
-
" logits = clip(**inputs).logits_per_image\n",
|
102 |
-
" return logits"
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
]
|
104 |
},
|
105 |
{
|
@@ -158,7 +171,7 @@
|
|
158 |
"# retrieve inference run details\n",
|
159 |
"def get_last_inference_version(run_id):\n",
|
160 |
" try:\n",
|
161 |
-
" inference_run = api.run(f'dalle-mini/dalle-mini/
|
162 |
" return inference_run.summary.get('version', None)\n",
|
163 |
" except:\n",
|
164 |
" return None"
|
@@ -215,6 +228,8 @@
|
|
215 |
" print(f'Processing artifact: {artifact.name}')\n",
|
216 |
" version = int(artifact.version[1:])\n",
|
217 |
" results = []\n",
|
|
|
|
|
218 |
" columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
|
219 |
" \n",
|
220 |
" if latest_only:\n",
|
@@ -232,7 +247,7 @@
|
|
232 |
"\n",
|
233 |
" # start/resume corresponding run\n",
|
234 |
" if run is None:\n",
|
235 |
-
" run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'
|
236 |
"\n",
|
237 |
" # work in temporary directory\n",
|
238 |
" with tempfile.TemporaryDirectory() as tmp:\n",
|
@@ -283,7 +298,6 @@
|
|
283 |
" logits = logits.reshape(-1, num_images)\n",
|
284 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
285 |
" logits = jax.device_get(logits)\n",
|
286 |
-
"\n",
|
287 |
" # add to results table\n",
|
288 |
" for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
|
289 |
" if sample == padding_item: continue\n",
|
@@ -291,11 +305,68 @@
|
|
291 |
" top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
|
292 |
" top_scores = [scores[x] for x in idx]\n",
|
293 |
" results.append([sample] + top_images + top_scores)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
"\n",
|
295 |
" # log results\n",
|
296 |
" table = wandb.Table(columns=columns, data=results)\n",
|
297 |
" run.log({'Samples': table, 'version': version})\n",
|
298 |
-
" wandb.finish()"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
]
|
300 |
},
|
301 |
{
|
@@ -314,12 +385,10 @@
|
|
314 |
{
|
315 |
"cell_type": "code",
|
316 |
"execution_count": null,
|
317 |
-
"id": "
|
318 |
"metadata": {},
|
319 |
"outputs": [],
|
320 |
-
"source": [
|
321 |
-
"wandb.finish()"
|
322 |
-
]
|
323 |
}
|
324 |
],
|
325 |
"metadata": {
|
|
|
36 |
"VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
|
37 |
"normalize_text = True\n",
|
38 |
"latest_only = False # log only latest or all versions\n",
|
39 |
+
"suffix = '_1' # mainly for duplicate inference runs with a deleted version\n",
|
40 |
+
"add_clip_32 = False"
|
41 |
]
|
42 |
},
|
43 |
{
|
|
|
52 |
"VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
|
53 |
"normalize_text = False\n",
|
54 |
"latest_only = True # log only latest or all versions\n",
|
55 |
+
"suffix = '_2' # mainly for duplicate inference runs with a deleted version\n",
|
56 |
+
"add_clip_32 = True"
|
57 |
]
|
58 |
},
|
59 |
{
|
|
|
84 |
"clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
85 |
"processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
86 |
"clip_params = replicate(clip.params)\n",
|
87 |
+
"vqgan_params = replicate(vqgan.params)\n",
|
88 |
+
"\n",
|
89 |
+
"if add_clip_32:\n",
|
90 |
+
" clip32 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
91 |
+
" processor32 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
92 |
+
" clip32_params = replicate(clip32.params)"
|
93 |
]
|
94 |
},
|
95 |
{
|
|
|
105 |
"\n",
|
106 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
107 |
"def p_clip(inputs):\n",
|
108 |
+
" logits = clip(params=clip_params, **inputs).logits_per_image\n",
|
109 |
+
" return logits\n",
|
110 |
+
"\n",
|
111 |
+
"if add_clip_32:\n",
|
112 |
+
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
113 |
+
" def p_clip32(inputs):\n",
|
114 |
+
" logits = clip32(params=clip32_params, **inputs).logits_per_image\n",
|
115 |
+
" return logits"
|
116 |
]
|
117 |
},
|
118 |
{
|
|
|
171 |
"# retrieve inference run details\n",
|
172 |
"def get_last_inference_version(run_id):\n",
|
173 |
" try:\n",
|
174 |
+
" inference_run = api.run(f'dalle-mini/dalle-mini/{run_id}-clip16{suffix}')\n",
|
175 |
" return inference_run.summary.get('version', None)\n",
|
176 |
" except:\n",
|
177 |
" return None"
|
|
|
228 |
" print(f'Processing artifact: {artifact.name}')\n",
|
229 |
" version = int(artifact.version[1:])\n",
|
230 |
" results = []\n",
|
231 |
+
" if add_clip_32:\n",
|
232 |
+
" results32 = []\n",
|
233 |
" columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
|
234 |
" \n",
|
235 |
" if latest_only:\n",
|
|
|
247 |
"\n",
|
248 |
" # start/resume corresponding run\n",
|
249 |
" if run is None:\n",
|
250 |
+
" run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'{run_id}-clip16{suffix}', resume='allow')\n",
|
251 |
"\n",
|
252 |
" # work in temporary directory\n",
|
253 |
" with tempfile.TemporaryDirectory() as tmp:\n",
|
|
|
298 |
" logits = logits.reshape(-1, num_images)\n",
|
299 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
300 |
" logits = jax.device_get(logits)\n",
|
|
|
301 |
" # add to results table\n",
|
302 |
" for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
|
303 |
" if sample == padding_item: continue\n",
|
|
|
305 |
" top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
|
306 |
" top_scores = [scores[x] for x in idx]\n",
|
307 |
" results.append([sample] + top_images + top_scores)\n",
|
308 |
+
" \n",
|
309 |
+
" # get clip 32 scores - TODO: this should be refactored as it is same code as above\n",
|
310 |
+
" if add_clip_32:\n",
|
311 |
+
" print('Calculating CLIP 32 scores')\n",
|
312 |
+
" clip_inputs = processor32(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
|
313 |
+
" # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
|
314 |
+
" images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
|
315 |
+
" clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
|
316 |
+
" clip_inputs = shard(clip_inputs)\n",
|
317 |
+
" logits = p_clip32(clip_inputs)\n",
|
318 |
+
" logits = logits.reshape(-1, num_images)\n",
|
319 |
+
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
320 |
+
" logits = jax.device_get(logits)\n",
|
321 |
+
" # add to results table\n",
|
322 |
+
" for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
|
323 |
+
" if sample == padding_item: continue\n",
|
324 |
+
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
325 |
+
" top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
|
326 |
+
" top_scores = [scores[x] for x in idx]\n",
|
327 |
+
" results32.append([sample] + top_images + top_scores)\n",
|
328 |
"\n",
|
329 |
" # log results\n",
|
330 |
" table = wandb.Table(columns=columns, data=results)\n",
|
331 |
" run.log({'Samples': table, 'version': version})\n",
|
332 |
+
" wandb.finish()\n",
|
333 |
+
" \n",
|
334 |
+
" if add_clip_32: \n",
|
335 |
+
" run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'{run_id}-clip32{suffix}', resume='allow')\n",
|
336 |
+
" table = wandb.Table(columns=columns, data=results32)\n",
|
337 |
+
" run.log({'Samples': table, 'version': version})\n",
|
338 |
+
" wandb.finish()\n",
|
339 |
+
" run = None # ensure we don't log on this run"
|
340 |
+
]
|
341 |
+
},
|
342 |
+
{
|
343 |
+
"cell_type": "code",
|
344 |
+
"execution_count": null,
|
345 |
+
"id": "fdcd09d6-079c-461a-a81a-d9e650d3b099",
|
346 |
+
"metadata": {},
|
347 |
+
"outputs": [],
|
348 |
+
"source": [
|
349 |
+
"p_clip32"
|
350 |
+
]
|
351 |
+
},
|
352 |
+
{
|
353 |
+
"cell_type": "code",
|
354 |
+
"execution_count": null,
|
355 |
+
"id": "7d86ceee-c9ac-4860-abad-410cadd16c3c",
|
356 |
+
"metadata": {},
|
357 |
+
"outputs": [],
|
358 |
+
"source": [
|
359 |
+
"clip_inputs['attention_mask'].shape, clip_inputs['pixel_values'].shape"
|
360 |
+
]
|
361 |
+
},
|
362 |
+
{
|
363 |
+
"cell_type": "code",
|
364 |
+
"execution_count": null,
|
365 |
+
"id": "fbba4858-da2d-4dd5-97b7-ce3ab4746f96",
|
366 |
+
"metadata": {},
|
367 |
+
"outputs": [],
|
368 |
+
"source": [
|
369 |
+
"clip_inputs['input_ids'].shape"
|
370 |
]
|
371 |
},
|
372 |
{
|
|
|
385 |
{
|
386 |
"cell_type": "code",
|
387 |
"execution_count": null,
|
388 |
+
"id": "a7a5fdf5-3c6e-421b-96a8-5115f730328c",
|
389 |
"metadata": {},
|
390 |
"outputs": [],
|
391 |
+
"source": []
|
|
|
|
|
392 |
}
|
393 |
],
|
394 |
"metadata": {
|