Skip to content

Commit

Permalink
Add jac_batch_size=, quality-of-life improvements for variable
Browse files Browse the repository at this point in the history
batch axes
  • Loading branch information
brentyi committed Oct 29, 2024
1 parent dace546 commit a81af8a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
15 changes: 13 additions & 2 deletions src/jaxls/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def compute_jac_with_perturb(factor: _AnalyzedFactor) -> jax.Array:
)(jnp.zeros((val_subset._get_tangent_dim(),)))

# Compute Jacobian for each factor.
stacked_jac = jax.vmap(compute_jac_with_perturb)(factor)
stacked_jac = jax.lax.map(
compute_jac_with_perturb, factor, batch_size=factor.jac_batch_size
)
(num_factor,) = factor._get_batch_axes()
assert stacked_jac.shape == (
num_factor,
Expand Down Expand Up @@ -396,11 +398,20 @@ def _sort_key(x: Any) -> str:

@jdc.pytree_dataclass
class Factor[*Args]:
"""A single cost in our factor graph."""
"""A single cost in our factor graph. Costs with the same pytree structure
will automatically be paralellized."""

compute_residual: jdc.Static[Callable[[VarValues, *Args], jax.Array]]
args: tuple[*Args]
jac_mode: jdc.Static[Literal["auto", "forward", "reverse"]] = "auto"
"""Depending on the function being differentiated, it may be faster to use
forward-mode or reverse-mode autodiff."""
jac_batch_size: jdc.Static[int | None] = None
"""Batch size for computing Jacobians that can be parallelized. Can be set
to make tradeoffs between runtime and memory usage.
If None, we compute all Jacobians in parallel. If 1, we compute Jacobians
one at a time."""

@staticmethod
@deprecated("Use Factor() directly instead of Factor.make()")
Expand Down
12 changes: 10 additions & 2 deletions src/jaxls/_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Iterable
from dataclasses import dataclass
from functools import total_ordering
from typing import Any, Callable, ClassVar, cast, overload
from typing import Any, Callable, ClassVar, Self, cast, overload

import jax
import jax_dataclasses as jdc
Expand Down Expand Up @@ -83,6 +83,11 @@ def with_value(self, value: T) -> VarWithValue[T]:
for `VarValues.make()`."""
return VarWithValue(self, value)

def __getitem__(self, index_or_slice: int | slice) -> Self:
"""Shorthand for slicing the variable ID."""
assert not isinstance(self.id, int)
return self.__class__(self.id[index_or_slice])

@overload
def __init_subclass__[T_](
cls,
Expand Down Expand Up @@ -174,7 +179,10 @@ class VarValues:
"""Variable ID for each value, sorted in ascending order."""

def get_value[T](self, var: Var[T]) -> T:
"""Get the value of a specific variable."""
"""Get the value of a specific variable or variables."""
if not isinstance(var.id, int) and var.id.ndim > 0:
return jax.vmap(self.get_value)(var)

assert getattr(var.id, "shape", None) == () or isinstance(var.id, int)
var_type = type(var)
index = jnp.searchsorted(self.ids_from_type[var_type], var.id)
Expand Down

0 comments on commit a81af8a

Please sign in to comment.