From f98083de8df3b178eea357c35b830c2be80e1ad4 Mon Sep 17 00:00:00 2001 From: Simon Byrne Date: Fri, 3 Jan 2020 13:17:31 -0800 Subject: [PATCH] Add Buffer type, improve Datatype handling (#329) This contains two related changes: 1. Defines a specific `Buffer` type, which contains the reference to the storage buffer, its count and datatype. This allows us to simplify the type signatures of various functions, as `count` and `datatype` no longer need to be arguments to the functions. This also adds default conversion methods for `Array`s and `Subarray`s (creating the derived datatypes where necessary, and determining the appropriate `count`s), and moves the point-to-point operations to use these conversions. 2. Improves the handling of `Datatype` handles, by making them garbage-collected objects (like other MPI handles), moves lower-level functions to a submodule, defines consistent interfaces. Also fixes #327. I still need to move the collective calls over as well, however that will require more thought on how to handle the "chunked" operations like scatter/gather. I also removed the inverse dictionary mappings from MPI Datatype -> Julia Type, as that is no longer so easy to determine. --- Project.toml | 3 +- deps/consts_msmpi.jl | 4 + deps/gen_consts.jl | 4 + docs/src/advanced.md | 18 +- src/MPI.jl | 3 + src/buffers.jl | 127 ++++++++++++ src/collective.jl | 26 +-- src/cuda.jl | 15 +- src/datatypes.jl | 438 ++++++++++++++++++++++++------------------ src/deprecated.jl | 50 +++++ src/onesided.jl | 10 +- src/operators.jl | 10 +- src/pointtopoint.jl | 160 ++++++--------- test/test_datatype.jl | 191 ++++++++++-------- test/test_sendrecv.jl | 14 +- test/test_subarray.jl | 106 +++++++--- 16 files changed, 748 insertions(+), 431 deletions(-) create mode 100644 src/buffers.jl diff --git a/Project.toml b/Project.toml index de51d2623..805768933 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.11.0" [deps] Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" @@ -12,8 +13,8 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" [compat] -julia = "1" Requires = "~0.5" +julia = "1" [extras] DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78" diff --git a/deps/consts_msmpi.jl b/deps/consts_msmpi.jl index 4f6caf42d..31d54c074 100644 --- a/deps/consts_msmpi.jl +++ b/deps/consts_msmpi.jl @@ -1,5 +1,9 @@ # From https://github.com/microsoft/Microsoft-MPI/blob/v10.0/src/include/mpi.h +const MPI_Aint = Int +const MPI_Offset = Int64 +const MPI_Count = Int64 + for T in [:MPI_Comm, :MPI_Info, :MPI_Win, :MPI_Request, :MPI_Op, :MPI_Datatype] @eval begin primitive type $T 32 end diff --git a/deps/gen_consts.jl b/deps/gen_consts.jl index 9e59d0012..52f720857 100644 --- a/deps/gen_consts.jl +++ b/deps/gen_consts.jl @@ -125,6 +125,10 @@ int main(int argc, char *argv[]) { fprintf(fptr, "# Do not edit\\n"); """) + println(f," fprintf(fptr, \"const MPI_Aint = Int%d\\n\", 8*(int)sizeof(MPI_Aint));") + println(f," fprintf(fptr, \"const MPI_Offset = Int%d\\n\", 8*(int)sizeof(MPI_Offset));") + println(f," fprintf(fptr, \"const MPI_Count = Int%d\\n\", 8*(int)sizeof(MPI_Count));") + println(f," fprintf(fptr, \"const MPI_Status_size = %d\\n\", (int)sizeof(MPI_Status));") println(f," fprintf(fptr, \"const MPI_Status_Source_offset = %d\\n\", (int)offsetof(MPI_Status, MPI_SOURCE));") println(f," fprintf(fptr, \"const MPI_Status_Tag_offset = %d\\n\", (int)offsetof(MPI_Status, MPI_TAG));") diff --git a/docs/src/advanced.md b/docs/src/advanced.md index 2b56ffcd5..bfe66771a 100644 --- a/docs/src/advanced.md +++ b/docs/src/advanced.md @@ -8,11 +8,25 @@ MPI.refcount_inc MPI.refcount_dec ``` +## Buffers + +```@docs +MPI.Buffer +MPI.Buffer_send +MPI.MPIPtr +``` + ## Datatype objects ```@docs -MPI.mpitype -MPI.Type_Create_Subarray +MPI.Datatype +MPI.Types.extent +MPI.Types.create_contiguous +MPI.Types.create_vector +MPI.Types.create_subarray +MPI.Types.create_struct +MPI.Types.create_resized +MPI.Types.commit! ``` ## Operator objects diff --git a/src/MPI.jl b/src/MPI.jl index 5d9802837..704892b1e 100644 --- a/src/MPI.jl +++ b/src/MPI.jl @@ -2,6 +2,7 @@ module MPI using Libdl, Serialization using Requires +using DocStringExtensions macro mpichk(expr) @assert expr isa Expr && expr.head == :call && expr.args[1] == :ccall @@ -38,11 +39,13 @@ function _doc_external(fname) end include(joinpath(@__DIR__, "..", "deps", "deps.jl")) + include("handle.jl") include("info.jl") include("comm.jl") include("environment.jl") include("datatypes.jl") +include("buffers.jl") include("operators.jl") include("pointtopoint.jl") include("collective.jl") diff --git a/src/buffers.jl b/src/buffers.jl new file mode 100644 index 000000000..ea82643f1 --- /dev/null +++ b/src/buffers.jl @@ -0,0 +1,127 @@ +const MPIInteger = Union{Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64} +const MPIFloatingPoint = Union{Float32, Float64} +const MPIComplex = Union{ComplexF32, ComplexF64} + +const MPIDatatype = Union{Char, + Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, + UInt64, + Float32, Float64, ComplexF32, ComplexF64} +MPIBuffertype{T} = Union{Ptr{T}, Array{T}, SubArray{T}, Ref{T}} + +MPIBuffertypeOrConst{T} = Union{MPIBuffertype{T}, SentinelPtr} + +Base.cconvert(::Type{MPIPtr}, x::Union{Ptr{T}, Array{T}, Ref{T}}) where T = Base.cconvert(Ptr{T}, x) +function Base.cconvert(::Type{MPIPtr}, x::SubArray{T}) where T + Base.cconvert(Ptr{T}, x) +end +function Base.unsafe_convert(::Type{MPIPtr}, x::MPIBuffertype{T}) where T + ptr = Base.unsafe_convert(Ptr{T}, x) + reinterpret(MPIPtr, ptr) +end +function Base.cconvert(::Type{MPIPtr}, ::Nothing) + reinterpret(MPIPtr, C_NULL) +end + +macro assert_minlength(buffer, count) + quote + if $(esc(buffer)) isa AbstractArray + @assert length($(esc(buffer))) >= $(esc(count)) + end + end +end + +""" + MPI.MPIPtr + +A pointer to an MPI buffer. This type is used only as part of the implicit conversion in +`ccall`: a Julia object can be passed to MPI by defining methods for +`Base.cconvert(::Type{MPIPtr}, ...)`/`Base.unsafe_convert(::Type{MPIPtr}, ...)`. + +Currently supported are: + - `Ptr` + - `Ref` + - `Array` + - `SubArray` + - `CuArray` if CuArrays.jl is loaded. + +Additionally, certain sentinel values can be used, e.g. `MPI_IN_PLACE` or `MPI_BOTTOM`. +""" +MPIPtr + + +""" + MPI.Buffer + +An MPI buffer for communication operations. + +# Fields +$(DocStringExtensions.FIELDS) + +# Usage + + Buffer(data, count::Integer, datatype::Datatype) + +Generic constructor. + + Buffer(data) + +Construct a `Buffer` backed by `data`, automatically determining the appropriate `count` +and `datatype`. Methods are provided for + + - `Ref` + - `Array` + - `CuArray` if CuArrays.jl is loaded + - `SubArray`s of an `Array` or `CuArray` where the layout is contiguous, sequential or + blocked. + +""" +struct Buffer{A} + """a Julia object referencing a region of memory to be used for communication. It is + required that the object can be `cconvert`ed to an [`MPIPtr`](@ref).""" + data::A + + """the number of elements of `datatype` in the buffer. Note that this may not + correspond to the number of elements in the array if derived types are used.""" + count::Cint + + """the [`MPI.Datatype`](@ref) stored in the buffer.""" + datatype::Datatype +end +Buffer(buf::Buffer) = buf +Buffer(data, count::Integer, datatype::Datatype) = Buffer(data, Cint(count), datatype) + +function Buffer(arr::Array) + Buffer(arr, Cint(length(arr)), Datatype(eltype(arr))) +end +function Buffer(ref::Ref) + Buffer(ref, Cint(1), Datatype(eltype(ref))) +end + +# SubArray +function Buffer(sub::Base.FastContiguousSubArray) + Buffer(sub, Cint(length(sub)), Datatype(eltype(sub))) +end +function Buffer(sub::Base.FastSubArray) + datatype = Types.create_vector(length(sub), 1, sub.stride1, + Datatype(eltype(sub); commit=false)) + Types.commit!(datatype) + Buffer(sub, Cint(1), datatype) +end +function Buffer(sub::SubArray{T,N,P,I,false}) where {T,N,P,I<:Tuple{Vararg{Union{Base.ScalarIndex, Base.Slice, AbstractUnitRange}}}} + datatype = Types.create_subarray(size(parent(sub)), + map(length, sub.indices), + map(i -> first(i)-1, sub.indices), + Datatype(eltype(sub), commit=false)) + Types.commit!(datatype) + Buffer(parent(sub), Cint(1), datatype) +end + +""" + Buffer_send(data) + +Construct a [`Buffer`](@ref) object for a send operation from `data`, allowing cases where +`isbits(data)`. +""" +Buffer_send(data) = isbits(data) ? Buffer(Ref(data)) : Buffer(data) + +const BUFFER_NULL = Buffer(C_NULL, 0, DATATYPE_NULL) diff --git a/src/collective.jl b/src/collective.jl index 114911251..700b209c1 100644 --- a/src/collective.jl +++ b/src/collective.jl @@ -33,7 +33,7 @@ function Bcast!(buffer, count::Integer, # MPI_Comm comm) @mpichk ccall((:MPI_Bcast, libmpi), Cint, (MPIPtr, Cint, MPI_Datatype, Cint, MPI_Comm), - buffer, count, mpitype(eltype(buffer)), root, comm) + buffer, count, Datatype(eltype(buffer)), root, comm) buffer end @@ -105,7 +105,7 @@ function Scatter!(sendbuf, recvbuf, count::Integer, root::Integer, comm::Comm) # MPI_Comm comm) @mpichk ccall((:MPI_Scatter, libmpi), Cint, (MPIPtr, Cint, MPI_Datatype, MPIPtr, Cint, MPI_Datatype, Cint, MPI_Comm), - sendbuf, count, mpitype(T), recvbuf, count, mpitype(T), root, comm) + sendbuf, count, Datatype(T), recvbuf, count, Datatype(T), root, comm) recvbuf end @@ -174,7 +174,7 @@ function Scatterv!(sendbuf, recvbuf, counts::Vector, root::Integer, comm::Comm) # int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm) @mpichk ccall((:MPI_Scatterv, libmpi), Cint, (MPIPtr, Ptr{Cint}, Ptr{Cint}, MPI_Datatype, MPIPtr, Cint, MPI_Datatype, Cint, MPI_Comm), - sendbuf, counts, disps, mpitype(T), recvbuf, recvcnt, mpitype(T), root, comm) + sendbuf, counts, disps, Datatype(T), recvbuf, recvcnt, Datatype(T), root, comm) recvbuf end @@ -245,7 +245,7 @@ function Gather!(sendbuf, recvbuf, count::Integer, root::Integer, comm::Comm) # MPI_Comm comm) @mpichk ccall((:MPI_Gather, libmpi), Cint, (MPIPtr, Cint, MPI_Datatype, MPIPtr, Cint, MPI_Datatype, Cint, MPI_Comm), - sendbuf, count, mpitype(T), recvbuf, count, mpitype(T), root, comm) + sendbuf, count, Datatype(T), recvbuf, count, Datatype(T), root, comm) isroot ? recvbuf : nothing end function Gather!(sendbuf, recvbuf, root::Integer, comm::Comm) @@ -305,7 +305,7 @@ function Allgather!(sendbuf, recvbuf, count::Integer, comm::Comm) # MPI_Datatype recvtype, MPI_Comm comm) @mpichk ccall((:MPI_Allgather, libmpi), Cint, (MPIPtr, Cint, MPI_Datatype, MPIPtr, Cint, MPI_Datatype, MPI_Comm), - sendbuf, count, mpitype(T), recvbuf, count, mpitype(T), comm) + sendbuf, count, Datatype(T), recvbuf, count, Datatype(T), comm) recvbuf end function Allgather!(sendrecvbuf, count::Integer, comm::Comm) @@ -381,7 +381,7 @@ function Gatherv!(sendbuf, recvbuf, counts::Vector{Cint}, root::Integer, comm::C # MPI_Datatype recvtype, int root, MPI_Comm comm) @mpichk ccall((:MPI_Gatherv, libmpi), Cint, (MPIPtr, Cint, MPI_Datatype, MPIPtr, Ptr{Cint}, Ptr{Cint}, MPI_Datatype, Cint, MPI_Comm), - sendbuf, sendcnt, mpitype(T), recvbuf, counts, displs, mpitype(T), root, comm) + sendbuf, sendcnt, Datatype(T), recvbuf, counts, displs, Datatype(T), root, comm) isroot ? recvbuf : nothing end @@ -436,7 +436,7 @@ function Allgatherv!(sendbuf, recvbuf, counts::Vector{Cint}, comm::Comm) # const int displs[], MPI_Datatype recvtype, MPI_Comm comm) @mpichk ccall((:MPI_Allgatherv, libmpi), Cint, (MPIPtr, Cint, MPI_Datatype, MPIPtr, Ptr{Cint}, Ptr{Cint}, MPI_Datatype, MPI_Comm), - sendbuf, sendcnt, mpitype(T), recvbuf, counts, displs, mpitype(T), comm) + sendbuf, sendcnt, Datatype(T), recvbuf, counts, displs, Datatype(T), comm) recvbuf end function Allgatherv!(sendrecvbuf, counts::Vector{Cint}, comm::Comm) @@ -499,7 +499,7 @@ function Alltoall!(sendbuf, recvbuf, count::Integer, comm::Comm) # MPI_Comm comm) @mpichk ccall((:MPI_Alltoall, libmpi), Cint, (MPIPtr, Cint, MPI_Datatype, MPIPtr, Cint, MPI_Datatype, MPI_Comm), - sendbuf, count, mpitype(T), recvbuf, count, mpitype(T), comm) + sendbuf, count, Datatype(T), recvbuf, count, Datatype(T), comm) recvbuf end function Alltoall!(sendrecvbuf, count::Integer, comm::Comm) @@ -558,7 +558,7 @@ function Alltoallv!(sendbuf, recvbuf, scounts::Vector{Cint}, rcounts::Vector{Cin # MPI_Datatype recvtype, MPI_Comm comm) @mpichk ccall((:MPI_Alltoallv, libmpi), Cint, (MPIPtr, Ptr{Cint}, Ptr{Cint}, MPI_Datatype, MPIPtr, Ptr{Cint}, Ptr{Cint}, MPI_Datatype, MPI_Comm), - sendbuf, scounts, sdispls, mpitype(T), recvbuf, rcounts, rdispls, mpitype(T), comm) + sendbuf, scounts, sdispls, Datatype(T), recvbuf, rcounts, rdispls, Datatype(T), comm) recvbuf end @@ -616,7 +616,7 @@ function Reduce!(sendbuf, recvbuf, count::Integer, op::Union{Op,MPI_Op}, root::I # MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm) @mpichk ccall((:MPI_Reduce, libmpi), Cint, (MPIPtr, MPIPtr, Cint, MPI_Datatype, MPI_Op, Cint, MPI_Comm), - sendbuf, recvbuf, count, mpitype(T), op, root, comm) + sendbuf, recvbuf, count, Datatype(T), op, root, comm) recvbuf end @@ -699,7 +699,7 @@ function Allreduce!(sendbuf, recvbuf, count::Integer, op::Union{Op,MPI_Op}, comm # MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) @mpichk ccall((:MPI_Allreduce, libmpi), Cint, (MPIPtr, MPIPtr, Cint, MPI_Datatype, MPI_Op, MPI_Comm), - sendbuf, recvbuf, count, mpitype(T), op, comm) + sendbuf, recvbuf, count, Datatype(T), op, comm) recvbuf end function Allreduce!(sendbuf, recvbuf, count::Integer, opfunc, comm::Comm) @@ -766,7 +766,7 @@ function Scan!(sendbuf, recvbuf, count::Integer, # MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) @mpichk ccall((:MPI_Scan, libmpi), Cint, (MPIPtr, MPIPtr, Cint, MPI_Datatype, MPI_Op, MPI_Comm), - sendbuf, recvbuf, count, mpitype(T), op, comm) + sendbuf, recvbuf, count, Datatype(T), op, comm) recvbuf end function Scan!(sendbuf, recvbuf, count::Integer, opfunc, comm::Comm) @@ -840,7 +840,7 @@ function Exscan!(sendbuf, recvbuf, count::Integer, # MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) @mpichk ccall((:MPI_Exscan, libmpi), Cint, (MPIPtr, MPIPtr, Cint, MPI_Datatype, MPI_Op, MPI_Comm), - sendbuf, recvbuf, count, mpitype(T), op, comm) + sendbuf, recvbuf, count, Datatype(T), op, comm) recvbuf end function Exscan!(sendbuf, recvbuf, count::Integer, opfunc, comm::Comm) diff --git a/src/cuda.jl b/src/cuda.jl index 7e985dca4..93e22635a 100644 --- a/src/cuda.jl +++ b/src/cuda.jl @@ -11,6 +11,17 @@ function Base.unsafe_convert(::Type{MPIPtr}, buf::DeviceBuffer) reinterpret(MPIPtr, buf.ptr) end # CuArrays > v1.3 -function Base.unsafe_convert(::Type{MPIPtr}, buf::CuArray{T}) where T - reinterpret(MPIPtr, Base.unsafe_convert(CuPtr{T}, buf)) +function Base.unsafe_convert(::Type{MPIPtr}, X::CuArray{T}) where T + reinterpret(MPIPtr, Base.unsafe_convert(CuPtr{T}, X)) +end +# only need to define this for strided arrays: all others can be handled by generic machinery +function Base.unsafe_convert(::Type{MPIPtr}, V::SubArray{T,N,P,I,true}) where {T,N,P<:CuArray,I} + X = parent(V) + pX = Base.unsafe_convert(CuPtr{T}, X) + pV = pX + ((V.offset1 + V.stride1) - first(LinearIndices(X)))*sizeof(T) + return reinterpret(MPIPtr, pV) +end + +function Buffer(arr::CuArray) + Buffer(arr, Cint(length(arr)), Datatype(eltype(arr))) end diff --git a/src/datatypes.jl b/src/datatypes.jl index 315105751..07eecdc74 100644 --- a/src/datatypes.jl +++ b/src/datatypes.jl @@ -1,3 +1,18 @@ +""" + Datatype + +A `Datatype` represents the layout of the data in memory. + +# Usage + + Datatype(T; commit=true) + +Either return the predefined `Datatype` or create a new `Datatype` for the Julia type +`T`. If `commit=true`, then the [`Types.commit!`](@ref) operation will also be applied so +that it can be used for communication operations. + +Note that this can only be called on types for which `isbitstype(T)` is `true`. +""" @mpi_handle Datatype const DATATYPE_NULL = _Datatype(MPI_DATATYPE_NULL) @@ -5,249 +20,306 @@ Datatype() = Datatype(DATATYPE_NULL.val) function free(dt::Datatype) if dt.val != DATATYPE_NULL.val - @mpichk ccall((:MPI_Datatype_free, libmpi), Cint, (Ptr{MPI_Datatype},), dt) + @mpichk ccall((:MPI_Type_free, libmpi), Cint, (Ptr{MPI_Datatype},), dt) refcount_dec() end return nothing end - - -macro assert_minlength(buffer, count) - quote - if $(esc(buffer)) isa AbstractArray - @assert length($(esc(buffer))) >= $(esc(count)) +for (mpiname, T) in [ + :INT8_T => Int8 + :UINT8_T => UInt8 + :INT16_T => Int16 + :UINT16_T => UInt16 + :INT32_T => Int32 + :UINT32_T => UInt32 + :INT64_T => Int64 + :UINT64_T => UInt64 + :BYTE => UInt8 + :SHORT => Cshort + :UNSIGNED_SHORT => Cushort + :INT => Cint + :UNSIGNED => Cuint + :LONG => Clong + :UNSIGNED_LONG => Culong + :CHAR => Cchar + :SIGNED_CHAR => Cchar + :UNSIGNED_CHAR => Cuchar + :WCHAR => Cwchar_t + :FLOAT => Float32 + :DOUBLE => Float64 + :C_FLOAT_COMPLEX => ComplexF32 + :C_DOUBLE_COMPLEX => ComplexF64] + + @eval if @isdefined($(Symbol(:MPI_,mpiname))) + const $mpiname = _Datatype($(Symbol(:MPI_,mpiname))) + if !hasmethod(Datatype, Tuple{Type{$T}}) + Datatype(::Type{$T}; commit=true) = $mpiname end end end + +module Types -const MPIInteger = Union{Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64} -const MPIFloatingPoint = Union{Float32, Float64} -const MPIComplex = Union{ComplexF32, ComplexF64} +import MPI +import MPI: @mpichk, libmpi, _doc_external, + Datatype, MPI_Datatype, MPI_Aint, + refcount_inc, refcount_dec, free -const MPIDatatype = Union{Char, - Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, - UInt64, - Float32, Float64, ComplexF32, ComplexF64} -MPIBuffertype{T} = Union{Ptr{T}, Array{T}, SubArray{T}, Ref{T}} +""" + lb, extent = MPI.Types.extent(dt::MPI.Datatype) -MPIBuffertypeOrConst{T} = Union{MPIBuffertype{T}, SentinelPtr} +Gets the lowerbound `lb` and the extent `extent` in bytes. -Base.cconvert(::Type{MPIPtr}, x::Union{Ptr{T}, Array{T}, Ref{T}}) where T = Base.cconvert(Ptr{T}, x) -function Base.cconvert(::Type{MPIPtr}, x::SubArray{T}) where T - @assert Base.iscontiguous(x) - Base.cconvert(Ptr{T}, x) -end -function Base.unsafe_convert(::Type{MPIPtr}, x::MPIBuffertype{T}) where T - ptr = Base.unsafe_convert(Ptr{T}, x) - reinterpret(MPIPtr, ptr) -end -function Base.cconvert(::Type{MPIPtr}, ::Nothing) - reinterpret(MPIPtr, C_NULL) +# External links +$(_doc_external("MPI_Type_get_extent")) +""" +function extent(dt::Datatype) + lb = Ref{MPI_Aint}() + extent = Ref{MPI_Aint}() + # int MPI_Type_get_extent(MPI_Datatype datatype, MPI_Aint *lb, + # MPI_Aint *extent) + @mpichk ccall((:MPI_Type_get_extent, libmpi), Cint, + (MPI_Datatype, Ptr{MPI_Aint}, Ptr{MPI_Aint}), + dt, lb, extent) + return lb[], extent[] end +""" + MPI.Types.create_contiguous(count::Integer, oldtype::MPI.Datatype) +Create a derived [`Datatype`](@ref) that replicates `oldtype` into `count` contiguous locations. -fieldoffsets(::Type{T}) where {T} = Int[fieldoffset(T, i) for i in 1:length(fieldnames(T))] +Note that [`MPI.Types.commit!`](@ref) must be used before the datatype can be used for +communication. +# External links +$(_doc_external("MPI_Type_contiguous")) """ - mpitype(T) +function create_contiguous(count::Integer, oldtype::Datatype) + newtype = Datatype() + @mpichk ccall((:MPI_Type_contiguous, libmpi), Cint, + (Cint, MPI_Datatype, Ptr{MPI_Datatype}), + count, oldtype, newtype) + refcount_inc() + finalizer(free, newtype) + return newtype +end + -Returns the MPI `Datatype` code for a given type `T`. In the case the the type does not -exist, it is created and then returned. The dictonary is defined in `__init__` so the -module can be precompiled """ -function mpitype(::Type{T}) where T - get!(mpitype_dict, T) do - if !isbitstype(T) - throw(ArgumentError("Type must be isbitstype()")) - end + MPI.Types.create_vector(count::Integer, blocklength::Integer, stride::Integer, oldtype::MPI.Datatype) - # get the data from the type - fieldtypes = T.types - offsets = fieldoffsets(T) - nfields = Cint(length(fieldtypes)) +Create a derived [`Datatype`](@ref) that replicates `oldtype` into locations that +consist of equally spaced blocks. - if nfields == 0 # primitive type - if sizeof(T) == 0 - error("Can't convert 0-size type to MPI") - end - nfields, blocklengths, displacements, types = factorPrimitiveType(T) - else # struct - # put data in MPI format - blocklengths = ones(Cint, nfields) - displacements = zeros(Cptrdiff_t, nfields) # size_t == MPI_Aint ? - types = Array{MPI_Datatype}(undef, nfields) - for i=1:nfields - displacements[i] = offsets[i] - # create an MPI_Datatype for the current field if it does not exist yet - types[i] = mpitype(fieldtypes[i]) - end +Note that [`MPI.Types.commit!`](@ref) must be used before the datatype can be used for +communication. - end - newtype = Type_Create_Struct(nfields, blocklengths, displacements, types) +# Example - # commit the datatatype - Type_Commit!(newtype) - return newtype.val - end -end +```julia +datatype = MPI.Types.create_vector(3, 2, 5, MPI.Datatype(Int64)) +MPI.Types.commit!(datatype) +``` +will create a datatype with the following layout +``` +|<----->| block length ++---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ +| X | X | | | | X | X | | | | X | X | | | | ++---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ -function factorPrimitiveType(::Type{T}) where {T} - - tsize = sizeof(T) # size in bytes - displacements = zeros(Cptrdiff_t, 0) # size_t == MPI_Aint ? - types = MPI_Datatype[] - - # put largest sizes first - mpi_types = [mpitype(Int64), mpitype(Int32), mpitype(Int16), mpitype(Int8)] - mpi_sizes = [8, 4, 2, 1] - curr_disp = 0 - - while curr_disp != tsize - remsize = tsize - curr_disp - for i=1:length(mpi_types) - size_i = mpi_sizes[i] - # because each size is a multiple of the smaller sizes, taking the largest - # size always results in the smallest number of types and doesn't result - # in avoidable small remainders - if remsize >= size_i - push!(types, mpi_types[i]) - push!(displacements, curr_disp) - curr_disp += size_i - break - end - end - end +|<---- stride ----->| +``` +where each segment represents an `Int64`. - nfields = length(types) - blocklengths = ones(Cint, nfields) +(image by Jonathan Dursi, ) - return nfields, blocklengths, displacements, types +# External links +$(_doc_external("MPI_Type_vector")) +""" +function create_vector(count::Integer, blocklength::Integer, stride::Integer, oldtype::Datatype) + newtype = Datatype() + # int MPI_Type_vector(int count, int blocklength, int stride, + # MPI_Datatype oldtype, MPI_Datatype *newtype) + @mpichk ccall((:MPI_Type_vector, libmpi), Cint, + (Cint, Cint, Cint, MPI_Datatype, Ptr{MPI_Datatype}), + count, blocklength, stride, oldtype, newtype) + refcount_inc() + finalizer(free, newtype) + return newtype end +""" + MPI.Types.create_subarray(sizes, subsizes, offset, oldtype::Datatype; + rowmajor=false) + +Creates a derived [`Datatype`](@ref) describing an `N`-dimensional subarray of size +`subsizes` of an `N`-dimensional array of size `sizes` and element type `oldtype`, with +the first element offset by `offset` (i.e. the 0-based index of the first element). -function Type_Create_Struct(nfields::Integer, blocklengths::MPIBuffertype{Cint}, - displacements::MPIBuffertype{Cptrdiff_t}, - types::MPIBuffertype{MPI_Datatype}) +Column-major indexing (used by Julia and Fortran) is assumed; use the keyword +`rowmajor=true` to specify row-major layout (used by C and numpy). - newtype = Datatype() +Note that [`MPI.Types.commit!`](@ref) must be used before the datatype can be used for +communication. - # int MPI_Type_create_struct(int count, const int array_of_blocklengths[], - # const MPI_Aint array_of_displacements[], - # const MPI_Datatype array_of_types[], MPI_Datatype *newtype) - @mpichk ccall((:MPI_Type_create_struct, libmpi), Cint, - (Cint, Ptr{Cint}, Ptr{Cptrdiff_t}, Ptr{MPI_Datatype}, Ptr{MPI_Datatype}), - nfields, blocklengths, displacements, types, newtype) +# External links +$(_doc_external("MPI_Type_create_subarray")) +""" +function create_subarray(sizes, + subsizes, + offset, + oldtype::Datatype; + rowmajor=false) + @assert (N = length(sizes)) == length(subsizes) == length(offset) + sizes = sizes isa Vector{Cint} ? sizes : Cint[s for s in sizes] + subsizes = subsizes isa Vector{Cint} ? subsizes : Cint[s for s in subsizes] + offset = offset isa Vector{Cint} ? offset : Cint[s for s in offset] + + newtype = Datatype() + @mpichk ccall((:MPI_Type_create_subarray, libmpi), Cint, + (Cint, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Cint, MPI_Datatype, Ptr{MPI_Datatype}), + N, sizes, subsizes, offset, + rowmajor ? MPI.MPI_ORDER_C : MPI.MPI_ORDER_FORTRAN, + oldtype, newtype) + refcount_inc() + finalizer(free, newtype) return newtype end """ - Type_Create_Subarray(ndims::Integer, array_of_sizes::MPIBuffertype{Cint}, - array_of_subsizes::MPIBuffertype{Cint}, - array_of_starts::MPIBuffertype{Cint}, order::Integer, oldtype) + MPI.Types.create_struct(blocklengths, displacements, types) + +Creates a derived [`Datatype`](@ref) describing a struct layout. -Creates a data type describing an `ndims`-dimensional subarray of size `array_of_subsizes` -of an `ndims-dimensional` array of size `array_of_sizes` and element type `oldtype`, -starting at the top-left location `array_of_starts`. Zero-based indexing is assumed. The -parameter `order` refers to the memory layout of the parent array, and can be either -`MPI_ORDER_C` or `MPI_ORDER_FORTRAN`. Note that, like other MPI data types, the type -returned by this function should be committed with `MPI_Type_commit`. +Note that [`MPI.Types.commit!`](@ref) must be used before the datatype can be used for +communication. + +# External links +$(_doc_external("MPI_Type_create_struct")) """ -function Type_Create_Subarray(ndims::Integer, - array_of_sizes::MPIBuffertype{Cint}, - array_of_subsizes::MPIBuffertype{Cint}, - array_of_starts::MPIBuffertype{Cint}, - order::Integer, - oldtype) +function create_struct(blocklengths, displacements, types) + @assert (N = length(blocklengths)) == length(displacements) == length(types) + blocklengths = blocklengths isa Vector{Cint} ? blocklengths : Cint[s for s in blocklengths] + displacements = displacements isa Vector{MPI_Aint} ? displacements : MPI_Aint[s for s in displacements] newtype = Datatype() - @mpichk ccall((:MPI_Type_create_subarray, libmpi), Cint, - (Cint, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Cint, MPI_Datatype, Ptr{MPI_Datatype}), - ndims, array_of_sizes, array_of_subsizes, array_of_starts, - order, mpitype(oldtype), newtype) + # int MPI_Type_create_struct(int count, + # const int array_of_blocklengths[], + # const MPI_Aint array_of_displacements[], + # const MPI_Datatype array_of_types[], + # MPI_Datatype *newtype) + GC.@preserve types begin + mpi_types = [t.val for t in types] + @mpichk ccall((:MPI_Type_create_struct, libmpi), Cint, + (Cint, Ptr{Cint}, Ptr{MPI_Aint}, Ptr{MPI_Datatype}, Ptr{MPI_Datatype}), + N, blocklengths, displacements, mpi_types, newtype) + end return newtype end -function Type_Contiguous(count::Integer, oldtype) + + +""" + MPI.Types.create_resized(oldtype::Datatype, lb::Integer, extent::Integer) + +Creates a new [`Datatype`](@ref) that is identical to `oldtype`, except that the lower +bound of this new datatype is set to be `lb`, and its upper bound is set to be `lb + +extent`. + +Note that [`MPI.Types.commit!`](@ref) must be used before the datatype can be used for +communication. + +# See also +- [`MPI.Types.extent`](@ref) + +# External links +$(_doc_external("MPI_Type_create_resized")) +""" +function create_resized(oldtype::Datatype, lb::Integer, extent::Integer) newtype = Datatype() - @mpichk ccall((:MPI_Type_contiguous, libmpi), Cint, - (Cint, MPI_Datatype, Ptr{MPI_Datatype}), - count, oldtype, newtype) + # int MPI_Type_create_resized(MPI_Datatype oldtype, MPI_Aint lb, + # MPI_Aint extent, MPI_Datatype *newtype) + @mpichk ccall((:MPI_Type_create_resized, libmpi), Cint, + (MPI_Datatype, Cptrdiff_t, Cptrdiff_t, Ptr{MPI_Datatype}), + oldtype, lb, extent, newtype) + refcount_inc() + finalizer(free, newtype) return newtype end -function Type_Commit!(newtype::Datatype) + +""" + MPI.Types.commit!(newtype::Datatype) + +Commits a [`Datatype`](@ref) so that it can be used for communication. + +# External links +$(_doc_external("MPI_Type_commit")) +""" +function commit!(newtype::Datatype) # int MPI_Type_commit(MPI_Datatype *datatype) @mpichk ccall((:MPI_Type_commit, libmpi), Cint, (Ptr{MPI_Datatype},), newtype) end -# Setter function for mpitype_dict and mpitype_dict_inverse -function recordDataType(T::DataType, mpiT::MPI_Datatype) - - if !haskey(mpitype_dict, T) - mpitype_dict[T] = mpiT - end - - if !haskey(mpitype_dict_inverse, mpiT) - mpitype_dict_inverse[mpiT] = T - end - - return nothing -end -recordDataType(T::DataType, dtyp::Datatype) = recordDataType(T, dtyp.val) - - -const mpitype_dict = Dict{DataType, MPI_Datatype}() -const mpitype_dict_inverse = Dict{MPI_Datatype, DataType}() - -function init_datatypes() - if Sys.iswindows() || MPI_VERSION >= v"2.2" - # use specific-width types if available - for (T, c) in [ - Int8 => MPI_INT8_T, - UInt8 => MPI_UINT8_T, - Int16 => MPI_INT16_T, - UInt16 => MPI_UINT16_T, - Int32 => MPI_INT32_T, - UInt32 => MPI_UINT32_T, - Int64 => MPI_INT64_T, - UInt64 => MPI_UINT64_T, - ComplexF32 => MPI_C_FLOAT_COMPLEX, - ComplexF64 => MPI_C_DOUBLE_COMPLEX, - ] - recordDataType(T, MPI_Datatype(c)) +function Datatype(::Type{T}; commit=true) where T + if !isbitstype(T) + throw(ArgumentError("Type must be isbitstype")) + end + blocklengths = Cint[] + displacements = MPI_Aint[] + types = Datatype[] + + if isprimitivetype(T) + # primitive type + szrem = sz = sizeof(T) + disp = 0 + for (i,basetype) in (8 => Datatype(UInt64), 4 => Datatype(UInt32), 2 => Datatype(UInt16), 1 => Datatype(UInt8)) + if sz == i + return basetype + end + blk, szrem = divrem(szrem, i) + if blk != 0 + push!(blocklengths, blk) + push!(displacements, disp) + push!(types, basetype) + disp += i * blk + end + end + else + # struct + Fprev = nothing + for i in 1:fieldcount(T) + F = fieldtype(T,i) + offset = fieldoffset(T,i) + if sizeof(F) == 0 + continue + elseif F == Fprev + blocklengths[end] += 1 + else + push!(blocklengths, 1) + push!(displacements, offset) + push!(types, Datatype(F; commit=false)) + Fprev = F + end end end - - for (T, c) in [ - UInt8 => MPI_BYTE, - Cshort => MPI_SHORT, - Cushort => MPI_UNSIGNED_SHORT, - Cint => MPI_INT, - Cuint => MPI_UNSIGNED, - Clong => MPI_LONG, - Culong => MPI_UNSIGNED_LONG, - Cchar => MPI_CHAR, - Cchar => MPI_SIGNED_CHAR, - Cuchar => MPI_UNSIGNED_CHAR, - Cwchar_t => MPI_WCHAR, - Float32 => MPI_FLOAT, - Float64 => MPI_DOUBLE, - - Char => MPI_UNSIGNED, - ] - recordDataType(T, MPI_Datatype(c)) + dt = create_struct(blocklengths, displacements, types) + if commit + commit!(dt) end + return dt end -push!(mpi_init_hooks, init_datatypes) +end # module + -function Get_address(location::MPIBuffertype{T}) where T +function Get_address(location) addr = Ref{Cptrdiff_t}(0) - @mpichk ccall((:MPI_Get_address, libmpi), Cint, (Ptr{T}, Ref{Cptrdiff_t}), location, addr) + @mpichk ccall((:MPI_Get_address, libmpi), Cint, (Ptr{Cvoid}, Ref{MPI_Aint}), location, addr) return addr[] end diff --git a/src/deprecated.jl b/src/deprecated.jl index 02b450fe1..1196c4957 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -38,3 +38,53 @@ end false Gatherv!(buf, nothing, counts, root, comm) end end false + +@deprecate(mpitype(T), Datatype(T), false) + +@deprecate(Type_Create_Subarray(ndims::Integer, sizes::MPIBuffertype{Cint}, subsizes::MPIBuffertype{Cint}, + starts::MPIBuffertype{Cint}, order::Integer, oldtype), + Types.create_subarray(sizes, subsizes, starts, Datatype(oldtype); rowmajor = order == MPI_ORDER_C), false) +@deprecate(Type_Create_Struct(nfields::Integer, blocklengths::MPIBuffertype{Cint}, + displacements::MPIBuffertype{Cptrdiff_t}, types::MPIBuffertype{MPI_Datatype}), + Types.create_struct(blocklengths, displacements, types), false) +@deprecate(Type_Commit!(datatype), Types.commit!(datatype), false) + + +@deprecate(Send(buf, count::Integer, datatype::Datatype, dest::Integer, tag::Integer, comm::Comm), + Send(Buffer(buf, count, datatype), dest, tag, comm), false) +@deprecate(Send(buf::AbstractArray, count::Integer, dest::Integer, tag::Integer, comm::Comm), + Send(view(buf, 1:count), dest, tag, comm), false) +@deprecate(Send(buf::Ref, count::Integer, dest::Integer, tag::Integer, comm::Comm), + Send(buf, dest, tag, comm), false) + +@deprecate(Isend(buf, count::Integer, datatype::Datatype, dest::Integer, tag::Integer, comm::Comm), + Isend(Buffer(buf,count,datatype), dest, tag, comm), false) +@deprecate(Isend(buf::AbstractArray, count::Integer, dest::Integer, tag::Integer, comm::Comm), + Isend(view(buf,1:count), dest, tag, comm), false) +@deprecate(Isend(buf::Ref, count::Integer, dest::Integer, tag::Integer, comm::Comm), + Isend(buf, dest, tag, comm), false) + +@deprecate(Recv!(buf, count::Integer, datatype::Datatype, src::Integer, tag::Integer, comm::Comm), + Recv!(Buffer(buf, count, datatype), src, tag, comm), false) +@deprecate(Recv!(buf::AbstractArray, count::Integer, src::Integer, tag::Integer, comm::Comm), + Recv!(view(buf, 1:count), src, tag, comm), false) +@deprecate(Recv!(buf::Ref, count::Integer, src::Integer, tag::Integer, comm::Comm), + Recv!(buf, src, tag, comm), false) + +@deprecate(Irecv!(buf, count::Integer, datatype::Datatype, src::Integer, tag::Integer, comm::Comm), + Irecv!(Buffer(buf,count,datatype), src, tag, comm), false) +@deprecate(Irecv!(buf::AbstractArray, count::Integer, src::Integer, tag::Integer, comm::Comm), + Irecv!(view(buf,1:count), src, tag, comm), false) +@deprecate(Irecv!(buf::Ref, count::Integer, src::Integer, tag::Integer, comm::Comm), + Irecv!(buf, src, tag, comm), false) + +@deprecate(Sendrecv!(sendbuf, sendcount::Integer, sendtype, dest::Integer, sendtag::Integer, + recvbuf, recvcount::Integer, recvtype, source::Integer, recvtag::Integer, + comm::Comm), + Sendrecv!(Buffer(sendbuf, sendcount, sendtype), dest, sendtag, + Buffer(recvbuf, recvcount, recvtype), source, recvtag, comm), false) +@deprecate(Sendrecv!(sendbuf, sendcount::Integer, dest::Integer, sendtag::Integer, + recvbuf, recvcount::Integer, source::Integer, recvtag::Integer, + comm::Comm), + Sendrecv!(view(sendbuf, 1:sendcount), dest, sendtag, + view(recvbuf, 1:recvcount), source, recvtag, comm), false) diff --git a/src/onesided.jl b/src/onesided.jl index 26c64446e..c3c33795a 100644 --- a/src/onesided.jl +++ b/src/onesided.jl @@ -155,7 +155,7 @@ function Get(origin_buffer, count::Integer, target_rank::Integer, target_disp::I # MPI_Datatype target_datatype, MPI_Win win) @mpichk ccall((:MPI_Get, libmpi), Cint, (MPIPtr, Cint, MPI_Datatype, Cint, Cptrdiff_t, Cint, MPI_Datatype, MPI_Win), - origin_buffer, count, mpitype(T), target_rank, Cptrdiff_t(target_disp), count, mpitype(T), win) + origin_buffer, count, Datatype(T), target_rank, Cptrdiff_t(target_disp), count, Datatype(T), win) end function Get(origin_buffer::AbstractArray{T}, target_rank::Integer, win::Win) where T count = length(origin_buffer) @@ -173,7 +173,7 @@ function Put(origin_buffer, count::Integer, target_rank::Integer, target_disp::I T = eltype(origin_buffer) @mpichk ccall((:MPI_Put, libmpi), Cint, (MPIPtr, Cint, MPI_Datatype, Cint, Cptrdiff_t, Cint, MPI_Datatype, MPI_Win), - origin_buffer, count, mpitype(T), target_rank, Cptrdiff_t(target_disp), count, mpitype(T), win) + origin_buffer, count, Datatype(T), target_rank, Cptrdiff_t(target_disp), count, Datatype(T), win) end function Put(origin_buffer::AbstractArray{T}, target_rank::Integer, win::Win) where T count = length(origin_buffer) @@ -191,7 +191,7 @@ function Fetch_and_op(sourceval, returnval, target_rank::Integer, target_disp::I T = eltype(sourceval) @mpichk ccall((:MPI_Fetch_and_op, libmpi), Cint, (MPIPtr, MPIPtr, MPI_Datatype, Cint, Cptrdiff_t, MPI_Op, MPI_Win), - sourceval, returnval, mpitype(T), target_rank, target_disp, op, win) + sourceval, returnval, Datatype(T), target_rank, target_disp, op, win) end function Accumulate(origin_buffer, count::Integer, target_rank::Integer, target_disp::Integer, op::Op, win::Win) @@ -202,7 +202,7 @@ function Accumulate(origin_buffer, count::Integer, target_rank::Integer, target_ T = eltype(origin_buffer) @mpichk ccall((:MPI_Accumulate, libmpi), Cint, (MPIPtr, Cint, MPI_Datatype, Cint, Cptrdiff_t, Cint, MPI_Datatype, MPI_Op, MPI_Win), - origin_buffer, count, mpitype(T), target_rank, Cptrdiff_t(target_disp), count, mpitype(T), op, win) + origin_buffer, count, Datatype(T), target_rank, Cptrdiff_t(target_disp), count, Datatype(T), op, win) end function Get_accumulate(origin_buffer, result_buffer, count::Integer, target_rank::Integer, target_disp::Integer, op::Op, win::Win) @@ -215,5 +215,5 @@ function Get_accumulate(origin_buffer, result_buffer, count::Integer, target_ran T = eltype(origin_buffer) @mpichk ccall((:MPI_Get_accumulate, libmpi), Cint, (MPIPtr, Cint, MPI_Datatype, MPIPtr, Cint, MPI_Datatype, Cint, Cptrdiff_t, Cint, MPI_Datatype, MPI_Op, MPI_Win), - origin_buffer, count, mpitype(T), result_buffer, count, mpitype(T), target_rank, Cptrdiff_t(target_disp), count, mpitype(T), op, win) + origin_buffer, count, Datatype(T), result_buffer, count, Datatype(T), target_rank, Cptrdiff_t(target_disp), count, Datatype(T), op, win) end diff --git a/src/operators.jl b/src/operators.jl index 88ee80a65..78a04682b 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -59,13 +59,9 @@ end function (w::OpWrapper{F,T})(_a::Ptr{Cvoid}, _b::Ptr{Cvoid}, _len::Ptr{Cint}, t::Ptr{MPI_Datatype}) where {F,T} len = unsafe_load(_len) - if isconcretetype(T) - S = T - else - S = mpitype_dict_inverse[unsafe_load(t)] - end - a = Ptr{S}(_a) - b = Ptr{S}(_b) + @assert isconcretetype(T) + a = Ptr{T}(_a) + b = Ptr{T}(_b) for i = 1:len unsafe_store!(b, w.f(unsafe_load(a,i), unsafe_load(b,i)), i) end diff --git a/src/pointtopoint.jl b/src/pointtopoint.jl index f4a82589d..8b8791b79 100644 --- a/src/pointtopoint.jl +++ b/src/pointtopoint.jl @@ -157,43 +157,35 @@ If the number of entries received exceeds the limits of the count parameter, the # External links $(_doc_external("MPI_Get_count")) """ -function Get_count(stat::Status, datatype::Union{MPI_Datatype, Datatype}) +function Get_count(stat::Status, datatype::Datatype) count = Ref{Cint}() @mpichk ccall((:MPI_Get_count, libmpi), Cint, (Ptr{Status}, MPI_Datatype, Ptr{Cint}), Ref(stat), datatype, count) Int(count[]) end -Get_count(stat::Status, ::Type{T}) where {T} = Get_count(stat, mpitype(T)) +Get_count(stat::Status, ::Type{T}) where {T} = Get_count(stat, Datatype(T)) """ - Send(buf, [count::Integer, [datatype::Datatype,]] - dest::Integer, tag::Integer, comm::Comm) where T + Send(buf, dest::Integer, tag::Integer, comm::Comm) -Perform a blocking send of `count` elements of type `datatype` from `buf` to MPI -rank `dest` of communicator `comm` using the message tag `tag`. - -If not provided, `datatype` and `count` are derived from the element type and length of -`buf`, respectively. +Perform a blocking send from the buffer `buf` to MPI rank `dest` of communicator `comm` +using the message tag `tag`. # External links $(_doc_external("MPI_Send")) """ -function Send(buf, count::Integer, datatype::Union{Datatype, MPI_Datatype}, - dest::Integer, tag::Integer, comm::Comm) +function Send(buf::Buffer, dest::Integer, tag::Integer, comm::Comm) # int MPI_Send(const void* buf, int count, MPI_Datatype datatype, int dest, # int tag, MPI_Comm comm) @mpichk ccall((:MPI_Send, libmpi), Cint, (MPIPtr, Cint, MPI_Datatype, Cint, Cint, MPI_Comm), - buf, count, datatype, dest, tag, comm) + buf.data, buf.count, buf.datatype, dest, tag, comm) return nothing end - -Send(buf, count::Integer, dest::Integer, tag::Integer, comm::Comm) = - Send(buf, count, mpitype(eltype(buf)), dest, tag, comm) -Send(buf::AbstractArray, dest::Integer, tag::Integer, comm::Comm) = - Send(buf, length(buf), dest, tag, comm) +Send(arr::Union{Ref,AbstractArray}, dest::Integer, tag::Integer, comm::Comm) = + Send(Buffer(arr), dest, tag, comm) """ @@ -203,7 +195,7 @@ Complete a blocking send of `obj` to MPI rank `dest` of communicator `comm` using with the message tag `tag`. """ function Send(obj::T, dest::Integer, tag::Integer, comm::Comm) where T - buf = [obj] + buf = Ref{T}(obj) Send(buf, dest, tag, comm) end @@ -219,51 +211,32 @@ function send(obj, dest::Integer, tag::Integer, comm::Comm) end """ - Isend(buf, [count::Integer, [datatype::Datatype,]] - dest::Integer, tag::Integer, comm::Comm) + Isend(data, dest::Integer, tag::Integer, comm::Comm) -Starts a nonblocking send of `count` elements of type `datatype` from `buf` to -MPI rank `dest` of communicator `comm` using with the message tag `tag`. +Starts a nonblocking send of `data` to MPI rank `dest` of communicator `comm` using with +the message tag `tag`. -If not provided, `datatype` and `count` are derived from the element type and length of -`buf`, respectively. +`data` can be a `Buffer`, or any object for which [`Buffer_send`](@ref) is defined. Returns the [`Request`](@ref) object for the nonblocking send. # External links $(_doc_external("MPI_Isend")) """ -function Isend(buf, count::Integer, datatype::Union{Datatype, MPI_Datatype}, - dest::Integer, tag::Integer, comm::Comm) +function Isend(buf::Buffer, dest::Integer, tag::Integer, comm::Comm) req = Request() # int MPI_Isend(const void* buf, int count, MPI_Datatype datatype, int dest, # int tag, MPI_Comm comm, MPI_Request *request) @mpichk ccall((:MPI_Isend, libmpi), Cint, (MPIPtr, Cint, MPI_Datatype, Cint, Cint, MPI_Comm, Ptr{MPI_Request}), - buf, count, datatype, dest, tag, comm, req) + buf.data, buf.count, buf.datatype, dest, tag, comm, req) req.buffer = buf refcount_inc() finalizer(free, req) return req end - -Isend(buf, count::Integer, dest::Integer, tag::Integer, comm::Comm) = - Isend(buf, count, mpitype(eltype(buf)), dest, tag, comm) -Isend(buf::AbstractArray, dest::Integer, tag::Integer, comm::Comm) = - Isend(buf, length(buf), dest, tag, comm) - -""" - Isend(obj::T, dest::Integer, tag::Integer, comm::Comm) where T - -Starts a nonblocking send of `obj` to MPI rank `dest` of communicator `comm` -using with the message tag `tag`. - -Returns the commication `Request` for the nonblocking send. -""" -function Isend(obj::T, dest::Integer, tag::Integer, comm::Comm) where T - buf = [obj] - Isend(buf, dest, tag, comm) -end +Isend(data, dest::Integer, tag::Integer, comm::Comm) = + Isend(Buffer_send(data), dest, tag, comm) """ isend(obj, dest::Integer, tag::Integer, comm::Comm) @@ -279,47 +252,52 @@ function isend(obj, dest::Integer, tag::Integer, comm::Comm) end """ - Recv!(buf, [count::Integer, [datatype::Datatype,]] - src::Integer, tag::Integer, comm::Comm) + Recv!(data, src::Integer, tag::Integer, comm::Comm) -Completes a blocking receive of up to `count` elements of type `datatype` into `buf` -from MPI rank `src` of communicator `comm` using with the message tag `tag`. +Completes a blocking receive into the buffer `data` from MPI rank `src` of communicator +`comm` using with the message tag `tag`. -If not provided, `datatype` and `count` are derived from the element type and length of -`buf`, respectively. +`data` can be a [`Buffer`](@ref), or any object for which `Buffer(data)` is defined. Returns the [`Status`](@ref) of the receive. +# See also +- [`Recv`](@ref) +- [`recv`](@ref) + # External links $(_doc_external("MPI_Recv")) """ -function Recv!(buf, count::Integer, datatype::Union{Datatype,MPI_Datatype}, src::Integer, - tag::Integer, comm::Comm) +function Recv!(buf::Buffer, src::Integer, tag::Integer, comm::Comm) stat_ref = Ref{Status}(MPI.STATUS_EMPTY) # int MPI_Recv(void* buf, int count, MPI_Datatype datatype, int source, # int tag, MPI_Comm comm, MPI_Status *status) @mpichk ccall((:MPI_Recv, libmpi), Cint, (MPIPtr, Cint, MPI_Datatype, Cint, Cint, MPI_Comm, Ptr{Status}), - buf, count, datatype, src, tag, comm, stat_ref) + buf.data, buf.count, buf.datatype, src, tag, comm, stat_ref) return stat_ref[] end - -Recv!(buf, count::Integer, src::Integer, tag::Integer, comm::Comm) = - Recv!(buf, count, mpitype(eltype(buf)), src, tag, comm) -Recv!(buf::AbstractArray, src::Integer, tag::Integer, comm::Comm) = - Recv!(buf, length(buf), src, tag, comm) +Recv!(buf, src::Integer, tag::Integer, comm::Comm) = + Recv!(Buffer(buf), src, tag, comm) """ Recv(::Type{T}, src::Integer, tag::Integer, comm::Comm) -Completes a blocking receive of a buffer of type `T` from MPI rank `src` of communicator +Completes a blocking receive of an object of type `T` from MPI rank `src` of communicator `comm` using with the message tag `tag`. -Returns the buffer of type `T` and the [`Status`](@ref) of the receive. +Returns a tuple of the object of type `T` and the [`Status`](@ref) of the receive. + +# See also +- [`Recv!`](@ref) +- [`recv`](@ref) + +# External links +$(_doc_external("MPI_Recv")) """ function Recv(::Type{T}, src::Integer, tag::Integer, comm::Comm) where T buf = Ref{T}() - stat = Recv!(buf, 1, src, tag, comm) + stat = Recv!(buf, src, tag, comm) (buf[], stat) end @@ -340,35 +318,32 @@ function recv(src::Integer, tag::Integer, comm::Comm) end """ - Irecv!(buf, [count::Integer, [datatype::Datatype,]] - src::Integer, tag::Integer, comm::Comm) where T + Irecv!(data, src::Integer, tag::Integer, comm::Comm) -Starts a nonblocking receive of up to `count` elements of type `datatype` into `buf` -from MPI rank `src` of communicator `comm` using with the message tag `tag` +Starts a nonblocking receive into the buffer `data` from MPI rank `src` of communicator +`comm` using with the message tag `tag`. + +`data` can be a [`Buffer`](@ref), or any object for which `Buffer(data)` is defined. Returns the [`Request`](@ref) for the nonblocking receive. # External links $(_doc_external("MPI_Irecv")) """ -function Irecv!(buf, count::Integer, datatype::Union{Datatype, MPI_Datatype}, - src::Integer, tag::Integer, comm::Comm) where T - req = Request() +function Irecv!(buf::Buffer, src::Integer, tag::Integer, comm::Comm) + req = Request() # int MPI_Irecv(void* buf, int count, MPI_Datatype datatype, int source, # int tag, MPI_Comm comm, MPI_Request *request) @mpichk ccall((:MPI_Irecv, libmpi), Cint, (MPIPtr, Cint, MPI_Datatype, Cint, Cint, MPI_Comm, Ptr{MPI_Request}), - buf, count, datatype, src, tag, comm, req) + buf.data, buf.count, buf.datatype, src, tag, comm, req) req.buffer = buf refcount_inc() finalizer(free, req) return req end - -Irecv!(buf, count::Integer, src::Integer, tag::Integer, comm::Comm) = - Irecv!(buf, count, mpitype(eltype(buf)), src, tag, comm) -Irecv!(buf::AbstractArray, src::Integer, tag::Integer, comm::Comm) = - Irecv!(buf, length(buf), src, tag, comm) +Irecv!(data, src::Integer, tag::Integer, comm::Comm) = + Irecv!(Buffer(data), src, tag, comm) function irecv(src::Integer, tag::Integer, comm::Comm) @@ -383,11 +358,9 @@ function irecv(src::Integer, tag::Integer, comm::Comm) end """ - Sendrecv!(sendbuf, [sendcount::Integer, [sendtype::Union{Datatype, MPI_Datatype}]], - dest::Integer, sendtag::Integer, - recvbuf, [recvcount::Integer, [recvtype::Union{Datatype, MPI_Datatype}]], - source::Integer, recvtag::Integer, - comm::Comm) + Sendrecv!(sendbuf, dest::Integer, sendtag::Integer, + recvbuf, source::Integer, recvtag::Integer, + comm::Comm) Complete a blocking send-receive operation over the MPI communicator `comm`. Send `sendcount` elements of type `sendtype` from `sendbuf` to the MPI rank `dest` using message @@ -400,9 +373,9 @@ element type and length of `sendbuf`/`recvbuf`, respectively. # External links $(_doc_external("MPI_Sendrecv")) """ -function Sendrecv!(sendbuf, sendcount::Integer, sendtype::Union{Datatype, MPI_Datatype}, dest::Integer, sendtag::Integer, - recvbuf, recvcount::Integer, recvtype::Union{Datatype, MPI_Datatype}, source::Integer, recvtag::Integer, - comm::Comm) +function Sendrecv!(sendbuf::Buffer, dest::Integer, sendtag::Integer, + recvbuf::Buffer, source::Integer, recvtag::Integer, + comm::Comm) # int MPI_Sendrecv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, int dest, int sendtag, # void *recvbuf, int recvcount, MPI_Datatype recvtype, int source, int recvtag, # MPI_Comm comm, MPI_Status *status) @@ -410,24 +383,15 @@ function Sendrecv!(sendbuf, sendcount::Integer, sendtype::Union{Datatype, MPI_Da @mpichk ccall((:MPI_Sendrecv, libmpi), Cint, (MPIPtr, Cint, MPI_Datatype, Cint, Cint, MPIPtr, Cint, MPI_Datatype, Cint, Cint, - MPI_Comm, Ptr{Status}), - sendbuf, sendcount, sendtype, dest, sendtag, - recvbuf, recvcount, recvtype, source, recvtag, comm, stat_ref) + MPI_Comm, Ptr{Status}), + sendbuf.data, sendbuf.count, sendbuf.datatype, dest, sendtag, + recvbuf.data, recvbuf.count, recvbuf.datatype, source, recvtag, + comm, stat_ref) return stat_ref[] end +Sendrecv!(sendbuf, dest::Integer, sendtag::Integer, recvbuf, source::Integer, recvtag::Integer, comm::Comm) = + Sendrecv!(Buffer(sendbuf), dest, sendtag, Buffer(recvbuf), source, recvtag, comm) -function Sendrecv!(sendbuf, sendcount::Integer, dest::Integer, sendtag::Integer, - recvbuf, recvcount::Integer, source::Integer, recvtag::Integer, - comm::Comm) - return Sendrecv!(sendbuf, sendcount, mpitype(eltype(sendbuf)), dest, sendtag, - recvbuf, recvcount, mpitype(eltype(recvbuf)), source, recvtag, comm) -end -function Sendrecv!(sendbuf::AbstractArray, dest::Integer, sendtag::Integer, - recvbuf::AbstractArray, source::Integer, recvtag::Integer, - comm::Comm) - return Sendrecv!(sendbuf, length(sendbuf), dest, sendtag, - recvbuf, length(recvbuf), source, recvtag, comm) -end """ status = Wait!(req::Request) diff --git a/test/test_datatype.jl b/test/test_datatype.jl index d7aa8478e..f1ee4a1cb 100644 --- a/test/test_datatype.jl +++ b/test/test_datatype.jl @@ -3,113 +3,146 @@ using MPI MPI.Init() -#MPI.mpitype_dict[Boundary] = MPI.mpitype_dict[Int] comm_size = MPI.Comm_size(MPI.COMM_WORLD) -comm_rank = MPI.Comm_rank(MPI.COMM_WORLD) + 1 +comm_rank = MPI.Comm_rank(MPI.COMM_WORLD) # send to next higher process, with wraparound -dest = (comm_rank % comm_size) + 1 -if comm_rank > 1 - src = comm_rank - 1 -else - src = comm_size -end - +dest = mod(comm_rank+1, comm_size) +src = mod(comm_rank-1, comm_size) # test simple type - mutable struct NotABits - a::Any + a::Any end -@test_throws ArgumentError MPI.mpitype(NotABits) +@testset "Non bitstype" begin + @test_throws ArgumentError MPI.Datatype(NotABits) +end struct Boundary - c::UInt16 # force some padding to be inserted - a::Int - b::UInt8 + c::UInt16 # force some padding to be inserted + a::Int + b::UInt8 end - -MPI.mpitype(Boundary) - -arr = [Boundary( (comm_rank + i) % 127, i + comm_rank, i % 64) for i = 1:3] -req_send = MPI.Isend(arr, dest - 1, 1, MPI.COMM_WORLD) - -# receive the message -arr_recv = Array{Boundary}(undef, 3) -req_recv = MPI.Irecv!(arr_recv, src - 1, 1, MPI.COMM_WORLD) - -MPI.Wait!(req_send) -MPI.Wait!(req_recv) - -# check received array -for i=1:3 - bndry_i = arr_recv[i] - @test bndry_i.a == (src + i) - @test bndry_i.b == i % 64 - @test bndry_i.c == (src + i) % 127 +@testset "Compound type" begin + sz = sizeof(Boundary) + al = Base.datatype_alignment(Boundary) + @test MPI.Types.extent(MPI.Datatype(Boundary)) == (0, cld(sz,al)*al) + + arr = [Boundary( (comm_rank + i) % 127, i + comm_rank, i % 64) for i = 1:3] + req_send = MPI.Isend(arr, dest, 1, MPI.COMM_WORLD) + + # receive the message + arr_recv = Array{Boundary}(undef, 3) + req_recv = MPI.Irecv!(arr_recv, src, 1, MPI.COMM_WORLD) + + MPI.Wait!(req_send) + MPI.Wait!(req_recv) + + # check received array + for i=1:3 + bndry_i = arr_recv[i] + @test bndry_i.a == (src + i) + @test bndry_i.b == i % 64 + @test bndry_i.c == (src + i) % 127 + end end - -# test nested types struct Boundary2 - a::UInt32 - b::Tuple{Int, UInt8} + a::UInt32 + b::Tuple{Int, UInt8} + c::Nothing +end +@testset "nested types" begin + sz = sizeof(Boundary2) + al = Base.datatype_alignment(Boundary2) + @test MPI.Types.extent(MPI.Datatype(Boundary2)) == (0, cld(sz,al)*al) + + arr = [Boundary2( (comm_rank + i) % 127, ( Int(i + comm_rank), UInt8(i % 64)), nothing) for i = 1:3] + arr_recv = Array{Boundary2}(undef,3) + + req_send = MPI.Isend(arr, dest, 1, MPI.COMM_WORLD) + req_recv = MPI.Irecv!(arr_recv, src, 1, MPI.COMM_WORLD) + + MPI.Wait!(req_send) + MPI.Wait!(req_recv) + + # check received array + for i=1:3 + bndry_i = arr_recv[i] + @test bndry_i.a == (src + i) % 127 + @test bndry_i.b[1] == (src + i) + @test bndry_i.b[2] == (i % 64) + @test bndry_i.c === nothing + end end -MPI.mpitype(Boundary2) - -arr = Array{Boundary2}(undef,3) -arr_recv = Array{Boundary2}(undef,3) - -for i=1:3 - arr[i] = Boundary2( (comm_rank + i) % 127, ( Int(i + comm_rank), UInt8(i % 64) ) ) +primitive type Primitive16 16 end +primitive type Primitive24 24 end +primitive type Primitive80 80 end + +@testset for PrimitiveType in (Primitive16, Primitive24, Primitive80) + sz = sizeof(PrimitiveType) + al = Base.datatype_alignment(PrimitiveType) + @test MPI.Types.extent(MPI.Datatype(PrimitiveType)) == (0, cld(sz,al)*al) + + if VERSION < v"1.3" && PrimitiveType == Primitive80 + # alignment is broken on earlier Julia versions + continue + end + + arr = [Core.Intrinsics.trunc_int(PrimitiveType, UInt128(comm_rank + i)) for i = 1:4] + arr_recv = Array{PrimitiveType}(undef,4) + + recv_req = MPI.Irecv!(arr_recv, src, 2, MPI.COMM_WORLD) + send_req = MPI.Isend(arr, dest, 2, MPI.COMM_WORLD) + + MPI.Wait!(recv_req) + MPI.Wait!(send_req) + + @test arr_recv == [Core.Intrinsics.trunc_int(PrimitiveType, UInt128(src + i)) for i = 1:4] end -req_send = MPI.Isend(arr, dest - 1, 1, MPI.COMM_WORLD) -req_recv = MPI.Irecv!(arr_recv, src - 1, 1, MPI.COMM_WORLD) +@testset "packed non-aligned tuples" begin + T = NTuple{3,UInt8} -MPI.Wait!(req_send) -MPI.Wait!(req_recv) + sz = sizeof(T) + al = Base.datatype_alignment(T) + @test MPI.Types.extent(MPI.Datatype(T)) == (0, cld(sz,al)*al) -# check received array -for i=1:3 - bndry_i = arr_recv[i] - @test bndry_i.a == (src + i) % 127 - @test bndry_i.b[1] == (src + i) - @test bndry_i.b[2] == (i % 64) -end + arr = [(UInt8(comm_rank),UInt8(i),UInt8(0)) for i = 1:8] + arr_recv = Array{T}(undef,8) + req_send = MPI.Isend(arr, dest, 1, MPI.COMM_WORLD) + req_recv = MPI.Irecv!(arr_recv, src, 1, MPI.COMM_WORLD) -# test a primitive type -primitive type Primitive16 16 end -primitive type Primitive24 24 end + MPI.Wait!(req_send) + MPI.Wait!(req_recv) -nfields, blocklengths, displacements, types = MPI.factorPrimitiveType(Primitive16) -@test nfields == 1 -@test displacements[1] == 0 -@test types[1] == MPI.mpitype(Int16) -@test blocklengths[1] == 1 + # check received array + @test arr_recv == [(UInt8(src),UInt8(i),UInt8(0)) for i = 1:8] +end -nfields, blocklengths, displacements, types = MPI.factorPrimitiveType(Primitive24) -@test nfields == 2 -@test displacements[1] == 0 -@test displacements[2] == 2 -@test types[1] == MPI.mpitype(Int16) -@test types[2] == MPI.mpitype(Int8) -@test blocklengths[1] == 1 -@test blocklengths[2] == 1 +@testset "0-sized type" begin + sz = sizeof(Nothing) + al = Base.datatype_alignment(Nothing) + # OpenMPI gives incorrect values + # see https://github.com/open-mpi/ompi/issues/7266 + # @test MPI.Types.extent(MPI.Datatype(Nothing)) == (0, cld(sz,al)*al) -obj = [Ptr{Int}(comm_rank)] -obj_recv = Array{Ptr{Int}}(undef, 1) -recv_req = MPI.Irecv!(obj_recv, src - 1, 2, MPI.COMM_WORLD) -send_req = MPI.Isend(obj, dest - 1, 2, MPI.COMM_WORLD) + arr = [nothing for i = 1:100] + arr_recv = Array{Nothing}(undef,100) -MPI.Wait!(recv_req) -MPI.Wait!(send_req) + req_send = MPI.Isend(arr, dest, 1, MPI.COMM_WORLD) + req_recv = MPI.Irecv!(arr_recv, src, 1, MPI.COMM_WORLD) -@test obj_recv[1] == Ptr{Int}(src) + MPI.Wait!(req_send) + MPI.Wait!(req_recv) + + # check received array + @test arr_recv == [nothing for i = 1:100] +end MPI.Barrier(MPI.COMM_WORLD) diff --git a/test/test_sendrecv.jl b/test/test_sendrecv.jl index 564a126d9..e09da0b43 100644 --- a/test/test_sendrecv.jl +++ b/test/test_sendrecv.jl @@ -102,12 +102,6 @@ comm_rank = MPI.Comm_rank(comm) comm_size = MPI.Comm_size(comm) a = Float64[comm_rank, comm_rank, comm_rank] -# construct subarray type -subarr_send = MPI.Type_Create_Subarray(1, Cint[3], Cint[1], Cint[0], MPI.MPI_ORDER_FORTRAN, Float64) -subarr_recv = MPI.Type_Create_Subarray(1, Cint[3], Cint[1], Cint[2], MPI.MPI_ORDER_FORTRAN, Float64) -MPI.Type_Commit!(subarr_send) -MPI.Type_Commit!(subarr_recv) - # construct cartesian communicator with 1D topology comm_cart = MPI.Cart_create(comm, 1, Cint[comm_size], Cint[1], false) @@ -115,8 +109,8 @@ comm_cart = MPI.Cart_create(comm, 1, Cint[comm_size], Cint[1], false) src_rank, dest_rank = MPI.Cart_shift(comm_cart, 0, -1) # execute left shift using subarrays -MPI.Sendrecv!(a, 1, subarr_send, dest_rank, 0, - a, 1, subarr_recv, src_rank, 0, comm_cart) +MPI.Sendrecv!(@view(a[1]), dest_rank, 0, + @view(a[3]), src_rank, 0, comm_cart) @test a == [comm_rank, comm_rank, (comm_rank+1) % comm_size] @@ -124,8 +118,8 @@ MPI.Sendrecv!(a, 1, subarr_send, dest_rank, 0, # --------------------------- a = Float64[comm_rank, comm_rank, comm_rank] b = Float64[ -1, -1, -1] -MPI.Sendrecv!(a, 2, dest_rank, 1, - b, 2, src_rank, 1, comm_cart) +MPI.Sendrecv!(@view(a[1:2]), dest_rank, 1, + @view(b[1:2]), src_rank, 1, comm_cart) @test b == [(comm_rank+1) % comm_size, (comm_rank+1) % comm_size, -1] diff --git a/test/test_subarray.jl b/test/test_subarray.jl index aefec5a98..bf1c57bdd 100644 --- a/test/test_subarray.jl +++ b/test/test_subarray.jl @@ -1,44 +1,88 @@ using Test using MPI +if get(ENV,"JULIA_MPI_TEST_ARRAYTYPE","") == "CuArray" + using CuArrays + ArrayType = CuArray +else + ArrayType = Array +end + MPI.Init() comm = MPI.COMM_WORLD -size = MPI.Comm_size(comm) +comm_size = MPI.Comm_size(comm) rank = MPI.Comm_rank(comm) -# assuming there are at least two processes -x = rank == 0 ? collect(reshape(1.0:16.0, 4, 4)) : zeros(4, 4) - -subarray = MPI.Type_Create_Subarray(2, - Cint[4, 4], - Cint[2, 2], - Cint[0, 0], - MPI.MPI_ORDER_FORTRAN, - Float64) -MPI.Type_Commit!(subarray) - -# test blocking send -if rank == 0 - MPI.Send(x, 1, subarray, 1, 0, comm) -elseif rank == 1 - MPI.Recv!(x, 1, subarray, 0, 0, comm) - @test x == [1 5 0 0; - 2 6 0 0; - 0 0 0 0; - 0 0 0 0] +dest = mod(rank+1, comm_size) +src = mod(rank-1, comm_size) + +@testset "contiguous" begin + X = ArrayType(rank .+ collect(reshape(1.0:16.0, 4, 4))) + Y = ArrayType(zeros(4)) + req_send = MPI.Isend(@view(X[:,1]), dest, 0, comm) + req_recv = MPI.Irecv!(Y, src, 0, comm) + + MPI.Wait!(req_send) + MPI.Wait!(req_recv) + + @test Y == X[:,1] .- rank .+ src + + Y = ArrayType(zeros(2)) + + req_send = MPI.Isend(Y, dest, 1, comm) + req_recv = MPI.Irecv!(@view(X[3:4,1]), src, 1, comm) + + MPI.Wait!(req_send) + MPI.Wait!(req_recv) + + @test X[3:4,1] == Y +end + +@testset "strided" begin + X = ArrayType(rank .+ collect(reshape(1.0:16.0, 4, 4))) + Y = ArrayType(zeros(4)) + req_send = MPI.Isend(@view(X[2,:]), dest, 0, comm) + req_recv = MPI.Irecv!(Y, src, 0, comm) + + MPI.Wait!(req_send) + MPI.Wait!(req_recv) + + @test Y == X[2,:] .- rank .+ src + + Y = ArrayType(zeros(2)) + + req_send = MPI.Isend(Y, dest, 1, comm) + req_recv = MPI.Irecv!(@view(X[3,1:2]), src, 1, comm) + + MPI.Wait!(req_send) + MPI.Wait!(req_recv) + + @test X[3,1:2] == Y end -# test non blocking send -if rank == 0 - MPI.Isend(x, 1, subarray, 1, 0, comm) -elseif rank == 1 - req = MPI.Irecv!(x, 1, subarray, 0, 0, comm) - MPI.Wait!(req) - @test x == [1 5 0 0; - 2 6 0 0; - 0 0 0 0; - 0 0 0 0] +@testset "dense subarray" begin + X = ArrayType(rank .+ collect(reshape(1.0:16.0, 4, 4))) + Y = ArrayType(zeros(2,2)) + req_send = MPI.Isend(@view(X[2:3,3:4]), dest, 0, comm) + req_recv = MPI.Irecv!(Y, src, 0, comm) + + MPI.Wait!(req_send) + MPI.Wait!(req_recv) + + @test Y == X[2:3,3:4] .- rank .+ src + + Y = ArrayType(zeros(2,2)) + + req_send = MPI.Isend(Y, dest, 1, comm) + req_recv = MPI.Irecv!(@view(X[3:4,1:2]), src, 1, comm) + + MPI.Wait!(req_send) + MPI.Wait!(req_recv) + + @test X[3:4,1:2] == Y end +GC.gc() MPI.Finalize() +@test MPI.Finalized()