Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assume unknown functions are non-linear in hessian_sparsity #1384

Merged
merged 5 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 7 additions & 19 deletions src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -646,24 +646,13 @@ let
linearity_rules = [
@rule +(~~xs) => reduce(+, filter(isidx, ~~xs), init=_scalar)
@rule *(~~xs) => reduce(*, filter(isidx, ~~xs), init=_scalar)
@rule (~f)(~x::(!isidx)) => _scalar

@rule (~f)(~x::isidx) => if haslinearity_1(~f)
combine_terms_1(linearity_1(~f), ~x)
else
error("Function of unknown linearity used: ", ~f)
end
@rule (~f)(~x) => isidx(~x) ? combine_terms_1(linearity_1(~f), ~x) : _scalar
@rule (^)(~x::isidx, ~y) => ~y isa Number && isone(~y) ? ~x : (~x) * (~x)
@rule (~f)(~x, ~y) => begin
if haslinearity_2(~f)
a = isidx(~x) ? ~x : _scalar
b = isidx(~y) ? ~y : _scalar
combine_terms_2(linearity_2(~f), a, b)
else
error("Function of unknown linearity used: ", ~f)
end
end
@rule ~x::issym => 0]
@rule (~f)(~x, ~y) => combine_terms_2(linearity_2(~f), isidx(~x) ? ~x : _scalar, isidx(~y) ? ~y : _scalar)

@rule ~x::issym => 0
]
linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); maketerm=basic_mkterm))

global hessian_sparsity
Expand Down Expand Up @@ -696,9 +685,8 @@ let
@assert !(expr isa AbstractArray)
expr = value(expr)
u = map(value, vars)
idx(i) = TermCombination(Set([Dict(i=>1)]))
dict = Dict(u .=> idx.(1:length(u)))
f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; maketerm=basic_mkterm)(expr)
dict = Dict(ui => TermCombination(Set([Dict(i=>1)])) for (i, ui) in enumerate(u))
f = Rewriters.Prewalk(x-> get(dict, x, x); maketerm=basic_mkterm)(expr)
lp = linearity_propagator(f)
S = _sparse(lp, length(u))
S = full ? S : tril(S)
Expand Down
22 changes: 3 additions & 19 deletions src/linearity.jl
Original file line number Diff line number Diff line change
@@ -1,61 +1,45 @@
using SpecialFunctions
import Base.Broadcast


const linearity_known_1 = IdDict{Function,Bool}()
const linearity_known_2 = IdDict{Function,Bool}()

const linearity_map_1 = IdDict{Function, Bool}()
const linearity_map_2 = IdDict{Function, Tuple{Bool, Bool, Bool}}()

# 1-arg

const monadic_linear = [deg2rad, +, rad2deg, transpose, -, conj]

const monadic_nonlinear = [asind, log1p, acsch, erfc, digamma, acos, asec, acosh, airybiprime, acsc, cscd, log, tand, log10, csch, asinh, airyai, abs2, gamma, lgamma, erfcx, bessely0, cosh, sin, cos, atan, cospi, cbrt, acosd, bessely1, acoth, erfcinv, erf, dawson, inv, acotd, airyaiprime, erfinv, trigamma, asecd, besselj1, exp, acot, sqrt, sind, sinpi, asech, log2, tan, invdigamma, airybi, exp10, sech, erfi, coth, asin, cotd, cosd, sinh, abs, besselj0, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh, slog, ssqrt, scbrt]

# We store 3 bools even for 1-arg functions for type stability
const three_trues = (true, true, true)
for f in monadic_linear
linearity_known_1[f] = true
linearity_map_1[f] = true
end

for f in monadic_nonlinear
linearity_known_1[f] = true
linearity_map_1[f] = false
end

# 2-arg
for f in [+, rem2pi, -, >, isless, <, isequal, max, min, convert, <=, >=]
linearity_known_2[f] = true
linearity_map_2[f] = (true, true, true)
end

for f in [*]
linearity_known_2[f] = true
linearity_map_2[f] = (true, true, false)
end

for f in [/]
linearity_known_2[f] = true
linearity_map_2[f] = (true, false, false)
end
for f in [\]
linearity_known_2[f] = true
linearity_map_2[f] = (false, true, false)
end

for f in [hypot, atan, mod, rem, lbeta, ^, beta]
linearity_known_2[f] = true
linearity_map_2[f] = (false, false, false)
end

haslinearity_1(@nospecialize(f)) = get(linearity_known_1, f, false)
haslinearity_2(@nospecialize(f)) = get(linearity_known_2, f, false)

linearity_1(@nospecialize(f)) = linearity_map_1[f]
linearity_2(@nospecialize(f)) = linearity_map_2[f]
# Fallback assumption: Function is not linear, i.e., derivatives are non-zero
linearity_1(@nospecialize(f)) = get(linearity_map_1, f, false)
linearity_2(@nospecialize(f)) = get(linearity_map_2, f, (false, false, false))

# TermCombination datastructure

Expand Down
113 changes: 113 additions & 0 deletions test/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,116 @@ let
@test isequal(expand_derivatives(D(Symbolics.scbrt(1 + x ^ 2))), simplify((2x) / (3Symbolics.scbrt(1 + x^2)^2)))
@test isequal(expand_derivatives(D(Symbolics.slog(1 + x ^ 2))), simplify((2x) / (1 + x ^ 2)))
end

# Hessian sparsity involving unknown functions
let
@variables x₁ x₂ p q[1:1]
expr = 3x₁^2 + 4x₁ * x₂
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]

expr = 3x₁^2 + 4x₁ * x₂ + p
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]

# issue 643: example test2_num
expr = 3x₁^2 + 4x₁ * x₂ + q[1]
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]

# Custom function: By default assumed to be non-linear
myexp(x) = exp(x)
@register_symbolic myexp(x)
expr = 3x₁^2 + 4x₁ * x₂ + myexp(p)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + myexp(x₂)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true true]

mylogaddexp(x, y) = log(exp(x) + exp(y))
@register_symbolic mylogaddexp(x, y)
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(p, 2)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(3, p)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(p, 2)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(p, q[1])
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(p, x₂)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true true]
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(x₂, 4)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true true]

# Custom linear function: Possible to extend `Symbolics.linearity_1`/`Symbolics.linearity_2`
myidentity(x) = x
@register_symbolic myidentity(x)
Symbolics.linearity_1(::typeof(myidentity)) = true
expr = 3x₁^2 + 4x₁ * x₂ + myidentity(p)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + myidentity(q[1])
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + myidentity(x₂)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]

mymul1plog(x, y) = x * (1 + log(y))
@register_symbolic mymul1plog(x, y)
Symbolics.linearity_2(::typeof(mymul1plog)) = (true, false, false)
expr = 3x₁^2 + 4x₁ * x₂ + mymul1plog(p, q[1])
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + mymul1plog(x₂, q[1])
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + mymul1plog(q[1], x₂)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true true]
end

# issue #555
let
# first example
@variables p[1:1] x[1:1]
p = collect(p)
x = collect(x)
@test collect(Symbolics.sparsehessian(p[1] * x[1], x)) == [0;;]
@test isequal(collect(Symbolics.sparsehessian(p[1] * x[1]^2, x)), [2p[1];;])

# second example
@variables a[1:2]
a = collect(a)
ex = (a[1]+a[2])^2
@test Symbolics.hessian(ex, [a[1]]) == [2;;]
@test collect(Symbolics.sparsehessian(ex, [a[1]])) == [2;;]
@test collect(Symbolics.sparsehessian(ex, a)) == fill(2, 2, 2)
end

# issue #847
let
@variables x[1:2] y[1:2]
x = Symbolics.scalarize(x)
y = Symbolics.scalarize(y)

z = (x[1] + x[2]) * (y[1] + y[2])
@test Symbolics.islinear(z, x)
@test Symbolics.isaffine(z, x)

z = (x[1] + x[2])
@test Symbolics.islinear(z, x)
@test Symbolics.isaffine(z, x)
end

# issue #790
let
c(x) = [sum(x) - 1]
@variables xs[1:2] ys[1:1]
w = Symbolics.scalarize(xs)
v = Symbolics.scalarize(ys)
expr = dot(v, c(w))
@test !Symbolics.islinear(expr, w)
@test Symbolics.isaffine(expr, w)
@test collect(Symbolics.hessian_sparsity(expr, w)) == fill(false, 2, 2)
end

# issue #749
let
@variables x y
@register_symbolic Base.FastMath.exp_fast(x, y)
expr = Base.FastMath.exp_fast(x, y)
@test !Symbolics.islinear(expr, [x, y])
@test !Symbolics.isaffine(expr, [x, y])
@test collect(Symbolics.hessian_sparsity(expr, [x, y])) == fill(true, 2, 2)
end
Loading