-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow einsum to support naive contraction strategy #24915
Comments
I think your path specification is invalid. For example, if you pass it to NumPy, you get this error: np.einsum(formula, *arrays, optimize=[(0,1,2,3,4,5)]) Traceback (most recent call last):
File "/Users/vanderplas/github/google/jax/tmp.py", line 9, in <module>
np.einsum(formula, *arrays, optimize=[(0,1,2,3,4,5)])
File "/Users/vanderplas/.local/share/virtualenvs/jax-LBbfM5ix/lib/python3.12/site-packages/numpy/_core/einsumfunc.py", line 1441, in einsum
operands, contraction_list = einsum_path(*operands, optimize=optimize,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/vanderplas/.local/share/virtualenvs/jax-LBbfM5ix/lib/python3.12/site-packages/numpy/_core/einsumfunc.py", line 878, in einsum_path
raise TypeError("Did not understand the path: %s" % str(path_type))
TypeError: Did not understand the path: [(0, 1, 2, 3, 4, 5)] |
Thank you for taking a look! My understanding is that this path is the default behavior for numpy. I.e., it corresponds to the basic implementation that you have in https://github.com/jax-ml/jax/blob/main/tests/lax_numpy_einsum_test.py#L295 It is much more memory efficient than doing the einsum as a sequence of dot_general's in this case, which from my investigation is hard-coded into the JAX implementation. It makes sense because dot_general is very highly optimized, but being able to get the more memory-efficient behavior seems desirable in some settings. |
I prototyped a version of this using a sequence of nested jax.lax.scan calls, but it was ugly and I don't think the most performant. I also played around with using Jax.vmap over the indices (i, j) and using jnp.einsum using the per-element path Complete contraction: ij,ik,il,jk,jl,kl->ij It was pretty cool to use JAX's abstractions to achieve this, and the vmap implementation did have better performance characteristics than jnp.einsum in this case, but I still think it uses more memory than the naive approach. If Jax.lax.map supported the in_axes argument, I think that would help, since I could just replace my usage of vmap with map. |
Here is a basic implementation of the naive strategy in terms of jax.vmap and jax.lax.scan, specialized to the formula 'ij,ik,il,jk,jl,kl->ij'.
when I benchmark it using n x n arrays for n = [128, 256, 512, 1024] here is what I get for timing information (measured in seconds, not counting JIT compilation). The story is that jnp.einsum is faster up to n=512, but fails at n=1024, while the naive approach implemented above still runs, albeit it takes more time than I'd like.
|
Here's another impl one can throw into the mix:
Benchmarks show that this is significantly better than the vmap_einsum above. And it's even better than jnp.einsum beyond n=256
|
If anyone is interested, I typed up this exploration on my blog: https://www.ryanhmckenna.com/2024/11/exploring-multi-input-einsums-in-jax.html |
Thanks for exploring this – are you running benchmarks on GPU/TPU as well, or just CPU? The reason I ask is that |
These tests were done in a colab sandbox with GPU, happy to do some more benchmarking if there's something specific you'd like to see |
OK, thanks. Overall, I tend to be -1 on changes like this. It greatly complicates things on the JAX side in order to make up for deficiencies in the compiler. The compiler behavior may be improved in the future, at which point we would needlessly be generating more complicated code with no clear way of alerting ourselves that this is the case. |
Is this a compiler deficiency though? My understanding is it is a JAX implementation choice that leads to this behavior, specifically https://github.com/jax-ml/jax/blob/main/jax/_src/numpy/lax_numpy.py#L9773, which implements einsum in terms of a "_dot_general" primitive, which I believe means the einsum is calculated as a sequence of pairwise contractions. Even if the compiler was better at _dot_general, it wouldn't get around the intractability of storing the required n^3 sized intermediates in this case. Happy to keep this alternate implementation local to where I need it though to keep the jax impls simpler though. |
The compiler often fuses sequences of operations into single kernels to avoid storing intermediates. There may already be fusion paths for sequences of I'm not saying your code is not useful; I think the approach probably makes sense in some situations. I just don't think it's a good fit for JAX's einsum implementation. (If @mattjj disagrees though, I'm happy to defer to his judgment here). |
Ah I see that makes sense, do you think I should open up an issue at https://github.com/openxla/xla in that case? |
I would like to compute an einsum according to the following formula:
I want to express the computation as 4 nested for loops over indices i, j, k, l without creating any intermediate arrays. As far as einsum_path is concerned, I can do this by passing the einsum path directly as [(0, 1, 2, 3, 4, 5)] via the optimize kwarg).
However, when I try to do the einsum, I get this NotImplementedError with a comment that says "# if this is actually reachable, open an issue!"
https://github.com/jax-ml/jax/blob/main/jax/_src/numpy/lax_numpy.py#L9775
The text was updated successfully, but these errors were encountered: