Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Horovod #211

Draft
wants to merge 28 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
af9a7a6
initial
Linux-cpp-lisp Feb 2, 2022
a98d1c4
log hacks
Linux-cpp-lisp Feb 2, 2022
675986a
LR sched optim state
Linux-cpp-lisp Feb 3, 2022
e3286e1
cont
Linux-cpp-lisp Feb 3, 2022
8380098
samplers
Linux-cpp-lisp Feb 3, 2022
6879406
loss/metrics allgather
Linux-cpp-lisp Feb 3, 2022
48cb1c6
fix dataset barriers
Linux-cpp-lisp Feb 3, 2022
9de35d6
loss/metrics gather
Linux-cpp-lisp Feb 3, 2022
5a785d4
dataset sync
Linux-cpp-lisp Feb 3, 2022
a1475d3
device pinning
Linux-cpp-lisp Feb 3, 2022
6816ab8
Merge branch 'develop' into horovod
Linux-cpp-lisp Feb 7, 2022
08b390b
Merge branch 'develop' into horovod
Linux-cpp-lisp Feb 8, 2022
c46a8f7
set_to_none
Linux-cpp-lisp Feb 8, 2022
63ddaa9
confirm load cached dataset
Linux-cpp-lisp Feb 8, 2022
e88d08c
initial unit test
Linux-cpp-lisp Feb 8, 2022
2909e6d
fixes
Linux-cpp-lisp Feb 8, 2022
f471505
batch size
Linux-cpp-lisp Feb 8, 2022
b2de594
full passing integration test
Linux-cpp-lisp Feb 8, 2022
b3acb32
Merge branch 'develop' into horovod
Linux-cpp-lisp Mar 17, 2022
3a2c5f5
Parallel `nequip-evaluate` with Horovod
Linux-cpp-lisp May 12, 2022
c422e48
Merge branch 'develop' into horovod
Linux-cpp-lisp May 12, 2022
35566b4
lint
Linux-cpp-lisp May 12, 2022
c577ebf
fix tests and config
Linux-cpp-lisp May 13, 2022
59b37a1
bugfix
Linux-cpp-lisp May 13, 2022
e75ffb5
Merge branch 'develop' into horovod
Linux-cpp-lisp May 13, 2022
8a651f8
Merge branch 'develop' into horovod
Linux-cpp-lisp May 15, 2022
63b9f0c
Merge branch 'develop' into horovod
Linux-cpp-lisp May 16, 2022
c655667
Merge branch 'develop' into horovod
Linux-cpp-lisp May 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions nequip/data/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from nequip.utils import instantiate, get_w_prefix


def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset:
def dataset_from_config(
config, prefix: str = "dataset", force_use_cached: bool = False
) -> AtomicDataset:
"""initialize database based on a config instance

It needs dataset type name (case insensitive),
Expand All @@ -20,6 +22,7 @@ def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset:

config (dict, nequip.utils.Config): dict/object that store all the parameters
prefix (str): Optional. The prefix of all dataset parameters
force_use_cached (bool): Optional, default False. Whether to error if a cached dataset cannot be found.

Return:

Expand Down Expand Up @@ -78,7 +81,10 @@ def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset:
instance, _ = instantiate(
class_name,
prefix=prefix,
positional_args={"type_mapper": type_mapper},
positional_args={
"type_mapper": type_mapper,
"force_use_cached": force_use_cached,
},
optional_args=config,
)

Expand Down
13 changes: 11 additions & 2 deletions nequip/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
extra_fixed_fields: Dict[str, Any] = {},
include_frames: Optional[List[int]] = None,
type_mapper: Optional[TypeMapper] = None,
force_use_cached: bool = False,
):
# TO DO, this may be simplified
# See if a subclass defines some inputs
Expand Down Expand Up @@ -125,7 +126,9 @@ def __init__(
# Initialize the InMemoryDataset, which runs download and process
# See https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets
# Then pre-process the data if disk files are not found
super().__init__(root=root, transform=type_mapper)
super().__init__(
root=root, transform=type_mapper, force_use_cached=force_use_cached
)
if self.data is None:
self.data, self.fixed_fields, include_frames = torch.load(
self.processed_paths[0]
Expand All @@ -150,7 +153,9 @@ def _get_parameters(self) -> Dict[str, Any]:
pnames = list(inspect.signature(self.__init__).parameters)
IGNORE_KEYS = {
# the type mapper is applied after saving, not before, so doesn't matter to cache validity
"type_mapper"
"type_mapper",
# this parameter controls loading, doesn't affect dataset
"force_use_cached",
}
params = {
k: getattr(self, k)
Expand Down Expand Up @@ -686,6 +691,7 @@ def __init__(
extra_fixed_fields: Dict[str, Any] = {},
include_frames: Optional[List[int]] = None,
type_mapper: TypeMapper = None,
force_use_cached: bool = False,
):
self.key_mapping = key_mapping
self.npz_fixed_field_keys = npz_fixed_field_keys
Expand All @@ -699,6 +705,7 @@ def __init__(
extra_fixed_fields=extra_fixed_fields,
include_frames=include_frames,
type_mapper=type_mapper,
force_use_cached=force_use_cached,
)

@property
Expand Down Expand Up @@ -854,6 +861,7 @@ def __init__(
type_mapper: TypeMapper = None,
key_mapping: Optional[dict] = None,
include_keys: Optional[List[str]] = None,
force_use_cached: bool = False,
):
self.ase_args = {}
self.ase_args.update(getattr(type(self), "ASE_ARGS", dict()))
Expand All @@ -872,6 +880,7 @@ def __init__(
extra_fixed_fields=extra_fixed_fields,
include_frames=include_frames,
type_mapper=type_mapper,
force_use_cached=force_use_cached,
)

@classmethod
Expand Down
157 changes: 116 additions & 41 deletions nequip/scripts/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Optional
import sys
import math
import os
import shutil
import argparse
import logging
import textwrap
Expand All @@ -11,10 +14,15 @@

import torch

from nequip.data import AtomicData, Collater, dataset_from_config, register_fields
try:
import horovod.torch as hvd
except ImportError:
pass

from nequip.data import AtomicData, Collater, register_fields
from nequip.scripts.deploy import load_deployed_model, R_MAX_KEY
from nequip.scripts._logger import set_up_script_logger
from nequip.scripts.train import default_config, check_code_version
from nequip.scripts.train import default_config, check_code_version, _load_datasets
from nequip.utils._global_options import _set_global_options
from nequip.train import Trainer, Loss, Metrics
from nequip.utils import load_file, instantiate, Config
Expand Down Expand Up @@ -100,6 +108,12 @@ def main(args=None, running_as_script: bool = True):
type=str,
default=None,
)
parser.add_argument(
"--horovod",
help="Whether to distribute with horovod.",
type=bool,
default=False,
)
parser.add_argument(
"--output",
help="ExtXYZ (.xyz) file to write out the test set and model predictions to.",
Expand Down Expand Up @@ -172,15 +186,26 @@ def main(args=None, running_as_script: bool = True):
assert args.output_fields == ""
args.output_fields = []

if running_as_script:
set_up_script_logger(args.log)
logger = logging.getLogger("nequip-evaluate")
logger.setLevel(logging.INFO)

# Handle devices and setup
if args.device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)

if running_as_script:
set_up_script_logger(args.log)
logger = logging.getLogger("nequip-evaluate")
logger.setLevel(logging.INFO)
if args.horovod:
hvd.init()
logger.info(f"Using Horovod; this is rank {hvd.rank()}/{hvd.size()}")
if device.type == "cuda":
assert device == torch.device("cuda") # no specific index
torch.cuda.set_device(hvd.local_rank())
if hvd.rank() != 0:
# disable outputs
logger.handlers = []

logger.info(f"Using device: {device}")
if device.type == "cuda":
Expand Down Expand Up @@ -245,16 +270,23 @@ def main(args=None, running_as_script: bool = True):
)

dataset_is_validation: bool = False
try:
# Try to get validation dataset
dataset = dataset_from_config(dataset_config, prefix="validation_dataset")
# look for validation and only fall back to `dataset` prefix
# have to tell the loading function whether to use horovod
dataset_config.horovod = args.horovod
# this function syncs horovod if it is enabled
datasets = _load_datasets(
dataset_config,
prefixes=["validation_dataset", "dataset"],
stop_on_first_found=True,
)
if datasets["validation_dataset"] is not None:
dataset = datasets["validation_dataset"]
dataset_is_validation = True
except KeyError:
pass
if not dataset_is_validation:
# Get shared train + validation dataset
# prefix `dataset`
dataset = dataset_from_config(dataset_config)
else:
dataset = datasets["dataset"]
del datasets
assert dataset is not None

logger.info(
f"Loaded {'validation_' if dataset_is_validation else ''}dataset specified in {args.dataset_config.name}.",
)
Expand Down Expand Up @@ -340,23 +372,42 @@ def main(args=None, running_as_script: bool = True):
batch_i: int = 0
batch_size: int = args.batch_size

is_rank_zero: bool = True
if args.horovod:
is_rank_zero = hvd.rank() == 0
# divide the frames between ranks
n_per_rank = int(math.ceil(len(test_idcs) / hvd.size()))
test_idcs = test_idcs[hvd.rank() * n_per_rank : (hvd.rank() + 1) * n_per_rank]

logger.info("Starting...")
context_stack = contextlib.ExitStack()
with contextlib.ExitStack() as context_stack:
# "None" checks if in a TTY and disables if not
prog = context_stack.enter_context(tqdm(total=len(test_idcs), disable=None))
if do_metrics:
display_bar = context_stack.enter_context(
tqdm(
bar_format=""
if prog.disable # prog.ncols doesn't exist if disabled
else ("{desc:." + str(prog.ncols) + "}"),
disable=None,
if is_rank_zero:
# only do output on rank zero
# "None" checks if in a TTY and disables if not
prog = context_stack.enter_context(tqdm(total=len(test_idcs), disable=None))
if do_metrics:
display_bar = context_stack.enter_context(
tqdm(
bar_format=""
if prog.disable # prog.ncols doesn't exist if disabled
else ("{desc:." + str(prog.ncols) + "}"),
disable=None,
)
)
)

if output_type is not None:
output = context_stack.enter_context(open(args.output, "w"))
if args.horovod:
# give each rank its own output and merge later
# we do NOT guerantee that the final XYZ is in any order
# just that we include the indexes into the original dataset
# so this is OK
outfile = args.output.parent / (
args.output.stem + f"-rank{hvd.rank()}.xyz"
)
else:
outfile = args.output
output = context_stack.enter_context(open(outfile, "w"))
else:
output = None

Expand Down Expand Up @@ -394,24 +445,48 @@ def main(args=None, running_as_script: bool = True):
# Accumulate metrics
if do_metrics:
metrics(out, batch)
display_bar.set_description_str(
" | ".join(
f"{k} = {v:4.4f}"
for k, v in metrics.flatten_metrics(
metrics.current_result(),
type_names=dataset.type_mapper.type_names,
)[0].items()
if args.horovod:
# sync metrics across ranks
metrics.gather()
if is_rank_zero:
display_bar.set_description_str(
" | ".join(
f"{k} = {v:4.4f}"
for k, v in metrics.flatten_metrics(
metrics.current_result()
)[0].items()
)
)
)

batch_i += 1
prog.update(batch.num_graphs)

prog.close()
if do_metrics:
display_bar.close()

if do_metrics:
if is_rank_zero:
prog.update(batch.num_graphs)

if is_rank_zero:
prog.close()
if do_metrics:
display_bar.close()

if args.horovod and output_type is not None:
os.sync()

if is_rank_zero:
logger.info("Merging output files...")
output_files = [
args.output.parent / (args.output.stem + f"-rank{rank}.xyz")
for rank in range(hvd.size())
]
with open(args.output, "wb") as wfd:
for f in output_files:
with open(f, "rb") as fd:
shutil.copyfileobj(fd, wfd)
wfd.write(b"\n")
os.sync()
# delete old ones
for f in output_files:
f.unlink()

if is_rank_zero and do_metrics:
logger.info("\n--- Final result: ---")
logger.critical(
"\n".join(
Expand Down
Loading