-
-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change Adjoints to be ComponentArrays #170
base: main
Are you sure you want to change the base?
Conversation
This PR aims to avoid issues like: ``` using ComponentArrays A = ComponentMatrix(ones(2, 2), Axis(:a, :b), FlatAxis()) A[:b, :] # works A'[:, :b] # fails! ``` by wrapping adjoint operations in the underlying data of the ComponentArray structure. By not having to care about adjoints of ComponentArray, we arguably also reduce overall cognitive load.
@jonniedie sorry for the broken tests, these are passing now in my machine. |
test/gpu_tests.jl
Outdated
@@ -46,7 +46,7 @@ end | |||
@test rmul!(jlca3, 2) == ComponentArray(jla .* 2, Axis(a=1:2, b=3:4)) | |||
end | |||
@testset "mul!" begin | |||
A = jlca .* jlca'; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here and in many other tests, I think we were testing the wrong behaviour: I think we should have used *
instead of .*
(like this PR does).
.*
is different than *
in that it performs broadcasting on both sides and then multiplies elementwise. However, the axis of the (broadcasted) left and right hand side are not identical, so a non-ComponentMatrix must be returned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broadcasting on GPU is a behaviour that we particularly want to test.
The axes of the two sides are not required to be identical in broadcasting (and in general).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right about this test; I have thus restored the diff here.
However, in other places like in this one, the testing logic was actually incorrect. In this particular example, doing broadcasting should not yield the axes that the test was specifying, as detailed in my comment above.
To correct these, I could either change the operation to be matrix multiplication (instead of broadcasting), or change the testing logic. I chose the former, but happy to change if you think otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broadcasting should extrude singleton dimensions to the non-singleton axes of that dimension within the broadcasted operation. We're following the same behavior here that the other array packages follow:
julia> using StaticArrays
julia> a = SA[1, 2, 3]; b = SA[4, 5, 6];
julia> a .* b'
3×3 SMatrix{3, 3, Int64, 9} with indices SOneTo(3)×SOneTo(3):
4 5 6
8 10 12
12 15 18
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the case of a ::AbstractVector .* ::Adjoint{_, <:AbstractVector}
, it just happens to give the same result as plain matrix multiplication.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unclear to me what the StaticArrays
example demonstrates? For example, I might argue that the result of the example with StaticArrays
is compatible with the current behaviour of this PR, as the indices of b
and a
are already compatible, in the sense that:
vcat(b', b', b')
gives a 3×3 SMatrix{3, 3, Int64, 9} with indices SOneTo(3)×SOneTo(3)
as a .* b'
does.
Also, in my understanding the broadcasting behaviour of master
is inconsistent (essentially a special case only for transposes?):
julia> ComponentMatrix(zeros(2, 1), (Axis(:a, :b), FlatAxis())) .* ComponentMatrix(zeros(1, 2), (FlatAxis(), Axis(:c, :d)))
2×2 Matrix{Float64}:
0.0 0.0
0.0 0.0
julia> ComponentVector(a=0.0, b=0.0).*ComponentVector(c=0.0, d=0.0)'
2×2 ComponentMatrix{Float64} with axes Axis(a = 1, b = 2) × Axis(c = 1, d = 2)
0.0 0.0
0.0 0.0
In general, I find it hard to imagine a way to implement the broadcasting described in way that will result in a consistent operation both for multiplying 2x1
.* 1x1
matrices and 1x1
.* 1x1
matrices. If we want to continue this chat, it would help me a lot if we could clarify the behaviour of:
ComponentMatrix(zeros(2, 1), (Axis(:a, :b), FlatAxis())) .* ComponentMatrix(zeros(1, 2), (FlatAxis(), Axis(:c, :d)))
ComponentMatrix(zeros(2, 1), (Axis(:a, :b), FlatAxis())) .* ComponentMatrix(zeros(1, 2), (Axis(:e), Axis(:c, :d)))
ComponentMatrix(zeros(2, 1), (Axis(:a, :b), Axis(:f))) .* ComponentMatrix(zeros(1, 2), (Axis(:e), Axis(:c, :d)))
ComponentVector(a=0.0, b=0.0).*ComponentVector(c=0.0, d=0.0)'
ComponentVector(a=0.0).*ComponentVector(c=0.0)'
ComponentMatrix(zeros(1, 1), (Axis(:a), FlatAxis())) .* ComponentMatrix(zeros(1, 1), (FlatAxis(), Axis(:c)))
ComponentMatrix(zeros(1, 1), (Axis(:a), FlatAxis())) .* ComponentMatrix(zeros(1, 1), (Axis(:d), Axis(:c)))
ComponentMatrix(zeros(1, 1), (Axis(:a), Axis(:b))) .* ComponentMatrix(zeros(1, 1), (FlatAxis(), Axis(:c)))
ComponentMatrix(zeros(1, 1), (Axis(:a), Axis(:b))) .* ComponentMatrix(zeros(1, 1), (Axis(:d), Axis(:c)))
as I find the result of master
on these inconsistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unclear to me what the StaticArrays example demonstrates?
The StaticArrays
example demonstrates that Broadcast.combine_axes
is choosing the non-singleton axes of each dimension from the arrays it's given rather than falling back to the default Base.OneTo
axes.
Also, in my understanding the broadcasting behaviour of master is inconsistent (essentially a special case only for transposes?):
There's kinda a complicated story to why we have to special case vector transposes. Before ArrayInterfaces.jl
existed, the way DifferentialEquations.jl
and some other packages created matrix cache arrays from vectors (say, for a Jacobian cache) was to call u .* u' .* false
. Now ArrayInterfaces.jl
lets you overload a zeromatrix
function for this purpose. We want the cache matrices to in these cases to be ComponentMatrix
so users that are interacting with them directly can address blocks by named indices. We could probably at this point switch to using zeromatrix
. It may break things for some packages that are still doing u .* u'
to expand matrices, though, so I've kinda put off the idea for fear of having to deal with that.
But with that said, I actually want the current behavior we have for the vector and vector transpose case to be the behavior for all cases. Unfortunately its just inherently type unstable to do this in general because it would mean that the size of a dimension (which generally isn't known at compile time) would determine the type of axes we choose. Consider the following examples (I'm using vectors instead of matrices here because it makes the issue clearer and removes the temptation for us to think of this as having anything to do with matrix multiplication):
# Group 1
ComponentVector(a=4, b=2) .* [3] # Case 1
ComponentVector(a=4, b=2) .* [3, 1] # Case 2
# Group 2
ComponentVector(a=4) .* [3] # Case 1
ComponentVector(a=4) .* [3, 1] # Case 2
Within each group, the argument types are the same, so for type stability, we need to have the return types the same as well.
In Group 1, Case 1 (I'll just abbreviate it to G1C1), it's clear that we should want use the axis of the ComponentVector
, since the regular Vector
has a single element and will thus be extruded into the axis of the other array for that dimension. If we want to have this, however, we also need to make it so the ComponentVector
axis "wins" for G1C2, since the compiler doesn't know how many elements the plain Vector
has.
In G2C1, we're faced with a similar choice and it seems like we should similarly let the ComponentVector
"win" here for consistency. But clearly we can't do that because in G2C2, which looks the same to the compiler as G2C1, we will be extruding the single element ComponentVector
onto the axes of the plain Vector
. So for this group, we have to let both results fall back to giving a plain Vector
as an output, despite the fact that we would probably prefer to have the G1C1 be a ComponentVector
for consistency.
For inspiration into how it should be handled, I think we should look back at StaticArrays
and see how they do it:
julia> SA[4, 2] .* [3] # G1C1
2-element SizedVector{2, Int64, Vector{Int64}} with indices SOneTo(2):
12
6
julia> SA[4, 2] .* [3, 1] # G1C2
2-element SizedVector{2, Int64, Vector{Int64}} with indices SOneTo(2):
12
2
julia> SA[4] .* [3] # G2C1
1-element Vector{Int64}:
12
julia> SA[4] .* [3, 1] # G2C2
2-element Vector{Int64}:
12
4
Since StaticArrays
have knowledge of their size, they make the decision to always use the static array axis when the length of the dimension of the StaticArray
is > 1
and always use the other array's axis when it's == 1
. This makes sense and probably gives the best tradeoff between consistency and usefulness.
Could we do the same thing for ComponentArrays
? Well... sorta. In most usage of ComponentArrays
, the named indices correspond 1:1 to the array elements so we can statically know the size of the ComponentArray
by looking at the last index of each dimension of the component axes. But there is this undocumented but still very useful use of Axes
that don't follow this behavior (see #163). So we can't necessarily guarantee that a dimension of aComponentArray
's is length 1 just because it's component axis type information seems to indicate it is. But maybe we could just assume they are and throw an error in cases where this isn't true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed explanation.
There's kinda a complicated story to why we have to special case vector transposes ... We could probably at this point switch to using zeromatrix. It may break things for some packages that are still doing u .* u' to expand matrices, though, so I've kinda put off the idea for fear of having to deal with that.
This sounds ideal, as I am all in to eliminating special cases. Assuming that u .* u'
(.*false
?) is used to expand matrices in some packages (that we don't test against?) sounds fragile to me and, as far as I understand, it already "fails to work" in very popular packages that I happen to use:
julia> using ForwardDiff, ComponentArrays
julia> ForwardDiff.jacobian(x -> x, ComponentVector(a=1.0))
1×1 Matrix{Float64}:
1.0
The potential behaviour that you described seems complete and well thought (at a first look), but I am of the opinion that unless or until we have a complete and documented behaviour the resulting complexity and confusion will make the simplicity of a uniform rule like "broadcasting results in flataxis" (as per this PR) will be preferable, and that's before taking maintenance overheads into mind.
Nevertheless, I think this chat might be expanding beyond the scope of this PR. If you want to maintain the special case for transposed broadcast, I will include it in this PR. I tried to do that today by adding the following lines in broadcasted.jl
:
Broadcast.combine_axes(
a::ComponentArray, b::ComponentArray{T, 2, <:Union{LinearAlgebra.Adjoint{T, A}, LinearAlgebra.Transpose{T, A}}, Ax}
) where {T, A <: AbstractVector, Ax} = (axes(a)[1], axes(b)[2])
Broadcast.combine_axes(
a::ComponentArray{T, 2, <:Union{LinearAlgebra.Adjoint{T, A}, LinearAlgebra.Transpose{T, A}}, Ax}, b::ComponentArray
) where {T, A <: AbstractVector, Ax} = (axes(b)[2], axes(a)[1])
to substitute the lines in that file currently deleted by this PR. This works for
ComponentVector(a=0) .* ComponentVector(b=0)'
but not for
false .* ComponentVector(a=0) .* ComponentVector(b=0)'
.
I failed to find why it does not work (while it works in master
) even after spending a couple of hours on this today, so I would appreciate some help on this, assuming that's the route we are taking.
Why not just overload |
julia> a = ComponentVector(ones(2), Axis(:a, :b))
julia> a .* a'
2×2 ComponentMatrix{Float64} with axes Axis(a = 1, b = 2) × Axis(a = 1, b = 2) It's unclear what other edge cases we might be missing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this! Sorry it’s taken so long to get to reviewing. The only thing is we should keep all of the dot multiplies in the broadcasting tests because they are testing that broadcasting is working correctly. We already have a section in the math tests that test whether outer and inner products work.
@@ -369,11 +369,11 @@ end | |||
@testset "Broadcasting" begin | |||
temp = deepcopy(ca) | |||
@test eltype(Float32.(ca)) == Float32 | |||
@test ca .* ca' == cmat | |||
@test ca * ca' == cmat | |||
@test 1 .* (ca .+ ca) == ComponentArray(a .+ a, getaxes(ca)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don’t want to get rid of tests for broadcasted extrusion. All of these should stay as dot multiplication. We already have tests for linear algebra. These are equivalent mathematically, but not programmatically.
@test getaxes(false .* ca .* ca') == (ax, ax) | ||
@test getaxes(false .* ca' .* ca) == (ax, ax) | ||
@test getaxes(false .* ca * ca') == (ax, ax) | ||
@test isa(ca' * ca, Float64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to make sure both forms of broadcasting work.
Thanks a lot for the review @jonniedie! I am happy to add any tests you feel sensible. However I think some of the old tests were actually incorrect, as I argue in this thread. Can you reply to this thread with what you think on my argument? |
This PR aims to avoid issues like:
by wrapping adjoint operations in the underlying data of the ComponentArray structure.
By not having to care about adjoints of ComponentArray, we arguably also reduce overall cognitive load.