-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FR] add Gumbel distribution #548
Comments
Also commenting to add info on Paper, (deprecated?) Pytorch version with straight-through forward-mode support |
Wow, thanks to @daydreamt! That was fast. I'd be happy to close this, or re-name it as "Gumbel-Softmax" if that functionality is also wanted. Based on the pytorch link above, it seems like they are making use of some pytorch-specific detachments, etc. to implement the straight-through version (argmax in the fwd/sampling direction, softmax in the backward/gradient direction, if I'm understanding correctly). Not sure how such a thing could be done with the JAX backend. |
@tbsexton A similar version can be implemented using jax.lax.stop_gradient, which plays the role of PyTorch I would like to close this issue for now and discuss on the new one. :) |
As requested @tbsexton at #545, it would be nice to have Gumbel distribution in NumPyro.
References: wikipedia, pytorch dist, jax sampler.
The text was updated successfully, but these errors were encountered: