Skip to content

Commit

Permalink
Merge pull request Pyomo#3409 from jsiirola/configdict-catch-recursion
Browse files Browse the repository at this point in the history
ConfigDict: prevent recursion on partially-constructed objects
  • Loading branch information
jsiirola authored Nov 8, 2024
2 parents 5e9b0f2 + 91ec93f commit c662b35
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 13 deletions.
33 changes: 22 additions & 11 deletions pyomo/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1693,17 +1693,23 @@ class UninitializedMixin(object):
@property
def _data(self):
#
# This is a possibly dangerous construct: falling back on
# calling the _default can mask a real problem in the default
# type/value.
# We assume that _default is usually a concrete value. But, we
# also accept a types (classes) and initialization functions as
# defaults, in which case we will construct an instance of that
# class and use that as the default. If they both raise
# exceptions, we will let the original exception propagate up.
#
try:
self._setter(self._default)
except:
if hasattr(self._default, '__call__'):
self._setter(self._default())
else:
raise
_default_val = self._default()
try:
self._setter(_default_val)
return self._data
except:
pass
raise
return self._data

@_data.setter
Expand Down Expand Up @@ -2705,14 +2711,19 @@ def __len__(self):
def __iter__(self):
return map(attrgetter('_name'), self._data.values())

def __getattr__(self, name):
def __getattr__(self, attr):
# Note: __getattr__ is only called after all "usual" attribute
# lookup methods have failed. So, if we get here, we already
# know that key is not a __slot__ or a method, etc...
_name = name.replace(' ', '_')
if _name not in self._data:
raise AttributeError("Unknown attribute '%s'" % name)
return ConfigDict.__getitem__(self, _name)
_attr = attr.replace(' ', '_')
# Note: we test for "_data" because finding attributes on a
# partially constructed ConfigDict (before the _data attribute
# was declared) can lead to infinite recursion.
if _attr == "_data" or _attr not in self._data:
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{attr}'"
)
return ConfigDict.__getitem__(self, _attr)

def __setattr__(self, name, value):
if name in ConfigDict._reserved_words:
Expand Down
17 changes: 15 additions & 2 deletions pyomo/common/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,7 +1863,18 @@ def test_default_function(self):
c.value()

c = ConfigValue('a', domain=int)
with self.assertRaisesRegex(ValueError, 'invalid value for configuration'):
with self.assertRaisesRegex(
ValueError, '(?s)invalid value for configuration.*casting a'
):
c.value()

# Test that if both the default and the result from calling the
# default raise exceptions, the propagated exception is from
# castig the original default:
c = ConfigValue(default=lambda: 'a', domain=int)
with self.assertRaisesRegex(
ValueError, "(?s)invalid value for configuration.*lambda"
):
c.value()

def test_set_default(self):
Expand Down Expand Up @@ -2736,7 +2747,9 @@ def test_getattr_setattr(self):
):
config.baz = 10

with self.assertRaisesRegex(AttributeError, "Unknown attribute 'baz'"):
with self.assertRaisesRegex(
AttributeError, "'ConfigDict' object has no attribute 'baz'"
):
a = config.baz

def test_nonString_keys(self):
Expand Down
15 changes: 15 additions & 0 deletions pyomo/common/tests/test_tee.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# This software is distributed under the 3-clause BSD License.
# ___________________________________________________________________________

import gc
import os
import time
import sys
Expand All @@ -23,6 +24,20 @@


class TestTeeStream(unittest.TestCase):
def setUp(self):
self.reenable_gc = gc.isenabled()
gc.disable()
# Set a short switch interval so that the threading tests behave
# as expected
self.switchinterval = sys.getswitchinterval()
sys.setswitchinterval(tee._poll_interval / 100)

def tearDown(self):
sys.setswitchinterval(self.switchinterval)
if self.reenable_gc:
gc.enable()
gc.collect()

def test_stdout(self):
a = StringIO()
b = StringIO()
Expand Down

0 comments on commit c662b35

Please sign in to comment.