From d58fca58de1aa941c3bdd2a242618734837df497 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 14 Oct 2024 12:21:35 -0400 Subject: [PATCH] Vectorize DistModFromRedshift --- src/tdastro/astro_utils/snia_utils.py | 46 ++++++++------------ tests/tdastro/astro_utils/test_snia_utils.py | 14 +++++- 2 files changed, 32 insertions(+), 28 deletions(-) diff --git a/src/tdastro/astro_utils/snia_utils.py b/src/tdastro/astro_utils/snia_utils.py index 7afc219c..a921ff76 100644 --- a/src/tdastro/astro_utils/snia_utils.py +++ b/src/tdastro/astro_utils/snia_utils.py @@ -176,30 +176,6 @@ def _x0_from_distmod(distmod, x1, c, alpha, beta, m_abs): return x0 -def _distmod_from_redshift(redshift, H0=73.0, Omega_m=0.3): - """Compute distance modulus given redshift and cosmology. - - Parameters - ---------- - redshift : `float` - The redshift value. - H0: `float` - The Hubble constant. - Omega_m: `float` - The matter density. - - Returns - ------- - distmod : `float` - The distance modulus (in mag) - """ - - cosmo = FlatLambdaCDM(H0=H0, Om0=Omega_m) - distmod = cosmo.distmod(redshift).value - - return distmod - - class HostmassX1Func(NumericalInversePolynomialFunc): """A class for sampling from the HostmassX1Distr. @@ -320,11 +296,27 @@ class DistModFromRedshift(FunctionNode): """ def __init__(self, redshift, H0=73.0, Omega_m=0.3, **kwargs): + # Create the cosmology once for this node. + self.cosmo = FlatLambdaCDM(H0=H0, Om0=Omega_m) + # Call the super class's constructor with the needed information. super().__init__( - func=_distmod_from_redshift, + func=self._distmod_from_redshift, redshift=redshift, - H0=H0, - Omega_m=Omega_m, **kwargs, ) + + def _distmod_from_redshift(self, redshift): + """Compute distance modulus given redshift and cosmology. + + Parameters + ---------- + redshift : `float` or `numpy.ndarray` + The redshift value(s). + + Returns + ------- + distmod : `float` or `numpy.ndarray` + The distance modulus (in mag) + """ + return self.cosmo.distmod(redshift).value diff --git a/tests/tdastro/astro_utils/test_snia_utils.py b/tests/tdastro/astro_utils/test_snia_utils.py index e595c5c9..c1c377f2 100644 --- a/tests/tdastro/astro_utils/test_snia_utils.py +++ b/tests/tdastro/astro_utils/test_snia_utils.py @@ -1,6 +1,7 @@ import numpy as np +import pytest from scipy.stats import norm -from tdastro.astro_utils.snia_utils import HostmassX1Distr, HostmassX1Func +from tdastro.astro_utils.snia_utils import DistModFromRedshift, HostmassX1Distr, HostmassX1Func from tdastro.rand_nodes.np_random import NumpyRandomFunc @@ -95,3 +96,14 @@ def test_sample_hostmass_x1c(): hm_node1.get_param(states1, "hostmass"), hm_node3.get_param(states3, "hostmass"), ) + + +def test_dist_mod_from_redshift(): + """Test the computation of dist_mod from the redshift.""" + redshifts = [0.01, 0.02, 0.05, 0.5] + expected = [33.08419428, 34.60580484, 36.64346629, 42.17006132] + + for idx, z in enumerate(redshifts): + node = DistModFromRedshift(redshift=z, H0=73.0, Omega_m=0.3) + state = node.sample_parameters(num_samples=1) + assert node.get_param(state, "function_node_result") == pytest.approx(expected[idx])