From aa3dd9acc3e46fe6e0a08a07270dc1eeb5d437df Mon Sep 17 00:00:00 2001 From: Roman Novak <44512421+romanngg@users.noreply.github.com> Date: Sat, 9 Jul 2022 09:01:16 -0700 Subject: [PATCH] Fix `dp` to work for contractions with no sized axes (scalars only). (#195) * Fix `dp` to work for contractions with no sized axes (scalars only). * Add comma for linter * Bug fixes and tweaks for v3.4.0 (#192) * Unpins MyPy, now appears stable * Minor typing tweaks * Skips unicode surrogates, fixes #182 * Adds simple validation to BranchBound, fixes #136 * More accurate branch bound exception handler * Update opt_einsum/paths.py Co-authored-by: Daniel Smith --- opt_einsum/paths.py | 2 +- opt_einsum/tests/test_contract.py | 1 + opt_einsum/tests/test_paths.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/opt_einsum/paths.py b/opt_einsum/paths.py index eac36301..ac8ccd5d 100644 --- a/opt_einsum/paths.py +++ b/opt_einsum/paths.py @@ -1217,7 +1217,7 @@ def __call__( output = frozenset(symbol2int[c] for c in output_) size_dict_canonical = {symbol2int[c]: v for c, v in size_dict_.items() if c in symbol2int} size_dict = [size_dict_canonical[j] for j in range(len(size_dict_canonical))] - naive_cost = naive_scale * len(inputs) * functools.reduce(operator.mul, size_dict) + naive_cost = naive_scale * len(inputs) * functools.reduce(operator.mul, size_dict, 1) inputs, inputs_done, inputs_contractions = _dp_parse_out_single_term_ops(inputs, all_inds, ind_counts) diff --git a/opt_einsum/tests/test_contract.py b/opt_einsum/tests/test_contract.py index 453a7790..ef3ee5de 100644 --- a/opt_einsum/tests/test_contract.py +++ b/opt_einsum/tests/test_contract.py @@ -13,6 +13,7 @@ "a,->a", "ab,->ab", ",ab,->ab", + ",,->", # Test hadamard-like products "a,ab,abc->abc", "a,b,ab->ab", diff --git a/opt_einsum/tests/test_paths.py b/opt_einsum/tests/test_paths.py index c6fef4f5..eb4b18bb 100644 --- a/opt_einsum/tests/test_paths.py +++ b/opt_einsum/tests/test_paths.py @@ -58,6 +58,7 @@ ["ab,->ab", 1], [",a,->a", 2], [",,a,->a", 3], + [",,->", 2], ]