Skip to content

Commit

Permalink
Merge pull request #160 from lincc-frameworks/save_lightcurves
Browse files Browse the repository at this point in the history
Add a bunch of helper functions
  • Loading branch information
jeremykubica authored Oct 15, 2024
2 parents 7b38e91 + c747ef8 commit 81ed6b6
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 32 deletions.
30 changes: 23 additions & 7 deletions src/tdastro/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,12 @@ def __str__(self):

def _update_node_string(self, extra_tag=None):
"""Update the node's string."""
pos_string = f"{self.node_pos}:" if self.node_pos is not None else ""
if self.node_label is not None:
self.node_string = f"{pos_string}{self.node_label}"
# If a label is given, just use that.
self.node_string = self.node_label
else:
# Otherwise use a combination of the node's class and position.
pos_string = f"{self.node_pos}:" if self.node_pos is not None else ""
self.node_string = f"{pos_string}{self.__class__.__qualname__}"

# Allow for the appending of an extra tag.
Expand All @@ -215,6 +217,17 @@ def _update_node_string(self, extra_tag=None):
for _, setter_info in self.setters.items():
setter_info.node_name = self.node_string

def set_node_label(self, new_label):
"""Set the node's label field.
Parameter
---------
new_label : str or None
The new label for the node.
"""
self.node_label = new_label
self._update_node_string()

def set_graph_positions(self, seen_nodes=None):
"""Force an update of the graph structure (numbering of each node).
Expand All @@ -239,8 +252,8 @@ def set_graph_positions(self, seen_nodes=None):
for dep in self.direct_dependencies:
dep.set_graph_positions(seen_nodes)

def get_param(self, graph_state, name):
"""Get the value of a parameter stored in this node.
def get_param(self, graph_state, name, default=None):
"""Get the value of a parameter stored in this node or a default value.
Note
----
Expand All @@ -253,20 +266,23 @@ def get_param(self, graph_state, name):
The dictionary of graph state information.
name : `str`
The parameter name to query.
default : any
The default value to return if the parameter is not in GraphState.
Returns
-------
any
The parameter value.
The parameter value or the default.
Raises
------
``KeyError`` if this parameter has not be set.
``ValueError`` if graph_state is None.
"""
if graph_state is None:
raise ValueError(f"Unable to look up parameter={name}. No graph_state given.")
return graph_state[self.node_string][name]
if self.node_string in graph_state and name in graph_state[self.node_string]:
return graph_state[self.node_string][name]
return default

def get_local_params(self, graph_state):
"""Get a dictionary of all parameters local to this node.
Expand Down
42 changes: 39 additions & 3 deletions src/tdastro/graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class GraphState:
are fixed in this GraphState instance.
"""

_NAME_SEPARATOR = "|>"
_NAME_SEPARATOR = "."

def __init__(self, num_samples=1):
if num_samples < 1:
Expand All @@ -60,6 +60,17 @@ def __init__(self, num_samples=1):
def __len__(self):
return self.num_parameters

def __contains__(self, key):
if key in self.states:
return True
elif self._NAME_SEPARATOR in key:
tokens = key.split(self._NAME_SEPARATOR)
if len(tokens) != 2:
raise KeyError(f"Invalid GraphState key: {key}")
return tokens[0] in self.states and tokens[1] in self.states[tokens[0]]
else:
return False

def __str__(self):
str_lines = []
for node_name, node_vars in self.states.items():
Expand Down Expand Up @@ -97,8 +108,17 @@ def __eq__(self, other):
return True

def __getitem__(self, key):
"""Access the dictionary of parameter values for a node name."""
return self.states[key]
"""Access the dictionary of parameter values for a node name. Allows
access by both the pair of keys and the extended name."""
if key in self.states:
return self.states[key]
elif self._NAME_SEPARATOR in key:
tokens = key.split(self._NAME_SEPARATOR)
if len(tokens) != 2:
raise KeyError(f"Invalid GraphState key: {key}")
return self.states[tokens[0]][tokens[1]]
else:
raise KeyError(f"Unknown GraphState key: {key}")

@staticmethod
def extended_param_name(node_name, param_name):
Expand Down Expand Up @@ -320,6 +340,22 @@ def to_table(self):
values[self.extended_param_name(node_name, param_name)] = np.array(param_value)
return values

def to_dict(self):
"""Flatten the graph state to a dictionary with columns for each parameter.
The column names are: {node_name}{separator}{param_name}
Returns
-------
values : dict
The resulting dictionary.
"""
values = {}
for node_name, node_params in self.states.items():
for param_name, param_value in node_params.items():
values[self.extended_param_name(node_name, param_name)] = list(param_value)
return values

def save_to_file(self, filename, overwrite=False):
"""Save the GraphState to a file.
Expand Down
6 changes: 3 additions & 3 deletions tests/tdastro/sources/test_static_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_static_source() -> None:
assert model.get_param(state, "ra") is None
assert model.get_param(state, "dec") is None
assert model.get_param(state, "distance") is None
assert str(model) == "0:my_static_source"
assert str(model) == "my_static_source"

times = np.array([1, 2, 3, 4, 5, 10])
wavelengths = np.array([100.0, 200.0, 300.0])
Expand All @@ -47,8 +47,8 @@ def test_test_physical_model_pytree():
state = model.sample_parameters()

pytree = model.build_pytree(state)
assert pytree["0:my_static_source"]["brightness"] == 10.0
assert len(pytree["0:my_static_source"]) == 1
assert pytree["my_static_source"]["brightness"] == 10.0
assert len(pytree["my_static_source"]) == 1
assert len(pytree) == 1


Expand Down
38 changes: 21 additions & 17 deletions tests/tdastro/test_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ def test_parameterized_node():
assert model1.value2(state) == 0.5
assert model1.value_sum(state) == 1.0

# We use the default for un-assigned parameters.
assert model1.get_param(state, "value3") is None
assert model1.get_param(state, "value4", 1.0) == 1.0

# If we set a position (and there is no node_label), the position shows up in the name.
model1.node_pos = 100
model1._update_node_string()
assert str(model1) == "100:PairModel"

# Use value1=model.value and value2=1.0
model2 = PairModel(value1=model1.value1, value2=1.0, node_label="test")
assert str(model2) == "test"
Expand All @@ -111,11 +120,6 @@ def test_parameterized_node():
assert model2.get_param(state, "value2") == 1.0
assert model2.get_param(state, "value_sum") == 1.5

# If we set an ID it shows up in the name.
model2.node_pos = 100
model2._update_node_string()
assert str(model2) == "100:test"

# Compute value1 from model2's result and value2 from the sampler function.
# The sampler function is auto-wrapped in a FunctionNode.
model3 = PairModel(value1=model2.value_sum, value2=_sampler_fun)
Expand Down Expand Up @@ -186,18 +190,18 @@ def test_parameterized_node_build_pytree():
graph_state = model2.sample_parameters()

pytree = model2.build_pytree(graph_state)
assert pytree["1:A"]["value1"] == 0.5
assert pytree["1:A"]["value2"] == 1.5
assert pytree["0:B"]["value2"] == 3.0
assert pytree["A"]["value1"] == 0.5
assert pytree["A"]["value2"] == 1.5
assert pytree["B"]["value2"] == 3.0

# Manually set value2 to allow_gradient to False and check that it no
# longer appears in the pytree.
model1.setters["value2"].allow_gradient = False

pytree = model2.build_pytree(graph_state)
assert pytree["1:A"]["value1"] == 0.5
assert pytree["0:B"]["value2"] == 3.0
assert "value2" not in pytree["1:A"]
assert pytree["A"]["value1"] == 0.5
assert pytree["B"]["value2"] == 3.0
assert "value2" not in pytree["A"]

# If we set node B's value1 to allow the gradient, it will appear and
# neither of node A's value will appear (because the gradient stops at
Expand All @@ -206,9 +210,9 @@ def test_parameterized_node_build_pytree():
model2.setters["value1"].allow_gradient = True

pytree = model2.build_pytree(graph_state)
assert "1:A" not in pytree
assert pytree["0:B"]["value1"] == 0.5
assert pytree["0:B"]["value2"] == 3.0
assert "A" not in pytree
assert pytree["B"]["value1"] == 0.5
assert pytree["B"]["value2"] == 3.0


def test_single_variable_node():
Expand Down Expand Up @@ -328,6 +332,6 @@ def _test_func2(value1, value2):
values, gradients = gr_func(pytree)
print(gradients)
assert values == 9.0
assert gradients["0:sum:_test_func"]["value1"] == 1.0
assert gradients["1:div:_test_func2"]["value1"] == 2.0
assert gradients["1:div:_test_func2"]["value2"] == -16.0
assert gradients["sum:_test_func"]["value1"] == 1.0
assert gradients["div:_test_func2"]["value1"] == 2.0
assert gradients["div:_test_func2"]["value2"] == -16.0
55 changes: 53 additions & 2 deletions tests/tdastro/test_graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ def test_create_single_sample_graph_state():
with pytest.raises(KeyError):
_ = state["c"]["v1"]

# We can access the entries using the extended key name.
assert state[f"a{state._NAME_SEPARATOR}v1"] == 1.0
assert state[f"a{state._NAME_SEPARATOR}v2"] == 2.0
assert state[f"b{state._NAME_SEPARATOR}v1"] == 3.0
with pytest.raises(KeyError):
_ = state[f"c{state._NAME_SEPARATOR}v1"]

# We can create a human readable string representation of the GraphState.
debug_str = str(state)
assert debug_str == "a:\n v1: 1.0\n v2: 2.0\nb:\n v1: 3.0"
Expand Down Expand Up @@ -63,9 +70,30 @@ def test_create_single_sample_graph_state():

# Test we cannot use a name containing the separator as a substring.
with pytest.raises(ValueError):
state.set("a|>b", "v1", 10.0)
state.set(f"a{state._NAME_SEPARATOR}b", "v1", 10.0)
with pytest.raises(ValueError):
state.set("b", "v1|>v3", 10.0)
state.set("b", f"v1{state._NAME_SEPARATOR}v3", 10.0)


def test_graph_state_contains():
"""Test that we can use the 'in' operator in GraphState."""
state = GraphState()
state.set("a", "v1", 1.0)
state.set("a", "v2", 2.0)
state.set("b", "v1", 3.0)

assert "a" in state
assert "b" in state
assert "c" not in state

assert f"a{state._NAME_SEPARATOR}v1" in state
assert f"a{state._NAME_SEPARATOR}v2" in state
assert f"a{state._NAME_SEPARATOR}v3" not in state
assert f"b{state._NAME_SEPARATOR}v1" in state
assert f"c{state._NAME_SEPARATOR}v1" not in state

with pytest.raises(KeyError):
assert f"b{state._NAME_SEPARATOR}v1{state._NAME_SEPARATOR}v2" not in state


def test_create_multi_sample_graph_state():
Expand Down Expand Up @@ -266,6 +294,29 @@ def test_graph_state_from_table():
np.testing.assert_allclose(state["b"]["v1"], [7.0, 8.0, 9.0])


def test_graph_state_to_dict():
"""Test that we can create a dictionary from a GraphState."""
state = GraphState(num_samples=3)
state.set("a", "v1", [1.0, 2.0, 3.0])
state.set("a", "v2", [3.0, 4.0, 5.0])
state.set("b", "v1", [6.0, 7.0, 8.0])

result = state.to_dict()
assert len(result) == 3
np.testing.assert_allclose(
result[GraphState.extended_param_name("a", "v1")],
[1.0, 2.0, 3.0],
)
np.testing.assert_allclose(
result[GraphState.extended_param_name("a", "v2")],
[3.0, 4.0, 5.0],
)
np.testing.assert_allclose(
result[GraphState.extended_param_name("b", "v1")],
[6.0, 7.0, 8.0],
)


def test_graph_state_to_table():
"""Test that we can create an AstroPy Table from a GraphState."""
state = GraphState(num_samples=3)
Expand Down

0 comments on commit 81ed6b6

Please sign in to comment.