diff --git a/examples/README.jl b/examples/README.jl index 48273fc..b619abf 100644 --- a/examples/README.jl +++ b/examples/README.jl @@ -28,16 +28,16 @@ using SparseArraysBaseNext: SparseArrayInterface ## @derive SparseArrayInterface (T=SparseArrayDOK) Base.getindex(::T, ::Int...) -@derive SparseArrayInterface begin - Base.getindex(::SparseArrayDOK, ::Int...) - Base.setindex!(::SparseArrayDOK, ::Any, ::Int...) -end - -## @derive (T=SparseArrayDOK,) begin -## Base.getindex(::T, ::Int...) -## Base.setindex!(::T, ::Any, ::Int...) +## @derive SparseArrayInterface begin +## Base.getindex(::SparseArrayDOK, ::Int...) +## Base.setindex!(::SparseArrayDOK, ::Any, ::Int...) ## end +@derive (T=SparseArrayDOK,) begin + Base.getindex(::T, ::Int...) + Base.setindex!(::T, ::Any, ::Int...) +end + ## @derive SparseArrayInterface (T=SparseArrayDOK) begin ## Base.getindex(::T, ::Int...) ## Base.setindex!(::T, ::Any, ::Int...) diff --git a/src/lib/Derive/derive_macro.jl b/src/lib/Derive/derive_macro.jl index 781a8b4..e303041 100644 --- a/src/lib/Derive/derive_macro.jl +++ b/src/lib/Derive/derive_macro.jl @@ -15,7 +15,7 @@ end function globalref_derive(expr) return @match expr begin - :(Derive.$f($(x...))) => :($(GlobalRef(Derive, :($f)))($(x...))) + :(Derive.$f) => :($(GlobalRef(Derive, :($f)))) e::Expr => Expr(e.head, map(globalref_derive, e.args)...) a => a end @@ -93,7 +93,7 @@ end argname(i::Int) = Symbol(:arg, i) -function derive_interface_func(interface::Expr, func::Expr) +function derive_interface_func(interface::Union{Symbol,Expr}, func::Expr) name, args, kwargs, whereparams, rettype = split_function_head(func) argnames = map(argname, 1:length(args)) named_args = map(1:length(args)) do i @@ -122,29 +122,31 @@ function derive_interface_func(interface::Expr, func::Expr) jlfn = JLFunction(; name, args=named_args, kwargs, whereparams, rettype, body) # Use `globalref_derive` to not require having `Derive` in the # namespace when `@derive` is called. - return @show globalref_derive(codegen_ast(jlfn)) + return globalref_derive(codegen_ast(jlfn)) end function derive_func(types::Expr, func::Expr) Meta.isexpr(types, :tuple) && all(arg -> Meta.isexpr(arg, :(=)), types.args) || error("Wrong types format.") name, args, kwargs, whereparams, rettype = split_function_head(func) - - @show types - @show func - @show args - - args = map - for type in types.args - @show type - @show type.args[1] - @show type.args[2] - - args = Expr + new_args = args + for type_expr in types.args + typevar, type = @match type_expr begin + :($x = $y) => (x, y) + end + new_args = map(args) do arg + return @match arg begin + :(::$T) => T == typevar ? :(::$type) : :(::$T) + :(::$T...) => T == typevar ? :(::$type...) : :(::$T...) + end + end end - - error() - return derive_func(interface, func) + active_argnames = map(argname, findall(args .≠ new_args)) + interface = globalref_derive(:(Derive.AbstractInterface(Derive.AbstractInterface.(($(active_argnames...),))...))) + _, func, _ = split_function( + codegen_ast(JLFunction(; name, args=new_args, kwargs, whereparams, rettype)) + ) + return derive_interface_func(interface, func) end function derive(::Val{:AbstractArrayOps})