Spaces:
Running
on
A10G
Running
on
A10G
# -*- coding: utf-8 -*- | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
""" | |
This file contains components with some default boilerplate logic user may need | |
in training / testing. They will not work for everyone, but many users may find them useful. | |
The behavior of functions/classes in this file is subject to change, | |
since they are meant to represent the "common default behavior" people need in their projects. | |
""" | |
import argparse | |
import logging | |
import os | |
import sys | |
import weakref | |
from collections import OrderedDict | |
from typing import Optional | |
import torch | |
from fvcore.nn.precise_bn import get_bn_modules | |
from omegaconf import OmegaConf | |
from torch.nn.parallel import DistributedDataParallel | |
import annotator.oneformer.detectron2.data.transforms as T | |
from annotator.oneformer.detectron2.checkpoint import DetectionCheckpointer | |
from annotator.oneformer.detectron2.config import CfgNode, LazyConfig | |
from annotator.oneformer.detectron2.data import ( | |
MetadataCatalog, | |
build_detection_test_loader, | |
build_detection_train_loader, | |
) | |
from annotator.oneformer.detectron2.evaluation import ( | |
DatasetEvaluator, | |
inference_on_dataset, | |
print_csv_format, | |
verify_results, | |
) | |
from annotator.oneformer.detectron2.modeling import build_model | |
from annotator.oneformer.detectron2.solver import build_lr_scheduler, build_optimizer | |
from annotator.oneformer.detectron2.utils import comm | |
from annotator.oneformer.detectron2.utils.collect_env import collect_env_info | |
from annotator.oneformer.detectron2.utils.env import seed_all_rng | |
from annotator.oneformer.detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter | |
from annotator.oneformer.detectron2.utils.file_io import PathManager | |
from annotator.oneformer.detectron2.utils.logger import setup_logger | |
from . import hooks | |
from .train_loop import AMPTrainer, SimpleTrainer, TrainerBase | |
__all__ = [ | |
"create_ddp_model", | |
"default_argument_parser", | |
"default_setup", | |
"default_writers", | |
"DefaultPredictor", | |
"DefaultTrainer", | |
] | |
def create_ddp_model(model, *, fp16_compression=False, **kwargs): | |
""" | |
Create a DistributedDataParallel model if there are >1 processes. | |
Args: | |
model: a torch.nn.Module | |
fp16_compression: add fp16 compression hooks to the ddp object. | |
See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook | |
kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`. | |
""" # noqa | |
if comm.get_world_size() == 1: | |
return model | |
if "device_ids" not in kwargs: | |
kwargs["device_ids"] = [comm.get_local_rank()] | |
ddp = DistributedDataParallel(model, **kwargs) | |
if fp16_compression: | |
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks | |
ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook) | |
return ddp | |
def default_argument_parser(epilog=None): | |
""" | |
Create a parser with some common arguments used by detectron2 users. | |
Args: | |
epilog (str): epilog passed to ArgumentParser describing the usage. | |
Returns: | |
argparse.ArgumentParser: | |
""" | |
parser = argparse.ArgumentParser( | |
epilog=epilog | |
or f""" | |
Examples: | |
Run on single machine: | |
$ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml | |
Change some config options: | |
$ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001 | |
Run on multiple machines: | |
(machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags] | |
(machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags] | |
""", | |
formatter_class=argparse.RawDescriptionHelpFormatter, | |
) | |
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") | |
parser.add_argument( | |
"--resume", | |
action="store_true", | |
help="Whether to attempt to resume from the checkpoint directory. " | |
"See documentation of `DefaultTrainer.resume_or_load()` for what it means.", | |
) | |
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") | |
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*") | |
parser.add_argument("--num-machines", type=int, default=1, help="total number of machines") | |
parser.add_argument( | |
"--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)" | |
) | |
# PyTorch still may leave orphan processes in multi-gpu training. | |
# Therefore we use a deterministic way to obtain port, | |
# so that users are aware of orphan processes by seeing the port occupied. | |
port = 2**15 + 2**14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2**14 | |
parser.add_argument( | |
"--dist-url", | |
default="tcp://127.0.0.1:{}".format(port), | |
help="initialization URL for pytorch distributed backend. See " | |
"https://pytorch.org/docs/stable/distributed.html for details.", | |
) | |
parser.add_argument( | |
"opts", | |
help=""" | |
Modify config options at the end of the command. For Yacs configs, use | |
space-separated "PATH.KEY VALUE" pairs. | |
For python-based LazyConfig, use "path.key=value". | |
""".strip(), | |
default=None, | |
nargs=argparse.REMAINDER, | |
) | |
return parser | |
def _try_get_key(cfg, *keys, default=None): | |
""" | |
Try select keys from cfg until the first key that exists. Otherwise return default. | |
""" | |
if isinstance(cfg, CfgNode): | |
cfg = OmegaConf.create(cfg.dump()) | |
for k in keys: | |
none = object() | |
p = OmegaConf.select(cfg, k, default=none) | |
if p is not none: | |
return p | |
return default | |
def _highlight(code, filename): | |
try: | |
import pygments | |
except ImportError: | |
return code | |
from pygments.lexers import Python3Lexer, YamlLexer | |
from pygments.formatters import Terminal256Formatter | |
lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer() | |
code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai")) | |
return code | |
def default_setup(cfg, args): | |
""" | |
Perform some basic common setups at the beginning of a job, including: | |
1. Set up the detectron2 logger | |
2. Log basic information about environment, cmdline arguments, and config | |
3. Backup the config to the output directory | |
Args: | |
cfg (CfgNode or omegaconf.DictConfig): the full config to be used | |
args (argparse.NameSpace): the command line arguments to be logged | |
""" | |
output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir") | |
if comm.is_main_process() and output_dir: | |
PathManager.mkdirs(output_dir) | |
rank = comm.get_rank() | |
setup_logger(output_dir, distributed_rank=rank, name="fvcore") | |
logger = setup_logger(output_dir, distributed_rank=rank) | |
logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size())) | |
logger.info("Environment info:\n" + collect_env_info()) | |
logger.info("Command line arguments: " + str(args)) | |
if hasattr(args, "config_file") and args.config_file != "": | |
logger.info( | |
"Contents of args.config_file={}:\n{}".format( | |
args.config_file, | |
_highlight(PathManager.open(args.config_file, "r").read(), args.config_file), | |
) | |
) | |
if comm.is_main_process() and output_dir: | |
# Note: some of our scripts may expect the existence of | |
# config.yaml in output directory | |
path = os.path.join(output_dir, "config.yaml") | |
if isinstance(cfg, CfgNode): | |
logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml"))) | |
with PathManager.open(path, "w") as f: | |
f.write(cfg.dump()) | |
else: | |
LazyConfig.save(cfg, path) | |
logger.info("Full config saved to {}".format(path)) | |
# make sure each worker has a different, yet deterministic seed if specified | |
seed = _try_get_key(cfg, "SEED", "train.seed", default=-1) | |
seed_all_rng(None if seed < 0 else seed + rank) | |
# cudnn benchmark has large overhead. It shouldn't be used considering the small size of | |
# typical validation set. | |
if not (hasattr(args, "eval_only") and args.eval_only): | |
torch.backends.cudnn.benchmark = _try_get_key( | |
cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False | |
) | |
def default_writers(output_dir: str, max_iter: Optional[int] = None): | |
""" | |
Build a list of :class:`EventWriter` to be used. | |
It now consists of a :class:`CommonMetricPrinter`, | |
:class:`TensorboardXWriter` and :class:`JSONWriter`. | |
Args: | |
output_dir: directory to store JSON metrics and tensorboard events | |
max_iter: the total number of iterations | |
Returns: | |
list[EventWriter]: a list of :class:`EventWriter` objects. | |
""" | |
PathManager.mkdirs(output_dir) | |
return [ | |
# It may not always print what you want to see, since it prints "common" metrics only. | |
CommonMetricPrinter(max_iter), | |
JSONWriter(os.path.join(output_dir, "metrics.json")), | |
TensorboardXWriter(output_dir), | |
] | |
class DefaultPredictor: | |
""" | |
Create a simple end-to-end predictor with the given config that runs on | |
single device for a single input image. | |
Compared to using the model directly, this class does the following additions: | |
1. Load checkpoint from `cfg.MODEL.WEIGHTS`. | |
2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`. | |
3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`. | |
4. Take one input image and produce a single output, instead of a batch. | |
This is meant for simple demo purposes, so it does the above steps automatically. | |
This is not meant for benchmarks or running complicated inference logic. | |
If you'd like to do anything more complicated, please refer to its source code as | |
examples to build and use the model manually. | |
Attributes: | |
metadata (Metadata): the metadata of the underlying dataset, obtained from | |
cfg.DATASETS.TEST. | |
Examples: | |
:: | |
pred = DefaultPredictor(cfg) | |
inputs = cv2.imread("input.jpg") | |
outputs = pred(inputs) | |
""" | |
def __init__(self, cfg): | |
self.cfg = cfg.clone() # cfg can be modified by model | |
self.model = build_model(self.cfg) | |
self.model.eval() | |
if len(cfg.DATASETS.TEST): | |
self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0]) | |
checkpointer = DetectionCheckpointer(self.model) | |
checkpointer.load(cfg.MODEL.WEIGHTS) | |
self.aug = T.ResizeShortestEdge( | |
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST | |
) | |
self.input_format = cfg.INPUT.FORMAT | |
assert self.input_format in ["RGB", "BGR"], self.input_format | |
def __call__(self, original_image): | |
""" | |
Args: | |
original_image (np.ndarray): an image of shape (H, W, C) (in BGR order). | |
Returns: | |
predictions (dict): | |
the output of the model for one image only. | |
See :doc:`/tutorials/models` for details about the format. | |
""" | |
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258 | |
# Apply pre-processing to image. | |
if self.input_format == "RGB": | |
# whether the model expects BGR inputs or RGB | |
original_image = original_image[:, :, ::-1] | |
height, width = original_image.shape[:2] | |
image = self.aug.get_transform(original_image).apply_image(original_image) | |
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) | |
inputs = {"image": image, "height": height, "width": width} | |
predictions = self.model([inputs])[0] | |
return predictions | |
class DefaultTrainer(TrainerBase): | |
""" | |
A trainer with default training logic. It does the following: | |
1. Create a :class:`SimpleTrainer` using model, optimizer, dataloader | |
defined by the given config. Create a LR scheduler defined by the config. | |
2. Load the last checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when | |
`resume_or_load` is called. | |
3. Register a few common hooks defined by the config. | |
It is created to simplify the **standard model training workflow** and reduce code boilerplate | |
for users who only need the standard training workflow, with standard features. | |
It means this class makes *many assumptions* about your training logic that | |
may easily become invalid in a new research. In fact, any assumptions beyond those made in the | |
:class:`SimpleTrainer` are too much for research. | |
The code of this class has been annotated about restrictive assumptions it makes. | |
When they do not work for you, you're encouraged to: | |
1. Overwrite methods of this class, OR: | |
2. Use :class:`SimpleTrainer`, which only does minimal SGD training and | |
nothing else. You can then add your own hooks if needed. OR: | |
3. Write your own training loop similar to `tools/plain_train_net.py`. | |
See the :doc:`/tutorials/training` tutorials for more details. | |
Note that the behavior of this class, like other functions/classes in | |
this file, is not stable, since it is meant to represent the "common default behavior". | |
It is only guaranteed to work well with the standard models and training workflow in detectron2. | |
To obtain more stable behavior, write your own training logic with other public APIs. | |
Examples: | |
:: | |
trainer = DefaultTrainer(cfg) | |
trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS | |
trainer.train() | |
Attributes: | |
scheduler: | |
checkpointer (DetectionCheckpointer): | |
cfg (CfgNode): | |
""" | |
def __init__(self, cfg): | |
""" | |
Args: | |
cfg (CfgNode): | |
""" | |
super().__init__() | |
logger = logging.getLogger("detectron2") | |
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2 | |
setup_logger() | |
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) | |
# Assume these objects must be constructed in this order. | |
model = self.build_model(cfg) | |
optimizer = self.build_optimizer(cfg, model) | |
data_loader = self.build_train_loader(cfg) | |
model = create_ddp_model(model, broadcast_buffers=False) | |
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( | |
model, data_loader, optimizer | |
) | |
self.scheduler = self.build_lr_scheduler(cfg, optimizer) | |
self.checkpointer = DetectionCheckpointer( | |
# Assume you want to save checkpoints together with logs/statistics | |
model, | |
cfg.OUTPUT_DIR, | |
trainer=weakref.proxy(self), | |
) | |
self.start_iter = 0 | |
self.max_iter = cfg.SOLVER.MAX_ITER | |
self.cfg = cfg | |
self.register_hooks(self.build_hooks()) | |
def resume_or_load(self, resume=True): | |
""" | |
If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by | |
a `last_checkpoint` file), resume from the file. Resuming means loading all | |
available states (eg. optimizer and scheduler) and update iteration counter | |
from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used. | |
Otherwise, this is considered as an independent training. The method will load model | |
weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start | |
from iteration 0. | |
Args: | |
resume (bool): whether to do resume or not | |
""" | |
self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume) | |
if resume and self.checkpointer.has_checkpoint(): | |
# The checkpoint stores the training iteration that just finished, thus we start | |
# at the next iteration | |
self.start_iter = self.iter + 1 | |
def build_hooks(self): | |
""" | |
Build a list of default hooks, including timing, evaluation, | |
checkpointing, lr scheduling, precise BN, writing events. | |
Returns: | |
list[HookBase]: | |
""" | |
cfg = self.cfg.clone() | |
cfg.defrost() | |
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN | |
ret = [ | |
hooks.IterationTimer(), | |
hooks.LRScheduler(), | |
hooks.PreciseBN( | |
# Run at the same freq as (but before) evaluation. | |
cfg.TEST.EVAL_PERIOD, | |
self.model, | |
# Build a new data loader to not affect training | |
self.build_train_loader(cfg), | |
cfg.TEST.PRECISE_BN.NUM_ITER, | |
) | |
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) | |
else None, | |
] | |
# Do PreciseBN before checkpointer, because it updates the model and need to | |
# be saved by checkpointer. | |
# This is not always the best: if checkpointing has a different frequency, | |
# some checkpoints may have more precise statistics than others. | |
if comm.is_main_process(): | |
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) | |
def test_and_save_results(): | |
self._last_eval_results = self.test(self.cfg, self.model) | |
return self._last_eval_results | |
# Do evaluation after checkpointer, because then if it fails, | |
# we can use the saved checkpoint to debug. | |
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) | |
if comm.is_main_process(): | |
# Here the default print/log frequency of each writer is used. | |
# run writers in the end, so that evaluation metrics are written | |
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20)) | |
return ret | |
def build_writers(self): | |
""" | |
Build a list of writers to be used using :func:`default_writers()`. | |
If you'd like a different list of writers, you can overwrite it in | |
your trainer. | |
Returns: | |
list[EventWriter]: a list of :class:`EventWriter` objects. | |
""" | |
return default_writers(self.cfg.OUTPUT_DIR, self.max_iter) | |
def train(self): | |
""" | |
Run training. | |
Returns: | |
OrderedDict of results, if evaluation is enabled. Otherwise None. | |
""" | |
super().train(self.start_iter, self.max_iter) | |
if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process(): | |
assert hasattr( | |
self, "_last_eval_results" | |
), "No evaluation results obtained during training!" | |
verify_results(self.cfg, self._last_eval_results) | |
return self._last_eval_results | |
def run_step(self): | |
self._trainer.iter = self.iter | |
self._trainer.run_step() | |
def state_dict(self): | |
ret = super().state_dict() | |
ret["_trainer"] = self._trainer.state_dict() | |
return ret | |
def load_state_dict(self, state_dict): | |
super().load_state_dict(state_dict) | |
self._trainer.load_state_dict(state_dict["_trainer"]) | |
def build_model(cls, cfg): | |
""" | |
Returns: | |
torch.nn.Module: | |
It now calls :func:`detectron2.modeling.build_model`. | |
Overwrite it if you'd like a different model. | |
""" | |
model = build_model(cfg) | |
logger = logging.getLogger(__name__) | |
logger.info("Model:\n{}".format(model)) | |
return model | |
def build_optimizer(cls, cfg, model): | |
""" | |
Returns: | |
torch.optim.Optimizer: | |
It now calls :func:`detectron2.solver.build_optimizer`. | |
Overwrite it if you'd like a different optimizer. | |
""" | |
return build_optimizer(cfg, model) | |
def build_lr_scheduler(cls, cfg, optimizer): | |
""" | |
It now calls :func:`detectron2.solver.build_lr_scheduler`. | |
Overwrite it if you'd like a different scheduler. | |
""" | |
return build_lr_scheduler(cfg, optimizer) | |
def build_train_loader(cls, cfg): | |
""" | |
Returns: | |
iterable | |
It now calls :func:`detectron2.data.build_detection_train_loader`. | |
Overwrite it if you'd like a different data loader. | |
""" | |
return build_detection_train_loader(cfg) | |
def build_test_loader(cls, cfg, dataset_name): | |
""" | |
Returns: | |
iterable | |
It now calls :func:`detectron2.data.build_detection_test_loader`. | |
Overwrite it if you'd like a different data loader. | |
""" | |
return build_detection_test_loader(cfg, dataset_name) | |
def build_evaluator(cls, cfg, dataset_name): | |
""" | |
Returns: | |
DatasetEvaluator or None | |
It is not implemented by default. | |
""" | |
raise NotImplementedError( | |
""" | |
If you want DefaultTrainer to automatically run evaluation, | |
please implement `build_evaluator()` in subclasses (see train_net.py for example). | |
Alternatively, you can call evaluation functions yourself (see Colab balloon tutorial for example). | |
""" | |
) | |
def test(cls, cfg, model, evaluators=None): | |
""" | |
Evaluate the given model. The given model is expected to already contain | |
weights to evaluate. | |
Args: | |
cfg (CfgNode): | |
model (nn.Module): | |
evaluators (list[DatasetEvaluator] or None): if None, will call | |
:meth:`build_evaluator`. Otherwise, must have the same length as | |
``cfg.DATASETS.TEST``. | |
Returns: | |
dict: a dict of result metrics | |
""" | |
logger = logging.getLogger(__name__) | |
if isinstance(evaluators, DatasetEvaluator): | |
evaluators = [evaluators] | |
if evaluators is not None: | |
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format( | |
len(cfg.DATASETS.TEST), len(evaluators) | |
) | |
results = OrderedDict() | |
for idx, dataset_name in enumerate(cfg.DATASETS.TEST): | |
data_loader = cls.build_test_loader(cfg, dataset_name) | |
# When evaluators are passed in as arguments, | |
# implicitly assume that evaluators can be created before data_loader. | |
if evaluators is not None: | |
evaluator = evaluators[idx] | |
else: | |
try: | |
evaluator = cls.build_evaluator(cfg, dataset_name) | |
except NotImplementedError: | |
logger.warn( | |
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, " | |
"or implement its `build_evaluator` method." | |
) | |
results[dataset_name] = {} | |
continue | |
results_i = inference_on_dataset(model, data_loader, evaluator) | |
results[dataset_name] = results_i | |
if comm.is_main_process(): | |
assert isinstance( | |
results_i, dict | |
), "Evaluator must return a dict on the main process. Got {} instead.".format( | |
results_i | |
) | |
logger.info("Evaluation results for {} in csv format:".format(dataset_name)) | |
print_csv_format(results_i) | |
if len(results) == 1: | |
results = list(results.values())[0] | |
return results | |
def auto_scale_workers(cfg, num_workers: int): | |
""" | |
When the config is defined for certain number of workers (according to | |
``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of | |
workers currently in use, returns a new cfg where the total batch size | |
is scaled so that the per-GPU batch size stays the same as the | |
original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``. | |
Other config options are also scaled accordingly: | |
* training steps and warmup steps are scaled inverse proportionally. | |
* learning rate are scaled proportionally, following :paper:`ImageNet in 1h`. | |
For example, with the original config like the following: | |
.. code-block:: yaml | |
IMS_PER_BATCH: 16 | |
BASE_LR: 0.1 | |
REFERENCE_WORLD_SIZE: 8 | |
MAX_ITER: 5000 | |
STEPS: (4000,) | |
CHECKPOINT_PERIOD: 1000 | |
When this config is used on 16 GPUs instead of the reference number 8, | |
calling this method will return a new config with: | |
.. code-block:: yaml | |
IMS_PER_BATCH: 32 | |
BASE_LR: 0.2 | |
REFERENCE_WORLD_SIZE: 16 | |
MAX_ITER: 2500 | |
STEPS: (2000,) | |
CHECKPOINT_PERIOD: 500 | |
Note that both the original config and this new config can be trained on 16 GPUs. | |
It's up to user whether to enable this feature (by setting ``REFERENCE_WORLD_SIZE``). | |
Returns: | |
CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``. | |
""" | |
old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE | |
if old_world_size == 0 or old_world_size == num_workers: | |
return cfg | |
cfg = cfg.clone() | |
frozen = cfg.is_frozen() | |
cfg.defrost() | |
assert ( | |
cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0 | |
), "Invalid REFERENCE_WORLD_SIZE in config!" | |
scale = num_workers / old_world_size | |
bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale)) | |
lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale | |
max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale)) | |
warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale)) | |
cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS) | |
cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale)) | |
cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale)) | |
cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers # maintain invariant | |
logger = logging.getLogger(__name__) | |
logger.info( | |
f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, " | |
f"max_iter={max_iter}, warmup={warmup_iter}." | |
) | |
if frozen: | |
cfg.freeze() | |
return cfg | |
# Access basic attributes from the underlying trainer | |
for _attr in ["model", "data_loader", "optimizer"]: | |
setattr( | |
DefaultTrainer, | |
_attr, | |
property( | |
# getter | |
lambda self, x=_attr: getattr(self._trainer, x), | |
# setter | |
lambda self, value, x=_attr: setattr(self._trainer, x, value), | |
), | |
) | |