Skip to content

Commit

Permalink
use norm if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Nov 15, 2024
1 parent 9ebc63b commit f17734a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 21 deletions.
26 changes: 9 additions & 17 deletions bpd/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
import jax.numpy as jnp
import jax.scipy as jsp
from jax import grad, vmap
from jax.numpy.linalg import norm
from jax.typing import ArrayLike

from bpd.prior import inv_shear_func1, inv_shear_func2, inv_shear_transformation

_grad_fnc1 = vmap(vmap(grad(inv_shear_func1), in_axes=(0, None)), in_axes=(0, None))
_grad_fnc2 = vmap(vmap(grad(inv_shear_func2), in_axes=(0, None)), in_axes=(0, None))


def shear_loglikelihood_unreduced(
g: tuple[float, float], e_post, prior: Callable, interim_prior: Callable
Expand All @@ -19,34 +23,22 @@ def shear_loglikelihood_unreduced(
# normalization in priors can be ignored for now as alpha is fixed.
_, K, _ = e_post.shape # (N, K, 2)

Check failure on line 24 in bpd/likelihood.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (F841)

bpd/likelihood.py:24:8: F841 Local variable `K` is assigned to but never used

Check failure on line 24 in bpd/likelihood.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F841)

bpd/likelihood.py:24:8: F841 Local variable `K` is assigned to but never used

e_post_mag = jnp.sqrt(e_post[..., 0] ** 2 + e_post[..., 1] ** 2)
e_post_mag = norm(e_post, axis=-1)
denom = interim_prior(e_post_mag) # (N, K), can ignore angle in prior as uniform

# for num, use trick
# p(w_n' | g, alpha ) = p(w_n' \cross^{-1} g | alpha ) = p(w_n | alpha) * |jac(w_n / w_n')|

# shape = (N, K, 2)
grad1 = vmap(
vmap(grad(inv_shear_func1, argnums=0), in_axes=(0, None)),
in_axes=(0, None),
)(e_post, g)

grad2 = vmap(
vmap(grad(inv_shear_func2, argnums=0), in_axes=(0, None)),
in_axes=(0, None),
)(e_post, g)

grad1 = _grad_fnc1(e_post, g)
grad2 = _grad_fnc2(e_post, g)
absjacdet = jnp.abs(grad1[..., 0] * grad2[..., 1] - grad1[..., 1] * grad2[..., 0])

e_post_unsheared = inv_shear_transformation(e_post, g)
e_obs_unsheared_mag = jnp.sqrt(
e_post_unsheared[..., 0] ** 2 + e_post_unsheared[..., 1] ** 2
)
e_obs_unsheared_mag = norm(e_post_unsheared, axis=-1)
num = prior(e_obs_unsheared_mag) * absjacdet # (N, K)

ratio = jnp.log((1 / K)) + jsp.special.logsumexp(
jnp.log(num) - jnp.log(denom), axis=-1
)
ratio = jsp.special.logsumexp(jnp.log(num) - jnp.log(denom), axis=-1)
return ratio


Expand Down
5 changes: 2 additions & 3 deletions bpd/pipelines/toy_ellips.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from jax import Array, random, vmap
from jax import jit as jjit
from jax._src.prng import PRNGKeyArray
from jax.numpy.linalg import norm

from bpd.chains import run_inference_nuts
from bpd.prior import ellip_mag_prior, sample_synthetic_sheared_ellips_unclipped
Expand All @@ -23,9 +24,7 @@ def logtarget(

# ignore angle prior assumed uniform
# prior enforces magnitude < 1.0 for posterior samples
e_sheared_mag = jnp.sqrt(e_sheared[0] ** 2 + e_sheared[1] ** 2)
prior = jnp.log(interim_prior(e_sheared_mag))

prior = jnp.log(interim_prior(norm(e_sheared)))
likelihood = jnp.sum(jsp.stats.norm.logpdf(e_obs, loc=e_sheared, scale=sigma_m))
return prior + likelihood

Expand Down
3 changes: 2 additions & 1 deletion bpd/prior.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax.numpy as jnp
from jax import Array, random
from jax.numpy.linalg import norm


def ellip_mag_prior(e, sigma: float):
Expand Down Expand Up @@ -136,7 +137,7 @@ def sample_synthetic_sheared_ellips_clipped(
# clip magnitude to < 1
# preserve angle after noise added when clipping
beta = jnp.arctan2(e_obs[:, :, 1], e_obs[:, :, 0]) / 2
e_obs_mag = jnp.sqrt(e_obs[:, :, 0] ** 2 + e_obs[:, :, 1] ** 2)
e_obs_mag = norm(e_obs, axis=-1)
e_obs_mag = jnp.clip(e_obs_mag, 0, e_tol) # otherwise likelihood explodes

final_eobs1 = e_obs_mag * jnp.cos(2 * beta)
Expand Down

0 comments on commit f17734a

Please sign in to comment.