-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[NDTensors] Finish implementation of array storage combiner contracti…
…on (#1237)
- Loading branch information
Showing
15 changed files
with
666 additions
and
551 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
14 changes: 14 additions & 0 deletions
14
NDTensors/src/arraystorage/blocksparsearray/storage/combiner/contract.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
function contract(a_src::BlockSparseArray, labels_src, a_comb::CombinerArray, labels_comb) | ||
# TODO: Special cases for index replacement, need | ||
# to check for trivial block permutations. | ||
return if is_combining(a_src, labels_src, a_comb, labels_comb) | ||
contract_combine(a_src, labels_src, a_comb, labels_comb) | ||
else | ||
# TODO: Check this is actually uncombining. | ||
contract_uncombine(a_src, labels_src, a_comb, labels_comb) | ||
end | ||
end | ||
|
||
function contract(a_comb::CombinerArray, labels_comb, a_src::BlockSparseArray, labels_src) | ||
return contract(a_src, labels_src, a_comb, labels_comb) | ||
end |
142 changes: 142 additions & 0 deletions
142
NDTensors/src/arraystorage/blocksparsearray/storage/combiner/contract_combine.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
function contract_combine( | ||
a_src::BlockSparseArray, labels_src, a_comb::CombinerArray, labels_comb | ||
) | ||
labels_dest = contract_labels(labels_comb, labels_src) | ||
axes_dest = contract_inds(axes(a_comb), labels_comb, axes(a_src), labels_src, labels_dest) | ||
|
||
## TODO: Add this back. | ||
## #<fermions>: | ||
## a_src = before_combiner_signs( | ||
## a_src, | ||
## labels_src, | ||
## axes(a_src), | ||
## a_comb, | ||
## labels_comb, | ||
## axes(a_comb), | ||
## labels_dest, | ||
## axes_dest, | ||
## ) | ||
|
||
# Account for permutation of data. | ||
cpos_in_labels_comb = 1 | ||
clabel = labels_comb[cpos_in_labels_comb] | ||
labels_uc = deleteat(labels_comb, cpos_in_labels_comb) | ||
cpos_in_labels_dest = findfirst(==(clabel), labels_dest) | ||
labels_dest_uc = insertat(labels_dest, labels_uc, cpos_in_labels_dest) | ||
perm = getperm(labels_dest_uc, labels_src) | ||
ucpos_in_labels_src = Tuple(findall(x -> x in labels_uc, labels_src)) | ||
a_dest = permutedims_combine( | ||
a_src, axes_dest, perm, ucpos_in_labels_src, blockperm(a_comb), blockcomb(a_comb) | ||
) | ||
|
||
return a_dest, labels_dest | ||
end | ||
|
||
function permutedims_combine( | ||
a_src::BlockSparseArray, | ||
axes_dest, | ||
perm::Tuple, | ||
combdims::Tuple, | ||
blockperm::Vector{Int}, | ||
blockcomb::Vector{Int}, | ||
) | ||
a_dest = permutedims_combine_output( | ||
a_src, axes_dest, perm, combdims, blockperm, blockcomb | ||
) | ||
|
||
# Permute the indices | ||
axes_perm = permute(axes(a_src), perm) | ||
|
||
# Now that the indices are permuted, compute | ||
# which indices are now combined | ||
combdims_perm = sort(_permute_combdims(combdims, perm)) | ||
comb_ind_loc = minimum(combdims_perm) | ||
|
||
# Determine the new index before combining | ||
axes_to_combine = getindices(axes_perm, combdims_perm) | ||
axis_comb = ⊗(axes_to_combine...) | ||
axis_comb = BlockArrays.blockedrange(length.(BlockArrays.blocks(axis_comb)[blockperm])) | ||
|
||
for b in nzblocks(a_src) | ||
a_src_b = @view a_src[b] | ||
b_perm = permute(b, perm) | ||
b_perm_comb = combine_dims(b_perm, axes_perm, combdims_perm) | ||
b_perm_comb = perm_block(b_perm_comb, comb_ind_loc, blockperm) | ||
# TODO: Wrap this in `BlockArrays.Block`? | ||
b_in_combined_dim = b_perm_comb.n[comb_ind_loc] | ||
new_b_in_combined_dim = blockcomb[b_in_combined_dim] | ||
offset = 0 | ||
pos_in_new_combined_block = 1 | ||
while b_in_combined_dim - pos_in_new_combined_block > 0 && | ||
blockcomb[b_in_combined_dim - pos_in_new_combined_block] == new_b_in_combined_dim | ||
# offset += blockdim(axis_comb, b_in_combined_dim - pos_in_new_combined_block) | ||
offset += length( | ||
axis_comb[BlockArrays.Block(b_in_combined_dim - pos_in_new_combined_block)] | ||
) | ||
pos_in_new_combined_block += 1 | ||
end | ||
b_dest = setindex(b_perm_comb, new_b_in_combined_dim, comb_ind_loc) | ||
a_dest_b_total = @view a_dest[b_dest] | ||
# dimsa_dest_b_tot = size(a_dest_b_total) | ||
|
||
# TODO: Simplify this code. | ||
subind = ntuple(ndims(a_src) - length(combdims) + 1) do i | ||
if i == comb_ind_loc | ||
range( | ||
1 + offset; | ||
stop=offset + length(axis_comb[BlockArrays.Block(b_in_combined_dim)]), | ||
) | ||
else | ||
range(1; stop=size(a_dest_b_total)[i]) | ||
end | ||
end | ||
|
||
a_dest_b = @view a_dest_b_total[subind...] | ||
a_dest_b = reshape(a_dest_b, permute(size(a_src_b), perm)) | ||
# TODO: Make this `convert` call more general | ||
# for GPUs using `unwrap_type`. | ||
a_src_bₐ = convert(Array, a_src_b) | ||
# TODO: Use `expose` to make more efficient and robust. | ||
permutedims!(a_dest_b, a_src_bₐ, perm) | ||
end | ||
|
||
return a_dest | ||
end | ||
|
||
function permutedims_combine_output( | ||
a_src::BlockSparseArray, | ||
axes_dest, | ||
perm::Tuple, | ||
combdims::Tuple, | ||
blockperm::Vector{Int}, | ||
blockcomb::Vector{Int}, | ||
) | ||
# Permute the indices | ||
axes_src = axes(a_src) | ||
axes_perm = permute(axes_src, perm) | ||
|
||
# Now that the indices are permuted, compute | ||
# which indices are now combined | ||
combdims_perm = sort(_permute_combdims(combdims, perm)) | ||
|
||
# Permute the nonzero blocks (dimension-wise) | ||
blocks = nzblocks(a_src) | ||
|
||
# TODO: Use `permute.(blocks, perm)`. | ||
blocks_perm = BlockArrays.Block.(permute.(getfield.(blocks, :n), Ref(perm))) | ||
|
||
# Combine the nonzero blocks (dimension-wise) | ||
blocks_perm_comb = combine_dims(blocks_perm, axes_perm, combdims_perm) | ||
|
||
# Permute the blocks (within the newly combined dimension) | ||
comb_ind_loc = minimum(combdims_perm) | ||
blocks_perm_comb = perm_blocks(blocks_perm_comb, comb_ind_loc, blockperm) | ||
blocks_perm_comb = sort(blocks_perm_comb; lt=isblockless) | ||
|
||
# Combine the blocks (within the newly combined and permuted dimension) | ||
blocks_perm_comb = combine_blocks(blocks_perm_comb, comb_ind_loc, blockcomb) | ||
T = eltype(a_src) | ||
N = length(axes_dest) | ||
B = set_ndims(unwrap_type(a_src), length(axes_dest)) | ||
return BlockSparseArray{T,N,B}(undef, blocks_perm_comb, axes_dest) | ||
end |
155 changes: 155 additions & 0 deletions
155
NDTensors/src/arraystorage/blocksparsearray/storage/combiner/contract_uncombine.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
function contract_inds_uncombine(inds_src::Tuple, labels_src, inds_comb::Tuple, labels_comb) | ||
cpos_in_labels_comb = 1 | ||
clabel = labels_comb[cpos_in_labels_comb] | ||
labels_uc = deleteat(labels_comb, cpos_in_labels_comb) | ||
labels_dest = labels_src | ||
cpos_in_labels_dest = findfirst(==(clabel), labels_dest) | ||
# Move combined index to first position | ||
perm = ntuple(identity, length(inds_src)) | ||
if cpos_in_labels_dest != 1 | ||
labels_dest_orig = labels_dest | ||
labels_dest = deleteat(labels_dest, cpos_in_labels_dest) | ||
labels_dest = insertafter(labels_dest, clabel, 0) | ||
cpos_in_labels_dest = 1 | ||
perm = getperm(labels_dest, labels_dest_orig) | ||
inds_src = permute(inds_src, perm) | ||
labels_src = permute(labels_src, perm) | ||
end | ||
labels_dest = insertat(labels_dest, labels_uc, cpos_in_labels_dest) | ||
inds_dest = contract_inds(inds_comb, labels_comb, inds_src, labels_src, labels_dest) | ||
return inds_dest, labels_dest, perm, cpos_in_labels_dest | ||
end | ||
|
||
function contract_uncombine( | ||
a_src::BlockSparseArray, labels_src, a_comb::CombinerArray, labels_comb | ||
) | ||
axes_dest, labels_dest, perm, cpos_in_labels_dest = contract_inds_uncombine( | ||
axes(a_src), labels_src, axes(a_comb), labels_comb | ||
) | ||
a_src = permutedims(a_src, perm) | ||
|
||
## TODO: Add this back. | ||
## # <fermions>: | ||
## a_src = before_combiner_signs( | ||
## a_src, | ||
## labels_src, | ||
## axes(a_src), | ||
## a_comb, | ||
## labels_comb, | ||
## axes(a_comb), | ||
## labels_dest, | ||
## axes_dest, | ||
## ) | ||
|
||
a_dest = uncombine( | ||
a_src, | ||
labels_src, | ||
axes_dest, | ||
labels_dest, | ||
cpos_in_labels_dest, | ||
blockperm(a_comb), | ||
blockcomb(a_comb), | ||
) | ||
|
||
## TODO: Add this back. | ||
## # <fermions>: | ||
## a_dest = after_combiner_signs( | ||
## a_dest, | ||
## labels_dest, | ||
## axes_dest, | ||
## a_comb, | ||
## labels_comb, | ||
## axes(a_comb), | ||
## ) | ||
|
||
return a_dest, labels_dest | ||
end | ||
|
||
function uncombine( | ||
a_src::BlockSparseArray, | ||
labels_src, | ||
axes_dest, | ||
labels_dest, | ||
combdim::Int, | ||
blockperm::Vector{Int}, | ||
blockcomb::Vector{Int}, | ||
) | ||
a_dest = uncombine_output( | ||
a_src, labels_src, axes_dest, labels_dest, combdim, blockperm, blockcomb | ||
) | ||
invblockperm = invperm(blockperm) | ||
# This is needed for reshaping the block | ||
# TODO: It is already calculated in uncombine_output, use it from there | ||
labels_uncomb_perm = setdiff(labels_dest, labels_src) | ||
ind_uncomb_perm = ⊗( | ||
axes_dest[map(x -> findfirst(==(x), labels_dest), labels_uncomb_perm)]... | ||
) | ||
ind_uncomb = BlockArrays.blockedrange( | ||
length.(BlockArrays.blocks(ind_uncomb_perm)[blockperm]) | ||
) | ||
# Same as axes(a_src) but with the blocks uncombined | ||
axes_uncomb = insertat(axes(a_src), ind_uncomb, combdim) | ||
axes_uncomb_perm = insertat(axes(a_src), ind_uncomb_perm, combdim) | ||
for b in nzblocks(a_src) | ||
a_src_b_tot = @view a_src[b] | ||
bs_uncomb = uncombine_block(b, combdim, blockcomb) | ||
offset = 0 | ||
for i in 1:length(bs_uncomb) | ||
b_uncomb = bs_uncomb[i] | ||
b_uncomb_perm = perm_block(b_uncomb, combdim, invblockperm) | ||
b_uncomb_perm_reshape = reshape(b_uncomb_perm, axes_uncomb_perm, axes_dest) | ||
a_dest_b = @view a_dest[b_uncomb_perm_reshape] | ||
b_uncomb_in_combined_dim = b_uncomb_perm.n[combdim] | ||
start = offset + 1 | ||
stop = offset + length(ind_uncomb_perm[BlockArrays.Block(b_uncomb_in_combined_dim)]) | ||
subind = ntuple( | ||
i -> i == combdim ? range(start; stop=stop) : range(1; stop=size(a_src_b_tot)[i]), | ||
ndims(a_src), | ||
) | ||
offset = stop | ||
a_src_b = @view a_src_b_tot[subind...] | ||
|
||
# Alternative (but maybe slower): | ||
#copyto!(a_dest_b, a_src_b) | ||
|
||
if length(a_src_b) == 1 | ||
# Call `cpu` to avoid allowscalar error on GPU. | ||
# TODO: a_desteplace with `@allowscalar`, requires adding | ||
# `GPUArraysCore.jl` as a dependency, or use `expose`. | ||
a_dest_b[] = cpu(a_src_b)[] | ||
else | ||
# TODO: Use `unspecify_parameters(unwrap_type(a_src))` intead of `Array`. | ||
a_dest_bₐ = convert(Array, a_dest_b) | ||
a_dest_bₐᵣ = reshape(a_dest_bₐ, size(a_src_b)) | ||
copyto!(expose(a_dest_bₐᵣ), expose(a_src_b)) | ||
end | ||
end | ||
end | ||
return a_dest | ||
end | ||
|
||
function uncombine_output( | ||
a_src::BlockSparseArray, | ||
labels_src, | ||
axes_dest, | ||
labels_dest, | ||
combdim::Int, | ||
blockperm::Vector{Int}, | ||
blockcomb::Vector{Int}, | ||
) | ||
labels_uncomb_perm = setdiff(labels_dest, labels_src) | ||
ind_uncomb_perm = ⊗( | ||
axes_dest[map(x -> findfirst(==(x), labels_dest), labels_uncomb_perm)]... | ||
) | ||
axes_uncomb_perm = insertat(axes(a_src), ind_uncomb_perm, combdim) | ||
# Uncombine the blocks of a_src | ||
blocks_uncomb = uncombine_blocks(nzblocks(a_src), combdim, blockcomb) | ||
blocks_uncomb_perm = perm_blocks(blocks_uncomb, combdim, invperm(blockperm)) | ||
|
||
# TODO: Should this be zero data instead of undef? | ||
T = eltype(a_src) | ||
N = length(axes_uncomb_perm) | ||
B = unwrap_type(a_src) | ||
a_uncomb_perm = BlockSparseArray{T,N,B}(undef, blocks_uncomb_perm, axes_uncomb_perm) | ||
return reshape(a_uncomb_perm, axes_dest) | ||
end |
Oops, something went wrong.