Spaces:
Runtime error
Runtime error
Niv Sardi
commited on
Commit
·
c92a751
1
Parent(s):
f0a5526
augment: accept command line arguments
Browse files- python/augment.py +140 -113
python/augment.py
CHANGED
@@ -12,6 +12,8 @@ import cv2
|
|
12 |
import filetype
|
13 |
from filetype.match import image_matchers
|
14 |
|
|
|
|
|
15 |
import imgaug as ia
|
16 |
from imgaug import augmenters as iaa
|
17 |
from imgaug.augmentables.batches import UnnormalizedBatch
|
@@ -23,145 +25,170 @@ import pipelines
|
|
23 |
|
24 |
BATCH_SIZE = 16
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
logo_images = []
|
29 |
-
logo_alphas = []
|
30 |
-
logo_labels = {}
|
31 |
-
|
32 |
-
db = {}
|
33 |
-
with open(defaults.MAIN_CSV_PATH, 'r') as f:
|
34 |
-
reader = csv.DictReader(f)
|
35 |
-
db = {e.bco: e for e in [Entity.from_dict(d) for d in reader]}
|
36 |
|
37 |
-
|
|
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
}
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
continue
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
img = cv2.imread(d.path, cv2.IMREAD_UNCHANGED)
|
53 |
-
else:
|
54 |
-
png = svg2png(url=d.path)
|
55 |
-
img = cv2.imdecode(np.asarray(bytearray(png), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
56 |
-
label = db[d.name.split('.')[0]].id
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
61 |
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
-
|
69 |
-
|
|
|
70 |
|
71 |
-
|
|
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
d = cv2.merge([b, g, r])
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
logo_labels.update({d.tobytes(): label})
|
80 |
|
81 |
-
|
82 |
-
# because imgaug is pretty strict about what data it will process
|
83 |
-
# and that we want the alpha layer to pass the same transformations as the orig
|
84 |
-
logo_alphas.append(np.dstack((alpha, alpha, alpha)).astype('float32'))
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
|
90 |
-
|
91 |
-
#
|
92 |
-
|
93 |
|
94 |
-
#
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
s = (i*BATCH_SIZE)%n
|
103 |
-
e = min(s + BATCH_SIZE, n)
|
104 |
-
le = max(0, BATCH_SIZE - (e - s))
|
105 |
|
106 |
-
|
107 |
-
|
|
|
108 |
|
109 |
-
|
|
|
|
|
|
|
110 |
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
pipeline = pipelines.HUGE
|
116 |
|
117 |
-
|
118 |
-
for b in lst:
|
119 |
-
print(f"Loading next unaugmented batch...")
|
120 |
-
yield b
|
121 |
|
122 |
-
|
123 |
|
124 |
-
|
125 |
-
|
|
|
|
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
148 |
|
149 |
try:
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
except
|
155 |
-
print(f'couldnt
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
import filetype
|
13 |
from filetype.match import image_matchers
|
14 |
|
15 |
+
from progress.bar import ChargingBar
|
16 |
+
|
17 |
import imgaug as ia
|
18 |
from imgaug import augmenters as iaa
|
19 |
from imgaug.augmentables.batches import UnnormalizedBatch
|
|
|
25 |
|
26 |
BATCH_SIZE = 16
|
27 |
|
28 |
+
def process(args):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
dest_images_path = os.path.join(args.dest, 'images')
|
31 |
+
dest_labels_path = os.path.join(args.dest, 'labels')
|
32 |
|
33 |
+
mkdir.make_dirs([dest_images_path, dest_labels_path])
|
34 |
+
logo_images = []
|
35 |
+
logo_alphas = []
|
36 |
+
logo_labels = {}
|
37 |
|
38 |
+
db = {}
|
39 |
+
with open(defaults.MAIN_CSV_PATH, 'r') as f:
|
40 |
+
reader = csv.DictReader(f)
|
41 |
+
db = {e.bco: e for e in [Entity.from_dict(d) for d in reader]}
|
|
|
42 |
|
43 |
+
background_images = [d for d in os.scandir(args.backgrounds)]
|
44 |
+
assert(len(background_images))
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
stats = {
|
47 |
+
'failed': 0,
|
48 |
+
'ok': 0
|
49 |
+
}
|
50 |
|
51 |
+
for d in os.scandir(args.logos):
|
52 |
+
img = None
|
53 |
+
if not d.is_file():
|
54 |
+
stats['failed'] += 1
|
55 |
+
continue
|
56 |
|
57 |
+
try:
|
58 |
+
if filetype.match(d.path, matchers=image_matchers):
|
59 |
+
img = cv2.imread(d.path, cv2.IMREAD_UNCHANGED)
|
60 |
+
else:
|
61 |
+
png = svg2png(url=d.path)
|
62 |
+
img = cv2.imdecode(np.asarray(bytearray(png), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
63 |
+
label = db[d.name.split('.')[0]].id
|
64 |
|
65 |
+
(h, w, c) = img.shape
|
66 |
+
if c == 3:
|
67 |
+
img = imtool.add_alpha(img)
|
68 |
|
69 |
+
if img.ndim < 3:
|
70 |
+
print(f'very bad dim: {img.ndim}')
|
71 |
|
72 |
+
img = imtool.remove_white(img)
|
73 |
+
(h, w, c) = img.shape
|
|
|
74 |
|
75 |
+
assert(w > 10)
|
76 |
+
assert(h > 10)
|
|
|
77 |
|
78 |
+
stats['ok'] += 1
|
|
|
|
|
|
|
79 |
|
80 |
+
(b, g, r, _) = cv2.split(img)
|
81 |
+
alpha = img[:, :, 3]/255
|
82 |
+
d = cv2.merge([b, g, r])
|
83 |
|
84 |
+
logo_images.append(d)
|
85 |
+
# tried id() tried __array_interface__, tried tagging, nothing works
|
86 |
+
logo_labels.update({d.tobytes(): label})
|
87 |
|
88 |
+
# XXX(xaiki): we pass alpha as a float32 heatmap,
|
89 |
+
# because imgaug is pretty strict about what data it will process
|
90 |
+
# and that we want the alpha layer to pass the same transformations as the orig
|
91 |
+
logo_alphas.append(np.dstack((alpha, alpha, alpha)).astype('float32'))
|
92 |
|
93 |
+
except Exception as e:
|
94 |
+
stats['failed'] += 1
|
95 |
+
print(f'error loading: {d.path}: {e}')
|
|
|
|
|
|
|
96 |
|
97 |
+
print(stats)
|
98 |
+
#print(len(logo_alphas), len(logo_images), len(logo_labels))
|
99 |
+
assert(len(logo_alphas) == len(logo_images))
|
100 |
|
101 |
+
# so that we don't get a lot of the same logos on the same page.
|
102 |
+
zipped = list(zip(logo_images, logo_alphas))
|
103 |
+
random.shuffle(zipped)
|
104 |
+
logo_images, logo_alphas = zip(*zipped)
|
105 |
|
106 |
+
n = len(logo_images)
|
107 |
+
batches = []
|
108 |
+
for i in range(math.floor(n*2/BATCH_SIZE)):
|
109 |
+
s = (i*BATCH_SIZE)%n
|
110 |
+
e = min(s + BATCH_SIZE, n)
|
111 |
+
le = max(0, BATCH_SIZE - (e - s))
|
112 |
|
113 |
+
a = logo_images[0:le] + logo_images[s:e]
|
114 |
+
h = logo_alphas[0:le] + logo_alphas[s:e]
|
|
|
115 |
|
116 |
+
assert(len(a) == BATCH_SIZE)
|
|
|
|
|
|
|
117 |
|
118 |
+
batches.append(UnnormalizedBatch(images=a,heatmaps=h))
|
119 |
|
120 |
+
bar = ChargingBar('Processing', max=len(batches))
|
121 |
+
# We use a single, very fast augmenter here to show that batches
|
122 |
+
# are only loaded once there is space again in the buffer.
|
123 |
+
pipeline = pipelines.HUGE
|
124 |
|
125 |
+
def create_generator(lst):
|
126 |
+
for b in lst:
|
127 |
+
print(f"Loading next unaugmented batch...")
|
128 |
+
yield b
|
129 |
+
|
130 |
+
batches_generator = create_generator(batches)
|
131 |
+
|
132 |
+
with pipeline.pool(processes=-1, seed=1) as pool:
|
133 |
+
batches_aug = pool.imap_batches(batches_generator, output_buffer_size=5)
|
134 |
+
|
135 |
+
print(f"Requesting next augmented batch...")
|
136 |
+
for i, batch_aug in enumerate(batches_aug):
|
137 |
+
idx = list(range(len(batch_aug.images_aug)))
|
138 |
+
random.shuffle(idx)
|
139 |
+
for j, d in enumerate(background_images):
|
140 |
+
img = imtool.remove_white(cv2.imread(d.path))
|
141 |
+
basename = d.name.replace('.png', '') + f'.{i}.{j}'
|
142 |
|
143 |
+
anotations = []
|
144 |
+
for k in range(math.floor(len(batch_aug.images_aug)/3)):
|
145 |
+
logo_idx = (j+k*4)%len(batch_aug.images_aug)
|
146 |
+
|
147 |
+
orig = batch_aug.images_unaug[logo_idx]
|
148 |
+
label = logo_labels[orig.tobytes()]
|
149 |
+
logo = batch_aug.images_aug[logo_idx]
|
150 |
|
151 |
+
assert(logo.shape == orig.shape)
|
152 |
+
|
153 |
+
# XXX(xaiki): we get alpha from heatmap, but will only use one channel
|
154 |
+
# we could make mix_alpha into mix_mask and pass all 3 chanels
|
155 |
+
alpha = cv2.split(batch_aug.heatmaps_aug[logo_idx])
|
156 |
|
157 |
+
try:
|
158 |
+
bb = imtool.mix_alpha(img, logo, alpha[0],
|
159 |
+
random.random(), random.random())
|
160 |
+
c = bb.to_centroid(img.shape)
|
161 |
+
anotations.append(c.to_anotation(label))
|
162 |
+
except AssertionError as e:
|
163 |
+
print(f'couldnt process {i}, {j}: {e}')
|
164 |
|
165 |
try:
|
166 |
+
cv2.imwrite(f'{dest_images_path}/{basename}.png', img)
|
167 |
+
label_path = f"{dest_labels_path}/{basename}.txt"
|
168 |
+
with open(label_path, 'a') as f:
|
169 |
+
f.write('\n'.join(anotations))
|
170 |
+
except Exception:
|
171 |
+
print(f'couldnt write image {basename}')
|
172 |
+
|
173 |
+
if i < len(batches)-1:
|
174 |
+
print("Requesting next augmented batch...")
|
175 |
+
bar.next()
|
176 |
+
bar.finish()
|
177 |
+
|
178 |
+
if __name__ == '__main__':
|
179 |
+
import argparse
|
180 |
+
|
181 |
+
parser = argparse.ArgumentParser(description='mix backgrounds and logos into augmented data for YOLO')
|
182 |
+
parser.add_argument('--logos', metavar='logos', type=str,
|
183 |
+
default=defaults.LOGOS_DATA_PATH,
|
184 |
+
help='dir containing logos')
|
185 |
+
parser.add_argument('--backgrounds', metavar='backgrounds', type=str,
|
186 |
+
|
187 |
+
default=defaults.IMAGES_PATH,
|
188 |
+
help='dir containing background plates')
|
189 |
+
parser.add_argument('--dst', dest='dest', type=str,
|
190 |
+
default=defaults.AUGMENTED_DATA_PATH,
|
191 |
+
help='dest dir')
|
192 |
+
|
193 |
+
args = parser.parse_args()
|
194 |
+
process(args)
|