Skip to content

Commit

Permalink
add flag for activating robust calculation of expand_derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
Karl Wessel committed Nov 8, 2024
1 parent 8c518c2 commit 0ea30bf
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 19 deletions.
38 changes: 19 additions & 19 deletions src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,12 @@ julia> dfx=expand_derivatives(Dx(f))
(k*((2abs(x - y)) / y - 2z)*IfElse.ifelse(signbit(x - y), -1, 1)) / y
```
"""
function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing)
function expand_derivatives(O::Symbolic, simplify=false; robust=false, occurrences=nothing)
if iscall(O) && isa(operation(O), Differential)
arg = only(arguments(O))
arg = expand_derivatives(arg, false)
arg = expand_derivatives(arg, false; robust)

if occurrences == nothing
if robust || occurrences == nothing
occurrences = occursin_info(operation(O).x, arg)
end

Expand All @@ -202,14 +202,14 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing)
return D(arg) # base case if any argument is directly equal to the i.v.
else
return sum(inner_args, init=0) do a
return expand_derivatives(Differential(a)(arg)) *
expand_derivatives(D(a))
return expand_derivatives(Differential(a)(arg); robust) *
expand_derivatives(D(a); robust)
end
end
elseif op === (IfElse.ifelse)
args = arguments(arg)
O = op(args[1], D(args[2]), D(args[3]))
return expand_derivatives(O, simplify; occurrences)
return expand_derivatives(O, simplify; robust, occurrences)
elseif isa(op, Differential)
# The recursive expand_derivatives was not able to remove
# a nested Differential. We can attempt to differentiate the
Expand All @@ -218,20 +218,20 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing)
if isequal(op.x, D.x)
return D(arg)
else
inner = expand_derivatives(D(arguments(arg)[1]), false)
inner = expand_derivatives(D(arguments(arg)[1]), false; robust)
# if the inner expression is not expandable either, return
if iscall(inner) && operation(inner) isa Differential
return D(arg)
else
return expand_derivatives(op(inner), simplify)
return expand_derivatives(op(inner), simplify; robust)
end
end
elseif isa(op, Integral)
if isa(op.domain.domain, AbstractInterval)
domain = op.domain.domain
a, b = DomainSets.endpoints(domain)
c = 0
inner_function = expand_derivatives(arguments(arg)[1])
inner_function = expand_derivatives(arguments(arg)[1]; robust)
if iscall(value(a))
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a)))
t2 = D(a)
Expand All @@ -242,7 +242,7 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing)
t2 = D(b)
c += t1*t2
end
inner = expand_derivatives(D(arguments(arg)[1]))
inner = expand_derivatives(D(arguments(arg)[1]); robust)
c += op(inner)
return value(c)
end
Expand All @@ -254,7 +254,7 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing)
c = 0

for i in 1:l
t2 = expand_derivatives(D(inner_args[i]),false, occurrences=arguments(occurrences)[i])
t2 = expand_derivatives(D(inner_args[i]),false; robust, occurrences=arguments(occurrences)[i])

x = if _iszero(t2)
t2
Expand Down Expand Up @@ -286,23 +286,23 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing)
return simplify ? SymbolicUtils.simplify(x) : x
end
elseif iscall(O) && isa(operation(O), Integral)
return operation(O)(expand_derivatives(arguments(O)[1]))
return operation(O)(expand_derivatives(arguments(O)[1]; robust))
elseif !hasderiv(O)
return O
else
args = map(a->expand_derivatives(a, false), arguments(O))
args = map(a->expand_derivatives(a, false; robust), arguments(O))
O1 = operation(O)(args...)
return simplify ? SymbolicUtils.simplify(O1) : O1
end
end
function expand_derivatives(n::Num, simplify=false; occurrences=nothing)
wrap(expand_derivatives(value(n), simplify; occurrences=occurrences))
function expand_derivatives(n::Num, simplify=false; robust=false, occurrences=nothing)
wrap(expand_derivatives(value(n), simplify; robust, occurrences))
end
function expand_derivatives(n::Complex{Num}, simplify=false; occurrences=nothing)
wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify; occurrences=occurrences),
expand_derivatives(imag(n), simplify; occurrences=occurrences)))
function expand_derivatives(n::Complex{Num}, simplify=false; robust=false, occurrences=nothing)
wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify; robust, occurrences),
expand_derivatives(imag(n), simplify; robust, occurrences)))
end
expand_derivatives(x, simplify=false; occurrences=nothing) = x
expand_derivatives(x, simplify=false; robust=false, occurrences=nothing) = x

_iszero(x) = false
_isone(x) = false
Expand Down
30 changes: 30 additions & 0 deletions test/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,36 @@ let
@test isequal(expand_derivatives(Differential(t)(t^2 + im*t)), 2t + im)
end

# 1262
#
let
@variables t b(t)
D = Differential(t)
expr = b - ((D(b))^2) * D(D(b))
expr2 = D(expr)
@test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true))
@test_throws BoundsError expand_derivatives(expr2)
@test isequal(expand_derivatives(expr2; robust=true), D(b) - (D(b)^2)*D(D(D(b))) - 2D(b)*(D(D(b))^2))
end

# 1126
#
let
@syms y f(y) g(y) h(y)
D = Differential(y)

expr_gen = (fun) -> D(D(((-D(D(fun))) / g(y))))

expr = expr_gen(g(y))
@test_broken isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true))
expr = expr_gen(h(y))
@test_broken isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true))

expected = substitute(expand_derivatives(expr; robust=true), h(y) => f(y))
expr = expr_gen(f(y))
@test_throws BoundsError expand_derivatives(expr)
@test isequal(expand(expand_derivatives(expr; robust=true)), expected)
end

# Check `is_derivative` function
let
Expand Down

0 comments on commit 0ea30bf

Please sign in to comment.