Skip to content

Commit

Permalink
Updates MyPy
Browse files Browse the repository at this point in the history
  • Loading branch information
dgasmith committed Jun 26, 2024
1 parent 1bdc70f commit 3b627b6
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 58 deletions.
4 changes: 3 additions & 1 deletion opt_einsum/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, L
"shapes is set to True but given at least one operand looks like an array"
" (at least one operand has a shape attribute). "
)
operands = [possibly_convert_to_numpy(x) for x in operands[1:]]
operands = operands[1:]
else:
operands = [possibly_convert_to_numpy(x) for x in operands[1:]]
else:
subscripts, operands = convert_interleaved_input(operands)

Expand Down
47 changes: 36 additions & 11 deletions opt_einsum/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest

from opt_einsum.parser import get_symbol
from opt_einsum.typing import GenericArrayType, PathType
from opt_einsum.typing import GenericArrayType, PathType, TensorShapeType

_valid_chars = "abcdefghijklmopqABC"
_sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4]
Expand All @@ -30,36 +30,61 @@ def import_numpy_or_skip() -> Any:
return import_module("numpy")


def build_views(string: str, dimension_dict: Optional[Dict[str, int]] = None) -> List[GenericArrayType]:
def build_shapes(string: str, dimension_dict: Optional[Dict[str, int]] = None) -> Tuple[TensorShapeType, ...]:
"""
Builds random numpy arrays for testing.
Builds random tensor shapes for testing.
Parameters:
string: List of tensor strings to build
dimension_dict: Dictionary of index _sizes
Returns
The resulting views.
The resulting shapes.
Examples:
```python
>>> view = build_views('abbc', {'a': 2, 'b':3, 'c':5})
>>> view[0].shape
(2, 3, 3, 5)
>>> shapes = build_shapes('abbc', {'a': 2, 'b':3, 'c':5})
>>> shapes
[(2, 3), (3, 3, 5), (5,)]
```
"""
np = import_numpy_or_skip()

if dimension_dict is None:
dimension_dict = _default_dim_dict

views = []
shapes = []
terms = string.split("->")[0].split(",")
for term in terms:
dims = [dimension_dict[x] for x in term]
views.append(np.random.rand(*dims))
return views
shapes.append(tuple(dims))
return tuple(shapes)


def build_views(string: str, dimension_dict: Optional[Dict[str, int]] = None) -> Tuple[GenericArrayType]:
"""
Builds random numpy arrays for testing.
Parameters:
string: List of tensor strings to build
dimension_dict: Dictionary of index _sizes
Returns
The resulting views.
Examples:
```python
>>> view = build_views('abbc', {'a': 2, 'b':3, 'c':5})
>>> view[0].shape
(2, 3, 3, 5)
```
"""
np = import_numpy_or_skip()
views = []
for shape in build_shapes(string, dimension_dict=dimension_dict):
views.append(np.random.rand(*shape))
return tuple(views)


@overload
Expand Down
14 changes: 7 additions & 7 deletions opt_einsum/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
@pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.")
@pytest.mark.parametrize("string", tests)
def test_tensorflow(string: str) -> None:
views = helpers.build_views(string)
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
opt = np.empty_like(ein)

Expand Down Expand Up @@ -128,7 +128,7 @@ def test_tensorflow_with_constants(constants: Set[int]) -> None:
@pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.")
@pytest.mark.parametrize("string", tests)
def test_tensorflow_with_sharing(string: str) -> None:
views = helpers.build_views(string)
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)

shps = [v.shape for v in views]
Expand All @@ -153,7 +153,7 @@ def test_tensorflow_with_sharing(string: str) -> None:
@pytest.mark.skipif(not found_theano, reason="Theano not installed.")
@pytest.mark.parametrize("string", tests)
def test_theano(string: str) -> None:
views = helpers.build_views(string)
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]

Expand Down Expand Up @@ -197,7 +197,7 @@ def test_theano_with_constants(constants: Set[int]) -> None:
@pytest.mark.skipif(not found_theano, reason="Theano not installed.")
@pytest.mark.parametrize("string", tests)
def test_theano_with_sharing(string: str) -> None:
views = helpers.build_views(string)
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)

shps = [v.shape for v in views]
Expand All @@ -220,7 +220,7 @@ def test_theano_with_sharing(string: str) -> None:
@pytest.mark.skipif(not found_cupy, reason="Cupy not installed.")
@pytest.mark.parametrize("string", tests)
def test_cupy(string: str) -> None: # pragma: no cover
views = helpers.build_views(string)
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]

Expand Down Expand Up @@ -267,7 +267,7 @@ def test_cupy_with_constants(constants: Set[int]) -> None: # pragma: no cover
@pytest.mark.skipif(not found_jax, reason="jax not installed.")
@pytest.mark.parametrize("string", tests)
def test_jax(string: str) -> None: # pragma: no cover
views = helpers.build_views(string)
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]

Expand Down Expand Up @@ -457,7 +457,7 @@ def test_auto_backend_custom_array_no_tensordot() -> None:

@pytest.mark.parametrize("string", tests)
def test_object_arrays_backend(string: str) -> None:
views = helpers.build_views(string)
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
assert ein.dtype != object

Expand Down
2 changes: 1 addition & 1 deletion opt_einsum/tests/test_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_contract_expressions(string: str, optimize: OptimizeKind, use_blas: boo

def test_contract_expression_interleaved_input() -> None:
x, y, z = (np.random.randn(2, 2) for _ in "xyz")
expected = np.einsum(x, [0, 1], y, [1, 2], z, [2, 3], [3, 0]) # type: ignore
expected = np.einsum(x, [0, 1], y, [1, 2], z, [2, 3], [3, 0])
xshp, yshp, zshp = ((2, 2) for _ in "xyz")
expr = contract_expression(xshp, [0, 1], yshp, [1, 2], zshp, [2, 3], [3, 0])
out = expr(x, y, z)
Expand Down
1 change: 1 addition & 0 deletions opt_einsum/tests/test_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from opt_einsum import contract, contract_expression, contract_path
from opt_einsum.typing import PathType

# NumPy is required for the majority of this file
np = pytest.importorskip("numpy")


Expand Down
60 changes: 31 additions & 29 deletions opt_einsum/tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@
"""

import itertools
import sys
from concurrent.futures import ProcessPoolExecutor
from typing import Any, Dict, List, Optional

import pytest

import opt_einsum as oe
from opt_einsum.testing import build_views, rand_equation
from opt_einsum.testing import build_shapes, rand_equation, using_numpy
from opt_einsum.typing import ArrayIndexType, OptimizeKind, PathType, TensorShapeType

np = pytest.importorskip("numpy")

explicit_path_tests = {
"GEMM1": (
[set("abd"), set("ac"), set("bdc")],
Expand Down Expand Up @@ -129,10 +127,11 @@ def test_flop_cost() -> None:


def test_bad_path_option() -> None:
with pytest.raises(KeyError):
oe.contract("a,b,c", [1], [2], [3], optimize="optimall") # type: ignore
with pytest.raises(TypeError):
oe.contract("a,b,c", [1], [2], [3], optimize="optimall", shapes=True) # type: ignore


@using_numpy
def test_explicit_path() -> None:
x = oe.contract("a,b,c", [1], [2], [3], optimize=[(1, 2), (0, 1)])
assert x.item() == 6
Expand Down Expand Up @@ -160,39 +159,39 @@ def test_memory_paths() -> None:

expression = "abc,bdef,fghj,cem,mhk,ljk->adgl"

views = build_views(expression)
views = build_shapes(expression)

# Test tiny memory limit
path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=5)
path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=5, shapes=True)
assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)])

path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=5)
path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=5, shapes=True)
assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)])

# Check the possibilities, greedy is capped
path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=-1)
path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=-1, shapes=True)
assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)])

path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=-1)
path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=-1, shapes=True)
assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)])


@pytest.mark.parametrize("alg,expression,order", path_edge_tests)
def test_path_edge_cases(alg: OptimizeKind, expression: str, order: PathType) -> None:
views = build_views(expression)
views = build_shapes(expression)

# Test tiny memory limit
path_ret = oe.contract_path(expression, *views, optimize=alg)
path_ret = oe.contract_path(expression, *views, optimize=alg, shapes=True)
assert check_path(path_ret[0], order)


@pytest.mark.parametrize("expression,order", path_scalar_tests)
@pytest.mark.parametrize("alg", oe.paths._PATH_OPTIONS)
def test_path_scalar_cases(alg: OptimizeKind, expression: str, order: PathType) -> None:
views = build_views(expression)
views = build_shapes(expression)

# Test tiny memory limit
path_ret = oe.contract_path(expression, *views, optimize=alg)
path_ret = oe.contract_path(expression, *views, optimize=alg, shapes=True)
# print(path_ret[0])
assert len(path_ret[0]) == order

Expand All @@ -201,24 +200,24 @@ def test_optimal_edge_cases() -> None:

# Edge test5
expression = "a,ac,ab,ad,cd,bd,bc->"
edge_test4 = build_views(expression, dimension_dict={"a": 20, "b": 20, "c": 20, "d": 20})
path, _ = oe.contract_path(expression, *edge_test4, optimize="greedy", memory_limit="max_input")
edge_test4 = build_shapes(expression, dimension_dict={"a": 20, "b": 20, "c": 20, "d": 20})
path, _ = oe.contract_path(expression, *edge_test4, optimize="greedy", memory_limit="max_input", shapes=True)
assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)])

path, _ = oe.contract_path(expression, *edge_test4, optimize="optimal", memory_limit="max_input")
path, _ = oe.contract_path(expression, *edge_test4, optimize="optimal", memory_limit="max_input", shapes=True)
assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)])


def test_greedy_edge_cases() -> None:

expression = "abc,cfd,dbe,efa"
dim_dict = {k: 20 for k in expression.replace(",", "")}
tensors = build_views(expression, dimension_dict=dim_dict)
tensors = build_shapes(expression, dimension_dict=dim_dict)

path, _ = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit="max_input")
path, _ = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit="max_input", shapes=True)
assert check_path(path, [(0, 1, 2, 3)])

path, _ = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit=-1)
path, _ = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit=-1, shapes=True)
assert check_path(path, [(0, 1), (0, 2), (0, 1)])


Expand Down Expand Up @@ -315,9 +314,10 @@ def test_dp_errors_when_no_contractions_found() -> None:
@pytest.mark.parametrize("optimize", ["greedy", "branch-2", "branch-all", "optimal", "dp"])
def test_can_optimize_outer_products(optimize: OptimizeKind) -> None:

a, b, c = [np.random.randn(10, 10) for _ in range(3)]
d = np.random.randn(10, 2)
assert oe.contract_path("ab,cd,ef,fg", a, b, c, d, optimize=optimize)[0] == [
a, b, c = [(10, 10) for _ in range(3)]
d = (10, 2)

assert oe.contract_path("ab,cd,ef,fg", a, b, c, d, optimize=optimize, shapes=True)[0] == [
(2, 3),
(0, 2),
(0, 1),
Expand All @@ -329,13 +329,14 @@ def test_large_path(num_symbols: int) -> None:
symbols = "".join(oe.get_symbol(i) for i in range(num_symbols))
dimension_dict = dict(zip(symbols, itertools.cycle([2, 3, 4])))
expression = ",".join(symbols[t : t + 2] for t in range(num_symbols - 1))
tensors = build_views(expression, dimension_dict=dimension_dict)
tensors = build_shapes(expression, dimension_dict=dimension_dict)

# Check that path construction does not crash
oe.contract_path(expression, *tensors, optimize="greedy")
oe.contract_path(expression, *tensors, optimize="greedy", shapes=True)


def test_custom_random_greedy() -> None:
np = pytest.importorskip("numpy")

eq, shapes = rand_equation(10, 4, seed=42)
views = list(map(np.ones, shapes))
Expand Down Expand Up @@ -375,6 +376,7 @@ def test_custom_random_greedy() -> None:


def test_custom_branchbound() -> None:
np = pytest.importorskip("numpy")

eq, shapes = rand_equation(8, 4, seed=42)
views = list(map(np.ones, shapes))
Expand Down Expand Up @@ -407,10 +409,8 @@ def test_branchbound_validation() -> None:
oe.BranchBound(nbranch=0)


@pytest.mark.skipif(sys.version_info < (3, 2), reason="requires python3.2 or higher")
def test_parallel_random_greedy() -> None:

from concurrent.futures import ProcessPoolExecutor
np = pytest.importorskip("numpy")

pool = ProcessPoolExecutor(2)

Expand Down Expand Up @@ -454,6 +454,7 @@ def test_parallel_random_greedy() -> None:


def test_custom_path_optimizer() -> None:
np = pytest.importorskip("numpy")

class NaiveOptimizer(oe.paths.PathOptimizer):
def __call__(
Expand All @@ -478,6 +479,7 @@ def __call__(


def test_custom_random_optimizer() -> None:
np = pytest.importorskip("numpy")

class NaiveRandomOptimizer(oe.path_random.RandomOptimizer):
@staticmethod
Expand Down
Loading

0 comments on commit 3b627b6

Please sign in to comment.