Skip to content

Commit

Permalink
[NeMo-UX] Support load_strictness (NVIDIA#10612)
Browse files Browse the repository at this point in the history
* add load_strictness to nemo 2.0

Signed-off-by: ashors1 <[email protected]>

* bug fix

Signed-off-by: ashors1 <[email protected]>

* Revert "bug fix"

This reverts commit 54df253.

* Revert "add load_strictness to nemo 2.0"

This reverts commit 4be9bae.

* use 'strict' arg from PTL rather than adding another arg

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* control mcore strict loading via an environment variable as a temporary workaround

Signed-off-by: ashors1 <[email protected]>

* pass ckpt_load_strictness to megatron strategy

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* address comments

Signed-off-by: ashors1 <[email protected]>

* bug fix

Signed-off-by: ashors1 <[email protected]>

* remove unused import

Signed-off-by: ashors1 <[email protected]>

* fix selective restore

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* fix selective restore

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* forward strict arg to peft

Signed-off-by: ashors1 <[email protected]>

---------

Signed-off-by: ashors1 <[email protected]>
Signed-off-by: ashors1 <[email protected]>
Co-authored-by: ashors1 <[email protected]>
  • Loading branch information
ashors1 and ashors1 authored Nov 26, 2024
1 parent 7198fa4 commit 5d97b70
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 6 deletions.
11 changes: 11 additions & 0 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,17 @@ def get_safe(param_id):

def load_model_state_dict(megatron_parallel, checkpoint: Mapping[str, Any], strict: bool = True) -> None:
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.validation import StrictHandling, parse_strict_flag

## convert from StrictHandling to bool for PTL
if strict is not None and not isinstance(strict, bool):
strict = parse_strict_flag(strict)
strict_options = [
StrictHandling.ASSUME_OK_UNEXPECTED,
StrictHandling.RAISE_UNEXPECTED,
StrictHandling.RAISE_ALL,
]
strict = strict in strict_options

for index, module in enumerate(megatron_parallel):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
Expand Down
50 changes: 48 additions & 2 deletions nemo/lightning/io/pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,11 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio

@override
def load_checkpoint(
self, path: _PATH, sharded_state_dict=None, map_location: Optional[Callable] = None
self,
path: _PATH,
sharded_state_dict=None,
map_location: Optional[Callable] = None,
strict: Optional['StrictHandling'] | bool = None,
) -> Dict[str, Any]:
"""Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files.
Expand All @@ -187,6 +191,7 @@ def load_checkpoint(
"""
from megatron.core import dist_checkpointing
from megatron.core.dist_checkpointing.validation import StrictHandling

if map_location is not None:
raise ValueError("`map_location` argument is not supported for `MegatronCheckpointIO.load_checkpoint`.")
Expand Down Expand Up @@ -220,8 +225,21 @@ def load_checkpoint(
if sharded_strategy is not None:
logging.info(f'Using {sharded_strategy} dist-ckpt load strategy.')

if isinstance(strict, bool):
# For backward-compatibility reasons and a bug in MCore (strict check not applied to factories)
# we must apply a simple strict check here.
if not strict:
sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict)
strict = StrictHandling.ASSUME_OK_UNEXPECTED if strict else StrictHandling.LOG_ALL
if strict is None:
# Default behavior
strict = StrictHandling.ASSUME_OK_UNEXPECTED

checkpoint = dist_checkpointing.load(
sharded_state_dict=sharded_state_dict, checkpoint_dir=str(path), sharded_strategy=sharded_strategy
sharded_state_dict=sharded_state_dict,
checkpoint_dir=str(path),
sharded_strategy=sharded_strategy,
strict=strict,
)
checkpoint = _fix_tensors_device(checkpoint)

Expand Down Expand Up @@ -284,6 +302,34 @@ def save_sharded_strategy(self) -> 'SaveShardedStrategy':
self._save_sharded_strategy = self._determine_dist_ckpt_save_strategy()
return self._save_sharded_strategy

def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]):
from megatron.core import dist_checkpointing
from megatron.core.dist_checkpointing.dict_utils import extract_matching_values
from megatron.core.dist_checkpointing.mapping import ShardedBase

ckpt_sharded_metadata = dist_checkpointing.load_tensors_metadata(path)
loaded_keys = []
missing_keys = []
unexpected_keys = []

def should_remove_missing_sharded_base(x: Any):
if isinstance(x, ShardedBase):
if x.key in ckpt_sharded_metadata:
loaded_keys.append(x.key)
return False
else:
unexpected_keys.append(x.key)
return True
return False

_, sharded_state_dict = extract_matching_values(sharded_state_dict, should_remove_missing_sharded_base)
logging.info(f'The following keys are not in the checkpoint and will not be loaded: {unexpected_keys}')

# TODO: compute missing_keys by:
# 1. all_gather_object of loaded_keys
# 2. missing_keys = ckpt_sharded_metadata.keys() - loaded_keys
return sharded_state_dict


def _fix_tensors_device(ckpt: Dict) -> Dict:
"""Ensure checkpoint tensors are on the correct device."""
Expand Down
8 changes: 6 additions & 2 deletions nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,11 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio

@override
def load_checkpoint(
self, path: _PATH, sharded_state_dict=None, map_location: Optional[Callable] = None
self,
path: _PATH,
sharded_state_dict=None,
map_location: Optional[Callable] = None,
strict: Optional['StrictHandling'] | bool = None,
) -> Dict[str, Any]:
"""
=====================
Expand Down Expand Up @@ -452,7 +456,7 @@ def load_checkpoint(
self.model_ckpt_path = path

# Note: this will include the Trainer-state of the model-checkpoint
model_ckpt = self.checkpoint_io.load_checkpoint(path, sharded_state_dict, map_location)
model_ckpt = self.checkpoint_io.load_checkpoint(path, sharded_state_dict, map_location, strict)
if adapter_ckpt is not None:
## PEFT Resume, FIRST TIME
adapter_ckpt['state_dict'].update(model_ckpt['state_dict'])
Expand Down
16 changes: 14 additions & 2 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
ckpt_load_directly_on_device (bool): if True, loads the weights directly on GPU.
Has effect only for `zarr` based checkpoints (PyT Distributed always loads on device).
Defaults to True.
ckpt_load_strictness (StrictHandling, optional): defines loading strictness.
If not None, overwrites the `strict` flag passed to `load_checkpoint`.
Defaults to None.
setup_optimizers (bool): Whether to call the trainer's setup_optimizers function to perform any
necessary conversions of optimizer parameters and move optimizer parameters to the correct device.
Defaults to True.
Expand Down Expand Up @@ -204,6 +207,7 @@ def __init__(
ckpt_parallel_load: bool = True,
ckpt_parallel_save_optim: bool = True,
ckpt_load_directly_on_device: bool = True,
ckpt_load_strictness: Optional['StrictHandling'] = None,
setup_optimizers: bool = True,
init_model_parallel: bool = True,
replace_progress_bar: bool = True,
Expand Down Expand Up @@ -238,6 +242,7 @@ def __init__(
self.lazy_init = lazy_init
self.ckpt_load_optimizer = ckpt_load_optimizer
self.ckpt_save_optimizer = ckpt_save_optimizer
self.ckpt_load_strictness = ckpt_load_strictness
self.pipeline_dtype = pipeline_dtype
self._setup_optimizers = setup_optimizers
self._init_model_parallel = init_model_parallel
Expand Down Expand Up @@ -733,7 +738,12 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path], selective_restore:
if self.lightning_module.optimizers(use_pl_optimizer=False):
sharded_state_dict["optimizer"] = [self.optimizer_sharded_state_dict(is_loading=True)]

checkpoint = self.checkpoint_io.load_checkpoint(checkpoint_path, sharded_state_dict=sharded_state_dict)
strict = (
self.lightning_module.strict_loading if self.ckpt_load_strictness is None else self.ckpt_load_strictness
)
checkpoint = self.checkpoint_io.load_checkpoint(
checkpoint_path, sharded_state_dict=sharded_state_dict, strict=strict
)

if selective_restore:
final_checkpoint = {}
Expand All @@ -755,7 +765,8 @@ def selective_restore(self) -> None:

if self.restore_config.load_model_state:
logging.info(f"Restoring model weights from {self.restore_config}")
self.load_model_state_dict(checkpoint=checkpoint)
strict = True if self.ckpt_load_strictness is None else self.ckpt_load_strictness
self.load_model_state_dict(checkpoint=checkpoint, strict=strict)

if self.restore_config.load_optim_state:
logging.info(f"Restoring optimizer states from {self.restore_config}")
Expand Down Expand Up @@ -790,6 +801,7 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr
"""loads model state dict"""
assert self.megatron_parallel is not None

strict = strict if self.ckpt_load_strictness is None else self.ckpt_load_strictness
_strategy_lib.load_model_state_dict(self.megatron_parallel, checkpoint, strict=strict)

if not 'optimizer' in checkpoint:
Expand Down

0 comments on commit 5d97b70

Please sign in to comment.