Skip to content

Commit

Permalink
Faster finite forward. (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
pkofod authored Feb 13, 2019
1 parent 46fd262 commit 99c1038
Showing 1 changed file with 49 additions and 15 deletions.
64 changes: 49 additions & 15 deletions src/objective_types/oncedifferentiable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ function OnceDifferentiable(f, x_seed::AbstractArray{T},
F::Real,
DF::AbstractArray,
autodiff) where T

# When here, at the constructor with positional autodiff, it should already
# be the case, that f is inplace.
if typeof(f) <: Union{InplaceObjective, NotInplaceObjective}

fF = make_f(f, x_seed, F)
Expand All @@ -46,7 +47,10 @@ function OnceDifferentiable(f, x_seed::AbstractArray{T},
# Figure out which Val-type to use for DiffEqDiffTools based on our
# symbol interface.
fdtype = diffeqdiff_fdtype(autodiff)
gcache = DiffEqDiffTools.GradientCache(x_seed, x_seed, fdtype)
df_array_spec = DF
x_array_spec = x_seed
return_spec = typeof(F)
gcache = DiffEqDiffTools.GradientCache(df_array_spec, x_array_spec, fdtype, return_spec)

function g!(storage, x)
DiffEqDiffTools.finite_difference_gradient!(storage, f, x, gcache)
Expand Down Expand Up @@ -86,36 +90,66 @@ function OnceDifferentiable(f, x::AbstractArray, F::AbstractArray,
end
OnceDifferentiable(f, x, F, alloc_DF(x, F), :forward, chunk)
end
function OnceDifferentiable(f, x::AbstractArray, F::AbstractArray, DF::AbstractArray,
autodiff::Symbol , chunk::ForwardDiff.Chunk = ForwardDiff.Chunk(x))
function OnceDifferentiable(f, x_seed::AbstractArray, F::AbstractArray, DF::AbstractArray,
autodiff::Symbol , chunk::ForwardDiff.Chunk = ForwardDiff.Chunk(x_seed))
if typeof(f) <: Union{InplaceObjective, NotInplaceObjective}
fF = make_f(f, x, F)
dfF = make_df(f, x, F)
fdfF = make_fdf(f, x, F)
return OnceDifferentiable(fF, dfF, fdfF, x, F, DF)
fF = make_f(f, x_seed, F)
dfF = make_df(f, x_seed, F)
fdfF = make_fdf(f, x_seed, F)
return OnceDifferentiable(fF, dfF, fdfF, x_seed, F, DF)
else
if is_finitediff(autodiff)

# Figure out which Val-type to use for DiffEqDiffTools based on our
# symbol interface.
fdtype = diffeqdiff_fdtype(autodiff)
# Apparently only the third input is aliased.
j_diffeqdiff_cache = DiffEqDiffTools.JacobianCache(similar(x_seed), similar(F), similar(F), fdtype)
if autodiff == :finiteforward
# These copies can be done away with if we add a keyword for
# reusing arrays instead for overwriting them.
Fx = copy(F)
DF = copy(DF)

x_f, x_df = x_of_nans(x_seed), x_of_nans(x_seed)
f_calls, j_calls = [0,], [0,]
function j_finiteforward!(J, x)
# Exploit the possibility that it might be that x_f == x
# then we don't have to call f again.

# if at least one element of x_f is different from x, update
if any(x_f .!= x)
Fx = similar(Fx)
f(Fx, x)
f_calls .+= 1
end

DiffEqDiffTools.finite_difference_jacobian!(J, f, x, j_diffeqdiff_cache, Fx)
end
function fj_finiteforward!(F, J, x)
f(F, x)
DiffEqDiffTools.finite_difference_jacobian!(J, f, x, j_diffeqdiff_cache, F)
end


return OnceDifferentiable(f, j_finiteforward!, fj_finiteforward!, copy(F), copy(DF), x_f, x_df, f_calls, j_calls)
end

central_cache = DiffEqDiffTools.JacobianCache(similar(x), similar(F), similar(F), fdtype)
function fj_finitediff!(F, J, x)
f(F, x)
DiffEqDiffTools.finite_difference_jacobian!(J, f, x, central_cache)
DiffEqDiffTools.finite_difference_jacobian!(J, f, x, j_diffeqdiff_cache)
F
end
function j_finitediff!(J, x)
F_cache = similar(F)
fj_finitediff!(F_cache, J, x)
end
return OnceDifferentiable(f, j_finitediff!, fj_finitediff!, x, F, DF)

return OnceDifferentiable(f, j_finitediff!, fj_finitediff!, x_seed, F, DF)

elseif is_forwarddiff(autodiff)

jac_cfg = ForwardDiff.JacobianConfig(f, F, x, chunk)
ForwardDiff.checktag(jac_cfg, f, x)
jac_cfg = ForwardDiff.JacobianConfig(f, F, x_seed, chunk)
ForwardDiff.checktag(jac_cfg, f, x_seed)

F2 = copy(F)
function j_forwarddiff!(J, x)
Expand All @@ -127,7 +161,7 @@ function OnceDifferentiable(f, x::AbstractArray, F::AbstractArray, DF::AbstractA
DiffResults.value(jac_res)
end

return OnceDifferentiable(f, j_forwarddiff!, fj_forwarddiff!, x, F, DF)
return OnceDifferentiable(f, j_forwarddiff!, fj_forwarddiff!, x_seed, F, DF)
else
error("The autodiff value $(autodiff) is not supported. Use :finite or :forward.")
end
Expand Down

0 comments on commit 99c1038

Please sign in to comment.