Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors][NDTensorsCUDAExt] Improve performance of GPU backends #1194

Merged
merged 147 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from 133 commits
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
6438473
Create generic randn for CUDA and use randn to make on device
kmp5VT Sep 13, 2023
682bcf3
Make generic_zero for CUDA
kmp5VT Sep 13, 2023
4f879d5
Format
kmp5VT Sep 13, 2023
b567ecd
remove monorepo
kmp5VT Sep 13, 2023
a4baeb7
Convert tests to CPU to not perform scalar operations
kmp5VT Sep 14, 2023
03a9b27
import -> using. Can use NDTenosrs randn! instead of another function
kmp5VT Sep 14, 2023
538cf76
Merge branch 'main' into kmp5/debug/cuda_rand
kmp5VT Sep 14, 2023
7314b2f
Merge branch 'main' into kmp5/debug/cuda_rand
kmp5VT Sep 20, 2023
154fac2
Temporarily get contract working by making a `mul!!` function
kmp5VT Sep 21, 2023
3da5e96
format
kmp5VT Sep 21, 2023
cb36f23
remove change to combiner
kmp5VT Sep 21, 2023
aef93a5
Use adapt to not copy elements
kmp5VT Sep 21, 2023
89d2837
Remove commented code
kmp5VT Sep 21, 2023
3e93d61
Remove CUDA fill functions and use more generic in abstractarray/fill.jl
kmp5VT Sep 21, 2023
6b42956
Remove NDTensors.
kmp5VT Sep 21, 2023
fe470b7
format
kmp5VT Sep 21, 2023
65d7174
Add comment about gpus
kmp5VT Sep 21, 2023
7232d31
Do elementwise operations on data to avoid scalar indexing
kmp5VT Sep 25, 2023
6197213
Force dot to return value on CPU
kmp5VT Sep 25, 2023
0ed96cb
remove copy code
kmp5VT Sep 25, 2023
fa67f7c
format
kmp5VT Sep 25, 2023
db0a4fc
Merge branch 'main' into kmp5/debug/cuda_rand
kmp5VT Sep 26, 2023
674c40c
Bootleg fix for a conversion to UnifiedMemory issue
kmp5VT Sep 26, 2023
79ec79f
force itensor to use data to speed up computation
kmp5VT Sep 26, 2023
3a142b4
Merge branch 'kmp5/debug/cuda_rand' of github.com:kmp5VT/ITensors.jl …
kmp5VT Sep 26, 2023
68aae9c
remove unecessary code
kmp5VT Sep 27, 2023
3b808e1
Use all of linearalgebra
kmp5VT Sep 27, 2023
de4b1f9
grab data on cpu not gpu
kmp5VT Sep 27, 2023
f69389f
[experiment] remove strided. Theres an issue with broadcast here duri…
kmp5VT Sep 27, 2023
418043b
[experiment] Add todo here
kmp5VT Sep 27, 2023
b876aa1
[experiment] Trucate! does many scalar operations. Potentially could …
kmp5VT Sep 27, 2023
f03d5cf
[experiment] Add comment for dot, its still broken for GPU
kmp5VT Sep 27, 2023
dcb4466
Merge commit 'f03d5cf5cb502b670924c6ce7d8c5bff7167024d' into GPU_fixe…
kmp5VT Sep 28, 2023
c7732a3
Fix printing for GPU no scalar indexing
kmp5VT Sep 28, 2023
2c486b3
dot function calls permute and linearalgebra dot instead of gemm [exp…
kmp5VT Sep 28, 2023
dcf2477
Revert axpy and working on broadcast
kmp5VT Sep 29, 2023
72e0df9
add comments about broken code
kmp5VT Oct 2, 2023
1cfaa29
Force truncate to be done on CPU [experimental]
kmp5VT Oct 2, 2023
442fe95
remove show
kmp5VT Oct 2, 2023
74791ac
unified memory hack [experimental]
kmp5VT Oct 2, 2023
ca023c6
`any` performs scalar operations [experimental]
kmp5VT Oct 2, 2023
e1ef74e
Make a sweepup function to deal with CUDA memory [experimental]
kmp5VT Oct 2, 2023
5569fca
format
kmp5VT Oct 2, 2023
37e00e4
Merge remote-tracking branch 'origin/GPU_fixes_temp' into kmp5/debug/…
kmp5VT Oct 3, 2023
67d1a6b
Revert to main
kmp5VT Oct 3, 2023
b1aafd7
Merge branch 'main' into kmp5/debug/cuda_rand
kmp5VT Oct 3, 2023
4ccc092
Remove factorize and truncate->sweepup [experimental]
kmp5VT Oct 3, 2023
03ccc0a
Merge branch 'kmp5/debug/cuda_rand' of github.com:kmp5VT/ITensors.jl …
kmp5VT Oct 3, 2023
c3cf6fb
remove sweepup function
kmp5VT Oct 3, 2023
d3a7b24
Don't return anonymous function
kmp5VT Oct 3, 2023
7efb9dd
remove old line
kmp5VT Oct 3, 2023
5519053
revert changes to itensors.jl
kmp5VT Oct 3, 2023
69664cf
remove dot from NDTensors
kmp5VT Oct 3, 2023
2500571
Merge branch 'main' into kmp5/debug/cuda_rand
kmp5VT Oct 5, 2023
d9d4302
Merge branch 'kmp5/debug/cuda_rand' of github.com:kmp5VT/ITensors.jl …
kmp5VT Oct 5, 2023
d5f5dad
Broadcast mostly working for GPU. Make more fmap functions and increa…
kmp5VT Oct 6, 2023
e607404
Fix matrix decomp
kmp5VT Oct 6, 2023
7779d12
format
kmp5VT Oct 6, 2023
fdd8e1c
We don't actually need fmap just don't use bc.f directly
kmp5VT Oct 6, 2023
9316868
format
kmp5VT Oct 6, 2023
8bc1f3c
Fix broadcast
kmp5VT Oct 6, 2023
4d4e10d
Don't need to remake bc
kmp5VT Oct 6, 2023
1b60b78
revert, dont use `data(T)` in print_tensor
kmp5VT Oct 6, 2023
d9fa5d6
Make sure to return the functions + and -
kmp5VT Oct 6, 2023
cdf7731
Call any on the parent of matrixT
kmp5VT Oct 6, 2023
4824984
Add comment about `f = bc.f`
kmp5VT Oct 6, 2023
c02dad8
Revert to using fill!
kmp5VT Oct 6, 2023
d96f32e
Don't need to convert positive twice because these functions don't wo…
kmp5VT Oct 6, 2023
aac8681
Fix `[]` for GPU
kmp5VT Oct 6, 2023
4f941b7
Add a comment
kmp5VT Oct 6, 2023
d859624
format
kmp5VT Oct 6, 2023
d69f758
Merge branch 'main' into kmp5/debug/cuda_rand
kmp5VT Oct 8, 2023
f19e52d
Make a better comment and add back the original line
kmp5VT Oct 8, 2023
0f310e9
itensor.throws -> throws
kmp5VT Oct 8, 2023
46e00ce
Add `iscu` function
kmp5VT Oct 8, 2023
f212536
import iscu from NDTensors not setparameters
kmp5VT Oct 8, 2023
02d84c7
Throw an error if trying to compute qr or ql positive
kmp5VT Oct 8, 2023
c770277
Push CPU conversion to NDTensors and make iscu for Tensors
kmp5VT Oct 8, 2023
3fadfcd
Use the [] to convert `T1` and `T2` to CPU
kmp5VT Oct 8, 2023
2b9ba99
format
kmp5VT Oct 8, 2023
499a616
Print warning because they do work when scalar allowed
kmp5VT Oct 8, 2023
cffb20e
Merge branch 'main' into kmp5/debug/cuda_rand
mtfishman Oct 8, 2023
86c8c32
Use recursion to unwrap types
kmp5VT Oct 9, 2023
2122ae3
Only call leaf_parenttype once
kmp5VT Oct 9, 2023
0788770
Fix spelling error
kmp5VT Oct 9, 2023
78ecb1b
Use iscu to just adapt CUDA computations. skip on CPU
kmp5VT Oct 9, 2023
2d46f5a
format
kmp5VT Oct 9, 2023
628c367
Merge branch 'main' into kmp5/debug/cuda_rand
kmp5VT Oct 9, 2023
3e86fe9
Move iscu to its own file
kmp5VT Oct 9, 2023
e65af67
Merge branch 'kmp5/debug/cuda_rand' of github.com:kmp5VT/ITensors.jl …
kmp5VT Oct 9, 2023
10e26cc
Update mul!! system
kmp5VT Oct 9, 2023
64c2c9a
format
kmp5VT Oct 9, 2023
394d9d8
Remove space
kmp5VT Oct 9, 2023
8668e0f
convert to CPU for qr_positive. And throw error if using ql with CUDA
kmp5VT Oct 9, 2023
2e9ecac
add iscu for itensors
kmp5VT Oct 9, 2023
8d89f7e
define parenttype for Tensor and TensorStorage
kmp5VT Oct 9, 2023
9b1dee0
call ndtensors.iscu
kmp5VT Oct 9, 2023
d055726
Create set_types.jl for leaf_parenttype
kmp5VT Oct 9, 2023
4bb0444
create leaf_parenttype for itensor
kmp5VT Oct 9, 2023
227f7ea
Update this code per matts comments
kmp5VT Oct 9, 2023
595d533
format
kmp5VT Oct 9, 2023
1631b41
Launch mul!! on all 3 tensor types
kmp5VT Oct 10, 2023
f7d5e44
update itensor leaf_parenttype
kmp5VT Oct 10, 2023
015686d
move leaf to import not using
kmp5VT Oct 10, 2023
379a6c5
Fix qr_positive conversion and convert cu to cpu for ql
kmp5VT Oct 10, 2023
93096dd
ql test no longer broken
kmp5VT Oct 10, 2023
8da4495
format
kmp5VT Oct 10, 2023
e2f83e8
create a permutedims!! function and implement in code
kmp5VT Oct 10, 2023
3c6f073
format
kmp5VT Oct 10, 2023
3aecf80
Remove cuda permutedims
kmp5VT Oct 11, 2023
5634b1b
More robust permutedims functions
kmp5VT Oct 11, 2023
be69b13
Add a comment/idea
kmp5VT Oct 11, 2023
86bd294
remove include
kmp5VT Oct 11, 2023
0dd361c
update permutedim calls
kmp5VT Oct 11, 2023
b08342d
Make abstract mul!! call, not CuArray mul!!
kmp5VT Oct 11, 2023
383e51f
format
kmp5VT Oct 11, 2023
7ce845b
Move array specif functions to permutedims
kmp5VT Oct 12, 2023
e4afbad
Make a choke point function for permutedims
kmp5VT Oct 12, 2023
b066065
Add array/permutedims.jl
kmp5VT Oct 12, 2023
d9caa44
create mul.jl functions
kmp5VT Oct 12, 2023
573500c
Move abstract array functions out of densetensor
kmp5VT Oct 12, 2023
0a8e895
Create choke point for permutedims! function
kmp5VT Oct 12, 2023
09e83dd
Remove this function as it created an infinite loop
kmp5VT Oct 12, 2023
06d6a04
rewrite this permutedims function
kmp5VT Oct 12, 2023
062efeb
Move abstract array contract functions to different file.
kmp5VT Oct 12, 2023
d579f5e
remove unecessary call
kmp5VT Oct 12, 2023
851614a
format
kmp5VT Oct 12, 2023
cb96243
These functions could cause conflicts
kmp5VT Oct 12, 2023
6c902bd
format
kmp5VT Oct 12, 2023
5c74d2a
Updated and working NDTensors.permutedims
kmp5VT Oct 13, 2023
211eb84
Permute in place
kmp5VT Oct 13, 2023
23e7659
Call NDTensors.permute in tests
kmp5VT Oct 13, 2023
bf82369
Format
kmp5VT Oct 13, 2023
6445af6
Make `base.permutedims(Tensor) =NDTensors.permutedims`
kmp5VT Oct 14, 2023
2c18c13
Make permutedims match bangbang code flow
kmp5VT Oct 14, 2023
f1a1b61
Have mul call mul! then return
kmp5VT Oct 14, 2023
b5fd9d1
format
kmp5VT Oct 14, 2023
50eb9a8
Use base.permutedims not NDTensors
kmp5VT Oct 14, 2023
bf86c83
Use simplified functions to dispatch later on parenttype
kmp5VT Oct 16, 2023
a836b89
format
kmp5VT Oct 16, 2023
709a83b
Update NDTensors/src/dense/densetensor.jl
kmp5VT Oct 16, 2023
eb26ef1
Update NDTensors/src/array/mul.jl
kmp5VT Oct 16, 2023
e032199
Update NDTensors/src/array/permutedims.jl
kmp5VT Oct 16, 2023
7d93e0c
Update NDTensors/src/array/permutedims.jl
kmp5VT Oct 16, 2023
cc92767
Add
kmp5VT Oct 16, 2023
ebcea17
appended to previous commit
kmp5VT Oct 16, 2023
262e924
format
kmp5VT Oct 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using NDTensors
using NDTensors.SetParameters
using Adapt
using Functors
using LinearAlgebra: BlasFloat
using LinearAlgebra

if isdefined(Base, :get_extension)
using CUDA
Expand All @@ -18,6 +18,7 @@ end

include("imports.jl")
include("set_types.jl")
include("iscu.jl")
include("adapt.jl")
include("linearalgebra.jl")
end
2 changes: 1 addition & 1 deletion NDTensors/ext/NDTensorsCUDAExt/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ buffertype(::NDTensorCuArrayAdaptor{B}) where {B} = B
function Adapt.adapt_storage(adaptor::NDTensorCuArrayAdaptor, xs::AbstractArray)
ElT = eltype(xs)
BufT = buffertype(adaptor)
return isbits(xs) ? xs : CuArray{ElT,1,BufT}(xs)
return isbits(xs) ? xs : adapt(CuArray{ElT,1,BufT}, xs)
end

function NDTensors.adapt_storagetype(
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/ext/NDTensorsCUDAExt/imports.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import NDTensors: cu, set_ndims, set_eltype, set_eltype_if_unspecified, similartype
import NDTensors:
ContractionProperties, _contract!, GemmBackend, auto_select_backend, _gemm!
ContractionProperties, _contract!, GemmBackend, auto_select_backend, _gemm!, iscu
import NDTensors.SetParameters: nparameters, get_parameter, set_parameter, default_parameter

import .CUDA: CuArrayAdaptor
1 change: 1 addition & 0 deletions NDTensors/ext/NDTensorsCUDAExt/iscu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
iscu(::Type{<:CuArray}) = true
5 changes: 5 additions & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@ include("abstractarray/set_types.jl")
include("abstractarray/to_shape.jl")
include("abstractarray/similar.jl")
include("abstractarray/ndims.jl")
include("abstractarray/permutedims.jl")
include("abstractarray/fill.jl")
include("abstractarray/mul.jl")
include("array/set_types.jl")
include("array/permutedims.jl")
include("array/mul.jl")
include("tupletools.jl")
include("emptynumber.jl")
include("nodata.jl")
Expand All @@ -66,6 +70,7 @@ include("tensor/similar.jl")
include("adapt.jl")
include("tensoralgebra/generic_tensor_operations.jl")
include("tensoralgebra/contraction_logic.jl")
include("abstractarray/tensoralgebra/contract.jl")

#####################################
# DenseTensor and DiagTensor
Expand Down
10 changes: 4 additions & 6 deletions NDTensors/src/abstractarray/fill.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
function generic_randn(arraytype::Type{<:AbstractArray}, dim::Integer=0)
function generic_randn(
arraytype::Type{<:AbstractArray}, dim::Integer=0; rng=Random.default_rng()
)
arraytype_specified = set_unspecified_parameters(
leaf_parenttype(arraytype), DefaultParameters()
)
data = similar(arraytype_specified, dim)
ElT = eltype(data)
for i in 1:length(data)
data[i] = randn(ElT)
end
return data
return randn!(rng, data)
end

function generic_zeros(arraytype::Type{<:AbstractArray}, dim::Integer=0)
Expand Down
18 changes: 18 additions & 0 deletions NDTensors/src/abstractarray/mul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
function mul!!(CM::AbstractArray, AM::AbstractArray, BM::AbstractArray, α, β)
return mul!!(
leaf_parenttype(CM), CM, leaf_parenttype(AM), AM, leaf_parenttype(BM), BM, α, β
)
end

function mul!!(
::Type{<:AbstractArray},
CM,
::Type{<:AbstractArray},
AM,
::Type{<:AbstractArray},
BM,
α,
β,
)
return mul!(CM, AM, BM, α, β)
end
29 changes: 29 additions & 0 deletions NDTensors/src/abstractarray/permutedims.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
## NOTICE!!: Here we are not importing Base.permutedims or Base.permutedims! but
## are writing our own implementation. This allows us to
# NDTensors.permutedims
function permutedims(M::AbstractArray, perm)
return permutedims(leaf_parenttype(M), M, perm)
end

# NDTensors.permutedims
function permutedims(::Type{<:AbstractArray}, M, perm)
return Base.permutedims(M, perm)
end

# NDTensors.permutedims!
function permutedims!(Mdest::AbstractArray, M::AbstractArray, perm)
return permutedims!(leaf_parenttype(Mdest), Mdest, leaf_parenttype(M), M, perm)
end

# NDTensors.permutedims!
function permutedims!(::Type{<:AbstractArray}, Mdest, ::Type{<:AbstractArray}, M, perm)
return Base.permutedims!(Mdest, M, perm)
end

function permutedims!!(B::AbstractArray, A::AbstractArray, perm, f)
return permutedims!!(leaf_parenttype(B), B, leaf_parenttype(A), A, perm, f)
end

function permutedims!!(::Type{<:AbstractArray}, B, ::Type{<:AbstractArray}, A, perm, f)
return B .= f.(B, Base.permutedims(A, perm))
end
kmp5VT marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 5 additions & 0 deletions NDTensors/src/abstractarray/similar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ function similar(arraytype::Type{<:AbstractArray}, dims::Tuple)
return similartype(arraytype, shape)(undef, NDTensors.to_shape(arraytype, shape))
end

# For when there are CUArray specific issues inline
iscu(A::AbstractArray) = iscu(typeof(A))
function iscu(A::Type{<:AbstractArray})
return (leaf_parenttype(A) == A ? false : iscu(leaf_parenttype(A)))
end
# This function actually allocates the data.
# Catches conversions of dimensions specified by ranges
# dimensions specified by integers with `Base.to_shape`.
Expand Down
178 changes: 178 additions & 0 deletions NDTensors/src/abstractarray/tensoralgebra/contract.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
using LinearAlgebra: BlasFloat
export backend_auto, backend_blas, backend_generic

@eval struct GemmBackend{T}
(f::Type{<:GemmBackend})() = $(Expr(:new, :f))
end
GemmBackend(s) = GemmBackend{Symbol(s)}()
macro GemmBackend_str(s)
return :(GemmBackend{$(Expr(:quote, Symbol(s)))})
end

const gemm_backend = Ref(:Auto)
function backend_auto()
return gemm_backend[] = :Auto
end
function backend_blas()
return gemm_backend[] = :BLAS
end
function backend_generic()
return gemm_backend[] = :Generic
end

@inline function auto_select_backend(
::Type{<:StridedVecOrMat{<:BlasFloat}},
::Type{<:StridedVecOrMat{<:BlasFloat}},
::Type{<:StridedVecOrMat{<:BlasFloat}},
)
return GemmBackend(:BLAS)
end

@inline function auto_select_backend(
::Type{<:AbstractVecOrMat}, ::Type{<:AbstractVecOrMat}, ::Type{<:AbstractVecOrMat}
)
return GemmBackend(:Generic)
end

function _gemm!(
tA, tB, alpha, A::TA, B::TB, beta, C::TC
) where {TA<:AbstractVecOrMat,TB<:AbstractVecOrMat,TC<:AbstractVecOrMat}
if gemm_backend[] == :Auto
_gemm!(auto_select_backend(TA, TB, TC), tA, tB, alpha, A, B, beta, C)
else
_gemm!(GemmBackend(gemm_backend[]), tA, tB, alpha, A, B, beta, C)
end
end

# BLAS matmul
function _gemm!(
::GemmBackend{:BLAS},
tA,
tB,
alpha,
A::AbstractVecOrMat,
B::AbstractVecOrMat,
beta,
C::AbstractVecOrMat,
)
#@timeit_debug timer "BLAS.gemm!" begin
return BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
#end # @timeit
end

# generic matmul
function _gemm!(
::GemmBackend{:Generic},
tA,
tB,
alpha::AT,
A::AbstractVecOrMat,
B::AbstractVecOrMat,
beta::BT,
C::AbstractVecOrMat,
) where {AT,BT}
mul!(C, tA == 'T' ? transpose(A) : A, tB == 'T' ? transpose(B) : B, alpha, beta)
return C
end

# Non-trivial permutation
function _contract_scalar_perm!(
Rᵃ::AbstractArray{ElR}, Tᵃ::AbstractArray, perm, α, β=zero(ElR)
) where {ElR}
if iszero(β)
if iszero(α)
fill!(Rᵃ, 0)
else
Rᵃ = permutedims!!(Rᵃ, Tᵃ, perm, (r, t) -> α * t)
end
elseif isone(β)
if iszero(α)
# Rᵃ .= Rᵃ
# No-op
else
Rᵃ = permutedims!!(Rᵃ, Tᵃ, perm, (r, t) -> r + α * t)
end
else
if iszero(α)
# Rᵃ .= β .* Rᵃ
LinearAlgebra.scal!(length(Rᵃ), β, Rᵃ, 1)
else
Rᵃ .= α .* permutedims(Tᵃ, perm) .+ β .* Rᵃ
end
end
return Rᵃ
end

function _contract!(
CT::AbstractArray{El,NC},
AT::AbstractArray{El,NA},
BT::AbstractArray{El,NB},
props::ContractionProperties,
α::Number=one(El),
β::Number=zero(El),
) where {El,NC,NA,NB}
tA = 'N'
if props.permuteA
#@timeit_debug timer "_contract!: permutedims A" begin
Ap = permutedims(leaf_parenttype(AT), AT, props.PA)
kmp5VT marked this conversation as resolved.
Show resolved Hide resolved
#end # @timeit
AM = transpose(reshape(Ap, (props.dmid, props.dleft)))
else
#A doesn't have to be permuted
if Atrans(props)
AM = transpose(reshape(AT, (props.dmid, props.dleft)))
else
AM = reshape(AT, (props.dleft, props.dmid))
end
end

tB = 'N'
if props.permuteB
#@timeit_debug timer "_contract!: permutedims B" begin
Bp = permutedims(leaf_parenttype(BT), BT, props.PB)
#end # @timeit
BM = reshape(Bp, (props.dmid, props.dright))
else
if Btrans(props)
BM = transpose(reshape(BT, (props.dright, props.dmid)))
else
BM = reshape(BT, (props.dmid, props.dright))
end
end

# TODO: this logic may be wrong
if props.permuteC
# if we are computing C = α * A B + β * C
# we need to make sure C is permuted to the same
# ordering as A B which is the inverse of props.PC
if β ≠ 0
CM = reshape(
permutedims(leaf_parenttype(CT), CT, invperm(props.PC)), (props.dleft, props.dright)
)
else
# Need to copy here since we will be permuting
# into C later
CM = reshape(copy(CT), (props.dleft, props.dright))
end
else
if Ctrans(props)
CM = transpose(reshape(CT, (props.dright, props.dleft)))
else
CM = reshape(CT, (props.dleft, props.dright))
end
end

#tC = similar(CM)
#_gemm!(tA, tB, El(α), AM, BM, El(β), CM)
CM = mul!!(CM, AM, BM, El(α), El(β))

if props.permuteC
Cr = reshape(CM, props.newCrange)
# TODO: use invperm(pC) here?
#@timeit_debug timer "_contract!: permutedims C" begin
CT .= permutedims(leaf_parenttype(Cr), Cr, props.PC)
#end # @timeit
end

return CT
end
3 changes: 3 additions & 0 deletions NDTensors/src/array/mul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
function mul!!(::Type{<:Array}, CM, ::Type{<:Array}, AM, ::Type{<:Array}, BM, α, β)
return @strided mul!(CM, AM, BM, α, β)
end
13 changes: 13 additions & 0 deletions NDTensors/src/array/permutedims.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# NDTensors.permutedims
function permutedims(::Type{<:Array}, M, perm)
return @strided Mdest = Base.permutedims(M, perm)
kmp5VT marked this conversation as resolved.
Show resolved Hide resolved
end

# NDTensors.permutedims!
function permutedims!(::Type{<:Array}, Mdest, ::Type{<:Array}, M, perm)
return @strided Mdest .= Base.permutedims(M, perm)
kmp5VT marked this conversation as resolved.
Show resolved Hide resolved
end

function permutedims!!(::Type{<:Array}, B, ::Type{<:Array}, A, perm, f)
@strided B .= f.(B, Base.permutedims(A, perm))
end
4 changes: 3 additions & 1 deletion NDTensors/src/arraytensor/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ end
function permutedims!(
output_array::MatrixOrArrayStorage, array::MatrixOrArrayStorage, perm, f::Function
)
@strided output_array .= f.(output_array, permutedims(array, perm))
output_array = permutedims!!(
leaf_parenttype(output_array), output_array, leaf_parenttype(array), array, perm, f
)
return output_array
end
1 change: 0 additions & 1 deletion NDTensors/src/dense/dense.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#
# Dense storage
#
using LinearAlgebra: BlasFloat

struct Dense{ElT,DataT<:AbstractArray} <: TensorStorage{ElT}
data::DataT
Expand Down
9 changes: 6 additions & 3 deletions NDTensors/src/dense/densetensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ end
# Single index
#

@propagate_inbounds function getindex(T::DenseTensor{<:Number})
return (iscu(T) ? NDTensors.cpu(data(T))[] : data(T)[])
end

@propagate_inbounds function getindex(T::DenseTensor{<:Number}, I::Integer...)
Base.@_inline_meta
return getindex(data(T), Base._sub2ind(T, I...))
Expand Down Expand Up @@ -195,7 +199,7 @@ function permutedims!(
) where {N,StoreT<:StridedArray}
RA = array(R)
TA = array(T)
@strided RA .= permutedims(TA, perm)
RA = permutedims!(leaf_parenttype(RA), RA, leaf_parenttype(TA), TA, perm)
kmp5VT marked this conversation as resolved.
Show resolved Hide resolved
return R
end

Expand Down Expand Up @@ -243,8 +247,7 @@ function permutedims!(
end
RA = array(R)
TA = array(T)
@strided RA .= f.(RA, permutedims(TA, perm))
return R
return permutedims!!(RA, TA, perm, f)
end

"""
Expand Down
Loading
Loading