Skip to content

Commit

Permalink
Add all_orbits option to BigQuery client and faster downloading
Browse files Browse the repository at this point in the history
  • Loading branch information
akoumjian committed Nov 27, 2024
1 parent faa7012 commit 568877d
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 19 deletions.
26 changes: 14 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ dependencies = [
"google-cloud-bigquery",
"google-cloud-secret-manager",
"numpy",
"quivr"
"quivr",
"google-cloud-bigquery-storage>=2.27.0",
"tqdm>=4.67.1"
]

[build-system]
Expand Down Expand Up @@ -69,17 +71,17 @@ coverage = "pytest --cov=mpcq --cov-report=xml"

[project.optional-dependencies]
dev = [
"black",
"ipython",
"isort",
"mypy",
"pdm",
"pytest-benchmark",
"pytest-cov",
"pytest-doctestplus",
"pytest-mock",
"pytest",
"ruff",
"black",
"ipython",
"isort",
"mypy",
"pdm",
"pytest-benchmark",
"pytest-cov",
"pytest-doctestplus",
"pytest-mock",
"pytest",
"ruff",
]

[tool.black]
Expand Down
2 changes: 1 addition & 1 deletion src/mpcq/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.0"
__version__ = "0.2.6.dev20+gfaa7012.d20241127"
97 changes: 91 additions & 6 deletions src/mpcq/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def query_observations(self, provids: List[str]) -> MPCObservations:
"""
query_job = self.client.query(query)
results = query_job.result()
table = results.to_arrow()
table = results.to_arrow(progress_bar_type="tqdm", create_bqstorage_client=True)

obstime = Time(
table["obstime"].to_numpy(zero_copy_only=False),
Expand Down Expand Up @@ -216,6 +216,91 @@ def query_observations(self, provids: List[str]) -> MPCObservations:
status=table["status"],
)

def all_orbits(self) -> MPCOrbits:
"""
Query the MPC database for all orbits and associated data.
Returns
-------
orbits : MPCOrbits
The orbits and associated data for all objects in the MPC database.
"""
query = f"""
SELECT
mpc_orbits.id,
mpc_orbits.unpacked_primary_provisional_designation AS provid,
mpc_orbits.epoch_mjd,
mpc_orbits.q,
mpc_orbits.e,
mpc_orbits.i,
mpc_orbits.node,
mpc_orbits.argperi,
mpc_orbits.peri_time,
mpc_orbits.q_unc,
mpc_orbits.e_unc,
mpc_orbits.i_unc,
mpc_orbits.node_unc,
mpc_orbits.argperi_unc,
mpc_orbits.peri_time_unc,
mpc_orbits.a1,
mpc_orbits.a2,
mpc_orbits.a3,
mpc_orbits.h,
mpc_orbits.g,
mpc_orbits.created_at,
mpc_orbits.updated_at
FROM `{self.dataset_id}.public_mpc_orbits` AS mpc_orbits
ORDER BY mpc_orbits.epoch_mjd ASC;
"""
query_job = self.client.query(query)
results = query_job.result()

table = results.to_arrow(progress_bar_type="tqdm", create_bqstorage_client=True)

created_at = Time(
table["created_at"].to_numpy(zero_copy_only=False),
format="datetime64",
scale="utc",
)
updated_at = Time(
table["updated_at"].to_numpy(zero_copy_only=False),
format="datetime64",
scale="utc",
)

# Handle NULL values in the epoch_mjd column: ideally
# we should have the Timestamp class be able to handle this
mjd_array = table["epoch_mjd"].to_numpy(zero_copy_only=False)
mjds = np.ma.masked_array(mjd_array, mask=np.isnan(mjd_array))
epoch = Time(mjds, format="mjd", scale="tt")

return MPCOrbits.from_kwargs(
# Note, since we didn't request a specific provid we use the one MPC provides
requested_provid=table["provid"],
id=table["id"],
provid=table["provid"],
epoch=Timestamp.from_astropy(epoch),
q=table["q"],
e=table["e"],
i=table["i"],
node=table["node"],
argperi=table["argperi"],
peri_time=table["peri_time"],
q_unc=table["q_unc"],
e_unc=table["e_unc"],
i_unc=table["i_unc"],
node_unc=table["node_unc"],
argperi_unc=table["argperi_unc"],
peri_time_unc=table["peri_time_unc"],
a1=table["a1"],
a2=table["a2"],
a3=table["a3"],
h=table["h"],
g=table["g"],
created_at=Timestamp.from_astropy(created_at),
updated_at=Timestamp.from_astropy(updated_at),
)

def query_orbits(self, provids: List[str]) -> MPCOrbits:
"""
Query the MPC database for the orbits and associated data for the given
Expand Down Expand Up @@ -281,7 +366,7 @@ def query_orbits(self, provids: List[str]) -> MPCOrbits:
"""
query_job = self.client.query(query)
results = query_job.result()
table = results.to_arrow()
table = results.to_arrow(progress_bar_type="tqdm", create_bqstorage_client=True)

created_at = Time(
table["created_at"].to_numpy(zero_copy_only=False),
Expand Down Expand Up @@ -373,7 +458,7 @@ def query_submission_info(self, submission_ids: List[str]) -> MPCSubmissionResul
"""
query_job = self.client.query(query)
results = query_job.result()
table = results.to_arrow()
table = results.to_arrow(progress_bar_type="tqdm", create_bqstorage_client=True)

return MPCSubmissionResults.from_pyarrow(table)

Expand Down Expand Up @@ -423,9 +508,9 @@ def query_submission_history(self, provids: List[str]) -> MPCSubmissionHistory:
results = query_job.result()

# Convert the results to a PyArrow table
table = results.to_arrow(progress_bar_type="tqdm", create_bqstorage_client=True)
table = (
results.to_arrow()
.group_by(["requested_provid", "primary_designation", "submission_id"])
table.group_by(["requested_provid", "primary_designation", "submission_id"])
.aggregate(
[("obsid", "count_distinct"), ("obstime", "min"), ("obstime", "max")]
)
Expand Down Expand Up @@ -527,7 +612,7 @@ def query_primary_objects(self, provids: List[str]) -> MPCPrimaryObjects:
"""
query_job = self.client.query(query)
results = query_job.result()
table = results.to_arrow()
table = results.to_arrow(progress_bar_type="tqdm", create_bqstorage_client=True)

created_at = Time(
table["created_at"].to_numpy(zero_copy_only=False),
Expand Down

0 comments on commit 568877d

Please sign in to comment.