shayekh's picture
Upload 61 files
cc9c7ee
raw
history blame
4.61 kB
"""Miscellaneous utility functions."""
import random
import numpy as np
import torch
import copy
import itertools
def seed(value=42):
"""Set random seed for everything.
Args:
value (int): Seed
"""
np.random.seed(value)
torch.manual_seed(value)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(value)
def map_dict_to_obj(dic):
result_dic = {}
if dic is not None:
for k, v in dic.items():
if isinstance(v, dict):
result_dic[k] = map_dict_to_obj(v)
else:
try:
obj = configmapper.get_object("params", v)
result_dic[k] = obj
except:
result_dic[k] = v
return result_dic
def get_item_in_config(config, path):
## config is a dictionary
curr = config
if isinstance(config, dict):
for step in path:
curr = curr[step]
if curr is None:
break
else:
for step in path:
curr = curr.__getattr__(step)
if curr is None:
break
return curr
# init = train_config.grid_search
# curr = get_item_in_config(init,['hyperparams','loader_params'])
# curr.set_value('batch_size',1)
# print(train_config.grid_search)
def generate_grid_search_configs(main_config, grid_config, root="hyperparams"):
## DFS
locations_values_pair = {}
init = grid_config.as_dict()
# print(init)
stack = [root]
visited = [stack[-1]]
log_label_path = None
hparams_path = None
# root = init[stack[-1]]
while len(stack) != 0:
root = get_item_in_config(init, stack)
flag = 0
# print(visited)
# print(stack)
if (
not isinstance(root, dict) and "hparams" not in stack
): ## Meaning it is a leaf node
# print(stack)
if isinstance(root, list):
locations_values_pair[
tuple(copy.deepcopy(stack))
] = root ## Append the current stack, and the list values
else:
locations_values_pair[tuple(copy.deepcopy(stack))] = [
root,
] ## Append the current stack, and the list values
_ = stack.pop() ## Pop this root because we don't need it.
else:
if isinstance(root, list) and "hparams" in stack:
hparams_path = copy.deepcopy(stack)
visited.append(".".join(stack))
stack.pop()
continue
if "log_label" in root.keys():
log_label_path = copy.deepcopy(
stack
+ [
"log_label",
]
)
if "log_label" in root.keys():
log_label_path = copy.deepcopy(
stack
+ [
"log_label",
]
)
parent = root ## Otherwise it has children
for key in parent.keys(): ## For the children
if (
".".join(
stack
+ [
key,
]
)
not in visited
): ## Check if I have visited these children
flag = 1 ## If not, we need to repeat the process for this key
stack.append(key) ## Append this key to the stack
visited.append(".".join(stack))
break
if flag == 0:
stack.pop()
paths = list(locations_values_pair.keys())
values = itertools.product(*list(locations_values_pair.values()))
result_configs = []
for value in values:
for item_index in range(len(value)):
curr_path = paths[item_index]
curr_item = value[item_index]
curr_config_item = get_item_in_config(main_config, curr_path[1:-1])
curr_config_item.set_value(curr_path[-1], curr_item)
log_item = get_item_in_config(main_config, log_label_path[1:-1])
log_item.set_value(log_label_path[-1], str(len(result_configs) + 1))
hparam_item = get_item_in_config(main_config, hparams_path[1:-1])
hparam_item.set_value(
hparams_path[-1],
get_item_in_config(grid_config.hyperparams, hparams_path[1:]),
)
result_configs.append(copy.deepcopy(main_config))
return result_configs