Skip to content

Commit

Permalink
Update references to JAX's GitHub repo
Browse files Browse the repository at this point in the history
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax

PiperOrigin-RevId: 702886821
  • Loading branch information
jakeharmon8 authored and ChexDev committed Dec 11, 2024
1 parent 8af2c9e commit 0f8960f
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ def assert_numerical_grads(f: Callable[..., Array],
difference gradients.
"""
# Correct scaling.
# Remove after https://github.com/google/jax/issues/3130 is fixed.
# Remove after https://github.com/jax-ml/jax/issues/3130 is fixed.
atol *= f_args[0].size

# Mock `jax.lax.stop_gradient` because finite diff. method does not honour it.
Expand Down
2 changes: 1 addition & 1 deletion chex/_src/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def set_n_cpu_devices(n: Optional[int] = None) -> None:
This allows `jax.pmap` to be tested on a single-CPU platform.
This utility only takes effect before XLA backends are initialized, i.e.
before any JAX operation is executed (including `jax.devices()` etc.).
See https://github.com/google/jax/issues/1408.
See https://github.com/jax-ml/jax/issues/1408.
Args:
n: A required number of CPU devices (``FLAGS.chex_n_cpu_devices`` is used by
Expand Down

0 comments on commit 0f8960f

Please sign in to comment.