Spaces:
Runtime error
Runtime error
feat: add functions
Browse files
dev/inference/wandb-backend.ipynb
CHANGED
@@ -2,13 +2,15 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
|
7 |
"metadata": {},
|
8 |
"outputs": [],
|
9 |
"source": [
|
10 |
"import csv\n",
|
11 |
"import tempfile\n",
|
|
|
|
|
12 |
"import wandb\n",
|
13 |
"from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n",
|
14 |
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
@@ -42,26 +44,82 @@
|
|
42 |
},
|
43 |
{
|
44 |
"cell_type": "code",
|
45 |
-
"execution_count":
|
46 |
"id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
|
47 |
"metadata": {},
|
48 |
"outputs": [],
|
49 |
"source": [
|
50 |
"with open('samples.csv', newline='', encoding='utf8') as f:\n",
|
51 |
-
" reader = csv.
|
|
|
52 |
" for row in reader:\n",
|
53 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
]
|
55 |
},
|
56 |
{
|
57 |
"cell_type": "code",
|
58 |
"execution_count": null,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
"id": "3ffb1d09-bd1c-4f57-9ae5-3eda6f7d3a08",
|
60 |
"metadata": {},
|
61 |
"outputs": [],
|
62 |
"source": [
|
|
|
63 |
"wandb_run = wandb_runs[0]\n",
|
64 |
-
"
|
65 |
]
|
66 |
},
|
67 |
{
|
@@ -280,27 +338,30 @@
|
|
280 |
},
|
281 |
{
|
282 |
"cell_type": "code",
|
283 |
-
"execution_count":
|
284 |
"id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
|
285 |
"metadata": {},
|
286 |
"outputs": [],
|
287 |
-
"source": [
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
|
|
|
|
|
|
304 |
},
|
305 |
{
|
306 |
"cell_type": "code",
|
@@ -323,7 +384,7 @@
|
|
323 |
{
|
324 |
"cell_type": "code",
|
325 |
"execution_count": null,
|
326 |
-
"id": "
|
327 |
"metadata": {},
|
328 |
"outputs": [],
|
329 |
"source": []
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 197,
|
6 |
"id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
|
7 |
"metadata": {},
|
8 |
"outputs": [],
|
9 |
"source": [
|
10 |
"import csv\n",
|
11 |
"import tempfile\n",
|
12 |
+
"from functools import partial\n",
|
13 |
+
"import jax\n",
|
14 |
"import wandb\n",
|
15 |
"from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n",
|
16 |
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
|
|
44 |
},
|
45 |
{
|
46 |
"cell_type": "code",
|
47 |
+
"execution_count": 245,
|
48 |
"id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
|
49 |
"metadata": {},
|
50 |
"outputs": [],
|
51 |
"source": [
|
52 |
"with open('samples.csv', newline='', encoding='utf8') as f:\n",
|
53 |
+
" reader = csv.DictReader(f)\n",
|
54 |
+
" samples = []\n",
|
55 |
" for row in reader:\n",
|
56 |
+
" samples.append(row)"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": 246,
|
62 |
+
"id": "f75b2869-fc25-4f56-b937-e97bbb712ede",
|
63 |
+
"metadata": {},
|
64 |
+
"outputs": [
|
65 |
+
{
|
66 |
+
"data": {
|
67 |
+
"text/plain": [
|
68 |
+
"101"
|
69 |
+
]
|
70 |
+
},
|
71 |
+
"execution_count": 246,
|
72 |
+
"metadata": {},
|
73 |
+
"output_type": "execute_result"
|
74 |
+
}
|
75 |
+
],
|
76 |
+
"source": [
|
77 |
+
"len(samples)"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "code",
|
82 |
+
"execution_count": 248,
|
83 |
+
"id": "2ea0b166-a20c-4d78-bffb-b792ca512d17",
|
84 |
+
"metadata": {},
|
85 |
+
"outputs": [
|
86 |
+
{
|
87 |
+
"data": {
|
88 |
+
"text/plain": [
|
89 |
+
"104"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
"execution_count": 248,
|
93 |
+
"metadata": {},
|
94 |
+
"output_type": "execute_result"
|
95 |
+
}
|
96 |
+
],
|
97 |
+
"source": [
|
98 |
+
"samples_to_add = ['empty'] * (-len(samples) % 8)\n",
|
99 |
+
"samples.extend(samples_to_add)\n",
|
100 |
+
"len(samples)"
|
101 |
]
|
102 |
},
|
103 |
{
|
104 |
"cell_type": "code",
|
105 |
"execution_count": null,
|
106 |
+
"id": "a2c629e9-1a82-40c6-a260-ca1780c19a2e",
|
107 |
+
"metadata": {},
|
108 |
+
"outputs": [],
|
109 |
+
"source": [
|
110 |
+
"api = wandb.Api()"
|
111 |
+
]
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"cell_type": "code",
|
115 |
+
"execution_count": 204,
|
116 |
"id": "3ffb1d09-bd1c-4f57-9ae5-3eda6f7d3a08",
|
117 |
"metadata": {},
|
118 |
"outputs": [],
|
119 |
"source": [
|
120 |
+
"# TODO: iterate on runs\n",
|
121 |
"wandb_run = wandb_runs[0]\n",
|
122 |
+
"functions_pmapped = False"
|
123 |
]
|
124 |
},
|
125 |
{
|
|
|
338 |
},
|
339 |
{
|
340 |
"cell_type": "code",
|
341 |
+
"execution_count": 207,
|
342 |
"id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
|
343 |
"metadata": {},
|
344 |
"outputs": [],
|
345 |
+
"source": [
|
346 |
+
"# function to generate encoded images\n",
|
347 |
+
"# we should generate this function only once per run\n",
|
348 |
+
"if not functions_pmapped:\n",
|
349 |
+
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
350 |
+
" def p_generate(tokenized_prompt, key, params):\n",
|
351 |
+
" return model.generate(\n",
|
352 |
+
" **tokenized_prompt,\n",
|
353 |
+
" do_sample=True,\n",
|
354 |
+
" num_beams=1,\n",
|
355 |
+
" prng_key=key,\n",
|
356 |
+
" params=params\n",
|
357 |
+
" )\n",
|
358 |
+
" \n",
|
359 |
+
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
360 |
+
" def p_decode(indices, params):\n",
|
361 |
+
" return vqgan.decode_code(indices, params=params)\n",
|
362 |
+
" \n",
|
363 |
+
" functions_pmapped = False"
|
364 |
+
]
|
365 |
},
|
366 |
{
|
367 |
"cell_type": "code",
|
|
|
384 |
{
|
385 |
"cell_type": "code",
|
386 |
"execution_count": null,
|
387 |
+
"id": "e79ac8f2-adc2-4a16-970c-dadcceadd566",
|
388 |
"metadata": {},
|
389 |
"outputs": [],
|
390 |
"source": []
|