From 4f28f4d2f3e3a122af571d84ae6cca049a880fb7 Mon Sep 17 00:00:00 2001 From: JasonLeeJsL Date: Fri, 8 Dec 2023 15:05:52 +0800 Subject: [PATCH] update eri_kernal_test --- d4ft/native/obara_saika/eri.h | 1 + d4ft/native/obara_saika/eri_kernel.cu | 8 +++----- tests/native/xla/eri_kernel_test.py | 6 +++--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/d4ft/native/obara_saika/eri.h b/d4ft/native/obara_saika/eri.h index 9887b69..5433b1c 100644 --- a/d4ft/native/obara_saika/eri.h +++ b/d4ft/native/obara_saika/eri.h @@ -185,6 +185,7 @@ eri(int nax, int nay, int naz, int nbx, int nby, int nbz, FLOAT I[MAX_CD + 1][MAX_XYZ + 1] = {0}; FLOAT out[MAX_YZ + 1] = {0}; // set I[0] to Boys + // return 1; for (int i = 0; i <= Ms[0]; ++i) { I[0][i] = BoysIgamma(i, T_); } diff --git a/d4ft/native/obara_saika/eri_kernel.cu b/d4ft/native/obara_saika/eri_kernel.cu index 5910112..c6d10eb 100644 --- a/d4ft/native/obara_saika/eri_kernel.cu +++ b/d4ft/native/obara_saika/eri_kernel.cu @@ -128,12 +128,10 @@ void Hartree_64::Gpu(cudaStream_t stream, int loc; loc = ab_thread_offset.ptr[index] + i; thread_ab_index[loc] = index; - // output.ptr[loc] = sorted_ab_idx.ptr[index]; // ab - // output.ptr[loc + screened_length.ptr[0]] = sorted_cd_idx.ptr[i]; // cd } __syncthreads(); }); - + // get ncd, rcd, zcd in cd order hemi::parallel_for(ep, 0, sorted_cd_idx.spec->shape[0], [=] HEMI_LAMBDA(int index) { int cd; @@ -160,7 +158,7 @@ void Hartree_64::Gpu(cudaStream_t stream, cudaMemset(output.ptr, 0, sizeof(double)); // Now we have ab cd, we can compute eri and contract it to output // For contract, we need 1. count 2. pgto normalization coeff 3. pgto coeff 4.rdm1 (Mocoeff) - hemi::parallel_for(ep, 0, thread_length, [=] HEMI_LAMBDA(int64_t index) { + hemi::parallel_for(ep, 0, thread_length, [=] HEMI_LAMBDA(int index) { int a, b, c, d; // pgto 4c idx int i, j, k, l; // cgto 4c idx int ab_index, cd_index; @@ -236,7 +234,7 @@ void Hartree_64::Gpu(cudaStream_t stream, rbx, rby, rbz, // b rcx, rcy, rcz, // c rdx, rdy, rdz, // d - za, zb, zc, zd, // z + za, zb, zc, zd, // z min_a.ptr, min_c.ptr, max_ab.ptr, max_cd.ptr, Ms.ptr) * dcount * Na * Nb * Nc * Nd * Ca * Cb * Cc * Cd * Mab * Mcd; // eri_result += dcount * Na * Nb * Nc * Nd * Ca * Cb * Cc * Cd * Mab * Mcd; diff --git a/tests/native/xla/eri_kernel_test.py b/tests/native/xla/eri_kernel_test.py index 67128ac..8876ff2 100644 --- a/tests/native/xla/eri_kernel_test.py +++ b/tests/native/xla/eri_kernel_test.py @@ -287,7 +287,7 @@ def compute_hartree_test(cgto: CGTO, static_args: AngularStats, Mo_coeff_spin): t1_abcd = time.time() eri_abab = jnp.array(eri_abab) - + # current support s, p, d s_num = jnp.count_nonzero(l_xyz == 0) p_num = jnp.count_nonzero(l_xyz == 1) @@ -353,7 +353,7 @@ def compute_hartree_test(cgto: CGTO, static_args: AngularStats, Mo_coeff_spin): # for (ss, ss) (pp, pp) (dd, dd), (sp, sp) ... need ensure idx > cnt. For anyone else, no need output = 0 eps = 1e-10 - i = 1 + i = 0 j = 1 thread_load = 2**10 sorted_ab_idx = sorted_idx[i] @@ -560,7 +560,7 @@ def tensorize( har_jit = jax.jit(hartree_uncontracted) t1 = time.time() # for cnt in range(10): - t_abcd = hartree_uncontracted(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms) + t_abcd = har_jit(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms) jax.block_until_ready(t_abcd) t2 = time.time() print(len(t_abcd))