Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more support for out-of-place trust-region solvers #66

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/NLSolvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ export Backtracking, Static, HZAW
export FFQuadInterp

include("globalization/trs_solvers/root.jl")
export NWI, Dogleg, NTR
export NWI, Dogleg, NTR, TCG

# Quasi-Newton (including Newton and gradient descent) functionality
include("quasinewton/quasinewton.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/globalization/trs_solvers/TRS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function (ms::TRSolver)(∇f, H, Δ, p)
x, info = trs(H, ∇f, Δ)
p .= x[:, 1]

m = dot(∇f, p) + dot(p, H * p) / 2
m = dot(∇f, p) + dot(p, H, p) / 2
interior = norm(p, 2) ≤ Δ
return (
p = p,
Expand Down
42 changes: 38 additions & 4 deletions src/globalization/trs_solvers/root.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,24 @@
abstract type TRSPSolver end
abstract type NearlyExactTRSP <: TRSPSolver end

trs_supports_outofplace(trs) = false

function trs_outofplace_check(trs,prob)
if !trs_supports_outofplace(trs)
throw(
ErrorException("solve() not defined for OutOfPlace() with $(typeof(trs).name.wrapper) for $(typeof(prob).name.wrapper)"),
)
end
end

include("solvers/NWI.jl")
include("solvers/Dogleg.jl")
include("solvers/NTR.jl")
include("solvers/TCG.jl")
#include("subproblemsolvers/TRS.jl") just make an example instead of relying onTRS.jl

function tr_return(; λ, ∇f, H, s, interior, solved, hard_case, Δ, m = nothing)
m = m isa Nothing ? dot(∇f, s) + dot(s, H * s) / 2 : m
m = m isa Nothing ? dot(∇f, s) + dot(s, H, s) / 2 : m
(
p = s,
mz = m,
Expand All @@ -23,13 +33,37 @@ function tr_return(; λ, ∇f, H, s, interior, solved, hard_case, Δ, m = nothin
)
end

function update_H!(H, h, λ = nothing)
update_H!(mstyle::OutOfPlace,H, h, λ) = _update_H(H, h, λ)
update_H!(mstyle::OutOfPlace,H, h) = _update_H(H, h, nothing)
update_H!(mstyle::InPlace,H, h, λ) = _update_H!(H, h, λ)
update_H!(mstyle::InPlace,H, h) = _update_H!(H, h, nothing)

function _update_H!(H, h, λ)
T = eltype(h)
n = length(h)
if !(λ == T(0))
if λ == nothing
for i = 1:n
@inbounds H[i, i] = λ isa Nothing ? h[i] : h[i] + λ
@inbounds H[i, i] = h[i]
end
elseif !(λ == T(0))
for i = 1:n
@inbounds H[i, i] = h[i] + λ
end
end
H
end

function _update_H(H, h, λ = nothing)
T = eltype(h)
if λ == nothing
h̄ = Diagonal(h)
H̄ = H - Diagonal(H)
return H̄ + h̄
elseif !(λ == T(0))
h̄ = Diagonal(h)
H̄ = H - Diagonal(H)
return H̄ + h̄ + λ*I
else
return H
end
end
4 changes: 3 additions & 1 deletion src/globalization/trs_solvers/solvers/Dogleg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ struct Dogleg{T} <: TRSPSolver
end
Dogleg() = Dogleg(nothing)

trs_supports_outofplace(trs::Dogleg) = true

function (dogleg::Dogleg)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxiter = 50)
T = eltype(p)
n = length(∇f)
Expand Down Expand Up @@ -80,7 +82,7 @@ function (dogleg::Dogleg)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxite
interior = false
end
end
m = dot(∇f, p) + dot(p, H * p) / 2
m = dot(∇f, p) + dot(p, H, p) / 2

return (
p = p,
Expand Down
62 changes: 39 additions & 23 deletions src/globalization/trs_solvers/solvers/NTR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,21 @@ function (ms::NTR)(
n = length(∇f)
h = H isa UniformScaling ? copy(∇f) .* 0 .+ 1 : diag(H)
H = H isa UniformScaling ? Diagonal(copy(∇f) .* 0 .+ 1) : H

inplace = mstyle == InPlace()
# Check for interior convergence
if λ == T(0)
F = cholesky(Symmetric(H); check = false)
s .= -∇f
s .= F \ s
if inplace
s .= -∇f
s .= F \ s
else
s = -∇f
s = F \ s
end
s₂ = norm(s, 2)

if issuccess(F) && s₂ < Δ
H = update_H!(H, h)
H = update_H!(mstyle, H, h)
return tr_return(;
λ = λ,
∇f = ∇f,
Expand All @@ -94,7 +99,7 @@ function (ms::NTR)(
λL, λU = isg.L, isg.U

for iter = 1:maxiter
H = update_H!(H, h, λ)
H = update_H!(mstyle, H, h, λ)
F = cholesky(Symmetric(H); check = false)
in𝓖, linpack = false, false
#===========================================================================
Expand All @@ -109,13 +114,17 @@ function (ms::NTR)(
# Algorithm 7.3.1 on p. 185 in [ConnGouldTointBook]
# Step 1 was factorizing
# Step 2
s .= -∇f
s .= F \ s

if inplace
s .= -∇f
s .= F \ s
else
s = -∇f
s = F \ s
end
# Check if step is approximately equal to the radius
s₂ = norm(s, 2)
if s₂ ≈ Δ
H = update_H!(H, h)
H = update_H!(mstyle, H, h)
return tr_return(;
λ = λ,
∇f = ∇f,
Expand Down Expand Up @@ -145,15 +154,18 @@ function (ms::NTR)(
if in𝓖
linpack = true
w, u = λL_with_linpack(F)
λL = max(λL, λ - dot(u, H * u))
λL = max(λL, λ - dot(u, H, u))

α, s_g, m_g = 𝓖_root(u, s, Δ, ∇f, H)
s .= s_g

if inplace
s .= s_g
else
s = s_g
end
s₂ = norm(s)
# check hard case convergnce
if α^2 * dot(u, H * u) ≤ κhard * (dot(s, H * s) + λ * Δ^2)
H = update_H!(H, h)
if α^2 * dot(u, H, u) ≤ κhard * (dot(s, H, s) + λ * Δ^2)
H = update_H!(mstyle, H, h)
return tr_return(;
λ = λ,
∇f = ∇f,
Expand All @@ -167,7 +179,7 @@ function (ms::NTR)(
)
end
# If not the hard case solution, try to factorize H(λ⁺)
H = update_H!(H, h, λ⁺)
H = update_H!(mstyle, H, h, λ⁺)
F = cholesky(H; check = false)
if issuccess(F) # Then we're in L, great! lemma 7.3.2
λ = λ⁺
Expand All @@ -180,7 +192,7 @@ function (ms::NTR)(

# check for convergence
if in𝓖 && abs(s₂ - Δ) ≤ κeasy * Δ
H = update_H!(H, h)
H = update_H!(mstyle, H, h)
return tr_return(;
λ = λ,
∇f = ∇f,
Expand All @@ -194,9 +206,13 @@ function (ms::NTR)(
elseif abs(s₂ - Δ) ≤ κeasy * Δ # implicitly "if in 𝓕" since we're in that branch
# u and α comes from linpack
if linpack
if α^2 * dot(u, H * u) ≤ κhard * (dot(sλ, H * sλ) * Δ^2)
s .= s .+ α * u
H = update_H!(H, h)
if α^2 * dot(u, H, u) ≤ κhard * (dot(sλ, H, sλ) * Δ^2)
if inplace
s .= s .+ α * u
else
s = s + α * u
end
H = update_H!(mstyle, H, h)
return tr_return(;
λ = λ,
∇f = ∇f,
Expand All @@ -216,10 +232,10 @@ function (ms::NTR)(
# lower bound, we cannot apply the Newton step here.
δ, v = λL_in_𝓝(H, F)
λL = max(λL, λ + δ / dot(v, v)) # update lower bound
λ = max(sqrt(λL * λU), λL + θ * (λU - λL)) # no converence possible, so step in bracket
λ = max(sqrt(λL * λU), λL + θ * (λU - λL)) # no convergence possible, so step in bracket
end
end
H = update_H!(H, h)
H = update_H!(mstyle, H, h)
tr_return(;
λ = λ,
∇f = ∇f,
Expand Down Expand Up @@ -272,9 +288,9 @@ function 𝓖_root(u, s, Δ, ∇f, H)
α₂ = (-pb - pd) / 2pa

s₁ = s + α₁ * u
m₁ = dot(∇f, s₁) + dot(s₁, H * s₁) / 2
m₁ = dot(∇f, s₁) + dot(s₁, H, s₁) / 2
s₂ = s + α₂ * u
m₂ = dot(∇f, s₂) + dot(s₂, H * s₂) / 2
m₂ = dot(∇f, s₂) + dot(s₂, H, s₂) / 2
α, s, m = m₁ ≤ m₂ ? (α₁, s₁, m₁) : (α₂, s₂, m₂)
α, s, m
end
49 changes: 32 additions & 17 deletions src/globalization/trs_solvers/solvers/NWI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct NWI{T} <: NearlyExactTRSP
end
NWI() = NWI(eigen)
summary(::NWI) = "Trust Region (Newton, eigen)"

trs_supports_outofplace(trs::NWI) = true
"""
initial_safeguards(B, h, g, Δ)

Expand Down Expand Up @@ -112,19 +112,27 @@ function is_maybe_hard_case(QΛQ, Qt∇f::AbstractVector{T}) where {T}
end

# Equation 4.38 in N&W (2006)
calc_p!(p, Qt∇f, QΛQ, λ) = calc_p!(p, Qt∇f, QΛQ, λ, 1)
calc_p!(mstyle::MutateStyle, p, Qt∇f, QΛQ, λ) = calc_p!(mstyle, p, Qt∇f, QΛQ, λ, 1)

# Equation 4.45 in N&W (2006) since we allow for first_j > 1
function calc_p!(p, Qt∇f, QΛQ, λ::T, first_j) where {T}
function calc_p!(mstyle::MutateStyle, p, Qt∇f, QΛQ, λ::T, first_j) where {T}
inplace = mstyle === InPlace()
# Reset search direction to 0
fill!(p, T(0))

p = if inplace
fill!(p, T(0))
else
T(0) .* p
end
# Unpack eigenvalues and eigenvectors
Λ = QΛQ.values
Q = QΛQ.vectors
for j = first_j:length(Λ)
κ = Qt∇f[j] / (Λ[j] + λ)
@. p = p - κ * Q[:, j]
if inplace
@. p = p - κ * Q[:, j]
else
p = p .- κ .* Q[:, j]
end
end
p
end
Expand Down Expand Up @@ -153,7 +161,7 @@ function (ms::NWI)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxiter = 50)
n = length(∇f)
H = H isa UniformScaling ? Diagonal(copy(∇f) .* 0 .+ 1) : H
h = diag(H)

inplace = mstyle == InPlace()
# Note that currently the eigenvalues are only sorted if H is perfectly
# symmetric. (Julia issue #17093)
if H isa Diagonal
Expand All @@ -176,14 +184,14 @@ function (ms::NWI)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxiter = 50)
# positive, so the Newton step, pN, is fine unless norm(pN, 2) > Δ.
if λmin >= sqrt(eps(T))
λ = T(0) # no amount of I is added yet
p = calc_p!(p, Qt∇f, QΛQ, λ) # calculate the Newton step
p = calc_p!(mstyle, p, Qt∇f, QΛQ, λ) # calculate the Newton step
if norm(p, 2) ≤ Δ
# No shrinkage is necessary: -(H \ ∇f) is the minimizer
interior = true
solved = true
hard_case = false

m = dot(∇f, p) + dot(p, H * p) / 2
m = dot(∇f, p) + dot(p, H, p) / 2

return (
p = p,
Expand Down Expand Up @@ -218,7 +226,7 @@ function (ms::NWI)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxiter = 50)

# The old p is discarded, and replaced with one that takes into account
# the first j such that λj ≠ λmin. Formula 4.45 in N&W (2006)
pλ = calc_p!(p, Qt∇f, QΛQ, λ, first_j)
pλ = calc_p!(mstyle, p, Qt∇f, QΛQ, λ, first_j)

# Check if the choice of λ leads to a solution inside the trust region.
# If it does, then we construct the "hard case solution".
Expand All @@ -228,9 +236,12 @@ function (ms::NWI)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxiter = 50)

tau = sqrt(Δ^2 - norm(pλ, 2)^2)

@. p = -pλ + tau * Q[:, 1]

m = dot(∇f, p) + dot(p, H * p) / 2
if inplace
@. p = -pλ + tau * Q[:, 1]
else
p = tau .* Q[:, 1] .- pλ
end
m = dot(∇f, p) + dot(p, H, p) / 2

return (
p = p,
Expand All @@ -257,7 +268,7 @@ function (ms::NWI)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxiter = 50)
λ = safeguard_λ(λ, isg)
for iter = 1:maxiter
λ_previous = λ
H = update_H!(H, h, λ)
H = update_H!(mstyle, H, h, λ)

F =
H isa Diagonal ? cholesky(H; check = false) :
Expand All @@ -271,7 +282,11 @@ function (ms::NWI)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxiter = 50)
continue
end
R = F.U
p .= R \ (R' \ -∇f)
if inplace
p .= R \ (R' \ -∇f)
else
p = R \ (R' \ -∇f)
end
q_l = R' \ p

p_norm = norm(p, 2)
Expand All @@ -289,8 +304,8 @@ function (ms::NWI)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxiter = 50)
end
end

H = update_H!(H, h)
m = dot(∇f, p) + dot(p, H * p) / 2
H = update_H!(mstyle, H, h)
m = dot(∇f, p) + dot(p, H, p) / 2
return (
p = p,
mz = m,
Expand Down
Loading
Loading