Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better extension support in xhat inner bound spokes #471

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions mpisppy/cylinders/xhatbase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
###############################################################################
# mpi-sppy: MPI-based Stochastic Programming in PYthon
#
# Copyright (c) 2024, Lawrence Livermore National Security, LLC, Alliance for
# Sustainable Energy, LLC, The Regents of the University of California, et al.
# All rights reserved. Please see the files COPYRIGHT.md and LICENSE.md for
# full copyright and license information.
###############################################################################

import abc
import mpisppy.cylinders.spoke as spoke
from mpisppy.utils.xhat_eval import Xhat_Eval

class XhatInnerBoundBase(spoke.InnerBoundNonantSpoke):

@abc.abstractmethod
def xhat_extension(self):
raise NotImplementedError


def xhat_prep(self):
if "bundles_per_rank" in self.opt.options\
and self.opt.options["bundles_per_rank"] != 0:
raise RuntimeError("xhat spokes cannot have bundles (yet)")

## for later
self.verbose = self.opt.options["verbose"] # typing aid

if not isinstance(self.opt, Xhat_Eval):
raise RuntimeError(f"{self.__class__.__name__} must be used with Xhat_Eval.")

xhatter = self.xhat_extension()

### begin iter0 stuff
xhatter.pre_iter0()
if self.opt.extensions is not None:
self.opt.extobject.pre_iter0() # for an extension
self.opt._save_original_nonants()

self.opt._lazy_create_solvers() # no iter0 loop, but we need the solvers

self.opt._update_E1()
if abs(1 - self.opt.E1) > self.opt.E1_tolerance:
raise ValueError(f"Total probability of scenarios was {self.opt.E1} "+\
f"(E1_tolerance is {self.opt.E1_tolerance})")
### end iter0 stuff (but note: no need for iter 0 solves in an xhatter)

xhatter.post_iter0()
if self.opt.extensions is not None:
self.opt.extobject.post_iter0() # for an extension

self.opt._save_nonants() # make the cache

return xhatter
38 changes: 5 additions & 33 deletions mpisppy/cylinders/xhatlooper_bounder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
# full copyright and license information.
###############################################################################
# updated April 2020
import mpisppy.cylinders.spoke as spoke
from mpisppy.extensions.xhatlooper import XhatLooper
from mpisppy.utils.xhat_eval import Xhat_Eval
from mpisppy.cylinders.xhatbase import XhatInnerBoundBase
import logging
import mpisppy.log

Expand All @@ -20,44 +19,17 @@
logger = logging.getLogger("mpisppy.cylinders.xhatlooper_bounder")


class XhatLooperInnerBound(spoke.InnerBoundNonantSpoke):
class XhatLooperInnerBound(XhatInnerBoundBase):

converger_spoke_char = 'X'

def xhatlooper_prep(self):
if "bundles_per_rank" in self.opt.options\
and self.opt.options["bundles_per_rank"] != 0:
raise RuntimeError("xhat spokes cannot have bundles (yet)")

if not isinstance(self.opt, Xhat_Eval):
raise RuntimeError("XhatShuffleInnerBound must be used with Xhat_Eval.")

xhatter = XhatLooper(self.opt)

### begin iter0 stuff
xhatter.pre_iter0()
self.opt._save_original_nonants()

self.opt._lazy_create_solvers() # no iter0 loop, but we need the solvers

self.opt._update_E1()
if abs(1 - self.opt.E1) > self.opt.E1_tolerance:
if self.opt.cylinder_rank == 0:
print("ERROR")
print("Total probability of scenarios was ", self.opt.E1)
print("E1_tolerance = ", self.opt.E1_tolerance)
quit()
### end iter0 stuff

xhatter.post_iter0()
self.opt._save_nonants() # make the cache

return xhatter
def xhat_extension(self):
return XhatLooper(self.opt)

def main(self):
logger.debug(f"Entering main on xhatlooper spoke rank {self.global_rank}")

xhatter = self.xhatlooper_prep()
xhatter = self.xhat_prep()

scen_limit = self.opt.options['xhat_looper_options']['scen_limit']

Expand Down
42 changes: 8 additions & 34 deletions mpisppy/cylinders/xhatshufflelooper_bounder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,25 @@
import logging
import random
import mpisppy.log
import mpisppy.cylinders.spoke as spoke

from mpisppy.utils.xhat_eval import Xhat_Eval
from mpisppy.extensions.xhatbase import XhatBase
from mpisppy.cylinders.xhatbase import XhatInnerBoundBase

# Could also pass, e.g., sys.stdout instead of a filename
mpisppy.log.setup_logger("mpisppy.cylinders.xhatshufflelooper_bounder",
"xhatclp.log",
level=logging.CRITICAL)
logger = logging.getLogger("mpisppy.cylinders.xhatshufflelooper_bounder")

class XhatShuffleInnerBound(spoke.InnerBoundNonantSpoke):
class XhatShuffleInnerBound(XhatInnerBoundBase):

converger_spoke_char = 'X'

def xhatbase_prep(self):
def xhat_extension(self):
return XhatBase(self.opt)

if "bundles_per_rank" in self.opt.options\
and self.opt.options["bundles_per_rank"] != 0:
raise RuntimeError("xhat spokes cannot have bundles (yet)")

## for later
self.verbose = self.opt.options["verbose"] # typing aid
self.solver_options = self.opt.options["xhat_looper_options"]["xhat_solver_options"]

if not isinstance(self.opt, Xhat_Eval):
raise RuntimeError("XhatShuffleInnerBound must be used with Xhat_Eval.")

xhatter = XhatBase(self.opt)
self.xhatter = xhatter

### begin iter0 stuff
xhatter.pre_iter0() # for an extension
self.opt._save_original_nonants()

self.opt._lazy_create_solvers() # no iter0 loop, but we need the solvers

self.opt._update_E1()
if abs(1 - self.opt.E1) > self.opt.E1_tolerance:
raise ValueError(f"Total probability of scenarios was {self.opt.E1} "+\
f"(E1_tolerance is {self.opt.E1_tolerance})")
### end iter0 stuff (but note: no need for iter 0 solves in an xhatter)

xhatter.post_iter0()

self.opt._save_nonants() # make the cache
def xhat_prep(self):
self.xhatter = super().xhat_prep()

## option drive this? (could be dangerous)
self.random_seed = 42
Expand Down Expand Up @@ -90,7 +63,7 @@ def _vb(msg):
def main(self):
logger.debug(f"Entering main on xhatshuffle spoke rank {self.global_rank}")

self.xhatbase_prep()
self.xhat_prep()
if "reverse" in self.opt.options["xhat_looper_options"]:
self.reverse = self.opt.options["xhat_looper_options"]["reverse"]
else:
Expand All @@ -99,6 +72,7 @@ def main(self):
self.iter_step = self.opt.options["xhat_looper_options"]["iter_step"]
else:
self.iter_step = None
self.solver_options = self.opt.options["xhat_looper_options"]["xhat_solver_options"]

# give all ranks the same seed
self.random_stream.seed(self.random_seed)
Expand Down
46 changes: 5 additions & 41 deletions mpisppy/cylinders/xhatspecific_bounder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
# udpated April 20
# specific xhat supplied (copied from xhatlooper_bounder by DLW, Dec 2019)

import mpisppy.cylinders.spoke as spoke
from mpisppy.extensions.xhatspecific import XhatSpecific
from mpisppy.utils.xhat_eval import Xhat_Eval
from mpisppy.cylinders.xhatbase import XhatInnerBoundBase

import mpisppy.MPI as mpi
import logging
Expand All @@ -22,47 +21,12 @@


############################################################################
class XhatSpecificInnerBound(spoke.InnerBoundNonantSpoke):
class XhatSpecificInnerBound(XhatInnerBoundBase):

converger_spoke_char = 'S'

def ib_prep(self):
"""
Set up the objects needed for bounding.

Returns:
xhatter (xhatspecific object): Constructed by a call to Prep
"""
if "bundles_per_rank" in self.opt.options\
and self.opt.options["bundles_per_rank"] != 0:
raise RuntimeError("xhat spokes cannot have bundles (yet)")

if not isinstance(self.opt, Xhat_Eval):
raise RuntimeError("XhatShuffleInnerBound must be used with Xhat_Eval.")

xhatter = XhatSpecific(self.opt)
# somehow deal with the prox option .... TBD .... important for aph APH

# begin iter0 stuff
xhatter.pre_iter0()
self.opt._save_original_nonants()

self.opt._lazy_create_solvers() # no iter0 loop, but we need the solvers

self.opt._update_E1()
if (abs(1 - self.opt.E1) > self.opt.E1_tolerance):
if self.opt.cylinder_rank == 0:
print("ERROR")
print("Total probability of scenarios was ", self.opt.E1)
print("E1_tolerance = ", self.opt.E1_tolerance)
quit()

### end iter0 stuff

xhatter.post_iter0()
self.opt._save_nonants() # make the cache

return xhatter
def xhat_extension(self):
return XhatSpecific(self.opt)

def main(self):
"""
Expand All @@ -76,7 +40,7 @@ def main(self):
xhat_scenario_dict = self.opt.options["xhat_specific_options"]\
["xhat_scenario_dict"]

xhatter = self.ib_prep()
xhatter = self.xhat_prep()

ib_iter = 1 # ib is for inner bound
while (not self.got_kill_signal()):
Expand Down
38 changes: 8 additions & 30 deletions mpisppy/cylinders/xhatxbar_bounder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
# xbar from xhat (copied from xhat specific, DLW Feb 2023)

import pyomo.environ as pyo
import mpisppy.cylinders.spoke as spoke
from mpisppy.extensions.xhatxbar import XhatXbar
from mpisppy.utils.xhat_eval import Xhat_Eval
from mpisppy.cylinders.xhatbase import XhatInnerBoundBase

import mpisppy.MPI as mpi
import logging
Expand All @@ -34,43 +33,22 @@ def _attach_xbars(opt):


############################################################################
class XhatXbarInnerBound(spoke.InnerBoundNonantSpoke):
class XhatXbarInnerBound(XhatInnerBoundBase):

converger_spoke_char = 'B'

def ib_prep(self):
def xhat_extension(self):
return XhatXbar(self.opt)

def xhat_prep(self):
"""
Set up the objects needed for bounding.

Returns:
xhatter (xhatxbar object): Constructed by a call to Prep
"""
if "bundles_per_rank" in self.opt.options\
and self.opt.options["bundles_per_rank"] != 0:
raise RuntimeError("xhat spokes cannot have bundles (yet)")

if not isinstance(self.opt, Xhat_Eval):
raise RuntimeError("XhatXbarInnerBound must be used with Xhat_Eval.")

xhatter = XhatXbar(self.opt)
# somehow deal with the prox option .... TBD .... important for aph APH

# begin iter0 stuff
xhatter.pre_iter0()
self.opt._save_original_nonants()

self.opt._lazy_create_solvers() # no iter0 loop, but we need the solvers

self.opt._update_E1()
if (abs(1 - self.opt.E1) > self.opt.E1_tolerance):
raise RuntimeError(f"Total probability of scenarios was {self.E1}; E1_tolerance = ", self.E1_tolerance)

### end iter0 stuff

xhatter.post_iter0()
xhatter = super().xhat_prep()
_attach_xbars(self.opt)
self.opt._save_nonants() # make the cache

return xhatter

def main(self):
Expand All @@ -81,7 +59,7 @@ def main(self):
dtm = logging.getLogger(f'dtm{global_rank}')
logging.debug("Enter xhatxbar main on rank {}".format(global_rank))

xhatter = self.ib_prep()
xhatter = self.xhat_prep()

ib_iter = 1 # ib is for inner bound
while (not self.got_kill_signal()):
Expand Down
Loading
Loading