Skip to content

Commit

Permalink
Check if user supplied old style gradients and Hessians and correct i…
Browse files Browse the repository at this point in the history
…t. (#6)

* Check if user supplied old style gradients and Hessians.

* Add tests.
  • Loading branch information
pkofod authored Apr 11, 2017
1 parent 368feba commit 076d706
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 13 deletions.
2 changes: 2 additions & 0 deletions src/NLSolversBase.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__precompile__(true)

module NLSolversBase

using Compat
Expand Down
84 changes: 73 additions & 11 deletions src/objective_types.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
@compat abstract type AbstractObjective end

function fix_order(storage_input, x_input, fun!, fun!_msg)
_storage = copy(storage_input)
_x = copy(x_input)
fun!(_storage, _x)
if _storage == storage_input && _x != x_input
warn("Storage (g) and evaluation point (x) order has changed. The order is now $(fun!_msg)(storage, x) as opposed to the old $(fun!_msg)(x, storage). Changing the order and proceeding, but please change your code to use the new syntax.")
return (storage, x) -> fun!(x, storage)
else
return (storage, x) -> fun!(storage, x)
end
end
# Used for objectives and solvers where no gradient is available/exists
type NonDifferentiable{T} <: AbstractObjective
f
f_x::T
last_x_f::Array{T}
f_calls::Vector{Int}
end
type UnitializedNonDifferentiable <: AbstractObjective
f
end
# The user friendly/short form NonDifferentiable constructor
NonDifferentiable(f) = UnitializedNonDifferentiable(f)
NonDifferentiable{T}(f, x_seed::Array{T}) = NonDifferentiable(f, f(x_seed), copy(x_seed), [1])

# Used for objectives and solvers where the gradient is available/exists
Expand All @@ -22,20 +37,37 @@ type OnceDifferentiable{T, Tgrad} <: AbstractObjective
f_calls::Vector{Int}
g_calls::Vector{Int}
end
type UnitializedOnceDifferentiable <: AbstractObjective
f
g!
fg!
end
# The user friendly/short form OnceDifferentiable constructor
function OnceDifferentiable(f, g!, fg!, x_seed)
OnceDifferentiable(f, g!, fg!) = UnitializedOnceDifferentiable(f, g!, fg!)
OnceDifferentiable(f, g!) = UnitializedOnceDifferentiable(f, g!, nothing)
OnceDifferentiable(f) = UnitializedOnceDifferentiable(f, nothing, nothing)

function OnceDifferentiable(f, g!, fg!, x_seed::AbstractArray)
g = similar(x_seed)
g!(g, x_seed)
OnceDifferentiable(f, g!, fg!, f(x_seed), g, copy(x_seed), copy(x_seed), [1], [1])

_new_g! = fix_order(g, x_seed, g!, "g!")
_new_fg! = fix_order(g, x_seed, fg!, "fg!")

f_val = _new_fg!(g, x_seed)
OnceDifferentiable(f, _new_g!, _new_fg!, f_val, g, copy(x_seed), copy(x_seed), [1], [1])
end

# Automatically create the fg! helper function if only f and g! is provided
function OnceDifferentiable(f, g!, x_seed)
function OnceDifferentiable(f, g!, x_seed::AbstractArray)
g = similar(x_seed)

_new_g! = fix_order(g, x_seed, g!, "g!")

function fg!(storage, x)
g!(storage, x)
_new_g!(storage, x)
return f(x)
end
return OnceDifferentiable(f, g!, fg!, x_seed)
return OnceDifferentiable(f, _new_g!, fg!, x_seed)
end

# Used for objectives and solvers where the gradient and Hessian is available/exists
Expand All @@ -54,14 +86,41 @@ type TwiceDifferentiable{T<:Real} <: AbstractObjective
g_calls::Vector{Int}
h_calls::Vector{Int}
end
type UnitializedTwiceDifferentiable <: AbstractObjective
f
g!
fg!
h!
end
TwiceDifferentiable(f, g!, h!) = UnitializedTwiceDifferentiable(f, g!, nothing, h!)
TwiceDifferentiable(f, g!) = UnitializedTwiceDifferentiable(f, g!, nothing, nothing)
TwiceDifferentiable(f) = UnitializedTwiceDifferentiable(f, nothing, nothing, nothing)
# The user friendly/short form TwiceDifferentiable constructor
function TwiceDifferentiable{T}(f, g!, fg!, h!, x_seed::Array{T})
n_x = length(x_seed)
g = similar(x_seed)
H = Array{T}(n_x, n_x)
g!(g, x_seed)
h!(H, x_seed)
TwiceDifferentiable(f, g!, fg!, h!, f(x_seed),

_new_g! = fix_order(g, x_seed, g!, "g!")
_new_fg! = fix_order(g, x_seed, fg!, "fg!")

local _new_h!
try
_H = copy(H)
_x = copy(x_seed)
h!(_H, _x)
_new_h! = (storage, x) -> h!(storage, x)
catch m
if isa(m, MethodError) || isa(m, BoundsError)
warn("Storage and evaluation point order has changed. The syntax is now h!(storage, x) as opposed to the old h!(x, storage). Your Hessian appears to have it the wrong way around. Changing the order and proceeding, but please change your code to use the new syntax.")
_new_h! = (storage, x) -> h!(x, storage)
end
end

f_val = _new_fg!(g, x_seed)
_new_h!(H, x_seed)

TwiceDifferentiable(f, _new_g!, _new_fg!, _new_h!, f_val,
g, H, copy(x_seed),
copy(x_seed), copy(x_seed), [1], [1], [1])
end
Expand All @@ -71,9 +130,12 @@ function TwiceDifferentiable{T}(f,
g!,
h!,
x_seed::Array{T})
g = similar(x_seed)
_new_g! = fix_order(g, x_seed, g!, "g!")

function fg!(storage::Vector, x::Vector)
g!(storage, x)
_new_g!(storage, x)
return f(x)
end
return TwiceDifferentiable(f, g!, fg!, h!, x_seed)
return TwiceDifferentiable(f, _new_g!, fg!, h!, x_seed)
end
45 changes: 45 additions & 0 deletions test/deprecations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
@testset "deprecations" begin

# Test example
function exponential(x::Vector)
return exp((2.0 - x[1])^2) + exp((3.0 - x[2])^2)
end

function exponential_gradient!(x, storage)
storage[1] = -2.0 * (2.0 - x[1]) * exp((2.0 - x[1])^2)
storage[2] = -2.0 * (3.0 - x[2]) * exp((3.0 - x[2])^2)
end

function exponential_hessian!(x, storage)
storage[1, 1] = 2.0 * exp((2.0 - x[1])^2) * (2.0 * x[1]^2 - 8.0 * x[1] + 9)
storage[1, 2] = 0.0
storage[2, 1] = 0.0
storage[2, 2] = 2.0 * exp((3.0 - x[1])^2) * (2.0 * x[2]^2 - 12.0 * x[2] + 19)
end

x_seed = [0.0, 0.0]
f_x_seed = 8157.682077608529

nd = NonDifferentiable(exponential, x_seed)
@test nd.f == exponential
@test value(nd) == f_x_seed
@test nd.last_x_f == [0.0, 0.0]
@test nd.f_calls == [1]

od = OnceDifferentiable(exponential, exponential_gradient!, x_seed)
@test od.f == exponential
#@test od.g! == exponential_gradient!
@test value(od) == f_x_seed
@test od.last_x_f == [0.0, 0.0]
@test od.f_calls == [1]
@test od.g_calls == [1]

td = TwiceDifferentiable(exponential, exponential_gradient!, exponential_hessian!, x_seed)
@test td.f == exponential
#@test td.g! == exponential_gradient!
@test value(td) == f_x_seed
@test td.last_x_f == [0.0, 0.0]
@test td.f_calls == [1]
@test td.g_calls == [1]
@test td.h_calls == [1]
end
4 changes: 2 additions & 2 deletions test/objective_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@

od = OnceDifferentiable(exponential, exponential_gradient!, x_seed)
@test od.f == exponential
@test od.g! == exponential_gradient!
#@test od.g! == exponential_gradient!
@test value(od) == f_x_seed
@test od.last_x_f == [0.0, 0.0]
@test od.f_calls == [1]
@test od.g_calls == [1]

td = TwiceDifferentiable(exponential, exponential_gradient!, exponential_hessian!, x_seed)
@test td.f == exponential
@test td.g! == exponential_gradient!
#@test td.g! == exponential_gradient!
@test value(td) == f_x_seed
@test td.last_x_f == [0.0, 0.0]
@test td.f_calls == [1]
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ using Base.Test

include("objective_types.jl")
include("interface.jl")
include("deprecations.jl")

0 comments on commit 076d706

Please sign in to comment.