File size: 4,052 Bytes
5e9bd47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import os
import random
import numpy as np
import torch
import math
import time
import datetime
import json
from json import encoder
FORMAT_INFO = {
"inchi": {
"name": "InChI_text",
"tokenizer": "tokenizer_inchi.json",
"max_len": 300
},
"atomtok": {
"name": "SMILES_atomtok",
"tokenizer": "tokenizer_smiles_atomtok.json",
"max_len": 256
},
"nodes": {"max_len": 384},
"atomtok_coords": {"max_len": 480},
"chartok_coords": {"max_len": 480}
}
def init_logger(log_file='train.log'):
from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler
logger = getLogger(__name__)
logger.setLevel(INFO)
handler1 = StreamHandler()
handler1.setFormatter(Formatter("%(message)s"))
handler2 = FileHandler(filename=log_file)
handler2.setFormatter(Formatter("%(message)s"))
logger.addHandler(handler1)
logger.addHandler(handler2)
return logger
def init_summary_writer(save_path):
from tensorboardX import SummaryWriter
summary = SummaryWriter(save_path)
return summary
def save_args(args):
dt = datetime.datetime.strftime(datetime.datetime.now(), "%y%m%d-%H%M")
path = os.path.join(args.save_path, f'train_{dt}.log')
with open(path, 'w') as f:
for k, v in vars(args).items():
f.write(f"**** {k} = *{v}*\n")
return
def seed_torch(seed=42):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class EpochMeter(AverageMeter):
def __init__(self):
super().__init__()
self.epoch = AverageMeter()
def update(self, val, n=1):
super().update(val, n)
self.epoch.update(val, n)
class LossMeter(EpochMeter):
def __init__(self):
self.subs = {}
super().__init__()
def reset(self):
super().reset()
for k in self.subs:
self.subs[k].reset()
def update(self, loss, losses, n=1):
loss = loss.item()
super().update(loss, n)
losses = {k: v.item() for k, v in losses.items()}
for k, v in losses.items():
if k not in self.subs:
self.subs[k] = EpochMeter()
self.subs[k].update(v, n)
def asMinutes(s):
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)
def timeSince(since, percent):
now = time.time()
s = now - since
es = s / (percent)
rs = es - s
return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))
def print_rank_0(message):
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def to_device(data, device):
if torch.is_tensor(data):
return data.to(device)
if type(data) is list:
return [to_device(v, device) for v in data]
if type(data) is dict:
return {k: to_device(v, device) for k, v in data.items()}
def round_floats(o):
if isinstance(o, float):
return round(o, 3)
if isinstance(o, dict):
return {k: round_floats(v) for k, v in o.items()}
if isinstance(o, (list, tuple)):
return [round_floats(x) for x in o]
return o
def format_df(df):
def _dumps(obj):
if obj is None:
return obj
return json.dumps(round_floats(obj)).replace(" ", "")
for field in ['node_coords', 'node_symbols', 'edges']:
if field in df.columns:
df[field] = [_dumps(obj) for obj in df[field]]
return df
|