Skip to content

Commit

Permalink
Rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Dec 2, 2024
1 parent a414f00 commit b079fd2
Show file tree
Hide file tree
Showing 14 changed files with 288 additions and 170 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ julia> Pkg.add(url="https://github.com/ITensor/SparseArraysBaseNext.jl")

````julia
using SparseArraysBaseNext:
SparseArrayDOK, eachstoredindex, isstored, storedlength, storedvalues
SparseArrayDOK, eachstoredindex, isstored, storedlength, storedpairs, storedvalues
using Test: @test

a = SparseArrayDOK{Float64}(2, 2)
Expand All @@ -35,9 +35,10 @@ b = a .+ 2 .* a'
@test b[2, 1] == 21 + 2 * 12
@test b[1, 2] == 12 + 2 * 21
@test b[2, 2] == 0
@test storedvalues(b) ==
Dict(CartesianIndex(2, 1) => 21 + 2 * 12, CartesianIndex(1, 2) => 12 + 2 * 21)
@test issetequal(storedvalues(b), [21 + 2 * 12, 12 + 2 * 21])
@test issetequal(eachstoredindex(b), [CartesianIndex(2, 1), CartesianIndex(1, 2)])
@test storedpairs(b) ==
Dict(CartesianIndex(2, 1) => 21 + 2 * 12, CartesianIndex(1, 2) => 12 + 2 * 21)
@test !isstored(b, 1, 1)
@test isstored(b, 2, 1)
@test isstored(b, 1, 2)
Expand Down
7 changes: 4 additions & 3 deletions examples/README.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ julia> Pkg.add(url="https://github.com/ITensor/SparseArraysBaseNext.jl")
# ## Examples

using SparseArraysBaseNext:
SparseArrayDOK, eachstoredindex, isstored, storedlength, storedvalues
SparseArrayDOK, eachstoredindex, isstored, storedlength, storedpairs, storedvalues
using Test: @test

a = SparseArrayDOK{Float64}(2, 2)
Expand All @@ -36,9 +36,10 @@ b = a .+ 2 .* a'
@test b[2, 1] == 21 + 2 * 12
@test b[1, 2] == 12 + 2 * 21
@test b[2, 2] == 0
@test storedvalues(b) ==
Dict(CartesianIndex(2, 1) => 21 + 2 * 12, CartesianIndex(1, 2) => 12 + 2 * 21)
@test issetequal(storedvalues(b), [21 + 2 * 12, 12 + 2 * 21])
@test issetequal(eachstoredindex(b), [CartesianIndex(2, 1), CartesianIndex(1, 2)])
@test storedpairs(b) ==
Dict(CartesianIndex(2, 1) => 21 + 2 * 12, CartesianIndex(1, 2) => 12 + 2 * 21)
@test !isstored(b, 1, 1)
@test isstored(b, 2, 1)
@test isstored(b, 1, 2)
Expand Down
7 changes: 2 additions & 5 deletions src/SparseArraysBaseNext.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
module SparseArraysBaseNext

include("derive.jl")
include("wrappedarrays.jl")
include("abstractinterface.jl")
include("abstractarrayinterface.jl")
include("defaultarrayinterface.jl")
include("lib/InterfaceImplementations/InterfaceImplementations.jl")
using .InterfaceImplementations: InterfaceImplementations
include("sparsearrayinterface.jl")
include("sparsearraydok.jl")

Expand Down
111 changes: 0 additions & 111 deletions src/abstractarrayinterface.jl

This file was deleted.

16 changes: 0 additions & 16 deletions src/derive.jl

This file was deleted.

9 changes: 9 additions & 0 deletions src/lib/InterfaceImplementations/InterfaceImplementations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module InterfaceImplementations

include("abstractinterface.jl")
include("derive.jl")
include("wrappedarrays.jl")
include("abstractarrayinterface.jl")
include("defaultarrayinterface.jl")

end
183 changes: 183 additions & 0 deletions src/lib/InterfaceImplementations/abstractarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Like the equivalent `Base` functions but allow overloading
# by interface in the first argument.
# Overloading `Base` directly leads to too many method ambiguities.
# TODO: Define generic fallbacks that use `invoke`.
function getindex end
function setindex! end
function similar end
function map end
function map! end
function BroadcastStyle end
function MemoryLayout end
function mul! end

# TODO: Add `ndims` type parameter.
abstract type AbstractArrayInterface <: AbstractInterface end

function AbstractArrayOps()
return (
Base.getindex,
Base.setindex!,
Base.similar,
Base.map,
Base.map!,
Base.Broadcast.BroadcastStyle,
ArrayLayouts.MemoryLayout,
LinearAlgebra.mul!,
)
end

function derive(op::typeof(Base.getindex), interface::AbstractArrayInterface, type::Type)
return quote
# TODO: Use `ArrayLayouts.layout_getindex`, `ArrayLayouts.sub_materialize`
# to handle slicing (implemented by copying SubArray).
function Base.getindex(a::$type, I::Int...)
return $getindex($interface, a, I...)
end
end
end
function derive(op::typeof(Base.getindex), type::Type)
return quote
function Base.getindex(a::$type, I::Int...)
return $getindex($AbstractInterface(a), a, I...)
end
end
end

function derive(op::typeof(Base.setindex!), interface::AbstractArrayInterface, type::Type)
return quote
function Base.setindex!(a::$type, value, I::Int...)
$setindex!($interface, a, value, I...)
return a
end
end
end
function derive(op::typeof(Base.setindex!), type::Type)
return quote
function Base.setindex!(a::$type, value, I::Int...)
$setindex!($AbstractInterface(a), a, value, I...)
return a
end
end
end

function derive(op::typeof(Base.similar), interface::AbstractArrayInterface, type::Type)
return quote
# TODO: Generalize to axes.
function Base.similar(a::$type, T::Type, I::Tuple{Vararg{Int}})
return $similar($interface, a, T, I)
end
end
end
function derive(op::typeof(Base.similar), type::Type)
return quote
# TODO: Generalize to axes.
function Base.similar(a::$type, T::Type, I::Tuple{Vararg{Int}})
return $similar($AbstractInterface(a), a, T, I)
end
end
end

function derive(op::typeof(Base.map), interface::AbstractArrayInterface, type::Type)
return quote
function Base.map(f, as::$type...)
# TODO: Define interface promotion, maybe just use `BroadcastStyle` directly.
# https://docs.julialang.org/en/v1/manual/interfaces/#writing-binary-broadcasting-rules
return $map($interface, f, as...)
end
end
end
function derive(op::typeof(Base.map), type::Type)
return quote
function Base.map(f, as::$type...)
# TODO: Define interface promotion, maybe just use `BroadcastStyle` directly.
# https://docs.julialang.org/en/v1/manual/interfaces/#writing-binary-broadcasting-rules
return $map($AbstractInterface(AbstractInterface.(as)...), f, as...)
end
end
end

function derive(op::typeof(Base.map!), interface::AbstractArrayInterface, type::Type)
return quote
function Base.map!(f, dest::$type, as::AbstractArray...)
# TODO: Define interface promotion, maybe just use `BroadcastStyle` directly.
# https://docs.julialang.org/en/v1/manual/interfaces/#writing-binary-broadcasting-rules
$map!($interface, f, dest, as...)
return dest
end
end
end
function derive(op::typeof(Base.map!), type::Type)
return quote
function Base.map!(f, dest::$type, as::AbstractArray...)
# TODO: Define interface promotion, maybe just use `BroadcastStyle` directly.
# https://docs.julialang.org/en/v1/manual/interfaces/#writing-binary-broadcasting-rules
$map!($AbstractInterface(AbstractInterface.(as)...), f, dest, as...)
return dest
end
end
end

function derive(
op::Type{Base.Broadcast.BroadcastStyle}, interface::AbstractArrayInterface, type::Type
)
return quote
function Base.Broadcast.BroadcastStyle(type::Type{<:$type})
return $BroadcastStyle($interface, type)
end
end
end
function derive(op::Type{Base.Broadcast.BroadcastStyle}, type::Type)
return quote
function Base.Broadcast.BroadcastStyle(type::Type{<:$type})
return $BroadcastStyle($AbstractInterface(type), type)
end
end
end

using ArrayLayouts: ArrayLayouts
function derive(
op::Type{ArrayLayouts.MemoryLayout}, interface::AbstractArrayInterface, type::Type
)
return quote
function $ArrayLayouts.MemoryLayout(type::Type{<:$type})
return $MemoryLayout($interface, type)
end
end
end
function derive(op::Type{ArrayLayouts.MemoryLayout}, type::Type)
return quote
function $ArrayLayouts.MemoryLayout(type::Type{<:$type})
return $MemoryLayout($AbstractInterface(type), type)
end
end
end

using LinearAlgebra: LinearAlgebra
function derive(
op::typeof(LinearAlgebra.mul!), interface::AbstractArrayInterface, type::Type
)
return quote
function $LinearAlgebra.mul!(dest::$type, a::$type, b::$type, α::Number, β::Number)
# TODO: Determine from `a` and `b`.
return $mul!($interface, dest, a, b, α, β)
end
end
end
function derive(op::typeof(LinearAlgebra.mul!), type::Type)
return quote
function $LinearAlgebra.mul!(dest::$type, a::$type, b::$type, α::Number, β::Number)
# TODO: Determine from `a` and `b`.
return $mul!(
$AbstractInterface(
$AbstractInterface(dest), $AbstractInterface(a), $AbstractInterface(b)
),
dest,
a,
b,
α,
β,
)
end
end
end
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit b079fd2

Please sign in to comment.