You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The solver's jax.lax.while_loop implementation prevents gradient computation through the environment step during gradient based trajectory optimization. This occurs in the solver implementation when iterations > 1.
Error encountered with jax.jit compiled grad function:
ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values.
Current workaround of using opt.iteration=1 leads to potentially inaccurate simulation and gradients.
Proposed Solution
Add an option to set a fixed iteration count (e.g., 4) that would be compatible with reverse-mode differentiation using either lax.scan or lax.fori_loop with static bounds.
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered:
I like this suggestion and have labeled it as a good one for someone to take on externally. If no one does, we'll eventually implement it ourselves.
If someone would like to try it, I'd recommend briefly proposing (in this issue) how to modify the API to expose this functionality, and then if we all agree, then open a PR.
The feature, motivation and pitch
Problem
The solver's
jax.lax.while_loop
implementation prevents gradient computation through the environment step during gradient based trajectory optimization. This occurs in the solver implementation when iterations > 1.Error encountered with
jax.jit
compiled grad function:Current workaround of using
opt.iteration=1
leads to potentially inaccurate simulation and gradients.Proposed Solution
Add an option to set a fixed iteration count (e.g., 4) that would be compatible with reverse-mode differentiation using either
lax.scan
orlax.fori_loop
with static bounds.Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: