diff --git a/NDTensors/src/lib/TensorAlgebra/src/blockedpermutation.jl b/NDTensors/src/lib/TensorAlgebra/src/blockedpermutation.jl index 160cd3aec8..c13b68a642 100644 --- a/NDTensors/src/lib/TensorAlgebra/src/blockedpermutation.jl +++ b/NDTensors/src/lib/TensorAlgebra/src/blockedpermutation.jl @@ -160,9 +160,7 @@ end BlockArrays.blocks(blockedperm::BlockedPermutation) = getfield(blockedperm, :blocks) -function blockedperm(length::Val, permblocks_maybe_empty::Tuple{Vararg{Int}}...) - # Drop empty blocks - permblocks = filter(!isempty, permblocks_maybe_empty) +function blockedperm(length::Val, permblocks::Tuple{Vararg{Int}}...) @assert value(length) == sum(Base.length, permblocks; init=zero(Bool)) blockedperm = _BlockedPermutation(permblocks) @assert isperm(blockedperm) diff --git a/NDTensors/src/lib/TensorAlgebra/src/contract/allocate_output.jl b/NDTensors/src/lib/TensorAlgebra/src/contract/allocate_output.jl index 7732fa1258..2beff4c5bc 100644 --- a/NDTensors/src/lib/TensorAlgebra/src/contract/allocate_output.jl +++ b/NDTensors/src/lib/TensorAlgebra/src/contract/allocate_output.jl @@ -30,11 +30,43 @@ function output_axes( α::Number=true, ) axes_contracted = blockpermute(axes(a1), perm1) - axes_contracted2 = blockpermute(axes(a2), perm2) - @assert axes_contracted == axes_contracted2 + axes_contracted′ = blockpermute(axes(a2), perm2) + @assert axes_contracted == axes_contracted′ return () end +# Vec-mat. +function output_axes( + ::typeof(contract), + perm_dest::BlockedPermutation{1}, + a1::AbstractArray, + perm1::BlockedPermutation{1}, + a2::AbstractArray, + biperm2::BlockedPermutation{2}, + α::Number=true, +) + (axes_contracted,) = blockpermute(axes(a1), perm1) + axes_contracted′, axes_dest = blockpermute(axes(a2), biperm2) + @assert axes_contracted == axes_contracted′ + return genperm((axes_dest...,), invperm(Tuple(perm_dest))) +end + +# Mat-vec. +function output_axes( + ::typeof(contract), + perm_dest::BlockedPermutation{1}, + a1::AbstractArray, + perm1::BlockedPermutation{2}, + a2::AbstractArray, + biperm2::BlockedPermutation{1}, + α::Number=true, +) + axes_dest, axes_contracted = blockpermute(axes(a1), perm1) + (axes_contracted′,) = blockpermute(axes(a2), biperm2) + @assert axes_contracted == axes_contracted′ + return genperm((axes_dest...,), invperm(Tuple(perm_dest))) +end + # TODO: Use `ArrayLayouts`-like `MulAdd` object, # i.e. `ContractAdd`? function allocate_output( diff --git a/NDTensors/src/lib/TensorAlgebra/src/contract/blockedperms.jl b/NDTensors/src/lib/TensorAlgebra/src/contract/blockedperms.jl index 22d103c293..60009de9ef 100644 --- a/NDTensors/src/lib/TensorAlgebra/src/contract/blockedperms.jl +++ b/NDTensors/src/lib/TensorAlgebra/src/contract/blockedperms.jl @@ -22,8 +22,11 @@ function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2) perm_codomain2 = BaseExtensions.indexin(contracted, dimnames2) perm_domain2 = BaseExtensions.indexin(domain, dimnames2) - biperm_dest = blockedperm(perm_codomain_dest, perm_domain_dest) - biperm1 = blockedperm(perm_codomain1, perm_domain1) - biperm2 = blockedperm(perm_codomain2, perm_domain2) + permblocks_dest = (perm_codomain_dest, perm_domain_dest) + biperm_dest = blockedperm(filter(!isempty, permblocks_dest)...) + permblocks1 = (perm_codomain1, perm_domain1) + biperm1 = blockedperm(filter(!isempty, permblocks1)...) + permblocks2 = (perm_codomain2, perm_domain2) + biperm2 = blockedperm(filter(!isempty, permblocks2)...) return biperm_dest, biperm1, biperm2 end diff --git a/NDTensors/src/lib/TensorAlgebra/src/contract/contract_matricize/contract.jl b/NDTensors/src/lib/TensorAlgebra/src/contract/contract_matricize/contract.jl index 3ddc38ff76..beb70104bb 100644 --- a/NDTensors/src/lib/TensorAlgebra/src/contract/contract_matricize/contract.jl +++ b/NDTensors/src/lib/TensorAlgebra/src/contract/contract_matricize/contract.jl @@ -39,3 +39,19 @@ function _mul!( a_dest[] = transpose(a1) * a2 * α + a_dest[] * β return a_dest end + +# Vec-mat. +function _mul!( + a_dest::AbstractVector, a1::AbstractVector, a2::AbstractMatrix, α::Number, β::Number +) + mul!(transpose(a_dest), transpose(a1), a2, α, β) + return a_dest +end + +# Mat-vec. +function _mul!( + a_dest::AbstractVector, a1::AbstractMatrix, a2::AbstractVector, α::Number, β::Number +) + mul!(a_dest, a1, a2, α, β) + return a_dest +end