From 9316867ab8d29f4f7e2d6681583e7d93a352669e Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Tue, 15 Oct 2024 12:05:55 -0400 Subject: [PATCH] Allow NumericalInversePolynomialFunc to take a function --- src/tdastro/rand_nodes/scipy_random.py | 37 +++++++++++++---- tests/tdastro/rand_nodes/test_scipy_random.py | 40 +++++++++++++++++++ 2 files changed, 69 insertions(+), 8 deletions(-) diff --git a/src/tdastro/rand_nodes/scipy_random.py b/src/tdastro/rand_nodes/scipy_random.py index 4b633ec2..59d95804 100644 --- a/src/tdastro/rand_nodes/scipy_random.py +++ b/src/tdastro/rand_nodes/scipy_random.py @@ -9,15 +9,29 @@ from tdastro.graph_state import transpose_dict_of_list +class PDFFunctionWrapper: + """A class that just wraps a given PDF function. + + Attributes + ---------- + _pdf : function + The PDF function. + """ + + def __init__(self, func): + self._pdf = func + self.pdf = self._pdf + + class NumericalInversePolynomialFunc(FunctionNode): """A class for sampling from scipy's NumericalInversePolynomial - given an distribution object or a class from which to create - such an object. + given a distribution function, an object with a pdf function, + or a class from which to create such an object. Note ---- If a class is provided, then the sampling function will create a new - object (with the sampled parameters) for each sampling. + object (with the sampled parameters) for each sampling. This is very expensive. Attributes ---------- @@ -42,12 +56,17 @@ class NumericalInversePolynomialFunc(FunctionNode): """ def __init__(self, dist=None, seed=None, **kwargs): - # Check that the distribution object/class has a pdf or logpdf function. - if not hasattr(dist, "pdf") and not hasattr(dist, "logpdf"): + # Check that the distribution object/class has a pdf or logpdf function + # or that we have provided a function directly. + if hasattr(dist, "pdf") and not hasattr(dist, "logpdf"): + self._dist = dist + elif callable(dist): + self._dist = PDFFunctionWrapper(dist) + else: raise ValueError("Distribution must have either pdf() or logpdf().") - self._dist = dist - # Classes show up as type="type" + # Classes show up as type="type". In this case we will need to create + # a concrete object from the class and any given parameters. if isinstance(dist, type): self._inv_poly = None self._vect_sample = np.vectorize(self._create_and_sample) @@ -79,7 +98,9 @@ def set_seed(self, new_seed): self._rng = np.random.default_rng(seed=new_seed) def _create_and_sample(self, args, rng): - """Create the distribution function and sample it. + """Create the distribution function and sample it. This is only + needed if our distribution is in the form of a class that must + be instantiated with different parameters each sampling run. Parameters ---------- diff --git a/tests/tdastro/rand_nodes/test_scipy_random.py b/tests/tdastro/rand_nodes/test_scipy_random.py index 3cc0400c..7ab69172 100644 --- a/tests/tdastro/rand_nodes/test_scipy_random.py +++ b/tests/tdastro/rand_nodes/test_scipy_random.py @@ -44,6 +44,26 @@ def pdf(self, xval): return 1.0 / self.width +def triangle_pdf(x): + """A pdf function with is a line from (0, 1.0) to (2.0, 0.0). + + Parameters + ---------- + x : float + The value of x at which to evaluate. + + Returns + ------- + y : float + The pdf value at x. + """ + if x < 0.0: + return 0.0 + if x > 2.0: + return 0.0 + return 1.0 - 0.5 * x + + def test_numerical_inverse_polynomial_func_object(): """Test that we can generate numbers from a uniform distribution.""" dist = FlatDist(min_val=0.25, max_val=0.75) @@ -133,3 +153,23 @@ def test_numerical_inverse_polynomial_func_class(): # Test that the generated values are consistent with the distributions. assert np.all(values >= min_vals) assert np.all(values <= max_vals) + + +def test_numerical_inverse_polynomial_func_function(): + """Test that we can create a NumericalInversePolynomialFunc from a function.""" + scipy_node = NumericalInversePolynomialFunc(triangle_pdf, seed=100) + + # Test that we get distribution that ramps down as x increases. + num_samples = 50_000 + counts = np.zeros(10) + for _ in range(num_samples): + value = scipy_node.generate() + assert 0.0 <= value <= 2.0 + + bin = int(10 * value / 2.0) + counts[bin] += 1 + + mid_heights = np.array([0.95, 0.85, 0.75, 0.65, 0.55, 0.45, 0.35, 0.25, 0.15, 0.05]) + expected = mid_heights * 0.2 * num_samples + for idx in range(10): + assert np.abs(counts[idx] - expected[idx]) < 200.0