Skip to content

Commit

Permalink
feat(modeling): support qwen2
Browse files Browse the repository at this point in the history
  • Loading branch information
SolenoidWGT committed Aug 7, 2024
1 parent 3d84b85 commit 99b0302
Show file tree
Hide file tree
Showing 5 changed files with 396 additions and 7 deletions.
158 changes: 158 additions & 0 deletions configs/qwen2_7b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) InternLM. All rights reserved.
model_type = "LLAMA2"

VOCAB_SIZE = 152064

HIDDEN_SIZE = 3584
NUM_ATTENTION_HEAD = 28
NUM_KV_ATTENTION_HEAD = 4
MLP_RATIO = 1
NUM_LAYER = 28
MULTIPLE_OF = 256

model = dict(
checkpoint=False,
num_chunks=1,
num_attention_heads=NUM_ATTENTION_HEAD,
num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
qkv_bias=True,
o_bias=False,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-6,
rope_base=1000000,
sliding_window_cfg=dict(
use_sliding_window=False,
sliding_window=131072,
max_window_layers=28,
),
multiple_of=MULTIPLE_OF,
intermediate_size=18944,
)

hybrid_zero_optimizer = dict(
# Enable low_level_optimzer overlap_communication
overlap_sync_grad=True,
overlap_sync_param=False,
# bucket size for nccl communication params
reduce_bucket_size=512 * 1024 * 1024,
# grad clipping
clip_grad_norm=1.0,
)

parallel = dict(
zero1=dict(size=8, fsdp=False),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=False, memory_pool=False),
)


JOB_NAME = "qwen2"
LEARNING_RATE = 1e-3
MIN_LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.1
WARMUP_RATIO = 0.028
OPTIMIZER_WARMUP_STEP = 0

MICRO_NUM = 1
MICRO_BSZ = 1
SEQ_LEN = 4096
TOTAL_STEP = 75000
PACK_SAMPLE_INTO_ONE = False
USE_PACKED_DATASET = True
SAVED_DATA_PATH = ""

SAVE_CKPT_FOLDER = None
LOAD_MODEL_PATH = None
CHECKPOINT_EVERY = 1000

data = dict(
seq_len=SEQ_LEN,
micro_num=MICRO_NUM,
micro_bsz=MICRO_BSZ,
valid_micro_num=4,
valid_every=0,
pack_sample_into_one=PACK_SAMPLE_INTO_ONE,
total_steps=TOTAL_STEP,
skip_batches="",
rampup_batch_size="",
min_length=50,
train_folder=None,
valid_folder=None,
empty_cache_and_diag_interval=200,
diag_outlier_ratio=1.1,
use_packed_dataset=USE_PACKED_DATASET,
)
loss = dict(label_smoothing=0.0)
adam = dict(
lr=LEARNING_RATE,
adam_beta1=0.9,
adam_beta2=0.95,
adam_beta2_c=0,
adam_eps=1e-8,
weight_decay=WEIGHT_DECAY,
)

lr_scheduler = dict(
total_steps=data["total_steps"],
init_steps=OPTIMIZER_WARMUP_STEP, # optimizer_warmup_step
warmup_ratio=WARMUP_RATIO,
eta_min=MIN_LEARNING_RATE,
last_epoch=-1,
)

beta2_scheduler = dict(
init_beta2=adam["adam_beta2"],
c=adam["adam_beta2_c"],
cur_iter=-1,
)
cudnn_deterministic = False
cudnn_benchmark = False
monitor = dict(
alert=dict(
enable_feishu_alert=False,
feishu_alert_address=None, # feishu webhook to send alert message
light_monitor_address=None, # light_monitor address to send heartbeat
alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
),
tensorboard=dict(
queue_max_length=10,
),
)
grad_scaler = dict(
fp16=dict(
# the initial loss scale, defaults to 2**16
initial_scale=2**14,
# the minimum loss scale, defaults to None
min_scale=1,
# the number of steps to increase loss scale when no overflow occurs
growth_interval=1000,
),
# the multiplication factor for increasing loss scale, defaults to 2
growth_factor=2,
# the multiplication factor for decreasing loss scale, defaults to 0.5
backoff_factor=0.5,
# the maximum loss scale, defaults to None
max_scale=2**24,
# the number of overflows before decreasing loss scale, defaults to 2
hysteresis=2,
)
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
auto_resume=False,
checkpoint_every=CHECKPOINT_EVERY,
async_upload=False, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
oss_snapshot_freq=CHECKPOINT_EVERY, # snapshot ckpt save frequency.
load_ckpt_info=dict(path="./Qwen2-7B", content=("model",), ckpt_type="hf_qwen2"),
)
206 changes: 206 additions & 0 deletions internlm/checkpoint/load_funcs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) InternLM. All rights reserved.
import os
import re

import torch

Expand Down Expand Up @@ -313,9 +314,214 @@ def load_hf_model_pretrained_weights(folder, model):
logger.info("Pretrained weights loaded successfully")


from safetensors import safe_open


def load_pp_hf_ckpt(path, prefix: str = None, suffix: str = None):
assert path is not None, "Please specify the folder of the pretrained model"
if gpc.is_rank_for_log():
logger.info(f"Loading pretrained model from {path}")
fns = get_fns(path)
model_fns, ckpt_type = [], None
if prefix or suffix:
prefix = "" if prefix is None else prefix
suffix = "" if suffix is None else suffix
for fn in fns:
if fn.endswith(suffix) and fn.startswith(prefix):
if not ckpt_type:
if fn.endswith(".safetensors"):
ckpt_type = "safetensors"
else:
ckpt_type = "torch"
model_fns.append(os.path.join(path, fn))
else:
for fn in fns:
if not ckpt_type:
if fn.endswith(".safetensors") and fn.startswith("model"):
ckpt_type = "safetensors"
elif fn.endswith(".bin") and fn.startswith("pytorch_model"):
ckpt_type = "torch"
if (ckpt_type == "safetensors" and fn.endswith(".safetensors")) or (
ckpt_type == "torch" and fn.endswith(".bin")
):
model_fns.append(os.path.join(path, fn))
model_fns.sort()
states = {}

for model_fn in model_fns:
tensors = {}
with safe_open(model_fn, framework="pt", device="cpu") as f:
for k in f.keys():
tensors[k] = f.get_tensor(k)
states.update(tensors)
return ckpt_type == "safetensors" if ckpt_type else False, states


def get_mapping(key, mappings):
match = []
for mapping in mappings:
if isinstance(mapping, tuple) and len(mapping) and re.search(mapping[0], key):
match.append(mapping)
# search the ordinal number of the layer
layer_pattern = re.search(r"\.\b(\d+)\.", key)
if layer_pattern:
layer = int(layer_pattern.group(1))
else:
layer = None
return layer, match, "bias" in key


def replace_between_dots(text, old, new):
pattern = re.compile(r"(?<=\.)" + re.escape(old) + r"(?=\.)")
return pattern.sub(new, text)


def get_local_splited_weight(
states,
mappings,
pp_layer_range,
tp_world_size,
tp_local_rank,
):
def find_interval_index(number, intervals):
for chunk_id, (start, end) in enumerate(intervals):
if start <= number <= end:
return chunk_id
return -1

current_states = {}
for k, v in states.items():
# match the pattern in ckpt module name
layer, matches, bias = get_mapping(k, mappings)
if matches:
# import pdb; pdb.set_trace()
if layer and (chunk_id := find_interval_index(layer, pp_layer_range)) == -1: # [(0, 14), (14, 28)]
continue

for mapping in matches:
ckpt_name, model_name, chunk_dim, local_rank = mapping
if local_rank:
# replace the pattern in ckpt module name into model module name
key = re.sub(ckpt_name, model_name, k).replace("model.", "")
if layer:
key = replace_between_dots(key, str(layer), str(layer - pp_layer_range[chunk_id][0]))
# don't chunk dim 1 row tensor (row vector)
if tp_world_size > 1 and (chunk_dim == 0 or (chunk_dim == 1 and v.dim() > 1)):
value = torch.chunk(v, tp_world_size, dim=chunk_dim)[tp_local_rank]
else:
value = v
if not bias or tp_local_rank == 0 or chunk_dim != 1:
current_states[key] = value
else:
print("unknown key: ", k)
return current_states


def obtain_spliting_parameters():
num_layers = gpc.config.model.num_layers
num_chunks = gpc.config.model.num_chunks
if gpc.is_initialized(ParallelMode.TENSOR):
tp_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
else:
tp_world_size = 1
tp_local_rank = 0
if gpc.is_initialized(ParallelMode.PIPELINE):
pp_world_size = gpc.get_world_size(ParallelMode.PIPELINE)
pp_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
else:
pp_world_size = 1
pp_local_rank = 0
# 0 represents num_chunks=1, currently only support num_chunks=1.
assert num_chunks == 1, "May cause future collisions, ignore this if necessary"
pp_layer_range = partition_uniform(num_layers, pp_world_size, num_chunks)
return (
num_layers,
num_chunks,
tp_world_size,
tp_local_rank,
pp_world_size,
pp_local_rank,
pp_layer_range[pp_local_rank],
)


def load_qwen_2_pretrained_weights_dynamic(folder, model, **kwargs): # pylint: disable=W0613
# assert gpc.config.model_type == "QWEN", 'Please use model_type="QWEN" to load qwen huggingface checkpoint.'
# is_st, states = load_ckpt(folder, "gemma", ".ckpt") # torch checkpoint
_, states = load_pp_hf_ckpt(folder) # huggingface checkpoint

if gpc.config.model.num_layers >= 80 and gpc.is_rank_for_log():
logger.warning(
"you are loading a very large huggingface model, it may lead to out of CPU memory,\
You can try to manually let the rank with tp<4 sleep here for 2 minutes."
)
import time

if gpc.get_global_rank() % 8 < 4:
time.sleep(120)

model_state_dict = {}
for key, value in states.items():
if "transformer.h" in key:
model_state_dict[key.replace("transformer.h", "layers")] = value
elif "transformer." in key:
model_state_dict[key.replace("transformer.", "")] = value
elif "post_attention_layernorm." in key:
model_state_dict[key.replace("post_attention_layernorm.", "ln2.")] = value
elif "input_layernorm." in key:
model_state_dict[key.replace("input_layernorm.", "ln1.")] = value
else:
model_state_dict[key] = value
del states

(
num_layers,
num_chunks, # pylint: disable=W0612
tp_world_size,
tp_local_rank,
pp_world_size,
pp_local_rank,
pp_layer_range,
) = obtain_spliting_parameters()
first = pp_local_rank == 0
last = pp_local_rank + 1 == pp_world_size

mappings = [
("self_attn.q_proj", "attention.wq", 0, True),
("self_attn.k_proj", "attention.wk", 0, True),
("self_attn.v_proj", "attention.wv", 0, True),
("self_attn.o_proj", "attention.wo", 1, True),
("mlp.up_proj", "feed_forward.w3", 0, True),
("mlp.gate_proj", "feed_forward.w1", 0, True),
("mlp.down_proj", "feed_forward.w2", 1, True),
("ln1", "attention_norm", -1, True),
("ln2", "ffn_norm", -1, True),
("norm", "norm", -1, last),
("embed_tokens", "tok_embeddings", 1, first),
("lm_head", "output", 0, last),
]

block = model.layers[0]
num_kv_heads = block.attention.num_kv_heads
num_attn_heads = block.attention.num_heads
head_dim = block.attention.head_dim

current_states = get_local_splited_weight(model_state_dict, mappings, pp_layer_range, tp_world_size, tp_local_rank)

missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)

if gpc.get_local_rank(ParallelMode.DATA) == 0:
logger.info(
f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
f"tp:{tp_local_rank}, pp:{pp_local_rank}"
)


LOAD_FUNC_DICT = {
"llama": load_llama_pretrained_weights,
"hf_llama": load_hf_llama_pretrained_weights,
"internlm_test": load_internlm_with_dynamic_parallel_size,
"hf_model": load_hf_model_pretrained_weights,
"hf_qwen2": load_qwen_2_pretrained_weights_dynamic,
}
Loading

0 comments on commit 99b0302

Please sign in to comment.