|
|
|
from collections import OrderedDict
|
|
|
|
import numpy as np
|
|
|
|
|
|
class LogBuffer:
|
|
|
|
def __init__(self):
|
|
self.val_history = OrderedDict()
|
|
self.n_history = OrderedDict()
|
|
self.output = OrderedDict()
|
|
self.ready = False
|
|
|
|
def clear(self):
|
|
self.val_history.clear()
|
|
self.n_history.clear()
|
|
self.clear_output()
|
|
|
|
def clear_output(self):
|
|
self.output.clear()
|
|
self.ready = False
|
|
|
|
def update(self, vars, count=1):
|
|
assert isinstance(vars, dict)
|
|
for key, var in vars.items():
|
|
if key not in self.val_history:
|
|
self.val_history[key] = []
|
|
self.n_history[key] = []
|
|
self.val_history[key].append(var)
|
|
self.n_history[key].append(count)
|
|
|
|
def average(self, n=0):
|
|
"""Average latest n values or all values."""
|
|
assert n >= 0
|
|
for key in self.val_history:
|
|
values = np.array(self.val_history[key][-n:])
|
|
nums = np.array(self.n_history[key][-n:])
|
|
avg = np.sum(values * nums) / np.sum(nums)
|
|
self.output[key] = avg
|
|
self.ready = True
|
|
|