Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Jun 14, 2024
1 parent 67c93d9 commit f434eb9
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
using ArrayLayouts: ArrayLayouts
using ArrayLayouts: ArrayLayouts, MatMulMatAdd

function ArrayLayouts.MemoryLayout(arraytype::Type{<:SparseArrayLike})
return SparseLayout()
end

function ArrayLayouts.materialize!(
m::MatMulMatAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout}
)
α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C
sparse_mul!(a_dest, a1, a2, α, β)
return a_dest
end
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ end
# where there is not a stored value.
# Some types (like `Diagonal`) may not support this.
function setindex_notstored!(a::AbstractArray, value, I)
iszero(value) && return a
return throw(ArgumentError("Can't set nonzero values of $(typeof(a))."))
end

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module AbstractSparseArrays
using ArrayLayouts: ArrayLayouts, MatMulMatAdd, MemoryLayout, MulAdd
using NDTensors.SparseArrayInterface: SparseArrayInterface, AbstractSparseArray
using ArrayLayouts: ArrayLayouts, MemoryLayout, MulAdd

struct SparseArray{T,N} <: AbstractSparseArray{T,N}
data::Vector{T}
Expand All @@ -26,6 +26,13 @@ ArrayLayouts.MemoryLayout(::Type{<:SparseArray}) = SparseLayout()
function Base.similar(::MulAdd{<:SparseLayout,<:SparseLayout}, elt::Type, axes)
return similar(SparseArray{elt}, axes)
end
function ArrayLayouts.materialize!(
m::MatMulMatAdd{<:SparseLayout,<:SparseLayout,<:SparseLayout}
)
α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C
SparseArrayInterface.sparse_mul!(a_dest, a1, a2, α, β)
return a_dest
end

# AbstractArray interface
Base.size(a::SparseArray) = a.dims
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
module SparseArrays
using LinearAlgebra: LinearAlgebra
using NDTensors.SparseArrayInterface: SparseArrayInterface

struct SparseArray{T,N} <: AbstractArray{T,N}
Expand All @@ -19,6 +20,18 @@ function SparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}) where {T}
end
SparseArray{T}(dims::Vararg{Int}) where {T} = SparseArray{T}(dims)

# LinearAlgebra interface
function LinearAlgebra.mul!(
a_dest::AbstractMatrix,
a1::SparseArray{<:Any,2},
a2::SparseArray{<:Any,2},
α::Number,
β::Number,
)
SparseArrayInterface.sparse_mul!(a_dest, a1, a2, α, β)
return a_dest
end

# AbstractArray interface
Base.size(a::SparseArray) = a.dims
function Base.similar(a::SparseArray, elt::Type, dims::Tuple{Vararg{Int}})
Expand All @@ -28,8 +41,11 @@ end
function Base.getindex(a::SparseArray, I...)
return SparseArrayInterface.sparse_getindex(a, I...)
end
function Base.setindex!(a::SparseArray, I...)
return SparseArrayInterface.sparse_setindex!(a, I...)
function Base.setindex!(a::SparseArray, value, I...)
return SparseArrayInterface.sparse_setindex!(a, value, I...)
end
function Base.fill!(a::SparseArray, value)
return SparseArrayInterface.sparse_fill!(a, value)
end

# Minimal interface
Expand Down

0 comments on commit f434eb9

Please sign in to comment.