Skip to content

Commit

Permalink
Update remote APIs to take a TableLike instead of a dictionary
Browse files Browse the repository at this point in the history
  • Loading branch information
jleibs committed Dec 17, 2024
1 parent d7a61b0 commit 1a1440d
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 116 deletions.
22 changes: 12 additions & 10 deletions rerun_py/rerun_bindings/rerun_bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ from typing import Iterator, Optional, Sequence, Union

import pyarrow as pa

from .types import AnyColumn, AnyComponentColumn, ComponentLike, IndexValuesLike, MetadataLike, ViewContentsLike
from .types import AnyColumn, AnyComponentColumn, ComponentLike, IndexValuesLike, TableLike, ViewContentsLike

class IndexColumnDescriptor:
"""
Expand Down Expand Up @@ -581,30 +581,32 @@ class StorageNodeClient:
"""Get the metadata for all recordings in the storage node."""
...

def register(self, storage_url: str, metadata: Optional[dict[str, MetadataLike]] = None) -> str:
def register(self, storage_url: str, metadata: Optional[TableLike] = None) -> str:
"""
Register a recording along with some metadata.
Parameters
----------
storage_url : str
The URL to the storage location.
metadata : dict[str, MetadataLike]
A dictionary where the keys are the metadata columns and the values are pyarrow arrays.
metadata : Optional[Table | RecordBatch]
A pyarrow Table or RecordBatch containing the metadata to update.
This Table must contain only a single row.
"""
...

def update_catalog(self, id: str, metadata: dict[str, MetadataLike]) -> None:
def update_catalog(self, metadata: TableLike) -> None:
"""
Update the metadata for the recording with the given id.
Update the catalog metadata for one or more recordings.
The updates are provided as a pyarrow Table or RecordBatch containing the metadata to update.
The Table must contain an 'id' column, which is used to specify the recording to update for each row.
Parameters
----------
id : str
The id of the recording to update.
metadata : dict[str, MetadataLike]
A dictionary where the keys are the metadata columns and the values are pyarrow arrays.
metadata : Table | RecordBatch
A pyarrow Table or RecordBatch containing the metadata to update.
"""
...
Expand Down
5 changes: 4 additions & 1 deletion rerun_py/rerun_bindings/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,7 @@
This can be any numpy-compatible array of integers, or a [`pa.Int64Array`][]
"""

MetadataLike: TypeAlias = pa.Array
TableLike: TypeAlias = Union[pa.Table, pa.RecordBatch, pa.RecordBatchReader]
"""
A type alias for TableLike pyarrow objects.
"""
158 changes: 53 additions & 105 deletions rerun_py/src/remote.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#![allow(unsafe_op_in_unsafe_fn)]
use arrow::{
array::{ArrayData, RecordBatch, RecordBatchIterator, RecordBatchReader},
array::{RecordBatch, RecordBatchIterator, RecordBatchReader},
datatypes::Schema,
ffi_stream::ArrowArrayStreamReader,
pyarrow::PyArrowType,
};
// False positive due to #[pyfunction] macro
use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyDict, Bound, PyResult};
use pyo3::{exceptions::PyRuntimeError, prelude::*, Bound, PyResult};
use re_chunk::{Chunk, TransportChunk};
use re_chunk_store::ChunkStore;
use re_dataframe::ChunkStoreHandle;
Expand Down Expand Up @@ -134,69 +135,47 @@ impl PyStorageNodeClient {
/// ----------
/// storage_url : str
/// The URL to the storage location.
/// metadata : dict[str, MetadataLike]
/// A dictionary where the keys are the metadata columns and the values are pyarrow arrays.
/// metadata : Optional[Table | RecordBatch]
/// A pyarrow Table or RecordBatch containing the metadata to update.
/// This Table must contain only a single row.
#[pyo3(signature = (
storage_url,
metadata = None
))]
fn register(
&mut self,
storage_url: &str,
metadata: Option<&Bound<'_, PyDict>>,
) -> PyResult<String> {
fn register(&mut self, storage_url: &str, metadata: Option<MetadataLike>) -> PyResult<String> {
self.runtime.block_on(async {
let storage_url = url::Url::parse(storage_url)
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;

let _obj = object_store::ObjectStoreScheme::parse(&storage_url)
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;

let payload = metadata
let metadata = metadata
.map(|metadata| {
let (schema, data): (
Vec<arrow2::datatypes::Field>,
Vec<Box<dyn arrow2::array::Array>>,
) = metadata
.iter()
.map(|(key, value)| {
let key = key.to_string();
let value = value.extract::<MetadataLike>()?;
let value_array = value.to_arrow2()?;
let field = arrow2::datatypes::Field::new(
key,
value_array.data_type().clone(),
true,
);
Ok((field, value_array))
})
.collect::<PyResult<Vec<_>>>()?
.into_iter()
.unzip();

let schema = arrow2::datatypes::Schema::from(schema);
let data = arrow2::chunk::Chunk::new(data);

let metadata_tc = TransportChunk {
schema: schema.clone(),
data,
};
let metadata = metadata.into_record_batch()?;

if metadata.num_rows() != 1 {
return Err(PyRuntimeError::new_err(
"Metadata must contain exactly one row",
));
}

let metadata_tc = TransportChunk::from_arrow_record_batch(&metadata);

encode(EncoderVersion::V0, metadata_tc)
.map_err(|err| PyRuntimeError::new_err(err.to_string()))
})
.transpose()?
// TODO(zehiko) this is going away soon
.ok_or(PyRuntimeError::new_err("No metadata"))?;
.map(|payload| DataframePart {
encoder_version: EncoderVersion::V0 as i32,
payload,
});

let request = RegisterRecordingRequest {
// TODO(jleibs): Description should really just be in the metadata
description: Default::default(),
storage_url: storage_url.to_string(),
metadata: Some(DataframePart {
encoder_version: EncoderVersion::V0 as i32,
payload,
}),
metadata,
typ: RecordingType::Rrd.into(),
};

Expand Down Expand Up @@ -226,48 +205,33 @@ impl PyStorageNodeClient {
})
}

/// Update the metadata for the recording with the given id.
/// Update the catalog metadata for one or more recordings.
///
/// The updates are provided as a pyarrow Table or RecordBatch containing the metadata to update.
/// The Table must contain an 'id' column, which is used to specify the recording to update for each row.
///
/// Parameters
/// ----------
/// id : str
/// The id of the recording to update.
/// metadata : dict[str, MetadataLike]
/// A dictionary where the keys are the metadata columns and the values are pyarrow arrays.
/// metadata : Table | RecordBatch
/// A pyarrow Table or RecordBatch containing the metadata to update.
#[pyo3(signature = (
id,
metadata
))]
fn update_catalog(&mut self, id: &str, metadata: &Bound<'_, PyDict>) -> PyResult<()> {
#[allow(clippy::needless_pass_by_value)]
fn update_catalog(&mut self, metadata: MetadataLike) -> PyResult<()> {
self.runtime.block_on(async {
let (schema, data): (
Vec<arrow2::datatypes::Field>,
Vec<Box<dyn arrow2::array::Array>>,
) = metadata
.iter()
.map(|(key, value)| {
let key = key.to_string();
let value = value.extract::<MetadataLike>()?;
let value_array = value.to_arrow2()?;
let field =
arrow2::datatypes::Field::new(key, value_array.data_type().clone(), true);
Ok((field, value_array))
})
.collect::<PyResult<Vec<_>>>()?
.into_iter()
.unzip();
let metadata = metadata.into_record_batch()?;

let schema = arrow2::datatypes::Schema::from(schema);

let data = arrow2::chunk::Chunk::new(data);
// TODO(jleibs): This id name should probably come from `re_protos`
if metadata.schema().column_with_name("id").is_none() {
return Err(PyRuntimeError::new_err(
"Metadata must contain an 'id' column",
));
}

let metadata_tc = TransportChunk {
schema: schema.clone(),
data,
};
let metadata_tc = TransportChunk::from_arrow_record_batch(&metadata);

let request = UpdateCatalogRequest {
recording_id: Some(RecordingId { id: id.to_owned() }),
metadata: Some(DataframePart {
encoder_version: EncoderVersion::V0 as i32,
payload: encode(EncoderVersion::V0, metadata_tc)
Expand Down Expand Up @@ -363,39 +327,23 @@ impl PyStorageNodeClient {
/// A type alias for metadata.
#[derive(FromPyObject)]
enum MetadataLike {
PyArrow(PyArrowType<ArrayData>),
// TODO(jleibs): Support converting other primitives
RecordBatch(PyArrowType<RecordBatch>),
Reader(PyArrowType<ArrowArrayStreamReader>),
}

impl MetadataLike {
fn to_arrow2(&self) -> PyResult<Box<dyn re_chunk::Arrow2Array>> {
match self {
Self::PyArrow(array) => {
let array = arrow2::array::from_data(&array.0);
if array.len() == 1 {
Ok(array)
} else {
Err(PyRuntimeError::new_err(
"Metadata must be a single array, not a list",
))
}
}
}
}

#[allow(dead_code)]
fn to_arrow(&self) -> PyResult<std::sync::Arc<dyn arrow::array::Array>> {
match self {
Self::PyArrow(array) => {
let array = arrow::array::make_array(array.0.clone());
if array.len() == 1 {
Ok(array)
} else {
Err(PyRuntimeError::new_err(
"Metadata must be a single array, not a list",
))
}
}
}
fn into_record_batch(self) -> PyResult<RecordBatch> {
let (schema, batches) = match self {
Self::RecordBatch(record_batch) => (record_batch.0.schema(), vec![record_batch.0]),
Self::Reader(reader) => (
reader.0.schema(),
reader.0.collect::<Result<Vec<_>, _>>().map_err(|err| {
PyRuntimeError::new_err(format!("Failed to read RecordBatches: {err}"))
})?,
),
};

arrow::compute::concat_batches(&schema, &batches)
.map_err(|err| PyRuntimeError::new_err(err.to_string()))
}
}

0 comments on commit 1a1440d

Please sign in to comment.