Skip to content

Commit

Permalink
Merge pull request JuliaParallel#160 from JuliaParallel/op_create
Browse files Browse the repository at this point in the history
support arbitrary Julia functions in reduction operations
  • Loading branch information
eschnett authored Jun 10, 2016
2 parents 6a0c164 + 60d130c commit 171569e
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 50 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ ccomm = MPI.CComm(juliacomm)
### Currently wrapped MPI functions
Convention: `MPI_Fun => MPI.Fun`

Constants like `MPI_SUM` are wrapped as `MPI.SUM`. Note also that
arbitrary Julia functions `f(x,y)` can be passed as reduction operations
to the MPI `Allreduce` and `Reduce` functions.

#### Administrative functions
Julia Function (assuming `import MPI`) | Fortran Function
---------------------------------------|--------------------------------------------------------
Expand Down
47 changes: 21 additions & 26 deletions src/MPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ include(depfile)
include("mpi-base.jl")
include("cman.jl")

const mpitype_dict = Dict{DataType, Cint}()
const mpitype_dict_inverse = Dict{Cint, DataType}()

function __init__()
@static if is_unix()
# need to open libmpi with RTLD_GLOBAL flag for Linux, before
Expand All @@ -32,32 +35,24 @@ function __init__()
end
end

global const mpitype_dict = Dict{DataType, Cint}(
# Older versions of OpenMPI (such as those used by default in
# Travis) do not define MPI_WCHAR and the MPI_*INT*_T types for
# Fortran. We thus don't require them (yet).
# Char => MPI_WCHAR,
# Int8 => MPI_INT8_T,
# UInt8 => MPI_UINT8_T,
# Int16 => MPI_INT16_T,
# UInt16 => MPI_UINT16_T,
# Int32 => MPI_INT32_T,
# UInt32 => MPI_UINT32_T,
# Int64 => MPI_INT64_T,
# UInt64 => MPI_UINT64_T,
Char => MPI_INTEGER4,
Int8 => MPI_INTEGER1,
UInt8 => MPI_INTEGER1,
Int16 => MPI_INTEGER2,
UInt16 => MPI_INTEGER2,
Int32 => MPI_INTEGER4,
UInt32 => MPI_INTEGER4,
Int64 => MPI_INTEGER8,
UInt64 => MPI_INTEGER8,
Float32 => MPI_REAL4,
Float64 => MPI_REAL8,
Complex64 => MPI_COMPLEX8,
Complex128 => MPI_COMPLEX16)
# Note: older versions of OpenMPI (e.g. the version on Travis) do not
# define MPI_CHAR and MPI_*INT*_T for Fortran, so we don't use them (yet).
for (T,mpiT) in (Char => MPI_INTEGER4, # => MPI_WCHAR, (note: wchar_t is 16 bits in Windows)
Int8 => MPI_INTEGER1, # => MPI_INT8_T,
UInt8 => MPI_INTEGER1, # => MPI_UINT8_T,
Int16 => MPI_INTEGER2, # => MPI_INT16_T,
UInt16 => MPI_INTEGER2, # => MPI_UINT16_T,
Int32 => MPI_INTEGER4, # => MPI_INT32_T,
UInt32 => MPI_INTEGER4, # => MPI_UINT32_T,
Int64 => MPI_INTEGER8, # => MPI_INT64_T,
UInt64 => MPI_INTEGER8, # => MPI_UINT64_T,
Float32 => MPI_REAL4,
Float64 => MPI_REAL8,
Complex64 => MPI_COMPLEX8,
Complex128 => MPI_COMPLEX16)
mpitype_dict[T] = mpiT
mpitype_dict_inverse[mpiT] = T
end
end

end
26 changes: 14 additions & 12 deletions src/mpi-base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ function mpitype{T}(::Type{T})

# add it to the dictonary of known types
mpitype_dict[T] = newtype_ref[]
mpitype_dict_inverse[newtype_ref[]] = T

return mpitype_dict[T]
end
Expand Down Expand Up @@ -205,8 +206,8 @@ function type_create(T::DataType)
# create the datatype
newtype_ref = Ref{Cint}()
flag = Ref{Cint}()
ccall(MPI_TYPE_CREATE_STRUCT, Void, (Ptr{Cint}, Ptr{Cint}, Ptr{Cptrdiff_t},
Ptr{Cint}, Ptr{Cint}, Ptr{Cint}), &nfields, blocklengths, displacements,
ccall(MPI_TYPE_CREATE_STRUCT, Void, (Ptr{Cint}, Ptr{Cint}, Ptr{Cptrdiff_t},
Ptr{Cint}, Ptr{Cint}, Ptr{Cint}), &nfields, blocklengths, displacements,
types, newtype_ref, flag)

if flag[] != 0
Expand Down Expand Up @@ -573,11 +574,11 @@ function Reduce{T}(sendbuf::MPIBuffertype{T}, count::Integer,
isroot ? recvbuf : nothing
end

function Reduce{T}(sendbuf::Array{T}, op::Op, root::Integer, comm::Comm)
function Reduce{T}(sendbuf::Array{T}, op::Union{Op,Function}, root::Integer, comm::Comm)
Reduce(sendbuf, length(sendbuf), op, root, comm)
end

function Reduce{T}(object::T, op::Op, root::Integer, comm::Comm)
function Reduce{T}(object::T, op::Union{Op,Function}, root::Integer, comm::Comm)
isroot = Comm_rank(comm) == root
sendbuf = T[object]
recvbuf = Reduce(sendbuf, op, root, comm)
Expand All @@ -588,18 +589,18 @@ function Allreduce!{T}(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T},
count::Integer, op::Op, comm::Comm)
flag = Ref{Cint}()
ccall(MPI_ALLREDUCE, Void, (Ptr{T}, Ptr{T}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint},
Ptr{Cint}, Ptr{Cint}), sendbuf, recvbuf, &count, &mpitype(T),
Ptr{Cint}, Ptr{Cint}), sendbuf, recvbuf, &count, &mpitype(T),
&op.val, &comm.val, flag)

recvbuf
end

function Allreduce!{T}(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T},
op::Op, comm::Comm)
op::Union{Op,Function}, comm::Comm)
Allreduce!(sendbuf, recvbuf, length(recvbuf), op, comm)
end

function Allreduce{T}(obj::T, op::Op, comm::Comm)
function Allreduce{T}(obj::T, op::Union{Op,Function}, comm::Comm)
objref = Ref(obj)
outref = Ref{T}()
Allreduce!(objref, outref, 1, op, comm)
Expand All @@ -608,19 +609,20 @@ function Allreduce{T}(obj::T, op::Op, comm::Comm)
end

# allocate receive buffer automatically
function allreduce{T}(sendbuf::MPIBuffertype{T}, op::Op, comm::Comm)
function allreduce{T}(sendbuf::MPIBuffertype{T}, op::Union{Op,Function}, comm::Comm)

recvbuf = similar(sendbuf)
Allreduce!(sendbuf, recvbuf, length(recvbuf), op, comm)
end

include("mpi-op.jl")

function Scatter{T}(sendbuf::MPIBuffertype{T},count::Integer, root::Integer,
function Scatter{T}(sendbuf::MPIBuffertype{T},count::Integer, root::Integer,
comm::Comm)
recvbuf = Array(T, count)
ccall(MPI_SCATTER, Void,
(Ptr{T}, Ptr{Cint}, Ptr{Cint}, Ptr{T}, Ptr{Cint}, Ptr{Cint},
Ptr{Cint}, Ptr{Cint}, Ptr{Cint}), sendbuf, &count, &mpitype(T),
(Ptr{T}, Ptr{Cint}, Ptr{Cint}, Ptr{T}, Ptr{Cint}, Ptr{Cint},
Ptr{Cint}, Ptr{Cint}, Ptr{Cint}), sendbuf, &count, &mpitype(T),
recvbuf, &count, &mpitype(T), &root, &comm.val, &0)
recvbuf
end
Expand Down
70 changes: 70 additions & 0 deletions src/mpi-op.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Implement user-defined MPI reduction operations, by passing Julia
# functions as callbacks to MPI.

# for defining thread-local variables; can use Compat
# once JuliaLang/Compat.jl#223 is resolved.
if isdefined(Base, :Threads)
import Base.Threads: nthreads, threadid
else
nthreads() = 1
threadid() = 1
end

# Unfortunately, MPI_Op_create takes a function that does not accept
# a void* "thunk" parameter, making it impossible to fully simulate
# a closure. So, we have to use a global variable instead. (Since
# the reduction functions are barriers, being re-entrant is probably
# not important in practice, fortunately.) For MPI_THREAD_MULTIPLE
# using Julia native threading, however, we do make this global thread-local
const _user_functions = Array(Function, 1) # resized to nthreads() at runtime
const _user_op = Op(MPI_OP_NULL) # _mpi_user_function operation, initialized below

# C callback function corresponding to MPI_User_function
function _mpi_user_function(_a::Ptr{Void}, _b::Ptr{Void}, _len::Ptr{Cint}, t::Ptr{Cint})
len = unsafe_load(_len)
T = mpitype_dict_inverse[unsafe_load(t)]
a = Ptr{T}(_a)
b = Ptr{T}(_b)
f = _user_functions[threadid()]
for i = 1:len
unsafe_store!(b, f(unsafe_load(a,i), unsafe_load(b,i)), i)
end
return nothing
end

function user_op(opfunc::Function)
# we must initialize these at runtime, but it can't be done in __init__
# since MPI.Init is not called yet. So we do it lazily here:
if _user_op.val == MPI_OP_NULL
# FIXME: to be thread-safe, there should really be a mutex lock
# of some sort so that this initialization only occurs once.
# To do when native threading in Julia stabilizes (and is documented).
resize!(_user_functions, nthreads())
user_function = cfunction(_mpi_user_function, Void, (Ptr{Void}, Ptr{Void}, Ptr{Cint}, Ptr{Cint}))
opnum = Ref{Cint}()
ccall(MPI_OP_CREATE, Void, (Ptr{Void}, Ref{Cint}, Ref{Cint}, Ptr{Cint}),
user_function, false, opnum, &0)
_user_op.val = opnum[]
end

_user_functions[threadid()] = opfunc
return _user_op
end

# use function types in Julia 0.5 to automatically use built-in
# MPI operations for the corresponding Julia functions.
if VERSION >= v"0.5.0-dev+2396"
for (f,op) in ((+,SUM), (*,PROD),
(min,MIN), (max,MAX),
(&, BAND), (|, BOR), ($, BXOR))
@eval user_op(::$(typeof(f))) = $op
end
end

Allreduce!{T}(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T},
count::Integer, opfunc::Function, comm::Comm) =
Allreduce!(sendbuf, recvbuf, count, user_op(opfunc), comm)

Reduce{T}(sendbuf::MPIBuffertype{T}, count::Integer,
op::Function, root::Integer, comm::Comm) =
Reduce(sendbuf, count, user_op(opfunc), root, comm)
2 changes: 2 additions & 0 deletions src/win_mpiconstants.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ const MPI_ALLTOALL = (:MPI_ALLTOALL, "msmpi.dll")
const MPI_ALLTOALLV = (:MPI_ALLTOALLV, "msmpi.dll")
const MPI_INITIALIZED = (:MPI_INITIALIZED, "msmpi.dll")
const MPI_FINALIZED = (:MPI_FINALIZED, "msmpi.dll")
const MPI_OP_CREATE = (:MPI_OP_CREATE, "msmpi.dll")
const MPI_OP_FREE = (:MPI_OP_FREE, "msmpi.dll")
const MPI_SCATTER = (:MPI_SCATTER, "msmpi.dll")
const MPI_SCATTERV = (:MPI_SCATTERV, "msmpi.dll")
const MPI_SEND = (:MPI_SEND, "msmpi.dll")
Expand Down
27 changes: 15 additions & 12 deletions test/test_allreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,26 @@ MPI.Init()
comm_size = MPI.Comm_size(MPI.COMM_WORLD)

send_arr = Int[1, 2, 3]
recv_arr = zeros(Int, 3)

MPI.Allreduce!(send_arr, recv_arr, MPI.SUM, MPI.COMM_WORLD)
for op in (MPI.SUM, +, (x,y) -> 2x+y-x)
recv_arr = zeros(Int, 3)

for i=1:3
@test recv_arr[i] == comm_size*send_arr[i]
end
MPI.Allreduce!(send_arr, recv_arr, op, MPI.COMM_WORLD)

for i=1:3
@test recv_arr[i] == comm_size*send_arr[i]
end


val = MPI.Allreduce(2, MPI.SUM, MPI.COMM_WORLD)
@test val == comm_size*2
val = MPI.Allreduce(2, op, MPI.COMM_WORLD)
@test val == comm_size*2

vals = MPI.allreduce(send_arr, MPI.SUM, MPI.COMM_WORLD)
for i=1:3
@test vals[i] == comm_size*send_arr[i]
@test length(vals) == 3
@test eltype(vals) == Int
vals = MPI.allreduce(send_arr, op, MPI.COMM_WORLD)
for i=1:3
@test vals[i] == comm_size*send_arr[i]
@test length(vals) == 3
@test eltype(vals) == Int
end
end

MPI.Barrier( MPI.COMM_WORLD )
Expand Down

0 comments on commit 171569e

Please sign in to comment.