diff --git a/tests/native/xla/eri_kernel_test.py b/tests/native/xla/eri_kernel_test.py index 1fbe43d..67128ac 100644 --- a/tests/native/xla/eri_kernel_test.py +++ b/tests/native/xla/eri_kernel_test.py @@ -375,17 +375,17 @@ def compute_hartree_test(cgto: CGTO, static_args: AngularStats, Mo_coeff_spin): thread_num = jnp.sum(ab_thread_num) ab_thread_offset = jnp.concatenate((jnp.array([0]), jnp.cumsum(ab_thread_num)[:-1]), dtype=jnp.int32) print(i,j,screened_cnt) - # output += abcd_eri_fun(cgto, orig_idx,sorted_ab_idx, - # sorted_cd_idx, - # screened_cd_idx_start, - # # start_offset, - # jnp.sum(cdcd_len-screened_cd_idx_start), - # pgto_idx_to_cgto_idx, - # rdm1, - # thread_load, - # thread_num, - # ab_thread_num, - # ab_thread_offset) + output += abcd_eri_fun(cgto, orig_idx,sorted_ab_idx, + sorted_cd_idx, + screened_cd_idx_start, + # start_offset, + jnp.sum(cdcd_len-screened_cd_idx_start), + pgto_idx_to_cgto_idx, + rdm1, + thread_load, + thread_num, + ab_thread_num, + ab_thread_offset) # for i in range(6): # for j in range(i, 6): # eps = 1e-10 @@ -501,23 +501,6 @@ def tensorize( pgto_normalization_factor = jnp.array(cgto.N[orig_idx]) har_jit = jax.jit(hartree) - output = har_jit(N, - jnp.array([thread_load], dtype=jnp.int32), - jnp.array([thread_num], dtype=jnp.int64), - jnp.array([screened_cnt], dtype=jnp.int64), - n, r, z, min_a, min_c, max_ab, max_cd, Ms, - jnp.array(sorted_ab_idx, dtype=jnp.int32), - jnp.array(sorted_cd_idx, dtype=jnp.int32), - jnp.array(screened_cd_idx_start, dtype=jnp.int32), - # jnp.array(start_offset, dtype=jnp.int32), - jnp.array(ab_thread_num, dtype=jnp.int32), - jnp.array(ab_thread_offset, dtype=jnp.int32), - pgto_coeff, - pgto_normalization_factor, - pgto_idx_to_cgto_idx, - rdm1, - jnp.array([cgto.n_cgtos], dtype=jnp.int32), - jnp.array([cgto.n_pgtos], dtype=jnp.int32)) t1 = time.time() output = har_jit(N, @@ -575,11 +558,10 @@ def tensorize( abcd_idx = idx_counts[:, :4] har_jit = jax.jit(hartree_uncontracted) - t_abcd = hartree_uncontracted(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms) t1 = time.time() - for cnt in range(10): - t_abcd = hartree_uncontracted(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms) - jax.block_until_ready(t_abcd) + # for cnt in range(10): + t_abcd = hartree_uncontracted(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms) + jax.block_until_ready(t_abcd) t2 = time.time() print(len(t_abcd)) print("abab time =", t2-t1)