Skip to content

Commit

Permalink
Merge pull request #1343 from AayushSabharwal/as/sub-called-var
Browse files Browse the repository at this point in the history
feat: support substituting `CallWithMetadata` in expressions
  • Loading branch information
ChrisRackauckas authored Nov 4, 2024
2 parents 85a06f9 + cc4ab73 commit bff7352
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
13 changes: 11 additions & 2 deletions src/num.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,18 @@ end
substitute(expr, s::Pair; kw...) = substituter([s[1] => s[2]])(expr; kw...)
substitute(expr, s::Vector; kw...) = substituter(s)(expr; kw...)

substituter(pair::Pair) = substituter((pair,))
function _unwrap_callwithmeta(x)
x = value(x)
return x isa CallWithMetadata ? x.f : x
end
function subrules_to_dict(pairs)
if pairs isa Pair
pairs = (pairs,)
end
return Dict(_unwrap_callwithmeta(k) => value(v) for (k, v) in pairs)
end
function substituter(pairs)
dict = Dict(value(k) => value(v) for (k, v) in pairs)
dict = subrules_to_dict(pairs)
(expr; kw...) -> SymbolicUtils.substitute(value(expr), dict; kw...)
end

Expand Down
1 change: 1 addition & 0 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ infinite loops in cases where the substitutions in `dict` are circular
See also: [`fast_substitute`](@ref).
"""
function fixpoint_sub(x, dict; operator = Nothing, maxiters = 10000)
dict = subrules_to_dict(dict)
y = fast_substitute(x, dict; operator)
while !isequal(x, y) && maxiters > 0
y = x
Expand Down
11 changes: 10 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Symbolics
using Symbolics: symbolic_to_float, var_from_nested_derivative
using Symbolics: symbolic_to_float, var_from_nested_derivative, unwrap

@testset "get_variables" begin
@variables t x y z(t)
Expand Down Expand Up @@ -46,3 +46,12 @@ end
expr = Symbolics.fixpoint_sub(x, Dict(x => y, y => x); maxiters = 9)
@test isequal(expr, y)
end

@testset "Issue#1342 substitute working on called symbolics" begin
@variables p(..) x y
arg = unwrap(substitute(p(x), [p => identity]))
@test iscall(arg) && operation(arg) == identity && isequal(only(arguments(arg)), x)
@test unwrap(substitute(p(x), [p => sqrt, x => 4.0])) 2.0
arg = Symbolics.fixpoint_sub(p(x), [p => sqrt, x => 2y + 3, y => 1.0 + p(4)])
@test arg 3.0
end

0 comments on commit bff7352

Please sign in to comment.