diff --git a/.bazelrc b/.bazelrc index 25e891c..4396e0e 100644 --- a/.bazelrc +++ b/.bazelrc @@ -3,7 +3,7 @@ build --@rules_cuda//cuda:enable_cuda build --@rules_cuda//cuda:copts=-std=c++17 build --repo_env=CUDA_DIR=/opt/cuda build --repo_env=CUDA_PATH=/opt/cuda -#build --repo_env=CC=gcc-12 +build --repo_env=CC=gcc-12 # TODO: check whether both env var are needed build --action_env=CUDA_PATH=/opt/cuda #build --action_env=CC=gcc-12 diff --git a/README.md b/README.md index 1343ddf..f1eab0d 100644 --- a/README.md +++ b/README.md @@ -239,6 +239,13 @@ To visualize the run: python main.py --run viz --config.save_dir _exp/bh76,6-31g+lda_x,3aahmyt0 ``` +# CUDA dev guide +``` shell +bazel build //... +bazel run //d4ft/native/obara_saika:eri_test +bazel test --test_output=all //tests/native/xla:example_test +``` + ## Tutorial and Documentation ### Viewing in the Browser diff --git a/d4ft/native/gamma/BUILD b/d4ft/native/gamma/BUILD index 3984b36..73abe29 100644 --- a/d4ft/native/gamma/BUILD +++ b/d4ft/native/gamma/BUILD @@ -19,7 +19,7 @@ cc_library( name = "constants", hdrs = ["constants.h"], deps = [ - # "@cuda//:cudart_static", + "@cuda//:cudart_static", "@hemi", ], ) diff --git a/d4ft/native/gamma/igamma.h b/d4ft/native/gamma/igamma.h index 4fb8996..89ddc7e 100644 --- a/d4ft/native/gamma/igamma.h +++ b/d4ft/native/gamma/igamma.h @@ -129,7 +129,7 @@ HEMI_DEV_CALLABLE FLOAT IgammacContinuedFraction(FLOAT ax, FLOAT x, FLOAT a, dqkm2_da = dqkm2_da * HEMI_CONSTANT(eps); dqkm1_da = dqkm1_da * HEMI_CONSTANT(eps); } - FLOAT conditional; + // FLOAT conditional; if (mode == VALUE) { enabled = (enabled && t > HEMI_CONSTANT(eps) < FLOAT >); } else { diff --git a/d4ft/native/obara_saika/BUILD b/d4ft/native/obara_saika/BUILD index 1f00221..2cdd116 100644 --- a/d4ft/native/obara_saika/BUILD +++ b/d4ft/native/obara_saika/BUILD @@ -13,7 +13,7 @@ # limitations under the License. load("@rules_cc//cc:defs.bzl", "cc_library") -# load("@rules_cuda//cuda:defs.bzl", "cuda_library") +load("@rules_cuda//cuda:defs.bzl", "cuda_library") package(default_visibility = ["//visibility:public"]) @@ -42,38 +42,38 @@ cc_library( ], ) -# cuda_library( -# name = "eri_kernel", -# srcs = ["eri_kernel.cu"], -# hdrs = ["eri_kernel.h"], -# copts = [ -# "--expt-extended-lambda", -# "--expt-relaxed-constexpr", -# ], -# linkopts = [ -# "-lstdc++", -# ], -# deps = [ -# ":eri", -# "@hemi", -# ], -# ) +cuda_library( + name = "eri_kernel", + srcs = ["eri_kernel.cu"], + hdrs = ["eri_kernel.h"], + copts = [ + "--expt-extended-lambda", + "--expt-relaxed-constexpr", + ], + linkopts = [ + "-lstdc++", + ], + deps = [ + ":eri", + "@hemi", + ], +) -# cc_binary( -# name = "eri_test", -# srcs = ["eri_test.cc"], -# deps = [ -# ":eri_kernel", -# # "@cuda//:cudart_static", -# "@hemi", -# ], -# ) +cc_binary( + name = "eri_test", + srcs = ["eri_test.cc"], + deps = [ + ":eri_kernel", + "@cuda//:cudart_static", + "@hemi", + ], +) -# cc_binary( -# name = "boys_test", -# srcs = ["boys_test.cc"], -# deps = [ -# ":boys", -# # "@cuda//:cudart_static", -# ], -# ) +cc_binary( + name = "boys_test", + srcs = ["boys_test.cc"], + deps = [ + ":boys", + "@cuda//:cudart_static", + ], +) diff --git a/d4ft/native/xla/BUILD b/d4ft/native/xla/BUILD index 65a10ae..7e61822 100644 --- a/d4ft/native/xla/BUILD +++ b/d4ft/native/xla/BUILD @@ -26,7 +26,7 @@ cc_library( hdrs = ["custom_call.h"], deps = [ ":specs", - # "@cuda//:cudart_static", + "@cuda//:cudart_static", ], ) diff --git a/tests/native/xla/BUILD b/tests/native/xla/BUILD index f857a52..8a452bf 100644 --- a/tests/native/xla/BUILD +++ b/tests/native/xla/BUILD @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -# load("@rules_cuda//cuda:defs.bzl", "cuda_library") +load("@rules_cuda//cuda:defs.bzl", "cuda_library") +load("@pip_requirements//:requirements.bzl", "requirement") +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") +load("@rules_python//python:defs.bzl", "py_test") package(default_visibility = ["//visibility:public"]) @@ -27,43 +30,43 @@ cc_library( ], ) -# cuda_library( -# name = "example_cu", -# srcs = ["example.cu"], -# copts = [ -# "--std=c++17", -# ], -# deps = [ -# ":example_h", -# # "@cuda//:cudart_static", -# ], -# ) +cuda_library( + name = "example_cu", + srcs = ["example.cu"], + copts = [ + "--std=c++17", + ], + deps = [ + ":example_h", + "@cuda//:cudart_static", + ], +) -# pybind_extension( -# name = "example", -# srcs = [ -# "example.cc", -# ], -# copts = [ -# "--std=c++17", -# ], -# deps = [ -# ":example_cu", -# "//d4ft/native/xla:custom_call_h", -# ], -# ) +pybind_extension( + name = "example", + srcs = [ + "example.cc", + ], + copts = [ + "--std=c++17", + ], + deps = [ + ":example_cu", + "//d4ft/native/xla:custom_call_h", + ], +) -# py_test( -# name = "example_test", -# srcs = ["example_test.py"], -# data = [ -# ":example.so", -# # "@cuda//:bin", -# ], -# deps = [ -# "//d4ft/native/xla:custom_call", -# requirement("absl-py"), -# requirement("jax"), -# requirement("jaxlib"), -# ], -# ) +py_test( + name = "example_test", + srcs = ["example_test.py"], + data = [ + ":example.so", + "@cuda//:bin", + ], + deps = [ + "//d4ft/native/xla:custom_call", + requirement("absl-py"), + requirement("jax"), + requirement("jaxlib"), + ], +) diff --git a/third_party/pip_requirements/requirements-dev.txt b/third_party/pip_requirements/requirements-dev.txt index 9a8d525..fdc05db 100644 --- a/third_party/pip_requirements/requirements-dev.txt +++ b/third_party/pip_requirements/requirements-dev.txt @@ -3,8 +3,8 @@ bs4==0.0.1 dm-haiku>=0.0.10 einops==0.6.1 jax-xc>=0.0.7 -# --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -# jax[cuda12_local]==0.4.13 +--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +jax[cuda12_local]==0.4.13 jax>=0.4.13 jaxlib>=0.3.25 jaxtyping==0.2.15