diff --git a/d4ft/integral/gto/tensorization.py b/d4ft/integral/gto/tensorization.py index be1c076..5cc2b45 100644 --- a/d4ft/integral/gto/tensorization.py +++ b/d4ft/integral/gto/tensorization.py @@ -22,15 +22,19 @@ from d4ft.integral.gto import symmetry from d4ft.integral.gto.cgto import CGTO, PGTO from d4ft.types import IdxCount2C, IdxCount4C -from d4ft.native.obara_saika.eri_kernel import _Hartree_32, _Hartree_64 +from d4ft.native.obara_saika.eri_kernel import _Hartree_32, _Hartree_32_uncontracted, _Hartree_64, _Hartree_64_uncontracted from d4ft.native.xla.custom_call import CustomCallMeta Hartree_64 = CustomCallMeta("Hartree_64", (_Hartree_64,), {}) +Hartree_64_uncontracted = CustomCallMeta("Hartree_64_uncontracted", (_Hartree_64_uncontracted,), {}) Hartree_32 = CustomCallMeta("Hartree_32", (_Hartree_32,), {}) +Hartree_32_uncontracted = CustomCallMeta("Hartree_32_uncontracted", (_Hartree_32_uncontracted,), {}) if jax.config.jax_enable_x64: hartree = Hartree_64() + hartree_uncontracted = Hartree_64_uncontracted() else: hartree = Hartree_32() + hartree_uncontracted = Hartree_32_uncontracted() def tensorize_2c_cgto(f: Callable, static_args, cgto: bool = True): """2c centers tensorization with provided index set, @@ -129,25 +133,29 @@ def tensorize( cgto_seg_id, n_segs: int, ): - Ns = gtos.N N = gtos.n_pgtos + Ns = gtos.N + # Why: Reshape n r z to 1D will significantly reduce computing time n = jnp.array(gtos.pgto.angular.T, dtype=jnp.int32) r = jnp.array(gtos.pgto.center.T) z = jnp.array(gtos.pgto.exponent) + min_a = jnp.array(static_args.min_a, dtype=jnp.int32) min_c = jnp.array(static_args.min_c, dtype=jnp.int32) max_ab = jnp.array(static_args.max_ab, dtype=jnp.int32) max_cd = jnp.array(static_args.max_cd, dtype=jnp.int32) Ms = jnp.array([static_args.max_xyz+1, static_args.max_yz+1, static_args.max_z+1], dtype=jnp.int32) abcd_idx = idx_counts[:, :4] + gtos_abcd, coeffs_abcd = zip( *[ gtos.map_pgto_params(lambda gto_param, i=i: gto_param[abcd_idx[:, i]]) for i in range(4) ] ) - t_abcd = hartree(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms) + t_abcd = hartree_uncontracted(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms) + jax.block_until_ready(t_abcd) if not cgto: return t_abcd counts_abcd_i = idx_counts[:, 4] diff --git a/d4ft/native/obara_saika/eri_kernel.cc b/d4ft/native/obara_saika/eri_kernel.cc index abf125f..0db6728 100644 --- a/d4ft/native/obara_saika/eri_kernel.cc +++ b/d4ft/native/obara_saika/eri_kernel.cc @@ -7,5 +7,7 @@ PYBIND11_MODULE(eri_kernel, m) { // py::class_(m, "Parent").def(py::init<>()); REGISTER_XLA_FUNCTION(m, Hartree_32); REGISTER_XLA_FUNCTION(m, Hartree_64); + REGISTER_XLA_FUNCTION(m, Hartree_32_uncontracted); + REGISTER_XLA_FUNCTION(m, Hartree_64_uncontracted); // REGISTER_XLA_MEMBER(m, Parent, ExampleMember); } diff --git a/d4ft/native/obara_saika/eri_kernel.cu b/d4ft/native/obara_saika/eri_kernel.cu index 832e95f..9b02dea 100644 --- a/d4ft/native/obara_saika/eri_kernel.cu +++ b/d4ft/native/obara_saika/eri_kernel.cu @@ -31,6 +31,7 @@ void Hartree_32::Gpu(cudaStream_t stream, Array& max_ab, Array& max_cd, Array& Ms, + Array& ab_range, Array& output) { std::cout<shape[0]<& N, + Array& screened_length, + Array& n, + Array& r, + Array& z, + Array& min_a, + Array& min_c, + Array& max_ab, + Array& max_cd, + Array& Ms, + Array& sorted_ab_idx, + Array& sorted_cd_idx, + Array& screened_cd_idx_start, + Array& screened_idx_offset, + Array& output) { + // Prescreening + int* idx_4c; + int idx_length; + cudaMemcpy(&idx_length, screened_length.ptr, sizeof(int), cudaMemcpyDeviceToHost); + cudaMalloc((void **)&idx_4c, 2 * idx_length * sizeof(int)); + std::cout<shape[0]; + + // Pre-screen, result is (ab_index, cd_index), i.e. (ab, cd) + hemi::ExecutionPolicy ep; + ep.setStream(stream); + hemi::parallel_for(ep, 0, screened_cd_idx_start.spec->shape[0], [=] HEMI_LAMBDA(int index) { + for(int i = screened_cd_idx_start.ptr[index]; i < num_cd; i++ ){ + int loc; + loc = screened_idx_offset.ptr[index] + i - screened_cd_idx_start.ptr[index]; + idx_4c[loc] = sorted_ab_idx.ptr[index]; // ab + idx_4c[loc + screened_length.ptr[0]] = sorted_cd_idx.ptr[i]; // cd + output.ptr[loc] = sorted_ab_idx.ptr[index]; // ab + output.ptr[loc + screened_length.ptr[0]] = sorted_cd_idx.ptr[i]; // cd + } + __syncthreads(); + }); + + // Now we have ab cd, we can compute eri and contract it to output + // For contract, we need 1. count 2. pgto normalization coeff 3. pgto coeff 4.rdm1 (Mocoeff) + hemi::parallel_for(ep, 0, idx_length, [=] HEMI_LAMBDA(int index) { + int a, b, c, d; // pgto 4c idx + int i, j, k, l; // cgto 4c idx + double eri_result; + triu_ij_from_index(N.ptr[0], idx_4c[index], &a, &b); + triu_ij_from_index(N.ptr[0], idx_4c[index + screened_length.ptr[0]], &c, &d); + eri_result = eri(n.ptr[0 * N.ptr[0] + a], n.ptr[1 * N.ptr[0] + a], n.ptr[2 * N.ptr[0] + a], // a + n.ptr[0 * N.ptr[0] + b], n.ptr[1 * N.ptr[0] + b], n.ptr[2 * N.ptr[0] + b], // b + n.ptr[0 * N.ptr[0] + c], n.ptr[1 * N.ptr[0] + c], n.ptr[2 * N.ptr[0] + c], // c + n.ptr[0 * N.ptr[0] + d], n.ptr[1 * N.ptr[0] + d], n.ptr[2 * N.ptr[0] + d], // d + r.ptr[0 * N.ptr[0] + a], r.ptr[1 * N.ptr[0] + a], r.ptr[2 * N.ptr[0] + a], // a + r.ptr[0 * N.ptr[0] + b], r.ptr[1 * N.ptr[0] + b], r.ptr[2 * N.ptr[0] + b], // b + r.ptr[0 * N.ptr[0] + c], r.ptr[1 * N.ptr[0] + c], r.ptr[2 * N.ptr[0] + c], // c + r.ptr[0 * N.ptr[0] + d], r.ptr[1 * N.ptr[0] + d], r.ptr[2 * N.ptr[0] + d], // d + z.ptr[a], z.ptr[b], z.ptr[c], z.ptr[d], // z + min_a.ptr, min_c.ptr, max_ab.ptr, max_cd.ptr, Ms.ptr); + }); + + // std::cout<shape[0]<shape[0], [=] HEMI_LAMBDA(int index) { + // int i, j, k, l, ij, kl; + // // triu_ij_from_index(num_unique_ij(N.ptr[0]), index_4c.ptr[index], &ij, &kl); + // // triu_ij_from_index(N.ptr[0], ij, &i, &j); + // // triu_ij_from_index(N.ptr[0], kl, &k, &l); + // // output.ptr[index] = index_4c.ptr[index]; + // i = index_4c.ptr[4*index + 0]; + // j = index_4c.ptr[4*index + 1]; + // k = index_4c.ptr[4*index + 2]; + // l = index_4c.ptr[4*index + 3]; + // output.ptr[index] = eri(n.ptr[0 * N.ptr[0] + i], n.ptr[1 * N.ptr[0] + i], n.ptr[2 * N.ptr[0] + i], // a + // n.ptr[0 * N.ptr[0] + j], n.ptr[1 * N.ptr[0] + j], n.ptr[2 * N.ptr[0] + j], // b + // n.ptr[0 * N.ptr[0] + k], n.ptr[1 * N.ptr[0] + k], n.ptr[2 * N.ptr[0] + k], // c + // n.ptr[0 * N.ptr[0] + l], n.ptr[1 * N.ptr[0] + l], n.ptr[2 * N.ptr[0] + l], // d + // r.ptr[0 * N.ptr[0] + i], r.ptr[1 * N.ptr[0] + i], r.ptr[2 * N.ptr[0] + i], // a + // r.ptr[0 * N.ptr[0] + j], r.ptr[1 * N.ptr[0] + j], r.ptr[2 * N.ptr[0] + j], // b + // r.ptr[0 * N.ptr[0] + k], r.ptr[1 * N.ptr[0] + k], r.ptr[2 * N.ptr[0] + k], // c + // r.ptr[0 * N.ptr[0] + l], r.ptr[1 * N.ptr[0] + l], r.ptr[2 * N.ptr[0] + l], // d + // z.ptr[i], z.ptr[j], z.ptr[k], z.ptr[l], // z + // min_a.ptr, min_c.ptr, max_ab.ptr, max_cd.ptr, Ms.ptr); + // }); +} + +// template +void Hartree_32_uncontracted::Gpu(cudaStream_t stream, + Array& N, + Array& index_4c, + Array& n, + Array& r, + Array& z, + Array& min_a, + Array& min_c, + Array& max_ab, + Array& max_cd, + Array& Ms, + Array& output) { + // std::cout<shape[0]<shape[0], [=] HEMI_LAMBDA(int index) { + int i, j, k, l, ij, kl; + // triu_ij_from_index(num_unique_ij(N.ptr[0]), index_4c.ptr[index], &ij, &kl); + // triu_ij_from_index(N.ptr[0], ij, &i, &j); + // triu_ij_from_index(N.ptr[0], kl, &k, &l); + // output.ptr[index] = index_4c.ptr[index]; + i = index_4c.ptr[4*index + 0]; + j = index_4c.ptr[4*index + 1]; + k = index_4c.ptr[4*index + 2]; + l = index_4c.ptr[4*index + 3]; + output.ptr[index] = eri(n.ptr[0 * N.ptr[0] + i], n.ptr[1 * N.ptr[0] + i], n.ptr[2 * N.ptr[0] + i], // a + n.ptr[0 * N.ptr[0] + j], n.ptr[1 * N.ptr[0] + j], n.ptr[2 * N.ptr[0] + j], // b + n.ptr[0 * N.ptr[0] + k], n.ptr[1 * N.ptr[0] + k], n.ptr[2 * N.ptr[0] + k], // c + n.ptr[0 * N.ptr[0] + l], n.ptr[1 * N.ptr[0] + l], n.ptr[2 * N.ptr[0] + l], // d + r.ptr[0 * N.ptr[0] + i], r.ptr[1 * N.ptr[0] + i], r.ptr[2 * N.ptr[0] + i], // a + r.ptr[0 * N.ptr[0] + j], r.ptr[1 * N.ptr[0] + j], r.ptr[2 * N.ptr[0] + j], // b + r.ptr[0 * N.ptr[0] + k], r.ptr[1 * N.ptr[0] + k], r.ptr[2 * N.ptr[0] + k], // c + r.ptr[0 * N.ptr[0] + l], r.ptr[1 * N.ptr[0] + l], r.ptr[2 * N.ptr[0] + l], // d + z.ptr[i], z.ptr[j], z.ptr[k], z.ptr[l], // z + min_a.ptr, min_c.ptr, max_ab.ptr, max_cd.ptr, Ms.ptr); + }); +} + +void Hartree_64_uncontracted::Gpu(cudaStream_t stream, Array& N, Array& index_4c, Array& n, @@ -70,7 +195,7 @@ void Hartree_64::Gpu(cudaStream_t stream, Array& max_cd, Array& Ms, Array& output) { - std::cout<shape[0]<shape[0]<shape[0], [=] HEMI_LAMBDA(int index) { diff --git a/d4ft/native/obara_saika/eri_kernel.h b/d4ft/native/obara_saika/eri_kernel.h index 84fff6f..6033d28 100644 --- a/d4ft/native/obara_saika/eri_kernel.h +++ b/d4ft/native/obara_saika/eri_kernel.h @@ -28,6 +28,130 @@ #include "d4ft/native/xla/specs.h" class Hartree_32 { + public: + // template + static auto ShapeInference(const Spec& shape1, + const Spec& shape10, + const Spec& shape2, + const Spec& shape3, + const Spec& shape4, + const Spec& shape5, + const Spec& shape6, + const Spec& shape7, + const Spec& shape8, + const Spec& shape9, + const Spec& shape11) { + // float n2 = shape4.shape[0]*(shape4.shape[0]+1)/2; + // float n4 = n2*(n2+1)/2; + // int n4_int = static_cast(n4); + std::vector outshape={shape1.shape[0]}; + Spec out(outshape); + return std::make_tuple(out); + } + // static void Cpu(Array& arg1, Array& arg2, + // Array& out) { + // std::memcpy(out.ptr, arg1.ptr, sizeof(float) * arg1.spec->Size()); + // } + // template + static void Cpu(Array& N, + Array& index_4c, + Array& n, + Array& r, + Array& z, + Array& min_a, + Array& min_c, + Array& max_ab, + Array& max_cd, + Array& Ms, + Array& ab_range, + Array& output){ + // std::memcpy(output.ptr, outshape.ptr, sizeof(float) * outshape.spec->Size()); + } + + // template + static void Gpu(cudaStream_t stream, + Array& N, + Array& index_4c, + Array& n, + Array& r, + Array& z, + Array& min_a, + Array& min_c, + Array& max_ab, + Array& max_cd, + Array& Ms, + Array& ab_range, + Array& output); +}; + +class Hartree_64 { + public: + // template + static auto ShapeInference(const Spec& shape1, + const Spec& shape10, + const Spec& shape2, + const Spec& shape3, + const Spec& shape4, + const Spec& shape5, + const Spec& shape6, + const Spec& shape7, + const Spec& shape8, + const Spec& shape9, + const Spec& shape11, + const Spec& shape12, + const Spec& shape13, + const Spec& shape14) { + // double n2 = shape4.shape[0]*(shape4.shape[0]+1)/2; + // double n4 = n2*(n2+1)/2; + // int n4_int = static_cast(n4); + std::vector outshape={2*shape11.shape[0]*shape12.shape[0]}; + // std::vector outshape={shape1.shape[0]}; + Spec out(outshape); + return std::make_tuple(out); + } + // static void Cpu(Array& arg1, Array& arg2, + // Array& out) { + // std::memcpy(out.ptr, arg1.ptr, sizeof(float) * arg1.spec->Size()); + // } + // template + static void Cpu(Array& N, + Array& screened_length, + Array& n, + Array& r, + Array& z, + Array& min_a, + Array& min_c, + Array& max_ab, + Array& max_cd, + Array& Ms, + Array& sorted_ab_idx, + Array& sorted_cd_idx, + Array& screened_cd_idx_start, + Array& screened_idx_offset, + Array& output){ + // std::memcpy(output.ptr, outshape.ptr, sizeof(float) * outshape.spec->Size()); + } + + // template + static void Gpu(cudaStream_t stream, + Array& N, + Array& screened_length, + Array& n, + Array& r, + Array& z, + Array& min_a, + Array& min_c, + Array& max_ab, + Array& max_cd, + Array& Ms, + Array& sorted_ab_idx, + Array& sorted_cd_idx, + Array& screened_cd_idx_start, + Array& screened_idx_offset, + Array& output); +}; + +class Hartree_32_uncontracted { public: // template static auto ShapeInference(const Spec& shape1, @@ -81,7 +205,7 @@ class Hartree_32 { Array& output); }; -class Hartree_64 { +class Hartree_64_uncontracted { public: // template static auto ShapeInference(const Spec& shape1, diff --git a/tests/native/xla/eri_kernel_test.py b/tests/native/xla/eri_kernel_test.py index 1da221b..9570958 100644 --- a/tests/native/xla/eri_kernel_test.py +++ b/tests/native/xla/eri_kernel_test.py @@ -11,7 +11,7 @@ from absl.testing import absltest from d4ft.native.xla.custom_call import CustomCallMeta -from d4ft.native.obara_saika.eri_kernel import _Hartree_32, _Hartree_64 +from d4ft.native.obara_saika.eri_kernel import _Hartree_32, _Hartree_32_uncontracted, _Hartree_64, _Hartree_64_uncontracted from d4ft.integral.obara_saika.electron_repulsion_integral import ( electron_repulsion_integral, @@ -21,15 +21,20 @@ from d4ft.integral.gto.cgto import CGTO from d4ft.integral.gto import symmetry, tensorization from copy import deepcopy +from d4ft.types import AngularStats, CGTOSymTensorIncore, Tensor2C, Tensor4C # from obsa.obara_saika import get_coulomb, get_kinetic, get_nuclear, get_overlap # from jax.interpreters import ad, batching, mlir, xla Hartree_64 = CustomCallMeta("Hartree_64", (_Hartree_64,), {}) +Hartree_64_uncontracted = CustomCallMeta("Hartree_64_uncontracted", (_Hartree_64_uncontracted,), {}) Hartree_32 = CustomCallMeta("Hartree_32", (_Hartree_32,), {}) +Hartree_32_uncontracted = CustomCallMeta("Hartree_32_uncontracted", (_Hartree_32_uncontracted,), {}) if jax.config.jax_enable_x64: hartree = Hartree_64() + hartree_uncontracted = Hartree_64_uncontracted() else: hartree = Hartree_32() + hartree_uncontracted = Hartree_32_uncontracted() # TODO # def _example_batch_rule(args, axes): @@ -64,9 +69,10 @@ def num_unique_ijkl(n): # To support higher angular, first adjust constants in eri.h: MAX_XYZ, MAX_YZ.. + # pyscf_mol = get_pyscf_mol("C180-0", "sto-3g") # pyscf_mol = get_pyscf_mol("C60-Ih", "sto-3g") - pyscf_mol = get_pyscf_mol("C180-0", "6-31G") - # pyscf_mol = get_pyscf_mol("O2", "sto-3g") + # pyscf_mol = get_pyscf_mol("O2", "6-31G") + pyscf_mol = get_pyscf_mol("O2", "sto-3g") mol = Mol.from_pyscf_mol(pyscf_mol) cgto = CGTO.from_mol(mol) self.s = angular_stats.angular_static_args(*[cgto.pgto.angular] * 4) @@ -81,6 +87,7 @@ def num_unique_ijkl(n): num_4c_idx = symmetry.num_unique_ij(n_2c_idx) self.num_4c_idx = num_4c_idx + # self.num_4c_idx = num_4c_idx # batch_size: int = 2**23 # i = 0 # start = batch_size * i @@ -107,6 +114,7 @@ def num_unique_ijkl(n): # self.cgto_seg_id = symmetry.get_cgto_segment_id_sym( # self.abcd_idx_counts[:, :-1], cgto.cgto_splits, four_center=True # ) + # self.a, self.b, self.c, self.d = gtos_abcd # self.N = cgto.n_pgtos # self.n = jnp.array(deepcopy(cgto.pgto.angular.T.reshape((3*self.N,))), dtype=jnp.int32) @@ -120,8 +128,8 @@ def num_unique_ijkl(n): # self.max_cd = jnp.array(self.s.max_cd, dtype=jnp.int32) # self.Ms = jnp.array([self.s.max_xyz+1, self.s.max_yz+1, self.s.max_z+1], dtype=jnp.int32) - def test_example(self) -> None: - logging.info(jax.devices()) + # def test_example(self) -> None: + # logging.info(jax.devices()) # def f_curry(*args): # return electron_repulsion_integral(*args, static_args=self.s) # vmap_f = jax.vmap(f_curry, in_axes=(0, 0, 0, 0)) @@ -132,19 +140,19 @@ def test_example(self) -> None: # T_2 = time.time() # print(T_2-T_1) - # cgto_4c_fn = tensorization.tensorize_4c_cgto_cuda(self.s) - # T_1 = time.time() - # # e2 = hartree(jnp.array([self.N], dtype=jnp.int32),jnp.array(range(self.N), dtype=jnp.int32), self.n, self.r, self.z, self.min_a, - # # self.min_c, self.max_ab, self.max_cd, self.Ms) - # cgto_abcd_2 = cgto_4c_fn( - # self.cgto, self.abcd_idx_counts, self.cgto_seg_id, self.n_segs - # ) - # T_2 = time.time() - # print(T_2-T_1) - # # abcd_2 = jnp.einsum("k,k,k,k,k,k->k", e2, self.N_abcd, *self.coeffs_abcd) - # # cgto_abcd_2 = jax.ops.segment_sum(abcd_2, self.cgto_seg_id, self.n_segs) - # # np.testing.assert_allclose(e1,e2,atol=2e-5) - # np.testing.assert_allclose(cgto_abcd_1,cgto_abcd_2,atol=1e-5) + # cgto_4c_fn = tensorization.tensorize_4c_cgto_cuda(self.s) + # T_1 = time.time() + # # e2 = hartree(jnp.array([self.N], dtype=jnp.int32),jnp.array(range(self.N), dtype=jnp.int32), self.n, self.r, self.z, self.min_a, + # # self.min_c, self.max_ab, self.max_cd, self.Ms) + # cgto_abcd_2 = cgto_4c_fn( + # self.cgto, self.abcd_idx_counts, self.cgto_seg_id, self.n_segs + # ) + # T_2 = time.time() + # print(T_2-T_1) + # abcd_2 = jnp.einsum("k,k,k,k,k,k->k", e2, self.N_abcd, *self.coeffs_abcd) + # cgto_abcd_2 = jax.ops.segment_sum(abcd_2, self.cgto_seg_id, self.n_segs) + # np.testing.assert_allclose(e1,e2,atol=2e-5) + # np.testing.assert_allclose(cgto_abcd_1,cgto_abcd_2,atol=1e-5) # out_vmap = jax.vmap(example_fn)(self.a_b, self.b_b) @@ -156,26 +164,26 @@ def test_example(self) -> None: # np.testing.assert_array_equal(self.outshape, out) def test_abab(self) -> None: - - pgto_4c_fn = tensorization.tensorize_4c_cgto_cuda(self.s, cgto=False) + compute_hartree_test(self.cgto, self.s) + # pgto_4c_fn = tensorization.tensorize_4c_cgto_cuda(self.s, cgto=False) # pgto_4c_fn_gt = tensorization.tensorize_4c_cgto(electron_repulsion_integral, self.s, cgto=False) # cgto_4c_fn = tensorization.tensorize_4c_cgto_range(eri_fn, s4) - eri_abab = pgto_4c_fn(self.cgto, self.abab_idx_count, None, None) + # eri_abab = pgto_4c_fn(self.cgto, self.abab_idx_count, None, None) # eri_abab_gt = pgto_4c_fn_gt(self.cgto, self.abab_idx_count, None, None) - sorted_idx = jnp.argsort(eri_abab) - sorted_abab = eri_abab[sorted_idx] - eps = 1e-10 - sorted_cd_thres = eps / jnp.sqrt(sorted_abab) - cnt = jnp.array([e for e in range(len(self.abab_idx_count))]) - idx = jnp.maximum(cnt, jnp.searchsorted(sorted_abab, sorted_cd_thres)) - idx = idx[idx < len(sorted_idx)] - - abab_len = len(sorted_idx) - screened_cnt = jnp.sum(abab_len-idx) - print(len(self.abab_idx_count)) - print("original length =", self.num_4c_idx) - print("screened length =", screened_cnt) + # sorted_idx = jnp.argsort(eri_abab) + # sorted_abab = eri_abab[sorted_idx] + # eps = 1e-10 + # sorted_cd_thres = (eps / jnp.sqrt(sorted_abab))**2 + # cnt = jnp.array([e for e in range(len(self.abab_idx_count))]) + # idx = jnp.maximum(cnt, jnp.searchsorted(sorted_abab, sorted_cd_thres)) + # idx = idx[idx < len(sorted_idx)] + + # abab_len = len(sorted_idx) + # screened_cnt = jnp.sum(abab_len-idx) + # print(len(self.abab_idx_count)) + # print("original length =", self.num_4c_idx) + # print("screened length =", screened_cnt) # abcd = [jnp.array([sorted_idx[cnt]* jnp.ones(len(sorted_idx)-idx[cnt])], sorted_idx[cd_idx:]).T for cd_idx, cnt in zip(idx,range(len(idx)))] # abs = [sorted_idx[cnt] * jnp.ones(len(sorted_idx)-idx[cnt]) for cnt in range(len(idx))] # cds = sorted_idx[idx:] @@ -215,6 +223,119 @@ def test_abab(self) -> None: # logging.info(f"block diag (ab|ab) computed, size: {eri_abab.shape}") +def compute_hartree_test(cgto: CGTO, static_args: AngularStats): + pass + l_xyz = jnp.sum(cgto.pgto.angular, 1) + orig_idx = jnp.argsort(l_xyz) + + # current support s, p, d + s_num = jnp.count_nonzero(l_xyz == 0) + p_num = jnp.count_nonzero(l_xyz == 1) + d_num = jnp.count_nonzero(l_xyz == 2) + max_angular = jnp.max(l_xyz) + + N = jnp.array([cgto.n_pgtos], dtype=jnp.int32) + n = jnp.array(cgto.pgto.angular.T, dtype=jnp.int32)[orig_idx] + r = jnp.array(cgto.pgto.center.T)[orig_idx] + z = jnp.array(cgto.pgto.exponent)[orig_idx] + + min_a = jnp.array(static_args.min_a, dtype=jnp.int32) + min_c = jnp.array(static_args.min_c, dtype=jnp.int32) + max_ab = jnp.array(static_args.max_ab, dtype=jnp.int32) + max_cd = jnp.array(static_args.max_cd, dtype=jnp.int32) + Ms = jnp.array([static_args.max_xyz+1, static_args.max_yz+1, static_args.max_z+1], dtype=jnp.int32) + + ab_idx_counts = symmetry.get_2c_sym_idx(cgto.n_pgtos) + rank_ab_idx = jnp.arange(ab_idx_counts.shape[0]) + ss_mask = (ab_idx_counts[:, 1] < s_num) + sp_mask = (ab_idx_counts[:, 1] >= s_num) & (ab_idx_counts[:, 1] < s_num + p_num) & (ab_idx_counts[:, 0] < s_num) + sd_mask = (ab_idx_counts[:, 1] >= s_num + p_num) & (ab_idx_counts[:, 0] < s_num) + pp_mask = (ab_idx_counts[:, 1] < s_num + p_num) & (ab_idx_counts[:, 0] >= s_num) + pd_mask = (ab_idx_counts[:, 1] >= s_num + p_num) & (ab_idx_counts[:, 0] >= s_num) & (ab_idx_counts[:, 0] < s_num + p_num) + dd_mask = (ab_idx_counts[:, 0] >= s_num + p_num) + + ss_idx = rank_ab_idx[ss_mask] + sp_idx = rank_ab_idx[sp_mask] + sd_idx = rank_ab_idx[sd_mask] + pp_idx = rank_ab_idx[pp_mask] + pd_idx = rank_ab_idx[pd_mask] + dd_idx = rank_ab_idx[dd_mask] + + # ab_idx_counts = jnp.vstack([ab_idx_counts[ss_mask], ab_idx_counts[sp_mask], ab_idx_counts[sd_mask], + # ab_idx_counts[pp_mask], ab_idx_counts[pd_mask], ab_idx_counts[dd_mask]]) + + # ss_num = jnp.count_nonzero(ss_mask) + # sp_num = jnp.count_nonzero(sp_mask) + # sd_num = jnp.count_nonzero(sd_mask) + # pp_num = jnp.count_nonzero(pp_mask) + # pd_num = jnp.count_nonzero(pd_mask) + # dd_num = jnp.count_nonzero(dd_mask) + # ss_start = 0 + # ss_end = ss_start + ss_num + # sp_start = ss_end + # sp_end = sp_start + sp_num + # sd_start = sp_end + # sd_end = sd_start + sd_num + # pp_start = sd_end + # pp_end = pp_start + pp_num + # pd_start = pp_end + # pd_end = pd_start + pd_num + # dd_start = pd_end + # dd_end = dd_start + dd_num + # ab_range = jnp.array([[ss_start, sp_start, sd_start, pp_start, pd_start, dd_start], + # [ss_end, sp_end, sd_end, pp_end, pd_end, dd_end]],dtype=jnp.int32) + + ab_idx, counts_ab = ab_idx_counts[:, :2], ab_idx_counts[:, 2] + abab_idx_counts = jnp.hstack([ab_idx, ab_idx, + counts_ab[:, None]*counts_ab[:, None]]).astype(int) + abab_idx = jnp.array(abab_idx_counts[: ,:4], dtype=jnp.int32) + + # Compute eri abab + eri_abab = jnp.array(hartree_uncontracted(N, abab_idx, n, r, z, min_a, min_c, max_ab, max_cd, Ms)) + + sorted_idx = [ss_idx[jnp.argsort(eri_abab[ss_idx])], + sp_idx[jnp.argsort(eri_abab[sp_idx])], + sd_idx[jnp.argsort(eri_abab[sd_idx])], + pp_idx[jnp.argsort(eri_abab[pp_idx])], + pd_idx[jnp.argsort(eri_abab[pd_idx])], + dd_idx[jnp.argsort(eri_abab[dd_idx])],] + sorted_eri = [eri_abab[sorted_idx[0]], + eri_abab[sorted_idx[1]], + eri_abab[sorted_idx[2]], + eri_abab[sorted_idx[3]], + eri_abab[sorted_idx[4]], + eri_abab[sorted_idx[5]]] + + # ss,ss + # for (ss, ss) (pp, pp) (dd, dd), (sp, sp) ... need ensure idx > cnt. For anyone else, no need + eps = 1e-10 + sorted_ab_idx = sorted_idx[0] + sorted_cd_idx = sorted_idx[0] + sorted_eri_abab = sorted_eri[0] + sorted_eri_cdcd = sorted_eri[0] + sorted_ab_thres = (eps / jnp.sqrt(sorted_eri_abab))**2 + cnt = jnp.array([e for e in range(len(sorted_eri_abab))]) + cd_idx = jnp.searchsorted(sorted_eri_cdcd, sorted_ab_thres) + cd_idx = jnp.maximum(cnt, cd_idx) + cdcd_len = len(sorted_eri_cdcd) + start_offset = jnp.concatenate((jnp.array([0]), jnp.cumsum(cdcd_len-cd_idx)[:-1]), dtype=jnp.int32) + screened_cnt = jnp.sum(cdcd_len-cd_idx) + output = hartree(jnp.array([N], dtype=jnp.int32), jnp.array([screened_cnt], dtype=jnp.int32), + n, r, z, min_a, min_c, max_ab, max_cd, Ms, + jnp.array(sorted_ab_idx, dtype=jnp.int32), + jnp.array(sorted_cd_idx, dtype=jnp.int32), + jnp.array(cd_idx, dtype=jnp.int32), + jnp.array(start_offset, dtype=jnp.int32)) + + # print(s_num, p_num) + # abcd_idx = output[:2*screened_cnt].reshape((2,screened_cnt)) + # print(abcd_idx[:,-100:]) + + + + + + if __name__ == "__main__": absltest.main()