Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
zekun-shi committed Oct 27, 2023
1 parent ce6f1c1 commit e2c1007
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 79 deletions.
2 changes: 1 addition & 1 deletion .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ build --@rules_cuda//cuda:enable_cuda
build --@rules_cuda//cuda:copts=-std=c++17
build --repo_env=CUDA_DIR=/opt/cuda
build --repo_env=CUDA_PATH=/opt/cuda
#build --repo_env=CC=gcc-12
build --repo_env=CC=gcc-12
# TODO: check whether both env var are needed
build --action_env=CUDA_PATH=/opt/cuda
#build --action_env=CC=gcc-12
Expand Down
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,13 @@ To visualize the run:
python main.py --run viz --config.save_dir _exp/bh76,6-31g+lda_x,3aahmyt0
```

# CUDA dev guide
``` shell
bazel build //...
bazel run //d4ft/native/obara_saika:eri_test
bazel test --test_output=all //tests/native/xla:example_test
```

## Tutorial and Documentation

### Viewing in the Browser
Expand Down
2 changes: 1 addition & 1 deletion d4ft/native/gamma/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ cc_library(
name = "constants",
hdrs = ["constants.h"],
deps = [
# "@cuda//:cudart_static",
"@cuda//:cudart_static",
"@hemi",
],
)
Expand Down
2 changes: 1 addition & 1 deletion d4ft/native/gamma/igamma.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ HEMI_DEV_CALLABLE FLOAT IgammacContinuedFraction(FLOAT ax, FLOAT x, FLOAT a,
dqkm2_da = dqkm2_da * HEMI_CONSTANT(eps)<FLOAT>;
dqkm1_da = dqkm1_da * HEMI_CONSTANT(eps)<FLOAT>;
}
FLOAT conditional;
// FLOAT conditional;
if (mode == VALUE) {
enabled = (enabled && t > HEMI_CONSTANT(eps) < FLOAT >);
} else {
Expand Down
68 changes: 34 additions & 34 deletions d4ft/native/obara_saika/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

load("@rules_cc//cc:defs.bzl", "cc_library")
# load("@rules_cuda//cuda:defs.bzl", "cuda_library")
load("@rules_cuda//cuda:defs.bzl", "cuda_library")

package(default_visibility = ["//visibility:public"])

Expand Down Expand Up @@ -42,38 +42,38 @@ cc_library(
],
)

# cuda_library(
# name = "eri_kernel",
# srcs = ["eri_kernel.cu"],
# hdrs = ["eri_kernel.h"],
# copts = [
# "--expt-extended-lambda",
# "--expt-relaxed-constexpr",
# ],
# linkopts = [
# "-lstdc++",
# ],
# deps = [
# ":eri",
# "@hemi",
# ],
# )
cuda_library(
name = "eri_kernel",
srcs = ["eri_kernel.cu"],
hdrs = ["eri_kernel.h"],
copts = [
"--expt-extended-lambda",
"--expt-relaxed-constexpr",
],
linkopts = [
"-lstdc++",
],
deps = [
":eri",
"@hemi",
],
)

# cc_binary(
# name = "eri_test",
# srcs = ["eri_test.cc"],
# deps = [
# ":eri_kernel",
# # "@cuda//:cudart_static",
# "@hemi",
# ],
# )
cc_binary(
name = "eri_test",
srcs = ["eri_test.cc"],
deps = [
":eri_kernel",
"@cuda//:cudart_static",
"@hemi",
],
)

# cc_binary(
# name = "boys_test",
# srcs = ["boys_test.cc"],
# deps = [
# ":boys",
# # "@cuda//:cudart_static",
# ],
# )
cc_binary(
name = "boys_test",
srcs = ["boys_test.cc"],
deps = [
":boys",
"@cuda//:cudart_static",
],
)
2 changes: 1 addition & 1 deletion d4ft/native/xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ cc_library(
hdrs = ["custom_call.h"],
deps = [
":specs",
# "@cuda//:cudart_static",
"@cuda//:cudart_static",
],
)

Expand Down
81 changes: 42 additions & 39 deletions tests/native/xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# load("@rules_cuda//cuda:defs.bzl", "cuda_library")
load("@rules_cuda//cuda:defs.bzl", "cuda_library")
load("@pip_requirements//:requirements.bzl", "requirement")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
load("@rules_python//python:defs.bzl", "py_test")

package(default_visibility = ["//visibility:public"])

Expand All @@ -27,43 +30,43 @@ cc_library(
],
)

# cuda_library(
# name = "example_cu",
# srcs = ["example.cu"],
# copts = [
# "--std=c++17",
# ],
# deps = [
# ":example_h",
# # "@cuda//:cudart_static",
# ],
# )
cuda_library(
name = "example_cu",
srcs = ["example.cu"],
copts = [
"--std=c++17",
],
deps = [
":example_h",
"@cuda//:cudart_static",
],
)

# pybind_extension(
# name = "example",
# srcs = [
# "example.cc",
# ],
# copts = [
# "--std=c++17",
# ],
# deps = [
# ":example_cu",
# "//d4ft/native/xla:custom_call_h",
# ],
# )
pybind_extension(
name = "example",
srcs = [
"example.cc",
],
copts = [
"--std=c++17",
],
deps = [
":example_cu",
"//d4ft/native/xla:custom_call_h",
],
)

# py_test(
# name = "example_test",
# srcs = ["example_test.py"],
# data = [
# ":example.so",
# # "@cuda//:bin",
# ],
# deps = [
# "//d4ft/native/xla:custom_call",
# requirement("absl-py"),
# requirement("jax"),
# requirement("jaxlib"),
# ],
# )
py_test(
name = "example_test",
srcs = ["example_test.py"],
data = [
":example.so",
"@cuda//:bin",
],
deps = [
"//d4ft/native/xla:custom_call",
requirement("absl-py"),
requirement("jax"),
requirement("jaxlib"),
],
)
4 changes: 2 additions & 2 deletions third_party/pip_requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ bs4==0.0.1
dm-haiku>=0.0.10
einops==0.6.1
jax-xc>=0.0.7
# --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# jax[cuda12_local]==0.4.13
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12_local]==0.4.13
jax>=0.4.13
jaxlib>=0.3.25
jaxtyping==0.2.15
Expand Down

0 comments on commit e2c1007

Please sign in to comment.