Skip to content

Commit

Permalink
[NDTensors] Add DiagonalArrays submodule (#1225)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Oct 30, 2023
1 parent 690b219 commit c8664bf
Show file tree
Hide file tree
Showing 10 changed files with 274 additions and 12 deletions.
1 change: 0 additions & 1 deletion NDTensors/src/BlockSparseArrays/src/blocksparsearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ function BlockArrays.viewblock(block_arr::BlockSparseArray, block)
# TODO: Make this `Zeros`?
## zero = zeros(eltype(block_arr), block_size)
return block_arr.blocks[blks...] # Fails because zero isn't defined
## return get_nonzero(block_arr.blocks, blks, zero)
end

function Base.getindex(block_arr::BlockSparseArray{T,N}, bi::BlockIndex{N}) where {T,N}
Expand Down
13 changes: 2 additions & 11 deletions NDTensors/src/BlockSparseArrays/src/sparsearray.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# TODO: Define a constructor with a default `zero`.
struct SparseArray{T,N,Zero} <: AbstractArray{T,N}
data::Dictionary{CartesianIndex{N},T}
dims::NTuple{N,Int64}
dims::NTuple{N,Int}
zero::Zero
end

Expand All @@ -20,13 +21,3 @@ end
function Base.getindex(a::SparseArray{T,N}, I::Vararg{Int,N}) where {T,N}
return getindex(a, CartesianIndex(I))
end

## # `getindex` but uses a default if the value is
## # structurally zero.
## function get_nonzero(a::SparseArray{T,N}, I::CartesianIndex{N}, zero) where {T,N}
## @boundscheck checkbounds(a, I)
## return get(a.data, I, zero)
## end
## function get_nonzero(a::SparseArray{T,N}, I::NTuple{N,Int}, zero) where {T,N}
## return get_nonzero(a, CartesianIndex(I), zero)
## end
49 changes: 49 additions & 0 deletions NDTensors/src/DiagonalArrays/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# DiagonalArrays.jl

A Julia `DiagonalArray` type.

````julia
using NDTensors.DiagonalArrays:
DiagonalArray,
densearray,
diagview,
diaglength,
getdiagindex,
setdiagindex!,
setdiag!,
diagcopyto!

d = DiagonalArray([1., 2, 3], 3, 4, 5)
@show d[1, 1, 1] == 1
@show d[2, 2, 2] == 2
@show d[1, 2, 1] == 0

d[2, 2, 2] = 22
@show d[2, 2, 2] == 22

@show diaglength(d) == 3
@show densearray(d) == d
@show getdiagindex(d, 2) == d[2, 2, 2]

setdiagindex!(d, 222, 2)
@show d[2, 2, 2] == 222

a = randn(3, 4, 5)
new_diag = randn(3)
setdiag!(a, new_diag)
diagcopyto!(d, a)

@show diagview(a) == new_diag
@show diagview(d) == new_diag
````

You can generate this README with:
```julia
using Literate
Literate.markdown("examples/README.jl", "."; flavor=Literate.CommonMarkFlavor())
```

---

*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*

42 changes: 42 additions & 0 deletions NDTensors/src/DiagonalArrays/examples/README.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# # DiagonalArrays.jl
#
# A Julia `DiagonalArray` type.

using NDTensors.DiagonalArrays:
DiagonalArray,
densearray,
diagview,
diaglength,
getdiagindex,
setdiagindex!,
setdiag!,
diagcopyto!

d = DiagonalArray([1.0, 2, 3], 3, 4, 5)
@show d[1, 1, 1] == 1
@show d[2, 2, 2] == 2
@show d[1, 2, 1] == 0

d[2, 2, 2] = 22
@show d[2, 2, 2] == 22

@show diaglength(d) == 3
@show densearray(d) == d
@show getdiagindex(d, 2) == d[2, 2, 2]

setdiagindex!(d, 222, 2)
@show d[2, 2, 2] == 222

a = randn(3, 4, 5)
new_diag = randn(3)
setdiag!(a, new_diag)
diagcopyto!(d, a)

@show diagview(a) == new_diag
@show diagview(d) == new_diag

# You can generate this README with:
# ```julia
# using Literate
# Literate.markdown("examples/README.jl", "."; flavor=Literate.CommonMarkFlavor())
# ```
110 changes: 110 additions & 0 deletions NDTensors/src/DiagonalArrays/src/DiagonalArrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
module DiagonalArrays

using Compat # allequal
using LinearAlgebra

export DiagonalArray

include("diagview.jl")

struct DefaultZero end

function (::DefaultZero)(eltype::Type, I::CartesianIndex)
return zero(eltype)
end

struct DiagonalArray{T,N,Diag<:AbstractVector{T},Zero} <: AbstractArray{T,N}
diag::Diag
dims::NTuple{N,Int}
zero::Zero
end

function DiagonalArray{T,N}(
diag::AbstractVector{T}, d::Tuple{Vararg{Int,N}}, zero=DefaultZero()
) where {T,N}
return DiagonalArray{T,N,typeof(diag),typeof(zero)}(diag, d, zero)
end

function DiagonalArray{T,N}(
diag::AbstractVector, d::Tuple{Vararg{Int,N}}, zero=DefaultZero()
) where {T,N}
return DiagonalArray{T,N}(T.(diag), d, zero)
end

function DiagonalArray{T,N}(diag::AbstractVector, d::Vararg{Int,N}) where {T,N}
return DiagonalArray{T,N}(diag, d)
end

function DiagonalArray{T}(
diag::AbstractVector, d::Tuple{Vararg{Int,N}}, zero=DefaultZero()
) where {T,N}
return DiagonalArray{T,N}(diag, d, zero)
end

function DiagonalArray{T}(diag::AbstractVector, d::Vararg{Int,N}) where {T,N}
return DiagonalArray{T,N}(diag, d)
end

function DiagonalArray(diag::AbstractVector{T}, d::Tuple{Vararg{Int,N}}) where {T,N}
return DiagonalArray{T,N}(diag, d)
end

function DiagonalArray(diag::AbstractVector{T}, d::Vararg{Int,N}) where {T,N}
return DiagonalArray{T,N}(diag, d)
end

# undef
function DiagonalArray{T,N}(::UndefInitializer, d::Tuple{Vararg{Int,N}}) where {T,N}
return DiagonalArray{T,N}(Vector{T}(undef, minimum(d)), d)
end

function DiagonalArray{T,N}(::UndefInitializer, d::Vararg{Int,N}) where {T,N}
return DiagonalArray{T,N}(undef, d)
end

function DiagonalArray{T}(::UndefInitializer, d::Tuple{Vararg{Int,N}}) where {T,N}
return DiagonalArray{T,N}(undef, d)
end

function DiagonalArray{T}(::UndefInitializer, d::Vararg{Int,N}) where {T,N}
return DiagonalArray{T,N}(undef, d)
end

Base.size(a::DiagonalArray) = a.dims

diagview(a::DiagonalArray) = a.diag
LinearAlgebra.diag(a::DiagonalArray) = copy(diagview(a))

function Base.getindex(a::DiagonalArray{T,N}, I::CartesianIndex{N}) where {T,N}
i = diagindex(a, I)
isnothing(i) && return a.zero(T, I)
return getdiagindex(a, i)
end

function Base.getindex(a::DiagonalArray{T,N}, I::Vararg{Int,N}) where {T,N}
return getindex(a, CartesianIndex(I))
end

function Base.setindex!(a::DiagonalArray{T,N}, v, I::CartesianIndex{N}) where {T,N}
i = diagindex(a, I)
isnothing(i) && return error("Can't set off-diagonal element of DiagonalArray")
setdiagindex!(a, v, i)
return a
end

function Base.setindex!(a::DiagonalArray{T,N}, v, I::Vararg{Int,N}) where {T,N}
a[CartesianIndex(I)] = v
return a
end

# Make dense.
function densearray(a::DiagonalArray)
# TODO: Check this works on GPU.
# TODO: Make use of `a.zero`?
d = similar(diagview(a), size(a))
fill!(d, zero(eltype(a)))
diagcopyto!(d, a)
return d
end

end
54 changes: 54 additions & 0 deletions NDTensors/src/DiagonalArrays/src/diagview.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Convert to an offset along the diagonal.
# Otherwise, return `nothing`.
function diagindex(a::AbstractArray{T,N}, I::CartesianIndex{N}) where {T,N}
!allequal(Tuple(I)) && return nothing
return first(Tuple(I))
end

function diagindex(a::AbstractArray{T,N}, I::Vararg{Int,N}) where {T,N}
return diagindex(a, CartesianIndex(I))
end

function getdiagindex(a::AbstractArray, i::Integer)
return diagview(a)[i]
end

function setdiagindex!(a::AbstractArray, v, i::Integer)
diagview(a)[i] = v
return a
end

function setdiag!(a::AbstractArray, v)
copyto!(diagview(a), v)
return a
end

function diaglength(a::AbstractArray)
# length(diagview(a))
return minimum(size(a))
end

function diagstride(A::AbstractArray)
s = 1
p = 1
for i in 1:(ndims(A) - 1)
p *= size(A, i)
s += p
end
return s
end

function diagindices(A::AbstractArray)
diaglength = minimum(size(A))
maxdiag = LinearIndices(A)[CartesianIndex(ntuple(Returns(diaglength), ndims(A)))]
return 1:diagstride(A):maxdiag
end

function diagview(A::AbstractArray)
return @view A[diagindices(A)]
end

function diagcopyto!(dest::AbstractArray, src::AbstractArray)
copyto!(diagview(dest), diagview(src))
return dest
end
10 changes: 10 additions & 0 deletions NDTensors/src/DiagonalArrays/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using Test
using NDTensors.DiagonalArrays

@testset "Test NDTensors.DiagonalArrays" begin
@testset "README" begin
@test include(
joinpath(pkgdir(DiagonalArrays), "src", "DiagonalArrays", "examples", "README.jl")
) isa Any
end
end
2 changes: 2 additions & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ using TupleTools

include("SetParameters/src/SetParameters.jl")
using .SetParameters
include("DiagonalArrays/src/DiagonalArrays.jl")
using .DiagonalArrays
include("BlockSparseArrays/src/BlockSparseArrays.jl")
using .BlockSparseArrays
include("SmallVectors/src/SmallVectors.jl")
Expand Down
4 changes: 4 additions & 0 deletions NDTensors/test/DiagonalArrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
using Test
using NDTensors

include(joinpath(pkgdir(NDTensors), "src", "DiagonalArrays", "test", "runtests.jl"))
1 change: 1 addition & 0 deletions NDTensors/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ end
@safetestset "NDTensors" begin
@testset "$filename" for filename in [
"BlockSparseArrays.jl",
"DiagonalArrays.jl",
"SetParameters.jl",
"SmallVectors.jl",
"SortedSets.jl",
Expand Down

0 comments on commit c8664bf

Please sign in to comment.