From a274599c87c6224382e2ce328c7439b7b713e363 Mon Sep 17 00:00:00 2001 From: Daniel Smith Date: Tue, 5 Jul 2022 18:11:47 -0700 Subject: [PATCH] Bug fixes and tweaks for v3.4.0 (#192) * Unpins MyPy, now appears stable * Minor typing tweaks * Skips unicode surrogates, fixes #182 * Adds simple validation to BranchBound, fixes #136 * More accurate branch bound exception handler * Update opt_einsum/paths.py --- devtools/conda-envs/full-environment.yaml | 2 +- devtools/conda-envs/min-deps-environment.yaml | 2 +- devtools/conda-envs/min-ver-environment.yaml | 2 +- opt_einsum/contract.py | 6 +++--- opt_einsum/parser.py | 10 ++++++---- opt_einsum/paths.py | 7 +++++-- opt_einsum/tests/test_parser.py | 14 ++++++++++++++ opt_einsum/tests/test_paths.py | 5 +++++ scripts/compare_random_paths.py | 2 +- 9 files changed, 37 insertions(+), 13 deletions(-) create mode 100644 opt_einsum/tests/test_parser.py diff --git a/devtools/conda-envs/full-environment.yaml b/devtools/conda-envs/full-environment.yaml index 33abe651..12863c3f 100644 --- a/devtools/conda-envs/full-environment.yaml +++ b/devtools/conda-envs/full-environment.yaml @@ -20,4 +20,4 @@ dependencies: - pytest - codecov - pytest-cov - - mypy ==0.812 + - mypy diff --git a/devtools/conda-envs/min-deps-environment.yaml b/devtools/conda-envs/min-deps-environment.yaml index ac87bcfd..9d2f4112 100644 --- a/devtools/conda-envs/min-deps-environment.yaml +++ b/devtools/conda-envs/min-deps-environment.yaml @@ -11,4 +11,4 @@ dependencies: - pytest - codecov - pytest-cov - - mypy ==0.812 + - mypy diff --git a/devtools/conda-envs/min-ver-environment.yaml b/devtools/conda-envs/min-ver-environment.yaml index 0d52d0f6..04c59067 100644 --- a/devtools/conda-envs/min-ver-environment.yaml +++ b/devtools/conda-envs/min-ver-environment.yaml @@ -15,4 +15,4 @@ dependencies: - pytest - codecov - pytest-cov - - mypy ==0.812 + - mypy diff --git a/opt_einsum/contract.py b/opt_einsum/contract.py index 14ccfaff..0df94d03 100644 --- a/opt_einsum/contract.py +++ b/opt_einsum/contract.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple from . import backends, blas, helpers, parser, paths, sharing -from .typing import ArrayIndexType, ArrayType, Collection, ContractionListType, PathType +from .typing import ArrayIndexType, ArrayType, ContractionListType, PathType __all__ = [ "contract_path", @@ -75,7 +75,7 @@ def __repr__(self) -> str: ] for n, contraction in enumerate(self.contraction_list): - inds, idx_rm, einsum_str, remaining, do_blas = contraction + _, _, einsum_str, remaining, do_blas = contraction if remaining is not None: remaining_str = ",".join(remaining) + "->" + self.output_subscript @@ -865,7 +865,7 @@ def __str__(self) -> str: Shaped = namedtuple("Shaped", ["shape"]) -def shape_only(shape: Collection[Tuple[int, ...]]) -> Shaped: +def shape_only(shape: PathType) -> Shaped: """Dummy ``numpy.ndarray`` which has a shape only - for generating contract expressions. """ diff --git a/opt_einsum/parser.py b/opt_einsum/parser.py index 76b095a2..8e3a00b2 100644 --- a/opt_einsum/parser.py +++ b/opt_einsum/parser.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 """ A functionally equivalent parser of the numpy.einsum input parser """ @@ -61,7 +59,7 @@ def has_valid_einsum_chars_only(einsum_str: str) -> bool: def get_symbol(i: int) -> str: """Get the symbol corresponding to int ``i`` - runs through the usual 52 - letters before resorting to unicode characters, starting at ``chr(192)``. + letters before resorting to unicode characters, starting at ``chr(192)`` and skipping surrogates. **Examples:** @@ -78,7 +76,11 @@ def get_symbol(i: int) -> str: """ if i < 52: return _einsum_symbols_base[i] - return chr(i + 140) + elif i >= 55296: + # Skip chr(57343) - chr(55296) as surrogates + return chr(i + 2048) + else: + return chr(i + 140) def gen_unused_symbols(used: str, n: int) -> Iterator[str]: diff --git a/opt_einsum/paths.py b/opt_einsum/paths.py index 23a2c49d..eac36301 100644 --- a/opt_einsum/paths.py +++ b/opt_einsum/paths.py @@ -352,6 +352,9 @@ def __init__( minimize="flops", cost_fn="memory-removed", ): + if (nbranch is not None) and nbranch < 1: + raise ValueError(f"The number of branches must be at least one, `nbranch={nbranch}`.") + self.nbranch = nbranch self.cutoff_flops_factor = cutoff_flops_factor self.minimize = minimize @@ -827,9 +830,9 @@ def _tree_to_sequence(tree: Tuple[Any, ...]) -> PathType: s[0] += (sum(1 for q in t if q < i),) t.insert(s[0][-1], i) - for i in [i for i in j if type(i) != int]: + for i_tup in [i_tup for i_tup in j if type(i_tup) != int]: s[0] += (len(t) + len(c),) - c.append(i) + c.append(i_tup) return s diff --git a/opt_einsum/tests/test_parser.py b/opt_einsum/tests/test_parser.py new file mode 100644 index 00000000..e81fafdd --- /dev/null +++ b/opt_einsum/tests/test_parser.py @@ -0,0 +1,14 @@ +""" +Directly tests various parser utility functions. +""" + +from opt_einsum.parser import get_symbol + + +def test_get_symbol(): + assert get_symbol(2) == "c" + assert get_symbol(200000) == "\U00031540" + # Ensure we skip surrogates '[\uD800-\uDFFF]' + assert get_symbol(55295) == "\ud88b" + assert get_symbol(55296) == "\ue000" + assert get_symbol(57343) == "\ue7ff" diff --git a/opt_einsum/tests/test_paths.py b/opt_einsum/tests/test_paths.py index 03fd1049..c6fef4f5 100644 --- a/opt_einsum/tests/test_paths.py +++ b/opt_einsum/tests/test_paths.py @@ -394,6 +394,11 @@ def test_custom_branchbound(): path, path_info = oe.contract_path(eq, *views, optimize=optimizer) +def test_branchbound_validation(): + with pytest.raises(ValueError): + oe.BranchBound(nbranch=0) + + @pytest.mark.skipif(sys.version_info < (3, 2), reason="requires python3.2 or higher") def test_parallel_random_greedy(): from concurrent.futures import ProcessPoolExecutor diff --git a/scripts/compare_random_paths.py b/scripts/compare_random_paths.py index c7917cda..b6d4bf7a 100644 --- a/scripts/compare_random_paths.py +++ b/scripts/compare_random_paths.py @@ -33,7 +33,7 @@ alpha = list("abcdefghijklmnopqrstuvwyxz") alpha_dict = {num: x for num, x in enumerate(alpha)} -print("Maximum term size is %d" % (max_size ** max_dims)) +print("Maximum term size is %d" % (max_size**max_dims)) def make_term():