Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Adjoints to be ComponentArrays #170

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/array_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ function Base.vcat(x::AbstractComponentVecOrMat, y::AbstractComponentVecOrMat)
return ComponentArray(vcat(data_x, data_y), Axis((;idxmap_x..., idxmap_y...)), getaxes(x)[2:end]...)
end
end
Base.vcat(x::CV...) where {CV<:AdjOrTransComponentArray} = ComponentArray(reduce(vcat, map(y->getdata(y.parent)', x)), getaxes(x[1]))
Base.vcat(x::ComponentVector, args...) = vcat(getdata(x), getdata.(args)...)
Base.vcat(x::ComponentVector, args::Union{Number, UniformScaling, AbstractVecOrMat}...) = vcat(getdata(x), getdata.(args)...)
Base.vcat(x::ComponentVector, args::Vararg{AbstractVector{T}, N}) where {T,N} = vcat(getdata(x), getdata.(args)...)
Expand Down
4 changes: 0 additions & 4 deletions src/broadcasting.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
Base.BroadcastStyle(::Type{<:ComponentArray{T, N, A, Axes}}) where {T, N, A, Axes} = Broadcast.BroadcastStyle(A)

# Need special case here for adjoint vectors in order to avoid type instability in axistype
Broadcast.combine_axes(a::ComponentArray, b::AdjOrTransComponentVector) = (axes(a)[1], axes(b)[2])
Broadcast.combine_axes(a::AdjOrTransComponentVector, b::ComponentArray) = (axes(b)[2], axes(a)[1])

Broadcast.axistype(a::CombinedAxis, b::AbstractUnitRange) = a
Broadcast.axistype(a::AbstractUnitRange, b::CombinedAxis) = b
Broadcast.axistype(a::CombinedAxis, b::CombinedAxis) = CombinedAxis(FlatAxis(), Base.Broadcast.axistype(_array_axis(a), _array_axis(b)))
Expand Down
203 changes: 17 additions & 186 deletions src/compat/gpuarrays.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
const GPUComponentArray = ComponentArray{T,N,<:GPUArrays.AbstractGPUArray,Ax} where {T,N,Ax}
const GPUComponentVector{T,Ax} = ComponentArray{T,1,<:GPUArrays.AbstractGPUVector,Ax}
const GPUComponentMatrix{T,Ax} = ComponentArray{T,2,<:GPUArrays.AbstractGPUMatrix,Ax}
const AbstractGPUArrayOrAdj = Union{<:GPUArrays.AbstractGPUArray{T, N}, Adjoint{T, <:GPUArrays.AbstractGPUArray{T, N}}, Transpose{T, <:GPUArrays.AbstractGPUArray{T, N}}} where {T, N}
const GPUComponentArray = ComponentArray{T,N,<:AbstractGPUArrayOrAdj{T, N},Ax} where {T,N,Ax}
const GPUComponentVector{T,Ax} = ComponentArray{T,1,<:AbstractGPUArrayOrAdj{T, 1},Ax}
const GPUComponentMatrix{T,Ax} = ComponentArray{T,2,<:AbstractGPUArrayOrAdj{T, 2},Ax}
const GPUComponentVecorMat{T,Ax} = Union{GPUComponentVector{T,Ax},GPUComponentMatrix{T,Ax}}

GPUArrays.backend(x::ComponentArray) = GPUArrays.backend(getdata(x))
Expand All @@ -25,7 +26,10 @@ end

LinearAlgebra.dot(x::GPUComponentArray, y::GPUComponentArray) = dot(getdata(x), getdata(y))
LinearAlgebra.norm(ca::GPUComponentArray, p::Real) = norm(getdata(ca), p)
LinearAlgebra.rmul!(ca::GPUComponentArray, b::Number) = GPUArrays.generic_rmul!(ca, b)
function LinearAlgebra.rmul!(ca::GPUComponentArray, b::Number)
GPUArrays.generic_rmul!(getdata(ca), b)
return ca
end

function Base.map(f, x::GPUComponentArray, args...)
data = map(f, getdata(x), getdata.(args)...)
Expand Down Expand Up @@ -78,196 +82,23 @@ end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::GPUComponentVecorMat, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
return GPUArrays.generic_matmatmul!(C, getdata(A), getdata(B), a, b)
end

function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::GPUComponentVecorMat, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::GPUComponentVecorMat, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
A::AbstractGPUArrayOrAdj,
B::GPUComponentVecorMat, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, B::GPUComponentVecorMat,
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
},
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
},
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
return GPUArrays.generic_matmatmul!(C, A, getdata(B), a, b)
end

function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::GPUComponentVecorMat, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real,
b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::GPUComponentVecorMat, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::GPUComponentVecorMat, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::GPUComponentVecorMat, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, B::GPUComponentVecorMat,
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real,
b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
},
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real,
b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
B::AbstractGPUArrayOrAdj, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, getdata(A), B, a, b)
end

function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
},
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Real, b::Real)
A::AbstractGPUArrayOrAdj,
B::AbstractGPUArrayOrAdj, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
end
18 changes: 4 additions & 14 deletions src/componentarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,11 @@ const CArray = ComponentArray
const CVector = ComponentVector
const CMatrix = ComponentMatrix

const AdjOrTrans{T, A} = Union{Adjoint{T, A}, Transpose{T, A}}
const AdjOrTransComponentArray{T, A} = Union{Adjoint{T, A}, Transpose{T, A}} where A<:ComponentArray
const AdjOrTransComponentVector{T} = Union{Adjoint{T, A}, Transpose{T, A}} where A<:ComponentVector
const AdjOrTransComponentMatrix{T} = Union{Adjoint{T, A}, Transpose{T, A}} where A<:ComponentMatrix

const ComponentVecOrMat = Union{ComponentVector, ComponentMatrix}
const AdjOrTransComponentVecOrMat = AdjOrTrans{T, <:ComponentVecOrMat} where T
const AbstractComponentArray = Union{ComponentArray, AdjOrTransComponentArray}
const AbstractComponentVecOrMat = Union{ComponentVecOrMat, AdjOrTransComponentVecOrMat}
const AbstractComponentVector = Union{ComponentVector, AdjOrTransComponentVector}
const AbstractComponentMatrix = Union{ComponentMatrix, AdjOrTransComponentMatrix}
const AbstractComponentArray = ComponentArray
const AbstractComponentVecOrMat = ComponentVecOrMat
const AbstractComponentVector = ComponentVector
const AbstractComponentMatrix = ComponentMatrix


## Constructor helpers
Expand Down Expand Up @@ -288,12 +282,8 @@ julia> getaxes(ca)
```
"""
@inline getaxes(x::ComponentArray) = getfield(x, :axes)
@inline getaxes(x::AdjOrTrans{T, <:ComponentVector}) where T = (FlatAxis(), getaxes(x.parent)[1])
@inline getaxes(x::AdjOrTrans{T, <:ComponentMatrix}) where T = reverse(getaxes(x.parent))

@inline getaxes(::Type{<:ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = map(x->x(), (Axes.types...,))
@inline getaxes(::Type{<:AdjOrTrans{T,CA}}) where {T,CA<:ComponentVector} = (FlatAxis(), getaxes(CA)[1]) |> typeof
@inline getaxes(::Type{<:AdjOrTrans{T,CA}}) where {T,CA<:ComponentMatrix} = reverse(getaxes(CA)) |> typeof

## Field access through these functions to reserve dot-getting for keys
@inline getaxes(x::VarAxes) = getaxes(typeof(x))
Expand Down
25 changes: 13 additions & 12 deletions src/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ _first_axis(x::AbstractComponentVecOrMat) = getaxes(x)[1]

_second_axis(x::AbstractMatrix) = FlatAxis()
_second_axis(x::ComponentMatrix) = getaxes(x)[2]
_second_axis(x::AdjOrTransComponentVecOrMat) = getaxes(x)[2]

_out_axes(::typeof(*), a, b::AbstractVector) = (_first_axis(a), )
_out_axes(::typeof(*), a, b::AbstractMatrix) = (_first_axis(a), _second_axis(b))
Expand All @@ -27,19 +26,21 @@ for op in [:*, :\, :/]
function Base.$op(A::AbstractComponentVecOrMat, B::AbstractComponentVecOrMat)
C = $op(getdata(A), getdata(B))
ax = _out_axes($op, A, B)
return ComponentArray(C, ax)
return ComponentArray(C, ax...)
end
end
for (adj, Adj) in zip([:adjoint, :transpose], [:Adjoint, :Transpose])
@eval begin
function Base.$op(aᵀ::$Adj{T,<:ComponentVector}, B::AbstractComponentMatrix) where {T}
cᵀ = $op(getdata(aᵀ), getdata(B))
ax2 = _out_axes($op, aᵀ, B)[2]
return $adj(ComponentArray(cᵀ', ax2))
end
function Base.$op(A::$Adj{T,<:CV}, B::CV) where {T<:Real, CV<:ComponentVector{T}}
return $op(getdata(A), getdata(B))
end
end


for op in [:adjoint, :transpose]
@eval begin
function LinearAlgebra.$op(M::ComponentMatrix{T,A,Tuple{Ax1,Ax2}}) where {T,A,Ax1,Ax2}
data = $op(getdata(M))
return ComponentArray(data, (Ax2(), Ax1())[1:ndims(data)]...)
end

function LinearAlgebra.$op(M::ComponentVector{T,A,Tuple{Ax1}}) where {T,A,Ax1}
return ComponentMatrix($op(getdata(M)), FlatAxis(), Ax1())
end
end
end
2 changes: 1 addition & 1 deletion src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,4 @@ function Base.show(io::IO, ::MIME"text/plain", x::ComponentMatrix{T,A,Axes}) whe
println(io, " with axes $(axs[1]) × $(axs[2])")
Base.print_matrix(io, getdata(x))
return nothing
end
end
2 changes: 1 addition & 1 deletion test/gpu_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ end
@test rmul!(jlca3, 2) == ComponentArray(jla .* 2, Axis(a=1:2, b=3:4))
end
@testset "mul!" begin
A = jlca .* jlca';
Copy link
Contributor Author

@nrontsis nrontsis Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and in many other tests, I think we were testing the wrong behaviour: I think we should have used * instead of .* (like this PR does).

.* is different than * in that it performs broadcasting on both sides and then multiplies elementwise. However, the axis of the (broadcasted) left and right hand side are not identical, so a non-ComponentMatrix must be returned.

Copy link
Contributor

@YichengDWu YichengDWu Nov 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broadcasting on GPU is a behaviour that we particularly want to test.

The axes of the two sides are not required to be identical in broadcasting (and in general).

Copy link
Contributor Author

@nrontsis nrontsis Nov 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right about this test; I have thus restored the diff here.

However, in other places like in this one, the testing logic was actually incorrect. In this particular example, doing broadcasting should not yield the axes that the test was specifying, as detailed in my comment above.

To correct these, I could either change the operation to be matrix multiplication (instead of broadcasting), or change the testing logic. I chose the former, but happy to change if you think otherwise.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broadcasting should extrude singleton dimensions to the non-singleton axes of that dimension within the broadcasted operation. We're following the same behavior here that the other array packages follow:

julia> using StaticArrays

julia> a = SA[1, 2, 3]; b = SA[4, 5, 6];

julia> a .* b'
3×3 SMatrix{3, 3, Int64, 9} with indices SOneTo(3)×SOneTo(3):
  4   5   6
  8  10  12
 12  15  18

Copy link
Collaborator

@jonniedie jonniedie Dec 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the case of a ::AbstractVector .* ::Adjoint{_, <:AbstractVector}, it just happens to give the same result as plain matrix multiplication.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear to me what the StaticArrays example demonstrates? For example, I might argue that the result of the example with StaticArrays is compatible with the current behaviour of this PR, as the indices of b and a are already compatible, in the sense that:

vcat(b', b', b')

gives a 3×3 SMatrix{3, 3, Int64, 9} with indices SOneTo(3)×SOneTo(3) as a .* b' does.

Also, in my understanding the broadcasting behaviour of master is inconsistent (essentially a special case only for transposes?):

julia> ComponentMatrix(zeros(2, 1), (Axis(:a, :b), FlatAxis())) .* ComponentMatrix(zeros(1, 2), (FlatAxis(), Axis(:c, :d)))
2×2 Matrix{Float64}:
 0.0  0.0
 0.0  0.0

julia> ComponentVector(a=0.0, b=0.0).*ComponentVector(c=0.0, d=0.0)'
2×2 ComponentMatrix{Float64} with axes Axis(a = 1, b = 2) × Axis(c = 1, d = 2)
 0.0  0.0
 0.0  0.0

In general, I find it hard to imagine a way to implement the broadcasting described in way that will result in a consistent operation both for multiplying 2x1 .* 1x1 matrices and 1x1 .* 1x1 matrices. If we want to continue this chat, it would help me a lot if we could clarify the behaviour of:

ComponentMatrix(zeros(2, 1), (Axis(:a, :b), FlatAxis())) .* ComponentMatrix(zeros(1, 2), (FlatAxis(), Axis(:c, :d)))
ComponentMatrix(zeros(2, 1), (Axis(:a, :b), FlatAxis())) .* ComponentMatrix(zeros(1, 2), (Axis(:e), Axis(:c, :d)))
ComponentMatrix(zeros(2, 1), (Axis(:a, :b), Axis(:f))) .* ComponentMatrix(zeros(1, 2), (Axis(:e), Axis(:c, :d)))
ComponentVector(a=0.0, b=0.0).*ComponentVector(c=0.0, d=0.0)'
ComponentVector(a=0.0).*ComponentVector(c=0.0)'
ComponentMatrix(zeros(1, 1), (Axis(:a), FlatAxis())) .* ComponentMatrix(zeros(1, 1), (FlatAxis(), Axis(:c)))
ComponentMatrix(zeros(1, 1), (Axis(:a), FlatAxis())) .* ComponentMatrix(zeros(1, 1), (Axis(:d), Axis(:c)))
ComponentMatrix(zeros(1, 1), (Axis(:a), Axis(:b))) .* ComponentMatrix(zeros(1, 1), (FlatAxis(), Axis(:c)))
ComponentMatrix(zeros(1, 1), (Axis(:a), Axis(:b))) .* ComponentMatrix(zeros(1, 1), (Axis(:d), Axis(:c)))

as I find the result of master on these inconsistent.

Copy link
Collaborator

@jonniedie jonniedie Dec 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear to me what the StaticArrays example demonstrates?

The StaticArrays example demonstrates that Broadcast.combine_axes is choosing the non-singleton axes of each dimension from the arrays it's given rather than falling back to the default Base.OneTo axes.

Also, in my understanding the broadcasting behaviour of master is inconsistent (essentially a special case only for transposes?):

There's kinda a complicated story to why we have to special case vector transposes. Before ArrayInterfaces.jl existed, the way DifferentialEquations.jl and some other packages created matrix cache arrays from vectors (say, for a Jacobian cache) was to call u .* u' .* false. Now ArrayInterfaces.jl lets you overload a zeromatrix function for this purpose. We want the cache matrices to in these cases to be ComponentMatrix so users that are interacting with them directly can address blocks by named indices. We could probably at this point switch to using zeromatrix. It may break things for some packages that are still doing u .* u' to expand matrices, though, so I've kinda put off the idea for fear of having to deal with that.

But with that said, I actually want the current behavior we have for the vector and vector transpose case to be the behavior for all cases. Unfortunately its just inherently type unstable to do this in general because it would mean that the size of a dimension (which generally isn't known at compile time) would determine the type of axes we choose. Consider the following examples (I'm using vectors instead of matrices here because it makes the issue clearer and removes the temptation for us to think of this as having anything to do with matrix multiplication):

# Group 1
ComponentVector(a=4, b=2) .* [3]    # Case 1
ComponentVector(a=4, b=2) .* [3, 1] # Case 2

# Group 2
ComponentVector(a=4) .* [3]    # Case 1
ComponentVector(a=4) .* [3, 1] # Case 2

Within each group, the argument types are the same, so for type stability, we need to have the return types the same as well.

In Group 1, Case 1 (I'll just abbreviate it to G1C1), it's clear that we should want use the axis of the ComponentVector, since the regular Vector has a single element and will thus be extruded into the axis of the other array for that dimension. If we want to have this, however, we also need to make it so the ComponentVector axis "wins" for G1C2, since the compiler doesn't know how many elements the plain Vector has.

In G2C1, we're faced with a similar choice and it seems like we should similarly let the ComponentVector "win" here for consistency. But clearly we can't do that because in G2C2, which looks the same to the compiler as G2C1, we will be extruding the single element ComponentVector onto the axes of the plain Vector. So for this group, we have to let both results fall back to giving a plain Vector as an output, despite the fact that we would probably prefer to have the G1C1 be a ComponentVector for consistency.

For inspiration into how it should be handled, I think we should look back at StaticArrays and see how they do it:

julia> SA[4, 2] .* [3] # G1C1
2-element SizedVector{2, Int64, Vector{Int64}} with indices SOneTo(2):
 12
  6

julia> SA[4, 2] .* [3, 1] # G1C2
2-element SizedVector{2, Int64, Vector{Int64}} with indices SOneTo(2):
 12
  2

julia> SA[4] .* [3] # G2C1
1-element Vector{Int64}:
 12

julia> SA[4] .* [3, 1] # G2C2
2-element Vector{Int64}:
 12
  4

Since StaticArrays have knowledge of their size, they make the decision to always use the static array axis when the length of the dimension of the StaticArray is > 1 and always use the other array's axis when it's == 1. This makes sense and probably gives the best tradeoff between consistency and usefulness.

Could we do the same thing for ComponentArrays? Well... sorta. In most usage of ComponentArrays, the named indices correspond 1:1 to the array elements so we can statically know the size of the ComponentArray by looking at the last index of each dimension of the component axes. But there is this undocumented but still very useful use of Axes that don't follow this behavior (see #163). So we can't necessarily guarantee that a dimension of aComponentArray's is length 1 just because it's component axis type information seems to indicate it is. But maybe we could just assume they are and throw an error in cases where this isn't true?

Copy link
Contributor Author

@nrontsis nrontsis Jan 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed explanation.

There's kinda a complicated story to why we have to special case vector transposes ... We could probably at this point switch to using zeromatrix. It may break things for some packages that are still doing u .* u' to expand matrices, though, so I've kinda put off the idea for fear of having to deal with that.

This sounds ideal, as I am all in to eliminating special cases. Assuming that u .* u'(.*false?) is used to expand matrices in some packages (that we don't test against?) sounds fragile to me and, as far as I understand, it already "fails to work" in very popular packages that I happen to use:

julia> using ForwardDiff, ComponentArrays
julia> ForwardDiff.jacobian(x -> x, ComponentVector(a=1.0))
1×1 Matrix{Float64}:
 1.0

The potential behaviour that you described seems complete and well thought (at a first look), but I am of the opinion that unless or until we have a complete and documented behaviour the resulting complexity and confusion will make the simplicity of a uniform rule like "broadcasting results in flataxis" (as per this PR) will be preferable, and that's before taking maintenance overheads into mind.

Nevertheless, I think this chat might be expanding beyond the scope of this PR. If you want to maintain the special case for transposed broadcast, I will include it in this PR. I tried to do that today by adding the following lines in broadcasted.jl:

Broadcast.combine_axes(
    a::ComponentArray, b::ComponentArray{T, 2, <:Union{LinearAlgebra.Adjoint{T, A}, LinearAlgebra.Transpose{T, A}}, Ax}
) where {T, A <: AbstractVector, Ax} = (axes(a)[1], axes(b)[2])
Broadcast.combine_axes(
    a::ComponentArray{T, 2, <:Union{LinearAlgebra.Adjoint{T, A}, LinearAlgebra.Transpose{T, A}}, Ax}, b::ComponentArray
) where {T, A <: AbstractVector, Ax} = (axes(b)[2], axes(a)[1])

to substitute the lines in that file currently deleted by this PR. This works for
ComponentVector(a=0) .* ComponentVector(b=0)'
but not for
false .* ComponentVector(a=0) .* ComponentVector(b=0)'.
I failed to find why it does not work (while it works in master) even after spending a couple of hours on this today, so I would appreciate some help on this, assuming that's the route we are taking.

A = jlca * jlca';
@test_nowarn mul!(deepcopy(A), A, A, 1, 2);
@test_nowarn mul!(deepcopy(A), A', A', 1, 2);
@test_nowarn mul!(deepcopy(A), A', A, 1, 2);
Expand Down
Loading