Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fixes and tweaks for v3.4.0 #192

Merged
merged 6 commits into from
Jul 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion devtools/conda-envs/full-environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ dependencies:
- pytest
- codecov
- pytest-cov
- mypy ==0.812
- mypy
2 changes: 1 addition & 1 deletion devtools/conda-envs/min-deps-environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ dependencies:
- pytest
- codecov
- pytest-cov
- mypy ==0.812
- mypy
2 changes: 1 addition & 1 deletion devtools/conda-envs/min-ver-environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ dependencies:
- pytest
- codecov
- pytest-cov
- mypy ==0.812
- mypy
6 changes: 3 additions & 3 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

from . import backends, blas, helpers, parser, paths, sharing
from .typing import ArrayIndexType, ArrayType, Collection, ContractionListType, PathType
from .typing import ArrayIndexType, ArrayType, ContractionListType, PathType

__all__ = [
"contract_path",
Expand Down Expand Up @@ -75,7 +75,7 @@ def __repr__(self) -> str:
]

for n, contraction in enumerate(self.contraction_list):
inds, idx_rm, einsum_str, remaining, do_blas = contraction
_, _, einsum_str, remaining, do_blas = contraction

if remaining is not None:
remaining_str = ",".join(remaining) + "->" + self.output_subscript
Expand Down Expand Up @@ -865,7 +865,7 @@ def __str__(self) -> str:
Shaped = namedtuple("Shaped", ["shape"])


def shape_only(shape: Collection[Tuple[int, ...]]) -> Shaped:
def shape_only(shape: PathType) -> Shaped:
"""Dummy ``numpy.ndarray`` which has a shape only - for generating
contract expressions.
"""
Expand Down
10 changes: 6 additions & 4 deletions opt_einsum/parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# coding: utf-8
"""
A functionally equivalent parser of the numpy.einsum input parser
"""
Expand Down Expand Up @@ -61,7 +59,7 @@ def has_valid_einsum_chars_only(einsum_str: str) -> bool:

def get_symbol(i: int) -> str:
"""Get the symbol corresponding to int ``i`` - runs through the usual 52
letters before resorting to unicode characters, starting at ``chr(192)``.
letters before resorting to unicode characters, starting at ``chr(192)`` and skipping surrogates.

**Examples:**

Expand All @@ -78,7 +76,11 @@ def get_symbol(i: int) -> str:
"""
if i < 52:
return _einsum_symbols_base[i]
return chr(i + 140)
elif i >= 55296:
# Skip chr(57343) - chr(55296) as surrogates
return chr(i + 2048)
else:
return chr(i + 140)


def gen_unused_symbols(used: str, n: int) -> Iterator[str]:
Expand Down
7 changes: 5 additions & 2 deletions opt_einsum/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ def __init__(
minimize="flops",
cost_fn="memory-removed",
):
if (nbranch is not None) and nbranch < 1:
raise ValueError(f"The number of branches must be at least one, `nbranch={nbranch}`.")

self.nbranch = nbranch
self.cutoff_flops_factor = cutoff_flops_factor
self.minimize = minimize
Expand Down Expand Up @@ -827,9 +830,9 @@ def _tree_to_sequence(tree: Tuple[Any, ...]) -> PathType:
s[0] += (sum(1 for q in t if q < i),)
t.insert(s[0][-1], i)

for i in [i for i in j if type(i) != int]:
for i_tup in [i_tup for i_tup in j if type(i_tup) != int]:
s[0] += (len(t) + len(c),)
c.append(i)
c.append(i_tup)

return s

Expand Down
14 changes: 14 additions & 0 deletions opt_einsum/tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
Directly tests various parser utility functions.
"""

from opt_einsum.parser import get_symbol


def test_get_symbol():
assert get_symbol(2) == "c"
assert get_symbol(200000) == "\U00031540"
# Ensure we skip surrogates '[\uD800-\uDFFF]'
assert get_symbol(55295) == "\ud88b"
assert get_symbol(55296) == "\ue000"
assert get_symbol(57343) == "\ue7ff"
5 changes: 5 additions & 0 deletions opt_einsum/tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,11 @@ def test_custom_branchbound():
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)


def test_branchbound_validation():
with pytest.raises(ValueError):
oe.BranchBound(nbranch=0)


@pytest.mark.skipif(sys.version_info < (3, 2), reason="requires python3.2 or higher")
def test_parallel_random_greedy():
from concurrent.futures import ProcessPoolExecutor
Expand Down
2 changes: 1 addition & 1 deletion scripts/compare_random_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
alpha = list("abcdefghijklmnopqrstuvwyxz")
alpha_dict = {num: x for num, x in enumerate(alpha)}

print("Maximum term size is %d" % (max_size ** max_dims))
print("Maximum term size is %d" % (max_size**max_dims))


def make_term():
Expand Down