Replies: 3 comments 9 replies
-
Here's some more background and supplementary information:
Moving construction-time shape inference to compile-time is necessary to guarantee that we'll always have the most shape information possible. It also serves to reduce duplication, because shape inference logic that applies at construction-time necessarily applies at compile-time, but the latter is capable of handling much more than the former. See #1275 for one such change in this direction. Now, consider the following example: import aesara
import aesara.tensor as at
s = at.lscalar("s")
x = at.ones((s,))
x.type
# TensorType(float64, (None,)) In this case, the shape information for How could this be handled? To start, we know that we ultimately need to add a constraint on the value of The point we're talking about is at the Anyway, all of this work needs to be at compile-time, and that's another reason to focus on moving all shape inference to a single, coherent set of compile-time operations. For instance, at compile-time we're able to "parse" expressions like |
Beta Was this translation helpful? Give feedback.
-
Another option instead of import aesara
import aesara.tensor as at
from aesara.ifelse import ifelse
x = at.vector("x")
y = at.vector("y")
y = at.specify_shape(y, x.shape)
r = ifelse(at.eq(x.shape[0], y.shape[0]), at.ones_like(y), at.ones_like(y).sum(keepdims=True))
f = aesara.function([x, y], r)
aesara.dprint(f) This currently does not get rid of the IfElse, and actually removes the
f([1], [1]) # array([1.])
f([1, 1, 1, 1], [1, 1, 1, 1]) # array([1., 1., 1., 1.])
f([1], [1, 1, 1, 1]) # array([4.]), but should have raised! |
Beta Was this translation helpful? Give feedback.
-
It seems it might be? import aesara
import aesara.tensor as at
from aesara.ifelse import ifelse
x = at.vector("x")
y = at.vector("y")
r = ifelse(at.eq(x.shape[0], y.shape[0]), at.ones_like(y), at.ones_like(y).sum(keepdims=True))
f = aesara.function([x, y], r, mode="JAX")
f([1], [1, 1, 1, 1])
# TypeError: true_fun and false_fun output must have identical types got
# ('DIFFERENT ShapedArray(float64[4]) vs. ShapedArray(float64[1])',). It also fails without JIT
|
Beta Was this translation helpful? Give feedback.
-
In the following I will try to summarize the situation with shapes & broadcasting in Aesara and try to lay it out in simple terms. None of this is news; most of the content of this post has been mentioned here and there in comments. It just hasn't been spelled out completely in one place.
Shapes in Aesara today
Users can currently initialize tensors without providing any information about the shape. Aesara’s shape system currently represents the lack of information about the shape with
None
:If the user has more information about the shape of a tensor they can specify it in different ways: either using the built-in type, initializing the tensor with a
shape
or usingspecify_shape
:If they only have partial information they can also specify it:
This shape information is currently propagated through the graph using
TensorType.shape
.Adding extra information:
shape different from 1
There are situations in which this type-level information is not sufficient, for instance in
Elemwise.grad
where theOp
needs to know whether broadcasting occured between inputs to know whether to sum or not. We thus need to encode more information in our type system, namely the knowledge that the shape is different from 1.Let’s assume here that we represent this information with
-1
: if the shape of a dimension is specified as-1
this means that we know it is different from1
. It is not as specific as a full shape specification, but carries more information thanNone
.Users can specify this information when creating a tensor:
Again, this information can be propagated through the graph the same way
None
is. There are still places where we would need to add the correct logic to propagate this constraint, but these areElemwise
and otherOp
s we are trying to fix by doing this.To summarize, the user will be able to specify the shape at the type level using, from the greatest to the least amount of information:
-1
means that we only know the shape is not 1.None
means that we have no information about the shape;That is all the information Aesara needs about shape at compile-time.
What do we do when this information is not available?
As stated above, there are situations in which the behavior of an
Op
differs depending on whethershape=1
ofshape!=1
. What to do when this information is not available in the graph, i.e. when there is ambiguity, is where potential tradeoffs appear. Let us go through these tradeoffs methodically.There are three moments at which we can decide to handle the lack of information: construction-time, compile-time, and runtime. In the following we will only consider compile-time and runtime; handling shape inference at construction-time is not necessary, implies a lot of duplicated logic with compile-time inference for little added value, if any.
Handle the issue at compile-time
It is impossible to resolve the ambiguity at compile-time; if the information is not there, it is simply not there. We thus only have two ways forward: assuming or failing.
First, assuming. Remember that Aesara currently represents unspecified dimensions with
None
, e.g. whenat.vector
is used. This is just Aesara (correctly) stating that we have no shape information about the tensors that are created this way. We could change this behavior, and assume on behalf of the user that unspecified dimensions default to-1
; this will obviously lead to surprising behavior when the user inputs e.g.np.ones((1,))
and is error-prone. We can always add a runtime assertion, but that does not make the behavior less surprising and adds a (small) runtime cost.Then, failing. Here Aesara does not make any assumption about the user’s intents, but instead asks the user to resolve any ambiguity. That means failing at compile-time, explaining the ambiguity and asking the user to provide more information. This can however quickly become frustrating for users, make existing code fail, and completely forbid cases where they want to pass tensors where sometimes one dimension is equal to 1, and sometimes the same dimension is different from one.
In either case, trying to handle the lack of information at compile-time involves important tradeoffs with the user interface.
Handle the issue at runtime
The ambiguity is necessarily resolved at runtime: users call compiled functions with arrays that have a concrete shape. We could thus defer handling insufficient in-graph information at runtime. There are two ways to do this.
The first possibility to consider is a custom
Op
, let's call itSumDims
which is called with booleans that indicate whether a dimension should be summed or not. ThisOp
basically encapsulates a multi-branch conditional. However, none of this structure is explicit in the graph, and we will need to implement rewrites specific to thisOp
. This is a fundamental limitation with every customOp
approach.Or we can make this branching structure explicit and use the existing
IfElse
directly: whenever there was not enough information to determineshape!=1
at compile time we enumerate at the distinct conditions and their graph so each branch represents a distinctx.sum
. The advantage of this approach over the customOp
is that we can leverage existing rewrites; if new rewrites need to be implemented they will be more widely applicable in other scenarios.There are still legitimate concerns with the
IfElse
approach1: performance and graph complexity. The first thing to note is that theIfElse
can be removed at compile-time when shape inference is performed if enough information was provided by the user. Whenever we cannot resolve the shapes at compile-time, however, branching logic is the price to pay to get the information that was not provided by the user: as far as the Aesara compiler knows, either scenario can happen. The runtime cost would be small, but the user interested in small performance gains can always avoid the cost of this branching logic by giving more information regarding the shapes of their tensors.What is not in this write-up
I left out some of the concerns given in the other issues because they are not related to the representation in the IR and should be addressed downstream:
IfElse
solution.I also purposefully did not use the expression "dynamic broadcasting" which is not specific enough to unambiguously identify a set of solutions.
Footnotes
These concerns also apply to the custom
Op
approach. ↩Beta Was this translation helpful? Give feedback.
All reactions