Skip to content

Commit

Permalink
[NDTensors] Finish implementation of array storage combiner contracti…
Browse files Browse the repository at this point in the history
…on (#1237)
  • Loading branch information
mtfishman authored Nov 4, 2023
1 parent 9a7380f commit 351770e
Show file tree
Hide file tree
Showing 15 changed files with 666 additions and 551 deletions.
16 changes: 14 additions & 2 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ include("empty/adapt.jl")
#####################################
# Array Tensor (experimental)
#

# TODO: Turn this into a module `CombinerArray`.
include("arraystorage/combiner/storage/combinerarray.jl")

include("arraystorage/arraystorage/storage/arraystorage.jl")
include("arraystorage/arraystorage/storage/conj.jl")
include("arraystorage/arraystorage/storage/permutedims.jl")
Expand All @@ -157,14 +161,22 @@ include("arraystorage/diagonalarray/tensor/contract.jl")
include("arraystorage/blocksparsearray/storage/unwrap.jl")
include("arraystorage/blocksparsearray/storage/contract.jl")

include("arraystorage/blocksparsearray/tensor/combiner/contract.jl")
include("arraystorage/blocksparsearray/tensor/combiner/contract_uncombine.jl")
## TODO: Delete once it is rewritten for array storage types.
## include("arraystorage/blocksparsearray/tensor/combiner/contract_uncombine.jl")

# Combiner storage
include("arraystorage/combiner/storage/promote_rule.jl")
include("arraystorage/combiner/storage/contract_utils.jl")
include("arraystorage/combiner/storage/contract.jl")

include("arraystorage/combiner/tensor/to_arraystorage.jl")
include("arraystorage/combiner/tensor/contract.jl")

include("arraystorage/blocksparsearray/storage/combiner/contract.jl")
include("arraystorage/blocksparsearray/storage/combiner/contract_utils.jl")
include("arraystorage/blocksparsearray/storage/combiner/contract_combine.jl")
include("arraystorage/blocksparsearray/storage/combiner/contract_uncombine.jl")

#####################################
# Deprecations
#
Expand Down
21 changes: 21 additions & 0 deletions NDTensors/src/arraystorage/arraystorage/storage/arraystorage.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ const ArrayStorage{T,N} = Union{
StridedView{T,N},
DiagonalArray{T,N},
BlockSparseArray{T,N},
CombinerArray{N},
}

const MatrixStorage{T} = Union{
Expand All @@ -25,6 +26,26 @@ const MatrixStorage{T} = Union{

const MatrixOrArrayStorage{T} = Union{MatrixStorage{T},ArrayStorage{T}}

# TODO: Delete this, it is a hack to decide
# if an Index is blocked.
function is_blocked_ind(i)
return try
blockdim(i, 1)
true
catch
false
end
end

# TODO: Delete once `TensorStorage` is removed.
function to_axes(inds::Tuple)
if any(is_blocked_ind, inds)
return BlockArrays.blockedrange.(map(i -> [blockdim(i, b) for b in 1:nblocks(i)], inds))
else
return Base.OneTo.(dim.(inds))
end
end

# TODO: Delete once `Dense` is removed.
function to_arraystorage(x::DenseTensor)
return tensor(reshape(data(x), size(x)), inds(x))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
function contract(a_src::BlockSparseArray, labels_src, a_comb::CombinerArray, labels_comb)
# TODO: Special cases for index replacement, need
# to check for trivial block permutations.
return if is_combining(a_src, labels_src, a_comb, labels_comb)
contract_combine(a_src, labels_src, a_comb, labels_comb)
else
# TODO: Check this is actually uncombining.
contract_uncombine(a_src, labels_src, a_comb, labels_comb)
end
end

function contract(a_comb::CombinerArray, labels_comb, a_src::BlockSparseArray, labels_src)
return contract(a_src, labels_src, a_comb, labels_comb)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
function contract_combine(
a_src::BlockSparseArray, labels_src, a_comb::CombinerArray, labels_comb
)
labels_dest = contract_labels(labels_comb, labels_src)
axes_dest = contract_inds(axes(a_comb), labels_comb, axes(a_src), labels_src, labels_dest)

## TODO: Add this back.
## #<fermions>:
## a_src = before_combiner_signs(
## a_src,
## labels_src,
## axes(a_src),
## a_comb,
## labels_comb,
## axes(a_comb),
## labels_dest,
## axes_dest,
## )

# Account for permutation of data.
cpos_in_labels_comb = 1
clabel = labels_comb[cpos_in_labels_comb]
labels_uc = deleteat(labels_comb, cpos_in_labels_comb)
cpos_in_labels_dest = findfirst(==(clabel), labels_dest)
labels_dest_uc = insertat(labels_dest, labels_uc, cpos_in_labels_dest)
perm = getperm(labels_dest_uc, labels_src)
ucpos_in_labels_src = Tuple(findall(x -> x in labels_uc, labels_src))
a_dest = permutedims_combine(
a_src, axes_dest, perm, ucpos_in_labels_src, blockperm(a_comb), blockcomb(a_comb)
)

return a_dest, labels_dest
end

function permutedims_combine(
a_src::BlockSparseArray,
axes_dest,
perm::Tuple,
combdims::Tuple,
blockperm::Vector{Int},
blockcomb::Vector{Int},
)
a_dest = permutedims_combine_output(
a_src, axes_dest, perm, combdims, blockperm, blockcomb
)

# Permute the indices
axes_perm = permute(axes(a_src), perm)

# Now that the indices are permuted, compute
# which indices are now combined
combdims_perm = sort(_permute_combdims(combdims, perm))
comb_ind_loc = minimum(combdims_perm)

# Determine the new index before combining
axes_to_combine = getindices(axes_perm, combdims_perm)
axis_comb = (axes_to_combine...)
axis_comb = BlockArrays.blockedrange(length.(BlockArrays.blocks(axis_comb)[blockperm]))

for b in nzblocks(a_src)
a_src_b = @view a_src[b]
b_perm = permute(b, perm)
b_perm_comb = combine_dims(b_perm, axes_perm, combdims_perm)
b_perm_comb = perm_block(b_perm_comb, comb_ind_loc, blockperm)
# TODO: Wrap this in `BlockArrays.Block`?
b_in_combined_dim = b_perm_comb.n[comb_ind_loc]
new_b_in_combined_dim = blockcomb[b_in_combined_dim]
offset = 0
pos_in_new_combined_block = 1
while b_in_combined_dim - pos_in_new_combined_block > 0 &&
blockcomb[b_in_combined_dim - pos_in_new_combined_block] == new_b_in_combined_dim
# offset += blockdim(axis_comb, b_in_combined_dim - pos_in_new_combined_block)
offset += length(
axis_comb[BlockArrays.Block(b_in_combined_dim - pos_in_new_combined_block)]
)
pos_in_new_combined_block += 1
end
b_dest = setindex(b_perm_comb, new_b_in_combined_dim, comb_ind_loc)
a_dest_b_total = @view a_dest[b_dest]
# dimsa_dest_b_tot = size(a_dest_b_total)

# TODO: Simplify this code.
subind = ntuple(ndims(a_src) - length(combdims) + 1) do i
if i == comb_ind_loc
range(
1 + offset;
stop=offset + length(axis_comb[BlockArrays.Block(b_in_combined_dim)]),
)
else
range(1; stop=size(a_dest_b_total)[i])
end
end

a_dest_b = @view a_dest_b_total[subind...]
a_dest_b = reshape(a_dest_b, permute(size(a_src_b), perm))
# TODO: Make this `convert` call more general
# for GPUs using `unwrap_type`.
a_src_bₐ = convert(Array, a_src_b)
# TODO: Use `expose` to make more efficient and robust.
permutedims!(a_dest_b, a_src_bₐ, perm)
end

return a_dest
end

function permutedims_combine_output(
a_src::BlockSparseArray,
axes_dest,
perm::Tuple,
combdims::Tuple,
blockperm::Vector{Int},
blockcomb::Vector{Int},
)
# Permute the indices
axes_src = axes(a_src)
axes_perm = permute(axes_src, perm)

# Now that the indices are permuted, compute
# which indices are now combined
combdims_perm = sort(_permute_combdims(combdims, perm))

# Permute the nonzero blocks (dimension-wise)
blocks = nzblocks(a_src)

# TODO: Use `permute.(blocks, perm)`.
blocks_perm = BlockArrays.Block.(permute.(getfield.(blocks, :n), Ref(perm)))

# Combine the nonzero blocks (dimension-wise)
blocks_perm_comb = combine_dims(blocks_perm, axes_perm, combdims_perm)

# Permute the blocks (within the newly combined dimension)
comb_ind_loc = minimum(combdims_perm)
blocks_perm_comb = perm_blocks(blocks_perm_comb, comb_ind_loc, blockperm)
blocks_perm_comb = sort(blocks_perm_comb; lt=isblockless)

# Combine the blocks (within the newly combined and permuted dimension)
blocks_perm_comb = combine_blocks(blocks_perm_comb, comb_ind_loc, blockcomb)
T = eltype(a_src)
N = length(axes_dest)
B = set_ndims(unwrap_type(a_src), length(axes_dest))
return BlockSparseArray{T,N,B}(undef, blocks_perm_comb, axes_dest)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
function contract_inds_uncombine(inds_src::Tuple, labels_src, inds_comb::Tuple, labels_comb)
cpos_in_labels_comb = 1
clabel = labels_comb[cpos_in_labels_comb]
labels_uc = deleteat(labels_comb, cpos_in_labels_comb)
labels_dest = labels_src
cpos_in_labels_dest = findfirst(==(clabel), labels_dest)
# Move combined index to first position
perm = ntuple(identity, length(inds_src))
if cpos_in_labels_dest != 1
labels_dest_orig = labels_dest
labels_dest = deleteat(labels_dest, cpos_in_labels_dest)
labels_dest = insertafter(labels_dest, clabel, 0)
cpos_in_labels_dest = 1
perm = getperm(labels_dest, labels_dest_orig)
inds_src = permute(inds_src, perm)
labels_src = permute(labels_src, perm)
end
labels_dest = insertat(labels_dest, labels_uc, cpos_in_labels_dest)
inds_dest = contract_inds(inds_comb, labels_comb, inds_src, labels_src, labels_dest)
return inds_dest, labels_dest, perm, cpos_in_labels_dest
end

function contract_uncombine(
a_src::BlockSparseArray, labels_src, a_comb::CombinerArray, labels_comb
)
axes_dest, labels_dest, perm, cpos_in_labels_dest = contract_inds_uncombine(
axes(a_src), labels_src, axes(a_comb), labels_comb
)
a_src = permutedims(a_src, perm)

## TODO: Add this back.
## # <fermions>:
## a_src = before_combiner_signs(
## a_src,
## labels_src,
## axes(a_src),
## a_comb,
## labels_comb,
## axes(a_comb),
## labels_dest,
## axes_dest,
## )

a_dest = uncombine(
a_src,
labels_src,
axes_dest,
labels_dest,
cpos_in_labels_dest,
blockperm(a_comb),
blockcomb(a_comb),
)

## TODO: Add this back.
## # <fermions>:
## a_dest = after_combiner_signs(
## a_dest,
## labels_dest,
## axes_dest,
## a_comb,
## labels_comb,
## axes(a_comb),
## )

return a_dest, labels_dest
end

function uncombine(
a_src::BlockSparseArray,
labels_src,
axes_dest,
labels_dest,
combdim::Int,
blockperm::Vector{Int},
blockcomb::Vector{Int},
)
a_dest = uncombine_output(
a_src, labels_src, axes_dest, labels_dest, combdim, blockperm, blockcomb
)
invblockperm = invperm(blockperm)
# This is needed for reshaping the block
# TODO: It is already calculated in uncombine_output, use it from there
labels_uncomb_perm = setdiff(labels_dest, labels_src)
ind_uncomb_perm = (
axes_dest[map(x -> findfirst(==(x), labels_dest), labels_uncomb_perm)]...
)
ind_uncomb = BlockArrays.blockedrange(
length.(BlockArrays.blocks(ind_uncomb_perm)[blockperm])
)
# Same as axes(a_src) but with the blocks uncombined
axes_uncomb = insertat(axes(a_src), ind_uncomb, combdim)
axes_uncomb_perm = insertat(axes(a_src), ind_uncomb_perm, combdim)
for b in nzblocks(a_src)
a_src_b_tot = @view a_src[b]
bs_uncomb = uncombine_block(b, combdim, blockcomb)
offset = 0
for i in 1:length(bs_uncomb)
b_uncomb = bs_uncomb[i]
b_uncomb_perm = perm_block(b_uncomb, combdim, invblockperm)
b_uncomb_perm_reshape = reshape(b_uncomb_perm, axes_uncomb_perm, axes_dest)
a_dest_b = @view a_dest[b_uncomb_perm_reshape]
b_uncomb_in_combined_dim = b_uncomb_perm.n[combdim]
start = offset + 1
stop = offset + length(ind_uncomb_perm[BlockArrays.Block(b_uncomb_in_combined_dim)])
subind = ntuple(
i -> i == combdim ? range(start; stop=stop) : range(1; stop=size(a_src_b_tot)[i]),
ndims(a_src),
)
offset = stop
a_src_b = @view a_src_b_tot[subind...]

# Alternative (but maybe slower):
#copyto!(a_dest_b, a_src_b)

if length(a_src_b) == 1
# Call `cpu` to avoid allowscalar error on GPU.
# TODO: a_desteplace with `@allowscalar`, requires adding
# `GPUArraysCore.jl` as a dependency, or use `expose`.
a_dest_b[] = cpu(a_src_b)[]
else
# TODO: Use `unspecify_parameters(unwrap_type(a_src))` intead of `Array`.
a_dest_bₐ = convert(Array, a_dest_b)
a_dest_bₐᵣ = reshape(a_dest_bₐ, size(a_src_b))
copyto!(expose(a_dest_bₐᵣ), expose(a_src_b))
end
end
end
return a_dest
end

function uncombine_output(
a_src::BlockSparseArray,
labels_src,
axes_dest,
labels_dest,
combdim::Int,
blockperm::Vector{Int},
blockcomb::Vector{Int},
)
labels_uncomb_perm = setdiff(labels_dest, labels_src)
ind_uncomb_perm = (
axes_dest[map(x -> findfirst(==(x), labels_dest), labels_uncomb_perm)]...
)
axes_uncomb_perm = insertat(axes(a_src), ind_uncomb_perm, combdim)
# Uncombine the blocks of a_src
blocks_uncomb = uncombine_blocks(nzblocks(a_src), combdim, blockcomb)
blocks_uncomb_perm = perm_blocks(blocks_uncomb, combdim, invperm(blockperm))

# TODO: Should this be zero data instead of undef?
T = eltype(a_src)
N = length(axes_uncomb_perm)
B = unwrap_type(a_src)
a_uncomb_perm = BlockSparseArray{T,N,B}(undef, blocks_uncomb_perm, axes_uncomb_perm)
return reshape(a_uncomb_perm, axes_dest)
end
Loading

0 comments on commit 351770e

Please sign in to comment.