-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Suboptimal order when einsum contains non-repeating indices #112
Comments
At the moment there is no way to do these non-linear types of contraction, only those that can be broken up into pairwise |
Why doesn't opt_einsum discover the following schedule? It's 1000x faster (colab)
BTW, this einsum comes up when trying to extend https://arxiv.org/abs/1510.01799 to conv2d layers |
Currently Happy to chat about adding this, but at the moment I do lean a bit towards downstream technologies implementing this kind of optimization themselves. I think we can do a decent job in NumPy, but less so with GPU backends. |
I was just about to comment pretty much the same as @dgasmith. The only thing I'll add is that in this specific case, it might be worth considering these kind of trivial axis reductions (i.e. 'nlq->nl'
'nlp->nl'
'nl,nl->l' would only be ~twice as slow as the 'fast' version (since with the numpy backend at least it still wouldn't know |
@jcmgray Good point there, this could fall under the "things we should always check for" like the Hadamard product issue that is coming up. This would be cheap Hey in 3-4 years we can use the walrus operator here :) |
Yes exactly, might be worth compiling a list of such steps in another issue:
|
Yup, lets split this out into an issue and see how hard it would be to add in a uniform manner. |
I'm wondering if this can be handled in a general fashion by trying to minimize the scaling order. Without knowledge of underlying backend, O(n^3) schedule should be preferable to O(n^4) Some examples:
Similar problem comes up in graphical models literature and is typically handled with a two stage approach (Junction Tree algorithm). Make a graph with each index corresponding to a vertex, with indices connected if they co-occur in the same factor, then: Step 1. triangulate the graph using a greedy heuristic Each clique corresponds to an intermediate term while spanning clique-tree gives a reduction schedule. An example of doing this for partition function of 3x3 grid Ising model |
In both those cases I think it might just be as simple as performing the individual reductions in any order? e.g. in the first case 'a->'
'b->'
',->' # i.e. scalar multiplication or in the second case 'abcd->a'
'aefg->a'
'a,a->a' as all the indices fall under the category of 'appear on a single input and not the output'. Maybe there is a more general example? |
A slightly more general example is There's a choice over which indices to reduce last. Choosing
|
So I think the 'single index reduction' preprocessing step I'm imagining would handle that fine:
If there are more terms, I'm pretty sure |
Summing out leaf indices may create new leaf indices, so this preprocessing step may need to be repeated to convergence. A more general example is a binary tree which is doable with O(n^2) scaling but gets O(n^3) currently.
|
Ah yes but again (and not trying to be contrary - its good to think about these edge cases!) if you just perform the single axis reductions first you get the n^2 scaling again: import opt_einsum as oe
import numpy as np
from collections import Counter
from itertools import chain
def binary_tree_einsum(depth):
edges = []
def tc(num):
return chr(num+100000) if chr(num) in ' ,->.' else chr(num + 100)
def rec(parent, child, depth):
edges.append(tc(parent)+tc(child))
if depth > 0:
rec(child, 2*child, depth-1)
rec(child, 2*child+1, depth-1)
rec(0, 1, depth)
views = [np.ones((2,2))]*len(edges)
# explicitly reduce
freqs = Counter(chain(*edges))
new_terms = []
new_views = []
for term, view in zip(edges, views):
new_term = "".join(ix for ix in term if freqs[ix] != 1)
new_terms.append(new_term)
new_views.append(np.einsum(term + '->' + new_term, view) )
eq = ','.join(new_terms) + '->'
print(oe.contract_path(eq, *new_views))
binary_tree_einsum(3)
# Optimized scaling: 2 I'm pretty certain (sorry may not be totally clear, on the left the dashed lines are the contraction order, on the right the graph is deformed into the tree, with the light grey nodes the intermediates). The problem at the moment is that the path finders assume that all edges appear more than once so it doesn't 'annihilate' any leaf indices at the beginning which it should. But after that, if ever the final two indices meet they are indeed contracted so there will never be a leftover singleton. |
Ah nice! That solution seems to work. PS: I was curious to check if the optimizer will perform well for graphs with bounded treewidth, but unbounded pathwidth, but being limited to ascii makes it a bit hard to generate such graphs programmatically |
Yes I'm not really sure what that might look like to be honest! There is a general result by Markov & Shi linking the asymptotic scaling to the treewidth of the line graph, and whilst that is an optimal result, practically speaking when this is bounded the graphs are also easier with heuristic methods. I might mention that there are actually several different optimisers, with the default |
BTW, duplicating each factor seemed like an easier workaround, however, it doesn't always fix the scaling problem. It recovers n^2 scaling for tree with 8 elements, but still gives me n^3 scaling for tree with 16 elements. Is this an issue of suboptimal greedy optimizer kicking in?
Tree-like structures should be easy to discover -- greedy triangulation with minfill heuristic should just work. If I add even more edges, it drops back to O(n^2) scaling. ❓
BTW, results like Markov & Shi's seem to come up in many places. Small treewidth is the most basic condition to guarantee fast computation. A more general condition is for the problem to reduce to a computation on "minor-excuded" class of graphs (Ch. 17 of Grohe's "Descriptive Complexity, Canonisation, and Definable Graph Structure Theory" book) . For instance, counting perfect matchings is an einsum which is computable in polynomial time for planar graphs using the FKT algorithm. Personally I'm interested in ways of computing large einsums approximately since things are already inexact due to measurement noise and floating point round-off. For minor-excluded class with bounded degree there's a polynomial time approximation algorithm by Jung, Shah. A related simpler heuristic is the Generalized distributive law. Basically you reformulate einsum in terms of equations which give exact result after k updates when factor graph is a tree. When it is not a tree, you update n*k times, and get good result for small n when edge interactions are not "too strong". This lets you deal with problems that have high treewidth either due to original einsum structure or due to large factors. The latter case can be handled by approximating large factors as products of smaller factors. In statistical physics, "Generalized distributive law" comes up in approximating Ising free energy. When factors are restricted to be pairwise edge potentials, this algorithm gives what is known as the Bethe-Peierls approximation. When nearby vertices are merged into larger factors, this algorithm gives higher quality Kikuchi approximation. |
Yes, if you try Interestingly, the essentially optimal To be honest I am really not that familiar with the wider graph theory literature and classical applications of what I might call hyper tensor networks. In many-body quantum, the approximate contraction is always based on how much 'entanglement' is in the network - essentially whether the tensors are approximately low-rank across certain partitions. And for quantum circuit simulation, essentially nothing is low-rank so the contractions are performed exactly. In that exact case, the two following statements seem to be most practically relevant:
Anyway, I dunno if you are planning on working on any of the things you mention, but I'd certainly interested to see any results, and especially whether there are other approximation schemes that are relevant to classical quantum simulation. |
Aha, optimize='dp' seems appropriate here. I've been using Carl Woll's "TensorSimplify" package to come up with tractable formulas for various neural network-related quantities, but einsum_opt seems like a more flexible tool Note that using optimize='dp' fixes order for binary tree, but it's still suboptimal for 'abcd,aefg->a'
However, adding a "all ones" 1-d factor for every dimension, seems to recover good scaling in all the instances I found to be suboptimal: |
Can we think of edge cases here that #114 will not fix the scaling for? |
It seems to work for all my example...but I'm curious why preprocessing is even needed when using exact algorithm, should optimize='dp' handle these cases automatically? |
So |
BTW, it is now 2022 |
Is there a way to get optimized path for expression with some tensors repeated?
For instance, einsum('nlp,nlq->l', B, B) can be done in O(n^3) time, but einsum opt gives a schedule that takes O(n^4) time.
Faster version way to do this
The text was updated successfully, but these errors were encountered: