From 0ad4f8173e8c24a003d46f5978d8d561bf5be93a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 24 Oct 2024 16:29:20 +0530 Subject: [PATCH] feat: add iteration limit for `fixpoint_sub` --- src/variable.jl | 16 ++++++++++++---- test/utils.jl | 8 ++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/variable.jl b/src/variable.jl index b5230f07b..4e7f23adb 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -503,21 +503,29 @@ function SymbolicIndexingInterface.symbolic_evaluate(ex::Union{Num, Arr, Symboli end """ - fixpoint_sub(expr, dict; operator = Nothing) + fixpoint_sub(expr, dict; operator = Nothing, maxiters = 10000) Given a symbolic expression, equation or inequality `expr` perform the substitutions in `dict` recursively until the expression does not change. Substitutions that depend on one another will thus be recursively expanded. For example, `fixpoint_sub(x, Dict(x => y, y => 3))` will return `3`. The `operator` keyword can be -specified to prevent substitution of expressions inside operators of the given type. +specified to prevent substitution of expressions inside operators of the given type. The +`maxiters` keyword is used to limit the number of times the substitution can occur to avoid +infinite loops in cases where the substitutions in `dict` are circular +(e.g. `[x => y, y => x]`). See also: [`fast_substitute`](@ref). """ -function fixpoint_sub(x, dict; operator = Nothing) +function fixpoint_sub(x, dict; operator = Nothing, maxiters = 10000) y = fast_substitute(x, dict; operator) - while !isequal(x, y) + while !isequal(x, y) && maxiters > 0 y = x x = fast_substitute(y, dict; operator) + maxiters -= 1 + end + + if !isequal(x, y) + @warn "Did not converge after `maxiters = $maxiters` substitutions. Either there is a cycle in the rules or `maxiters` needs to be higher." end return x diff --git a/test/utils.jl b/test/utils.jl index 616dee642..977bed6b9 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -38,3 +38,11 @@ end @test var_from_nested_derivative(p) == (p, 0) @test var_from_nested_derivative(D(p(x))) == (p(x), 1) end + +@testset "fixpoint_sub maxiters" begin + @variables x y + expr = Symbolics.fixpoint_sub(x, Dict(x => y, y => x)) + @test isequal(expr, x) + expr = Symbolics.fixpoint_sub(x, Dict(x => y, y => x); maxiters = 9) + @test isequal(expr, y) +end