Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add convergence test interim samples from image #36

Merged
merged 9 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ jobs:
- name: Run Ruff
run: ruff check --output-format=github .

- name: Run Tests
- name: Run fast tests
run: |
pytest --durations=0
pytest -m "not slow" --durations=0

- name: Run slow tests
run: |
pytest -m "slow" --durations=0
16 changes: 14 additions & 2 deletions bpd/pipelines/image_ellips.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from bpd.chains import run_inference_nuts
from bpd.draw import draw_gaussian, draw_gaussian_galsim
from bpd.noise import add_noise
from bpd.prior import ellip_mag_prior, sample_ellip_prior
from bpd.prior import ellip_mag_prior, sample_ellip_prior, scalar_shear_transformation


def get_target_galaxy_params_simple(
Expand All @@ -37,6 +37,18 @@ def get_target_galaxy_params_simple(
}


def get_true_params_from_galaxy_params(galaxy_params: dict[str, Array]):
true_params = {**galaxy_params}
e1, e2 = true_params.pop("e1"), true_params.pop("e2")
g1, g2 = true_params.pop("g1"), true_params.pop("g2")

e1_prime, e2_prime = scalar_shear_transformation((e1, e2), (g1, g2))
true_params["e1"] = e1_prime
true_params["e2"] = e2_prime

return true_params # don't add g1,g2 back as we are not inferring those


def get_target_images_single(
rng_key: PRNGKeyArray,
n_samples: int,
Expand Down Expand Up @@ -112,7 +124,7 @@ def pipeline_image_interim_samples_one_galaxy(
max_num_doublings: int = 5,
initial_step_size: float = 1e-3,
n_warmup_steps: int = 500,
is_mass_matrix_diagonal: bool = False,
is_mass_matrix_diagonal: bool = True,
slen: int = 53,
fft_size: int = 256,
background: float = 1.0,
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,6 @@ exclude = ["*.ipynb", "scripts/one_galaxy_shear.py", "scripts/benchmarks/*.py"]

[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-ra"
addopts = "-ra -v --strict-markers"
filterwarnings = ["ignore::DeprecationWarning:tensorflow.*"]
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
12 changes: 2 additions & 10 deletions scripts/one_galaxy_shear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from bpd.pipelines.image_ellips import (
get_target_galaxy_params_simple,
get_target_images_single,
get_true_params_from_galaxy_params,
pipeline_image_interim_samples,
)
from bpd.pipelines.shear_inference import pipeline_shear_inference
from bpd.prior import scalar_shear_transformation

init_fnc = init_with_truth

Expand Down Expand Up @@ -53,19 +53,11 @@ def main(
nkey,
n_samples=n_gals,
single_galaxy_params=galaxy_params,
psf_hlr=psf_hlr,
background=background,
slen=slen,
pixel_scale=pixel_scale,
)

true_params = {**galaxy_params}
e1, e2 = true_params.pop("e1"), true_params.pop("e2")
g1, g2 = true_params.pop("g1"), true_params.pop("g2")

e1_prime, e2_prime = scalar_shear_transformation((e1, e2), (g1, g2))
true_params["e1"] = e1_prime
true_params["e2"] = e2_prime
true_params = get_true_params_from_galaxy_params(galaxy_params)

# prepare pipelines
pipe1 = partial(
Expand Down
71 changes: 69 additions & 2 deletions tests/test_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,21 @@
from jax import random, vmap

from bpd.chains import run_inference_nuts
from bpd.initialization import init_with_truth
from bpd.pipelines.image_ellips import (
get_target_galaxy_params_simple,
get_target_images_single,
get_true_params_from_galaxy_params,
pipeline_image_interim_samples_one_galaxy,
)
from bpd.pipelines.shear_inference import pipeline_shear_inference
from bpd.pipelines.toy_ellips import logtarget as logtarget_toy_ellips
from bpd.pipelines.toy_ellips import pipeline_toy_ellips_samples
from bpd.prior import ellip_mag_prior, sample_synthetic_sheared_ellips_unclipped


@pytest.mark.parametrize("seed", [1234, 4567])
def test_interim_ellipticity_posterior_convergence(seed):
def test_interim_toy_convergence(seed):
"""Check efficiency and convergence of chains for 100 galaxies."""
g1, g2 = 0.02, 0.0
sigma_m = 1e-4
Expand Down Expand Up @@ -74,7 +81,7 @@ def test_interim_ellipticity_posterior_convergence(seed):


@pytest.mark.parametrize("seed", [1234, 4567])
def test_shear_posterior_convergence(seed):
def test_toy_shear_convergence(seed):
g1, g2 = 0.02, 0.0
sigma_m = 1e-4
sigma_e = 1e-3
Expand Down Expand Up @@ -124,3 +131,63 @@ def test_shear_posterior_convergence(seed):

assert ess > 0.5 * 4000
assert jnp.abs(rhat - 1) < 0.01


@pytest.mark.slow
@pytest.mark.parametrize("seed", [1234, 4567])
def test_low_noise_single_galaxy_interim_samples(seed):
lf = 6.0
hlr = 1.0
g1, g2 = 0.02, 0.0
sigma_e = 1e-3
sigma_e_int = 3e-2
n_samples = 500
background = 1.0
slen = 53
fft_size = 256
init_fnc = init_with_truth

rng_key = random.key(seed)
pkey, nkey, gkey = random.split(rng_key, 3)

galaxy_params = get_target_galaxy_params_simple(
pkey, lf=lf, g1=g1, g2=g2, hlr=hlr, shape_noise=sigma_e
)

draw_params = {**galaxy_params}
draw_params["f"] = 10 ** draw_params.pop("lf")
target_image = get_target_images_single(
nkey,
n_samples=1,
single_galaxy_params=draw_params,
background=background,
slen=slen,
)[0]
true_params = get_true_params_from_galaxy_params(galaxy_params)

pipe1 = partial(
pipeline_image_interim_samples_one_galaxy,
initialization_fnc=init_fnc,
sigma_e_int=sigma_e_int,
n_samples=n_samples,
slen=slen,
fft_size=fft_size,
n_warmup_steps=300,
)
vpipe1 = vmap(jjit(pipe1), (0, 0, None))

# chain initialization
# one galaxy, test convergence, so 4 random seeds
keys = random.split(gkey, 4)
init_positions = vmap(init_fnc, (0, None))(keys, true_params)

samples = vpipe1(keys, init_positions, target_image)

# check each component
for _, v in samples.items():
assert v.shape == (4, n_samples)
ess = effective_sample_size(v)
rhat = potential_scale_reduction(v)

assert ess > 0.5 * n_samples
assert jnp.abs(rhat - 1) < 0.01