From 1a1440d14ee93343fb3d7c319411c474de60e4ad Mon Sep 17 00:00:00 2001 From: Jeremy Leibs Date: Tue, 17 Dec 2024 17:44:51 -0500 Subject: [PATCH] Update remote APIs to take a TableLike instead of a dictionary --- rerun_py/rerun_bindings/rerun_bindings.pyi | 22 +-- rerun_py/rerun_bindings/types.py | 5 +- rerun_py/src/remote.rs | 158 +++++++-------------- 3 files changed, 69 insertions(+), 116 deletions(-) diff --git a/rerun_py/rerun_bindings/rerun_bindings.pyi b/rerun_py/rerun_bindings/rerun_bindings.pyi index 53fd6781b43f..bd046bf15069 100644 --- a/rerun_py/rerun_bindings/rerun_bindings.pyi +++ b/rerun_py/rerun_bindings/rerun_bindings.pyi @@ -3,7 +3,7 @@ from typing import Iterator, Optional, Sequence, Union import pyarrow as pa -from .types import AnyColumn, AnyComponentColumn, ComponentLike, IndexValuesLike, MetadataLike, ViewContentsLike +from .types import AnyColumn, AnyComponentColumn, ComponentLike, IndexValuesLike, TableLike, ViewContentsLike class IndexColumnDescriptor: """ @@ -581,7 +581,7 @@ class StorageNodeClient: """Get the metadata for all recordings in the storage node.""" ... - def register(self, storage_url: str, metadata: Optional[dict[str, MetadataLike]] = None) -> str: + def register(self, storage_url: str, metadata: Optional[TableLike] = None) -> str: """ Register a recording along with some metadata. @@ -589,22 +589,24 @@ class StorageNodeClient: ---------- storage_url : str The URL to the storage location. - metadata : dict[str, MetadataLike] - A dictionary where the keys are the metadata columns and the values are pyarrow arrays. + metadata : Optional[Table | RecordBatch] + A pyarrow Table or RecordBatch containing the metadata to update. + This Table must contain only a single row. """ ... - def update_catalog(self, id: str, metadata: dict[str, MetadataLike]) -> None: + def update_catalog(self, metadata: TableLike) -> None: """ - Update the metadata for the recording with the given id. + Update the catalog metadata for one or more recordings. + + The updates are provided as a pyarrow Table or RecordBatch containing the metadata to update. + The Table must contain an 'id' column, which is used to specify the recording to update for each row. Parameters ---------- - id : str - The id of the recording to update. - metadata : dict[str, MetadataLike] - A dictionary where the keys are the metadata columns and the values are pyarrow arrays. + metadata : Table | RecordBatch + A pyarrow Table or RecordBatch containing the metadata to update. """ ... diff --git a/rerun_py/rerun_bindings/types.py b/rerun_py/rerun_bindings/types.py index c5ddf94e7477..a38e70036d34 100644 --- a/rerun_py/rerun_bindings/types.py +++ b/rerun_py/rerun_bindings/types.py @@ -68,4 +68,7 @@ This can be any numpy-compatible array of integers, or a [`pa.Int64Array`][] """ -MetadataLike: TypeAlias = pa.Array +TableLike: TypeAlias = Union[pa.Table, pa.RecordBatch, pa.RecordBatchReader] +""" +A type alias for TableLike pyarrow objects. +""" diff --git a/rerun_py/src/remote.rs b/rerun_py/src/remote.rs index a74f3a1a0666..2619d3990088 100644 --- a/rerun_py/src/remote.rs +++ b/rerun_py/src/remote.rs @@ -1,11 +1,12 @@ #![allow(unsafe_op_in_unsafe_fn)] use arrow::{ - array::{ArrayData, RecordBatch, RecordBatchIterator, RecordBatchReader}, + array::{RecordBatch, RecordBatchIterator, RecordBatchReader}, datatypes::Schema, + ffi_stream::ArrowArrayStreamReader, pyarrow::PyArrowType, }; // False positive due to #[pyfunction] macro -use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyDict, Bound, PyResult}; +use pyo3::{exceptions::PyRuntimeError, prelude::*, Bound, PyResult}; use re_chunk::{Chunk, TransportChunk}; use re_chunk_store::ChunkStore; use re_dataframe::ChunkStoreHandle; @@ -134,17 +135,14 @@ impl PyStorageNodeClient { /// ---------- /// storage_url : str /// The URL to the storage location. - /// metadata : dict[str, MetadataLike] - /// A dictionary where the keys are the metadata columns and the values are pyarrow arrays. + /// metadata : Optional[Table | RecordBatch] + /// A pyarrow Table or RecordBatch containing the metadata to update. + /// This Table must contain only a single row. #[pyo3(signature = ( storage_url, metadata = None ))] - fn register( - &mut self, - storage_url: &str, - metadata: Option<&Bound<'_, PyDict>>, - ) -> PyResult { + fn register(&mut self, storage_url: &str, metadata: Option) -> PyResult { self.runtime.block_on(async { let storage_url = url::Url::parse(storage_url) .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; @@ -152,51 +150,32 @@ impl PyStorageNodeClient { let _obj = object_store::ObjectStoreScheme::parse(&storage_url) .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; - let payload = metadata + let metadata = metadata .map(|metadata| { - let (schema, data): ( - Vec, - Vec>, - ) = metadata - .iter() - .map(|(key, value)| { - let key = key.to_string(); - let value = value.extract::()?; - let value_array = value.to_arrow2()?; - let field = arrow2::datatypes::Field::new( - key, - value_array.data_type().clone(), - true, - ); - Ok((field, value_array)) - }) - .collect::>>()? - .into_iter() - .unzip(); - - let schema = arrow2::datatypes::Schema::from(schema); - let data = arrow2::chunk::Chunk::new(data); - - let metadata_tc = TransportChunk { - schema: schema.clone(), - data, - }; + let metadata = metadata.into_record_batch()?; + + if metadata.num_rows() != 1 { + return Err(PyRuntimeError::new_err( + "Metadata must contain exactly one row", + )); + } + + let metadata_tc = TransportChunk::from_arrow_record_batch(&metadata); encode(EncoderVersion::V0, metadata_tc) .map_err(|err| PyRuntimeError::new_err(err.to_string())) }) .transpose()? - // TODO(zehiko) this is going away soon - .ok_or(PyRuntimeError::new_err("No metadata"))?; + .map(|payload| DataframePart { + encoder_version: EncoderVersion::V0 as i32, + payload, + }); let request = RegisterRecordingRequest { // TODO(jleibs): Description should really just be in the metadata description: Default::default(), storage_url: storage_url.to_string(), - metadata: Some(DataframePart { - encoder_version: EncoderVersion::V0 as i32, - payload, - }), + metadata, typ: RecordingType::Rrd.into(), }; @@ -226,48 +205,33 @@ impl PyStorageNodeClient { }) } - /// Update the metadata for the recording with the given id. + /// Update the catalog metadata for one or more recordings. + /// + /// The updates are provided as a pyarrow Table or RecordBatch containing the metadata to update. + /// The Table must contain an 'id' column, which is used to specify the recording to update for each row. /// /// Parameters /// ---------- - /// id : str - /// The id of the recording to update. - /// metadata : dict[str, MetadataLike] - /// A dictionary where the keys are the metadata columns and the values are pyarrow arrays. + /// metadata : Table | RecordBatch + /// A pyarrow Table or RecordBatch containing the metadata to update. #[pyo3(signature = ( - id, metadata ))] - fn update_catalog(&mut self, id: &str, metadata: &Bound<'_, PyDict>) -> PyResult<()> { + #[allow(clippy::needless_pass_by_value)] + fn update_catalog(&mut self, metadata: MetadataLike) -> PyResult<()> { self.runtime.block_on(async { - let (schema, data): ( - Vec, - Vec>, - ) = metadata - .iter() - .map(|(key, value)| { - let key = key.to_string(); - let value = value.extract::()?; - let value_array = value.to_arrow2()?; - let field = - arrow2::datatypes::Field::new(key, value_array.data_type().clone(), true); - Ok((field, value_array)) - }) - .collect::>>()? - .into_iter() - .unzip(); + let metadata = metadata.into_record_batch()?; - let schema = arrow2::datatypes::Schema::from(schema); - - let data = arrow2::chunk::Chunk::new(data); + // TODO(jleibs): This id name should probably come from `re_protos` + if metadata.schema().column_with_name("id").is_none() { + return Err(PyRuntimeError::new_err( + "Metadata must contain an 'id' column", + )); + } - let metadata_tc = TransportChunk { - schema: schema.clone(), - data, - }; + let metadata_tc = TransportChunk::from_arrow_record_batch(&metadata); let request = UpdateCatalogRequest { - recording_id: Some(RecordingId { id: id.to_owned() }), metadata: Some(DataframePart { encoder_version: EncoderVersion::V0 as i32, payload: encode(EncoderVersion::V0, metadata_tc) @@ -363,39 +327,23 @@ impl PyStorageNodeClient { /// A type alias for metadata. #[derive(FromPyObject)] enum MetadataLike { - PyArrow(PyArrowType), - // TODO(jleibs): Support converting other primitives + RecordBatch(PyArrowType), + Reader(PyArrowType), } impl MetadataLike { - fn to_arrow2(&self) -> PyResult> { - match self { - Self::PyArrow(array) => { - let array = arrow2::array::from_data(&array.0); - if array.len() == 1 { - Ok(array) - } else { - Err(PyRuntimeError::new_err( - "Metadata must be a single array, not a list", - )) - } - } - } - } - - #[allow(dead_code)] - fn to_arrow(&self) -> PyResult> { - match self { - Self::PyArrow(array) => { - let array = arrow::array::make_array(array.0.clone()); - if array.len() == 1 { - Ok(array) - } else { - Err(PyRuntimeError::new_err( - "Metadata must be a single array, not a list", - )) - } - } - } + fn into_record_batch(self) -> PyResult { + let (schema, batches) = match self { + Self::RecordBatch(record_batch) => (record_batch.0.schema(), vec![record_batch.0]), + Self::Reader(reader) => ( + reader.0.schema(), + reader.0.collect::, _>>().map_err(|err| { + PyRuntimeError::new_err(format!("Failed to read RecordBatches: {err}")) + })?, + ), + }; + + arrow::compute::concat_batches(&schema, &batches) + .map_err(|err| PyRuntimeError::new_err(err.to_string())) } }