Skip to content

Commit

Permalink
[TensorAlgebra] Empty blockedperm blocks, mat-vecs in contract
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed May 24, 2024
1 parent e08e131 commit ef9e1a5
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 8 deletions.
4 changes: 1 addition & 3 deletions NDTensors/src/lib/TensorAlgebra/src/blockedpermutation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,7 @@ end

BlockArrays.blocks(blockedperm::BlockedPermutation) = getfield(blockedperm, :blocks)

function blockedperm(length::Val, permblocks_maybe_empty::Tuple{Vararg{Int}}...)
# Drop empty blocks
permblocks = filter(!isempty, permblocks_maybe_empty)
function blockedperm(length::Val, permblocks::Tuple{Vararg{Int}}...)
@assert value(length) == sum(Base.length, permblocks; init=zero(Bool))
blockedperm = _BlockedPermutation(permblocks)
@assert isperm(blockedperm)
Expand Down
36 changes: 34 additions & 2 deletions NDTensors/src/lib/TensorAlgebra/src/contract/allocate_output.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,43 @@ function output_axes(
α::Number=true,
)
axes_contracted = blockpermute(axes(a1), perm1)
axes_contracted2 = blockpermute(axes(a2), perm2)
@assert axes_contracted == axes_contracted2
axes_contracted′ = blockpermute(axes(a2), perm2)
@assert axes_contracted == axes_contracted′
return ()
end

# Vec-mat.
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{1},
a1::AbstractArray,
perm1::BlockedPermutation{1},
a2::AbstractArray,
biperm2::BlockedPermutation{2},
α::Number=true,
)
(axes_contracted,) = blockpermute(axes(a1), perm1)
axes_contracted′, axes_dest = blockpermute(axes(a2), biperm2)
@assert axes_contracted == axes_contracted′
return genperm((axes_dest...,), invperm(Tuple(perm_dest)))
end

# Mat-vec.
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{1},
a1::AbstractArray,
perm1::BlockedPermutation{2},
a2::AbstractArray,
biperm2::BlockedPermutation{1},
α::Number=true,
)
axes_dest, axes_contracted = blockpermute(axes(a1), perm1)
(axes_contracted′,) = blockpermute(axes(a2), biperm2)
@assert axes_contracted == axes_contracted′
return genperm((axes_dest...,), invperm(Tuple(perm_dest)))
end

# TODO: Use `ArrayLayouts`-like `MulAdd` object,
# i.e. `ContractAdd`?
function allocate_output(
Expand Down
9 changes: 6 additions & 3 deletions NDTensors/src/lib/TensorAlgebra/src/contract/blockedperms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)
perm_codomain2 = BaseExtensions.indexin(contracted, dimnames2)
perm_domain2 = BaseExtensions.indexin(domain, dimnames2)

biperm_dest = blockedperm(perm_codomain_dest, perm_domain_dest)
biperm1 = blockedperm(perm_codomain1, perm_domain1)
biperm2 = blockedperm(perm_codomain2, perm_domain2)
permblocks_dest = (perm_codomain_dest, perm_domain_dest)
biperm_dest = blockedperm(filter(!isempty, permblocks_dest)...)
permblocks1 = (perm_codomain1, perm_domain1)
biperm1 = blockedperm(filter(!isempty, permblocks1)...)
permblocks2 = (perm_codomain2, perm_domain2)
biperm2 = blockedperm(filter(!isempty, permblocks2)...)
return biperm_dest, biperm1, biperm2
end
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,19 @@ function _mul!(
a_dest[] = transpose(a1) * a2 * α + a_dest[] * β
return a_dest
end

# Vec-mat.
function _mul!(
a_dest::AbstractVector, a1::AbstractVector, a2::AbstractMatrix, α::Number, β::Number
)
mul!(transpose(a_dest), transpose(a1), a2, α, β)
return a_dest
end

# Mat-vec.
function _mul!(
a_dest::AbstractVector, a1::AbstractMatrix, a2::AbstractVector, α::Number, β::Number
)
mul!(a_dest, a1, a2, α, β)
return a_dest
end

0 comments on commit ef9e1a5

Please sign in to comment.