diff --git a/lib/OptimizationMOI/src/nlp.jl b/lib/OptimizationMOI/src/nlp.jl index dbfb80089..b745c0aa0 100644 --- a/lib/OptimizationMOI/src/nlp.jl +++ b/lib/OptimizationMOI/src/nlp.jl @@ -375,7 +375,7 @@ function MOI.eval_hessian_lagrangian(evaluator::MOIOptimizationNLPEvaluator{T}, σ, μ) where {T} if evaluator.f.lag_h !== nothing - evaluator.f.lag_h(h, x, σ, μ) + evaluator.f.lag_h(h, x, σ, Vector(μ)) return end if evaluator.f.hess === nothing diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 99743d24d..f4db587b0 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -42,27 +42,27 @@ function SciMLBase.__solve(cache::OptimizationCache{ P, C } - maxiters = if cache.solver_args.epochs === nothing + if OptimizationBase.isa_dataiterator(cache.p) + data = cache.p + dataiterate = true + else + data = [cache.p] + dataiterate = false + end + + epochs = if cache.solver_args.epochs === nothing if cache.solver_args.maxiters === nothing - throw(ArgumentError("The number of epochs must be specified with either the epochs or maxiters kwarg.")) + throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data).")) else - cache.solver_args.maxiters + cache.solver_args.maxiters / length(data) end else cache.solver_args.epochs end - maxiters = Optimization._check_and_convert_maxiters(maxiters) - if maxiters === nothing - throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg.")) - end - - if OptimizationBase.isa_dataiterator(cache.p) - data = cache.p - dataiterate = true - else - data = [cache.p] - dataiterate = false + epochs = Optimization._check_and_convert_maxiters(epochs) + if epochs === nothing + throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data).")) end opt = cache.opt @@ -75,21 +75,35 @@ function SciMLBase.__solve(cache::OptimizationCache{ min_θ = cache.u0 state = Optimisers.setup(opt, θ) - + iterations = 0 + fevals = 0 + gevals = 0 t0 = time() Optimization.@withprogress cache.progress name="Training" begin - for epoch in 1:maxiters + for epoch in 1:epochs for (i, d) in enumerate(data) if cache.f.fg !== nothing && dataiterate x = cache.f.fg(G, θ, d) + iterations += 1 + fevals += 1 + gevals += 1 elseif dataiterate cache.f.grad(G, θ, d) x = cache.f(θ, d) + iterations += 1 + fevals += 2 + gevals += 1 elseif cache.f.fg !== nothing x = cache.f.fg(G, θ) + iterations += 1 + fevals += 1 + gevals += 1 else cache.f.grad(G, θ) x = cache.f(θ) + iterations += 1 + fevals += 2 + gevals += 1 end opt_state = Optimization.OptimizationState( iter = i + (epoch - 1) * length(data), @@ -112,7 +126,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ min_err = x min_θ = copy(θ) end - if i == maxiters #Last iter, revert to best. + if i == length(data) #Last iter, revert to best. opt = min_opt x = min_err θ = min_θ @@ -132,10 +146,9 @@ function SciMLBase.__solve(cache::OptimizationCache{ end t1 = time() - stats = Optimization.OptimizationStats(; iterations = maxiters, - time = t1 - t0, fevals = maxiters, gevals = maxiters) + stats = Optimization.OptimizationStats(; iterations, + time = t1 - t0, fevals, gevals) SciMLBase.build_solution(cache, cache.opt, θ, first(x)[1], stats = stats) - # here should be build_solution to create the output message end end diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index 12b6f2754..4728cbf25 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -27,6 +27,9 @@ using Zygote sol = solve(prob, Optimisers.Adam(), maxiters = 1000) @test 10 * sol.objective < l1 + @test sol.stats.iterations == 1000 + @test sol.stats.fevals == 1000 + @test sol.stats.gevals == 1000 @testset "cache" begin objective(x, p) = (p[1] - x[1])^2 @@ -99,6 +102,10 @@ end res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000) @test res.objective < 1e-4 + @test res.stats.iterations == 10000*length(data) + @test res.stats.fevals == 10000*length(data) + @test res.stats.gevals == 10000*length(data) + using MLDataDevices data = CPUDevice()(data) diff --git a/src/sophia.jl b/src/sophia.jl index b63f0c099..9f4d973e9 100644 --- a/src/sophia.jl +++ b/src/sophia.jl @@ -88,7 +88,8 @@ function SciMLBase.__solve(cache::OptimizationCache{ cache.f.grad(gₜ, θ) x = cache.f(θ) end - opt_state = Optimization.OptimizationState(; iter = i + (epoch - 1) * length(data), + opt_state = Optimization.OptimizationState(; + iter = i + (epoch - 1) * length(data), u = θ, objective = first(x), grad = gₜ,