diff --git a/src/num.jl b/src/num.jl index 5d0a3bac0..c3d3101e8 100644 --- a/src/num.jl +++ b/src/num.jl @@ -78,9 +78,18 @@ end substitute(expr, s::Pair; kw...) = substituter([s[1] => s[2]])(expr; kw...) substitute(expr, s::Vector; kw...) = substituter(s)(expr; kw...) -substituter(pair::Pair) = substituter((pair,)) +function _unwrap_callwithmeta(x) + x = value(x) + return x isa CallWithMetadata ? x.f : x +end +function subrules_to_dict(pairs) + if pairs isa Pair + pairs = (pairs,) + end + return Dict(_unwrap_callwithmeta(k) => value(v) for (k, v) in pairs) +end function substituter(pairs) - dict = Dict(value(k) => value(v) for (k, v) in pairs) + dict = subrules_to_dict(pairs) (expr; kw...) -> SymbolicUtils.substitute(value(expr), dict; kw...) end diff --git a/src/variable.jl b/src/variable.jl index 3dfae335e..15db8d35f 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -526,6 +526,7 @@ infinite loops in cases where the substitutions in `dict` are circular See also: [`fast_substitute`](@ref). """ function fixpoint_sub(x, dict; operator = Nothing, maxiters = 10000) + dict = subrules_to_dict(dict) y = fast_substitute(x, dict; operator) while !isequal(x, y) && maxiters > 0 y = x diff --git a/test/utils.jl b/test/utils.jl index 977bed6b9..727c53366 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,5 +1,5 @@ using Symbolics -using Symbolics: symbolic_to_float, var_from_nested_derivative +using Symbolics: symbolic_to_float, var_from_nested_derivative, unwrap @testset "get_variables" begin @variables t x y z(t) @@ -46,3 +46,12 @@ end expr = Symbolics.fixpoint_sub(x, Dict(x => y, y => x); maxiters = 9) @test isequal(expr, y) end + +@testset "Issue#1342 substitute working on called symbolics" begin + @variables p(..) x y + arg = unwrap(substitute(p(x), [p => identity])) + @test iscall(arg) && operation(arg) == identity && isequal(only(arguments(arg)), x) + @test unwrap(substitute(p(x), [p => sqrt, x => 4.0])) ≈ 2.0 + arg = Symbolics.fixpoint_sub(p(x), [p => sqrt, x => 2y + 3, y => 1.0 + p(4)]) + @test arg ≈ 3.0 +end