Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
gregreen committed Sep 19, 2023
2 parents b8b4a82 + 3d1321e commit 3feb350
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 11 deletions.
18 changes: 16 additions & 2 deletions scripts/plot_flow_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,14 @@ def add_1dpopulation_boundaries(axs, dim1, attrs):
valid_keys = ["x", "y", "z", "cylR"]
plot_sph, plot_cyl = [], []
if "volume_type" not in attrs or attrs["volume_type"] == "sphere":
r_in, r_out = 1 / attrs["parallax_max"], 1 / attrs["parallax_min"]
if "r_out" in attrs:
r_out = attrs["r_out"]
else:
r_out = 1 / attrs["parallax_min"]
if "r_in" in attrs:
r_in = attrs["r_in"]
else:
r_in = 1 / attrs["parallax_max"]
plot_sph = [r_in, r_out]
elif attrs["volume_type"] == "cylinder":
if "r_in" in attrs:
Expand Down Expand Up @@ -357,7 +364,14 @@ def add_2dpopulation_boundaries(axs, dim1, dim2, attrs, color="white"):

plot_sph, plot_cyl = [], []
if "volume_type" not in attrs or attrs["volume_type"] == "sphere":
r_in, r_out = 1 / attrs["parallax_max"], 1 / attrs["parallax_min"]
if "r_out" in attrs:
r_out = attrs["r_out"]
else:
r_out = 1 / attrs["parallax_min"]
if "r_in" in attrs:
r_in = attrs["r_in"]
else:
r_in = 1 / attrs["parallax_max"]
plot_sph = [r_in, r_out]
elif attrs["volume_type"] == "cylinder":
if "r_in" in attrs:
Expand Down
3 changes: 1 addition & 2 deletions scripts/plot_gaia.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,7 @@ def plot_vcirc_2d_slice(
raise ValueError(f"dimension {dim} not supported")

fig, all_axs = plt.subplots(
2,
3,
2, 3,
figsize=(6, 2.2),
dpi=200,
gridspec_kw=dict(width_ratios=[2, 2, 2], height_ratios=[0.2, 2]),
Expand Down
78 changes: 73 additions & 5 deletions scripts/plot_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,71 @@
dpi = 200


def get_sampling_progressbar_fn(n_batches, n_samples):
widgets = [
progressbar.Bar(),
progressbar.Percentage(),
" | ",
progressbar.Timer(format="Elapsed: %(elapsed)s"),
" | ",
progressbar.AdaptiveETA(),
" | ",
progressbar.Variable("batches_done", width=6, precision=0),
", ",
progressbar.Variable("n_batches", width=6, precision=0),
", ",
progressbar.Variable("n_samples", width=8, precision=0),
]
bar = progressbar.ProgressBar(max_value=n_batches, widgets=widgets)
# n_batches = n_batches
# n_samples = n_samples

def update_progressbar(i):
bar.update(i + 1, batches_done=i + 1, n_batches=n_batches,
n_samples=n_samples)

return update_progressbar


def get_model_values(phi_model, q_eval, fig_dir=None, fname=None, save=True):
# Calculate the model potential values
# In practice, up to 1e6 densities can be calculated at once..
batch_size = 131072
n0 = len(q_eval)
q_eval = tf.data.Dataset.from_tensor_slices(q_eval).batch(batch_size)

if save:
fname = os.path.join(fig_dir, f'data/{fname}_{n0}.npz')
if (not save) or not os.path.exists(fname):
rhos = []
dphi_dqs = []
phis = []

bar, iteration = get_sampling_progressbar_fn(len(q_eval), n0), 0
for i, b in enumerate(q_eval):
phi,dphi_dq,d2phi_dq2 = potential_tf.calc_phi_derivatives(
phi_model['phi'], b, return_phi=True
)
rhos.append(2.325*d2phi_dq2.numpy()/(4*np.pi)) # [M_Sun/pc^3]
dphi_dqs.append(dphi_dq)
phis.append(phi)
bar(iteration)
iteration += 1
rhos = np.concatenate(rhos)
dphi_dqs = np.concatenate(dphi_dqs)
phis = np.concatenate(phis)
if save:
Path(os.path.join(fig_dir), 'data').mkdir(parents=True, exist_ok=True)
np.savez(fname, phi=phis, dphi_dq=dphi_dqs, rho=rhos)
else:
npzfile = np.load(fname)
rhos = npzfile['rho']
dphi_dqs = npzfile['dphi_dq']
phis = npzfile['phi']

return phis, dphi_dqs, rhos


def plot_rho(
phi_model,
coords_train,
Expand Down Expand Up @@ -175,9 +240,12 @@ def plot_rho(
nbins = 64
weights = np.full_like(
x_train, 1 / ((xmax - xmin) * (xmax - xmin) / nbins**2 * 2 * dz) / 10**9
)
)
grid_size = 32
x = np.linspace(xmin, xmax, grid_size + 1)
y = np.linspace(ymin, ymax, grid_size + 1)
h = ax_e.hist2d(
x_train, y_train, range=lims, weights=weights, bins=64, rasterized=True
x_train, y_train, range=lims, weights=weights, bins=(x, y), rasterized=True
) # , norm=matplotlib.colors.LogNorm(vmin=1))
cb_e = fig.colorbar(h[3], cax=cax_e, orientation="horizontal")
cb_e.ax.xaxis.set_ticks_position("top")
Expand Down Expand Up @@ -811,7 +879,7 @@ def main():
"--fig-fmt",
type=str,
nargs="+",
default=("svg",),
default=("pdf",),
help="Formats in which to save figures (svg, png, pdf, etc.).",
)
parser.add_argument(
Expand Down Expand Up @@ -967,8 +1035,8 @@ def main():
("vy", "vz"),
]
print(" Calculating Phi gradients (might take a while) ...")
_, dphi_dq, _ = potential_tf.calc_phi_derivatives(
phi_model["phi"], df_data["eta"][:, :3], return_phi=True
_, dphi_dq, _ = get_model_values(
phi_model, df_data["eta"][:, :3], save=False
)
for dim1, dim2 in dims:
print(f" --> ({dim1}, {dim2})")
Expand Down
18 changes: 16 additions & 2 deletions scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,14 @@ def cut(eta, attrs):
z = eta[:, 2]

if "volume_type" not in attrs or attrs["volume_type"] == "sphere":
r_in, r_out = 1 / attrs["parallax_max"], 1 / attrs["parallax_min"]
if "r_out" in attrs:
r_out = attrs["r_out"]
else:
r_out = 1 / attrs["parallax_min"]
if "r_in" in attrs:
r_in = attrs["r_in"]
else:
r_in = 1 / attrs["parallax_max"]
idx = (r2 > r_in**2) & (r2 < r_out**2)
elif attrs["volume_type"] == "cylinder":
R_out, H_out = attrs["R_out"], attrs["H_out"]
Expand Down Expand Up @@ -593,7 +600,14 @@ def get_index_of_points_inside_attrs(eta, attrs, r=None, R=None, z=None):
R = np.sum(eta[:, :2] ** 2, axis=1) ** 0.5
z = eta[:, 2]
if "volume_type" not in attrs or attrs["volume_type"] == "sphere":
r_in, r_out = 1 / attrs["parallax_max"], 1 / attrs["parallax_min"]
if "r_out" in attrs:
r_out = attrs["r_out"]
else:
r_out = 1 / attrs["parallax_min"]
if "r_in" in attrs:
r_in = attrs["r_in"]
else:
r_in = 1 / attrs["parallax_max"]
idx = (r >= r_in) & (r <= r_out)
elif attrs["volume_type"] == "cylinder":
R_out, H_out = attrs["R_out"], attrs["H_out"]
Expand Down

0 comments on commit 3feb350

Please sign in to comment.