From 602e61ba38f43d4ab06501156ad6abd7ce002382 Mon Sep 17 00:00:00 2001 From: Daniel Smith Date: Tue, 25 Jun 2024 21:28:52 -0400 Subject: [PATCH] Adds ellipses test case, closes #235, 236 --- opt_einsum/parser.py | 2 +- opt_einsum/tests/test_parser.py | 14 +++++++++++--- opt_einsum/tests/test_sharing.py | 2 ++ 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/opt_einsum/parser.py b/opt_einsum/parser.py index 2697c0d..b2ae20a 100644 --- a/opt_einsum/parser.py +++ b/opt_einsum/parser.py @@ -196,8 +196,8 @@ def possibly_convert_to_numpy(x: Any) -> Any: <__main__.Shape object at 0x10f850710> """ - # TODO if not hasattr(x, "shape"): + # TODO : fix the raw NumPy import import numpy as np return np.asanyarray(x) diff --git a/opt_einsum/tests/test_parser.py b/opt_einsum/tests/test_parser.py index 3bbbee0..0fe7a69 100644 --- a/opt_einsum/tests/test_parser.py +++ b/opt_einsum/tests/test_parser.py @@ -5,7 +5,7 @@ import pytest from opt_einsum.parser import get_symbol, parse_einsum_input, possibly_convert_to_numpy -from opt_einsum.testing import build_arrays_from_tuples, using_numpy +from opt_einsum.testing import build_arrays_from_tuples def test_get_symbol() -> None: @@ -34,9 +34,8 @@ def test_parse_einsum_input_shapes_error() -> None: _ = parse_einsum_input([eq, *ops], shapes=True) -@using_numpy def test_parse_einsum_input_shapes() -> None: - import numpy as np + np = pytest.importorskip("numpy") eq = "ab,bc,cd" shapes = [(2, 3), (3, 4), (4, 5)] @@ -44,3 +43,12 @@ def test_parse_einsum_input_shapes() -> None: assert input_subscripts == eq assert output_subscript == "ad" assert np.allclose([possibly_convert_to_numpy(shp) for shp in shapes], operands) + + +def test_parse_with_ellisis() -> None: + eq = "...a,ab" + shapes = [(2, 3), (3, 4)] + input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shapes], shapes=True) + assert input_subscripts == "da,ab" + assert output_subscript == "db" + assert shapes == operands diff --git a/opt_einsum/tests/test_sharing.py b/opt_einsum/tests/test_sharing.py index f5c1cb2..42717fb 100644 --- a/opt_einsum/tests/test_sharing.py +++ b/opt_einsum/tests/test_sharing.py @@ -13,6 +13,8 @@ from opt_einsum.testing import build_views from opt_einsum.typing import BackendType +pytest.importorskip("numpy") + try: import numpy as np # noqa # type: ignore