Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add macro to create custom Ops also on aarch64 #871

Merged
merged 23 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/reference/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ MPI.Types.duplicate

```@docs
MPI.Op
MPI.@Op
```

## Info objects
Expand Down
53 changes: 53 additions & 0 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,56 @@ function Op(f, T=Any; iscommutative=false)
finalizer(free, op)
return op
end

"""
@Op(f, T)
vchuravy marked this conversation as resolved.
Show resolved Hide resolved

Define a custom operator [`Op`](@ref) using the function `f`.
On platfroms like AArch53, Julia does not support runtime closures,
vchuravy marked this conversation as resolved.
Show resolved Hide resolved
being passed to C. The generic version of [`Op`](@ref) uses that
to support arbitrary function being passed as MPI reduction operators.
In contrast `@Op` can be used to statically declare a function to
be passed as an MPI operator.

```julia
function my_reduce(x, y)
2x+y-x
end
MPI.@Op(my_reduce, Int)
# ...
MPI.Reduce!(send_arr, recv_arr, my_reduce, MPI.COMM_WORLD; root=root)
vchuravy marked this conversation as resolved.
Show resolved Hide resolved
#...

vchuravy marked this conversation as resolved.
Show resolved Hide resolved
!!! warning
Note that `@Op` works be introducing a new method to `Op`, potentially invalidating other users of `Op`.
```
vchuravy marked this conversation as resolved.
Show resolved Hide resolved
"""
macro Op(f, T)
name_wrapper = gensym(Symbol(f, :_, T, :_wrapper))
name_fptr = gensym(Symbol(f, :_, T, :_ptr))
name_module = gensym(Symbol(f, :_, T, :_module))
# The gist is that we can use a method very similar to how we handle `min`/`max`
vchuravy marked this conversation as resolved.
Show resolved Hide resolved
# but since this might be used from user code we can't use add_load_time_hook!
# this is why we introduce a new module that has a `__init__` function.
expr = quote
module $(name_module)
# import ..$f, ..$T
$(Expr(:import, Expr(:., :., :., f), Expr(:., :., :., T))) # Julia 1.6 strugles with import ..$f, ..$T
const $(name_wrapper) = $OpWrapper{typeof($f),$T}($f)
const $(name_fptr) = Ref(@cfunction($(name_wrapper), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{$MPI_Datatype})))
function __init__()
$(name_fptr)[] = @cfunction($(name_wrapper), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{$MPI_Datatype}))
end
import MPI: Op
function Op(::typeof($f), ::Type{$T}; iscommutative=true)
vchuravy marked this conversation as resolved.
Show resolved Hide resolved
vchuravy marked this conversation as resolved.
Show resolved Hide resolved
op = Op($OP_NULL.val, $(name_fptr)[])
# int MPI_Op_create(MPI_User_function* user_fn, int commute, MPI_Op* op)
$API.MPI_Op_create($(name_fptr)[], iscommutative, op)

finalizer($free, op)
end
end
end
expr.head = :toplevel
esc(expr)
end
9 changes: 7 additions & 2 deletions test/test_reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,15 @@ if isroot
@test sum_mesg == sz .* mesg
end

function my_reduce(x, y)
2x+y-x
end
MPI.@Op(my_reduce, Int)

if can_do_closures
operators = [MPI.SUM, +, (x,y) -> 2x+y-x]
operators = [MPI.SUM, +, my_reduce, (x,y) -> 2x+y-x]
else
operators = [MPI.SUM, +]
operators = [MPI.SUM, +, my_reduce]
end

for T = [Int]
Expand Down
Loading