diff --git a/src/BatchedBLAS.jl b/src/BatchedBLAS.jl index 0f71aba..ff19eb8 100644 --- a/src/BatchedBLAS.jl +++ b/src/BatchedBLAS.jl @@ -13,18 +13,18 @@ maybe_cast(::Type, x) = x maybe_cast(::Type{T}, x::AbstractFloat) where T<:Integer = round(T, clamp(x, typemin(T), typemax(T))) -function configurator(config, dims) - if length(dims)==1 - xthreads = min(32, dims[1]) - xblocks = cld(dims[1], xthreads) - return (xthreads,), (xblocks,) - elseif length(dims)==2 - xthreads = min(32, dims[1]) - ythreads = min(fld(config.threads, xthreads), cld(prod(dims), xthreads)) - xblocks = cld(dims[1], xthreads) - yblocks = cld(dims[2], ythreads) - return (xthreads, ythreads), (xblocks, yblocks) - end +function configurator(config, dim1) + xthreads = min(32, dim1) + xblocks = cld(dim1, xthreads) + return (xthreads,), (xblocks,) +end + +function configurator(config, dim1, dim2) + xthreads = min(32, dim1) + ythreads = min(fld(config.threads, xthreads), cld(dim1*dim2, xthreads)) + xblocks = cld(dim1, xthreads) + yblocks = cld(dim2, ythreads) + return (xthreads, ythreads), (xblocks, yblocks) end """ @@ -51,7 +51,7 @@ function batched_dot!(o::CuVector{To}, x::CuMatrix{Tx}, y::CuMatrix{Ty}) where { T = promote_type(To, Tx, Ty) kernel = @cuda name="batched_dot!" launch=false kernel(T, o, x, y) config = launch_configuration(kernel.fun) - threads, blocks = configurator(config, (size(o,1),)) + threads, blocks = configurator(config, size(o,1)) kernel(T, o, x, y; threads=threads, blocks=blocks) end @@ -115,7 +115,7 @@ function batched_gemv!(tA::AbstractChar, T = promote_type(TA, Tx, Ty) kernel = @cuda name="batched_gemv!" launch=false kernel(T, tA, alpha, A, x, beta, y) config = launch_configuration(kernel.fun) - threads, blocks = configurator(config, (size(y,1),size(y,2))) + threads, blocks = configurator(config, size(y,1), size(y,2)) kernel(T, tA, alpha, A, x, beta, y; threads=threads, blocks=blocks) end @@ -171,7 +171,7 @@ function batched_symv!(uplo::AbstractChar, T = promote_type(TA, Tx, Ty) kernel = @cuda name="batched_symv!" launch=false kernel(T, uplo, alpha, A, x, beta, y) config = launch_configuration(kernel.fun) - threads, blocks = configurator(config, (size(y,1),size(y,2))) + threads, blocks = configurator(config, size(y,1), size(y,2)) kernel(T, uplo, alpha, A, x, beta, y; threads=threads, blocks=blocks) end @@ -230,7 +230,7 @@ function batched_spmv!(uplo::AbstractChar, T = promote_type(TA, Tx, Ty) kernel = @cuda name="batched_spmv!" launch=false kernel(T, uplo, alpha, A, x, beta, y) config = launch_configuration(kernel.fun) - threads, blocks = configurator(config, (size(y,1),size(y,2))) + threads, blocks = configurator(config, size(y,1), size(y,2)) kernel(T, uplo, alpha, A, x, beta, y; threads=threads, blocks=blocks) end @@ -259,7 +259,7 @@ function batched_ger!(alpha::Talpha, x::CuMatrix{Tx}, y::CuMatrix{Ty}, A::CuArra kernel = @cuda name="batched_ger_vector!" launch=false kernel(alpha, x, y, A) config = launch_configuration(kernel.fun) - threads, blocks = configurator(config, (size(x,1),size(x,2))) + threads, blocks = configurator(config, size(x,1), size(x,2)) kernel(alpha, x, y, A; threads=threads, blocks=blocks) end @@ -304,7 +304,7 @@ function batched_syr!(uplo::AbstractChar, kernel = @cuda name="batched_syr_vector!" launch=false kernel(uplo, alpha, x, A) config = launch_configuration(kernel.fun) - threads, blocks = configurator(config, (size(x,1),size(x,2))) + threads, blocks = configurator(config, size(x,1), size(x,2)) kernel(uplo, alpha, x, A; threads=threads, blocks=blocks) end @@ -352,7 +352,7 @@ function batched_spr!(uplo::AbstractChar, kernel = @cuda name="batched_spr_vector!" launch=false kernel(uplo, alpha, x, A) config = launch_configuration(kernel.fun) - threads, blocks = configurator(config, (size(x,1),size(x,2))) + threads, blocks = configurator(config, size(x,1), size(x,2)) kernel(uplo, alpha, x, A; threads=threads, blocks=blocks) end