Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BlockSparseArrays] Redesign block views again #1513

Merged
merged 4 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.37"
version = "0.3.38"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,33 @@ function blocked_cartesianindices(axes::Tuple, subaxes::Tuple, blocks)
end
end

# Represents a view of a block of a blocked array.
struct BlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
array::Array
block::Tuple{Vararg{Block{1,Int},N}}
end
function Base.axes(a::BlockView)
# TODO: Try to avoid conversion to `Base.OneTo{Int}`, or just convert
# the element type to `Int` with `Int.(...)`.
# When the axes of `a.array` are `GradedOneTo`, the block is `LabelledUnitRange`,
# which has element type `LabelledInteger`. That causes conversion problems
# in some generic Base Julia code, for example when printing `BlockView`.
return ntuple(ndims(a)) do dim
return Base.OneTo{Int}(only(axes(axes(a.array, dim)[a.block[dim]])))
end
end
function Base.size(a::BlockView)
return length.(axes(a))
end
function Base.getindex(a::BlockView{<:Any,N}, index::Vararg{Int,N}) where {N}
return blocks(a.array)[Int.(a.block)...][index...]
end
function Base.setindex!(a::BlockView{<:Any,N}, value, index::Vararg{Int,N}) where {N}
blocks(a.array)[Int.(a.block)...] = blocks(a.array)[Int.(a.block)...]
blocks(a.array)[Int.(a.block)...][index...] = value
return a
end

function view!(a::BlockSparseArray{<:Any,N}, index::Block{N}) where {N}
return view!(a, Tuple(index)...)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,18 @@ function Base.setindex!(
blocksparse_setindex!(a, value, I...)
return a
end

function Base.setindex!(
a::AbstractBlockSparseArray{<:Any,N}, value, I::Vararg{Block{1},N}
) where {N}
blocksize = ntuple(dim -> length(axes(a, dim)[I[dim]]), N)
if size(value) ≠ blocksize
throw(
DimensionMismatch(
"Trying to set block $(Block(Int.(I)...)), which has a size $blocksize, with data of size $(size(value)).",
),
)
end
blocks(a)[Int.(I)...] = value
return a
end
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using BlockArrays: Block, BlockSlices
using BlockArrays: BlockArrays, Block, BlockSlices, viewblock

function blocksparse_view(a, I...)
return Base.invoke(view, Tuple{AbstractArray,Vararg{Any}}, a, I...)
Expand All @@ -22,3 +22,19 @@ function Base.view(
)
return blocksparse_view(a, I)
end

# Specialized code for getting the view of a block.
function BlockArrays.viewblock(
a::AbstractBlockSparseArray{<:Any,N}, block::Block{N}
) where {N}
return viewblock(a, Tuple(block)...)
end
function BlockArrays.viewblock(
a::AbstractBlockSparseArray{<:Any,N}, block::Vararg{Block{1},N}
) where {N}
I = CartesianIndex(Int.(block))
if I ∈ stored_indices(blocks(a))
return blocks(a)[I]
end
return BlockView(a, block)
end
12 changes: 6 additions & 6 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using BlockArrays:
using Compat: @compat
using LinearAlgebra: mul!
using NDTensors.BlockSparseArrays:
@view!, BlockSparseArray, block_nstored, block_reshape, view!
@view!, BlockSparseArray, BlockView, block_nstored, block_reshape, view!
using NDTensors.SparseArrayInterface: nstored
using NDTensors.TensorAlgebra: contract
using Test: @test, @test_broken, @test_throws, @testset
Expand Down Expand Up @@ -362,10 +362,10 @@ include("TestBlockSparseArraysUtils.jl")
b = @view a[Block(2, 2)]
@test size(b) == (3, 4)
for i in parentindices(b)
@test i isa BlockSlice{<:Block{1}}
@test i isa Base.OneTo{Int}
end
@test parentindices(b)[1] == BlockSlice(Block(2), 3:5)
@test parentindices(b)[2] == BlockSlice(Block(2), 4:7)
@test parentindices(b)[1] == 1:3
@test parentindices(b)[2] == 1:4

a = BlockSparseArray{elt}([2, 3], [3, 4])
b = @view a[Block(2, 2)[1:2, 2:2]]
Expand All @@ -392,9 +392,9 @@ include("TestBlockSparseArraysUtils.jl")

a = BlockSparseArray{elt}([2, 3], [3, 4])
b = @views a[Block(2, 2)][1:2, 2:3]
@test b isa SubArray{<:Any,<:Any,<:BlockSparseArray}
@test b isa SubArray{<:Any,<:Any,<:BlockView}
for i in parentindices(b)
@test i isa BlockSlice{<:BlockIndexRange{1}}
@test i isa UnitRange{Int}
end
x = randn(elt, 2, 2)
b .= x
Expand Down
15 changes: 15 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/unitrangedual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ end
using NDTensors.LabelledNumbers: LabelledNumbers, label
LabelledNumbers.label(a::UnitRangeDual) = dual(label(nondual(a)))

using NDTensors.LabelledNumbers: LabelledUnitRange
# The Base version of `length(::AbstractUnitRange)` drops the label.
function Base.length(a::UnitRangeDual{<:Any,<:LabelledUnitRange})
return dual(length(nondual(a)))
end
function Base.iterate(a::UnitRangeDual, i)
i == last(a) && return nothing
return dual.(iterate(nondual(a), i))
end
# TODO: Is this a good definition?
Base.unitrange(a::UnitRangeDual{<:Any,<:AbstractUnitRange}) = a

using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, unlabel
dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i)))

using BlockArrays: BlockArrays, blockaxes, blocklasts, combine_blockaxes, findblock
BlockArrays.blockaxes(a::UnitRangeDual) = blockaxes(nondual(a))
BlockArrays.blockfirsts(a::UnitRangeDual) = label_dual.(blockfirsts(nondual(a)))
Expand Down
3 changes: 3 additions & 0 deletions NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ function Base.OrdinalRange{T,T}(a::LabelledUnitRange) where {T<:Integer}
return OrdinalRange{T,T}(unlabel(a))
end

# TODO: Is this a good definition?
Base.unitrange(a::LabelledUnitRange) = a

for f in [:first, :getindex, :last, :length, :step]
@eval Base.$f(a::LabelledUnitRange, args...) = labelled($f(unlabel(a), args...), label(a))
end
Expand Down
Loading