Skip to content

Commit

Permalink
Update broken test for OptimizationManopt
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Jun 13, 2024
1 parent a0dfc18 commit 8420994
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
10 changes: 8 additions & 2 deletions lib/OptimizationManopt/src/OptimizationManopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@ function SciMLBase.__solve(cache::OptimizationCache{
local x, cur, state

manifold = haskey(cache.solver_args, :manifold) ? cache.solver_args[:manifold] : nothing
gradF = haskey(cache.solver_args, :riemannian_grad) ? cache.solver_args[:riemannian_grad] : nothing
hessF = haskey(cache.solver_args, :riemannian_hess) ? cache.solver_args[:riemannian_hess] : nothing

if manifold === nothing
throw(ArgumentError("Manifold not specified in the problem for e.g. `OptimizationProblem(f, x, p; manifold = SymmetricPositiveDefinite(5))`."))
Expand Down Expand Up @@ -433,9 +435,13 @@ function SciMLBase.__solve(cache::OptimizationCache{

_loss = build_loss(cache.f, cache, _cb)

gradF = build_gradF(cache.f, cur)
if gradF === nothing
gradF = build_gradF(cache.f, cur)
end

hessF = build_hessF(cache.f, cur)
if hessF === nothing
hessF = build_hessF(cache.f, cur)
end

if haskey(solver_kwarg, :stopping_criterion)
stopping_criterion = Manopt.StopWhenAny(solver_kwarg.stopping_criterion...)
Expand Down
5 changes: 2 additions & 3 deletions lib/OptimizationManopt/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ end
optprob = OptimizationFunction(rosenbrock, AutoForwardDiff())
prob = OptimizationProblem(optprob, x0, p; manifold = R2)

@test_broken Optimization.solve(prob, opt)
@test_broken sol.minimum < 0.1
sol = Optimization.solve(prob, opt)
@test sol.minimum < 0.1
end

@testset "TrustRegions" begin
Expand Down Expand Up @@ -207,7 +207,6 @@ end
q = Matrix{Float64}(I, 5, 5) .+ 2.0
data2 = [exp(M, q, σ * rand(M; vector_at = q)) for i in 1:m]

f(M, x, p = nothing) = sum(distance(M, x, data2[i])^2 for i in 1:m)
f(x, p = nothing) = sum(distance(M, x, data2[i])^2 for i in 1:m)

optf = OptimizationFunction(f, Optimization.AutoFiniteDiff())
Expand Down

0 comments on commit 8420994

Please sign in to comment.