diff --git a/doc/OnlineDocs/explanation/philosophy/expressions/managing.rst b/doc/OnlineDocs/explanation/philosophy/expressions/managing.rst index f028eef778d..ded3a5f8f9f 100644 --- a/doc/OnlineDocs/explanation/philosophy/expressions/managing.rst +++ b/doc/OnlineDocs/explanation/philosophy/expressions/managing.rst @@ -140,10 +140,6 @@ tree: seven event callbacks that users can hook into, providing very fine-grained control over the expression walker. -:class:`SimpleExpressionVisitor ` - A :func:`visitor` method is called for each node in the tree, - and the visitor class collects information about the tree. - :class:`ExpressionValueVisitor ` When the :func:`visitor` method is called on each node in the tree, the *values* of its children have been computed. The @@ -166,12 +162,6 @@ These classes define a variety of suitable tree search methods: * ``walk_expression``: depth-first traversal of the expression tree. -* :class:`SimpleExpressionVisitor ` - - * ``xbfs``: breadth-first search where leaf nodes are immediately visited - * ``xbfs_yield_leaves``: breadth-first search where leaf nodes are - immediately visited, and the visit method yields a value - * :class:`ExpressionValueVisitor ` * ``dfs_postorder_stack``: postorder depth-first search using a @@ -179,12 +169,11 @@ These classes define a variety of suitable tree search methods: To implement a visitor object, a user needs to provide specializations -for specific events. For legacy visitors based on the PyUtilib -visitor pattern (e.g., :class:`SimpleExpressionVisitor` and -:class:`ExpressionValueVisitor`), one must create a subclass of one of these -classes and override at least one of the following: +for specific events. For legacy visitors based on the PyUtilib visitor +pattern (e.g., :class:`ExpressionValueVisitor`), one must create a +subclass and override at least one of the following: -:func:`visitor` +:func:`visit` Defines the operation that is performed when a node is visited. In the :class:`ExpressionValueVisitor ` and @@ -196,10 +185,7 @@ classes and override at least one of the following: Checks if the search should terminate with this node. If no, then this method returns the tuple ``(False, None)``. If yes, then this method returns ``(False, value)``, where *value* is - computed by this method. This method is not used in the - :class:`SimpleExpressionVisitor - ` visitor - class. + computed by this method. :func:`finalize` This method defines the final value that is returned from the @@ -216,8 +202,8 @@ callbacks, which are documented in the class documentation. Detailed documentation of the APIs for these methods is provided with the class documentation for these visitors. -SimpleExpressionVisitor Example -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +StreamBasedExpressionVisitor Example +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In this example, we describe an visitor class that counts the number of nodes in an expression (including leaf nodes). Consider the following @@ -225,10 +211,11 @@ class: .. literalinclude:: /src/expr/managing_visitor1.spy -The class constructor creates a counter, and the :func:`visit` method -increments this counter for every node that is visited. The :func:`finalize` -method returns the value of this counter after the tree has been walked. The -following function illustrates this use of this visitor class: +The :func:`initializeWalker` method creates a counter, and the +:func:`exitNode` method increments this counter for every node that is +visited. The :func:`finalizeResult` method returns the value of this +counter after the tree has been walked. The following function +illustrates this use of this visitor class: .. literalinclude:: /src/expr/managing_visitor2.spy diff --git a/doc/OnlineDocs/reference/topical/expressions/visitors.rst b/doc/OnlineDocs/reference/topical/expressions/visitors.rst index 847cc40b07d..ef09c62c2ff 100644 --- a/doc/OnlineDocs/reference/topical/expressions/visitors.rst +++ b/doc/OnlineDocs/reference/topical/expressions/visitors.rst @@ -5,6 +5,5 @@ Visitor Classes .. autosummary:: pyomo.core.expr.StreamBasedExpressionVisitor - pyomo.core.expr.SimpleExpressionVisitor pyomo.core.expr.ExpressionValueVisitor pyomo.core.expr.ExpressionReplacementVisitor diff --git a/pyomo/core/expr/numvalue.py b/pyomo/core/expr/numvalue.py index d2c63a3d967..31349b13c18 100644 --- a/pyomo/core/expr/numvalue.py +++ b/pyomo/core/expr/numvalue.py @@ -116,6 +116,12 @@ def __repr__(self): def __call__(self, exception=None): return self.value + def is_constant(self): + return True + + def is_fixed(self): + return True + nonpyomo_leaf_types.add(NonNumericValue) diff --git a/pyomo/core/expr/visitor.py b/pyomo/core/expr/visitor.py index b172055c8e3..d8ab3b298e5 100644 --- a/pyomo/core/expr/visitor.py +++ b/pyomo/core/expr/visitor.py @@ -679,6 +679,11 @@ def _nonrecursive_walker_loop(self, ptr): ptr = ptr[0] +@deprecated( + "The SimpleExpressionVisitor is deprecated. " + "Please use the StreamBasedExpressionVisitor instead.", + version='6.9.0.dev0', +) class SimpleExpressionVisitor(object): """ Note: @@ -736,6 +741,14 @@ def xbfs(self, node): The return value is determined by the :func:`finalize` function, which may be defined by the user. Defaults to :const:`None`. """ + if ( + node.__class__ in nonpyomo_leaf_types + or not node.is_expression_type() + or node.nargs() == 0 + ): + self.visit(node) + return self.finalize() + dq = deque([node]) while dq: current = dq.popleft() @@ -1335,20 +1348,25 @@ def evaluate_expression(exp, exception=True, constant=False): # ===================================================== -class _ComponentVisitor(SimpleExpressionVisitor): +class _ComponentVisitor(StreamBasedExpressionVisitor): def __init__(self, types): - self.seen = set() - if types.__class__ is set: - self.types = types - else: - self.types = set(types) + super().__init__() + if types.__class__ is not set: + types = set(types) + self._types = types - def visit(self, node): - if node.__class__ in self.types: - if id(node) in self.seen: - return - self.seen.add(id(node)) - return node + def initializeWalker(self, expr): + self._objs = [] + self._seen = set() + return True, None + + def finalizeResult(self, result): + return self._objs + + def exitNode(self, node, data): + if node.__class__ in self._types and id(node) not in self._seen: + self._seen.add(id(node)) + self._objs.append(node) def identify_components(expr, component_types): @@ -1370,7 +1388,7 @@ def identify_components(expr, component_types): # in the expression. # visitor = _ComponentVisitor(component_types) - yield from visitor.xbfs_yield_leaves(expr) + yield from visitor.walk_expression(expr) # ===================================================== @@ -1378,7 +1396,7 @@ def identify_components(expr, component_types): # ===================================================== -class _VariableVisitor(StreamBasedExpressionVisitor): +class IdentifyVariableVisitor(StreamBasedExpressionVisitor): def __init__(self, include_fixed=False, named_expression_cache=None): """Visitor that collects all unique variables participating in an expression @@ -1392,108 +1410,85 @@ def __init__(self, include_fixed=False, named_expression_cache=None): """ super().__init__() self._include_fixed = include_fixed - if named_expression_cache is None: - # This cache will map named expression ids to the - # tuple: ([variables], {variable ids}) - named_expression_cache = {} - self._named_expression_cache = named_expression_cache - # Stack of active named expressions. This holds the id of - # expressions we are currently in. - self._active_named_expressions = [] + self._cache = named_expression_cache + # Stack of named expressions. This holds the id of the + # subexpression we are currently processing, along with a + # (_objs, _seen, _exprs) tuple for the parent context.. + self._expr_stack = [] + # The following attributes will be added by initializeWalker: + # self._objs: the list of found objects + # self._seen: set(self._objs) + # self._exprs: list of (e, e.expr) for any (nested) named expressions def initializeWalker(self, expr): - if expr.__class__ in native_types: - return False, [] - elif expr.is_named_expression_type(): - eid = id(expr) - if eid in self._named_expression_cache: - # If we were given a named expression that is already cached, - # just do nothing and return the expression's variables - variables, var_set = self._named_expression_cache[eid] - return False, variables - else: - # We were given a named expression that is not cached. - # Initialize data structures and add this expression to the - # stack. This expression will get popped in exitNode. - self._variables = [] - self._seen = set() - self._named_expression_cache[eid] = [], set() - self._active_named_expressions.append(eid) - return True, expr - elif expr.is_variable_type(): - return False, [expr] - else: - self._variables = [] - self._seen = set() - return True, expr + assert not self._expr_stack + self._objs = [] + self._seen = set() + self._exprs = None + if not self.beforeChild(None, expr, 0)[0]: + return False, self.finalizeResult(None) + return True, expr def beforeChild(self, parent, child, index): if child.__class__ in native_types: return False, None - elif child.is_named_expression_type(): - eid = id(child) - if eid in self._named_expression_cache: - # We have already encountered this named expression. We just add - # the cached variables to our list and don't descend. - if self._active_named_expressions: - # If we are in another named expression, we update the - # parent expression's cache. We don't need to update the - # global list as we will do this when we exit the active - # named expression. - parent_eid = self._active_named_expressions[-1] - variables, var_set = self._named_expression_cache[parent_eid] - else: - # If we are not in a named expression, we update the global - # list. - variables = self._variables - var_set = self._seen - for var in self._named_expression_cache[eid][0]: - if id(var) not in var_set: - var_set.add(id(var)) - variables.append(var) - return False, None + elif child.is_expression_type(): + if child.is_named_expression_type(): + return self._process_named_expr(child) else: - # If we are descending into a new named expression, initialize - # a cache to store the expression's local variables. - self._named_expression_cache[id(child)] = ([], set()) - self._active_named_expressions.append(id(child)) return True, None - elif child.is_variable_type() and (self._include_fixed or not child.fixed): - if self._active_named_expressions: - # If we are in a named expression, add new variables to the cache. - eid = self._active_named_expressions[-1] - variables, var_set = self._named_expression_cache[eid] - else: - variables = self._variables - var_set = self._seen - if id(child) not in var_set: - var_set.add(id(child)) - variables.append(child) - return False, None - else: - return True, None + if child.is_variable_type() and (self._include_fixed or not child.fixed): + if id(child) not in self._seen: + self._seen.add(id(child)) + self._objs.append(child) + return False, None def exitNode(self, node, data): - if node.is_named_expression_type(): - # If we are returning from a named expression, we have at least one - # active named expression. We must make sure that we properly - # handle the variables for the named expression we just exited. - eid = self._active_named_expressions.pop() - if self._active_named_expressions: - # If we still are in a named expression, we update that expression's - # cache with any new variables encountered. - parent_eid = self._active_named_expressions[-1] - variables, var_set = self._named_expression_cache[parent_eid] - else: - variables = self._variables - var_set = self._seen - for var in self._named_expression_cache[eid][0]: - if id(var) not in var_set: - var_set.add(id(var)) - variables.append(var) + if node.is_named_expression_type() and self._cache is not None: + # If we are returning from a named expression, we must make + # sure that we properly restore the "outer" context and then + # merge the objects from the named expression we just exited + # into the list for the parent expression context. + sub_info = self._objs, self._seen, self._exprs + eid, (self._objs, self._seen, self._exprs) = self._expr_stack.pop() + assert eid == id(node) + self._merge_obj_lists(sub_info) def finalizeResult(self, result): - return self._variables + assert not self._expr_stack + return self._objs + + def _merge_obj_lists(self, info): + _objs, _seen, _exprs = info + self._objs.extend(v for v in _objs if id(v) not in self._seen) + self._seen.update(_seen) + if self._exprs is not None: + self._exprs.extend(_exprs) + + def _process_named_expr(self, child): + eid = id(child) + if self._cache is None: + return True, None + elif eid in self._cache and all(c.expr is e for c, e in self._cache[eid][2]): + # We have already encountered this named expression. We just add + # the cached objects to our list and don't descend. + # + # Note that a cache hit requires not only that we have seen + # this expression before, but also that none of the named + # expressions have changed. If they have, then the cache + # miss will fall over to the else clause below and descend + # into the expression, (implicitly) rebuilding the cache. + self._merge_obj_lists(self._cache[eid]) + return False, None + else: + # If we are descending into a new named expression, initialize + # a cache to store the expression's local objects. + self._expr_stack.append((eid, (self._objs, self._seen, self._exprs))) + self._objs = [] + self._seen = set() + self._exprs = [(child, child.expr)] + self._cache[eid] = (self._objs, self._seen, self._exprs) + return True, None def identify_variables(expr, include_fixed=True, named_expression_cache=None): @@ -1510,13 +1505,17 @@ def identify_variables(expr, include_fixed=True, named_expression_cache=None): Yields: Each variable that is found. """ - if named_expression_cache is None: - named_expression_cache = {} - visitor = _VariableVisitor( - named_expression_cache=named_expression_cache, include_fixed=include_fixed - ) - variables = visitor.walk_expression(expr) - yield from variables + v = identify_variables.visitor + save = v._include_fixed, v._cache + try: + v._include_fixed = include_fixed + v._cache = named_expression_cache + yield from v.walk_expression(expr) + finally: + v._include_fixed, v._cache = save + + +identify_variables.visitor = IdentifyVariableVisitor() # ===================================================== @@ -1524,20 +1523,28 @@ def identify_variables(expr, include_fixed=True, named_expression_cache=None): # ===================================================== -class _MutableParamVisitor(SimpleExpressionVisitor): +class IdentifyMutableParamVisitor(IdentifyVariableVisitor): def __init__(self): - self.seen = set() - - def visit(self, node): - if node.__class__ in nonpyomo_leaf_types: - return + # Hide the IdentifyVariableVisitor API (not relevant here) + super().__init__() - # TODO: Confirm that this has the right semantics - if not node.is_variable_type() and node.is_fixed() and not node.is_constant(): - if id(node) in self.seen: - return - self.seen.add(id(node)) - return node + def beforeChild(self, parent, child, index): + if child.__class__ in native_types: + return False, None + elif child.is_expression_type(): + if child.is_named_expression_type(): + return self._process_named_expr(child) + else: + return True, None + if ( + not child.is_variable_type() + and child.is_fixed() + and not child.is_constant() + ): + if id(child) not in self._seen: + self._seen.add(id(child)) + self._objs.append(child) + return False, None def identify_mutable_parameters(expr): @@ -1551,9 +1558,10 @@ def identify_mutable_parameters(expr): Yields: Each mutable parameter that is found. """ - visitor = _MutableParamVisitor() - yield from visitor.xbfs_yield_leaves(expr) + yield from identify_mutable_parameters.visitor.walk_expression(expr) + +identify_mutable_parameters.visitor = IdentifyMutableParamVisitor() # ===================================================== # polynomial_degree diff --git a/pyomo/core/tests/unit/test_visitor.py b/pyomo/core/tests/unit/test_visitor.py index 5733710ab46..d9731adde41 100644 --- a/pyomo/core/tests/unit/test_visitor.py +++ b/pyomo/core/tests/unit/test_visitor.py @@ -60,6 +60,7 @@ from pyomo.core.expr.visitor import ( FixedExpressionError, NonConstantExpressionError, + SimpleExpressionVisitor, StreamBasedExpressionVisitor, ExpressionReplacementVisitor, evaluate_expression, @@ -130,6 +131,76 @@ def test_identify_vars_expr(self): self.assertEqual(list(identify_variables(m.E[0])), [m.a]) self.assertEqual(list(identify_variables(m.E[1])), [m.b]) + def test_identify_vars_expr_cache(self): + # + # Identify variables in named expressions + # + m = ConcreteModel() + m.a = Var(initialize=1) + m.b = Var(initialize=2) + m.c = Var(initialize=3) + m.d = Var(initialize=4) + m.e = Expression(expr=3 * m.a) + + cache = {} + self.assertEqual( + list(identify_variables(m.b + m.e, named_expression_cache=cache)), + [m.b, m.a], + ) + self.assertEqual(cache, {id(m.e): ([m.a], {id(m.a)}, [(m.e, m.e.expr)])}) + + # Check that the cache is used (to check, we will cause the cache to lie) + v, s, e = cache[id(m.e)] + v.pop() + v.extend([m.b, m.c]) + s.clear() + s.update((id(m.b), id(m.c))) + self.assertEqual( + list(identify_variables(m.b + m.e, named_expression_cache=cache)), + [m.b, m.c], + ) + + # Check that changing the expression invalidates the cache + m.e = 4 * m.d + self.assertEqual( + list(identify_variables(m.b + m.e, named_expression_cache=cache)), + [m.b, m.d], + ) + + # Check that changing a nested expression invalidates the cache + m.f = Expression(expr=5 * m.a * m.e * m.b) + self.assertEqual( + list(identify_variables(m.c + m.f, named_expression_cache=cache)), + [m.c, m.a, m.d, m.b], + ) + self.assertEqual( + cache, + { + id(m.e): ([m.d], {id(m.d)}, [(m.e, m.e.expr)]), + id(m.f): ( + [m.a, m.d, m.b], + {id(m.a), id(m.d), id(m.b)}, + [(m.f, m.f.expr), (m.e, m.e.expr)], + ), + }, + ) + m.e = 5 + self.assertEqual( + list(identify_variables(m.c + m.f, named_expression_cache=cache)), + [m.c, m.a, m.b], + ) + self.assertEqual( + cache, + { + id(m.e): ([], set(), [(m.e, m.e.expr)]), + id(m.f): ( + [m.a, m.b], + {id(m.a), id(m.b)}, + [(m.f, m.f.expr), (m.e, m.e.expr)], + ), + }, + ) + def test_identify_vars_vars(self): m = ConcreteModel() m.I = RangeSet(3) @@ -282,7 +353,7 @@ def test_identify_mutable_parameters_params(self): ) self.assertEqual( list(identify_mutable_parameters(m.a ** m.b[1] + m.b[2])), - [m.b[2], m.a, m.b[1]], + [m.a, m.b[1], m.b[2]], ) self.assertEqual( list(identify_mutable_parameters(m.a ** m.b[1] + m.b[2] * m.b[3] * m.b[2])), @@ -297,17 +368,17 @@ def test_identify_mutable_parameters_params(self): # self.assertEqual( list(identify_mutable_parameters(m.x(m.a, 'string_param', 1, []) * m.b[1])), - [m.b[1], m.a], + [m.a, m.b[1]], ) self.assertEqual( list(identify_mutable_parameters(m.x(m.p, 'string_param', 1, []) * m.b[1])), [m.b[1]], ) self.assertEqual( - list(identify_mutable_parameters(tanh(m.a) * m.b[1])), [m.b[1], m.a] + list(identify_mutable_parameters(tanh(m.a) * m.b[1])), [m.a, m.b[1]] ) self.assertEqual( - list(identify_mutable_parameters(abs(m.a) * m.b[1])), [m.b[1], m.a] + list(identify_mutable_parameters(abs(m.a) * m.b[1])), [m.a, m.b[1]] ) # # Check logic for allowing duplicates @@ -1838,6 +1909,58 @@ def test_evaluate_abex(self): return self.run_walker(self.evaluate_abex()) +class TestSimpleExpressionVisitor(unittest.TestCase): + def test_base_class(self): + m = ConcreteModel() + m.x = Var() + m.y = Var() + m.p = Param(mutable=True) + v = SimpleExpressionVisitor() + + e = 5 + self.assertEqual(v.xbfs(e), None) + self.assertEqual(list(v.xbfs_yield_leaves(e)), []) + + e = m.x + self.assertEqual(v.xbfs(e), None) + self.assertEqual(list(v.xbfs_yield_leaves(e)), []) + + e = m.x + 5 * m.y**m.p + self.assertEqual(v.xbfs(e), None) + self.assertEqual(list(v.xbfs_yield_leaves(e)), []) + + def test_derived_visitor(self): + class _Visitor(SimpleExpressionVisitor): + def __init__(self): + super().__init__() + self.nodes = [] + + def visit(self, node): + self.nodes.append(node) + return node + + def finalize(self): + return len(self.nodes) + + m = ConcreteModel() + m.x = Var() + m.y = Var() + m.p = Param(mutable=True) + v = _Visitor() + + e = 5 + self.assertEqual(v.xbfs(e), 1) + self.assertEqual(list(v.xbfs_yield_leaves(e)), [5]) + + e = m.x + self.assertEqual(v.xbfs(e), 3) + self.assertEqual(list(v.xbfs_yield_leaves(e)), [m.x]) + + e = m.x + 5 * m.y**m.p + self.assertEqual(v.xbfs(e), 11) + self.assertEqual(list(v.xbfs_yield_leaves(e)), [m.x, 5, m.y, m.p]) + + class TestEvaluateExpression(unittest.TestCase): def test_constant(self): m = ConcreteModel() diff --git a/pyomo/util/vars_from_expressions.py b/pyomo/util/vars_from_expressions.py index 878a1a13b58..c5dcd0ef0fd 100644 --- a/pyomo/util/vars_from_expressions.py +++ b/pyomo/util/vars_from_expressions.py @@ -17,7 +17,7 @@ actually in the subtree or not. """ from pyomo.core import Block -import pyomo.core.expr as EXPR +from pyomo.core.expr.visitor import IdentifyVariableVisitor def get_vars_from_components( @@ -42,8 +42,8 @@ def get_vars_from_components( descend_into: Ctypes to descend into when finding Constraints descent_order: Traversal strategy for finding the objects of type ctype """ + visitor = IdentifyVariableVisitor(include_fixed, {}) seen = set() - named_expression_cache = {} for constraint in block.component_data_objects( ctype, active=active, @@ -51,11 +51,7 @@ def get_vars_from_components( descend_into=descend_into, descent_order=descent_order, ): - for var in EXPR.identify_variables( - constraint.expr, - include_fixed=include_fixed, - named_expression_cache=named_expression_cache, - ): + for var in visitor.walk_expression(constraint.expr): if id(var) not in seen: seen.add(id(var)) yield var