From 83d73f375703d345527238fa980ab4fa4ec1b9d9 Mon Sep 17 00:00:00 2001 From: Emil Ernerfeldt Date: Mon, 9 Dec 2024 16:47:59 +0100 Subject: [PATCH] Port deserializer to arrow1 --- .../re_types_builder/src/codegen/rust/api.rs | 53 +++++----- .../src/codegen/rust/deserializer.rs | 96 +++++++++---------- 2 files changed, 70 insertions(+), 79 deletions(-) diff --git a/crates/build/re_types_builder/src/codegen/rust/api.rs b/crates/build/re_types_builder/src/codegen/rust/api.rs index 058dce7b7a89e..b8f891c40bad0 100644 --- a/crates/build/re_types_builder/src/codegen/rust/api.rs +++ b/crates/build/re_types_builder/src/codegen/rust/api.rs @@ -181,7 +181,7 @@ fn generate_object_file( code.push_str("\n\n"); - code.push_str("use ::re_types_core::external::arrow2;\n"); + code.push_str("use ::re_types_core::external::arrow;\n"); code.push_str("use ::re_types_core::SerializationResult;\n"); code.push_str("use ::re_types_core::{DeserializationResult, DeserializationError};\n"); code.push_str("use ::re_types_core::{ComponentDescriptor, ComponentName};\n"); @@ -880,8 +880,8 @@ fn quote_trait_impls_for_datatype_or_component( } }; - let quoted_from_arrow2 = if optimize_for_buffer_slice { - let from_arrow2_body = if let Some(forwarded_type) = forwarded_type.as_ref() { + let quoted_from_arrow = if optimize_for_buffer_slice { + let from_arrow_body = if let Some(forwarded_type) = forwarded_type.as_ref() { let is_pod = obj .try_get_attr::(ATTR_RUST_DERIVE) .map_or(false, |d| d.contains("bytemuck::Pod")) @@ -890,11 +890,11 @@ fn quote_trait_impls_for_datatype_or_component( .map_or(false, |d| d.contains("bytemuck::Pod")); if is_pod { quote! { - #forwarded_type::from_arrow2(arrow_data).map(bytemuck::cast_vec) + #forwarded_type::from_arrow(arrow_data).map(bytemuck::cast_vec) } } else { quote! { - #forwarded_type::from_arrow2(arrow_data).map(|v| v.into_iter().map(Self).collect()) + #forwarded_type::from_arrow(arrow_data).map(|v| v.into_iter().map(Self).collect()) } } } else { @@ -906,14 +906,13 @@ fn quote_trait_impls_for_datatype_or_component( // re_tracing::profile_function!(); #![allow(clippy::wildcard_imports)] - use arrow::datatypes::*; - use arrow2::{ array::*, buffer::*}; - use ::re_types_core::{Loggable as _, ResultExt as _}; + use arrow::{array::*, buffer::*, datatypes::*}; + use ::re_types_core::{arrow_zip_validity::ZipValidity, Loggable as _, ResultExt as _}; - // This code-path cannot have null fields. If it does have a validity mask - // all bits must indicate valid data. - if let Some(validity) = arrow_data.validity() { - if validity.unset_bits() != 0 { + // This code-path cannot have null fields. + // If it does have a nulls-array, all bits must indicate valid data. + if let Some(nulls) = arrow_data.nulls() { + if nulls.null_count() != 0 { return Err(DeserializationError::missing_data()); } } @@ -924,13 +923,13 @@ fn quote_trait_impls_for_datatype_or_component( quote! { #[inline] - fn from_arrow2( - arrow_data: &dyn arrow2::array::Array, + fn from_arrow( + arrow_data: &dyn arrow::array::Array, ) -> DeserializationResult> where Self: Sized { - #from_arrow2_body + #from_arrow_body } } } else { @@ -940,7 +939,7 @@ fn quote_trait_impls_for_datatype_or_component( // Forward deserialization to existing datatype if it's transparent. let quoted_deserializer = if let Some(forwarded_type) = forwarded_type.as_ref() { quote! { - #forwarded_type::from_arrow2_opt(arrow_data).map(|v| v.into_iter().map(|v| v.map(Self)).collect()) + #forwarded_type::from_arrow_opt(arrow_data).map(|v| v.into_iter().map(|v| v.map(Self)).collect()) } } else { let quoted_deserializer = quote_arrow_deserializer(arrow_registry, objects, obj); @@ -949,9 +948,9 @@ fn quote_trait_impls_for_datatype_or_component( // re_tracing::profile_function!(); #![allow(clippy::wildcard_imports)] - use arrow::datatypes::*; - use arrow2::{ array::*, buffer::*}; - use ::re_types_core::{Loggable as _, ResultExt as _}; + use arrow::{array::*, buffer::*, datatypes::*}; + use ::re_types_core::{arrow_zip_validity::ZipValidity, Loggable as _, ResultExt as _}; + Ok(#quoted_deserializer) } }; @@ -1019,8 +1018,8 @@ fn quote_trait_impls_for_datatype_or_component( #quoted_serializer // NOTE: Don't inline this, this gets _huge_. - fn from_arrow2_opt( - arrow_data: &dyn arrow2::array::Array, + fn from_arrow_opt( + arrow_data: &dyn arrow::array::Array, ) -> DeserializationResult>> where Self: Sized @@ -1028,7 +1027,7 @@ fn quote_trait_impls_for_datatype_or_component( #quoted_deserializer } - #quoted_from_arrow2 + #quoted_from_arrow } } } @@ -1227,7 +1226,7 @@ fn quote_trait_impls_for_archetype(obj: &Object) -> TokenStream { quote! { if let Some(array) = arrays_by_name.get(#field_typ_fqname_str) { - <#component>::from_arrow2_opt(&**array) + <#component>::from_arrow_opt(&**array) .with_context(#obj_field_fqname)? #quoted_collection } else { @@ -1238,7 +1237,7 @@ fn quote_trait_impls_for_archetype(obj: &Object) -> TokenStream { quote! { if let Some(array) = arrays_by_name.get(#field_typ_fqname_str) { Some({ - <#component>::from_arrow2_opt(&**array) + <#component>::from_arrow_opt(&**array) .with_context(#obj_field_fqname)? #quoted_collection }) @@ -1253,7 +1252,7 @@ fn quote_trait_impls_for_archetype(obj: &Object) -> TokenStream { .ok_or_else(DeserializationError::missing_data) .with_context(#obj_field_fqname)?; - <#component>::from_arrow2_opt(&**array).with_context(#obj_field_fqname)? #quoted_collection + <#component>::from_arrow_opt(&**array).with_context(#obj_field_fqname)? #quoted_collection }} }; @@ -1323,10 +1322,10 @@ fn quote_trait_impls_for_archetype(obj: &Object) -> TokenStream { } #[inline] - fn from_arrow2_components( + fn from_arrow_components( arrow_data: impl IntoIterator, + arrow::array::ArrayRef, )>, ) -> DeserializationResult { re_tracing::profile_function!(); diff --git a/crates/build/re_types_builder/src/codegen/rust/deserializer.rs b/crates/build/re_types_builder/src/codegen/rust/deserializer.rs index 25880acd6fd01..45fe4925e0f1a 100644 --- a/crates/build/re_types_builder/src/codegen/rust/deserializer.rs +++ b/crates/build/re_types_builder/src/codegen/rust/deserializer.rs @@ -18,9 +18,9 @@ use crate::{ /// This short-circuits on error using the `try` (`?`) operator: the outer scope must be one that /// returns a `Result<_, DeserializationError>`! /// -/// There is a 1:1 relationship between `quote_arrow_deserializer` and `Loggable::from_arrow2_opt`: +/// There is a 1:1 relationship between `quote_arrow_deserializer` and `Loggable::from_arrow_opt`: /// ```ignore -/// fn from_arrow2_opt(data: &dyn ::arrow2::array::Array) -> DeserializationResult>> { +/// fn from_arrow_opt(data: &dyn ::arrow::array::Array) -> DeserializationResult>> { /// Ok(#quoted_deserializer) /// } /// ``` @@ -56,7 +56,7 @@ pub fn quote_arrow_deserializer( objects: &Objects, obj: &Object, ) -> TokenStream { - // Runtime identifier of the variable holding the Arrow payload (`&dyn ::arrow2::array::Array`). + // Runtime identifier of the variable holding the Arrow payload (`&dyn ::arrow::array::Array`). let data_src = format_ident!("arrow_data"); let datatype = &arrow_registry.get(&obj.fqname); @@ -236,7 +236,7 @@ pub fn quote_arrow_deserializer( }); let quoted_downcast = { - let cast_as = quote!(arrow2::array::StructArray); + let cast_as = quote!(arrow::array::StructArray); quote_array_downcast(obj_fqname, &data_src, cast_as, "ed_self_datatype) }; quote! {{ @@ -248,19 +248,19 @@ pub fn quote_arrow_deserializer( // datastructures for all of our children. Vec::new() } else { - let (#data_src_fields, #data_src_arrays) = (#data_src.fields(), #data_src.values()); + let (#data_src_fields, #data_src_arrays) = (#data_src.fields(), #data_src.columns()); let arrays_by_name: ::std::collections::HashMap<_, _> = #data_src_fields .iter() - .map(|field| field.name.as_str()) + .map(|field| field.name().as_str()) .zip(#data_src_arrays) .collect(); #(#quoted_field_deserializers;)* - arrow2::bitmap::utils::ZipValidity::new_with_validity( + ZipValidity::new_with_validity( ::itertools::izip!(#(#quoted_field_names),*), - #data_src.validity(), + #data_src.nulls(), ) .map(|opt| opt.map(|(#(#quoted_field_names),*)| Ok(Self { #(#quoted_unwrappings,)* })).transpose()) // NOTE: implicit Vec to Result @@ -274,7 +274,7 @@ pub fn quote_arrow_deserializer( // We use sparse arrow unions for c-style enums, which means only 8 bits is required for each field, // and nulls are encoded with a special 0-index `_null_markers` variant. - let data_src_types = format_ident!("{data_src}_types"); + let data_src_types = format_ident!("{data_src}_type_ids"); let obj_fqname = obj.fqname.as_str(); let quoted_branches = obj.fields.iter().enumerate().map(|(typ, obj_field)| { @@ -287,13 +287,13 @@ pub fn quote_arrow_deserializer( }); let quoted_downcast = { - let cast_as = quote!(arrow2::array::UnionArray); + let cast_as = quote!(arrow::array::UnionArray); quote_array_downcast(obj_fqname, &data_src, &cast_as, "ed_self_datatype) }; quote! {{ let #data_src = #quoted_downcast?; - let #data_src_types = #data_src.types(); + let #data_src_types = #data_src.type_ids(); #data_src_types .iter() @@ -319,8 +319,7 @@ pub fn quote_arrow_deserializer( // We use dense arrow unions for proper sum-type unions. // Nulls are encoded with a special 0-index `_null_markers` variant. - let data_src_types = format_ident!("{data_src}_types"); - let data_src_arrays = format_ident!("{data_src}_arrays"); + let data_src_type_ids = format_ident!("{data_src}_type_ids"); let data_src_offsets = format_ident!("{data_src}_offsets"); let quoted_field_deserializers = obj @@ -349,11 +348,11 @@ pub fn quote_arrow_deserializer( quote! { let #data_dst = { - // NOTE: `data_src_arrays` is a runtime collection of all of the + // NOTE: `data_src` is a runtime collection of all of the // input's payload's union arms, while `#type_id` is our comptime union // arm counter… there's no guarantee it's actually there at // runtime! - if #data_src_arrays.len() <= #type_id { + if #data_src.type_ids().inner().len() <= #type_id { // By not returning an error but rather defaulting to an empty // vector, we introduce some kind of light forwards compatibility: // old clients that don't yet know about the new arms can still @@ -365,8 +364,8 @@ pub fn quote_arrow_deserializer( // )).with_context(#obj_fqname); } - // NOTE: The array indexing is safe: checked above. - let #data_src = &*#data_src_arrays[#type_id]; + // NOTE: indexing is safe: checked above. + let #data_src = #data_src.child(#type_id).as_ref(); #quoted_deserializer.collect::>() } } @@ -417,7 +416,7 @@ pub fn quote_arrow_deserializer( }); let quoted_downcast = { - let cast_as = quote!(arrow2::array::UnionArray); + let cast_as = quote!(arrow::array::UnionArray); quote_array_downcast(obj_fqname, &data_src, &cast_as, "ed_self_datatype) }; @@ -430,7 +429,7 @@ pub fn quote_arrow_deserializer( // datastructures for all of our children. Vec::new() } else { - let (#data_src_types, #data_src_arrays) = (#data_src.types(), #data_src.fields()); + let #data_src_type_ids = #data_src.type_ids(); let #data_src_offsets = #data_src.offsets() // NOTE: expected dense union, got a sparse one instead @@ -440,16 +439,16 @@ pub fn quote_arrow_deserializer( DeserializationError::datatype_mismatch(expected, actual) }).with_context(#obj_fqname)?; - if #data_src_types.len() != #data_src_offsets.len() { + if #data_src_type_ids.len() != #data_src_offsets.len() { // NOTE: need one offset array per union arm! return Err(DeserializationError::offset_slice_oob( - (0, #data_src_types.len()), #data_src_offsets.len(), + (0, #data_src_type_ids.len()), #data_src_offsets.len(), )).with_context(#obj_fqname); } #(#quoted_field_deserializers;)* - #data_src_types + #data_src_type_ids .iter() .enumerate() .map(|(i, typ)| { @@ -495,9 +494,9 @@ enum InnerRepr { /// /// The `datatype` comes from our compile-time Arrow registry, not from the runtime payload! /// If the datatype happens to be a struct or union, this will merely inject a runtime call to -/// `Loggable::from_arrow2_opt` and call it a day, preventing code bloat. +/// `Loggable::from_arrow_opt` and call it a day, preventing code bloat. /// -/// `data_src` is the runtime identifier of the variable holding the Arrow payload (`&dyn ::arrow2::array::Array`). +/// `data_src` is the runtime identifier of the variable holding the Arrow payload (`&dyn ::arrow::array::Array`). /// The returned `TokenStream` always instantiates a `Vec>`. /// /// This short-circuits on error using the `try` (`?`) operator: the outer scope must be one that @@ -509,7 +508,7 @@ fn quote_arrow_field_deserializer( quoted_datatype: &TokenStream, is_nullable: bool, obj_field_fqname: &str, - data_src: &proc_macro2::Ident, // &dyn ::arrow2::array::Array + data_src: &proc_macro2::Ident, // &dyn ::arrow::array::Array inner_repr: InnerRepr, ) -> TokenStream { _ = is_nullable; // not yet used, will be needed very soon @@ -518,7 +517,7 @@ fn quote_arrow_field_deserializer( if let DataType::Extension(fqname, _, _) = datatype { if objects.get(fqname).map_or(false, |obj| obj.is_enum()) { let fqname_use = quote_fqname_as_type_path(fqname); - return quote!(#fqname_use::from_arrow2_opt(#data_src).with_context(#obj_field_fqname)?.into_iter()); + return quote!(#fqname_use::from_arrow_opt(#data_src).with_context(#obj_field_fqname)?.into_iter()); } } @@ -538,11 +537,6 @@ fn quote_arrow_field_deserializer( | DataType::Null => { let quoted_iter_transparency = quote_iterator_transparency(objects, datatype, IteratorKind::OptionValue, None); - let quoted_iter_transparency = if *datatype.to_logical_type() == DataType::Boolean { - quoted_iter_transparency - } else { - quote!(.map(|opt| opt.copied()) #quoted_iter_transparency) - }; let quoted_downcast = { let cast_as = format!("{:?}", datatype.to_logical_type()).replace("DataType::", ""); @@ -565,7 +559,7 @@ fn quote_arrow_field_deserializer( DataType::Utf8 => { let quoted_downcast = { - let cast_as = quote!(arrow2::array::Utf8Array); + let cast_as = quote!(StringArray); quote_array_downcast(obj_field_fqname, data_src, cast_as, quoted_datatype) }; @@ -583,9 +577,9 @@ fn quote_arrow_field_deserializer( let #data_src_buf = #data_src.values(); let offsets = #data_src.offsets(); - arrow2::bitmap::utils::ZipValidity::new_with_validity( + ZipValidity::new_with_validity( offsets.windows(2), - #data_src.validity(), + #data_src.nulls(), ) .map(|elem| elem.map(|window| { // NOTE: Do _not_ use `Buffer::sliced`, it panics on malformed inputs. @@ -603,10 +597,8 @@ fn quote_arrow_field_deserializer( (start, end), #data_src_buf.len(), )); } - // Safety: all checked above. - #[allow(unsafe_code, clippy::undocumented_unsafe_blocks)] - // NOTE: The `clone` is a `Buffer::clone`, which is just a refcount bump. - let data = unsafe { #data_src_buf.clone().sliced_unchecked(start, len) }; + #[allow(unsafe_code, clippy::undocumented_unsafe_blocks)] // TODO: unsafe slice again + let data = #data_src_buf.slice_with_length(start, len); Ok(data) }).transpose() @@ -632,7 +624,7 @@ fn quote_arrow_field_deserializer( ); let quoted_downcast = { - let cast_as = quote!(arrow2::array::FixedSizeListArray); + let cast_as = quote!(arrow::array::FixedSizeListArray); quote_array_downcast(obj_field_fqname, data_src, cast_as, quoted_datatype) }; @@ -662,7 +654,7 @@ fn quote_arrow_field_deserializer( #quoted_inner.collect::>() }; - arrow2::bitmap::utils::ZipValidity::new_with_validity(offsets, #data_src.validity()) + ZipValidity::new_with_validity(offsets, #data_src.nulls()) .map(|elem| elem.map(|(start, end): (usize, usize)| { // NOTE: Do _not_ use `Buffer::sliced`, it panics on malformed inputs. @@ -673,7 +665,7 @@ fn quote_arrow_field_deserializer( // NOTE: It is absolutely crucial we explicitly handle the // boundchecks manually first, otherwise rustc completely chokes // when slicing the data (as in: a 100x perf drop)! - if end > #data_src_inner.len() { + if #data_src_inner.len() < end { // error context is appended below during final collection return Err(DeserializationError::offset_slice_oob( (start, end), #data_src_inner.len(), @@ -746,7 +738,7 @@ fn quote_arrow_field_deserializer( ); let quoted_downcast = { - let cast_as = quote!(arrow2::array::ListArray); + let cast_as = quote!(arrow::array::ListArray); quote_array_downcast(obj_field_fqname, data_src, cast_as, quoted_datatype) }; let quoted_collect_inner = match inner_repr { @@ -757,8 +749,8 @@ fn quote_arrow_field_deserializer( let quoted_inner_data_range = match inner_repr { InnerRepr::BufferT => { quote! { - #[allow(unsafe_code, clippy::undocumented_unsafe_blocks)] - let data = unsafe { #data_src_inner.clone().sliced_unchecked(start, end - start) }; + #[allow(unsafe_code, clippy::undocumented_unsafe_blocks)] // TODO: unsafe + let data = #data_src_inner.clone().slice(start, end - start); let data = ::re_types_core::ArrowBuffer::from(data); } } @@ -813,9 +805,9 @@ fn quote_arrow_field_deserializer( }; let offsets = #data_src.offsets(); - arrow2::bitmap::utils::ZipValidity::new_with_validity( + ZipValidity::new_with_validity( offsets.windows(2), - #data_src.validity(), + #data_src.nulls(), ) .map(|elem| elem.map(|window| { // NOTE: Do _not_ use `Buffer::sliced`, it panics on malformed inputs. @@ -850,7 +842,7 @@ fn quote_arrow_field_deserializer( unreachable!() }; let fqname_use = quote_fqname_as_type_path(fqname); - quote!(#fqname_use::from_arrow2_opt(#data_src).with_context(#obj_field_fqname)?.into_iter()) + quote!(#fqname_use::from_arrow_opt(#data_src).with_context(#obj_field_fqname)?.into_iter()) } _ => unimplemented!("{datatype:#?}"), @@ -994,7 +986,7 @@ fn quote_iterator_transparency( /// /// There is a 1:1 relationship between `quote_arrow_deserializer_buffer_slice` and `Loggable::from_arrow`: /// ```ignore -/// fn from_arrow(data: &dyn ::arrow2::array::Array) -> DeserializationResult> { +/// fn from_arrow(data: &dyn ::arrow::array::Array) -> DeserializationResult> { /// Ok(#quoted_deserializer_) /// } /// ``` @@ -1005,7 +997,7 @@ pub fn quote_arrow_deserializer_buffer_slice( objects: &Objects, obj: &Object, ) -> TokenStream { - // Runtime identifier of the variable holding the Arrow payload (`&dyn ::arrow2::array::Array`). + // Runtime identifier of the variable holding the Arrow payload (`&dyn ::arrow::array::Array`). let data_src = format_ident!("arrow_data"); let datatype = &arrow_registry.get(&obj.fqname); @@ -1077,7 +1069,7 @@ fn quote_arrow_field_deserializer_buffer_slice( datatype: &DataType, is_nullable: bool, obj_field_fqname: &str, - data_src: &proc_macro2::Ident, // &dyn ::arrow2::array::Array + data_src: &proc_macro2::Ident, // &dyn ::arrow::array::Array ) -> TokenStream { _ = is_nullable; // not yet used, will be needed very soon @@ -1107,7 +1099,7 @@ fn quote_arrow_field_deserializer_buffer_slice( quote! { #quoted_downcast? .values() - .as_slice() + .as_ref() } } @@ -1121,7 +1113,7 @@ fn quote_arrow_field_deserializer_buffer_slice( ); let quoted_downcast = { - let cast_as = quote!(arrow2::array::FixedSizeListArray); + let cast_as = quote!(arrow::array::FixedSizeListArray); quote_array_downcast( obj_field_fqname, data_src,