Skip to content

Commit

Permalink
fix ci
Browse files Browse the repository at this point in the history
  • Loading branch information
SolenoidWGT committed May 21, 2024
1 parent 4b9593e commit f0af50b
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 218 deletions.
322 changes: 104 additions & 218 deletions tests/test_model/test_npu_ops.py
Original file line number Diff line number Diff line change
@@ -1,249 +1,135 @@
"""
TODO: add NPU CI
"""

import copy
import math
import multiprocessing as mp
from functools import partial

import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn

from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.model.modules.multi_head_attention import (
AscendFlashSelfAttention,
CrossAttention,
SelfAttention,
)

HEAD_NUM = 32
HIDDEN_SZIE = 4096
SEQ_LEN = 2048
MICRO_BSZ = 1
HEAD_DIM = HIDDEN_SZIE // HEAD_NUM
VOCAB_SIZE = 32000

MICRO_BSZ_LIST = [1, 2]
DTYPE_LIST = [torch.bfloat16, torch.float16]
NUM_KV_HEAD_LIST = [8, 32]
USE_PADDING = [True, False]

internlm_accelerator = get_accelerator()


def check_mean_and_std(name, out1, out2):
named1_mean = out1.to(dtype=torch.float64).mean()
named1_std = out1.to(dtype=torch.float64).std()
named2_mean = out2.to(dtype=torch.float64).mean()
named2_std = out2.to(dtype=torch.float64).std()
check_statistic_equality(name, named1_mean, named2_mean, eq=True, is_mean=True)
check_statistic_equality(name, named1_std, named2_std, eq=True, is_mean=False)


def check_statistic_equality(name, value1, value2, eq=False, is_mean=True, threshold=1e-9):
if (abs(value1 - value2) < threshold) ^ eq:
if eq:
print(
f"On {name}, "
f"we have {'mean' if is_mean else 'std'}s of fa_out "
f"very {'close' if not eq else 'different'}, "
f"from :{value1} "
f"and :{value2}",
flush=True,
)
else:
print(
f"On {name}, "
f"we have {'mean' if is_mean else 'std'}s of fa_out "
f"very {'close' if not eq else 'different'}, "
f"from :{value1} "
f"and :{value2}",
flush=True,
)


def do_cmp_attn(
name,
B, # pylint: disable=W0613
S, # pylint: disable=W0613
N,
N_KV,
q,
k,
v,
dtype,
attention_mask, # pylint: disable=W0613
softmax_scale,
attention_dropout=0.0,
**attn_args, # pylint: disable=W0613
):

npu_attn_cls = CrossAttention if N != N_KV else SelfAttention
npu_attn = npu_attn_cls(
causal=True,
softmax_scale=softmax_scale,
attention_dropout=attention_dropout,
).to(dtype)
# TODO: 修复它.
npu_flash_attn = AscendFlashSelfAttention(
causal=True,
softmax_scale=softmax_scale,
attention_dropout=attention_dropout,
).to(dtype)

if N == N_KV:
a = npu_attn(torch.concat([q, k, v], dim=2)) # pylint: disable=E1102
else:
a = npu_attn(q.squeeze(dim=2), torch.concat([k, v], dim=2)) # pylint: disable=E1102
from internlm.core.naive_amp import NaiveAMPModel, set_fp32_attr_to_module
from internlm.model.modeling_internlm import InternLM1Decoder
from internlm.train.pipeline import initialize_parallel_communicator
from internlm.train.utils import create_param_groups
from internlm.utils.common import get_current_device
from tests.common_fixture import find_free_port
from tests.test_model.test_model_internlm import build_environment, seed_all

b = npu_flash_attn(q=q, k=k, v=v) # pylint: disable=E1102
assert torch.isfinite(a).all().item() and torch.isfinite(b).all().item()

if dtype == torch.bfloat16:
# torch_npu's equal not support bfloat16 by now.
assert torch.allclose(a.to(torch.float32), b.to(torch.float32), atol=5e-2, rtol=1e-4), f"{name} not pass"
else:
assert torch.allclose(a, b, atol=5e-2, rtol=1e-4), f"{name} not pass"
def _pre_forward_hook_for_check(model, inputs): # pylint: disable=W0613
assert all(_.dtype == torch.float32 for _ in inputs)


def npu_transform(B, S, N, N_KV, D, dtype, use_padding):
if use_padding:
x = torch.LongTensor([[i + 1 if i < S // 2 else 0 for i in range(S)] for _ in range(B)]).npu() # padding S-1024
def _post_forward_hook_for_check(model, inputs, outputs): # pylint: disable=W0613
if isinstance(outputs, tuple):
assert all(_.dtype == torch.half for _ in outputs)
else:
x = torch.LongTensor([[i + 1 for i in range(S)] for _ in range(B)]).npu() # no-padiing

wq = torch.zeros((N * D, N * D), dtype=dtype, device="npu")
wk = torch.zeros((N_KV * D, N * D), dtype=dtype, device="npu")
wv = torch.zeros((N_KV * D, N * D), dtype=dtype, device="npu")
wembed = torch.zeros((VOCAB_SIZE, HIDDEN_SZIE), dtype=dtype, device="npu")

# It is very important to set appropriate initialization values for parameters so
# that the values fall within an appropriate precision range to prevent overflow or underflow.
with torch.no_grad():
wq = nn.init.normal_(wq.data)
wk = nn.init.normal_(wk.data)
wv = nn.init.normal_(wv.data)
wembed = nn.init.normal_(wembed.data, std=0.02)

embed_x = F.embedding(x, wembed).to(dtype)
q = F.linear(embed_x, wq) # pylint: disable=E1102
k = F.linear(embed_x, wk) # pylint: disable=E1102
v = F.linear(embed_x, wv) # pylint: disable=E1102

q = rearrange(q, "b s (one h d) -> b s one h d", b=B, s=S, d=D, one=1)
k = rearrange(k, "b s (one h d) -> b s one h d", b=B, s=S, d=D, one=1)
v = rearrange(v, "b s (one h d) -> b s one h d", b=B, s=S, d=D, one=1)

do_cmp_attn(
f"B_{B}_S_{S}_N_{N}_N_KV_{N_KV}_D_{D}_{dtype}",
B,
S,
N,
N_KV,
q,
k,
v,
dtype,
None,
1 / math.sqrt(HIDDEN_SZIE // HEAD_NUM),
)
assert outputs.dtype == torch.half


def check_RMSNormNPU():
def check_fused_precision(args):
# init
rank, world_size, free_port = args
build_environment(rank, world_size, free_port)
device = get_current_device()
input_data = torch.randn(128).to(torch.float32).to(device)
input_data_2 = input_data.clone().detach()

rmsnorm_torch = RMSNormTorch(128, eps=1e-5).to(torch.bfloat16).to(device)
output_torch = rmsnorm_torch(input_data)
# fix seed
seed_all(1024)
# define model
model = InternLM1Decoder(
hidden_size=16, # 768
num_attention_heads=2, # 12
mlp_ratio=2,
attn_drop_rate=0.0,
drop_rate=0.0,
dtype=torch.bfloat16,
layer_norm_epsilon=1e-5,
checkpoint=False,
layer_idx=0,
residual_in_fp32=False,
device=device,
norm_type="rmsnorm",
dropout_selective_checkpoint=True,
use_scaled_init=True,
use_swiglu=True,
)
model = model.to(device)
set_fp32_attr_to_module(model.norm1)
model = NaiveAMPModel(
model=model,
output_to_fp32=True,
dtype=torch.half,
sync_buffer=False,
)
_ = initialize_parallel_communicator(model)
model.model.norm1.register_forward_pre_hook(partial(_pre_forward_hook_for_check))
model.model.norm1.register_forward_hook(partial(_post_forward_hook_for_check))

rmsnorm_npu = RMSNormNPU(128, eps=1e-5).to(torch.bfloat16).to(device)
output_npu = rmsnorm_npu(input_data_2)
hidden_states = torch.rand(1, 1, 16).to(device).requires_grad_()

if torch.equal(output_torch, output_npu):
print("RMSNorm check passed: totaly equal", flush=True)
else:
max_diff, index_max_diff = (output_torch - output_npu).abs().max(dim=0)
max_diff = max_diff.item()
index_max_diff = index_max_diff.item()
rtol = max_diff / abs(output_npu[index_max_diff])
print(
f"The relative error is {rtol}. Between {output_torch[index_max_diff]} and {output_npu[index_max_diff]}",
flush=True,
)
assert rtol <= 1e-5, f"RMSNorm check failed: The relative error is {rtol}"
print("RMSNorm check passed: allclose", flush=True)
# forward
model(hidden_states)


def check_AdamW():
class MlpModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
class MlpModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(4, 1, bias=False).half()
self.linear2 = nn.Linear(1, 4, bias=False).float()

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x

def check_split_fused_group(args):
# init
rank, world_size, free_port = args
build_environment(rank, world_size, free_port)
device = get_current_device()
dtype = torch.bfloat16
input_data = torch.rand(16, 128, dtype=dtype).to(device)
torch_model = MlpModel().to(dtype).to(get_current_device())
npu_model = copy.deepcopy(torch_model)

adamW_torch = torch.optim.AdamW(
params=torch_model.parameters(),
lr=1e-4,
betas=(0.9, 0.95),
eps=1e-8,
)

adamW_npu = NPUFusedAdamW(
params=npu_model.parameters(),
lr=1e-4,
betas=(0.9, 0.95),
eps=1e-8,
rtol, atol = (1e-3, 5e-3)

# fix seed
seed_all(1024)
# define model
model = MlpModel().to(device)
groups = create_param_groups(model, weight_decay=0.05)

standard_group = (
{
"name": "default",
"params": [torch.Tensor([[0.3088, 0.2935, -0.2900, 0.4280]]).to(torch.float16).to(device).requires_grad_()],
"weight_decay": 0.05,
},
{
"name": "fp32",
"params": [torch.Tensor([[0.6273], [0.4844], [-0.0463], [-0.0090]]).to(device).requires_grad_()],
"weight_decay": 0.05,
},
)

adamW_torch.zero_grad()
adamW_npu.zero_grad()

output_torch = torch_model(input_data)
output_npu = npu_model(input_data)

output_torch.mean().backward()
output_npu.mean().backward()

adamW_torch.step()
adamW_npu.step()

params_zip = zip(list(torch_model.parameters()), list(npu_model.parameters()))
for torch_param, npu_param in params_zip:
assert torch.allclose(torch_param, npu_param, rtol=1e-5, atol=1e-5)


@pytest.mark.parametrize("micro_bsz", MICRO_BSZ_LIST)
@pytest.mark.parametrize("test_dtype", DTYPE_LIST)
@pytest.mark.parametrize("num_kv_head", NUM_KV_HEAD_LIST)
@pytest.mark.parametrize("use_padding", USE_PADDING)
def test_NPU_fa(micro_bsz, test_dtype, num_kv_head, use_padding):
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU:
npu_transform(micro_bsz, SEQ_LEN, HEAD_NUM, num_kv_head, HIDDEN_SZIE // HEAD_NUM, test_dtype, use_padding)
# check groups params
for t1, t2 in zip(groups, standard_group):
# assert t1["name"] == t2["name"]
assert all(
torch.allclose(p1, p2, rtol=rtol, atol=atol, equal_nan=True) for p1, p2 in zip(t1["params"], t2["params"])
)


def test_RMSNorm():
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU:
check_RMSNormNPU()
@pytest.mark.fused_precision
def test_fused_precision():
ctx = mp.get_context("spawn")
free_port = str(find_free_port())
with ctx.Pool(processes=8) as pool:
pool.map(check_fused_precision, [[rank, 8, free_port] for rank in range(8)])
pool.close()
pool.join()


def test_AdamW():
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU:
check_AdamW()
@pytest.mark.split_groups
def test_split_fused_groups():
ctx = mp.get_context("spawn")
free_port = str(find_free_port())
with ctx.Pool(processes=8) as pool:
pool.map(check_split_fused_group, [[rank, 8, free_port] for rank in range(8)])
pool.close()
pool.join()


if __name__ == "__main__":
pytest.main(["-s", "-q", "test_npu_ops.py"])
pytest.main(["-s", "-q", "test_norm.py"])
44 changes: 44 additions & 0 deletions tests/test_model/test_npu_ops/test_npu_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
import torch

from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.model.ops.norm import _RMSNorm as RMSNormTorch
from internlm.model.ops.norm import _RMSNormNPU as RMSNormNPU
from internlm.utils.common import get_current_device

internlm_accelerator = get_accelerator()


def check_RMSNormNPU():
device = get_current_device()
input_data = torch.randn(128).to(torch.float32).to(device)
input_data_2 = input_data.clone().detach()

rmsnorm_torch = RMSNormTorch(128, eps=1e-5).to(torch.bfloat16).to(device)
output_torch = rmsnorm_torch(input_data)

rmsnorm_npu = RMSNormNPU(128, eps=1e-5).to(torch.bfloat16).to(device)
output_npu = rmsnorm_npu(input_data_2)

if torch.equal(output_torch, output_npu):
print("RMSNorm check passed: totaly equal", flush=True)
else:
max_diff, index_max_diff = (output_torch - output_npu).abs().max(dim=0)
max_diff = max_diff.item()
index_max_diff = index_max_diff.item()
rtol = max_diff / abs(output_npu[index_max_diff])
print(
f"The relative error is {rtol}. Between {output_torch[index_max_diff]} and {output_npu[index_max_diff]}",
flush=True,
)
assert rtol <= 1e-5, f"RMSNorm check failed: The relative error is {rtol}"
print("RMSNorm check passed: allclose", flush=True)


def test_RMSNorm():
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU:
check_RMSNormNPU()


if __name__ == "__main__":
pytest.main(["-s", "-q", "test_npu_ops.py"])

0 comments on commit f0af50b

Please sign in to comment.