Skip to content

Commit

Permalink
Add initial slurm support (multiple nodes sharing the same task id)
Browse files Browse the repository at this point in the history
  • Loading branch information
allegroai committed Mar 12, 2020
1 parent 5b29aa1 commit afad6a4
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 10 deletions.
16 changes: 8 additions & 8 deletions docs/trains.conf
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ sdk {
subsampling: 0
}

# Support plot-per-graph fully matching Tensorboard behavior (i.e. if this is set to True, each series should have its own graph)
tensorboard_single_series_per_graph: False
# Support plot-per-graph fully matching Tensorboard behavior (i.e. if this is set to true, each series should have its own graph)
tensorboard_single_series_per_graph: false
}

network {
Expand Down Expand Up @@ -125,11 +125,11 @@ sdk {

log {
# debugging feature: set this to true to make null log propagate messages to root logger (so they appear in stdout)
null_log_propagate: False
null_log_propagate: false
task_log_buffer_capacity: 66

# disable urllib info and lower levels
disable_urllib3_info: True
disable_urllib3_info: true
}

development {
Expand All @@ -139,14 +139,14 @@ sdk {
task_reuse_time_window_in_hours: 72.0

# Run VCS repository detection asynchronously
vcs_repo_detect_async: True
vcs_repo_detect_async: true

# Store uncommitted git/hg source code diff in experiment manifest when training in development mode
# This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section
store_uncommitted_code_diff_on_train: True
store_uncommitted_code_diff: true

# Support stopping an experiment in case it was externally stopped, status was changed or task was reset
support_stopping: True
support_stopping: true

# Default Task output_uri. if output_uri is not provided to Task.init, default_output_uri will be used instead.
default_output_uri: ""
Expand All @@ -160,7 +160,7 @@ sdk {
ping_period_sec: 30

# Log all stdout & stderr
log_stdout: True
log_stdout: true
}
}
}
2 changes: 2 additions & 0 deletions trains/backend_interface/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,8 @@ def __update_master_pid_task(cls, pid=None, task=None):
pid = pid or os.getpid()
if not task:
PROC_MASTER_ID_ENV_VAR.set(str(pid) + ':')
elif isinstance(task, str):
PROC_MASTER_ID_ENV_VAR.set(str(pid) + ':' + task)
else:
PROC_MASTER_ID_ENV_VAR.set(str(pid) + ':' + str(task.id))
# make sure we refresh the edit lock next time we need it,
Expand Down
47 changes: 46 additions & 1 deletion trains/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" Configuration module. Uses backend_config to load system configuration. """
import logging
import os
from os.path import expandvars, expanduser

from ..backend_api import load_config
Expand Down Expand Up @@ -47,7 +48,51 @@ def get_log_to_backend(default=None):


def get_node_id(default=0):
return NODE_ID_ENV_VAR.get(default=default)
node_id = NODE_ID_ENV_VAR.get()

try:
mpi_world_rank = int(os.environ.get('OMPI_COMM_WORLD_NODE_RANK', os.environ.get('PMI_RANK')))
except:
mpi_world_rank = None

try:
mpi_rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', os.environ.get('SLURM_PROCID')))
except:
mpi_rank = None

# if we have no node_id, use the mpi rank
if node_id is None and (mpi_world_rank is not None or mpi_rank is not None):
node_id = mpi_world_rank if mpi_world_rank is not None else mpi_rank

# if node is is till None, use the default
if node_id is None:
node_id = default

# check if we have pyTorch node/worker ID
try:
from torch.utils.data.dataloader import get_worker_info
worker_info = get_worker_info()
if not worker_info:
torch_rank = None
else:
w_id = worker_info.id
try:
torch_rank = int(w_id)
except Exception:
# guess a number based on wid hopefully unique value
import hashlib
h = hashlib.md5()
h.update(str(w_id).encode('utf-8'))
torch_rank = int(h.hexdigest(), 16)
except Exception:
torch_rank = None

# if we also have a torch rank add it to the node rank
if torch_rank is not None:
# Since we dont know the world rank, we assume it is not bigger than 10k
node_id = (10000 * node_id) + torch_rank

return node_id


def get_log_redirect_level():
Expand Down
9 changes: 8 additions & 1 deletion trains/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
from .binding.joblib_bind import PatchedJoblib
from .binding.matplotlib_bind import PatchedMatplotlib
from .config import config, DEV_TASK_NO_REUSE
from .config import config, DEV_TASK_NO_REUSE, get_node_id
from .config import running_remotely, get_remote_task_id
from .config.cache import SessionCache
from .debugging.log import LoggerRoot
Expand Down Expand Up @@ -240,6 +240,13 @@ def __setattr__(self, attr, val):
# we could not find a task ID, revert to old stub behaviour
if not is_sub_process_task_id:
return _TaskStub()
elif running_remotely() and get_node_id(default=0) != 0:
print("get_node_id", get_node_id(), get_remote_task_id())

# make sure we only do it once per process
cls.__forked_proc_main_pid = os.getpid()
# make sure everyone understands we should act as if we are a subprocess (fake pid 1)
cls.__update_master_pid_task(pid=1, task=get_remote_task_id())
else:
# set us as master process (without task ID)
cls.__update_master_pid_task()
Expand Down

0 comments on commit afad6a4

Please sign in to comment.