From 8c43999460c25fe22dd44be40c23c2d2560621d2 Mon Sep 17 00:00:00 2001 From: brentyi Date: Wed, 20 Nov 2024 19:57:21 -0800 Subject: [PATCH] Fix performance problems with `jac_batch_size=None` --- src/jaxls/_factor_graph.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/jaxls/_factor_graph.py b/src/jaxls/_factor_graph.py index 04e0bdf..aa41a00 100644 --- a/src/jaxls/_factor_graph.py +++ b/src/jaxls/_factor_graph.py @@ -145,9 +145,14 @@ def compute_jac_with_perturb(factor: _AnalyzedFactor) -> jax.Array: )(jnp.zeros((val_subset._get_tangent_dim(),))) # Compute Jacobian for each factor. - stacked_jac = jax.lax.map( - compute_jac_with_perturb, factor, batch_size=factor.jac_batch_size - ) + if factor.jac_batch_size is None: + stacked_jac = jax.vmap(compute_jac_with_perturb)(factor) + else: + # When `batch_size` is `None`, jax.lax.map reduces to a scan + # (similar to `batch_size=1`). + stacked_jac = jax.lax.map( + compute_jac_with_perturb, factor, batch_size=factor.jac_batch_size + ) (num_factor,) = factor._get_batch_axes() assert stacked_jac.shape == ( num_factor,