Skip to content

Commit

Permalink
dispatch to configurator
Browse files Browse the repository at this point in the history
  • Loading branch information
bjarthur committed Apr 19, 2022
1 parent 4ad8954 commit 97239d7
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions src/BatchedBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@ maybe_cast(::Type, x) = x
maybe_cast(::Type{T}, x::AbstractFloat) where T<:Integer =
round(T, clamp(x, typemin(T), typemax(T)))

function configurator(config, dims)
if length(dims)==1
xthreads = min(32, dims[1])
xblocks = cld(dims[1], xthreads)
return (xthreads,), (xblocks,)
elseif length(dims)==2
xthreads = min(32, dims[1])
ythreads = min(fld(config.threads, xthreads), cld(prod(dims), xthreads))
xblocks = cld(dims[1], xthreads)
yblocks = cld(dims[2], ythreads)
return (xthreads, ythreads), (xblocks, yblocks)
end
function configurator(config, dim1)
xthreads = min(32, dim1)
xblocks = cld(dim1, xthreads)
return (xthreads,), (xblocks,)
end

function configurator(config, dim1, dim2)
xthreads = min(32, dim1)
ythreads = min(fld(config.threads, xthreads), cld(dim1*dim2, xthreads))
xblocks = cld(dim1, xthreads)
yblocks = cld(dim2, ythreads)
return (xthreads, ythreads), (xblocks, yblocks)
end

"""
Expand All @@ -51,7 +51,7 @@ function batched_dot!(o::CuVector{To}, x::CuMatrix{Tx}, y::CuMatrix{Ty}) where {
T = promote_type(To, Tx, Ty)
kernel = @cuda name="batched_dot!" launch=false kernel(T, o, x, y)
config = launch_configuration(kernel.fun)
threads, blocks = configurator(config, (size(o,1),))
threads, blocks = configurator(config, size(o,1))
kernel(T, o, x, y; threads=threads, blocks=blocks)
end

Expand Down Expand Up @@ -115,7 +115,7 @@ function batched_gemv!(tA::AbstractChar,
T = promote_type(TA, Tx, Ty)
kernel = @cuda name="batched_gemv!" launch=false kernel(T, tA, alpha, A, x, beta, y)
config = launch_configuration(kernel.fun)
threads, blocks = configurator(config, (size(y,1),size(y,2)))
threads, blocks = configurator(config, size(y,1), size(y,2))
kernel(T, tA, alpha, A, x, beta, y; threads=threads, blocks=blocks)
end

Expand Down Expand Up @@ -171,7 +171,7 @@ function batched_symv!(uplo::AbstractChar,
T = promote_type(TA, Tx, Ty)
kernel = @cuda name="batched_symv!" launch=false kernel(T, uplo, alpha, A, x, beta, y)
config = launch_configuration(kernel.fun)
threads, blocks = configurator(config, (size(y,1),size(y,2)))
threads, blocks = configurator(config, size(y,1), size(y,2))
kernel(T, uplo, alpha, A, x, beta, y; threads=threads, blocks=blocks)
end

Expand Down Expand Up @@ -230,7 +230,7 @@ function batched_spmv!(uplo::AbstractChar,
T = promote_type(TA, Tx, Ty)
kernel = @cuda name="batched_spmv!" launch=false kernel(T, uplo, alpha, A, x, beta, y)
config = launch_configuration(kernel.fun)
threads, blocks = configurator(config, (size(y,1),size(y,2)))
threads, blocks = configurator(config, size(y,1), size(y,2))
kernel(T, uplo, alpha, A, x, beta, y; threads=threads, blocks=blocks)
end

Expand Down Expand Up @@ -259,7 +259,7 @@ function batched_ger!(alpha::Talpha, x::CuMatrix{Tx}, y::CuMatrix{Ty}, A::CuArra

kernel = @cuda name="batched_ger_vector!" launch=false kernel(alpha, x, y, A)
config = launch_configuration(kernel.fun)
threads, blocks = configurator(config, (size(x,1),size(x,2)))
threads, blocks = configurator(config, size(x,1), size(x,2))
kernel(alpha, x, y, A; threads=threads, blocks=blocks)
end

Expand Down Expand Up @@ -304,7 +304,7 @@ function batched_syr!(uplo::AbstractChar,

kernel = @cuda name="batched_syr_vector!" launch=false kernel(uplo, alpha, x, A)
config = launch_configuration(kernel.fun)
threads, blocks = configurator(config, (size(x,1),size(x,2)))
threads, blocks = configurator(config, size(x,1), size(x,2))
kernel(uplo, alpha, x, A; threads=threads, blocks=blocks)
end

Expand Down Expand Up @@ -352,7 +352,7 @@ function batched_spr!(uplo::AbstractChar,

kernel = @cuda name="batched_spr_vector!" launch=false kernel(uplo, alpha, x, A)
config = launch_configuration(kernel.fun)
threads, blocks = configurator(config, (size(x,1),size(x,2)))
threads, blocks = configurator(config, size(x,1), size(x,2))
kernel(uplo, alpha, x, A; threads=threads, blocks=blocks)
end

Expand Down

0 comments on commit 97239d7

Please sign in to comment.