Skip to content

Commit

Permalink
CI: skip cotengra numpy install for some
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Sep 7, 2023
1 parent 7e11612 commit 11f4c15
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
install: |
apt-get update
apt-get install -y --no-install-recommends python3 python3-pip
pip3 install -U pip pytest numpy cotengra
pip3 install -U pip pytest # numpy cotengra
run: |
set -e
pip3 install cotengrust --find-links dist --force-reinstall
Expand Down
59 changes: 39 additions & 20 deletions tests/test_cotengrust.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,33 @@
import pytest
import numpy as np
from numpy.testing import assert_allclose
import cotengra as ctg

try:
import cotengra as ctg

ctg_missing = False
except ImportError:
ctg_missing = True
ctg = None

import cotengrust as ctgr


requires_cotengra = pytest.mark.skipif(ctg_missing, reason="requires cotengra")


@pytest.mark.parametrize("which", ["greedy", "optimal"])
def test_basic_call(which):
inputs = [('a', 'b'), ('b', 'c'), ('c', 'd'), ('d', 'a')]
output = ('b', 'd')
size_dict = {'a': 2, 'b': 3, 'c': 4, 'd': 5}
path = {
"greedy": ctgr.optimize_greedy,
"optimal": ctgr.optimize_optimal,
}[
which
](inputs, output, size_dict)
assert all(len(con) <= 2 for con in path)


def find_output_str(lhs):
tmp_lhs = lhs.replace(",", "")
return "".join(s for s in sorted(set(tmp_lhs)) if tmp_lhs.count(s) == 1)
Expand All @@ -21,20 +44,16 @@ def eq_to_inputs_output(eq):


def get_rand_size_dict(inputs, d_min=2, d_max=3):
import random

size_dict = {}
for term in inputs:
for ix in term:
if ix not in size_dict:
size_dict[ix] = np.random.randint(d_min, d_max + 1)
size_dict[ix] = random.randint(d_min, d_max)
return size_dict


def build_arrays(inputs, size_dict):
return [
np.random.randn(*[size_dict[ix] for ix in term]) for term in inputs
]


# these are taken from opt_einsum
test_case_eqs = [
# Test scalar-like operations
Expand Down Expand Up @@ -120,24 +139,26 @@ def build_arrays(inputs, size_dict):
]


@requires_cotengra
@pytest.mark.parametrize("eq", test_case_eqs)
@pytest.mark.parametrize("which", ["greedy", "optimal"])
def test_manual_cases(eq, which):
inputs, output = eq_to_inputs_output(eq)
size_dict = get_rand_size_dict(inputs)
arrays = build_arrays(inputs, size_dict)
expected = np.einsum(eq, *arrays, optimize=True)
path = {
"greedy": ctgr.optimize_greedy,
"optimal": ctgr.optimize_optimal,
}[
which
](inputs, output, size_dict)
assert all(len(con) <= 2 for con in path)
tree = ctg.ContractionTree.from_path(inputs, output, size_dict, path=path)
assert_allclose(tree.contract(arrays), expected)
tree = ctg.ContractionTree.from_path(
inputs, output, size_dict, path=path, check=True
)
assert tree.is_complete()


@requires_cotengra
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("which", ["greedy", "optimal"])
def test_basic_rand(seed, which):
Expand All @@ -151,22 +172,20 @@ def test_basic_rand(seed, which):
d_max=3,
seed=seed,
)
eq = ",".join(map("".join, inputs)) + "->" + "".join(output)

path = {
"greedy": ctgr.optimize_greedy,
"optimal": ctgr.optimize_optimal,
}[
which
](inputs, output, size_dict)
assert all(len(con) <= 2 for con in path)
tree = ctg.ContractionTree.from_path(inputs, output, size_dict, path=path)
arrays = [np.random.randn(*s) for s in shapes]
assert_allclose(
tree.contract(arrays), np.einsum(eq, *arrays, optimize=True)
tree = ctg.ContractionTree.from_path(
inputs, output, size_dict, path=path, check=True
)
assert tree.is_complete()


@requires_cotengra
def test_optimal_lattice_eq():
inputs, output, _, size_dict = ctg.utils.lattice_equation(
[4, 5], d_max=3, seed=42
Expand Down

0 comments on commit 11f4c15

Please sign in to comment.