From 460d1f53d1dbc29898cd5aeb14806faa8f497ffb Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Tue, 18 Jun 2024 21:02:59 +0800 Subject: [PATCH] fix bug --- internlm/data/tokenized/packed_dataset.py | 5 ++--- internlm/data/tokenized/single_dataset.py | 5 ++--- internlm/initialize/launch.py | 4 ---- train.py | 4 ++++ 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/internlm/data/tokenized/packed_dataset.py b/internlm/data/tokenized/packed_dataset.py index 0284af15..1d525965 100644 --- a/internlm/data/tokenized/packed_dataset.py +++ b/internlm/data/tokenized/packed_dataset.py @@ -115,7 +115,7 @@ def __init__( self.seed = DEFAULT_SEED self.path = self.get_dataset_name() - if not self.dataset.use_shm: + if not gpc.config.data.use_shm: self._process_init() else: if self.dataset.found_cache: @@ -282,7 +282,7 @@ def __init__( ): super().__init__(dataset, max_length_per_sample, packed_length) self.path = self.get_dataset_name() - if not self.dataset.use_shm: + if not gpc.config.data.use_shm: self.sample_indices, self.len_samples_shuffled, self.acm_len_samples = self.accu_sample_len(seed=self.seed) self.num_tokens = sum(self.lengths) else: @@ -549,7 +549,6 @@ def get_packed_dataset_without_short_length( fp, ds_type_id, min_length=min_length_num, - use_shm=gpc.config.data.use_shm, pack_sample_into_one=pack_sample_into_one, ) diff --git a/internlm/data/tokenized/single_dataset.py b/internlm/data/tokenized/single_dataset.py index 63b2a3b9..a3c95dc0 100644 --- a/internlm/data/tokenized/single_dataset.py +++ b/internlm/data/tokenized/single_dataset.py @@ -55,9 +55,8 @@ class JsonlDataset(torch.utils.data.Dataset): Note that only the "tokens" key is used. """ - def __init__(self, path: str, dataset_type_id: int = 0, min_length=50, use_shm=False, pack_sample_into_one=False): - self.use_shm = use_shm - if not self.use_shm: + def __init__(self, path: str, dataset_type_id: int = 0, min_length=50, pack_sample_into_one=False): + if not gpc.config.data.use_shm: self._process_init(path, dataset_type_id, min_length) else: devices_per_node = internlm_accelerator.device_count() diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 64668b21..58a36cc3 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -2,7 +2,6 @@ # -*- encoding: utf-8 -*- import argparse -import gc import os from pathlib import Path from typing import Dict, Union @@ -628,9 +627,6 @@ def initialize_distributed_env( """ backend = internlm_accelerator._communication_backend_name - # close automatic garbage collection - gc.disable() - if launcher == "torch": launch_from_torch(config=config, seed=seed, backend=backend) elif launcher == "slurm": diff --git a/train.py b/train.py index 5145a850..be50d7e8 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import gc import logging import os import shutil @@ -178,6 +179,9 @@ def main(args): # transfer the train data loader into train data iterator train_iter = iter(train_dl) + # close automatic garbage collection + gc.disable() + with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof: # start iterating the train data and begin training for batch_count in range(train_state.batch_count, total_steps):