Skip to content

Commit

Permalink
Set up package extensions, improve permutedims
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Dec 16, 2024
1 parent d1958cc commit c9702e7
Show file tree
Hide file tree
Showing 14 changed files with 119 additions and 38 deletions.
16 changes: 12 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,21 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NestedPermutedDimsArrays = "2c2a8ec4-3cfc-4276-aa3e-1307b4294e58"
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"

[weakdeps]
LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"

[extensions]
BlockSparseArraysAdaptExt = "Adapt"
BlockSparseArraysTensorAlgebraExt = ["LabelledNumbers", "TensorAlgebra"]

[compat]
Adapt = "4.1.1"
Aqua = "0.8.9"
Expand All @@ -29,16 +33,20 @@ BlockArrays = "1.2.0"
Derive = "0.3.1"
Dictionaries = "0.4.3"
GPUArraysCore = "0.1.0"
GradedUnitRanges = "0.1.0"
LabelledNumbers = "0.1.0"
LinearAlgebra = "1.10"
MacroTools = "0.5.13"
SparseArraysBase = "0.2"
SplitApplyCombine = "1.2.3"
TensorAlgebra = "0.1.0"
TypeParameterAccessors = "0.1.0"
Test = "1.10"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module BlockSparseArraysAdaptExt
using Adapt: Adapt, adapt
using ..BlockSparseArrays: AbstractBlockSparseArray, map_stored_blocks
using BlockSparseArrays: AbstractBlockSparseArray, map_stored_blocks
Adapt.adapt_structure(to, x::AbstractBlockSparseArray) = map_stored_blocks(adapt(to), x)
end
3 changes: 0 additions & 3 deletions ext/BlockSparseArraysGradedUnitRangesExt/test/Project.toml

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
module BlockSparseArraysTensorAlgebraExt
using BlockArrays: AbstractBlockedUnitRange
using ..BlockSparseArrays: AbstractBlockSparseArray, blockreshape
using GradedUnitRanges: tensor_product
using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion

function TensorAlgebra.:(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
return tensor_product(a1, a2)
end

TensorAlgebra.FusionStyle(::AbstractBlockedUnitRange) = BlockReshapeFusion()

function TensorAlgebra.fusedims(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,38 @@
module BlockSparseArraysGradedUnitRangesExt
module BlockSparseArraysTensorAlgebraExt
using BlockArrays: AbstractBlockedUnitRange
using GradedUnitRanges: tensor_product
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion

function TensorAlgebra.:(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
return tensor_product(a1, a2)

Check warning on line 7 in ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl#L6-L7

Added lines #L6 - L7 were not covered by tests
end

using BlockArrays: AbstractBlockedUnitRange
using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion

TensorAlgebra.FusionStyle(::AbstractBlockedUnitRange) = BlockReshapeFusion()

Check warning on line 14 in ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl#L14

Added line #L14 was not covered by tests

function TensorAlgebra.fusedims(

Check warning on line 16 in ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl#L16

Added line #L16 was not covered by tests
::BlockReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...
)
return blockreshape(a, axes)

Check warning on line 19 in ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl#L19

Added line #L19 was not covered by tests
end

function TensorAlgebra.splitdims(

Check warning on line 22 in ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl#L22

Added line #L22 was not covered by tests
::BlockReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...
)
return blockreshape(a, axes)

Check warning on line 25 in ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl#L25

Added line #L25 was not covered by tests
end

using BlockArrays:
AbstractBlockVector,
AbstractBlockedUnitRange,
Block,
BlockIndexRange,
blockedrange,
blocks
using ..BlockSparseArrays:
using BlockSparseArrays:
BlockSparseArrays,
AbstractBlockSparseArray,
AbstractBlockSparseArrayInterface,
Expand Down
3 changes: 0 additions & 3 deletions ext/BlockSparseArraysTensorAlgebraExt/test/Project.toml

This file was deleted.

5 changes: 0 additions & 5 deletions src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,4 @@ include("abstractblocksparsearray/cat.jl")
include("blocksparsearray/defaults.jl")
include("blocksparsearray/blocksparsearray.jl")
include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl")
include("../ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl")
include(
"../ext/BlockSparseArraysGradedUnitRangesExt/src/BlockSparseArraysGradedUnitRangesExt.jl"
)
include("../ext/BlockSparseArraysAdaptExt/src/BlockSparseArraysAdaptExt.jl")
end
16 changes: 16 additions & 0 deletions src/abstractblocksparsearray/map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,22 @@ function Base.copyto!(
return @interface interface(a_src) copyto!(a_dest, a_src)
end

# This avoids going through the generic version that calls `Base.permutedims!`,
# which eventually calls block sparse `map!`, which involves slicing operations
# that are not friendly to GPU (since they involve `SubArray` wrapping
# `PermutedDimsArray`).
# TODO: Handle slicing better in `map!` so that this can be removed.
function Base.permutedims(a::AnyAbstractBlockSparseArray, perm)
@interface interface(a) permutedims(a, perm)
end

# The `::AbstractBlockSparseArrayInterface` version
# has a special case for when `a_dest` and `PermutedDimsArray(a_src, perm)`
# have the same blocking, and therefore can just use:
# ```julia
# permutedims!(blocks(a_dest), blocks(a_src), perm)
# ```
# TODO: Handle slicing better in `map!` so that this can be removed.
function Base.permutedims!(a_dest, a_src::AnyAbstractBlockSparseArray, perm)
return @interface interface(a_src) permutedims!(a_dest, a_src, perm)
end
Expand Down
71 changes: 60 additions & 11 deletions src/blocksparsearrayinterface/blocksparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@ using BlockArrays:
blocklengths,
blocks,
findblockindex
using Derive: Derive, @interface
using Derive: Derive, @interface, DefaultArrayInterface
using LinearAlgebra: Adjoint, Transpose
using SparseArraysBase:
AbstractSparseArrayInterface, eachstoredindex, perm, iperm, storedlength, storedvalues
AbstractSparseArrayInterface,
getstoredindex,
getunstoredindex,
eachstoredindex,
perm,
iperm,
storedlength,
storedvalues

# Like `SparseArraysBase.eachstoredindex` but
# at the block level, i.e. iterates over the
Expand Down Expand Up @@ -154,6 +161,39 @@ end
return a
end

# Version of `permutedims!` that assumes the destination and source
# have the same blocking.
# TODO: Delete this and handle this logic in block sparse `map!`.
function blocksparse_permutedims!(a_dest::AbstractArray, a_src::AbstractArray, perm)
blocks(a_dest) .= blocks(PermutedDimsArray(a_src, perm))
return a_dest
end

# We overload `permutedims` here so that we can assume the destination and source
# have the same blocking and avoid non-GPU friendly slicing operations in block sparse `map!`.
# TODO: Delete this and handle this logic in block sparse `map!`.
@interface ::AbstractBlockSparseArrayInterface function Base.permutedims(
a::AbstractArray, perm
)
a_dest = similar(PermutedDimsArray(a, perm))
blocksparse_permutedims!(a_dest, a, perm)
return a_dest
end

# We overload `permutedims!` here so that we can special case when the destination and source
# have the same blocking and avoid non-GPU friendly slicing operations in block sparse `map!`.
# TODO: Delete this and handle this logic in block sparse `map!`.
@interface ::AbstractBlockSparseArrayInterface function Base.permutedims!(

Check warning on line 186 in src/blocksparsearrayinterface/blocksparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/blocksparsearrayinterface/blocksparsearrayinterface.jl#L186

Added line #L186 was not covered by tests
a_dest::AbstractArray, a_src::AbstractArray, perm
)
if all(blockisequal.(axes(a_dest), axes(PermutedDimsArray(a_src, perm))))
blocksparse_permutedims!(a_dest, a_src, perm)
return a_dest

Check warning on line 191 in src/blocksparsearrayinterface/blocksparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/blocksparsearrayinterface/blocksparsearrayinterface.jl#L189-L191

Added lines #L189 - L191 were not covered by tests
end
@interface DefaultArrayInterface() permutedims!(a_dest, a_src, perm)
return a_dest

Check warning on line 194 in src/blocksparsearrayinterface/blocksparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/blocksparsearrayinterface/blocksparsearrayinterface.jl#L193-L194

Added lines #L193 - L194 were not covered by tests
end

@interface ::AbstractBlockSparseArrayInterface function Base.fill!(a::AbstractArray, value)
# TODO: Only do this check if `value isa Number`?
if iszero(value)
Expand Down Expand Up @@ -190,6 +230,7 @@ _getindices(i::CartesianIndex, indices) = CartesianIndex(_getindices(Tuple(i), i

# Represents the array of arrays of a `PermutedDimsArray`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `PermutedDimsArray`.
# TODO: Delete this in favor of `NestedPermutedDimsArrays.NestedPermutedDimsArray`.
struct SparsePermutedDimsArrayBlocks{
T,N,BlockType<:AbstractArray{T,N},Array<:PermutedDimsArray{T,N}
} <: AbstractSparseArray{BlockType,N}
Expand All @@ -203,23 +244,31 @@ end
function Base.size(a::SparsePermutedDimsArrayBlocks)
return _getindices(size(blocks(parent(a.array))), _perm(a.array))
end
function Base.getindex(
function SparseArraysBase.isstored(
a::SparsePermutedDimsArrayBlocks{<:Any,N}, index::Vararg{Int,N}
) where {N}
return isstored(blocks(parent(a.array)), _getindices(index, _invperm(a.array))...)
end
function SparseArraysBase.getstoredindex(
a::SparsePermutedDimsArrayBlocks{<:Any,N}, index::Vararg{Int,N}
) where {N}
return PermutedDimsArray(
blocks(parent(a.array))[_getindices(index, _invperm(a.array))...], _perm(a.array)
getstoredindex(blocks(parent(a.array)), _getindices(index, _invperm(a.array))...),
_perm(a.array),
)
end
function SparseArraysBase.getunstoredindex(
a::SparsePermutedDimsArrayBlocks{<:Any,N}, index::Vararg{Int,N}
) where {N}
return PermutedDimsArray(
getunstoredindex(blocks(parent(a.array)), _getindices(index, _invperm(a.array))...),
_perm(a.array),
)
end
function SparseArraysBase.eachstoredindex(a::SparsePermutedDimsArrayBlocks)
return map(I -> _getindices(I, _perm(a.array)), eachstoredindex(blocks(parent(a.array))))
end
# TODO: Either make this the generic interface or define
# `SparseArraysBase.sparse_storage`, which is used
# to defined this.
function SparseArraysBase.storedlength(a::SparsePermutedDimsArrayBlocks)
return length(eachstoredindex(a))
end
## TODO: Delete.
## TODO: Define `storedvalues` instead.
## function SparseArraysBase.sparse_storage(a::SparsePermutedDimsArrayBlocks)
## return error("Not implemented")
## end
Expand Down
2 changes: 1 addition & 1 deletion test/basics/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ arrayts = (Array, JLArray)
a[Block(3, 2, 2, 3)] = dev(randn(elt, 1, 2, 2, 1))
perm = (2, 3, 4, 1)
for b in (PermutedDimsArray(a, perm), permutedims(a, perm))
@test Array(b) == permutedims(Array(a), perm)
@test @allowscalar(Array(b)) == permutedims(Array(a), perm)
@test issetequal(eachblockstoredindex(b), [Block(2, 2, 3, 3)])
@test @allowscalar b[Block(2, 2, 3, 3)] == permutedims(a[Block(3, 2, 2, 3)], perm)
end
Expand Down
2 changes: 0 additions & 2 deletions test/basics/test_extensions.jl

This file was deleted.

File renamed without changes.
File renamed without changes.

0 comments on commit c9702e7

Please sign in to comment.