Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pad_with_graphs written with numpy #52

Open
GrantMcConachie opened this issue Jul 5, 2024 · 1 comment
Open

pad_with_graphs written with numpy #52

GrantMcConachie opened this issue Jul 5, 2024 · 1 comment

Comments

@GrantMcConachie
Copy link

Hello! I was wondering if there is any particular reason that the pad_with_graphs function uses the numpy library rather than the jax.numpy library. It looks like every numpy function in there can just be replaced with jax.numpy without any issues, but I could be missing something.

@tisabe
Copy link

tisabe commented Jul 12, 2024

I was also wondering this in the past. Another piece where numpy is used/can be used is batch/batch_np. IIRC the numpy version was much faster in some situations, hinting that there is some unwanted jit compilation happening when using jnp functions. That might also be the case if the numpy functions in pad_with_graphs were replaced with jnp functions. To me it seems that in jax.numpy.sum some jit compiling is always happening, which is not what you want to happen if array sizes change. It would be nice to have some clarification on this though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants