Skip to content

Commit

Permalink
Revert "Remove if-else function"
Browse files Browse the repository at this point in the history
This reverts commit 762b7e0.
  • Loading branch information
tsunhopang committed Sep 5, 2024
1 parent 762b7e0 commit 8805205
Showing 1 changed file with 10 additions and 15 deletions.
25 changes: 10 additions & 15 deletions src/jimgw/single_event/transforms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import jax.lax
import jax.numpy as jnp
from beartype import beartype as typechecker
from jaxtyping import Float, Array, jaxtyped
Expand Down Expand Up @@ -261,10 +260,10 @@ def __init__(
)
)

self.get_iota = lambda x: jax.lax.cond(
"iota" in conditional_names,
lambda _: x["iota"],
lambda _: spin_to_iota(
if "iota" in conditional_names:
self.get_iota = lambda x: x["iota"]
else:
self.get_iota = lambda x: spin_to_iota(
x["theta_jn"],
x["phi_jl"],
x["theta_1"],
Expand All @@ -276,9 +275,7 @@ def __init__(
x["q"],
self.freq_ref,
0.0,
),
operand=None,
)
)

@jnp.vectorize
def _calc_R_det_arg(ra, dec, psi, iota, gmst):
Expand Down Expand Up @@ -371,10 +368,10 @@ def __init__(
and "M_c" in conditional_names
)

self.get_iota = lambda x: jax.lax.cond(
"iota" in conditional_names,
lambda _: x["iota"],
lambda _: spin_to_iota(
if "iota" in conditional_names:
self.get_iota = lambda x: x["iota"]
else:
self.get_iota = lambda x: spin_to_iota(
x["theta_jn"],
x["phi_jl"],
x["theta_1"],
Expand All @@ -386,9 +383,7 @@ def __init__(
x["q"],
self.freq_ref,
0.0,
),
operand=None,
)
)

@jnp.vectorize
def _calc_R_dets(ra, dec, psi, iota):
Expand Down

0 comments on commit 8805205

Please sign in to comment.