You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The issue tracker should only be used to report bugs or feature requests. If you are looking for support from other library users, please ask a question on StackOverflow.
Describe the bug
When I try to import the PGD attack from JAX module, I get the following error: ModuleNotFoundError: No module named 'jax.experimental.stax'
The reason is that the FGSM implementation imports the logsoftmax function from the experimental package in jax/attacks/fast_gradient_method.py
To Reproduce
Steps to reproduce the behavior:
Open a google colab IPython notebook
add the following code !pip install git+https://github.com/cleverhans-lab/cleverhans.git#egg=cleverhans to install cleverhans
then try to import from cleverhans.jax.attacks.projected_gradient_descent import projected_gradient_descent
execute and see the error
Expected behavior
the logsoftmax function should be imported from jax.nn package. Change the import to from jax.nn import log_softmax as logsoftmax and the error would be gone.
Screenshots
If applicable, add screenshots to help explain your problem.
System configuration
Google colab's default
The text was updated successfully, but these errors were encountered:
The issue tracker should only be used to report bugs or feature requests. If you are looking for support from other library users, please ask a question on StackOverflow.
Describe the bug
When I try to import the PGD attack from JAX module, I get the following error:
ModuleNotFoundError: No module named 'jax.experimental.stax'
The reason is that the FGSM implementation imports the logsoftmax function from the experimental package in
jax/attacks/fast_gradient_method.py
To Reproduce
Steps to reproduce the behavior:
!pip install git+https://github.com/cleverhans-lab/cleverhans.git#egg=cleverhans
to install cleverhansfrom cleverhans.jax.attacks.projected_gradient_descent import projected_gradient_descent
Expected behavior
the logsoftmax function should be imported from
jax.nn
package. Change the import tofrom jax.nn import log_softmax as logsoftmax
and the error would be gone.Screenshots
If applicable, add screenshots to help explain your problem.
System configuration
The text was updated successfully, but these errors were encountered: