From 503b6f5778010578201119b86160084a23682c17 Mon Sep 17 00:00:00 2001 From: Zekun Shi Date: Fri, 27 Oct 2023 15:50:16 +0800 Subject: [PATCH] bump jaxlib version --- README.md | 6 ++++++ d4ft/native/xla/custom_call.py | 5 +++-- third_party/pip_requirements/requirements-dev.txt | 4 ++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index f1eab0d..ba9b45a 100644 --- a/README.md +++ b/README.md @@ -240,6 +240,12 @@ python main.py --run viz --config.save_dir _exp/bh76,6-31g+lda_x,3aahmyt0 ``` # CUDA dev guide +First create the proper pip +``` shell +cd third_party/pip_requirements +ln -sf requirements-dev.txt requirements.txt +``` + ``` shell bazel build //... bazel run //d4ft/native/obara_saika:eri_test diff --git a/d4ft/native/xla/custom_call.py b/d4ft/native/xla/custom_call.py index 9005be3..d20ca46 100644 --- a/d4ft/native/xla/custom_call.py +++ b/d4ft/native/xla/custom_call.py @@ -18,7 +18,7 @@ import jax from jax import core, dtypes -from jax.abstract_arrays import ShapedArray +from jax.core import ShapedArray from jax.interpreters import xla from jax.lib import xla_client @@ -61,9 +61,10 @@ def __new__(cls: Any, name: str, parents: Tuple, attrs: Dict) -> Any: cpu_capsule, gpu_capsule = base._capsules # Register the XLA custom calls - xla_client.register_cpu_custom_call_target( + xla_client.register_custom_call_target( f"{name}_cpu".encode(), fn=cpu_capsule, + platform="cpu", ) xla_client.register_custom_call_target( f"{name}_gpu".encode(), diff --git a/third_party/pip_requirements/requirements-dev.txt b/third_party/pip_requirements/requirements-dev.txt index fdc05db..9706171 100644 --- a/third_party/pip_requirements/requirements-dev.txt +++ b/third_party/pip_requirements/requirements-dev.txt @@ -4,8 +4,8 @@ dm-haiku>=0.0.10 einops==0.6.1 jax-xc>=0.0.7 --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -jax[cuda12_local]==0.4.13 -jax>=0.4.13 +jax[cuda12_pip]==0.4.19 +# jax>=0.4.13 jaxlib>=0.3.25 jaxtyping==0.2.15 matplotlib==3.7.2