From b079fd2873f58c947875ab94cf4a7b1a6a58fdd9 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 2 Dec 2024 15:04:48 -0500 Subject: [PATCH] Rewrite --- README.md | 7 +- examples/README.jl | 7 +- src/SparseArraysBaseNext.jl | 7 +- src/abstractarrayinterface.jl | 111 ----------- src/derive.jl | 16 -- .../InterfaceImplementations.jl | 9 + .../abstractarrayinterface.jl | 183 ++++++++++++++++++ .../abstractinterface.jl | 0 .../defaultarrayinterface.jl | 0 src/lib/InterfaceImplementations/derive.jl | 29 +++ .../InterfaceImplementations/wrappedarrays.jl | 14 ++ src/sparsearraydok.jl | 32 +-- src/sparsearrayinterface.jl | 34 ++-- src/wrappedarrays.jl | 9 - 14 files changed, 288 insertions(+), 170 deletions(-) delete mode 100644 src/abstractarrayinterface.jl delete mode 100644 src/derive.jl create mode 100644 src/lib/InterfaceImplementations/InterfaceImplementations.jl create mode 100644 src/lib/InterfaceImplementations/abstractarrayinterface.jl rename src/{ => lib/InterfaceImplementations}/abstractinterface.jl (100%) rename src/{ => lib/InterfaceImplementations}/defaultarrayinterface.jl (100%) create mode 100644 src/lib/InterfaceImplementations/derive.jl create mode 100644 src/lib/InterfaceImplementations/wrappedarrays.jl delete mode 100644 src/wrappedarrays.jl diff --git a/README.md b/README.md index 3b3568b..0761df8 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ julia> Pkg.add(url="https://github.com/ITensor/SparseArraysBaseNext.jl") ````julia using SparseArraysBaseNext: - SparseArrayDOK, eachstoredindex, isstored, storedlength, storedvalues + SparseArrayDOK, eachstoredindex, isstored, storedlength, storedpairs, storedvalues using Test: @test a = SparseArrayDOK{Float64}(2, 2) @@ -35,9 +35,10 @@ b = a .+ 2 .* a' @test b[2, 1] == 21 + 2 * 12 @test b[1, 2] == 12 + 2 * 21 @test b[2, 2] == 0 -@test storedvalues(b) == - Dict(CartesianIndex(2, 1) => 21 + 2 * 12, CartesianIndex(1, 2) => 12 + 2 * 21) +@test issetequal(storedvalues(b), [21 + 2 * 12, 12 + 2 * 21]) @test issetequal(eachstoredindex(b), [CartesianIndex(2, 1), CartesianIndex(1, 2)]) +@test storedpairs(b) == + Dict(CartesianIndex(2, 1) => 21 + 2 * 12, CartesianIndex(1, 2) => 12 + 2 * 21) @test !isstored(b, 1, 1) @test isstored(b, 2, 1) @test isstored(b, 1, 2) diff --git a/examples/README.jl b/examples/README.jl index ff0d3e5..cdc4e1e 100644 --- a/examples/README.jl +++ b/examples/README.jl @@ -20,7 +20,7 @@ julia> Pkg.add(url="https://github.com/ITensor/SparseArraysBaseNext.jl") # ## Examples using SparseArraysBaseNext: - SparseArrayDOK, eachstoredindex, isstored, storedlength, storedvalues + SparseArrayDOK, eachstoredindex, isstored, storedlength, storedpairs, storedvalues using Test: @test a = SparseArrayDOK{Float64}(2, 2) @@ -36,9 +36,10 @@ b = a .+ 2 .* a' @test b[2, 1] == 21 + 2 * 12 @test b[1, 2] == 12 + 2 * 21 @test b[2, 2] == 0 -@test storedvalues(b) == - Dict(CartesianIndex(2, 1) => 21 + 2 * 12, CartesianIndex(1, 2) => 12 + 2 * 21) +@test issetequal(storedvalues(b), [21 + 2 * 12, 12 + 2 * 21]) @test issetequal(eachstoredindex(b), [CartesianIndex(2, 1), CartesianIndex(1, 2)]) +@test storedpairs(b) == + Dict(CartesianIndex(2, 1) => 21 + 2 * 12, CartesianIndex(1, 2) => 12 + 2 * 21) @test !isstored(b, 1, 1) @test isstored(b, 2, 1) @test isstored(b, 1, 2) diff --git a/src/SparseArraysBaseNext.jl b/src/SparseArraysBaseNext.jl index 73ee183..035ccf5 100644 --- a/src/SparseArraysBaseNext.jl +++ b/src/SparseArraysBaseNext.jl @@ -1,10 +1,7 @@ module SparseArraysBaseNext -include("derive.jl") -include("wrappedarrays.jl") -include("abstractinterface.jl") -include("abstractarrayinterface.jl") -include("defaultarrayinterface.jl") +include("lib/InterfaceImplementations/InterfaceImplementations.jl") +using .InterfaceImplementations: InterfaceImplementations include("sparsearrayinterface.jl") include("sparsearraydok.jl") diff --git a/src/abstractarrayinterface.jl b/src/abstractarrayinterface.jl deleted file mode 100644 index 963bd97..0000000 --- a/src/abstractarrayinterface.jl +++ /dev/null @@ -1,111 +0,0 @@ -# Like the equivalent `Base` functions but allow overloading -# by interface in the first argument. -# Overloading `Base` directly leads to too many method ambiguities. -# TODO: Define generic fallbacks that use `invoke`. -function getindex end -function setindex! end -function similar end -function map end -function map! end -function BroadcastStyle end -function MemoryLayout end -function mul! end - -# TODO: Add `ndims` type parameter. -abstract type AbstractArrayInterface <: AbstractInterface end - -function AbstractArrayOps() - return ( - AbstractArrayGetIndex(), - AbstractArraySetIndex(), - AbstractArraySimilar(), - AbstractArrayMap(), - AbstractArrayBroadcast(), - AbstractArrayMemoryLayout(), - AbstractArrayMul(), - ) -end - -struct AbstractArrayGetIndex <: AbstractArrayInterface end - -function derive(type::Type, ::AbstractArrayGetIndex) - return quote - # TODO: Use `ArrayLayouts.layout_getindex`, `ArrayLayouts.sub_materialize` - # to handle slicing (implemented by copying SubArray). - function Base.getindex(a::$type, I::Int...) - return $getindex($AbstractInterface(a), a, I...) - end - end -end - -struct AbstractArraySetIndex <: AbstractArrayInterface end - -function derive(type::Type, ::AbstractArraySetIndex) - return quote - function Base.setindex!(a::$type, value, I::Int...) - $setindex!($AbstractInterface(a), a, value, I...) - return a - end - end -end - -struct AbstractArraySimilar <: AbstractArrayInterface end - -function derive(type::Type, ::AbstractArraySimilar) - return quote - # TODO: Generalize to axes. - function Base.similar(a::$type, T::Type, I::Tuple{Vararg{Int}}) - return $similar($AbstractInterface(a), a, T, I) - end - end -end - -struct AbstractArrayMap <: AbstractArrayInterface end - -function derive(type::Type, ::AbstractArrayMap) - return quote - function Base.map(f, as::$type...) - # TODO: Define interface promotion, maybe just use `BroadcastStyle` directly. - # https://docs.julialang.org/en/v1/manual/interfaces/#writing-binary-broadcasting-rules - return $map($AbstractInterface($AbstractInterface.(as)...), f, as...) - end - function Base.map!(f, dest::$type, as::AbstractArray...) - # TODO: Define interface promotion, maybe just use `BroadcastStyle` directly. - # https://docs.julialang.org/en/v1/manual/interfaces/#writing-binary-broadcasting-rules - $map!($AbstractInterface($AbstractInterface.(as)...), f, dest, as...) - return dest - end - end -end - -struct AbstractArrayBroadcast <: AbstractArrayInterface end - -function derive(type::Type, ::AbstractArrayBroadcast) - return quote - function Base.BroadcastStyle(type::Type{<:$type}) - return $BroadcastStyle($AbstractInterface(type), type) - end - end -end - -struct AbstractArrayMemoryLayout <: AbstractArrayInterface end - -using ArrayLayouts: ArrayLayouts -function derive(type::Type, ::AbstractArrayMemoryLayout) - return quote - function $ArrayLayouts.MemoryLayout(type::Type{<:$type}) - return $MemoryLayout($AbstractInterface(type), type) - end - end -end - -struct AbstractArrayMul <: AbstractArrayInterface end - -using LinearAlgebra: LinearAlgebra -function derive(type::Type, ::AbstractArrayMul) - return quote - function $LinearAlgebra.mul!(dest::$type, a::$type, b::$type, α::Number, β::Number) - return $mul!($AbstractInterface(a), dest, a, b, α, β) - end - end -end diff --git a/src/derive.jl b/src/derive.jl deleted file mode 100644 index be53343..0000000 --- a/src/derive.jl +++ /dev/null @@ -1,16 +0,0 @@ -macro derive(type, interface) - return esc(derive(type, interface)) -end - -function derive(type::Union{Symbol,Expr}, interface::Union{Symbol,Expr}) - return derive(eval(type), eval(interface)) -end - -function derive(type::Type, interfaces::Tuple) - expr = Expr(:block) - for interface in interfaces - subexpr = derive(type, interface) - expr = Expr(:block, expr.args..., subexpr.args...) - end - return expr -end diff --git a/src/lib/InterfaceImplementations/InterfaceImplementations.jl b/src/lib/InterfaceImplementations/InterfaceImplementations.jl new file mode 100644 index 0000000..8ebf7fe --- /dev/null +++ b/src/lib/InterfaceImplementations/InterfaceImplementations.jl @@ -0,0 +1,9 @@ +module InterfaceImplementations + +include("abstractinterface.jl") +include("derive.jl") +include("wrappedarrays.jl") +include("abstractarrayinterface.jl") +include("defaultarrayinterface.jl") + +end diff --git a/src/lib/InterfaceImplementations/abstractarrayinterface.jl b/src/lib/InterfaceImplementations/abstractarrayinterface.jl new file mode 100644 index 0000000..8bc45bf --- /dev/null +++ b/src/lib/InterfaceImplementations/abstractarrayinterface.jl @@ -0,0 +1,183 @@ +# Like the equivalent `Base` functions but allow overloading +# by interface in the first argument. +# Overloading `Base` directly leads to too many method ambiguities. +# TODO: Define generic fallbacks that use `invoke`. +function getindex end +function setindex! end +function similar end +function map end +function map! end +function BroadcastStyle end +function MemoryLayout end +function mul! end + +# TODO: Add `ndims` type parameter. +abstract type AbstractArrayInterface <: AbstractInterface end + +function AbstractArrayOps() + return ( + Base.getindex, + Base.setindex!, + Base.similar, + Base.map, + Base.map!, + Base.Broadcast.BroadcastStyle, + ArrayLayouts.MemoryLayout, + LinearAlgebra.mul!, + ) +end + +function derive(op::typeof(Base.getindex), interface::AbstractArrayInterface, type::Type) + return quote + # TODO: Use `ArrayLayouts.layout_getindex`, `ArrayLayouts.sub_materialize` + # to handle slicing (implemented by copying SubArray). + function Base.getindex(a::$type, I::Int...) + return $getindex($interface, a, I...) + end + end +end +function derive(op::typeof(Base.getindex), type::Type) + return quote + function Base.getindex(a::$type, I::Int...) + return $getindex($AbstractInterface(a), a, I...) + end + end +end + +function derive(op::typeof(Base.setindex!), interface::AbstractArrayInterface, type::Type) + return quote + function Base.setindex!(a::$type, value, I::Int...) + $setindex!($interface, a, value, I...) + return a + end + end +end +function derive(op::typeof(Base.setindex!), type::Type) + return quote + function Base.setindex!(a::$type, value, I::Int...) + $setindex!($AbstractInterface(a), a, value, I...) + return a + end + end +end + +function derive(op::typeof(Base.similar), interface::AbstractArrayInterface, type::Type) + return quote + # TODO: Generalize to axes. + function Base.similar(a::$type, T::Type, I::Tuple{Vararg{Int}}) + return $similar($interface, a, T, I) + end + end +end +function derive(op::typeof(Base.similar), type::Type) + return quote + # TODO: Generalize to axes. + function Base.similar(a::$type, T::Type, I::Tuple{Vararg{Int}}) + return $similar($AbstractInterface(a), a, T, I) + end + end +end + +function derive(op::typeof(Base.map), interface::AbstractArrayInterface, type::Type) + return quote + function Base.map(f, as::$type...) + # TODO: Define interface promotion, maybe just use `BroadcastStyle` directly. + # https://docs.julialang.org/en/v1/manual/interfaces/#writing-binary-broadcasting-rules + return $map($interface, f, as...) + end + end +end +function derive(op::typeof(Base.map), type::Type) + return quote + function Base.map(f, as::$type...) + # TODO: Define interface promotion, maybe just use `BroadcastStyle` directly. + # https://docs.julialang.org/en/v1/manual/interfaces/#writing-binary-broadcasting-rules + return $map($AbstractInterface(AbstractInterface.(as)...), f, as...) + end + end +end + +function derive(op::typeof(Base.map!), interface::AbstractArrayInterface, type::Type) + return quote + function Base.map!(f, dest::$type, as::AbstractArray...) + # TODO: Define interface promotion, maybe just use `BroadcastStyle` directly. + # https://docs.julialang.org/en/v1/manual/interfaces/#writing-binary-broadcasting-rules + $map!($interface, f, dest, as...) + return dest + end + end +end +function derive(op::typeof(Base.map!), type::Type) + return quote + function Base.map!(f, dest::$type, as::AbstractArray...) + # TODO: Define interface promotion, maybe just use `BroadcastStyle` directly. + # https://docs.julialang.org/en/v1/manual/interfaces/#writing-binary-broadcasting-rules + $map!($AbstractInterface(AbstractInterface.(as)...), f, dest, as...) + return dest + end + end +end + +function derive( + op::Type{Base.Broadcast.BroadcastStyle}, interface::AbstractArrayInterface, type::Type +) + return quote + function Base.Broadcast.BroadcastStyle(type::Type{<:$type}) + return $BroadcastStyle($interface, type) + end + end +end +function derive(op::Type{Base.Broadcast.BroadcastStyle}, type::Type) + return quote + function Base.Broadcast.BroadcastStyle(type::Type{<:$type}) + return $BroadcastStyle($AbstractInterface(type), type) + end + end +end + +using ArrayLayouts: ArrayLayouts +function derive( + op::Type{ArrayLayouts.MemoryLayout}, interface::AbstractArrayInterface, type::Type +) + return quote + function $ArrayLayouts.MemoryLayout(type::Type{<:$type}) + return $MemoryLayout($interface, type) + end + end +end +function derive(op::Type{ArrayLayouts.MemoryLayout}, type::Type) + return quote + function $ArrayLayouts.MemoryLayout(type::Type{<:$type}) + return $MemoryLayout($AbstractInterface(type), type) + end + end +end + +using LinearAlgebra: LinearAlgebra +function derive( + op::typeof(LinearAlgebra.mul!), interface::AbstractArrayInterface, type::Type +) + return quote + function $LinearAlgebra.mul!(dest::$type, a::$type, b::$type, α::Number, β::Number) + # TODO: Determine from `a` and `b`. + return $mul!($interface, dest, a, b, α, β) + end + end +end +function derive(op::typeof(LinearAlgebra.mul!), type::Type) + return quote + function $LinearAlgebra.mul!(dest::$type, a::$type, b::$type, α::Number, β::Number) + # TODO: Determine from `a` and `b`. + return $mul!( + $AbstractInterface( + $AbstractInterface(dest), $AbstractInterface(a), $AbstractInterface(b) + ), + dest, + a, + b, + α, + β, + ) + end + end +end diff --git a/src/abstractinterface.jl b/src/lib/InterfaceImplementations/abstractinterface.jl similarity index 100% rename from src/abstractinterface.jl rename to src/lib/InterfaceImplementations/abstractinterface.jl diff --git a/src/defaultarrayinterface.jl b/src/lib/InterfaceImplementations/defaultarrayinterface.jl similarity index 100% rename from src/defaultarrayinterface.jl rename to src/lib/InterfaceImplementations/defaultarrayinterface.jl diff --git a/src/lib/InterfaceImplementations/derive.jl b/src/lib/InterfaceImplementations/derive.jl new file mode 100644 index 0000000..2619a5e --- /dev/null +++ b/src/lib/InterfaceImplementations/derive.jl @@ -0,0 +1,29 @@ +macro derive(type, interface, ops) + return esc(derive(__module__, ops, interface, type)) +end + +macro derive(type, ops) + return esc(derive(__module__, ops, type)) +end + +function derive(mod::Module, ops::Union{Symbol,Expr}, type::Union{Symbol,Expr}) + return derive(Base.eval(mod, ops), Base.eval(mod, type)) +end + +function derive( + mod::Module, + ops::Union{Symbol,Expr}, + interface::Union{Symbol,Expr}, + type::Union{Symbol,Expr}, +) + return derive(Base.eval(mod, ops), Base.eval(mod, interface), Base.eval(mod, type)) +end + +function derive(ops::Tuple, args...) + expr = Expr(:block) + for op in ops + subexpr = derive(op, args...) + expr = Expr(:block, expr.args..., subexpr.args...) + end + return expr +end diff --git a/src/lib/InterfaceImplementations/wrappedarrays.jl b/src/lib/InterfaceImplementations/wrappedarrays.jl new file mode 100644 index 0000000..ea87dd5 --- /dev/null +++ b/src/lib/InterfaceImplementations/wrappedarrays.jl @@ -0,0 +1,14 @@ +using Adapt: WrappedArray + +macro wrappedtype(type) + return esc(wrappedtype(type)) +end + +function wrappedtype(type::Symbol) + wrappedtype = Symbol(:Wrapped, type) + anytype = Symbol(:Any, type) + return quote + const $wrappedtype{T,N} = $WrappedArray{T,N,$type,$type{T,N}} + const $anytype{T,N} = Union{$type{T,N},$wrappedtype{T,N}} + end +end diff --git a/src/sparsearraydok.jl b/src/sparsearraydok.jl index 6bb9cbb..f35143c 100644 --- a/src/sparsearraydok.jl +++ b/src/sparsearraydok.jl @@ -1,6 +1,6 @@ # TODO: Define `AbstractSparseArray`, make this a subtype. struct SparseArrayDOK{T,N} <: AbstractArray{T,N} - storedvalues::Dict{CartesianIndex{N},T} + storage::Dict{CartesianIndex{N},T} size::NTuple{N,Int} end @@ -9,36 +9,46 @@ function SparseArrayDOK{T}(size::Int...) where {T} return SparseArrayDOK{T,N}(Dict{CartesianIndex{N},T}(), size) end -AbstractInterface(::Type{<:SparseArrayDOK}) = SparseArrayInterface() +using .InterfaceImplementations: @wrappedtype +# Define `WrappedSparseArrayDOK` and `AnySparseArrayDOK`. +@wrappedtype SparseArrayDOK -# TODO: Define `WrappedSparseArrayDOK`, `AnySparseArrayDOK`, define macros -# to make those easier to define. -@derive AnyArrayType(SparseArrayDOK) AbstractArrayOps() +using .InterfaceImplementations: InterfaceImplementations +function InterfaceImplementations.AbstractInterface(::Type{<:SparseArrayDOK}) + return SparseArrayInterface() +end + +using .InterfaceImplementations: AbstractArrayOps, @derive +@derive AnySparseArrayDOK AbstractArrayOps() +storage(a::SparseArrayDOK) = a.storage Base.size(a::SparseArrayDOK) = a.size -storedvalues(a::SparseArrayDOK) = a.storedvalues +storedvalues(a::SparseArrayDOK) = values(storage(a)) function isstored(a::SparseArrayDOK, I::Int...) - return CartesianIndex(I) in keys(storedvalues(a)) + return CartesianIndex(I) in keys(storage(a)) end function eachstoredindex(a::SparseArrayDOK) - return keys(storedvalues(a)) + return keys(storage(a)) end function getstoredindex(a::SparseArrayDOK, I::Int...) - return storedvalues(a)[CartesianIndex(I)] + return storage(a)[CartesianIndex(I)] end function getunstoredindex(a::SparseArrayDOK, I::Int...) return zero(eltype(a)) end function setstoredindex!(a::SparseArrayDOK, value, I::Int...) - storedvalues(a)[CartesianIndex(I)] = value + storage(a)[CartesianIndex(I)] = value return a end function setunstoredindex!(a::SparseArrayDOK, value, I::Int...) - storedvalues(a)[CartesianIndex(I)] = value + storage(a)[CartesianIndex(I)] = value return a end +# Optional, but faster than the default. +storedpairs(a::SparseArrayDOK) = storage(a) + using LinearAlgebra: Adjoint storedvalues(a::Adjoint) = storedvalues(parent(a)) function isstored(a::Adjoint, i::Int, j::Int) diff --git a/src/sparsearrayinterface.jl b/src/sparsearrayinterface.jl index 157b509..d0beb43 100644 --- a/src/sparsearrayinterface.jl +++ b/src/sparsearrayinterface.jl @@ -1,4 +1,6 @@ # Minimal interface for `SparseArrayInterface`. +# TODO: Define default definitions for these based +# on the dense case. storedvalues(a) = error() isstored(a, I::Int...) = error() eachstoredindex(a) = error() @@ -9,6 +11,7 @@ setunstoredindex!(a, value, I::Int...) = error() # Derived interface. storedlength(a) = length(storedvalues(a)) +storedpairs(a) = map(I -> I => getstoredindex(a, I), eachstoredindex(a)) function eachstoredindex(a1, a2, a_rest...) # TODO: Make this more customizable, say with a function @@ -18,16 +21,17 @@ end # TODO: Add `ndims` type parameter. # TODO: Define `AbstractSparseArrayInterface`, make this a subtype. +using .InterfaceImplementations: AbstractArrayInterface struct SparseArrayInterface <: AbstractArrayInterface end # TODO: Use `ArrayLayouts.layout_getindex`, `ArrayLayouts.sub_materialize` # to handle slicing (implemented by copying SubArray). -function getindex(::SparseArrayInterface, a, I::Int...) +function InterfaceImplementations.getindex(::SparseArrayInterface, a, I::Int...) !isstored(a, I...) && return getunstoredindex(a, I...) return getstoredindex(a, I...) end -function setindex!(::SparseArrayInterface, a, value, I::Int...) +function InterfaceImplementations.setindex!(::SparseArrayInterface, a, value, I::Int...) iszero(value) && return a if !isstored(a, I...) setunstoredindex!(a, value, I...) @@ -39,23 +43,25 @@ end # TODO: This may need to be defined in `sparsearraydok.jl`, after `SparseArrayDOK` # is defined. And/or define `default_type(::SparseArrayStyle, T::Type) = SparseArrayDOK{T}`. -function similar(::SparseArrayInterface, a, T::Type, size::Tuple{Vararg{Int}}) +function InterfaceImplementations.similar( + ::SparseArrayInterface, a, T::Type, size::Tuple{Vararg{Int}} +) return SparseArrayDOK{T}(size...) end ## TODO: Make this more general, handle mixtures of integers and ranges. ## TODO: Make this logic generic to all `similar(::AbstractInterface, ...)`. -## function similar(interface::SparseArrayInterface, a, T::Type, dims::Tuple{Vararg{Base.OneTo}}) -## return similar(interface, a, T, Base.to_shape(dims)) +## function InterfaceImplementations.similar(interface::SparseArrayInterface, a, T::Type, dims::Tuple{Vararg{Base.OneTo}}) +## return InterfaceImplementations.similar(interface, a, T, Base.to_shape(dims)) ## end -function map(::SparseArrayInterface, f, as...) +function InterfaceImplementations.map(::SparseArrayInterface, f, as...) # This is defined in this way so we can rely on the Broadcast logic # for determining the destination of the operation (element type, shape, etc.). return f.(as...) end -function map!(::SparseArrayInterface, f, dest, as...) +function InterfaceImplementations.map!(::SparseArrayInterface, f, dest, as...) # Check `f` preserves zeros. # Define as `map_stored!`. # Define `eachstoredindex` promotion. @@ -70,20 +76,24 @@ struct SparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}() -function BroadcastStyle(::SparseArrayInterface, type::Type) +function InterfaceImplementations.BroadcastStyle(::SparseArrayInterface, type::Type) return SparseArrayStyle{ndims(type)}() end function Base.similar(bc::Broadcast.Broadcasted{<:SparseArrayStyle}, T::Type, axes::Tuple) # TODO: Allow `similar` to accept `axes` directly. - return SparseArraysBaseNext.similar(SparseArrayInterface(), bc, T, Int.(length.(axes))) + return InterfaceImplementations.similar( + SparseArrayInterface(), bc, T, Int.(length.(axes)) + ) end using BroadcastMapConversion: map_function, map_args # TODO: Look into `SparseArrays.capturescalars`: # https://github.com/JuliaSparse/SparseArrays.jl/blob/1beb0e4a4618b0399907b0000c43d9f66d34accc/src/higherorderfns.jl#L1092-L1102 function Base.copyto!(dest::AbstractArray, bc::Broadcast.Broadcasted{<:SparseArrayStyle}) - SparseArraysBaseNext.map!(SparseArrayInterface(), map_function(bc), dest, map_args(bc)...) + InterfaceImplementations.map!( + SparseArrayInterface(), map_function(bc), dest, map_args(bc)... + ) return dest end @@ -93,9 +103,9 @@ abstract type AbstractSparseLayout <: ArrayLayouts.MemoryLayout end struct SparseLayout <: AbstractSparseLayout end -MemoryLayout(::SparseArrayInterface, type::Type) = SparseLayout() +InterfaceImplementations.MemoryLayout(::SparseArrayInterface, type::Type) = SparseLayout() -function mul!(::SparseArrayInterface, a_dest, a1, a2, α, β) +function InterfaceImplementations.mul!(::SparseArrayInterface, a_dest, a1, a2, α, β) return ArrayLayouts.mul!(a_dest, a1, a2, α, β) end diff --git a/src/wrappedarrays.jl b/src/wrappedarrays.jl deleted file mode 100644 index 01f5527..0000000 --- a/src/wrappedarrays.jl +++ /dev/null @@ -1,9 +0,0 @@ -using Adapt: WrappedArray - -function WrappedArrayType(type::Type) - return WrappedArray{<:Any,<:Any,type,type{<:Any,<:Any}} -end - -function AnyArrayType(type::Type) - return Union{type{<:Any,<:Any},WrappedArrayType(type)} -end