Skip to content

Commit

Permalink
[NeMo-UX] Turn on mcore performance optimizations (NVIDIA#10209)
Browse files Browse the repository at this point in the history
* expose TP overlap

Signed-off-by: Jieming Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* add tp overlap recipes

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* turn on pipeline parallel overlap

Signed-off-by: Jimmy Zhang <[email protected]>

* refactor

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* Update base.py

Signed-off-by: JimmyZhang12 <[email protected]>

* Update megatron_parallel.py

Signed-off-by: JimmyZhang12 <[email protected]>

* remove env var

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* add optimization config

Signed-off-by: Jimmy Zhang <[email protected]>

* fix typo

Signed-off-by: Jimmy Zhang <[email protected]>

* refactor into megatron parallel setup

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* refactor

Signed-off-by: Jimmy Zhang <[email protected]>

* fix config ordering, add wgrad deferral

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* cleanup

Signed-off-by: Jimmy Zhang <[email protected]>

* use config

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* clean

Signed-off-by: Jimmy Zhang <[email protected]>

* enable wgrad defferal

Signed-off-by: Jimmy Zhang <[email protected]>

* add grad bucket size

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* move everthing into a callback

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* cleanup

Signed-off-by: Jimmy Zhang <[email protected]>

* fix imports

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* move userbuffer init

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* cleanup

Signed-off-by: Jimmy Zhang <[email protected]>

* fix VP

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* address comments

Signed-off-by: Jimmy Zhang <[email protected]>

* add gradient accum guard

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* Update base.py

Signed-off-by: JimmyZhang12 <[email protected]>

* address comments

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* address comments

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

---------

Signed-off-by: Jieming Zhang <[email protected]>
Signed-off-by: JimmyZhang12 <[email protected]>
Signed-off-by: Jimmy Zhang <[email protected]>
Signed-off-by: JimmyZhang12 <[email protected]>
Co-authored-by: Jieming Zhang <[email protected]>
Co-authored-by: JimmyZhang12 <[email protected]>
  • Loading branch information
3 people authored Sep 6, 2024
1 parent fdf1979 commit 1d5de59
Show file tree
Hide file tree
Showing 6 changed files with 397 additions and 0 deletions.
11 changes: 11 additions & 0 deletions nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@

_, HAVE_TE = safe_import("transformer_engine")

# Gradient accumulation fusion may be enabled if available, for more information see:
# https://github.com/NVIDIA/Megatron-LM/blob/01945b98d1ea3a2acb5e8301e181a328104f4856/megatron/core/tensor_parallel/layers.py#L575
# TODO: Clean this up with a getter and install instructions
_grad_accum_fusion_available = True
try:
import fused_weight_gradient_mlp_cuda
except ImportError:
_grad_accum_fusion_available = False

if TYPE_CHECKING:
from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel

Expand Down Expand Up @@ -124,6 +133,8 @@ class GPTConfig(TransformerConfig, io.IOMixin):
seq_length: int = 1024
attention_softmax_in_fp32: bool = False
masked_softmax_fusion: bool = True
cross_entropy_loss_fusion: bool = True
gradient_accumulation_fusion: bool = _grad_accum_fusion_available
deallocate_pipeline_outputs = True

transformer_layer_spec: Union[ModuleSpec, Callable[["GPTConfig"], ModuleSpec]] = default_layer_spec
Expand Down
25 changes: 25 additions & 0 deletions nemo/collections/llm/recipes/llama3_70b.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed_plugin
from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192
from nemo.collections.llm.utils import Config, Partial
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.utils.exp_manager import TimingCallback

NAME = "llama3_70b"
Expand Down Expand Up @@ -93,6 +95,29 @@ def pretrain_recipe(
)


def pretrain_recipe_performance(
name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int, fn: Callable = pretrain
) -> Partial:
"""'pretrain_recipe_performance' turns on performance optimizations that cannot be enabled by default
due to being model specific or lacking sufficent support. For better compatibility please use
the default 'pretrain_recipe()' above."""
recipe = pretrain_recipe(
name=name, ckpt_dir=ckpt_dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn
)

recipe.trainer.callbacks.append(
Config(
MegatronCommOverlapCallback,
tp_comm_overlap=True,
tp_comm_overlap_cfg=userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192,
defer_embedding_wgrad_compute=True,
wgrad_deferral_limit=22,
)
)

return recipe


def hf_resume() -> Config[nl.AutoResume]:
return Config(
nl.AutoResume,
Expand Down
20 changes: 20 additions & 0 deletions nemo/collections/llm/recipes/llama3_8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed_plugin
from nemo.collections.llm.utils import Config, Partial
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.utils.exp_manager import TimingCallback

NAME = "llama3_8b"
Expand Down Expand Up @@ -92,6 +93,25 @@ def pretrain_recipe(
)


def pretrain_recipe_performance(
name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int, fn: Callable = pretrain
) -> Partial:
"""'pretrain_recipe_performance' turns on performance optimizations that cannot be enabled by default
due to being model specific or lacking sufficent support. For better compatibility please use
the default 'pretrain_recipe()' above."""
recipe = pretrain_recipe(
name=name, ckpt_dir=ckpt_dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn
)

recipe.trainer.callbacks.append(
Config(
MegatronCommOverlapCallback,
tp_comm_overlap=False,
)
)
return recipe


def hf_resume() -> Config[nl.AutoResume]:
return Config(
nl.AutoResume,
Expand Down
Empty file.
73 changes: 73 additions & 0 deletions nemo/collections/llm/recipes/tp_overlap_configs/userbuffers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from dataclasses import dataclass


@dataclass
class TPOverlapCfg:
pass


@dataclass
class PipelineOverlapCfg(TPOverlapCfg):
num_sm: int
cga_size: int
num_splits: int
set_sm_margin: bool
fp8_buf: bool = (False,)
method: str = 'pipeline'


@dataclass
class RingExchangeOverlapCfg(TPOverlapCfg):
aggregate: bool = False
method: str = 'ring_exchange'


@dataclass
class BulkOverlapCfg(TPOverlapCfg):
num_sm: int
cga_size: int
set_sm_margin: bool
method: str = 'bulk'


@dataclass
class TransformerLayerTPOverlapCfg:
qkv_dgrad: TPOverlapCfg
qkv_wgrad: TPOverlapCfg
fc1_dgrad: TPOverlapCfg
fc1_wgrad: TPOverlapCfg
qkv_fprop: TPOverlapCfg
proj_dgrad: TPOverlapCfg
fc1_fprop: TPOverlapCfg
fc2_dgrad: TPOverlapCfg
proj_fprop: TPOverlapCfg
fc2_fprop: TPOverlapCfg


# TODO: Add more configs and create a getter function for expose a single api
# Model configs: H100/70B/TP8/MBS1/SeqLen8K
userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192 = TransformerLayerTPOverlapCfg(
qkv_dgrad=BulkOverlapCfg(num_sm=4, cga_size=2, set_sm_margin=False),
qkv_wgrad=BulkOverlapCfg(num_sm=24, cga_size=2, set_sm_margin=False),
fc1_dgrad=BulkOverlapCfg(num_sm=2, cga_size=2, set_sm_margin=False),
fc1_wgrad=BulkOverlapCfg(num_sm=4, cga_size=2, set_sm_margin=False),
qkv_fprop=RingExchangeOverlapCfg(aggregate=False),
proj_dgrad=RingExchangeOverlapCfg(aggregate=False),
fc1_fprop=RingExchangeOverlapCfg(aggregate=False),
fc2_dgrad=RingExchangeOverlapCfg(aggregate=False),
proj_fprop=PipelineOverlapCfg(num_sm=24, cga_size=2, num_splits=4, set_sm_margin=True),
fc2_fprop=PipelineOverlapCfg(num_sm=16, cga_size=2, num_splits=4, set_sm_margin=True),
)

userbuffers_fp8_h100_h8192_tp4_mbs1_seqlen8192 = TransformerLayerTPOverlapCfg(
qkv_dgrad=BulkOverlapCfg(num_sm=4, cga_size=2, set_sm_margin=False),
qkv_wgrad=BulkOverlapCfg(num_sm=24, cga_size=2, set_sm_margin=False),
fc1_dgrad=BulkOverlapCfg(num_sm=2, cga_size=2, set_sm_margin=False),
fc1_wgrad=BulkOverlapCfg(num_sm=4, cga_size=2, set_sm_margin=False),
qkv_fprop=RingExchangeOverlapCfg(aggregate=False),
proj_dgrad=RingExchangeOverlapCfg(aggregate=False),
fc1_fprop=RingExchangeOverlapCfg(aggregate=False),
fc2_dgrad=RingExchangeOverlapCfg(aggregate=False),
proj_fprop=PipelineOverlapCfg(num_sm=24, cga_size=2, num_splits=4, set_sm_margin=True, fp8_buf=True),
fc2_fprop=PipelineOverlapCfg(num_sm=16, cga_size=2, num_splits=4, set_sm_margin=True, fp8_buf=True),
)
Loading

0 comments on commit 1d5de59

Please sign in to comment.