forked from InternLM/InternEvo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4b9593e
commit f0af50b
Showing
2 changed files
with
148 additions
and
218 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,249 +1,135 @@ | ||
""" | ||
TODO: add NPU CI | ||
""" | ||
|
||
import copy | ||
import math | ||
import multiprocessing as mp | ||
from functools import partial | ||
|
||
import pytest | ||
import torch | ||
import torch.nn.functional as F | ||
from einops import rearrange | ||
from torch import nn | ||
|
||
from internlm.accelerator import AcceleratorType, get_accelerator | ||
from internlm.model.modules.multi_head_attention import ( | ||
AscendFlashSelfAttention, | ||
CrossAttention, | ||
SelfAttention, | ||
) | ||
|
||
HEAD_NUM = 32 | ||
HIDDEN_SZIE = 4096 | ||
SEQ_LEN = 2048 | ||
MICRO_BSZ = 1 | ||
HEAD_DIM = HIDDEN_SZIE // HEAD_NUM | ||
VOCAB_SIZE = 32000 | ||
|
||
MICRO_BSZ_LIST = [1, 2] | ||
DTYPE_LIST = [torch.bfloat16, torch.float16] | ||
NUM_KV_HEAD_LIST = [8, 32] | ||
USE_PADDING = [True, False] | ||
|
||
internlm_accelerator = get_accelerator() | ||
|
||
|
||
def check_mean_and_std(name, out1, out2): | ||
named1_mean = out1.to(dtype=torch.float64).mean() | ||
named1_std = out1.to(dtype=torch.float64).std() | ||
named2_mean = out2.to(dtype=torch.float64).mean() | ||
named2_std = out2.to(dtype=torch.float64).std() | ||
check_statistic_equality(name, named1_mean, named2_mean, eq=True, is_mean=True) | ||
check_statistic_equality(name, named1_std, named2_std, eq=True, is_mean=False) | ||
|
||
|
||
def check_statistic_equality(name, value1, value2, eq=False, is_mean=True, threshold=1e-9): | ||
if (abs(value1 - value2) < threshold) ^ eq: | ||
if eq: | ||
print( | ||
f"On {name}, " | ||
f"we have {'mean' if is_mean else 'std'}s of fa_out " | ||
f"very {'close' if not eq else 'different'}, " | ||
f"from :{value1} " | ||
f"and :{value2}", | ||
flush=True, | ||
) | ||
else: | ||
print( | ||
f"On {name}, " | ||
f"we have {'mean' if is_mean else 'std'}s of fa_out " | ||
f"very {'close' if not eq else 'different'}, " | ||
f"from :{value1} " | ||
f"and :{value2}", | ||
flush=True, | ||
) | ||
|
||
|
||
def do_cmp_attn( | ||
name, | ||
B, # pylint: disable=W0613 | ||
S, # pylint: disable=W0613 | ||
N, | ||
N_KV, | ||
q, | ||
k, | ||
v, | ||
dtype, | ||
attention_mask, # pylint: disable=W0613 | ||
softmax_scale, | ||
attention_dropout=0.0, | ||
**attn_args, # pylint: disable=W0613 | ||
): | ||
|
||
npu_attn_cls = CrossAttention if N != N_KV else SelfAttention | ||
npu_attn = npu_attn_cls( | ||
causal=True, | ||
softmax_scale=softmax_scale, | ||
attention_dropout=attention_dropout, | ||
).to(dtype) | ||
# TODO: 修复它. | ||
npu_flash_attn = AscendFlashSelfAttention( | ||
causal=True, | ||
softmax_scale=softmax_scale, | ||
attention_dropout=attention_dropout, | ||
).to(dtype) | ||
|
||
if N == N_KV: | ||
a = npu_attn(torch.concat([q, k, v], dim=2)) # pylint: disable=E1102 | ||
else: | ||
a = npu_attn(q.squeeze(dim=2), torch.concat([k, v], dim=2)) # pylint: disable=E1102 | ||
from internlm.core.naive_amp import NaiveAMPModel, set_fp32_attr_to_module | ||
from internlm.model.modeling_internlm import InternLM1Decoder | ||
from internlm.train.pipeline import initialize_parallel_communicator | ||
from internlm.train.utils import create_param_groups | ||
from internlm.utils.common import get_current_device | ||
from tests.common_fixture import find_free_port | ||
from tests.test_model.test_model_internlm import build_environment, seed_all | ||
|
||
b = npu_flash_attn(q=q, k=k, v=v) # pylint: disable=E1102 | ||
assert torch.isfinite(a).all().item() and torch.isfinite(b).all().item() | ||
|
||
if dtype == torch.bfloat16: | ||
# torch_npu's equal not support bfloat16 by now. | ||
assert torch.allclose(a.to(torch.float32), b.to(torch.float32), atol=5e-2, rtol=1e-4), f"{name} not pass" | ||
else: | ||
assert torch.allclose(a, b, atol=5e-2, rtol=1e-4), f"{name} not pass" | ||
def _pre_forward_hook_for_check(model, inputs): # pylint: disable=W0613 | ||
assert all(_.dtype == torch.float32 for _ in inputs) | ||
|
||
|
||
def npu_transform(B, S, N, N_KV, D, dtype, use_padding): | ||
if use_padding: | ||
x = torch.LongTensor([[i + 1 if i < S // 2 else 0 for i in range(S)] for _ in range(B)]).npu() # padding S-1024 | ||
def _post_forward_hook_for_check(model, inputs, outputs): # pylint: disable=W0613 | ||
if isinstance(outputs, tuple): | ||
assert all(_.dtype == torch.half for _ in outputs) | ||
else: | ||
x = torch.LongTensor([[i + 1 for i in range(S)] for _ in range(B)]).npu() # no-padiing | ||
|
||
wq = torch.zeros((N * D, N * D), dtype=dtype, device="npu") | ||
wk = torch.zeros((N_KV * D, N * D), dtype=dtype, device="npu") | ||
wv = torch.zeros((N_KV * D, N * D), dtype=dtype, device="npu") | ||
wembed = torch.zeros((VOCAB_SIZE, HIDDEN_SZIE), dtype=dtype, device="npu") | ||
|
||
# It is very important to set appropriate initialization values for parameters so | ||
# that the values fall within an appropriate precision range to prevent overflow or underflow. | ||
with torch.no_grad(): | ||
wq = nn.init.normal_(wq.data) | ||
wk = nn.init.normal_(wk.data) | ||
wv = nn.init.normal_(wv.data) | ||
wembed = nn.init.normal_(wembed.data, std=0.02) | ||
|
||
embed_x = F.embedding(x, wembed).to(dtype) | ||
q = F.linear(embed_x, wq) # pylint: disable=E1102 | ||
k = F.linear(embed_x, wk) # pylint: disable=E1102 | ||
v = F.linear(embed_x, wv) # pylint: disable=E1102 | ||
|
||
q = rearrange(q, "b s (one h d) -> b s one h d", b=B, s=S, d=D, one=1) | ||
k = rearrange(k, "b s (one h d) -> b s one h d", b=B, s=S, d=D, one=1) | ||
v = rearrange(v, "b s (one h d) -> b s one h d", b=B, s=S, d=D, one=1) | ||
|
||
do_cmp_attn( | ||
f"B_{B}_S_{S}_N_{N}_N_KV_{N_KV}_D_{D}_{dtype}", | ||
B, | ||
S, | ||
N, | ||
N_KV, | ||
q, | ||
k, | ||
v, | ||
dtype, | ||
None, | ||
1 / math.sqrt(HIDDEN_SZIE // HEAD_NUM), | ||
) | ||
assert outputs.dtype == torch.half | ||
|
||
|
||
def check_RMSNormNPU(): | ||
def check_fused_precision(args): | ||
# init | ||
rank, world_size, free_port = args | ||
build_environment(rank, world_size, free_port) | ||
device = get_current_device() | ||
input_data = torch.randn(128).to(torch.float32).to(device) | ||
input_data_2 = input_data.clone().detach() | ||
|
||
rmsnorm_torch = RMSNormTorch(128, eps=1e-5).to(torch.bfloat16).to(device) | ||
output_torch = rmsnorm_torch(input_data) | ||
# fix seed | ||
seed_all(1024) | ||
# define model | ||
model = InternLM1Decoder( | ||
hidden_size=16, # 768 | ||
num_attention_heads=2, # 12 | ||
mlp_ratio=2, | ||
attn_drop_rate=0.0, | ||
drop_rate=0.0, | ||
dtype=torch.bfloat16, | ||
layer_norm_epsilon=1e-5, | ||
checkpoint=False, | ||
layer_idx=0, | ||
residual_in_fp32=False, | ||
device=device, | ||
norm_type="rmsnorm", | ||
dropout_selective_checkpoint=True, | ||
use_scaled_init=True, | ||
use_swiglu=True, | ||
) | ||
model = model.to(device) | ||
set_fp32_attr_to_module(model.norm1) | ||
model = NaiveAMPModel( | ||
model=model, | ||
output_to_fp32=True, | ||
dtype=torch.half, | ||
sync_buffer=False, | ||
) | ||
_ = initialize_parallel_communicator(model) | ||
model.model.norm1.register_forward_pre_hook(partial(_pre_forward_hook_for_check)) | ||
model.model.norm1.register_forward_hook(partial(_post_forward_hook_for_check)) | ||
|
||
rmsnorm_npu = RMSNormNPU(128, eps=1e-5).to(torch.bfloat16).to(device) | ||
output_npu = rmsnorm_npu(input_data_2) | ||
hidden_states = torch.rand(1, 1, 16).to(device).requires_grad_() | ||
|
||
if torch.equal(output_torch, output_npu): | ||
print("RMSNorm check passed: totaly equal", flush=True) | ||
else: | ||
max_diff, index_max_diff = (output_torch - output_npu).abs().max(dim=0) | ||
max_diff = max_diff.item() | ||
index_max_diff = index_max_diff.item() | ||
rtol = max_diff / abs(output_npu[index_max_diff]) | ||
print( | ||
f"The relative error is {rtol}. Between {output_torch[index_max_diff]} and {output_npu[index_max_diff]}", | ||
flush=True, | ||
) | ||
assert rtol <= 1e-5, f"RMSNorm check failed: The relative error is {rtol}" | ||
print("RMSNorm check passed: allclose", flush=True) | ||
# forward | ||
model(hidden_states) | ||
|
||
|
||
def check_AdamW(): | ||
class MlpModel(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.linear1 = nn.Linear(128, 256) | ||
self.linear2 = nn.Linear(256, 512) | ||
class MlpModel(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.linear1 = nn.Linear(4, 1, bias=False).half() | ||
self.linear2 = nn.Linear(1, 4, bias=False).float() | ||
|
||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = self.linear2(x) | ||
return x | ||
|
||
def check_split_fused_group(args): | ||
# init | ||
rank, world_size, free_port = args | ||
build_environment(rank, world_size, free_port) | ||
device = get_current_device() | ||
dtype = torch.bfloat16 | ||
input_data = torch.rand(16, 128, dtype=dtype).to(device) | ||
torch_model = MlpModel().to(dtype).to(get_current_device()) | ||
npu_model = copy.deepcopy(torch_model) | ||
|
||
adamW_torch = torch.optim.AdamW( | ||
params=torch_model.parameters(), | ||
lr=1e-4, | ||
betas=(0.9, 0.95), | ||
eps=1e-8, | ||
) | ||
|
||
adamW_npu = NPUFusedAdamW( | ||
params=npu_model.parameters(), | ||
lr=1e-4, | ||
betas=(0.9, 0.95), | ||
eps=1e-8, | ||
rtol, atol = (1e-3, 5e-3) | ||
|
||
# fix seed | ||
seed_all(1024) | ||
# define model | ||
model = MlpModel().to(device) | ||
groups = create_param_groups(model, weight_decay=0.05) | ||
|
||
standard_group = ( | ||
{ | ||
"name": "default", | ||
"params": [torch.Tensor([[0.3088, 0.2935, -0.2900, 0.4280]]).to(torch.float16).to(device).requires_grad_()], | ||
"weight_decay": 0.05, | ||
}, | ||
{ | ||
"name": "fp32", | ||
"params": [torch.Tensor([[0.6273], [0.4844], [-0.0463], [-0.0090]]).to(device).requires_grad_()], | ||
"weight_decay": 0.05, | ||
}, | ||
) | ||
|
||
adamW_torch.zero_grad() | ||
adamW_npu.zero_grad() | ||
|
||
output_torch = torch_model(input_data) | ||
output_npu = npu_model(input_data) | ||
|
||
output_torch.mean().backward() | ||
output_npu.mean().backward() | ||
|
||
adamW_torch.step() | ||
adamW_npu.step() | ||
|
||
params_zip = zip(list(torch_model.parameters()), list(npu_model.parameters())) | ||
for torch_param, npu_param in params_zip: | ||
assert torch.allclose(torch_param, npu_param, rtol=1e-5, atol=1e-5) | ||
|
||
|
||
@pytest.mark.parametrize("micro_bsz", MICRO_BSZ_LIST) | ||
@pytest.mark.parametrize("test_dtype", DTYPE_LIST) | ||
@pytest.mark.parametrize("num_kv_head", NUM_KV_HEAD_LIST) | ||
@pytest.mark.parametrize("use_padding", USE_PADDING) | ||
def test_NPU_fa(micro_bsz, test_dtype, num_kv_head, use_padding): | ||
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: | ||
npu_transform(micro_bsz, SEQ_LEN, HEAD_NUM, num_kv_head, HIDDEN_SZIE // HEAD_NUM, test_dtype, use_padding) | ||
# check groups params | ||
for t1, t2 in zip(groups, standard_group): | ||
# assert t1["name"] == t2["name"] | ||
assert all( | ||
torch.allclose(p1, p2, rtol=rtol, atol=atol, equal_nan=True) for p1, p2 in zip(t1["params"], t2["params"]) | ||
) | ||
|
||
|
||
def test_RMSNorm(): | ||
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: | ||
check_RMSNormNPU() | ||
@pytest.mark.fused_precision | ||
def test_fused_precision(): | ||
ctx = mp.get_context("spawn") | ||
free_port = str(find_free_port()) | ||
with ctx.Pool(processes=8) as pool: | ||
pool.map(check_fused_precision, [[rank, 8, free_port] for rank in range(8)]) | ||
pool.close() | ||
pool.join() | ||
|
||
|
||
def test_AdamW(): | ||
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: | ||
check_AdamW() | ||
@pytest.mark.split_groups | ||
def test_split_fused_groups(): | ||
ctx = mp.get_context("spawn") | ||
free_port = str(find_free_port()) | ||
with ctx.Pool(processes=8) as pool: | ||
pool.map(check_split_fused_group, [[rank, 8, free_port] for rank in range(8)]) | ||
pool.close() | ||
pool.join() | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main(["-s", "-q", "test_npu_ops.py"]) | ||
pytest.main(["-s", "-q", "test_norm.py"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import pytest | ||
import torch | ||
|
||
from internlm.accelerator import AcceleratorType, get_accelerator | ||
from internlm.model.ops.norm import _RMSNorm as RMSNormTorch | ||
from internlm.model.ops.norm import _RMSNormNPU as RMSNormNPU | ||
from internlm.utils.common import get_current_device | ||
|
||
internlm_accelerator = get_accelerator() | ||
|
||
|
||
def check_RMSNormNPU(): | ||
device = get_current_device() | ||
input_data = torch.randn(128).to(torch.float32).to(device) | ||
input_data_2 = input_data.clone().detach() | ||
|
||
rmsnorm_torch = RMSNormTorch(128, eps=1e-5).to(torch.bfloat16).to(device) | ||
output_torch = rmsnorm_torch(input_data) | ||
|
||
rmsnorm_npu = RMSNormNPU(128, eps=1e-5).to(torch.bfloat16).to(device) | ||
output_npu = rmsnorm_npu(input_data_2) | ||
|
||
if torch.equal(output_torch, output_npu): | ||
print("RMSNorm check passed: totaly equal", flush=True) | ||
else: | ||
max_diff, index_max_diff = (output_torch - output_npu).abs().max(dim=0) | ||
max_diff = max_diff.item() | ||
index_max_diff = index_max_diff.item() | ||
rtol = max_diff / abs(output_npu[index_max_diff]) | ||
print( | ||
f"The relative error is {rtol}. Between {output_torch[index_max_diff]} and {output_npu[index_max_diff]}", | ||
flush=True, | ||
) | ||
assert rtol <= 1e-5, f"RMSNorm check failed: The relative error is {rtol}" | ||
print("RMSNorm check passed: allclose", flush=True) | ||
|
||
|
||
def test_RMSNorm(): | ||
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: | ||
check_RMSNormNPU() | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main(["-s", "-q", "test_npu_ops.py"]) |