Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow einsum to support naive contraction strategy #24915

Open
ryan112358 opened this issue Nov 15, 2024 · 12 comments
Open

Allow einsum to support naive contraction strategy #24915

ryan112358 opened this issue Nov 15, 2024 · 12 comments
Assignees
Labels
enhancement New feature or request

Comments

@ryan112358
Copy link

ryan112358 commented Nov 15, 2024

I would like to compute an einsum according to the following formula:

n = 8192
arrays = [jax.random.normal(key=jax.random.PRNGKey(0), shape=(n, n)) for _ in range(6)]
formula = 'ij,ik,il,jk,jl,kl->ij'

I want to express the computation as 4 nested for loops over indices i, j, k, l without creating any intermediate arrays. As far as einsum_path is concerned, I can do this by passing the einsum path directly as [(0, 1, 2, 3, 4, 5)] via the optimize kwarg).

>>> jax.numpy.einsum_path(formula,` *arrays, optimize=[(0,1,2,3,4,5)])
Complete contraction:  ij,ik,il,jk,jl,kl->ij
          Naive scaling:  4
      Optimized scaling:  4
       Naive FLOP count:  2.702e+16
   Optimized FLOP count:  2.702e+16
    Theoretical speedup:  1.000e+0
   Largest intermediate:  6.711e+7 elements
 --------------------------------------------------------------------------------
 scaling        BLAS                current                             remaining
 --------------------------------------------------------------------------------
    4              0  kl,jl,jk,il,ik,ij->ij                                ij->ij)

However, when I try to do the einsum, I get this NotImplementedError with a comment that says "# if this is actually reachable, open an issue!"

https://github.com/jax-ml/jax/blob/main/jax/_src/numpy/lax_numpy.py#L9775

>>> ans = jnp.einsum(formula, *arrays, optimize=[(0,1,2,3,4,5)])
>>> ans.block_until_ready()
@ryan112358 ryan112358 added the enhancement New feature or request label Nov 15, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Nov 21, 2024

I think your path specification is invalid. For example, if you pass it to NumPy, you get this error:

np.einsum(formula, *arrays, optimize=[(0,1,2,3,4,5)])
Traceback (most recent call last):
  File "/Users/vanderplas/github/google/jax/tmp.py", line 9, in <module>
    np.einsum(formula, *arrays, optimize=[(0,1,2,3,4,5)])
  File "/Users/vanderplas/.local/share/virtualenvs/jax-LBbfM5ix/lib/python3.12/site-packages/numpy/_core/einsumfunc.py", line 1441, in einsum
    operands, contraction_list = einsum_path(*operands, optimize=optimize,
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/vanderplas/.local/share/virtualenvs/jax-LBbfM5ix/lib/python3.12/site-packages/numpy/_core/einsumfunc.py", line 878, in einsum_path
    raise TypeError("Did not understand the path: %s" % str(path_type))
TypeError: Did not understand the path: [(0, 1, 2, 3, 4, 5)]

@ryan112358
Copy link
Author

ryan112358 commented Nov 22, 2024

Thank you for taking a look! My understanding is that this path is the default behavior for numpy. I.e., it corresponds to the basic implementation that you have in

https://github.com/jax-ml/jax/blob/main/tests/lax_numpy_einsum_test.py#L295

It is much more memory efficient than doing the einsum as a sequence of dot_general's in this case, which from my investigation is hard-coded into the JAX implementation. It makes sense because dot_general is very highly optimized, but being able to get the more memory-efficient behavior seems desirable in some settings.

@ryan112358
Copy link
Author

ryan112358 commented Nov 22, 2024

I prototyped a version of this using a sequence of nested jax.lax.scan calls, but it was ugly and I don't think the most performant. I also played around with using Jax.vmap over the indices (i, j) and using jnp.einsum using the per-element path

Complete contraction: ij,ik,il,jk,jl,kl->ij
[vmap] Per-Row contraction: j,k,l,jk,jl,kl->j
[double vmap] Per-element contraction: ,k,l,k,l,kl->

It was pretty cool to use JAX's abstractions to achieve this, and the vmap implementation did have better performance characteristics than jnp.einsum in this case, but I still think it uses more memory than the naive approach.

If Jax.lax.map supported the in_axes argument, I think that would help, since I could just replace my usage of vmap with map.

@ryan112358
Copy link
Author

ryan112358 commented Nov 23, 2024

Here is a basic implementation of the naive strategy in terms of jax.vmap and jax.lax.scan, specialized to the formula 'ij,ik,il,jk,jl,kl->ij'.

import jax
import jax.numpy as jnp
import time

def inner_einsum(*arrays):
  # computes einsum for ,k,l,k,l,kl->
  # Does not create any intermediate arrays

  A, B, C, D, E, F = arrays
  K, L = B.size, C.size

  def foo(partial1, k):
    def bar(partial2, l):
      return partial2 + C[l] * E[l] * F[k, l], ()
    return partial1 + B[k] * D[k] * jax.lax.scan(bar, 0, jnp.arange(L))[0], ()
  return A * jax.lax.scan(foo, 0, jnp.arange(K))[0]


@jax.jit
def vmap_einsum(*arrays):
  # computes einsum for ij,ik,il,jk,jl,kl->ij naively
  # No memory overhead.  Vectorized across output cells.

  return jax.vmap(
      jax.vmap(inner_einsum, in_axes=(0, None, None, 0, 0, None)),
      in_axes=(0, 0, 0, None, None, None)
  )(*arrays)

@jax.jit
def default_einsum(*arrays):
  return jnp.einsum('ij,ik,il,jk,jl,kl->ij', *arrays)

when I benchmark it using n x n arrays for n = [128, 256, 512, 1024] here is what I get for timing information (measured in seconds, not counting JIT compilation). The story is that jnp.einsum is faster up to n=512, but fails at n=1024, while the naive approach implemented above still runs, albeit it takes more time than I'd like.

n=128
vmap_einsum 0.14367246627807617
default_einsum 0.002198457717895508

n=256
vmap_einsum 0.7639327049255371
default_einsum 0.017670154571533203

n=512
vmap_einsum 4.290320158004761
default_einsum 0.24642205238342285

n=1024
vmap_einsum 35.70246410369873
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
[<ipython-input-5-cc97c95bcc5f>](https://colab.corp.google.com/drive/1s4c4kdOR2VNVoKIyZHlj0M9gh0wm_Am9#) in <cell line: 0>()
      4 
      5   for einsum_fn in EINSUM_IMPLS:
----> 6     jax.block_until_ready(einsum_fn(*arrays))
      7     t0 = time.time()
      8     jax.block_until_ready(einsum_fn(*arrays))

    [... skipping hidden 5 frame]

@ryan112358
Copy link
Author

ryan112358 commented Nov 23, 2024

Here's another impl one can throw into the mix: scan_einsum where we strip out a non-output axis and sequentially compute + add up the resulting smaller einsums, as follows:

@jax.jit
def scan_einsum(*arrays):
  # we will scan over k and build up a running sum

  A, B, C, D, E, F = arrays
  K = B.shape[1]
  zeros = jnp.zeros(A.shape)

  def add_small_einsum(partial, k):
    # einsum with k stripped out
    # i,j,i,il,j,jl,l->ij
    return partial + jnp.einsum('ij,i,il,j,jl,l->ij', A, B[:,k], C, D[:,k], E, F[k,:]), ()

  return jax.lax.scan(add_small_einsum, zeros, jnp.arange(K))[0]

Benchmarks show that this is significantly better than the vmap_einsum above. And it's even better than jnp.einsum beyond n=256

n=128
vmap_einsum 0.13236498832702637
scan_einsum 0.0034575462341308594
default_einsum 0.0014224052429199219

n=256
vmap_einsum 0.7413990497589111
scan_einsum 0.011484861373901367
default_einsum 0.018535137176513672

n=512
vmap_einsum 4.2713000774383545
scan_einsum 0.04682159423828125
default_einsum 0.23777055740356445

n=1024
vmap_einsum 35.49849033355713
scan_einsum 0.47335124015808105
XlaRuntimeError

@ryan112358
Copy link
Author

If anyone is interested, I typed up this exploration on my blog:

https://www.ryanhmckenna.com/2024/11/exploring-multi-input-einsums-in-jax.html

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 2, 2024

Thanks for exploring this – are you running benchmarks on GPU/TPU as well, or just CPU? The reason I ask is that scan has a pretty big performance penalty on accelerators (essentially each iteration is its own kernel launch) so I expect any efficiency gains on CPU will not transfer to GPU or TPU.

@ryan112358
Copy link
Author

These tests were done in a colab sandbox with GPU, happy to do some more benchmarking if there's something specific you'd like to see

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 2, 2024

OK, thanks.

Overall, I tend to be -1 on changes like this. It greatly complicates things on the JAX side in order to make up for deficiencies in the compiler. The compiler behavior may be improved in the future, at which point we would needlessly be generating more complicated code with no clear way of alerting ourselves that this is the case.

@ryan112358
Copy link
Author

Is this a compiler deficiency though? My understanding is it is a JAX implementation choice that leads to this behavior, specifically https://github.com/jax-ml/jax/blob/main/jax/_src/numpy/lax_numpy.py#L9773, which implements einsum in terms of a "_dot_general" primitive, which I believe means the einsum is calculated as a sequence of pairwise contractions. Even if the compiler was better at _dot_general, it wouldn't get around the intractability of storing the required n^3 sized intermediates in this case.

Happy to keep this alternate implementation local to where I need it though to keep the jax impls simpler though.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 2, 2024

The compiler often fuses sequences of operations into single kernels to avoid storing intermediates. There may already be fusion paths for sequences of dot_general in some situations, but I'm not sure. scan is a much less specific primitive than dot general, so emitting scan would hamper the ability of the compiler to make such optimizations in the future.

I'm not saying your code is not useful; I think the approach probably makes sense in some situations. I just don't think it's a good fit for JAX's einsum implementation. (If @mattjj disagrees though, I'm happy to defer to his judgment here).

@ryan112358
Copy link
Author

Ah I see that makes sense, do you think I should open up an issue at https://github.com/openxla/xla in that case?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants