Skip to content

Commit

Permalink
handling *args and **kwargs node links
Browse files Browse the repository at this point in the history
  • Loading branch information
pfebrer committed Jul 18, 2022
1 parent 4391711 commit 58f2dd9
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 10 deletions.
63 changes: 54 additions & 9 deletions sisl/viz/nodes/node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import inspect
from multiprocessing import connection
from typing import Any, Optional
from collections import ChainMap
from collections.abc import Mapping
Expand Down Expand Up @@ -77,6 +78,8 @@ class Node(NDArrayOperatorsMixin, metaclass=NodeMeta):
"""
# Object that will be the reference for output that has not been returned.
_blank = object()
# This is the signal to remove a kwarg from the inputs.
DELETE_KWARG = object()
# Whether debugging messages should be issued
_debug: bool = False
_debug_show_inputs: bool = False
Expand Down Expand Up @@ -368,22 +371,30 @@ def update_inputs(self, *args, **inputs):
if not inputs and len(args) == 0:
return

# If arg inputs are provided, just replace the whole args input
if len(args) > 0:
if self._args_inputs_key is None:
raise ValueError(f"{self}.update_inputs does not support positional arguments, please provide values as keyword arguments.")
inputs[self._args_inputs_key] = args

# If kwargs inputs are provided, add them to the previous input kwargs.
if self._kwargs_inputs_key is not None:
inputs[self._kwargs_inputs_key] = self._inputs.get(self._kwargs_inputs_key, {}).copy()
for k in inputs:
keys = list(inputs)
for k in keys:
if k not in self._input_fields:
inputs[self._kwargs_inputs_key][k] = inputs.pop(k)
value = inputs.pop(k)
# This might be an actual value or a signal to delete a kwarg
if value is self.DELETE_KWARG:
inputs[self._kwargs_inputs_key].pop(k, None)
else:
inputs[self._kwargs_inputs_key][k] = value

# Otherwise, update the inputs
self._inputs.update(inputs)

# Now, update all connections between nodes
self._update_connections(inputs)
self._update_connections(self._inputs)

# Mark the node as outdated
self._receive_outdated()
Expand All @@ -399,18 +410,52 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
def __getitem__(self, key):
return GetItemNode(data=self, key=key)

def _update_connections(self, updated_inputs={}):
for key, input in updated_inputs.items():
def _update_connections(self, inputs):

def _update(key, value):
# Get the old connected node (if any) and tell them
# that we are no longer using their input
old_connection = self._input_nodes.pop(key, None)
old_connection = self._input_nodes.get(key, None)
if old_connection is value:
# The input value has not been updated, no need to update any connections
return

if old_connection is not None:
self._input_nodes.pop(key)
old_connection._receive_output_unlink(self)

# If the new input is a node, create the connection
if isinstance(input, Node):
self._input_nodes[key] = input
input._receive_output_link(self)
if isinstance(value, Node):
self._input_nodes[key] = value
value._receive_output_link(self)

previous_connections = list(self._input_nodes)

for key, input in inputs.items():
if key == self._args_inputs_key:
# Loop through all the current *args to update connections
for i, item in enumerate(input):
_update(f'{key}[{i}]', item)
# For indices higher than the current *args length, remove the connections.
# (this is because the previous *args might have been longer)
for k in previous_connections:
if k.startswith(f'{key}['):
if int(k[len(key)+1:-1]) > len(input):
_update(k, None)
elif key == self._kwargs_inputs_key:
current_kwargs = []
# Loop through all the current **kwargs to update connections
for k, item in input.items():
connection_key = f'{key}[{k}]'
current_kwargs.append(connection_key)
_update(connection_key, item)
# Remove connections for those keys that are no longer in the kwargs
for k in previous_connections:
if k.startswith(f'{key}[') and k not in current_kwargs:
_update(k, None)
else:
# This is the normal case, where the key is not either the *args or the **kwargs key.
_update(key, input)

# We need to handle the case of the args_inputs key.

Expand Down
144 changes: 143 additions & 1 deletion sisl/viz/nodes/tests/test_lazynodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,146 @@ def reduce_(*nums, factor: int = 1):
val2 = reduce_(val, 4, factor=1)

assert val2.get() == 16
assert val._nupdates == 1
assert val._nupdates == 1

@lazy_context(nodes=True)
def test_update_args():
"""When calling update_inputs with args, the old args should be completely
discarded and replaced by the ones provided.
"""

@Node.from_func
def reduce_(*nums, factor: int = 1):

val = 0
for num in nums:
val += num
return val * factor

val = reduce_(1, 2, 3, factor=2)

assert val.get() == 12

val.update_inputs(4, 5, factor=3)

assert val.get() == 27

@lazy_context(nodes=True)
def test_node_links_args():

@Node.from_func
def my_node(*some_args):
return some_args

node1 = my_node()
node2 = my_node(4, 1, node1)

# Check that node1 knows that node2 uses its output
assert len(node1._output_links) == 1
assert node1._output_links[0] is node2

# And that node2 knows it's using node1 as an input.
assert len(node2._input_nodes) == 1
assert 'some_args[2]' in node2._input_nodes
assert node2._input_nodes['some_args[2]'] is node1

# Now check that if we update node2, the connections
# will be removed.
node2.update_inputs(2)

assert len(node2._input_nodes) == 0
assert len(node1._output_links) == 0

# Check that connections are properly built when
# updating inputs with a value containing a node.
node2.update_inputs(node1)

# Check that node1 knows that node2 uses its output
assert len(node1._output_links) == 1
assert node1._output_links[0] is node2

# And that node2 knows it's using node1 as an input.
assert len(node2._input_nodes) == 1
assert 'some_args[0]' in node2._input_nodes
assert node2._input_nodes['some_args[0]'] is node1

@lazy_context(nodes=True)
def test_kwargs():
"""Checks that functions with **kwargs are correctly handled by Node."""

@Node.from_func
def my_dict(**some_kwargs):
return some_kwargs

val = my_dict(a=2, b=4)

assert val.get() == {"a": 2, "b": 4}

val2 = my_dict(old=val)

assert val2.get() == {"old": {"a": 2, "b": 4}}

@lazy_context(nodes=True)
def test_update_kwargs():
"""Checks that functions with **kwargs are correctly handled by Node."""

@Node.from_func
def my_dict(**some_kwargs):
return some_kwargs

val = my_dict(a=2, b=4)
assert val.get() == {"a": 2, "b": 4}

val.update_inputs(a=3)
assert val.get() == {"a": 3, "b": 4}

val.update_inputs(c=5)
assert val.get() == {"a": 3, "b": 4, "c": 5}

val.update_inputs(a=Node.DELETE_KWARG)
assert val.get() == {"b": 4, "c": 5}

@lazy_context(nodes=True)
def test_node_links_kwargs():

@Node.from_func
def my_node(**some_kwargs):
return some_kwargs

node1 = my_node()
node2 = my_node(a=node1)

# Check that node1 knows its output is being used by
# node2
assert len(node1._output_links) == 1
assert node1._output_links[0] is node2

# And that node2 knows it's using node1 as an input.
assert len(node2._input_nodes) == 1
assert 'some_kwargs[a]' in node2._input_nodes
assert node2._input_nodes['some_kwargs[a]'] is node1

# Test that kwargs that no longer exist are delinked.

# Now check that if we update node2, the connections
# will be removed.
node2.update_inputs(a="other value")

assert len(node1._output_links) == 0
assert len(node2._input_nodes) == 0

# Finally check that connections are properly built when
# updating inputs with a value containing a node.
node3 = my_node()
node2.update_inputs(a=node3)

# Check that node3 knows its output is being used by
# node2
assert len(node3._output_links) == 1
assert node3._output_links[0] is node2

# And that node2 knows it's using node3 as an input.
assert len(node2._input_nodes) == 1
assert 'some_kwargs[a]' in node2._input_nodes
assert node2._input_nodes['some_kwargs[a]'] is node3

0 comments on commit 58f2dd9

Please sign in to comment.