From 3fcb2e6003303dc0f8b903a411a9b22a12230137 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 30 May 2024 20:44:30 -0400 Subject: [PATCH 01/10] Format --- .../src/abstractblocksparsearray/arraylayouts.jl | 1 - .../src/abstractblocksparsearray/linearalgebra.jl | 12 ------------ .../src/blocksparsearrayinterface/arraylayouts.jl | 9 ++++++++- 3 files changed, 8 insertions(+), 14 deletions(-) delete mode 100644 NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/linearalgebra.jl diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl index d8e79ba743..c9cbf33cdc 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl @@ -1,7 +1,6 @@ using ArrayLayouts: ArrayLayouts, MemoryLayout, MulAdd using BlockArrays: BlockLayout using ..SparseArrayInterface: SparseLayout -using LinearAlgebra: mul! function ArrayLayouts.MemoryLayout(arraytype::Type{<:BlockSparseArrayLike}) outer_layout = typeof(MemoryLayout(blockstype(arraytype))) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/linearalgebra.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/linearalgebra.jl deleted file mode 100644 index 47914daaf8..0000000000 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/linearalgebra.jl +++ /dev/null @@ -1,12 +0,0 @@ -using LinearAlgebra: LinearAlgebra, mul! - -function LinearAlgebra.mul!( - a_dest::AbstractMatrix, - a1::AbstractBlockSparseMatrix, - a2::AbstractBlockSparseMatrix, - α::Number=true, - β::Number=false, -) - mul!(blocks(a_dest), blocks(a1), blocks(a2), α, β) - return a_dest -end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/arraylayouts.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/arraylayouts.jl index bf4d515a34..7543ad434d 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/arraylayouts.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/arraylayouts.jl @@ -3,6 +3,13 @@ using BlockArrays: BlockLayout using ..SparseArrayInterface: SparseLayout using LinearAlgebra: mul! +function blocksparse_muladd!( + α::Number, a1::AbstractMatrix, a2::AbstractMatrix, β::Number, a_dest::AbstractMatrix +) + mul!(blocks(a_dest), blocks(a1), blocks(a2), α, β) + return a_dest +end + function ArrayLayouts.materialize!( m::MatMulMatAdd{ <:BlockLayout{<:SparseLayout}, @@ -11,6 +18,6 @@ function ArrayLayouts.materialize!( }, ) α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C - mul!(a_dest, a1, a2, α, β) + blocksparse_muladd!(α, a1, a2, β, a_dest) return a_dest end From 3dbaa8623fa7d4963a1420411533025361e5d21b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 31 May 2024 11:04:59 -0400 Subject: [PATCH 02/10] Remove include for deleted file --- NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl index d0430732fb..514b73cee7 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl @@ -1,6 +1,5 @@ module BlockSparseArrays include("blocksparsearrayinterface/blocksparsearrayinterface.jl") -include("blocksparsearrayinterface/linearalgebra.jl") include("blocksparsearrayinterface/blockzero.jl") include("blocksparsearrayinterface/broadcast.jl") include("blocksparsearrayinterface/arraylayouts.jl") From f7bdaf9812414e7e1c9060b021e85c86bb883af2 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 31 May 2024 11:07:32 -0400 Subject: [PATCH 03/10] Fix includes --- NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl index 514b73cee7..09064999e4 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl @@ -1,5 +1,6 @@ module BlockSparseArrays include("blocksparsearrayinterface/blocksparsearrayinterface.jl") +include("blocksparsearrayinterface/linearalgebra.jl") include("blocksparsearrayinterface/blockzero.jl") include("blocksparsearrayinterface/broadcast.jl") include("blocksparsearrayinterface/arraylayouts.jl") @@ -10,7 +11,6 @@ include("abstractblocksparsearray/abstractblocksparsevector.jl") include("abstractblocksparsearray/view.jl") include("abstractblocksparsearray/arraylayouts.jl") include("abstractblocksparsearray/sparsearrayinterface.jl") -include("abstractblocksparsearray/linearalgebra.jl") include("abstractblocksparsearray/broadcast.jl") include("abstractblocksparsearray/map.jl") include("blocksparsearray/defaults.jl") From 49e7019f2bd2a7633a8d3503e1e643a75fbbec40 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 31 May 2024 11:26:00 -0400 Subject: [PATCH 04/10] Fix transposed matmul --- .../blocksparsearrayinterface.jl | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 56e6e7c96b..b93fb5b9b2 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -141,6 +141,15 @@ end function Base.getindex(a::SparseTransposeBlocks, index::Vararg{Int,2}) return transpose(blocks(parent(a.array))[reverse(index)...]) end +# TODO: This should be handled by generic `AbstractSparseArray` code. +function Base.getindex(a::SparseTransposeBlocks, index::CartesianIndex{2}) + return a[Tuple(index)...] +end +# TODO: Create a generic `parent_index` function to map an index +# a parent index. +function Base.isassigned(a::SparseTransposeBlocks, index::Vararg{Int,2}) + return isassigned(blocks(parent(a.array)), reverse(index)...) +end function SparseArrayInterface.stored_indices(a::SparseTransposeBlocks) return map(reverse_index, stored_indices(blocks(parent(a.array)))) end @@ -163,9 +172,22 @@ end function Base.size(a::SparseAdjointBlocks) return reverse(size(blocks(parent(a.array)))) end +# TODO: Create a generic `parent_index` function to map an index +# a parent index. function Base.getindex(a::SparseAdjointBlocks, index::Vararg{Int,2}) return blocks(parent(a.array))[reverse(index)...]' end +# TODO: Create a generic `parent_index` function to map an index +# a parent index. +# TODO: This should be handled by generic `AbstractSparseArray` code. +function Base.getindex(a::SparseAdjointBlocks, index::CartesianIndex{2}) + return a[Tuple(index)...] +end +# TODO: Create a generic `parent_index` function to map an index +# a parent index. +function Base.isassigned(a::SparseAdjointBlocks, index::Vararg{Int,2}) + return isassigned(blocks(parent(a.array)), reverse(index)...) +end function SparseArrayInterface.stored_indices(a::SparseAdjointBlocks) return map(reverse_index, stored_indices(blocks(parent(a.array)))) end @@ -229,9 +251,6 @@ end function Base.size(a::SparseSubArrayBlocks) return length.(axes(a)) end -function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::CartesianIndex{N}) where {N} - return a[Tuple(I)...] -end function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N} parent_blocks = @view blocks(parent(a.array))[blockrange(a)...] parent_block = parent_blocks[I...] @@ -239,6 +258,10 @@ function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where block = Block(ntuple(i -> blockrange(a)[i][I[i]], ndims(a))) return @view parent_block[blockindices(parent(a.array), block, a.array.indices)...] end +# TODO: This should be handled by generic `AbstractSparseArray` code. +function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::CartesianIndex{N}) where {N} + return a[Tuple(I)...] +end function Base.setindex!(a::SparseSubArrayBlocks{<:Any,N}, value, I::Vararg{Int,N}) where {N} parent_blocks = view(blocks(parent(a.array)), axes(a)...) return parent_blocks[I...][blockindices(parent(a.array), Block(I), a.array.indices)...] = From 80708b12ab5fde6c4af2fcd60e1ff9d06ba360bf Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 31 May 2024 11:45:01 -0400 Subject: [PATCH 05/10] Dual axes with adjoint --- .../src/BlockSparseArraysGradedAxesExt.jl | 57 ++++++++++++++++--- 1 file changed, 50 insertions(+), 7 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl index a5752f35d8..9976a034b6 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl @@ -1,15 +1,21 @@ module BlockSparseArraysGradedAxesExt using BlockArrays: AbstractBlockVector, Block, BlockedUnitRange, blocks using ..BlockSparseArrays: - BlockSparseArrays, AbstractBlockSparseArray, BlockSparseArray, block_merge + BlockSparseArrays, + AbstractBlockSparseArray, + BlockSparseArray, + BlockSparseMatrix, + block_merge using ...GradedAxes: GradedUnitRange, OneToOne, blockmergesortperm, blocksortperm, + dual, invblockperm, nondual, tensor_product +using LinearAlgebra: Adjoint, Transpose using ...TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims @@ -61,19 +67,56 @@ function Base.eachindex(a::AbstractBlockSparseArray) return CartesianIndices(nondual.(axes(a))) end +function Base.adjoint(a::BlockSparseMatrix) + return Adjoint(BlockSparseArray(blocks(a), dual.(axes(a)))) +end + # This is a temporary fix for `show` being broken for BlockSparseArrays # with mixed dual and non-dual axes. This shouldn't be needed once # GradedAxes is rewritten using BlockArrays v1. # TODO: Delete this once GradedAxes is rewritten. -function Base.show(io::IO, mime::MIME"text/plain", a::BlockSparseArray; kwargs...) - a_nondual = BlockSparseArray(blocks(a), nondual.(axes(a))) - println(io, "typeof(axes) = ", typeof(axes(a)), "\n") +function blocksparse_show( + io::IO, mime::MIME"text/plain", a::AbstractArray, axes_a::Tuple; kwargs... +) + println(io, "typeof(axes) = ", typeof(axes_a), "\n") println( io, "Warning: To temporarily circumvent a bug in printing BlockSparseArrays with mixtures of dual and non-dual axes, the types of the dual axes printed below might not be accurate. The types printed above this message are the correct ones.\n", ) - return invoke( - show, Tuple{IO,MIME"text/plain",AbstractArray}, io, mime, a_nondual; kwargs... - ) + return invoke(show, Tuple{IO,MIME"text/plain",AbstractArray}, io, mime, a; kwargs...) +end + +# This is a temporary fix for `show` being broken for BlockSparseArrays +# with mixed dual and non-dual axes. This shouldn't be needed once +# GradedAxes is rewritten using BlockArrays v1. +# TODO: Delete this once GradedAxes is rewritten. +function Base.show(io::IO, mime::MIME"text/plain", a::BlockSparseArray; kwargs...) + axes_a = axes(a) + a_nondual = BlockSparseArray(blocks(a), nondual.(axes(a))) + return blocksparse_show(io, mime, a_nondual, axes_a; kwargs...) +end + +# This is a temporary fix for `show` being broken for BlockSparseArrays +# with mixed dual and non-dual axes. This shouldn't be needed once +# GradedAxes is rewritten using BlockArrays v1. +# TODO: Delete this once GradedAxes is rewritten. +function Base.show( + io::IO, mime::MIME"text/plain", a::Adjoint{<:Any,<:BlockSparseMatrix}; kwargs... +) + axes_a = axes(a) + a_nondual = BlockSparseArray(blocks(a'), dual.(nondual.(axes(a))))' + return blocksparse_show(io, mime, a_nondual, axes_a; kwargs...) +end + +# This is a temporary fix for `show` being broken for BlockSparseArrays +# with mixed dual and non-dual axes. This shouldn't be needed once +# GradedAxes is rewritten using BlockArrays v1. +# TODO: Delete this once GradedAxes is rewritten. +function Base.show( + io::IO, mime::MIME"text/plain", a::Transpose{<:Any,<:BlockSparseMatrix}; kwargs... +) + axes_a = axes(a) + a_nondual = tranpose(BlockSparseArray(transpose(blocks(a)), nondual.(axes(a)))) + return blocksparse_show(io, mime, a_nondual, axes_a; kwargs...) end end From ae84ba0a6fad8ef20915703913c7c34a1e0729ad Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 31 May 2024 12:12:37 -0400 Subject: [PATCH 06/10] Add tests for block sparse matrix multiplication --- .../wrappedabstractblocksparsearray.jl | 18 ++++++++++++++++++ .../lib/BlockSparseArrays/test/test_basics.jl | 12 ++++++++++++ 2 files changed, 30 insertions(+) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index 784b4e27c8..a98353e152 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -88,6 +88,24 @@ function Base.similar( return similar(arraytype, eltype(arraytype), axes) end +# Needed by `BlockArrays` matrix multiplication interface +# Fixes ambiguity error with `BlockArrays.jl`. +function Base.similar( + arraytype::Type{<:BlockSparseArrayLike}, + axes::Tuple{BlockedUnitRange,Vararg{AbstractUnitRange{Int}}}, +) + return similar(arraytype, eltype(arraytype), axes) +end + +# Needed by `BlockArrays` matrix multiplication interface +# Fixes ambiguity error with `BlockArrays.jl`. +function Base.similar( + arraytype::Type{<:BlockSparseArrayLike}, + axes::Tuple{AbstractUnitRange{Int},BlockedUnitRange,Vararg{AbstractUnitRange{Int}}}, +) + return similar(arraytype, eltype(arraytype), axes) +end + # Needed for disambiguation function Base.similar( arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Vararg{BlockedUnitRange}} diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index 62cc23d9db..565d349c33 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -266,6 +266,18 @@ include("TestBlockSparseArraysUtils.jl") @test a_dest isa BlockSparseArray{elt} @test block_nstored(a_dest) == 1 end + @testset "Matrix multiplication" begin + a1 = BlockSparseArray{elt}([2, 3], [2, 3]) + a1[Block(1, 2)] = randn(elt, size(@view(a1[Block(1, 2)]))) + a1[Block(2, 1)] = randn(elt, size(@view(a1[Block(2, 1)]))) + a2 = BlockSparseArray{elt}([2, 3], [2, 3]) + a2[Block(1, 2)] = randn(elt, size(@view(a2[Block(1, 2)]))) + a2[Block(2, 1)] = randn(elt, size(@view(a2[Block(2, 1)]))) + @test Array(a1 * a2) ≈ Array(a1) * Array(a2) + @test Array(a1' * a2) ≈ Array(a1') * Array(a2) + @test Array(a1 * a2') ≈ Array(a1) * Array(a2') + @test Array(a1' * a2') ≈ Array(a1') * Array(a2') + end @testset "TensorAlgebra" begin a1 = BlockSparseArray{elt}([2, 3], [2, 3]) a1[Block(1, 1)] = randn(elt, size(@view(a1[Block(1, 1)]))) From 876495f49019216f9fc9b2074e26f0fd96adde23 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 31 May 2024 12:18:54 -0400 Subject: [PATCH 07/10] Add tests for matrix multiplication with dual axes --- .../test/runtests.jl | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index fe9cb7c13d..e66c45aaef 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -3,7 +3,7 @@ using Compat: Returns using Test: @test, @testset, @test_broken using BlockArrays: Block, blocksize using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored -using NDTensors.GradedAxes: GradedAxes, GradedUnitRange, dual, gradedrange +using NDTensors.GradedAxes: GradedAxes, GradedUnitRange, UnitRangeDual, dual, gradedrange using NDTensors.LabelledNumbers: label using NDTensors.SparseArrayInterface: nstored using NDTensors.TensorAlgebra: fusedims, splitdims @@ -87,8 +87,28 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) for I in eachindex(a) @test a[I] == a_dense[I] end - + @test axes(a') == dual.(reverse(axes(a))) + # TODO: Define and use `isdual` here. + @test axes(a', 1) isa UnitRangeDual + @test !(axes(a', 2) isa UnitRangeDual) @test isnothing(show(devnull, MIME("text/plain"), a)) end + @testset "Matrix multiplication" begin + r = gradedrange([U1(0) => 2, U1(1) => 3]) + a1 = BlockSparseArray{elt}(dual(r), r) + a1[Block(1, 2)] = randn(elt, size(@view(a1[Block(1, 2)]))) + a1[Block(2, 1)] = randn(elt, size(@view(a1[Block(2, 1)]))) + a2 = BlockSparseArray{elt}(dual(r), r) + a2[Block(1, 2)] = randn(elt, size(@view(a2[Block(1, 2)]))) + a2[Block(2, 1)] = randn(elt, size(@view(a2[Block(2, 1)]))) + @test Array(a1 * a2) ≈ Array(a1) * Array(a2) + @test Array(a1' * a2') ≈ Array(a1') * Array(a2') + + a2 = BlockSparseArray{elt}(r, dual(r)) + a2[Block(1, 2)] = randn(elt, size(@view(a2[Block(1, 2)]))) + a2[Block(2, 1)] = randn(elt, size(@view(a2[Block(2, 1)]))) + @test Array(a1' * a2) ≈ Array(a1') * Array(a2) + @test Array(a1 * a2') ≈ Array(a1) * Array(a2') + end end end From 13796dbe4b85d4c0443d8cf09877c24878f1d9bb Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 31 May 2024 12:19:37 -0400 Subject: [PATCH 08/10] Bump to v0.3.17 --- NDTensors/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NDTensors/Project.toml b/NDTensors/Project.toml index e6e2f687bd..23f556c2a1 100644 --- a/NDTensors/Project.toml +++ b/NDTensors/Project.toml @@ -1,7 +1,7 @@ name = "NDTensors" uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" authors = ["Matthew Fishman "] -version = "0.3.16" +version = "0.3.17" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" From 41b54935dbef7964954f9f439ab7d57ac42ec2bd Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 31 May 2024 14:16:12 -0400 Subject: [PATCH 09/10] Fix ambiguity error with OffsetArrays in older Julia versions --- .../wrappedabstractblocksparsearray.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index a98353e152..286847992e 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -88,6 +88,17 @@ function Base.similar( return similar(arraytype, eltype(arraytype), axes) end +# Needed by `BlockArrays` matrix multiplication interface +# TODO: This fixes an ambiguity error with `OffsetArrays.jl`, but +# is only appears to be needed in older versions of Julia like v1.6. +# Delete once we drop support for older versions of Julia. +function Base.similar( + arraytype::Type{<:BlockSparseArrayLike}, + axes::Tuple{AbstractUnitRange,Vararg{AbstractUnitRange}}, +) + return similar(arraytype, eltype(arraytype), axes) +end + # Needed by `BlockArrays` matrix multiplication interface # Fixes ambiguity error with `BlockArrays.jl`. function Base.similar( From 5ab1564c3eaf29f095ecc1c0e865a1f583d97b66 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 31 May 2024 14:34:10 -0400 Subject: [PATCH 10/10] Dual the axes in adjoint in a more elegant way --- .../src/BlockSparseArraysGradedAxesExt.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl index 9976a034b6..55723344e1 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl @@ -3,6 +3,7 @@ using BlockArrays: AbstractBlockVector, Block, BlockedUnitRange, blocks using ..BlockSparseArrays: BlockSparseArrays, AbstractBlockSparseArray, + AbstractBlockSparseMatrix, BlockSparseArray, BlockSparseMatrix, block_merge @@ -67,8 +68,11 @@ function Base.eachindex(a::AbstractBlockSparseArray) return CartesianIndices(nondual.(axes(a))) end -function Base.adjoint(a::BlockSparseMatrix) - return Adjoint(BlockSparseArray(blocks(a), dual.(axes(a)))) +# TODO: Handle this through some kind of trait dispatch, maybe +# a `SymmetryStyle`-like trait to check if the block sparse +# matrix has graded axes. +function Base.axes(a::Adjoint{<:Any,<:AbstractBlockSparseMatrix}) + return dual.(reverse(axes(a'))) end # This is a temporary fix for `show` being broken for BlockSparseArrays