From 84da13220f7bbd880a579a37853ebb1ea8b81a2c Mon Sep 17 00:00:00 2001 From: Jeremy Leibs Date: Wed, 18 Dec 2024 08:20:46 +0100 Subject: [PATCH] Update the remote APIs to take and send a Table (#8521) This allows the update APIs to now support multi-recording updates, and generally gives users more direct control over things like the column metadata, etc. --- .../proto/rerun/v0/remote_store.proto | 1 - .../re_protos/src/v0/rerun.remote_store.v0.rs | 2 - examples/python/remote/metadata.py | 17 +- rerun_py/rerun_bindings/rerun_bindings.pyi | 22 +-- rerun_py/rerun_bindings/types.py | 5 +- rerun_py/src/remote.rs | 158 ++++++------------ 6 files changed, 83 insertions(+), 122 deletions(-) diff --git a/crates/store/re_protos/proto/rerun/v0/remote_store.proto b/crates/store/re_protos/proto/rerun/v0/remote_store.proto index 2955f45b8b3c..5d2f68ca3794 100644 --- a/crates/store/re_protos/proto/rerun/v0/remote_store.proto +++ b/crates/store/re_protos/proto/rerun/v0/remote_store.proto @@ -44,7 +44,6 @@ message RegisterRecordingRequest { // ---------------- UpdateCatalog ----------------- message UpdateCatalogRequest { - rerun.common.v0.RecordingId recording_id = 1; DataframePart metadata = 2; } diff --git a/crates/store/re_protos/src/v0/rerun.remote_store.v0.rs b/crates/store/re_protos/src/v0/rerun.remote_store.v0.rs index 1429146e776d..a755b57daac4 100644 --- a/crates/store/re_protos/src/v0/rerun.remote_store.v0.rs +++ b/crates/store/re_protos/src/v0/rerun.remote_store.v0.rs @@ -47,8 +47,6 @@ impl ::prost::Name for RegisterRecordingRequest { } #[derive(Clone, PartialEq, ::prost::Message)] pub struct UpdateCatalogRequest { - #[prost(message, optional, tag = "1")] - pub recording_id: ::core::option::Option, #[prost(message, optional, tag = "2")] pub metadata: ::core::option::Option, } diff --git a/examples/python/remote/metadata.py b/examples/python/remote/metadata.py index e199499d8f2a..31371bfac2d3 100644 --- a/examples/python/remote/metadata.py +++ b/examples/python/remote/metadata.py @@ -14,23 +14,31 @@ subparsers = parser.add_subparsers(dest="subcommand") print_cmd = subparsers.add_parser("print", help="Print everything") + register_cmd = subparsers.add_parser("register", help="Register a new recording") update_cmd = subparsers.add_parser("update", help="Update metadata for a recording") update_cmd.add_argument("id", help="ID of the recording to update") update_cmd.add_argument("key", help="Key of the metadata to update") update_cmd.add_argument("value", help="Value of the metadata to update") + register_cmd.add_argument("storage_url", help="Storage URL to register") + args = parser.parse_args() # Register the new rrd conn = rr.remote.connect("http://0.0.0.0:51234") - catalog = pl.from_arrow(conn.query_catalog()) + catalog = pl.from_arrow(conn.query_catalog().read_all()) if args.subcommand == "print": print(catalog) - if args.subcommand == "update": + elif args.subcommand == "register": + extra_metadata = pa.Table.from_pydict({"extra": [42]}) + id = conn.register(args.storage_url, extra_metadata) + print(f"Registered new recording with ID: {id}") + + elif args.subcommand == "update": id = catalog.filter(catalog["id"].str.starts_with(args.id)).select(pl.first("id")).item() if id is None: @@ -38,4 +46,7 @@ exit(1) print(f"Updating metadata for {id}") - conn.update_catalog(id, {args.key: pa.array([args.value])}) + new_metadata = pa.Table.from_pydict({"id": [id], args.key: [args.value]}) + print(new_metadata) + + conn.update_catalog(new_metadata) diff --git a/rerun_py/rerun_bindings/rerun_bindings.pyi b/rerun_py/rerun_bindings/rerun_bindings.pyi index 53fd6781b43f..bd046bf15069 100644 --- a/rerun_py/rerun_bindings/rerun_bindings.pyi +++ b/rerun_py/rerun_bindings/rerun_bindings.pyi @@ -3,7 +3,7 @@ from typing import Iterator, Optional, Sequence, Union import pyarrow as pa -from .types import AnyColumn, AnyComponentColumn, ComponentLike, IndexValuesLike, MetadataLike, ViewContentsLike +from .types import AnyColumn, AnyComponentColumn, ComponentLike, IndexValuesLike, TableLike, ViewContentsLike class IndexColumnDescriptor: """ @@ -581,7 +581,7 @@ class StorageNodeClient: """Get the metadata for all recordings in the storage node.""" ... - def register(self, storage_url: str, metadata: Optional[dict[str, MetadataLike]] = None) -> str: + def register(self, storage_url: str, metadata: Optional[TableLike] = None) -> str: """ Register a recording along with some metadata. @@ -589,22 +589,24 @@ class StorageNodeClient: ---------- storage_url : str The URL to the storage location. - metadata : dict[str, MetadataLike] - A dictionary where the keys are the metadata columns and the values are pyarrow arrays. + metadata : Optional[Table | RecordBatch] + A pyarrow Table or RecordBatch containing the metadata to update. + This Table must contain only a single row. """ ... - def update_catalog(self, id: str, metadata: dict[str, MetadataLike]) -> None: + def update_catalog(self, metadata: TableLike) -> None: """ - Update the metadata for the recording with the given id. + Update the catalog metadata for one or more recordings. + + The updates are provided as a pyarrow Table or RecordBatch containing the metadata to update. + The Table must contain an 'id' column, which is used to specify the recording to update for each row. Parameters ---------- - id : str - The id of the recording to update. - metadata : dict[str, MetadataLike] - A dictionary where the keys are the metadata columns and the values are pyarrow arrays. + metadata : Table | RecordBatch + A pyarrow Table or RecordBatch containing the metadata to update. """ ... diff --git a/rerun_py/rerun_bindings/types.py b/rerun_py/rerun_bindings/types.py index c5ddf94e7477..a38e70036d34 100644 --- a/rerun_py/rerun_bindings/types.py +++ b/rerun_py/rerun_bindings/types.py @@ -68,4 +68,7 @@ This can be any numpy-compatible array of integers, or a [`pa.Int64Array`][] """ -MetadataLike: TypeAlias = pa.Array +TableLike: TypeAlias = Union[pa.Table, pa.RecordBatch, pa.RecordBatchReader] +""" +A type alias for TableLike pyarrow objects. +""" diff --git a/rerun_py/src/remote.rs b/rerun_py/src/remote.rs index a74f3a1a0666..2619d3990088 100644 --- a/rerun_py/src/remote.rs +++ b/rerun_py/src/remote.rs @@ -1,11 +1,12 @@ #![allow(unsafe_op_in_unsafe_fn)] use arrow::{ - array::{ArrayData, RecordBatch, RecordBatchIterator, RecordBatchReader}, + array::{RecordBatch, RecordBatchIterator, RecordBatchReader}, datatypes::Schema, + ffi_stream::ArrowArrayStreamReader, pyarrow::PyArrowType, }; // False positive due to #[pyfunction] macro -use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyDict, Bound, PyResult}; +use pyo3::{exceptions::PyRuntimeError, prelude::*, Bound, PyResult}; use re_chunk::{Chunk, TransportChunk}; use re_chunk_store::ChunkStore; use re_dataframe::ChunkStoreHandle; @@ -134,17 +135,14 @@ impl PyStorageNodeClient { /// ---------- /// storage_url : str /// The URL to the storage location. - /// metadata : dict[str, MetadataLike] - /// A dictionary where the keys are the metadata columns and the values are pyarrow arrays. + /// metadata : Optional[Table | RecordBatch] + /// A pyarrow Table or RecordBatch containing the metadata to update. + /// This Table must contain only a single row. #[pyo3(signature = ( storage_url, metadata = None ))] - fn register( - &mut self, - storage_url: &str, - metadata: Option<&Bound<'_, PyDict>>, - ) -> PyResult { + fn register(&mut self, storage_url: &str, metadata: Option) -> PyResult { self.runtime.block_on(async { let storage_url = url::Url::parse(storage_url) .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; @@ -152,51 +150,32 @@ impl PyStorageNodeClient { let _obj = object_store::ObjectStoreScheme::parse(&storage_url) .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; - let payload = metadata + let metadata = metadata .map(|metadata| { - let (schema, data): ( - Vec, - Vec>, - ) = metadata - .iter() - .map(|(key, value)| { - let key = key.to_string(); - let value = value.extract::()?; - let value_array = value.to_arrow2()?; - let field = arrow2::datatypes::Field::new( - key, - value_array.data_type().clone(), - true, - ); - Ok((field, value_array)) - }) - .collect::>>()? - .into_iter() - .unzip(); - - let schema = arrow2::datatypes::Schema::from(schema); - let data = arrow2::chunk::Chunk::new(data); - - let metadata_tc = TransportChunk { - schema: schema.clone(), - data, - }; + let metadata = metadata.into_record_batch()?; + + if metadata.num_rows() != 1 { + return Err(PyRuntimeError::new_err( + "Metadata must contain exactly one row", + )); + } + + let metadata_tc = TransportChunk::from_arrow_record_batch(&metadata); encode(EncoderVersion::V0, metadata_tc) .map_err(|err| PyRuntimeError::new_err(err.to_string())) }) .transpose()? - // TODO(zehiko) this is going away soon - .ok_or(PyRuntimeError::new_err("No metadata"))?; + .map(|payload| DataframePart { + encoder_version: EncoderVersion::V0 as i32, + payload, + }); let request = RegisterRecordingRequest { // TODO(jleibs): Description should really just be in the metadata description: Default::default(), storage_url: storage_url.to_string(), - metadata: Some(DataframePart { - encoder_version: EncoderVersion::V0 as i32, - payload, - }), + metadata, typ: RecordingType::Rrd.into(), }; @@ -226,48 +205,33 @@ impl PyStorageNodeClient { }) } - /// Update the metadata for the recording with the given id. + /// Update the catalog metadata for one or more recordings. + /// + /// The updates are provided as a pyarrow Table or RecordBatch containing the metadata to update. + /// The Table must contain an 'id' column, which is used to specify the recording to update for each row. /// /// Parameters /// ---------- - /// id : str - /// The id of the recording to update. - /// metadata : dict[str, MetadataLike] - /// A dictionary where the keys are the metadata columns and the values are pyarrow arrays. + /// metadata : Table | RecordBatch + /// A pyarrow Table or RecordBatch containing the metadata to update. #[pyo3(signature = ( - id, metadata ))] - fn update_catalog(&mut self, id: &str, metadata: &Bound<'_, PyDict>) -> PyResult<()> { + #[allow(clippy::needless_pass_by_value)] + fn update_catalog(&mut self, metadata: MetadataLike) -> PyResult<()> { self.runtime.block_on(async { - let (schema, data): ( - Vec, - Vec>, - ) = metadata - .iter() - .map(|(key, value)| { - let key = key.to_string(); - let value = value.extract::()?; - let value_array = value.to_arrow2()?; - let field = - arrow2::datatypes::Field::new(key, value_array.data_type().clone(), true); - Ok((field, value_array)) - }) - .collect::>>()? - .into_iter() - .unzip(); + let metadata = metadata.into_record_batch()?; - let schema = arrow2::datatypes::Schema::from(schema); - - let data = arrow2::chunk::Chunk::new(data); + // TODO(jleibs): This id name should probably come from `re_protos` + if metadata.schema().column_with_name("id").is_none() { + return Err(PyRuntimeError::new_err( + "Metadata must contain an 'id' column", + )); + } - let metadata_tc = TransportChunk { - schema: schema.clone(), - data, - }; + let metadata_tc = TransportChunk::from_arrow_record_batch(&metadata); let request = UpdateCatalogRequest { - recording_id: Some(RecordingId { id: id.to_owned() }), metadata: Some(DataframePart { encoder_version: EncoderVersion::V0 as i32, payload: encode(EncoderVersion::V0, metadata_tc) @@ -363,39 +327,23 @@ impl PyStorageNodeClient { /// A type alias for metadata. #[derive(FromPyObject)] enum MetadataLike { - PyArrow(PyArrowType), - // TODO(jleibs): Support converting other primitives + RecordBatch(PyArrowType), + Reader(PyArrowType), } impl MetadataLike { - fn to_arrow2(&self) -> PyResult> { - match self { - Self::PyArrow(array) => { - let array = arrow2::array::from_data(&array.0); - if array.len() == 1 { - Ok(array) - } else { - Err(PyRuntimeError::new_err( - "Metadata must be a single array, not a list", - )) - } - } - } - } - - #[allow(dead_code)] - fn to_arrow(&self) -> PyResult> { - match self { - Self::PyArrow(array) => { - let array = arrow::array::make_array(array.0.clone()); - if array.len() == 1 { - Ok(array) - } else { - Err(PyRuntimeError::new_err( - "Metadata must be a single array, not a list", - )) - } - } - } + fn into_record_batch(self) -> PyResult { + let (schema, batches) = match self { + Self::RecordBatch(record_batch) => (record_batch.0.schema(), vec![record_batch.0]), + Self::Reader(reader) => ( + reader.0.schema(), + reader.0.collect::, _>>().map_err(|err| { + PyRuntimeError::new_err(format!("Failed to read RecordBatches: {err}")) + })?, + ), + }; + + arrow::compute::concat_batches(&schema, &batches) + .map_err(|err| PyRuntimeError::new_err(err.to_string())) } }