Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 13, 2024
1 parent a66fc4b commit 5fef8c8
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions src/jaxls/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,16 +195,6 @@ def solve(self, graph: FactorGraph, initial_vals: VarValues) -> VarValues:
vals = initial_vals
residual_vector = graph.compute_residual_vector(vals)

cg_state = None
if isinstance(self.linear_solver, ConjugateGradientConfig):
cg_state = _ConjugateGradientState(
ATb_norm_prev=0.0, eta=self.linear_solver.tolerance_max
)
elif self.linear_solver == "conjugate_gradient":
cg_state = _ConjugateGradientState(
ATb_norm_prev=0.0, eta=ConjugateGradientConfig().tolerance_max
)

state = NonlinearSolverState(
iterations=0,
vals=vals,
Expand All @@ -215,7 +205,16 @@ def solve(self, graph: FactorGraph, initial_vals: VarValues) -> VarValues:
lambd=self.trust_region.lambda_initial
if self.trust_region is not None
else 0.0,
cg_state=cg_state,
cg_state=None
if self.linear_solver != "conjugate_gradient"
else _ConjugateGradientState(
ATb_norm_prev=0.0,
eta=(
ConjugateGradientConfig()
if self.conjugate_gradient_config is None
else self.conjugate_gradient_config
).tolerance_max,
),
)

# Optimization.
Expand Down

0 comments on commit 5fef8c8

Please sign in to comment.