diff --git a/README.md b/README.md index 8f428b342..953cc7f8a 100644 --- a/README.md +++ b/README.md @@ -186,6 +186,10 @@ ccomm = MPI.CComm(juliacomm) ### Currently wrapped MPI functions Convention: `MPI_Fun => MPI.Fun` +Constants like `MPI_SUM` are wrapped as `MPI.SUM`. Note also that +arbitrary Julia functions `f(x,y)` can be passed as reduction operations +to the MPI `Allreduce` and `Reduce` functions. + #### Administrative functions Julia Function (assuming `import MPI`) | Fortran Function ---------------------------------------|-------------------------------------------------------- diff --git a/src/MPI.jl b/src/MPI.jl index cc64e2f97..5beb73221 100644 --- a/src/MPI.jl +++ b/src/MPI.jl @@ -18,6 +18,9 @@ include(depfile) include("mpi-base.jl") include("cman.jl") +const mpitype_dict = Dict{DataType, Cint}() +const mpitype_dict_inverse = Dict{Cint, DataType}() + function __init__() @static if is_unix() # need to open libmpi with RTLD_GLOBAL flag for Linux, before @@ -32,32 +35,24 @@ function __init__() end end - global const mpitype_dict = Dict{DataType, Cint}( - # Older versions of OpenMPI (such as those used by default in - # Travis) do not define MPI_WCHAR and the MPI_*INT*_T types for - # Fortran. We thus don't require them (yet). - # Char => MPI_WCHAR, - # Int8 => MPI_INT8_T, - # UInt8 => MPI_UINT8_T, - # Int16 => MPI_INT16_T, - # UInt16 => MPI_UINT16_T, - # Int32 => MPI_INT32_T, - # UInt32 => MPI_UINT32_T, - # Int64 => MPI_INT64_T, - # UInt64 => MPI_UINT64_T, - Char => MPI_INTEGER4, - Int8 => MPI_INTEGER1, - UInt8 => MPI_INTEGER1, - Int16 => MPI_INTEGER2, - UInt16 => MPI_INTEGER2, - Int32 => MPI_INTEGER4, - UInt32 => MPI_INTEGER4, - Int64 => MPI_INTEGER8, - UInt64 => MPI_INTEGER8, - Float32 => MPI_REAL4, - Float64 => MPI_REAL8, - Complex64 => MPI_COMPLEX8, - Complex128 => MPI_COMPLEX16) + # Note: older versions of OpenMPI (e.g. the version on Travis) do not + # define MPI_CHAR and MPI_*INT*_T for Fortran, so we don't use them (yet). + for (T,mpiT) in (Char => MPI_INTEGER4, # => MPI_WCHAR, (note: wchar_t is 16 bits in Windows) + Int8 => MPI_INTEGER1, # => MPI_INT8_T, + UInt8 => MPI_INTEGER1, # => MPI_UINT8_T, + Int16 => MPI_INTEGER2, # => MPI_INT16_T, + UInt16 => MPI_INTEGER2, # => MPI_UINT16_T, + Int32 => MPI_INTEGER4, # => MPI_INT32_T, + UInt32 => MPI_INTEGER4, # => MPI_UINT32_T, + Int64 => MPI_INTEGER8, # => MPI_INT64_T, + UInt64 => MPI_INTEGER8, # => MPI_UINT64_T, + Float32 => MPI_REAL4, + Float64 => MPI_REAL8, + Complex64 => MPI_COMPLEX8, + Complex128 => MPI_COMPLEX16) + mpitype_dict[T] = mpiT + mpitype_dict_inverse[mpiT] = T + end end end diff --git a/src/mpi-base.jl b/src/mpi-base.jl index b88e7ad2f..23fb64f2b 100644 --- a/src/mpi-base.jl +++ b/src/mpi-base.jl @@ -61,6 +61,7 @@ function mpitype{T}(::Type{T}) # add it to the dictonary of known types mpitype_dict[T] = newtype_ref[] + mpitype_dict_inverse[newtype_ref[]] = T return mpitype_dict[T] end @@ -205,8 +206,8 @@ function type_create(T::DataType) # create the datatype newtype_ref = Ref{Cint}() flag = Ref{Cint}() - ccall(MPI_TYPE_CREATE_STRUCT, Void, (Ptr{Cint}, Ptr{Cint}, Ptr{Cptrdiff_t}, - Ptr{Cint}, Ptr{Cint}, Ptr{Cint}), &nfields, blocklengths, displacements, + ccall(MPI_TYPE_CREATE_STRUCT, Void, (Ptr{Cint}, Ptr{Cint}, Ptr{Cptrdiff_t}, + Ptr{Cint}, Ptr{Cint}, Ptr{Cint}), &nfields, blocklengths, displacements, types, newtype_ref, flag) if flag[] != 0 @@ -573,11 +574,11 @@ function Reduce{T}(sendbuf::MPIBuffertype{T}, count::Integer, isroot ? recvbuf : nothing end -function Reduce{T}(sendbuf::Array{T}, op::Op, root::Integer, comm::Comm) +function Reduce{T}(sendbuf::Array{T}, op::Union{Op,Function}, root::Integer, comm::Comm) Reduce(sendbuf, length(sendbuf), op, root, comm) end -function Reduce{T}(object::T, op::Op, root::Integer, comm::Comm) +function Reduce{T}(object::T, op::Union{Op,Function}, root::Integer, comm::Comm) isroot = Comm_rank(comm) == root sendbuf = T[object] recvbuf = Reduce(sendbuf, op, root, comm) @@ -588,18 +589,18 @@ function Allreduce!{T}(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T}, count::Integer, op::Op, comm::Comm) flag = Ref{Cint}() ccall(MPI_ALLREDUCE, Void, (Ptr{T}, Ptr{T}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, - Ptr{Cint}, Ptr{Cint}), sendbuf, recvbuf, &count, &mpitype(T), + Ptr{Cint}, Ptr{Cint}), sendbuf, recvbuf, &count, &mpitype(T), &op.val, &comm.val, flag) recvbuf end - + function Allreduce!{T}(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T}, - op::Op, comm::Comm) + op::Union{Op,Function}, comm::Comm) Allreduce!(sendbuf, recvbuf, length(recvbuf), op, comm) end -function Allreduce{T}(obj::T, op::Op, comm::Comm) +function Allreduce{T}(obj::T, op::Union{Op,Function}, comm::Comm) objref = Ref(obj) outref = Ref{T}() Allreduce!(objref, outref, 1, op, comm) @@ -608,19 +609,20 @@ function Allreduce{T}(obj::T, op::Op, comm::Comm) end # allocate receive buffer automatically -function allreduce{T}(sendbuf::MPIBuffertype{T}, op::Op, comm::Comm) +function allreduce{T}(sendbuf::MPIBuffertype{T}, op::Union{Op,Function}, comm::Comm) recvbuf = similar(sendbuf) Allreduce!(sendbuf, recvbuf, length(recvbuf), op, comm) end +include("mpi-op.jl") -function Scatter{T}(sendbuf::MPIBuffertype{T},count::Integer, root::Integer, +function Scatter{T}(sendbuf::MPIBuffertype{T},count::Integer, root::Integer, comm::Comm) recvbuf = Array(T, count) ccall(MPI_SCATTER, Void, - (Ptr{T}, Ptr{Cint}, Ptr{Cint}, Ptr{T}, Ptr{Cint}, Ptr{Cint}, - Ptr{Cint}, Ptr{Cint}, Ptr{Cint}), sendbuf, &count, &mpitype(T), + (Ptr{T}, Ptr{Cint}, Ptr{Cint}, Ptr{T}, Ptr{Cint}, Ptr{Cint}, + Ptr{Cint}, Ptr{Cint}, Ptr{Cint}), sendbuf, &count, &mpitype(T), recvbuf, &count, &mpitype(T), &root, &comm.val, &0) recvbuf end diff --git a/src/mpi-op.jl b/src/mpi-op.jl new file mode 100644 index 000000000..c960a913b --- /dev/null +++ b/src/mpi-op.jl @@ -0,0 +1,70 @@ +# Implement user-defined MPI reduction operations, by passing Julia +# functions as callbacks to MPI. + +# for defining thread-local variables; can use Compat +# once JuliaLang/Compat.jl#223 is resolved. +if isdefined(Base, :Threads) + import Base.Threads: nthreads, threadid +else + nthreads() = 1 + threadid() = 1 +end + +# Unfortunately, MPI_Op_create takes a function that does not accept +# a void* "thunk" parameter, making it impossible to fully simulate +# a closure. So, we have to use a global variable instead. (Since +# the reduction functions are barriers, being re-entrant is probably +# not important in practice, fortunately.) For MPI_THREAD_MULTIPLE +# using Julia native threading, however, we do make this global thread-local +const _user_functions = Array(Function, 1) # resized to nthreads() at runtime +const _user_op = Op(MPI_OP_NULL) # _mpi_user_function operation, initialized below + +# C callback function corresponding to MPI_User_function +function _mpi_user_function(_a::Ptr{Void}, _b::Ptr{Void}, _len::Ptr{Cint}, t::Ptr{Cint}) + len = unsafe_load(_len) + T = mpitype_dict_inverse[unsafe_load(t)] + a = Ptr{T}(_a) + b = Ptr{T}(_b) + f = _user_functions[threadid()] + for i = 1:len + unsafe_store!(b, f(unsafe_load(a,i), unsafe_load(b,i)), i) + end + return nothing +end + +function user_op(opfunc::Function) + # we must initialize these at runtime, but it can't be done in __init__ + # since MPI.Init is not called yet. So we do it lazily here: + if _user_op.val == MPI_OP_NULL + # FIXME: to be thread-safe, there should really be a mutex lock + # of some sort so that this initialization only occurs once. + # To do when native threading in Julia stabilizes (and is documented). + resize!(_user_functions, nthreads()) + user_function = cfunction(_mpi_user_function, Void, (Ptr{Void}, Ptr{Void}, Ptr{Cint}, Ptr{Cint})) + opnum = Ref{Cint}() + ccall(MPI_OP_CREATE, Void, (Ptr{Void}, Ref{Cint}, Ref{Cint}, Ptr{Cint}), + user_function, false, opnum, &0) + _user_op.val = opnum[] + end + + _user_functions[threadid()] = opfunc + return _user_op +end + +# use function types in Julia 0.5 to automatically use built-in +# MPI operations for the corresponding Julia functions. +if VERSION >= v"0.5.0-dev+2396" + for (f,op) in ((+,SUM), (*,PROD), + (min,MIN), (max,MAX), + (&, BAND), (|, BOR), ($, BXOR)) + @eval user_op(::$(typeof(f))) = $op + end +end + +Allreduce!{T}(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T}, + count::Integer, opfunc::Function, comm::Comm) = + Allreduce!(sendbuf, recvbuf, count, user_op(opfunc), comm) + +Reduce{T}(sendbuf::MPIBuffertype{T}, count::Integer, + op::Function, root::Integer, comm::Comm) = + Reduce(sendbuf, count, user_op(opfunc), root, comm) diff --git a/src/win_mpiconstants.jl b/src/win_mpiconstants.jl index f5023c7f3..56cf4d8c0 100644 --- a/src/win_mpiconstants.jl +++ b/src/win_mpiconstants.jl @@ -64,6 +64,8 @@ const MPI_ALLTOALL = (:MPI_ALLTOALL, "msmpi.dll") const MPI_ALLTOALLV = (:MPI_ALLTOALLV, "msmpi.dll") const MPI_INITIALIZED = (:MPI_INITIALIZED, "msmpi.dll") const MPI_FINALIZED = (:MPI_FINALIZED, "msmpi.dll") +const MPI_OP_CREATE = (:MPI_OP_CREATE, "msmpi.dll") +const MPI_OP_FREE = (:MPI_OP_FREE, "msmpi.dll") const MPI_SCATTER = (:MPI_SCATTER, "msmpi.dll") const MPI_SCATTERV = (:MPI_SCATTERV, "msmpi.dll") const MPI_SEND = (:MPI_SEND, "msmpi.dll") diff --git a/test/test_allreduce.jl b/test/test_allreduce.jl index 442568d4f..e0db92fde 100644 --- a/test/test_allreduce.jl +++ b/test/test_allreduce.jl @@ -6,23 +6,26 @@ MPI.Init() comm_size = MPI.Comm_size(MPI.COMM_WORLD) send_arr = Int[1, 2, 3] -recv_arr = zeros(Int, 3) -MPI.Allreduce!(send_arr, recv_arr, MPI.SUM, MPI.COMM_WORLD) +for op in (MPI.SUM, +, (x,y) -> 2x+y-x) + recv_arr = zeros(Int, 3) -for i=1:3 - @test recv_arr[i] == comm_size*send_arr[i] -end + MPI.Allreduce!(send_arr, recv_arr, op, MPI.COMM_WORLD) + + for i=1:3 + @test recv_arr[i] == comm_size*send_arr[i] + end -val = MPI.Allreduce(2, MPI.SUM, MPI.COMM_WORLD) -@test val == comm_size*2 + val = MPI.Allreduce(2, op, MPI.COMM_WORLD) + @test val == comm_size*2 -vals = MPI.allreduce(send_arr, MPI.SUM, MPI.COMM_WORLD) -for i=1:3 - @test vals[i] == comm_size*send_arr[i] - @test length(vals) == 3 - @test eltype(vals) == Int + vals = MPI.allreduce(send_arr, op, MPI.COMM_WORLD) + for i=1:3 + @test vals[i] == comm_size*send_arr[i] + @test length(vals) == 3 + @test eltype(vals) == Int + end end MPI.Barrier( MPI.COMM_WORLD )