Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] Array storage combiner contraction refactor #1237

Merged
merged 18 commits into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ include("empty/adapt.jl")
#####################################
# Array Tensor (experimental)
#

# TODO: Turn this into a module `CombinerArray`.
include("arraystorage/combiner/storage/combinerarray.jl")

include("arraystorage/arraystorage/storage/arraystorage.jl")
include("arraystorage/arraystorage/storage/conj.jl")
include("arraystorage/arraystorage/storage/permutedims.jl")
Expand All @@ -157,14 +161,22 @@ include("arraystorage/diagonalarray/tensor/contract.jl")
include("arraystorage/blocksparsearray/storage/unwrap.jl")
include("arraystorage/blocksparsearray/storage/contract.jl")

include("arraystorage/blocksparsearray/tensor/combiner/contract.jl")
include("arraystorage/blocksparsearray/tensor/combiner/contract_uncombine.jl")
## TODO: Delete once it is rewritten for array storage types.
## include("arraystorage/blocksparsearray/tensor/combiner/contract_uncombine.jl")

# Combiner storage
include("arraystorage/combiner/storage/promote_rule.jl")
include("arraystorage/combiner/storage/contract_utils.jl")
include("arraystorage/combiner/storage/contract.jl")

include("arraystorage/combiner/tensor/to_arraystorage.jl")
include("arraystorage/combiner/tensor/contract.jl")

include("arraystorage/blocksparsearray/storage/combiner/contract.jl")
include("arraystorage/blocksparsearray/storage/combiner/contract_utils.jl")
include("arraystorage/blocksparsearray/storage/combiner/contract_combine.jl")
include("arraystorage/blocksparsearray/storage/combiner/contract_uncombine.jl")

#####################################
# Deprecations
#
Expand Down
21 changes: 21 additions & 0 deletions NDTensors/src/arraystorage/arraystorage/storage/arraystorage.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ const ArrayStorage{T,N} = Union{
StridedView{T,N},
DiagonalArray{T,N},
BlockSparseArray{T,N},
CombinerArray{N},
}

const MatrixStorage{T} = Union{
Expand All @@ -25,6 +26,26 @@ const MatrixStorage{T} = Union{

const MatrixOrArrayStorage{T} = Union{MatrixStorage{T},ArrayStorage{T}}

# TODO: Delete this, it is a hack to decide
# if an Index is blocked.
function is_blocked_ind(i)
return try
blockdim(i, 1)
true
catch
false
end
end

# TODO: Delete once `TensorStorage` is removed.
function to_axes(inds::Tuple)
if any(is_blocked_ind, inds)
return BlockArrays.blockedrange.(map(i -> [blockdim(i, b) for b in 1:nblocks(i)], inds))
else
return Base.OneTo.(dim.(inds))
end
end

# TODO: Delete once `Dense` is removed.
function to_arraystorage(x::DenseTensor)
return tensor(reshape(data(x), size(x)), inds(x))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
function contract(a_src::BlockSparseArray, labels_src, a_comb::CombinerArray, labels_comb)
# TODO: Special cases for index replacement, need
# to check for trivial block permutations.
return if is_combining(a_src, labels_src, a_comb, labels_comb)
contract_combine(a_src, labels_src, a_comb, labels_comb)
else
# TODO: Check this is actually uncombining.
contract_uncombine(a_src, labels_src, a_comb, labels_comb)
end
end

function contract(a_comb::CombinerArray, labels_comb, a_src::BlockSparseArray, labels_src)
return contract(a_src, labels_src, a_comb, labels_comb)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
function contract_combine(
a_src::BlockSparseArray, labels_src, a_comb::CombinerArray, labels_comb
)
labels_dest = contract_labels(labels_comb, labels_src)
axes_dest = contract_inds(axes(a_comb), labels_comb, axes(a_src), labels_src, labels_dest)

## TODO: Add this back.
## #<fermions>:
## a_src = before_combiner_signs(
## a_src,
## labels_src,
## axes(a_src),
## a_comb,
## labels_comb,
## axes(a_comb),
## labels_dest,
## axes_dest,
## )

# Account for permutation of data.
cpos_in_labels_comb = 1
clabel = labels_comb[cpos_in_labels_comb]
labels_uc = deleteat(labels_comb, cpos_in_labels_comb)
cpos_in_labels_dest = findfirst(==(clabel), labels_dest)
labels_dest_uc = insertat(labels_dest, labels_uc, cpos_in_labels_dest)
perm = getperm(labels_dest_uc, labels_src)
ucpos_in_labels_src = Tuple(findall(x -> x in labels_uc, labels_src))
a_dest = permutedims_combine(
a_src, axes_dest, perm, ucpos_in_labels_src, blockperm(a_comb), blockcomb(a_comb)
)

return a_dest, labels_dest
end

function permutedims_combine(
a_src::BlockSparseArray,
axes_dest,
perm::Tuple,
combdims::Tuple,
blockperm::Vector{Int},
blockcomb::Vector{Int},
)
a_dest = permutedims_combine_output(
a_src, axes_dest, perm, combdims, blockperm, blockcomb
)

# Permute the indices
axes_perm = permute(axes(a_src), perm)

# Now that the indices are permuted, compute
# which indices are now combined
combdims_perm = sort(_permute_combdims(combdims, perm))
comb_ind_loc = minimum(combdims_perm)

# Determine the new index before combining
axes_to_combine = getindices(axes_perm, combdims_perm)
axis_comb = ⊗(axes_to_combine...)
axis_comb = BlockArrays.blockedrange(length.(BlockArrays.blocks(axis_comb)[blockperm]))

for b in nzblocks(a_src)
a_src_b = @view a_src[b]
b_perm = permute(b, perm)
b_perm_comb = combine_dims(b_perm, axes_perm, combdims_perm)
b_perm_comb = perm_block(b_perm_comb, comb_ind_loc, blockperm)
# TODO: Wrap this in `BlockArrays.Block`?
b_in_combined_dim = b_perm_comb.n[comb_ind_loc]
new_b_in_combined_dim = blockcomb[b_in_combined_dim]
offset = 0
pos_in_new_combined_block = 1
while b_in_combined_dim - pos_in_new_combined_block > 0 &&
blockcomb[b_in_combined_dim - pos_in_new_combined_block] == new_b_in_combined_dim
# offset += blockdim(axis_comb, b_in_combined_dim - pos_in_new_combined_block)
offset += length(
axis_comb[BlockArrays.Block(b_in_combined_dim - pos_in_new_combined_block)]
)
pos_in_new_combined_block += 1
end
b_dest = setindex(b_perm_comb, new_b_in_combined_dim, comb_ind_loc)
a_dest_b_total = @view a_dest[b_dest]
# dimsa_dest_b_tot = size(a_dest_b_total)

# TODO: Simplify this code.
subind = ntuple(ndims(a_src) - length(combdims) + 1) do i
if i == comb_ind_loc
range(
1 + offset;
stop=offset + length(axis_comb[BlockArrays.Block(b_in_combined_dim)]),
)
else
range(1; stop=size(a_dest_b_total)[i])
end
end

a_dest_b = @view a_dest_b_total[subind...]
a_dest_b = reshape(a_dest_b, permute(size(a_src_b), perm))
# TODO: Make this `convert` call more general
# for GPUs using `unwrap_type`.
a_src_bₐ = convert(Array, a_src_b)
# TODO: Use `expose` to make more efficient and robust.
permutedims!(a_dest_b, a_src_bₐ, perm)
end

return a_dest
end

function permutedims_combine_output(
a_src::BlockSparseArray,
axes_dest,
perm::Tuple,
combdims::Tuple,
blockperm::Vector{Int},
blockcomb::Vector{Int},
)
# Permute the indices
axes_src = axes(a_src)
axes_perm = permute(axes_src, perm)

# Now that the indices are permuted, compute
# which indices are now combined
combdims_perm = sort(_permute_combdims(combdims, perm))

# Permute the nonzero blocks (dimension-wise)
blocks = nzblocks(a_src)

# TODO: Use `permute.(blocks, perm)`.
blocks_perm = BlockArrays.Block.(permute.(getfield.(blocks, :n), Ref(perm)))

# Combine the nonzero blocks (dimension-wise)
blocks_perm_comb = combine_dims(blocks_perm, axes_perm, combdims_perm)

# Permute the blocks (within the newly combined dimension)
comb_ind_loc = minimum(combdims_perm)
blocks_perm_comb = perm_blocks(blocks_perm_comb, comb_ind_loc, blockperm)
blocks_perm_comb = sort(blocks_perm_comb; lt=isblockless)

# Combine the blocks (within the newly combined and permuted dimension)
blocks_perm_comb = combine_blocks(blocks_perm_comb, comb_ind_loc, blockcomb)
T = eltype(a_src)
N = length(axes_dest)
B = set_ndims(unwrap_type(a_src), length(axes_dest))
return BlockSparseArray{T,N,B}(undef, blocks_perm_comb, axes_dest)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
function contract_inds_uncombine(inds_src::Tuple, labels_src, inds_comb::Tuple, labels_comb)
cpos_in_labels_comb = 1
clabel = labels_comb[cpos_in_labels_comb]
labels_uc = deleteat(labels_comb, cpos_in_labels_comb)
labels_dest = labels_src
cpos_in_labels_dest = findfirst(==(clabel), labels_dest)
# Move combined index to first position
perm = ntuple(identity, length(inds_src))
if cpos_in_labels_dest != 1
labels_dest_orig = labels_dest
labels_dest = deleteat(labels_dest, cpos_in_labels_dest)
labels_dest = insertafter(labels_dest, clabel, 0)
cpos_in_labels_dest = 1
perm = getperm(labels_dest, labels_dest_orig)
inds_src = permute(inds_src, perm)
labels_src = permute(labels_src, perm)
end
labels_dest = insertat(labels_dest, labels_uc, cpos_in_labels_dest)
inds_dest = contract_inds(inds_comb, labels_comb, inds_src, labels_src, labels_dest)
return inds_dest, labels_dest, perm, cpos_in_labels_dest
end

function contract_uncombine(
a_src::BlockSparseArray, labels_src, a_comb::CombinerArray, labels_comb
)
axes_dest, labels_dest, perm, cpos_in_labels_dest = contract_inds_uncombine(
axes(a_src), labels_src, axes(a_comb), labels_comb
)
a_src = permutedims(a_src, perm)

## TODO: Add this back.
## # <fermions>:
## a_src = before_combiner_signs(
## a_src,
## labels_src,
## axes(a_src),
## a_comb,
## labels_comb,
## axes(a_comb),
## labels_dest,
## axes_dest,
## )

a_dest = uncombine(
a_src,
labels_src,
axes_dest,
labels_dest,
cpos_in_labels_dest,
blockperm(a_comb),
blockcomb(a_comb),
)

## TODO: Add this back.
## # <fermions>:
## a_dest = after_combiner_signs(
## a_dest,
## labels_dest,
## axes_dest,
## a_comb,
## labels_comb,
## axes(a_comb),
## )

return a_dest, labels_dest
end

function uncombine(
a_src::BlockSparseArray,
labels_src,
axes_dest,
labels_dest,
combdim::Int,
blockperm::Vector{Int},
blockcomb::Vector{Int},
)
a_dest = uncombine_output(
a_src, labels_src, axes_dest, labels_dest, combdim, blockperm, blockcomb
)
invblockperm = invperm(blockperm)
# This is needed for reshaping the block
# TODO: It is already calculated in uncombine_output, use it from there
labels_uncomb_perm = setdiff(labels_dest, labels_src)
ind_uncomb_perm = ⊗(
axes_dest[map(x -> findfirst(==(x), labels_dest), labels_uncomb_perm)]...
)
ind_uncomb = BlockArrays.blockedrange(
length.(BlockArrays.blocks(ind_uncomb_perm)[blockperm])
)
# Same as axes(a_src) but with the blocks uncombined
axes_uncomb = insertat(axes(a_src), ind_uncomb, combdim)
axes_uncomb_perm = insertat(axes(a_src), ind_uncomb_perm, combdim)
for b in nzblocks(a_src)
a_src_b_tot = @view a_src[b]
bs_uncomb = uncombine_block(b, combdim, blockcomb)
offset = 0
for i in 1:length(bs_uncomb)
b_uncomb = bs_uncomb[i]
b_uncomb_perm = perm_block(b_uncomb, combdim, invblockperm)
b_uncomb_perm_reshape = reshape(b_uncomb_perm, axes_uncomb_perm, axes_dest)
a_dest_b = @view a_dest[b_uncomb_perm_reshape]
b_uncomb_in_combined_dim = b_uncomb_perm.n[combdim]
start = offset + 1
stop = offset + length(ind_uncomb_perm[BlockArrays.Block(b_uncomb_in_combined_dim)])
subind = ntuple(
i -> i == combdim ? range(start; stop=stop) : range(1; stop=size(a_src_b_tot)[i]),
ndims(a_src),
)
offset = stop
a_src_b = @view a_src_b_tot[subind...]

# Alternative (but maybe slower):
#copyto!(a_dest_b, a_src_b)

if length(a_src_b) == 1
# Call `cpu` to avoid allowscalar error on GPU.
# TODO: a_desteplace with `@allowscalar`, requires adding
# `GPUArraysCore.jl` as a dependency, or use `expose`.
a_dest_b[] = cpu(a_src_b)[]
else
# TODO: Use `unspecify_parameters(unwrap_type(a_src))` intead of `Array`.
a_dest_bₐ = convert(Array, a_dest_b)
a_dest_bₐᵣ = reshape(a_dest_bₐ, size(a_src_b))
copyto!(expose(a_dest_bₐᵣ), expose(a_src_b))
end
end
end
return a_dest
end

function uncombine_output(
a_src::BlockSparseArray,
labels_src,
axes_dest,
labels_dest,
combdim::Int,
blockperm::Vector{Int},
blockcomb::Vector{Int},
)
labels_uncomb_perm = setdiff(labels_dest, labels_src)
ind_uncomb_perm = ⊗(
axes_dest[map(x -> findfirst(==(x), labels_dest), labels_uncomb_perm)]...
)
axes_uncomb_perm = insertat(axes(a_src), ind_uncomb_perm, combdim)
# Uncombine the blocks of a_src
blocks_uncomb = uncombine_blocks(nzblocks(a_src), combdim, blockcomb)
blocks_uncomb_perm = perm_blocks(blocks_uncomb, combdim, invperm(blockperm))

# TODO: Should this be zero data instead of undef?
T = eltype(a_src)
N = length(axes_uncomb_perm)
B = unwrap_type(a_src)
a_uncomb_perm = BlockSparseArray{T,N,B}(undef, blocks_uncomb_perm, axes_uncomb_perm)
return reshape(a_uncomb_perm, axes_dest)
end
Loading
Loading