Spaces:
Running
on
L4
Running
on
L4
from typing import List | |
import hydra | |
from lightning import Callback | |
from lightning.pytorch.loggers import Logger | |
from omegaconf import DictConfig | |
from matcha.utils import pylogger | |
log = pylogger.get_pylogger(__name__) | |
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: | |
"""Instantiates callbacks from config. | |
:param callbacks_cfg: A DictConfig object containing callback configurations. | |
:return: A list of instantiated callbacks. | |
""" | |
callbacks: List[Callback] = [] | |
if not callbacks_cfg: | |
log.warning("No callback configs found! Skipping..") | |
return callbacks | |
if not isinstance(callbacks_cfg, DictConfig): | |
raise TypeError("Callbacks config must be a DictConfig!") | |
for _, cb_conf in callbacks_cfg.items(): | |
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: | |
log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access | |
callbacks.append(hydra.utils.instantiate(cb_conf)) | |
return callbacks | |
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: | |
"""Instantiates loggers from config. | |
:param logger_cfg: A DictConfig object containing logger configurations. | |
:return: A list of instantiated loggers. | |
""" | |
logger: List[Logger] = [] | |
if not logger_cfg: | |
log.warning("No logger configs found! Skipping...") | |
return logger | |
if not isinstance(logger_cfg, DictConfig): | |
raise TypeError("Logger config must be a DictConfig!") | |
for _, lg_conf in logger_cfg.items(): | |
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: | |
log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access | |
logger.append(hydra.utils.instantiate(lg_conf)) | |
return logger | |