Skip to content

Commit

Permalink
ENH: minimize_ipopt: add callback support
Browse files Browse the repository at this point in the history
  • Loading branch information
mdhaber committed Sep 25, 2023
1 parent f68a639 commit b06a95c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
34 changes: 30 additions & 4 deletions cyipopt/scipy_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import cyipopt


class IpoptProblemWrapper(object):
class IpoptProblemWrapper(cyipopt.Problem):
"""Class used to map a scipy minimize definition to a cyipopt problem.
Parameters
Expand Down Expand Up @@ -73,6 +73,16 @@ class IpoptProblemWrapper(object):
True, the constraint function `fun` is expected to return a tuple
`(con_val, con_jac)` consisting of the evaluated constraint `con_val`
and the evaluated jacobian `con_jac`.
callback : callable, optional
A callable with signature
``callback(OptimizeResult: intermediate_result)``
that is called once per iteration,
where ``intermediate_result`` is a keyword parameter containing an
:py:class:`scipy.optimize.OptimizeResult` with attributes ``x`` and
``fun``, the present values of the parameter vector and objective
function.
eps : float, optional
Epsilon used in finite differences.
con_dims : array_like, optional
Expand All @@ -97,6 +107,7 @@ def __init__(self,
hess=None,
hessp=None,
constraints=(),
callback=None,
eps=1e-8,
con_dims=(),
sparse_jacs=(),
Expand Down Expand Up @@ -150,6 +161,7 @@ def wrapped_fun(x):
self._constraint_kwargs = []
self._constraint_jac_is_sparse = sparse_jacs
self._constraint_jacobian_structure = (jac_nnz_row, jac_nnz_col)
self.callback = callback
if isinstance(constraints, dict):
constraints = (constraints, )
for con in constraints:
Expand Down Expand Up @@ -239,6 +251,12 @@ def intermediate(self, alg_mod, iter_count, obj_value, inf_pr, inf_du, mu,
ls_trials):

self.nit = iter_count
if self.callback is not None:
iterate = self.get_current_iterate()
res = OptimizeResult()
res.x = iterate["x"]
res.fun = obj_value
self.callback(intermediate_result=res)


def get_bounds(bounds):
Expand Down Expand Up @@ -465,6 +483,16 @@ def minimize_ipopt(fun,
constraint function ``fun`` must return a tuple ``(con_val, con_jac)``
consisting of the evaluated constraint ``con_val`` and the evaluated
Jacobian ``con_jac``.
callback : callable, optional
A callable with signature
``callback(OptimizeResult: intermediate_result)``
that is called once per iteration,
where ``intermediate_result`` is a keyword parameter containing an
:py:class:`scipy.optimize.OptimizeResult` with attributes ``x`` and
``fun``, the present values of the parameter vector and objective
function.
tol : float, optional (default=1e-8)
The desired relative convergence tolerance, passed as an option to
Ipopt. See [1]_ for details.
Expand Down Expand Up @@ -559,6 +587,7 @@ def minimize_ipopt(fun,
hess=hess,
hessp=hessp,
constraints=constraints,
callback=callback,
eps=1e-8,
con_dims=con_dims,
sparse_jacs=sparse_jacs,
Expand Down Expand Up @@ -661,9 +690,6 @@ def _minimize_ipopt_iv(fun, x0, args, kwargs, method, jac, hess, hessp,
constraints = optimize._minimize.standardize_constraints(constraints, x0,
'old')

if method is None and callback is not None:
raise NotImplementedError('`callback` is not yet supported by Ipopt.`')

if tol is not None:
tol = np.asarray(tol)[()]
if tol.ndim != 0 or not np.issubdtype(tol.dtype, np.number) or tol <= 0:
Expand Down
17 changes: 13 additions & 4 deletions cyipopt/tests/unit/test_scipy_optional.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ def f(x):
with pytest.raises(ValueError, match=message):
cyipopt.minimize_ipopt(f, x0, bounds=[['low', 'high']])

message = "`callback` is not yet supported by Ipopt."
with pytest.raises(NotImplementedError, match=message):
cyipopt.minimize_ipopt(f, x0, callback='a duck')

message = "`tol` must be a positive scalar."
with pytest.raises(ValueError, match=message):
cyipopt.minimize_ipopt(f, x0, tol=[1, 2, 3])
Expand Down Expand Up @@ -575,3 +571,16 @@ def test_minimize_late_binding_bug():
assert res.success
np.testing.assert_allclose(res.x, ref.x)
np.testing.assert_allclose(res.fun, ref.fun)

@pytest.mark.skipif("scipy" not in sys.modules,
reason="Test only valid if Scipy available.")
def test_minimize_callback():
# Test that `callback` works
def f(x):
return x @ x

def callback(intermediate_result):
assert intermediate_result.fun == f(intermediate_result.x)
assert False

cyipopt.minimize_ipopt(f, [1, 2, 3], callback=callback)

0 comments on commit b06a95c

Please sign in to comment.