-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BlockSparseArrays] More general broadcasting and slicing (#1332)
- Loading branch information
Showing
25 changed files
with
824 additions
and
126 deletions.
There are no files selected for viewing
78 changes: 63 additions & 15 deletions
78
NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,74 @@ | ||
@eval module $(gensym()) | ||
using Compat: Returns | ||
using Test: @test, @testset, @test_broken | ||
using BlockArrays: Block, blocksize | ||
using NDTensors.BlockSparseArrays: BlockSparseArray | ||
using NDTensors.GradedAxes: gradedrange | ||
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored | ||
using NDTensors.GradedAxes: GradedUnitRange, gradedrange | ||
using NDTensors.LabelledNumbers: label | ||
using NDTensors.Sectors: U1 | ||
using NDTensors.SparseArrayInterface: nstored | ||
using NDTensors.TensorAlgebra: fusedims, splitdims | ||
using Random: randn! | ||
function blockdiagonal!(f, a::AbstractArray) | ||
for i in 1:minimum(blocksize(a)) | ||
b = Block(ntuple(Returns(i), ndims(a))) | ||
a[b] = f(a[b]) | ||
end | ||
return a | ||
end | ||
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) | ||
@testset "BlockSparseArraysGradedAxesExt (eltype=$elt)" for elt in elts | ||
d1 = gradedrange([U1(0) => 1, U1(1) => 1]) | ||
d2 = gradedrange([U1(1) => 1, U1(0) => 1]) | ||
a = BlockSparseArray{elt}(d1, d2, d1, d2) | ||
for i in 1:minimum(blocksize(a)) | ||
b = Block(i, i, i, i) | ||
a[b] = randn!(a[b]) | ||
@testset "map" begin | ||
d1 = gradedrange([U1(0) => 2, U1(1) => 2]) | ||
d2 = gradedrange([U1(0) => 2, U1(1) => 2]) | ||
a = BlockSparseArray{elt}(d1, d2, d1, d2) | ||
blockdiagonal!(randn!, a) | ||
|
||
for b in (a + a, 2 * a) | ||
@test size(b) == (4, 4, 4, 4) | ||
@test blocksize(b) == (2, 2, 2, 2) | ||
@test nstored(b) == 32 | ||
@test block_nstored(b) == 2 | ||
# TODO: Have to investigate why this fails | ||
# on Julia v1.6, or drop support for v1.6. | ||
for i in 1:ndims(a) | ||
@test axes(b, i) isa GradedUnitRange | ||
end | ||
@test label(axes(b, 1)[Block(1)]) == U1(0) | ||
@test label(axes(b, 1)[Block(2)]) == U1(1) | ||
@test Array(a) isa Array{elt} | ||
@test Array(a) == a | ||
@test 2 * Array(a) == b | ||
end | ||
|
||
b = a[2:3, 2:3, 2:3, 2:3] | ||
@test size(b) == (2, 2, 2, 2) | ||
@test blocksize(b) == (2, 2, 2, 2) | ||
@test nstored(b) == 2 | ||
@test block_nstored(b) == 2 | ||
for i in 1:ndims(a) | ||
@test axes(b, i) isa GradedUnitRange | ||
end | ||
@test label(axes(b, 1)[Block(1)]) == U1(0) | ||
@test label(axes(b, 1)[Block(2)]) == U1(1) | ||
@test Array(a) isa Array{elt} | ||
@test Array(a) == a | ||
end | ||
# TODO: Add tests for various slicing operations. | ||
@testset "fusedims" begin | ||
d1 = gradedrange([U1(0) => 1, U1(1) => 1]) | ||
d2 = gradedrange([U1(0) => 1, U1(1) => 1]) | ||
a = BlockSparseArray{elt}(d1, d2, d1, d2) | ||
blockdiagonal!(randn!, a) | ||
m = fusedims(a, (1, 2), (3, 4)) | ||
@test axes(m, 1) isa GradedUnitRange | ||
@test axes(m, 2) isa GradedUnitRange | ||
@test a[1, 1, 1, 1] == m[1, 1] | ||
@test a[2, 2, 2, 2] == m[4, 4] | ||
# TODO: Current `fusedims` doesn't merge | ||
# common sectors, need to fix. | ||
@test_broken blocksize(m) == (3, 3) | ||
@test a == splitdims(m, (d1, d2), (d1, d2)) | ||
end | ||
m = fusedims(a, (1, 2), (3, 4)) | ||
@test a[1, 1, 1, 1] == m[2, 2] | ||
@test a[2, 2, 2, 2] == m[3, 3] | ||
# TODO: Current `fusedims` doesn't merge | ||
# common sectors, need to fix. | ||
@test_broken blocksize(m) == (3, 3) | ||
@test a == splitdims(m, (d1, d2), (d1, d2)) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6 changes: 6 additions & 0 deletions
6
...SparseArrays/src/BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,11 @@ | ||
using BlockArrays: AbstractBlockArray, BlocksView | ||
using ..SparseArrayInterface: SparseArrayInterface, nstored | ||
|
||
function SparseArrayInterface.nstored(a::AbstractBlockArray) | ||
return sum(b -> nstored(b), blocks(a); init=zero(Int)) | ||
end | ||
|
||
# TODO: Handle `BlocksView` wrapping a sparse array? | ||
function SparseArrayInterface.storage_indices(a::BlocksView) | ||
return CartesianIndices(a) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
43 changes: 43 additions & 0 deletions
43
NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/broadcast.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,48 @@ | ||
using BlockArrays: BlockedUnitRange, BlockSlice | ||
using Base.Broadcast: Broadcast | ||
|
||
function Broadcast.BroadcastStyle(arraytype::Type{<:BlockSparseArrayLike}) | ||
return BlockSparseArrayStyle{ndims(arraytype)}() | ||
end | ||
|
||
# Fix ambiguity error with `BlockArrays`. | ||
function Broadcast.BroadcastStyle( | ||
arraytype::Type{ | ||
<:SubArray{ | ||
<:Any, | ||
<:Any, | ||
<:AbstractBlockSparseArray, | ||
<:Tuple{BlockSlice{<:Any,<:BlockedUnitRange},Vararg{Any}}, | ||
}, | ||
}, | ||
) | ||
return BlockSparseArrayStyle{ndims(arraytype)}() | ||
end | ||
function Broadcast.BroadcastStyle( | ||
arraytype::Type{ | ||
<:SubArray{ | ||
<:Any, | ||
<:Any, | ||
<:AbstractBlockSparseArray, | ||
<:Tuple{ | ||
BlockSlice{<:Any,<:BlockedUnitRange}, | ||
BlockSlice{<:Any,<:BlockedUnitRange}, | ||
Vararg{Any}, | ||
}, | ||
}, | ||
}, | ||
) | ||
return BlockSparseArrayStyle{ndims(arraytype)}() | ||
end | ||
function Broadcast.BroadcastStyle( | ||
arraytype::Type{ | ||
<:SubArray{ | ||
<:Any, | ||
<:Any, | ||
<:AbstractBlockSparseArray, | ||
<:Tuple{Any,BlockSlice{<:Any,<:BlockedUnitRange},Vararg{Any}}, | ||
}, | ||
}, | ||
) | ||
return BlockSparseArrayStyle{ndims(arraytype)}() | ||
end |
Oops, something went wrong.