Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Adjoints to be ComponentArrays #170

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

nrontsis
Copy link
Contributor

@nrontsis nrontsis commented Oct 23, 2022

This PR aims to avoid issues like:

using ComponentArrays
A = ComponentMatrix(ones(2, 2), Axis(:a, :b), FlatAxis())
A[:b, :] # works
A'[:, :b] # fails!

by wrapping adjoint operations in the underlying data of the ComponentArray structure.

By not having to care about adjoints of ComponentArray, we arguably also reduce overall cognitive load.

This PR aims to avoid issues like:
```
using ComponentArrays
A = ComponentMatrix(ones(2, 2), Axis(:a, :b), FlatAxis())
A[:b, :] # works
A'[:, :b] # fails!
```
by wrapping adjoint operations in the underlying data of the ComponentArray structure.

By not having to care about adjoints of ComponentArray, we arguably also reduce overall cognitive load.
@nrontsis
Copy link
Contributor Author

@jonniedie sorry for the broken tests, these are passing now in my machine.

@@ -46,7 +46,7 @@ end
@test rmul!(jlca3, 2) == ComponentArray(jla .* 2, Axis(a=1:2, b=3:4))
end
@testset "mul!" begin
A = jlca .* jlca';
Copy link
Contributor Author

@nrontsis nrontsis Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and in many other tests, I think we were testing the wrong behaviour: I think we should have used * instead of .* (like this PR does).

.* is different than * in that it performs broadcasting on both sides and then multiplies elementwise. However, the axis of the (broadcasted) left and right hand side are not identical, so a non-ComponentMatrix must be returned.

Copy link
Contributor

@YichengDWu YichengDWu Nov 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broadcasting on GPU is a behaviour that we particularly want to test.

The axes of the two sides are not required to be identical in broadcasting (and in general).

Copy link
Contributor Author

@nrontsis nrontsis Nov 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right about this test; I have thus restored the diff here.

However, in other places like in this one, the testing logic was actually incorrect. In this particular example, doing broadcasting should not yield the axes that the test was specifying, as detailed in my comment above.

To correct these, I could either change the operation to be matrix multiplication (instead of broadcasting), or change the testing logic. I chose the former, but happy to change if you think otherwise.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broadcasting should extrude singleton dimensions to the non-singleton axes of that dimension within the broadcasted operation. We're following the same behavior here that the other array packages follow:

julia> using StaticArrays

julia> a = SA[1, 2, 3]; b = SA[4, 5, 6];

julia> a .* b'
3×3 SMatrix{3, 3, Int64, 9} with indices SOneTo(3)×SOneTo(3):
  4   5   6
  8  10  12
 12  15  18

Copy link
Collaborator

@jonniedie jonniedie Dec 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the case of a ::AbstractVector .* ::Adjoint{_, <:AbstractVector}, it just happens to give the same result as plain matrix multiplication.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear to me what the StaticArrays example demonstrates? For example, I might argue that the result of the example with StaticArrays is compatible with the current behaviour of this PR, as the indices of b and a are already compatible, in the sense that:

vcat(b', b', b')

gives a 3×3 SMatrix{3, 3, Int64, 9} with indices SOneTo(3)×SOneTo(3) as a .* b' does.

Also, in my understanding the broadcasting behaviour of master is inconsistent (essentially a special case only for transposes?):

julia> ComponentMatrix(zeros(2, 1), (Axis(:a, :b), FlatAxis())) .* ComponentMatrix(zeros(1, 2), (FlatAxis(), Axis(:c, :d)))
2×2 Matrix{Float64}:
 0.0  0.0
 0.0  0.0

julia> ComponentVector(a=0.0, b=0.0).*ComponentVector(c=0.0, d=0.0)'
2×2 ComponentMatrix{Float64} with axes Axis(a = 1, b = 2) × Axis(c = 1, d = 2)
 0.0  0.0
 0.0  0.0

In general, I find it hard to imagine a way to implement the broadcasting described in way that will result in a consistent operation both for multiplying 2x1 .* 1x1 matrices and 1x1 .* 1x1 matrices. If we want to continue this chat, it would help me a lot if we could clarify the behaviour of:

ComponentMatrix(zeros(2, 1), (Axis(:a, :b), FlatAxis())) .* ComponentMatrix(zeros(1, 2), (FlatAxis(), Axis(:c, :d)))
ComponentMatrix(zeros(2, 1), (Axis(:a, :b), FlatAxis())) .* ComponentMatrix(zeros(1, 2), (Axis(:e), Axis(:c, :d)))
ComponentMatrix(zeros(2, 1), (Axis(:a, :b), Axis(:f))) .* ComponentMatrix(zeros(1, 2), (Axis(:e), Axis(:c, :d)))
ComponentVector(a=0.0, b=0.0).*ComponentVector(c=0.0, d=0.0)'
ComponentVector(a=0.0).*ComponentVector(c=0.0)'
ComponentMatrix(zeros(1, 1), (Axis(:a), FlatAxis())) .* ComponentMatrix(zeros(1, 1), (FlatAxis(), Axis(:c)))
ComponentMatrix(zeros(1, 1), (Axis(:a), FlatAxis())) .* ComponentMatrix(zeros(1, 1), (Axis(:d), Axis(:c)))
ComponentMatrix(zeros(1, 1), (Axis(:a), Axis(:b))) .* ComponentMatrix(zeros(1, 1), (FlatAxis(), Axis(:c)))
ComponentMatrix(zeros(1, 1), (Axis(:a), Axis(:b))) .* ComponentMatrix(zeros(1, 1), (Axis(:d), Axis(:c)))

as I find the result of master on these inconsistent.

Copy link
Collaborator

@jonniedie jonniedie Dec 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear to me what the StaticArrays example demonstrates?

The StaticArrays example demonstrates that Broadcast.combine_axes is choosing the non-singleton axes of each dimension from the arrays it's given rather than falling back to the default Base.OneTo axes.

Also, in my understanding the broadcasting behaviour of master is inconsistent (essentially a special case only for transposes?):

There's kinda a complicated story to why we have to special case vector transposes. Before ArrayInterfaces.jl existed, the way DifferentialEquations.jl and some other packages created matrix cache arrays from vectors (say, for a Jacobian cache) was to call u .* u' .* false. Now ArrayInterfaces.jl lets you overload a zeromatrix function for this purpose. We want the cache matrices to in these cases to be ComponentMatrix so users that are interacting with them directly can address blocks by named indices. We could probably at this point switch to using zeromatrix. It may break things for some packages that are still doing u .* u' to expand matrices, though, so I've kinda put off the idea for fear of having to deal with that.

But with that said, I actually want the current behavior we have for the vector and vector transpose case to be the behavior for all cases. Unfortunately its just inherently type unstable to do this in general because it would mean that the size of a dimension (which generally isn't known at compile time) would determine the type of axes we choose. Consider the following examples (I'm using vectors instead of matrices here because it makes the issue clearer and removes the temptation for us to think of this as having anything to do with matrix multiplication):

# Group 1
ComponentVector(a=4, b=2) .* [3]    # Case 1
ComponentVector(a=4, b=2) .* [3, 1] # Case 2

# Group 2
ComponentVector(a=4) .* [3]    # Case 1
ComponentVector(a=4) .* [3, 1] # Case 2

Within each group, the argument types are the same, so for type stability, we need to have the return types the same as well.

In Group 1, Case 1 (I'll just abbreviate it to G1C1), it's clear that we should want use the axis of the ComponentVector, since the regular Vector has a single element and will thus be extruded into the axis of the other array for that dimension. If we want to have this, however, we also need to make it so the ComponentVector axis "wins" for G1C2, since the compiler doesn't know how many elements the plain Vector has.

In G2C1, we're faced with a similar choice and it seems like we should similarly let the ComponentVector "win" here for consistency. But clearly we can't do that because in G2C2, which looks the same to the compiler as G2C1, we will be extruding the single element ComponentVector onto the axes of the plain Vector. So for this group, we have to let both results fall back to giving a plain Vector as an output, despite the fact that we would probably prefer to have the G1C1 be a ComponentVector for consistency.

For inspiration into how it should be handled, I think we should look back at StaticArrays and see how they do it:

julia> SA[4, 2] .* [3] # G1C1
2-element SizedVector{2, Int64, Vector{Int64}} with indices SOneTo(2):
 12
  6

julia> SA[4, 2] .* [3, 1] # G1C2
2-element SizedVector{2, Int64, Vector{Int64}} with indices SOneTo(2):
 12
  2

julia> SA[4] .* [3] # G2C1
1-element Vector{Int64}:
 12

julia> SA[4] .* [3, 1] # G2C2
2-element Vector{Int64}:
 12
  4

Since StaticArrays have knowledge of their size, they make the decision to always use the static array axis when the length of the dimension of the StaticArray is > 1 and always use the other array's axis when it's == 1. This makes sense and probably gives the best tradeoff between consistency and usefulness.

Could we do the same thing for ComponentArrays? Well... sorta. In most usage of ComponentArrays, the named indices correspond 1:1 to the array elements so we can statically know the size of the ComponentArray by looking at the last index of each dimension of the component axes. But there is this undocumented but still very useful use of Axes that don't follow this behavior (see #163). So we can't necessarily guarantee that a dimension of aComponentArray's is length 1 just because it's component axis type information seems to indicate it is. But maybe we could just assume they are and throw an error in cases where this isn't true?

Copy link
Contributor Author

@nrontsis nrontsis Jan 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed explanation.

There's kinda a complicated story to why we have to special case vector transposes ... We could probably at this point switch to using zeromatrix. It may break things for some packages that are still doing u .* u' to expand matrices, though, so I've kinda put off the idea for fear of having to deal with that.

This sounds ideal, as I am all in to eliminating special cases. Assuming that u .* u'(.*false?) is used to expand matrices in some packages (that we don't test against?) sounds fragile to me and, as far as I understand, it already "fails to work" in very popular packages that I happen to use:

julia> using ForwardDiff, ComponentArrays
julia> ForwardDiff.jacobian(x -> x, ComponentVector(a=1.0))
1×1 Matrix{Float64}:
 1.0

The potential behaviour that you described seems complete and well thought (at a first look), but I am of the opinion that unless or until we have a complete and documented behaviour the resulting complexity and confusion will make the simplicity of a uniform rule like "broadcasting results in flataxis" (as per this PR) will be preferable, and that's before taking maintenance overheads into mind.

Nevertheless, I think this chat might be expanding beyond the scope of this PR. If you want to maintain the special case for transposed broadcast, I will include it in this PR. I tried to do that today by adding the following lines in broadcasted.jl:

Broadcast.combine_axes(
    a::ComponentArray, b::ComponentArray{T, 2, <:Union{LinearAlgebra.Adjoint{T, A}, LinearAlgebra.Transpose{T, A}}, Ax}
) where {T, A <: AbstractVector, Ax} = (axes(a)[1], axes(b)[2])
Broadcast.combine_axes(
    a::ComponentArray{T, 2, <:Union{LinearAlgebra.Adjoint{T, A}, LinearAlgebra.Transpose{T, A}}, Ax}, b::ComponentArray
) where {T, A <: AbstractVector, Ax} = (axes(b)[2], axes(a)[1])

to substitute the lines in that file currently deleted by this PR. This works for
ComponentVector(a=0) .* ComponentVector(b=0)'
but not for
false .* ComponentVector(a=0) .* ComponentVector(b=0)'.
I failed to find why it does not work (while it works in master) even after spending a couple of hours on this today, so I would appreciate some help on this, assuming that's the route we are taking.

@YichengDWu
Copy link
Contributor

Why not just overload getindex for the adjoint?

@nrontsis
Copy link
Contributor Author

nrontsis commented Nov 18, 2022

Why not just overload getindex for the adjoint?

  • To reduce the size of this library and the number of concepts that it introduces
  • ComponentArrays either way allows for ComponentMatrix{Float64, Adjoint{Float64, Matrix{Float64}} that can be created like: ComponentMatrix(ones(2, 2)', Axis(:a, :b), FlatAxis()). There were thus two kind of adjoints: ComponentMatrix{Float64, Adjoint{Float64, Matrix{Float64}} and adjoint(::ComponentMatrix{Float64, Matrix{Float64}, resulting in confusion about which one to be used.
  • Besides the getindex issue, broadcasting seems to give incorrect results in the following example, as detailed in this comment:
julia> a = ComponentVector(ones(2), Axis(:a, :b))
julia> a .* a'
2×2 ComponentMatrix{Float64} with axes Axis(a = 1, b = 2) × Axis(a = 1, b = 2)

It's unclear what other edge cases we might be missing.

Copy link
Collaborator

@jonniedie jonniedie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! Sorry it’s taken so long to get to reviewing. The only thing is we should keep all of the dot multiplies in the broadcasting tests because they are testing that broadcasting is working correctly. We already have a section in the math tests that test whether outer and inner products work.

@@ -369,11 +369,11 @@ end
@testset "Broadcasting" begin
temp = deepcopy(ca)
@test eltype(Float32.(ca)) == Float32
@test ca .* ca' == cmat
@test ca * ca' == cmat
@test 1 .* (ca .+ ca) == ComponentArray(a .+ a, getaxes(ca))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don’t want to get rid of tests for broadcasted extrusion. All of these should stay as dot multiplication. We already have tests for linear algebra. These are equivalent mathematically, but not programmatically.

@test getaxes(false .* ca .* ca') == (ax, ax)
@test getaxes(false .* ca' .* ca) == (ax, ax)
@test getaxes(false .* ca * ca') == (ax, ax)
@test isa(ca' * ca, Float64)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to make sure both forms of broadcasting work.

@nrontsis
Copy link
Contributor Author

Thanks a lot for the review @jonniedie!

I am happy to add any tests you feel sensible. However I think some of the old tests were actually incorrect, as I argue in this thread. Can you reply to this thread with what you think on my argument?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants