From 69cb00c1a3ae260a70c10d1da248467140947900 Mon Sep 17 00:00:00 2001 From: jiaxingli <43110891+li126com@users.noreply.github.com> Date: Tue, 9 Jul 2024 12:32:22 +0800 Subject: [PATCH 01/12] feat(doc): add MOE installation (#265) --- doc/en/install.md | 5 +++++ doc/install.md | 6 +++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/doc/en/install.md b/doc/en/install.md index 404a08ae1..304d110a7 100644 --- a/doc/en/install.md +++ b/doc/en/install.md @@ -82,6 +82,11 @@ pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp cd ../../ ``` +### Additional Installation +```bash +pip install git+https://github.com/databricks/megablocks@v0.3.2 # MOE need +``` + ### Environment Image Users can use the provided dockerfile combined with docker.Makefile to build their own images, or obtain images with InternEvo runtime environment installed from https://hub.docker.com/r/internlm/internlm. diff --git a/doc/install.md b/doc/install.md index 99391efa7..b894f8fa5 100644 --- a/doc/install.md +++ b/doc/install.md @@ -82,6 +82,11 @@ pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp cd ../../ ``` +### 额外安装 +```bash +pip install git+https://github.com/databricks/megablocks@v0.3.2 # MOE相关 +``` + ### 环境镜像 用户可以使用提供的 dockerfile 结合 docker.Makefile 来构建自己的镜像,或者也可以从 https://hub.docker.com/r/internlm/internlm 获取安装了 InternEvo 运行环境的镜像。 @@ -134,4 +139,3 @@ pip3 install setuptools wget https://gitee.com/ascend/pytorch/releases/download/v6.0.rc1-pytorch2.1.0/torch_npu-2.1.0.post3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl pip install torch_npu-2.1.0.post3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ``` - From 751c103e32138a88a8e8628499af48e2c83cce86 Mon Sep 17 00:00:00 2001 From: jiaxingli <43110891+li126com@users.noreply.github.com> Date: Tue, 9 Jul 2024 12:33:02 +0800 Subject: [PATCH 02/12] Fix(QA): fix loading ckpt and add launcher setting for test loss (#206) --- tests/test_training/test_loss.py | 36 +++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index 7a2d5ca2a..e69073538 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -6,6 +6,7 @@ import torch.distributed as dist import internlm +from internlm.accelerator import AcceleratorType, get_accelerator from internlm.checkpoint import CheckpointManager from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc @@ -44,8 +45,8 @@ 4.616517543792725, ] - cur_loss_list = [] +internlm_accelerator = get_accelerator() def train( @@ -57,7 +58,8 @@ def train( interleaved: bool = False, tp_mode: str = "mtp", enable_sp: bool = False, - enable_ckpt: bool = False, + save_ckpt: bool = False, + load_ckpt: bool = False, model_type: str = "INTERNLM", optimizer_ver: str = "v1", ): @@ -69,6 +71,9 @@ def train( config.data.fixed_random_dataset_seqlen = False config.lr_scheduler.total_steps = TOTAL_STEPS config.model_type = model_type + config.ckpt.load_ckpt_folder = None + config.ckpt.load_ckpt_info = None + config.ckpt.auto_resume = False total_steps = config.data.total_steps skip_batches = config.data.skip_batches label_smoothing = config.loss.label_smoothing @@ -80,16 +85,16 @@ def train( # update ckpt config if model_type == "INTERNLM" and tp_mode != "isp" and interleaved is False: config.ckpt.load_ckpt_info = dict(path=INTERNLM1_CKPT_PATH, content=("model",), ckpt_type="internlm_test") - config.ckpt.auto_resume = False - if enable_ckpt: + if save_ckpt: config.ckpt.enable_save_ckpt = True config.ckpt.checkpoint_every = 10 config.ckpt.save_ckpt_folder = "local:llm_ckpts/" - config.ckpt.load_ckpt_folder = "local:llm_ckpts/" - config.ckpt.load_ckpt_info["content"] = ("all",) config.ckpt.oss_snapshot_freq = 100 + if load_ckpt: + config.ckpt.load_ckpt_info = dict(path="local:llm_ckpts/10", content=("all",), ckpt_type="internevo") + # update parallel config config.parallel.tensor = dict(size=tp_size, mode=tp_mode) config.parallel.pipeline = dict(size=pp_size) @@ -98,7 +103,18 @@ def train( config.parallel.pipeline = dict(size=pp_size, interleaved_overlap=True) config.model.num_chunks = num_chunks - initialize_distributed_env(config=config) + if tp_mode == "isp" and internlm_accelerator.get_accelerator_backend() in [ + AcceleratorType.NPU, + AcceleratorType.DIPU, + ]: + config.data.use_packed_dataset = False + + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: + launcher = "slurm" + else: + launcher = "torch" + + initialize_distributed_env(config=config, launcher=launcher) assert hasattr(gpc, "config") and gpc.config is not None # check parallel config @@ -241,7 +257,7 @@ def train( ) if gpc.is_rank_for_log(): assert loss is not None and not math.isnan(loss.item()) - global cur_loss_list + global cur_loss_list # pylint: disable=W0602 cur_loss_list.append((loss.item() - moe_loss.item() if moe_loss is not None else loss.item())) timer("fwd-bwd").stop() @@ -463,7 +479,7 @@ def test_training_with_isp_save_ckpt(): CONFIG_FILE_PATH = "./configs/7B_isp_sft.py" # model training save ckpt - train(dp_size=4, tp_size=2, wp_size=4, tp_mode="isp", enable_sp=True, enable_ckpt=True) + train(dp_size=4, tp_size=2, wp_size=4, tp_mode="isp", enable_sp=True, save_ckpt=True) @pytest.mark.training_8GPU_ISP_LOAD_CKPT @@ -476,7 +492,7 @@ def test_training_with_isp_load_ckpt(): TOTAL_STEPS = 20 # model training load ckpt - train(dp_size=4, tp_size=2, wp_size=4, tp_mode="isp", enable_sp=True, enable_ckpt=True) + train(dp_size=4, tp_size=2, wp_size=4, tp_mode="isp", enable_sp=True, load_ckpt=True) @pytest.mark.training_llama2 From 98ff6f8095867e394ed0992bf76b641842855fd5 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Tue, 9 Jul 2024 12:50:02 +0800 Subject: [PATCH 03/12] feat(moe): impl hugginface internln-moe (#271) --- transformers/convert2hf_internlm_moe.py | 410 +++++ transformers/internlm_moe_model/__init__.py | 9 + .../configuration_internlm_moe.py | 122 ++ .../modeling_internlm_moe.py | 1421 +++++++++++++++++ .../tokenization_internlm.py | 237 +++ 5 files changed, 2199 insertions(+) create mode 100644 transformers/convert2hf_internlm_moe.py create mode 100644 transformers/internlm_moe_model/__init__.py create mode 100644 transformers/internlm_moe_model/configuration_internlm_moe.py create mode 100644 transformers/internlm_moe_model/modeling_internlm_moe.py create mode 100644 transformers/internlm_moe_model/tokenization_internlm.py diff --git a/transformers/convert2hf_internlm_moe.py b/transformers/convert2hf_internlm_moe.py new file mode 100644 index 000000000..1800e6253 --- /dev/null +++ b/transformers/convert2hf_internlm_moe.py @@ -0,0 +1,410 @@ +# Copyright (c) InternLM. All rights reserved. +""" +python transformers/convert2hf_internlm.py --src /path/to/src --tgt /path/to/tgt \ + --max_shard 2G --maxx_pos 8192 \ + --tokenizer /path/to/tokenizer.model \ +""" +import argparse +import gc +import json +import os +import re +import time + +import torch +from datasets import Dataset +from internlm_model import InternLMTokenizer +from internlm_moe_model import InternLMMoEConfig, InternLMMoEForCausalLM +from tqdm import tqdm + +from transformers import Trainer, TrainingArguments +from transformers.modeling_utils import no_init_weights + +embedding_key_list = ["embedding.word_embeddings.weight", "embedding.weight", "tok_embeddings.weight", None] + + +def _find_max_tp_pp(names): + ckpt_names = [] + for name in names: + if name.startswith("model_t") and not name.endswith("md5"): + # _t: avoid conflictint with model_config.pt + ckpt_names.append(name) + + max_tp, max_pp = -1, -1 + for ckpt in ckpt_names: + _, tp, pp = os.path.splitext(ckpt)[0].split("_") + max_tp = max(max_tp, int(tp[2:]) + 1) + max_pp = max(max_pp, int(pp[2:]) + 1) + + return max_tp, max_pp + + +def load_source(src): + """ + load model_config.pt and model_tp{x}_pp{x}.pt from ``src`` + + :return: + - model_config: dict + - states: 2-d array. states[i][j] stands for state_dict of tp_i pp_j + """ + + # config + print("Config loading", flush=True) + config_file = os.path.join(src, "model_config.pt") + assert os.path.isfile(config_file), f"model_config.pt is not found in :{os.listdir(src)}" + model_config = torch.load(config_file) + print(model_config) + print("Config loaded.", flush=True) + + # checkpoint + # find tp pp + assert os.path.isdir(src), "not a folder." + ckpt_names = os.listdir(src) + max_tp, max_pp = _find_max_tp_pp(ckpt_names) + num_moe_layer = model_config["num_layers"] + num_experts = model_config["num_experts"] + + # 2-d array tp_rank, pp_rank + print("Source Checkpoint Loading", flush=True) + states = [[None for _ in range(max_pp)] for __ in range(max_tp)] + moe_states = [[{} for _ in range(max_pp)] for __ in range(max_tp)] + for tp in tqdm(range(max_tp)): + for pp in tqdm(range(max_pp)): + ckpt_name = os.path.join(src, f"model_tp{tp}_pp{pp}.pt") + states[tp][pp] = torch.load(ckpt_name, map_location="cpu") + for lay_id in tqdm(range(num_moe_layer)): + for expert_id in range(num_experts): + moe_ckpt_name = os.path.join(src, f"model_moe_layer{lay_id}_expert{expert_id}_tp{tp}.pt") + moe_states[tp][pp].update(torch.load(moe_ckpt_name, map_location="cpu")) + print("Source Checkpoint Loaded", flush=True) + return model_config, states, moe_states + + +def merge(states): + """ + Merge state dicts of pipeline format and shift some layers. + + :return: + - config: InternLMMoEConfig + - states: merged state dict + """ + # merge pp + merged_states = [] + print("Pipeline Merging", flush=True) + for tp_state in tqdm(states): + layer_shift = 0 + shifted_state = {} + # shift key + for tp_pp_state in tp_state: + _layer_shift = 0 + keys = list(tp_pp_state.keys()) + for key in keys: + if key.endswith(".inv_freq"): + continue + match = re.search(r"\.\d+\.", key) + name = key + if match is not None: + # layers + s, e = match.span() + layer_idx = int(key[s + 1 : e - 1]) + layer_shift + _layer_shift = max(_layer_shift, int(key[s + 1 : e - 1])) + name = key[:s] + f".{layer_idx}." + key[e:] + if name.startswith("model."): + name = name[6:] + shifted_state[name] = tp_pp_state[key] + layer_shift += _layer_shift + 1 + + merged_states.append(shifted_state) + + print("Pipeline Merged", flush=True) + + return merged_states + + +def convert(src, tgt, tokenizer, dtype, max_shard_size, max_pos, topk, rope_scaling): + """ + Convert state_dict to hf format. + + 1. Load and merge state dict + 2. Convert to huggingface + 3. Load tokneizer and save it with ``tokenizer.save_pretrained`` + 4. Load state dict to the model + 5. Call ``model.save_pretrained`` to save checkpoints. + """ + # load states + model_config, src_states, src_moe_states = load_source(src) + states = merge(src_states) + moe_states = merge(src_moe_states) + del src_states + del src_moe_states + + num_shards = len(states) + print("Converting to huggingface format...", flush=True) + + n_heads = model_config["num_attention_heads"] + dim = model_config["hidden_size"] + # n_heads_per_shard = n_heads // num_shards + # dims_per_head = dim // n_heads + intermediate_size = None + + print("Start converting...", flush=True) + state_dict = {} + for layer_i in tqdm(range(model_config["num_layers"])): + wqkvs = [ + states[tp].pop(f"blocks.{layer_i}.mixer.Wqkv.weight").reshape(3, n_heads // num_shards, -1, dim) + for tp in range(num_shards) + ] + bqkvs = [ + states[tp].pop(f"blocks.{layer_i}.mixer.Wqkv.bias").reshape(3, n_heads // num_shards, -1) + for tp in range(num_shards) + ] + state_dict.update( + { + f"model.layers.{layer_i}.input_layernorm.weight": states[0][f"blocks.{layer_i}.norm1.weight"].clone(), + f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][ + f"blocks.{layer_i}.norm2.weight" + ].clone(), + } + ) + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = torch.cat( + [wqkvs[i][0] for i in range(num_shards)], + dim=0, + ).reshape(dim, dim) + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.bias"] = torch.cat( + [bqkvs[i][0] for i in range(num_shards)], + dim=0, + ).reshape(-1) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = torch.cat( + [wqkvs[i][1] for i in range(num_shards)], + dim=0, + ).reshape(dim, dim) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.bias"] = torch.cat( + [bqkvs[i][1] for i in range(num_shards)], + dim=0, + ).reshape(-1) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( + [wqkvs[i][2] for i in range(num_shards)], + dim=0, + ).reshape(dim, dim) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.bias"] = torch.cat( + [bqkvs[i][2] for i in range(num_shards)], + dim=0, + ).reshape(-1) + + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [states[i][f"blocks.{layer_i}.mixer.out_proj.weight"] for i in range(num_shards)], dim=1 + ) + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.bias"] = states[0][f"blocks.{layer_i}.mixer.out_proj.bias"] + + state_dict[f"model.layers.{layer_i}.mlp.gate.weight"] = states[0][ + f"blocks.{layer_i}.mlp.moe_layer.gate.wg.weight" + ].clone() + + if model_config["moe_use_residual"]: + state_dict[f"model.layers.{layer_i}.mlp.shared_experts.gate_proj.weight"] = torch.cat( + [moe_states[i][f"blocks.{layer_i}.mlp.moe_layer.residual_mlp.w1.weight"] for i in range(num_shards)], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.mlp.shared_experts.down_proj.weight"] = torch.cat( + [moe_states[i][f"blocks.{layer_i}.mlp.moe_layer.residual_mlp.w3.weight"] for i in range(num_shards)], + dim=1, + ) + state_dict[f"model.layers.{layer_i}.mlp.shared_experts.up_proj.weight"] = torch.cat( + [moe_states[i][f"blocks.{layer_i}.mlp.moe_layer.residual_mlp.w2.weight"] for i in range(num_shards)], + dim=0, + ) + + for expert_id in range(model_config["num_experts"]): + state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_id}.gate_proj.weight"] = torch.cat( + [ + moe_states[i][f"blocks.{layer_i}.mlp.moe_layer.experts.wrapped_experts.{expert_id}.w1.weight"] + for i in range(num_shards) + ], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_id}.down_proj.weight"] = torch.cat( + [ + moe_states[i][f"blocks.{layer_i}.mlp.moe_layer.experts.wrapped_experts.{expert_id}.w3.weight"] + for i in range(num_shards) + ], + dim=1, + ) + state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_id}.up_proj.weight"] = torch.cat( + [ + moe_states[i][f"blocks.{layer_i}.mlp.moe_layer.experts.wrapped_experts.{expert_id}.w2.weight"] + for i in range(num_shards) + ], + dim=0, + ) + + intermediate_size, _ = state_dict[f"model.layers.{0}.mlp.experts.{0}.gate_proj.weight"].shape + + # embedding + for embedding_key in embedding_key_list: + if embedding_key in states[0]: + break + if embedding_key is None: + raise KeyError("Cannot find embedding key!") + if model_config["embed_split_hidden"]: + embed_concat_dim = 1 + tok_emb_list = [states[i][embedding_key] for i in range(num_shards)] + else: + embed_concat_dim = 0 + _, size_1 = states[0][embedding_key].shape + embdim_pertp = size_1 // num_shards + tok_emb_list = [ + torch.concat( + [ + states[tp][embedding_key][:, embdim_pertp * local_rank : embdim_pertp * (local_rank + 1)] + for tp in range(num_shards) + ], + dim=0, + ) + for local_rank in range(num_shards) + ] + state_dict.update( + { + "model.norm.weight": states[0]["norm.weight"], + "model.embed_tokens.weight": torch.cat(tok_emb_list, dim=embed_concat_dim), + "lm_head.weight": torch.cat([states[i]["head.weight"] for i in range(num_shards)], dim=0), + }, + ) + + # initialize model + # tokenizer + tokenizer = InternLMTokenizer(tokenizer) + # config + config = InternLMMoEConfig( + vocab_size=model_config["vocab_size"], + hidden_size=model_config["hidden_size"], + intermediate_size=intermediate_size, + num_attention_heads=model_config["num_attention_heads"], + num_hidden_layers=model_config["num_layers"], + rms_norm_eps=model_config["layer_norm_epsilon"], + bias=True, + rope_theta=model_config.get("rope_base", 10000), + rope_scaling=rope_scaling, + num_experts=model_config.get("num_experts", 1), + num_experts_per_tok=topk, + num_shared_experts=1 if model_config["moe_use_residual"] else 0, + ) + # tokenizer + config.max_position_embeddings = max_pos + # set bos eos pad to avoid improper generation + # since model.generate will create attention_mask + # according to pad_token_id and bos_token_id + config.bos_token_id = tokenizer.bos_token_id + config.eos_token_id = tokenizer.eos_token_id + config.pad_token_id = tokenizer.pad_token_id + + # model + print("Initializing model...", flush=True) + start = time.time() + with no_init_weights(): + model = InternLMMoEForCausalLM._from_config(config, torch_dtype=dtype) + print(f"Initializing model takes {time.time() - start}s", flush=True) + model.load_state_dict(state_dict) + + # 驱动选择 + device = "cuda" if torch.cuda.is_available() else "cpu" + + X = torch.zeros((32, 32), dtype=torch.int64).to(device=device) + labels = [] + for i in range(32): + labels.append((i + 1) % 32) + X[i] = 1 + labels = torch.tensor(labels) + dataset = Dataset.from_dict({"input_ids": X, "labels": X}) + + training_args = TrainingArguments( + output_dir="./results", # output directory 结果输出地址 + num_train_epochs=10, # total # of training epochs 训练总批次 + per_device_train_batch_size=1, # batch size per device during training 训练批大小 + per_device_eval_batch_size=1, # batch size for evaluation 评估批大小 + learning_rate=1e-3, # 学习率 + save_steps=False, # 不保存检查点 + ) + + trainer = Trainer( + model=model, # the instantiated 🤗 Transformers model to be trained 需要训练的模型 + args=training_args, # training arguments, defined above 训练参数 + train_dataset=dataset, # training dataset 训练集 + eval_dataset=dataset, # evaluation dataset 测试集 + ) + + trainer.train() + trainer.evaluate() + + del states + gc.collect() + print(f"Saving model to {tgt}...", flush=True) + tokenizer.save_pretrained(tgt) + model.save_pretrained(tgt, max_shard_size=max_shard_size) + + # fix auto_map in config + with open(os.path.join(tgt, "config.json")) as fp: + config_dict = json.load(fp) + config_dict["auto_map"]["AutoModel"] = "modeling_internlm.InternLMMoEForCausalLM" + with open(os.path.join(tgt, "config.json"), "w") as fp: + json.dump(config_dict, fp, indent=2) + + +def convert_tokenizer(src, tgt): + assert os.path.isfile(src) + tokenizer = InternLMTokenizer(src) + tokenizer.save_pretrained(tgt) + + +def get_rope_scaling(args): + if args.rotary_type == "origin": + return None + elif args.rotary_type == "dynamic": + return {"type": args.rotary_type, "factor": args.scaling_factor} + else: + raise NotImplementedError(f"Unknown rope type {args.rotary_type}") + + +def print_args(args): + print("-------------- Arguments --------------") + print(f"Source Path: {args.src}") + print(f"Target Path: {args.tgt}") + print(f"Dtype: {args.dtype}") + print(f"Max Shard Size: {args.max_shard}") + print(f"Max Position Embedding: {args.max_pos}") + print(f"Tokenizer Path: {args.tokenizer}") + print(f"Rotary Type: {args.rotary_type}") + print(f"Scaling Factor: {args.scaling_factor}") + print("---------------------------------------") + + +def parse_args(): + parser = argparse.ArgumentParser() + # model + parser.add_argument("--src", type=str, default=None, help="Input folder") + parser.add_argument("--tgt", type=str, help="Output folder") + parser.add_argument("--dtype", default="bfloat16", type=str, help="Data type after converting") + parser.add_argument("--max_shard", type=str, default="10GB", help="Max size of every sharded checkpoint.") + parser.add_argument("--max_pos", type=int, default=4096, help="Max position embedding of model.") + # tokenizer + parser.add_argument("--tokenizer", type=str, default=None, help="Tokenizer model.") + # rope + parser.add_argument("--rotary_type", type=str, default="origin", help="Rope type", choices=["origin", "dynamic"]) + parser.add_argument("--scaling_factor", type=float, default=1.0, help="Scaling factor of dynamic rope.") + parser.add_argument("--topk", type=int, default=1, help="top-k experts in MoE.") + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + print_args(args) + dtype = getattr(torch, args.dtype) + rope_scaling = get_rope_scaling(args) + + assert args.src is not None, "--src is needed!" + assert args.tokenizer is not None, "--tokenizer is needed!" + assert args.topk is not None, "--topk is needed!" + start = time.time() + convert(args.src, args.tgt, args.tokenizer, dtype, args.max_shard, args.max_pos, args.topk, rope_scaling) + print(f"Converting model takes {time.time() - start}s totally", flush=True) diff --git a/transformers/internlm_moe_model/__init__.py b/transformers/internlm_moe_model/__init__.py new file mode 100644 index 000000000..d1c242ee4 --- /dev/null +++ b/transformers/internlm_moe_model/__init__.py @@ -0,0 +1,9 @@ +from .configuration_internlm_moe import InternLMMoEConfig +from .modeling_internlm_moe import InternLMMoEForCausalLM +from .tokenization_internlm import InternLMTokenizer + +__all__ = [ + "InternLMMoEConfig", + "InternLMMoEForCausalLM", + "InternLMTokenizer", +] diff --git a/transformers/internlm_moe_model/configuration_internlm_moe.py b/transformers/internlm_moe_model/configuration_internlm_moe.py new file mode 100644 index 000000000..e6a0cc62e --- /dev/null +++ b/transformers/internlm_moe_model/configuration_internlm_moe.py @@ -0,0 +1,122 @@ +# coding=utf-8 +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/configuration_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" InternLM model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +INTERNLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +# Modified from transformers.model.llama.configuration_llama.LlamaConfig +class InternLMMoEConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`InternLMModel`]. It is used to instantiate + an InternLM model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the InternLM-7B. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the InternLM model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`InternLMModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + Example: + ```python + >>> from transformers import InternLMModel, InternLMConfig + >>> # Initializing a InternLM internlm-7b style configuration + >>> configuration = InternLMConfig() + >>> # Initializing a model from the internlm-7b style configuration + >>> model = InternLMModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "internlm" + _auto_class = "AutoConfig" + + def __init__( # pylint: disable=W0102 + self, + vocab_size=103168, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + bias=True, + rotary={"base": 10000, "type": "dynamic"}, # pylint: disable=W0102 + attn_implementation="eager", + num_experts=1, + num_experts_per_tok=1, + num_shared_experts=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.bias = bias + self.rotary = rotary + self.attn_implementation = attn_implementation + self.num_routed_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.num_shared_experts = num_shared_experts + if self.attn_implementation is None: + self.attn_implementation = "eager" + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/transformers/internlm_moe_model/modeling_internlm_moe.py b/transformers/internlm_moe_model/modeling_internlm_moe.py new file mode 100644 index 000000000..249a6ca36 --- /dev/null +++ b/transformers/internlm_moe_model/modeling_internlm_moe.py @@ -0,0 +1,1421 @@ +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/modeling_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch InternLM model.""" +import math +import queue +import threading +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + +try: + from transformers.generation.streamers import BaseStreamer +except: # noqa # pylint: disable=bare-except + BaseStreamer = None + +from .configuration_internlm_moe import InternLMMoEConfig + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "InternLMMoEConfig" + +flash_attn_func, flash_attn_varlen_func = None, None +pad_input, index_first_axis, unpad_input = None, None, None + + +def _import_flash_attn(): + global flash_attn_func, flash_attn_varlen_func + global pad_input, index_first_axis, unpad_input + try: + from flash_attn import flash_attn_func as _flash_attn_func + from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis as _index_first_axis + from flash_attn.bert_padding import pad_input as _pad_input + from flash_attn.bert_padding import unpad_input as _unpad_input + + flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func + pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input + except ImportError: + raise ImportError("flash_attn is not installed.") + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.llama.modeling_llama._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def _compute_load_balancing_loss(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float: + """Calculate the load balancing loss contribution.""" + if gate_logits is None or not isinstance(gate_logits, tuple) or gate_logits[0] is None: + return 0 + moe_losses = [] + for logit in gate_logits: + gates = F.softmax(logit, dim=1) + weight, indices = torch.topk(gates, top_k, dim=1) + num_tokens_per_expert = torch.histc(indices, bins=num_experts, min=0, max=num_experts) + moe_losses.append(torch.dot(num_tokens_per_expert.to(weight.dtype), weight.mean(dim=0))) + + return sum(moe_losses) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM +class InternLMRMSNorm(nn.Module): + """RMSNorm implemention.""" + + def __init__(self, hidden_size, eps=1e-6): + """ + InternLMRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM +class InternLMRotaryEmbedding(torch.nn.Module): + """Implement InternLM's rotary embedding. + + Args: + dim (int): Characteristic dimension of each self-attentional head. + max_position_embeddings (int, optional): Model's training length. Defaults to 2048. + base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000. + device (Any, optional): Running device. Defaults to None. + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(torch.float32), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(torch.float32), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) + return ( + self.cos_cached[:seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:seq_len, ...].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM +class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module): + """Implement InternLM's DyanmicNTK extrapolation method, thereby broadening the model support context to 16K. + + Args: + dim (int): Characteristic dimension of each self-attentional head. + max_position_embeddings (int, optional): Model's training length. Defaults to 2048. + base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000. + device (Any, optional): Running device. Defaults to None. + scaling_factor (float, optional): NTK method extrapolation coefficient. Defaults to 1.0. + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.dim = dim + self.base = base + self.scaling_factor = scaling_factor + + # Build here to make `torch.jit.trace` work. + self.max_position_embeddings = max_position_embeddings + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) + + def _update_cached(self, x, seq_len=None): + self.max_seq_len_cached = max(seq_len, self.max_position_embeddings) + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim)) + else: + inv_freq = self.inv_freq + t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len <= self.max_position_embeddings: + # Reset the tables if the sequence length has changed, + if self.max_seq_len_cached > self.max_position_embeddings: + self._update_cached(x, seq_len) + else: + self._update_cached(x, seq_len) + + return ( + self.cos_cached[:seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:seq_len, ...].to(dtype=x.dtype), + ) + + +# Copied from transformers.model.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + if position_ids.size(1) == 1: + q_cos = cos[position_ids].unsqueeze(1).expand(q.shape) + q_sin = sin[position_ids].unsqueeze(1).expand(q.shape) + q_embed = (q * q_cos) + (rotate_half(q) * q_sin) + + position_ids = position_ids.flatten() + 1 + max_length = max(position_ids) + position_ids = torch.stack( + [torch.cat([torch.ones(max_length - w, dtype=torch.long), torch.arange(w)]) for w in position_ids] + ) + k_cos = cos[position_ids].unsqueeze(1).expand(k.shape) + k_sin = sin[position_ids].unsqueeze(1).expand(k.shape) + k_embed = (k * k_cos) + (rotate_half(k) * k_sin) + else: + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->InternLM +class InternLMMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# A mixed expert module containing shared experts. +class InternLMMoELayer(nn.Module): + def __init__(self, config: InternLMMoEConfig): + super().__init__() + self.config = config + self.num_shared_experts = config.num_shared_experts + self.num_shared_experts = config.num_shared_experts + self.num_experts_per_tok = config.num_experts_per_tok + self.gate = nn.Linear(config.hidden_size, config.num_routed_experts, bias=False) + self.experts = torch.nn.ModuleList( + [ + InternLMMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + for _ in range(config.num_routed_experts) + ] + ) + if config.num_shared_experts > 0: + intermediate_size = config.intermediate_size * config.num_shared_experts + self.shared_experts = InternLMMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + ) + + def forward(self, x): + orig_shape = x.shape + orig_inputs = x + x = x.view(-1, x.shape[-1]) + # if self.gate.weight.dtype != torch.float32: + # self.gate = self.gate.float() + # x = x.float() + logits = self.gate(x) + gates = F.softmax(logits, dim=1) + weights, indices = torch.topk(gates, self.num_experts_per_tok, dim=1) + weights /= weights.sum(dim=-1, keepdim=True) + flat_indices = indices.view(-1) + + x = x.repeat_interleave(self.num_experts_per_tok, dim=0) + y = torch.empty_like(x) + for i, expert in enumerate(self.experts): + y[flat_indices == i] = expert(x[flat_indices == i]) + y = (y.view(*weights.shape, -1) * weights.unsqueeze(-1)).sum(dim=1) + + if self.config.num_shared_experts > 0: + y = y + self.shared_experts(orig_inputs) + + # moe_loss = self.load_balancing_loss(weights, indices) if self.training else None + return y.view(*orig_shape), logits + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->InternLM +class InternLMAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: InternLMMoEConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias) + self.rotary_emb = self._init_rope() + self.is_causal = True + + def _init_rope(self): + if self.config.rotary["type"] == "origin": + self.rotary_emb = InternLMRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rotary["base"], + ) + elif self.config.rotary["type"] == "dynamic": + self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rotary["base"], + scaling_factor=self.config.rotary.get("scaling_factor", 1.0), + ) + else: + raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').") + return self.rotary_emb + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + kv_seq_len = key_states.shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->InternLM +class InternLMFlashAttention2(InternLMAttention): + """ + InternLM flash attention module. This module inherits from `InternLMAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, # pylint: disable=W0613 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # InternLMFlashAttention2 attention does not support output_attentions + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + kv_seq_len = key_states.shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward(query_states, key_states, value_states, attention_mask, q_len) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + causal = self.is_causal and query_length != 1 + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q.to(torch.int64), + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +INTERNLM_ATTENTION_CLASSES = { + "eager": InternLMAttention, + "flash_attention_2": InternLMFlashAttention2, +} + + +# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->InternLM +class InternLMMoEDecoderLayer(nn.Module): + def __init__(self, config: InternLMMoEConfig): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = INTERNLM_ATTENTION_CLASSES[config.attn_implementation](config=config) + + self.mlp = ( + InternLMMoELayer(config) + if config.num_routed_experts > 1 + else InternLMMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + ) + self.input_layernorm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + router_logits = None + if len(hidden_states) == 2: + hidden_states, router_logits = hidden_states + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +INTERNLM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`InternLMMoEConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +# Copied from transformers.models.llama.modeling_llama.LlamaPretrainedModel with Llama->InternLM +@add_start_docstrings( + "The bare InternLM Model outputting raw hidden-states without any specific head on top.", + INTERNLM_START_DOCSTRING, +) +class InternLMMoEPreTrainedModel(PreTrainedModel): + config_class = InternLMMoEConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["InternLMMoEDecoderLayer"] + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): # pylint: disable=W0237 + if isinstance(module, InternLMMoEModel): + module.gradient_checkpointing = value + + +INTERNLM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or + when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->InternLM +@add_start_docstrings( + "The bare InternLM Model outputting raw hidden-states without any specific head on top.", + INTERNLM_START_DOCSTRING, +) +class InternLMMoEModel(InternLMMoEPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLMDecoderLayer`] + Args: + config: InternLMMoEConfig + """ + + _auto_class = "AutoModel" + + def __init__(self, config: InternLMMoEConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.layers = nn.ModuleList([InternLMMoEDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(INTERNLM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = output_router_logits if output_router_logits is not None else False + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.attn_implementation == "flash_attention_2": + _import_flash_attn() + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if self.config.attn_implementation == "flash_attention_2": + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + all_router_logits = () if output_router_logits else None + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->InternLM +class InternLMMoEForCausalLM(InternLMMoEPreTrainedModel): + _auto_class = "AutoModelForCausalLM" + + def __init__(self, config): + super().__init__(config) + self.model = InternLMMoEModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(INTERNLM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Returns: + + Example: + ```python + >>> from transformers import AutoTokenizer, InternLMForCausalLM + >>> model = InternLMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ``` + + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = output_router_logits if output_router_logits is not None else False + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + moe_loss = None + if output_router_logits: + aux_loss = _compute_load_balancing_loss( + outputs.router_logits if return_dict else outputs[-1], self.num_experts, self.num_experts_per_tok + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (moe_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + def build_inputs( + self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction="" + ): # pylint: disable=W0102 + if tokenizer.add_bos_token: + prompt = "" + else: + prompt = tokenizer.bos_token + if meta_instruction: + prompt += f"""<|System|>:{meta_instruction}\n""" + for record in history: + prompt += f"""<|User|>:{record[0]}\n<|Bot|>:{record[1]}\n""" + prompt += f"""<|User|>:{query}\n<|Bot|>:""" + return tokenizer([prompt], return_tensors="pt") + + @torch.no_grad() + def chat( # pylint: disable=W0102 + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = [], + streamer: Optional[BaseStreamer] = None, + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n" + "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory " + "(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" + "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user " + "such as English and 中文.", + **kwargs, + ): + inputs = self.build_inputs(tokenizer, query, history, meta_instruction) + inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)} + outputs = self.generate( + **inputs, + streamer=streamer, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + **kwargs, + ) + outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :] + response = tokenizer.decode(outputs, skip_special_tokens=True) + response = response.split("")[0] + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat( # pylint: disable=W0102 + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = [], + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + **kwargs, + ): + """ + Return a generator in format: (response, history) + Eg. + ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) + ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')]) + """ + if BaseStreamer is None: + raise ModuleNotFoundError( + "The version of `transformers` is too low. Please make sure " + "that you have installed `transformers>=4.28.0`." + ) + + response_queue = queue.Queue(maxsize=20) + + class ChatStreamer(BaseStreamer): + def __init__(self, tokenizer) -> None: + super().__init__() + self.tokenizer = tokenizer + self.queue = response_queue + self.query = query + self.history = history + self.response = "" + self.cache = [] + self.received_inputs = False + self.queue.put((self.response, history + [(self.query, self.response)])) + + def put(self, value): + if len(value.shape) > 1 and value.shape[0] > 1: + raise ValueError("ChatStreamer only supports batch size 1") + elif len(value.shape) > 1: + value = value[0] + + if not self.received_inputs: + # The first received value is input_ids, ignore here + self.received_inputs = True + return + + self.cache.extend(value.tolist()) + token = self.tokenizer.decode(self.cache, skip_special_tokens=True) + if "�" in token and len(token) <= 5: + return + if token.strip() != "": + self.response = self.response + token + history = self.history + [(self.query, self.response)] + self.queue.put((self.response, history)) + self.cache = [] + else: + self.end() + + def end(self): + self.queue.put(None) + + def stream_producer(): + return self.chat( + tokenizer=tokenizer, + query=query, + streamer=ChatStreamer(tokenizer=tokenizer), + history=history, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + **kwargs, + ) + + def consumer(): + producer = threading.Thread(target=stream_producer) + producer.start() + while True: + res = response_queue.get() + if res is None: + return + yield res + + return consumer() + + +@add_start_docstrings( + """ + The InternLM Model transformer with a sequence classification head on top (linear layer). + [`InternLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + INTERNLM_START_DOCSTRING, +) +class InternLMMoEForSequenceClassification(InternLMMoEPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = InternLMMoEModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(INTERNLM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int # pylint: disable=R1714 + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/transformers/internlm_moe_model/tokenization_internlm.py b/transformers/internlm_moe_model/tokenization_internlm.py new file mode 100644 index 000000000..de4745591 --- /dev/null +++ b/transformers/internlm_moe_model/tokenization_internlm.py @@ -0,0 +1,237 @@ +# coding=utf-8 +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/tokenization_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes for InternLM.""" +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"} + +PRETRAINED_VOCAB_FILES_MAP = {} + + +# Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer -> InternLM2Tokenizer +class InternLMTokenizer(PreTrainedTokenizer): + """ + Construct a InternLM tokenizer. Based on byte-level Byte-Pair-Encoding. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + model_input_names = ["input_ids", "attention_mask"] + _auto_class = "AutoTokenizer" + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token="", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + decode_with_prefix_space=False, + clean_up_tokenization_spaces=False, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.decode_with_prefix_space = decode_with_prefix_space + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + self._no_prefix_space_tokens = None + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def no_prefix_space_tokens(self): + if self._no_prefix_space_tokens is None: + vocab = self.convert_ids_to_tokens(list(range(self.vocab_size))) + self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")} + return self._no_prefix_space_tokens + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + @property + def bos_token_id(self) -> Optional[int]: + return self.sp_model.bos_id() + + @property + def eos_token_id(self) -> Optional[int]: + return self.sp_model.eos_id() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text): + """Returns a tokenized string.""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def _maybe_add_prefix_space(self, tokens, decoded): + if tokens and tokens[0] not in self.no_prefix_space_tokens: + return " " + decoded + else: + return decoded + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + out_string = self.clean_up_tokenization(out_string) + out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string) + return out_string[1:] + + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + if self.add_bos_token: + bos_token_ids = [self.bos_token_id] + else: + bos_token_ids = [] + + output = bos_token_ids + token_ids_0 + + if token_ids_1 is not None: + output = output + token_ids_1 + + if self.add_eos_token: + output = output + [self.eos_token_id] + + return output + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] From db977826e9e5befeaf3fab79e4be12588676f08d Mon Sep 17 00:00:00 2001 From: sallyjunjun <72725839+sallyjunjun@users.noreply.github.com> Date: Tue, 16 Jul 2024 10:35:28 +0800 Subject: [PATCH 04/12] feat(huggingface): native support for huggingface model and dataset (#244) Co-authored-by: zigzagcai Co-authored-by: SlinkierApple13 <18917988589@163.com> --- doc/code-docs/source/checkpoint.rst | 28 +- doc/code-docs/source/initialize.rst | 10 +- doc/usage.md | 165 +++++++-- internlm/checkpoint/load_funcs.py | 15 + internlm/core/parallel/comm/isp.py | 6 +- internlm/core/trainer_builder.py | 333 ++++++++++++++++++ internlm/data/build_dataloader.py | 39 +- internlm/data/streaming/__init__.py | 13 + internlm/data/streaming/batch_sampler.py | 78 ++++ internlm/data/streaming/collaters.py | 58 +++ internlm/data/streaming/dataset.py | 119 +++++++ internlm/data/streaming/utils.py | 19 + internlm/data/train_state.py | 2 +- internlm/data/utils.py | 12 + internlm/initialize/initialize_trainer.py | 14 +- internlm/model/builder.py | 15 +- internlm/model/registry.py | 1 + internlm/monitor/__init__.py | 3 +- internlm/monitor/monitor.py | 42 +++ internlm/train/pipeline.py | 31 +- internlm/utils/storage_manager.py | 19 +- requirements/runtime.txt | 2 +- tests/test_infer/test_trainer_generate.py | 6 +- .../test_forward_output_no_fa.py | 5 +- tests/test_training/test_load_ckpt_loss.py | 5 +- tests/test_training/test_loss.py | 6 +- tests/test_training/test_no_fa_train_temp.py | 5 +- tests/test_training/test_norm_weight.py | 5 +- .../test_swap_nb_loss_and_gradnorm.py | 5 +- tests/test_training/train_CI.py | 6 +- tests/test_utils/common_fixture.py | 2 +- tests/test_utils/test_model_checkpoint.py | 3 +- train.py | 318 +---------------- 33 files changed, 970 insertions(+), 420 deletions(-) create mode 100644 internlm/core/trainer_builder.py create mode 100644 internlm/data/streaming/__init__.py create mode 100644 internlm/data/streaming/batch_sampler.py create mode 100644 internlm/data/streaming/collaters.py create mode 100644 internlm/data/streaming/dataset.py create mode 100644 internlm/data/streaming/utils.py diff --git a/doc/code-docs/source/checkpoint.rst b/doc/code-docs/source/checkpoint.rst index aab161e91..c01c69504 100644 --- a/doc/code-docs/source/checkpoint.rst +++ b/doc/code-docs/source/checkpoint.rst @@ -16,7 +16,7 @@ CheckpointManager - ``checkpoint_every``: 检查点存储频率,参数类型 ``int``,默认为: ``50``。 -- ``load_ckpt_folder``: 初始化检查点/权重加载路径。参数类型 ``str``,默认为: ``None``,详见 :ref:`load-ckpt-folder`。 +- ``load_ckpt_info``: 初始化检查点/权重加载信息。参数类型 ``dict``,默认为: ``None``,详见 :ref:`load-ckpt-info`。 - ``async_upload``: 是否开启异步上传,默认值为:``False``,详见 :ref:`asyncupload`。 @@ -36,8 +36,8 @@ CheckpointManager ckpt = dict( enable_save_ckpt=False, # enable ckpt save. save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. - load_ckpt_folder=dict(path="local:/mnt/mfs/ckpt", content=["all",], ckpt_type="internlm"), - auto_resume=False, # disable auto-resume, internlm will load model checkpoint from the path of 'load_ckpt_folder'. + load_ckpt_info=dict(path="local:/mnt/mfs/ckpt", content=["all",], ckpt_type="internlm"), + auto_resume=False, # disable auto-resume, internlm will load model checkpoint from the path of 'load_ckpt_info'. checkpoint_every=CHECKPOINT_EVERY, async_upload=True, # async ckpt upload. (only work for boto3, volc and oss2 ckpt) async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. @@ -52,7 +52,7 @@ CheckpointManager 加载与存储格式约定 -------------------------- -.. _load-ckpt-folder: +.. _load-ckpt-info: (1) 路径格式约定 ~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -75,10 +75,10 @@ InternEvo对config中出现的所有存储路径都遵循以下的路径格式 -(2) 模型加载(load_ckpt_folder)格式约定 +(2) 模型加载(load_ckpt_info)格式约定 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -load_ckpt_folder 由三个字段组成, ``path`` 、 ``content`` 和 ``ckpt_type`` 。 +load_ckpt_info 由三个字段组成, ``path`` 、 ``content`` 和 ``ckpt_type`` 。 - ``path``:给出了检查点/初始化模型权重的加载路径(path的格式见下小节) @@ -92,17 +92,23 @@ load_ckpt_folder 由三个字段组成, ``path`` 、 ``content`` 和 ``ckpt_ty - ``ckpt_type``:表示加载的模型权重类型,目前支持的字段包括: - - ``internlm``:internevo约定的checkpoint存储格式。 + - ``internevo``:internevo约定的checkpoint存储格式。 + - ``llama``:llama约定的checkpoint存储格式。 + - ``hf_llama``:huggingface llama约定的checkpoint存储格式。 + - ``hf_model``:适用于加载huggingface所有模型的checkpoint存储格式。 下面给出两个例子: .. code-block:: python # 从文件存储相对路径 ckpt_model 中加载已有模型权重初始化模型,适合 sft 等训练初始化 - load_ckpt_folder= dict(path="local:ckpt_model", content=["model",], ckpt_type="internlm") + load_ckpt_info = dict(path="local:ckpt_model", content=("model",), ckpt_type="internevo") # 从文件存储相对路径 ckpt_model 中加载所有的状态,适合断点续训的场景 - load_ckpt_folder= dict(path="local:ckpt_model", content=["all",], ckpt_type="internlm") + load_ckpt_info = dict(path="local:ckpt_model", content=("all",), ckpt_type="internevo") + + # 从 huggingface 下载指定模型,加载checkpoint + load_ckpt_info = dict(path="internlm/internlm-7b", content=("model",), ckpt_type="hf_model") .. _asyncupload: @@ -144,13 +150,13 @@ config.ckpt 中相关的参数: 检查点自动加载功能的目的是在resume训练时,自动加载 ``save_ckpt_folder`` 路径下最新的检查点(包括snapshot检查点)。配合上自动重启机制,可以实现无人干预的任务自动恢复。 -该功能默认开启,所以要注意如果需要加载 ``load_ckpt_folder`` 路径下的模型权重,要将 ``auto_resume`` 设置为 False,否则可能会产生预期外的行为。 +该功能默认开启,所以要注意如果需要加载 ``load_ckpt_info`` 路径下的模型权重,要将 ``auto_resume`` 设置为 False,否则可能会产生预期外的行为。 config.ckpt 中相关的参数: - ``auto_resume``: 是否开启检查点自动恢复。参数类型 ``bool``,默认为 ``True``。 -``auto_resume`` 如果为True,则尝试从 ``save_ckpt_folder`` 路径中自动加载最新的ckpt,如果找不到,则从step 0开始训练。如果为False,则尝试从 ``load_ckpt_folder`` 中加载模型参数。 +``auto_resume`` 如果为True,则尝试从 ``save_ckpt_folder`` 路径中自动加载最新的ckpt,如果找不到,则从step 0开始训练。如果为False,则尝试从 ``load_ckpt_info`` 中加载模型参数。 .. _stopfile: diff --git a/doc/code-docs/source/initialize.rst b/doc/code-docs/source/initialize.rst index d1e7511b9..ff3985ee8 100644 --- a/doc/code-docs/source/initialize.rst +++ b/doc/code-docs/source/initialize.rst @@ -77,17 +77,19 @@ InternEvo 在配置文件中使用字段 ``model_type`` 和 ``model`` 来控制 - 字段 ``model_type`` 指明了要初始化的模型类型 - 字段 ``model`` 中的参数指定了在模型初始化过程中的参数设置 -值得注意的是,用户可以定义新的模型类型,并使用装饰器 ``@MODEL_INITIALIZER.register_module`` 注册模型的初始化函数,其中 ``MODEL_INITIALIZER`` 是类 ``internlm.util.registry.Registry`` 的一个实例化对象,示例如下所示: +值得注意的是,用户可以定义新的模型类型,并通过 ``register_module`` 注册模型的初始化函数,示例如下所示: .. code-block:: python - MODEL_TYPE = "NEW_MODEL" + model_initializer = Registry("model_initializer") - @MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) - def build_new_model_with_cfg(*args, **kwargs): + def register_model_initializer() -> None: + model_initializer.register_module("INTERNLM", InternLM1) .. _InternLM-optim-init: +其中,"INTERNLM"为新的模型类型,InternLM1为新模型的入口函数。 + 优化器初始化 ------------------------- diff --git a/doc/usage.md b/doc/usage.md index a1dcef624..78b929603 100644 --- a/doc/usage.md +++ b/doc/usage.md @@ -83,32 +83,36 @@ MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" # Ckpt folder format: # fs: 'local:/mnt/nfs/XXX' SAVE_CKPT_FOLDER = "local:llm_ckpts" -LOAD_CKPT_FOLDER = "local:llm_ckpts/49" # boto3 Ckpt folder format: # import os # BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint # SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" -# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" CHECKPOINT_EVERY = 50 ckpt = dict( enable_save_ckpt=False, # enable ckpt save. save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. - # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"), - load_ckpt_folder="local:llm_ckpts/", # 'load_ckpt_info' setting guide: # 1. the 'path' indicate ckpt path, # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" - # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported. - load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "llama", "hf_llama", "hf_model". + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"), + # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering + # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) + # with an automatic restart mechanism upon training reboot. + # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint + # path specified in `load_ckpt_info` by default. + # If you want to initialize your model weights from another model, you must set `auto_resume` to False. + # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. + auto_resume=True, checkpoint_every=CHECKPOINT_EVERY, async_upload=True, # async ckpt upload. (only work for boto3 ckpt) async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. ) -TRAIN_FOLDER = "/path/to/dataset" -VALID_FOLDER = "/path/to/dataset" +TRAIN_FOLDER = None # "/path/to/dataset" +VALID_FOLDER = None # "/path/to/dataset" data = dict( seq_len=SEQ_LEN, # micro_num means the number of micro_batch contained in one gradient update @@ -122,13 +126,22 @@ data = dict( pack_sample_into_one=False, total_steps=50000, skip_batches="", + # rampup_batch_size (str): A string with three space-separated integers representing the + # starting batch size, the increment, and the number of steps between + # each increment. For example, "192 24 8" means that the batch size (micro_num) + # starts at 192 and increases by 24 every 8 steps. Defaults to None. + # (IMPORTANT): The interval step size is 'micro_bsz'. rampup_batch_size="", # Datasets with less than 50 rows will be discarded min_length=50, - # train_folder=TRAIN_FOLDER, - # valid_folder=VALID_FOLDER, - empty_cache_and_diag_interval=10, + train_folder=TRAIN_FOLDER, + valid_folder=VALID_FOLDER, + empty_cache_and_diag_interval=200, diag_outlier_ratio=1.1, + # whether use shared memory to load meta files + use_shm=False, + # when use shm, the default shm_path is "/dev/shm/metacache" + # shm_path="/dev/shm/metacache" ) grad_scaler = dict( @@ -153,11 +166,16 @@ grad_scaler = dict( hybrid_zero_optimizer = dict( # Enable low_level_optimzer overlap_communication overlap_sync_grad=True, - overlap_sync_param=True, + overlap_sync_param=False, # bucket size for nccl communication params reduce_bucket_size=512 * 1024 * 1024, # grad clipping clip_grad_norm=1.0, + # whether use new optm + use_split_tensor_optim=False, + # when use split tensor optm + # Perform all gather with a set of parameters of all_gather_size + all_gather_size=512 * 1024 * 1024, ) loss = dict( @@ -187,6 +205,7 @@ beta2_scheduler = dict( cur_iter=-1, ) +use_fp32_norm = False model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, @@ -198,28 +217,50 @@ model = dict( num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, apply_post_layer_norm=False, - dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" + dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" norm_type="rmsnorm", layer_norm_epsilon=1e-5, use_flash_attn=True, + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. ) """ -zero1 parallel: - 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, - so parameters will be divided within the range of dp. - 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. - 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. +zero1 parallel (dict): + 1. size: int + * if size <= 0, the size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. + * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. + 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. +tensor parallel (dict): + 1. size: int, the size of tensor parallel. + 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], + defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. + msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. + fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. + isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. pipeline parallel (dict): 1. size: int, the size of pipeline parallel. - 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. -tensor parallel: tensor parallel size, usually the number of GPUs per node. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, + defaults to False. +weight parallel (dict): + 1. size: int, the size of weight parallel. + 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. memory_pool: bool, enable/disable memory pool, defaults to False. """ parallel = dict( - zero1=8, + zero1=dict(size=-1), + tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), - sequence_parallel=False, + weight=dict(size=1, overlap=True, memory_pool=True), ) cudnn_deterministic = False @@ -231,6 +272,10 @@ monitor = dict( enable_feishu_alert=DO_ALERT, feishu_alert_address=None, # feishu webhook to send alert message light_monitor_address=None, # light_monitor address to send heartbeat + alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", + ), + tensorboard=dict( + queue_max_length=10, ), ) ``` @@ -264,23 +309,56 @@ data = dict( ``` 数据集的详细内容可参考``数据准备``模块相关的介绍。 +同时,也支持huggingface格式的数据集处理。 +train_folder设置为huggingface上可以通过load_dataset直接下载的数据集路径,如:"roneneldan/TinyStories" +在data中,需要新增type及tokenizer_path字段,标示数据集是huggingface格式,并指定tokenizer路径,如: +```python +TRAIN_FOLDER = "roneneldan/TinyStories" +SEQ_LEN = 2048 +data = dict( + type="hf", + tokenizer_path="internlm/internlm-7b", + seq_len=SEQ_LEN, # 数据样本长度,默认值为 2048 + micro_num=1, # micro_num 是指在一次模型参数更新中会处理的 micro_batch 的数目,默认值为 1 + micro_bsz=1, # packed_length = micro_bsz * SEQ_LEN,为一次处理的 micro_batch 的数据大小,默认值为 1 + total_steps=50000, # 总的所需执行的 step 的数目,默认值为 50000 + min_length=50, # 若数据集文件中,数据行数少于50,将会被废弃 + train_folder=TRAIN_FOLDER, # 数据集文件路径,默认值为 None;若 train_folder 为空,则以自动生成的随机数据集 +进行训练测试 + pack_sample_into_one=False, # 数据整理的逻辑,决定是按照 seq_len 维度或者是 sequence 的真实长度来进行attention计算 +) +``` + #### 模型配置 如果在启动训练时要加载模型 `checkpoint`,可进行如下相关配置: ```python SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt" -LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt" +# MODEL_ONLY_FOLDER = "internlm/internlm-7b" +MODEL_ONLY_FOLDER = "local:/path/to/load/resume/ckpt" ckpt = dict( + enable_save_ckpt=True, # 是否开启保存 checkpoint 功能 save_ckpt_folder=SAVE_CKPT_FOLDER, # 存储模型和优化器 checkpoint 的路径 checkpoint_every=float("inf"), # 每多少个 step 存储一次 checkpoint,默认值为 inf # 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练 # content 表示哪些状态会被加载,支持: "model", "sampler", "optimizer", "scheduler", "all" - # ckpt_type 表示加载的模型类型,目前支持: "internlm" + # ckpt_type 表示加载的模型类型,目前支持: "internevo", "llama", "hf_llama", "hf_model" + # 其中,"hf_model"类型表示从huggingface上下载模型加载ckpt,MODEL_ONLY_FOLDER需要设置为可以 + # 通过AutoModel直接加载的模型路径,如:"internlm/internlm-7b" load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), + # 'auto_resume' 旨在在遇到由硬件故障引起的训练中断/挂起时,自动从 'save_ckpt_folder' 加载最新的检查点, + # 使用调度系统(例如 k8s/slurm)在训练重启时自动重启机制。 + # 请注意,如果未设置 auto_resume(其默认值为 True),它将不会默认加载 load_ckpt_info 中指定的检查点路径。 + # 如果你想从另一个模型初始化你的模型权重,必须将 auto_resume 设置为 False。 + # 如果你想从头开始训练,请将 auto_resume 设置为 False 并将 'load_ckpt_info' 设置为 None。 + auto_resume=False, + async_upload=True, # 异步检查点上传。(仅适用于 boto3 检查点) + async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # 异步上传期间临时文件的路径。 + oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # 快照检查点保存频率。 ) ``` 注意: -- 路径若以 `local:` 为前缀,则存储在本地文件系统;若以 `boto3:` 为前缀,则存储在远程 oss 上 +- 路径若以 `local:` 为前缀,则存储在本地文件系统;若以 `boto3:` 为前缀,则存储在远程 oss 上;若无前缀,为huggingface上可以直接下载的模型路径。 模型相关关键参数配置如下所示: ```python @@ -306,7 +384,7 @@ model = dict( layer_norm_epsilon=1e-5, ) ``` -注意:用户可自定义模型类型名和模型结构,并配置相对应的模型参数。通过`utils/registry.py`下的`MODEL_INITIALIZER`对象进行模型初始化函数接口注册,在训练主函数`train.py`中初始化模型时,可通过`model_type`配置获取指定的模型初始化接口函数。 +注意:用户可自定义模型类型名和模型结构,并配置相对应的模型参数。通过`internlm/model/registry.py`下的`model_initializer`对象进行模型初始化函数接口注册,在训练主函数`train.py`中初始化模型时,可通过`model_type`配置获取指定的模型初始化接口函数。 *如果基于 InternLM 7B继续训练,可以参考 [ModelZoo](https://github.com/InternLM/InternLM/tree/main#model-zoo) 中 OpenXLab 链接下载权重* @@ -315,21 +393,32 @@ model = dict( 训练并行配置样例如下: ```python parallel = dict( - zero1=8, - tensor=1, + zero1=dict(size=-1), + tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), - sequence_parallel=False, + weight=dict(size=1, overlap=True, memory_pool=True), ) ``` -- zero1:zero 并行策略,分如下三种情况,默认值为 -1 - - 当`zero1 <= 0`,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配 - - 当`zero1 == 1`,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数 - - 当`zero1 > 1`且`zero1 <= data_parallel_world_size`,则 zero1 进程组是数据并行进程组的子集 -- tensor:张量并行大小,通常是每个节点的 GPU 数量,默认值为 1 -- pipeline:流水线并行策略 - - size:流水线并行大小,默认值为 1 - - interleaved_overlap:bool 类型,交错式调度时,开启或关闭通信优化,默认值为关闭 -- sequence_parallel:是否开启序列化并行,默认值为 False +- zero1(字典): + 1. size: 整数 + - 当`zero1 <= 0`,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配 + - 当`zero1 == 1`,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数 + - 当`zero1 > 1`且`zero1 <= data_parallel_world_size`,则 zero1 进程组是数据并行进程组的子集 + 2. fsdp: 布尔值,启用/禁用torch的完全分片数据并行,默认为False。 +- tensor(字典): + 1. size: 整数,张量并行的大小。 + 2. mode: 字符串,张量并行模式,应该是 ['mtp', 'msp', 'fsp', 'isp'] 中的一个, + - 默认为 'mtp',意味着没有序列并行的纯Megatron张量并行。 + - msp: 带序列并行的Megatron张量并行,序列并行大小 = 张量并行大小。 + - fsp: 通过flash-attn带序列并行的张量并行,序列并行大小 = 张量并行大小。 + - isp: 定制的内部序列并行,不带张量并行,可以与权重并行一起使用。 +- pipeline(字典): + 1. size: 整数,流水线并行的大小。 + 2. interleaved_overlap: 布尔值,启用/禁用在使用交错流水线调度器时的通信重叠,默认为False。 +- weight(字典): + 1. size: 整数,权重并行的大小。 + 2. overlap: 布尔值,启用/禁用all_gather/reduce_scatter通信重叠,默认为False。 + 3. memory_pool: 布尔值,启用/禁用内存池,默认为False。 注意:`数据并行大小 = 总的 GPU 数目 / 流水线并行大小 / 张量并行大小` diff --git a/internlm/checkpoint/load_funcs.py b/internlm/checkpoint/load_funcs.py index 6bdfd6346..1ba0ac6a3 100644 --- a/internlm/checkpoint/load_funcs.py +++ b/internlm/checkpoint/load_funcs.py @@ -9,6 +9,7 @@ from internlm.core.parallel.shard import partition_uniform from internlm.utils.logger import get_logger from internlm.utils.storage_manager import get_fns, llm_load +from transformers import AutoModelForCausalLM logger = get_logger(__file__) internlm_accelerator = get_accelerator() @@ -304,8 +305,22 @@ def load_internlm_with_dynamic_parallel_size(folder, model): ) +def load_hf_model_pretrained_weights(folder, model): + """NOTE: when loading huggingface's model pretrained weights, you should set `adapt_hf=True` in your config.""" + assert folder is not None, "Please specify the folder of the pretrained model" + if gpc.is_rank_for_log(): + logger.info(f"Loading pretrained model from {folder}") + + pretrained_model = AutoModelForCausalLM.from_pretrained(folder, trust_remote_code=True) + model.load_state_dict(pretrained_model.state_dict(), strict=False) + + if gpc.is_rank_for_log(): + logger.info("Pretrained weights loaded successfully") + + LOAD_FUNC_DICT = { "llama": load_llama_pretrained_weights, "hf_llama": load_hf_llama_pretrained_weights, "internlm_test": load_internlm_with_dynamic_parallel_size, + "hf_model": load_hf_model_pretrained_weights, } diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 14637912b..71dde3a35 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -258,8 +258,12 @@ def __init__( def _parse_model_structure(self, cid: int, model: nn.Module) -> None: self._overlap_states[cid] = ISPOverlapState() + def get_model(obj: nn.Module) -> nn.Module: + return get_model(obj.model) if hasattr(obj, "model") else obj + # Important: only works for llama-class models - for _, children in model.named_children(): + children_name = get_model(model).named_children() + for _, children in children_name: if isinstance(children, nn.ModuleList): self._overlap_states[cid].ckpt_block_num = int(self.model_conf.activation_checkpointing * len(children)) diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py new file mode 100644 index 000000000..8933a5df6 --- /dev/null +++ b/internlm/core/trainer_builder.py @@ -0,0 +1,333 @@ +import gc +import logging +import time +from functools import partial +from typing import Dict, Optional + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader + +from internlm.checkpoint.checkpoint_manager import CheckpointManager +from internlm.core.context import global_context as gpc +from internlm.core.context.process_group_initializer import ParallelMode +from internlm.core.trainer import Trainer +from internlm.data.streaming.utils import hf_simple_resume +from internlm.data.train_state import get_train_state +from internlm.eval.evaluation import evaluate_on_val_dls +from internlm.initialize.initialize_trainer import initialize_trainer +from internlm.model.losses.ce_loss import FlashGPTLMLoss +from internlm.model.metrics import AccPerplex +from internlm.monitor.monitor import send_alert_message +from internlm.train.pipeline import ( + get_scheduler_hooks, + initialize_llm_profile, + initialize_optimizer, + initialize_parallel_communicator, + load_new_batch, + record_current_batch_training_metrics, +) +from internlm.utils.common import ( + BatchSkipper, + enable_pytorch_expandable_segments, + get_current_device, + get_megatron_flops, + launch_time, +) +from internlm.utils.gputest import empty_cache_and_diag +from internlm.utils.logger import get_logger +from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.parallel import get_parallel_log_file_name +from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler +from internlm.utils.writer import Writer + +# global llm logger +logger = logging.getLogger(__file__) + + +class TrainerBuilder(Trainer): + """ + Manage InternEvo training process. + + Args: + model (torch.nn.Module): The model to be trained. + train_dl (torch.utils.data.DataLoader): The training data loader. + val_dls (Optional[Dict[str, torch.utils.data.DataLoader]]): The validation data loaders. + kwargs: Additional keyward arguments. + """ + + def __init__( + self, + model: torch.nn.Module, + train_dl: DataLoader, + val_dls: Optional[Dict[str, DataLoader]] = None, + **kwargs, + ): + """ + Initialize InternEvo TrainerBuilder class. + + Args: + model (torch.nn.Module): The model to be trained. + train_dl (torch.utils.data.DataLoader): The training data loader. + val_dls (Optional[Dict[str, torch.utils.data.DataLoader]]): The validation data loaders. + kwargs: Additional keyward arguments. + """ + + # record very_begining_time + very_begining_time = time.time() + + # set torch expandable_segments + enable_pytorch_expandable_segments() + + # get and broadcast current time + current_time = launch_time() + objs = [current_time] + dist.broadcast_object_list(objs, src=0) + current_time = objs[0].replace(":", ".") + global logger + logger = get_logger( + __file__, launch_time=current_time, job_name=gpc.config.JOB_NAME, file_name=get_parallel_log_file_name() + ) + + # initialize isp communicator + isp_communicator = initialize_parallel_communicator(model) + + with open(kwargs["config"], "r") as f: + config_lines = f.readlines() + + # initialize loss function + criterion = FlashGPTLMLoss( + parallel_output=gpc.config.model.parallel_output, label_smoothing=gpc.config.loss.label_smoothing + ) + + # initialize and resume train state + train_state = get_train_state(train_dl) + + # initialize optimizer + optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator) + + # initialize checkpoint manager + ckpt_manager = CheckpointManager( + ckpt_config=gpc.config.ckpt, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + train_dl=train_dl, + model_config=gpc.config.model, + model_config_file="".join(config_lines), + feishu_address=gpc.config.monitor.alert.feishu_alert_address, + ) + + # load other persistent training states + ckpt_manager.try_resume_training(train_state, current_time) + + # initialize customed llm writer + writer = Writer( + job_name=gpc.config.JOB_NAME, + launch_time=current_time, + file_name=get_parallel_log_file_name(), + tensorboard_folder=gpc.config.tensorboard_folder, + resume_tb_folder=train_state.resume_tb_folder, # resume from ckpt. + step_count=train_state.step_count, # resume from ckpt. + config=config_lines, + logger=logger, + enable_tb=gpc.config.enable_tb, + queue_max_length=gpc.config.tensorboard.queue_max_length, + total_steps=gpc.config.data.total_steps, + ) + + # initialize metric for calculating accuracy and perplexity + metric = AccPerplex( + device=get_current_device(), + tp_pg=gpc.get_group(ParallelMode.TENSOR), + dp_pg=gpc.get_group(ParallelMode.DATA), + dataset_types=kwargs["dataset_types"], + ) + + # initialize simple memory profiler + if kwargs["profiling"]: + self.memory_profiler = SimpleMemoryProfiler( + model, + optimizer.optim, + log_folder=f"RUN/{gpc.config.JOB_NAME}/{current_time}/memory_trace/rank{gpc.get_global_rank()}_" + + f"dp{gpc.get_local_rank(ParallelMode.DATA)}_" + + f"wp{gpc.get_local_rank(ParallelMode.WEIGHT)}_" + + f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}", + ) + else: + self.memory_profiler = None + + # initialize batch skipper + skip_batches = gpc.config.data.skip_batches + if gpc.config.data.type == "hf" and gpc.config.ckpt.auto_resume: + skip_batches = hf_simple_resume(train_state) + self.batch_skipper = BatchSkipper(skip_batches) + + # set TrainerBuilder attributes + self.very_begining_time = very_begining_time + self.profiling = kwargs["profiling"] + self.current_time = current_time + self.train_dl = train_dl + self.val_dls = val_dls + self.train_state = train_state + self.optimizer = optimizer + self.beta2_scheduler = beta2_scheduler + self.isp_communicator = isp_communicator + self.writer = writer + self.ckpt_manager = ckpt_manager + self.metric = metric + + # initialize trainer + engine, scheduler = initialize_trainer( + model=model, + optimizer=optimizer, + criterion=criterion, + lr_scheduler=lr_scheduler, + beta2_scheduler=beta2_scheduler, + scheduler_hooks=get_scheduler_hooks(metric, optimizer, isp_communicator), + ) + + super().__init__(engine, scheduler) + + def fit(self): + """ + Launch InternEvo TrainerBuilder training process. + """ + + self.train() + train_iter = iter(self.train_dl) + + with initialize_llm_profile(profiling=self.profiling, start_time=self.current_time) as prof: + # close automatic garbage collection + gc.disable() + # start iterating the train data and begin training + for batch_count in range(self.train_state.batch_count, gpc.config.data.total_steps): + empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) + # internlm_accelerator.memory._record_memory_history() + start_time = time.time() + timer("one-batch").start() + + # load batch data + batch, train_iter = load_new_batch( + train_dl=self.train_dl, train_iter=train_iter, train_state=self.train_state + ) + + # record the consumed samples in training + self.train_state.batch_count = batch_count + self.train_state.num_consumed_samples_in_epoch += len(batch[1]) + if self.batch_skipper(batch_count): # skip this batch + if gpc.is_rank_for_log(): + logger.info(f"Skip batch count:`{batch_count}`...") + timer("one-batch").stop() + continue + + # zero the grads of parameters + self.zero_grad() + # process data + if batch[0].get("type_ids", None) is not None: + self.metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None)) + # if batch[0].get("cu_seqlens", None) is not None: + # metric.set_cu_seqlens(cu_seqlens=batch[0].pop("cu_seqlens", None)) + + # do forward and backward + timer("fwd-bwd").start() + + moe_loss = None + if hasattr(gpc.config.model, "num_experts"): + _, _, loss, moe_loss = self.execute_schedule( + batch, + forward_only=False, + return_loss=True, + return_output_label=False, + ) + else: + _, _, loss = self.execute_schedule( # pylint: disable=W0632 + batch, + forward_only=False, + return_loss=True, + return_output_label=False, + ) + timer("fwd-bwd").stop() + + if self.isp_communicator and self.isp_communicator.enable_memory_pool: + self.isp_communicator.memory_pool.reset_lazy_pools() + + # update parameters, and returns (success_update, grad_norm) + trainer_result = self.step() + assert trainer_result is not None + + success_update, grad_norm_groups = trainer_result + if success_update: # update parameters successfully + self.train_state.step_count += 1 + else: + self.train_state.inf_nan_skip_batches += ( + 1 # record the amount of updating parameters unsuccessfully. + ) + if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case + logger.warning(f"Warning: skip parameter update at step {batch_count}.") + send_alert_message( + address=gpc.config.monitor.alert.feishu_alert_address, + message=f"Warning: skip parameter update at step {batch_count}.", + ) + + get_tflops_func = partial( + get_megatron_flops, + checkpoint=gpc.config.model.checkpoint, + seq_len=gpc.config.data["seq_len"], + hidden_size=gpc.config.model.hidden_size, + num_layers=gpc.config.model.num_layers, + vocab_size=gpc.config.model.vocab_size, + global_batch_size=gpc.config.data.micro_bsz + * gpc.config.data.micro_num + * gpc.get_world_size(ParallelMode.DATA), + global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), + mlp_ratio=gpc.config.model["mlp_ratio"], + ) + + # calculate and record the training metrics, eg. loss, accuracy and so on. + record_current_batch_training_metrics( + get_tflops_func=get_tflops_func, + logger=logger, + writer=self.writer, + success_update=success_update, + batch_count=batch_count, + batch=batch, + train_state=self.train_state, + optimizer=self.optimizer, + beta2_scheduler=self.beta2_scheduler, + trainer=self, + start_time=start_time, + very_begining_time=self.very_begining_time, + loss=loss, + moe_loss=moe_loss, + grad_norm=grad_norm_groups, + metric=self.metric, + ) + + timer("one-batch").stop() + + # evaluate on validation data loaders + if gpc.config.data.valid_every > 0 and self.train_state.step_count % gpc.config.data.valid_every == 0: + evaluate_on_val_dls( + self, + val_dls=self.val_dls, + writer=self.writer, + logger=logger, + step_count=self.train_state.step_count, + ) + + # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" + # # save batch sampler that tracks the true consumed samples + now_break = self.ckpt_manager.try_save_checkpoint(self.train_state) + if now_break: + break + + if self.memory_profiler is not None: + self.memory_profiler.step() + + if batch_count % 2 == 0: + prof.step() + + # internlm_accelerator.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") + + self.ckpt_manager.wait_async_upload_finish() diff --git a/internlm/data/build_dataloader.py b/internlm/data/build_dataloader.py index c2c0ea690..5af73b84b 100644 --- a/internlm/data/build_dataloader.py +++ b/internlm/data/build_dataloader.py @@ -6,6 +6,12 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.data.streaming.batch_sampler import StreamingStaticBatchSampler +from internlm.data.streaming.collaters import nopack_collate_fn, pack_collate_fn +from internlm.data.streaming.dataset import ( + HuggingFacePackedDataset, + HuggingFaceStreamingDataset, +) from internlm.data.tokenized.batch_sampler import ( StaticBatchSampler, get_dpsampler_dataloader, @@ -108,6 +114,31 @@ def get_tokenized_valid_loader_items(data_cfg): return valid_ds, valid_collate_fn +def get_hf_train_loader_items(data_cfg): + train_ds = HuggingFaceStreamingDataset( + dataset_name=data_cfg.train_folder, + tokenizer_name=data_cfg.tokenizer_path, + model_max_length=data_cfg.seq_len, + subset_name=data_cfg.get("subset_name", None), + ) + if gpc.config.model_type == "hf" and not data_cfg.use_packed_dataset: + train_sampler = StreamingStaticBatchSampler( + batch_size=data_cfg.micro_num * data_cfg.micro_bsz, rampup_batch_size=data_cfg.rampup_batch_size + ) + train_collate_fn = partial( + nopack_collate_fn, micro_num=data_cfg.micro_num, micro_bsz=data_cfg.micro_bsz, seq_len=data_cfg.seq_len + ) + else: + train_ds = HuggingFacePackedDataset(dataset=train_ds, seq_len=data_cfg.seq_len, micro_bsz=data_cfg.micro_bsz) + train_sampler = StreamingStaticBatchSampler( + batch_size=data_cfg.micro_num, rampup_batch_size=data_cfg.rampup_batch_size + ) + train_collate_fn = partial( + pack_collate_fn, micro_num=data_cfg.micro_num, micro_bsz=data_cfg.micro_bsz, seq_len=data_cfg.seq_len + ) + return train_ds, train_sampler, train_collate_fn + + def build_train_loader_with_data_type(): """ Build and return the training data loader based on data type. @@ -115,11 +146,15 @@ def build_train_loader_with_data_type(): Returns: A tuple of (train_dl, dataset_types). """ data_cfg = gpc.config.data + train_folder = data_cfg.get("train_folder", None) - dataset_types = list(get_dataset_type_ids_map(train_folder).keys()) if train_folder else ["en", "cn", "code"] if data_cfg.type == "tokenized": train_ds, train_sampler, train_collate_fn = get_tokenized_train_loader_items(data_cfg) + dataset_types = list(get_dataset_type_ids_map(train_folder).keys()) if train_folder else ["en", "cn", "code"] + elif data_cfg.type == "hf": + train_ds, train_sampler, train_collate_fn = get_hf_train_loader_items(data_cfg) + dataset_types = ["en"] else: raise ValueError(f"dataset type {data_cfg.type} is not supported") @@ -141,7 +176,7 @@ def build_valid_loader_with_data_type(): data_cfg = gpc.config.data - if data_cfg.type == "tokenized": + if data_cfg.type in ["tokenized", "hf"]: valid_ds, valid_collate_fn = get_tokenized_valid_loader_items(data_cfg) else: raise ValueError(f"dataset type {data_cfg.type} is not supported") diff --git a/internlm/data/streaming/__init__.py b/internlm/data/streaming/__init__.py new file mode 100644 index 000000000..513e3243b --- /dev/null +++ b/internlm/data/streaming/__init__.py @@ -0,0 +1,13 @@ +from .batch_sampler import StreamingStaticBatchSampler +from .collaters import nopack_collate_fn, pack_collate_fn +from .dataset import HuggingFacePackedDataset, HuggingFaceStreamingDataset +from .utils import hf_simple_resume + +__all__ = [ + "StreamingStaticBatchSampler", + "nopack_collate_fn", + "pack_collate_fn", + "HuggingFaceStreamingDataset", + "HuggingFacePackedDataset", + "hf_simple_resume", +] diff --git a/internlm/data/streaming/batch_sampler.py b/internlm/data/streaming/batch_sampler.py new file mode 100644 index 000000000..11f9bb8b0 --- /dev/null +++ b/internlm/data/streaming/batch_sampler.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import sys +from typing import Optional + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) + + +class StreamingStaticBatchSampler: + """ + StreamingStaticBatchSampler is used for the training process. + """ + + def __init__(self, batch_size: int = 1, rampup_batch_size: Optional[str] = None, micro_bsz: int = 1): + if rampup_batch_size: + start_bsz, bsz_incre, incre_every = map(int, rampup_batch_size.split()) + else: + start_bsz, bsz_incre, incre_every = batch_size, batch_size, 1 + + self.raw_rampup_batch_size = rampup_batch_size + self.start_bsz = start_bsz + self.bsz_incre = bsz_incre + self.incre_every = incre_every + + if gpc.is_initialized(ParallelMode.PIPELINE): + assert ( + batch_size - self.start_bsz + ) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}" + assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})" + assert ( + self.start_bsz % micro_bsz == 0 + ), f"start_bsz({self.start_bsz}) should be multiple of micro_bsz({micro_bsz})" + assert ( + self.bsz_incre % micro_bsz == 0 + ), f"bsz_incre({self.bsz_incre}) should be multiple of micro_bsz({micro_bsz})" + + self.batch_size = batch_size + self.num_consumed_samples_in_epoch = 0 + self.batch_count = 0 + + def __len__(self): + return sys.maxsize + + def __iter__(self): + while True: + batch_rampup_idx = self.batch_count // self.incre_every + cur_batch_size = batch_rampup_idx * self.bsz_incre + self.start_bsz + cur_batch_size = min(cur_batch_size, self.batch_size) + + self.num_consumed_samples_in_epoch += cur_batch_size + self.batch_count += 1 + yield [0] * cur_batch_size + + def state_dict(self): + states = { + "batch_size": self.batch_size, + "raw_rampup_batch_size": self.raw_rampup_batch_size, + "num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch, + "batch_count": self.batch_count, + } + return states + + def load_state_dict(self, states): + for name in ("raw_rampup_batch_size",): # 'batch_size' + assert states[name] == getattr(self, name), (name, states[name], getattr(self, name)) # should not change + self.num_consumed_samples_in_epoch = states["num_consumed_samples_in_epoch"] + self.batch_count = states["batch_count"] + + def copy(self): + copy_sampler = StreamingStaticBatchSampler(self.batch_size, self.raw_rampup_batch_size) + + copy_sampler.load_state_dict(self.state_dict()) + return copy_sampler diff --git a/internlm/data/streaming/collaters.py b/internlm/data/streaming/collaters.py new file mode 100644 index 000000000..4391fd236 --- /dev/null +++ b/internlm/data/streaming/collaters.py @@ -0,0 +1,58 @@ +import torch + + +def nopack_collate_fn(batch, micro_num, micro_bsz, seq_len): + input_ids_list = [] + attention_mask_list = [] + labels_list = [] + for b in batch: + attention_mask = torch.tensor(b["attention_mask"]) + input_ids = torch.LongTensor(b["input_ids"]) + input_ids = torch.abs(input_ids * attention_mask) + input_ids = torch.nn.functional.pad(input_ids, (0, seq_len - len(input_ids)), mode="constant", value=0) + attention_mask = torch.nn.functional.pad( + attention_mask, (0, seq_len - len(attention_mask)), mode="constant", value=0 + ) + label = torch.LongTensor([w if w > 0 else -100 for w in input_ids.tolist()][1:] + [-100]) + input_ids_list.append(input_ids) + attention_mask_list.append(attention_mask) + labels_list.append(label) + input_ids = torch.stack(input_ids_list) + attention_mask = torch.stack(attention_mask_list) + labels = torch.stack(labels_list) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "type_ids": torch.zeros(micro_num, micro_bsz, seq_len, dtype=torch.int64), + }, labels + + +def pack_collate_fn(batch, micro_num, micro_bsz, seq_len): + packed_length = micro_bsz * seq_len + + input_ids_list = [] + cu_seqlens_list = [] + indexes_list = [] + labels_list = [] + + for b in batch: + assert len(b["input_ids"]) == packed_length + assert b["cu_seqlens"][0] == 0 and b["cu_seqlens"][-1] == packed_length + assert len(b["indexes"]) == packed_length + assert len(b["labels"]) == packed_length + + input_ids_list.append(torch.LongTensor(b["input_ids"])) + cu_seqlens_list.append(torch.IntTensor(b["cu_seqlens"])) + indexes_list.append(torch.IntTensor(b["indexes"])) + labels_list.append(torch.LongTensor(b["labels"])) + + input_ids = torch.stack(input_ids_list) + indexes = torch.stack(indexes_list) + labels = torch.stack(labels_list) + + return { + "input_ids": input_ids, + "cu_seqlens": cu_seqlens_list, + "indexes": indexes, + "type_ids": torch.zeros(micro_num, micro_bsz * seq_len, dtype=torch.int64), + }, labels diff --git a/internlm/data/streaming/dataset.py b/internlm/data/streaming/dataset.py new file mode 100644 index 000000000..a3844d706 --- /dev/null +++ b/internlm/data/streaming/dataset.py @@ -0,0 +1,119 @@ +import itertools +import sys + +import datasets +import numpy as np +from datasets.distributed import split_dataset_by_node +from torch.utils.data import Dataset + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from transformers import AutoTokenizer + + +class HuggingFaceStreamingDataset(Dataset): + """ + Streaming and on-the-fly tokenized dataset for huggingface + """ + + def __init__( + self, dataset_name, tokenizer_name, model_max_length, split="train", buffer_size=1000, subset_name=None + ): + self.dataset = datasets.load_dataset(dataset_name, data_dir=subset_name, split=split, streaming=True) + self.dataset = split_dataset_by_node( + self.dataset, rank=gpc.get_local_rank(ParallelMode.DATA), world_size=gpc.get_world_size(ParallelMode.DATA) + ) + self.buffer_size = buffer_size + self.senior_iterator = iter(self) + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) + self.tokenizer.model_max_length = model_max_length + + def __iter__(self): + buffer = [] + for sample in self.dataset: + buffer.append(sample) + if len(buffer) >= self.buffer_size: + yield from self._tokenize(buffer) + buffer = [] + + if buffer: + yield from self._tokenize(buffer) + + def __len__(self): + return sys.maxsize + + def _tokenize(self, samples): + texts = [sample["text"] for sample in samples] + tokenized_outputs = self.tokenizer(texts, truncation=True) + for i in range(len(samples)): + yield {key: tokenized_outputs[key][i] for key in tokenized_outputs} + + def __getitem__(self, _): + return next(self.senior_iterator) + + +class HuggingFacePackedDataset(Dataset): + """ + Simple packed dataset for huggingface. + """ + + def __init__(self, dataset, seq_len, micro_bsz): + self.dataset = dataset + self.seq_len = seq_len + self.micro_bsz = micro_bsz + + self.senior_iterator = iter(self) + + def __iter__(self): + input_ids = [] + cu_seqlens = [0] + labels = [] + for sample in self.dataset: + if len(input_ids + sample["input_ids"]) > self.micro_bsz * self.seq_len: + assert cu_seqlens[-1] <= self.micro_bsz * self.seq_len + input_ids = input_ids + [0] * (self.micro_bsz * self.seq_len - len(input_ids)) + cu_seqlens = ( + cu_seqlens + [self.micro_bsz * self.seq_len] + if cu_seqlens[-1] < self.micro_bsz * self.seq_len + else cu_seqlens + ) + labels = labels + [-100] * (self.micro_bsz * self.seq_len - len(labels)) + yield { + "input_ids": input_ids, + "cu_seqlens": cu_seqlens, + "indexes": list( + itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])]) + ), + "labels": labels, + } + input_ids = sample["input_ids"] + cu_seqlens = [0, len(sample["input_ids"])] + labels = sample["input_ids"][1:] + [-100] + else: + input_ids = input_ids + sample["input_ids"] + cu_seqlens.append(len(sample["input_ids"]) + cu_seqlens[-1]) + labels = labels + sample["input_ids"][1:] + [-100] + if input_ids: + assert cu_seqlens[-1] <= self.micro_bsz * self.seq_len + input_ids = input_ids + [0] * (self.micro_bsz * self.seq_len - len(input_ids)) + cu_seqlens = ( + cu_seqlens + [self.micro_bsz * self.seq_len] + if cu_seqlens[-1] < self.micro_bsz * self.seq_len + else cu_seqlens + ) + labels = labels + [-100] * (self.micro_bsz * self.seq_len - len(labels)) + yield { + "input_ids": input_ids, + "cu_seqlens": cu_seqlens, + "indexes": list( + itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])]) + ), + "labels": labels, + } + + def __len__(self): + return sys.maxsize + + def __getitem__(self, _): + return next(self.senior_iterator) diff --git a/internlm/data/streaming/utils.py b/internlm/data/streaming/utils.py new file mode 100644 index 000000000..ee331ba22 --- /dev/null +++ b/internlm/data/streaming/utils.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from internlm.core.context import global_context as gpc + + +# simple auto_resume for huggingface streaming dataloader +def hf_simple_resume(train_state): + skip_batches = gpc.config.data.get("skip_batches", "") + if train_state.batch_count > 0: + assert skip_batches == "", "skip_batches should be empty when huggingface dataloader resume from ckpts" + skip_batches = f"0-{train_state.batch_count - 1}" + train_state.batch_count = 0 + train_state.num_consumed_samples_in_epoch = 0 + if hasattr(train_state, "batch_sampler"): + train_state.batch_sampler.batch_count = 0 + train_state.batch_sampler.num_consumed_samples_in_epoch = 0 + train_state.batch_sampler_iter = iter(train_state.batch_sampler) + return skip_batches diff --git a/internlm/data/train_state.py b/internlm/data/train_state.py index 6564a4df0..cd1cc8a1e 100644 --- a/internlm/data/train_state.py +++ b/internlm/data/train_state.py @@ -5,7 +5,7 @@ def get_train_state(dataloader): # initialize and resume train state - if gpc.config.data.type == "tokenized": + if gpc.config.data.type in ["tokenized", "hf"]: train_state = TrainState(gpc.config, dataloader.batch_sampler) else: raise ValueError(f"dataset type {gpc.config.data.type} is not supported") diff --git a/internlm/data/utils.py b/internlm/data/utils.py index 91585a707..19e74ae2d 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -51,6 +51,10 @@ def unpack_type_ids(type_ids, cu_seqlens): def unpack_data(data, label): + + if gpc.config.model_type == "hf": + return data, label + data["input_ids"] = _unpack_data(data["input_ids"], data["cu_seqlens"], padding_v=0).squeeze(0) label = _unpack_data(label, data["cu_seqlens"], padding_v=-100).squeeze(0) @@ -73,4 +77,12 @@ def packed_data_normalizer(data, label): if gpc.config.parallel.sequence_parallel and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp": data["indexes"] = _split(data["indexes"], ParallelMode.TENSOR, dim=0) + if gpc.config.model_type == "hf": + data.pop("cu_seqlens") + data.pop("max_seqlen") + data["position_ids"] = data.pop("indexes") + data["attention_mask"] = torch.ones( + (data["input_ids"].shape), dtype=torch.bool, device=data["input_ids"].device + ) + return data, label diff --git a/internlm/initialize/initialize_trainer.py b/internlm/initialize/initialize_trainer.py index b90a25e9f..7e440528a 100644 --- a/internlm/initialize/initialize_trainer.py +++ b/internlm/initialize/initialize_trainer.py @@ -3,7 +3,7 @@ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/initialize -from typing import Callable, Iterable, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple from torch import nn from torch.nn.modules.loss import _Loss @@ -32,8 +32,6 @@ def initialize_trainer( model: nn.Module, optimizer: Optimizer, criterion: Optional[_Loss] = None, - train_dataloader: Optional[Iterable] = None, - test_dataloader: Optional[Iterable] = None, lr_scheduler: Optional[_LRScheduler] = None, beta2_scheduler: Optional[Beta2Scheduler] = None, scheduler_hooks: Optional[List[SchedulerHook]] = None, @@ -45,14 +43,10 @@ def initialize_trainer( model (:class:`torch.nn.Module` or `Callable`): Your model instance or a function to build the model. optimizer (:class:`BaseOptimizer`): Your optimizer for training. criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance. - train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training. - test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing. lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional. Returns: - Tuple (trainer, train_dataloader, test_dataloader, lr_scheduler): - A tuple of ``(trainer, train_dataloader, test_dataloader, lr_scheduler)`` - where only ``trainer`` could not be None. + Tuple (engine, scheduler) """ if isinstance(model, nn.Module): @@ -131,6 +125,4 @@ def initialize_trainer( clip_grad_norm=clip_grad_norm, ) - trainer = Trainer(engine, scheduler) - - return trainer, train_dataloader, test_dataloader, lr_scheduler + return engine, scheduler diff --git a/internlm/model/builder.py b/internlm/model/builder.py index 2b10406bb..c8adcd41f 100644 --- a/internlm/model/builder.py +++ b/internlm/model/builder.py @@ -5,7 +5,7 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.parallel.shard import pipeline_parallel_sharding_wrapper -from internlm.model.registry import model_initializer +from internlm.model.registry import hf_config_initializer, model_initializer from internlm.utils.common import get_current_device @@ -24,10 +24,15 @@ def create_model(model_type, *args, **kwargs) -> Union[nn.Module, List[nn.Module model_buidler = model_initializer.get_module(module_name=model_type) if not gpc.is_using_parallel_mode(ParallelMode.PIPELINE): - kwargs["first"] = kwargs["last"] = True - kwargs["start_layer_idx"] = 0 - kwargs["num_layers"] = num_layers - model = model_buidler(*args, **kwargs).to(kwargs["device"]) + if model_type == "hf": + hf_config_builder = hf_config_initializer.get_module(module_name=model_type) + config = hf_config_builder(return_dict=False) + model = model_buidler(*args, config).to(kwargs["device"]) + else: + kwargs["first"] = kwargs["last"] = True + kwargs["start_layer_idx"] = 0 + kwargs["num_layers"] = num_layers + model = model_buidler(*args, **kwargs).to(kwargs["device"]) setattr(model, "first_layer", 0) setattr(model, "last_layer", num_layers) else: diff --git a/internlm/model/registry.py b/internlm/model/registry.py index e91a22551..01f02dc1f 100644 --- a/internlm/model/registry.py +++ b/internlm/model/registry.py @@ -73,6 +73,7 @@ def has(self, module_name: str): model_initializer = Registry("model_initializer") +hf_config_initializer = Registry("hf_config_initializer") def register_model_initializer() -> None: diff --git a/internlm/monitor/__init__.py b/internlm/monitor/__init__.py index 56c8309bd..2bcfa2ccf 100644 --- a/internlm/monitor/__init__.py +++ b/internlm/monitor/__init__.py @@ -1,8 +1,9 @@ -from .monitor import initialize_monitor_manager, send_alert_message +from .monitor import initialize_monitor_manager, internevo_monitor, send_alert_message from .utils import set_env_var __all__ = [ "send_alert_message", "initialize_monitor_manager", "set_env_var", + "internevo_monitor", ] diff --git a/internlm/monitor/monitor.py b/internlm/monitor/monitor.py index cca9ca448..fc33de62a 100644 --- a/internlm/monitor/monitor.py +++ b/internlm/monitor/monitor.py @@ -1,17 +1,59 @@ import fcntl +import logging import os +import shutil import signal import socket import time +import traceback from contextlib import contextmanager +from functools import wraps from threading import Thread +from internlm.accelerator.abstract_accelerator import get_accelerator from internlm.core.context import global_context as gpc from internlm.monitor.alert import send_feishu_msg_with_webhook from internlm.utils.common import SingletonMeta from .utils import get_job_key, set_env_var +logger = logging.getLogger(__file__) +internlm_accelerator = get_accelerator() + + +def internevo_monitor(feishu_alert=True, clean_run=True): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if feishu_alert: + with initialize_monitor_manager( + job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address + ): + return execute_with_exception_handling(func, *args, **kwargs) + else: + return execute_with_exception_handling(func, *args, **kwargs) + + def execute_with_exception_handling(func, *args, **kwargs): + if not clean_run: + return func(*args, **kwargs) + try: + return func(*args, **kwargs) + except Exception: + hostname = socket.gethostname() + logger.error( + f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}", + ) + finally: + devices_per_node = internlm_accelerator.device_count() + local_rank = gpc.get_global_rank() % devices_per_node + if gpc.config.data.use_shm and local_rank == 0: + if os.path.exists(gpc.config.data.shm_path): + shutil.rmtree(gpc.config.data.shm_path) + + return wrapper + + return decorator + def send_alert_message(address: str = None, title: str = None, message: str = None): """ diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index bdafdedb6..0c5156615 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -145,10 +145,21 @@ def _check_module(name, module): for param in module.parameters(): setattr(param, IS_REPLICA_ZERO_PARALLEL, True) + def _check_module_hf(_, module): + # TODO: check parallel attribute for hf model + for param in module.parameters(): + if gpc.is_initialized(ParallelMode.TENSOR) and is_using_isp(): + setattr(param, IS_TENSOR_DATA_PARALLEL, True) + elif gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp(): + setattr(param, IS_TENSOR_ZERO_PARALLEL, True) + for _chunk in unwrap_naive_amp(model): # set param parallel attribute for name, module in _chunk.named_modules(): - _check_module(name, module) + if gpc.config.model_type == "hf": + _check_module_hf(name, module) + else: + _check_module(name, module) for name, param in _chunk.named_parameters(): assert ( @@ -464,7 +475,8 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai if batch[0].get("type_ids", None) is not None: # if use_packed_dataset is False, we need to unpack type_ids if not gpc.config.data.use_packed_dataset: - batch[0]["type_ids"] = unpack_type_ids(batch[0]["type_ids"], batch[0]["cu_seqlens"]) + if gpc.config.data.type != "hf" or gpc.config.model_type != "hf": + batch[0]["type_ids"] = unpack_type_ids(batch[0]["type_ids"], batch[0]["cu_seqlens"]) return batch, train_iter @@ -554,10 +566,17 @@ def record_current_batch_training_metrics( num_tokens_in_batch = batch[1].nelement() real_num_tokens = math.ceil(acc_perplex.pop("real_token_num") / gpc.get_world_size(ParallelMode.GLOBAL)) - num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]]) - max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]]) - max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]]) - min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]]) + # TODO: check logic + if gpc.config.data.type == "hf" and gpc.config.model_type == "hf" and not gpc.config.data.use_packed_dataset: + num_samples_in_batch = gpc.config.data.micro_bsz * gpc.config.data.micro_num + max_length_in_batch = batch[0]["attention_mask"].sum(dim=1).max().item() + max_samples_in_batch = gpc.config.data.micro_bsz + min_samples_in_batch = gpc.config.data.micro_bsz + else: + num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]]) + max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]]) + max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]]) + min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]]) time_cost = time.time() - start_time tk_per_gpu = round( num_tokens_in_batch * gpc.get_world_size(ParallelMode.DATA) / gpc.get_world_size(ParallelMode.GLOBAL), diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index 563ea69cc..6aa1ebd12 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -4,6 +4,8 @@ import multiprocessing import os +from internlm.utils.common import SingletonMeta + if "USE_DILL_PICKLE" in os.environ: import dill @@ -964,23 +966,6 @@ def check_tmp_folder_accessibility(tmp_local_folder: str): raise RuntimeError(error_str) -class SingletonMeta(type): - """ - Singleton Meta. - """ - - _instances = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super().__call__(*args, **kwargs) - else: - assert ( - len(args) == 0 and len(kwargs) == 0 - ), f"{cls.__name__} is a singleton class and a instance has been created." - return cls._instances[cls] - - class StorageManager(metaclass=SingletonMeta): """ Storage Manager for saving or loading checkpoint. diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 8416d49fc..595a31bd3 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,4 +1,4 @@ -transformers<4.30.0 +transformers sentencepiece numpy tqdm diff --git a/tests/test_infer/test_trainer_generate.py b/tests/test_infer/test_trainer_generate.py index 3b7fffd05..3ccbfb54d 100644 --- a/tests/test_infer/test_trainer_generate.py +++ b/tests/test_infer/test_trainer_generate.py @@ -7,7 +7,7 @@ from internlm.apis.inference import SequenceGenerator, batch_tokenize from internlm.checkpoint import CheckpointManager # noqa: E402 from internlm.core.context import global_context as gpc # noqa: E402 -from internlm.core.trainer import TrainState # noqa: E402 +from internlm.core.trainer import TrainState, Trainer # noqa: E402 from internlm.data import build_train_loader_with_data_type # noqa: E402 from internlm.initialize import initialize_distributed_env # noqa: E402 from internlm.model.losses import FlashGPTLMLoss # noqa: E402 @@ -47,15 +47,15 @@ def setup_generator(config, tokenizer): ckpt_manager.try_resume_training(train_state) # initialize trainer - trainer, train_dl, _, _ = internlm.initialize_trainer( + engine, scheduler = internlm.initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, - train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, scheduler_hooks=get_scheduler_hooks(None, optimizer, isp_communicator), ) + trainer = Trainer(engine, scheduler) trainer.schedule.data_process_func = None diff --git a/tests/test_training/test_forward_output_no_fa.py b/tests/test_training/test_forward_output_no_fa.py index 00025a7b5..4c36ad87a 100644 --- a/tests/test_training/test_forward_output_no_fa.py +++ b/tests/test_training/test_forward_output_no_fa.py @@ -12,6 +12,7 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config +from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type from internlm.initialize.launch import args_sanity_check from internlm.model.losses import FlashGPTLMLoss @@ -197,15 +198,15 @@ def train_check_output(args): ), ] - trainer, train_dl, _, _ = internlm.initialize_trainer( + engine, scheduler = internlm.initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, - train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, scheduler_hooks=scheduler_hooks, ) + trainer = Trainer(engine, scheduler) # transfer the train data loader into train data iterator trainer.train() diff --git a/tests/test_training/test_load_ckpt_loss.py b/tests/test_training/test_load_ckpt_loss.py index 0cd221458..45cd319c4 100644 --- a/tests/test_training/test_load_ckpt_loss.py +++ b/tests/test_training/test_load_ckpt_loss.py @@ -29,6 +29,7 @@ ) from internlm.core.trainer import ( # noqa: E402 #pylint: disable=wrong-import-position TrainState, + Trainer, ) from internlm.data import ( # noqa: E402 #pylint: disable=wrong-import-position build_train_loader_with_data_type, @@ -265,15 +266,15 @@ def train_model(args): ), ] - trainer, train_dl, _, _ = internlm.initialize_trainer( + engine, scheduler = internlm.initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, - train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, scheduler_hooks=scheduler_hooks, ) + trainer = Trainer(engine, scheduler) trainer.train() train_iter = iter(train_dl) diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index e69073538..fa8147cde 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -10,7 +10,7 @@ from internlm.checkpoint import CheckpointManager from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.trainer import TrainState +from internlm.core.trainer import TrainState, Trainer from internlm.data import build_train_loader_with_data_type from internlm.initialize import initialize_distributed_env from internlm.model.losses import FlashGPTLMLoss @@ -193,15 +193,15 @@ def train( ) # initialize trainer - trainer, train_dl, _, _ = internlm.initialize_trainer( + engine, scheduler = internlm.initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, - train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, scheduler_hooks=get_scheduler_hooks(metric, optimizer, isp_communicator), ) + trainer = Trainer(engine, scheduler) # initialize the batch skipper batch_skipper = BatchSkipper(skip_batches) diff --git a/tests/test_training/test_no_fa_train_temp.py b/tests/test_training/test_no_fa_train_temp.py index d430c16a8..afc1c4934 100644 --- a/tests/test_training/test_no_fa_train_temp.py +++ b/tests/test_training/test_no_fa_train_temp.py @@ -6,6 +6,7 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type from internlm.model.losses import FlashGPTLMLoss from internlm.model.metrics import AccPerplex @@ -70,15 +71,15 @@ def train_check(args): dataset_types=dataset_types, ) - trainer, train_dl, _, _ = internlm.initialize_trainer( + engine, scheduler = internlm.initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, - train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, scheduler_hooks=get_scheduler_hooks(metric, optimizer, isp_communicator), ) + trainer = Trainer(engine, scheduler) # transfer the train data loader into train data iterator trainer.train() diff --git a/tests/test_training/test_norm_weight.py b/tests/test_training/test_norm_weight.py index 848cf7402..98b3093dc 100644 --- a/tests/test_training/test_norm_weight.py +++ b/tests/test_training/test_norm_weight.py @@ -9,6 +9,7 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type from internlm.model.losses import FlashGPTLMLoss from internlm.model.metrics import AccPerplex @@ -90,15 +91,15 @@ def train_check_norm_weight(args): dataset_types=dataset_types, ) - trainer, train_dl, _, _ = internlm.initialize_trainer( + engine, scheduler = internlm.initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, - train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, scheduler_hooks=get_scheduler_hooks(metric, optimizer, isp_communicator), ) + trainer = Trainer(engine, scheduler) # transfer the train data loader into train data iterator trainer.train() diff --git a/tests/test_training/test_swap_nb_loss_and_gradnorm.py b/tests/test_training/test_swap_nb_loss_and_gradnorm.py index f6e523823..92f09ada9 100644 --- a/tests/test_training/test_swap_nb_loss_and_gradnorm.py +++ b/tests/test_training/test_swap_nb_loss_and_gradnorm.py @@ -14,6 +14,7 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config +from internlm.core.trainer import Trainer from internlm.data import ( build_train_loader_with_data_type, build_valid_loader_with_data_type, @@ -302,15 +303,15 @@ def exam_loss(args): ), ] - trainer, train_dl, _, _ = internlm.initialize_trainer( + engine, scheduler = internlm.initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, - train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, scheduler_hooks=scheduler_hooks, ) + trainer = Trainer(engine, scheduler) trainer.train() diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index ab027e1c6..4e9ab7490 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -20,7 +20,7 @@ from internlm.checkpoint import CheckpointManager # noqa: E402 from internlm.core.context import ParallelMode # noqa: E402 from internlm.core.context import global_context as gpc # noqa: E402 -from internlm.core.trainer import TrainState # noqa: E402 +from internlm.core.trainer import TrainState, Trainer # noqa: E402 from internlm.data import ( # noqa: E402 build_train_loader_with_data_type, build_valid_loader_with_data_type, @@ -181,15 +181,15 @@ def main(args): ), ] - trainer, train_dl, _, _ = internlm.initialize_trainer( + engine, scheduler = internlm.initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, - train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, scheduler_hooks=scheduler_hooks, ) + trainer = Trainer(engine, scheduler) # initialize simple memory profiler if args.profiling: diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index b8b56ec66..023b085c1 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -12,7 +12,7 @@ from internlm.model.registry import register_model_initializer from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer from internlm.train.utils import create_param_groups -from internlm.utils.storage_manager import SingletonMeta +from internlm.utils.common import SingletonMeta OSS_NAME = os.environ.get("OSS_BUCKET_NAME", None) OSS_IP = os.environ.get("OSS_IP", None) diff --git a/tests/test_utils/test_model_checkpoint.py b/tests/test_utils/test_model_checkpoint.py index 65325d329..5fe8b3c49 100644 --- a/tests/test_utils/test_model_checkpoint.py +++ b/tests/test_utils/test_model_checkpoint.py @@ -13,7 +13,8 @@ from internlm.core.context.parallel_context import Config from internlm.core.trainer import TrainState from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer -from internlm.utils.storage_manager import SingletonMeta, wait_async_upload_finish +from internlm.utils.common import SingletonMeta +from internlm.utils.storage_manager import wait_async_upload_finish from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import ASYNC_TMP_FOLDER, BOTO_SAVE_PATH, diff --git a/train.py b/train.py index d218b4751..085344202 100644 --- a/train.py +++ b/train.py @@ -1,329 +1,45 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import gc -import logging -import os -import shutil -import socket -import time -import traceback -from functools import partial - -import torch.distributed as dist - -import internlm -from internlm.accelerator import get_accelerator -from internlm.checkpoint import CheckpointManager -from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.trainer_builder import TrainerBuilder from internlm.data import ( build_train_loader_with_data_type, build_valid_loader_with_data_type, ) -from internlm.data.train_state import get_train_state -from internlm.eval.evaluation import evaluate_on_val_dls from internlm.initialize import initialize_distributed_env -from internlm.model.losses import FlashGPTLMLoss -from internlm.model.metrics import AccPerplex -from internlm.monitor import initialize_monitor_manager, send_alert_message -from internlm.monitor.monitor import monitor_manager as mm -from internlm.train import ( - get_scheduler_hooks, - initialize_llm_profile, - initialize_model, - initialize_optimizer, - initialize_parallel_communicator, - load_new_batch, - record_current_batch_training_metrics, -) -from internlm.utils.common import ( - BatchSkipper, - enable_pytorch_expandable_segments, - get_current_device, - get_megatron_flops, - launch_time, - parse_args, -) -from internlm.utils.gputest import empty_cache_and_diag -from internlm.utils.logger import get_logger -from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.utils.parallel import get_parallel_log_file_name -from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler -from internlm.utils.writer import Writer - -# global llm logger -logger = logging.getLogger(__file__) -internlm_accelerator = get_accelerator() +from internlm.monitor import internevo_monitor +from internlm.train import initialize_model +from internlm.utils.common import parse_args +@internevo_monitor(feishu_alert=True, clean_run=True) def main(args): - very_begining_time = time.time() - enable_pytorch_expandable_segments() - - # init setting - skip_batches = gpc.config.data.skip_batches - total_steps = gpc.config.data.total_steps - valid_every = gpc.config.data.valid_every - label_smoothing = gpc.config.loss.label_smoothing - - get_tflops_func = partial( - get_megatron_flops, - checkpoint=gpc.config.model.checkpoint, - seq_len=gpc.config.data["seq_len"], - hidden_size=gpc.config.model.hidden_size, - num_layers=gpc.config.model.num_layers, - vocab_size=gpc.config.model.vocab_size, - global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA), - global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), - mlp_ratio=gpc.config.model["mlp_ratio"], - ) - - # get and broadcast current time - current_time = launch_time() - objs = [current_time] - dist.broadcast_object_list(objs, src=0) - current_time = objs[0].replace(":", ".") - global logger - logger = get_logger( - __file__, launch_time=current_time, job_name=gpc.config.JOB_NAME, file_name=get_parallel_log_file_name() - ) - # initialize model model = initialize_model() - # initialize isp communicator - isp_communicator = initialize_parallel_communicator(model) - - with open(args.config, "r") as f: - config_lines = f.readlines() - - # initialize loss function - criterion = FlashGPTLMLoss(parallel_output=gpc.config.model.parallel_output, label_smoothing=label_smoothing) - - # initialize the train and validation data loader + # initialize train dataloader train_dl, dataset_types = build_train_loader_with_data_type() - val_dls = build_valid_loader_with_data_type() - - # initialize and resume train state - train_state = get_train_state(train_dl) - - optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator) - - ckpt_manager = CheckpointManager( - ckpt_config=gpc.config.ckpt, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - train_dl=train_dl, - model_config=gpc.config.model, - model_config_file="".join(config_lines), - feishu_address=gpc.config.monitor.alert.feishu_alert_address, - ) - - # Loading other persistent training states. - ckpt_manager.try_resume_training(train_state, current_time) - - # initialize customed llm writer - writer = Writer( - job_name=gpc.config.JOB_NAME, - launch_time=current_time, - file_name=get_parallel_log_file_name(), - tensorboard_folder=gpc.config.tensorboard_folder, - resume_tb_folder=train_state.resume_tb_folder, # resume from ckpt. - step_count=train_state.step_count, # resume from ckpt. - config=config_lines, - logger=logger, - enable_tb=gpc.config.enable_tb, - queue_max_length=gpc.config.tensorboard.queue_max_length, - total_steps=total_steps, - ) - - # initialize metric for calculating accuracy and perplexity - metric = AccPerplex( - device=get_current_device(), - tp_pg=gpc.get_group(ParallelMode.TENSOR), - dp_pg=gpc.get_group(ParallelMode.DATA), - dataset_types=dataset_types, - ) - - # initialize trainer - trainer, train_dl, _, _ = internlm.initialize_trainer( - model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dl, - lr_scheduler=lr_scheduler, - beta2_scheduler=beta2_scheduler, - scheduler_hooks=get_scheduler_hooks(metric, optimizer, isp_communicator), - ) - # initialize simple memory profiler - if args.profiling: - memory_profiler = SimpleMemoryProfiler( - model, - optimizer.optim, - log_folder=f"RUN/{gpc.config.JOB_NAME}/{current_time}/memory_trace/rank{gpc.get_global_rank()}_" - + f"dp{gpc.get_local_rank(ParallelMode.DATA)}_" - + f"wp{gpc.get_local_rank(ParallelMode.WEIGHT)}_" - + f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}", - ) - else: - memory_profiler = None - - # initialize the batch skipper - batch_skipper = BatchSkipper(skip_batches) - - trainer.train() - - # transfer the train data loader into train data iterator - train_iter = iter(train_dl) - - with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof: - # close automatic garbage collection - gc.disable() - # start iterating the train data and begin training - for batch_count in range(train_state.batch_count, total_steps): - empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) - # internlm_accelerator.memory._record_memory_history() - start_time = time.time() - timer("one-batch").start() - - # load batch data - batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) - - # record the consumed samples in training - train_state.batch_count = batch_count - train_state.num_consumed_samples_in_epoch += len(batch[1]) - if batch_skipper(batch_count): # skip this batch - if gpc.is_rank_for_log(): - logger.info(f"Skip batch count:`{batch_count}`...") - timer("one-batch").stop() - continue - - # zero the grads of parameters - trainer.zero_grad() - # process data - if batch[0].get("type_ids", None) is not None: - metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None)) - # if batch[0].get("cu_seqlens", None) is not None: - # metric.set_cu_seqlens(cu_seqlens=batch[0].pop("cu_seqlens", None)) - - # do forward and backward - timer("fwd-bwd").start() - - moe_loss = None - if hasattr(gpc.config.model, "num_experts"): - _, _, loss, moe_loss = trainer.execute_schedule( - batch, - forward_only=False, - return_loss=True, - return_output_label=False, - ) - else: - _, _, loss = trainer.execute_schedule( - batch, - forward_only=False, - return_loss=True, - return_output_label=False, - ) - timer("fwd-bwd").stop() - - if isp_communicator and isp_communicator.enable_memory_pool: - isp_communicator.memory_pool.reset_lazy_pools() - - # update parameters, and returns (success_update, grad_norm) - trainer_result = trainer.step() - assert trainer_result is not None - - success_update, grad_norm_groups = trainer_result - if success_update: # update parameters successfully - train_state.step_count += 1 - else: - train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully. - if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case - logger.warning(f"Warning: skip parameter update at step {batch_count}.") - send_alert_message( - address=gpc.config.monitor.alert.feishu_alert_address, - message=f"Warning: skip parameter update at step {batch_count}.", - ) - - # calculate and record the training metrics, eg. loss, accuracy and so on. - record_current_batch_training_metrics( - get_tflops_func=get_tflops_func, - logger=logger, - writer=writer, - success_update=success_update, - batch_count=batch_count, - batch=batch, - train_state=train_state, - optimizer=optimizer, - beta2_scheduler=beta2_scheduler, - trainer=trainer, - start_time=start_time, - very_begining_time=very_begining_time, - loss=loss, - moe_loss=moe_loss, - grad_norm=grad_norm_groups, - metric=metric, - ) - - timer("one-batch").stop() - - # evaluate on validation data loaders - if valid_every > 0 and train_state.step_count % valid_every == 0: - evaluate_on_val_dls( - trainer=trainer, - val_dls=val_dls, - writer=writer, - logger=logger, - step_count=train_state.step_count, - ) - - # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" - # # save batch sampler that tracks the true consumed samples - now_break = ckpt_manager.try_save_checkpoint(train_state) - if now_break: - break - - if memory_profiler is not None: - memory_profiler.step() + # initialize validation dataloader + val_dls = build_valid_loader_with_data_type() - if batch_count % 2 == 0: - prof.step() + # initialize kwargs + kwargs = vars(args) | {"dataset_types": dataset_types} - # internlm_accelerator.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") + # build trainer + trainer = TrainerBuilder(model, train_dl, val_dls, **kwargs) - ckpt_manager.wait_async_upload_finish() + # training + trainer.fit() if __name__ == "__main__": args = parse_args() - hostname = socket.gethostname() - # initialize distributed environment + # Initialize distributed environment initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) assert hasattr(gpc, "config") and gpc.config is not None - # initialize monitor manager context - with initialize_monitor_manager( - job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address - ): - try: - main(args) - except Exception: - logger.error( - f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}", - ) - mm.monitor_exception( - alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc() - ) - - # internlm_accelerator.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") - finally: - # local rank0 delete all files in shm_path, when use shm - devices_per_node = internlm_accelerator.device_count() - local_rank = gpc.get_global_rank() % devices_per_node - if gpc.config.data.use_shm and local_rank == 0: - if os.path.exists(gpc.config.data.shm_path): - shutil.rmtree(gpc.config.data.shm_path) + # Run the main function with parsed arguments + main(args) From 2ac2d08fc9f6f526bf3f3f9aad14b0783a9b4a24 Mon Sep 17 00:00:00 2001 From: Season Date: Tue, 16 Jul 2024 10:36:19 +0800 Subject: [PATCH 05/12] Fix(ckpt): fix llama2 loading function (#276) --- internlm/checkpoint/load_funcs.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/internlm/checkpoint/load_funcs.py b/internlm/checkpoint/load_funcs.py index 1ba0ac6a3..423695adb 100644 --- a/internlm/checkpoint/load_funcs.py +++ b/internlm/checkpoint/load_funcs.py @@ -148,12 +148,6 @@ def load_hf_llama_pretrained_weights(folder, model): if f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq" in states: states.pop(f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq") - if gpc.config.model_type in ("LLAMA2",): - w2 = states.pop(f"layers.{i}.feed_forward.w2.weight") - w3 = states.pop(f"layers.{i}.feed_forward.w3.weight") - states[f"layers.{i}.feed_forward.w2.weight"] = w3 - states[f"layers.{i}.feed_forward.w3.weight"] = w2 - for name in list(states.keys()): if name.startswith(f"layers.{i}"): current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name) From fb09282b5733c8f7e6fb8c6895f0e332eed69ad3 Mon Sep 17 00:00:00 2001 From: Chang Cheng <1953414760@qq.com> Date: Tue, 16 Jul 2024 10:38:00 +0800 Subject: [PATCH 06/12] feat(tools): update InternEvo style ckpt inference tool. (#260) --- configs/7B_internlm2.py | 17 +- configs/_base_/models/internlm2_1B.py | 2 +- configs/_base_/models/internlm2_20B.py | 2 +- configs/_base_/models/internlm2_7B.py | 2 +- doc/usage.md | 25 ++ generate.py | 251 ++++++++++++++++++ internlm/data/__init__.py | 2 + internlm/data/build_dataloader.py | 49 +++- internlm/data/tokenized/collaters.py | 26 ++ internlm/data/tokenized/dataset.py | 2 +- internlm/data/tokenized/packed_dataset.py | 29 +- tools/README.md | 57 +++- tools/README_EN.md | 2 +- ...ernlm_model.py => load_internlm2_model.py} | 80 ++++-- web_demo_internlm.py | 2 +- 15 files changed, 504 insertions(+), 44 deletions(-) create mode 100644 generate.py rename tools/{load_internlm_model.py => load_internlm2_model.py} (85%) diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index a69896ce5..891885c37 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -1,5 +1,5 @@ JOB_NAME = "7b_internlm2_train" -model_type="INTERNLM2_PUBLIC" +model_type = "INTERNLM2_PUBLIC" DO_ALERT = False VOCAB_SIZE = 92544 @@ -205,3 +205,18 @@ # metric_dtype can be "fp32" or other string # only when set to "fp32" will use fp32 to calc in metrics # metric_dtype = "fp32" + +generation = dict( + ckpt_folder="/path/to/saved/ckpt", + output_folder="/path/to/save/generation", + batch_size=1, + eos_id=[2, 0], + bos_id=1, + max_length=100, + do_sample=True, + temperature=1.0, + top_k=50, + top_p=1.0, + repetition_penalty=1, + length_penalty=1.0, +) diff --git a/configs/_base_/models/internlm2_1B.py b/configs/_base_/models/internlm2_1B.py index ff0569d32..7d0639197 100644 --- a/configs/_base_/models/internlm2_1B.py +++ b/configs/_base_/models/internlm2_1B.py @@ -25,7 +25,7 @@ mlp_ratio=MLP_RATIO, multiple_of=MULTIPLE_OF, norm_type="rmsnorm", - adapt_hf=True, + qk_interleaved=False, apply_post_layer_norm=False, no_bias=True, layer_norm_epsilon=1e-5, diff --git a/configs/_base_/models/internlm2_20B.py b/configs/_base_/models/internlm2_20B.py index 82b062493..1347b98f6 100644 --- a/configs/_base_/models/internlm2_20B.py +++ b/configs/_base_/models/internlm2_20B.py @@ -23,7 +23,7 @@ num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, mlp_ratio=MLP_RATIO, norm_type="rmsnorm", - adapt_hf=True, + qk_interleaved=False, apply_post_layer_norm=False, no_bias=True, layer_norm_epsilon=1e-5, diff --git a/configs/_base_/models/internlm2_7B.py b/configs/_base_/models/internlm2_7B.py index 81f5acd47..94cae4b36 100644 --- a/configs/_base_/models/internlm2_7B.py +++ b/configs/_base_/models/internlm2_7B.py @@ -23,7 +23,7 @@ num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, mlp_ratio=MLP_RATIO, norm_type="rmsnorm", - adapt_hf=False, + qk_interleaved=True, apply_post_layer_norm=False, no_bias=True, layer_norm_epsilon=1e-5, diff --git a/doc/usage.md b/doc/usage.md index 78b929603..ad78fe2e9 100644 --- a/doc/usage.md +++ b/doc/usage.md @@ -459,6 +459,31 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py - 2023-07-07 12:29:16,994 INFO train.py:323 in record_current_batch_training_metrics -- tflops=189.3109313713174,step=5,loss=9.822169303894043,tgs (tokens/gpu/second)=4262.67,lr=1.4000000000000001e-06,loss_scale=65536.0,grad_norm=47.10386835560855,micro_num=4,num_consumed_tokens=786432,inf_nan_skip_batches=0,num_samples_in_batch=17,largest_length=2048,largest_batch=6,smallest_batch=3,adam_beta2=0.95,fwd_bwd_time=3.69 ``` +### 加载训练的checkpoint并生成 + +若在 slurm 上启动分布式运行环境,多节点 16 卡的运行命令如下所示: +```bash +$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python generate.py --config ./configs/7B_sft.py +``` + +在配置文件中添加`generation`配置 +``` +generation = dict( + ckpt_folder="/path/to/saved/ckpt", + output_folder="/path/to/save/generation", + batch_size=1, + eos_id=[2, 0], + bos_id=1, + max_length=100, + do_sample=True, + temperature=1.0, + top_k=50, + top_p=1.0, + repetition_penalty=1, + length_penalty=1.0, +) +``` + ### 长文本生成 在推理阶段,我们可以使用 Dynamic NTK RoPE 来代替原始的 RoPE,从而使得模型能够适应长文本的输入输出,达到 16K 的外推效果。 diff --git a/generate.py b/generate.py new file mode 100644 index 000000000..4ae760299 --- /dev/null +++ b/generate.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import gc +import json +import logging +import os +import shutil +import socket +import traceback +from pathlib import Path + +import numpy as np +import torch +from tqdm import tqdm + +from internlm.accelerator import get_accelerator +from internlm.apis.inference import SequenceGenerator +from internlm.core.context import global_context as gpc +from internlm.data import build_generation_loader_with_data_type +from internlm.initialize import initialize_distributed_env +from internlm.monitor import initialize_monitor_manager +from internlm.monitor.monitor import monitor_manager as mm +from internlm.train import initialize_model, initialize_parallel_communicator +from internlm.utils.common import ( + enable_pytorch_expandable_segments, + launch_time, + parse_args, +) +from internlm.utils.gputest import empty_cache_and_diag +from internlm.utils.logger import get_logger +from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.parallel import get_parallel_log_file_name +from internlm.utils.storage_manager import init_storage_manager +from tools.load_internlm2_model import get_model_device, merge_pp_within_tp + +# global llm logger +logger = logging.getLogger(__file__) +internlm_accelerator = get_accelerator() + + +def get_latest_subdirectory(folder_path): + if ":" in folder_path: + prefix, folder_path = folder_path.split(":", 1) + prefix += ":" + else: + prefix = "" + subdirectories = [name for name in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, name))] + subdirectories_sorted = sorted( + subdirectories, key=lambda x: os.path.getctime(os.path.join(folder_path, x)), reverse=True + ) + if subdirectories_sorted: + return prefix + os.path.join(folder_path, subdirectories_sorted[0]) + else: + return None + + +def main(): + enable_pytorch_expandable_segments() + + generation_config = gpc.config["generation"] + + generation_config = type( + "", + (object,), + { + "output_folder": Path(generation_config["output_folder"]), + "ckpt_folder": generation_config["ckpt_folder"] + if "ckpt_folder" in generation_config + else get_latest_subdirectory(gpc.config.ckpt.save_ckpt_folder), + "data_folder": generation_config["data_folder"] if "data_folder" in generation_config else None, + "batch_size": generation_config.get("batch_size", None), + "eos_id": generation_config.get("eos_id", 2), + "bos_id": generation_config.get("bos_id", 1), + "pad_id": generation_config.get("bos_id", 1), + "additional_eos_token_list": generation_config.get("additional_eos_token_list", None), + "max_length": generation_config.get("max_length", 100), + "do_sample": generation_config.get("do_sample", True), + "temperature": generation_config.get("temperature", 1.0), + "num_beams": generation_config.get("num_beams", 1), + "top_k": generation_config.get("top_k", 50), + "top_p": generation_config.get("top_p", 1.0), + "repetition_penalty": generation_config.get("repetition_penalty", 1), + "length_penalty": generation_config.get("length_penalty", 1.0), + }, + ) + + if not os.path.exists(generation_config.output_folder.absolute()): + generation_config.output_folder.mkdir(exist_ok=True, parents=True) + + # get and broadcast current time + current_time = launch_time() + objs = [current_time] + torch.distributed.broadcast_object_list(objs, src=0) + current_time = objs[0].replace(":", ".") + global logger + logger = get_logger( + __file__, launch_time=current_time, job_name=gpc.config.JOB_NAME, file_name=get_parallel_log_file_name() + ) + + try: + init_storage_manager(False, None, None) + except AssertionError: + pass + except Exception as e: + raise e + + # initialize model + model = initialize_model() + _ = initialize_parallel_communicator(model) + model = model.model + + state_dict = merge_pp_within_tp(generation_config.ckpt_folder, del_model_prefix=True) + missing_k, unexpected_keys = model.load_state_dict(state_dict, strict=False) + if len(missing_k) != 0: + logger.warning(f"Warning: missing keys {missing_k}") + if len(unexpected_keys) != 0: + logger.warning(f"Warning: unexpected keys {unexpected_keys}") + + param_dtype = gpc.config.model.dtype + if isinstance(param_dtype, str): + try: + param_dtype = eval(param_dtype) # pylint: disable=W0123 + finally: + pass + if param_dtype == "torch.tf32": + param_dtype = torch.float32 + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + model.to(param_dtype) + model.eval() + torch.distributed.barrier() + + data_cfg = gpc.config.data + if generation_config.data_folder: + data_cfg.valid_folder = generation_config.data_folder + gene_dls = build_generation_loader_with_data_type(data_cfg, generation_config) + + sequenece_generator = SequenceGenerator( + decoder=model, + eos_token_id=generation_config.eos_id, + pad_token_id=generation_config.bos_id, + bos_token_id=generation_config.pad_id, + additional_eos_token_list=generation_config.additional_eos_token_list, + ) + + ds_count = 0 + gc.disable() + with torch.inference_mode(): + for ds_name, gene_dl in gene_dls.items(): + if len(gene_dl) == 0: + logger.info(f"Validation dataset: {ds_name} is empty") + continue + timer(f"dataset {ds_count}").start() + + # pylint: disable=forgotten-debug-statement + all_output_str = [] + # pylint: disable=unused-variable + for val_idx, (labels, input_ids) in tqdm( + enumerate(gene_dl), + desc="generate.", + total=len(gene_dl), + position=1, + leave=False, + ): + empty_cache_and_diag(val_idx, interval=gpc.config.data.empty_cache_and_diag_interval) + input_ids = torch.LongTensor(input_ids) + if input_ids.size(1) >= generation_config.max_length: + logger.warning( + f"Not generating for the {val_idx}'th batch, because the sequence " + f"length of the batch is {input_ids.size(1)} over the max generation" + f"length {generation_config.max_length}" + ) + output_ids = input_ids[:, : generation_config.max_length, ...] + else: + input_ids = input_ids.clamp(min=0, max=gpc.config.model.vocab_size).to(get_model_device(model)) + output_ids = sequenece_generator.generate( + tokens=input_ids, + max_length=generation_config.max_length, + do_sample=generation_config.do_sample, + temperature=generation_config.temperature, + num_beams=generation_config.num_beams, + top_k=generation_config.top_k, + top_p=generation_config.top_p, + repetition_penalty=generation_config.repetition_penalty, + length_penalty=generation_config.length_penalty, + ) + for output in output_ids: + not_pad_indices = torch.nonzero(output != generation_config.pad_id) + if not_pad_indices.nelement() != 0: + sequence = output[not_pad_indices[0] :] + else: + sequence = output + sequence = sequence.tolist() + line = str.encode(json.dumps({"tokens": sequence})) + all_output_str.append( + ( + line, + len(line), + ) + ) + + bin_meta, last_position = [], 0 + with open(generation_config.output_folder.joinpath(f"{ds_name}.bin"), "wb") as file: + for line, token_num in all_output_str: + file.write(line) + bin_meta.append((last_position, token_num)) + last_position += len(line) + + with open(generation_config.output_folder.joinpath(f"{ds_name}.bin.meta"), "wb") as file: + np.save(file, bin_meta) + + timer(f"dataset {ds_count}").stop() + ds_count += 1 + + +if __name__ == "__main__": + args = parse_args() + hostname = socket.gethostname() + + # initialize distributed environment + initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) + assert hasattr(gpc, "config") and gpc.config is not None + assert "generation" in gpc.config, f"Please set `generation` config in `{args.config}` file" + assert ( + "output_folder" in gpc.config["generation"] + ), "Must set `output_folder` for the save folder of generation data" + + # initialize monitor manager context + with initialize_monitor_manager( + job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address + ): + try: + main() + except Exception: + logger.error( + f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}", + ) + mm.monitor_exception( + alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc() + ) + + # internlm_accelerator.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") + finally: + # local rank0 delete all files in shm_path, when use shm + devices_per_node = internlm_accelerator.device_count() + local_rank = gpc.get_global_rank() % devices_per_node + if gpc.config.data.use_shm and local_rank == 0: + if os.path.exists(gpc.config.data.shm_path): + shutil.rmtree(gpc.config.data.shm_path) diff --git a/internlm/data/__init__.py b/internlm/data/__init__.py index 08ad5d884..35f6ade4a 100644 --- a/internlm/data/__init__.py +++ b/internlm/data/__init__.py @@ -1,4 +1,5 @@ from .build_dataloader import ( + build_generation_loader_with_data_type, build_train_loader_with_data_type, build_valid_loader_with_data_type, ) @@ -6,4 +7,5 @@ __all__ = [ "build_train_loader_with_data_type", "build_valid_loader_with_data_type", + "build_generation_loader_with_data_type", ] diff --git a/internlm/data/build_dataloader.py b/internlm/data/build_dataloader.py index 5af73b84b..aa09a9607 100644 --- a/internlm/data/build_dataloader.py +++ b/internlm/data/build_dataloader.py @@ -16,7 +16,11 @@ StaticBatchSampler, get_dpsampler_dataloader, ) -from internlm.data.tokenized.collaters import jsonl_ds_collate_fn, packed_collate_fn +from internlm.data.tokenized.collaters import ( + generation_collate_fn, + jsonl_ds_collate_fn, + packed_collate_fn, +) from internlm.data.tokenized.dataset import get_dataset_dict from internlm.data.tokenized.dummy_dataset import RandomDataset from internlm.data.tokenized.dummy_dataset_multimodal import RandomDatasetMultimodal @@ -213,3 +217,46 @@ def build_valid_loader_with_data_type(): ) return val_dls + + +def build_generation_loader_with_data_type(data_cfg, generation_cfg): + """Generate and return the validation data loader based on data type.""" + + if data_cfg.type == "tokenized": + gene_ds, _ = get_tokenized_valid_loader_items(data_cfg) + else: + raise ValueError(f"dataset type {data_cfg.type} is not supported") + + if gene_ds is None: + return None + + gene_dls = {} + for gene_name, ds in gene_ds.items(): + # making the batch_size of validate larger can speed up the evaluation, but it should not be too large, + # otherwise too much data may be dropped + batch_size = min( + data_cfg.valid_micro_num * data_cfg.micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA) + ) + batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz + if generation_cfg.batch_size: + batch_size = generation_cfg.batch_size + + if batch_size == 0 and gpc.is_rank_for_log(): + logger.info(f"skip validate {gene_name}.") + continue + + gene_dls[gene_name] = get_dpsampler_dataloader( + ds, + shuffle=False, + num_workers=data_cfg.get("num_worker", 0), + batch_size=batch_size, + collate_fn=partial(generation_collate_fn, pad_id=generation_cfg.pad_id), + ) + + if gpc.is_rank_for_log(): + logger.info( + f"load validation dataset {gene_name} with valid batch size {str(batch_size)} and " + f"samples {str(len(gene_dls[gene_name]))}." + ) + + return gene_dls diff --git a/internlm/data/tokenized/collaters.py b/internlm/data/tokenized/collaters.py index 785ecc60f..fab7c5ac7 100644 --- a/internlm/data/tokenized/collaters.py +++ b/internlm/data/tokenized/collaters.py @@ -100,3 +100,29 @@ def jsonl_ds_collate_fn(batch, max_length_per_sample): return {"input_ids": xs, "images": images}, ys else: return {"input_ids": xs}, ys + + +def generation_collate_fn(batch, pad_id=0): + """ + Collate function for generation dataset. + + Args: + batch (List[Dict]): List of dictionaries representing each sample in batch. + Each dictionary contains "tokens". + + Returns: + Tuple[Dict[str, torch.Tensor], torch.Tensor]: A tuple containing a dictionary of tensors with "input_ids", + and the tensor of padded "labels". + + """ + xs, ys = [], [] + for x in batch: + tokens = [abs(w) for w in x["tokens"]] + labels = [w if w > 0 else -100 for w in x["tokens"]] + labels = labels[1:] + [-100] + xs.append(torch.as_tensor(tokens[::-1])) + ys.append(torch.as_tensor(labels[::-1])) # y has been shifted + xs = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=pad_id).flip(dims=[1]) + ys = torch.nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-100).flip(dims=[1]) + + return {"input_ids": xs}, ys diff --git a/internlm/data/tokenized/dataset.py b/internlm/data/tokenized/dataset.py index e39a39b78..8991272b2 100644 --- a/internlm/data/tokenized/dataset.py +++ b/internlm/data/tokenized/dataset.py @@ -51,6 +51,6 @@ def get_dataset_dict(folder, split="valid") -> Dict: datasets.append(ds) if datasets: ds = ConcatDataset(datasets=datasets) - data_dict[os.path.basename(root)] = ds + data_dict[os.path.basename(root.rstrip(os.path.sep))] = ds return data_dict diff --git a/internlm/data/tokenized/packed_dataset.py b/internlm/data/tokenized/packed_dataset.py index 1d5259657..b2a8b1092 100644 --- a/internlm/data/tokenized/packed_dataset.py +++ b/internlm/data/tokenized/packed_dataset.py @@ -599,6 +599,7 @@ class PackedDatasetWithPadForMultimodal(PackedDataset): Args: dataset: The original dataset to pack. max_length_per_sample: The maximum length of each original sample. Default is 2048. + padding_side: The padding side. Default is "right". packed_length: The length of each packed sample. Default is 4096. padding_idx: The token id of padding. Default is 0. """ @@ -609,13 +610,17 @@ def __init__( max_length_per_sample: int = 2048, packed_length: int = 4096, padding_idx: int = 0, + padding_side: str = "right", image_token_id: int = 200000, + has_image: bool = True, ): super().__init__(dataset, max_length_per_sample, packed_length) self.padding_idx = padding_idx + self.padding_side = padding_side self.sample_indices, self.belongs = self.accu_sample_len(self.seed) self.num_tokens = sum(self.lengths) self.image_token_id = image_token_id + self.has_image = has_image def get_dataset_name(self): return self.dataset.get_dataset_name() @@ -653,7 +658,10 @@ def __len__(self): def build_pack(self, index): - pack, cu_seqlens, indexes, labels, type_ids, images = [], [0], [], [], [], [] + pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], [] + + if self.has_image: + images = [] start_pos = np.searchsorted(self.belongs, index, "left") end_pos = np.searchsorted(self.belongs, index, "right") @@ -665,8 +673,9 @@ def build_pack(self, index): for sample_idx in cur_samples: sample = self.dataset[sample_idx] length = min(len(sample["tokens"]), self.max_length_per_sample) - cur_images = sample["images"] - images.extend(cur_images) + if self.has_image: + cur_images = sample["images"] + images.extend(cur_images) chunk = sample["tokens"][:length] pack.extend(chunk) cu_seqlens.append(cu_seqlens[-1] + len(chunk)) @@ -680,10 +689,16 @@ def build_pack(self, index): indexes.extend(list(range(length))) if cu_seqlens[-1] != self.packed_length: - pack = pack + [self.padding_idx] * (self.packed_length - cu_seqlens[-1]) - labels = labels + [-100] * (self.packed_length - cu_seqlens[-1]) - type_ids = type_ids + [0] * (self.packed_length - cu_seqlens[-1]) - indexes.extend([0] * (self.packed_length - cu_seqlens[-1])) + if self.padding_side == "right": + pack = pack + [self.padding_idx] * (self.packed_length - cu_seqlens[-1]) + labels = labels + [-100] * (self.packed_length - cu_seqlens[-1]) + type_ids = type_ids + [0] * (self.packed_length - cu_seqlens[-1]) + indexes.extend([0] * (self.packed_length - cu_seqlens[-1])) + else: + pack = [self.padding_idx] * (self.packed_length - cu_seqlens[-1]) + pack + labels = [-100] * (self.packed_length - cu_seqlens[-1]) + labels + type_ids = [0] * (self.packed_length - cu_seqlens[-1]) + type_ids + indexes = [0] * (self.packed_length - cu_seqlens[-1]) + indexes cu_seqlens.append(self.packed_length) out = { diff --git a/tools/README.md b/tools/README.md index 2b47b1f40..a24040cae 100644 --- a/tools/README.md +++ b/tools/README.md @@ -5,7 +5,7 @@ ├── interface.py # 生成用的接口 ├── internlm_sft_on_moss.py # 在 moss 数据集上进行 SFT 训练的样例 ├── intern_moss_example.py # 在 moss 数据集上进行训练的样例 -├── load_internlm_model.py # 加载 InternLM 原生格式并进行推理的工具 +├── load_internlm2_model.py # 加载 InternLM 原生格式并进行推理的工具 ├── openai_api.py # 使用 OpenAI 接口实现的流式部署 ├── pal_inference.py # PAL 范式推理的工具 ├── README_EN.md @@ -141,3 +141,58 @@ if __name__ == "__main__": if hasattr(chunk.choices[0].delta, "content"): print(chunk.choices[0].delta.content, end="", flush=True) ``` + +# load_internlm2_model.py + +加载`InternEvo`框架训练的模型权重并进行推理 + +```bash +torchrun --master_port 12321 --nnodes=1 --node_rank=0 --nproc_per_node=1 --ckpt_dir=[where the internlm2 model weights are stored] --tokenizer_path=tools/tokenizer_internlm2.model tools/load_internlm2_model.py +``` + +LLaMA 7B推理的例子: + +```python + model = initialize_internlm_model( + model_type="LLAMA2", + ckpt_dir=args.ckpt_dir, + model_config=dict( + num_chunks=1, + checkpoint=0.2, + dtype="torch.bfloat16", + embed_split_hidden=True, + num_layers=32, + hidden_size=4096, + vocab_size=32000, + embed_grad_scale=1, + parallel_output=True, + num_attention_heads=32, + num_kv_attention_heads=32, + mlp_ratio=2.675, + use_flash_attn=True, + norm_type="rmsnorm", + apply_post_layer_norm=False, + no_bias=True, + layer_norm_epsilon=1e-5, + ), + del_model_prefix=True, + ) + + from sentencepiece import SentencePieceProcessor + + prompt = """<|User|>:{query}\n<|Bot|>:""" + prompt = prompt.replace("{query}", "hello") + # LLaMA tokenizer转换成SentencePieceProcessor 或 此处加载Huggingface Tokenizer,则需额外将generate中调用的decode等方法修改成HF风格 + tokenizer = SentencePieceProcessor(args.tokenizer_path) + generation_config = GenerationConfig() + output_generator = internlm_interactive_generation( + model=model, + tokenizer=tokenizer, + prompt=prompt, + generation_config=generation_config, + additional_eos_token_list=[tokenizer.eos_id()], + ) + + for text in output_generator: + print(text) +``` diff --git a/tools/README_EN.md b/tools/README_EN.md index 63aba4106..fe93560d7 100644 --- a/tools/README_EN.md +++ b/tools/README_EN.md @@ -6,7 +6,7 @@ This directory provide some tools for model training with the following file str ├── interface.py # interface for generation ├── internlm_sft_on_moss.py # example for SFT training on moss dataset ├── intern_moss_example.py # example for training on moss dataset -├── load_internlm_model.py # tools for loading InternLM checkpoints and generating +├── load_internlm2_model.py # tools for loading InternLM checkpoints and generating ├── openai_api.py # stream deployment with OpenAI APIs ├── pal_inference.py # tools for PAL reasoning ├── README_EN.md diff --git a/tools/load_internlm_model.py b/tools/load_internlm2_model.py similarity index 85% rename from tools/load_internlm_model.py rename to tools/load_internlm2_model.py index 3de52c22c..6f1561b0e 100644 --- a/tools/load_internlm_model.py +++ b/tools/load_internlm2_model.py @@ -1,3 +1,4 @@ +import argparse import inspect import logging import os @@ -9,9 +10,8 @@ from internlm.apis.inference import SequenceGenerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.initialize.launch import launch_from_torch -from internlm.model.registry import model_initializer -from internlm.train import initialize_model +from internlm.initialize.launch import initialize_distributed_env +from internlm.train import initialize_model, initialize_parallel_communicator from internlm.utils.storage_manager import get_fns, init_storage_manager, llm_load from tools.interface import GenerationConfig @@ -102,6 +102,10 @@ def match_fn_signature(func: Callable, args_dict: Dict) -> None: logger.warning(f"These args:{args_set} are popped for func:{func.__name__}.") +def use_torchrun_starter(): + return os.getenv("RANK") is not None + + def get_tp_rank() -> int: """Get the tensor parallel rank. This script uses torchrun to initialize the environment, so RANK in the environment variable is the tensor @@ -119,7 +123,7 @@ def get_tp_world_size() -> int: Returns: int: The tensor parallel world size to which the current process belongs. """ - return int(os.environ.get("WORLD_SIZE", 0)) + return int(os.environ.get("WORLD_SIZE", 1)) def initialize_internlm_model( @@ -173,27 +177,32 @@ def initialize_internlm_model( model_config["dtype"] = param_dtype model_config["parallel_output"] = False # FIXME: fix it. - match_fn_signature(model_initializer.get_module(model_type), model_config) if gpc.is_rank_for_log(): logger.info(f"model_config: {model_config}.") - launch_from_torch( + + initialize_distributed_env( config=dict( model_type=model_type, model=model_config, parallel=dict( zero1=dict(size=1, fsdp=False), pipeline=dict(size=1, interleaved_overlap=True), - tensor=get_tp_world_size(), + tensor=dict(size=get_tp_world_size(), mode="mtp"), sequence_parallel=0, ), ), + launcher="torch" if use_torchrun_starter() else "slurm", seed=seed, + master_port=23574, + args_check=False, ) - model = initialize_model() # Directly get the origin model without NativeAMP wrapper. + model = initialize_model() + _ = initialize_parallel_communicator(model) model = model.model state_dict = merge_pp_within_tp(ckpt_dir, del_model_prefix=del_model_prefix) + load_info = model.load_state_dict(state_dict, strict=False) logger.info(f"Rank:{gpc.get_local_rank(ParallelMode.TENSOR)}. Load info: {load_info}.") @@ -224,11 +233,11 @@ def internlm_interactive_generation( sequenece_generator = SequenceGenerator( decoder=model, eos_token_id=tokenizer.eos_id(), - pad_token_id=tokenizer.eos_id(), + pad_token_id=tokenizer.bos_id(), bos_token_id=tokenizer.bos_id(), additional_eos_token_list=additional_eos_token_list, ) - additional_eos_token_list = torch.LongTensor(additional_eos_token_list) + additional_eos_token_list = torch.LongTensor(additional_eos_token_list) if additional_eos_token_list else None input_ids = [tokenizer.bos_id()] + tokenizer.encode(prompt) input_ids = torch.LongTensor([input_ids]).to(get_model_device(model)) output_generator = sequenece_generator.streaming_generate( @@ -250,32 +259,48 @@ def internlm_interactive_generation( yield cur_output +def get_default_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_dir", type=str, help="path to the ckpt file", required=True) + parser.add_argument( + "--tokenizer_path", type=str, default="tools/tokenizer_internlm2.model", help="path to the tokenizer file" + ) + + return parser + + if __name__ == "__main__": + parser = get_default_parser() + args = parser.parse_args() + """ Here is a simple example to generate with origin internlm model architecture. Use the following command to run: - >>> torchrun --master_port 12331 --nnodes=1 --node_rank=0 --nproc_per_node=1 tools/load_internlm_model.py + >>> torchrun --master_port 12321 --nnodes=1 --node_rank=0 --nproc_per_node=1 tools/load_internlm2_model.py """ model = initialize_internlm_model( - model_type="INTERNLM", - ckpt_dir="[Please replace this with the directory where the internlm model weights are stored]", + model_type="INTERNLM2_PUBLIC", + ckpt_dir=args.ckpt_dir, model_config=dict( - checkpoint=False, - num_attention_heads=32, + num_chunks=1, + checkpoint=0.2, + dtype="torch.bfloat16", embed_split_hidden=True, - vocab_size=103168, - embed_grad_scale=1, - parallel_output=False, - hidden_size=4096, num_layers=32, - mlp_ratio=8 / 3, - apply_post_layer_norm=False, - dtype="torch.bfloat16", + hidden_size=4096, + vocab_size=92544, + embed_grad_scale=1, + parallel_output=True, + num_attention_heads=32, + num_kv_attention_heads=8, + mlp_ratio=3.5, + use_flash_attn=True, norm_type="rmsnorm", + qk_interleaved=True, + apply_post_layer_norm=False, + no_bias=True, layer_norm_epsilon=1e-5, - use_flash_attn=True, - num_chunks=1, - use_dynamic_ntk_rope=True, + rope_base=1000000, ), del_model_prefix=True, ) @@ -284,15 +309,14 @@ def internlm_interactive_generation( prompt = """<|User|>:{query}\n<|Bot|>:""" prompt = prompt.replace("{query}", "hello") - tokenizer = SentencePieceProcessor("tools/tokenizer_internlm.model") # pylint: disable=E1121 - + tokenizer = SentencePieceProcessor(args.tokenizer_path) # pylint: disable=E1121 generation_config = GenerationConfig() output_generator = internlm_interactive_generation( model=model, tokenizer=tokenizer, prompt=prompt, generation_config=generation_config, - additional_eos_token_list=[103028], + additional_eos_token_list=[tokenizer.eos_id()], ) for text in output_generator: diff --git a/web_demo_internlm.py b/web_demo_internlm.py index 8730c0c2b..abe0568e7 100644 --- a/web_demo_internlm.py +++ b/web_demo_internlm.py @@ -8,7 +8,7 @@ from internlm.accelerator import get_accelerator from tools.interface import GenerationConfig -from tools.load_internlm_model import ( +from tools.load_internlm2_model import ( initialize_internlm_model, internlm_interactive_generation, ) From aa3e9c4462517342b1eb0670b41b2867618f7ea6 Mon Sep 17 00:00:00 2001 From: Season Date: Tue, 16 Jul 2024 14:31:28 +0800 Subject: [PATCH 07/12] feat(singleton): ensure singleton thread safety and no performance degradation (#205) --- internlm/utils/common.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 82161c7d0..323613d78 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -5,6 +5,7 @@ import inspect import os import random +import threading from abc import ABC, abstractmethod from contextlib import contextmanager from datetime import datetime @@ -169,18 +170,27 @@ def __call__(self, batch_count): class SingletonMeta(type): """ - Singleton Meta. + Thread-safe Singleton Meta with double-checked locking. + Reference: https://en.wikipedia.org/wiki/Double-checked_locking """ _instances = {} + _lock = threading.Lock() def __call__(cls, *args, **kwargs): + # First check (without locking) for performance reasons if cls not in cls._instances: - cls._instances[cls] = super().__call__(*args, **kwargs) + # Acquire a lock before proceeding to the second check + with cls._lock: + # Second check with lock held to ensure thread safety + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance else: assert ( len(args) == 0 and len(kwargs) == 0 - ), f"{cls.__name__} is a singleton class and a instance has been created." + ), f"{cls.__name__} is a singleton class and an instance has been created." + return cls._instances[cls] From 0f87f47644813737794312d19d43363b187959b9 Mon Sep 17 00:00:00 2001 From: jiaxingli <43110891+li126com@users.noreply.github.com> Date: Tue, 16 Jul 2024 19:17:45 +0800 Subject: [PATCH 08/12] Fix(docker): update docker image and dockerfile for new version (#200) --- README-zh-Hans.md | 6 +++--- doc/en/install.md | 24 ++++++++++++++++-------- doc/install.md | 21 ++++++++++++++------- docker.Makefile | 24 ++++++++++-------------- docker/Dockerfile-centos | 15 +++++++++------ docker/Dockerfile-ubuntu | 15 +++++++++------ experiment/Dockerfile-centos | 23 +++++++++++++---------- experiment/Dockerfile-ubuntu | 23 +++++++++++++---------- experiment/README-CN.md | 12 +++--------- experiment/README-EN.md | 12 +++--------- 10 files changed, 93 insertions(+), 82 deletions(-) diff --git a/README-zh-Hans.md b/README-zh-Hans.md index 10768d650..237a50e28 100644 --- a/README-zh-Hans.md +++ b/README-zh-Hans.md @@ -17,9 +17,9 @@ [![使用文档](https://readthedocs.org/projects/internevo/badge/?version=latest)](https://internevo.readthedocs.io/zh_CN/latest/?badge=latest) [![license](./doc/imgs/license.svg)](./LICENSE) -[📘使用教程](./doc/en/usage.md) | -[🛠️安装指引](./doc/en/install.md) | -[📊框架性能](./doc/en/train_performance.md) | +[📘使用教程](./doc/usage.md) | +[🛠️安装指引](./doc/install.md) | +[📊框架性能](./doc/train_performance.md) | [🤔问题报告](https://github.com/InternLM/InternEvo/issues/new) [English](./README.md) | diff --git a/doc/en/install.md b/doc/en/install.md index 304d110a7..eae4a12c6 100644 --- a/doc/en/install.md +++ b/doc/en/install.md @@ -78,7 +78,10 @@ cd ../../../../ Install Apex (version 23.05): ```bash cd ./third_party/apex -pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ +# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key... +pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ +# otherwise +pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./ cd ../../ ``` @@ -88,31 +91,36 @@ pip install git+https://github.com/databricks/megablocks@v0.3.2 # MOE need ``` ### Environment Image -Users can use the provided dockerfile combined with docker.Makefile to build their own images, or obtain images with InternEvo runtime environment installed from https://hub.docker.com/r/internlm/internlm. +Users can use the provided dockerfile combined with docker.Makefile to build their own images, or obtain images with InternEvo runtime environment installed from https://hub.docker.com/r/internlm/internevo/tags. #### Image Configuration and Build The configuration and build of the Dockerfile are implemented through the docker.Makefile. To build the image, execute the following command in the root directory of InternEvo: ``` bash make -f docker.Makefile BASE_OS=centos7 ``` -In docker.Makefile, you can customize the basic image, environment version, etc., and the corresponding parameters can be passed directly through the command line. For BASE_OS, ubuntu20.04 and centos7 are respectively supported. +In docker.Makefile, you can customize the basic image, environment version, etc., and the corresponding parameters can be passed directly through the command line. The default is the recommended environment version. For BASE_OS, ubuntu20.04 and centos7 are respectively supported. #### Pull Standard Image The standard image based on ubuntu and centos has been built and can be directly pulled: ```bash # ubuntu20.04 -docker pull internlm/internlm:torch1.13.1-cuda11.7.1-flashatten1.0.5-ubuntu20.04 +docker pull internlm/internevo:torch2.1.0-cuda11.8.0-flashatten2.2.1-ubuntu20.04 # centos7 -docker pull internlm/internlm:torch1.13.1-cuda11.7.1-flashatten1.0.5-centos7 +docker pull internlm/internevo:torch2.1.0-cuda11.8.0-flashatten2.2.1-centos7 ``` #### Run Container For the local standard image built with dockerfile or pulled, use the following command to run and enter the container: ```bash -docker run --gpus all -it -m 500g --cap-add=SYS_PTRACE --cap-add=IPC_LOCK --shm-size 20g --network=host --name myinternlm internlm/internlm:torch1.13.1-cuda11.7.1-flashatten1.0.5-centos7 bash +docker run --gpus all -it -m 500g --cap-add=SYS_PTRACE --cap-add=IPC_LOCK --shm-size 20g --network=host --name internevo_centos internlm/internevo:torch2.1.0-cuda11.8.0-flashatten2.2.1-centos7 bash +``` + +#### Start Training +The default directory in the container is `/InternEvo`, please start training according to the [Usage](./usage.md). The default 7B model starts the single-machine with 8-GPU training command example as follows: +```bash +torchrun --nproc_per_node=8 --nnodes=1 train.py --config configs/7B_sft.py --launcher torch ``` -The default directory in the container is `/InternLM`, please start training according to the [Usage](./usage.md). ## Environment Installation (NPU) For machines with NPU, the version of the installation environment can refer to that of GPU. Use Ascend's torch_npu instead of torch on NPU machines. Additionally, Flash-Attention and Apex are no longer supported for installation on NPU. The corresponding functionalities have been internally implemented in the InternEvo codebase. The following tutorial is only for installing torch_npu. @@ -135,4 +143,4 @@ pip3 install pyyaml pip3 install setuptools wget https://gitee.com/ascend/pytorch/releases/download/v6.0.rc1-pytorch2.1.0/torch_npu-2.1.0.post3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl pip install torch_npu-2.1.0.post3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -``` \ No newline at end of file +``` diff --git a/doc/install.md b/doc/install.md index b894f8fa5..f14934726 100644 --- a/doc/install.md +++ b/doc/install.md @@ -78,7 +78,10 @@ cd ../../../../ 安装 Apex (version 23.05): ```bash cd ./third_party/apex -pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ +# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key... +pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ +# otherwise +pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./ cd ../../ ``` @@ -88,32 +91,36 @@ pip install git+https://github.com/databricks/megablocks@v0.3.2 # MOE相关 ``` ### 环境镜像 -用户可以使用提供的 dockerfile 结合 docker.Makefile 来构建自己的镜像,或者也可以从 https://hub.docker.com/r/internlm/internlm 获取安装了 InternEvo 运行环境的镜像。 +用户可以使用提供的 dockerfile 结合 docker.Makefile 来构建自己的镜像,或者也可以从 https://hub.docker.com/r/internlm/internevo/tags 获取安装了 InternEvo 运行环境的镜像。 #### 镜像配置及构造 dockerfile 的配置以及构造均通过 docker.Makefile 文件实现,在 InternEvo 根目录下执行如下命令即可 build 镜像: ``` bash make -f docker.Makefile BASE_OS=centos7 ``` -在 docker.Makefile 中可自定义基础镜像,环境版本等内容,对应参数可直接通过命令行传递。对于 BASE_OS 分别支持 ubuntu20.04 和 centos7。 +在 docker.Makefile 中可自定义基础镜像,环境版本等内容,对应参数可直接通过命令行传递,默认为推荐的环境版本。对于 BASE_OS 分别支持 ubuntu20.04 和 centos7。 #### 镜像拉取 基于 ubuntu 和 centos 的标准镜像已经 build 完成也可直接拉取使用: ```bash # ubuntu20.04 -docker pull internlm/internlm:torch1.13.1-cuda11.7.1-flashatten1.0.5-ubuntu20.04 +docker pull internlm/internevo:torch2.1.0-cuda11.8.0-flashatten2.2.1-ubuntu20.04 # centos7 -docker pull internlm/internlm:torch1.13.1-cuda11.7.1-flashatten1.0.5-centos7 +docker pull internlm/internevo:torch2.1.0-cuda11.8.0-flashatten2.2.1-centos7 ``` #### 容器启动 对于使用 dockerfile 构建或拉取的本地标准镜像,使用如下命令启动并进入容器: ```bash -docker run --gpus all -it -m 500g --cap-add=SYS_PTRACE --cap-add=IPC_LOCK --shm-size 20g --network=host --name myinternlm internlm/internlm:torch1.13.1-cuda11.7.1-flashatten1.0.5-centos7 bash +docker run --gpus all -it -m 500g --cap-add=SYS_PTRACE --cap-add=IPC_LOCK --shm-size 20g --network=host --name internevo_centos internlm/internevo:torch2.1.0-cuda11.8.0-flashatten2.2.1-centos7 bash ``` -容器内默认目录即 `/InternLM`,根据[使用文档](./usage.md)即可启动训练。 +#### 训练启动 +容器内默认目录即 `/InternEvo`,参考[使用文档](./usage.md)可获取具体使用方法。默认7B模型启动单机8卡训练命令样例: +```bash +torchrun --nproc_per_node=8 --nnodes=1 train.py --config configs/7B_sft.py --launcher torch +``` ## 环境安装(NPU) 在搭载NPU的机器上安装环境的版本可参考GPU,在NPU上使用昇腾torch_npu代替torch,同时Flash-Attention和Apex不再支持安装,相应功能已由InternEvo代码内部实现。以下教程仅为torch_npu安装。 diff --git a/docker.Makefile b/docker.Makefile index 7cfd55afe..2bcbbae02 100644 --- a/docker.Makefile +++ b/docker.Makefile @@ -1,12 +1,11 @@ DOCKER_REGISTRY ?= docker.io -DOCKER_ORG ?= my -DOCKER_IMAGE ?= internlm +DOCKER_ORG ?= internlm +DOCKER_IMAGE ?= internevo DOCKER_FULL_NAME = $(DOCKER_REGISTRY)/$(DOCKER_ORG)/$(DOCKER_IMAGE) -CUDA_VERSION = 11.7.1 -GCC_VERSION = 10.2.0 - +CUDA_VERSION = 11.8.0 CUDNN_VERSION = 8 + BASE_RUNTIME = # ubuntu20.04 centos7 BASE_OS = centos7 @@ -17,9 +16,10 @@ CUDA_CHANNEL = nvidia INSTALL_CHANNEL ?= pytorch PYTHON_VERSION ?= 3.10 -PYTORCH_VERSION ?= 1.13.1 -TORCHVISION_VERSION ?= 0.14.1 -TORCHAUDIO_VERSION ?= 0.13.1 +PYTORCH_TAG ?= 2.1.0 +PYTORCH_VERSION ?= 2.1.0+cu118 +TORCHVISION_VERSION ?= 0.16.0+cu118 +TORCHAUDIO_VERSION ?= 2.1.0+cu118 BUILD_PROGRESS ?= auto TRITON_VERSION ?= GMP_VERSION ?= 6.2.1 @@ -28,18 +28,14 @@ MPC_VERSION ?= 1.2.1 GCC_VERSION ?= 10.2.0 HTTPS_PROXY_I ?= HTTP_PROXY_I ?= -FLASH_ATTEN_VERSION ?= 1.0.5 +FLASH_ATTEN_VERSION ?= 2.2.1 FLASH_ATTEN_TAG ?= v${FLASH_ATTEN_VERSION} BUILD_ARGS = --build-arg BASE_IMAGE=$(BASE_IMAGE) \ --build-arg PYTHON_VERSION=$(PYTHON_VERSION) \ - --build-arg CUDA_VERSION=$(CUDA_VERSION) \ - --build-arg CUDA_CHANNEL=$(CUDA_CHANNEL) \ --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) \ --build-arg TORCHVISION_VERSION=$(TORCHVISION_VERSION) \ --build-arg TORCHAUDIO_VERSION=$(TORCHAUDIO_VERSION) \ - --build-arg INSTALL_CHANNEL=$(INSTALL_CHANNEL) \ - --build-arg TRITON_VERSION=$(TRITON_VERSION) \ --build-arg GMP_VERSION=$(GMP_VERSION) \ --build-arg MPFR_VERSION=$(MPFR_VERSION) \ --build-arg MPC_VERSION=$(MPC_VERSION) \ @@ -98,7 +94,7 @@ all: devel-image .PHONY: devel-image devel-image: BASE_IMAGE := $(BASE_DEVEL) -devel-image: DOCKER_TAG := torch${PYTORCH_VERSION}-cuda${CUDA_VERSION}-flashatten${FLASH_ATTEN_VERSION}-${BASE_OS} +devel-image: DOCKER_TAG := torch${PYTORCH_TAG}-cuda${CUDA_VERSION}-flashatten${FLASH_ATTEN_VERSION}-${BASE_OS} devel-image: $(DOCKER_BUILD) diff --git a/docker/Dockerfile-centos b/docker/Dockerfile-centos index 9a8f8e5bd..7b2a0fd0f 100644 --- a/docker/Dockerfile-centos +++ b/docker/Dockerfile-centos @@ -107,18 +107,18 @@ ENV CXX=${GCC_HOME}/bin/c++ ############################################################################## -# Install InternLM development environment, including flash-attention and apex +# Install InternEvo development environment, including flash-attention and apex ############################################################################## FROM dep as intrenlm-dev -COPY . /InternLM -WORKDIR /InternLM +COPY . /InternEvo +WORKDIR /InternEvo ARG https_proxy ARG http_proxy ARG TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" RUN git submodule update --init --recursive \ && /opt/conda/bin/pip --no-cache-dir install -r requirements/torch.txt \ && /opt/conda/bin/pip --no-cache-dir install -r requirements/runtime.txt \ - && cd /InternLM/third_party/flash-attention \ + && cd /InternEvo/third_party/flash-attention \ && /opt/conda/bin/python setup.py install \ && cd ./csrc \ && cd fused_dense_lib && /opt/conda/bin/pip install -v . \ @@ -127,6 +127,9 @@ RUN git submodule update --init --recursive \ && cd ../layer_norm && /opt/conda/bin/pip install -v . \ && cd ../../../../ \ && cd ./third_party/apex \ - && /opt/conda/bin/pip --no-cache-dir install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ \ + && /opt/conda/bin/pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ \ + && /opt/conda/bin/pip install pytorch-extension \ && /opt/conda/bin/pip cache purge \ - && rm -rf ~/.cache/pip + && rm -rf ~/.cache/pip \ + && /opt/conda/bin/conda init \ + && . ~/.bashrc diff --git a/docker/Dockerfile-ubuntu b/docker/Dockerfile-ubuntu index da16f5601..8c4293819 100644 --- a/docker/Dockerfile-ubuntu +++ b/docker/Dockerfile-ubuntu @@ -88,18 +88,18 @@ ENV CXX=${GCC_HOME}/bin/c++ ############################################################################## -# Install InternLM development environment, including flash-attention and apex +# Install InternEvo development environment, including flash-attention and apex ############################################################################## FROM dep as intrenlm-dev -COPY . /InternLM -WORKDIR /InternLM +COPY . /InternEvo +WORKDIR /InternEvo ARG https_proxy ARG http_proxy ARG TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" RUN git submodule update --init --recursive \ && /opt/conda/bin/pip --no-cache-dir install -r requirements/torch.txt \ && /opt/conda/bin/pip --no-cache-dir install -r requirements/runtime.txt \ - && cd /InternLM/third_party/flash-attention \ + && cd /InternEvo/third_party/flash-attention \ && /opt/conda/bin/python setup.py install \ && cd ./csrc \ && cd fused_dense_lib && /opt/conda/bin/pip install -v . \ @@ -108,6 +108,9 @@ RUN git submodule update --init --recursive \ && cd ../layer_norm && /opt/conda/bin/pip install -v . \ && cd ../../../../ \ && cd ./third_party/apex \ - && /opt/conda/bin/pip --no-cache-dir install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ \ + && /opt/conda/bin/pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ \ + && /opt/conda/bin/pip install pytorch-extension \ && /opt/conda/bin/pip cache purge \ - && rm -rf ~/.cache/pip + && rm -rf ~/.cache/pip \ + && /opt/conda/bin/conda init \ + && . ~/.bashrc diff --git a/experiment/Dockerfile-centos b/experiment/Dockerfile-centos index 4ac9c64ef..e967a1c66 100644 --- a/experiment/Dockerfile-centos +++ b/experiment/Dockerfile-centos @@ -106,11 +106,11 @@ ENV CXX=${GCC_HOME}/bin/c++ ############################################################################## -# Install InternLM development environment, including flash-attention and apex +# Install InternEvo development environment, including flash-attention and apex ############################################################################## FROM dep as intrenlm-dev -COPY . /InternLM -WORKDIR /InternLM +COPY . /InternEvo +WORKDIR /InternEvo ARG https_proxy ARG http_proxy ARG PYTORCH_VERSION @@ -134,11 +134,11 @@ RUN /opt/conda/bin/pip --no-cache-dir install \ torch-scatter \ pyecharts \ py-libnuma \ - -f https://data.pyg.org/whl/torch-${PYTORCH_VERSION}+cu117.html \ + -f https://data.pyg.org/whl/torch-${PYTORCH_VERSION}.html \ && /opt/conda/bin/pip --no-cache-dir install \ - --extra-index-url https://download.pytorch.org/whl/cu117 \ - torch==${PYTORCH_VERSION}+cu117 \ - torchvision==${TORCHVISION_VERSION}+cu117 \ + --extra-index-url https://download.pytorch.org/whl/cu118 \ + torch==${PYTORCH_VERSION} \ + torchvision==${TORCHVISION_VERSION} \ torchaudio==${TORCHAUDIO_VERSION} ARG https_proxy @@ -147,7 +147,7 @@ ARG TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" ARG FLASH_ATTEN_TAG RUN git submodule update --init --recursive \ - && cd /InternLM/third_party/flash-attention \ + && cd /InternEvo/third_party/flash-attention \ && git checkout ${FLASH_ATTEN_TAG} \ && /opt/conda/bin/python setup.py install \ && cd ./csrc \ @@ -157,6 +157,9 @@ RUN git submodule update --init --recursive \ && cd ../layer_norm && /opt/conda/bin/pip install -v . \ && cd ../../../../ \ && cd ./third_party/apex \ - && /opt/conda/bin/pip --no-cache-dir install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ \ + && /opt/conda/bin/pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ \ + && /opt/conda/bin/pip install pytorch-extension \ && /opt/conda/bin/pip cache purge \ - && rm -rf ~/.cache/pip + && rm -rf ~/.cache/pip \ + && /opt/conda/bin/conda init \ + && . ~/.bashrc diff --git a/experiment/Dockerfile-ubuntu b/experiment/Dockerfile-ubuntu index 055f9a620..799457021 100644 --- a/experiment/Dockerfile-ubuntu +++ b/experiment/Dockerfile-ubuntu @@ -87,11 +87,11 @@ ENV CXX=${GCC_HOME}/bin/c++ ############################################################################## -# Install InternLM development environment, including flash-attention and apex +# Install InternEvo development environment, including flash-attention and apex ############################################################################## FROM dep as intrenlm-dev -COPY . /InternLM -WORKDIR /InternLM +COPY . /InternEvo +WORKDIR /InternEvo ARG https_proxy ARG http_proxy ARG PYTORCH_VERSION @@ -115,11 +115,11 @@ RUN /opt/conda/bin/pip --no-cache-dir install \ torch-scatter \ pyecharts \ py-libnuma \ - -f https://data.pyg.org/whl/torch-${PYTORCH_VERSION}+cu117.html \ + -f https://data.pyg.org/whl/torch-${PYTORCH_VERSION}.html \ && /opt/conda/bin/pip --no-cache-dir install \ - --extra-index-url https://download.pytorch.org/whl/cu117 \ - torch==${PYTORCH_VERSION}+cu117 \ - torchvision==${TORCHVISION_VERSION}+cu117 \ + --extra-index-url https://download.pytorch.org/whl/cu118 \ + torch==${PYTORCH_VERSION} \ + torchvision==${TORCHVISION_VERSION} \ torchaudio==${TORCHAUDIO_VERSION} ARG https_proxy @@ -128,7 +128,7 @@ ARG TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" ARG FLASH_ATTEN_TAG RUN git submodule update --init --recursive \ - && cd /InternLM/third_party/flash-attention \ + && cd /InternEvo/third_party/flash-attention \ && git checkout ${FLASH_ATTEN_TAG} \ && /opt/conda/bin/python setup.py install \ && cd ./csrc \ @@ -138,6 +138,9 @@ RUN git submodule update --init --recursive \ && cd ../layer_norm && /opt/conda/bin/pip install -v . \ && cd ../../../../ \ && cd ./third_party/apex \ - && /opt/conda/bin/pip --no-cache-dir install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ \ + && /opt/conda/bin/pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ \ + && /opt/conda/bin/pip install pytorch-extension \ && /opt/conda/bin/pip cache purge \ - && rm -rf ~/.cache/pip + && rm -rf ~/.cache/pip \ + && /opt/conda/bin/conda init \ + && . ~/.bashrc diff --git a/experiment/README-CN.md b/experiment/README-CN.md index 7fee559bb..de56039b7 100644 --- a/experiment/README-CN.md +++ b/experiment/README-CN.md @@ -2,24 +2,18 @@ 本模块用于测试新版本环境,默认测试新环境 torch=2.0.1,flash-attention=2.1.0。新环境可能具有不稳定性,标准环境安装请参考:[安装文档](../doc/install.md) ### 镜像构建及拉取 -构建镜像时请于 InternLM 根目录下执行 docker.Makefile,该文件与标准环境镜像共用,所使用的 Dockerfile 位于 experiment 目录下。也可直接从 https://hub.docker.com/r/internlm/internlm 拉取镜像,命令如下: +构建镜像时请于 InternEvo 根目录下执行 docker.Makefile,该文件与标准环境镜像共用,所使用的 Dockerfile 位于 experiment 目录下。也可直接从 https://hub.docker.com/r/internlm/internevo/tags 拉取镜像,命令如下: ```bash # 构建镜像 # ubuntu20.04 make -f docker.Makefile BASE_OS=ubuntu20.04 DOCKERFILE_PATH=./experiment/Dockerfile-ubuntu PYTORCH_VERSION=2.0.1 TORCHVISION_VERSION=0.15.2 TORCHAUDIO_VERSION=2.0.2 FLASH_ATTEN_VERSION=2.1.0 # centos7 make -f docker.Makefile BASE_OS=centos7 DOCKERFILE_PATH=./experiment/Dockerfile-centos PYTORCH_VERSION=2.0.1 TORCHVISION_VERSION=0.15.2 TORCHAUDIO_VERSION=2.0.2 FLASH_ATTEN_VERSION=2.1.0 - -# 拉取镜像 -# ubuntu20.04 -docker pull internlm/internlm:experiment-torch2.0.1-flashatten2.1.0-ubuntu20.04 -# centos7 -docker pull internlm/internlm:experiment-torch2.0.1-flashatten2.1.0-centos7 ``` ### 容器启动 对于使用 dockerfile 构建或拉取的本地标准镜像,使用如下命令启动并进入容器: ```bash -docker run --gpus all -it -m 500g --cap-add=SYS_PTRACE --cap-add=IPC_LOCK --shm-size 20g --network=host --name myinternlm internlm/internlm:experiment-torch2.0.1-flashatten2.1.0-centos7 bash +docker run --gpus all -it -m 500g --cap-add=SYS_PTRACE --cap-add=IPC_LOCK --shm-size 20g --network=host --name myinternlm internlm/Internevo:experiment-torch2.0.1-flashatten2.1.0-centos7 bash ``` -容器内默认目录即 `/InternLM`,根据[使用文档](../doc/usage.md)即可启动训练。 +容器内默认目录即 `/InternEvo`,根据[使用文档](../doc/usage.md)即可启动训练。 diff --git a/experiment/README-EN.md b/experiment/README-EN.md index f68efc86c..8f4daf6a9 100644 --- a/experiment/README-EN.md +++ b/experiment/README-EN.md @@ -2,24 +2,18 @@ This module is used to test the new version environment, the default test new environment is torch=2.0.1, flash-attention=2.1.0. The new environment may be unstable, for the standard environment installation please refer to: [installation guide](../doc/en/install.md) ### Build and Pull Image -When building the image, please make docker.Makefile in the InternLM root directory. This Makefile is shared with the standard environment image, and the Dockerfile used is located in the experiment directory. You can also pull the image directly from https://hub.docker.com/r/internlm/internlm, the command is as follows: +When building the image, please make docker.Makefile in the InternEvo root directory. This Makefile is shared with the standard environment image, and the Dockerfile used is located in the experiment directory. You can also pull the image directly from https://hub.docker.com/r/internlm/internevo/tags, the command is as follows: ```bash # Build Image # ubuntu20.04 make -f docker.Makefile BASE_OS=ubuntu20.04 DOCKERFILE_PATH=./experiment/Dockerfile-ubuntu PYTORCH_VERSION=2.0.1 TORCHVISION_VERSION=0.15.2 TORCHAUDIO_VERSION=2.0.2 FLASH_ATTEN_VERSION=2.1.0 # centos7 make -f docker.Makefile BASE_OS=centos7 DOCKERFILE_PATH=./experiment/Dockerfile-centos PYTORCH_VERSION=2.0.1 TORCHVISION_VERSION=0.15.2 TORCHAUDIO_VERSION=2.0.2 FLASH_ATTEN_VERSION=2.1.0 - -# Pull Image -# ubuntu20.04 -docker pull internlm/internlm:experiment-torch2.0.1-flashatten2.1.0-ubuntu20.04 -# centos7 -docker pull internlm/internlm:experiment-torch2.0.1-flashatten2.1.0-centos7 ``` ### Run Container For the local standard image built with dockerfile or pulled, use the following command to run and enter the container: ```bash -docker run --gpus all -it -m 500g --cap-add=SYS_PTRACE --cap-add=IPC_LOCK --shm-size 20g --network=host --name myinternlm internlm/internlm:experiment-torch2.0.1-flashatten2.1.0-centos7 bash +docker run --gpus all -it -m 500g --cap-add=SYS_PTRACE --cap-add=IPC_LOCK --shm-size 20g --network=host --name myinternlm internlm/Internevo:experiment-torch2.0.1-flashatten2.1.0-centos7 bash ``` -The default directory in the container is `/InternLM`, please start training according to the [Usage](../doc/en/usage.md). +The default directory in the container is `/InternEvo`, please start training according to the [Usage](../doc/en/usage.md). From 19d00ac9a1a69fbad11c9bc1717d8f860aa22bda Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Tue, 16 Jul 2024 19:25:50 +0800 Subject: [PATCH 09/12] feat(config): add 1.8B config for 16 experts (#239) --- configs/1.8B_MoE16_sft.py | 237 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 237 insertions(+) create mode 100644 configs/1.8B_MoE16_sft.py diff --git a/configs/1.8B_MoE16_sft.py b/configs/1.8B_MoE16_sft.py new file mode 100644 index 000000000..62babb450 --- /dev/null +++ b/configs/1.8B_MoE16_sft.py @@ -0,0 +1,237 @@ +JOB_NAME = "1.8b_moe_train" +DO_ALERT = False + +SEQ_LEN = 2048 +HIDDEN_SIZE = 1024 +NUM_ATTENTION_HEAD = 16 +MLP_RATIO = 1.5 +NUM_LAYER = 24 +VOCAB_SIZE = 92544 +MULTIPLE_OF = 128 + +MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" +# Ckpt folder format: +# fs: 'local:/mnt/nfs/XXX' +SAVE_CKPT_FOLDER = "local:llm_ckpts" +LOAD_CKPT_FOLDER = "local:llm_ckpts/49" + +# boto3 Ckpt folder format: +# import os +# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint +# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" +# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" +CHECKPOINT_EVERY = 50 +ckpt = dict( + enable_save_ckpt=False, # enable ckpt save. + save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"), + load_ckpt_folder="local:llm_ckpts/", + # 'load_ckpt_info' setting guide: + # 1. the 'path' indicate ckpt path, + # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "llama", "hf_llama". + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"), + # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering + # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) + # with an automatic restart mechanism upon training reboot. + # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint + # path specified in `load_ckpt_info` by default. + # If you want to initialize your model weights from another model, you must set `auto_resume` to False. + # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. + auto_resume=True, + checkpoint_every=CHECKPOINT_EVERY, + async_upload=True, # async ckpt upload. (only work for boto3 ckpt) + async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. + oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. +) + +TRAIN_FOLDER = None # "/path/to/dataset" +VALID_FOLDER = None # "/path/to/dataset" +data = dict( + seq_len=SEQ_LEN, + # micro_num means the number of micro_batch contained in one gradient update + micro_num=4, + # packed_length = micro_bsz * SEQ_LEN + micro_bsz=2, + # defaults to the value of micro_num + valid_micro_num=4, + # defaults to 0, means disable evaluate + valid_every=5000, + pack_sample_into_one=False, + total_steps=5000, + skip_batches="", + # rampup_batch_size (str): A string with three space-separated integers representing the + # starting batch size, the increment, and the number of steps between + # each increment. For example, "192 24 8" means that the batch size (micro_num) + # starts at 192 and increases by 24 every 8 steps. Defaults to None. + # (IMPORTANT): The interval step size is 'micro_bsz'. + rampup_batch_size="", + # Datasets with less than 50 rows will be discarded + min_length=50, + train_folder=TRAIN_FOLDER, + valid_folder=VALID_FOLDER, + empty_cache_and_diag_interval=200, + diag_outlier_ratio=1.1, +) + +grad_scaler = dict( + fp16=dict( + # the initial loss scale, defaults to 2**16 + initial_scale=2**16, + # the minimum loss scale, defaults to None + min_scale=1, + # the number of steps to increase loss scale when no overflow occurs + growth_interval=1000, + ), + # the multiplication factor for increasing loss scale, defaults to 2 + growth_factor=2, + # the multiplication factor for decreasing loss scale, defaults to 0.5 + backoff_factor=0.5, + # the maximum loss scale, defaults to None + max_scale=2**24, + # the number of overflows before decreasing loss scale, defaults to 2 + hysteresis=2, +) + +hybrid_zero_optimizer = dict( + # Enable low_level_optimzer overlap_communication + overlap_sync_grad=False, + overlap_sync_param=False, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + +loss = dict( + label_smoothing=0, + moe_loss_coeff=0.1, +) + +adam = dict( + lr=1e-4, + adam_beta1=0.9, + adam_beta2=0.95, + adam_beta2_c=0, + adam_eps=1e-8, + weight_decay=0.01, +) + +lr_scheduler = dict( + total_steps=data["total_steps"], + init_steps=0, # optimizer_warmup_step + warmup_ratio=0.01, + eta_min=1e-5, + last_epoch=-1, +) + +beta2_scheduler = dict( + init_beta2=adam["adam_beta2"], + c=adam["adam_beta2_c"], + cur_iter=-1, +) + +use_fp32_norm = False +model = dict( + checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] + num_attention_heads=NUM_ATTENTION_HEAD, + embed_split_hidden=True, + vocab_size=VOCAB_SIZE, + embed_grad_scale=1, + parallel_output=False, + hidden_size=HIDDEN_SIZE, + num_layers=NUM_LAYER, + mlp_ratio=MLP_RATIO, + apply_post_layer_norm=False, + dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" + norm_type="rmsnorm", + layer_norm_epsilon=1e-5, + use_flash_attn=True, + multiple_of=MULTIPLE_OF, + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, + num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. + num_experts=16, + moe_use_residual=False, + moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-D" +) +""" +zero1 parallel (dict): + 1. size: int + * if size <= 0, the size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. + * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. + For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. + 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. +tensor parallel (dict): + 1. size: int, the size of tensor parallel. + 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], + defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. + msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. + fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. + isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. +pipeline parallel (dict): + 1. size: int, the size of pipeline parallel. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, + defaults to False. +weight parallel (dict): + 1. size: int, the size of weight parallel. + 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. memory_pool: bool, enable/disable memory pool, defaults to False. +""" +parallel = dict( + zero1=dict(size=-1, fsdp=False), + tensor=dict(size=1, mode="mtp"), + pipeline=dict(size=1, interleaved_overlap=True), + weight=dict(size=1, overlap=True, memory_pool=True), +) + +cudnn_deterministic = False +cudnn_benchmark = False + +monitor = dict( + # feishu alert configs + alert=dict( + enable_feishu_alert=DO_ALERT, + feishu_alert_address=None, # feishu webhook to send alert message + light_monitor_address=None, # light_monitor address to send heartbeat + alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", + ), + tensorboard=dict( + queue_max_length=10, + ), +) + +# custom moe impl configs +# GShard MoE config +# moe = dict( +# top_k=2, +# capacity_factor=1.0, +# eval_capacity_factor=1.0, +# min_capacity=4, +# noisy_gate_policy=None, +# drop_tokens=True, +# use_rts=True, +# use_fused_gating=False, +# ) + +# MegaBlock MoE config +moe = dict( + top_k=2, + # capacity_factor=1.0, # only used in MegaBlock(non-dmoe) + # drop_tokens=True, # only used in MegaBlock(non-dmoe) + # parallel_mode="tensor", # only used in MegaBlock-D(dmoe), parallel_mode can be tensor or weight +) + +model_type = "INTERNLM_MoE" + +# metric_dtype can be "fp32" or other string +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" From 2c6df5cdbb70fc1bf5466c23d1cc9ffc14c94a6a Mon Sep 17 00:00:00 2001 From: Season Date: Tue, 16 Jul 2024 20:34:07 +0800 Subject: [PATCH 10/12] fix(huggingface): fix huggingface dataloader when using some huggingface third-party tokenizers (#277) --- internlm/data/build_dataloader.py | 11 ++++++++-- internlm/data/streaming/collaters.py | 33 ++++++++++++++++++---------- internlm/data/streaming/dataset.py | 19 +++++++++------- 3 files changed, 41 insertions(+), 22 deletions(-) diff --git a/internlm/data/build_dataloader.py b/internlm/data/build_dataloader.py index aa09a9607..a95e3c98b 100644 --- a/internlm/data/build_dataloader.py +++ b/internlm/data/build_dataloader.py @@ -125,15 +125,22 @@ def get_hf_train_loader_items(data_cfg): model_max_length=data_cfg.seq_len, subset_name=data_cfg.get("subset_name", None), ) + pad_token_id = gpc.config.model.get("pad_token_id", 0) if gpc.config.model_type == "hf" and not data_cfg.use_packed_dataset: train_sampler = StreamingStaticBatchSampler( batch_size=data_cfg.micro_num * data_cfg.micro_bsz, rampup_batch_size=data_cfg.rampup_batch_size ) train_collate_fn = partial( - nopack_collate_fn, micro_num=data_cfg.micro_num, micro_bsz=data_cfg.micro_bsz, seq_len=data_cfg.seq_len + nopack_collate_fn, + micro_num=data_cfg.micro_num, + micro_bsz=data_cfg.micro_bsz, + seq_len=data_cfg.seq_len, + pad_token_id=pad_token_id, ) else: - train_ds = HuggingFacePackedDataset(dataset=train_ds, seq_len=data_cfg.seq_len, micro_bsz=data_cfg.micro_bsz) + train_ds = HuggingFacePackedDataset( + dataset=train_ds, seq_len=data_cfg.seq_len, micro_bsz=data_cfg.micro_bsz, pad_token_id=pad_token_id + ) train_sampler = StreamingStaticBatchSampler( batch_size=data_cfg.micro_num, rampup_batch_size=data_cfg.rampup_batch_size ) diff --git a/internlm/data/streaming/collaters.py b/internlm/data/streaming/collaters.py index 4391fd236..01b9ba866 100644 --- a/internlm/data/streaming/collaters.py +++ b/internlm/data/streaming/collaters.py @@ -1,25 +1,34 @@ import torch -def nopack_collate_fn(batch, micro_num, micro_bsz, seq_len): +def nopack_collate_fn(batch, micro_num, micro_bsz, seq_len, pad_token_id=0): input_ids_list = [] attention_mask_list = [] labels_list = [] + for b in batch: - attention_mask = torch.tensor(b["attention_mask"]) - input_ids = torch.LongTensor(b["input_ids"]) - input_ids = torch.abs(input_ids * attention_mask) - input_ids = torch.nn.functional.pad(input_ids, (0, seq_len - len(input_ids)), mode="constant", value=0) - attention_mask = torch.nn.functional.pad( - attention_mask, (0, seq_len - len(attention_mask)), mode="constant", value=0 - ) - label = torch.LongTensor([w if w > 0 else -100 for w in input_ids.tolist()][1:] + [-100]) - input_ids_list.append(input_ids) - attention_mask_list.append(attention_mask) - labels_list.append(label) + assert len(b["input_ids"]) > 0 + + if "attention_mask" in b: + assert len(b["input_ids"]) == len( + b["attention_mask"] + ), "input_ids and attention_mask should be equal length" + else: + b["attention_mask"] = [True] * len(b["input_ids"]) + + input_ids = b["input_ids"] + [pad_token_id] * (seq_len - len(b["input_ids"])) + attention_mask = b["attention_mask"] + [False] * (seq_len - len(b["attention_mask"])) + labels = [w if w > 0 else -100 for w in b["input_ids"]][1:] + [-100] + labels = labels + [-100] * (seq_len - len(b["input_ids"])) + + input_ids_list.append(torch.LongTensor(input_ids)) + attention_mask_list.append(torch.BoolTensor(attention_mask)) + labels_list.append(torch.LongTensor(labels)) + input_ids = torch.stack(input_ids_list) attention_mask = torch.stack(attention_mask_list) labels = torch.stack(labels_list) + return { "input_ids": input_ids, "attention_mask": attention_mask, diff --git a/internlm/data/streaming/dataset.py b/internlm/data/streaming/dataset.py index a3844d706..681787361 100644 --- a/internlm/data/streaming/dataset.py +++ b/internlm/data/streaming/dataset.py @@ -47,7 +47,9 @@ def _tokenize(self, samples): texts = [sample["text"] for sample in samples] tokenized_outputs = self.tokenizer(texts, truncation=True) for i in range(len(samples)): - yield {key: tokenized_outputs[key][i] for key in tokenized_outputs} + assert "input_ids" in tokenized_outputs, "huggingface tokenizer should generate input_ids" + if len(tokenized_outputs["input_ids"][i]) > 0: + yield {key: tokenized_outputs[key][i] for key in tokenized_outputs} def __getitem__(self, _): return next(self.senior_iterator) @@ -55,14 +57,14 @@ def __getitem__(self, _): class HuggingFacePackedDataset(Dataset): """ - Simple packed dataset for huggingface. + Simple packed dataset for huggingface """ - def __init__(self, dataset, seq_len, micro_bsz): + def __init__(self, dataset, seq_len, micro_bsz, pad_token_id=0): self.dataset = dataset self.seq_len = seq_len self.micro_bsz = micro_bsz - + self.pad_token_id = pad_token_id self.senior_iterator = iter(self) def __iter__(self): @@ -72,7 +74,7 @@ def __iter__(self): for sample in self.dataset: if len(input_ids + sample["input_ids"]) > self.micro_bsz * self.seq_len: assert cu_seqlens[-1] <= self.micro_bsz * self.seq_len - input_ids = input_ids + [0] * (self.micro_bsz * self.seq_len - len(input_ids)) + input_ids = input_ids + [self.pad_token_id] * (self.micro_bsz * self.seq_len - len(input_ids)) cu_seqlens = ( cu_seqlens + [self.micro_bsz * self.seq_len] if cu_seqlens[-1] < self.micro_bsz * self.seq_len @@ -89,14 +91,15 @@ def __iter__(self): } input_ids = sample["input_ids"] cu_seqlens = [0, len(sample["input_ids"])] - labels = sample["input_ids"][1:] + [-100] + labels = [w if w > 0 else -100 for w in sample["input_ids"]][1:] + [-100] else: input_ids = input_ids + sample["input_ids"] cu_seqlens.append(len(sample["input_ids"]) + cu_seqlens[-1]) - labels = labels + sample["input_ids"][1:] + [-100] + labels = labels + [w if w > 0 else -100 for w in sample["input_ids"]][1:] + [-100] + if input_ids: assert cu_seqlens[-1] <= self.micro_bsz * self.seq_len - input_ids = input_ids + [0] * (self.micro_bsz * self.seq_len - len(input_ids)) + input_ids = input_ids + [self.pad_token_id] * (self.micro_bsz * self.seq_len - len(input_ids)) cu_seqlens = ( cu_seqlens + [self.micro_bsz * self.seq_len] if cu_seqlens[-1] < self.micro_bsz * self.seq_len From 7cd091c9f046124f185d6d5c40b5e486f3e10fe3 Mon Sep 17 00:00:00 2001 From: cx <759046501@qq.com> Date: Wed, 17 Jul 2024 19:36:57 +0800 Subject: [PATCH 11/12] feat(*): re-impl embedding/head of isp version (#261) Co-authored-by: huangting4201 <1538303371@qq.com> --- internlm/checkpoint/components.py | 64 ++++----- internlm/core/context/__init__.py | 2 - internlm/core/context/parallel_context.py | 8 +- .../core/context/process_group_initializer.py | 63 ++++++++ internlm/core/parallel/comm/isp.py | 126 ++++++++++++++++ internlm/core/parallel/comm/tensor.py | 4 +- internlm/core/parallel/comm/utils.py | 33 +++++ internlm/core/parallel/shard.py | 24 +++- internlm/core/trainer_builder.py | 8 +- internlm/data/utils.py | 6 - internlm/initialize/initialize_trainer.py | 27 +++- internlm/model/metrics.py | 12 ++ internlm/model/modules/embedding.py | 5 +- internlm/model/modules/linear.py | 2 +- internlm/model/ops/cross_entropy.py | 136 ++++++++++++++++++ .../solver/optimizer/hybrid_zero_optim.py | 24 ++-- .../solver/optimizer/hybrid_zero_optim_v2.py | 5 - internlm/solver/optimizer/utils.py | 20 +-- internlm/train/pipeline.py | 25 ++-- internlm/train/utils.py | 9 +- internlm/utils/parallel.py | 10 -- tests/test_training/test_loss.py | 31 ++-- 22 files changed, 504 insertions(+), 140 deletions(-) diff --git a/internlm/checkpoint/components.py b/internlm/checkpoint/components.py index 25435d3c6..67837a3e9 100644 --- a/internlm/checkpoint/components.py +++ b/internlm/checkpoint/components.py @@ -100,7 +100,7 @@ def load_model_checkpoint(folder, model): If tensor parallel mode is isp, the saved weight is named: - folder - - model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt + - model_wp{wp_rank}_pp{pp_rank}.pt If fsdp is activated, the saved weight is named: - folder @@ -122,19 +122,19 @@ def load_model_checkpoint(folder, model): fns = get_fns(folder) # avoid ckpt misuse between FSDP and no-FSDP - test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop() + _start_with = "model_w" if is_using_isp() else "model_t" + test_fn = list([f for f in fns if f.startswith(_start_with) and not f.endswith(".md5")]).pop() assert ("_dp" in test_fn and gpc.config.parallel.zero1.fsdp) or ( "_dp" not in test_fn and not gpc.config.parallel.zero1.fsdp ), "FSDP model wants to load no-FSDP ckpts or reverse" max_pp, max_wp, max_tp, max_zo = 0, 0, 0, 0 for fn in fns: - if fn.startswith("model_t") and not fn.endswith(".md5"): + if fn.startswith(_start_with) and not fn.endswith(".md5"): segements = os.path.splitext(fn)[0].split("_") if is_using_isp(): max_pp = max(max_pp, int(segements[-1][2:])) max_wp = max(max_wp, int(segements[-2][2:])) - max_tp = max(max_tp, int(segements[-3][2:])) elif gpc.config.parallel.zero1.fsdp: max_zo = max(max_zo, int(segements[-1][2:])) max_pp = max(max_pp, int(segements[-2][2:])) @@ -149,16 +149,17 @@ def load_model_checkpoint(folder, model): assert ( wp_size == max_wp + 1 ), f"The weights are save for {max_wp+1} parallelism, while current has {wp_size} weight parallelism" - assert ( - tp_size == max_tp + 1 - ), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism" + if not is_using_isp(): + assert ( + tp_size == max_tp + 1 + ), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism" if gpc.config.parallel.zero1.fsdp: assert ( dp_size == max_zo + 1 ), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards" if is_using_isp(): - should_load_name = f"model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt" + should_load_name = f"model_wp{wp_rank}_pp{pp_rank}.pt" elif gpc.config.parallel.zero1.fsdp: should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}.pt" else: @@ -205,7 +206,7 @@ def save_model_checkpoint(folder, model): If tensor parallel mode is isp, the saved weight is named: - folder - - model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt + - model_wp{wp_rank}_pp{pp_rank}.pt If fsdp is activated, the saved weight is named: - folder @@ -243,11 +244,11 @@ def save_model_checkpoint(folder, model): # for tensor parallel mode with isp if is_using_isp(): - if wdp_rank == 0 or dp_rank == 0: - fn = f"model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt" + if wdp_rank == 0: + fn = f"model_wp{wp_rank}_pp{pp_rank}.pt" fp = os.path.join(folder, fn) llm_save(fp, saved_obj=states) - topo_fn = f"topo_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.json" + topo_fn = f"topo_wp{wp_rank}_pp{pp_rank}.json" topo_fp = os.path.join(folder, topo_fn) llm_save(topo_fp, saved_obj=topo) else: @@ -292,13 +293,12 @@ def load_optimizer_checkpoint(folder, optim): """ fns = get_fns(folder) - max_tp, max_wp, max_pp, max_zero, max_dp = 0, 0, 0, 0, 0 + max_tp, max_wp, max_pp, max_zero = 0, 0, 0, 0 for fn in fns: if fn.startswith("optimizer_") and not fn.endswith(".md5"): if is_using_isp(): - _, tp, wp, pp, dp = os.path.splitext(fn)[0].split("_") - max_dp = max(max_dp, int(dp[2:])) - max_tp = max(max_tp, int(tp[2:])) + _, wp, pp, zero = os.path.splitext(fn)[0].split("_") + max_zero = max(max_zero, int(zero[2:])) max_wp = max(max_wp, int(wp[2:])) max_pp = max(max_pp, int(pp[2:])) else: @@ -311,24 +311,18 @@ def load_optimizer_checkpoint(folder, optim): tp_size = gpc.get_world_size(ParallelMode.TENSOR) wp_size = gpc.get_world_size(ParallelMode.WEIGHT) pp_size = gpc.get_world_size(ParallelMode.PIPELINE) - dp_size = gpc.get_world_size(ParallelMode.DATA) - if is_using_isp(): - assert dp_size == max_dp + 1, ( - f"The optimizer states are save for {max_dp+1} data parallelism, " - f"while current has {dp_size} data parallelism" - ) - if not is_using_isp(): - assert zero_size == max_zero + 1, ( - f"The optimizer states are save for {max_zero+1} zero parallel, " - f"while current has {zero_size} zero broadcast range." - ) + assert zero_size == max_zero + 1, ( + f"The optimizer states are save for {max_zero+1} zero parallel, " + f"while current has {zero_size} zero broadcast range." + ) assert ( pp_size == max_pp + 1 ), f"The optimizer states are save for {max_pp+1} pipelines, while current has {pp_size} pipelines" - assert ( - tp_size == max_tp + 1 - ), f"The optimizer states are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism" + if not is_using_isp(): + assert ( + tp_size == max_tp + 1 + ), f"The optimizer states are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism" assert ( wp_size == max_wp + 1 ), f"The optimizer states are save for {max_wp+1} parallelism, while current has {wp_size} weight parallelism" @@ -337,9 +331,8 @@ def load_optimizer_checkpoint(folder, optim): tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - dp_rank = gpc.get_local_rank(ParallelMode.DATA) if is_using_isp(): - fp = f"optimizer_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}_dp{dp_rank}.pt" + fp = f"optimizer_wp{wp_rank}_pp{pp_rank}_zo{zero_rank}.pt" else: fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt" @@ -387,16 +380,17 @@ def save_optimizer_checkpoint(optim, state_path): tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - dp_rank = gpc.get_local_rank(ParallelMode.DATA) zero_size = gpc.get_world_size(ParallelMode.ZERO1) tp_size = gpc.get_world_size(ParallelMode.TENSOR) + wp_size = gpc.get_world_size(ParallelMode.WEIGHT) dp_size = gpc.get_world_size(ParallelMode.DATA) states = optim.state_dict() if isinstance(optim, (HybridZeroOptimizer, HybridZeroOptimizer_v2)): if is_using_isp(): - fp = f"optimizer_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}_dp{dp_rank}.pt" - llm_save(os.path.join(state_path, fp), states) + fp = f"optimizer_wp{wp_rank}_pp{pp_rank}_zo{zero_rank}.pt" + if (gpc.get_global_rank() % (tp_size * dp_size)) < zero_size * wp_size: + llm_save(os.path.join(state_path, fp), states) else: fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt" if (gpc.get_global_rank() % (tp_size * dp_size)) < zero_size * tp_size: diff --git a/internlm/core/context/__init__.py b/internlm/core/context/__init__.py index a306ad70c..983beda9c 100644 --- a/internlm/core/context/__init__.py +++ b/internlm/core/context/__init__.py @@ -1,6 +1,5 @@ from .parallel_context import ( IS_REPLICA_ZERO_PARALLEL, - IS_TENSOR_DATA_PARALLEL, IS_TENSOR_EXPERT_DATA_PARALLEL, IS_TENSOR_ZERO_PARALLEL, IS_WEIGHT_ZERO_PARALLEL, @@ -32,7 +31,6 @@ __all__ = [ "Config", "IS_TENSOR_ZERO_PARALLEL", - "IS_TENSOR_DATA_PARALLEL", "IS_REPLICA_ZERO_PARALLEL", "IS_WEIGHT_ZERO_PARALLEL", "IS_TENSOR_EXPERT_DATA_PARALLEL", diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 6b23fdae6..de141c661 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -24,12 +24,13 @@ from .process_group_initializer import ParallelMode from .random import add_seed, get_seeds, set_mode +# for layernorm IS_REPLICA_ZERO_PARALLEL = "is_replica_zero_parallel" -# for isp, with optimizer split in dp group -IS_TENSOR_DATA_PARALLEL = "is_tensor_data_parallel" -# for mtp/msp/fsp, with optimizer split in zero1 group +# for mtp/msp/fsp with tensor parallel, and optimizer split in zero1 group IS_TENSOR_ZERO_PARALLEL = "is_tensor_zero_parallel" +# for isp with weight parallel, and optimizer split in zero1 group IS_WEIGHT_ZERO_PARALLEL = "is_weight_zero_parallel" +# for moe IS_TENSOR_EXPERT_DATA_PARALLEL = "is_tensor_expert_data_parallel" logger = get_logger(__file__) @@ -564,6 +565,7 @@ def init_parallel_groups(self): initializers.append(pgroup_initializer.Initializer_Weight_Data(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Data(*initializer_args)) + initializers.append(pgroup_initializer.Initializer_ISP_Data(*initializer_args)) if isinstance(parallel_config["tensor"], dict) and parallel_config["tensor"]["mode"] == "isp": initializers.append(pgroup_initializer.Initializer_Zero1_ISP(*initializer_args)) else: diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index 5519c7d84..bbf319187 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -60,6 +60,9 @@ class ParallelMode(Enum): # sequence parallel SEQUENCE = "sequence" + # real data parallel for isp + ISP_DATA = "isp_data" + # grouped query attention GQA = "gqa" @@ -854,6 +857,66 @@ def init_dist_group(self, use_cpu: bool = False): return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode +class Initializer_ISP_Data(ProcessGroupInitializer): + """A ProcessGroupInitializer for real data parallel group in isp. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + weight_parallel_size (int): Size of model weight parallel. + weight_data_parallel_size (int): Size of data parallel for common weight. + sequence_parallel_size (int): Size of data sequence parallel. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + zero1_parallel_size (int): Size of zero1 parallel. + nettest_parallel_size (int): Size of net testing parallel. + expert_parallel_size (int): Size of expert parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.isp_data_parallel_size = self.tensor_parallel_size * self.data_parallel_size + self.num_isp_data_parallel_group = self.world_size // self.isp_data_parallel_size + + assert self.world_size % self.isp_data_parallel_size == 0 + + def init_dist_group(self, use_cpu: bool = False): + """Initialize real data parallel groups for isp, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + A real data parallelism's information tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.ISP_DATA + + for i in range(self.num_isp_data_parallel_group): + ranks = [i * self.isp_data_parallel_size + j for j in range(self.isp_data_parallel_size)] + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) + if use_cpu: + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) + else: + group_cpu = None + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + class Initializer_GQA(ProcessGroupInitializer): """A ProcessGroupInitializer for allreduce kv gradients with common attention head. diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 71dde3a35..ca3e55075 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -18,9 +18,12 @@ from internlm.core.parallel.comm.utils import ( DUMMY_HANDLE_CONST, AsyncCommHandle, + _gather, all_gather_raw, + apply_to_tensors_only, reduce_scatter_raw, ) +from internlm.model.modules.embedding import Embedding1D from internlm.model.modules.linear import ParallelLinearWithCommExt from internlm.utils.common import SchedulerHook, get_current_device from internlm.utils.utils import ( @@ -59,6 +62,129 @@ def grad_hook(self, tensor: torch.Tensor, async_op: bool = False, **kwargs) -> T pass +class HeadWeightParallelCommunicator(WPCommunicator): + """ + Weight parallel communicator for Head module. + """ + + def __init__(self, process_group: dist.ProcessGroup = None) -> None: + self.process_group = process_group + + def communication_mode(self) -> str: + return "wp" + + def weight_hook( + self, + tensor: torch.Tensor, + async_op: bool = False, + module: nn.Module = None, # pylint: disable=W0613 + is_bias: bool = False, # pylint: disable=W0613 + ) -> torch.Tensor: + if dist.get_world_size(self.process_group) <= 1: + return tensor + + result, _ = all_gather_raw(tensor, self.process_group, async_op=async_op) + return result + + def grad_hook( + self, + tensor: torch.Tensor, + async_op: bool = False, + module: nn.Module = None, # pylint: disable=W0613 + reduce_op: dist.ReduceOp = dist.ReduceOp.AVG, + is_bias: bool = False, # pylint: disable=W0613 + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + if dist.get_world_size(self.process_group) <= 1: + return tensor, DUMMY_HANDLE_CONST + + result, handle = reduce_scatter_raw(tensor, self.process_group, op=reduce_op, async_op=async_op) + return result, handle + + +class EmbeddingWeightParallelCommunicator: + """ + Weight parallel communicator for embedding layer. + """ + + def __init__(self, parallel_mode: ParallelMode) -> None: + self.parallel_mode = parallel_mode + self.emb_column = 1 + + self._cur_micro_step = 0 + self._num_micro_step = gpc.config.data.micro_num + + def register_module_hook(self, module: Embedding1D) -> None: + assert isinstance(module, Embedding1D), "Embbeding weight parallel communicator is only support Embedding1D" + + module.weight.evo_tensor = None + + class PreModuleWrapper(torch.autograd.Function): + """ + Wrapper pre module to prefetch module weight for forward pass. + """ + + @staticmethod + def forward(ctx, inputs: torch.Tensor): # pylint: disable=W0613 + if module.weight.evo_tensor is None: + module.weight.evo_tensor = module.weight.data + + module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.emb_column) + inputs = inputs.detach() + return inputs + + @staticmethod + def backward(ctx: Any, grad_input: torch.Tensor) -> torch.Tensor: # pylint: disable=W0613 + # since input of embedding is int64 dtype, requires_grad=False, the backward fn may not be called + module.weight.data = module.weight.evo_tensor + return grad_input + + class PostModuleWrapper(torch.autograd.Function): + """ + Wrapper post module to prefetch module weight for backward pass. + """ + + @staticmethod + def forward(ctx, output: torch.Tensor): # pylint: disable=W0613 + module.weight.data = module.weight.evo_tensor + output = output.detach() + return output + + @staticmethod + def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: # pylint: disable=W0613 + module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.emb_column) + return grad_output + + def _pre_forward_hook(module, inputs): # pylint: disable=W0613 + return apply_to_tensors_only(PreModuleWrapper.apply, inputs) + + def _post_forward_hook(module, inputs, output): # pylint: disable=W0613 + return apply_to_tensors_only(PostModuleWrapper.apply, output) + + module.register_forward_pre_hook(_pre_forward_hook) + module.register_forward_hook(_post_forward_hook) + + module.weight.register_post_accumulate_grad_hook(self.grad_reduce_hook) + + def grad_reduce_hook(self, param: torch.Tensor): + + _grad, _ = reduce_scatter_raw( + param.grad, gpc.get_group(self.parallel_mode), op=dist.ReduceOp.AVG, reduce_dim=self.emb_column + ) + if param.evo_tensor.grad is None: + param.evo_tensor.grad = _grad + else: + param.evo_tensor.grad += _grad + + param.data = param.evo_tensor + param.grad = None + + self._cur_micro_step += 1 + if self._cur_micro_step == self._num_micro_step: + param.grad = param.evo_tensor.grad + param.evo_tensor.grad = None + self._cur_micro_step = 0 + + class ISPCommModelConfig: """ model config for isp communicator. diff --git a/internlm/core/parallel/comm/tensor.py b/internlm/core/parallel/comm/tensor.py index ca8c19003..2dfc8bd28 100644 --- a/internlm/core/parallel/comm/tensor.py +++ b/internlm/core/parallel/comm/tensor.py @@ -322,7 +322,7 @@ def output_hook(self, module: MoE, args: Any, output: Tuple[Any]) -> Tuple[Any]: return (_output, *_others) -class EmbbedingTensorParallelCommunicator: +class EmbeddingTensorParallelCommunicator: """ tensor parallel communicator for embbeding layer """ @@ -344,7 +344,7 @@ def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tup return gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim) -class EmbbedingSequenceParallelCommunicator: +class EmbeddingSequenceParallelCommunicator: """ sequence parallel communictor for embbeding layer """ diff --git a/internlm/core/parallel/comm/utils.py b/internlm/core/parallel/comm/utils.py index dbfeb3fda..aec3385a4 100644 --- a/internlm/core/parallel/comm/utils.py +++ b/internlm/core/parallel/comm/utils.py @@ -224,3 +224,36 @@ def reduce_scatter_raw( handle = dist.reduce_scatter_tensor(output, input_.contiguous(), op=op, group=process_group, async_op=async_op) return output, handle + + +def apply_to_tensors_only(function, value): + """ + Apply `function` to every Tensor in `value`. + + Args: + functional: The function class to apply. + value (Any): Target object to apply `function` to. + + Returns: + Any: Output of `function`. + """ + if isinstance(value, (tuple, list)): + touched_outputs = [] + for elem in value: + touched_output = apply_to_tensors_only(function, elem) + touched_outputs.append(touched_output) + + return value.__class__(touched_outputs) + elif isinstance(value, dict): + # apply inplace to avoid recreating dict inherited objects + for key in value.keys(): + value[key] = apply_to_tensors_only(function, value[key]) + return value + + elif isinstance(value, torch.Tensor): + # this also applies to torch.Tensor's subclasses like torch.nn.parameter.Parameter + touched_output = function(value) + + return touched_output + else: + return value diff --git a/internlm/core/parallel/shard.py b/internlm/core/parallel/shard.py index 33c187ec5..d0990b8cb 100644 --- a/internlm/core/parallel/shard.py +++ b/internlm/core/parallel/shard.py @@ -9,17 +9,37 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.parallel.comm.utils import _split from internlm.utils.logger import get_logger logger = get_logger(__file__) +def split_data_sequence_parallel(data, label): + _seq_dim = 1 # [batch, seqlen, ...] + _indexes_seq_dim = 0 # [seqlen, ...] + + # NOTICE: since cu_seqlens is used by attention, it should not be splited. + # NOTICE: indexes are only used by rotary embedding. There are a few cases: + # 1. msp/fsp: After wqkv computation, the hidden states are complete along the sequence dimension, + # so we should use the complete indexes when computing the rotary embedding. + # 2. isp: After wqkv computation, the hidden states are segmented along the sequence dimension, + # so we need to segment the indexes accordingly. + if "indexes" in data and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp": + data["indexes"] = _split(data["indexes"], ParallelMode.TENSOR, dim=_indexes_seq_dim) + + data["input_ids"] = _split(data["input_ids"], ParallelMode.TENSOR, dim=_seq_dim) + label = _split(label, ParallelMode.TENSOR, dim=_seq_dim) + + return data, label + + # The head layer in ISP mode is actually a special case, # and we would prefer a unified segmentation and communication logic. -def get_tensor_split_parallel_mode(is_head: bool = False) -> ParallelMode: +def get_tensor_split_parallel_mode() -> ParallelMode: tp_mode = gpc.config.parallel.tensor.mode - if tp_mode == "isp" and is_head is False: + if tp_mode == "isp": return ParallelMode.WEIGHT else: return ParallelMode.TENSOR diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index 8933a5df6..9e2c13dff 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -37,7 +37,7 @@ from internlm.utils.gputest import empty_cache_and_diag from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.utils.parallel import get_parallel_log_file_name +from internlm.utils.parallel import get_parallel_log_file_name, is_using_isp from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler from internlm.utils.writer import Writer @@ -137,10 +137,12 @@ def __init__( ) # initialize metric for calculating accuracy and perplexity + _dp_pg = gpc.get_group(ParallelMode.ISP_DATA) if is_using_isp() else gpc.get_group(ParallelMode.DATA) + _tp_pg = dist.new_group([gpc.get_global_rank()]) if is_using_isp() else gpc.get_group(ParallelMode.TENSOR) metric = AccPerplex( device=get_current_device(), - tp_pg=gpc.get_group(ParallelMode.TENSOR), - dp_pg=gpc.get_group(ParallelMode.DATA), + tp_pg=_tp_pg, + dp_pg=_dp_pg, dataset_types=kwargs["dataset_types"], ) diff --git a/internlm/data/utils.py b/internlm/data/utils.py index 19e74ae2d..d282a865b 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -5,9 +5,7 @@ import torch -from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.parallel.comm.utils import _split def get_dataset_type_ids_map(path): @@ -73,10 +71,6 @@ def packed_data_normalizer(data, label): data["cu_seqlens"] = data["cu_seqlens"][0].squeeze(0) data["max_seqlen"] = (data["cu_seqlens"][1:] - data["cu_seqlens"][:-1]).max().item() - # Move to parallel package for standardization - if gpc.config.parallel.sequence_parallel and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp": - data["indexes"] = _split(data["indexes"], ParallelMode.TENSOR, dim=0) - if gpc.config.model_type == "hf": data.pop("cu_seqlens") data.pop("max_seqlen") diff --git a/internlm/initialize/initialize_trainer.py b/internlm/initialize/initialize_trainer.py index 7e440528a..3881703c1 100644 --- a/internlm/initialize/initialize_trainer.py +++ b/internlm/initialize/initialize_trainer.py @@ -15,6 +15,7 @@ from internlm.core.context import global_context as gpc from internlm.core.engine import Engine from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler +from internlm.core.parallel.shard import split_data_sequence_parallel from internlm.core.scheduler import ( InterleavedPipelineScheduler, NonPipelineScheduler, @@ -26,6 +27,7 @@ from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer from internlm.solver.schedulers.beta2_scheduler import Beta2Scheduler from internlm.utils.common import SchedulerHook, get_current_device +from internlm.utils.parallel import is_using_isp def initialize_trainer( @@ -74,7 +76,24 @@ def initialize_trainer( # initialize scheduler for trainer scheduler = None - data_fn = packed_data_normalizer if gpc.config.data.use_packed_dataset else unpack_data + data_fns = [] + # default data process function + if gpc.config.data.use_packed_dataset: + data_fns.append(packed_data_normalizer) + else: + data_fns.append(unpack_data) + + # support sequence parallel for isp + if is_using_isp(): + data_fns.append(split_data_sequence_parallel) + + # TODO: support context parallel + + def _data_preparation_func(_data, _label): + for fn in data_fns: + _data, _label = fn(_data, _label) + + return _data, _label if gpc.is_using_parallel_mode(ParallelMode.PIPELINE): gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num @@ -89,7 +108,7 @@ def initialize_trainer( communication_overlap = gpc.config.parallel["pipeline"].get("interleaved_overlap", False) scheduler = InterleavedPipelineScheduler( - data_process_func=data_fn, + data_process_func=_data_preparation_func, num_microbatches=gpc.config.NUM_MICRO_BATCHES, num_chunks=gpc.config.model.num_chunks, dtype=gpc.config.model["dtype"], @@ -100,7 +119,7 @@ def initialize_trainer( ) else: scheduler = PipelineScheduler( - data_process_func=data_fn, + data_process_func=_data_preparation_func, num_microbatches=gpc.config.NUM_MICRO_BATCHES, dtype=gpc.config.model["dtype"], tensor_shape=tensor_shape, @@ -109,7 +128,7 @@ def initialize_trainer( ) else: scheduler = NonPipelineScheduler( - data_process_func=data_fn, + data_process_func=_data_preparation_func, gradient_accumulation_size=gpc.config.data.gradient_accumulation, scheduler_hooks=scheduler_hooks, ) diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index 54cc41ba2..6f887259d 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -3,11 +3,13 @@ import torch from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.ops.cross_entropy import new_cross_entropy from internlm.utils.common import SchedulerHook, get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.parallel import is_using_isp try: from torch_scatter import scatter as cuda_scatter @@ -114,6 +116,10 @@ def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str def set_current_type_ids(self, type_ids: torch.Tensor): self.batch_shift = 0 + if is_using_isp(): + step_seqlen = type_ids.shape[1] // gpc.get_world_size(ParallelMode.TENSOR) + sp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + type_ids = type_ids[:, step_seqlen * sp_rank : step_seqlen * (sp_rank + 1)] self.type_ids = type_ids.to(get_current_device()) def set_cu_seqlens(self, cu_seqlens: List): @@ -295,6 +301,12 @@ def update(self, logits, labels, type_ids=None): loss_list = self.loss_fn(logits, labels) + # get current rank part loss_list + if is_using_isp(): + step_seqlen = logits.shape[0] + sp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + loss_list = loss_list[step_seqlen * sp_rank : step_seqlen * (sp_rank + 1)] + cond = labels != -100 real_loss_list = loss_list[cond] self.loss += real_loss_list.sum() diff --git a/internlm/model/modules/embedding.py b/internlm/model/modules/embedding.py index fa922daaa..b85295af2 100644 --- a/internlm/model/modules/embedding.py +++ b/internlm/model/modules/embedding.py @@ -10,6 +10,7 @@ from internlm.core.context import global_context as gpc from internlm.model.ops.rotary_emb import apply_rotary_emb +from internlm.utils.parallel import is_using_isp class Embedding1D(nn.Module): @@ -42,7 +43,9 @@ def __init__( self.embed_args = args self.embed_kwargs = kwargs - embed_dim_per_partition = embedding_dim // gpc.tensor_parallel_size + _parallel_size = gpc.weight_parallel_size if is_using_isp() else gpc.tensor_parallel_size + + embed_dim_per_partition = embedding_dim // _parallel_size self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype)) def forward(self, input_: Tensor) -> Tensor: diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 0d8c4bf82..820df33be 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -455,7 +455,7 @@ def __init__( if norm_head: logger.info("Notice that norm head is enabled to normalize head weight.") - parallel_mode = get_tensor_split_parallel_mode(is_head=True) + parallel_mode = get_tensor_split_parallel_mode() super().__init__( in_features, out_features, parallel_mode, bias=bias, device=device, dtype=dtype, split_mode="column" ) diff --git a/internlm/model/ops/cross_entropy.py b/internlm/model/ops/cross_entropy.py index f3fdccf96..3567f16a6 100644 --- a/internlm/model/ops/cross_entropy.py +++ b/internlm/model/ops/cross_entropy.py @@ -6,12 +6,14 @@ This file implements support for the cross entropy operators. """ +import torch from torch import nn from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.utils.logger import get_logger +from internlm.utils.parallel import is_using_isp try: from flash_attn.losses.cross_entropy import ( @@ -26,6 +28,130 @@ internlm_accelerator = get_accelerator() +class _VocabSequenceParallelCrossEntropy(torch.autograd.Function): + """ + Cross Entropy module for isp. + """ + + @staticmethod + def forward(ctx, vocab_seq_parallel_logits, target, reduction, label_smoothing=0.0): # pylint: disable=W0613 + sp_size = gpc.get_world_size(ParallelMode.TENSOR) + + # reshape + # vocab_seq_parallel_logits: [B * (S/P), V] -> [B, S/P, V] + # target: [B * S/P] -> [B, S/P] + vocab_seq_parallel_logits = vocab_seq_parallel_logits.view( + -1, gpc.config.data.seq_len // sp_size, gpc.config.VOCAB_SIZE + ) + target = target.view(-1, gpc.config.data.seq_len // sp_size) + + # transpose + # vocab_seq_parallel_logits: [B, S/P, V] -> [S/P, B, V] + # target: [B, S/P] -> [S/P, B] + # return: [S, B] + vocab_seq_parallel_logits = vocab_seq_parallel_logits.transpose(0, 1).contiguous() + target = target.transpose(0, 1).contiguous() + + ctx.seqlen = vocab_seq_parallel_logits.size(0) * sp_size + batch_size = vocab_seq_parallel_logits.size(1) + + # Need softmax for backward + softmax = torch.nn.functional.softmax(vocab_seq_parallel_logits, dim=-1) + ctx.vocab_size = vocab_seq_parallel_logits.size(2) + loss = torch.nn.functional.nll_loss(softmax.log().view(-1, ctx.vocab_size), target.view(-1), reduction="none") + + loss_all = torch.empty( + ctx.seqlen, batch_size, dtype=vocab_seq_parallel_logits.dtype, device=vocab_seq_parallel_logits.device + ) + + torch.distributed.all_gather_into_tensor(loss_all, loss, group=gpc.get_group(ParallelMode.TENSOR)) + + # [s b] => [b, s] + loss_all = loss_all.transpose(0, 1).contiguous() + + ctx.save_for_backward(softmax, target) + + return loss_all + + @staticmethod + def backward(ctx, grad_output): + softmax, target = ctx.saved_tensors + + # transpose + grad_output = grad_output.transpose(0, 1).contiguous() + + step_seqlen = ctx.seqlen // gpc.get_world_size(ParallelMode.TENSOR) + sp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + grad_output_part = grad_output[step_seqlen * sp_rank : step_seqlen * (sp_rank + 1), :] + + grad_input = softmax + grad_2d = grad_input.view(-1, ctx.vocab_size) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + + grad_2d[arange_1d, target.view(-1)] -= 1 + grad_input.mul_(grad_output_part.unsqueeze(dim=-1)) + + # transpose + grad_input = grad_input.transpose(0, 1).contiguous() + # reshape + grad_input = grad_input.view(-1, gpc.config.VOCAB_SIZE) + + return grad_input, None, None + + +def vocab_sequence_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0): + return _VocabSequenceParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing) + + +def average_losses_across_data_parallel_group(losses): + """Reduce a tensor of losses across all GPUs.""" + averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses]) + torch.distributed.all_reduce(averaged_losses, group=gpc.get_group(ParallelMode.DATA)) + averaged_losses = averaged_losses / gpc.get_world_size(ParallelMode.DATA) + + return averaged_losses + + +class VocabSequenceParallelCrossEntropyLoss(nn.Module): + """ + Cross Entropy module for isp. + """ + + def __init__( + self, + ignore_index: int = -100, + reduction: str = "mean", + label_smoothing: float = 0, + process_group=None, + ): + super().__init__() + if reduction not in ["mean", "none"]: + raise NotImplementedError("Only support reduction = 'mean' or 'none'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.process_group = process_group + + def loss_mean_func(self, output_tensor): + losses = output_tensor.float() + loss = torch.sum(losses.view(-1)) / losses.numel() + + # TODO: allreduce loss in dp group + + return loss + + def forward(self, _input, target): + assert _input.is_cuda and target.is_cuda + + _loss_list = vocab_sequence_parallel_cross_entropy(_input, target, self.label_smoothing) + + if self.reduction == "mean": + loss = self.loss_mean_func(_loss_list) + return loss + + return _loss_list.view(-1) + + # TODO: ops是否需要实现更加统一的形式 def new_cross_entropy( ignore_index: int = -100, @@ -34,6 +160,16 @@ def new_cross_entropy( parallel_output: bool = False, **kwargs, ): + if is_using_isp(): + if gpc.is_rank_for_log(): + logger.warning("Use VocabSequenceParallelCrossEntropyLoss.") + return VocabSequenceParallelCrossEntropyLoss( + ignore_index=ignore_index, + reduction=reduction, + label_smoothing=label_smoothing, + process_group=gpc.get_group(ParallelMode.TENSOR), + ) + if parallel_output: assert ( gpc.config.model.get("use_flash_attn", False) and flash_cross_entropy_impl diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 5461f9228..ddfad64a4 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -15,7 +15,6 @@ from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import ( IS_REPLICA_ZERO_PARALLEL, - IS_TENSOR_DATA_PARALLEL, IS_TENSOR_EXPERT_DATA_PARALLEL, IS_TENSOR_ZERO_PARALLEL, IS_WEIGHT_ZERO_PARALLEL, @@ -165,7 +164,7 @@ def __init__( if self._is_moe_group(param_group): grad_reduce_mode = ParallelMode.EXPERT_DATA - elif param_group["name"] != "embed_head" and self.use_isp: + elif self.use_isp: grad_reduce_mode = ParallelMode.WEIGHT_DATA else: grad_reduce_mode = ParallelMode.DATA @@ -344,8 +343,9 @@ def extra_layernorm_reduce_grad_hook(*args): # pylint: disable=W0613 # get the AccumulateGrad object of the param itself # If these objects are not kept, reduction hooks may not be attached successfully. - accum_grad_obj = get_grad_accumulate_object(param) - self._grad_store.add_accumulate_grad_object(accum_grad_obj) + if not hasattr(param, "evo_tensor"): + accum_grad_obj = get_grad_accumulate_object(param) + self._grad_store.add_accumulate_grad_object(accum_grad_obj) # the grad of layernorm should be all-reduce across the global process group # here is the first stage all-reduce in tp/wp process group @@ -364,10 +364,16 @@ def extra_layernorm_reduce_grad_hook(*args): # pylint: disable=W0613 and self._isp_communicator.overlap and gpc.config.parallel.weight.size > 1 ): - accum_grad_obj.register_hook(accum_grad_hook) + if hasattr(param, "evo_tensor"): + param.register_post_accumulate_grad_hook(accum_grad_hook) + else: + accum_grad_obj.register_hook(accum_grad_hook) if self._overlap_sync_grad: - accum_grad_obj.register_hook(reduce_grad_hook) + if hasattr(param, "evo_tensor"): + param.register_post_accumulate_grad_hook(reduce_grad_hook) + else: + accum_grad_obj.register_hook(reduce_grad_hook) _define_and_attach(param, reduce_rank) @@ -619,10 +625,6 @@ def _compute_norm(self, group_id: int = 0): elif self.optim.param_groups[group_id]["name"] == "fp32": for param in params: setattr(param, IS_REPLICA_ZERO_PARALLEL, True) - elif self.optim.param_groups[group_id]["name"] == "embed_head": - # should be isp mode - for param in params: - setattr(param, IS_TENSOR_DATA_PARALLEL, True) elif self._is_moe_group(self.optim.param_groups[group_id]): for param in params: setattr(param, IS_TENSOR_EXPERT_DATA_PARALLEL, True) @@ -638,8 +640,6 @@ def _compute_norm(self, group_id: int = 0): for param in params: if hasattr(param, IS_REPLICA_ZERO_PARALLEL): delattr(param, IS_REPLICA_ZERO_PARALLEL) - if hasattr(param, IS_TENSOR_DATA_PARALLEL): - delattr(param, IS_TENSOR_DATA_PARALLEL) if hasattr(param, IS_TENSOR_ZERO_PARALLEL): delattr(param, IS_TENSOR_ZERO_PARALLEL) if hasattr(param, IS_WEIGHT_ZERO_PARALLEL): diff --git a/internlm/solver/optimizer/hybrid_zero_optim_v2.py b/internlm/solver/optimizer/hybrid_zero_optim_v2.py index d231f407f..eab75b6ad 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim_v2.py +++ b/internlm/solver/optimizer/hybrid_zero_optim_v2.py @@ -11,7 +11,6 @@ from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import ( IS_REPLICA_ZERO_PARALLEL, - IS_TENSOR_DATA_PARALLEL, IS_TENSOR_EXPERT_DATA_PARALLEL, IS_TENSOR_ZERO_PARALLEL, IS_WEIGHT_ZERO_PARALLEL, @@ -280,10 +279,6 @@ def add_attr_for_splited_param(self, origin_param, splited_param_current_rank): value = getattr(origin_param, IS_REPLICA_ZERO_PARALLEL) setattr(splited_param_current_rank, IS_REPLICA_ZERO_PARALLEL, value) - if hasattr(origin_param, IS_TENSOR_DATA_PARALLEL): - value = getattr(origin_param, IS_TENSOR_DATA_PARALLEL) - setattr(splited_param_current_rank, IS_TENSOR_DATA_PARALLEL, value) - if hasattr(origin_param, IS_TENSOR_EXPERT_DATA_PARALLEL): value = getattr(origin_param, IS_TENSOR_EXPERT_DATA_PARALLEL) setattr(splited_param_current_rank, IS_TENSOR_EXPERT_DATA_PARALLEL, value) diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 9279c1138..574e7cabf 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -16,7 +16,6 @@ from internlm.utils.logger import get_logger from internlm.utils.parallel import ( is_replica_zero_parallel_parameter, - is_tensor_data_parallel_parameter, is_tensor_expert_data_parallel_parameter, is_tensor_zero_parallel_parameter, is_using_isp, @@ -241,9 +240,6 @@ def reduce_grads(gradients, parameters, weight_parallel_mode): is_replica_zero_parallel_parameter(p) and gpc.get_local_rank(weight_parallel_mode) == 0 ): # if not used in each chunk, such as layernorm IS_REPLICA_ZERO_PARALLEL parameter group parallel_grads.append(g.data.float()) - elif is_tensor_data_parallel_parameter(p): - # process all ranks for IS_TENSOR_DATA_PARALLEL parameter group - parallel_grads.append(g.data.float()) elif is_tensor_zero_parallel_parameter(p): # process all ranks for IS_TENSOR_ZERO_PARALLEL parameter group parallel_grads.append(g.data.float()) @@ -283,10 +279,7 @@ def compute_norm(gradients, parameters, norm_type=2, zero_mode=ParallelMode.ZERO total_norm_cuda = torch.tensor([float(total_norm)], device=gradients[0].device, dtype=torch.float32) # Take max across all model-parallel GPUs. - if is_tensor_data_parallel_parameter(parameters[0]): - if gpc.is_using_parallel_mode(ParallelMode.TENSOR): - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.TENSOR)) - elif is_tensor_zero_parallel_parameter(parameters[0]): + if is_tensor_zero_parallel_parameter(parameters[0]): if gpc.is_using_parallel_mode(ParallelMode.TENSOR): dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.TENSOR)) else: @@ -322,17 +315,12 @@ def compute_norm(gradients, parameters, norm_type=2, zero_mode=ParallelMode.ZERO Sum across all model-parallel GPUs. 1. For the IS_REPLICA_ZERO_PARALLEL parameter group, gradients from rank 0 in the tp/wp process group and gradients along the pp+zero dimensions from all ranks should be aggregated. - 2. For the IS_TENSOR_DATA_PARALLEL parameter group, gradients along the tp+pp+zero(dp) dimensions + 2. For the IS_TENSOR_ZERO_PARALLEL parameter group, gradients along the tp+pp+zero dimensions from all ranks should be aggregated. - 3. For the IS_TENSOR_ZERO_PARALLEL parameter group, gradients along the tp+pp+zero dimensions - from all ranks should be aggregated. - 4. For the IS_WEIGHT_ZERO_PARALLEL parameter group, gradients along the wp+pp+zero dimensions + 3. For the IS_WEIGHT_ZERO_PARALLEL parameter group, gradients along the wp+pp+zero dimensions from all ranks should be aggregated. """ - if is_tensor_data_parallel_parameter(parameters[0]): - if gpc.is_using_parallel_mode(ParallelMode.TENSOR): - dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) - elif is_tensor_zero_parallel_parameter(parameters[0]): + if is_tensor_zero_parallel_parameter(parameters[0]): if gpc.is_using_parallel_mode(ParallelMode.TENSOR): dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) else: diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 0c5156615..79aab6b84 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -12,7 +12,6 @@ from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ( IS_REPLICA_ZERO_PARALLEL, - IS_TENSOR_DATA_PARALLEL, IS_TENSOR_EXPERT_DATA_PARALLEL, IS_TENSOR_ZERO_PARALLEL, IS_WEIGHT_ZERO_PARALLEL, @@ -26,13 +25,15 @@ unwrap_naive_amp, ) from internlm.core.parallel.comm.isp import ( + EmbeddingWeightParallelCommunicator, + HeadWeightParallelCommunicator, ISPCommModelConfig, ISPCommunicator, ISPCommunicatorSchedulerHook, ) from internlm.core.parallel.comm.tensor import ( - EmbbedingSequenceParallelCommunicator, - EmbbedingTensorParallelCommunicator, + EmbeddingSequenceParallelCommunicator, + EmbeddingTensorParallelCommunicator, HeadSequenceParallelCommunicator, HeadTensorParallelCommunicator, LinearRole, @@ -77,7 +78,6 @@ from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.parallel import ( is_replica_zero_parallel_parameter, - is_tensor_data_parallel_parameter, is_tensor_expert_data_parallel_parameter, is_tensor_zero_parallel_parameter, is_using_isp, @@ -118,11 +118,10 @@ def _check_module(name, module): setattr(param, IS_REPLICA_ZERO_PARALLEL, True) # embedding and head - if isinstance(module, (Embedding1D, ScaleColumnParallelLinear)): for param in module.parameters(): - if gpc.is_initialized(ParallelMode.TENSOR) and is_using_isp(): - setattr(param, IS_TENSOR_DATA_PARALLEL, True) + if gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp(): + setattr(param, IS_WEIGHT_ZERO_PARALLEL, True) elif gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp(): setattr(param, IS_TENSOR_ZERO_PARALLEL, True) @@ -149,7 +148,7 @@ def _check_module_hf(_, module): # TODO: check parallel attribute for hf model for param in module.parameters(): if gpc.is_initialized(ParallelMode.TENSOR) and is_using_isp(): - setattr(param, IS_TENSOR_DATA_PARALLEL, True) + setattr(param, IS_WEIGHT_ZERO_PARALLEL, True) elif gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp(): setattr(param, IS_TENSOR_ZERO_PARALLEL, True) @@ -164,7 +163,6 @@ def _check_module_hf(_, module): for name, param in _chunk.named_parameters(): assert ( is_replica_zero_parallel_parameter(param) - or is_tensor_data_parallel_parameter(param) or is_tensor_zero_parallel_parameter(param) or is_weight_zero_parallel_parameter(param) or is_tensor_expert_data_parallel_parameter(param) @@ -273,8 +271,8 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): ColumnParallelLinear.register_cls_communicator(isp_communicator) # row parallel linear will not be used. RowParallelLinear.register_cls_communicator(None) - _head_communicator = HeadSequenceParallelCommunicator(ParallelMode.TENSOR, _retain_out_sharded) - _embedding_communicator = EmbbedingSequenceParallelCommunicator(ParallelMode.TENSOR) + _head_communicator = HeadWeightParallelCommunicator(gpc.get_group(ParallelMode.WEIGHT)) + _embedding_communicator = EmbeddingWeightParallelCommunicator(ParallelMode.WEIGHT) # register communictor for mtp/msp/fsp linear. @@ -287,7 +285,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW) ) _head_communicator = HeadTensorParallelCommunicator(ParallelMode.TENSOR, _retain_out_sharded) - _embedding_communicator = EmbbedingTensorParallelCommunicator(ParallelMode.TENSOR) + _embedding_communicator = EmbeddingTensorParallelCommunicator(ParallelMode.TENSOR) # sequence parallel if gpc.config.parallel.tensor.mode in ("msp", "fsp"): save_total_input_as_activation = gpc.config.parallel.tensor.mode == "msp" @@ -310,7 +308,8 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): _head_communicator = HeadSequenceParallelCommunicator( ParallelMode.TENSOR, _retain_out_sharded, save_total_input_as_activation ) - _embedding_communicator = EmbbedingSequenceParallelCommunicator(ParallelMode.TENSOR) + + _embedding_communicator = EmbeddingSequenceParallelCommunicator(ParallelMode.TENSOR) # MoE sequence parallel if gpc.config.model.get("num_experts", 1) > 1: diff --git a/internlm/train/utils.py b/internlm/train/utils.py index 2f3034434..805dc6536 100644 --- a/internlm/train/utils.py +++ b/internlm/train/utils.py @@ -7,7 +7,6 @@ from internlm.core.context.parallel_context import global_context as gpc from internlm.core.naive_amp import unwrap_naive_amp from internlm.model.modules.utils import is_moe_param -from internlm.utils.parallel import is_tensor_data_parallel_parameter, is_using_isp def split_params_into_different_groups_for_optimizer( @@ -39,10 +38,7 @@ def split_params_into_different_groups_for_optimizer( elif not isinstance(param_groups, list): raise ValueError(f"Unknown param group type of {type(param_groups)}") - # create new groups for IS_TENSOR_DATA_PARALLEL parameter group new_groups = {} - if is_using_isp(): - new_groups["embed_head"] = {"name": "embed_head", "params": [], "optimizer_mode": ParallelMode.DATA} # create new groups for fp32 parameter group new_groups["fp32"] = {"name": "fp32", "params": [], "optimizer_mode": ParallelMode.ZERO1} @@ -60,11 +56,8 @@ def split_params_into_different_groups_for_optimizer( # assign param origin_params = [] for param in pgroup["params"]: - if is_tensor_data_parallel_parameter(param): - # should not be here if not isp mode - new_groups["embed_head"]["params"].append(param) # moe param means MoE is enabled - elif is_moe_param(param): + if is_moe_param(param): new_groups[param.group_name]["params"].append(param) elif param.dtype == torch.float32 and gpc.config.model.dtype != torch.float32: new_groups["fp32"]["params"].append(param) diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 1b92974d8..a793f0f48 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -5,7 +5,6 @@ from internlm.core.context import ( IS_REPLICA_ZERO_PARALLEL, - IS_TENSOR_DATA_PARALLEL, IS_TENSOR_EXPERT_DATA_PARALLEL, IS_TENSOR_ZERO_PARALLEL, IS_WEIGHT_ZERO_PARALLEL, @@ -30,15 +29,6 @@ def is_replica_zero_parallel_parameter(p): return hasattr(p, IS_REPLICA_ZERO_PARALLEL) and getattr(p, IS_REPLICA_ZERO_PARALLEL) -def is_tensor_data_parallel_parameter(p): - return ( - gpc.is_initialized(ParallelMode.TENSOR) - and is_using_isp() - and hasattr(p, IS_TENSOR_DATA_PARALLEL) - and getattr(p, IS_TENSOR_DATA_PARALLEL) - ) - - def is_tensor_zero_parallel_parameter(p): return ( gpc.is_initialized(ParallelMode.TENSOR) diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index fa8147cde..37123588e 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -185,12 +185,7 @@ def train( ckpt_manager.try_resume_training(train_state, current_time) # initialize metric for calculating accuracy and perplexity - metric = AccPerplex( - device=get_current_device(), - tp_pg=gpc.get_group(ParallelMode.TENSOR), - dp_pg=gpc.get_group(ParallelMode.DATA), - dataset_types=dataset_types, - ) + metric = None # initialize trainer engine, scheduler = internlm.initialize_trainer( @@ -240,7 +235,7 @@ def train( trainer.zero_grad() # process data if batch[0].get("type_ids", None) is not None: - metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None)) + batch[0].pop("type_ids", None) # do forward and backward timer("fwd-bwd").start() @@ -296,6 +291,7 @@ def check_loss_spike(): def check_loss_accuracy(): if gpc.is_rank_for_log(): + print(f"cur_loss_list:{cur_loss_list}", flush=True) for cur, target in zip(cur_loss_list, BASELINE_LOSS_LIST): assert ( abs(cur - target) < LOSS_DEVIATION_LIMIT @@ -451,17 +447,18 @@ def test_training_with_isp(): global CONFIG_FILE_PATH, BASELINE_LOSS_LIST CONFIG_FILE_PATH = "./configs/7B_isp_sft.py" BASELINE_LOSS_LIST = [ - 11.594964981079102, - 8.874114990234375, - 7.090242385864258, - 6.782063961029053, - 5.961512088775635, - 5.606202125549316, - 5.305666446685791, - 5.0156569480896, - 4.9411516189575195, - 4.983800411224365, + 10.711931228637695, + 7.549415588378906, + 6.495877742767334, + 5.944756507873535, + 5.246580123901367, + 5.334012031555176, + 4.999225616455078, + 4.70023250579834, + 4.591017723083496, + 4.589826583862305, ] + # model training train(dp_size=4, tp_size=2, wp_size=4, tp_mode="isp", enable_sp=True) From 57b7cd5bd87f26e63fd3f3bd289d6b5f71496064 Mon Sep 17 00:00:00 2001 From: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Date: Thu, 18 Jul 2024 13:49:30 +0800 Subject: [PATCH 12/12] feat(op): support varlen npu flash attention (#209) --- internlm/core/parallel/comm/isp.py | 7 +- internlm/model/ops/attention.py | 224 +++++++++++--- .../test_npu_ops/test_flash_attention.py | 292 +++++++++++------- 3 files changed, 363 insertions(+), 160 deletions(-) diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index ca3e55075..5406dc7e4 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -900,7 +900,12 @@ def auto_wrap_distributed_attention(cls: nn.Module) -> Callable[[bool, Any, floa def _attetion_constructor( local_attn_cls: type, causal=False, softmax_scale=None, attention_dropout=0.0 ) -> nn.Module: - if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp": + try: + tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") + except AttributeError: + tp_mode = "mtp" + + if tp_mode != "isp": return local_attn_cls(causal, softmax_scale, attention_dropout) else: return DistributedAttention( diff --git a/internlm/model/ops/attention.py b/internlm/model/ops/attention.py index 9205652aa..9bb538063 100644 --- a/internlm/model/ops/attention.py +++ b/internlm/model/ops/attention.py @@ -79,7 +79,7 @@ def _nyi_attn(func_name, *args, **kwargs): # pylint: disable=W0613 def _flash_float32_compatibility_wrapper(input_idxs: Tuple, flash_func: Callable, *args, **kwargs): if gpc.config.model.dtype is torch.float32: - inputs = (args[idx] for idx in input_idxs) + inputs = [args[idx] for idx in input_idxs] input_dtype = inputs[0].dtype other_args = [args[idx] for idx in range(len(inputs), len(args))] @@ -194,10 +194,35 @@ def _flash_fixedlen_qkvsplited_attn(q, k, v, dropout_p=0.0, softmax_scale=None, # npu flash attention operators -# TODO: should we add _flash_float32_compatibility_wrapper support for npu. +def _npu_varlen_qkvsplited_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, # pylint: disable=W0613 + max_seqlen_k, # pylint: disable=W0613 + dropout_p=0.0, + softmax_scale=None, + causal=False, +): + return _flash_float32_compatibility_wrapper( + (0, 1, 2), + _npu_varlen_qkvsplited_func, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + ) -def _npu_varlen_qkvsplited_attn( +def _npu_varlen_qkvsplited_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -208,17 +233,32 @@ def _npu_varlen_qkvsplited_attn( dropout_p=0.0, softmax_scale=None, causal=False, + use_fixlen=False, ): - # TODO: support npu native varlen flash attention + """Support Huawei Ascend's torch_npu flash attention. + Tested version: + torch: 2.1.0+cpu + torch_npu: 2.1.0.post3+git7c4136d + cann: 8.0.RC1.alpha003 + """ packed_length = q.size(dim=1) + softmax_scale = softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q) - k = unpack_qkv_before_attn(k, cu_seqlens=cu_seqlens_k) - v = unpack_qkv_before_attn(v, cu_seqlens=cu_seqlens_k) + if use_fixlen: - output = _npu_fixedlen_qkvsplited_attn(q, k, v, dropout_p, softmax_scale, causal) + q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q) + k = unpack_qkv_before_attn(k, cu_seqlens=cu_seqlens_k) + v = unpack_qkv_before_attn(v, cu_seqlens=cu_seqlens_k) - return pack_output_after_attn(output, cu_seqlens_q, packed_length) + output = _npu_fixedlen_qkvsplited_attn(q, k, v, dropout_p, softmax_scale, causal) + + output = pack_output_after_attn(output, cu_seqlens_q, packed_length) + else: + output = _npu_fused_varlen_qkvsplited_attn( + q, k, v, dropout_p, softmax_scale, causal, max_seqlen_q, max_seqlen_k, cu_seqlens_q, cu_seqlens_k + ) + + return output def _npu_fixedlen_qkvsplited_attn( @@ -236,6 +276,7 @@ def _npu_fixedlen_qkvsplited_attn( q, k, v = q.squeeze(dim=2), k.squeeze(dim=2), v.squeeze(dim=2) _, seqlen, n_head, _ = q.shape + sparse_mode = 0 attention_mask = torch.triu(torch.ones(seqlen, seqlen, device=get_current_device()), 1).bool() return _origin_npu_fixedlen_qkvsplited_func( @@ -247,25 +288,71 @@ def _npu_fixedlen_qkvsplited_attn( pse=None, atten_mask=attention_mask, scale=softmax_scale, - sparse_mode=0, # If necessary, expose the interface + sparse_mode=sparse_mode, # If necessary, expose the interface pre_tockens=seqlen, # Used for sparse calculations, representing the left boundary of the slides window next_tockens=0, # If necessary, expose the interface keep_prob=1 - dropout_p, inner_precise=0, # If necessary, expose the interface - ) + )[0] -def _npu_varlen_qkvpacked_attn( - qkv: torch.Tensor, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, causal=False # pylint: disable=W0613 +def _npu_fused_varlen_qkvsplited_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float, + softmax_scale=None, + causal=False, + max_seqlen_q: int = None, + max_seqlen_k: int = None, + cu_seqlens_q=None, + cu_seqlens_kv=None, + deterministic=False, ): - # TODO: support npu native varlen flash attention - packed_length = qkv.size(dim=1) + assert causal is True + assert q.dtype in (torch.bfloat16, torch.float16) - qkv = unpack_qkv_before_attn(qkv, cu_seqlens=cu_seqlens) + if len(q.shape) == 4: # [1, packedseqlen, n_head, headdim] + q, k, v = q.squeeze(dim=0), k.squeeze(dim=0), v.squeeze(dim=0) - output = _npu_fixedlen_qkvpacked_attn(qkv, dropout_p, softmax_scale, causal) + S, N = max(max_seqlen_q, max_seqlen_k), q.shape[1] + device = get_current_device() + sparse_mode = 0 - return pack_output_after_attn(output, cu_seqlens, packed_length) + if max_seqlen_k > 2048 and max_seqlen_q > 2048: + sparse_mode = 2 + max_seqlen_k = 2048 + max_seqlen_q = 2048 + + attention_mask = torch.triu(torch.ones(max_seqlen_q, max_seqlen_k, device=device), 1).bool() + cu_seqlens_q = cu_seqlens_q[1:].tolist() + cu_seqlens_kv = cu_seqlens_kv[1:].tolist() + + return _origin_npu_fixedlen_qkvsplited_func( + query=q, + key=k, + value=v, + head_num=N, + input_layout="TND", + pse=None, + atten_mask=attention_mask, + scale=softmax_scale, + sparse_mode=sparse_mode, + pre_tockens=S, # Used for sparse calculations, representing the left boundary of the slides window + next_tockens=0, + keep_prob=1 - dropout_p, + inner_precise=0 if not deterministic else 2, + actual_seq_kvlen=cu_seqlens_kv, + actual_seq_qlen=cu_seqlens_q, + )[0].unsqueeze(dim=0) + + +def _npu_varlen_qkvpacked_attn( + qkv: torch.Tensor, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, causal=False # pylint: disable=W0613 +): + # TODO: support npu native varlen flash attention + q, k, v = qkv.unbind(dim=2) + return _npu_varlen_qkvsplited_attn(q, k, v, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal) def _npu_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout_p: float, softmax_scale=None, causal=False): @@ -285,14 +372,20 @@ def _npu_varlen_kvpacked_attn( causal=False, ): # TODO: support npu native varlen flash attention - packed_length = q.size(dim=1) - - q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q) - kv = unpack_qkv_before_attn(kv, cu_seqlens=cu_seqlens_k) - - output = _npu_fixedlen_kvpacked_attn(q, kv, dropout_p, softmax_scale, causal) - - return pack_output_after_attn(output, cu_seqlens_q, packed_length) + k, v = kv.unbind(dim=2) + k, v = k.squeeze(dim=2), v.squeeze(dim=2) + return _npu_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + ) def _npu_fixedlen_kvpacked_attn(q: torch.Tensor, kv: torch.Tensor, dropout_p: float, softmax_scale=None, causal=False): @@ -335,12 +428,6 @@ def _deeplink_fixedlen_qkvsplited_attn(*args, **kwargs): # torch attention operators - - -def _torch_varlen_qkvpacked_attn(*args, **kwargs): - _nyi_attn("_torch_varlen_qkvpacked_attn", *args, **kwargs) - - # adpated from https://github.com/Dao-AILab/flash-attention/blob/v2.2.1/flash_attn/modules/mha.py def _torch_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None): batch_size, seqlen = qkv.shape[0], qkv.shape[1] @@ -369,10 +456,6 @@ def _torch_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout, softmax_scale=Non return output -def _torch_varlen_kvpacked_attn(*args, **kwargs): - _nyi_attn("_torch_varlen_kvpacked_attn", *args, **kwargs) - - # adpated from https://github.com/Dao-AILab/flash-attention/blob/v2.2.1/flash_attn/modules/mha.py def _torch_fixedlen_kvpacked_attn( q: torch.Tensor, kv: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None @@ -407,10 +490,6 @@ def _torch_fixedlen_kvpacked_attn( return output -def _torch_varlen_qkvsplited_attn(*args, **kwargs): - _nyi_attn("_torch_varlen_qkvsplited_attn", *args, **kwargs) - - def _torch_fixedlen_qkvsplited_attn( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None ): @@ -418,6 +497,71 @@ def _torch_fixedlen_qkvsplited_attn( return _torch_fixedlen_kvpacked_attn(q, kv, dropout, softmax_scale, causal, key_padding_mask) +def _torch_varlen_qkvsplited_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, # pylint: disable=W0613 + max_seqlen_k, # pylint: disable=W0613 + dropout, + softmax_scale=None, + causal=False, + key_padding_mask=None, +): + kv = torch.stack([k, v], dim=2) + packed_length = q.size(dim=1) + + q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q) + kv = unpack_qkv_before_attn(kv, cu_seqlens=cu_seqlens_k) + + output = _torch_fixedlen_kvpacked_attn(q, kv, dropout, softmax_scale, causal, key_padding_mask) + + return pack_output_after_attn(output, cu_seqlens_q, packed_length) + + +def _torch_varlen_qkvpacked_attn( + qkv: torch.Tensor, + cu_seqlens, + max_seqlen, # pylint: disable=W0613 + dropout, + softmax_scale=None, + causal=False, + key_padding_mask=None, +): + + packed_length = qkv.size(dim=1) + qkv = unpack_qkv_before_attn(qkv, cu_seqlens=cu_seqlens) + + output = _torch_fixedlen_qkvpacked_attn(qkv, dropout, softmax_scale, causal, key_padding_mask) + + return pack_output_after_attn(output, cu_seqlens, packed_length) + + +def _torch_varlen_kvpacked_attn( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, # pylint: disable=W0613 + max_seqlen_k, # pylint: disable=W0613 + dropout, + softmax_scale=None, + causal=False, + key_padding_mask=None, +): + + packed_length = q.size(dim=1) + + q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q) + kv = unpack_qkv_before_attn(kv, cu_seqlens=cu_seqlens_k) + + output = _torch_fixedlen_kvpacked_attn(q, kv, dropout, softmax_scale, causal, key_padding_mask) + + return pack_output_after_attn(output, cu_seqlens_q, packed_length) + + @auto_wrap_distributed_attention class SelfAttention(nn.Module): """Implements scaled dot-product attention with optional softmax scaling. diff --git a/tests/test_model/test_npu_ops/test_flash_attention.py b/tests/test_model/test_npu_ops/test_flash_attention.py index 31a8ba61b..8ab300e55 100644 --- a/tests/test_model/test_npu_ops/test_flash_attention.py +++ b/tests/test_model/test_npu_ops/test_flash_attention.py @@ -3,6 +3,7 @@ """ import math +import random import pytest import torch @@ -11,152 +12,205 @@ from torch import nn from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.model.modules.multi_head_attention import ( - AscendFlashSelfAttention, - CrossAttention, - SelfAttention, -) +from internlm.core.context import Config +from internlm.core.context import global_context as gpc +from internlm.model.ops.attention import SelfAttention +from internlm.model.ops.utils import pack_output_after_attn, unpack_qkv_before_attn +from internlm.utils.common import get_current_device, set_random_seed HEAD_NUM = 32 HIDDEN_SZIE = 4096 -SEQ_LEN = 2048 +SEQ_LEN = [2048, 4096] MICRO_BSZ = 1 HEAD_DIM = HIDDEN_SZIE // HEAD_NUM VOCAB_SIZE = 32000 - +NUM_KV_HEAD_LIST = [8, 32] MICRO_BSZ_LIST = [1, 2] DTYPE_LIST = [torch.bfloat16, torch.float16] -NUM_KV_HEAD_LIST = [8, 32] -USE_PADDING = [True, False] internlm_accelerator = get_accelerator() -def check_mean_and_std(name, out1, out2): - named1_mean = out1.to(dtype=torch.float64).mean() - named1_std = out1.to(dtype=torch.float64).std() - named2_mean = out2.to(dtype=torch.float64).mean() - named2_std = out2.to(dtype=torch.float64).std() - check_statistic_equality(name, named1_mean, named2_mean, eq=True, is_mean=True) - check_statistic_equality(name, named1_std, named2_std, eq=True, is_mean=False) - - -def check_statistic_equality(name, value1, value2, eq=False, is_mean=True, threshold=1e-9): - if (abs(value1 - value2) < threshold) ^ eq: - if eq: - print( - f"On {name}, " - f"we have {'mean' if is_mean else 'std'}s of fa_out " - f"very {'close' if not eq else 'different'}, " - f"from :{value1} " - f"and :{value2}", - flush=True, - ) - else: - print( - f"On {name}, " - f"we have {'mean' if is_mean else 'std'}s of fa_out " - f"very {'close' if not eq else 'different'}, " - f"from :{value1} " - f"and :{value2}", - flush=True, - ) - - -def do_cmp_attn( - name, - B, # pylint: disable=W0613 - S, # pylint: disable=W0613 - N, - N_KV, - q, - k, - v, - dtype, - attention_mask, # pylint: disable=W0613 - softmax_scale, - attention_dropout=0.0, - **attn_args, # pylint: disable=W0613 -): - - npu_attn_cls = CrossAttention if N != N_KV else SelfAttention - npu_attn = npu_attn_cls( - causal=True, - softmax_scale=softmax_scale, - attention_dropout=attention_dropout, - ).to(dtype) - # TODO: 修复它. - npu_flash_attn = AscendFlashSelfAttention( - causal=True, - softmax_scale=softmax_scale, - attention_dropout=attention_dropout, - ).to(dtype) - - if N == N_KV: - a = npu_attn(torch.concat([q, k, v], dim=2)) # pylint: disable=E1102 - else: - a = npu_attn(q.squeeze(dim=2), torch.concat([k, v], dim=2)) # pylint: disable=E1102 - - b = npu_flash_attn(q=q, k=k, v=v) # pylint: disable=E1102 - assert torch.isfinite(a).all().item() and torch.isfinite(b).all().item() - - if dtype == torch.bfloat16: - # torch_npu's equal not support bfloat16 by now. - assert torch.allclose(a.to(torch.float32), b.to(torch.float32), atol=5e-2, rtol=1e-4), f"{name} not pass" - else: - assert torch.allclose(a, b, atol=5e-2, rtol=1e-4), f"{name} not pass" +def init_qkv(B, S, N_KV, dtype, device): + x = torch.LongTensor([[i + 1 for i in range(S)] for _ in range(B)]).to(device) + cu_seqlens = [0] + sorted(random.sample(list(range(x.numel())), 4)) + if cu_seqlens[-1] != x.numel(): + cu_seqlens.append(x.numel()) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int64, device=device) + x = rearrange(x, "b s -> (b s)").unsqueeze(0) - -def npu_transform(B, S, N, N_KV, D, dtype, use_padding): - if use_padding: - x = torch.LongTensor([[i + 1 if i < S // 2 else 0 for i in range(S)] for _ in range(B)]).npu() # padding S-1024 - else: - x = torch.LongTensor([[i + 1 for i in range(S)] for _ in range(B)]).npu() # no-padiing - - wq = torch.zeros((N * D, N * D), dtype=dtype, device="npu") - wk = torch.zeros((N_KV * D, N * D), dtype=dtype, device="npu") - wv = torch.zeros((N_KV * D, N * D), dtype=dtype, device="npu") - wembed = torch.zeros((VOCAB_SIZE, HIDDEN_SZIE), dtype=dtype, device="npu") + KV_DIM = HEAD_DIM * N_KV + Q_PER_KV = HEAD_NUM // N_KV + wqkv = torch.rand((HIDDEN_SZIE + 2 * KV_DIM, HIDDEN_SZIE), dtype=dtype, device=device) + wembed = torch.rand((VOCAB_SIZE, HIDDEN_SZIE), dtype=dtype, device=device) # It is very important to set appropriate initialization values for parameters so # that the values fall within an appropriate precision range to prevent overflow or underflow. with torch.no_grad(): - wq = nn.init.normal_(wq.data) - wk = nn.init.normal_(wk.data) - wv = nn.init.normal_(wv.data) + wqkv.data = nn.init.normal_(wqkv.data) wembed = nn.init.normal_(wembed.data, std=0.02) embed_x = F.embedding(x, wembed).to(dtype) - q = F.linear(embed_x, wq) # pylint: disable=E1102 - k = F.linear(embed_x, wk) # pylint: disable=E1102 - v = F.linear(embed_x, wv) # pylint: disable=E1102 - - q = rearrange(q, "b s (one h d) -> b s one h d", b=B, s=S, d=D, one=1) - k = rearrange(k, "b s (one h d) -> b s one h d", b=B, s=S, d=D, one=1) - v = rearrange(v, "b s (one h d) -> b s one h d", b=B, s=S, d=D, one=1) - - do_cmp_attn( - f"B_{B}_S_{S}_N_{N}_N_KV_{N_KV}_D_{D}_{dtype}", - B, - S, - N, - N_KV, - q, - k, - v, - dtype, - None, - 1 / math.sqrt(HIDDEN_SZIE // HEAD_NUM), + qkv = F.linear(embed_x, wqkv) # pylint: disable=E1102 + qkv = rearrange(qkv, "b s (h gs d) -> b s h gs d", gs=Q_PER_KV + 2, d=HEAD_DIM) + q, k, v = (qkv[..., :Q_PER_KV, :], qkv[..., -2, :], qkv[..., -1, :]) + q = rearrange(q, "b t h gs d -> b t (h gs) d") + kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) + return q, kv, cu_seqlens + + +def fixed_length_fa(q, kv, cu_seqlens, packed_len, attn_cls, use_fa=False): + q = unpack_qkv_before_attn(q, cu_seqlens) + kv = unpack_qkv_before_attn(kv, cu_seqlens) + gpc._config = Config(dict(model=dict(use_flash_attn=use_fa, dtype=torch.bfloat16))) + c = attn_cls(q=q, kv=kv) # fix length self attention in npu + c = rearrange(c, "b s h d -> b s (h d)") + return pack_output_after_attn(c, cu_seqlens, packed_length=packed_len) + + +def var_length_fa(q, kv, cu_seqlens, max_seqlen, attn_cls): + gpc._config = Config(dict(model=dict(use_flash_attn=True, dtype=torch.bfloat16))) + b = attn_cls( # pylint: disable=E1102 + q=q, + kv=kv, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, ) + return rearrange(b, "b s h d -> b s (h d)") + + +def assert_equal(a, b, atol_bf16=5e-2, rtol_bf16=1e-4, atol_fp16=5e-2, rtol_fp16=1e-4): + assert a.dtype == b.dtype + assert torch.isfinite(a).all().item() and torch.isfinite(b).all().item() + if a.dtype is torch.bfloat16: + assert torch.allclose(a, b, atol=atol_bf16, rtol=rtol_bf16), f"a: {a}, b: {b}" + elif a.dtype is torch.float16: + assert torch.allclose(a, b, atol=atol_fp16, rtol=rtol_fp16), f"a: {a}, b: {b}" + else: + assert False + + +def npu_fwd_transform(B, S, N_KV, dtype): + + set_random_seed(1024) + softmax_scale = 1 / math.sqrt(HEAD_DIM) + cross_attn = SelfAttention(causal=True, softmax_scale=softmax_scale, attention_dropout=0.0).to(dtype) + npu_flash_attn = SelfAttention(causal=True, softmax_scale=softmax_scale, attention_dropout=0.0).to(dtype) + + with torch.no_grad(): + q, kv, cu_seqlens = init_qkv(B, S, N_KV, dtype, get_current_device()) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + + q, kv = q.requires_grad_(), kv.requires_grad_() + a = fixed_length_fa(q, kv, cu_seqlens, B * S, cross_attn, use_fa=False) + + q_2, kv_2 = q.detach().clone().requires_grad_(), kv.detach().clone().requires_grad_() + b = fixed_length_fa(q_2, kv_2, cu_seqlens, B * S, npu_flash_attn, use_fa=True) + + q_3, kv_3 = q.detach().clone().requires_grad_(), kv.detach().clone().requires_grad_() + c = var_length_fa(q_3, kv_3, cu_seqlens, max_seqlen, npu_flash_attn) + + # assert_equal(a, b, atol_bf16=1e-1) + assert_equal(a, c, atol_bf16=1e-1) + print("test npu_fwd_transform done!", flush=True) + + return a, b, c, q, q_2, q_3, kv, kv_2, kv_3 + + +def npu_transform(B, S, N_KV, dtype): + a, b, c, q, q_2, q_3, kv, kv_2, kv_3 = npu_fwd_transform(B, S, N_KV, dtype) # pylint: disable=W0612 + g = torch.randn_like(b) + g.uniform_(-2, 2) + + b.backward(g.clone(), retain_graph=True) + a.backward(g.clone(), retain_graph=True) + c.backward(g.clone(), retain_graph=True) + + # assert_equal(q.grad, W0612.grad, atol_bf16=1e-1) + assert_equal(q.grad, q_3.grad, atol_bf16=1e-1) + # assert_equal(kv.grad, kv_2.grad, atol_bf16=5e-1, rtol_bf16=1e-3) + assert_equal(kv.grad, kv_3.grad, atol_bf16=5e-1) + + print("test npu_transform done!", flush=True) + + +def deeplink_fwd_transform(B, S, N_KV, dtype): + from deeplink_ext.internevo_ops import FlashSelfAttention + + from internlm.model.modules.multi_head_attention import CrossAttention + + set_random_seed(1024) + softmax_scale = 1 / math.sqrt(HEAD_DIM) + cross_attn = CrossAttention(causal=True, softmax_scale=softmax_scale, attention_dropout=0.0).to(dtype) + dp_flash_attn = FlashSelfAttention(causal=True, softmax_scale=softmax_scale, attention_dropout=0.0).to(dtype) + + with torch.no_grad(): + q, kv, cu_seqlens = init_qkv(B, S, N_KV, dtype, get_current_device()) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + + q, kv = q.requires_grad_(), kv.requires_grad_() + a = fixed_length_fa(q, kv, cu_seqlens, B * S, cross_attn) + + q_2, kv_2 = q.detach().clone().requires_grad_(), kv.detach().clone().requires_grad_() + b = var_length_fa(q_2, kv_2, cu_seqlens, max_seqlen, dp_flash_attn) + + assert_equal(a, b) + print("test deeplink_fwd_transform done!", flush=True) + + return a, b, q, q_2, kv, kv_2 + + +def deeplink_transform(B, S, N_KV, dtype): + a, b, q, q_2, kv, kv_2 = deeplink_fwd_transform(B, S, N_KV, dtype) + + g = torch.randn_like(b) + g.uniform_(-2, 2) + + b.backward(g.clone(), retain_graph=True) + a.backward(g.clone(), retain_graph=True) + + assert_equal(q.grad, q_2.grad, atol_bf16=1e-1) + assert_equal(kv.grad, kv_2.grad, atol_bf16=1e-1) + + print("test deeplink_transform done!", flush=True) + + +@pytest.mark.parametrize("micro_bsz", MICRO_BSZ_LIST) +@pytest.mark.parametrize("test_dtype", DTYPE_LIST) +@pytest.mark.parametrize("num_kv_head", NUM_KV_HEAD_LIST) +@pytest.mark.parametrize("seqlen", SEQ_LEN) +def test_NPU_fa_fwd(micro_bsz, test_dtype, num_kv_head, seqlen): + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: + npu_fwd_transform(micro_bsz, seqlen, num_kv_head, test_dtype) @pytest.mark.parametrize("micro_bsz", MICRO_BSZ_LIST) @pytest.mark.parametrize("test_dtype", DTYPE_LIST) @pytest.mark.parametrize("num_kv_head", NUM_KV_HEAD_LIST) -@pytest.mark.parametrize("use_padding", USE_PADDING) -def test_NPU_fa(micro_bsz, test_dtype, num_kv_head, use_padding): +@pytest.mark.parametrize("seqlen", SEQ_LEN) +def test_NPU_fa_bwd(micro_bsz, test_dtype, num_kv_head, seqlen): if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: - npu_transform(micro_bsz, SEQ_LEN, HEAD_NUM, num_kv_head, HIDDEN_SZIE // HEAD_NUM, test_dtype, use_padding) + npu_transform(micro_bsz, seqlen, num_kv_head, test_dtype) + + +# @pytest.mark.parametrize("micro_bsz", MICRO_BSZ_LIST) +# @pytest.mark.parametrize("test_dtype", DTYPE_LIST) +# @pytest.mark.parametrize("num_kv_head", NUM_KV_HEAD_LIST) +# def test_deeplink_fa_fwd(micro_bsz, test_dtype, num_kv_head): +# if internlm_accelerator.get_accelerator_backend() == AcceleratorType.DIPU: +# deeplink_fwd_transform(micro_bsz, SEQ_LEN, num_kv_head, test_dtype) + + +# @pytest.mark.parametrize("micro_bsz", MICRO_BSZ_LIST) +# @pytest.mark.parametrize("test_dtype", DTYPE_LIST) +# @pytest.mark.parametrize("num_kv_head", NUM_KV_HEAD_LIST) +# def test_deeplink_fa_bwd(micro_bsz, test_dtype, num_kv_head): +# if internlm_accelerator.get_accelerator_backend() == AcceleratorType.DIPU: +# deeplink_transform(micro_bsz, SEQ_LEN, num_kv_head, test_dtype) if __name__ == "__main__":