diff --git a/mpisppy/cylinders/xhatbase.py b/mpisppy/cylinders/xhatbase.py new file mode 100644 index 00000000..8012cfc1 --- /dev/null +++ b/mpisppy/cylinders/xhatbase.py @@ -0,0 +1,54 @@ +############################################################################### +# mpi-sppy: MPI-based Stochastic Programming in PYthon +# +# Copyright (c) 2024, Lawrence Livermore National Security, LLC, Alliance for +# Sustainable Energy, LLC, The Regents of the University of California, et al. +# All rights reserved. Please see the files COPYRIGHT.md and LICENSE.md for +# full copyright and license information. +############################################################################### + +import abc +import mpisppy.cylinders.spoke as spoke +from mpisppy.utils.xhat_eval import Xhat_Eval + +class XhatInnerBoundBase(spoke.InnerBoundNonantSpoke): + + @abc.abstractmethod + def xhat_extension(self): + raise NotImplementedError + + + def xhat_prep(self): + if "bundles_per_rank" in self.opt.options\ + and self.opt.options["bundles_per_rank"] != 0: + raise RuntimeError("xhat spokes cannot have bundles (yet)") + + ## for later + self.verbose = self.opt.options["verbose"] # typing aid + + if not isinstance(self.opt, Xhat_Eval): + raise RuntimeError(f"{self.__class__.__name__} must be used with Xhat_Eval.") + + xhatter = self.xhat_extension() + + ### begin iter0 stuff + xhatter.pre_iter0() + if self.opt.extensions is not None: + self.opt.extobject.pre_iter0() # for an extension + self.opt._save_original_nonants() + + self.opt._lazy_create_solvers() # no iter0 loop, but we need the solvers + + self.opt._update_E1() + if abs(1 - self.opt.E1) > self.opt.E1_tolerance: + raise ValueError(f"Total probability of scenarios was {self.opt.E1} "+\ + f"(E1_tolerance is {self.opt.E1_tolerance})") + ### end iter0 stuff (but note: no need for iter 0 solves in an xhatter) + + xhatter.post_iter0() + if self.opt.extensions is not None: + self.opt.extobject.post_iter0() # for an extension + + self.opt._save_nonants() # make the cache + + return xhatter diff --git a/mpisppy/cylinders/xhatlooper_bounder.py b/mpisppy/cylinders/xhatlooper_bounder.py index 6cd38695..d82c462a 100644 --- a/mpisppy/cylinders/xhatlooper_bounder.py +++ b/mpisppy/cylinders/xhatlooper_bounder.py @@ -7,9 +7,8 @@ # full copyright and license information. ############################################################################### # updated April 2020 -import mpisppy.cylinders.spoke as spoke from mpisppy.extensions.xhatlooper import XhatLooper -from mpisppy.utils.xhat_eval import Xhat_Eval +from mpisppy.cylinders.xhatbase import XhatInnerBoundBase import logging import mpisppy.log @@ -20,44 +19,17 @@ logger = logging.getLogger("mpisppy.cylinders.xhatlooper_bounder") -class XhatLooperInnerBound(spoke.InnerBoundNonantSpoke): +class XhatLooperInnerBound(XhatInnerBoundBase): converger_spoke_char = 'X' - def xhatlooper_prep(self): - if "bundles_per_rank" in self.opt.options\ - and self.opt.options["bundles_per_rank"] != 0: - raise RuntimeError("xhat spokes cannot have bundles (yet)") - - if not isinstance(self.opt, Xhat_Eval): - raise RuntimeError("XhatShuffleInnerBound must be used with Xhat_Eval.") - - xhatter = XhatLooper(self.opt) - - ### begin iter0 stuff - xhatter.pre_iter0() - self.opt._save_original_nonants() - - self.opt._lazy_create_solvers() # no iter0 loop, but we need the solvers - - self.opt._update_E1() - if abs(1 - self.opt.E1) > self.opt.E1_tolerance: - if self.opt.cylinder_rank == 0: - print("ERROR") - print("Total probability of scenarios was ", self.opt.E1) - print("E1_tolerance = ", self.opt.E1_tolerance) - quit() - ### end iter0 stuff - - xhatter.post_iter0() - self.opt._save_nonants() # make the cache - - return xhatter + def xhat_extension(self): + return XhatLooper(self.opt) def main(self): logger.debug(f"Entering main on xhatlooper spoke rank {self.global_rank}") - xhatter = self.xhatlooper_prep() + xhatter = self.xhat_prep() scen_limit = self.opt.options['xhat_looper_options']['scen_limit'] diff --git a/mpisppy/cylinders/xhatshufflelooper_bounder.py b/mpisppy/cylinders/xhatshufflelooper_bounder.py index 44a8d00b..057346a0 100644 --- a/mpisppy/cylinders/xhatshufflelooper_bounder.py +++ b/mpisppy/cylinders/xhatshufflelooper_bounder.py @@ -9,10 +9,9 @@ import logging import random import mpisppy.log -import mpisppy.cylinders.spoke as spoke -from mpisppy.utils.xhat_eval import Xhat_Eval from mpisppy.extensions.xhatbase import XhatBase +from mpisppy.cylinders.xhatbase import XhatInnerBoundBase # Could also pass, e.g., sys.stdout instead of a filename mpisppy.log.setup_logger("mpisppy.cylinders.xhatshufflelooper_bounder", @@ -20,41 +19,15 @@ level=logging.CRITICAL) logger = logging.getLogger("mpisppy.cylinders.xhatshufflelooper_bounder") -class XhatShuffleInnerBound(spoke.InnerBoundNonantSpoke): +class XhatShuffleInnerBound(XhatInnerBoundBase): converger_spoke_char = 'X' - def xhatbase_prep(self): + def xhat_extension(self): + return XhatBase(self.opt) - if "bundles_per_rank" in self.opt.options\ - and self.opt.options["bundles_per_rank"] != 0: - raise RuntimeError("xhat spokes cannot have bundles (yet)") - - ## for later - self.verbose = self.opt.options["verbose"] # typing aid - self.solver_options = self.opt.options["xhat_looper_options"]["xhat_solver_options"] - - if not isinstance(self.opt, Xhat_Eval): - raise RuntimeError("XhatShuffleInnerBound must be used with Xhat_Eval.") - - xhatter = XhatBase(self.opt) - self.xhatter = xhatter - - ### begin iter0 stuff - xhatter.pre_iter0() # for an extension - self.opt._save_original_nonants() - - self.opt._lazy_create_solvers() # no iter0 loop, but we need the solvers - - self.opt._update_E1() - if abs(1 - self.opt.E1) > self.opt.E1_tolerance: - raise ValueError(f"Total probability of scenarios was {self.opt.E1} "+\ - f"(E1_tolerance is {self.opt.E1_tolerance})") - ### end iter0 stuff (but note: no need for iter 0 solves in an xhatter) - - xhatter.post_iter0() - - self.opt._save_nonants() # make the cache + def xhat_prep(self): + self.xhatter = super().xhat_prep() ## option drive this? (could be dangerous) self.random_seed = 42 @@ -90,7 +63,7 @@ def _vb(msg): def main(self): logger.debug(f"Entering main on xhatshuffle spoke rank {self.global_rank}") - self.xhatbase_prep() + self.xhat_prep() if "reverse" in self.opt.options["xhat_looper_options"]: self.reverse = self.opt.options["xhat_looper_options"]["reverse"] else: @@ -99,6 +72,7 @@ def main(self): self.iter_step = self.opt.options["xhat_looper_options"]["iter_step"] else: self.iter_step = None + self.solver_options = self.opt.options["xhat_looper_options"]["xhat_solver_options"] # give all ranks the same seed self.random_stream.seed(self.random_seed) diff --git a/mpisppy/cylinders/xhatspecific_bounder.py b/mpisppy/cylinders/xhatspecific_bounder.py index 0a921724..7fed91d7 100644 --- a/mpisppy/cylinders/xhatspecific_bounder.py +++ b/mpisppy/cylinders/xhatspecific_bounder.py @@ -9,9 +9,8 @@ # udpated April 20 # specific xhat supplied (copied from xhatlooper_bounder by DLW, Dec 2019) -import mpisppy.cylinders.spoke as spoke from mpisppy.extensions.xhatspecific import XhatSpecific -from mpisppy.utils.xhat_eval import Xhat_Eval +from mpisppy.cylinders.xhatbase import XhatInnerBoundBase import mpisppy.MPI as mpi import logging @@ -22,47 +21,12 @@ ############################################################################ -class XhatSpecificInnerBound(spoke.InnerBoundNonantSpoke): +class XhatSpecificInnerBound(XhatInnerBoundBase): converger_spoke_char = 'S' - def ib_prep(self): - """ - Set up the objects needed for bounding. - - Returns: - xhatter (xhatspecific object): Constructed by a call to Prep - """ - if "bundles_per_rank" in self.opt.options\ - and self.opt.options["bundles_per_rank"] != 0: - raise RuntimeError("xhat spokes cannot have bundles (yet)") - - if not isinstance(self.opt, Xhat_Eval): - raise RuntimeError("XhatShuffleInnerBound must be used with Xhat_Eval.") - - xhatter = XhatSpecific(self.opt) - # somehow deal with the prox option .... TBD .... important for aph APH - - # begin iter0 stuff - xhatter.pre_iter0() - self.opt._save_original_nonants() - - self.opt._lazy_create_solvers() # no iter0 loop, but we need the solvers - - self.opt._update_E1() - if (abs(1 - self.opt.E1) > self.opt.E1_tolerance): - if self.opt.cylinder_rank == 0: - print("ERROR") - print("Total probability of scenarios was ", self.opt.E1) - print("E1_tolerance = ", self.opt.E1_tolerance) - quit() - - ### end iter0 stuff - - xhatter.post_iter0() - self.opt._save_nonants() # make the cache - - return xhatter + def xhat_extension(self): + return XhatSpecific(self.opt) def main(self): """ @@ -76,7 +40,7 @@ def main(self): xhat_scenario_dict = self.opt.options["xhat_specific_options"]\ ["xhat_scenario_dict"] - xhatter = self.ib_prep() + xhatter = self.xhat_prep() ib_iter = 1 # ib is for inner bound while (not self.got_kill_signal()): diff --git a/mpisppy/cylinders/xhatxbar_bounder.py b/mpisppy/cylinders/xhatxbar_bounder.py index f98bb208..39f3716f 100644 --- a/mpisppy/cylinders/xhatxbar_bounder.py +++ b/mpisppy/cylinders/xhatxbar_bounder.py @@ -10,9 +10,8 @@ # xbar from xhat (copied from xhat specific, DLW Feb 2023) import pyomo.environ as pyo -import mpisppy.cylinders.spoke as spoke from mpisppy.extensions.xhatxbar import XhatXbar -from mpisppy.utils.xhat_eval import Xhat_Eval +from mpisppy.cylinders.xhatbase import XhatInnerBoundBase import mpisppy.MPI as mpi import logging @@ -34,43 +33,22 @@ def _attach_xbars(opt): ############################################################################ -class XhatXbarInnerBound(spoke.InnerBoundNonantSpoke): +class XhatXbarInnerBound(XhatInnerBoundBase): converger_spoke_char = 'B' - def ib_prep(self): + def xhat_extension(self): + return XhatXbar(self.opt) + + def xhat_prep(self): """ Set up the objects needed for bounding. Returns: xhatter (xhatxbar object): Constructed by a call to Prep """ - if "bundles_per_rank" in self.opt.options\ - and self.opt.options["bundles_per_rank"] != 0: - raise RuntimeError("xhat spokes cannot have bundles (yet)") - - if not isinstance(self.opt, Xhat_Eval): - raise RuntimeError("XhatXbarInnerBound must be used with Xhat_Eval.") - - xhatter = XhatXbar(self.opt) - # somehow deal with the prox option .... TBD .... important for aph APH - - # begin iter0 stuff - xhatter.pre_iter0() - self.opt._save_original_nonants() - - self.opt._lazy_create_solvers() # no iter0 loop, but we need the solvers - - self.opt._update_E1() - if (abs(1 - self.opt.E1) > self.opt.E1_tolerance): - raise RuntimeError(f"Total probability of scenarios was {self.E1}; E1_tolerance = ", self.E1_tolerance) - - ### end iter0 stuff - - xhatter.post_iter0() + xhatter = super().xhat_prep() _attach_xbars(self.opt) - self.opt._save_nonants() # make the cache - return xhatter def main(self): @@ -81,7 +59,7 @@ def main(self): dtm = logging.getLogger(f'dtm{global_rank}') logging.debug("Enter xhatxbar main on rank {}".format(global_rank)) - xhatter = self.ib_prep() + xhatter = self.xhat_prep() ib_iter = 1 # ib is for inner bound while (not self.got_kill_signal()): diff --git a/mpisppy/utils/cfg_vanilla.py b/mpisppy/utils/cfg_vanilla.py index 4c1fe84b..393b2a01 100644 --- a/mpisppy/utils/cfg_vanilla.py +++ b/mpisppy/utils/cfg_vanilla.py @@ -389,6 +389,7 @@ def _PHBase_spoke_foundation( rho_setter=None, all_nodenames=None, ph_extensions=None, + extension_kwargs=None, ): # only the shared options shoptions = shared_options(cfg) @@ -410,6 +411,8 @@ def _PHBase_spoke_foundation( spoke_dict["opt_kwargs"]["rho_setter"] = rho_setter if ph_extensions is not None: spoke_dict["opt_kwargs"]["extensions"] = ph_extensions + if extension_kwargs is not None: + spoke_dict["opt_kwargs"]["extension_kwargs"] = extension_kwargs return spoke_dict @@ -423,6 +426,7 @@ def _Xhat_Eval_spoke_foundation( rho_setter=None, all_nodenames=None, ph_extensions=None, + extension_kwargs=None, ): spoke_dict = _PHBase_spoke_foundation( spoke_class, @@ -433,11 +437,10 @@ def _Xhat_Eval_spoke_foundation( scenario_creator_kwargs=scenario_creator_kwargs, rho_setter=rho_setter, all_nodenames=all_nodenames, - ph_extensions=ph_extensions) + ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, + ) spoke_dict["opt_class"] = Xhat_Eval - if ph_extensions is not None: - spoke_dict["opt_kwargs"]["ph_extensions"] = ph_extensions - del spoke_dict["opt_kwargs"]["extensions"] # ph_extensions in Xhat_Eval return spoke_dict @@ -449,6 +452,8 @@ def lagrangian_spoke( scenario_creator_kwargs=None, rho_setter=None, all_nodenames=None, + ph_extensions=None, + extension_kwargs=None, ): lagrangian_spoke = _PHBase_spoke_foundation( LagrangianOuterBound, @@ -459,6 +464,8 @@ def lagrangian_spoke( scenario_creator_kwargs=scenario_creator_kwargs, rho_setter=rho_setter, all_nodenames=all_nodenames, + ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, ) if cfg.lagrangian_iter0_mipgap is not None: lagrangian_spoke["opt_kwargs"]["options"]["iter0_solver_options"]\ @@ -479,6 +486,8 @@ def reduced_costs_spoke( scenario_creator_kwargs=None, rho_setter=None, all_nodenames=None, + ph_extensions=None, + extension_kwargs=None, ): rc_spoke = _PHBase_spoke_foundation( ReducedCostsSpoke, @@ -489,6 +498,8 @@ def reduced_costs_spoke( scenario_creator_kwargs=scenario_creator_kwargs, rho_setter=rho_setter, all_nodenames=all_nodenames, + ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, ) add_ph_tracking(rc_spoke, cfg, spoke=True) @@ -506,6 +517,8 @@ def lagranger_spoke( scenario_creator_kwargs=None, rho_setter=None, all_nodenames = None, + ph_extensions=None, + extension_kwargs=None, ): lagranger_spoke = _PHBase_spoke_foundation( LagrangerOuterBound, @@ -516,6 +529,8 @@ def lagranger_spoke( scenario_creator_kwargs=scenario_creator_kwargs, rho_setter=rho_setter, all_nodenames=all_nodenames, + ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, ) if cfg.lagranger_iter0_mipgap is not None: lagranger_spoke["opt_kwargs"]["options"]["iter0_solver_options"]\ @@ -539,6 +554,8 @@ def subgradient_spoke( scenario_creator_kwargs=None, rho_setter=None, all_nodenames=None, + ph_extensions=None, + extension_kwargs=None, ): subgradient_spoke = _PHBase_spoke_foundation( SubgradientOuterBound, @@ -549,6 +566,8 @@ def subgradient_spoke( scenario_creator_kwargs=scenario_creator_kwargs, rho_setter=rho_setter, all_nodenames=all_nodenames, + ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, ) if cfg.subgradient_iter0_mipgap is not None: subgradient_spoke["opt_kwargs"]["options"]["iter0_solver_options"]\ @@ -571,6 +590,7 @@ def xhatlooper_spoke( all_scenario_names, scenario_creator_kwargs=None, ph_extensions=None, + extension_kwargs=None, ): xhatlooper_dict = _Xhat_Eval_spoke_foundation( @@ -581,6 +601,7 @@ def xhatlooper_spoke( all_scenario_names, scenario_creator_kwargs=scenario_creator_kwargs, ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, ) xhatlooper_dict["opt_kwargs"]["options"]['bundles_per_rank'] = 0 # no bundles for xhat @@ -602,6 +623,7 @@ def xhatxbar_spoke( scenario_creator_kwargs=None, variable_probability=None, ph_extensions=None, + extension_kwargs=None, all_nodenames=None, ): xhatxbar_dict = _Xhat_Eval_spoke_foundation( @@ -612,6 +634,7 @@ def xhatxbar_spoke( all_scenario_names, scenario_creator_kwargs=scenario_creator_kwargs, ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, all_nodenames=all_nodenames, ) @@ -635,6 +658,7 @@ def xhatshuffle_spoke( all_nodenames=None, scenario_creator_kwargs=None, ph_extensions=None, + extension_kwargs=None, ): xhatshuffle_dict = _Xhat_Eval_spoke_foundation( @@ -646,6 +670,7 @@ def xhatshuffle_spoke( all_nodenames=all_nodenames, scenario_creator_kwargs=scenario_creator_kwargs, ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, ) xhatshuffle_dict["opt_kwargs"]["options"]['bundles_per_rank'] = 0 # no bundles for xhat xhatshuffle_dict["opt_kwargs"]["options"]["xhat_looper_options"] = { @@ -669,7 +694,8 @@ def xhatspecific_spoke( scenario_dict, all_nodenames=None, scenario_creator_kwargs=None, - ph_extensions=None, + ph_extensions=None, + extension_kwargs=None, ): xhatspecific_dict = _Xhat_Eval_spoke_foundation( @@ -680,6 +706,7 @@ def xhatspecific_spoke( all_scenario_names, scenario_creator_kwargs=scenario_creator_kwargs, ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, ) xhatspecific_dict["opt_kwargs"]["options"]['bundles_per_rank'] = 0 # no bundles for xhat return xhatspecific_dict @@ -691,6 +718,7 @@ def xhatlshaped_spoke( all_scenario_names, scenario_creator_kwargs=None, ph_extensions=None, + extension_kwargs=None, ): xhatlshaped_dict = _Xhat_Eval_spoke_foundation( @@ -701,6 +729,7 @@ def xhatlshaped_spoke( all_scenario_names, scenario_creator_kwargs=scenario_creator_kwargs, ph_extensions=ph_extensions, + extension_kwargs=extension_kwargs, ) xhatlshaped_dict["opt_kwargs"]["options"]['bundles_per_rank'] = 0 # no bundles for xhat diff --git a/mpisppy/utils/xhat_eval.py b/mpisppy/utils/xhat_eval.py index fd93dd72..41217d30 100644 --- a/mpisppy/utils/xhat_eval.py +++ b/mpisppy/utils/xhat_eval.py @@ -43,7 +43,8 @@ def __init__( mpicomm=None, scenario_creator_kwargs=None, variable_probability=None, - ph_extensions=None, + extensions=None, + extension_kwargs=None, ): super().__init__( @@ -52,7 +53,8 @@ def __init__( scenario_creator, scenario_denouement=scenario_denouement, all_nodenames=all_nodenames, - extensions=ph_extensions, + extensions=extensions, + extension_kwargs=extension_kwargs, mpicomm=mpicomm, scenario_creator_kwargs=scenario_creator_kwargs, variable_probability=variable_probability,