Skip to content

Commit

Permalink
Reduce duplicate code
Browse files Browse the repository at this point in the history
  • Loading branch information
mrmundt committed Dec 16, 2024
1 parent ae8a685 commit c5d2cbc
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 205 deletions.
40 changes: 40 additions & 0 deletions pyomo/contrib/solver/common/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@
# __________________________________________________________________________

import abc
import datetime
from typing import List

from pyomo.core.base.constraint import ConstraintData, Constraint
from pyomo.core.base.sos import SOSConstraintData, SOSConstraint
from pyomo.core.base.var import VarData
from pyomo.core.base.param import ParamData, Param
from pyomo.core.base.objective import ObjectiveData
from pyomo.core.staleflag import StaleFlagManager
from pyomo.common.collections import ComponentMap
from pyomo.common.timing import HierarchicalTimer
from pyomo.contrib.solver.common.results import Results
from pyomo.contrib.solver.common.util import collect_vars_and_named_exprs, get_objective


Expand Down Expand Up @@ -495,3 +498,40 @@ def update(self, timer: HierarchicalTimer = None):
timer.start('vars')
self.remove_variables(old_vars)
timer.stop('vars')


class PersistentSolverMixin:
"""
The `solve` method in Gurobi and Highs is exactly the same, so this Mixin
minimizes the duplicate code
"""
def solve(self, model, **kwds) -> Results:
start_timestamp = datetime.datetime.now(datetime.timezone.utc)
self._active_config = config = self.config(value=kwds, preserve_implicit=True)
StaleFlagManager.mark_all_as_stale()

if self._last_results_object is not None:
self._last_results_object.solution_loader.invalidate()
if config.timer is None:
config.timer = HierarchicalTimer()
timer = config.timer

if model is not self._model:
timer.start('set_instance')
self.set_instance(model)
timer.stop('set_instance')
else:
timer.start('update')
self.update(timer=timer)
timer.stop('update')

res = self._solve()
self._last_results_object = res

end_timestamp = datetime.datetime.now(datetime.timezone.utc)
res.timing_info.start_timestamp = start_timestamp
res.timing_info.wall_time = (end_timestamp - start_timestamp).total_seconds()
res.timing_info.timer = timer
self._active_config = self.config

return res
12 changes: 10 additions & 2 deletions pyomo/contrib/solver/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,23 @@ class NoDualsError(PyomoException):
def __init__(self):
super().__init__(
'Solver does not currently have valid duals. Please '
'check the termination condition.'
'check results.termination_condition and/or results.solution_status.'
)


class NoReducedCostsError(PyomoException):
def __init__(self):
super().__init__(
'Solver does not currently have valid reduced costs. Please '
'check the termination condition.'
'check results.termination_condition and/or results.solution_status.'
)


class IncompatibleModelError(PyomoException):
def __init__(self):
super().__init__(
'Model is not compatible with the chosen solver. Please check '
'the model and solver.'
)


Expand Down
4 changes: 2 additions & 2 deletions pyomo/contrib/solver/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .common.factory import SolverFactory
from .solvers.ipopt import Ipopt
from .solvers.gurobi_persistent import Gurobi
from .solvers.gurobi_persistent import GurobiPersistent
from .solvers.gurobi_direct import GurobiDirect
from .solvers.highs import Highs

Expand All @@ -25,7 +25,7 @@ def load():
name='gurobi_persistent',
legacy_name='gurobi_persistent_v2',
doc='Persistent interface to Gurobi',
)(Gurobi)
)(GurobiPersistent)
SolverFactory.register(
name='gurobi_direct',
legacy_name='gurobi_direct_v2',
Expand Down
69 changes: 43 additions & 26 deletions pyomo/contrib/solver/solvers/gurobi_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pyomo.common.config import ConfigValue
from pyomo.common.dependencies import attempt_import
from pyomo.common.enums import ObjectiveSense
from pyomo.common.errors import MouseTrap
from pyomo.common.errors import MouseTrap, ApplicationError
from pyomo.common.shutdown import python_is_shutting_down
from pyomo.common.tee import capture_output, TeeStream
from pyomo.common.timing import HierarchicalTimer
Expand All @@ -34,6 +34,7 @@
NoDualsError,
NoReducedCostsError,
NoSolutionError,
IncompatibleModelError,
)
from pyomo.contrib.solver.common.results import (
Results,
Expand Down Expand Up @@ -168,17 +169,12 @@ def get_reduced_costs(self, vars_to_load=None):
return ComponentMap(iterator)


class GurobiDirect(SolverBase):
CONFIG = GurobiConfig()

_available = None
_num_instances = 0
_tc_map = None

def __init__(self, **kwds):
super().__init__(**kwds)
GurobiDirect._num_instances += 1

class GurobiSolverMixin:
"""
gurobi_direct and gurobi_persistent check availability and set versions
in the same way. This moves the logic to a central location to reduce
duplicate code.
"""
def available(self):
if not gurobipy_available: # this triggers the deferred import
return Availability.NotFound
Expand All @@ -199,7 +195,7 @@ def _check_license(self):

if avail:
if self._available is None:
self._available = GurobiDirect._check_full_license(m)
self._available = self._check_full_license(m)
return self._available
return Availability.BadLicense

Expand All @@ -215,18 +211,6 @@ def _check_full_license(cls, model=None):
except gurobipy.GurobiError:
return Availability.LimitedLicense

def __del__(self):
if not python_is_shutting_down():
GurobiDirect._num_instances -= 1
if GurobiDirect._num_instances == 0:
self.release_license()

@staticmethod
def release_license():
if gurobipy_available:
with capture_output(capture_fd=True):
gurobipy.disposeDefaultEnv()

def version(self):
version = (
gurobipy.GRB.VERSION_MAJOR,
Expand All @@ -235,9 +219,42 @@ def version(self):
)
return version


class GurobiDirect(GurobiSolverMixin, SolverBase):
"""
Interface to Gurobi direct (not persistent)
"""
CONFIG = GurobiConfig()

_available = None
_num_instances = 0
_tc_map = None

def __init__(self, **kwds):
super().__init__(**kwds)
GurobiDirect._num_instances += 1

@staticmethod
def release_license():
if gurobipy_available:
with capture_output(capture_fd=True):
gurobipy.disposeDefaultEnv()

def __del__(self):
if not python_is_shutting_down():
GurobiDirect._num_instances -= 1
if GurobiDirect._num_instances == 0:
self.release_license()

def solve(self, model, **kwds) -> Results:
start_timestamp = datetime.datetime.now(datetime.timezone.utc)
config = self.config(value=kwds, preserve_implicit=True)
if not self.available():
c = self.__class__
raise ApplicationError(
f'Solver {c.__module__}.{c.__qualname__} is not available '
f'({self.available()}).'
)
if config.timer is None:
config.timer = HierarchicalTimer()
timer = config.timer
Expand All @@ -251,7 +268,7 @@ def solve(self, model, **kwds) -> Results:
timer.stop('compile_model')

if len(repn.objectives) > 1:
raise ValueError(
raise IncompatibleModelError(
f"The {self.__class__.__name__} solver only supports models "
f"with zero or one objectives (received {len(repn.objectives)})."
)
Expand Down
Loading

0 comments on commit c5d2cbc

Please sign in to comment.