Skip to content

Commit

Permalink
MOI vector lambda and iteration fixes in Optimisers
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Oct 27, 2024
1 parent c526d71 commit 47a2481
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 22 deletions.
2 changes: 1 addition & 1 deletion lib/OptimizationMOI/src/nlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ function MOI.eval_hessian_lagrangian(evaluator::MOIOptimizationNLPEvaluator{T},
σ,
μ) where {T}
if evaluator.f.lag_h !== nothing
evaluator.f.lag_h(h, x, σ, μ)
evaluator.f.lag_h(h, x, σ, Vector(μ))
return
end
if evaluator.f.hess === nothing
Expand Down
53 changes: 33 additions & 20 deletions lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,27 @@ function SciMLBase.__solve(cache::OptimizationCache{
P,
C
}
maxiters = if cache.solver_args.epochs === nothing
if OptimizationBase.isa_dataiterator(cache.p)
data = cache.p
dataiterate = true
else
data = [cache.p]
dataiterate = false
end

epochs = if cache.solver_args.epochs === nothing
if cache.solver_args.maxiters === nothing
throw(ArgumentError("The number of epochs must be specified with either the epochs or maxiters kwarg."))
throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data)."))
else
cache.solver_args.maxiters
cache.solver_args.maxiters / length(data)
end
else
cache.solver_args.epochs
end

maxiters = Optimization._check_and_convert_maxiters(maxiters)
if maxiters === nothing
throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg."))
end

if OptimizationBase.isa_dataiterator(cache.p)
data = cache.p
dataiterate = true
else
data = [cache.p]
dataiterate = false
epochs = Optimization._check_and_convert_maxiters(epochs)
if epochs === nothing
throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data)."))
end

opt = cache.opt
Expand All @@ -75,21 +75,35 @@ function SciMLBase.__solve(cache::OptimizationCache{
min_θ = cache.u0

state = Optimisers.setup(opt, θ)

iterations = 0
fevals = 0
gevals = 0
t0 = time()
Optimization.@withprogress cache.progress name="Training" begin
for epoch in 1:maxiters
for epoch in 1:epochs
for (i, d) in enumerate(data)
if cache.f.fg !== nothing && dataiterate
x = cache.f.fg(G, θ, d)
iterations += 1
fevals += 1
gevals += 1
elseif dataiterate
cache.f.grad(G, θ, d)
x = cache.f(θ, d)
iterations += 1
fevals += 2
gevals += 1
elseif cache.f.fg !== nothing
x = cache.f.fg(G, θ)
iterations += 1
fevals += 1
gevals += 1
else
cache.f.grad(G, θ)
x = cache.f(θ)
iterations += 1
fevals += 2
gevals += 1
end
opt_state = Optimization.OptimizationState(
iter = i + (epoch - 1) * length(data),
Expand All @@ -112,7 +126,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
min_err = x
min_θ = copy(θ)
end
if i == maxiters #Last iter, revert to best.
if i == length(data) #Last iter, revert to best.
opt = min_opt
x = min_err
θ = min_θ
Expand All @@ -132,10 +146,9 @@ function SciMLBase.__solve(cache::OptimizationCache{
end

t1 = time()
stats = Optimization.OptimizationStats(; iterations = maxiters,
time = t1 - t0, fevals = maxiters, gevals = maxiters)
stats = Optimization.OptimizationStats(; iterations,
time = t1 - t0, fevals, gevals)
SciMLBase.build_solution(cache, cache.opt, θ, first(x)[1], stats = stats)
# here should be build_solution to create the output message
end

end
7 changes: 7 additions & 0 deletions lib/OptimizationOptimisers/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ using Zygote

sol = solve(prob, Optimisers.Adam(), maxiters = 1000)
@test 10 * sol.objective < l1
@test sol.stats.iterations == 1000
@test sol.stats.fevals == 1000
@test sol.stats.gevals == 1000

@testset "cache" begin
objective(x, p) = (p[1] - x[1])^2
Expand Down Expand Up @@ -99,6 +102,10 @@ end
res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000)

@test res.objective < 1e-4
@test res.stats.iterations == 10000*length(data)
@test res.stats.fevals == 10000*length(data)
@test res.stats.gevals == 10000*length(data)


using MLDataDevices
data = CPUDevice()(data)
Expand Down
3 changes: 2 additions & 1 deletion src/sophia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ function SciMLBase.__solve(cache::OptimizationCache{
cache.f.grad(gₜ, θ)
x = cache.f(θ)
end
opt_state = Optimization.OptimizationState(; iter = i + (epoch - 1) * length(data),
opt_state = Optimization.OptimizationState(;
iter = i + (epoch - 1) * length(data),
u = θ,
objective = first(x),
grad = gₜ,
Expand Down

0 comments on commit 47a2481

Please sign in to comment.