diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 7984dda..11fefc3 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -996,7 +996,7 @@ def assert_numerical_grads(f: Callable[..., Array], difference gradients. """ # Correct scaling. - # Remove after https://github.com/google/jax/issues/3130 is fixed. + # Remove after https://github.com/jax-ml/jax/issues/3130 is fixed. atol *= f_args[0].size # Mock `jax.lax.stop_gradient` because finite diff. method does not honour it. diff --git a/chex/_src/fake.py b/chex/_src/fake.py index 7d3d622..9fe2a57 100644 --- a/chex/_src/fake.py +++ b/chex/_src/fake.py @@ -57,7 +57,7 @@ def set_n_cpu_devices(n: Optional[int] = None) -> None: This allows `jax.pmap` to be tested on a single-CPU platform. This utility only takes effect before XLA backends are initialized, i.e. before any JAX operation is executed (including `jax.devices()` etc.). - See https://github.com/google/jax/issues/1408. + See https://github.com/jax-ml/jax/issues/1408. Args: n: A required number of CPU devices (``FLAGS.chex_n_cpu_devices`` is used by