diff --git a/src/mpi-base.jl b/src/mpi-base.jl index 81f153299..b88e7ad2f 100644 --- a/src/mpi-base.jl +++ b/src/mpi-base.jl @@ -584,12 +584,44 @@ function Reduce{T}(object::T, op::Op, root::Integer, comm::Comm) isroot ? recvbuf[1] : nothing end -function Scatter{T}(sendbuf::MPIBuffertype{T}, - count::Integer, root::Integer, comm::Comm) +function Allreduce!{T}(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T}, + count::Integer, op::Op, comm::Comm) + flag = Ref{Cint}() + ccall(MPI_ALLREDUCE, Void, (Ptr{T}, Ptr{T}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, + Ptr{Cint}, Ptr{Cint}), sendbuf, recvbuf, &count, &mpitype(T), + &op.val, &comm.val, flag) + + recvbuf +end + +function Allreduce!{T}(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T}, + op::Op, comm::Comm) + Allreduce!(sendbuf, recvbuf, length(recvbuf), op, comm) +end + +function Allreduce{T}(obj::T, op::Op, comm::Comm) + objref = Ref(obj) + outref = Ref{T}() + Allreduce!(objref, outref, 1, op, comm) + + outref[] +end + +# allocate receive buffer automatically +function allreduce{T}(sendbuf::MPIBuffertype{T}, op::Op, comm::Comm) + + recvbuf = similar(sendbuf) + Allreduce!(sendbuf, recvbuf, length(recvbuf), op, comm) +end + + +function Scatter{T}(sendbuf::MPIBuffertype{T},count::Integer, root::Integer, + comm::Comm) recvbuf = Array(T, count) ccall(MPI_SCATTER, Void, - (Ptr{T}, Ptr{Cint}, Ptr{Cint}, Ptr{T}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}), - sendbuf, &count, &mpitype(T), recvbuf, &count, &mpitype(T), &root, &comm.val, &0) + (Ptr{T}, Ptr{Cint}, Ptr{Cint}, Ptr{T}, Ptr{Cint}, Ptr{Cint}, + Ptr{Cint}, Ptr{Cint}, Ptr{Cint}), sendbuf, &count, &mpitype(T), + recvbuf, &count, &mpitype(T), &root, &comm.val, &0) recvbuf end diff --git a/src/win_mpiconstants.jl b/src/win_mpiconstants.jl index 9daecd3bd..f5023c7f3 100644 --- a/src/win_mpiconstants.jl +++ b/src/win_mpiconstants.jl @@ -86,5 +86,6 @@ const MPI_WAIT = (:MPI_WAIT, "msmpi.dll") const MPI_WAITSOME = (:MPI_WAITSOME, "msmpi.dll") const MPI_WAITANY = (:MPI_WAITANY, "msmpi.dll") const MPI_CANCEL = (:MPI_CANCEL, "msmpi.dll") +const MPI_ALLREDUCE = (:MPI_ALLREDUCE, "msmpi.dll") bitstype 32 CComm diff --git a/test/test_allreduce.jl b/test/test_allreduce.jl new file mode 100644 index 000000000..442568d4f --- /dev/null +++ b/test/test_allreduce.jl @@ -0,0 +1,29 @@ +using Base.Test +using MPI + +MPI.Init() + +comm_size = MPI.Comm_size(MPI.COMM_WORLD) + +send_arr = Int[1, 2, 3] +recv_arr = zeros(Int, 3) + +MPI.Allreduce!(send_arr, recv_arr, MPI.SUM, MPI.COMM_WORLD) + +for i=1:3 + @test recv_arr[i] == comm_size*send_arr[i] +end + + +val = MPI.Allreduce(2, MPI.SUM, MPI.COMM_WORLD) +@test val == comm_size*2 + +vals = MPI.allreduce(send_arr, MPI.SUM, MPI.COMM_WORLD) +for i=1:3 + @test vals[i] == comm_size*send_arr[i] + @test length(vals) == 3 + @test eltype(vals) == Int +end + +MPI.Barrier( MPI.COMM_WORLD ) +MPI.Finalize()