Skip to content

Commit

Permalink
test sp,sp C60-Ih
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonLeeJSL committed Dec 6, 2023
1 parent be6620a commit 34c2a5c
Showing 1 changed file with 14 additions and 32 deletions.
46 changes: 14 additions & 32 deletions tests/native/xla/eri_kernel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,17 +375,17 @@ def compute_hartree_test(cgto: CGTO, static_args: AngularStats, Mo_coeff_spin):
thread_num = jnp.sum(ab_thread_num)
ab_thread_offset = jnp.concatenate((jnp.array([0]), jnp.cumsum(ab_thread_num)[:-1]), dtype=jnp.int32)
print(i,j,screened_cnt)
# output += abcd_eri_fun(cgto, orig_idx,sorted_ab_idx,
# sorted_cd_idx,
# screened_cd_idx_start,
# # start_offset,
# jnp.sum(cdcd_len-screened_cd_idx_start),
# pgto_idx_to_cgto_idx,
# rdm1,
# thread_load,
# thread_num,
# ab_thread_num,
# ab_thread_offset)
output += abcd_eri_fun(cgto, orig_idx,sorted_ab_idx,
sorted_cd_idx,
screened_cd_idx_start,
# start_offset,
jnp.sum(cdcd_len-screened_cd_idx_start),
pgto_idx_to_cgto_idx,
rdm1,
thread_load,
thread_num,
ab_thread_num,
ab_thread_offset)
# for i in range(6):
# for j in range(i, 6):
# eps = 1e-10
Expand Down Expand Up @@ -501,23 +501,6 @@ def tensorize(
pgto_normalization_factor = jnp.array(cgto.N[orig_idx])

har_jit = jax.jit(hartree)
output = har_jit(N,
jnp.array([thread_load], dtype=jnp.int32),
jnp.array([thread_num], dtype=jnp.int64),
jnp.array([screened_cnt], dtype=jnp.int64),
n, r, z, min_a, min_c, max_ab, max_cd, Ms,
jnp.array(sorted_ab_idx, dtype=jnp.int32),
jnp.array(sorted_cd_idx, dtype=jnp.int32),
jnp.array(screened_cd_idx_start, dtype=jnp.int32),
# jnp.array(start_offset, dtype=jnp.int32),
jnp.array(ab_thread_num, dtype=jnp.int32),
jnp.array(ab_thread_offset, dtype=jnp.int32),
pgto_coeff,
pgto_normalization_factor,
pgto_idx_to_cgto_idx,
rdm1,
jnp.array([cgto.n_cgtos], dtype=jnp.int32),
jnp.array([cgto.n_pgtos], dtype=jnp.int32))

t1 = time.time()
output = har_jit(N,
Expand Down Expand Up @@ -575,11 +558,10 @@ def tensorize(
abcd_idx = idx_counts[:, :4]

har_jit = jax.jit(hartree_uncontracted)
t_abcd = hartree_uncontracted(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms)
t1 = time.time()
for cnt in range(10):
t_abcd = hartree_uncontracted(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms)
jax.block_until_ready(t_abcd)
# for cnt in range(10):
t_abcd = hartree_uncontracted(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms)
jax.block_until_ready(t_abcd)
t2 = time.time()
print(len(t_abcd))
print("abab time =", t2-t1)
Expand Down

0 comments on commit 34c2a5c

Please sign in to comment.