You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Here's a small case where np.einsum works but jnp.einsum does not
import numpy as np
import jax.numpy as jnp
formula = 'a,c,d,db,ab,cb,ac,cd,ad,b->dbc'
arrays = [np.random.rand(*(2,)*len(key)) for key in formula.split('->')[0].split(',')]
np.einsum(formula, *arrays)
array([[[6.26532636e-05, 9.94054312e-04],
[3.24902199e-05, 2.90052489e-03]],
[[1.21862902e-05, 9.85561040e-05],
[2.81959491e-06, 1.77314102e-04]]])
jnp.einsum(formula, *arrays) # this hangs and does not complete
System info (python version, jaxlib version, accelerator, etc.)
Ah I see interesting, I guess in that case I can get immediately unblocked by just changing the optimize kwarg for now. Went ahead and reported dgasmith/opt_einsum#243
According to dgasmith/opt_einsum#243, setting path='auto' might be a preferable default. As far as I understand, it defaults to 'optimal' if the number of components is small and will use something different if that will not run in a reasonable amount of time.
Description
Here's a small case where np.einsum works but jnp.einsum does not
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.33
jaxlib: 0.4.33
numpy: 1.26.4
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='849fd340451c', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 27 21:05:47 UTC 2024', machine='x86_64')
The text was updated successfully, but these errors were encountered: