Skip to content

Commit

Permalink
Adds ellipses test case, closes #235, 236
Browse files Browse the repository at this point in the history
  • Loading branch information
dgasmith committed Jun 26, 2024
1 parent 3b627b6 commit 602e61b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
2 changes: 1 addition & 1 deletion opt_einsum/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def possibly_convert_to_numpy(x: Any) -> Any:
<__main__.Shape object at 0x10f850710>
"""

# TODO
if not hasattr(x, "shape"):
# TODO : fix the raw NumPy import
import numpy as np

return np.asanyarray(x)
Expand Down
14 changes: 11 additions & 3 deletions opt_einsum/tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from opt_einsum.parser import get_symbol, parse_einsum_input, possibly_convert_to_numpy
from opt_einsum.testing import build_arrays_from_tuples, using_numpy
from opt_einsum.testing import build_arrays_from_tuples


def test_get_symbol() -> None:
Expand Down Expand Up @@ -34,13 +34,21 @@ def test_parse_einsum_input_shapes_error() -> None:
_ = parse_einsum_input([eq, *ops], shapes=True)


@using_numpy
def test_parse_einsum_input_shapes() -> None:
import numpy as np
np = pytest.importorskip("numpy")

eq = "ab,bc,cd"
shapes = [(2, 3), (3, 4), (4, 5)]
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shapes], shapes=True)
assert input_subscripts == eq
assert output_subscript == "ad"
assert np.allclose([possibly_convert_to_numpy(shp) for shp in shapes], operands)


def test_parse_with_ellisis() -> None:
eq = "...a,ab"
shapes = [(2, 3), (3, 4)]
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shapes], shapes=True)
assert input_subscripts == "da,ab"
assert output_subscript == "db"
assert shapes == operands
2 changes: 2 additions & 0 deletions opt_einsum/tests/test_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from opt_einsum.testing import build_views
from opt_einsum.typing import BackendType

pytest.importorskip("numpy")

try:
import numpy as np # noqa # type: ignore

Expand Down

0 comments on commit 602e61b

Please sign in to comment.