Skip to content

Commit

Permalink
update eri_kernal_test
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonLeeJSL committed Dec 8, 2023
1 parent 34c2a5c commit 4f28f4d
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
1 change: 1 addition & 0 deletions d4ft/native/obara_saika/eri.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ eri(int nax, int nay, int naz, int nbx, int nby, int nbz,
FLOAT I[MAX_CD + 1][MAX_XYZ + 1] = {0};
FLOAT out[MAX_YZ + 1] = {0};
// set I[0] to Boys
// return 1;
for (int i = 0; i <= Ms[0]; ++i) {
I[0][i] = BoysIgamma<FLOAT>(i, T_);
}
Expand Down
8 changes: 3 additions & 5 deletions d4ft/native/obara_saika/eri_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,10 @@ void Hartree_64::Gpu(cudaStream_t stream,
int loc;
loc = ab_thread_offset.ptr[index] + i;
thread_ab_index[loc] = index;
// output.ptr[loc] = sorted_ab_idx.ptr[index]; // ab
// output.ptr[loc + screened_length.ptr[0]] = sorted_cd_idx.ptr[i]; // cd
}
__syncthreads();
});

// get ncd, rcd, zcd in cd order
hemi::parallel_for(ep, 0, sorted_cd_idx.spec->shape[0], [=] HEMI_LAMBDA(int index) {
int cd;
Expand All @@ -160,7 +158,7 @@ void Hartree_64::Gpu(cudaStream_t stream,
cudaMemset(output.ptr, 0, sizeof(double));
// Now we have ab cd, we can compute eri and contract it to output
// For contract, we need 1. count 2. pgto normalization coeff 3. pgto coeff 4.rdm1 (Mocoeff)
hemi::parallel_for(ep, 0, thread_length, [=] HEMI_LAMBDA(int64_t index) {
hemi::parallel_for(ep, 0, thread_length, [=] HEMI_LAMBDA(int index) {
int a, b, c, d; // pgto 4c idx
int i, j, k, l; // cgto 4c idx
int ab_index, cd_index;
Expand Down Expand Up @@ -236,7 +234,7 @@ void Hartree_64::Gpu(cudaStream_t stream,
rbx, rby, rbz, // b
rcx, rcy, rcz, // c
rdx, rdy, rdz, // d
za, zb, zc, zd, // z
za, zb, zc, zd, // z
min_a.ptr, min_c.ptr,
max_ab.ptr, max_cd.ptr, Ms.ptr) * dcount * Na * Nb * Nc * Nd * Ca * Cb * Cc * Cd * Mab * Mcd;
// eri_result += dcount * Na * Nb * Nc * Nd * Ca * Cb * Cc * Cd * Mab * Mcd;
Expand Down
6 changes: 3 additions & 3 deletions tests/native/xla/eri_kernel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def compute_hartree_test(cgto: CGTO, static_args: AngularStats, Mo_coeff_spin):

t1_abcd = time.time()
eri_abab = jnp.array(eri_abab)

# current support s, p, d
s_num = jnp.count_nonzero(l_xyz == 0)
p_num = jnp.count_nonzero(l_xyz == 1)
Expand Down Expand Up @@ -353,7 +353,7 @@ def compute_hartree_test(cgto: CGTO, static_args: AngularStats, Mo_coeff_spin):
# for (ss, ss) (pp, pp) (dd, dd), (sp, sp) ... need ensure idx > cnt. For anyone else, no need
output = 0
eps = 1e-10
i = 1
i = 0
j = 1
thread_load = 2**10
sorted_ab_idx = sorted_idx[i]
Expand Down Expand Up @@ -560,7 +560,7 @@ def tensorize(
har_jit = jax.jit(hartree_uncontracted)
t1 = time.time()
# for cnt in range(10):
t_abcd = hartree_uncontracted(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms)
t_abcd = har_jit(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms)
jax.block_until_ready(t_abcd)
t2 = time.time()
print(len(t_abcd))
Expand Down

0 comments on commit 4f28f4d

Please sign in to comment.