Skip to content

Commit

Permalink
Remove if-else function
Browse files Browse the repository at this point in the history
  • Loading branch information
tsunhopang committed Sep 5, 2024
1 parent c4995d7 commit 762b7e0
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions src/jimgw/single_event/transforms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax.lax
import jax.numpy as jnp
from beartype import beartype as typechecker
from jaxtyping import Float, Array, jaxtyped
Expand Down Expand Up @@ -260,10 +261,10 @@ def __init__(
)
)

if "iota" in conditional_names:
self.get_iota = lambda x: x["iota"]
else:
self.get_iota = lambda x: spin_to_iota(
self.get_iota = lambda x: jax.lax.cond(
"iota" in conditional_names,
lambda _: x["iota"],
lambda _: spin_to_iota(
x["theta_jn"],
x["phi_jl"],
x["theta_1"],
Expand All @@ -275,7 +276,9 @@ def __init__(
x["q"],
self.freq_ref,
0.0,
)
),
operand=None,
)

@jnp.vectorize
def _calc_R_det_arg(ra, dec, psi, iota, gmst):
Expand Down Expand Up @@ -368,10 +371,10 @@ def __init__(
and "M_c" in conditional_names
)

if "iota" in conditional_names:
self.get_iota = lambda x: x["iota"]
else:
self.get_iota = lambda x: spin_to_iota(
self.get_iota = lambda x: jax.lax.cond(
"iota" in conditional_names,
lambda _: x["iota"],
lambda _: spin_to_iota(
x["theta_jn"],
x["phi_jl"],
x["theta_1"],
Expand All @@ -383,7 +386,9 @@ def __init__(
x["q"],
self.freq_ref,
0.0,
)
),
operand=None,
)

@jnp.vectorize
def _calc_R_dets(ra, dec, psi, iota):
Expand Down

0 comments on commit 762b7e0

Please sign in to comment.