Skip to content

Commit

Permalink
Merge pull request #162 from lincc-frameworks/vectorize_distmod
Browse files Browse the repository at this point in the history
Vectorize DistModFromRedshift
  • Loading branch information
jeremykubica authored Oct 15, 2024
2 parents 81ed6b6 + d58fca5 commit a452fd6
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 28 deletions.
46 changes: 19 additions & 27 deletions src/tdastro/astro_utils/snia_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,30 +176,6 @@ def _x0_from_distmod(distmod, x1, c, alpha, beta, m_abs):
return x0


def _distmod_from_redshift(redshift, H0=73.0, Omega_m=0.3):
"""Compute distance modulus given redshift and cosmology.
Parameters
----------
redshift : `float`
The redshift value.
H0: `float`
The Hubble constant.
Omega_m: `float`
The matter density.
Returns
-------
distmod : `float`
The distance modulus (in mag)
"""

cosmo = FlatLambdaCDM(H0=H0, Om0=Omega_m)
distmod = cosmo.distmod(redshift).value

return distmod


class HostmassX1Func(NumericalInversePolynomialFunc):
"""A class for sampling from the HostmassX1Distr.
Expand Down Expand Up @@ -320,11 +296,27 @@ class DistModFromRedshift(FunctionNode):
"""

def __init__(self, redshift, H0=73.0, Omega_m=0.3, **kwargs):
# Create the cosmology once for this node.
self.cosmo = FlatLambdaCDM(H0=H0, Om0=Omega_m)

# Call the super class's constructor with the needed information.
super().__init__(
func=_distmod_from_redshift,
func=self._distmod_from_redshift,
redshift=redshift,
H0=H0,
Omega_m=Omega_m,
**kwargs,
)

def _distmod_from_redshift(self, redshift):
"""Compute distance modulus given redshift and cosmology.
Parameters
----------
redshift : `float` or `numpy.ndarray`
The redshift value(s).
Returns
-------
distmod : `float` or `numpy.ndarray`
The distance modulus (in mag)
"""
return self.cosmo.distmod(redshift).value
14 changes: 13 additions & 1 deletion tests/tdastro/astro_utils/test_snia_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest
from scipy.stats import norm
from tdastro.astro_utils.snia_utils import HostmassX1Distr, HostmassX1Func
from tdastro.astro_utils.snia_utils import DistModFromRedshift, HostmassX1Distr, HostmassX1Func
from tdastro.rand_nodes.np_random import NumpyRandomFunc


Expand Down Expand Up @@ -95,3 +96,14 @@ def test_sample_hostmass_x1c():
hm_node1.get_param(states1, "hostmass"),
hm_node3.get_param(states3, "hostmass"),
)


def test_dist_mod_from_redshift():
"""Test the computation of dist_mod from the redshift."""
redshifts = [0.01, 0.02, 0.05, 0.5]
expected = [33.08419428, 34.60580484, 36.64346629, 42.17006132]

for idx, z in enumerate(redshifts):
node = DistModFromRedshift(redshift=z, H0=73.0, Omega_m=0.3)
state = node.sample_parameters(num_samples=1)
assert node.get_param(state, "function_node_result") == pytest.approx(expected[idx])

0 comments on commit a452fd6

Please sign in to comment.