Skip to content

Commit

Permalink
update screening in cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonLeeJSL committed Dec 4, 2023
1 parent f783a1b commit 96b5e09
Show file tree
Hide file tree
Showing 5 changed files with 419 additions and 39 deletions.
14 changes: 11 additions & 3 deletions d4ft/integral/gto/tensorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,19 @@
from d4ft.integral.gto import symmetry
from d4ft.integral.gto.cgto import CGTO, PGTO
from d4ft.types import IdxCount2C, IdxCount4C
from d4ft.native.obara_saika.eri_kernel import _Hartree_32, _Hartree_64
from d4ft.native.obara_saika.eri_kernel import _Hartree_32, _Hartree_32_uncontracted, _Hartree_64, _Hartree_64_uncontracted
from d4ft.native.xla.custom_call import CustomCallMeta

Hartree_64 = CustomCallMeta("Hartree_64", (_Hartree_64,), {})
Hartree_64_uncontracted = CustomCallMeta("Hartree_64_uncontracted", (_Hartree_64_uncontracted,), {})
Hartree_32 = CustomCallMeta("Hartree_32", (_Hartree_32,), {})
Hartree_32_uncontracted = CustomCallMeta("Hartree_32_uncontracted", (_Hartree_32_uncontracted,), {})
if jax.config.jax_enable_x64:
hartree = Hartree_64()
hartree_uncontracted = Hartree_64_uncontracted()
else:
hartree = Hartree_32()
hartree_uncontracted = Hartree_32_uncontracted()

def tensorize_2c_cgto(f: Callable, static_args, cgto: bool = True):
"""2c centers tensorization with provided index set,
Expand Down Expand Up @@ -129,25 +133,29 @@ def tensorize(
cgto_seg_id,
n_segs: int,
):
Ns = gtos.N
N = gtos.n_pgtos
Ns = gtos.N

# Why: Reshape n r z to 1D will significantly reduce computing time
n = jnp.array(gtos.pgto.angular.T, dtype=jnp.int32)
r = jnp.array(gtos.pgto.center.T)
z = jnp.array(gtos.pgto.exponent)

min_a = jnp.array(static_args.min_a, dtype=jnp.int32)
min_c = jnp.array(static_args.min_c, dtype=jnp.int32)
max_ab = jnp.array(static_args.max_ab, dtype=jnp.int32)
max_cd = jnp.array(static_args.max_cd, dtype=jnp.int32)
Ms = jnp.array([static_args.max_xyz+1, static_args.max_yz+1, static_args.max_z+1], dtype=jnp.int32)
abcd_idx = idx_counts[:, :4]

gtos_abcd, coeffs_abcd = zip(
*[
gtos.map_pgto_params(lambda gto_param, i=i: gto_param[abcd_idx[:, i]])
for i in range(4)
]
)
t_abcd = hartree(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms)
t_abcd = hartree_uncontracted(jnp.array([N], dtype=jnp.int32), jnp.array(abcd_idx,dtype=jnp.int32), n, r, z, min_a, min_c, max_ab, max_cd, Ms)
jax.block_until_ready(t_abcd)
if not cgto:
return t_abcd
counts_abcd_i = idx_counts[:, 4]
Expand Down
2 changes: 2 additions & 0 deletions d4ft/native/obara_saika/eri_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ PYBIND11_MODULE(eri_kernel, m) {
// py::class_<Parent>(m, "Parent").def(py::init<>());
REGISTER_XLA_FUNCTION(m, Hartree_32);
REGISTER_XLA_FUNCTION(m, Hartree_64);
REGISTER_XLA_FUNCTION(m, Hartree_32_uncontracted);
REGISTER_XLA_FUNCTION(m, Hartree_64_uncontracted);
// REGISTER_XLA_MEMBER(m, Parent, ExampleMember);
}
127 changes: 126 additions & 1 deletion d4ft/native/obara_saika/eri_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ void Hartree_32::Gpu(cudaStream_t stream,
Array<const int>& max_ab,
Array<const int>& max_cd,
Array<const int>& Ms,
Array<const int>& ab_range,
Array<float>& output) {
std::cout<<index_4c.spec->shape[0]<<std::endl;
hemi::ExecutionPolicy ep;
Expand Down Expand Up @@ -59,6 +60,130 @@ void Hartree_32::Gpu(cudaStream_t stream,
}

void Hartree_64::Gpu(cudaStream_t stream,
Array<const int>& N,
Array<const int>& screened_length,
Array<const int>& n,
Array<const double>& r,
Array<const double>& z,
Array<const int>& min_a,
Array<const int>& min_c,
Array<const int>& max_ab,
Array<const int>& max_cd,
Array<const int>& Ms,
Array<const int>& sorted_ab_idx,
Array<const int>& sorted_cd_idx,
Array<const int>& screened_cd_idx_start,
Array<const int>& screened_idx_offset,
Array<int>& output) {
// Prescreening
int* idx_4c;
int idx_length;
cudaMemcpy(&idx_length, screened_length.ptr, sizeof(int), cudaMemcpyDeviceToHost);
cudaMalloc((void **)&idx_4c, 2 * idx_length * sizeof(int));
std::cout<<idx_length<<std::endl;
int num_cd = sorted_cd_idx.spec->shape[0];

// Pre-screen, result is (ab_index, cd_index), i.e. (ab, cd)
hemi::ExecutionPolicy ep;
ep.setStream(stream);
hemi::parallel_for(ep, 0, screened_cd_idx_start.spec->shape[0], [=] HEMI_LAMBDA(int index) {
for(int i = screened_cd_idx_start.ptr[index]; i < num_cd; i++ ){
int loc;
loc = screened_idx_offset.ptr[index] + i - screened_cd_idx_start.ptr[index];
idx_4c[loc] = sorted_ab_idx.ptr[index]; // ab
idx_4c[loc + screened_length.ptr[0]] = sorted_cd_idx.ptr[i]; // cd
output.ptr[loc] = sorted_ab_idx.ptr[index]; // ab
output.ptr[loc + screened_length.ptr[0]] = sorted_cd_idx.ptr[i]; // cd
}
__syncthreads();
});

// Now we have ab cd, we can compute eri and contract it to output
// For contract, we need 1. count 2. pgto normalization coeff 3. pgto coeff 4.rdm1 (Mocoeff)
hemi::parallel_for(ep, 0, idx_length, [=] HEMI_LAMBDA(int index) {
int a, b, c, d; // pgto 4c idx
int i, j, k, l; // cgto 4c idx
double eri_result;
triu_ij_from_index(N.ptr[0], idx_4c[index], &a, &b);
triu_ij_from_index(N.ptr[0], idx_4c[index + screened_length.ptr[0]], &c, &d);
eri_result = eri<double>(n.ptr[0 * N.ptr[0] + a], n.ptr[1 * N.ptr[0] + a], n.ptr[2 * N.ptr[0] + a], // a
n.ptr[0 * N.ptr[0] + b], n.ptr[1 * N.ptr[0] + b], n.ptr[2 * N.ptr[0] + b], // b
n.ptr[0 * N.ptr[0] + c], n.ptr[1 * N.ptr[0] + c], n.ptr[2 * N.ptr[0] + c], // c
n.ptr[0 * N.ptr[0] + d], n.ptr[1 * N.ptr[0] + d], n.ptr[2 * N.ptr[0] + d], // d
r.ptr[0 * N.ptr[0] + a], r.ptr[1 * N.ptr[0] + a], r.ptr[2 * N.ptr[0] + a], // a
r.ptr[0 * N.ptr[0] + b], r.ptr[1 * N.ptr[0] + b], r.ptr[2 * N.ptr[0] + b], // b
r.ptr[0 * N.ptr[0] + c], r.ptr[1 * N.ptr[0] + c], r.ptr[2 * N.ptr[0] + c], // c
r.ptr[0 * N.ptr[0] + d], r.ptr[1 * N.ptr[0] + d], r.ptr[2 * N.ptr[0] + d], // d
z.ptr[a], z.ptr[b], z.ptr[c], z.ptr[d], // z
min_a.ptr, min_c.ptr, max_ab.ptr, max_cd.ptr, Ms.ptr);
});

// std::cout<<index_4c.spec->shape[0]<<std::endl;
// hemi::ExecutionPolicy ep;
// ep.setStream(stream);
// hemi::parallel_for(ep, 0, index_4c.spec->shape[0], [=] HEMI_LAMBDA(int index) {
// int i, j, k, l, ij, kl;
// // triu_ij_from_index(num_unique_ij(N.ptr[0]), index_4c.ptr[index], &ij, &kl);
// // triu_ij_from_index(N.ptr[0], ij, &i, &j);
// // triu_ij_from_index(N.ptr[0], kl, &k, &l);
// // output.ptr[index] = index_4c.ptr[index];
// i = index_4c.ptr[4*index + 0];
// j = index_4c.ptr[4*index + 1];
// k = index_4c.ptr[4*index + 2];
// l = index_4c.ptr[4*index + 3];
// output.ptr[index] = eri<double>(n.ptr[0 * N.ptr[0] + i], n.ptr[1 * N.ptr[0] + i], n.ptr[2 * N.ptr[0] + i], // a
// n.ptr[0 * N.ptr[0] + j], n.ptr[1 * N.ptr[0] + j], n.ptr[2 * N.ptr[0] + j], // b
// n.ptr[0 * N.ptr[0] + k], n.ptr[1 * N.ptr[0] + k], n.ptr[2 * N.ptr[0] + k], // c
// n.ptr[0 * N.ptr[0] + l], n.ptr[1 * N.ptr[0] + l], n.ptr[2 * N.ptr[0] + l], // d
// r.ptr[0 * N.ptr[0] + i], r.ptr[1 * N.ptr[0] + i], r.ptr[2 * N.ptr[0] + i], // a
// r.ptr[0 * N.ptr[0] + j], r.ptr[1 * N.ptr[0] + j], r.ptr[2 * N.ptr[0] + j], // b
// r.ptr[0 * N.ptr[0] + k], r.ptr[1 * N.ptr[0] + k], r.ptr[2 * N.ptr[0] + k], // c
// r.ptr[0 * N.ptr[0] + l], r.ptr[1 * N.ptr[0] + l], r.ptr[2 * N.ptr[0] + l], // d
// z.ptr[i], z.ptr[j], z.ptr[k], z.ptr[l], // z
// min_a.ptr, min_c.ptr, max_ab.ptr, max_cd.ptr, Ms.ptr);
// });
}

// template <typename FLOAT>
void Hartree_32_uncontracted::Gpu(cudaStream_t stream,
Array<const int>& N,
Array<const int>& index_4c,
Array<const int>& n,
Array<const float>& r,
Array<const float>& z,
Array<const int>& min_a,
Array<const int>& min_c,
Array<const int>& max_ab,
Array<const int>& max_cd,
Array<const int>& Ms,
Array<float>& output) {
// std::cout<<index_4c.spec->shape[0]<<std::endl;
hemi::ExecutionPolicy ep;
ep.setStream(stream);
hemi::parallel_for(ep, 0, index_4c.spec->shape[0], [=] HEMI_LAMBDA(int index) {
int i, j, k, l, ij, kl;
// triu_ij_from_index(num_unique_ij(N.ptr[0]), index_4c.ptr[index], &ij, &kl);
// triu_ij_from_index(N.ptr[0], ij, &i, &j);
// triu_ij_from_index(N.ptr[0], kl, &k, &l);
// output.ptr[index] = index_4c.ptr[index];
i = index_4c.ptr[4*index + 0];
j = index_4c.ptr[4*index + 1];
k = index_4c.ptr[4*index + 2];
l = index_4c.ptr[4*index + 3];
output.ptr[index] = eri<float>(n.ptr[0 * N.ptr[0] + i], n.ptr[1 * N.ptr[0] + i], n.ptr[2 * N.ptr[0] + i], // a
n.ptr[0 * N.ptr[0] + j], n.ptr[1 * N.ptr[0] + j], n.ptr[2 * N.ptr[0] + j], // b
n.ptr[0 * N.ptr[0] + k], n.ptr[1 * N.ptr[0] + k], n.ptr[2 * N.ptr[0] + k], // c
n.ptr[0 * N.ptr[0] + l], n.ptr[1 * N.ptr[0] + l], n.ptr[2 * N.ptr[0] + l], // d
r.ptr[0 * N.ptr[0] + i], r.ptr[1 * N.ptr[0] + i], r.ptr[2 * N.ptr[0] + i], // a
r.ptr[0 * N.ptr[0] + j], r.ptr[1 * N.ptr[0] + j], r.ptr[2 * N.ptr[0] + j], // b
r.ptr[0 * N.ptr[0] + k], r.ptr[1 * N.ptr[0] + k], r.ptr[2 * N.ptr[0] + k], // c
r.ptr[0 * N.ptr[0] + l], r.ptr[1 * N.ptr[0] + l], r.ptr[2 * N.ptr[0] + l], // d
z.ptr[i], z.ptr[j], z.ptr[k], z.ptr[l], // z
min_a.ptr, min_c.ptr, max_ab.ptr, max_cd.ptr, Ms.ptr);
});
}

void Hartree_64_uncontracted::Gpu(cudaStream_t stream,
Array<const int>& N,
Array<const int>& index_4c,
Array<const int>& n,
Expand All @@ -70,7 +195,7 @@ void Hartree_64::Gpu(cudaStream_t stream,
Array<const int>& max_cd,
Array<const int>& Ms,
Array<double>& output) {
std::cout<<index_4c.spec->shape[0]<<std::endl;
// std::cout<<index_4c.spec->shape[0]<<std::endl;
hemi::ExecutionPolicy ep;
ep.setStream(stream);
hemi::parallel_for(ep, 0, index_4c.spec->shape[0], [=] HEMI_LAMBDA(int index) {
Expand Down
126 changes: 125 additions & 1 deletion d4ft/native/obara_saika/eri_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,130 @@
#include "d4ft/native/xla/specs.h"

class Hartree_32 {
public:
// template <typename FLOAT>
static auto ShapeInference(const Spec<int>& shape1,
const Spec<int>& shape10,
const Spec<int>& shape2,
const Spec<float>& shape3,
const Spec<float>& shape4,
const Spec<int>& shape5,
const Spec<int>& shape6,
const Spec<int>& shape7,
const Spec<int>& shape8,
const Spec<int>& shape9,
const Spec<int>& shape11) {
// float n2 = shape4.shape[0]*(shape4.shape[0]+1)/2;
// float n4 = n2*(n2+1)/2;
// int n4_int = static_cast<int>(n4);
std::vector<int> outshape={shape1.shape[0]};
Spec<float> out(outshape);
return std::make_tuple(out);
}
// static void Cpu(Array<const float>& arg1, Array<const int>& arg2,
// Array<float>& out) {
// std::memcpy(out.ptr, arg1.ptr, sizeof(float) * arg1.spec->Size());
// }
// template <typename FLOAT>
static void Cpu(Array<const int>& N,
Array<const int>& index_4c,
Array<const int>& n,
Array<const float>& r,
Array<const float>& z,
Array<const int>& min_a,
Array<const int>& min_c,
Array<const int>& max_ab,
Array<const int>& max_cd,
Array<const int>& Ms,
Array<const int>& ab_range,
Array<float>& output){
// std::memcpy(output.ptr, outshape.ptr, sizeof(float) * outshape.spec->Size());
}

// template <typename FLOAT>
static void Gpu(cudaStream_t stream,
Array<const int>& N,
Array<const int>& index_4c,
Array<const int>& n,
Array<const float>& r,
Array<const float>& z,
Array<const int>& min_a,
Array<const int>& min_c,
Array<const int>& max_ab,
Array<const int>& max_cd,
Array<const int>& Ms,
Array<const int>& ab_range,
Array<float>& output);
};

class Hartree_64 {
public:
// template <typename FLOAT>
static auto ShapeInference(const Spec<int>& shape1,
const Spec<int>& shape10,
const Spec<int>& shape2,
const Spec<double>& shape3,
const Spec<double>& shape4,
const Spec<int>& shape5,
const Spec<int>& shape6,
const Spec<int>& shape7,
const Spec<int>& shape8,
const Spec<int>& shape9,
const Spec<int>& shape11,
const Spec<int>& shape12,
const Spec<int>& shape13,
const Spec<int>& shape14) {
// double n2 = shape4.shape[0]*(shape4.shape[0]+1)/2;
// double n4 = n2*(n2+1)/2;
// int n4_int = static_cast<int>(n4);
std::vector<int> outshape={2*shape11.shape[0]*shape12.shape[0]};
// std::vector<int> outshape={shape1.shape[0]};
Spec<int> out(outshape);
return std::make_tuple(out);
}
// static void Cpu(Array<const float>& arg1, Array<const int>& arg2,
// Array<float>& out) {
// std::memcpy(out.ptr, arg1.ptr, sizeof(float) * arg1.spec->Size());
// }
// template <typename FLOAT>
static void Cpu(Array<const int>& N,
Array<const int>& screened_length,
Array<const int>& n,
Array<const double>& r,
Array<const double>& z,
Array<const int>& min_a,
Array<const int>& min_c,
Array<const int>& max_ab,
Array<const int>& max_cd,
Array<const int>& Ms,
Array<const int>& sorted_ab_idx,
Array<const int>& sorted_cd_idx,
Array<const int>& screened_cd_idx_start,
Array<const int>& screened_idx_offset,
Array<int>& output){
// std::memcpy(output.ptr, outshape.ptr, sizeof(float) * outshape.spec->Size());
}

// template <typename FLOAT>
static void Gpu(cudaStream_t stream,
Array<const int>& N,
Array<const int>& screened_length,
Array<const int>& n,
Array<const double>& r,
Array<const double>& z,
Array<const int>& min_a,
Array<const int>& min_c,
Array<const int>& max_ab,
Array<const int>& max_cd,
Array<const int>& Ms,
Array<const int>& sorted_ab_idx,
Array<const int>& sorted_cd_idx,
Array<const int>& screened_cd_idx_start,
Array<const int>& screened_idx_offset,
Array<int>& output);
};

class Hartree_32_uncontracted {
public:
// template <typename FLOAT>
static auto ShapeInference(const Spec<int>& shape1,
Expand Down Expand Up @@ -81,7 +205,7 @@ class Hartree_32 {
Array<float>& output);
};

class Hartree_64 {
class Hartree_64_uncontracted {
public:
// template <typename FLOAT>
static auto ShapeInference(const Spec<int>& shape1,
Expand Down
Loading

0 comments on commit 96b5e09

Please sign in to comment.