Spaces:
Runtime error
Runtime error
fix(inference): use float32 + flatten logits
Browse files
tools/inference/inference_pipeline.ipynb
CHANGED
@@ -70,7 +70,7 @@
|
|
70 |
"# Model references\n",
|
71 |
"\n",
|
72 |
"# dalle-mini\n",
|
73 |
-
"DALLE_MODEL = 'dalle-mini/dalle-mini/model-
|
74 |
"DALLE_COMMIT_ID = None # used only with 🤗 hub\n",
|
75 |
"\n",
|
76 |
"# VQGAN model\n",
|
@@ -92,7 +92,13 @@
|
|
92 |
"import jax.numpy as jnp\n",
|
93 |
"\n",
|
94 |
"# type used for computation - use bfloat16 on TPU's\n",
|
95 |
-
"dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32"
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
]
|
97 |
},
|
98 |
{
|
@@ -281,7 +287,7 @@
|
|
281 |
},
|
282 |
"outputs": [],
|
283 |
"source": [
|
284 |
-
"prompt = '
|
285 |
]
|
286 |
},
|
287 |
{
|
@@ -292,7 +298,8 @@
|
|
292 |
},
|
293 |
"outputs": [],
|
294 |
"source": [
|
295 |
-
"processed_prompt = text_normalizer(prompt) if model.config.normalize_text else prompt"
|
|
|
296 |
]
|
297 |
},
|
298 |
{
|
@@ -375,7 +382,7 @@
|
|
375 |
"outputs": [],
|
376 |
"source": [
|
377 |
"# number of predictions\n",
|
378 |
-
"n_predictions =
|
379 |
"\n",
|
380 |
"# We can customize top_k/top_p used for generating samples\n",
|
381 |
"gen_top_k = None\n",
|
@@ -431,7 +438,7 @@
|
|
431 |
"# get clip scores\n",
|
432 |
"clip_inputs = processor(text=[prompt] * jax.device_count(), images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
|
433 |
"logits = p_clip(shard(clip_inputs), clip_params)\n",
|
434 |
-
"logits = logits.squeeze()"
|
435 |
]
|
436 |
},
|
437 |
{
|
|
|
70 |
"# Model references\n",
|
71 |
"\n",
|
72 |
"# dalle-mini\n",
|
73 |
+
"DALLE_MODEL = 'dalle-mini/dalle-mini/model-3bqwu04f:latest' # can be wandb artifact or 🤗 Hub or local folder\n",
|
74 |
"DALLE_COMMIT_ID = None # used only with 🤗 hub\n",
|
75 |
"\n",
|
76 |
"# VQGAN model\n",
|
|
|
92 |
"import jax.numpy as jnp\n",
|
93 |
"\n",
|
94 |
"# type used for computation - use bfloat16 on TPU's\n",
|
95 |
+
"dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32\n",
|
96 |
+
"\n",
|
97 |
+
"# TODO:\n",
|
98 |
+
"# - we currently have an issue with model.generate() in bfloat16\n",
|
99 |
+
"# - https://github.com/google/jax/pull/9089 should fix it\n",
|
100 |
+
"# - remove below line and test on TPU with next release of JAX\n",
|
101 |
+
"dtype = jnp.float32"
|
102 |
]
|
103 |
},
|
104 |
{
|
|
|
287 |
},
|
288 |
"outputs": [],
|
289 |
"source": [
|
290 |
+
"prompt = 'a red T-shirt'"
|
291 |
]
|
292 |
},
|
293 |
{
|
|
|
298 |
},
|
299 |
"outputs": [],
|
300 |
"source": [
|
301 |
+
"processed_prompt = text_normalizer(prompt) if model.config.normalize_text else prompt\n",
|
302 |
+
"processed_prompt"
|
303 |
]
|
304 |
},
|
305 |
{
|
|
|
382 |
"outputs": [],
|
383 |
"source": [
|
384 |
"# number of predictions\n",
|
385 |
+
"n_predictions = 32\n",
|
386 |
"\n",
|
387 |
"# We can customize top_k/top_p used for generating samples\n",
|
388 |
"gen_top_k = None\n",
|
|
|
438 |
"# get clip scores\n",
|
439 |
"clip_inputs = processor(text=[prompt] * jax.device_count(), images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
|
440 |
"logits = p_clip(shard(clip_inputs), clip_params)\n",
|
441 |
+
"logits = logits.squeeze().flatten()"
|
442 |
]
|
443 |
},
|
444 |
{
|
tools/inference/log_inference_samples.ipynb
DELETED
@@ -1,434 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
|
7 |
-
"metadata": {},
|
8 |
-
"outputs": [],
|
9 |
-
"source": [
|
10 |
-
"import tempfile\n",
|
11 |
-
"from functools import partial\n",
|
12 |
-
"import random\n",
|
13 |
-
"import numpy as np\n",
|
14 |
-
"from PIL import Image\n",
|
15 |
-
"from tqdm.notebook import tqdm\n",
|
16 |
-
"import jax\n",
|
17 |
-
"import jax.numpy as jnp\n",
|
18 |
-
"from flax.training.common_utils import shard, shard_prng_key\n",
|
19 |
-
"from flax.jax_utils import replicate\n",
|
20 |
-
"import wandb\n",
|
21 |
-
"from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n",
|
22 |
-
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
23 |
-
"from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel\n",
|
24 |
-
"from dalle_mini.text import TextNormalizer"
|
25 |
-
]
|
26 |
-
},
|
27 |
-
{
|
28 |
-
"cell_type": "code",
|
29 |
-
"execution_count": null,
|
30 |
-
"id": "92f4557c-fd7f-4edc-81c2-de0b0a10c270",
|
31 |
-
"metadata": {},
|
32 |
-
"outputs": [],
|
33 |
-
"source": [
|
34 |
-
"run_ids = [\"63otg87g\"]\n",
|
35 |
-
"ENTITY, PROJECT = \"dalle-mini\", \"dalle-mini\" # used only for training run\n",
|
36 |
-
"VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
|
37 |
-
" \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
|
38 |
-
" \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\",\n",
|
39 |
-
")\n",
|
40 |
-
"latest_only = True # log only latest or all versions\n",
|
41 |
-
"suffix = \"\" # mainly for duplicate inference runs with a deleted version\n",
|
42 |
-
"add_clip_32 = False"
|
43 |
-
]
|
44 |
-
},
|
45 |
-
{
|
46 |
-
"cell_type": "code",
|
47 |
-
"execution_count": null,
|
48 |
-
"id": "71f27b96-7e6c-4472-a2e4-e99a8fb67a72",
|
49 |
-
"metadata": {},
|
50 |
-
"outputs": [],
|
51 |
-
"source": [
|
52 |
-
"# model.generate parameters - Not used yet\n",
|
53 |
-
"gen_top_k = None\n",
|
54 |
-
"gen_top_p = None\n",
|
55 |
-
"temperature = None"
|
56 |
-
]
|
57 |
-
},
|
58 |
-
{
|
59 |
-
"cell_type": "code",
|
60 |
-
"execution_count": null,
|
61 |
-
"id": "93b2e24b-f0e5-4abe-a3ec-0aa834cc3bf3",
|
62 |
-
"metadata": {},
|
63 |
-
"outputs": [],
|
64 |
-
"source": [
|
65 |
-
"batch_size = 8\n",
|
66 |
-
"num_images = 128\n",
|
67 |
-
"top_k = 8\n",
|
68 |
-
"text_normalizer = TextNormalizer()\n",
|
69 |
-
"padding_item = \"NONE\"\n",
|
70 |
-
"seed = random.randint(0, 2 ** 32 - 1)\n",
|
71 |
-
"key = jax.random.PRNGKey(seed)\n",
|
72 |
-
"api = wandb.Api()"
|
73 |
-
]
|
74 |
-
},
|
75 |
-
{
|
76 |
-
"cell_type": "code",
|
77 |
-
"execution_count": null,
|
78 |
-
"id": "c6a878fa-4bf5-4978-abb5-e235841d765b",
|
79 |
-
"metadata": {},
|
80 |
-
"outputs": [],
|
81 |
-
"source": [
|
82 |
-
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
|
83 |
-
"vqgan_params = replicate(vqgan.params)\n",
|
84 |
-
"\n",
|
85 |
-
"clip16 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
86 |
-
"processor16 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
87 |
-
"clip16_params = replicate(clip16.params)\n",
|
88 |
-
"\n",
|
89 |
-
"if add_clip_32:\n",
|
90 |
-
" clip32 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
91 |
-
" processor32 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
92 |
-
" clip32_params = replicate(clip32.params)"
|
93 |
-
]
|
94 |
-
},
|
95 |
-
{
|
96 |
-
"cell_type": "code",
|
97 |
-
"execution_count": null,
|
98 |
-
"id": "a500dd07-dbc3-477d-80d4-2b73a3b83ef3",
|
99 |
-
"metadata": {},
|
100 |
-
"outputs": [],
|
101 |
-
"source": [
|
102 |
-
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
103 |
-
"def p_decode(indices, params):\n",
|
104 |
-
" return vqgan.decode_code(indices, params=params)\n",
|
105 |
-
"\n",
|
106 |
-
"\n",
|
107 |
-
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
108 |
-
"def p_clip16(inputs, params):\n",
|
109 |
-
" logits = clip16(params=params, **inputs).logits_per_image\n",
|
110 |
-
" return logits\n",
|
111 |
-
"\n",
|
112 |
-
"\n",
|
113 |
-
"if add_clip_32:\n",
|
114 |
-
"\n",
|
115 |
-
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
116 |
-
" def p_clip32(inputs, params):\n",
|
117 |
-
" logits = clip32(params=params, **inputs).logits_per_image\n",
|
118 |
-
" return logits"
|
119 |
-
]
|
120 |
-
},
|
121 |
-
{
|
122 |
-
"cell_type": "code",
|
123 |
-
"execution_count": null,
|
124 |
-
"id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
|
125 |
-
"metadata": {},
|
126 |
-
"outputs": [],
|
127 |
-
"source": [
|
128 |
-
"with open(\"samples.txt\", encoding=\"utf8\") as f:\n",
|
129 |
-
" samples = [l.strip() for l in f.readlines()]\n",
|
130 |
-
" # make list multiple of batch_size by adding elements\n",
|
131 |
-
" samples_to_add = [padding_item] * (-len(samples) % batch_size)\n",
|
132 |
-
" samples.extend(samples_to_add)\n",
|
133 |
-
" # reshape\n",
|
134 |
-
" samples = [samples[i : i + batch_size] for i in range(0, len(samples), batch_size)]"
|
135 |
-
]
|
136 |
-
},
|
137 |
-
{
|
138 |
-
"cell_type": "code",
|
139 |
-
"execution_count": null,
|
140 |
-
"id": "f3e02d9d-4ee1-49e7-a7bc-4d8b139e9614",
|
141 |
-
"metadata": {},
|
142 |
-
"outputs": [],
|
143 |
-
"source": [
|
144 |
-
"def get_artifact_versions(run_id, latest_only=False):\n",
|
145 |
-
" try:\n",
|
146 |
-
" if latest_only:\n",
|
147 |
-
" return [\n",
|
148 |
-
" api.artifact(\n",
|
149 |
-
" type=\"bart_model\", name=f\"{ENTITY}/{PROJECT}/model-{run_id}:latest\"\n",
|
150 |
-
" )\n",
|
151 |
-
" ]\n",
|
152 |
-
" else:\n",
|
153 |
-
" return api.artifact_versions(\n",
|
154 |
-
" type_name=\"bart_model\",\n",
|
155 |
-
" name=f\"{ENTITY}/{PROJECT}/model-{run_id}\",\n",
|
156 |
-
" per_page=10000,\n",
|
157 |
-
" )\n",
|
158 |
-
" except:\n",
|
159 |
-
" return []"
|
160 |
-
]
|
161 |
-
},
|
162 |
-
{
|
163 |
-
"cell_type": "code",
|
164 |
-
"execution_count": null,
|
165 |
-
"id": "f0d7ed17-7abb-4a31-ab3c-a12b9039a570",
|
166 |
-
"metadata": {},
|
167 |
-
"outputs": [],
|
168 |
-
"source": [
|
169 |
-
"def get_training_config(run_id):\n",
|
170 |
-
" training_run = api.run(f\"{ENTITY}/{PROJECT}/{run_id}\")\n",
|
171 |
-
" config = training_run.config\n",
|
172 |
-
" return config"
|
173 |
-
]
|
174 |
-
},
|
175 |
-
{
|
176 |
-
"cell_type": "code",
|
177 |
-
"execution_count": null,
|
178 |
-
"id": "7e784a43-626d-4e8d-9e47-a23775b2f35f",
|
179 |
-
"metadata": {},
|
180 |
-
"outputs": [],
|
181 |
-
"source": [
|
182 |
-
"# retrieve inference run details\n",
|
183 |
-
"def get_last_inference_version(run_id):\n",
|
184 |
-
" try:\n",
|
185 |
-
" inference_run = api.run(f\"dalle-mini/dalle-mini/{run_id}-clip16{suffix}\")\n",
|
186 |
-
" return inference_run.summary.get(\"version\", None)\n",
|
187 |
-
" except:\n",
|
188 |
-
" return None"
|
189 |
-
]
|
190 |
-
},
|
191 |
-
{
|
192 |
-
"cell_type": "code",
|
193 |
-
"execution_count": null,
|
194 |
-
"id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
|
195 |
-
"metadata": {},
|
196 |
-
"outputs": [],
|
197 |
-
"source": [
|
198 |
-
"# compile functions - needed only once per run\n",
|
199 |
-
"def pmap_model_function(model):\n",
|
200 |
-
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
201 |
-
" def _generate(tokenized_prompt, key, params):\n",
|
202 |
-
" return model.generate(\n",
|
203 |
-
" **tokenized_prompt,\n",
|
204 |
-
" do_sample=True,\n",
|
205 |
-
" num_beams=1,\n",
|
206 |
-
" prng_key=key,\n",
|
207 |
-
" params=params,\n",
|
208 |
-
" top_k=gen_top_k,\n",
|
209 |
-
" top_p=gen_top_p\n",
|
210 |
-
" )\n",
|
211 |
-
"\n",
|
212 |
-
" return _generate"
|
213 |
-
]
|
214 |
-
},
|
215 |
-
{
|
216 |
-
"cell_type": "code",
|
217 |
-
"execution_count": null,
|
218 |
-
"id": "23b2444c-67a9-44d7-abd1-187ed83a9431",
|
219 |
-
"metadata": {},
|
220 |
-
"outputs": [],
|
221 |
-
"source": [
|
222 |
-
"run_id = run_ids[0]\n",
|
223 |
-
"# TODO: loop over runs"
|
224 |
-
]
|
225 |
-
},
|
226 |
-
{
|
227 |
-
"cell_type": "code",
|
228 |
-
"execution_count": null,
|
229 |
-
"id": "bba70f33-af8b-4eb3-9973-7be672301a0b",
|
230 |
-
"metadata": {},
|
231 |
-
"outputs": [],
|
232 |
-
"source": [
|
233 |
-
"artifact_versions = get_artifact_versions(run_id, latest_only)\n",
|
234 |
-
"last_inference_version = get_last_inference_version(run_id)\n",
|
235 |
-
"training_config = get_training_config(run_id)\n",
|
236 |
-
"run = None\n",
|
237 |
-
"p_generate = None\n",
|
238 |
-
"model_files = [\n",
|
239 |
-
" \"config.json\",\n",
|
240 |
-
" \"flax_model.msgpack\",\n",
|
241 |
-
" \"merges.txt\",\n",
|
242 |
-
" \"special_tokens_map.json\",\n",
|
243 |
-
" \"tokenizer.json\",\n",
|
244 |
-
" \"tokenizer_config.json\",\n",
|
245 |
-
" \"vocab.json\",\n",
|
246 |
-
"]\n",
|
247 |
-
"for artifact in artifact_versions:\n",
|
248 |
-
" print(f\"Processing artifact: {artifact.name}\")\n",
|
249 |
-
" version = int(artifact.version[1:])\n",
|
250 |
-
" results16, results32 = [], []\n",
|
251 |
-
" columns = [\"Caption\"] + [f\"Image {i+1}\" for i in range(top_k)]\n",
|
252 |
-
"\n",
|
253 |
-
" if latest_only:\n",
|
254 |
-
" assert last_inference_version is None or version > last_inference_version\n",
|
255 |
-
" else:\n",
|
256 |
-
" if last_inference_version is None:\n",
|
257 |
-
" # we should start from v0\n",
|
258 |
-
" assert version == 0\n",
|
259 |
-
" elif version <= last_inference_version:\n",
|
260 |
-
" print(\n",
|
261 |
-
" f\"v{version} has already been logged (versions logged up to v{last_inference_version}\"\n",
|
262 |
-
" )\n",
|
263 |
-
" else:\n",
|
264 |
-
" # check we are logging the correct version\n",
|
265 |
-
" assert version == last_inference_version + 1\n",
|
266 |
-
"\n",
|
267 |
-
" # start/resume corresponding run\n",
|
268 |
-
" if run is None:\n",
|
269 |
-
" run = wandb.init(\n",
|
270 |
-
" job_type=\"inference\",\n",
|
271 |
-
" entity=\"dalle-mini\",\n",
|
272 |
-
" project=\"dalle-mini\",\n",
|
273 |
-
" config=training_config,\n",
|
274 |
-
" id=f\"{run_id}-clip16{suffix}\",\n",
|
275 |
-
" resume=\"allow\",\n",
|
276 |
-
" )\n",
|
277 |
-
"\n",
|
278 |
-
" # work in temporary directory\n",
|
279 |
-
" with tempfile.TemporaryDirectory() as tmp:\n",
|
280 |
-
"\n",
|
281 |
-
" # download model files\n",
|
282 |
-
" artifact = run.use_artifact(artifact)\n",
|
283 |
-
" for f in model_files:\n",
|
284 |
-
" artifact.get_path(f).download(tmp)\n",
|
285 |
-
"\n",
|
286 |
-
" # load tokenizer and model\n",
|
287 |
-
" tokenizer = BartTokenizer.from_pretrained(tmp)\n",
|
288 |
-
" model = CustomFlaxBartForConditionalGeneration.from_pretrained(tmp)\n",
|
289 |
-
" model_params = replicate(model.params)\n",
|
290 |
-
"\n",
|
291 |
-
" # pmap model function needs to happen only once per model config\n",
|
292 |
-
" if p_generate is None:\n",
|
293 |
-
" p_generate = pmap_model_function(model)\n",
|
294 |
-
"\n",
|
295 |
-
" # process one batch of captions\n",
|
296 |
-
" for batch in tqdm(samples):\n",
|
297 |
-
" processed_prompts = (\n",
|
298 |
-
" [text_normalizer(x) for x in batch]\n",
|
299 |
-
" if model.config.normalize_text\n",
|
300 |
-
" else list(batch)\n",
|
301 |
-
" )\n",
|
302 |
-
"\n",
|
303 |
-
" # repeat the prompts to distribute over each device and tokenize\n",
|
304 |
-
" processed_prompts = processed_prompts * jax.device_count()\n",
|
305 |
-
" tokenized_prompt = tokenizer(\n",
|
306 |
-
" processed_prompts,\n",
|
307 |
-
" return_tensors=\"jax\",\n",
|
308 |
-
" padding=\"max_length\",\n",
|
309 |
-
" truncation=True,\n",
|
310 |
-
" max_length=128,\n",
|
311 |
-
" ).data\n",
|
312 |
-
" tokenized_prompt = shard(tokenized_prompt)\n",
|
313 |
-
"\n",
|
314 |
-
" # generate images\n",
|
315 |
-
" images = []\n",
|
316 |
-
" pbar = tqdm(\n",
|
317 |
-
" range(num_images // jax.device_count()),\n",
|
318 |
-
" desc=\"Generating Images\",\n",
|
319 |
-
" leave=True,\n",
|
320 |
-
" )\n",
|
321 |
-
" for i in pbar:\n",
|
322 |
-
" key, subkey = jax.random.split(key)\n",
|
323 |
-
" encoded_images = p_generate(\n",
|
324 |
-
" tokenized_prompt, shard_prng_key(subkey), model_params\n",
|
325 |
-
" )\n",
|
326 |
-
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
327 |
-
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
328 |
-
" decoded_images = decoded_images.clip(0.0, 1.0).reshape(\n",
|
329 |
-
" (-1, 256, 256, 3)\n",
|
330 |
-
" )\n",
|
331 |
-
" for img in decoded_images:\n",
|
332 |
-
" images.append(\n",
|
333 |
-
" Image.fromarray(np.asarray(img * 255, dtype=np.uint8))\n",
|
334 |
-
" )\n",
|
335 |
-
"\n",
|
336 |
-
" def add_clip_results(results, processor, p_clip, clip_params):\n",
|
337 |
-
" clip_inputs = processor(\n",
|
338 |
-
" text=batch,\n",
|
339 |
-
" images=images,\n",
|
340 |
-
" return_tensors=\"np\",\n",
|
341 |
-
" padding=\"max_length\",\n",
|
342 |
-
" max_length=77,\n",
|
343 |
-
" truncation=True,\n",
|
344 |
-
" ).data\n",
|
345 |
-
" # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
|
346 |
-
" images_per_prompt_indices = np.asarray(\n",
|
347 |
-
" range(0, len(images), batch_size)\n",
|
348 |
-
" )\n",
|
349 |
-
" clip_inputs[\"pixel_values\"] = jnp.concatenate(\n",
|
350 |
-
" list(\n",
|
351 |
-
" clip_inputs[\"pixel_values\"][images_per_prompt_indices + i]\n",
|
352 |
-
" for i in range(batch_size)\n",
|
353 |
-
" )\n",
|
354 |
-
" )\n",
|
355 |
-
" clip_inputs = shard(clip_inputs)\n",
|
356 |
-
" logits = p_clip(clip_inputs, clip_params)\n",
|
357 |
-
" logits = logits.reshape(-1, num_images)\n",
|
358 |
-
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
359 |
-
" logits = jax.device_get(logits)\n",
|
360 |
-
" # add to results table\n",
|
361 |
-
" for i, (idx, scores, sample) in enumerate(\n",
|
362 |
-
" zip(top_scores, logits, batch)\n",
|
363 |
-
" ):\n",
|
364 |
-
" if sample == padding_item:\n",
|
365 |
-
" continue\n",
|
366 |
-
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
367 |
-
" top_images = [\n",
|
368 |
-
" wandb.Image(cur_images[x], caption=f\"Score: {scores[x]:.2f}\")\n",
|
369 |
-
" for x in idx\n",
|
370 |
-
" ]\n",
|
371 |
-
" results.append([sample] + top_images)\n",
|
372 |
-
"\n",
|
373 |
-
" # get clip scores\n",
|
374 |
-
" pbar.set_description(\"Calculating CLIP 16 scores\")\n",
|
375 |
-
" add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
|
376 |
-
"\n",
|
377 |
-
" # get clip 32 scores\n",
|
378 |
-
" if add_clip_32:\n",
|
379 |
-
" pbar.set_description(\"Calculating CLIP 32 scores\")\n",
|
380 |
-
" add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
|
381 |
-
"\n",
|
382 |
-
" pbar.close()\n",
|
383 |
-
"\n",
|
384 |
-
" # log results\n",
|
385 |
-
" table = wandb.Table(columns=columns, data=results16)\n",
|
386 |
-
" run.log({\"Samples\": table, \"version\": version})\n",
|
387 |
-
" wandb.finish()\n",
|
388 |
-
"\n",
|
389 |
-
" if add_clip_32:\n",
|
390 |
-
" run = wandb.init(\n",
|
391 |
-
" job_type=\"inference\",\n",
|
392 |
-
" entity=\"dalle-mini\",\n",
|
393 |
-
" project=\"dalle-mini\",\n",
|
394 |
-
" config=training_config,\n",
|
395 |
-
" id=f\"{run_id}-clip32{suffix}\",\n",
|
396 |
-
" resume=\"allow\",\n",
|
397 |
-
" )\n",
|
398 |
-
" table = wandb.Table(columns=columns, data=results32)\n",
|
399 |
-
" run.log({\"Samples\": table, \"version\": version})\n",
|
400 |
-
" wandb.finish()\n",
|
401 |
-
" run = None # ensure we don't log on this run"
|
402 |
-
]
|
403 |
-
},
|
404 |
-
{
|
405 |
-
"cell_type": "code",
|
406 |
-
"execution_count": null,
|
407 |
-
"id": "415d3f54-7226-43de-9eea-4283a948dc93",
|
408 |
-
"metadata": {},
|
409 |
-
"outputs": [],
|
410 |
-
"source": []
|
411 |
-
}
|
412 |
-
],
|
413 |
-
"metadata": {
|
414 |
-
"kernelspec": {
|
415 |
-
"display_name": "Python 3 (ipykernel)",
|
416 |
-
"language": "python",
|
417 |
-
"name": "python3"
|
418 |
-
},
|
419 |
-
"language_info": {
|
420 |
-
"codemirror_mode": {
|
421 |
-
"name": "ipython",
|
422 |
-
"version": 3
|
423 |
-
},
|
424 |
-
"file_extension": ".py",
|
425 |
-
"mimetype": "text/x-python",
|
426 |
-
"name": "python",
|
427 |
-
"nbconvert_exporter": "python",
|
428 |
-
"pygments_lexer": "ipython3",
|
429 |
-
"version": "3.9.7"
|
430 |
-
}
|
431 |
-
},
|
432 |
-
"nbformat": 4,
|
433 |
-
"nbformat_minor": 5
|
434 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/inference/samples.txt
DELETED
@@ -1,124 +0,0 @@
|
|
1 |
-
t-shirt, size M
|
2 |
-
flower dress, size M
|
3 |
-
white snow covered mountain under blue sky during daytime
|
4 |
-
aerial view of the beach during daytime
|
5 |
-
aerial view of the beach at night
|
6 |
-
a beautiful sunset at a beach with a shell on the shore
|
7 |
-
a farmhouse surrounded by beautiful flowers
|
8 |
-
sunset over green mountains
|
9 |
-
a photo of san francisco golden gate bridge
|
10 |
-
painting of an oniric forest glade surrounded by tall trees
|
11 |
-
a graphite sketch of a gothic cathedral
|
12 |
-
a graphite sketch of Elon Musk
|
13 |
-
still life in the style of Kandinsky
|
14 |
-
still life in the style of Picasso
|
15 |
-
a colorful stairway to heaven
|
16 |
-
a background consisting of colors blue, green, and red
|
17 |
-
Mohammed Ali and Mike Tyson in a match
|
18 |
-
Pele and Maradona in a match
|
19 |
-
view of Mars from space
|
20 |
-
a picture of the Eiffel tower on the moon
|
21 |
-
a picture of the Eiffel tower on the moon, Earth is in the background
|
22 |
-
watercolor of the Eiffel tower on the moon
|
23 |
-
the moon is a skull
|
24 |
-
epic sword fight
|
25 |
-
underwater cathedral
|
26 |
-
a photo of a fantasy version of New York City
|
27 |
-
a picture of fantasy kingdoms
|
28 |
-
a volcano erupting next to San Francisco golden gate bridge
|
29 |
-
Paris in a far future, futuristic Paris
|
30 |
-
real painting of an alien from Monet
|
31 |
-
the communist statue of liberty
|
32 |
-
robots taking control over humans
|
33 |
-
illustration of an astronaut in a space suit playing guitar
|
34 |
-
a clown wearing a spacesuit floating in space
|
35 |
-
a dog playing with a ball
|
36 |
-
a cat sits on top of an alligator
|
37 |
-
a very cute cat laying by a big bike
|
38 |
-
a rat holding a red lightsaber in a white background
|
39 |
-
a very cute giraffe making a funny face
|
40 |
-
A unicorn is passing by a rainbow in a field of flowers
|
41 |
-
an elephant made of carrots
|
42 |
-
an elephant on a unicycle during a circus
|
43 |
-
photography of a penguin watching television
|
44 |
-
a penguin is walking on the Moon, Earth is in the background
|
45 |
-
a penguin standing on a tower of books holds onto a rope from a helicopter
|
46 |
-
rat wearing a crown
|
47 |
-
looking into the sky, 10 airplanes are seen overhead
|
48 |
-
shelves filled with books and alchemy potion bottles
|
49 |
-
this is a detailed high-resolution scan of a human brain
|
50 |
-
a restaurant menu
|
51 |
-
a bottle of coca-cola on a table
|
52 |
-
a peanut
|
53 |
-
a cross-section view of a walnut
|
54 |
-
a living room with two white armchairs and a painting of the collosseum. The painting is mounted above a modern fireplace.
|
55 |
-
a long line of alternating green and red blocks
|
56 |
-
a long line of green blocks on a beach at subset
|
57 |
-
a long line of peaches on a beach at sunset
|
58 |
-
a picture of a castle from minecraft
|
59 |
-
a cute pikachu teapot
|
60 |
-
an illustration of pikachu sitting on a bench eating an ice cream
|
61 |
-
mario is jumping over a zebra
|
62 |
-
famous anime hero
|
63 |
-
star wars concept art
|
64 |
-
Cartoon of a carrot with big eyes
|
65 |
-
a cartoon of a superhero bear
|
66 |
-
an illustration of a cute skeleton wearing a blue hoodie
|
67 |
-
illustration of a baby shark swimming around corals
|
68 |
-
an illustration of an avocado in a beanie riding a motorcycle
|
69 |
-
logo of a robot wearing glasses and reading a book
|
70 |
-
illustration of a cactus lifting weigths
|
71 |
-
logo of a cactus lifting weights
|
72 |
-
a photo of a camera from the future
|
73 |
-
a skeleton with the shape of a spider
|
74 |
-
a collection of glasses is sitting on a table
|
75 |
-
a painting of a capybara sitting on a mountain during fall in surrealist style
|
76 |
-
a pentagonal green clock
|
77 |
-
a small red block sitting on a large green block
|
78 |
-
a storefront that has the word 'openai' written on it
|
79 |
-
a tatoo of a black broccoli
|
80 |
-
a variety of clocks is sitting on a table
|
81 |
-
a table has a train model on it with other cars and things
|
82 |
-
a pixel art illustration of an eagle sitting in a field in the afternoon
|
83 |
-
an emoji of a baby fox wearing a blue hat, green gloves, red shirt, and yellow pants
|
84 |
-
an emoji of a baby penguin wearing a blue hat, blue gloves, red shirt, and green pants
|
85 |
-
an extreme close-up view of a capybara sitting in a field
|
86 |
-
an illustration of a baby cucumber with a mustache playing chess
|
87 |
-
an illustration of a baby daikon radish in a tutu walking a dog
|
88 |
-
an illustration of a baby hedgehog in a cape staring at its reflection in a mirror
|
89 |
-
an illustration of a baby panda with headphones holding an umbrella in the rain
|
90 |
-
urinals are lined up in a jungle
|
91 |
-
a muscular banana sitting upright on a bench smoking watching a banana on television, high definition photography
|
92 |
-
a human face
|
93 |
-
a person is holding a phone and a waterbottle, running a marathon
|
94 |
-
a child eating a birthday cake near some balloons
|
95 |
-
Young woman riding her bike through the forest
|
96 |
-
the best soccer team of the world
|
97 |
-
the best football team of the world
|
98 |
-
the best basketball team of the world
|
99 |
-
happy, happiness
|
100 |
-
sad, sadness
|
101 |
-
the representation of infinity
|
102 |
-
the end of the world
|
103 |
-
the last sunrise on earth
|
104 |
-
a portrait of a nightmare creature watching at you
|
105 |
-
an avocado armchair
|
106 |
-
an armchair in the shape of an avocado
|
107 |
-
illustration of an avocado armchair
|
108 |
-
illustration of an armchair in the shape of an avocado
|
109 |
-
logo of an avocado armchair
|
110 |
-
an avocado armchair flying into space
|
111 |
-
a cute avocado armchair singing karaoke on stage in front of a crowd of strawberry shaped lamps
|
112 |
-
an illustration of an avocado in a christmas sweater staring at its reflection in a mirror
|
113 |
-
illustration of an avocado armchair getting married to a pineapple
|
114 |
-
half human half cat
|
115 |
-
half human half dog
|
116 |
-
half human half pen
|
117 |
-
half human half garbage
|
118 |
-
half human half avocado
|
119 |
-
half human half Eiffel tower
|
120 |
-
a propaganda poster for transhumanism
|
121 |
-
a propaganda poster for building a space elevator
|
122 |
-
a beautiful epic fantasy painting of a space elevator
|
123 |
-
a transformer architecture
|
124 |
-
a transformer in real life
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|