Skip to content

Commit

Permalink
Always ensure consistency of new MPI datatypes (#877)
Browse files Browse the repository at this point in the history
* Always ensure consistency of new MPI datatypes

* Create a datatype with size a multiple of its alignment

This should help ensure memory allocations of the MPI datatype have the same
alignment as the Julia counterpart.
  • Loading branch information
giordano authored Sep 13, 2024
1 parent 9584ac8 commit aac9688
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
17 changes: 14 additions & 3 deletions src/datatypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ end

function Datatype(::Type{T}) where {T}
global created_datatypes
get!(created_datatypes, T) do
datatype = get!(created_datatypes, T) do
datatype = Datatype()
# lazily initialize so that it can be safely precompiled
function init()
Expand All @@ -162,6 +162,15 @@ function Datatype(::Type{T}) where {T}
init()
datatype
end

# Make sure the "aligned" size of the type matches the MPI "extent".
sz = sizeof(T)
al = Base.datatype_alignment(T)
mpi_extent = Types.extent(datatype)
aligned_size = (0, cld(sz,al)*al)
@assert mpi_extent == aligned_size "The MPI extent of type $(T) ($(mpi_extent[2])) does not match the size expected by Julia ($(aligned_size[2]))"

return datatype
end

function Base.show(io::IO, datatype::Datatype)
Expand Down Expand Up @@ -437,8 +446,10 @@ function create!(newtype::Datatype, ::Type{T}) where {T}
types = Datatype[]

if isprimitivetype(T)
# primitive type
szrem = sz = sizeof(T)
# This is a primitive type. Create a type which has size an integer multiple of its
# alignment on the Julia side: <https://github.com/JuliaParallel/MPI.jl/issues/853>.
al = Base.datatype_alignment(T)
szrem = sz = cld(sizeof(T), al) * al
disp = 0
for (i,basetype) in (8 => Datatype(UInt64), 4 => Datatype(UInt32), 2 => Datatype(UInt16), 1 => Datatype(UInt8))
if sz == i
Expand Down
13 changes: 6 additions & 7 deletions test/test_datatype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,26 +87,25 @@ end
primitive type Primitive16 16 end
primitive type Primitive24 24 end
primitive type Primitive80 80 end
primitive type Primitive104 104 end
primitive type Primitive136 136 end

@testset for PrimitiveType in (Primitive16, Primitive24, Primitive80)
@testset for PrimitiveType in (Primitive16, Primitive24, Primitive80, Primitive104, Primitive136)
sz = sizeof(PrimitiveType)
al = Base.datatype_alignment(PrimitiveType)
@test MPI.Types.extent(MPI.Datatype(PrimitiveType)) == (0, cld(sz,al)*al)

if VERSION < v"1.3" && PrimitiveType == Primitive80
# alignment is broken on earlier Julia versions
continue
end
conv = sizeof(PrimitiveType) <= sizeof(UInt128) ? Core.Intrinsics.trunc_int : Core.Intrinsics.sext_int

arr = [Core.Intrinsics.trunc_int(PrimitiveType, UInt128(comm_rank + i)) for i = 1:4]
arr = [conv(PrimitiveType, UInt128(comm_rank + i)) for i = 1:4]
arr_recv = Array{PrimitiveType}(undef,4)

recv_req = MPI.Irecv!(arr_recv, src, 2, MPI.COMM_WORLD)
send_req = MPI.Isend(arr, dest, 2, MPI.COMM_WORLD)

MPI.Waitall([recv_req, send_req])

@test arr_recv == [Core.Intrinsics.trunc_int(PrimitiveType, UInt128(src + i)) for i = 1:4]
@test arr_recv == [conv(PrimitiveType, UInt128(src + i)) for i = 1:4]
end

@testset "packed non-aligned tuples" begin
Expand Down

0 comments on commit aac9688

Please sign in to comment.