Spaces:
Runtime error
Runtime error
David Piscasio
commited on
Commit
·
0c094b2
1
Parent(s):
32da7be
Added util folder
Browse files- util/__init__.py +1 -0
- util/__pycache__/__init__.cpython-38.pyc +0 -0
- util/__pycache__/html.cpython-38.pyc +0 -0
- util/__pycache__/util.cpython-38.pyc +0 -0
- util/__pycache__/visualizer.cpython-38.pyc +0 -0
- util/get_data.py +110 -0
- util/html.py +86 -0
- util/image_pool.py +54 -0
- util/util.py +103 -0
- util/visualizer.py +257 -0
util/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""This package includes a miscellaneous collection of useful helper functions."""
|
util/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (295 Bytes). View file
|
|
util/__pycache__/html.cpython-38.pyc
ADDED
Binary file (3.64 kB). View file
|
|
util/__pycache__/util.cpython-38.pyc
ADDED
Binary file (3.23 kB). View file
|
|
util/__pycache__/visualizer.cpython-38.pyc
ADDED
Binary file (9.42 kB). View file
|
|
util/get_data.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
import tarfile
|
4 |
+
import requests
|
5 |
+
from warnings import warn
|
6 |
+
from zipfile import ZipFile
|
7 |
+
from bs4 import BeautifulSoup
|
8 |
+
from os.path import abspath, isdir, join, basename
|
9 |
+
|
10 |
+
|
11 |
+
class GetData(object):
|
12 |
+
"""A Python script for downloading CycleGAN or pix2pix datasets.
|
13 |
+
|
14 |
+
Parameters:
|
15 |
+
technique (str) -- One of: 'cyclegan' or 'pix2pix'.
|
16 |
+
verbose (bool) -- If True, print additional information.
|
17 |
+
|
18 |
+
Examples:
|
19 |
+
>>> from util.get_data import GetData
|
20 |
+
>>> gd = GetData(technique='cyclegan')
|
21 |
+
>>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
|
22 |
+
|
23 |
+
Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
|
24 |
+
and 'scripts/download_cyclegan_model.sh'.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, technique='cyclegan', verbose=True):
|
28 |
+
url_dict = {
|
29 |
+
'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
|
30 |
+
'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
|
31 |
+
}
|
32 |
+
self.url = url_dict.get(technique.lower())
|
33 |
+
self._verbose = verbose
|
34 |
+
|
35 |
+
def _print(self, text):
|
36 |
+
if self._verbose:
|
37 |
+
print(text)
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def _get_options(r):
|
41 |
+
soup = BeautifulSoup(r.text, 'lxml')
|
42 |
+
options = [h.text for h in soup.find_all('a', href=True)
|
43 |
+
if h.text.endswith(('.zip', 'tar.gz'))]
|
44 |
+
return options
|
45 |
+
|
46 |
+
def _present_options(self):
|
47 |
+
r = requests.get(self.url)
|
48 |
+
options = self._get_options(r)
|
49 |
+
print('Options:\n')
|
50 |
+
for i, o in enumerate(options):
|
51 |
+
print("{0}: {1}".format(i, o))
|
52 |
+
choice = input("\nPlease enter the number of the "
|
53 |
+
"dataset above you wish to download:")
|
54 |
+
return options[int(choice)]
|
55 |
+
|
56 |
+
def _download_data(self, dataset_url, save_path):
|
57 |
+
if not isdir(save_path):
|
58 |
+
os.makedirs(save_path)
|
59 |
+
|
60 |
+
base = basename(dataset_url)
|
61 |
+
temp_save_path = join(save_path, base)
|
62 |
+
|
63 |
+
with open(temp_save_path, "wb") as f:
|
64 |
+
r = requests.get(dataset_url)
|
65 |
+
f.write(r.content)
|
66 |
+
|
67 |
+
if base.endswith('.tar.gz'):
|
68 |
+
obj = tarfile.open(temp_save_path)
|
69 |
+
elif base.endswith('.zip'):
|
70 |
+
obj = ZipFile(temp_save_path, 'r')
|
71 |
+
else:
|
72 |
+
raise ValueError("Unknown File Type: {0}.".format(base))
|
73 |
+
|
74 |
+
self._print("Unpacking Data...")
|
75 |
+
obj.extractall(save_path)
|
76 |
+
obj.close()
|
77 |
+
os.remove(temp_save_path)
|
78 |
+
|
79 |
+
def get(self, save_path, dataset=None):
|
80 |
+
"""
|
81 |
+
|
82 |
+
Download a dataset.
|
83 |
+
|
84 |
+
Parameters:
|
85 |
+
save_path (str) -- A directory to save the data to.
|
86 |
+
dataset (str) -- (optional). A specific dataset to download.
|
87 |
+
Note: this must include the file extension.
|
88 |
+
If None, options will be presented for you
|
89 |
+
to choose from.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
save_path_full (str) -- the absolute path to the downloaded data.
|
93 |
+
|
94 |
+
"""
|
95 |
+
if dataset is None:
|
96 |
+
selected_dataset = self._present_options()
|
97 |
+
else:
|
98 |
+
selected_dataset = dataset
|
99 |
+
|
100 |
+
save_path_full = join(save_path, selected_dataset.split('.')[0])
|
101 |
+
|
102 |
+
if isdir(save_path_full):
|
103 |
+
warn("\n'{0}' already exists. Voiding Download.".format(
|
104 |
+
save_path_full))
|
105 |
+
else:
|
106 |
+
self._print('Downloading Data...')
|
107 |
+
url = "{0}/{1}".format(self.url, selected_dataset)
|
108 |
+
self._download_data(url, save_path=save_path)
|
109 |
+
|
110 |
+
return abspath(save_path_full)
|
util/html.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dominate
|
2 |
+
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class HTML:
|
7 |
+
"""This HTML class allows us to save images and write texts into a single HTML file.
|
8 |
+
|
9 |
+
It consists of functions such as <add_header> (add a text header to the HTML file),
|
10 |
+
<add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
|
11 |
+
It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, web_dir, title, refresh=0):
|
15 |
+
"""Initialize the HTML classes
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
|
19 |
+
title (str) -- the webpage name
|
20 |
+
refresh (int) -- how often the website refresh itself; if 0; no refreshing
|
21 |
+
"""
|
22 |
+
self.title = title
|
23 |
+
self.web_dir = web_dir
|
24 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
25 |
+
if not os.path.exists(self.web_dir):
|
26 |
+
os.makedirs(self.web_dir)
|
27 |
+
if not os.path.exists(self.img_dir):
|
28 |
+
os.makedirs(self.img_dir)
|
29 |
+
|
30 |
+
self.doc = dominate.document(title=title)
|
31 |
+
if refresh > 0:
|
32 |
+
with self.doc.head:
|
33 |
+
meta(http_equiv="refresh", content=str(refresh))
|
34 |
+
|
35 |
+
def get_image_dir(self):
|
36 |
+
"""Return the directory that stores images"""
|
37 |
+
return self.img_dir
|
38 |
+
|
39 |
+
def add_header(self, text):
|
40 |
+
"""Insert a header to the HTML file
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
text (str) -- the header text
|
44 |
+
"""
|
45 |
+
with self.doc:
|
46 |
+
h3(text)
|
47 |
+
|
48 |
+
def add_images(self, ims, txts, links, width=400):
|
49 |
+
"""add images to the HTML file
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
ims (str list) -- a list of image paths
|
53 |
+
txts (str list) -- a list of image names shown on the website
|
54 |
+
links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
|
55 |
+
"""
|
56 |
+
self.t = table(border=1, style="table-layout: fixed;") # Insert a table
|
57 |
+
self.doc.add(self.t)
|
58 |
+
with self.t:
|
59 |
+
with tr():
|
60 |
+
for im, txt, link in zip(ims, txts, links):
|
61 |
+
with td(style="word-wrap: break-word;", halign="center", valign="top"):
|
62 |
+
with p():
|
63 |
+
with a(href=os.path.join('images', link)):
|
64 |
+
img(style="width:%dpx" % width, src=os.path.join('images', im))
|
65 |
+
br()
|
66 |
+
p(txt)
|
67 |
+
|
68 |
+
def save(self):
|
69 |
+
"""save the current content to the HMTL file"""
|
70 |
+
html_file = '%s/index.html' % self.web_dir
|
71 |
+
f = open(html_file, 'wt')
|
72 |
+
f.write(self.doc.render())
|
73 |
+
f.close()
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__': # we show an example usage here.
|
77 |
+
html = HTML('web/', 'test_html')
|
78 |
+
html.add_header('hello world')
|
79 |
+
|
80 |
+
ims, txts, links = [], [], []
|
81 |
+
for n in range(4):
|
82 |
+
ims.append('image_%d.png' % n)
|
83 |
+
txts.append('text_%d' % n)
|
84 |
+
links.append('image_%d.png' % n)
|
85 |
+
html.add_images(ims, txts, links)
|
86 |
+
html.save()
|
util/image_pool.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class ImagePool():
|
6 |
+
"""This class implements an image buffer that stores previously generated images.
|
7 |
+
|
8 |
+
This buffer enables us to update discriminators using a history of generated images
|
9 |
+
rather than the ones produced by the latest generators.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, pool_size):
|
13 |
+
"""Initialize the ImagePool class
|
14 |
+
|
15 |
+
Parameters:
|
16 |
+
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
|
17 |
+
"""
|
18 |
+
self.pool_size = pool_size
|
19 |
+
if self.pool_size > 0: # create an empty pool
|
20 |
+
self.num_imgs = 0
|
21 |
+
self.images = []
|
22 |
+
|
23 |
+
def query(self, images):
|
24 |
+
"""Return an image from the pool.
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
images: the latest generated images from the generator
|
28 |
+
|
29 |
+
Returns images from the buffer.
|
30 |
+
|
31 |
+
By 50/100, the buffer will return input images.
|
32 |
+
By 50/100, the buffer will return images previously stored in the buffer,
|
33 |
+
and insert the current images to the buffer.
|
34 |
+
"""
|
35 |
+
if self.pool_size == 0: # if the buffer size is 0, do nothing
|
36 |
+
return images
|
37 |
+
return_images = []
|
38 |
+
for image in images:
|
39 |
+
image = torch.unsqueeze(image.data, 0)
|
40 |
+
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
|
41 |
+
self.num_imgs = self.num_imgs + 1
|
42 |
+
self.images.append(image)
|
43 |
+
return_images.append(image)
|
44 |
+
else:
|
45 |
+
p = random.uniform(0, 1)
|
46 |
+
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
|
47 |
+
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
|
48 |
+
tmp = self.images[random_id].clone()
|
49 |
+
self.images[random_id] = image
|
50 |
+
return_images.append(tmp)
|
51 |
+
else: # by another 50% chance, the buffer will return the current image
|
52 |
+
return_images.append(image)
|
53 |
+
return_images = torch.cat(return_images, 0) # collect all the images and return
|
54 |
+
return return_images
|
util/util.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module contains simple helper functions """
|
2 |
+
from __future__ import print_function
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import os
|
7 |
+
|
8 |
+
|
9 |
+
def tensor2im(input_image, imtype=np.uint8):
|
10 |
+
""""Converts a Tensor array into a numpy image array.
|
11 |
+
|
12 |
+
Parameters:
|
13 |
+
input_image (tensor) -- the input image tensor array
|
14 |
+
imtype (type) -- the desired type of the converted numpy array
|
15 |
+
"""
|
16 |
+
if not isinstance(input_image, np.ndarray):
|
17 |
+
if isinstance(input_image, torch.Tensor): # get the data from a variable
|
18 |
+
image_tensor = input_image.data
|
19 |
+
else:
|
20 |
+
return input_image
|
21 |
+
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
|
22 |
+
if image_numpy.shape[0] == 1: # grayscale to RGB
|
23 |
+
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
24 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
|
25 |
+
else: # if it is a numpy array, do nothing
|
26 |
+
image_numpy = input_image
|
27 |
+
return image_numpy.astype(imtype)
|
28 |
+
|
29 |
+
|
30 |
+
def diagnose_network(net, name='network'):
|
31 |
+
"""Calculate and print the mean of average absolute(gradients)
|
32 |
+
|
33 |
+
Parameters:
|
34 |
+
net (torch network) -- Torch network
|
35 |
+
name (str) -- the name of the network
|
36 |
+
"""
|
37 |
+
mean = 0.0
|
38 |
+
count = 0
|
39 |
+
for param in net.parameters():
|
40 |
+
if param.grad is not None:
|
41 |
+
mean += torch.mean(torch.abs(param.grad.data))
|
42 |
+
count += 1
|
43 |
+
if count > 0:
|
44 |
+
mean = mean / count
|
45 |
+
print(name)
|
46 |
+
print(mean)
|
47 |
+
|
48 |
+
|
49 |
+
def save_image(image_numpy, image_path, aspect_ratio=1.0):
|
50 |
+
"""Save a numpy image to the disk
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
image_numpy (numpy array) -- input numpy array
|
54 |
+
image_path (str) -- the path of the image
|
55 |
+
"""
|
56 |
+
|
57 |
+
image_pil = Image.fromarray(image_numpy)
|
58 |
+
h, w, _ = image_numpy.shape
|
59 |
+
|
60 |
+
if aspect_ratio > 1.0:
|
61 |
+
image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
|
62 |
+
if aspect_ratio < 1.0:
|
63 |
+
image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
|
64 |
+
image_pil.save(image_path)
|
65 |
+
|
66 |
+
|
67 |
+
def print_numpy(x, val=True, shp=False):
|
68 |
+
"""Print the mean, min, max, median, std, and size of a numpy array
|
69 |
+
|
70 |
+
Parameters:
|
71 |
+
val (bool) -- if print the values of the numpy array
|
72 |
+
shp (bool) -- if print the shape of the numpy array
|
73 |
+
"""
|
74 |
+
x = x.astype(np.float64)
|
75 |
+
if shp:
|
76 |
+
print('shape,', x.shape)
|
77 |
+
if val:
|
78 |
+
x = x.flatten()
|
79 |
+
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
|
80 |
+
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
|
81 |
+
|
82 |
+
|
83 |
+
def mkdirs(paths):
|
84 |
+
"""create empty directories if they don't exist
|
85 |
+
|
86 |
+
Parameters:
|
87 |
+
paths (str list) -- a list of directory paths
|
88 |
+
"""
|
89 |
+
if isinstance(paths, list) and not isinstance(paths, str):
|
90 |
+
for path in paths:
|
91 |
+
mkdir(path)
|
92 |
+
else:
|
93 |
+
mkdir(paths)
|
94 |
+
|
95 |
+
|
96 |
+
def mkdir(path):
|
97 |
+
"""create a single empty directory if it didn't exist
|
98 |
+
|
99 |
+
Parameters:
|
100 |
+
path (str) -- a single directory path
|
101 |
+
"""
|
102 |
+
if not os.path.exists(path):
|
103 |
+
os.makedirs(path)
|
util/visualizer.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import ntpath
|
5 |
+
import time
|
6 |
+
from . import util, html
|
7 |
+
from subprocess import Popen, PIPE
|
8 |
+
|
9 |
+
|
10 |
+
try:
|
11 |
+
import wandb
|
12 |
+
except ImportError:
|
13 |
+
print('Warning: wandb package cannot be found. The option "--use_wandb" will result in error.')
|
14 |
+
|
15 |
+
if sys.version_info[0] == 2:
|
16 |
+
VisdomExceptionBase = Exception
|
17 |
+
else:
|
18 |
+
VisdomExceptionBase = ConnectionError
|
19 |
+
|
20 |
+
|
21 |
+
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256, use_wandb=False):
|
22 |
+
"""Save images to the disk.
|
23 |
+
|
24 |
+
Parameters:
|
25 |
+
webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
|
26 |
+
visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
|
27 |
+
image_path (str) -- the string is used to create image paths
|
28 |
+
aspect_ratio (float) -- the aspect ratio of saved images
|
29 |
+
width (int) -- the images will be resized to width x width
|
30 |
+
|
31 |
+
This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
|
32 |
+
"""
|
33 |
+
image_dir = webpage.get_image_dir()
|
34 |
+
short_path = ntpath.basename(image_path[0])
|
35 |
+
name = os.path.splitext(short_path)[0]
|
36 |
+
|
37 |
+
webpage.add_header(name)
|
38 |
+
ims, txts, links = [], [], []
|
39 |
+
ims_dict = {}
|
40 |
+
for label, im_data in visuals.items():
|
41 |
+
im = util.tensor2im(im_data)
|
42 |
+
image_name = '%s_%s.png' % (name, label)
|
43 |
+
save_path = os.path.join(image_dir, image_name)
|
44 |
+
util.save_image(im, save_path, aspect_ratio=aspect_ratio)
|
45 |
+
ims.append(image_name)
|
46 |
+
txts.append(label)
|
47 |
+
links.append(image_name)
|
48 |
+
if use_wandb:
|
49 |
+
ims_dict[label] = wandb.Image(im)
|
50 |
+
webpage.add_images(ims, txts, links, width=width)
|
51 |
+
if use_wandb:
|
52 |
+
wandb.log(ims_dict)
|
53 |
+
|
54 |
+
|
55 |
+
class Visualizer():
|
56 |
+
"""This class includes several functions that can display/save images and print/save logging information.
|
57 |
+
|
58 |
+
It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self, opt):
|
62 |
+
"""Initialize the Visualizer class
|
63 |
+
|
64 |
+
Parameters:
|
65 |
+
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
66 |
+
Step 1: Cache the training/test options
|
67 |
+
Step 2: connect to a visdom server
|
68 |
+
Step 3: create an HTML object for saveing HTML filters
|
69 |
+
Step 4: create a logging file to store training losses
|
70 |
+
"""
|
71 |
+
self.opt = opt # cache the option
|
72 |
+
self.display_id = opt.display_id
|
73 |
+
self.use_html = opt.isTrain and not opt.no_html
|
74 |
+
self.win_size = opt.display_winsize
|
75 |
+
self.name = opt.name
|
76 |
+
self.port = opt.display_port
|
77 |
+
self.saved = False
|
78 |
+
self.use_wandb = opt.use_wandb
|
79 |
+
self.current_epoch = 0
|
80 |
+
self.ncols = opt.display_ncols
|
81 |
+
|
82 |
+
if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
|
83 |
+
import visdom
|
84 |
+
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
|
85 |
+
if not self.vis.check_connection():
|
86 |
+
self.create_visdom_connections()
|
87 |
+
|
88 |
+
if self.use_wandb:
|
89 |
+
self.wandb_run = wandb.init(project='CycleGAN-and-pix2pix', name=opt.name, config=opt) if not wandb.run else wandb.run
|
90 |
+
self.wandb_run._label(repo='CycleGAN-and-pix2pix')
|
91 |
+
|
92 |
+
if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
|
93 |
+
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
|
94 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
95 |
+
print('create web directory %s...' % self.web_dir)
|
96 |
+
util.mkdirs([self.web_dir, self.img_dir])
|
97 |
+
# create a logging file to store training losses
|
98 |
+
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
|
99 |
+
with open(self.log_name, "a") as log_file:
|
100 |
+
now = time.strftime("%c")
|
101 |
+
log_file.write('================ Training Loss (%s) ================\n' % now)
|
102 |
+
|
103 |
+
def reset(self):
|
104 |
+
"""Reset the self.saved status"""
|
105 |
+
self.saved = False
|
106 |
+
|
107 |
+
def create_visdom_connections(self):
|
108 |
+
"""If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
|
109 |
+
cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
|
110 |
+
print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
|
111 |
+
print('Command: %s' % cmd)
|
112 |
+
Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
|
113 |
+
|
114 |
+
def display_current_results(self, visuals, epoch, save_result):
|
115 |
+
"""Display current results on visdom; save current results to an HTML file.
|
116 |
+
|
117 |
+
Parameters:
|
118 |
+
visuals (OrderedDict) - - dictionary of images to display or save
|
119 |
+
epoch (int) - - the current epoch
|
120 |
+
save_result (bool) - - if save the current results to an HTML file
|
121 |
+
"""
|
122 |
+
if self.display_id > 0: # show images in the browser using visdom
|
123 |
+
ncols = self.ncols
|
124 |
+
if ncols > 0: # show all the images in one visdom panel
|
125 |
+
ncols = min(ncols, len(visuals))
|
126 |
+
h, w = next(iter(visuals.values())).shape[:2]
|
127 |
+
table_css = """<style>
|
128 |
+
table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
|
129 |
+
table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
|
130 |
+
</style>""" % (w, h) # create a table css
|
131 |
+
# create a table of images.
|
132 |
+
title = self.name
|
133 |
+
label_html = ''
|
134 |
+
label_html_row = ''
|
135 |
+
images = []
|
136 |
+
idx = 0
|
137 |
+
for label, image in visuals.items():
|
138 |
+
image_numpy = util.tensor2im(image)
|
139 |
+
label_html_row += '<td>%s</td>' % label
|
140 |
+
images.append(image_numpy.transpose([2, 0, 1]))
|
141 |
+
idx += 1
|
142 |
+
if idx % ncols == 0:
|
143 |
+
label_html += '<tr>%s</tr>' % label_html_row
|
144 |
+
label_html_row = ''
|
145 |
+
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
|
146 |
+
while idx % ncols != 0:
|
147 |
+
images.append(white_image)
|
148 |
+
label_html_row += '<td></td>'
|
149 |
+
idx += 1
|
150 |
+
if label_html_row != '':
|
151 |
+
label_html += '<tr>%s</tr>' % label_html_row
|
152 |
+
try:
|
153 |
+
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
|
154 |
+
padding=2, opts=dict(title=title + ' images'))
|
155 |
+
label_html = '<table>%s</table>' % label_html
|
156 |
+
self.vis.text(table_css + label_html, win=self.display_id + 2,
|
157 |
+
opts=dict(title=title + ' labels'))
|
158 |
+
except VisdomExceptionBase:
|
159 |
+
self.create_visdom_connections()
|
160 |
+
|
161 |
+
else: # show each image in a separate visdom panel;
|
162 |
+
idx = 1
|
163 |
+
try:
|
164 |
+
for label, image in visuals.items():
|
165 |
+
image_numpy = util.tensor2im(image)
|
166 |
+
self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
|
167 |
+
win=self.display_id + idx)
|
168 |
+
idx += 1
|
169 |
+
except VisdomExceptionBase:
|
170 |
+
self.create_visdom_connections()
|
171 |
+
|
172 |
+
if self.use_wandb:
|
173 |
+
columns = [key for key, _ in visuals.items()]
|
174 |
+
columns.insert(0,'epoch')
|
175 |
+
result_table = wandb.Table(columns=columns)
|
176 |
+
table_row = [epoch]
|
177 |
+
ims_dict = {}
|
178 |
+
for label, image in visuals.items():
|
179 |
+
image_numpy = util.tensor2im(image)
|
180 |
+
wandb_image = wandb.Image(image_numpy)
|
181 |
+
table_row.append(wandb_image)
|
182 |
+
ims_dict[label] = wandb_image
|
183 |
+
self.wandb_run.log(ims_dict)
|
184 |
+
if epoch != self.current_epoch:
|
185 |
+
self.current_epoch = epoch
|
186 |
+
result_table.add_data(*table_row)
|
187 |
+
self.wandb_run.log({"Result": result_table})
|
188 |
+
|
189 |
+
|
190 |
+
if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
|
191 |
+
self.saved = True
|
192 |
+
# save images to the disk
|
193 |
+
for label, image in visuals.items():
|
194 |
+
image_numpy = util.tensor2im(image)
|
195 |
+
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
|
196 |
+
util.save_image(image_numpy, img_path)
|
197 |
+
|
198 |
+
# update website
|
199 |
+
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
|
200 |
+
for n in range(epoch, 0, -1):
|
201 |
+
webpage.add_header('epoch [%d]' % n)
|
202 |
+
ims, txts, links = [], [], []
|
203 |
+
|
204 |
+
for label, image_numpy in visuals.items():
|
205 |
+
image_numpy = util.tensor2im(image)
|
206 |
+
img_path = 'epoch%.3d_%s.png' % (n, label)
|
207 |
+
ims.append(img_path)
|
208 |
+
txts.append(label)
|
209 |
+
links.append(img_path)
|
210 |
+
webpage.add_images(ims, txts, links, width=self.win_size)
|
211 |
+
webpage.save()
|
212 |
+
|
213 |
+
def plot_current_losses(self, epoch, counter_ratio, losses):
|
214 |
+
"""display the current losses on visdom display: dictionary of error labels and values
|
215 |
+
|
216 |
+
Parameters:
|
217 |
+
epoch (int) -- current epoch
|
218 |
+
counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
|
219 |
+
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
220 |
+
"""
|
221 |
+
if not hasattr(self, 'plot_data'):
|
222 |
+
self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
|
223 |
+
self.plot_data['X'].append(epoch + counter_ratio)
|
224 |
+
self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
|
225 |
+
try:
|
226 |
+
self.vis.line(
|
227 |
+
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
|
228 |
+
Y=np.array(self.plot_data['Y']),
|
229 |
+
opts={
|
230 |
+
'title': self.name + ' loss over time',
|
231 |
+
'legend': self.plot_data['legend'],
|
232 |
+
'xlabel': 'epoch',
|
233 |
+
'ylabel': 'loss'},
|
234 |
+
win=self.display_id)
|
235 |
+
except VisdomExceptionBase:
|
236 |
+
self.create_visdom_connections()
|
237 |
+
if self.use_wandb:
|
238 |
+
self.wandb_run.log(losses)
|
239 |
+
|
240 |
+
# losses: same format as |losses| of plot_current_losses
|
241 |
+
def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
|
242 |
+
"""print current losses on console; also save the losses to the disk
|
243 |
+
|
244 |
+
Parameters:
|
245 |
+
epoch (int) -- current epoch
|
246 |
+
iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
|
247 |
+
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
248 |
+
t_comp (float) -- computational time per data point (normalized by batch_size)
|
249 |
+
t_data (float) -- data loading time per data point (normalized by batch_size)
|
250 |
+
"""
|
251 |
+
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
|
252 |
+
for k, v in losses.items():
|
253 |
+
message += '%s: %.3f ' % (k, v)
|
254 |
+
|
255 |
+
print(message) # print the message
|
256 |
+
with open(self.log_name, "a") as log_file:
|
257 |
+
log_file.write('%s\n' % message) # save the message
|