Pedro Cuenca commited on
Commit
f62b045
·
1 Parent(s): 6e79248

refactor: move `captioned_strip` to library.

Browse files

Different versions are used in several places.


Former-commit-id: e74b5ae17c035e953d2859a609d651464db0c64c

dalle_mini/helpers.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw, ImageFont
2
+
3
+ def captioned_strip(images, caption):
4
+ increased_h = 0 if caption is None else 48
5
+ w, h = images[0].size[0], images[0].size[1]
6
+ img = Image.new("RGB", (len(images)*w, h + increased_h))
7
+ for i, img_ in enumerate(images):
8
+ img.paste(img_, (i*w, increased_h))
9
+
10
+ if caption is not None:
11
+ draw = ImageDraw.Draw(img)
12
+ font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
13
+ draw.text((20, 3), caption, (255,255,255), font=font)
14
+ return img
dev/predictions/wandb-examples-from-backend.py CHANGED
@@ -6,6 +6,7 @@ import wandb
6
  import os
7
 
8
  from dalle_mini.backend import ServiceError, get_images_from_backend
 
9
 
10
  os.environ["WANDB_SILENT"] = "true"
11
  os.environ["WANDB_CONSOLE"] = "off"
@@ -19,16 +20,6 @@ run = wandb.init(id=id,
19
  resume="allow"
20
  )
21
 
22
- def captioned_strip(images, caption):
23
- w, h = images[0].size[0], images[0].size[1]
24
- img = Image.new("RGB", (len(images)*w, h + 48))
25
- for i, img_ in enumerate(images):
26
- img.paste(img_, (i*w, 48))
27
- draw = ImageDraw.Draw(img)
28
- font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
29
- draw.text((20, 3), caption, (255,255,255), font=font)
30
- return img
31
-
32
  def log_to_wandb(prompts):
33
  try:
34
  backend_url = os.environ["BACKEND_SERVER"]
 
6
  import os
7
 
8
  from dalle_mini.backend import ServiceError, get_images_from_backend
9
+ from dalle_mini.helpers import captioned_strip
10
 
11
  os.environ["WANDB_SILENT"] = "true"
12
  os.environ["WANDB_CONSOLE"] = "off"
 
20
  resume="allow"
21
  )
22
 
 
 
 
 
 
 
 
 
 
 
23
  def log_to_wandb(prompts):
24
  try:
25
  backend_url = os.environ["BACKEND_SERVER"]
dev/predictions/wandb-examples.py CHANGED
@@ -178,18 +178,8 @@ def clip_top_k(prompt, images, k=8):
178
 
179
  # ## Log to wandb
180
 
181
- from PIL import ImageDraw, ImageFont
182
 
183
- def captioned_strip(images, caption):
184
- w, h = images[0].size[0], images[0].size[1]
185
- img = Image.new("RGB", (len(images)*w, h + 48))
186
- for i, img_ in enumerate(images):
187
- img.paste(img_, (i*w, 48))
188
- draw = ImageDraw.Draw(img)
189
- font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
190
- draw.text((20, 3), caption, (255,255,255), font=font)
191
- return img
192
-
193
  def log_to_wandb(prompts):
194
  strips = []
195
  for prompt in prompts:
 
178
 
179
  # ## Log to wandb
180
 
181
+ from dalle_mini.helpers import captioned_strip
182
 
 
 
 
 
 
 
 
 
 
 
183
  def log_to_wandb(prompts):
184
  strips = []
185
  for prompt in prompts: