Skip to content

Commit

Permalink
[NDTensors][NDTensorsCUDAExt] Improve performance of CUDA backend (#1194
Browse files Browse the repository at this point in the history
)
  • Loading branch information
kmp5VT authored Oct 17, 2023
1 parent 9608ffc commit e571300
Show file tree
Hide file tree
Showing 31 changed files with 387 additions and 230 deletions.
3 changes: 2 additions & 1 deletion NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using NDTensors
using NDTensors.SetParameters
using Adapt
using Functors
using LinearAlgebra: BlasFloat
using LinearAlgebra

if isdefined(Base, :get_extension)
using CUDA
Expand All @@ -18,6 +18,7 @@ end

include("imports.jl")
include("set_types.jl")
include("iscu.jl")
include("adapt.jl")
include("linearalgebra.jl")
end
2 changes: 1 addition & 1 deletion NDTensors/ext/NDTensorsCUDAExt/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ buffertype(::NDTensorCuArrayAdaptor{B}) where {B} = B
function Adapt.adapt_storage(adaptor::NDTensorCuArrayAdaptor, xs::AbstractArray)
ElT = eltype(xs)
BufT = buffertype(adaptor)
return isbits(xs) ? xs : CuArray{ElT,1,BufT}(xs)
return isbits(xs) ? xs : adapt(CuArray{ElT,1,BufT}, xs)
end

function NDTensors.adapt_storagetype(
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/ext/NDTensorsCUDAExt/imports.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import NDTensors: cu, set_ndims, set_eltype, set_eltype_if_unspecified, similartype
import NDTensors:
ContractionProperties, _contract!, GemmBackend, auto_select_backend, _gemm!
ContractionProperties, _contract!, GemmBackend, auto_select_backend, _gemm!, iscu
import NDTensors.SetParameters: nparameters, get_parameter, set_parameter, default_parameter

import .CUDA: CuArrayAdaptor
1 change: 1 addition & 0 deletions NDTensors/ext/NDTensorsCUDAExt/iscu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
iscu(::Type{<:CuArray}) = true
6 changes: 6 additions & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@ include("abstractarray/set_types.jl")
include("abstractarray/to_shape.jl")
include("abstractarray/similar.jl")
include("abstractarray/ndims.jl")
include("abstractarray/permutedims.jl")
include("abstractarray/fill.jl")
include("abstractarray/mul.jl")
include("array/set_types.jl")
include("array/permutedims.jl")
include("array/mul.jl")
include("tupletools.jl")
include("emptynumber.jl")
include("nodata.jl")
Expand All @@ -63,9 +67,11 @@ include("tensor/tensor.jl")
include("dims.jl")
include("tensor/set_types.jl")
include("tensor/similar.jl")
include("tensor/permutedims.jl")
include("adapt.jl")
include("tensoralgebra/generic_tensor_operations.jl")
include("tensoralgebra/contraction_logic.jl")
include("abstractarray/tensoralgebra/contract.jl")

#####################################
# DenseTensor and DiagTensor
Expand Down
10 changes: 4 additions & 6 deletions NDTensors/src/abstractarray/fill.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
function generic_randn(arraytype::Type{<:AbstractArray}, dim::Integer=0)
function generic_randn(
arraytype::Type{<:AbstractArray}, dim::Integer=0; rng=Random.default_rng()
)
arraytype_specified = set_unspecified_parameters(
leaf_parenttype(arraytype), DefaultParameters()
)
data = similar(arraytype_specified, dim)
ElT = eltype(data)
for i in 1:length(data)
data[i] = randn(ElT)
end
return data
return randn!(rng, data)
end

function generic_zeros(arraytype::Type{<:AbstractArray}, dim::Integer=0)
Expand Down
20 changes: 20 additions & 0 deletions NDTensors/src/abstractarray/mul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
function mul!!(CM::AbstractArray, AM::AbstractArray, BM::AbstractArray, α, β)
return mul!!(
leaf_parenttype(CM), CM, leaf_parenttype(AM), AM, leaf_parenttype(BM), BM, α, β
)
return CM
end

function mul!!(
::Type{<:AbstractArray},
CM,
::Type{<:AbstractArray},
AM,
::Type{<:AbstractArray},
BM,
α,
β,
)
mul!(CM, AM, BM, α, β)
return CM
end
37 changes: 37 additions & 0 deletions NDTensors/src/abstractarray/permutedims.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
## NOTICE!!: Here we are not importing Base.permutedims or Base.permutedims! but
## are writing our own implementation. This allows us to
# NDTensors.permutedims
function permutedims(M::AbstractArray, perm)
return permutedims(leaf_parenttype(M), M, perm)
end

# NDTensors.permutedims
function permutedims(::Type{<:AbstractArray}, M, perm)
return Base.permutedims(M, perm)
end

# NDTensors.permutedims!
function permutedims!(Mdest::AbstractArray, M::AbstractArray, perm)
return permutedims!(leaf_parenttype(Mdest), Mdest, leaf_parenttype(M), M, perm)
end

# NDTensors.permutedims!
function permutedims!(::Type{<:AbstractArray}, Mdest, ::Type{<:AbstractArray}, M, perm)
return Base.permutedims!(Mdest, M, perm)
end

function permutedims!!(B::AbstractArray, A::AbstractArray, perm, f)
return permutedims!!(leaf_parenttype(B), B, leaf_parenttype(A), A, perm, f)
end

function permutedims!!(
Bleaftype::Type{<:AbstractArray}, B, Aleaftype::Type{<:AbstractArray}, A, perm, f
)
permutedims!(Bleaftype, B, Aleaftype, A, perm, f)
return B
end

function permutedims!(::Type{<:AbstractArray}, B, ::Type{<:AbstractArray}, A, perm, f)
B .= f.(B, Base.permutedims(A, perm))
return B
end
5 changes: 5 additions & 0 deletions NDTensors/src/abstractarray/similar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ function similar(arraytype::Type{<:AbstractArray}, dims::Tuple)
return similartype(arraytype, shape)(undef, NDTensors.to_shape(arraytype, shape))
end

# For when there are CUArray specific issues inline
iscu(A::AbstractArray) = iscu(typeof(A))
function iscu(A::Type{<:AbstractArray})
return (leaf_parenttype(A) == A ? false : iscu(leaf_parenttype(A)))
end
# This function actually allocates the data.
# Catches conversions of dimensions specified by ranges
# dimensions specified by integers with `Base.to_shape`.
Expand Down
176 changes: 176 additions & 0 deletions NDTensors/src/abstractarray/tensoralgebra/contract.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
using LinearAlgebra: BlasFloat
export backend_auto, backend_blas, backend_generic

@eval struct GemmBackend{T}
(f::Type{<:GemmBackend})() = $(Expr(:new, :f))
end
GemmBackend(s) = GemmBackend{Symbol(s)}()
macro GemmBackend_str(s)
return :(GemmBackend{$(Expr(:quote, Symbol(s)))})
end

const gemm_backend = Ref(:Auto)
function backend_auto()
return gemm_backend[] = :Auto
end
function backend_blas()
return gemm_backend[] = :BLAS
end
function backend_generic()
return gemm_backend[] = :Generic
end

@inline function auto_select_backend(
::Type{<:StridedVecOrMat{<:BlasFloat}},
::Type{<:StridedVecOrMat{<:BlasFloat}},
::Type{<:StridedVecOrMat{<:BlasFloat}},
)
return GemmBackend(:BLAS)
end

@inline function auto_select_backend(
::Type{<:AbstractVecOrMat}, ::Type{<:AbstractVecOrMat}, ::Type{<:AbstractVecOrMat}
)
return GemmBackend(:Generic)
end

function _gemm!(
tA, tB, alpha, A::TA, B::TB, beta, C::TC
) where {TA<:AbstractVecOrMat,TB<:AbstractVecOrMat,TC<:AbstractVecOrMat}
if gemm_backend[] == :Auto
_gemm!(auto_select_backend(TA, TB, TC), tA, tB, alpha, A, B, beta, C)
else
_gemm!(GemmBackend(gemm_backend[]), tA, tB, alpha, A, B, beta, C)
end
end

# BLAS matmul
function _gemm!(
::GemmBackend{:BLAS},
tA,
tB,
alpha,
A::AbstractVecOrMat,
B::AbstractVecOrMat,
beta,
C::AbstractVecOrMat,
)
#@timeit_debug timer "BLAS.gemm!" begin
return BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
#end # @timeit
end

# generic matmul
function _gemm!(
::GemmBackend{:Generic},
tA,
tB,
alpha::AT,
A::AbstractVecOrMat,
B::AbstractVecOrMat,
beta::BT,
C::AbstractVecOrMat,
) where {AT,BT}
mul!(C, tA == 'T' ? transpose(A) : A, tB == 'T' ? transpose(B) : B, alpha, beta)
return C
end

# Non-trivial permutation
function _contract_scalar_perm!(
Rᵃ::AbstractArray{ElR}, Tᵃ::AbstractArray, perm, α, β=zero(ElR)
) where {ElR}
if iszero(β)
if iszero(α)
fill!(Rᵃ, 0)
else
Rᵃ = permutedims!!(Rᵃ, Tᵃ, perm, (r, t) -> α * t)
end
elseif isone(β)
if iszero(α)
# Rᵃ .= Rᵃ
# No-op
else
Rᵃ = permutedims!!(Rᵃ, Tᵃ, perm, (r, t) -> r + α * t)
end
else
if iszero(α)
# Rᵃ .= β .* Rᵃ
LinearAlgebra.scal!(length(Rᵃ), β, Rᵃ, 1)
else
Rᵃ .= α .* permutedims(Tᵃ, perm) .+ β .* Rᵃ
end
end
return Rᵃ
end

function _contract!(
CT::AbstractArray{El,NC},
AT::AbstractArray{El,NA},
BT::AbstractArray{El,NB},
props::ContractionProperties,
α::Number=one(El),
β::Number=zero(El),
) where {El,NC,NA,NB}
tA = 'N'
if props.permuteA
#@timeit_debug timer "_contract!: permutedims A" begin
Ap = permutedims(AT, props.PA)
#end # @timeit
AM = transpose(reshape(Ap, (props.dmid, props.dleft)))
else
#A doesn't have to be permuted
if Atrans(props)
AM = transpose(reshape(AT, (props.dmid, props.dleft)))
else
AM = reshape(AT, (props.dleft, props.dmid))
end
end

tB = 'N'
if props.permuteB
#@timeit_debug timer "_contract!: permutedims B" begin
Bp = permutedims(BT, props.PB)
#end # @timeit
BM = reshape(Bp, (props.dmid, props.dright))
else
if Btrans(props)
BM = transpose(reshape(BT, (props.dright, props.dmid)))
else
BM = reshape(BT, (props.dmid, props.dright))
end
end

# TODO: this logic may be wrong
if props.permuteC
# if we are computing C = α * A B + β * C
# we need to make sure C is permuted to the same
# ordering as A B which is the inverse of props.PC
if β 0
CM = reshape(permutedims(CT, invperm(props.PC)), (props.dleft, props.dright))
else
# Need to copy here since we will be permuting
# into C later
CM = reshape(copy(CT), (props.dleft, props.dright))
end
else
if Ctrans(props)
CM = transpose(reshape(CT, (props.dright, props.dleft)))
else
CM = reshape(CT, (props.dleft, props.dright))
end
end

#tC = similar(CM)
#_gemm!(tA, tB, El(α), AM, BM, El(β), CM)
CM = mul!!(CM, AM, BM, El(α), El(β))

if props.permuteC
Cr = reshape(CM, props.newCrange)
# TODO: use invperm(pC) here?
#@timeit_debug timer "_contract!: permutedims C" begin
CT .= permutedims(Cr, props.PC)
#end # @timeit
end

return CT
end
4 changes: 4 additions & 0 deletions NDTensors/src/array/mul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
function mul!!(::Type{<:Array}, CM, ::Type{<:Array}, AM, ::Type{<:Array}, BM, α, β)
@strided mul!(CM, AM, BM, α, β)
return CM
end
15 changes: 15 additions & 0 deletions NDTensors/src/array/permutedims.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# NDTensors.permutedims
function permutedims(::Type{<:Array}, M, perm)
return @strided Base.permutedims(M, perm)
end

# NDTensors.permutedims!
function permutedims!(::Type{<:Array}, Mdest, ::Type{<:Array}, M, perm)
@strided Mdest .= Base.permutedims(M, perm)
return Mdest
end

function permutedims!(::Type{<:Array}, B, ::Type{<:Array}, A, perm, f)
@strided B .= f.(B, Base.permutedims(A, perm))
return B
end
4 changes: 3 additions & 1 deletion NDTensors/src/arraytensor/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ end
function permutedims!(
output_array::MatrixOrArrayStorage, array::MatrixOrArrayStorage, perm, f::Function
)
@strided output_array .= f.(output_array, permutedims(array, perm))
output_array = permutedims!!(
leaf_parenttype(output_array), output_array, leaf_parenttype(array), array, perm, f
)
return output_array
end
1 change: 0 additions & 1 deletion NDTensors/src/dense/dense.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#
# Dense storage
#
using LinearAlgebra: BlasFloat

struct Dense{ElT,DataT<:AbstractArray} <: TensorStorage{ElT}
data::DataT
Expand Down
9 changes: 6 additions & 3 deletions NDTensors/src/dense/densetensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ end
# Single index
#

@propagate_inbounds function getindex(T::DenseTensor{<:Number})
return (iscu(T) ? NDTensors.cpu(data(T))[] : data(T)[])
end

@propagate_inbounds function getindex(T::DenseTensor{<:Number}, I::Integer...)
Base.@_inline_meta
return getindex(data(T), Base._sub2ind(T, I...))
Expand Down Expand Up @@ -195,7 +199,7 @@ function permutedims!(
) where {N,StoreT<:StridedArray}
RA = array(R)
TA = array(T)
@strided RA .= permutedims(TA, perm)
permutedims!(RA, TA, perm)
return R
end

Expand Down Expand Up @@ -243,8 +247,7 @@ function permutedims!(
end
RA = array(R)
TA = array(T)
@strided RA .= f.(RA, permutedims(TA, perm))
return R
return permutedims!!(RA, TA, perm, f)
end

"""
Expand Down
Loading

0 comments on commit e571300

Please sign in to comment.