Skip to content

Commit

Permalink
Expose option to switch between sparse matrix representations
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 18, 2024
1 parent 1613d1d commit 0b51e88
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 8 deletions.
23 changes: 21 additions & 2 deletions src/jaxls/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,24 @@ def solve(
| ConjugateGradientConfig = "conjugate_gradient",
trust_region: TrustRegionConfig | None = TrustRegionConfig(),
termination: TerminationConfig = TerminationConfig(),
sparse_mode: Literal["blockrow", "coo", "csr"] = "blockrow",
verbose: bool = True,
) -> VarValues:
"""Solve the nonlinear least squares problem using either Gauss-Newton
or Levenberg-Marquardt."""
or Levenberg-Marquardt.
Args:
initial_vals: Initial values for the variables. If None, default values will be used.
linear_solver: The linear solver to use.
trust_region: Configuration for Levenberg-Marquardt trust region.
termination: Configuration for termination criteria.
sparse_mode: The representation to use for sparse matrix
multiplication. Can be "blockrow", "coo", or "csr".
verbose: Whether to print verbose output during optimization.
Returns:
Optimized variable values.
"""
if initial_vals is None:
initial_vals = VarValues.make(
var_type(ids) for var_type, ids in self.sorted_ids_from_var_type.items()
Expand All @@ -79,7 +93,12 @@ def solve(
linear_solver = "conjugate_gradient"

solver = NonlinearSolver(
linear_solver, trust_region, termination, conjugate_gradient_config, verbose
linear_solver,
trust_region,
termination,
conjugate_gradient_config,
sparse_mode,
verbose,
)
return solver.solve(graph=self, initial_vals=initial_vals)

Expand Down
32 changes: 26 additions & 6 deletions src/jaxls/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
make_point_jacobi_precoditioner,
)

from ._sparse_matrices import BlockRowSparseMatrix, SparseCsrMatrix
from ._sparse_matrices import BlockRowSparseMatrix, SparseCooMatrix, SparseCsrMatrix
from ._variables import VarTypeOrdering, VarValues
from .utils import jax_log

Expand Down Expand Up @@ -191,6 +191,7 @@ class NonlinearSolver:
trust_region: TrustRegionConfig | None
termination: TerminationConfig
conjugate_gradient_config: ConjugateGradientConfig | None
sparse_mode: jdc.Static[Literal["blockrow", "coo", "csr"]]
verbose: jdc.Static[bool]

@jdc.jit
Expand Down Expand Up @@ -254,11 +255,30 @@ def step(
)

# linear_transpose() will return a tuple, with one element per primal.
A_multiply = A_blocksparse.multiply
AT_multiply_ = jax.linear_transpose(
A_multiply, jnp.zeros((A_blocksparse.shape[1],))
)
AT_multiply = lambda vec: AT_multiply_(vec)[0]
if self.sparse_mode == "blockrow":
A_multiply = A_blocksparse.multiply
AT_multiply_ = jax.linear_transpose(
A_multiply, jnp.zeros((A_blocksparse.shape[1],))
)
AT_multiply = lambda vec: AT_multiply_(vec)[0]
elif self.sparse_mode == "coo":
A_coo = SparseCooMatrix(
values=jac_values, coords=graph.jac_coords_coo
).as_jax_bcoo()
AT_coo = A_coo.transpose()
A_multiply = lambda vec: A_coo @ vec
AT_multiply = lambda vec: AT_coo @ vec
elif self.sparse_mode == "csr":
A_csr = SparseCsrMatrix(
values=jac_values, coords=graph.jac_coords_csr
).as_jax_bcsr()
A_multiply = lambda vec: A_csr @ vec
AT_multiply_ = jax.linear_transpose(
A_multiply, jnp.zeros((A_blocksparse.shape[1],))
)
AT_multiply = lambda vec: AT_multiply_(vec)[0]
else:
assert_never(self.sparse_mode)

# Compute right-hand side of normal equation.
ATb = -AT_multiply(state.residual_vector)
Expand Down
8 changes: 8 additions & 0 deletions src/jaxls/_sparse_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ class SparseCsrMatrix:
coords: SparseCsrCoordinates
"""Indices describing non-zero entries."""

def as_jax_bcsr(self) -> jax.experimental.sparse.BCSR:
return jax.experimental.sparse.BCSR(
args=(self.values, self.coords.indices, self.coords.indptr),
shape=self.coords.shape,
indices_sorted=True,
unique_indices=True,
)


@jdc.pytree_dataclass
class SparseCooCoordinates:
Expand Down

0 comments on commit 0b51e88

Please sign in to comment.