From 19de34edec00914489adc7adee53917730e390ae Mon Sep 17 00:00:00 2001 From: Norman Date: Fri, 14 Jun 2024 10:56:58 -0700 Subject: [PATCH 1/2] Fixes `operands could not be broadcast` error when the subscripts contain ellipsis and the operands are shapes. --- opt_einsum/parser.py | 4 ++-- opt_einsum/tests/test_parser.py | 9 +++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/opt_einsum/parser.py b/opt_einsum/parser.py index 47567ae..5c6ac53 100644 --- a/opt_einsum/parser.py +++ b/opt_einsum/parser.py @@ -309,8 +309,8 @@ def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, L else: subscripts, operands = convert_interleaved_input(operands) - if shapes: - operand_shapes = operands + if shapes: + operand_shapes = [list(s) for s in operands] else: operand_shapes = [o.shape for o in operands] diff --git a/opt_einsum/tests/test_parser.py b/opt_einsum/tests/test_parser.py index d582ca4..ed1a30b 100644 --- a/opt_einsum/tests/test_parser.py +++ b/opt_einsum/tests/test_parser.py @@ -41,3 +41,12 @@ def test_parse_einsum_input_shapes() -> None: assert input_subscripts == eq assert output_subscript == "ad" assert np.allclose([possibly_convert_to_numpy(shp) for shp in shps], operands) + + +def test_parse_with_ellisis(): + eq = "...a,ab" + shps = [(2, 3), (3, 4)] + input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shps], shapes=True) + assert input_subscripts == "da,ab" + assert output_subscript == "db" + assert np.allclose([possibly_convert_to_numpy(shp) for shp in shps], operands) From 9be3d95895b18f78f212b7980af1fcdf1a6ce6fa Mon Sep 17 00:00:00 2001 From: Norman Date: Fri, 14 Jun 2024 12:04:46 -0700 Subject: [PATCH 2/2] fixed indent --- opt_einsum/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opt_einsum/parser.py b/opt_einsum/parser.py index 5c6ac53..d24e9be 100644 --- a/opt_einsum/parser.py +++ b/opt_einsum/parser.py @@ -309,7 +309,7 @@ def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, L else: subscripts, operands = convert_interleaved_input(operands) - if shapes: + if shapes: operand_shapes = [list(s) for s in operands] else: operand_shapes = [o.shape for o in operands]