AnimeIns_CPU / utils /env_utils.py
ljsabc's picture
add utils
8aa4f1e
raw
history blame
2.67 kB
import os
import platform
import warnings
import torch.multiprocessing as mp
def set_multi_processing(
mp_start_method: str = "fork", opencv_num_threads: int = 0, distributed: bool = True
) -> None:
"""Set multi-processing related environment.
This function is refered from https://github.com/open-mmlab/mmengine/blob/main/mmengine/utils/dl_utils/setup_env.py
Args:
mp_start_method (str): Set the method which should be used to start
child processes. Defaults to 'fork'.
opencv_num_threads (int): Number of threads for opencv.
Defaults to 0.
distributed (bool): True if distributed environment.
Defaults to False.
""" # noqa
# set multi-process start method as `fork` to speed up the training
if platform.system() != "Windows":
current_method = mp.get_start_method(allow_none=True)
if current_method is not None and current_method != mp_start_method:
warnings.warn(
f"Multi-processing start method `{mp_start_method}` is "
f"different from the previous setting `{current_method}`."
f"It will be force set to `{mp_start_method}`. You can "
"change this behavior by changing `mp_start_method` in "
"your config."
)
mp.set_start_method(mp_start_method, force=True)
try:
import cv2
# disable opencv multithreading to avoid system being overloaded
cv2.setNumThreads(opencv_num_threads)
except ImportError:
pass
# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
if "OMP_NUM_THREADS" not in os.environ and distributed:
omp_num_threads = 1
warnings.warn(
"Setting OMP_NUM_THREADS environment variable for each process"
f" to be {omp_num_threads} in default, to avoid your system "
"being overloaded, please further tune the variable for "
"optimal performance in your application as needed."
)
os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
# # setup MKL threads
if "MKL_NUM_THREADS" not in os.environ and distributed:
mkl_num_threads = 1
warnings.warn(
"Setting MKL_NUM_THREADS environment variable for each process"
f" to be {mkl_num_threads} in default, to avoid your system "
"being overloaded, please further tune the variable for "
"optimal performance in your application as needed."
)
os.environ["MKL_NUM_THREADS"] = str(mkl_num_threads)