Skip to content

Commit

Permalink
Add ability to wrap logpdf functions as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Oct 15, 2024
1 parent 9316867 commit eef873a
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 43 deletions.
87 changes: 68 additions & 19 deletions src/tdastro/rand_nodes/scipy_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,6 @@
from tdastro.graph_state import transpose_dict_of_list


class PDFFunctionWrapper:
"""A class that just wraps a given PDF function.
Attributes
----------
_pdf : function
The PDF function.
"""

def __init__(self, func):
self._pdf = func
self.pdf = self._pdf


class NumericalInversePolynomialFunc(FunctionNode):
"""A class for sampling from scipy's NumericalInversePolynomial
given a distribution function, an object with a pdf function,
Expand Down Expand Up @@ -58,12 +44,9 @@ class NumericalInversePolynomialFunc(FunctionNode):
def __init__(self, dist=None, seed=None, **kwargs):
# Check that the distribution object/class has a pdf or logpdf function
# or that we have provided a function directly.
if hasattr(dist, "pdf") and not hasattr(dist, "logpdf"):
self._dist = dist
elif callable(dist):
self._dist = PDFFunctionWrapper(dist)
else:
if not hasattr(dist, "pdf") and not hasattr(dist, "logpdf"):
raise ValueError("Distribution must have either pdf() or logpdf().")
self._dist = dist

# Classes show up as type="type". In this case we will need to create
# a concrete object from the class and any given parameters.
Expand Down Expand Up @@ -183,3 +166,69 @@ def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs):
"""
state = self.sample_parameters(given_args, num_samples, rng_info)
return self.compute(state, rng_info, **kwargs)


class PDFFunctionWrapper:
"""A class that just wraps a given PDF function.
Attributes
----------
_pdf : function
The PDF function.
"""

def __init__(self, func):
self._pdf = func
self.pdf = self._pdf


class SamplePDF(NumericalInversePolynomialFunc):
"""A node for sampling from a given PDF function.
Parameters
----------
dist : function, class, or object
The pdf function from which to sample or a class/object with that function.
"""

def __init__(self, dist, **kwargs):
if hasattr(dist, "pdf"):
self.dist_obj = dist
elif callable(dist):
self.dist_obj = PDFFunctionWrapper(dist)
else:
raise ValueError("No pdf function detected.")
super().__init__(self.dist_obj, **kwargs)


class LogPDFFunctionWrapper:
"""A class that just wraps a given Log PDF function.
Attributes
----------
_log_pdf : function
The log PDF function.
"""

def __init__(self, func):
self._log_pdf = func
self.logpdf = self._log_pdf


class SampleLogPDF(NumericalInversePolynomialFunc):
"""A node for sampling from a given PDF function.
Parameters
----------
dist : function, class, or object
The pdf function from which to sample or a class/object with that function.
"""

def __init__(self, dist, **kwargs):
if hasattr(dist, "logpdf"):
self.dist_obj = dist
elif callable(dist):
self.dist_obj = LogPDFFunctionWrapper(dist)
else:
raise ValueError("No logpdf function detected.")
super().__init__(self.dist_obj, **kwargs)
68 changes: 44 additions & 24 deletions tests/tdastro/rand_nodes/test_scipy_random.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import numpy as np
from tdastro.rand_nodes.np_random import NumpyRandomFunc
from tdastro.rand_nodes.scipy_random import NumericalInversePolynomialFunc
from tdastro.rand_nodes.scipy_random import (
NumericalInversePolynomialFunc,
SampleLogPDF,
SamplePDF,
)


class FlatDist:
Expand Down Expand Up @@ -44,26 +48,6 @@ def pdf(self, xval):
return 1.0 / self.width


def triangle_pdf(x):
"""A pdf function with is a line from (0, 1.0) to (2.0, 0.0).
Parameters
----------
x : float
The value of x at which to evaluate.
Returns
-------
y : float
The pdf value at x.
"""
if x < 0.0:
return 0.0
if x > 2.0:
return 0.0
return 1.0 - 0.5 * x


def test_numerical_inverse_polynomial_func_object():
"""Test that we can generate numbers from a uniform distribution."""
dist = FlatDist(min_val=0.25, max_val=0.75)
Expand Down Expand Up @@ -155,9 +139,45 @@ def test_numerical_inverse_polynomial_func_class():
assert np.all(values <= max_vals)


def test_numerical_inverse_polynomial_func_function():
"""Test that we can create a NumericalInversePolynomialFunc from a function."""
scipy_node = NumericalInversePolynomialFunc(triangle_pdf, seed=100)
def test_numerical_sample_pdf():
"""Test that we can create a SamplePDF node from a function."""

def _triangle_pdf(x):
if x < 0.0:
return 0.0
if x > 2.0:
return 0.0
return 1.0 - 0.5 * x

scipy_node = SamplePDF(_triangle_pdf, seed=100)

# Test that we get distribution that ramps down as x increases.
num_samples = 50_000
counts = np.zeros(10)
for _ in range(num_samples):
value = scipy_node.generate()
assert 0.0 <= value <= 2.0

bin = int(10 * value / 2.0)
counts[bin] += 1

mid_heights = np.array([0.95, 0.85, 0.75, 0.65, 0.55, 0.45, 0.35, 0.25, 0.15, 0.05])
expected = mid_heights * 0.2 * num_samples
for idx in range(10):
assert np.abs(counts[idx] - expected[idx]) < 200.0


def test_numerical_sample_logpdf():
"""Test that we can create a SampleLogPDF node from a function."""

def _triangle_logpdf(x):
if x < 0.0:
return -np.inf
if x > 2.0:
return -np.inf
return np.log(1.0 - 0.5 * x)

scipy_node = SampleLogPDF(_triangle_logpdf, seed=100)

# Test that we get distribution that ramps down as x increases.
num_samples = 50_000
Expand Down

0 comments on commit eef873a

Please sign in to comment.