Skip to content

Commit

Permalink
Merge pull request #1353 from karlwessel/master
Browse files Browse the repository at this point in the history
add flag for activating robust calculation of expand_derivatives
  • Loading branch information
ChrisRackauckas authored Dec 22, 2024
2 parents ab3fcd6 + 3bf685a commit a119204
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 111 deletions.
240 changes: 129 additions & 111 deletions src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,127 @@ function recursive_hasoperator(op, O)
end
end

"""
executediff(D, arg, simplify=false; occurrences=nothing)
Apply the passed Differential D on the passed argument.
This function differs to `expand_derivatives` in that in only expands the
passed differential and not any other Differentials it encounters.
# Arguments
- `D::Differential`: The differential to apply
- `arg::Symbolic`: The symbolic expression to apply the differential on.
- `simplify::Bool=false`: Whether to simplify the resulting expression using
[`SymbolicUtils.simplify`](@ref).
- `occurrences=nothing`: Information about the occurrences of the independent
variable in the argument of the derivative. This is used internally for
optimization purposes.
"""
function executediff(D, arg, simplify=false; occurrences=nothing)
if occurrences == nothing
occurrences = occursin_info(D.x, arg)
end

_isfalse(occurrences) && return 0
occurrences isa Bool && return 1 # means it's a `true`

if !iscall(arg)
return D(arg) # Cannot expand
elseif (op = operation(arg); issym(op))
inner_args = arguments(arg)
if any(isequal(D.x), inner_args)
return D(arg) # base case if any argument is directly equal to the i.v.
else
return sum(inner_args, init=0) do a
return executediff(Differential(a), arg) *
executediff(D, a)
end
end
elseif op === (IfElse.ifelse)
args = arguments(arg)
O = op(args[1],
executediff(D, args[2], simplify; occurrences=arguments(occurrences)[2]),
executediff(D, args[3], simplify; occurrences=arguments(occurrences)[3]))
return O
elseif isa(op, Differential)
# The recursive expand_derivatives was not able to remove
# a nested Differential. We can attempt to differentiate the
# inner expression wrt to the outer iv. And leave the
# unexpandable Differential outside.
if isequal(op.x, D.x)
return D(arg)
else
inner = executediff(D, arguments(arg)[1], false)
# if the inner expression is not expandable either, return
if iscall(inner) && operation(inner) isa Differential
return D(arg)
else
# otherwise give the nested Differential another try
return executediff(op, inner, simplify)
end
end
elseif isa(op, Integral)
if isa(op.domain.domain, AbstractInterval)
domain = op.domain.domain
a, b = DomainSets.endpoints(domain)
c = 0
inner_function = arguments(arg)[1]
if iscall(value(a))
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a)))
t2 = D(a)
c -= t1*t2
end
if iscall(value(b))
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(b)))
t2 = D(b)
c += t1*t2
end
inner = executediff(D, arguments(arg)[1])
c += op(inner)
return value(c)
end
end

inner_args = arguments(arg)
l = length(inner_args)
exprs = []
c = 0

for i in 1:l
t2 = executediff(D, inner_args[i],false; occurrences=arguments(occurrences)[i])

x = if _iszero(t2)
t2
elseif _isone(t2)
d = derivative_idx(arg, i)
d isa NoDeriv ? D(arg) : d
else
t1 = derivative_idx(arg, i)
t1 = t1 isa NoDeriv ? D(arg) : t1
t1 * t2
end

if _iszero(x)
continue
elseif x isa Symbolic
push!(exprs, x)
else
c += x
end
end

if isempty(exprs)
return c
elseif length(exprs) == 1
term = (simplify ? SymbolicUtils.simplify(exprs[1]) : exprs[1])
return _iszero(c) ? term : c + term
else
x = +((!_iszero(c) ? vcat(c, exprs) : exprs)...)
return simplify ? SymbolicUtils.simplify(x) : x
end
end

"""
$(SIGNATURES)
Expand All @@ -162,9 +283,6 @@ and other derivative rules to expand any derivatives it encounters.
- `O::Symbolic`: The symbolic expression to expand.
- `simplify::Bool=false`: Whether to simplify the resulting expression using
[`SymbolicUtils.simplify`](@ref).
- `occurrences=nothing`: Information about the occurrences of the independent
variable in the argument of the derivative. This is used internally for
optimization purposes.
# Examples
```jldoctest
Expand All @@ -180,111 +298,11 @@ julia> dfx = expand_derivatives(Dx(f))
(k*((2abs(x - y)) / y - 2z)*IfElse.ifelse(signbit(x - y), -1, 1)) / y
```
"""
function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing)
function expand_derivatives(O::Symbolic, simplify=false)
if iscall(O) && isa(operation(O), Differential)
arg = only(arguments(O))
arg = expand_derivatives(arg, false)

if occurrences == nothing
occurrences = occursin_info(operation(O).x, arg)
end

_isfalse(occurrences) && return 0
occurrences isa Bool && return 1 # means it's a `true`

D = operation(O)

if !iscall(arg)
return D(arg) # Cannot expand
elseif (op = operation(arg); issym(op))
inner_args = arguments(arg)
if any(isequal(D.x), inner_args)
return D(arg) # base case if any argument is directly equal to the i.v.
else
return sum(inner_args, init=0) do a
return expand_derivatives(Differential(a)(arg)) *
expand_derivatives(D(a))
end
end
elseif op === (IfElse.ifelse)
args = arguments(arg)
O = op(args[1], D(args[2]), D(args[3]))
return expand_derivatives(O, simplify; occurrences)
elseif isa(op, Differential)
# The recursive expand_derivatives was not able to remove
# a nested Differential. We can attempt to differentiate the
# inner expression wrt to the outer iv. And leave the
# unexpandable Differential outside.
if isequal(op.x, D.x)
return D(arg)
else
inner = expand_derivatives(D(arguments(arg)[1]), false)
# if the inner expression is not expandable either, return
if iscall(inner) && operation(inner) isa Differential
return D(arg)
else
return expand_derivatives(op(inner), simplify)
end
end
elseif isa(op, Integral)
if isa(op.domain.domain, AbstractInterval)
domain = op.domain.domain
a, b = DomainSets.endpoints(domain)
c = 0
inner_function = expand_derivatives(arguments(arg)[1])
if iscall(value(a))
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a)))
t2 = D(a)
c -= t1*t2
end
if iscall(value(b))
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(b)))
t2 = D(b)
c += t1*t2
end
inner = expand_derivatives(D(arguments(arg)[1]))
c += op(inner)
return value(c)
end
end

inner_args = arguments(arg)
l = length(inner_args)
exprs = []
c = 0

for i in 1:l
t2 = expand_derivatives(D(inner_args[i]),false, occurrences=arguments(occurrences)[i])

x = if _iszero(t2)
t2
elseif _isone(t2)
d = derivative_idx(arg, i)
d isa NoDeriv ? D(arg) : d
else
t1 = derivative_idx(arg, i)
t1 = t1 isa NoDeriv ? D(arg) : t1
t1 * t2
end

if _iszero(x)
continue
elseif x isa Symbolic
push!(exprs, x)
else
c += x
end
end

if isempty(exprs)
return c
elseif length(exprs) == 1
term = (simplify ? SymbolicUtils.simplify(exprs[1]) : exprs[1])
return _iszero(c) ? term : c + term
else
x = +((!_iszero(c) ? vcat(c, exprs) : exprs)...)
return simplify ? SymbolicUtils.simplify(x) : x
end
return executediff(operation(O), arg, simplify)
elseif iscall(O) && isa(operation(O), Integral)
return operation(O)(expand_derivatives(arguments(O)[1]))
elseif !hasderiv(O)
Expand All @@ -295,14 +313,14 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing)
return simplify ? SymbolicUtils.simplify(O1) : O1
end
end
function expand_derivatives(n::Num, simplify=false; occurrences=nothing)
wrap(expand_derivatives(value(n), simplify; occurrences=occurrences))
function expand_derivatives(n::Num, simplify=false)
wrap(expand_derivatives(value(n), simplify))
end
function expand_derivatives(n::Complex{Num}, simplify=false; occurrences=nothing)
wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify; occurrences=occurrences),
expand_derivatives(imag(n), simplify; occurrences=occurrences)))
function expand_derivatives(n::Complex{Num}, simplify=false)
wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify),
expand_derivatives(imag(n), simplify)))
end
expand_derivatives(x, simplify=false; occurrences=nothing) = x
expand_derivatives(x, simplify=false) = x

_iszero(x) = false
_isone(x) = false
Expand Down
28 changes: 28 additions & 0 deletions test/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,34 @@ let
@test isequal(expand_derivatives(Differential(t)(t^2 + im*t)), 2t + im)
end

# 1262
#
let
@variables t b(t)
D = Differential(t)
expr = b - ((D(b))^2) * D(D(b))
expr2 = D(expr)
@test isequal(expand_derivatives(expr), expr)
@test isequal(expand_derivatives(expr2), D(b) - (D(b)^2)*D(D(D(b))) - 2D(b)*(D(D(b))^2))
end

# 1126
#
let
@syms y f(y) g(y) h(y)
D = Differential(y)

expr_gen = (fun) -> D(D(((-D(D(fun))) / g(y))))

expr = expr_gen(g(y))
# just make sure that no errors are thrown in the following, the results are to complicated to compare
expand_derivatives(expr)
expr = expr_gen(h(y))
expand_derivatives(expr)

expr = expr_gen(f(y))
expand_derivatives(expr)
end

# Check `is_derivative` function
let
Expand Down

0 comments on commit a119204

Please sign in to comment.