Skip to content

Commit

Permalink
Update exchanger API
Browse files Browse the repository at this point in the history
  • Loading branch information
utkinis committed Sep 12, 2023
1 parent 4240862 commit cca98e5
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 35 deletions.
32 changes: 16 additions & 16 deletions scripts_future_API/bench3d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,8 @@ function main(backend=CPU(), T::DataType=Float64, dims=(0, 0, 0))
### to be hidden later
ranges = split_ndrange(A, b_width)

exchangers = ntuple(Val(length(neighbors))) do dim
ntuple(2) do side
rank = neighbors[dim][side]
halo = get_recv_view(Val(side), Val(dim), A_new)
border = get_send_view(Val(side), Val(dim), A_new)
range = ranges[2*(dim-1) + side]
offset, ndrange = first(range), size(range)
Exchanger(backend, comm, rank, halo, border) do compute_bc
NVTX.@range "borders" diffusion_kernel!(backend, 256)(A_new, A, h, _dx, _dy, _dz, offset; ndrange)
if compute_bc
# apply_bcs!(Val(dim), fields, bcs.velocity)
end
KernelAbstractions.synchronize(backend)
end
end
exchangers = ntuple(Val(length(neighbors))) do _
ntuple(_ -> Exchanger(backend), Val(2))
end
### to be hidden later

Expand All @@ -73,7 +60,20 @@ function main(backend=CPU(), T::DataType=Float64, dims=(0, 0, 0))
NVTX.@range "step $it" begin
NVTX.@range "inner" diffusion_kernel!(backend, 256)(A_new, A, h, _dx, _dy, _dz, first(ranges[end]); ndrange=size(ranges[end]))
for dim in reverse(eachindex(neighbors))
notify.(exchangers[dim])
ntuple(Val(2)) do side
rank = neighbors[dim][side]
halo = get_recv_view(Val(side), Val(dim), A_new)
border = get_send_view(Val(side), Val(dim), A_new)
range = ranges[2*(dim-1) + side]
offset, ndrange = first(range), size(range)
start_exchange(exchangers[dim], comm, rank, halo, border) do compute_bc
NVTX.@range "borders" diffusion_kernel!(backend, 256)(A_new, A, h, _dx, _dy, _dz, offset; ndrange)
if compute_bc
# apply_bcs!(Val(dim), fields, bcs.velocity)
end
KernelAbstractions.synchronize(backend)
end
end
wait.(exchangers[dim])
end
KernelAbstractions.synchronize(backend)
Expand Down
24 changes: 12 additions & 12 deletions scripts_future_API/exchanger2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,27 @@ function main(backend = CPU(), T::DataType = Float64, dims = (0, 0, 0))

ranges = split_ndrange(A, b_width)

exchangers = ntuple(Val(length(neighbors))) do dim
ntuple(2) do side
exchangers = ntuple(Val(length(neighbors))) do _
ntuple(_ -> Exchanger(backend), Val(2))
end

do_work!(backend, 256)(A, me, first(ranges[end]); ndrange=size(ranges[end]))

for dim in reverse(eachindex(neighbors))
ntuple(Val(2)) do side
rank = neighbors[dim][side]
halo = get_recv_view(Val(side), Val(dim), A)
border = get_send_view(Val(side), Val(dim), A)
halo = get_recv_view(Val(side), Val(dim), A_new)
border = get_send_view(Val(side), Val(dim), A_new)
range = ranges[2*(dim-1) + side]
offset, ndrange = first(range), size(range)
Exchanger(backend, comm, rank, halo, border) do compute_bc
do_work!(backend, 256)(A, me, offset; ndrange)
start_exchange(exchangers[dim], comm, rank, halo, border) do compute_bc
NVTX.@range "borders" do_work!(backend, 256)(A, me, offset; ndrange)
if compute_bc
# apply_bcs!(Val(dim), fields, bcs.velocity)
end
KernelAbstractions.synchronize(backend)
end
end
end

do_work!(backend, 256)(A, me, first(ranges[end]); ndrange=size(ranges[end]))

for dim in reverse(eachindex(neighbors))
notify.(exchangers[dim])
wait.(exchangers[dim])
end

Expand Down
15 changes: 8 additions & 7 deletions scripts_future_API/mpi_utils2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ end
# exchanger
mutable struct Exchanger
@atomic done::Bool
top::Base.Event
channel::Channel
bottom::Base.Event
@atomic err
task::Task

function Exchanger(f::F, backend::Backend, comm, rank, halo, border) where F
top = Base.Event(#=autoreset=# true)
bottom = Base.Event(#=autoreset=# true)
function Exchanger(backend::Backend)
channel = Channel()
bottom = Base.Event(true)

send_buf = similar(border)
recv_buf = similar(halo)
Expand All @@ -56,7 +56,7 @@ mutable struct Exchanger
KernelAbstractions.priority!(backend, :high)
try
while !(@atomic this.done)
wait(top)
f, comm, rank, halo, border = take!(channel)
NVTX.@mark "after wait(top)"
if has_neighbor
recv = MPI.Irecv!(recv_buf, comm; source=rank)
Expand Down Expand Up @@ -86,9 +86,10 @@ end
setdone!(exc::Exchanger) = @atomic exc.done = true

Base.isdone(exc::Exchanger) = @atomic exc.done
function Base.notify(exc::Exchanger)

function start_exchange(f, exc::Exchanger, comm, rank, halo, border)
if !(@atomic exc.done)
notify(exc.top)
put!(exc.channel, (f, comm, rank, halo, border))
else
error("notify: Exchanger is not running")
end
Expand Down

0 comments on commit cca98e5

Please sign in to comment.