diff --git a/docs/notebooks/opsim_notebook.ipynb b/docs/notebooks/opsim_notebook.ipynb index 418fbb3e..439db6c8 100644 --- a/docs/notebooks/opsim_notebook.ipynb +++ b/docs/notebooks/opsim_notebook.ipynb @@ -292,6 +292,34 @@ "%%timeit\n", "_ = opsim_data2.range_search(query_ra, query_dec, 0.5)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sampling\n", + "\n", + "We can sample RA, dec, and time from an `OpSim` object using the `OpSimRADECSampler` node. This will select a random observation from the OpSim and then create a random (RA, dec) from near the center of that observation.\n", + "\n", + "This allows us to generate fake query locations for testing that are compatible with arbitrary OpSims as opposed to generating a large number of fake locations and then filtering only those that match the opsim. The distribution will be weighted by the opsim's pointing. If it points at region A 90% of time time and region B 10% of the time, then 90% of the resulting points will be from region A and 10% from region B." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tdastro.math_nodes.ra_dec_sampler import OpSimRADECSampler\n", + "\n", + "sampler_node = OpSimRADECSampler(ops_data, in_order=True)\n", + "\n", + "# Test we can generate a single value.\n", + "(ra, dec, time) = sampler_node.generate(num_samples=10)\n", + "\n", + "for i in range(10):\n", + " print(f\"{i}: ({ra[i]}, {dec[i]}) at t={time[i]}\")" + ] } ], "metadata": { diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 0751368d..430d0a06 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -810,17 +810,33 @@ def compute(self, graph_state, rng_info=None, **kwargs): self._save_results(results, graph_state) return results - def resample_and_compute(self, given_args=None, rng_info=None): - """A helper function for JAX gradients that runs the sampling then computation. + def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs): + """A helper function that regenerates the parameters for this nodes and the + ones above it, then returns the the output or this individual node. + + This is used both for testing and for computing JAX gradients. Parameters ---------- given_args : `dict`, optional A dictionary representing the given arguments for this sample run. This can be used as the JAX PyTree for differentiation. + num_samples : `int` + A count of the number of samples to compute. + Default: 1 rng_info : numpy.random._generator.Generator, optional A given numpy random number generator to use for this computation. If not provided, the function uses the node's random number generator. + **kwargs : `dict`, optional + Additional function arguments. """ - graph_state = self.sample_parameters(given_args, 1, rng_info) - return self.compute(graph_state, rng_info) + state = self.sample_parameters(given_args, num_samples, rng_info) + + # Get the result(s) of compute from the state object. + if len(self.outputs) == 1: + return self.get_param(state, self.outputs[0]) + + results = [] + for output_name in self.outputs: + results.append(self.get_param(state, output_name)) + return results diff --git a/src/tdastro/math_nodes/given_sampler.py b/src/tdastro/math_nodes/given_sampler.py index 6e8a1e0e..53317005 100644 --- a/src/tdastro/math_nodes/given_sampler.py +++ b/src/tdastro/math_nodes/given_sampler.py @@ -8,10 +8,11 @@ from astropy.table import Table from tdastro.base_models import FunctionNode +from tdastro.math_nodes.np_random import NumpyRandomFunc -class GivenSampler(FunctionNode): - """A FunctionNode that returns given results. +class GivenValueList(FunctionNode): + """A FunctionNode that returns given results for a single parameter. Attributes ---------- @@ -22,8 +23,11 @@ class GivenSampler(FunctionNode): """ def __init__(self, values, **kwargs): - self.values = np.array(values) + self.values = np.asarray(values) + if len(values) == 0: + raise ValueError("No values provided for GivenValueList") self.next_ind = 0 + super().__init__(self._non_func, **kwargs) def _non_func(self): @@ -73,42 +77,105 @@ def compute(self, graph_state, rng_info=None, **kwargs): return results -class TableSampler(FunctionNode): - """A FunctionNode that returns values from a table, including - a Pandas DataFrame or AstroPy Table. +class GivenValueSampler(NumpyRandomFunc): + """A FunctionNode that returns randomly selected items from a given list + with replacement. Attributes ---------- + values : list or numpy.ndarray + The values to select from. + _num_values : int + The number of values that can be sampled. + """ + + def __init__(self, values, seed=None, **kwargs): + self.values = np.asarray(values) + self._num_values = len(values) + if self._num_values == 0: + raise ValueError("No values provided for NumpySamplerNode") + + super().__init__("uniform", seed=seed, **kwargs) + + def compute(self, graph_state, rng_info=None, **kwargs): + """Return the given values. + + Parameters + ---------- + graph_state : `GraphState` + An object mapping graph parameters to their values. This object is modified + in place as it is sampled. + rng_info : numpy.random._generator.Generator, optional + A given numpy random number generator to use for this computation. If not + provided, the function uses the node's random number generator. + **kwargs : `dict`, optional + Additional function arguments. + + Returns + ------- + results : any + The result of the computation. This return value is provided so that testing + functions can easily access the results. + """ + rng = rng_info if rng_info is not None else self._rng + + if graph_state.num_samples == 1: + inds = rng.integers(0, self._num_values) + else: + inds = rng.integers(0, self._num_values, size=graph_state.num_samples) + + return self.values[inds] + + +class TableSampler(NumpyRandomFunc): + """A FunctionNode that returns values from a table-like data, + including a Pandas DataFrame or AstroPy Table. The results returned + can be in-order (for testing) or randomly selected with replacement. + + Parameters + ---------- data : pandas.DataFrame, astropy.table.Table, or dict The object containing the data to sample. - columns : list of str - The column names for the output columns. + in_order : bool + Return the given data in order of the rows (True). If False, performs + random sampling with replacement. Default: False + + Attributes + ---------- + data : astropy.table.Table + The object containing the data to sample. + in_order : bool + Return the given data in order of the rows (True). If False, performs + random sampling with replacement. Default: False next_ind : int - The next index to sample. + The next index to sample for in order sampling. + num_values : int + The total number of items from which to draw the data. """ - def __init__(self, data, node_label=None, **kwargs): + def __init__(self, data, in_order=False, **kwargs): self.next_ind = 0 + self.in_order = in_order if isinstance(data, dict): - self.data = pd.DataFrame(data) + self.data = Table(data) elif isinstance(data, Table): - self.data = data.to_pandas() - elif isinstance(data, pd.DataFrame): self.data = data.copy() + elif isinstance(data, pd.DataFrame): + self.data = Table.from_pandas(data) else: raise TypeError("Unsupported data type for TableSampler.") - # Add each of the flow's data columns as an output parameter. - self.columns = [x for x in self.data.columns] - super().__init__(self._non_func, node_label=node_label, outputs=self.columns, **kwargs) + # Check there are some rows. + self._num_values = len(self.data) + if self._num_values == 0: + raise ValueError("No data provided to TableSampler.") - def _non_func(self): - """This function does nothing. Everything happens in the overloaded compute().""" - pass + # Add each of the flow's data columns as an output parameter. + super().__init__("uniform", outputs=self.data.colnames, **kwargs) def reset(self): - """Reset the next index to use.""" + """Reset the next index to use. Only used for in-order sampling.""" self.next_ind = 0 def compute(self, graph_state, rng_info=None, **kwargs): @@ -120,8 +187,8 @@ def compute(self, graph_state, rng_info=None, **kwargs): An object mapping graph parameters to their values. This object is modified in place as it is sampled. rng_info : numpy.random._generator.Generator, optional - Unused in this function, but included to provide consistency with other - compute functions. + A given numpy random number generator to use for this computation. If not + provided, the function uses the node's random number generator. **kwargs : `dict`, optional Additional function arguments. @@ -131,24 +198,27 @@ def compute(self, graph_state, rng_info=None, **kwargs): The result of the computation. This return value is provided so that testing functions can easily access the results. """ - # Check that we have enough points left to sample. - end_index = self.next_ind + graph_state.num_samples - if end_index > len(self.data): - raise IndexError() + # Compute the indices to sample. + if self.in_order: + # Check that we have enough points left to sample. + end_index = self.next_ind + graph_state.num_samples + if end_index > len(self.data): + raise IndexError() - # Extract the table for [self.next_ind, end_index) and move - # the index counter. - samples = self.data[self.next_ind : end_index] - self.next_ind = end_index + sample_inds = np.arange(self.next_ind, end_index) + self.next_ind = end_index + else: + rng = rng_info if rng_info is not None else self._rng + sample_inds = rng.integers(0, self._num_values, size=graph_state.num_samples) - # Parse out each column in the flow samples as a result vector. + # Parse out each column into a separate parameter with the column name as its name. results = [] - for attr_name in self.columns: - attr_values = samples[attr_name].values + for attr_name in self.outputs: + attr_values = np.asarray(self.data[attr_name][sample_inds]) if graph_state.num_samples == 1: results.append(attr_values[0]) else: - results.append(np.array(attr_values)) + results.append(attr_values) # Save and return the results. self._save_results(results, graph_state) diff --git a/src/tdastro/math_nodes/np_random.py b/src/tdastro/math_nodes/np_random.py index 9b40d38b..b43876ff 100644 --- a/src/tdastro/math_nodes/np_random.py +++ b/src/tdastro/math_nodes/np_random.py @@ -103,23 +103,3 @@ def compute(self, graph_state, rng_info=None, **kwargs): results = self.func(**args, size=num_samples) self._save_results(results, graph_state) return results - - def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs): - """A helper function for testing that regenerates the output. - - Parameters - ---------- - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. - num_samples : `int` - A count of the number of samples to compute. - Default: 1 - rng_info : numpy.random._generator.Generator, optional - A given numpy random number generator to use for this computation. If not - provided, the function uses the node's random number generator. - **kwargs : `dict`, optional - Additional function arguments. - """ - state = self.sample_parameters(given_args, num_samples, rng_info) - return self.compute(state, rng_info, **kwargs) diff --git a/src/tdastro/math_nodes/ra_dec_sampler.py b/src/tdastro/math_nodes/ra_dec_sampler.py new file mode 100644 index 00000000..ba7cac1d --- /dev/null +++ b/src/tdastro/math_nodes/ra_dec_sampler.py @@ -0,0 +1,142 @@ +"""Samplers used for generating (RA, dec) coordinates.""" + +import numpy as np + +from tdastro.math_nodes.given_sampler import TableSampler +from tdastro.math_nodes.np_random import NumpyRandomFunc + + +class UniformRADEC(NumpyRandomFunc): + """A FunctionNode that uniformly samples (RA, dec) over a sphere, + + Attributes + ---------- + use_degrees : bool + The default return unit. If True returns samples in degrees. + Otherwise, if False, returns samples in radians. + """ + + def __init__(self, outputs=None, seed=None, use_degrees=True, **kwargs): + self.use_degrees = use_degrees + + # Override key arguments. We create a uniform sampler function, but + # won't need it because the subclass overloads compute(). + func_name = "uniform" + outputs = ["ra", "dec"] + super().__init__(func_name, outputs=outputs, seed=seed, **kwargs) + + def compute(self, graph_state, rng_info=None, **kwargs): + """Return the given values. + + Parameters + ---------- + graph_state : `GraphState` + An object mapping graph parameters to their values. This object is modified + in place as it is sampled. + rng_info : numpy.random._generator.Generator, optional + A given numpy random number generator to use for this computation. If not + provided, the function uses the node's random number generator. + **kwargs : `dict`, optional + Additional function arguments. + + Returns + ------- + results : any + The result of the computation. This return value is provided so that testing + functions can easily access the results. + """ + rng = rng_info if rng_info is not None else self._rng + + # Generate the random (RA, dec) lists. + ra = rng.uniform(0.0, 2.0 * np.pi, size=graph_state.num_samples) + dec = np.arcsin(rng.uniform(-1.0, 1.0, size=graph_state.num_samples)) + if self.use_degrees: + ra = np.degrees(ra) + dec = np.degrees(dec) + + # If we are generating a single sample, return floats. + if graph_state.num_samples == 1: + ra = ra[0] + dec = dec[0] + + # Set the outputs and return the results. This takes the place of + # function node's _save_results() function because we know the outputs. + graph_state.set(self.node_string, "ra", ra) + graph_state.set(self.node_string, "dec", dec) + return [ra, dec] + + +class OpSimRADECSampler(TableSampler): + """A FunctionNode that samples RA and dec (and time) from an OpSim. + RA and dec are returned in degrees. + + Note + ---- + Does not currently use uniform sampling from the radius. Uses a very + rough approximate as a proof of concept. Do not use for statistical analysis. + + Parameters + ---------- + data : OpSim + The OpSim object to use for sampling. + radius : float + The radius of the observations in degrees. Use 0.0 to just sample + the centers of the images. Default: 0.0 + in_order : bool + Return the given data in order of the rows (True). If False, performs + random sampling with replacement. Default: False + """ + + def __init__(self, data, radius=0.0, in_order=False, **kwargs): + if radius < 0.0: + raise ValueError("Invalid radius: {radius}") + self.radius = radius + + data_dict = { + "ra": data["ra"], + "dec": data["dec"], + "time": data["time"], + } + super().__init__(data_dict, in_order=in_order, **kwargs) + + def compute(self, graph_state, rng_info=None, **kwargs): + """Return the given values. + + Parameters + ---------- + graph_state : `GraphState` + An object mapping graph parameters to their values. This object is modified + in place as it is sampled. + rng_info : numpy.random._generator.Generator, optional + A given numpy random number generator to use for this computation. If not + provided, the function uses the node's random number generator. + **kwargs : `dict`, optional + Additional function arguments. + + Returns + ------- + results : any + The result of the computation. This return value is provided so that testing + functions can easily access the results. + """ + # Sample the center RA, dec, and times without the radius. + results = super().compute(graph_state, rng_info=rng_info, **kwargs) + + if self.radius > 0.0: + # Add an offset around the center. This is currently a placeholder that does + # NOT produce a uniform sampling. TODO: Make this uniform sampling. + rng = rng_info if rng_info is not None else self._rng + + # Choose a uniform circle around the center point. Not that this is not uniform over + # the final RA, dec because it does not account for compression in dec around the polls. + offset_amt = self.radius * np.sqrt(rng.uniform(0.0, 1.0, size=graph_state.num_samples)) + offset_ang = 2.0 * np.pi * rng.uniform(0.0, 1.0, size=graph_state.num_samples) + + # Add the offsets to RA and dec. Keep time unchanged. + results[0] += offset_amt * np.cos(offset_ang) # RA + results[1] += offset_amt * np.sin(offset_ang) # dec + + # Resave the results (overwriting the previous results) + self._save_results(results, graph_state) + + return results diff --git a/src/tdastro/math_nodes/scipy_random.py b/src/tdastro/math_nodes/scipy_random.py index 642f231c..8564d4c5 100644 --- a/src/tdastro/math_nodes/scipy_random.py +++ b/src/tdastro/math_nodes/scipy_random.py @@ -147,26 +147,6 @@ def compute(self, graph_state, rng_info=None, **kwargs): self._save_results(results, graph_state) return results - def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs): - """A helper function for testing that regenerates the output. - - Parameters - ---------- - given_args : `dict`, optional - A dictionary representing the given arguments for this sample run. - This can be used as the JAX PyTree for differentiation. - num_samples : `int` - A count of the number of samples to compute. - Default: 1 - rng_info : numpy.random._generator.Generator, optional - A given numpy random number generator to use for this computation. If not - provided, the function uses the node's random number generator. - **kwargs : `dict`, optional - Additional function arguments. - """ - state = self.sample_parameters(given_args, num_samples, rng_info) - return self.compute(state, rng_info, **kwargs) - class PDFFunctionWrapper: """A class that just wraps a given PDF function. diff --git a/tests/tdastro/math_nodes/test_basic_math_node.py b/tests/tdastro/math_nodes/test_basic_math_node.py index c753a0aa..67663d82 100644 --- a/tests/tdastro/math_nodes/test_basic_math_node.py +++ b/tests/tdastro/math_nodes/test_basic_math_node.py @@ -147,7 +147,7 @@ def test_basic_math_node_autodiff_jax(): state = node.sample_parameters() pytree = node.build_pytree(state) - gr_func = jax.value_and_grad(node.resample_and_compute) + gr_func = jax.value_and_grad(node.generate) values, gradients = gr_func(pytree) assert values == 2.0 assert gradients["a_node"]["a"] > 0.0 diff --git a/tests/tdastro/math_nodes/test_given_sampler.py b/tests/tdastro/math_nodes/test_given_sampler.py index 45b40b4d..f9c51865 100644 --- a/tests/tdastro/math_nodes/test_given_sampler.py +++ b/tests/tdastro/math_nodes/test_given_sampler.py @@ -4,7 +4,7 @@ from astropy.table import Table from tdastro.base_models import FunctionNode from tdastro.graph_state import GraphState -from tdastro.math_nodes.given_sampler import GivenSampler, TableSampler +from tdastro.math_nodes.given_sampler import GivenValueList, GivenValueSampler, TableSampler def _test_func(value1, value2): @@ -20,9 +20,9 @@ def _test_func(value1, value2): return value1 + value2 -def test_given_sampler(): - """Test that we can retrieve numbers from a GivenSampler.""" - given_node = GivenSampler([1.0, 1.5, 2.0, 2.5, 3.0, -1.0, 3.5]) +def test_given_value_list(): + """Test that we can retrieve numbers from a GivenValueList.""" + given_node = GivenValueList([1.0, 1.5, 2.0, 2.5, 3.0, -1.0, 3.5]) # Check that we generate the correct result and save it in the GraphState. state1 = GraphState(num_samples=2) @@ -40,12 +40,12 @@ def test_given_sampler(): assert np.array_equal(results, [2.5, 3.0]) assert np.array_equal(given_node.get_param(state3, "function_node_result"), [2.5, 3.0]) - # Check that GivenSampler raises an error when it has run out of samples. + # Check that GivenValueList raises an error when it has run out of samples. state4 = GraphState(num_samples=4) with pytest.raises(IndexError): _ = given_node.compute(state4) - # Resetting the GivenSampler starts back at the beginning. + # Resetting the GivenValueList starts back at the beginning. given_node.reset() state5 = GraphState(num_samples=6) results = given_node.compute(state5) @@ -56,10 +56,10 @@ def test_given_sampler(): ) -def test_test_given_sampler_compound(): - """Test that we can use the GivenSampler as input into another node.""" +def test_test_given_value_list_compound(): + """Test that we can use the GivenValueList as input into another node.""" values = [1.0, 1.5, 2.0, 2.5, 3.0, -1.0, 3.5, 4.0, 10.0, -2.0] - given_node = GivenSampler(values) + given_node = GivenValueList(values) # Create a function node that takes the next value and adds 2.0. compound_node = FunctionNode(_test_func, value1=given_node, value2=2.0) @@ -76,6 +76,21 @@ def test_test_given_sampler_compound(): ) +def test_given_value_sampler(): + """Test that we can retrieve numbers from a GivenValueSampler.""" + given_node = GivenValueSampler([1, 3, 5, 7]) + + # Check that we have sampled uniformly from the given options. + state = GraphState(num_samples=5_000) + results = given_node.compute(state) + assert len(results) == 5_000 + assert np.all((results == 1) | (results == 3) | (results == 5) | (results == 7)) + assert len(results[results == 1]) > 1000 + assert len(results[results == 3]) > 1000 + assert len(results[results == 5]) > 1000 + assert len(results[results == 7]) > 1000 + + @pytest.mark.parametrize("test_data_type", ["dict", "ap_table", "pd_df"]) def test_table_sampler(test_data_type): """Test that we can retrieve numbers from a TableSampler from a @@ -97,7 +112,7 @@ def test_table_sampler(test_data_type): data = None # Create the table sampler from the data. - table_node = TableSampler(data, node_label="node") + table_node = TableSampler(data, in_order=True, node_label="node") state = table_node.sample_parameters(num_samples=2) assert len(state) == 3 assert np.allclose(state["node"]["A"], [1, 2]) @@ -127,3 +142,35 @@ def test_table_sampler(test_data_type): assert np.allclose(state["node"]["A"], [1, 2]) assert np.allclose(state["node"]["B"], [1, 1]) assert np.allclose(state["node"]["C"], [3, 4]) + + +def test_table_sampler_ranndomized(): + """Test that we can retrieve numbers from a TableSampler.""" + raw_data_dict = { + "A": [1, 3, 5], + "B": [2, 4, 6], + } + + # Create the table sampler from the data. + table_node = TableSampler(raw_data_dict, node_label="node") + state = table_node.sample_parameters(num_samples=2000) + assert len(state) == 2 + + # We have sampled the a_vals roughly uniformly from the three options. + a_vals = state["node"]["A"] + assert len(a_vals) == 2000 + assert np.all((a_vals == 1) | (a_vals == 3) | (a_vals == 5)) + assert len(a_vals[a_vals == 1]) > 500 + assert len(a_vals[a_vals == 3]) > 500 + assert len(a_vals[a_vals == 5]) > 500 + + # We have sampled the b_vals roughly uniformly from the three options. + b_vals = state["node"]["B"] + assert len(b_vals) == 2000 + assert np.all((b_vals == 2) | (b_vals == 4) | (b_vals == 6)) + assert len(b_vals[b_vals == 2]) > 500 + assert len(b_vals[b_vals == 4]) > 500 + assert len(b_vals[b_vals == 6]) > 500 + + # We always sample consistent ROWS of a and b. + assert np.all(b_vals - a_vals == 1) diff --git a/tests/tdastro/math_nodes/test_np_random.py b/tests/tdastro/math_nodes/test_np_random.py index 98f11536..7ee91e29 100644 --- a/tests/tdastro/math_nodes/test_np_random.py +++ b/tests/tdastro/math_nodes/test_np_random.py @@ -36,13 +36,6 @@ def test_numpy_random_uniform(): assert np.all(values >= 10.0) assert np.abs(np.mean(values) - 15.0) < 0.5 - # We can override the range dynamically. - values = np.array([np_node3.generate(low=2.0) for _ in range(10_000)]) - assert len(np.unique(values)) > 10 - assert np.all(values <= 20.0) - assert np.all(values >= 2.0) - assert np.abs(np.mean(values) - 11.0) < 0.5 - def test_numpy_random_uniform_multi(): """Test that we can many generate numbers at once from a uniform distribution.""" diff --git a/tests/tdastro/math_nodes/test_ra_dec_sampler.py b/tests/tdastro/math_nodes/test_ra_dec_sampler.py new file mode 100644 index 00000000..f46bcb44 --- /dev/null +++ b/tests/tdastro/math_nodes/test_ra_dec_sampler.py @@ -0,0 +1,109 @@ +import numpy as np +from tdastro.astro_utils.opsim import OpSim +from tdastro.math_nodes.ra_dec_sampler import OpSimRADECSampler, UniformRADEC + + +def test_uniform_ra_dec(): + """Test that we can generate numbers from a uniform distribution on a sphere.""" + sampler_node = UniformRADEC(seed=100, node_label="sampler") + + # Test we can generate a single value. + (ra, dec) = sampler_node.generate(num_samples=1) + assert 0.0 <= ra <= 360.0 + assert -90.0 <= dec <= 90.0 + + # Generate many samples. + num_samples = 20_000 + state = sampler_node.sample_parameters(num_samples=num_samples) + + all_ra = state["sampler"]["ra"] + assert len(all_ra) == num_samples + assert np.all(all_ra >= 0.0) + assert np.all(all_ra <= 360.0) + + all_dec = state["sampler"]["dec"] + assert len(all_dec) == num_samples + assert np.all(all_dec >= -90.0) + assert np.all(all_dec <= 90.0) + + # Compute histograms of RA and dec values. + ra_bins = np.zeros(36) + dec_bins = np.zeros(18) + for idx in range(num_samples): + ra_bins[int(all_ra[idx] / 10.0)] += 1 + dec_bins[int((all_dec[idx] + 90.0) / 10.0)] += 1 + + # Check that all RA bins have approximately equal samples. + expected_count = num_samples / 36 + for bin_count in ra_bins: + assert 0.8 <= bin_count / expected_count <= 1.2 + + # Check that the dec bins around the poles have less samples + # than the bins around the equator. + assert dec_bins[0] < 0.25 * dec_bins[9] + assert dec_bins[17] < 0.25 * dec_bins[10] + + # Check that we can generate uniform samples in radians. + sampler_node2 = UniformRADEC(seed=100, node_label="sampler2", use_degrees=False) + state2 = sampler_node2.sample_parameters(num_samples=num_samples) + + all_ra = state2["sampler2"]["ra"] + assert len(all_ra) == num_samples + assert np.all(all_ra >= 0.0) + assert np.all(all_ra <= 2.0 * np.pi) + + all_dec = state2["sampler2"]["dec"] + assert len(all_dec) == num_samples + assert np.all(all_dec >= -np.pi) + assert np.all(all_dec <= np.pi) + + +def test_opsim_ra_dec_sampler(): + """Test that we can sample from am OpSim object.""" + values = { + "observationStartMJD": np.array([0.0, 1.0, 2.0, 3.0, 4.0]), + "fieldRA": np.array([15.0, 30.0, 15.0, 0.0, 60.0]), + "fieldDec": np.array([-10.0, -5.0, 0.0, 5.0, 10.0]), + "zp_nJy": np.ones(5), + } + ops_data = OpSim(values) + assert len(ops_data) == 5 + + sampler_node = OpSimRADECSampler(ops_data, in_order=True) + + # Test we can generate a single value. + (ra, dec, time) = sampler_node.generate(num_samples=1) + assert ra == 15.0 + assert dec == -10.0 + assert time == 0.0 + + # Test we can generate multiple observations + (ra, dec, time) = sampler_node.generate(num_samples=2) + assert np.allclose(ra, [30.0, 15.0]) + assert np.allclose(dec, [-5.0, 0.0]) + assert np.allclose(time, [1.0, 2.0]) + + # Do randomized sampling. + sampler_node2 = OpSimRADECSampler(ops_data, in_order=False, seed=100, node_label="sampler") + state = sampler_node2.sample_parameters(num_samples=5000) + + # Check that the samples are uniform and consistent. + int_times = state["sampler"]["time"].astype(int) + assert np.allclose(state["sampler"]["ra"], values["fieldRA"][int_times]) + assert np.allclose(state["sampler"]["dec"], values["fieldDec"][int_times]) + assert len(int_times[int_times == 0]) > 750 + assert len(int_times[int_times == 1]) > 750 + assert len(int_times[int_times == 2]) > 750 + assert len(int_times[int_times == 3]) > 750 + assert len(int_times[int_times == 4]) > 750 + + # Do randomized sampling with offsets. + sampler_node3 = OpSimRADECSampler(ops_data, in_order=False, seed=100, radius=0.1, node_label="sampler") + state = sampler_node3.sample_parameters(num_samples=5000) + + # Check that the samples are not all the centers (unique values > 5) but are close. + int_times = state["sampler"]["time"].astype(int) + assert len(np.unique(state["sampler"]["ra"])) > 5 + assert len(np.unique(state["sampler"]["dec"])) > 5 + assert np.allclose(state["sampler"]["ra"], values["fieldRA"][int_times], atol=0.2) + assert np.allclose(state["sampler"]["dec"], values["fieldDec"][int_times], atol=0.2) diff --git a/tests/tdastro/sources/test_physical_models.py b/tests/tdastro/sources/test_physical_models.py index e6235981..a23b0f08 100644 --- a/tests/tdastro/sources/test_physical_models.py +++ b/tests/tdastro/sources/test_physical_models.py @@ -2,7 +2,7 @@ import pytest from astropy.cosmology import Planck18 from tdastro.astro_utils.passbands import PassbandGroup -from tdastro.math_nodes.given_sampler import GivenSampler +from tdastro.math_nodes.given_sampler import GivenValueList from tdastro.sources.physical_model import PhysicalModel from tdastro.sources.static_source import StaticSource @@ -60,7 +60,7 @@ def test_physical_model_evaluate(): """Test that we can evaluate a PhysicalModel.""" times = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) waves = np.array([4000.0, 5000.0]) - brightness = GivenSampler([10.0, 20.0, 30.0]) + brightness = GivenValueList([10.0, 20.0, 30.0]) static_source = StaticSource(brightness=brightness) # Providing no state should give a single sample. @@ -110,7 +110,7 @@ def test_physical_model_get_band_fluxes(passbands_dir): # If we use multiple samples, we should get a correctly sized array. n_samples = 21 brightness_list = [1.5 * i for i in range(n_samples)] - static_source2 = StaticSource(brightness=GivenSampler(brightness_list)) + static_source2 = StaticSource(brightness=GivenValueList(brightness_list)) state2 = static_source2.sample_parameters(num_samples=n_samples) band_fluxes2 = static_source2.get_band_fluxes(passbands, times, filters, state2) assert band_fluxes2.shape == (n_samples, n_passbands) diff --git a/tests/tdastro/test_base_models.py b/tests/tdastro/test_base_models.py index 6e2e5ea1..9f966755 100644 --- a/tests/tdastro/test_base_models.py +++ b/tests/tdastro/test_base_models.py @@ -241,6 +241,9 @@ def test_function_node_basic(): assert my_func.compute(state, value2=3.0, value1=1.0) == 4.0 assert str(my_func) == "FunctionNode:_test_func_0" + # We can also compute this result (for testing) by calling generate(). + assert my_func.generate() == 3.0 + def test_function_node_chain(): """Test that we can create and query a chained FunctionNode.""" @@ -333,7 +336,7 @@ def _test_func2(value1, value2): graph_state = sum_node.sample_parameters() pytree = sum_node.build_pytree(graph_state) - gr_func = jax.value_and_grad(sum_node.resample_and_compute) + gr_func = jax.value_and_grad(sum_node.generate) values, gradients = gr_func(pytree) assert values == 9.0 assert gradients["sum"]["value1"] == 1.0