Skip to content

Commit

Permalink
Fix Factor._get_batch_axes() edge case
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 14, 2024
1 parent 9a53fa4 commit d2b7b7b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
4 changes: 2 additions & 2 deletions examples/pose_graph_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
),
# "Between" factor.
jaxls.Factor(
lambda vals, var0, var1, delta: (
lambda vals, delta, var0, var1: (
(vals[var0].inverse() @ vals[var1]) @ delta.inverse()
).log(),
(vars[0], vars[1], jaxlie.SE2.from_xy_theta(1.0, 0.0, 0.0)),
(jaxlie.SE2.from_xy_theta(1.0, 0.0, 0.0), vars[0], vars[1]),
),
]

Expand Down
17 changes: 11 additions & 6 deletions src/jaxls/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,17 +398,22 @@ def make[*Args_](
return Factor(compute_residual, args, jac_mode)

def _get_batch_axes(self) -> tuple[int, ...]:
def traverse_args(current: Any) -> tuple[int, ...]:
def traverse_args(current: Any) -> tuple[int, ...] | None:
children_and_meta = default_registry.flatten_one_level(current)
assert children_and_meta is not None
if children_and_meta is None:
return None
for child in children_and_meta[0]:
if isinstance(child, Var):
return () if isinstance(child.id, int) else child.id.shape
else:
return traverse_args(child)
assert False, "No variables found in factor!"

return traverse_args(self.args)
batch_axes = traverse_args(child)
if batch_axes is not None:
return batch_axes
return None

batch_axes = traverse_args(self.args)
assert batch_axes is not None, "No variables found in factor!"
return batch_axes


@jdc.pytree_dataclass
Expand Down

0 comments on commit d2b7b7b

Please sign in to comment.