From a87bb4879a0ddd55bcdece903f6abfa2082c00df Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 4 Oct 2023 16:41:41 -0400 Subject: [PATCH] Fix binomial registration --- src/extra_functions.jl | 2 +- src/register.jl | 11 ++++++++--- test/overloads.jl | 2 ++ 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/extra_functions.jl b/src/extra_functions.jl index d68f1807f..5c4f6cccf 100644 --- a/src/extra_functions.jl +++ b/src/extra_functions.jl @@ -1,4 +1,4 @@ -@register_symbolic Base.binomial(n,k) +@register_symbolic Base.binomial(n, k)::Int true [Integer] @register_symbolic Base.sign(x)::Int derivative(::typeof(sign), args::NTuple{1,Any}, ::Val{1}) = 0 diff --git a/src/register.jl b/src/register.jl index dc385101a..979bc6dfe 100644 --- a/src/register.jl +++ b/src/register.jl @@ -1,7 +1,7 @@ using SymbolicUtils: Symbolic """ - @register_symbolic(expr, define_promotion = true, Ts = [Num, Symbolic, Real]) + @register_symbolic(expr, define_promotion = true, Ts = [Real]) Overload appropriate methods so that Symbolics can stop tracing into the registered function. If `define_promotion` is true, then a promotion method in @@ -22,7 +22,7 @@ overwriting. @register_symbolic hoo(x, y)::Int # `hoo` returns `Int` ``` """ -macro register_symbolic(expr, define_promotion = true, Ts = []) +macro register_symbolic(expr, define_promotion = true, Ts = :([])) if expr.head === :(::) ret_type = expr.args[2] expr = expr.args[1] @@ -31,6 +31,8 @@ macro register_symbolic(expr, define_promotion = true, Ts = []) end @assert expr.head === :call + @assert Ts.head === :vect + Ts = Ts.args f = expr.args[1] args = expr.args[2:end] @@ -41,7 +43,10 @@ macro register_symbolic(expr, define_promotion = true, Ts = []) types = map(args) do x if x isa Symbol - :(($Real, $wrapper_type($Real), $Symbolic{<:$Real})) + if isempty(Ts) + Ts = [Real] + end + :(($(Ts...), $wrapper_type($Real), $Symbolic{<:$Real})) elseif Meta.isexpr(x, :(::)) T = x.args[2] :($has_symwrapper($T) ? diff --git a/test/overloads.jl b/test/overloads.jl index c728679c7..7218467e6 100644 --- a/test/overloads.jl +++ b/test/overloads.jl @@ -235,3 +235,5 @@ stringcontent = string(d.content) for f in [<, <=, >, >=, isless] @test_nowarn f(t, 1.0) end + +@test_nowarn binomial(t, 1)