Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
JaredCrean2 committed Jun 1, 2016
1 parent 52af66f commit 0435deb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
14 changes: 5 additions & 9 deletions src/mpi-base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -529,29 +529,25 @@ function Reduce{T}(object::T, op::Op, root::Integer, comm::Comm)
isroot ? recvbuf[1] : nothing
end

function Allreduce{T}(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T},
function Allreduce!{T}(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T},
count::Integer, op::Op, comm::Comm)
flag = Ref{Cint}()
ccall(MPI_ALLREDUCE, Void, (Ptr{T}, Ptr{T}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint},
Ptr{Cint}, Ptr{Cint}), sendbuf, recvbuf, &count, &mpitype(T),
&op.val, &comm.val, flag)

if flag[] != 0
throw(ErrorException("Allreduce returned non-zero exit status"))
end

recvbuf
end

function Allreduce{T}(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T},
function Allreduce!{T}(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T},
op::Op, comm::Comm)
Allreduce(sendbuf, recvbuf, length(recvbuf), op, comm)
Allreduce!(sendbuf, recvbuf, length(recvbuf), op, comm)
end

function Allreduce{T}(obj::T, op::Op, comm::Comm)
objref = Ref(obj)
outref = Ref{T}()
Allreduce(objref, outref, 1, op, comm)
Allreduce!(objref, outref, 1, op, comm)

outref[]
end
Expand All @@ -560,7 +556,7 @@ end
function allreduce{T}(sendbuf::MPIBuffertype{T}, op::Op, comm::Comm)

recvbuf = similar(sendbuf)
Allreduce(sendbuf, recvbuf, length(recvbuf), op, comm)
Allreduce!(sendbuf, recvbuf, length(recvbuf), op, comm)
end


Expand Down
6 changes: 3 additions & 3 deletions test/test_allreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@ using MPI

MPI.Init()

#MPI.mpitype_dict[Boundary] = MPI.mpitype_dict[Int]
comm_size = MPI.Comm_size(MPI.COMM_WORLD)
comm_rank = MPI.Comm_rank(MPI.COMM_WORLD) + 1

send_arr = Int[1, 2, 3]
recv_arr = zeros(Int, 3)

MPI.Allreduce(send_arr, recv_arr, MPI.SUM, MPI.COMM_WORLD)
MPI.Allreduce!(send_arr, recv_arr, MPI.SUM, MPI.COMM_WORLD)

for i=1:3
@test recv_arr[i] == comm_size*send_arr[i]
Expand All @@ -23,6 +21,8 @@ val = MPI.Allreduce(2, MPI.SUM, MPI.COMM_WORLD)
vals = MPI.allreduce(send_arr, MPI.SUM, MPI.COMM_WORLD)
for i=1:3
@test vals[i] == comm_size*send_arr[i]
@test length(vals) == 3
@test eltype(vals) == Int64
end

MPI.Barrier( MPI.COMM_WORLD )
Expand Down

0 comments on commit 0435deb

Please sign in to comment.