diff --git a/src/jaxls/_factor_graph.py b/src/jaxls/_factor_graph.py index aa41a00..78317b9 100644 --- a/src/jaxls/_factor_graph.py +++ b/src/jaxls/_factor_graph.py @@ -134,7 +134,7 @@ def compute_jac_with_perturb(factor: _AnalyzedFactor) -> jax.Array: "reverse": jax.jacrev, "auto": jax.jacrev if factor.residual_dim < val_subset._get_tangent_dim() - else jax.jacrev, + else jax.jacfwd, }[factor.jac_mode] return jacfunc( # The residual function, with respect to to some local delta.