Skip to content

Commit

Permalink
continue draft
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Nov 5, 2024
1 parent f54d7c1 commit 9d758a5
Showing 1 changed file with 28 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
get_target_galaxy_params_simple,
get_target_images_single,
get_true_params_from_galaxy_params,
pipeline_image_interim_samples,
pipeline_image_interim_samples_one_galaxy,
)
from bpd.pipelines.shear_inference import pipeline_shear_inference

init_fnc = init_with_truth


def main(
tag: str,
seed: int,
n_gals: int = 100, # technically, in this file it means 'noise realizations'
n_samples_per_gal: int = 100,
n_vec: int = 50, # how many galaxies to process simultaneously in 1 GPU core
g1: float = 0.02,
g2: float = 0.0,
lf: float = 6.0,
Expand All @@ -30,65 +32,64 @@ def main(
slen: int = 53,
fft_size: int = 256,
background: float = 1.0,
n_gals: int = 1000, # technically, here it means 'noise realizations'
n_samples_shear: int = 3000,
n_samples_per_gal: int = 100,
initial_step_size: float = 1e-3,
trim: int = 1,
):
rng_key = random.key(seed)
pkey, nkey, gkey, skey = random.split(rng_key, 3)
pkey, nkey, gkey = random.split(rng_key, 3)

# directory structure
dirpath = DATA_DIR / "cache_chains" / tag

if not dirpath.exists():
dirpath.mkdir(exist_ok=True)

fpath = dirpath / f"e_post_{seed}.npy"

# get images
galaxy_params = get_target_galaxy_params_simple(
pkey, lf=lf, g1=g1, g2=g2, hlr=hlr, shape_noise=shape_noise
galaxy_params = get_target_galaxy_params_simple( # default hlr, x, y
pkey, lf=lf, g1=g1, g2=g2, shape_noise=shape_noise
)

target_images = get_target_images_single(
draw_params = {**galaxy_params}
draw_params["f"] = 10 ** draw_params.pop("lf")
target_images, _ = get_target_images_single(
nkey,
n_samples=n_gals,
single_galaxy_params=galaxy_params,
single_galaxy_params=draw_params,
background=background,
slen=slen,
)
assert target_images.shape == (n_gals, slen, slen)

true_params = get_true_params_from_galaxy_params(galaxy_params)

# prepare pipelines
pipe1 = partial(
pipeline_image_interim_samples,
pipeline_image_interim_samples_one_galaxy,
initialization_fnc=init_fnc,
n_samples=k,
max_num_doublings=5,
initial_step_size=1e-3,
n_warmup_steps=500,
is_mass_matrix_diagonal=True,
background=background,
sigma_e_int=sigma_e_int,
n_samples=n_samples_per_gal,
initial_step_size=initial_step_size,
slen=slen,
pixel_scale=pixel_scale,
fft_size=fft_size,
background=background,
)
vpipe1 = vmap(jjit(pipe1), (0, None, 0))

pipe2 = partial(
pipeline_shear_inference,
true_g=jnp.array([g1, g2]),
sigma_e=shape_noise,
sigma_e_int=sigma_e_int,
n_samples=n_samples_shear,
)
vpipe2 = vmap(pipe2, in_axes=(0, 0))

# initialization
gkeys = random.split(gkey, n_gals)
init_positions = vmap(init_fnc, (0, None))(keys, true_params)


galaxy_samples = vpipe1(gkeys, true_params, target_images)

Check failure on line 85 in scripts/one_galaxy_image_interim_samples.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (E303)

scripts/one_galaxy_image_interim_samples.py:85:5: E303 Too many blank lines (2)

Check failure on line 85 in scripts/one_galaxy_image_interim_samples.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (E303)

scripts/one_galaxy_image_interim_samples.py:85:5: E303 Too many blank lines (2)

Check failure on line 86 in scripts/one_galaxy_image_interim_samples.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (W293)

scripts/one_galaxy_image_interim_samples.py:86:1: W293 Blank line contains whitespace

Check failure on line 86 in scripts/one_galaxy_image_interim_samples.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (W293)

scripts/one_galaxy_image_interim_samples.py:86:1: W293 Blank line contains whitespace


e_post = jnp.stack([galaxy_samples["e1"], galaxy_samples["e2"]], axis=-1)

Check failure on line 89 in scripts/one_galaxy_image_interim_samples.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (E303)

scripts/one_galaxy_image_interim_samples.py:89:5: E303 Too many blank lines (3)

Check failure on line 89 in scripts/one_galaxy_image_interim_samples.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (E303)

scripts/one_galaxy_image_interim_samples.py:89:5: E303 Too many blank lines (3)

jnp.save(

g_samples = vpipe2(skey, e_post)

Check failure on line 93 in scripts/one_galaxy_image_interim_samples.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (E251)

scripts/one_galaxy_image_interim_samples.py:93:14: E251 Unexpected spaces around keyword / parameter equals

Check failure on line 93 in scripts/one_galaxy_image_interim_samples.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (E251)

scripts/one_galaxy_image_interim_samples.py:93:16: E251 Unexpected spaces around keyword / parameter equals

Check failure on line 93 in scripts/one_galaxy_image_interim_samples.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (E251)

scripts/one_galaxy_image_interim_samples.py:93:14: E251 Unexpected spaces around keyword / parameter equals

Check failure on line 93 in scripts/one_galaxy_image_interim_samples.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (E251)

scripts/one_galaxy_image_interim_samples.py:93:16: E251 Unexpected spaces around keyword / parameter equals

Check failure on line 94 in scripts/one_galaxy_image_interim_samples.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff

scripts/one_galaxy_image_interim_samples.py:93:37: SyntaxError: Expected ')', found newline

Check failure on line 94 in scripts/one_galaxy_image_interim_samples.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff

scripts/one_galaxy_image_interim_samples.py:93:37: SyntaxError: Expected ')', found newline

Expand Down

0 comments on commit 9d758a5

Please sign in to comment.