From a0cdce5cf6b314e6e1959163900ec185cd3319b6 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 14 Oct 2024 08:36:31 -0400 Subject: [PATCH 1/4] Add ability to export to dict --- pyproject.toml | 1 + src/tdastro/graph_state.py | 16 ++++++++++++++++ tests/tdastro/test_graph_state.py | 23 +++++++++++++++++++++++ 3 files changed, 40 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 4be87883..515a3f2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ requires-python = ">=3.9" dependencies = [ "astropy", "jax", + "nested-pandas", "numpy", "pandas", "scipy", diff --git a/src/tdastro/graph_state.py b/src/tdastro/graph_state.py index 17b68da7..ef8a88c5 100644 --- a/src/tdastro/graph_state.py +++ b/src/tdastro/graph_state.py @@ -320,6 +320,22 @@ def to_table(self): values[self.extended_param_name(node_name, param_name)] = np.array(param_value) return values + def to_dict(self): + """Flatten the graph state to a dictionary with columns for each parameter. + + The column names are: {node_name}{separator}{param_name} + + Returns + ------- + values : dict + The resulting dictionary. + """ + values = {} + for node_name, node_params in self.states.items(): + for param_name, param_value in node_params.items(): + values[self.extended_param_name(node_name, param_name)] = np.array(param_value) + return values + def save_to_file(self, filename, overwrite=False): """Save the GraphState to a file. diff --git a/tests/tdastro/test_graph_state.py b/tests/tdastro/test_graph_state.py index 3344dff6..4070b998 100644 --- a/tests/tdastro/test_graph_state.py +++ b/tests/tdastro/test_graph_state.py @@ -266,6 +266,29 @@ def test_graph_state_from_table(): np.testing.assert_allclose(state["b"]["v1"], [7.0, 8.0, 9.0]) +def test_graph_state_to_dict(): + """Test that we can create a dictionary from a GraphState.""" + state = GraphState(num_samples=3) + state.set("a", "v1", [1.0, 2.0, 3.0]) + state.set("a", "v2", [3.0, 4.0, 5.0]) + state.set("b", "v1", [6.0, 7.0, 8.0]) + + result = state.to_dict() + assert len(result) == 3 + np.testing.assert_allclose( + result[GraphState.extended_param_name("a", "v1")].data, + [1.0, 2.0, 3.0], + ) + np.testing.assert_allclose( + result[GraphState.extended_param_name("a", "v2")].data, + [3.0, 4.0, 5.0], + ) + np.testing.assert_allclose( + result[GraphState.extended_param_name("b", "v1")].data, + [6.0, 7.0, 8.0], + ) + + def test_graph_state_to_table(): """Test that we can create an AstroPy Table from a GraphState.""" state = GraphState(num_samples=3) From 3c53b7341c18f5eaa322b5fadfdcabcca525ee49 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 14 Oct 2024 09:29:41 -0400 Subject: [PATCH 2/4] cleanups --- src/tdastro/base_models.py | 30 ++++++++++---- src/tdastro/graph_state.py | 28 +++++++++++-- tests/tdastro/sources/test_static_source.py | 6 +-- tests/tdastro/test_base_models.py | 44 +++++++++++---------- tests/tdastro/test_graph_state.py | 17 +++++--- 5 files changed, 86 insertions(+), 39 deletions(-) diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 773321a9..67f9e2ab 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -206,10 +206,12 @@ def __str__(self): def _update_node_string(self, extra_tag=None): """Update the node's string.""" - pos_string = f"{self.node_pos}:" if self.node_pos is not None else "" if self.node_label is not None: - self.node_string = f"{pos_string}{self.node_label}" + # If a label is given, just use that. + self.node_string = self.node_label else: + # Otherwise use a combination of the node's class and position. + pos_string = f"{self.node_pos}:" if self.node_pos is not None else "" self.node_string = f"{pos_string}{self.__class__.__qualname__}" # Allow for the appending of an extra tag. @@ -224,6 +226,17 @@ def _update_node_string(self, extra_tag=None): for _, setter_info in self.setters.items(): setter_info.node_name = self.node_string + def set_node_label(self, new_label): + """Set the node's label field. + + Parameter + --------- + new_label : str or None + The new label for the node. + """ + self.node_label = new_label + self._update_node_string() + def set_graph_positions(self, seen_nodes=None): """Force an update of the graph structure (numbering of each node). @@ -248,8 +261,8 @@ def set_graph_positions(self, seen_nodes=None): for dep in self.direct_dependencies: dep.set_graph_positions(seen_nodes) - def get_param(self, graph_state, name): - """Get the value of a parameter stored in this node. + def get_param(self, graph_state, name, default=None): + """Get the value of a parameter stored in this node or a default value. Note ---- @@ -262,20 +275,23 @@ def get_param(self, graph_state, name): The dictionary of graph state information. name : `str` The parameter name to query. + default : any + The default value to return if the parameter is not in GraphState. Returns ------- any - The parameter value. + The parameter value or the default. Raises ------ - ``KeyError`` if this parameter has not be set. ``ValueError`` if graph_state is None. """ if graph_state is None: raise ValueError(f"Unable to look up parameter={name}. No graph_state given.") - return graph_state[self.node_string][name] + if self.node_string in graph_state and name in graph_state[self.node_string]: + return graph_state[self.node_string][name] + return default def get_local_params(self, graph_state): """Get a dictionary of all parameters local to this node. diff --git a/src/tdastro/graph_state.py b/src/tdastro/graph_state.py index ef8a88c5..78275f03 100644 --- a/src/tdastro/graph_state.py +++ b/src/tdastro/graph_state.py @@ -47,7 +47,7 @@ class GraphState: are fixed in this GraphState instance. """ - _NAME_SEPARATOR = "|>" + _NAME_SEPARATOR = "." def __init__(self, num_samples=1): if num_samples < 1: @@ -60,6 +60,17 @@ def __init__(self, num_samples=1): def __len__(self): return self.num_parameters + def __contains__(self, key): + if key in self.states: + return True + elif self._NAME_SEPARATOR in key: + tokens = key.split(self._NAME_SEPARATOR) + if len(tokens) != 2: + raise KeyError(f"Invalid GraphState key: {key}") + return True + else: + return False + def __str__(self): str_lines = [] for node_name, node_vars in self.states.items(): @@ -97,8 +108,17 @@ def __eq__(self, other): return True def __getitem__(self, key): - """Access the dictionary of parameter values for a node name.""" - return self.states[key] + """Access the dictionary of parameter values for a node name. Allows + access by both the pair of keys and the extended name.""" + if key in self.states: + return self.states[key] + elif self._NAME_SEPARATOR in key: + tokens = key.split(self._NAME_SEPARATOR) + if len(tokens) != 2: + raise KeyError(f"Invalid GraphState key: {key}") + return self.states[tokens[0]][tokens[1]] + else: + raise KeyError(f"Unknown GraphState key: {key}") @staticmethod def extended_param_name(node_name, param_name): @@ -333,7 +353,7 @@ def to_dict(self): values = {} for node_name, node_params in self.states.items(): for param_name, param_value in node_params.items(): - values[self.extended_param_name(node_name, param_name)] = np.array(param_value) + values[self.extended_param_name(node_name, param_name)] = list(param_value) return values def save_to_file(self, filename, overwrite=False): diff --git a/tests/tdastro/sources/test_static_source.py b/tests/tdastro/sources/test_static_source.py index 5d191f77..d0407d35 100644 --- a/tests/tdastro/sources/test_static_source.py +++ b/tests/tdastro/sources/test_static_source.py @@ -25,7 +25,7 @@ def test_static_source() -> None: assert model.get_param(state, "ra") is None assert model.get_param(state, "dec") is None assert model.get_param(state, "distance") is None - assert str(model) == "0:my_static_source" + assert str(model) == "my_static_source" times = np.array([1, 2, 3, 4, 5, 10]) wavelengths = np.array([100.0, 200.0, 300.0]) @@ -47,8 +47,8 @@ def test_test_physical_model_pytree(): state = model.sample_parameters() pytree = model.build_pytree(state) - assert pytree["0:my_static_source"]["brightness"] == 10.0 - assert len(pytree["0:my_static_source"]) == 1 + assert pytree["my_static_source"]["brightness"] == 10.0 + assert len(pytree["my_static_source"]) == 1 assert len(pytree) == 1 diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index e0f0ee31..fafc0b35 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -102,6 +102,15 @@ def test_parameterized_node(): assert model1.value2(state) == 0.5 assert model1.value_sum(state) == 1.0 + # We use the default for un-assigned parameters. + assert model1.get_param(state, "value3") is None + assert model1.get_param(state, "value4", 1.0) == 1.0 + + # If we set a position (and there is no node_label), the position shows up in the name. + model1.node_pos = 100 + model1._update_node_string() + assert str(model1) == "100:PairModel" + # Use value1=model.value and value2=1.0 model2 = PairModel(value1=model1.value1, value2=1.0, node_label="test") assert str(model2) == "test" @@ -111,11 +120,6 @@ def test_parameterized_node(): assert model2.get_param(state, "value2") == 1.0 assert model2.get_param(state, "value_sum") == 1.5 - # If we set an ID it shows up in the name. - model2.node_pos = 100 - model2._update_node_string() - assert str(model2) == "100:test" - # Compute value1 from model2's result and value2 from the sampler function. # The sampler function is auto-wrapped in a FunctionNode. model3 = PairModel(value1=model2.value_sum, value2=_sampler_fun) @@ -175,9 +179,9 @@ def test_parameterized_node_get_info(): # Get the node strings. node_strings = model3.get_all_node_info("node_string") assert len(node_strings) == 6 - assert "0:node3" in node_strings - assert "1:node1" in node_strings - assert "3:node2" in node_strings + assert "node3" in node_strings + assert "node1" in node_strings + assert "node2" in node_strings # Get the node hash values and check they are all unique. node_hashes = model3.get_all_node_info("node_hash") @@ -212,18 +216,18 @@ def test_parameterized_node_build_pytree(): graph_state = model2.sample_parameters() pytree = model2.build_pytree(graph_state) - assert pytree["1:A"]["value1"] == 0.5 - assert pytree["1:A"]["value2"] == 1.5 - assert pytree["0:B"]["value2"] == 3.0 + assert pytree["A"]["value1"] == 0.5 + assert pytree["A"]["value2"] == 1.5 + assert pytree["B"]["value2"] == 3.0 # Manually set value2 to allow_gradient to False and check that it no # longer appears in the pytree. model1.setters["value2"].allow_gradient = False pytree = model2.build_pytree(graph_state) - assert pytree["1:A"]["value1"] == 0.5 - assert pytree["0:B"]["value2"] == 3.0 - assert "value2" not in pytree["1:A"] + assert pytree["A"]["value1"] == 0.5 + assert pytree["B"]["value2"] == 3.0 + assert "value2" not in pytree["A"] # If we set node B's value1 to allow the gradient, it will appear and # neither of node A's value will appear (because the gradient stops at @@ -232,9 +236,9 @@ def test_parameterized_node_build_pytree(): model2.setters["value1"].allow_gradient = True pytree = model2.build_pytree(graph_state) - assert "1:A" not in pytree - assert pytree["0:B"]["value1"] == 0.5 - assert pytree["0:B"]["value2"] == 3.0 + assert "A" not in pytree + assert pytree["B"]["value1"] == 0.5 + assert pytree["B"]["value2"] == 3.0 def test_single_variable_node(): @@ -354,6 +358,6 @@ def _test_func2(value1, value2): values, gradients = gr_func(pytree) print(gradients) assert values == 9.0 - assert gradients["0:sum:_test_func"]["value1"] == 1.0 - assert gradients["1:div:_test_func2"]["value1"] == 2.0 - assert gradients["1:div:_test_func2"]["value2"] == -16.0 + assert gradients["sum:_test_func"]["value1"] == 1.0 + assert gradients["div:_test_func2"]["value1"] == 2.0 + assert gradients["div:_test_func2"]["value2"] == -16.0 diff --git a/tests/tdastro/test_graph_state.py b/tests/tdastro/test_graph_state.py index 4070b998..b5245ed5 100644 --- a/tests/tdastro/test_graph_state.py +++ b/tests/tdastro/test_graph_state.py @@ -27,6 +27,13 @@ def test_create_single_sample_graph_state(): with pytest.raises(KeyError): _ = state["c"]["v1"] + # We can access the entries using the extended key name. + assert state[f"a{state._NAME_SEPARATOR}v1"] == 1.0 + assert state[f"a{state._NAME_SEPARATOR}v2"] == 2.0 + assert state[f"b{state._NAME_SEPARATOR}v1"] == 3.0 + with pytest.raises(KeyError): + _ = state[f"c{state._NAME_SEPARATOR}v1"] + # We can create a human readable string representation of the GraphState. debug_str = str(state) assert debug_str == "a:\n v1: 1.0\n v2: 2.0\nb:\n v1: 3.0" @@ -63,9 +70,9 @@ def test_create_single_sample_graph_state(): # Test we cannot use a name containing the separator as a substring. with pytest.raises(ValueError): - state.set("a|>b", "v1", 10.0) + state.set(f"a{state._NAME_SEPARATOR}b", "v1", 10.0) with pytest.raises(ValueError): - state.set("b", "v1|>v3", 10.0) + state.set("b", f"v1{state._NAME_SEPARATOR}v3", 10.0) def test_create_multi_sample_graph_state(): @@ -276,15 +283,15 @@ def test_graph_state_to_dict(): result = state.to_dict() assert len(result) == 3 np.testing.assert_allclose( - result[GraphState.extended_param_name("a", "v1")].data, + result[GraphState.extended_param_name("a", "v1")], [1.0, 2.0, 3.0], ) np.testing.assert_allclose( - result[GraphState.extended_param_name("a", "v2")].data, + result[GraphState.extended_param_name("a", "v2")], [3.0, 4.0, 5.0], ) np.testing.assert_allclose( - result[GraphState.extended_param_name("b", "v1")].data, + result[GraphState.extended_param_name("b", "v1")], [6.0, 7.0, 8.0], ) From 829573e6efa7692855310f83966c72b2d0c32032 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 14 Oct 2024 09:34:59 -0400 Subject: [PATCH 3/4] Update pyproject.toml --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 515a3f2b..4be87883 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,6 @@ requires-python = ">=3.9" dependencies = [ "astropy", "jax", - "nested-pandas", "numpy", "pandas", "scipy", From c747ef8be418a8f01a418cc75ca785cc80ce15e4 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Tue, 15 Oct 2024 12:17:46 -0400 Subject: [PATCH 4/4] Address PR comment --- src/tdastro/graph_state.py | 2 +- tests/tdastro/test_graph_state.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/tdastro/graph_state.py b/src/tdastro/graph_state.py index 78275f03..ed99c03c 100644 --- a/src/tdastro/graph_state.py +++ b/src/tdastro/graph_state.py @@ -67,7 +67,7 @@ def __contains__(self, key): tokens = key.split(self._NAME_SEPARATOR) if len(tokens) != 2: raise KeyError(f"Invalid GraphState key: {key}") - return True + return tokens[0] in self.states and tokens[1] in self.states[tokens[0]] else: return False diff --git a/tests/tdastro/test_graph_state.py b/tests/tdastro/test_graph_state.py index b5245ed5..7df62f3b 100644 --- a/tests/tdastro/test_graph_state.py +++ b/tests/tdastro/test_graph_state.py @@ -75,6 +75,27 @@ def test_create_single_sample_graph_state(): state.set("b", f"v1{state._NAME_SEPARATOR}v3", 10.0) +def test_graph_state_contains(): + """Test that we can use the 'in' operator in GraphState.""" + state = GraphState() + state.set("a", "v1", 1.0) + state.set("a", "v2", 2.0) + state.set("b", "v1", 3.0) + + assert "a" in state + assert "b" in state + assert "c" not in state + + assert f"a{state._NAME_SEPARATOR}v1" in state + assert f"a{state._NAME_SEPARATOR}v2" in state + assert f"a{state._NAME_SEPARATOR}v3" not in state + assert f"b{state._NAME_SEPARATOR}v1" in state + assert f"c{state._NAME_SEPARATOR}v1" not in state + + with pytest.raises(KeyError): + assert f"b{state._NAME_SEPARATOR}v1{state._NAME_SEPARATOR}v2" not in state + + def test_create_multi_sample_graph_state(): """Test that we can create and access a multi-sample GraphState.""" state = GraphState(5)