diff --git a/docs/src/reference/pointtopoint.md b/docs/src/reference/pointtopoint.md index 64d96bd32..8da05d96a 100644 --- a/docs/src/reference/pointtopoint.md +++ b/docs/src/reference/pointtopoint.md @@ -55,6 +55,7 @@ MPI.Testall MPI.Testany MPI.Testsome MPI.Wait +Base.wait(req::MPI.Request) MPI.Waitall MPI.Waitany MPI.Waitsome diff --git a/src/nonblocking.jl b/src/nonblocking.jl index fef29ea47..bfb9da7b3 100644 --- a/src/nonblocking.jl +++ b/src/nonblocking.jl @@ -774,3 +774,14 @@ function Cancel!(req::AbstractRequest) API.MPI_Cancel(req) nothing end + +""" + Base.wait(req::MPI.Request) + +Wait for an MPI request to complete. Unlike [`MPI.Wait`](@ref), it will yield to other Julia tasks resulting in a cooperative wait. +""" +function Base.wait(req::MPI.Request) + while !MPI.Test(req) + yield() + end +end diff --git a/test/test_cooperative_wait.jl b/test/test_cooperative_wait.jl new file mode 100644 index 000000000..3219305ac --- /dev/null +++ b/test/test_cooperative_wait.jl @@ -0,0 +1,37 @@ +# tests for the various kinds of waits +include("common.jl") + +MPI.Init() + +myrank = MPI.Comm_rank(MPI.COMM_WORLD) +commsize = MPI.Comm_rank(MPI.COMM_WORLD) + +nsends = 2 +send_arr = [ArrayType{Int}([i]) for i = 1:nsends] +recv_arr = [ArrayType{Int}(undef,1) for i = 1:nsends] +synchronize() + +send_check = zeros(Int, nsends) +recv_check = zeros(Int, nsends) + +@sync for i = 1:nsends + Threads.@spawn begin + recv_req = MPI.Irecv!(recv_arr[i], MPI.COMM_WORLD; source=myrank, tag=i) + wait(recv_req) + @test MPI.isnull(recv_req) + recv_check[i] += 1 + end + Threads.@spawn begin + send_req = MPI.Isend(send_arr[i], MPI.COMM_WORLD; dest=myrank, tag=i) + wait(send_req) + @test MPI.isnull(send_req) + send_check[i] += 1 + end +end + +@test recv_check == ones(Int, nsends) +@test send_check == ones(Int, nsends) + +MPI.Barrier(MPI.COMM_WORLD) +MPI.Finalize() +@test MPI.Finalized()