diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 51cbedd8..d8024c5c 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -201,10 +201,12 @@ def __str__(self): def _update_node_string(self, extra_tag=None): """Update the node's string.""" - pos_string = f"{self.node_pos}:" if self.node_pos is not None else "" if self.node_label is not None: - self.node_string = f"{pos_string}{self.node_label}" + # If a label is given, just use that. + self.node_string = self.node_label else: + # Otherwise use a combination of the node's class and position. + pos_string = f"{self.node_pos}:" if self.node_pos is not None else "" self.node_string = f"{pos_string}{self.__class__.__qualname__}" # Allow for the appending of an extra tag. @@ -215,6 +217,17 @@ def _update_node_string(self, extra_tag=None): for _, setter_info in self.setters.items(): setter_info.node_name = self.node_string + def set_node_label(self, new_label): + """Set the node's label field. + + Parameter + --------- + new_label : str or None + The new label for the node. + """ + self.node_label = new_label + self._update_node_string() + def set_graph_positions(self, seen_nodes=None): """Force an update of the graph structure (numbering of each node). @@ -239,8 +252,8 @@ def set_graph_positions(self, seen_nodes=None): for dep in self.direct_dependencies: dep.set_graph_positions(seen_nodes) - def get_param(self, graph_state, name): - """Get the value of a parameter stored in this node. + def get_param(self, graph_state, name, default=None): + """Get the value of a parameter stored in this node or a default value. Note ---- @@ -253,20 +266,23 @@ def get_param(self, graph_state, name): The dictionary of graph state information. name : `str` The parameter name to query. + default : any + The default value to return if the parameter is not in GraphState. Returns ------- any - The parameter value. + The parameter value or the default. Raises ------ - ``KeyError`` if this parameter has not be set. ``ValueError`` if graph_state is None. """ if graph_state is None: raise ValueError(f"Unable to look up parameter={name}. No graph_state given.") - return graph_state[self.node_string][name] + if self.node_string in graph_state and name in graph_state[self.node_string]: + return graph_state[self.node_string][name] + return default def get_local_params(self, graph_state): """Get a dictionary of all parameters local to this node. diff --git a/src/tdastro/graph_state.py b/src/tdastro/graph_state.py index 17b68da7..ed99c03c 100644 --- a/src/tdastro/graph_state.py +++ b/src/tdastro/graph_state.py @@ -47,7 +47,7 @@ class GraphState: are fixed in this GraphState instance. """ - _NAME_SEPARATOR = "|>" + _NAME_SEPARATOR = "." def __init__(self, num_samples=1): if num_samples < 1: @@ -60,6 +60,17 @@ def __init__(self, num_samples=1): def __len__(self): return self.num_parameters + def __contains__(self, key): + if key in self.states: + return True + elif self._NAME_SEPARATOR in key: + tokens = key.split(self._NAME_SEPARATOR) + if len(tokens) != 2: + raise KeyError(f"Invalid GraphState key: {key}") + return tokens[0] in self.states and tokens[1] in self.states[tokens[0]] + else: + return False + def __str__(self): str_lines = [] for node_name, node_vars in self.states.items(): @@ -97,8 +108,17 @@ def __eq__(self, other): return True def __getitem__(self, key): - """Access the dictionary of parameter values for a node name.""" - return self.states[key] + """Access the dictionary of parameter values for a node name. Allows + access by both the pair of keys and the extended name.""" + if key in self.states: + return self.states[key] + elif self._NAME_SEPARATOR in key: + tokens = key.split(self._NAME_SEPARATOR) + if len(tokens) != 2: + raise KeyError(f"Invalid GraphState key: {key}") + return self.states[tokens[0]][tokens[1]] + else: + raise KeyError(f"Unknown GraphState key: {key}") @staticmethod def extended_param_name(node_name, param_name): @@ -320,6 +340,22 @@ def to_table(self): values[self.extended_param_name(node_name, param_name)] = np.array(param_value) return values + def to_dict(self): + """Flatten the graph state to a dictionary with columns for each parameter. + + The column names are: {node_name}{separator}{param_name} + + Returns + ------- + values : dict + The resulting dictionary. + """ + values = {} + for node_name, node_params in self.states.items(): + for param_name, param_value in node_params.items(): + values[self.extended_param_name(node_name, param_name)] = list(param_value) + return values + def save_to_file(self, filename, overwrite=False): """Save the GraphState to a file. diff --git a/tests/tdastro/sources/test_static_source.py b/tests/tdastro/sources/test_static_source.py index 5d191f77..d0407d35 100644 --- a/tests/tdastro/sources/test_static_source.py +++ b/tests/tdastro/sources/test_static_source.py @@ -25,7 +25,7 @@ def test_static_source() -> None: assert model.get_param(state, "ra") is None assert model.get_param(state, "dec") is None assert model.get_param(state, "distance") is None - assert str(model) == "0:my_static_source" + assert str(model) == "my_static_source" times = np.array([1, 2, 3, 4, 5, 10]) wavelengths = np.array([100.0, 200.0, 300.0]) @@ -47,8 +47,8 @@ def test_test_physical_model_pytree(): state = model.sample_parameters() pytree = model.build_pytree(state) - assert pytree["0:my_static_source"]["brightness"] == 10.0 - assert len(pytree["0:my_static_source"]) == 1 + assert pytree["my_static_source"]["brightness"] == 10.0 + assert len(pytree["my_static_source"]) == 1 assert len(pytree) == 1 diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 3eebfc14..76f88abf 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -102,6 +102,15 @@ def test_parameterized_node(): assert model1.value2(state) == 0.5 assert model1.value_sum(state) == 1.0 + # We use the default for un-assigned parameters. + assert model1.get_param(state, "value3") is None + assert model1.get_param(state, "value4", 1.0) == 1.0 + + # If we set a position (and there is no node_label), the position shows up in the name. + model1.node_pos = 100 + model1._update_node_string() + assert str(model1) == "100:PairModel" + # Use value1=model.value and value2=1.0 model2 = PairModel(value1=model1.value1, value2=1.0, node_label="test") assert str(model2) == "test" @@ -111,11 +120,6 @@ def test_parameterized_node(): assert model2.get_param(state, "value2") == 1.0 assert model2.get_param(state, "value_sum") == 1.5 - # If we set an ID it shows up in the name. - model2.node_pos = 100 - model2._update_node_string() - assert str(model2) == "100:test" - # Compute value1 from model2's result and value2 from the sampler function. # The sampler function is auto-wrapped in a FunctionNode. model3 = PairModel(value1=model2.value_sum, value2=_sampler_fun) @@ -186,18 +190,18 @@ def test_parameterized_node_build_pytree(): graph_state = model2.sample_parameters() pytree = model2.build_pytree(graph_state) - assert pytree["1:A"]["value1"] == 0.5 - assert pytree["1:A"]["value2"] == 1.5 - assert pytree["0:B"]["value2"] == 3.0 + assert pytree["A"]["value1"] == 0.5 + assert pytree["A"]["value2"] == 1.5 + assert pytree["B"]["value2"] == 3.0 # Manually set value2 to allow_gradient to False and check that it no # longer appears in the pytree. model1.setters["value2"].allow_gradient = False pytree = model2.build_pytree(graph_state) - assert pytree["1:A"]["value1"] == 0.5 - assert pytree["0:B"]["value2"] == 3.0 - assert "value2" not in pytree["1:A"] + assert pytree["A"]["value1"] == 0.5 + assert pytree["B"]["value2"] == 3.0 + assert "value2" not in pytree["A"] # If we set node B's value1 to allow the gradient, it will appear and # neither of node A's value will appear (because the gradient stops at @@ -206,9 +210,9 @@ def test_parameterized_node_build_pytree(): model2.setters["value1"].allow_gradient = True pytree = model2.build_pytree(graph_state) - assert "1:A" not in pytree - assert pytree["0:B"]["value1"] == 0.5 - assert pytree["0:B"]["value2"] == 3.0 + assert "A" not in pytree + assert pytree["B"]["value1"] == 0.5 + assert pytree["B"]["value2"] == 3.0 def test_single_variable_node(): @@ -328,6 +332,6 @@ def _test_func2(value1, value2): values, gradients = gr_func(pytree) print(gradients) assert values == 9.0 - assert gradients["0:sum:_test_func"]["value1"] == 1.0 - assert gradients["1:div:_test_func2"]["value1"] == 2.0 - assert gradients["1:div:_test_func2"]["value2"] == -16.0 + assert gradients["sum:_test_func"]["value1"] == 1.0 + assert gradients["div:_test_func2"]["value1"] == 2.0 + assert gradients["div:_test_func2"]["value2"] == -16.0 diff --git a/tests/tdastro/test_graph_state.py b/tests/tdastro/test_graph_state.py index 3344dff6..7df62f3b 100644 --- a/tests/tdastro/test_graph_state.py +++ b/tests/tdastro/test_graph_state.py @@ -27,6 +27,13 @@ def test_create_single_sample_graph_state(): with pytest.raises(KeyError): _ = state["c"]["v1"] + # We can access the entries using the extended key name. + assert state[f"a{state._NAME_SEPARATOR}v1"] == 1.0 + assert state[f"a{state._NAME_SEPARATOR}v2"] == 2.0 + assert state[f"b{state._NAME_SEPARATOR}v1"] == 3.0 + with pytest.raises(KeyError): + _ = state[f"c{state._NAME_SEPARATOR}v1"] + # We can create a human readable string representation of the GraphState. debug_str = str(state) assert debug_str == "a:\n v1: 1.0\n v2: 2.0\nb:\n v1: 3.0" @@ -63,9 +70,30 @@ def test_create_single_sample_graph_state(): # Test we cannot use a name containing the separator as a substring. with pytest.raises(ValueError): - state.set("a|>b", "v1", 10.0) + state.set(f"a{state._NAME_SEPARATOR}b", "v1", 10.0) with pytest.raises(ValueError): - state.set("b", "v1|>v3", 10.0) + state.set("b", f"v1{state._NAME_SEPARATOR}v3", 10.0) + + +def test_graph_state_contains(): + """Test that we can use the 'in' operator in GraphState.""" + state = GraphState() + state.set("a", "v1", 1.0) + state.set("a", "v2", 2.0) + state.set("b", "v1", 3.0) + + assert "a" in state + assert "b" in state + assert "c" not in state + + assert f"a{state._NAME_SEPARATOR}v1" in state + assert f"a{state._NAME_SEPARATOR}v2" in state + assert f"a{state._NAME_SEPARATOR}v3" not in state + assert f"b{state._NAME_SEPARATOR}v1" in state + assert f"c{state._NAME_SEPARATOR}v1" not in state + + with pytest.raises(KeyError): + assert f"b{state._NAME_SEPARATOR}v1{state._NAME_SEPARATOR}v2" not in state def test_create_multi_sample_graph_state(): @@ -266,6 +294,29 @@ def test_graph_state_from_table(): np.testing.assert_allclose(state["b"]["v1"], [7.0, 8.0, 9.0]) +def test_graph_state_to_dict(): + """Test that we can create a dictionary from a GraphState.""" + state = GraphState(num_samples=3) + state.set("a", "v1", [1.0, 2.0, 3.0]) + state.set("a", "v2", [3.0, 4.0, 5.0]) + state.set("b", "v1", [6.0, 7.0, 8.0]) + + result = state.to_dict() + assert len(result) == 3 + np.testing.assert_allclose( + result[GraphState.extended_param_name("a", "v1")], + [1.0, 2.0, 3.0], + ) + np.testing.assert_allclose( + result[GraphState.extended_param_name("a", "v2")], + [3.0, 4.0, 5.0], + ) + np.testing.assert_allclose( + result[GraphState.extended_param_name("b", "v1")], + [6.0, 7.0, 8.0], + ) + + def test_graph_state_to_table(): """Test that we can create an AstroPy Table from a GraphState.""" state = GraphState(num_samples=3)