From 0b51e88ff06b325745791670cc6ee4e34c57057f Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Fri, 18 Oct 2024 21:35:21 +0000 Subject: [PATCH] Expose option to switch between sparse matrix representations --- src/jaxls/_factor_graph.py | 23 +++++++++++++++++++++-- src/jaxls/_solvers.py | 32 ++++++++++++++++++++++++++------ src/jaxls/_sparse_matrices.py | 8 ++++++++ 3 files changed, 55 insertions(+), 8 deletions(-) diff --git a/src/jaxls/_factor_graph.py b/src/jaxls/_factor_graph.py index 0581648..a740a75 100644 --- a/src/jaxls/_factor_graph.py +++ b/src/jaxls/_factor_graph.py @@ -60,10 +60,24 @@ def solve( | ConjugateGradientConfig = "conjugate_gradient", trust_region: TrustRegionConfig | None = TrustRegionConfig(), termination: TerminationConfig = TerminationConfig(), + sparse_mode: Literal["blockrow", "coo", "csr"] = "blockrow", verbose: bool = True, ) -> VarValues: """Solve the nonlinear least squares problem using either Gauss-Newton - or Levenberg-Marquardt.""" + or Levenberg-Marquardt. + + Args: + initial_vals: Initial values for the variables. If None, default values will be used. + linear_solver: The linear solver to use. + trust_region: Configuration for Levenberg-Marquardt trust region. + termination: Configuration for termination criteria. + sparse_mode: The representation to use for sparse matrix + multiplication. Can be "blockrow", "coo", or "csr". + verbose: Whether to print verbose output during optimization. + + Returns: + Optimized variable values. + """ if initial_vals is None: initial_vals = VarValues.make( var_type(ids) for var_type, ids in self.sorted_ids_from_var_type.items() @@ -79,7 +93,12 @@ def solve( linear_solver = "conjugate_gradient" solver = NonlinearSolver( - linear_solver, trust_region, termination, conjugate_gradient_config, verbose + linear_solver, + trust_region, + termination, + conjugate_gradient_config, + sparse_mode, + verbose, ) return solver.solve(graph=self, initial_vals=initial_vals) diff --git a/src/jaxls/_solvers.py b/src/jaxls/_solvers.py index 1971173..d53a73d 100644 --- a/src/jaxls/_solvers.py +++ b/src/jaxls/_solvers.py @@ -16,7 +16,7 @@ make_point_jacobi_precoditioner, ) -from ._sparse_matrices import BlockRowSparseMatrix, SparseCsrMatrix +from ._sparse_matrices import BlockRowSparseMatrix, SparseCooMatrix, SparseCsrMatrix from ._variables import VarTypeOrdering, VarValues from .utils import jax_log @@ -191,6 +191,7 @@ class NonlinearSolver: trust_region: TrustRegionConfig | None termination: TerminationConfig conjugate_gradient_config: ConjugateGradientConfig | None + sparse_mode: jdc.Static[Literal["blockrow", "coo", "csr"]] verbose: jdc.Static[bool] @jdc.jit @@ -254,11 +255,30 @@ def step( ) # linear_transpose() will return a tuple, with one element per primal. - A_multiply = A_blocksparse.multiply - AT_multiply_ = jax.linear_transpose( - A_multiply, jnp.zeros((A_blocksparse.shape[1],)) - ) - AT_multiply = lambda vec: AT_multiply_(vec)[0] + if self.sparse_mode == "blockrow": + A_multiply = A_blocksparse.multiply + AT_multiply_ = jax.linear_transpose( + A_multiply, jnp.zeros((A_blocksparse.shape[1],)) + ) + AT_multiply = lambda vec: AT_multiply_(vec)[0] + elif self.sparse_mode == "coo": + A_coo = SparseCooMatrix( + values=jac_values, coords=graph.jac_coords_coo + ).as_jax_bcoo() + AT_coo = A_coo.transpose() + A_multiply = lambda vec: A_coo @ vec + AT_multiply = lambda vec: AT_coo @ vec + elif self.sparse_mode == "csr": + A_csr = SparseCsrMatrix( + values=jac_values, coords=graph.jac_coords_csr + ).as_jax_bcsr() + A_multiply = lambda vec: A_csr @ vec + AT_multiply_ = jax.linear_transpose( + A_multiply, jnp.zeros((A_blocksparse.shape[1],)) + ) + AT_multiply = lambda vec: AT_multiply_(vec)[0] + else: + assert_never(self.sparse_mode) # Compute right-hand side of normal equation. ATb = -AT_multiply(state.residual_vector) diff --git a/src/jaxls/_sparse_matrices.py b/src/jaxls/_sparse_matrices.py index 5824b9e..c07524e 100644 --- a/src/jaxls/_sparse_matrices.py +++ b/src/jaxls/_sparse_matrices.py @@ -145,6 +145,14 @@ class SparseCsrMatrix: coords: SparseCsrCoordinates """Indices describing non-zero entries.""" + def as_jax_bcsr(self) -> jax.experimental.sparse.BCSR: + return jax.experimental.sparse.BCSR( + args=(self.values, self.coords.indices, self.coords.indptr), + shape=self.coords.shape, + indices_sorted=True, + unique_indices=True, + ) + @jdc.pytree_dataclass class SparseCooCoordinates: