Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

re-merge from NVIDIA main #68

Open
wants to merge 28 commits into
base: multi-query-attention
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
8c86034
Add option to specify a data cache path separate from data directory.
jaredcasper May 19, 2023
ae37924
Check for write failure of index cache and print error message.
jaredcasper May 23, 2023
41ba8a2
Merge branch 'main' into data-cache
jaredcasper May 30, 2023
f9283c5
Add option to overlap p2p communication.
jaredcasper May 31, 2023
621c9de
typo
jaredcasper May 31, 2023
2c13d1f
Consistent arg names.
jaredcasper May 31, 2023
e6d7e09
Merge branch 'p2p_overlap' into 'main'
jaredcasper Jun 1, 2023
e4adfbc
Merge branch 'data-cache' into 'main'
jaredcasper Jun 1, 2023
4d564cf
Supporting loading checkpoints without add_position_embedding arg.
jaredcasper Jun 2, 2023
1997e94
Fix GPTDataset assert.
jaredcasper Jun 2, 2023
f965380
Merge branch 'gptdataset-assert' into 'main'
jaredcasper Jun 5, 2023
a6c574d
Fixed rotary_pos_emb's position in layer's forward args.
lmcafee-nvidia Jun 5, 2023
382fd9d
Merge branch 'lmcafee/rotary-kwarg-dev' into 'main'
jaredcasper Jun 5, 2023
41221b8
fix indexation for output tensor after gradscaler call
aklife97 Jun 5, 2023
d2891b4
Merge branch 'outputtensor_index' into 'main'
jaredcasper Jun 5, 2023
ea76ecd
Perform grad sync at correct place in interleaved pipeline parallelism
timmoon10 Jun 6, 2023
992da75
Merge branch 'interleaved-pipeline-bugfix' into 'main'
jaredcasper Jun 6, 2023
f6c6d86
Merge branch 'ckpt-load-fix' into 'main'
jaredcasper Jun 6, 2023
2880267
Add workarounds for non-determinism in Megatron training
jon-barker Jun 8, 2023
db71a33
Merge branch 'jbarker/non_determinism_fix' into 'main'
jon-barker Jun 8, 2023
1af380d
Update gitlab to catch pytest errors
shanmugamr1992 Jun 10, 2023
c7a0145
Merge branch 'pytestError' into 'main'
shanmugamr1992 Jun 10, 2023
bf5206e
Remove use of deprecated np.float in indexed_dataset.py
jon-barker Jun 12, 2023
000590e
Merge branch 'jbarker/np_float64_deprecation' into 'main'
jaredcasper Jun 12, 2023
f479999
Retro fix for tensor parallelism.
lmcafee-nvidia Jun 13, 2023
0604155
Merge branch 'lmcafee/retro-dataloader-fix' into 'main'
jaredcasper Jun 13, 2023
e5a9d7b
Merge branch 'main' of github.com:NVIDIA/Megatron-LM into NVIDIA-main2
RaymondLi0 Jun 23, 2023
8b6ceeb
pass on data_cache_pass in build_dataset_group
Jun 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ unit_tests:
- echo "Slurm job state $SLURM_STATE"
- if [[ "$SLURM_STATE" != "COMPLETED" ]]; then echo "Slurm job did not complete. See ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${RUN_NAME}/results directory for result logs. Skipping pytest."; exit 1; fi
- source $PYTHON_VIRTUAL_ENV
- pytest $BUILD_DIR/tests/functional_tests/python_test_utils/test_resume_checkpoint_pipeline.py || echo "Pytest failed. See ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${RUN_NAME}/results directory for result logs."
- cmd='pytest $BUILD_DIR/tests/functional_tests/python_test_utils/test_resume_checkpoint_pipeline.py'
- if $cmd; then echo "Pytest succeded"; else echo "Pytest failed. See ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${RUN_NAME}/results directory for result logs"; fi
- echo "Completed the job"
rules:
- if: $TEST_LEVEL =~ $TESTS_TO_RUN_ON_THIS_COMMIT || $CI_JOB_NAME =~ $TESTS_TO_RUN_ON_THIS_COMMIT || $CI_JOB_NAME =~ $TEST_REGEX_ON_THIS_COMMIT
Expand Down Expand Up @@ -134,7 +135,8 @@ unit_tests:
if [[ $USE_TE -ne 1 ]]; then
echo "Checking against ground truth file"
export EXPECTED_METRICS_FILE=$BUILD_DIR/tests/functional_tests/test_results/$RUN_MODEL/$RUN_NAME.json
pytest $BUILD_DIR/tests/functional_tests/python_test_utils/test_ci_pipeline.py || echo "Pytest failed. See ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${RUN_NAME}/results directory for result logs."
cmd='pytest $BUILD_DIR/tests/functional_tests/python_test_utils/test_ci_pipeline.py'
if $cmd; then echo "Pytest succeded"; else echo "Pytest failed. See ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${RUN_NAME}/results directory for result logs"; fi
fi
- echo "Completed the job"
rules:
Expand Down
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ The following table shows both model (MFU) and hardware (HFU) FLOPs utilization
* [Datasets](#datasets)
* [Collecting Wikipedia Training Data](#collecting-wikipedia-training-data)
* [Collecting GPT Webtext Data](#collecting-gpt-webtext-data)
* [Reproducibility](#reproducibility)

# Setup
We strongly recommend using the latest release of [NGC's PyTorch container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) with DGX nodes. If you can't use this for some reason, use the latest pytorch, cuda, nccl, and NVIDIA [APEX](https://github.com/NVIDIA/apex#quick-start) releases. Data preprocessing requires [NLTK](https://www.nltk.org/install.html), though this is not required for training, evaluation, or downstream tasks.
Expand Down Expand Up @@ -365,7 +366,7 @@ See [megatron/text_generation_server.py](megatron/text_generation_server.py) for
### Detoxify GPT via Self-generation
We include an example in `examples/detxoify_lm/` to detoxify language models by leveraging the generative power of language models.

See [examples/detxoify_lm/README.md](examples/detxoify_lm/README.md) for step-by-step tutorials on how to perform domain-adaptive training and detoxify LM using self-generated corpus.
See [examples/detxoify_lm/README.md](examples/detxoify_lm/README.md) for step-by-step tutorials on how to perform domain-adaptive training and detoxify LM using self-generated corpus.


## GPT Evaluation
Expand Down Expand Up @@ -513,3 +514,13 @@ We recommend using the `--json` argument when using WikiExtractor, which will du

## Collecting GPT Webtext Data
We utilize the publicly available [OpenWebText](https://github.com/eukaryote31/openwebtext) library from [jcpeterson](https://github.com/jcpeterson/openwebtext) and [eukaryote31's](https://github.com/eukaryote31/openwebtext) work to download urls. We then filtered, cleaned, and deduplicated all downloaded content according to the procedure described in our [openwebtext](./tools/openwebtext) directory. For reddit URLs corresponding to content up to October 2018 we arrived at approximately 37GB of content.

# Reproducibility
Megatron training is intended to be bitwise reproducible. This means that the same training config run twice in the same HW and SW environment should produce identical model checkpoints, losses and accuracy metric values (iteration time metrics may vary).

There are currently three known Megatron optimizations that break reproducibility whilst still producing almost identical training runs. They are only applicable when using NGC containers >=22.05. The following workarounds should be applied in cases where reproducibility is required:
1. When training using the `--bf16` option the backward pass of `torch.nn.functional.embedding` is non-deterministic. If reproducibility is required you should also use the option `--embedding-weights-in-fp32`. The speed and memory impact of this change is negligible.
2. Also when training using `--bf16`, reproducbility is only obtained when the checkpointing and resume schedule of training is identical. If the checkpointing schedule will change, i.e. checkpointing and resume will occur at different iterations, the option `--no-bias-gelu-fusion` should be used.
3. Flash attention is non-deterministic. If reproducibility is required do not use `--use-flash-attn`.

These sources of non-determinism are under active investigation. If you observe non-determinism in Megatron training under other circumstances please open an issue.
14 changes: 11 additions & 3 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
# Args from environment
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))

return args

def validate_args(args, defaults={}):
Expand Down Expand Up @@ -626,6 +626,8 @@ def _add_network_size_args(parser):
help='Number of Experts in Switch Transformer (None means no Switch)')
group.add_argument('--untie-embeddings-and-output-weights', action='store_true',
help='Untie embeddings and output weights.'),
group.add_argument('--embedding-weights-in-fp32', action='store_true',
help='Cast word embedding weights to fp32 before embedding fwd.'),
return parser


Expand Down Expand Up @@ -1020,6 +1022,10 @@ def _add_distributed_args(parser):
'--tensor-model-parallel-size instead.')
group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
help='Number of layers per virtual pipeline stage')
group.add_argument('--overlap-p2p-communication',
action='store_true',
help='overlap pipeline parallel communication with forward and backward chunks',
dest='overlap_p2p_comm')
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.')
Expand Down Expand Up @@ -1212,6 +1218,8 @@ def __call__(self, parser, args, values, option_string=None):
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--data-cache-path', default=None,
help='Path to a directory to hold cached index files.')

group.add_argument('--vocab-size', type=int, default=None,
help='Size of vocab before EOD or padding.')
Expand Down Expand Up @@ -1385,14 +1393,14 @@ def _add_vision_args(parser):
group.add_argument('--swin-backbone-type', type=str, default='tiny',
choices=['tiny', 'base', 'h3'],
help='pretraining objectives')

# inpainting arguments
group.add_argument('--mask-type', type=str, default='random',
choices=['random', 'row'],
help='mask types')
group.add_argument('--mask-factor', type=float, default=1.0,
help='mask size scaling parameter')

# dino arguments
group.add_argument('--iter-per-epoch', type=int, default=1250,
help='iterations per epoch')
Expand Down
12 changes: 8 additions & 4 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@ def check_checkpoint_args(checkpoint_args):
arguments and the one retrieved from checkpoint."""
args = get_args()

def _compare(arg_name, old_arg_name=None):
def _compare(arg_name, old_arg_name=None, default=None):
if old_arg_name is not None:
checkpoint_value = getattr(checkpoint_args, old_arg_name)
ckpt_arg_name = old_arg_name
else:
checkpoint_value = getattr(checkpoint_args, arg_name)
ckpt_arg_name = arg_name
if default is not None:
checkpoint_value = getattr(checkpoint_args, ckpt_arg_name, default)
else:
checkpoint_value = getattr(checkpoint_args, ckpt_arg_name)
args_value = getattr(args, arg_name)
error_message = '{} value from checkpoint ({}) is not equal to the ' \
'input argument value ({}).'.format(
Expand All @@ -52,7 +56,7 @@ def _compare(arg_name, old_arg_name=None):
_compare('num_layers')
_compare('hidden_size')
_compare('num_attention_heads')
_compare('add_position_embedding')
_compare('add_position_embedding', default=True)
try:
_compare('position_embedding_type')
except AttributeError as e:
Expand Down
Loading