Skip to content

Commit

Permalink
Fix base tests; rename gurobi to gurobi_persistent
Browse files Browse the repository at this point in the history
  • Loading branch information
mrmundt committed Dec 16, 2024
1 parent 66b5b24 commit 15db95f
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 95 deletions.
66 changes: 49 additions & 17 deletions pyomo/contrib/solver/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class SolverBase:
- version: The version of the solver
- is_persistent: Set to false for all non-persistent solvers.
Additionally, solvers should have a :attr:`config<SolverBase.config>` attribute that
Additionally, solvers should have a :attr:`CONFIG<SolverBase.CONFIG>` attribute that
inherits from one of :class:`SolverConfig<pyomo.contrib.solver.config.SolverConfig>`,
:class:`BranchAndBoundConfig<pyomo.contrib.solver.config.BranchAndBoundConfig>`,
:class:`PersistentSolverConfig<pyomo.contrib.solver.config.PersistentSolverConfig>`, or
Expand Down Expand Up @@ -111,7 +111,9 @@ def solve(self, model: BlockData, **kwargs) -> Results:
results: :class:`Results<pyomo.contrib.solver.results.Results>`
A results object
"""
raise NotImplementedError("Subclasses should implement this method.")
raise NotImplementedError(
"Every solver interface must provide their own implementation of `solve` which returns an appropriately populated `Results` object."
)

def available(self) -> Availability:
"""Test if the solver is available on this system.
Expand All @@ -121,7 +123,9 @@ def available(self) -> Availability:
available: Availability
An enum that indicates "how available" the solver is.
"""
raise NotImplementedError("Subclasses should implement this method.")
raise NotImplementedError(
"Every solver interface must provide their own implementation of `available` which returns an `Availability` object."
)

def version(self) -> Tuple:
"""
Expand All @@ -130,7 +134,9 @@ def version(self) -> Tuple:
version: tuple
A tuple representing the version
"""
raise NotImplementedError("Subclasses should implement this method.")
raise NotImplementedError(
"Solvers must return the version number in a tuple form (e.g., `(1, 2, 1)`)."
)

def is_persistent(self) -> bool:
"""
Expand Down Expand Up @@ -175,7 +181,9 @@ def solve(self, model: BlockData, **kwargs) -> Results:
results: :class:`Results<pyomo.contrib.solver.results.Results>`
A results object
"""
raise NotImplementedError("Subclasses should implement this method.")
raise NotImplementedError(
"Every solver interface must provide their own implementation of `solve` which returns an appropriately populated `Results` object."
)

def is_persistent(self) -> bool:
"""
Expand Down Expand Up @@ -262,73 +270,97 @@ def set_instance(self, model: BlockData):
"""
Set an instance of the model.
"""
raise NotImplementedError("Subclasses should implement this method.")
raise NotImplementedError(
"Persistent solvers should provide a mechanism for setting an instance of the model."
)

def set_objective(self, obj: ObjectiveData):
"""
Set current objective for the model.
"""
raise NotImplementedError("Subclasses should implement this method.")
raise NotImplementedError(
"Persistent solvers should provide a mechanism for setting the current objective of the model."
)

def add_variables(self, variables: List[VarData]):
"""
Add variables to the model.
"""
raise NotImplementedError("Subclasses should implement this method.")
raise NotImplementedError(
"Persistent solvers should provide a mechanism for adding variables to the model."
)

def add_parameters(self, params: List[ParamData]):
"""
Add parameters to the model.
"""
raise NotImplementedError("Subclasses should implement this method.")
raise NotImplementedError(
"Persistent solvers should provide a mechanism for adding parameters to the model."
)

def add_constraints(self, cons: List[ConstraintData]):
"""
Add constraints to the model.
"""
raise NotImplementedError("Subclasses should implement this method.")
raise NotImplementedError(
"Persistent solvers should provide a mechanism for adding constraints to the model."
)

def add_block(self, block: BlockData):
"""
Add a block to the model.
"""
raise NotImplementedError("Subclasses should implement this method.")
raise NotImplementedError(
"Persistent solvers should provide a mechanism for adding blocks to the model."
)

def remove_variables(self, variables: List[VarData]):
"""
Remove variables from the model.
"""
raise NotImplementedError("Subclasses should implement this method.")
raise NotImplementedError(
"Persistent solvers should provide a mechanism for removing variables from the model."
)

def remove_parameters(self, params: List[ParamData]):
"""
Remove parameters from the model.
"""
raise NotImplementedError("Subclasses should implement this method.")
raise NotImplementedError(
"Persistent solvers should provide a mechanism for removing parameters from the model."
)

def remove_constraints(self, cons: List[ConstraintData]):
"""
Remove constraints from the model.
"""
raise NotImplementedError("Subclasses should implement this method.")
raise NotImplementedError(
"Persistent solvers should provide a mechanism for removing constraints from the model."
)

def remove_block(self, block: BlockData):
"""
Remove a block from the model.
"""
raise NotImplementedError("Subclasses should implement this method.")
raise NotImplementedError(
"Persistent solvers should provide a mechanism for removing blocks from the model."
)

def update_variables(self, variables: List[VarData]):
"""
Update variables on the model.
"""
raise NotImplementedError("Subclasses should implement this method.")
raise NotImplementedError(
"Persistent solvers should provide a mechanism for updating variables on the model."
)

def update_parameters(self):
"""
Update parameters on the model.
"""
raise NotImplementedError("Subclasses should implement this method.")
raise NotImplementedError(
"Persistent solvers should provide a mechanism for updating parameters on the model."
)


class LegacySolverWrapper:
Expand Down
38 changes: 0 additions & 38 deletions pyomo/contrib/solver/common/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import abc
from typing import List
import datetime

from pyomo.core.base.constraint import ConstraintData, Constraint
from pyomo.core.base.sos import SOSConstraintData, SOSConstraint
Expand All @@ -21,8 +20,6 @@
from pyomo.common.collections import ComponentMap
from pyomo.common.timing import HierarchicalTimer
from pyomo.contrib.solver.common.util import collect_vars_and_named_exprs, get_objective
from pyomo.core.staleflag import StaleFlagManager
from pyomo.contrib.solver.common.results import Results


class PersistentSolverUtils(abc.ABC):
Expand Down Expand Up @@ -498,38 +495,3 @@ def update(self, timer: HierarchicalTimer = None):
timer.start('vars')
self.remove_variables(old_vars)
timer.stop('vars')


class PersistentSolverMixin:
"""
Mixin class for common solver functionality
"""

def solve(self, model, **kwds) -> Results:
start_timestamp = datetime.datetime.now(datetime.timezone.utc)
self._active_config = config = self.config(value=kwds, preserve_implicit=True)
StaleFlagManager.mark_all_as_stale()
if self._last_results_object is not None:
self._last_results_object.solution_loader.invalidate()
if config.timer is None:
config.timer = HierarchicalTimer()
timer = config.timer
if model is not self._model:
timer.start('set_instance')
self.set_instance(model)
timer.stop('set_instance')
else:
timer.start('update')
self.update(timer=timer)
timer.stop('update')
res = self._solve()
self._last_results_object = res
end_timestamp = datetime.datetime.now(datetime.timezone.utc)
res.timing_info.start_timestamp = start_timestamp
res.timing_info.wall_time = (end_timestamp - start_timestamp).total_seconds()
res.timing_info.timer = timer
self._active_config = self.config
return res

def _solve(self):
raise NotImplementedError("Subclasses should implement this method")
4 changes: 2 additions & 2 deletions pyomo/contrib/solver/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .common.factory import SolverFactory
from .solvers.ipopt import Ipopt
from .solvers.gurobi import Gurobi
from .solvers.gurobi_persistent import Gurobi
from .solvers.gurobi_direct import GurobiDirect
from .solvers.highs import Highs

Expand All @@ -22,7 +22,7 @@ def load():
name='ipopt', legacy_name='ipopt_v2', doc='The IPOPT NLP solver'
)(Ipopt)
SolverFactory.register(
name='gurobi', legacy_name='gurobi_v2', doc='Persistent interface to Gurobi'
name='gurobi_persistent', legacy_name='gurobi_persistent_v2', doc='Persistent interface to Gurobi'
)(Gurobi)
SolverFactory.register(
name='gurobi_direct',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,7 @@
NoReducedCostsError,
NoSolutionError,
)
from pyomo.contrib.solver.common.persistent import (
PersistentSolverUtils,
PersistentSolverMixin,
)
from pyomo.contrib.solver.common.persistent import PersistentSolverUtils
from pyomo.contrib.solver.common.solution import PersistentSolutionLoader
from pyomo.core.staleflag import StaleFlagManager

Expand Down Expand Up @@ -250,7 +247,6 @@ def __init__(self, **kwds):
PersistentSolverUtils.__init__(
self, treat_fixed_vars_as_params=treat_fixed_vars_as_params
)
PersistentSolverMixin.__init__(self)
Gurobi._num_instances += 1
self._solver_model = None
self._symbol_map = SymbolMap()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import pyomo.common.unittest as unittest
import pyomo.environ as pe
from pyomo.contrib.solver.solvers.gurobi import Gurobi
from pyomo.contrib.solver.solvers.gurobi_persistent import Gurobi
from pyomo.contrib.solver.common.results import SolutionStatus
from pyomo.core.expr.taylor_series import taylor_series_expansion

Expand Down
2 changes: 1 addition & 1 deletion pyomo/contrib/solver/tests/solvers/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from pyomo.contrib.solver.common.base import SolverBase
from pyomo.contrib.solver.solvers.ipopt import Ipopt
from pyomo.contrib.solver.solvers.gurobi import Gurobi
from pyomo.contrib.solver.solvers.gurobi_persistent import Gurobi
from pyomo.contrib.solver.solvers.gurobi_direct import GurobiDirect
from pyomo.contrib.solver.solvers.highs import Highs
from pyomo.core.expr.numeric_expr import LinearExpression
Expand Down
60 changes: 29 additions & 31 deletions pyomo/contrib/solver/tests/unit/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,20 @@ def test_class_method_list(self):
def test_init(self):
self.instance = base.SolverBase()
self.assertFalse(self.instance.is_persistent())
self.assertEqual(self.instance.version(), None)
self.assertEqual(self.instance.name, 'solverbase')
self.assertEqual(self.instance.CONFIG, self.instance.config)
self.assertEqual(self.instance.solve(None), None)
self.assertEqual(self.instance.available(), None)
with self.assertRaises(NotImplementedError):
self.assertEqual(self.instance.version(), None)
with self.assertRaises(NotImplementedError):
self.assertEqual(self.instance.solve(None), None)
with self.assertRaises(NotImplementedError):
self.assertEqual(self.instance.available(), None)

def test_context_manager(self):
with base.SolverBase() as self.instance:
self.assertFalse(self.instance.is_persistent())
self.assertEqual(self.instance.version(), None)
self.assertEqual(self.instance.name, 'solverbase')
self.assertEqual(self.instance.CONFIG, self.instance.config)
self.assertEqual(self.instance.solve(None), None)
self.assertEqual(self.instance.available(), None)

def test_config_kwds(self):
self.instance = base.SolverBase(tee=True)
Expand Down Expand Up @@ -90,43 +90,40 @@ def test_class_method_list(self):
def test_init(self):
self.instance = base.PersistentSolverBase()
self.assertTrue(self.instance.is_persistent())
self.assertEqual(self.instance.set_instance(None), None)
self.assertEqual(self.instance.add_variables(None), None)
self.assertEqual(self.instance.add_parameters(None), None)
self.assertEqual(self.instance.add_constraints(None), None)
self.assertEqual(self.instance.add_block(None), None)
self.assertEqual(self.instance.remove_variables(None), None)
self.assertEqual(self.instance.remove_parameters(None), None)
self.assertEqual(self.instance.remove_constraints(None), None)
self.assertEqual(self.instance.remove_block(None), None)
self.assertEqual(self.instance.set_objective(None), None)
self.assertEqual(self.instance.update_variables(None), None)
self.assertEqual(self.instance.update_parameters(), None)

with self.assertRaises(NotImplementedError):
self.instance._get_primals()

with self.assertRaises(NotImplementedError):
self.instance._get_duals()

with self.assertRaises(NotImplementedError):
self.instance._get_reduced_costs()

def test_context_manager(self):
with base.PersistentSolverBase() as self.instance:
self.assertTrue(self.instance.is_persistent())
self.assertEqual(self.instance.set_instance(None), None)
with self.assertRaises(NotImplementedError):
self.assertEqual(self.instance.add_variables(None), None)
with self.assertRaises(NotImplementedError):
self.assertEqual(self.instance.add_parameters(None), None)
with self.assertRaises(NotImplementedError):
self.assertEqual(self.instance.add_constraints(None), None)
with self.assertRaises(NotImplementedError):
self.assertEqual(self.instance.add_block(None), None)
with self.assertRaises(NotImplementedError):
self.assertEqual(self.instance.remove_variables(None), None)
with self.assertRaises(NotImplementedError):
self.assertEqual(self.instance.remove_parameters(None), None)
with self.assertRaises(NotImplementedError):
self.assertEqual(self.instance.remove_constraints(None), None)
with self.assertRaises(NotImplementedError):
self.assertEqual(self.instance.remove_block(None), None)
with self.assertRaises(NotImplementedError):
self.assertEqual(self.instance.set_objective(None), None)
with self.assertRaises(NotImplementedError):
self.assertEqual(self.instance.update_variables(None), None)
with self.assertRaises(NotImplementedError):
self.assertEqual(self.instance.update_parameters(), None)
with self.assertRaises(NotImplementedError):
self.instance._get_primals()
with self.assertRaises(NotImplementedError):
self.instance._get_duals()
with self.assertRaises(NotImplementedError):
self.instance._get_reduced_costs()

def test_context_manager(self):
with base.PersistentSolverBase() as self.instance:
self.assertTrue(self.instance.is_persistent())


class TestLegacySolverWrapper(unittest.TestCase):
Expand All @@ -148,7 +145,8 @@ def test_class_method_list(self):
def test_context_manager(self):
with _LegacyWrappedSolverBase() as instance:
self.assertIsInstance(instance, _LegacyWrappedSolverBase)
self.assertFalse(instance.available(False))
with self.assertRaises(NotImplementedError):
self.assertFalse(instance.available(False))

def test_map_config(self):
# Create a fake/empty config structure that can be added to an empty
Expand Down

0 comments on commit 15db95f

Please sign in to comment.