Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MJX] jax.lax.while_loop in solver.py prevents computation of backward gradients #2259

Open
EGalahad opened this issue Nov 29, 2024 · 2 comments
Assignees
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@EGalahad
Copy link

The feature, motivation and pitch

Problem

The solver's jax.lax.while_loop implementation prevents gradient computation through the environment step during gradient based trajectory optimization. This occurs in the solver implementation when iterations > 1.

Error encountered with jax.jit compiled grad function:

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values.

Current workaround of using opt.iteration=1 leads to potentially inaccurate simulation and gradients.

Proposed Solution

Add an option to set a fixed iteration count (e.g., 4) that would be compatible with reverse-mode differentiation using either lax.scan or lax.fori_loop with static bounds.

Alternatives

No response

Additional context

No response

@EGalahad EGalahad added the enhancement New feature or request label Nov 29, 2024
@erikfrey erikfrey added the good first issue Good for newcomers label Dec 3, 2024
@erikfrey
Copy link
Collaborator

erikfrey commented Dec 3, 2024

I like this suggestion and have labeled it as a good one for someone to take on externally. If no one does, we'll eventually implement it ourselves.

If someone would like to try it, I'd recommend briefly proposing (in this issue) how to modify the API to expose this functionality, and then if we all agree, then open a PR.

@jaraujo98
Copy link

@erikfrey are you still looking for a volunteer to tackle this? I'd like to give it a shot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

4 participants