Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add non-blocking wait #766

Merged
merged 7 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/reference/pointtopoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ MPI.Testall
MPI.Testany
MPI.Testsome
MPI.Wait
Base.wait(req::MPI.Request)
MPI.Waitall
MPI.Waitany
MPI.Waitsome
Expand Down
11 changes: 11 additions & 0 deletions src/nonblocking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -774,3 +774,14 @@ function Cancel!(req::AbstractRequest)
API.MPI_Cancel(req)
nothing
end

"""
Base.wait(req::MPI.Request)

Wait for an MPI request to complete. Unlike [`MPI.Wait`](@ref), it will yield to other Julia tasks resulting in a cooperative wait.
"""
function Base.wait(req::MPI.Request)
while !MPI.Test(req)
yield()
end
end
37 changes: 37 additions & 0 deletions test/test_cooperative_wait.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# tests for the various kinds of waits
include("common.jl")

MPI.Init()

myrank = MPI.Comm_rank(MPI.COMM_WORLD)
commsize = MPI.Comm_rank(MPI.COMM_WORLD)

nsends = 2
send_arr = [ArrayType{Int}([i]) for i = 1:nsends]
recv_arr = [ArrayType{Int}(undef,1) for i = 1:nsends]
synchronize()

send_check = zeros(Int, nsends)
recv_check = zeros(Int, nsends)

@sync for i = 1:nsends
Threads.@spawn begin
recv_req = MPI.Irecv!(recv_arr[i], MPI.COMM_WORLD; source=myrank, tag=i)
wait(recv_req)
@test MPI.isnull(recv_req)
recv_check[i] += 1
end
Threads.@spawn begin
send_req = MPI.Isend(send_arr[i], MPI.COMM_WORLD; dest=myrank, tag=i)
wait(send_req)
@test MPI.isnull(send_req)
send_check[i] += 1
end
end

@test recv_check == ones(Int, nsends)
@test send_check == ones(Int, nsends)

MPI.Barrier(MPI.COMM_WORLD)
MPI.Finalize()
@test MPI.Finalized()
Loading