From 55192864f39326d8b38b159ab14ecaed6489f6e4 Mon Sep 17 00:00:00 2001 From: John Siirola Date: Mon, 16 Sep 2024 08:56:10 -0600 Subject: [PATCH] Resolve issue in filter/validate deprecation path --- pyomo/core/base/set.py | 38 ++++++++++++++++---------- pyomo/core/tests/unit/test_set.py | 44 +++++++++++++++++++++++++------ 2 files changed, 60 insertions(+), 22 deletions(-) diff --git a/pyomo/core/base/set.py b/pyomo/core/base/set.py index 69b21c4d78b..5a15321c084 100644 --- a/pyomo/core/base/set.py +++ b/pyomo/core/base/set.py @@ -1484,18 +1484,7 @@ def _cb_validate_filter(self, mode, val_iter): try: flag = fcn(block, (), *vstar) if flag: - deprecation_warning( - f"{self.__class__.__name__} {self.name}: '{mode}=' " - "callback signature matched (block, *value). " - "Please update the callback to match the signature " - f"(block, value{', *index' if comp.is_indexed() else ''}).", - version='6.8.0', - ) - orig_fcn = fcn._fcn - fcn = ParameterizedScalarCallInitializer( - lambda m, v: orig_fcn(m, *v), True - ) - setattr(comp, '_' + mode, fcn) + self._filter_validate_scalar_api_deprecation(mode, warning=True) yield value continue except TypeError: @@ -1536,6 +1525,21 @@ def _cb_validate_filter(self, mode, val_iter): ) raise exc from None + def _filter_validate_scalar_api_deprecation(self, mode, warning): + comp = self.parent_component() + fcn = getattr(comp, '_' + mode) + if warning: + deprecation_warning( + f"{self.__class__.__name__} {self.name}: '{mode}=' " + "callback signature matched (block, *value). " + "Please update the callback to match the signature " + f"(block, value{', *index' if comp.is_indexed() else ''}).", + version='6.8.0', + ) + orig_fcn = fcn._fcn + fcn = ParameterizedScalarCallInitializer(lambda m, v: orig_fcn(m, *v), True) + setattr(comp, '_' + mode, fcn) + def _cb_normalized_dimen_verifier(self, dimen, val_iter): for value in val_iter: if value.__class__ in native_types: @@ -2256,14 +2260,20 @@ def __init__(self, *args, **kwds): self._init_values._init = CountedCallInitializer( self, self._init_values._init ) - # HACK: the DAT parser needs to know the domain of a set in - # order to correctly parse the data stream. + if not self.is_indexed(): + # HACK: the DAT parser needs to know the domain of a set in + # order to correctly parse the data stream. if self._init_domain.constant(): self._domain = self._init_domain(self.parent_block(), None, self) if self._init_dimen.constant(): self._dimen = self._init_dimen(self.parent_block(), None) + if self._filter.__class__ is ParameterizedIndexedCallInitializer: + self._filter_validate_scalar_api_deprecation('filter', warning=False) + if self._validate.__class__ is ParameterizedIndexedCallInitializer: + self._filter_validate_scalar_api_deprecation('validate', warning=False) + @deprecated( "check_values() is deprecated: Sets only contain valid members", version='5.7' ) diff --git a/pyomo/core/tests/unit/test_set.py b/pyomo/core/tests/unit/test_set.py index 6529c2b60a9..6312aaf63c6 100644 --- a/pyomo/core/tests/unit/test_set.py +++ b/pyomo/core/tests/unit/test_set.py @@ -4181,6 +4181,19 @@ def test_indexed_set(self): self.assertIs(type(m.I[3]), InsertionOrderSetData) self.assertEqual(m.I.data(), {1: (4, 2, 5), 2: (4, 2, 5), 3: (4, 2, 5)}) + # Explicit (constant dict) construction + m = ConcreteModel() + m.I = Set([1, 2], initialize={1: (4, 2, 5), 2: (7, 6)}) + self.assertEqual(len(m.I), 2) + self.assertEqual(list(m.I[1]), [4, 2, 5]) + self.assertEqual(list(m.I[2]), [7, 6]) + self.assertIsNot(m.I[1], m.I[2]) + self.assertTrue(m.I[1].isordered()) + self.assertTrue(m.I[2].isordered()) + self.assertIs(type(m.I[1]), InsertionOrderSetData) + self.assertIs(type(m.I[2]), InsertionOrderSetData) + self.assertEqual(m.I.data(), {1: (4, 2, 5), 2: (7, 6)}) + # Explicit (constant) construction m = ConcreteModel() m.I = Set([1, 2, 3], initialize=(4, 2, 5), ordered=Set.SortedOrder) @@ -4255,7 +4268,7 @@ def test_indexing(self): def test_add_filter_validate(self): m = ConcreteModel() m.I = Set(domain=Integers) - self.assertIs(m.I.filter, None) + self.assertIs(m.I._filter, None) with self.assertRaisesRegex( ValueError, r"Cannot add value 1.5 to Set I.\n" @@ -4302,7 +4315,7 @@ def _l_tri(model, i, j): return i >= j m.K = Set(initialize=RangeSet(3) * RangeSet(3), filter=_l_tri) - self.assertIsInstance(m.K.filter, ParameterizedScalarCallInitializer) + self.assertIsInstance(m.K._filter, ParameterizedScalarCallInitializer) self.assertEqual(list(m.K), [(1, 1), (2, 1), (2, 2), (3, 1), (3, 2), (3, 3)]) output = StringIO() @@ -4334,6 +4347,18 @@ def _lt_3(model, i): self.assertEqual(output.getvalue(), "") self.assertEqual(list(m.L[2]), [1, 2, 0]) + # This tests that the deprecation path works correctly in the + # case that the callback doesn't raise an error or ever return + # False + + def _l_off_diag(model, i, j): + self.assertIs(model, m) + return i != j + + m.M = Set(initialize=RangeSet(3) * RangeSet(3), filter=_l_off_diag) + self.assertIsInstance(m.M._filter, ParameterizedScalarCallInitializer) + self.assertEqual(list(m.M), [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)]) + m = ConcreteModel() def _validate(model, val): @@ -4374,12 +4399,15 @@ def _validate(model, i, j): m.I2 = Set(validate=_validate) with LoggingIntercept(module='pyomo.core') as output: self.assertTrue(m.I2.add((0, 1))) - self.assertRegex( - output.getvalue().replace('\n', ' '), - r"DEPRECATED: OrderedScalarSet I2: 'validate=' callback " - r"signature matched \(block, \*value\). Please update the " - r"callback to match the signature \(block, value\)", - ) + # Note that we are not emitting a deprecation warning (yet) + # for scalar sets + # self.assertEqual(output.getvalue(), "") + # output.getvalue().replace('\n', ' '), + # r"DEPRECATED: OrderedScalarSet I2: 'validate=' callback " + # r"signature matched \(block, \*value\). Please update the " + # r"callback to match the signature \(block, value\)", + # ) + self.assertEqual(output.getvalue(), "") with LoggingIntercept(module='pyomo.core') as output: with self.assertRaisesRegex( ValueError,