From 2985e9baa203c25cd92fbc5ee91aa79eb65d87d2 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Thu, 27 Jun 2024 16:03:11 -0400 Subject: [PATCH] [BlockSparseArrays] Redesign block views again (#1513) * [BlockSparseArrays] Redesign block views again * [NDTensors] Bump to v0.3.38 --- NDTensors/Project.toml | 2 +- .../BlockArraysExtensions.jl | 27 +++++++++++++++++++ .../abstractblocksparsearray.jl | 15 +++++++++++ .../src/abstractblocksparsearray/views.jl | 18 ++++++++++++- .../lib/BlockSparseArrays/test/test_basics.jl | 12 ++++----- .../src/lib/GradedAxes/src/unitrangedual.jl | 15 +++++++++++ .../LabelledNumbers/src/labelledunitrange.jl | 3 +++ 7 files changed, 84 insertions(+), 8 deletions(-) diff --git a/NDTensors/Project.toml b/NDTensors/Project.toml index 82ac74c82b..b831edc8a9 100644 --- a/NDTensors/Project.toml +++ b/NDTensors/Project.toml @@ -1,7 +1,7 @@ name = "NDTensors" uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" authors = ["Matthew Fishman "] -version = "0.3.37" +version = "0.3.38" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 76ff94eb1b..499fd42089 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -396,6 +396,33 @@ function blocked_cartesianindices(axes::Tuple, subaxes::Tuple, blocks) end end +# Represents a view of a block of a blocked array. +struct BlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N} + array::Array + block::Tuple{Vararg{Block{1,Int},N}} +end +function Base.axes(a::BlockView) + # TODO: Try to avoid conversion to `Base.OneTo{Int}`, or just convert + # the element type to `Int` with `Int.(...)`. + # When the axes of `a.array` are `GradedOneTo`, the block is `LabelledUnitRange`, + # which has element type `LabelledInteger`. That causes conversion problems + # in some generic Base Julia code, for example when printing `BlockView`. + return ntuple(ndims(a)) do dim + return Base.OneTo{Int}(only(axes(axes(a.array, dim)[a.block[dim]]))) + end +end +function Base.size(a::BlockView) + return length.(axes(a)) +end +function Base.getindex(a::BlockView{<:Any,N}, index::Vararg{Int,N}) where {N} + return blocks(a.array)[Int.(a.block)...][index...] +end +function Base.setindex!(a::BlockView{<:Any,N}, value, index::Vararg{Int,N}) where {N} + blocks(a.array)[Int.(a.block)...] = blocks(a.array)[Int.(a.block)...] + blocks(a.array)[Int.(a.block)...][index...] = value + return a +end + function view!(a::BlockSparseArray{<:Any,N}, index::Block{N}) where {N} return view!(a, Tuple(index)...) end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl index 40c15b7d05..1cf1b21cdc 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl @@ -42,3 +42,18 @@ function Base.setindex!( blocksparse_setindex!(a, value, I...) return a end + +function Base.setindex!( + a::AbstractBlockSparseArray{<:Any,N}, value, I::Vararg{Block{1},N} +) where {N} + blocksize = ntuple(dim -> length(axes(a, dim)[I[dim]]), N) + if size(value) ≠ blocksize + throw( + DimensionMismatch( + "Trying to set block $(Block(Int.(I)...)), which has a size $blocksize, with data of size $(size(value)).", + ), + ) + end + blocks(a)[Int.(I)...] = value + return a +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl index c09ac27890..456bb81827 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl @@ -1,4 +1,4 @@ -using BlockArrays: Block, BlockSlices +using BlockArrays: BlockArrays, Block, BlockSlices, viewblock function blocksparse_view(a, I...) return Base.invoke(view, Tuple{AbstractArray,Vararg{Any}}, a, I...) @@ -22,3 +22,19 @@ function Base.view( ) return blocksparse_view(a, I) end + +# Specialized code for getting the view of a block. +function BlockArrays.viewblock( + a::AbstractBlockSparseArray{<:Any,N}, block::Block{N} +) where {N} + return viewblock(a, Tuple(block)...) +end +function BlockArrays.viewblock( + a::AbstractBlockSparseArray{<:Any,N}, block::Vararg{Block{1},N} +) where {N} + I = CartesianIndex(Int.(block)) + if I ∈ stored_indices(blocks(a)) + return blocks(a)[I] + end + return BlockView(a, block) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index d11fcf7f9c..05799afebc 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -17,7 +17,7 @@ using BlockArrays: using Compat: @compat using LinearAlgebra: mul! using NDTensors.BlockSparseArrays: - @view!, BlockSparseArray, block_nstored, block_reshape, view! + @view!, BlockSparseArray, BlockView, block_nstored, block_reshape, view! using NDTensors.SparseArrayInterface: nstored using NDTensors.TensorAlgebra: contract using Test: @test, @test_broken, @test_throws, @testset @@ -362,10 +362,10 @@ include("TestBlockSparseArraysUtils.jl") b = @view a[Block(2, 2)] @test size(b) == (3, 4) for i in parentindices(b) - @test i isa BlockSlice{<:Block{1}} + @test i isa Base.OneTo{Int} end - @test parentindices(b)[1] == BlockSlice(Block(2), 3:5) - @test parentindices(b)[2] == BlockSlice(Block(2), 4:7) + @test parentindices(b)[1] == 1:3 + @test parentindices(b)[2] == 1:4 a = BlockSparseArray{elt}([2, 3], [3, 4]) b = @view a[Block(2, 2)[1:2, 2:2]] @@ -392,9 +392,9 @@ include("TestBlockSparseArraysUtils.jl") a = BlockSparseArray{elt}([2, 3], [3, 4]) b = @views a[Block(2, 2)][1:2, 2:3] - @test b isa SubArray{<:Any,<:Any,<:BlockSparseArray} + @test b isa SubArray{<:Any,<:Any,<:BlockView} for i in parentindices(b) - @test i isa BlockSlice{<:BlockIndexRange{1}} + @test i isa UnitRange{Int} end x = randn(elt, 2, 2) b .= x diff --git a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl index 358542856c..495c90a239 100644 --- a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl @@ -68,6 +68,21 @@ end using NDTensors.LabelledNumbers: LabelledNumbers, label LabelledNumbers.label(a::UnitRangeDual) = dual(label(nondual(a))) +using NDTensors.LabelledNumbers: LabelledUnitRange +# The Base version of `length(::AbstractUnitRange)` drops the label. +function Base.length(a::UnitRangeDual{<:Any,<:LabelledUnitRange}) + return dual(length(nondual(a))) +end +function Base.iterate(a::UnitRangeDual, i) + i == last(a) && return nothing + return dual.(iterate(nondual(a), i)) +end +# TODO: Is this a good definition? +Base.unitrange(a::UnitRangeDual{<:Any,<:AbstractUnitRange}) = a + +using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, unlabel +dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i))) + using BlockArrays: BlockArrays, blockaxes, blocklasts, combine_blockaxes, findblock BlockArrays.blockaxes(a::UnitRangeDual) = blockaxes(nondual(a)) BlockArrays.blockfirsts(a::UnitRangeDual) = label_dual.(blockfirsts(nondual(a))) diff --git a/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl b/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl index 62a0ddebdf..2e4379daba 100644 --- a/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl +++ b/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl @@ -28,6 +28,9 @@ function Base.OrdinalRange{T,T}(a::LabelledUnitRange) where {T<:Integer} return OrdinalRange{T,T}(unlabel(a)) end +# TODO: Is this a good definition? +Base.unitrange(a::LabelledUnitRange) = a + for f in [:first, :getindex, :last, :length, :step] @eval Base.$f(a::LabelledUnitRange, args...) = labelled($f(unlabel(a), args...), label(a)) end