Skip to content

Commit

Permalink
vendor agnostic
Browse files Browse the repository at this point in the history
  • Loading branch information
bjarthur committed Apr 24, 2023
1 parent d96d556 commit 259892b
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 140 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ authors = ["Ben Arthur <[email protected]>"]
version = "0.2.3"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"

[compat]
CUDA = "3, 4"
julia = "1.6"
39 changes: 24 additions & 15 deletions bench/runbench.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@ macro belapsed_median(args...)
esc(:(time(median(@benchmark $(args...))) / 1e9))
end

macro sync(ex)
quote
local ret = $(esc(ex))
KernelAbstractions.synchronize(CUDABackend())
ret

end
end


function doit(L,N)
x2 = CuArray(rand(L,N));
Expand All @@ -14,8 +23,8 @@ function doit(L,N)
o1 = CuArray(rand(N));
o3 = CuArray(rand(1,1,N));

tbgemm = @belapsed_median CUDA.@sync batched_mul!($o3, batched_transpose($x3), $y3)
tbdot = @belapsed_median CUDA.@sync batched_dot!($o1, $x2, $y2)
tbgemm = @belapsed_median @sync batched_mul!($o3, batched_transpose($x3), $y3)
tbdot = @belapsed_median @sync batched_dot!($o1, $x2, $y2)

CUDA.unsafe_free!.((x2, x3, y2, y3, o1, o3))
CUDA.memory_status()
Expand Down Expand Up @@ -51,18 +60,18 @@ function doit(L,N)
y2 = CuArray(rand(L,N));
y3 = CuArray(rand(L,1,N));

tbgemm = @belapsed_median CUDA.@sync batched_mul!($y3, $A3, $x3)
tbgemm = @belapsed_median @sync batched_mul!($y3, $A3, $x3)

tbgemvn = @belapsed_median CUDA.@sync batched_gemv!('N', 1.0, $A3, $x2, 0.0, $y2)
tbgemvt = @belapsed_median CUDA.@sync batched_gemv!('T', 1.0, $A3, $x2, 0.0, $y2)
tbgemvn = @belapsed_median @sync batched_gemv!('N', 1.0, $A3, $x2, 0.0, $y2)
tbgemvt = @belapsed_median @sync batched_gemv!('T', 1.0, $A3, $x2, 0.0, $y2)

tbsymvu = @belapsed_median CUDA.@sync batched_symv!('U', 1.0, $A3, $x2, 0.0, $y2)
tbsymvl = @belapsed_median CUDA.@sync batched_symv!('L', 1.0, $A3, $x2, 0.0, $y2)
tbsymvu = @belapsed_median @sync batched_symv!('U', 1.0, $A3, $x2, 0.0, $y2)
tbsymvl = @belapsed_median @sync batched_symv!('L', 1.0, $A3, $x2, 0.0, $y2)

AP = CuArray(hcat([SymmetricPacked(x, :U).tri for x in eachslice(_A, dims=3)]...));
tbspmvu = @belapsed_median CUDA.@sync batched_spmv!('U', 1.0, $AP, $x2, 0.0, $y2)
tbspmvu = @belapsed_median @sync batched_spmv!('U', 1.0, $AP, $x2, 0.0, $y2)
AP = CuArray(hcat([SymmetricPacked(x, :L).tri for x in eachslice(_A, dims=3)]...));
tbspmvl = @belapsed_median CUDA.@sync batched_spmv!('L', 1.0, $AP, $x2, 0.0, $y2)
tbspmvl = @belapsed_median @sync batched_spmv!('L', 1.0, $AP, $x2, 0.0, $y2)

CUDA.unsafe_free!.((A3, AP, x2, x3, y2, y3))
CUDA.memory_status()
Expand Down Expand Up @@ -108,17 +117,17 @@ function doit(L,N)
y2 = CuArray(rand(L,N));
y3 = CuArray(rand(L,1,N));

tbgemm = @belapsed_median CUDA.@sync batched_mul!($A3, $x3, batched_transpose($x3), -1.0, 1.0)
tbgemm = @belapsed_median @sync batched_mul!($A3, $x3, batched_transpose($x3), -1.0, 1.0)

tbger = @belapsed_median CUDA.@sync batched_ger!(-1.0, $x2, $y2, $A3)
tbger = @belapsed_median @sync batched_ger!(-1.0, $x2, $y2, $A3)

tbsyru = @belapsed_median CUDA.@sync batched_syr!('U', -1.0, $x2, $A3)
tbsyrl = @belapsed_median CUDA.@sync batched_syr!('L', -1.0, $x2, $A3)
tbsyru = @belapsed_median @sync batched_syr!('U', -1.0, $x2, $A3)
tbsyrl = @belapsed_median @sync batched_syr!('L', -1.0, $x2, $A3)

AP = CuArray(hcat([SymmetricPacked(x, :U).tri for x in eachslice(_A, dims=3)]...));
tbspru = @belapsed_median CUDA.@sync batched_spr!('U', -1.0, $x2, $AP)
tbspru = @belapsed_median @sync batched_spr!('U', -1.0, $x2, $AP)
AP = CuArray(hcat([SymmetricPacked(x, :L).tri for x in eachslice(_A, dims=3)]...));
tbsprl = @belapsed_median CUDA.@sync batched_spr!('L', -1.0, $x2, $AP)
tbsprl = @belapsed_median @sync batched_spr!('L', -1.0, $x2, $AP)

CUDA.unsafe_free!.((A3, AP, x2, x3, y2, y3))
CUDA.memory_status()
Expand Down
Loading

0 comments on commit 259892b

Please sign in to comment.