From 0ea30bf4b89099be98c7a24638391b131e96775a Mon Sep 17 00:00:00 2001 From: Karl Wessel Date: Fri, 8 Nov 2024 01:14:54 +0100 Subject: [PATCH] add flag for activating robust calculation of expand_derivatives --- src/diff.jl | 38 +++++++++++++++++++------------------- test/diff.jl | 30 ++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 19 deletions(-) diff --git a/src/diff.jl b/src/diff.jl index 891e2cde2..f9da8c071 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -180,12 +180,12 @@ julia> dfx=expand_derivatives(Dx(f)) (k*((2abs(x - y)) / y - 2z)*IfElse.ifelse(signbit(x - y), -1, 1)) / y ``` """ -function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) +function expand_derivatives(O::Symbolic, simplify=false; robust=false, occurrences=nothing) if iscall(O) && isa(operation(O), Differential) arg = only(arguments(O)) - arg = expand_derivatives(arg, false) + arg = expand_derivatives(arg, false; robust) - if occurrences == nothing + if robust || occurrences == nothing occurrences = occursin_info(operation(O).x, arg) end @@ -202,14 +202,14 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) return D(arg) # base case if any argument is directly equal to the i.v. else return sum(inner_args, init=0) do a - return expand_derivatives(Differential(a)(arg)) * - expand_derivatives(D(a)) + return expand_derivatives(Differential(a)(arg); robust) * + expand_derivatives(D(a); robust) end end elseif op === (IfElse.ifelse) args = arguments(arg) O = op(args[1], D(args[2]), D(args[3])) - return expand_derivatives(O, simplify; occurrences) + return expand_derivatives(O, simplify; robust, occurrences) elseif isa(op, Differential) # The recursive expand_derivatives was not able to remove # a nested Differential. We can attempt to differentiate the @@ -218,12 +218,12 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) if isequal(op.x, D.x) return D(arg) else - inner = expand_derivatives(D(arguments(arg)[1]), false) + inner = expand_derivatives(D(arguments(arg)[1]), false; robust) # if the inner expression is not expandable either, return if iscall(inner) && operation(inner) isa Differential return D(arg) else - return expand_derivatives(op(inner), simplify) + return expand_derivatives(op(inner), simplify; robust) end end elseif isa(op, Integral) @@ -231,7 +231,7 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) domain = op.domain.domain a, b = DomainSets.endpoints(domain) c = 0 - inner_function = expand_derivatives(arguments(arg)[1]) + inner_function = expand_derivatives(arguments(arg)[1]; robust) if iscall(value(a)) t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a))) t2 = D(a) @@ -242,7 +242,7 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) t2 = D(b) c += t1*t2 end - inner = expand_derivatives(D(arguments(arg)[1])) + inner = expand_derivatives(D(arguments(arg)[1]); robust) c += op(inner) return value(c) end @@ -254,7 +254,7 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) c = 0 for i in 1:l - t2 = expand_derivatives(D(inner_args[i]),false, occurrences=arguments(occurrences)[i]) + t2 = expand_derivatives(D(inner_args[i]),false; robust, occurrences=arguments(occurrences)[i]) x = if _iszero(t2) t2 @@ -286,23 +286,23 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) return simplify ? SymbolicUtils.simplify(x) : x end elseif iscall(O) && isa(operation(O), Integral) - return operation(O)(expand_derivatives(arguments(O)[1])) + return operation(O)(expand_derivatives(arguments(O)[1]; robust)) elseif !hasderiv(O) return O else - args = map(a->expand_derivatives(a, false), arguments(O)) + args = map(a->expand_derivatives(a, false; robust), arguments(O)) O1 = operation(O)(args...) return simplify ? SymbolicUtils.simplify(O1) : O1 end end -function expand_derivatives(n::Num, simplify=false; occurrences=nothing) - wrap(expand_derivatives(value(n), simplify; occurrences=occurrences)) +function expand_derivatives(n::Num, simplify=false; robust=false, occurrences=nothing) + wrap(expand_derivatives(value(n), simplify; robust, occurrences)) end -function expand_derivatives(n::Complex{Num}, simplify=false; occurrences=nothing) - wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify; occurrences=occurrences), - expand_derivatives(imag(n), simplify; occurrences=occurrences))) +function expand_derivatives(n::Complex{Num}, simplify=false; robust=false, occurrences=nothing) + wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify; robust, occurrences), + expand_derivatives(imag(n), simplify; robust, occurrences))) end -expand_derivatives(x, simplify=false; occurrences=nothing) = x +expand_derivatives(x, simplify=false; robust=false, occurrences=nothing) = x _iszero(x) = false _isone(x) = false diff --git a/test/diff.jl b/test/diff.jl index d40fa1185..76833b8bc 100644 --- a/test/diff.jl +++ b/test/diff.jl @@ -349,6 +349,36 @@ let @test isequal(expand_derivatives(Differential(t)(t^2 + im*t)), 2t + im) end +# 1262 +# +let + @variables t b(t) + D = Differential(t) + expr = b - ((D(b))^2) * D(D(b)) + expr2 = D(expr) + @test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true)) + @test_throws BoundsError expand_derivatives(expr2) + @test isequal(expand_derivatives(expr2; robust=true), D(b) - (D(b)^2)*D(D(D(b))) - 2D(b)*(D(D(b))^2)) +end + +# 1126 +# +let + @syms y f(y) g(y) h(y) + D = Differential(y) + + expr_gen = (fun) -> D(D(((-D(D(fun))) / g(y)))) + + expr = expr_gen(g(y)) + @test_broken isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true)) + expr = expr_gen(h(y)) + @test_broken isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true)) + + expected = substitute(expand_derivatives(expr; robust=true), h(y) => f(y)) + expr = expr_gen(f(y)) + @test_throws BoundsError expand_derivatives(expr) + @test isequal(expand(expand_derivatives(expr; robust=true)), expected) +end # Check `is_derivative` function let