-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[NDTensors][NDTensorsCUDAExt] Improve performance of CUDA backend (#1194
- Loading branch information
Showing
31 changed files
with
387 additions
and
230 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
import NDTensors: cu, set_ndims, set_eltype, set_eltype_if_unspecified, similartype | ||
import NDTensors: | ||
ContractionProperties, _contract!, GemmBackend, auto_select_backend, _gemm! | ||
ContractionProperties, _contract!, GemmBackend, auto_select_backend, _gemm!, iscu | ||
import NDTensors.SetParameters: nparameters, get_parameter, set_parameter, default_parameter | ||
|
||
import .CUDA: CuArrayAdaptor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
iscu(::Type{<:CuArray}) = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
function mul!!(CM::AbstractArray, AM::AbstractArray, BM::AbstractArray, α, β) | ||
return mul!!( | ||
leaf_parenttype(CM), CM, leaf_parenttype(AM), AM, leaf_parenttype(BM), BM, α, β | ||
) | ||
return CM | ||
end | ||
|
||
function mul!!( | ||
::Type{<:AbstractArray}, | ||
CM, | ||
::Type{<:AbstractArray}, | ||
AM, | ||
::Type{<:AbstractArray}, | ||
BM, | ||
α, | ||
β, | ||
) | ||
mul!(CM, AM, BM, α, β) | ||
return CM | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
## NOTICE!!: Here we are not importing Base.permutedims or Base.permutedims! but | ||
## are writing our own implementation. This allows us to | ||
# NDTensors.permutedims | ||
function permutedims(M::AbstractArray, perm) | ||
return permutedims(leaf_parenttype(M), M, perm) | ||
end | ||
|
||
# NDTensors.permutedims | ||
function permutedims(::Type{<:AbstractArray}, M, perm) | ||
return Base.permutedims(M, perm) | ||
end | ||
|
||
# NDTensors.permutedims! | ||
function permutedims!(Mdest::AbstractArray, M::AbstractArray, perm) | ||
return permutedims!(leaf_parenttype(Mdest), Mdest, leaf_parenttype(M), M, perm) | ||
end | ||
|
||
# NDTensors.permutedims! | ||
function permutedims!(::Type{<:AbstractArray}, Mdest, ::Type{<:AbstractArray}, M, perm) | ||
return Base.permutedims!(Mdest, M, perm) | ||
end | ||
|
||
function permutedims!!(B::AbstractArray, A::AbstractArray, perm, f) | ||
return permutedims!!(leaf_parenttype(B), B, leaf_parenttype(A), A, perm, f) | ||
end | ||
|
||
function permutedims!!( | ||
Bleaftype::Type{<:AbstractArray}, B, Aleaftype::Type{<:AbstractArray}, A, perm, f | ||
) | ||
permutedims!(Bleaftype, B, Aleaftype, A, perm, f) | ||
return B | ||
end | ||
|
||
function permutedims!(::Type{<:AbstractArray}, B, ::Type{<:AbstractArray}, A, perm, f) | ||
B .= f.(B, Base.permutedims(A, perm)) | ||
return B | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
using LinearAlgebra: BlasFloat | ||
export backend_auto, backend_blas, backend_generic | ||
|
||
@eval struct GemmBackend{T} | ||
(f::Type{<:GemmBackend})() = $(Expr(:new, :f)) | ||
end | ||
GemmBackend(s) = GemmBackend{Symbol(s)}() | ||
macro GemmBackend_str(s) | ||
return :(GemmBackend{$(Expr(:quote, Symbol(s)))}) | ||
end | ||
|
||
const gemm_backend = Ref(:Auto) | ||
function backend_auto() | ||
return gemm_backend[] = :Auto | ||
end | ||
function backend_blas() | ||
return gemm_backend[] = :BLAS | ||
end | ||
function backend_generic() | ||
return gemm_backend[] = :Generic | ||
end | ||
|
||
@inline function auto_select_backend( | ||
::Type{<:StridedVecOrMat{<:BlasFloat}}, | ||
::Type{<:StridedVecOrMat{<:BlasFloat}}, | ||
::Type{<:StridedVecOrMat{<:BlasFloat}}, | ||
) | ||
return GemmBackend(:BLAS) | ||
end | ||
|
||
@inline function auto_select_backend( | ||
::Type{<:AbstractVecOrMat}, ::Type{<:AbstractVecOrMat}, ::Type{<:AbstractVecOrMat} | ||
) | ||
return GemmBackend(:Generic) | ||
end | ||
|
||
function _gemm!( | ||
tA, tB, alpha, A::TA, B::TB, beta, C::TC | ||
) where {TA<:AbstractVecOrMat,TB<:AbstractVecOrMat,TC<:AbstractVecOrMat} | ||
if gemm_backend[] == :Auto | ||
_gemm!(auto_select_backend(TA, TB, TC), tA, tB, alpha, A, B, beta, C) | ||
else | ||
_gemm!(GemmBackend(gemm_backend[]), tA, tB, alpha, A, B, beta, C) | ||
end | ||
end | ||
|
||
# BLAS matmul | ||
function _gemm!( | ||
::GemmBackend{:BLAS}, | ||
tA, | ||
tB, | ||
alpha, | ||
A::AbstractVecOrMat, | ||
B::AbstractVecOrMat, | ||
beta, | ||
C::AbstractVecOrMat, | ||
) | ||
#@timeit_debug timer "BLAS.gemm!" begin | ||
return BLAS.gemm!(tA, tB, alpha, A, B, beta, C) | ||
#end # @timeit | ||
end | ||
|
||
# generic matmul | ||
function _gemm!( | ||
::GemmBackend{:Generic}, | ||
tA, | ||
tB, | ||
alpha::AT, | ||
A::AbstractVecOrMat, | ||
B::AbstractVecOrMat, | ||
beta::BT, | ||
C::AbstractVecOrMat, | ||
) where {AT,BT} | ||
mul!(C, tA == 'T' ? transpose(A) : A, tB == 'T' ? transpose(B) : B, alpha, beta) | ||
return C | ||
end | ||
|
||
# Non-trivial permutation | ||
function _contract_scalar_perm!( | ||
Rᵃ::AbstractArray{ElR}, Tᵃ::AbstractArray, perm, α, β=zero(ElR) | ||
) where {ElR} | ||
if iszero(β) | ||
if iszero(α) | ||
fill!(Rᵃ, 0) | ||
else | ||
Rᵃ = permutedims!!(Rᵃ, Tᵃ, perm, (r, t) -> α * t) | ||
end | ||
elseif isone(β) | ||
if iszero(α) | ||
# Rᵃ .= Rᵃ | ||
# No-op | ||
else | ||
Rᵃ = permutedims!!(Rᵃ, Tᵃ, perm, (r, t) -> r + α * t) | ||
end | ||
else | ||
if iszero(α) | ||
# Rᵃ .= β .* Rᵃ | ||
LinearAlgebra.scal!(length(Rᵃ), β, Rᵃ, 1) | ||
else | ||
Rᵃ .= α .* permutedims(Tᵃ, perm) .+ β .* Rᵃ | ||
end | ||
end | ||
return Rᵃ | ||
end | ||
|
||
function _contract!( | ||
CT::AbstractArray{El,NC}, | ||
AT::AbstractArray{El,NA}, | ||
BT::AbstractArray{El,NB}, | ||
props::ContractionProperties, | ||
α::Number=one(El), | ||
β::Number=zero(El), | ||
) where {El,NC,NA,NB} | ||
tA = 'N' | ||
if props.permuteA | ||
#@timeit_debug timer "_contract!: permutedims A" begin | ||
Ap = permutedims(AT, props.PA) | ||
#end # @timeit | ||
AM = transpose(reshape(Ap, (props.dmid, props.dleft))) | ||
else | ||
#A doesn't have to be permuted | ||
if Atrans(props) | ||
AM = transpose(reshape(AT, (props.dmid, props.dleft))) | ||
else | ||
AM = reshape(AT, (props.dleft, props.dmid)) | ||
end | ||
end | ||
|
||
tB = 'N' | ||
if props.permuteB | ||
#@timeit_debug timer "_contract!: permutedims B" begin | ||
Bp = permutedims(BT, props.PB) | ||
#end # @timeit | ||
BM = reshape(Bp, (props.dmid, props.dright)) | ||
else | ||
if Btrans(props) | ||
BM = transpose(reshape(BT, (props.dright, props.dmid))) | ||
else | ||
BM = reshape(BT, (props.dmid, props.dright)) | ||
end | ||
end | ||
|
||
# TODO: this logic may be wrong | ||
if props.permuteC | ||
# if we are computing C = α * A B + β * C | ||
# we need to make sure C is permuted to the same | ||
# ordering as A B which is the inverse of props.PC | ||
if β ≠ 0 | ||
CM = reshape(permutedims(CT, invperm(props.PC)), (props.dleft, props.dright)) | ||
else | ||
# Need to copy here since we will be permuting | ||
# into C later | ||
CM = reshape(copy(CT), (props.dleft, props.dright)) | ||
end | ||
else | ||
if Ctrans(props) | ||
CM = transpose(reshape(CT, (props.dright, props.dleft))) | ||
else | ||
CM = reshape(CT, (props.dleft, props.dright)) | ||
end | ||
end | ||
|
||
#tC = similar(CM) | ||
#_gemm!(tA, tB, El(α), AM, BM, El(β), CM) | ||
CM = mul!!(CM, AM, BM, El(α), El(β)) | ||
|
||
if props.permuteC | ||
Cr = reshape(CM, props.newCrange) | ||
# TODO: use invperm(pC) here? | ||
#@timeit_debug timer "_contract!: permutedims C" begin | ||
CT .= permutedims(Cr, props.PC) | ||
#end # @timeit | ||
end | ||
|
||
return CT | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
function mul!!(::Type{<:Array}, CM, ::Type{<:Array}, AM, ::Type{<:Array}, BM, α, β) | ||
@strided mul!(CM, AM, BM, α, β) | ||
return CM | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# NDTensors.permutedims | ||
function permutedims(::Type{<:Array}, M, perm) | ||
return @strided Base.permutedims(M, perm) | ||
end | ||
|
||
# NDTensors.permutedims! | ||
function permutedims!(::Type{<:Array}, Mdest, ::Type{<:Array}, M, perm) | ||
@strided Mdest .= Base.permutedims(M, perm) | ||
return Mdest | ||
end | ||
|
||
function permutedims!(::Type{<:Array}, B, ::Type{<:Array}, A, perm, f) | ||
@strided B .= f.(B, Base.permutedims(A, perm)) | ||
return B | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.