diff --git a/Project.toml b/Project.toml index c750cd2..d41074f 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,8 @@ authors = ["ITensor developers and contributors"] version = "0.1.0" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" @@ -13,7 +15,9 @@ BroadcastMapConversion = {url = "https://github.com/ITensor/BroadcastMapConversi TypeParameterAccessors = {url = "https://github.com/ITensor/TypeParameterAccessors.jl"} [compat] +Adapt = "4.1.1" Aqua = "0.8.9" +ArrayLayouts = "1.10.4" BroadcastMapConversion = "0.1.0" LinearAlgebra = "1.10" Test = "1.10" diff --git a/README.md b/README.md index 779e283..0761df8 100644 --- a/README.md +++ b/README.md @@ -19,20 +19,35 @@ julia> Pkg.add(url="https://github.com/ITensor/SparseArraysBaseNext.jl") ````julia using SparseArraysBaseNext: - SparseArrayDOK, eachstoredindex, isstored, storedlength, storedvalues + SparseArrayDOK, eachstoredindex, isstored, storedlength, storedpairs, storedvalues +using Test: @test a = SparseArrayDOK{Float64}(2, 2) a[1, 2] = 12 -@show a[1, 1] -@show a[1, 2] +a[2, 1] = 21 +@test a[1, 1] == 0 +@test a[2, 1] == 21 +@test a[1, 2] == 12 +@test a[2, 2] == 0 + b = a .+ 2 .* a' -@show storedvalues(b) -@show eachstoredindex(b) -@show isstored(b, 1, 1) -@show isstored(b, 2, 1) -@show isstored(b, 1, 2) -@show isstored(b, 2, 2) -@show storedlength(b) +@test b[1, 1] == 0 +@test b[2, 1] == 21 + 2 * 12 +@test b[1, 2] == 12 + 2 * 21 +@test b[2, 2] == 0 +@test issetequal(storedvalues(b), [21 + 2 * 12, 12 + 2 * 21]) +@test issetequal(eachstoredindex(b), [CartesianIndex(2, 1), CartesianIndex(1, 2)]) +@test storedpairs(b) == + Dict(CartesianIndex(2, 1) => 21 + 2 * 12, CartesianIndex(1, 2) => 12 + 2 * 21) +@test !isstored(b, 1, 1) +@test isstored(b, 2, 1) +@test isstored(b, 1, 2) +@test !isstored(b, 2, 2) +@test storedlength(b) == 2 + +c = a * a' +@test storedlength(c) == 2 +@test c == [12*12 0; 0 21*21] ```` --- diff --git a/examples/README.jl b/examples/README.jl index d387cb5..cdc4e1e 100644 --- a/examples/README.jl +++ b/examples/README.jl @@ -20,17 +20,32 @@ julia> Pkg.add(url="https://github.com/ITensor/SparseArraysBaseNext.jl") # ## Examples using SparseArraysBaseNext: - SparseArrayDOK, eachstoredindex, isstored, storedlength, storedvalues + SparseArrayDOK, eachstoredindex, isstored, storedlength, storedpairs, storedvalues +using Test: @test a = SparseArrayDOK{Float64}(2, 2) a[1, 2] = 12 -@show a[1, 1] -@show a[1, 2] +a[2, 1] = 21 +@test a[1, 1] == 0 +@test a[2, 1] == 21 +@test a[1, 2] == 12 +@test a[2, 2] == 0 + b = a .+ 2 .* a' -@show storedvalues(b) -@show eachstoredindex(b) -@show isstored(b, 1, 1) -@show isstored(b, 2, 1) -@show isstored(b, 1, 2) -@show isstored(b, 2, 2) -@show storedlength(b) +@test b[1, 1] == 0 +@test b[2, 1] == 21 + 2 * 12 +@test b[1, 2] == 12 + 2 * 21 +@test b[2, 2] == 0 +@test issetequal(storedvalues(b), [21 + 2 * 12, 12 + 2 * 21]) +@test issetequal(eachstoredindex(b), [CartesianIndex(2, 1), CartesianIndex(1, 2)]) +@test storedpairs(b) == + Dict(CartesianIndex(2, 1) => 21 + 2 * 12, CartesianIndex(1, 2) => 12 + 2 * 21) +@test !isstored(b, 1, 1) +@test isstored(b, 2, 1) +@test isstored(b, 1, 2) +@test !isstored(b, 2, 2) +@test storedlength(b) == 2 + +c = a * a' +@test storedlength(c) == 2 +@test c == [12*12 0; 0 21*21] diff --git a/src/SparseArraysBaseNext.jl b/src/SparseArraysBaseNext.jl index 73ee183..035ccf5 100644 --- a/src/SparseArraysBaseNext.jl +++ b/src/SparseArraysBaseNext.jl @@ -1,10 +1,7 @@ module SparseArraysBaseNext -include("derive.jl") -include("wrappedarrays.jl") -include("abstractinterface.jl") -include("abstractarrayinterface.jl") -include("defaultarrayinterface.jl") +include("lib/InterfaceImplementations/InterfaceImplementations.jl") +using .InterfaceImplementations: InterfaceImplementations include("sparsearrayinterface.jl") include("sparsearraydok.jl") diff --git a/src/abstractarrayinterface.jl b/src/abstractarrayinterface.jl deleted file mode 100644 index eb94970..0000000 --- a/src/abstractarrayinterface.jl +++ /dev/null @@ -1,85 +0,0 @@ -# Like the equivalent `Base` functions but allow overloading -# by interface in the first argument. -# Overloading `Base` directly leads to too many method ambiguities. -# Maybe we can make `Interface.f` instead of `interface_f` in the future -# but that is tricky with namespacing issues. -# TODO: Define generic fallbacks that use `invoke`. -function interface_getindex end -function interface_setindex! end -function interface_similar end -function interface_map end -function interface_map! end -function interface_BroadcastStyle end - -# TODO: Add `ndims` type parameter. -abstract type AbstractArrayInterface <: AbstractInterface end - -function AbstractArrayOps() - return ( - AbstractArrayGetIndex(), - AbstractArraySetIndex(), - AbstractArraySimilar(), - AbstractArrayMap(), - AbstractArrayBroadcast(), - ) -end - -struct AbstractArrayGetIndex <: AbstractArrayInterface end - -function derive(type::Type, ::AbstractArrayGetIndex) - return quote - function Base.getindex(a::$type, I::Int...) - return interface_getindex($AbstractInterface(a), a, I...) - end - end -end - -struct AbstractArraySetIndex <: AbstractArrayInterface end - -function derive(type::Type, ::AbstractArraySetIndex) - return quote - function Base.setindex!(a::$type, value, I::Int...) - interface_setindex!($AbstractInterface(a), a, value, I...) - return a - end - end -end - -struct AbstractArraySimilar <: AbstractArrayInterface end - -function derive(type::Type, ::AbstractArraySimilar) - return quote - # TODO: Generalize to axes. - function Base.similar(a::$type, T::Type, I::Tuple{Vararg{Int}}) - return interface_similar($AbstractInterface(a), a, T, I) - end - end -end - -struct AbstractArrayMap <: AbstractArrayInterface end - -function derive(type::Type, ::AbstractArrayMap) - return quote - function Base.map(f, as::$type...) - # TODO: Define interface promotion, maybe just use `BroadcastStyle` directly. - # https://docs.julialang.org/en/v1/manual/interfaces/#writing-binary-broadcasting-rules - return interface_map($AbstractInterface($AbstractInterface.(as)...), f, as...) - end - function Base.map!(f, dest::$type, as::AbstractArray...) - # TODO: Define interface promotion, maybe just use `BroadcastStyle` directly. - # https://docs.julialang.org/en/v1/manual/interfaces/#writing-binary-broadcasting-rules - interface_map!($AbstractInterface($AbstractInterface.(as)...), f, dest, as...) - return dest - end - end -end - -struct AbstractArrayBroadcast <: AbstractArrayInterface end - -function derive(type::Type, ::AbstractArrayBroadcast) - return quote - function Base.BroadcastStyle(type::Type{<:$type}) - return interface_BroadcastStyle($AbstractInterface(type), type) - end - end -end diff --git a/src/derive.jl b/src/derive.jl deleted file mode 100644 index f15ad21..0000000 --- a/src/derive.jl +++ /dev/null @@ -1,27 +0,0 @@ -macro derive(type, interfaces...) - return esc(derive(type, interfaces)) -end - -# TODO: Use `TupleTools.jl`. -tuple_cat(t1::Tuple, t2::Tuple) = (t1..., t2...) -tuple_cat(t1::Tuple, t2) = (t1..., t2) -tuple_cat(t1, t2::Tuple) = (t1, t2...) -tuple_cat(t1, t2) = (t1, t2) -tuple_flatten(t::Tuple) = reduce(tuple_cat, t) - -function derive(type::Union{Symbol,Expr}, interfaces::Tuple{Vararg{Symbol}}) - return derive(eval(type), tuple_flatten(map(interface -> eval(interface)(), interfaces))) -end - -function derive(types::Tuple{Vararg{Type}}, interfaces::Tuple) - return derive(Union{types...}, interfaces::Tuple) -end - -function derive(type_or_types::Union{Type,Tuple{Vararg{Type}}}, interfaces::Tuple) - expr = Expr(:block) - for interface in interfaces - subexpr = derive(type_or_types, interface) - expr = Expr(:block, expr.args..., subexpr.args...) - end - return expr -end diff --git a/src/lib/InterfaceImplementations/InterfaceImplementations.jl b/src/lib/InterfaceImplementations/InterfaceImplementations.jl new file mode 100644 index 0000000..8ebf7fe --- /dev/null +++ b/src/lib/InterfaceImplementations/InterfaceImplementations.jl @@ -0,0 +1,9 @@ +module InterfaceImplementations + +include("abstractinterface.jl") +include("derive.jl") +include("wrappedarrays.jl") +include("abstractarrayinterface.jl") +include("defaultarrayinterface.jl") + +end diff --git a/src/lib/InterfaceImplementations/abstractarrayinterface.jl b/src/lib/InterfaceImplementations/abstractarrayinterface.jl new file mode 100644 index 0000000..8bc45bf --- /dev/null +++ b/src/lib/InterfaceImplementations/abstractarrayinterface.jl @@ -0,0 +1,183 @@ +# Like the equivalent `Base` functions but allow overloading +# by interface in the first argument. +# Overloading `Base` directly leads to too many method ambiguities. +# TODO: Define generic fallbacks that use `invoke`. +function getindex end +function setindex! end +function similar end +function map end +function map! end +function BroadcastStyle end +function MemoryLayout end +function mul! end + +# TODO: Add `ndims` type parameter. +abstract type AbstractArrayInterface <: AbstractInterface end + +function AbstractArrayOps() + return ( + Base.getindex, + Base.setindex!, + Base.similar, + Base.map, + Base.map!, + Base.Broadcast.BroadcastStyle, + ArrayLayouts.MemoryLayout, + LinearAlgebra.mul!, + ) +end + +function derive(op::typeof(Base.getindex), interface::AbstractArrayInterface, type::Type) + return quote + # TODO: Use `ArrayLayouts.layout_getindex`, `ArrayLayouts.sub_materialize` + # to handle slicing (implemented by copying SubArray). + function Base.getindex(a::$type, I::Int...) + return $getindex($interface, a, I...) + end + end +end +function derive(op::typeof(Base.getindex), type::Type) + return quote + function Base.getindex(a::$type, I::Int...) + return $getindex($AbstractInterface(a), a, I...) + end + end +end + +function derive(op::typeof(Base.setindex!), interface::AbstractArrayInterface, type::Type) + return quote + function Base.setindex!(a::$type, value, I::Int...) + $setindex!($interface, a, value, I...) + return a + end + end +end +function derive(op::typeof(Base.setindex!), type::Type) + return quote + function Base.setindex!(a::$type, value, I::Int...) + $setindex!($AbstractInterface(a), a, value, I...) + return a + end + end +end + +function derive(op::typeof(Base.similar), interface::AbstractArrayInterface, type::Type) + return quote + # TODO: Generalize to axes. + function Base.similar(a::$type, T::Type, I::Tuple{Vararg{Int}}) + return $similar($interface, a, T, I) + end + end +end +function derive(op::typeof(Base.similar), type::Type) + return quote + # TODO: Generalize to axes. + function Base.similar(a::$type, T::Type, I::Tuple{Vararg{Int}}) + return $similar($AbstractInterface(a), a, T, I) + end + end +end + +function derive(op::typeof(Base.map), interface::AbstractArrayInterface, type::Type) + return quote + function Base.map(f, as::$type...) + # TODO: Define interface promotion, maybe just use `BroadcastStyle` directly. + # https://docs.julialang.org/en/v1/manual/interfaces/#writing-binary-broadcasting-rules + return $map($interface, f, as...) + end + end +end +function derive(op::typeof(Base.map), type::Type) + return quote + function Base.map(f, as::$type...) + # TODO: Define interface promotion, maybe just use `BroadcastStyle` directly. + # https://docs.julialang.org/en/v1/manual/interfaces/#writing-binary-broadcasting-rules + return $map($AbstractInterface(AbstractInterface.(as)...), f, as...) + end + end +end + +function derive(op::typeof(Base.map!), interface::AbstractArrayInterface, type::Type) + return quote + function Base.map!(f, dest::$type, as::AbstractArray...) + # TODO: Define interface promotion, maybe just use `BroadcastStyle` directly. + # https://docs.julialang.org/en/v1/manual/interfaces/#writing-binary-broadcasting-rules + $map!($interface, f, dest, as...) + return dest + end + end +end +function derive(op::typeof(Base.map!), type::Type) + return quote + function Base.map!(f, dest::$type, as::AbstractArray...) + # TODO: Define interface promotion, maybe just use `BroadcastStyle` directly. + # https://docs.julialang.org/en/v1/manual/interfaces/#writing-binary-broadcasting-rules + $map!($AbstractInterface(AbstractInterface.(as)...), f, dest, as...) + return dest + end + end +end + +function derive( + op::Type{Base.Broadcast.BroadcastStyle}, interface::AbstractArrayInterface, type::Type +) + return quote + function Base.Broadcast.BroadcastStyle(type::Type{<:$type}) + return $BroadcastStyle($interface, type) + end + end +end +function derive(op::Type{Base.Broadcast.BroadcastStyle}, type::Type) + return quote + function Base.Broadcast.BroadcastStyle(type::Type{<:$type}) + return $BroadcastStyle($AbstractInterface(type), type) + end + end +end + +using ArrayLayouts: ArrayLayouts +function derive( + op::Type{ArrayLayouts.MemoryLayout}, interface::AbstractArrayInterface, type::Type +) + return quote + function $ArrayLayouts.MemoryLayout(type::Type{<:$type}) + return $MemoryLayout($interface, type) + end + end +end +function derive(op::Type{ArrayLayouts.MemoryLayout}, type::Type) + return quote + function $ArrayLayouts.MemoryLayout(type::Type{<:$type}) + return $MemoryLayout($AbstractInterface(type), type) + end + end +end + +using LinearAlgebra: LinearAlgebra +function derive( + op::typeof(LinearAlgebra.mul!), interface::AbstractArrayInterface, type::Type +) + return quote + function $LinearAlgebra.mul!(dest::$type, a::$type, b::$type, α::Number, β::Number) + # TODO: Determine from `a` and `b`. + return $mul!($interface, dest, a, b, α, β) + end + end +end +function derive(op::typeof(LinearAlgebra.mul!), type::Type) + return quote + function $LinearAlgebra.mul!(dest::$type, a::$type, b::$type, α::Number, β::Number) + # TODO: Determine from `a` and `b`. + return $mul!( + $AbstractInterface( + $AbstractInterface(dest), $AbstractInterface(a), $AbstractInterface(b) + ), + dest, + a, + b, + α, + β, + ) + end + end +end diff --git a/src/abstractinterface.jl b/src/lib/InterfaceImplementations/abstractinterface.jl similarity index 100% rename from src/abstractinterface.jl rename to src/lib/InterfaceImplementations/abstractinterface.jl diff --git a/src/defaultarrayinterface.jl b/src/lib/InterfaceImplementations/defaultarrayinterface.jl similarity index 100% rename from src/defaultarrayinterface.jl rename to src/lib/InterfaceImplementations/defaultarrayinterface.jl diff --git a/src/lib/InterfaceImplementations/derive.jl b/src/lib/InterfaceImplementations/derive.jl new file mode 100644 index 0000000..2619a5e --- /dev/null +++ b/src/lib/InterfaceImplementations/derive.jl @@ -0,0 +1,29 @@ +macro derive(type, interface, ops) + return esc(derive(__module__, ops, interface, type)) +end + +macro derive(type, ops) + return esc(derive(__module__, ops, type)) +end + +function derive(mod::Module, ops::Union{Symbol,Expr}, type::Union{Symbol,Expr}) + return derive(Base.eval(mod, ops), Base.eval(mod, type)) +end + +function derive( + mod::Module, + ops::Union{Symbol,Expr}, + interface::Union{Symbol,Expr}, + type::Union{Symbol,Expr}, +) + return derive(Base.eval(mod, ops), Base.eval(mod, interface), Base.eval(mod, type)) +end + +function derive(ops::Tuple, args...) + expr = Expr(:block) + for op in ops + subexpr = derive(op, args...) + expr = Expr(:block, expr.args..., subexpr.args...) + end + return expr +end diff --git a/src/lib/InterfaceImplementations/wrappedarrays.jl b/src/lib/InterfaceImplementations/wrappedarrays.jl new file mode 100644 index 0000000..ea87dd5 --- /dev/null +++ b/src/lib/InterfaceImplementations/wrappedarrays.jl @@ -0,0 +1,14 @@ +using Adapt: WrappedArray + +macro wrappedtype(type) + return esc(wrappedtype(type)) +end + +function wrappedtype(type::Symbol) + wrappedtype = Symbol(:Wrapped, type) + anytype = Symbol(:Any, type) + return quote + const $wrappedtype{T,N} = $WrappedArray{T,N,$type,$type{T,N}} + const $anytype{T,N} = Union{$type{T,N},$wrappedtype{T,N}} + end +end diff --git a/src/sparsearraydok.jl b/src/sparsearraydok.jl index 8e27f6f..f35143c 100644 --- a/src/sparsearraydok.jl +++ b/src/sparsearraydok.jl @@ -1,5 +1,6 @@ +# TODO: Define `AbstractSparseArray`, make this a subtype. struct SparseArrayDOK{T,N} <: AbstractArray{T,N} - storedvalues::Dict{CartesianIndex{N},T} + storage::Dict{CartesianIndex{N},T} size::NTuple{N,Int} end @@ -8,34 +9,46 @@ function SparseArrayDOK{T}(size::Int...) where {T} return SparseArrayDOK{T,N}(Dict{CartesianIndex{N},T}(), size) end -AbstractInterface(::Type{<:SparseArrayDOK}) = SparseArrayInterface() +using .InterfaceImplementations: @wrappedtype +# Define `WrappedSparseArrayDOK` and `AnySparseArrayDOK`. +@wrappedtype SparseArrayDOK -@derive AnyArrays(SparseArrayDOK) AbstractArrayOps +using .InterfaceImplementations: InterfaceImplementations +function InterfaceImplementations.AbstractInterface(::Type{<:SparseArrayDOK}) + return SparseArrayInterface() +end + +using .InterfaceImplementations: AbstractArrayOps, @derive +@derive AnySparseArrayDOK AbstractArrayOps() +storage(a::SparseArrayDOK) = a.storage Base.size(a::SparseArrayDOK) = a.size -storedvalues(a::SparseArrayDOK) = a.storedvalues +storedvalues(a::SparseArrayDOK) = values(storage(a)) function isstored(a::SparseArrayDOK, I::Int...) - return CartesianIndex(I) in keys(storedvalues(a)) + return CartesianIndex(I) in keys(storage(a)) end function eachstoredindex(a::SparseArrayDOK) - return keys(storedvalues(a)) + return keys(storage(a)) end function getstoredindex(a::SparseArrayDOK, I::Int...) - return storedvalues(a)[CartesianIndex(I)] + return storage(a)[CartesianIndex(I)] end function getunstoredindex(a::SparseArrayDOK, I::Int...) return zero(eltype(a)) end function setstoredindex!(a::SparseArrayDOK, value, I::Int...) - storedvalues(a)[CartesianIndex(I)] = value + storage(a)[CartesianIndex(I)] = value return a end function setunstoredindex!(a::SparseArrayDOK, value, I::Int...) - storedvalues(a)[CartesianIndex(I)] = value + storage(a)[CartesianIndex(I)] = value return a end +# Optional, but faster than the default. +storedpairs(a::SparseArrayDOK) = storage(a) + using LinearAlgebra: Adjoint storedvalues(a::Adjoint) = storedvalues(parent(a)) function isstored(a::Adjoint, i::Int, j::Int) @@ -43,7 +56,7 @@ function isstored(a::Adjoint, i::Int, j::Int) end function eachstoredindex(a::Adjoint) # TODO: Make lazy with `Iterators.map`. - return map(CartesianIndex ∘ reverse ∘ Tuple, collect(eachstoredindex(parent(a)))) + return Base.map(CartesianIndex ∘ reverse ∘ Tuple, collect(eachstoredindex(parent(a)))) end function getstoredindex(a::Adjoint, i::Int, j::Int) return getstoredindex(parent(a), j, i)' diff --git a/src/sparsearrayinterface.jl b/src/sparsearrayinterface.jl index c2d0565..d0beb43 100644 --- a/src/sparsearrayinterface.jl +++ b/src/sparsearrayinterface.jl @@ -1,4 +1,6 @@ # Minimal interface for `SparseArrayInterface`. +# TODO: Define default definitions for these based +# on the dense case. storedvalues(a) = error() isstored(a, I::Int...) = error() eachstoredindex(a) = error() @@ -9,6 +11,7 @@ setunstoredindex!(a, value, I::Int...) = error() # Derived interface. storedlength(a) = length(storedvalues(a)) +storedpairs(a) = map(I -> I => getstoredindex(a, I), eachstoredindex(a)) function eachstoredindex(a1, a2, a_rest...) # TODO: Make this more customizable, say with a function @@ -17,14 +20,18 @@ function eachstoredindex(a1, a2, a_rest...) end # TODO: Add `ndims` type parameter. +# TODO: Define `AbstractSparseArrayInterface`, make this a subtype. +using .InterfaceImplementations: AbstractArrayInterface struct SparseArrayInterface <: AbstractArrayInterface end -function interface_getindex(::SparseArrayInterface, a, I::Int...) +# TODO: Use `ArrayLayouts.layout_getindex`, `ArrayLayouts.sub_materialize` +# to handle slicing (implemented by copying SubArray). +function InterfaceImplementations.getindex(::SparseArrayInterface, a, I::Int...) !isstored(a, I...) && return getunstoredindex(a, I...) return getstoredindex(a, I...) end -function interface_setindex!(::SparseArrayInterface, a, value, I::Int...) +function InterfaceImplementations.setindex!(::SparseArrayInterface, a, value, I::Int...) iszero(value) && return a if !isstored(a, I...) setunstoredindex!(a, value, I...) @@ -36,49 +43,119 @@ end # TODO: This may need to be defined in `sparsearraydok.jl`, after `SparseArrayDOK` # is defined. And/or define `default_type(::SparseArrayStyle, T::Type) = SparseArrayDOK{T}`. -function interface_similar(::SparseArrayInterface, a, T::Type, size::Tuple{Vararg{Int}}) +function InterfaceImplementations.similar( + ::SparseArrayInterface, a, T::Type, size::Tuple{Vararg{Int}} +) return SparseArrayDOK{T}(size...) end ## TODO: Make this more general, handle mixtures of integers and ranges. ## TODO: Make this logic generic to all `similar(::AbstractInterface, ...)`. -## function interface_similar(interface::SparseArrayInterface, a, T::Type, dims::Tuple{Vararg{Base.OneTo}}) -## return similar(interface, a, T, Base.to_shape(dims)) +## function InterfaceImplementations.similar(interface::SparseArrayInterface, a, T::Type, dims::Tuple{Vararg{Base.OneTo}}) +## return InterfaceImplementations.similar(interface, a, T, Base.to_shape(dims)) ## end -function interface_map(::SparseArrayInterface, f, as...) +function InterfaceImplementations.map(::SparseArrayInterface, f, as...) # This is defined in this way so we can rely on the Broadcast logic # for determining the destination of the operation (element type, shape, etc.). return f.(as...) end -function interface_map!(::SparseArrayInterface, f, dest, as...) +function InterfaceImplementations.map!(::SparseArrayInterface, f, dest, as...) # Check `f` preserves zeros. # Define as `map_stored!`. # Define `eachstoredindex` promotion. for I in eachstoredindex(as...) - dest[I] = f(map(a -> a[I], as)...) + dest[I] = f(Base.map(a -> a[I], as)...) end return dest end +# TODO: Define `AbstractSparseArrayStyle`, make this a subtype. struct SparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}() -function interface_BroadcastStyle(::SparseArrayInterface, type::Type) +function InterfaceImplementations.BroadcastStyle(::SparseArrayInterface, type::Type) return SparseArrayStyle{ndims(type)}() end function Base.similar(bc::Broadcast.Broadcasted{<:SparseArrayStyle}, T::Type, axes::Tuple) # TODO: Allow `similar` to accept `axes` directly. - return interface_similar(SparseArrayInterface(), bc, T, Int.(length.(axes))) + return InterfaceImplementations.similar( + SparseArrayInterface(), bc, T, Int.(length.(axes)) + ) end using BroadcastMapConversion: map_function, map_args # TODO: Look into `SparseArrays.capturescalars`: # https://github.com/JuliaSparse/SparseArrays.jl/blob/1beb0e4a4618b0399907b0000c43d9f66d34accc/src/higherorderfns.jl#L1092-L1102 function Base.copyto!(dest::AbstractArray, bc::Broadcast.Broadcasted{<:SparseArrayStyle}) - interface_map!(SparseArrayInterface(), map_function(bc), dest, map_args(bc)...) + InterfaceImplementations.map!( + SparseArrayInterface(), map_function(bc), dest, map_args(bc)... + ) return dest end + +using ArrayLayouts: ArrayLayouts, MatMulMatAdd + +abstract type AbstractSparseLayout <: ArrayLayouts.MemoryLayout end + +struct SparseLayout <: AbstractSparseLayout end + +InterfaceImplementations.MemoryLayout(::SparseArrayInterface, type::Type) = SparseLayout() + +function InterfaceImplementations.mul!(::SparseArrayInterface, a_dest, a1, a2, α, β) + return ArrayLayouts.mul!(a_dest, a1, a2, α, β) +end + +function mul_indices(I1::CartesianIndex{2}, I2::CartesianIndex{2}) + if I1[2] ≠ I2[1] + return nothing + end + return CartesianIndex(I1[1], I2[2]) +end + +function default_mul!!( + a_dest::AbstractMatrix, + a1::AbstractMatrix, + a2::AbstractMatrix, + α::Number=true, + β::Number=false, +) + mul!(a_dest, a1, a2, α, β) + return a_dest +end + +function default_mul!!( + a_dest::Number, a1::Number, a2::Number, α::Number=true, β::Number=false +) + return a1 * a2 * α + a_dest * β +end + +# a1 * a2 * α + a_dest * β +function sparse_mul!( + a_dest::AbstractArray, + a1::AbstractArray, + a2::AbstractArray, + α::Number=true, + β::Number=false; + (mul!!)=(default_mul!!), +) + for I1 in eachstoredindex(a1) + for I2 in eachstoredindex(a2) + I_dest = mul_indices(I1, I2) + if !isnothing(I_dest) + a_dest[I_dest] = mul!!(a_dest[I_dest], a1[I1], a2[I2], α, β) + end + end + end + return a_dest +end + +function ArrayLayouts.materialize!( + m::MatMulMatAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout} +) + sparse_mul!(m.C, m.A, m.B, m.α, m.β) + return m.C +end diff --git a/src/wrappedarrays.jl b/src/wrappedarrays.jl deleted file mode 100644 index 3369345..0000000 --- a/src/wrappedarrays.jl +++ /dev/null @@ -1,17 +0,0 @@ -using LinearAlgebra: Adjoint, Transpose -function WrappedArrays(type::Type) - return ( - Adjoint{<:Any,<:type}, - PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:type}, - SubArray{<:Any,<:Any,<:type}, - Transpose{<:Any,<:type}, - ) -end - -function AnyArrays(type::Type) - return (type, WrappedArrays(type)...) -end - -function AnyArrays(types::Tuple) - return tuple_flatten(AnyArrays.(types)) -end