Skip to content

Commit

Permalink
Merge branch 'main' into kmp5/debug/issue_1450
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT authored May 28, 2024
2 parents 0821a51 + 0e6c219 commit 5c6bd99
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 9 deletions.
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.12"
version = "0.3.13"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
4 changes: 1 addition & 3 deletions NDTensors/src/lib/TensorAlgebra/src/blockedpermutation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,7 @@ end

BlockArrays.blocks(blockedperm::BlockedPermutation) = getfield(blockedperm, :blocks)

function blockedperm(length::Val, permblocks_maybe_empty::Tuple{Vararg{Int}}...)
# Drop empty blocks
permblocks = filter(!isempty, permblocks_maybe_empty)
function blockedperm(length::Val, permblocks::Tuple{Vararg{Int}}...)
@assert value(length) == sum(Base.length, permblocks; init=zero(Bool))
blockedperm = _BlockedPermutation(permblocks)
@assert isperm(blockedperm)
Expand Down
36 changes: 34 additions & 2 deletions NDTensors/src/lib/TensorAlgebra/src/contract/allocate_output.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,43 @@ function output_axes(
α::Number=true,
)
axes_contracted = blockpermute(axes(a1), perm1)
axes_contracted2 = blockpermute(axes(a2), perm2)
@assert axes_contracted == axes_contracted2
axes_contracted′ = blockpermute(axes(a2), perm2)
@assert axes_contracted == axes_contracted′
return ()
end

# Vec-mat.
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{1},
a1::AbstractArray,
perm1::BlockedPermutation{1},
a2::AbstractArray,
biperm2::BlockedPermutation{2},
α::Number=true,
)
(axes_contracted,) = blockpermute(axes(a1), perm1)
axes_contracted′, axes_dest = blockpermute(axes(a2), biperm2)
@assert axes_contracted == axes_contracted′
return genperm((axes_dest...,), invperm(Tuple(perm_dest)))
end

# Mat-vec.
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{1},
a1::AbstractArray,
perm1::BlockedPermutation{2},
a2::AbstractArray,
biperm2::BlockedPermutation{1},
α::Number=true,
)
axes_dest, axes_contracted = blockpermute(axes(a1), perm1)
(axes_contracted′,) = blockpermute(axes(a2), biperm2)
@assert axes_contracted == axes_contracted′
return genperm((axes_dest...,), invperm(Tuple(perm_dest)))
end

# TODO: Use `ArrayLayouts`-like `MulAdd` object,
# i.e. `ContractAdd`?
function allocate_output(
Expand Down
9 changes: 6 additions & 3 deletions NDTensors/src/lib/TensorAlgebra/src/contract/blockedperms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)
perm_codomain2 = BaseExtensions.indexin(contracted, dimnames2)
perm_domain2 = BaseExtensions.indexin(domain, dimnames2)

biperm_dest = blockedperm(perm_codomain_dest, perm_domain_dest)
biperm1 = blockedperm(perm_codomain1, perm_domain1)
biperm2 = blockedperm(perm_codomain2, perm_domain2)
permblocks_dest = (perm_codomain_dest, perm_domain_dest)
biperm_dest = blockedperm(filter(!isempty, permblocks_dest)...)
permblocks1 = (perm_codomain1, perm_domain1)
biperm1 = blockedperm(filter(!isempty, permblocks1)...)
permblocks2 = (perm_codomain2, perm_domain2)
biperm2 = blockedperm(filter(!isempty, permblocks2)...)
return biperm_dest, biperm1, biperm2
end
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,19 @@ function _mul!(
a_dest[] = transpose(a1) * a2 * α + a_dest[] * β
return a_dest
end

# Vec-mat.
function _mul!(
a_dest::AbstractVector, a1::AbstractVector, a2::AbstractMatrix, α::Number, β::Number
)
mul!(transpose(a_dest), transpose(a1), a2, α, β)
return a_dest
end

# Mat-vec.
function _mul!(
a_dest::AbstractVector, a1::AbstractMatrix, a2::AbstractVector, α::Number, β::Number
)
mul!(a_dest, a1, a2, α, β)
return a_dest
end
14 changes: 14 additions & 0 deletions NDTensors/src/lib/TensorAlgebra/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test blocklasts(p) == (3, 5)
@test invperm(p) == blockedperm((5, 4, 1), (2, 3))

# Empty block.
p = blockedperm((3, 2), (), (1,))
@test Tuple(p) === (3, 2, 1)
@test isperm(p)
@test length(p) == 3
@test blocks(p) == ((3, 2), (), (1,))
@test blocklength(p) == 3
@test blocklengths(p) == (2, 0, 1)
@test blockfirsts(p) == (1, 3, 3)
@test blocklasts(p) == (2, 2, 3)
@test invperm(p) == blockedperm((3, 2), (), (1,))

# Split collection into `BlockedPermutation`.
p = blockedperm_indexin(("a", "b", "c", "d"), ("c", "a"), ("b", "d"))
@test p == blockedperm((3, 1), (2, 4))
Expand Down Expand Up @@ -120,6 +132,8 @@ end
for (d1s, d2s, d_dests) in (
((1, 2), (1, 2), ()),
((1, 2), (2, 1), ()),
((1, 2), (2, 1, 3), (3,)),
((1, 2, 3), (2, 1), (3,)),
((1, 2), (2, 3), (1, 3)),
((1, 2), (2, 3), (3, 1)),
((2, 1), (2, 3), (3, 1)),
Expand Down

0 comments on commit 5c6bd99

Please sign in to comment.