Skip to content

Commit

Permalink
Add support for determining the interface from specified inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Dec 5, 2024
1 parent ce2cb30 commit abf9f57
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 26 deletions.
16 changes: 8 additions & 8 deletions examples/README.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ using SparseArraysBaseNext: SparseArrayInterface

## @derive SparseArrayInterface (T=SparseArrayDOK) Base.getindex(::T, ::Int...)

@derive SparseArrayInterface begin
Base.getindex(::SparseArrayDOK, ::Int...)
Base.setindex!(::SparseArrayDOK, ::Any, ::Int...)
end

## @derive (T=SparseArrayDOK,) begin
## Base.getindex(::T, ::Int...)
## Base.setindex!(::T, ::Any, ::Int...)
## @derive SparseArrayInterface begin
## Base.getindex(::SparseArrayDOK, ::Int...)
## Base.setindex!(::SparseArrayDOK, ::Any, ::Int...)
## end

@derive (T=SparseArrayDOK,) begin
Base.getindex(::T, ::Int...)
Base.setindex!(::T, ::Any, ::Int...)
end

## @derive SparseArrayInterface (T=SparseArrayDOK) begin
## Base.getindex(::T, ::Int...)
## Base.setindex!(::T, ::Any, ::Int...)
Expand Down
38 changes: 20 additions & 18 deletions src/lib/Derive/derive_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end

function globalref_derive(expr)
return @match expr begin
:(Derive.$f($(x...))) => :($(GlobalRef(Derive, :($f)))($(x...)))
:(Derive.$f) => :($(GlobalRef(Derive, :($f))))
e::Expr => Expr(e.head, map(globalref_derive, e.args)...)
a => a
end
Expand Down Expand Up @@ -93,7 +93,7 @@ end

argname(i::Int) = Symbol(:arg, i)

function derive_interface_func(interface::Expr, func::Expr)
function derive_interface_func(interface::Union{Symbol,Expr}, func::Expr)
name, args, kwargs, whereparams, rettype = split_function_head(func)
argnames = map(argname, 1:length(args))
named_args = map(1:length(args)) do i
Expand Down Expand Up @@ -122,29 +122,31 @@ function derive_interface_func(interface::Expr, func::Expr)
jlfn = JLFunction(; name, args=named_args, kwargs, whereparams, rettype, body)
# Use `globalref_derive` to not require having `Derive` in the
# namespace when `@derive` is called.
return @show globalref_derive(codegen_ast(jlfn))
return globalref_derive(codegen_ast(jlfn))
end

function derive_func(types::Expr, func::Expr)
Meta.isexpr(types, :tuple) && all(arg -> Meta.isexpr(arg, :(=)), types.args) ||
error("Wrong types format.")
name, args, kwargs, whereparams, rettype = split_function_head(func)

@show types
@show func
@show args

args = map
for type in types.args
@show type
@show type.args[1]
@show type.args[2]

args = Expr
new_args = args
for type_expr in types.args
typevar, type = @match type_expr begin
:($x = $y) => (x, y)
end
new_args = map(args) do arg
return @match arg begin
:(::$T) => T == typevar ? :(::$type) : :(::$T)
:(::$T...) => T == typevar ? :(::$type...) : :(::$T...)
end
end
end

error()
return derive_func(interface, func)
active_argnames = map(argname, findall(args .≠ new_args))
interface = globalref_derive(:(Derive.AbstractInterface(Derive.AbstractInterface.(($(active_argnames...),))...)))
_, func, _ = split_function(
codegen_ast(JLFunction(; name, args=new_args, kwargs, whereparams, rettype))
)
return derive_interface_func(interface, func)
end

function derive(::Val{:AbstractArrayOps})
Expand Down

0 comments on commit abf9f57

Please sign in to comment.