From 363ea57cf47e5b054ffeb93249f67f2e59857a30 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Mon, 4 Nov 2024 13:36:21 -0800 Subject: [PATCH 1/9] function to simplify later steps --- bpd/pipelines/image_ellips.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/bpd/pipelines/image_ellips.py b/bpd/pipelines/image_ellips.py index cac4afe..fa62e03 100644 --- a/bpd/pipelines/image_ellips.py +++ b/bpd/pipelines/image_ellips.py @@ -10,7 +10,7 @@ from bpd.chains import run_inference_nuts from bpd.draw import draw_gaussian, draw_gaussian_galsim from bpd.noise import add_noise -from bpd.prior import ellip_mag_prior, sample_ellip_prior +from bpd.prior import ellip_mag_prior, sample_ellip_prior, scalar_shear_transformation def get_target_galaxy_params_simple( @@ -37,6 +37,18 @@ def get_target_galaxy_params_simple( } +def get_true_params_from_galaxy_params(galaxy_params: dict[str, Array]): + true_params = {**galaxy_params} + e1, e2 = true_params.pop("e1"), true_params.pop("e2") + g1, g2 = true_params.pop("g1"), true_params.pop("g2") + + e1_prime, e2_prime = scalar_shear_transformation((e1, e2), (g1, g2)) + true_params["e1"] = e1_prime + true_params["e2"] = e2_prime + + return true_params # don't add g1,g2 back as we are not inferring those + + def get_target_images_single( rng_key: PRNGKeyArray, n_samples: int, @@ -112,7 +124,7 @@ def pipeline_image_interim_samples_one_galaxy( max_num_doublings: int = 5, initial_step_size: float = 1e-3, n_warmup_steps: int = 500, - is_mass_matrix_diagonal: bool = False, + is_mass_matrix_diagonal: bool = True, slen: int = 53, fft_size: int = 256, background: float = 1.0, From a51d9b3a9cc5394a513943aeba5697d8b9a08ebc Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Mon, 4 Nov 2024 13:36:32 -0800 Subject: [PATCH 2/9] refactor, still not done until later pr --- scripts/one_galaxy_shear.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/scripts/one_galaxy_shear.py b/scripts/one_galaxy_shear.py index 6e147b3..74b5531 100755 --- a/scripts/one_galaxy_shear.py +++ b/scripts/one_galaxy_shear.py @@ -11,10 +11,10 @@ from bpd.pipelines.image_ellips import ( get_target_galaxy_params_simple, get_target_images_single, + get_true_params_from_galaxy_params, pipeline_image_interim_samples, ) from bpd.pipelines.shear_inference import pipeline_shear_inference -from bpd.prior import scalar_shear_transformation init_fnc = init_with_truth @@ -53,19 +53,11 @@ def main( nkey, n_samples=n_gals, single_galaxy_params=galaxy_params, - psf_hlr=psf_hlr, background=background, slen=slen, - pixel_scale=pixel_scale, ) - true_params = {**galaxy_params} - e1, e2 = true_params.pop("e1"), true_params.pop("e2") - g1, g2 = true_params.pop("g1"), true_params.pop("g2") - - e1_prime, e2_prime = scalar_shear_transformation((e1, e2), (g1, g2)) - true_params["e1"] = e1_prime - true_params["e2"] = e2_prime + true_params = get_true_params_from_galaxy_params(galaxy_params) # prepare pipelines pipe1 = partial( From 0442294027abef13a26063de71b9add520c1c234 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Mon, 4 Nov 2024 13:36:37 -0800 Subject: [PATCH 3/9] test draft --- tests/test_convergence.py | 69 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 2 deletions(-) diff --git a/tests/test_convergence.py b/tests/test_convergence.py index f2c4c25..f70b39c 100644 --- a/tests/test_convergence.py +++ b/tests/test_convergence.py @@ -9,6 +9,13 @@ from jax import random, vmap from bpd.chains import run_inference_nuts +from bpd.initialization import init_with_truth +from bpd.pipelines.image_ellips import ( + get_target_galaxy_params_simple, + get_target_images_single, + get_true_params_from_galaxy_params, + pipeline_image_interim_samples_one_galaxy, +) from bpd.pipelines.shear_inference import pipeline_shear_inference from bpd.pipelines.toy_ellips import logtarget as logtarget_toy_ellips from bpd.pipelines.toy_ellips import pipeline_toy_ellips_samples @@ -16,7 +23,7 @@ @pytest.mark.parametrize("seed", [1234, 4567]) -def test_interim_ellipticity_posterior_convergence(seed): +def test_interim_toy_convergence(seed): """Check efficiency and convergence of chains for 100 galaxies.""" g1, g2 = 0.02, 0.0 sigma_m = 1e-4 @@ -74,7 +81,7 @@ def test_interim_ellipticity_posterior_convergence(seed): @pytest.mark.parametrize("seed", [1234, 4567]) -def test_shear_posterior_convergence(seed): +def test_toy_shear_convergence(seed): g1, g2 = 0.02, 0.0 sigma_m = 1e-4 sigma_e = 1e-3 @@ -124,3 +131,61 @@ def test_shear_posterior_convergence(seed): assert ess > 0.5 * 4000 assert jnp.abs(rhat - 1) < 0.01 + + +@pytest.mark.parametrize("seed", [1234, 4567, 1111, 2222]) +def test_low_noise_single_galaxy_interim_samples(seed): + lf = 6.0 + hlr = 1.0 + g1, g2 = 0.02, 0.0 + sigma_e = 1e-3 + sigma_e_int = 3e-2 + n_samples = 1000 + background = 1.0 + slen = 53 + fft_size = 256 + init_fnc = init_with_truth + + rng_key = random.key(seed) + pkey, nkey, gkey = random.split(rng_key, 3) + + galaxy_params = get_target_galaxy_params_simple( + pkey, lf=lf, g1=g1, g2=g2, hlr=hlr, shape_noise=sigma_e + ) + + draw_params = {**galaxy_params} + draw_params["f"] = 10 ** draw_params.pop("lf") + target_image = get_target_images_single( + nkey, + n_samples=1, + single_galaxy_params=draw_params, + background=background, + slen=slen, + )[0] + true_params = get_true_params_from_galaxy_params(galaxy_params) + + pipe1 = partial( + pipeline_image_interim_samples_one_galaxy, + initialization_fnc=init_fnc, + sigma_e_int=sigma_e_int, + n_samples=n_samples, + slen=slen, + fft_size=fft_size, + ) + vpipe1 = vmap(jjit(pipe1), (0, 0, None)) + + # chain initialization + # one galaxy, test convergence, so 4 random seeds + keys = random.split(gkey, 4) + init_positions = vmap(init_fnc, (0, None))(keys, true_params) + + samples = vpipe1(keys, init_positions, target_image) + + # check each component + for _, v in samples.item(): + assert v.shape == (4, 1000) + ess = effective_sample_size(v) + rhat = potential_scale_reduction(v) + + assert ess > 0.5 * 4000 + assert jnp.abs(rhat - 1) < 0.01 From 601be04489b1f7744c3a781f6cc72a10f095e362 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Mon, 4 Nov 2024 13:46:49 -0800 Subject: [PATCH 4/9] less stuff to run --- tests/test_convergence.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_convergence.py b/tests/test_convergence.py index f70b39c..b7b3690 100644 --- a/tests/test_convergence.py +++ b/tests/test_convergence.py @@ -133,14 +133,14 @@ def test_toy_shear_convergence(seed): assert jnp.abs(rhat - 1) < 0.01 -@pytest.mark.parametrize("seed", [1234, 4567, 1111, 2222]) +@pytest.mark.parametrize("seed", [1234, 4567]) def test_low_noise_single_galaxy_interim_samples(seed): lf = 6.0 hlr = 1.0 g1, g2 = 0.02, 0.0 sigma_e = 1e-3 sigma_e_int = 3e-2 - n_samples = 1000 + n_samples = 500 background = 1.0 slen = 53 fft_size = 256 @@ -171,6 +171,7 @@ def test_low_noise_single_galaxy_interim_samples(seed): n_samples=n_samples, slen=slen, fft_size=fft_size, + n_warmup_steps=300, ) vpipe1 = vmap(jjit(pipe1), (0, 0, None)) @@ -183,9 +184,9 @@ def test_low_noise_single_galaxy_interim_samples(seed): # check each component for _, v in samples.item(): - assert v.shape == (4, 1000) + assert v.shape == (4, n_samples) ess = effective_sample_size(v) rhat = potential_scale_reduction(v) - assert ess > 0.5 * 4000 + assert ess > 0.5 * n_samples assert jnp.abs(rhat - 1) < 0.01 From f4fb1e9c1131c56253fea04b214c0c038abcd184 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Mon, 4 Nov 2024 14:36:01 -0800 Subject: [PATCH 5/9] bug fix --- tests/test_convergence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_convergence.py b/tests/test_convergence.py index b7b3690..00f87a2 100644 --- a/tests/test_convergence.py +++ b/tests/test_convergence.py @@ -183,7 +183,7 @@ def test_low_noise_single_galaxy_interim_samples(seed): samples = vpipe1(keys, init_positions, target_image) # check each component - for _, v in samples.item(): + for _, v in samples.items(): assert v.shape == (4, n_samples) ess = effective_sample_size(v) rhat = potential_scale_reduction(v) From b688e8759a9179c5f1e6fd1f1a9864a2b05155e5 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Tue, 5 Nov 2024 11:56:16 -0800 Subject: [PATCH 6/9] separate out slow and quick tests --- .github/workflows/tests.yml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index cc8ee22..6281d33 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,6 +35,10 @@ jobs: - name: Run Ruff run: ruff check --output-format=github . - - name: Run Tests + - name: Run fast tests run: | - pytest --durations=0 + pytest -m "not slow" --durations=0 + + - name: Run slow tests + run: | + pytest -m "slow" --durations=0 From 80d1d8ed34128d91c4d2c50d75f681585fdc01f3 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Tue, 5 Nov 2024 11:57:11 -0800 Subject: [PATCH 7/9] add slow marker --- tests/test_convergence.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_convergence.py b/tests/test_convergence.py index 00f87a2..b3d8d47 100644 --- a/tests/test_convergence.py +++ b/tests/test_convergence.py @@ -133,6 +133,7 @@ def test_toy_shear_convergence(seed): assert jnp.abs(rhat - 1) < 0.01 +@pytest.mark.slow @pytest.mark.parametrize("seed", [1234, 4567]) def test_low_noise_single_galaxy_interim_samples(seed): lf = 6.0 From e67657a0d78fdd1b741fddc8aa19269a07dec1cd Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Tue, 5 Nov 2024 11:57:16 -0800 Subject: [PATCH 8/9] flag I alwasy wante --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5fc2afe..d569a09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,5 +86,5 @@ exclude = ["*.ipynb", "scripts/one_galaxy_shear.py", "scripts/benchmarks/*.py"] [tool.pytest.ini_options] minversion = "6.0" -addopts = "-ra" +addopts = "-ra -v" filterwarnings = ["ignore::DeprecationWarning:tensorflow.*"] From c81cde4949ead8d26ea4ab0cad62aca1c3e9fe77 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Tue, 5 Nov 2024 12:05:21 -0800 Subject: [PATCH 9/9] register mark --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d569a09..e9e0692 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,5 +86,6 @@ exclude = ["*.ipynb", "scripts/one_galaxy_shear.py", "scripts/benchmarks/*.py"] [tool.pytest.ini_options] minversion = "6.0" -addopts = "-ra -v" +addopts = "-ra -v --strict-markers" filterwarnings = ["ignore::DeprecationWarning:tensorflow.*"] +markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]