You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello! I was wondering if there is any particular reason that the pad_with_graphs function uses the numpy library rather than the jax.numpy library. It looks like every numpy function in there can just be replaced with jax.numpy without any issues, but I could be missing something.
The text was updated successfully, but these errors were encountered:
I was also wondering this in the past. Another piece where numpy is used/can be used is batch/batch_np. IIRC the numpy version was much faster in some situations, hinting that there is some unwanted jit compilation happening when using jnp functions. That might also be the case if the numpy functions in pad_with_graphs were replaced with jnp functions. To me it seems that in jax.numpy.sum some jit compiling is always happening, which is not what you want to happen if array sizes change. It would be nice to have some clarification on this though.
Hello! I was wondering if there is any particular reason that the
pad_with_graphs
function uses the numpy library rather than the jax.numpy library. It looks like every numpy function in there can just be replaced with jax.numpy without any issues, but I could be missing something.The text was updated successfully, but these errors were encountered: