Skip to content

Commit

Permalink
Allow NumericalInversePolynomialFunc to take a function
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Oct 15, 2024
1 parent 7b38e91 commit 9316867
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 8 deletions.
37 changes: 29 additions & 8 deletions src/tdastro/rand_nodes/scipy_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,29 @@
from tdastro.graph_state import transpose_dict_of_list


class PDFFunctionWrapper:
"""A class that just wraps a given PDF function.
Attributes
----------
_pdf : function
The PDF function.
"""

def __init__(self, func):
self._pdf = func
self.pdf = self._pdf


class NumericalInversePolynomialFunc(FunctionNode):
"""A class for sampling from scipy's NumericalInversePolynomial
given an distribution object or a class from which to create
such an object.
given a distribution function, an object with a pdf function,
or a class from which to create such an object.
Note
----
If a class is provided, then the sampling function will create a new
object (with the sampled parameters) for each sampling.
object (with the sampled parameters) for each sampling. This is very expensive.
Attributes
----------
Expand All @@ -42,12 +56,17 @@ class NumericalInversePolynomialFunc(FunctionNode):
"""

def __init__(self, dist=None, seed=None, **kwargs):
# Check that the distribution object/class has a pdf or logpdf function.
if not hasattr(dist, "pdf") and not hasattr(dist, "logpdf"):
# Check that the distribution object/class has a pdf or logpdf function
# or that we have provided a function directly.
if hasattr(dist, "pdf") and not hasattr(dist, "logpdf"):
self._dist = dist
elif callable(dist):
self._dist = PDFFunctionWrapper(dist)
else:
raise ValueError("Distribution must have either pdf() or logpdf().")
self._dist = dist

# Classes show up as type="type"
# Classes show up as type="type". In this case we will need to create
# a concrete object from the class and any given parameters.
if isinstance(dist, type):
self._inv_poly = None
self._vect_sample = np.vectorize(self._create_and_sample)
Expand Down Expand Up @@ -79,7 +98,9 @@ def set_seed(self, new_seed):
self._rng = np.random.default_rng(seed=new_seed)

def _create_and_sample(self, args, rng):
"""Create the distribution function and sample it.
"""Create the distribution function and sample it. This is only
needed if our distribution is in the form of a class that must
be instantiated with different parameters each sampling run.
Parameters
----------
Expand Down
40 changes: 40 additions & 0 deletions tests/tdastro/rand_nodes/test_scipy_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,26 @@ def pdf(self, xval):
return 1.0 / self.width


def triangle_pdf(x):
"""A pdf function with is a line from (0, 1.0) to (2.0, 0.0).
Parameters
----------
x : float
The value of x at which to evaluate.
Returns
-------
y : float
The pdf value at x.
"""
if x < 0.0:
return 0.0
if x > 2.0:
return 0.0
return 1.0 - 0.5 * x


def test_numerical_inverse_polynomial_func_object():
"""Test that we can generate numbers from a uniform distribution."""
dist = FlatDist(min_val=0.25, max_val=0.75)
Expand Down Expand Up @@ -133,3 +153,23 @@ def test_numerical_inverse_polynomial_func_class():
# Test that the generated values are consistent with the distributions.
assert np.all(values >= min_vals)
assert np.all(values <= max_vals)


def test_numerical_inverse_polynomial_func_function():
"""Test that we can create a NumericalInversePolynomialFunc from a function."""
scipy_node = NumericalInversePolynomialFunc(triangle_pdf, seed=100)

# Test that we get distribution that ramps down as x increases.
num_samples = 50_000
counts = np.zeros(10)
for _ in range(num_samples):
value = scipy_node.generate()
assert 0.0 <= value <= 2.0

bin = int(10 * value / 2.0)
counts[bin] += 1

mid_heights = np.array([0.95, 0.85, 0.75, 0.65, 0.55, 0.45, 0.35, 0.25, 0.15, 0.05])
expected = mid_heights * 0.2 * num_samples
for idx in range(10):
assert np.abs(counts[idx] - expected[idx]) < 200.0

0 comments on commit 9316867

Please sign in to comment.