Skip to content

Commit

Permalink
coretree LC intersection: add workarounds for LastJourney data
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbuehlmann committed Nov 15, 2024
1 parent 89aba8b commit 434e033
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 20 deletions.
126 changes: 106 additions & 20 deletions haccytrees/scripts/intersect_cores_with_lightcone.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
import pygio
import numpy as np
import h5py
from mpi4py import MPI

from haccytrees.utils.mpi_error_handler import init_mpi_error_handler

init_mpi_error_handler()

core_fields = [
"x",
Expand Down Expand Up @@ -128,7 +133,6 @@ def read_lightcone(partition_cube: Partition, lightcone_path: Path, simulation_n
halo_lc["fof_halo_tag_clean"][mask_fragment],
halo_lc["fragment_idx"][mask_fragment],
) = split_fragment_tag(halo_lc["id"][mask_fragment])
assert len(np.unique(halo_lc["fof_halo_tag_clean"])) == len(halo_lc["id"])
assert np.all(halo_lc["fof_halo_tag_clean"] >= 0)

halo_lc["qz"] = (halo_lc["fof_halo_tag_clean"] % simulation_np) / simulation_np
Expand All @@ -149,7 +153,44 @@ def read_lightcone(partition_cube: Partition, lightcone_path: Path, simulation_n
s = np.argsort(halo_lc["fof_halo_tag_clean"])
halo_lc = {k: v[s] for k, v in halo_lc.items()}

unique_id = (halo_lc["replication"] << 41) + halo_lc["fof_halo_tag_clean"]
assert np.all(halo_lc["fof_halo_tag_clean"] >= 0)
assert np.all(halo_lc["fof_halo_tag_clean"] < (1 << 41))
assert np.all(halo_lc["replication"] >= 0)
assert np.all(halo_lc["replication"] < (1 << 22))
unique_id = (halo_lc["replication"].astype(np.int64) << 41) + halo_lc["fof_halo_tag_clean"]
num_duplicates = len(unique_id) - len(np.unique(unique_id))
max_duplicates = 0
num_duplicated_halos = 0
if num_duplicates > 0:
_uq, _idx, _cnt = np.unique(unique_id, return_index=True, return_counts=True)
_mask = _cnt > 1
max_duplicates = np.max(_cnt)
num_duplicated_halos = np.sum(_mask)
# _prtidx = np.argmax(_cnt)
# _prtindices = np.nonzero(unique_id == _uq[_prtidx])[0]
# print(
# f"DEBUG:: rank {partition_cube.rank} found {np.sum(_mask)} halos with the same fof_halo_tag/replication",
# f"max duplicate count {_cnt[_prtidx]}",
# f"fof_halo_tag_clean={halo_lc['fof_halo_tag_clean'][_prtindices]}",
# f"fof_halo_tag={halo_lc['id'][_prtindices]}",
# f"replication={halo_lc['replication'][_prtindices]}",
# f"pos=[{halo_lc['x'][_prtindices]}, {halo_lc['y'][_prtindices]}, {halo_lc['z'][_prtindices]}]",
# flush=True
# )
halo_lc = {k: v[_idx] for k, v in halo_lc.items()}
unique_id = unique_id[_idx]

# make sure it's sorted by fof_halo_tag_clean
s = np.argsort(halo_lc["fof_halo_tag_clean"])
halo_lc = {k: v[s] for k, v in halo_lc.items()}
unique_id = unique_id[s]

max_duplicates_global = partition_cube.comm.reduce(max_duplicates, op=MPI.MAX, root=0)
num_duplicated_halos_global = partition_cube.comm.reduce(num_duplicated_halos, op=MPI.SUM, root=0)
if partition_cube.rank == 0:
print(f"DEBUG:: found {num_duplicated_halos_global} duplicated halos, max duplicates {max_duplicates_global}", flush=True)


assert len(np.unique(unique_id)) == len(halo_lc["id"])

unique_tags, unique_reverse, unique_counts = np.unique(
Expand Down Expand Up @@ -282,15 +323,21 @@ def cli(
partition_cube.comm.Barrier()

# Read LC shell per step
if partition_cube.rank == 0:
print(" - Reading lightcone", flush=True)
lightcone_catalog = lightcone_pattern.replace("#", str(step))
halo_lc = read_lightcone(partition_cube, lightcone_catalog, sim_np)
partition_cube.comm.Barrier()

if partition_cube.rank == 0:
print(" - Distribute cores at step", flush=True)
cores_step = distribute_cores_at_step(
partition_cube, corematrix, snap_num, sim_np
)
# At this point, cores and halos on the lightcone are on the same rank

# At this point, cores and halos on the lightcone are on the same rank
if partition_cube.rank == 0:
print(" - Match LC with cores", flush=True)
# Get all the cores whos parent halo intersects with the lightcone
mask = np.isin(cores_step["fof_halo_tag_clean"], halo_lc["fof_halo_tag_clean"])
cores_step = {k: v[mask] for k, v in cores_step.items()}
Expand All @@ -300,63 +347,88 @@ def cli(
cores_step = {k: v[s] for k, v in cores_step.items()}

# Match cores to halos by fof_halo_tag (without fragment index)
assert np.all(np.diff(halo_lc["fof_halo_tag_clean"]) >= 0) # check if sorted
lc_index = np.searchsorted(
halo_lc["fof_halo_tag_clean"], cores_step["fof_halo_tag_clean"]
)
assert np.all(lc_index < len(halo_lc["fof_halo_tag_clean"]))
assert np.all(lc_index >= 0)
assert np.all(
halo_lc["fof_halo_tag_clean"][lc_index] == cores_step["fof_halo_tag_clean"]
)

# Find fof_halo_tag with correct fragment tag idx inside cores
if partition_cube.rank == 0:
print(" - find central core with matching fragment tag", flush=True)
cores_lc_host_tag = halo_lc["id"][lc_index]
mask_missing = ~np.isin(cores_lc_host_tag, cores_step["fof_halo_tag"])
if np.any(mask_missing):
message = f"DEBUG rank {partition_cube.rank}:\n"
message += f" Missing {np.sum(mask_missing)} fof_halo_tag (out of {len(mask_missing)})\n"
message += f" Missing fof_halo_tag: {cores_lc_host_tag[mask_missing]}\n"
message += f" Missing fof_halo_tag_clean: {halo_lc['fof_halo_tag_clean'][lc_index][mask_missing]}\n"
message += f" Missing fragment_idx: {halo_lc['fragment_idx'][lc_index][mask_missing]}\n"
message += (
f" Required by cores: {cores_step['fof_halo_tag'][mask_missing]}\n"
)
print(message)
num_missing = np.sum(mask_missing)
num_total = len(mask_missing)
if num_missing > 0:
# message = f"DEBUG rank {partition_cube.rank}:\n"
# message += f" Missing {np.sum(mask_missing)} fof_halo_tag (out of {len(mask_missing)})\n"
# message += f" Missing fof_halo_tag: {cores_lc_host_tag[mask_missing]}\n"
# message += f" Missing fof_halo_tag_clean: {halo_lc['fof_halo_tag_clean'][lc_index][mask_missing]}\n"
# message += f" Missing fragment_idx: {halo_lc['fragment_idx'][lc_index][mask_missing]}\n"
# message += (
# f" Required by cores: {cores_step['fof_halo_tag'][mask_missing]}\n"
# )
# print(message)

# Remove missing halos
cores_step = {k: v[~mask_missing] for k, v in cores_step.items()}
cores_lc_host_tag = cores_lc_host_tag[~mask_missing]
lc_index = lc_index[~mask_missing]

num_missing_global = partition_cube.comm.reduce(num_missing, root=0)
num_total_global = partition_cube.comm.reduce(num_total, root=0)
if partition_cube.rank == 0:
print(f" - missing {num_missing_global} out of {num_total_global} fof_halo_tag", flush=True)
assert np.all(np.isin(cores_lc_host_tag, cores_step["fof_halo_tag"]))
partition_cube.comm.Barrier()

# Handle replications
if partition_cube.rank == 0:
print(" - handle replications", flush=True)
if len(lc_index) > 0:
_counts = halo_lc["replications_count"][lc_index]
assert np.all(_counts > 0)
# replicate each core by the number of replications
s = np.repeat(np.arange(len(cores_lc_host_tag)), _counts)
cores_step = {k: v[s] for k, v in cores_step.items()}
lc_index = lc_index[s]
# offset each repeated index by 1
# offset each repeated index by 1 (so it points to the next replicated halo in the lightcone)
lc_index += np.concatenate([np.arange(c) for c in _counts])
# make sure we didn't screw up anything
assert np.all(
halo_lc["fof_halo_tag_clean"][lc_index]
== cores_step["fof_halo_tag_clean"]
)
cores_lc_host_tag = halo_lc["id"][lc_index]
mask_invalid = halo_lc["id"][lc_index] != cores_lc_host_tag[s]
if np.any(mask_invalid):
print(f"DEBUG:: found {np.sum(mask_invalid)} mismatching fof_halo_tag")
print(halo_lc["id"][lc_index][mask_invalid], cores_lc_host_tag[s][mask_invalid], flush=True)
# cores_lc_host_tag = halo_lc["id"][lc_index]
cores_lc_host_tag = cores_lc_host_tag[s]

# Some sanity checks:
# - we have all cores of halos in the lightcone
assert np.all(np.isin(cores_lc_host_tag, cores_step["fof_halo_tag"]))
cores_lc_host_idx = np.searchsorted(
cores_step["fof_halo_tag"], halo_lc["id"][lc_index]
cores_step["fof_halo_tag"],cores_lc_host_tag
)
# - check the fof_halo_tag matches
assert np.all(
cores_step["fof_halo_tag"][cores_lc_host_idx] == halo_lc["id"][lc_index]
cores_step["fof_halo_tag"][cores_lc_host_idx] == cores_lc_host_tag
)
# apparently they don't need to be centrals...
# assert np.all(cores_step["central"][cores_lc_host_idx] > 0)

partition_cube.comm.Barrier()
if partition_cube.rank == 0:
print(" - calculate core offsets", flush=True)
# Calculate distances
mask_valid = np.ones_like(cores_step["x"], dtype=np.bool_)
for x in "xyz":
_dx = (
cores_step[x]
Expand All @@ -366,17 +438,22 @@ def cli(
_dx[_dx > config.simulation.rl / 2] -= config.simulation.rl
_dx[_dx < -config.simulation.rl / 2] += config.simulation.rl
if not np.all(np.abs(_dx) <= 20):
print("DEBUG:: found large dx")
print(x, _dx[np.abs(_dx) > 20], flush=True)
assert np.all(np.abs(_dx) <= 20)
print("DEBUG:: found large dx", x, _dx[np.abs(_dx) > 20], flush=True)
mask_valid &= np.abs(_dx) <= 20
# assert np.all(np.abs(_dx) <= 20)
cores_step[f"d{x}"] = _dx
cores_step = {k: v[mask_valid] for k, v in cores_step.items()}
lc_index = lc_index[mask_valid]

for x in "xyz":
cores_step[x] = halo_lc[x][lc_index] + cores_step[f"d{x}"]
cores_step[f"host_{x}"] = halo_lc[x][lc_index]
r = np.sqrt(np.sum([cores_step[x] ** 2 for x in "xyz"], axis=0))
rhost = np.sqrt(np.sum([cores_step[f"host_{x}"] ** 2 for x in "xyz"], axis=0))

partition_cube.comm.Barrier()
if partition_cube.rank == 0:
print(" - calculate angular coordinates", flush=True)
# Calculate angular lightcone coordinates
# theta between [0, pi], phi between [0, 2pi]
cores_step["theta"] = np.arccos(cores_step["z"] / r)
Expand All @@ -396,6 +473,13 @@ def cli(
assert np.all(cores_step["host_phi"] >= 0)
assert np.all(cores_step["host_phi"] <= 2 * np.pi)

# if phi is 2pi, set it to 0
cores_step["phi"] = np.fmod(cores_step["phi"], 2 * np.pi)
cores_step["host_phi"] = np.fmod(cores_step["host_phi"], 2 * np.pi)

partition_cube.comm.Barrier()
if partition_cube.rank == 0:
print(" - distribute cores on LC", flush=True)
# distribute cores by the angular position of the host halo
cores_step = s2_distribute(
partition_s2,
Expand All @@ -404,6 +488,8 @@ def cli(
phi_key="host_phi",
)

if partition_cube.rank == 0:
print(" - write HDF5", flush=True)
# Write cores to file
output_file = output_base + f"-{step}.{partition_cube.rank}.hdf5"
output_fields = core_fields + [
Expand Down
28 changes: 28 additions & 0 deletions haccytrees/utils/mpi_error_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from mpi4py import MPI
import sys
import logging
import traceback

comm = MPI.COMM_WORLD


def _mpi_exception_handler(exc_type, exc_value, exc_traceback):
rank = comm.Get_rank()
if issubclass(exc_type, KeyboardInterrupt):
sys.__excepthook__(exc_type, exc_value, exc_traceback)
return
logging.error(
f"Uncaught exception on rank {rank}",
exc_info=(exc_type, exc_value, exc_traceback),
)
sys.stderr.write(f"Uncaught exception on rank {rank}\n")
sys.stderr.write(
"".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
)
with open(f"error_rank_{rank}.txt", "w") as f:
f.write("".join(traceback.format_exception_only(exc_type, exc_value)))
comm.Abort(1)


def init_mpi_error_handler():
sys.excepthook = _mpi_exception_handler

0 comments on commit 434e033

Please sign in to comment.