From a155ebb4b910c73cf84d7fb33af791d7732721da Mon Sep 17 00:00:00 2001 From: Nicola Di Cicco <93935338+nicoladicicco@users.noreply.github.com> Date: Fri, 19 Jul 2024 10:36:11 +0200 Subject: [PATCH] Use coordinate descent for continuous variables --- Project.toml | 3 ++- src/solver.jl | 63 +++++++++++++++++++++++++++++++++++++++++++--- test/raw_solver.jl | 44 ++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 8cf79d7..16fe459 100644 --- a/Project.toml +++ b/Project.toml @@ -30,8 +30,9 @@ julia = "1.6" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Intervals = "d8418881-c3e1-53bb-8760-2df7ec849ed5" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" [targets] -test = ["Aqua", "Test", "TestItemRunner"] +test = ["Aqua", "Intervals", "Test", "TestItemRunner"] diff --git a/src/solver.jl b/src/solver.jl index 7aa2ea4..b020818 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -336,6 +336,56 @@ function _move!(s, x::Int, dim::Int = 0) return best_values, best_swap, tabu end +""" + armijo_line_search(f, x, d, fx; α0 = 1.0, β = 0.5, c = 1e-4) + +Determines the optimal step size of a line search algorithm via the Armijo condition. +# Arguments: +- `f`: a function to minimize +- `x`: selected variable id +- `d`: descent direction (e.g., negative gradient) +- `fx`: value of f at `x` +- `α0`: initial step size +- `β`: step size reduction factor +- `c`: Armijo condition constant +""" +function armijo_line_search(f, x, d, fx; α0 = 1.0, β = 0.5, c = 1e-4) + α = α0 + while f(x + α*d) > fx + c*α*d*fx + α *= β + end + return α +end + +""" + _coordinate_descent!(s, x) + +Runs an iteration of coordinate descent over axis "x". +The derivative is (temporarily?) computed via finite difference. +The step size is determined via the Armijo condition for line search. +""" +function _coordinate_descent_move!(s, x) + domain = get_variable(s, x).domain + current_value = _value(s, x) + + function f(val) + _value!(s, x, val) + _compute!(s) + return get_error(s) + end + + current_error = f(current_value) + grad = (f(current_value + 1e-6) - f(current_value - 1e-6)) / (2e-6) + + α = armijo_line_search(f, current_value, -grad, current_error) + new_value = clamp(current_value - α * grad, domain.lb, domain.ub) + new_error = f(new_value) + + if new_error < current_error + current_value = new_value + end +end + """ _step!(s) @@ -345,10 +395,15 @@ function _step!(s) # select worst variables x = _select_worse(s) _verbose(s, "Selected x = $x") - - # Local move (change the value of the selected variable) - best_values, best_swap, tabu = _move!(s, x) - # _compute!(s) + + if _value(s, x) isa Int + # Local move (change the value of the selected variable) + best_values, best_swap, tabu = _move!(s, x) + # _compute!(s) + else + # We perform coordinate descent over the variable axis + _coordinate_descent_move!(s, x) + end # If local move is bad (tabu), then try permutation if tabu diff --git a/test/raw_solver.jl b/test/raw_solver.jl index d20658b..d7e4d63 100644 --- a/test/raw_solver.jl +++ b/test/raw_solver.jl @@ -83,6 +83,39 @@ function sudoku(n; start = nothing) return m end +function chemical_equilibrium(A, B, C) + m = model(; kind = :equilibrium) + + N = length(C) + M = length(B) + + d = domain(0..maximum(B)) + + # Add variables, number of moles per compound + + foreach(_ -> variable!(m, d), 1:N) + + # mass_conservation function + conserve = i -> (x -> + begin + δ = abs(sum(A[:, i] .* x) - B[i]) + return δ ≤ 1.e-6 ? 0. : δ + end + ) + + # Add constraints + for i in 1:M + constraint!(m, conserve(i), 1:N) + end + + # computes the total energy freed by the reaction + free_energy = x -> sum(j -> x[j] * (C[j] + log(x[j] / sum(x)))) + + objective!(m, free_energy) + + return m +end + @testset "Raw solver: internals" begin models = [ sudoku(2) @@ -203,3 +236,14 @@ end @info "Sol (vals): $(!isnothing(best_value(s)) ? best_values(s) : nothing)" @info time_info(s) end + +@testset "Raw solver: chemical equilibrium" begin + A = [2.0 1.0 0.0; 6.0 2.0 1.0; 1.0 2.0 4.0] + B = [20.0, 30.0, 25.0] + C = [-10.0, -8.0, -6.0] + m = chemical_equilibrium(A, B, C) + s = solver(m; options = Options(print_level = :minimal)) + solve!(s) + display(solution(s)) + display(s.time_stamps) +end