From fc88fcac55355895da0397fc0460954ac57e1ac7 Mon Sep 17 00:00:00 2001 From: "Daniel G. A. Smith" Date: Sun, 3 Jul 2022 10:07:53 -0700 Subject: [PATCH 1/6] Unpins MyPy, now appears stable --- devtools/conda-envs/full-environment.yaml | 2 +- devtools/conda-envs/min-deps-environment.yaml | 2 +- devtools/conda-envs/min-ver-environment.yaml | 2 +- opt_einsum/paths.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/devtools/conda-envs/full-environment.yaml b/devtools/conda-envs/full-environment.yaml index 33abe651..12863c3f 100644 --- a/devtools/conda-envs/full-environment.yaml +++ b/devtools/conda-envs/full-environment.yaml @@ -20,4 +20,4 @@ dependencies: - pytest - codecov - pytest-cov - - mypy ==0.812 + - mypy diff --git a/devtools/conda-envs/min-deps-environment.yaml b/devtools/conda-envs/min-deps-environment.yaml index ac87bcfd..9d2f4112 100644 --- a/devtools/conda-envs/min-deps-environment.yaml +++ b/devtools/conda-envs/min-deps-environment.yaml @@ -11,4 +11,4 @@ dependencies: - pytest - codecov - pytest-cov - - mypy ==0.812 + - mypy diff --git a/devtools/conda-envs/min-ver-environment.yaml b/devtools/conda-envs/min-ver-environment.yaml index 0d52d0f6..04c59067 100644 --- a/devtools/conda-envs/min-ver-environment.yaml +++ b/devtools/conda-envs/min-ver-environment.yaml @@ -15,4 +15,4 @@ dependencies: - pytest - codecov - pytest-cov - - mypy ==0.812 + - mypy diff --git a/opt_einsum/paths.py b/opt_einsum/paths.py index 23a2c49d..5340dbdc 100644 --- a/opt_einsum/paths.py +++ b/opt_einsum/paths.py @@ -827,9 +827,9 @@ def _tree_to_sequence(tree: Tuple[Any, ...]) -> PathType: s[0] += (sum(1 for q in t if q < i),) t.insert(s[0][-1], i) - for i in [i for i in j if type(i) != int]: + for i_tup in [i_tup for i_tup in j if type(i) != int]: s[0] += (len(t) + len(c),) - c.append(i) + c.append(i_tup) return s From 6704495fc90514b4944a264f73c8e2b59d64b994 Mon Sep 17 00:00:00 2001 From: "Daniel G. A. Smith" Date: Sun, 3 Jul 2022 10:08:47 -0700 Subject: [PATCH 2/6] Minor typing tweaks --- opt_einsum/contract.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/opt_einsum/contract.py b/opt_einsum/contract.py index 14ccfaff..0df94d03 100644 --- a/opt_einsum/contract.py +++ b/opt_einsum/contract.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple from . import backends, blas, helpers, parser, paths, sharing -from .typing import ArrayIndexType, ArrayType, Collection, ContractionListType, PathType +from .typing import ArrayIndexType, ArrayType, ContractionListType, PathType __all__ = [ "contract_path", @@ -75,7 +75,7 @@ def __repr__(self) -> str: ] for n, contraction in enumerate(self.contraction_list): - inds, idx_rm, einsum_str, remaining, do_blas = contraction + _, _, einsum_str, remaining, do_blas = contraction if remaining is not None: remaining_str = ",".join(remaining) + "->" + self.output_subscript @@ -865,7 +865,7 @@ def __str__(self) -> str: Shaped = namedtuple("Shaped", ["shape"]) -def shape_only(shape: Collection[Tuple[int, ...]]) -> Shaped: +def shape_only(shape: PathType) -> Shaped: """Dummy ``numpy.ndarray`` which has a shape only - for generating contract expressions. """ From 13973fb1b81f1fb6eb722d9dada69fc5d5d0f11a Mon Sep 17 00:00:00 2001 From: "Daniel G. A. Smith" Date: Sun, 3 Jul 2022 10:12:41 -0700 Subject: [PATCH 3/6] Skips unicode surrogates, fixes #182 --- opt_einsum/parser.py | 10 ++++++---- opt_einsum/tests/test_parser.py | 14 ++++++++++++++ scripts/compare_random_paths.py | 2 +- 3 files changed, 21 insertions(+), 5 deletions(-) create mode 100644 opt_einsum/tests/test_parser.py diff --git a/opt_einsum/parser.py b/opt_einsum/parser.py index 76b095a2..8e3a00b2 100644 --- a/opt_einsum/parser.py +++ b/opt_einsum/parser.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 """ A functionally equivalent parser of the numpy.einsum input parser """ @@ -61,7 +59,7 @@ def has_valid_einsum_chars_only(einsum_str: str) -> bool: def get_symbol(i: int) -> str: """Get the symbol corresponding to int ``i`` - runs through the usual 52 - letters before resorting to unicode characters, starting at ``chr(192)``. + letters before resorting to unicode characters, starting at ``chr(192)`` and skipping surrogates. **Examples:** @@ -78,7 +76,11 @@ def get_symbol(i: int) -> str: """ if i < 52: return _einsum_symbols_base[i] - return chr(i + 140) + elif i >= 55296: + # Skip chr(57343) - chr(55296) as surrogates + return chr(i + 2048) + else: + return chr(i + 140) def gen_unused_symbols(used: str, n: int) -> Iterator[str]: diff --git a/opt_einsum/tests/test_parser.py b/opt_einsum/tests/test_parser.py new file mode 100644 index 00000000..e81fafdd --- /dev/null +++ b/opt_einsum/tests/test_parser.py @@ -0,0 +1,14 @@ +""" +Directly tests various parser utility functions. +""" + +from opt_einsum.parser import get_symbol + + +def test_get_symbol(): + assert get_symbol(2) == "c" + assert get_symbol(200000) == "\U00031540" + # Ensure we skip surrogates '[\uD800-\uDFFF]' + assert get_symbol(55295) == "\ud88b" + assert get_symbol(55296) == "\ue000" + assert get_symbol(57343) == "\ue7ff" diff --git a/scripts/compare_random_paths.py b/scripts/compare_random_paths.py index c7917cda..b6d4bf7a 100644 --- a/scripts/compare_random_paths.py +++ b/scripts/compare_random_paths.py @@ -33,7 +33,7 @@ alpha = list("abcdefghijklmnopqrstuvwyxz") alpha_dict = {num: x for num, x in enumerate(alpha)} -print("Maximum term size is %d" % (max_size ** max_dims)) +print("Maximum term size is %d" % (max_size**max_dims)) def make_term(): From e96e2b0207ee3e2da4ed5b77d8531e11e8484b69 Mon Sep 17 00:00:00 2001 From: "Daniel G. A. Smith" Date: Sun, 3 Jul 2022 10:19:44 -0700 Subject: [PATCH 4/6] Adds simple validation to BranchBound, fixes #136 --- opt_einsum/paths.py | 3 +++ opt_einsum/tests/test_paths.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/opt_einsum/paths.py b/opt_einsum/paths.py index 5340dbdc..74dc96ff 100644 --- a/opt_einsum/paths.py +++ b/opt_einsum/paths.py @@ -352,6 +352,9 @@ def __init__( minimize="flops", cost_fn="memory-removed", ): + if nbranch < 1: + raise ValueError(f"The number of branches must be at least one, `nbranch={nbranch}.") + self.nbranch = nbranch self.cutoff_flops_factor = cutoff_flops_factor self.minimize = minimize diff --git a/opt_einsum/tests/test_paths.py b/opt_einsum/tests/test_paths.py index 03fd1049..c6fef4f5 100644 --- a/opt_einsum/tests/test_paths.py +++ b/opt_einsum/tests/test_paths.py @@ -394,6 +394,11 @@ def test_custom_branchbound(): path, path_info = oe.contract_path(eq, *views, optimize=optimizer) +def test_branchbound_validation(): + with pytest.raises(ValueError): + oe.BranchBound(nbranch=0) + + @pytest.mark.skipif(sys.version_info < (3, 2), reason="requires python3.2 or higher") def test_parallel_random_greedy(): from concurrent.futures import ProcessPoolExecutor From 8a3a46c1bd9636a2309303f4cb797c291021ecf8 Mon Sep 17 00:00:00 2001 From: "Daniel G. A. Smith" Date: Sun, 3 Jul 2022 10:49:57 -0700 Subject: [PATCH 5/6] More accurate branch bound exception handler --- opt_einsum/paths.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/opt_einsum/paths.py b/opt_einsum/paths.py index 74dc96ff..f7a0a5bd 100644 --- a/opt_einsum/paths.py +++ b/opt_einsum/paths.py @@ -352,7 +352,7 @@ def __init__( minimize="flops", cost_fn="memory-removed", ): - if nbranch < 1: + if (nbranch is not None) and nbranch < 1: raise ValueError(f"The number of branches must be at least one, `nbranch={nbranch}.") self.nbranch = nbranch @@ -830,7 +830,7 @@ def _tree_to_sequence(tree: Tuple[Any, ...]) -> PathType: s[0] += (sum(1 for q in t if q < i),) t.insert(s[0][-1], i) - for i_tup in [i_tup for i_tup in j if type(i) != int]: + for i_tup in [i_tup for i_tup in j if type(i_tup) != int]: s[0] += (len(t) + len(c),) c.append(i_tup) From 30d304b9821c45618d92ae0119b067c85e64a25e Mon Sep 17 00:00:00 2001 From: Daniel Smith Date: Sun, 3 Jul 2022 21:50:36 -0700 Subject: [PATCH 6/6] Update opt_einsum/paths.py --- opt_einsum/paths.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opt_einsum/paths.py b/opt_einsum/paths.py index f7a0a5bd..eac36301 100644 --- a/opt_einsum/paths.py +++ b/opt_einsum/paths.py @@ -353,7 +353,7 @@ def __init__( cost_fn="memory-removed", ): if (nbranch is not None) and nbranch < 1: - raise ValueError(f"The number of branches must be at least one, `nbranch={nbranch}.") + raise ValueError(f"The number of branches must be at least one, `nbranch={nbranch}`.") self.nbranch = nbranch self.cutoff_flops_factor = cutoff_flops_factor