diff --git a/scripts/one_galaxy_shear.py b/scripts/one_galaxy_image_interim_samples.py similarity index 71% rename from scripts/one_galaxy_shear.py rename to scripts/one_galaxy_image_interim_samples.py index 74b5531..3fc542a 100755 --- a/scripts/one_galaxy_shear.py +++ b/scripts/one_galaxy_image_interim_samples.py @@ -12,9 +12,8 @@ get_target_galaxy_params_simple, get_target_images_single, get_true_params_from_galaxy_params, - pipeline_image_interim_samples, + pipeline_image_interim_samples_one_galaxy, ) -from bpd.pipelines.shear_inference import pipeline_shear_inference init_fnc = init_with_truth @@ -22,6 +21,9 @@ def main( tag: str, seed: int, + n_gals: int = 100, # technically, in this file it means 'noise realizations' + n_samples_per_gal: int = 100, + n_vec: int = 50, # how many galaxies to process simultaneously in 1 GPU core g1: float = 0.02, g2: float = 0.0, lf: float = 6.0, @@ -30,13 +32,11 @@ def main( slen: int = 53, fft_size: int = 256, background: float = 1.0, - n_gals: int = 1000, # technically, here it means 'noise realizations' - n_samples_shear: int = 3000, - n_samples_per_gal: int = 100, + initial_step_size: float = 1e-3, trim: int = 1, ): rng_key = random.key(seed) - pkey, nkey, gkey, skey = random.split(rng_key, 3) + pkey, nkey, gkey = random.split(rng_key, 3) # directory structure dirpath = DATA_DIR / "cache_chains" / tag @@ -44,51 +44,52 @@ def main( if not dirpath.exists(): dirpath.mkdir(exist_ok=True) + fpath = dirpath / f"e_post_{seed}.npy" + # get images - galaxy_params = get_target_galaxy_params_simple( - pkey, lf=lf, g1=g1, g2=g2, hlr=hlr, shape_noise=shape_noise + galaxy_params = get_target_galaxy_params_simple( # default hlr, x, y + pkey, lf=lf, g1=g1, g2=g2, shape_noise=shape_noise ) - target_images = get_target_images_single( + draw_params = {**galaxy_params} + draw_params["f"] = 10 ** draw_params.pop("lf") + target_images, _ = get_target_images_single( nkey, n_samples=n_gals, - single_galaxy_params=galaxy_params, + single_galaxy_params=draw_params, background=background, slen=slen, ) + assert target_images.shape == (n_gals, slen, slen) true_params = get_true_params_from_galaxy_params(galaxy_params) # prepare pipelines pipe1 = partial( - pipeline_image_interim_samples, + pipeline_image_interim_samples_one_galaxy, initialization_fnc=init_fnc, - n_samples=k, - max_num_doublings=5, - initial_step_size=1e-3, - n_warmup_steps=500, - is_mass_matrix_diagonal=True, - background=background, + sigma_e_int=sigma_e_int, + n_samples=n_samples_per_gal, + initial_step_size=initial_step_size, slen=slen, - pixel_scale=pixel_scale, fft_size=fft_size, + background=background, ) vpipe1 = vmap(jjit(pipe1), (0, None, 0)) - pipe2 = partial( - pipeline_shear_inference, - true_g=jnp.array([g1, g2]), - sigma_e=shape_noise, - sigma_e_int=sigma_e_int, - n_samples=n_samples_shear, - ) - vpipe2 = vmap(pipe2, in_axes=(0, 0)) - + # initialization gkeys = random.split(gkey, n_gals) + init_positions = vmap(init_fnc, (0, None))(keys, true_params) + + galaxy_samples = vpipe1(gkeys, true_params, target_images) + + e_post = jnp.stack([galaxy_samples["e1"], galaxy_samples["e2"]], axis=-1) + jnp.save( + g_samples = vpipe2(skey, e_post)