Skip to content

Commit

Permalink
bump jaxlib version
Browse files Browse the repository at this point in the history
  • Loading branch information
zekun-shi committed Oct 27, 2023
1 parent e2c1007 commit 503b6f5
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,12 @@ python main.py --run viz --config.save_dir _exp/bh76,6-31g+lda_x,3aahmyt0
```

# CUDA dev guide
First create the proper pip
``` shell
cd third_party/pip_requirements
ln -sf requirements-dev.txt requirements.txt
```

``` shell
bazel build //...
bazel run //d4ft/native/obara_saika:eri_test
Expand Down
5 changes: 3 additions & 2 deletions d4ft/native/xla/custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import jax
from jax import core, dtypes
from jax.abstract_arrays import ShapedArray
from jax.core import ShapedArray
from jax.interpreters import xla
from jax.lib import xla_client

Expand Down Expand Up @@ -61,9 +61,10 @@ def __new__(cls: Any, name: str, parents: Tuple, attrs: Dict) -> Any:
cpu_capsule, gpu_capsule = base._capsules

# Register the XLA custom calls
xla_client.register_cpu_custom_call_target(
xla_client.register_custom_call_target(
f"{name}_cpu".encode(),
fn=cpu_capsule,
platform="cpu",
)
xla_client.register_custom_call_target(
f"{name}_gpu".encode(),
Expand Down
4 changes: 2 additions & 2 deletions third_party/pip_requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ dm-haiku>=0.0.10
einops==0.6.1
jax-xc>=0.0.7
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12_local]==0.4.13
jax>=0.4.13
jax[cuda12_pip]==0.4.19
# jax>=0.4.13
jaxlib>=0.3.25
jaxtyping==0.2.15
matplotlib==3.7.2
Expand Down

0 comments on commit 503b6f5

Please sign in to comment.