diff --git a/docs/src/reference/library.md b/docs/src/reference/library.md index 94e3942e0..5ce9b5299 100644 --- a/docs/src/reference/library.md +++ b/docs/src/reference/library.md @@ -14,5 +14,6 @@ MPI.MPI_LIBRARY_VERSION_STRING ```@docs MPI.versioninfo MPI.has_cuda +MPI.has_rocm MPI.identify_implementation ``` diff --git a/src/environment.jl b/src/environment.jl index 52597c430..e472d57be 100644 --- a/src/environment.jl +++ b/src/environment.jl @@ -320,7 +320,7 @@ Wtime() = API.MPI_Wtime() Check if the MPI implementation is known to have CUDA support. Currently only Open MPI provides a mechanism to check, so it will return `false` with other implementations -(unless overriden). +(unless overriden). For "IBMSpectrumMPI" it will return `true`. This can be overriden by setting the `JULIA_MPI_HAS_CUDA` environment variable to `true` or `false`. @@ -334,7 +334,7 @@ function has_cuda() # Only Open MPI provides a function to check CUDA support @static if MPI_LIBRARY == "OpenMPI" # int MPIX_Query_cuda_support(void) - return 0 != ccall((:MPIX_Query_cuda_support, libmpi), Cint, ()) + return @ccall libmpi.MPIX_Query_cuda_support()::Bool elseif MPI_LIBRARY == "IBMSpectrumMPI" return true else @@ -344,3 +344,28 @@ function has_cuda() return parse(Bool, flag) end end + +""" + MPI.has_rocm() + +Check if the MPI implementation is known to have ROCm support. Currently only Open MPI +provides a mechanism to check, so it will return `false` with other implementations +(unless overriden). + +This can be overriden by setting the `JULIA_MPI_HAS_ROCM` environment variable to `true` +or `false`. +""" +function has_rocm() + flag = get(ENV, "JULIA_MPI_HAS_ROCM", nothing) + if flag === nothing + # Only Open MPI provides a function to check ROCm support + @static if MPI_LIBRARY == "OpenMPI" && MPI_LIBRARY_VERSION ≥ v"5" + # int MPIX_Query_rocm_support(void) + return @ccall libmpi.MPIX_Query_rocm_support()::Bool + else + return false + end + else + return parse(Bool, flag) + end +end diff --git a/test/test_basic.jl b/test/test_basic.jl index 495c9f83b..57b6a65e5 100644 --- a/test/test_basic.jl +++ b/test/test_basic.jl @@ -12,6 +12,12 @@ if get(ENV,"JULIA_MPI_TEST_ARRAYTYPE","") == "CuArray" @test MPI.has_cuda() end +@test MPI.has_rocm() isa Bool + +if get(ENV,"JULIA_MPI_TEST_ARRAYTYPE","") == "ROCArray" + @test MPI.has_rocm() +end + @test !MPI.Finalized() MPI.Finalize() @test MPI.Finalized()