From afad6a42ea58c07b121402068c313fcfc4875115 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 12 Mar 2020 18:12:16 +0200 Subject: [PATCH] Add initial slurm support (multiple nodes sharing the same task id) --- docs/trains.conf | 16 ++++----- trains/backend_interface/task/task.py | 2 ++ trains/config/__init__.py | 47 ++++++++++++++++++++++++++- trains/task.py | 9 ++++- 4 files changed, 64 insertions(+), 10 deletions(-) diff --git a/docs/trains.conf b/docs/trains.conf index b7caeb3a..2a8f9efa 100644 --- a/docs/trains.conf +++ b/docs/trains.conf @@ -43,8 +43,8 @@ sdk { subsampling: 0 } - # Support plot-per-graph fully matching Tensorboard behavior (i.e. if this is set to True, each series should have its own graph) - tensorboard_single_series_per_graph: False + # Support plot-per-graph fully matching Tensorboard behavior (i.e. if this is set to true, each series should have its own graph) + tensorboard_single_series_per_graph: false } network { @@ -125,11 +125,11 @@ sdk { log { # debugging feature: set this to true to make null log propagate messages to root logger (so they appear in stdout) - null_log_propagate: False + null_log_propagate: false task_log_buffer_capacity: 66 # disable urllib info and lower levels - disable_urllib3_info: True + disable_urllib3_info: true } development { @@ -139,14 +139,14 @@ sdk { task_reuse_time_window_in_hours: 72.0 # Run VCS repository detection asynchronously - vcs_repo_detect_async: True + vcs_repo_detect_async: true # Store uncommitted git/hg source code diff in experiment manifest when training in development mode # This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section - store_uncommitted_code_diff_on_train: True + store_uncommitted_code_diff: true # Support stopping an experiment in case it was externally stopped, status was changed or task was reset - support_stopping: True + support_stopping: true # Default Task output_uri. if output_uri is not provided to Task.init, default_output_uri will be used instead. default_output_uri: "" @@ -160,7 +160,7 @@ sdk { ping_period_sec: 30 # Log all stdout & stderr - log_stdout: True + log_stdout: true } } } diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index eebcb366..dded2745 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -1059,6 +1059,8 @@ def __update_master_pid_task(cls, pid=None, task=None): pid = pid or os.getpid() if not task: PROC_MASTER_ID_ENV_VAR.set(str(pid) + ':') + elif isinstance(task, str): + PROC_MASTER_ID_ENV_VAR.set(str(pid) + ':' + task) else: PROC_MASTER_ID_ENV_VAR.set(str(pid) + ':' + str(task.id)) # make sure we refresh the edit lock next time we need it, diff --git a/trains/config/__init__.py b/trains/config/__init__.py index 0647c550..3d36a8e9 100644 --- a/trains/config/__init__.py +++ b/trains/config/__init__.py @@ -1,5 +1,6 @@ """ Configuration module. Uses backend_config to load system configuration. """ import logging +import os from os.path import expandvars, expanduser from ..backend_api import load_config @@ -47,7 +48,51 @@ def get_log_to_backend(default=None): def get_node_id(default=0): - return NODE_ID_ENV_VAR.get(default=default) + node_id = NODE_ID_ENV_VAR.get() + + try: + mpi_world_rank = int(os.environ.get('OMPI_COMM_WORLD_NODE_RANK', os.environ.get('PMI_RANK'))) + except: + mpi_world_rank = None + + try: + mpi_rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', os.environ.get('SLURM_PROCID'))) + except: + mpi_rank = None + + # if we have no node_id, use the mpi rank + if node_id is None and (mpi_world_rank is not None or mpi_rank is not None): + node_id = mpi_world_rank if mpi_world_rank is not None else mpi_rank + + # if node is is till None, use the default + if node_id is None: + node_id = default + + # check if we have pyTorch node/worker ID + try: + from torch.utils.data.dataloader import get_worker_info + worker_info = get_worker_info() + if not worker_info: + torch_rank = None + else: + w_id = worker_info.id + try: + torch_rank = int(w_id) + except Exception: + # guess a number based on wid hopefully unique value + import hashlib + h = hashlib.md5() + h.update(str(w_id).encode('utf-8')) + torch_rank = int(h.hexdigest(), 16) + except Exception: + torch_rank = None + + # if we also have a torch rank add it to the node rank + if torch_rank is not None: + # Since we dont know the world rank, we assume it is not bigger than 10k + node_id = (10000 * node_id) + torch_rank + + return node_id def get_log_redirect_level(): diff --git a/trains/task.py b/trains/task.py index fffead93..79fee37c 100644 --- a/trains/task.py +++ b/trains/task.py @@ -35,7 +35,7 @@ from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO from .binding.joblib_bind import PatchedJoblib from .binding.matplotlib_bind import PatchedMatplotlib -from .config import config, DEV_TASK_NO_REUSE +from .config import config, DEV_TASK_NO_REUSE, get_node_id from .config import running_remotely, get_remote_task_id from .config.cache import SessionCache from .debugging.log import LoggerRoot @@ -240,6 +240,13 @@ def __setattr__(self, attr, val): # we could not find a task ID, revert to old stub behaviour if not is_sub_process_task_id: return _TaskStub() + elif running_remotely() and get_node_id(default=0) != 0: + print("get_node_id", get_node_id(), get_remote_task_id()) + + # make sure we only do it once per process + cls.__forked_proc_main_pid = os.getpid() + # make sure everyone understands we should act as if we are a subprocess (fake pid 1) + cls.__update_master_pid_task(pid=1, task=get_remote_task_id()) else: # set us as master process (without task ID) cls.__update_master_pid_task()