Skip to content

Commit

Permalink
Merge pull request #207 from lincc-frameworks/random_functions
Browse files Browse the repository at this point in the history
Extend random nodes for testing and simulation
  • Loading branch information
jeremykubica authored Jan 2, 2025
2 parents 6eaa1ca + 06f2ca2 commit 80b2b6e
Show file tree
Hide file tree
Showing 12 changed files with 468 additions and 100 deletions.
28 changes: 28 additions & 0 deletions docs/notebooks/opsim_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,34 @@
"%%timeit\n",
"_ = opsim_data2.range_search(query_ra, query_dec, 0.5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sampling\n",
"\n",
"We can sample RA, dec, and time from an `OpSim` object using the `OpSimRADECSampler` node. This will select a random observation from the OpSim and then create a random (RA, dec) from near the center of that observation.\n",
"\n",
"This allows us to generate fake query locations for testing that are compatible with arbitrary OpSims as opposed to generating a large number of fake locations and then filtering only those that match the opsim. The distribution will be weighted by the opsim's pointing. If it points at region A 90% of time time and region B 10% of the time, then 90% of the resulting points will be from region A and 10% from region B."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tdastro.math_nodes.ra_dec_sampler import OpSimRADECSampler\n",
"\n",
"sampler_node = OpSimRADECSampler(ops_data, in_order=True)\n",
"\n",
"# Test we can generate a single value.\n",
"(ra, dec, time) = sampler_node.generate(num_samples=10)\n",
"\n",
"for i in range(10):\n",
" print(f\"{i}: ({ra[i]}, {dec[i]}) at t={time[i]}\")"
]
}
],
"metadata": {
Expand Down
24 changes: 20 additions & 4 deletions src/tdastro/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,17 +777,33 @@ def compute(self, graph_state, rng_info=None, **kwargs):
self._save_results(results, graph_state)
return results

def resample_and_compute(self, given_args=None, rng_info=None):
"""A helper function for JAX gradients that runs the sampling then computation.
def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs):
"""A helper function that regenerates the parameters for this nodes and the
ones above it, then returns the the output or this individual node.
This is used both for testing and for computing JAX gradients.
Parameters
----------
given_args : `dict`, optional
A dictionary representing the given arguments for this sample run.
This can be used as the JAX PyTree for differentiation.
num_samples : `int`
A count of the number of samples to compute.
Default: 1
rng_info : numpy.random._generator.Generator, optional
A given numpy random number generator to use for this computation. If not
provided, the function uses the node's random number generator.
**kwargs : `dict`, optional
Additional function arguments.
"""
graph_state = self.sample_parameters(given_args, 1, rng_info)
return self.compute(graph_state, rng_info)
state = self.sample_parameters(given_args, num_samples, rng_info)

# Get the result(s) of compute from the state object.
if len(self.outputs) == 1:
return self.get_param(state, self.outputs[0])

results = []
for output_name in self.outputs:
results.append(self.get_param(state, output_name))
return results
138 changes: 104 additions & 34 deletions src/tdastro/math_nodes/given_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from astropy.table import Table

from tdastro.base_models import FunctionNode
from tdastro.math_nodes.np_random import NumpyRandomFunc


class GivenSampler(FunctionNode):
"""A FunctionNode that returns given results.
class GivenValueList(FunctionNode):
"""A FunctionNode that returns given results for a single parameter.
Attributes
----------
Expand All @@ -22,8 +23,11 @@ class GivenSampler(FunctionNode):
"""

def __init__(self, values, **kwargs):
self.values = np.array(values)
self.values = np.asarray(values)
if len(values) == 0:
raise ValueError("No values provided for GivenValueList")
self.next_ind = 0

super().__init__(self._non_func, **kwargs)

def _non_func(self):
Expand Down Expand Up @@ -73,42 +77,105 @@ def compute(self, graph_state, rng_info=None, **kwargs):
return results


class TableSampler(FunctionNode):
"""A FunctionNode that returns values from a table, including
a Pandas DataFrame or AstroPy Table.
class GivenValueSampler(NumpyRandomFunc):
"""A FunctionNode that returns randomly selected items from a given list
with replacement.
Attributes
----------
values : list or numpy.ndarray
The values to select from.
_num_values : int
The number of values that can be sampled.
"""

def __init__(self, values, seed=None, **kwargs):
self.values = np.asarray(values)
self._num_values = len(values)
if self._num_values == 0:
raise ValueError("No values provided for NumpySamplerNode")

super().__init__("uniform", seed=seed, **kwargs)

def compute(self, graph_state, rng_info=None, **kwargs):
"""Return the given values.
Parameters
----------
graph_state : `GraphState`
An object mapping graph parameters to their values. This object is modified
in place as it is sampled.
rng_info : numpy.random._generator.Generator, optional
A given numpy random number generator to use for this computation. If not
provided, the function uses the node's random number generator.
**kwargs : `dict`, optional
Additional function arguments.
Returns
-------
results : any
The result of the computation. This return value is provided so that testing
functions can easily access the results.
"""
rng = rng_info if rng_info is not None else self._rng

if graph_state.num_samples == 1:
inds = rng.integers(0, self._num_values)
else:
inds = rng.integers(0, self._num_values, size=graph_state.num_samples)

return self.values[inds]


class TableSampler(NumpyRandomFunc):
"""A FunctionNode that returns values from a table-like data,
including a Pandas DataFrame or AstroPy Table. The results returned
can be in-order (for testing) or randomly selected with replacement.
Parameters
----------
data : pandas.DataFrame, astropy.table.Table, or dict
The object containing the data to sample.
columns : list of str
The column names for the output columns.
in_order : bool
Return the given data in order of the rows (True). If False, performs
random sampling with replacement. Default: False
Attributes
----------
data : astropy.table.Table
The object containing the data to sample.
in_order : bool
Return the given data in order of the rows (True). If False, performs
random sampling with replacement. Default: False
next_ind : int
The next index to sample.
The next index to sample for in order sampling.
num_values : int
The total number of items from which to draw the data.
"""

def __init__(self, data, node_label=None, **kwargs):
def __init__(self, data, in_order=False, **kwargs):
self.next_ind = 0
self.in_order = in_order

if isinstance(data, dict):
self.data = pd.DataFrame(data)
self.data = Table(data)
elif isinstance(data, Table):
self.data = data.to_pandas()
elif isinstance(data, pd.DataFrame):
self.data = data.copy()
elif isinstance(data, pd.DataFrame):
self.data = Table.from_pandas(data)
else:
raise TypeError("Unsupported data type for TableSampler.")

# Add each of the flow's data columns as an output parameter.
self.columns = [x for x in self.data.columns]
super().__init__(self._non_func, node_label=node_label, outputs=self.columns, **kwargs)
# Check there are some rows.
self._num_values = len(self.data)
if self._num_values == 0:
raise ValueError("No data provided to TableSampler.")

def _non_func(self):
"""This function does nothing. Everything happens in the overloaded compute()."""
pass
# Add each of the flow's data columns as an output parameter.
super().__init__("uniform", outputs=self.data.colnames, **kwargs)

def reset(self):
"""Reset the next index to use."""
"""Reset the next index to use. Only used for in-order sampling."""
self.next_ind = 0

def compute(self, graph_state, rng_info=None, **kwargs):
Expand All @@ -120,8 +187,8 @@ def compute(self, graph_state, rng_info=None, **kwargs):
An object mapping graph parameters to their values. This object is modified
in place as it is sampled.
rng_info : numpy.random._generator.Generator, optional
Unused in this function, but included to provide consistency with other
compute functions.
A given numpy random number generator to use for this computation. If not
provided, the function uses the node's random number generator.
**kwargs : `dict`, optional
Additional function arguments.
Expand All @@ -131,24 +198,27 @@ def compute(self, graph_state, rng_info=None, **kwargs):
The result of the computation. This return value is provided so that testing
functions can easily access the results.
"""
# Check that we have enough points left to sample.
end_index = self.next_ind + graph_state.num_samples
if end_index > len(self.data):
raise IndexError()
# Compute the indices to sample.
if self.in_order:
# Check that we have enough points left to sample.
end_index = self.next_ind + graph_state.num_samples
if end_index > len(self.data):
raise IndexError()

# Extract the table for [self.next_ind, end_index) and move
# the index counter.
samples = self.data[self.next_ind : end_index]
self.next_ind = end_index
sample_inds = np.arange(self.next_ind, end_index)
self.next_ind = end_index
else:
rng = rng_info if rng_info is not None else self._rng
sample_inds = rng.integers(0, self._num_values, size=graph_state.num_samples)

# Parse out each column in the flow samples as a result vector.
# Parse out each column into a separate parameter with the column name as its name.
results = []
for attr_name in self.columns:
attr_values = samples[attr_name].values
for attr_name in self.outputs:
attr_values = np.asarray(self.data[attr_name][sample_inds])
if graph_state.num_samples == 1:
results.append(attr_values[0])
else:
results.append(np.array(attr_values))
results.append(attr_values)

# Save and return the results.
self._save_results(results, graph_state)
Expand Down
20 changes: 0 additions & 20 deletions src/tdastro/math_nodes/np_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,3 @@ def compute(self, graph_state, rng_info=None, **kwargs):
results = self.func(**args, size=num_samples)
self._save_results(results, graph_state)
return results

def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs):
"""A helper function for testing that regenerates the output.
Parameters
----------
given_args : `dict`, optional
A dictionary representing the given arguments for this sample run.
This can be used as the JAX PyTree for differentiation.
num_samples : `int`
A count of the number of samples to compute.
Default: 1
rng_info : numpy.random._generator.Generator, optional
A given numpy random number generator to use for this computation. If not
provided, the function uses the node's random number generator.
**kwargs : `dict`, optional
Additional function arguments.
"""
state = self.sample_parameters(given_args, num_samples, rng_info)
return self.compute(state, rng_info, **kwargs)
Loading

0 comments on commit 80b2b6e

Please sign in to comment.