diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 0af0fe42..ad91a56a 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -1,4 +1,3 @@ -import jax.lax import jax.numpy as jnp from beartype import beartype as typechecker from jaxtyping import Float, Array, jaxtyped @@ -261,10 +260,10 @@ def __init__( ) ) - self.get_iota = lambda x: jax.lax.cond( - "iota" in conditional_names, - lambda _: x["iota"], - lambda _: spin_to_iota( + if "iota" in conditional_names: + self.get_iota = lambda x: x["iota"] + else: + self.get_iota = lambda x: spin_to_iota( x["theta_jn"], x["phi_jl"], x["theta_1"], @@ -276,9 +275,7 @@ def __init__( x["q"], self.freq_ref, 0.0, - ), - operand=None, - ) + ) @jnp.vectorize def _calc_R_det_arg(ra, dec, psi, iota, gmst): @@ -371,10 +368,10 @@ def __init__( and "M_c" in conditional_names ) - self.get_iota = lambda x: jax.lax.cond( - "iota" in conditional_names, - lambda _: x["iota"], - lambda _: spin_to_iota( + if "iota" in conditional_names: + self.get_iota = lambda x: x["iota"] + else: + self.get_iota = lambda x: spin_to_iota( x["theta_jn"], x["phi_jl"], x["theta_1"], @@ -386,9 +383,7 @@ def __init__( x["q"], self.freq_ref, 0.0, - ), - operand=None, - ) + ) @jnp.vectorize def _calc_R_dets(ra, dec, psi, iota):