Skip to content

Commit

Permalink
Merge pull request #12 from B612-Asteroid-Institute/nt/add_walltime
Browse files Browse the repository at this point in the history
add a run duration to search summary
  • Loading branch information
ntellis authored Aug 13, 2024
2 parents b78e5e6 + b3d3ea8 commit 19117e3
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 0 deletions.
7 changes: 7 additions & 0 deletions ipod/ipod.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import time
from typing import Optional, Tuple, Type, Union

import numpy as np
Expand Down Expand Up @@ -40,6 +41,7 @@ class SearchSummary(qv.Table):
arc_length = qv.Float64Column()
num_obs_prev = qv.Int64Column()
num_obs = qv.Int64Column()
run_duration = qv.Float64Column()


class OrbitOutliers(qv.Table):
Expand Down Expand Up @@ -84,6 +86,8 @@ def ipod(
propagator: Type[Propagator] = PYOORB,
propagator_kwargs: dict = {},
) -> Tuple[FittedOrbits, FittedOrbitMembers, PrecoveryCandidates, SearchSummary]:

time_start = time.perf_counter()
logger.debug(f"Running ipod with orbit {orbit.orbit_id[0].as_py()}...")
if astrometric_errors is None:
astrometric_errors = DEFAULT_ASTROMETRIC_ERRORS
Expand Down Expand Up @@ -301,6 +305,7 @@ def ipod(
"arc_length": 0,
"num_obs_prev": orbit_iter.num_obs[0].as_py(),
"num_obs": 0,
"run_duration": 0.0,
}

failed_corrections = 0
Expand Down Expand Up @@ -731,6 +736,8 @@ def ipod(
candidates = candidates_iter.apply_mask(
pc.is_in(candidates_iter.observation_id, orbit_members_iter.obs_id)
)
time_end = time.perf_counter()
search_summary_iter["run_duration"] = time_end - time_start

search_summary = SearchSummary.from_kwargs(
**{k: [v] for k, v in search_summary_iter.items()}
Expand Down
20 changes: 20 additions & 0 deletions ipod/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def merge_and_extend_orbits(
orbit_members_out = FittedOrbitMembers.empty()
orbit_candidates_out = PrecoveryCandidates.empty()
search_summary_out = SearchSummary.empty()
search_summary_iter = SearchSummary.empty()

if max_processes is None:
max_processes = mp.cpu_count()
Expand Down Expand Up @@ -88,6 +89,8 @@ def merge_and_extend_orbits(
np.ceil(len(orbits_iter) / max_processes).astype(int), chunk_size
)

search_summary_last_iter = search_summary_iter

# Run iterative precovery and differential correction
(
orbits_iter,
Expand Down Expand Up @@ -123,6 +126,8 @@ def merge_and_extend_orbits(
)
break

search_summary_iter = search_summary_iter.sort_by(["orbit_id"])

orbits_iter, orbit_members_iter = assign_duplicate_observations(
orbits_iter, orbit_members_iter
)
Expand Down Expand Up @@ -187,6 +192,21 @@ def merge_and_extend_orbits(
pc.is_in(search_summary_iter.orbit_id, orbits_iter.orbit_id)
)

# if we are on at least the second iteration, accumulate the run_duration
if iterations > 0:
mask = pc.is_in(
search_summary_last_iter.orbit_id, search_summary_iter.orbit_id
)
search_summary_last_iter = search_summary_last_iter.apply_mask(mask)
search_summary_last_iter = search_summary_last_iter.sort_by(["orbit_id"])
search_summary_iter = search_summary_iter.set_column(
"run_duration",
pc.add(
search_summary_last_iter.run_duration,
search_summary_iter.run_duration,
),
)

# Identify orbits that have not had their differential correction converge to a new solution
# and add them to the outgoing tables, also identify any orbits that have not
# had any new observations added since the previous iteration
Expand Down
58 changes: 58 additions & 0 deletions ipod/tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import shutil
import time

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -361,3 +362,60 @@ def test_ipod_orbit_outliers_all_bad(
all_obs_less_outliers.id.to_numpy(zero_copy_only=False),
)
break


def test_ipod_runtime(precovery_db, orbits, observations, od_observations):

db, test_db_dir, offset_ids = precovery_db

exposures, detections, associations = observations

for object_id in OBJECT_IDS:

# orbit = orbits.select("object_id", object_id)
orbit = orbits.select("object_id", object_id)

associations_i = associations.select("object_id", orbit.object_id[0])

od_observations_i = od_observations.apply_mask(
pc.is_in(od_observations.id, associations_i.detection_id)
)

fitted_orbits, fitted_orbit_members = evaluate_orbits(
orbit, od_observations_i, propagator=PYOORB()
)

detections_i = detections.apply_mask(
pc.is_in(detections.id, od_observations_i.id)
)
mjd_min = pc.min(detections_i.time.mjd())
mjd_max = pc.max(detections_i.time.mjd())

time_start = time.perf_counter()
ipod_result = ipod(
fitted_orbits,
od_observations_i[:10],
max_tolerance=10.0,
tolerance_step=2.0,
delta_time=10.0,
min_mjd=mjd_min.as_py() - 1.0,
max_mjd=mjd_max.as_py() + 1.0,
astrometric_errors={"default": (0.1, 0.1)},
database=db,
)
time_end = time.perf_counter()

(
ipod_fitted_orbits_i,
ipod_fitted_orbit_members_i,
precovery_candidates,
search_summary,
) = ipod_result

# assert we are accumulating runtimes > 0
assert search_summary.run_duration[0].as_py() > 0.0

# assert that our individual runtime is less than total
assert time_end - time_start > search_summary.run_duration[0].as_py()

break

0 comments on commit 19117e3

Please sign in to comment.