From 58f2dd9622f1b976169266c22bd079108cd748bd Mon Sep 17 00:00:00 2001 From: Pol Febrer Date: Mon, 18 Jul 2022 18:52:57 +0200 Subject: [PATCH] handling *args and **kwargs node links --- sisl/viz/nodes/node.py | 63 +++++++++-- sisl/viz/nodes/tests/test_lazynodes.py | 144 ++++++++++++++++++++++++- 2 files changed, 197 insertions(+), 10 deletions(-) diff --git a/sisl/viz/nodes/node.py b/sisl/viz/nodes/node.py index f343d72d2a..0e044d008b 100644 --- a/sisl/viz/nodes/node.py +++ b/sisl/viz/nodes/node.py @@ -1,6 +1,7 @@ from __future__ import annotations import inspect +from multiprocessing import connection from typing import Any, Optional from collections import ChainMap from collections.abc import Mapping @@ -77,6 +78,8 @@ class Node(NDArrayOperatorsMixin, metaclass=NodeMeta): """ # Object that will be the reference for output that has not been returned. _blank = object() + # This is the signal to remove a kwarg from the inputs. + DELETE_KWARG = object() # Whether debugging messages should be issued _debug: bool = False _debug_show_inputs: bool = False @@ -368,22 +371,30 @@ def update_inputs(self, *args, **inputs): if not inputs and len(args) == 0: return + # If arg inputs are provided, just replace the whole args input if len(args) > 0: if self._args_inputs_key is None: raise ValueError(f"{self}.update_inputs does not support positional arguments, please provide values as keyword arguments.") inputs[self._args_inputs_key] = args + # If kwargs inputs are provided, add them to the previous input kwargs. if self._kwargs_inputs_key is not None: inputs[self._kwargs_inputs_key] = self._inputs.get(self._kwargs_inputs_key, {}).copy() - for k in inputs: + keys = list(inputs) + for k in keys: if k not in self._input_fields: - inputs[self._kwargs_inputs_key][k] = inputs.pop(k) + value = inputs.pop(k) + # This might be an actual value or a signal to delete a kwarg + if value is self.DELETE_KWARG: + inputs[self._kwargs_inputs_key].pop(k, None) + else: + inputs[self._kwargs_inputs_key][k] = value # Otherwise, update the inputs self._inputs.update(inputs) # Now, update all connections between nodes - self._update_connections(inputs) + self._update_connections(self._inputs) # Mark the node as outdated self._receive_outdated() @@ -399,18 +410,52 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): def __getitem__(self, key): return GetItemNode(data=self, key=key) - def _update_connections(self, updated_inputs={}): - for key, input in updated_inputs.items(): + def _update_connections(self, inputs): + + def _update(key, value): # Get the old connected node (if any) and tell them # that we are no longer using their input - old_connection = self._input_nodes.pop(key, None) + old_connection = self._input_nodes.get(key, None) + if old_connection is value: + # The input value has not been updated, no need to update any connections + return + if old_connection is not None: + self._input_nodes.pop(key) old_connection._receive_output_unlink(self) # If the new input is a node, create the connection - if isinstance(input, Node): - self._input_nodes[key] = input - input._receive_output_link(self) + if isinstance(value, Node): + self._input_nodes[key] = value + value._receive_output_link(self) + + previous_connections = list(self._input_nodes) + + for key, input in inputs.items(): + if key == self._args_inputs_key: + # Loop through all the current *args to update connections + for i, item in enumerate(input): + _update(f'{key}[{i}]', item) + # For indices higher than the current *args length, remove the connections. + # (this is because the previous *args might have been longer) + for k in previous_connections: + if k.startswith(f'{key}['): + if int(k[len(key)+1:-1]) > len(input): + _update(k, None) + elif key == self._kwargs_inputs_key: + current_kwargs = [] + # Loop through all the current **kwargs to update connections + for k, item in input.items(): + connection_key = f'{key}[{k}]' + current_kwargs.append(connection_key) + _update(connection_key, item) + # Remove connections for those keys that are no longer in the kwargs + for k in previous_connections: + if k.startswith(f'{key}[') and k not in current_kwargs: + _update(k, None) + else: + # This is the normal case, where the key is not either the *args or the **kwargs key. + _update(key, input) # We need to handle the case of the args_inputs key. diff --git a/sisl/viz/nodes/tests/test_lazynodes.py b/sisl/viz/nodes/tests/test_lazynodes.py index ae0edf7dea..79af9a7a3c 100644 --- a/sisl/viz/nodes/tests/test_lazynodes.py +++ b/sisl/viz/nodes/tests/test_lazynodes.py @@ -172,4 +172,146 @@ def reduce_(*nums, factor: int = 1): val2 = reduce_(val, 4, factor=1) assert val2.get() == 16 - assert val._nupdates == 1 \ No newline at end of file + assert val._nupdates == 1 + +@lazy_context(nodes=True) +def test_update_args(): + """When calling update_inputs with args, the old args should be completely + discarded and replaced by the ones provided. + """ + + @Node.from_func + def reduce_(*nums, factor: int = 1): + + val = 0 + for num in nums: + val += num + return val * factor + + val = reduce_(1, 2, 3, factor=2) + + assert val.get() == 12 + + val.update_inputs(4, 5, factor=3) + + assert val.get() == 27 + +@lazy_context(nodes=True) +def test_node_links_args(): + + @Node.from_func + def my_node(*some_args): + return some_args + + node1 = my_node() + node2 = my_node(4, 1, node1) + + # Check that node1 knows that node2 uses its output + assert len(node1._output_links) == 1 + assert node1._output_links[0] is node2 + + # And that node2 knows it's using node1 as an input. + assert len(node2._input_nodes) == 1 + assert 'some_args[2]' in node2._input_nodes + assert node2._input_nodes['some_args[2]'] is node1 + + # Now check that if we update node2, the connections + # will be removed. + node2.update_inputs(2) + + assert len(node2._input_nodes) == 0 + assert len(node1._output_links) == 0 + + # Check that connections are properly built when + # updating inputs with a value containing a node. + node2.update_inputs(node1) + + # Check that node1 knows that node2 uses its output + assert len(node1._output_links) == 1 + assert node1._output_links[0] is node2 + + # And that node2 knows it's using node1 as an input. + assert len(node2._input_nodes) == 1 + assert 'some_args[0]' in node2._input_nodes + assert node2._input_nodes['some_args[0]'] is node1 + +@lazy_context(nodes=True) +def test_kwargs(): + """Checks that functions with **kwargs are correctly handled by Node.""" + + @Node.from_func + def my_dict(**some_kwargs): + return some_kwargs + + val = my_dict(a=2, b=4) + + assert val.get() == {"a": 2, "b": 4} + + val2 = my_dict(old=val) + + assert val2.get() == {"old": {"a": 2, "b": 4}} + +@lazy_context(nodes=True) +def test_update_kwargs(): + """Checks that functions with **kwargs are correctly handled by Node.""" + + @Node.from_func + def my_dict(**some_kwargs): + return some_kwargs + + val = my_dict(a=2, b=4) + assert val.get() == {"a": 2, "b": 4} + + val.update_inputs(a=3) + assert val.get() == {"a": 3, "b": 4} + + val.update_inputs(c=5) + assert val.get() == {"a": 3, "b": 4, "c": 5} + + val.update_inputs(a=Node.DELETE_KWARG) + assert val.get() == {"b": 4, "c": 5} + +@lazy_context(nodes=True) +def test_node_links_kwargs(): + + @Node.from_func + def my_node(**some_kwargs): + return some_kwargs + + node1 = my_node() + node2 = my_node(a=node1) + + # Check that node1 knows its output is being used by + # node2 + assert len(node1._output_links) == 1 + assert node1._output_links[0] is node2 + + # And that node2 knows it's using node1 as an input. + assert len(node2._input_nodes) == 1 + assert 'some_kwargs[a]' in node2._input_nodes + assert node2._input_nodes['some_kwargs[a]'] is node1 + + # Test that kwargs that no longer exist are delinked. + + # Now check that if we update node2, the connections + # will be removed. + node2.update_inputs(a="other value") + + assert len(node1._output_links) == 0 + assert len(node2._input_nodes) == 0 + + # Finally check that connections are properly built when + # updating inputs with a value containing a node. + node3 = my_node() + node2.update_inputs(a=node3) + + # Check that node3 knows its output is being used by + # node2 + assert len(node3._output_links) == 1 + assert node3._output_links[0] is node2 + + # And that node2 knows it's using node3 as an input. + assert len(node2._input_nodes) == 1 + assert 'some_kwargs[a]' in node2._input_nodes + assert node2._input_nodes['some_kwargs[a]'] is node3 +