From 69c0c8cf4774cdaddceb8458f6ab9fed25a8ee7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Medina?= Date: Sat, 7 Dec 2024 00:34:05 -0800 Subject: [PATCH] Add `flatten` attribute to derive SerializeRow Currently only the `match_by_name` flavor is supported. All the needed structs/traits to make this work are marked as `#[doc(hidden)]` to not increase the public API surface. Effort was done to not change any of the existing API. --- scylla-cql/src/lib.rs | 116 +++++++++++++ scylla-cql/src/types/serialize/row.rs | 58 ++++++- scylla-macros/src/serialize/row.rs | 234 +++++++++++++++++--------- scylla/src/lib.rs | 2 +- scylla/src/macros.rs | 8 + 5 files changed, 338 insertions(+), 80 deletions(-) diff --git a/scylla-cql/src/lib.rs b/scylla-cql/src/lib.rs index 09a4e56d7..74be72698 100644 --- a/scylla-cql/src/lib.rs +++ b/scylla-cql/src/lib.rs @@ -79,4 +79,120 @@ pub mod _macro_internal { pub use crate::types::serialize::{ CellValueBuilder, CellWriter, RowWriter, SerializationError, }; + + pub mod ser { + pub mod row { + pub use crate::{ + frame::response::result::ColumnSpec, + types::serialize::{ + row::{ + mk_ser_err, mk_typck_err, BuiltinSerializationErrorKind, + BuiltinTypeCheckError, BuiltinTypeCheckErrorKind, RowSerializationContext, + }, + value::SerializeValue, + writers::WrittenCellProof, + RowWriter, SerializationError, + }, + }; + + /// Whether a field used a column to finish its serialization or not + /// + /// Used when serializing by name as a single column may not have finished a rust + /// field in the case of a flattened struct + /// + /// For now this enum is an implementation detail of `#[derive(SerializeRow)]` when + /// serializing by name + #[derive(Debug)] + #[doc(hidden)] + pub enum FieldStatus { + /// The column finished the serialization for this field + Done, + /// The column was used but there are other fields not yet serialized + NotDone, + /// The column did not belong to this field + NotUsed, + } + + /// Represents a set of values that can be sent along a CQL statement when serializing by name + /// + /// For now this trait is an implementation detail of `#[derive(SerializeRow)]` when + /// serializing by name + #[doc(hidden)] + pub trait SerializeRowByName { + /// A type that can handle serialization of this struct column-by-column + type Partial<'d>: PartialSerializeRowByName + where + Self: 'd; + + /// Returns a type that can serialize this row "column-by-column" + fn partial(&self) -> Self::Partial<'_>; + } + + /// How to serialize a row column-by-column + /// + /// For now this trait is an implementation detail of `#[derive(SerializeRow)]` when + /// serializing by name + #[doc(hidden)] + pub trait PartialSerializeRowByName { + /// Serializes a single column in the row according to the information in the + /// given context + /// + /// It returns whether the column finished the serialization of the struct, did + /// it partially, none of at all, or errored + fn serialize_field( + &mut self, + spec: &ColumnSpec, + writer: &mut RowWriter<'_>, + ) -> Result; + + /// Checks if there are any missing columns to finish the serialization + fn check_missing(self) -> Result<(), SerializationError>; + } + + pub struct ByName<'t, T: SerializeRowByName>(pub &'t T); + + impl ByName<'_, T> { + /// Serializes all the fields/columns by name + pub fn serialize( + self, + ctx: &RowSerializationContext, + writer: &mut RowWriter<'_>, + ) -> Result<(), SerializationError> { + let mut partial = self.0.partial(); + + for spec in ctx.columns() { + let serialized = partial.serialize_field(spec, writer)?; + + if matches!(serialized, FieldStatus::NotUsed) { + return Err(mk_typck_err::( + BuiltinTypeCheckErrorKind::NoColumnWithName { + name: spec.name().to_owned(), + }, + )); + } + } + + partial.check_missing()?; + + Ok(()) + } + } + + pub fn serialize_column<'b, T>( + value: &impl SerializeValue, + spec: &ColumnSpec, + writer: &'b mut RowWriter<'_>, + ) -> Result, SerializationError> { + let sub_writer = writer.make_cell_writer(); + value.serialize(spec.typ(), sub_writer).map_err(|err| { + super::row::mk_ser_err::( + BuiltinSerializationErrorKind::ColumnSerializationFailed { + name: spec.name().to_owned(), + err, + }, + ) + }) + } + } + } } diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index cf05398fd..ee3192be6 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -572,7 +572,8 @@ pub struct BuiltinTypeCheckError { pub kind: BuiltinTypeCheckErrorKind, } -fn mk_typck_err(kind: impl Into) -> SerializationError { +#[doc(hidden)] +pub fn mk_typck_err(kind: impl Into) -> SerializationError { mk_typck_err_named(std::any::type_name::(), kind) } @@ -598,7 +599,8 @@ pub struct BuiltinSerializationError { pub kind: BuiltinSerializationErrorKind, } -fn mk_ser_err(kind: impl Into) -> SerializationError { +#[doc(hidden)] +pub fn mk_ser_err(kind: impl Into) -> SerializationError { mk_ser_err_named(std::any::type_name::(), kind) } @@ -1677,4 +1679,56 @@ pub(crate) mod tests { assert_eq!(reference, row); } + + #[test] + fn test_row_serialization_nested_structs() { + #[derive(SerializeRow, Debug)] + #[scylla(crate = crate)] + struct InnerColumnsOne { + x: i32, + y: f64, + } + + #[derive(SerializeRow, Debug)] + #[scylla(crate = crate)] + struct InnerColumnsTwo { + z: bool, + } + + #[derive(SerializeRow, Debug)] + #[scylla(crate = crate)] + struct OuterColumns { + #[scylla(flatten)] + inner_one: InnerColumnsOne, + a: String, + #[scylla(flatten)] + inner_two: InnerColumnsTwo, + } + + let spec = [ + col("a", ColumnType::Text), + col("x", ColumnType::Int), + col("z", ColumnType::Boolean), + col("y", ColumnType::Double), + ]; + + let value = OuterColumns { + inner_one: InnerColumnsOne { x: 5, y: 1.0 }, + a: "something".to_owned(), + inner_two: InnerColumnsTwo { z: true }, + }; + + let reference = do_serialize( + ( + &value.a, + &value.inner_one.x, + &value.inner_two.z, + &value.inner_one.y, + ), + &spec, + ); + let row = do_serialize(value, &spec); + + assert_eq!(reference, row); + } } diff --git a/scylla-macros/src/serialize/row.rs b/scylla-macros/src/serialize/row.rs index a1695fa57..f4cdfea95 100644 --- a/scylla-macros/src/serialize/row.rs +++ b/scylla-macros/src/serialize/row.rs @@ -2,7 +2,6 @@ use std::collections::HashMap; use darling::FromAttributes; use proc_macro::TokenStream; -use proc_macro2::Span; use syn::parse_quote; use crate::Flavor; @@ -55,6 +54,11 @@ struct FieldAttributes { // instead of the Rust field name. rename: Option, + // If set, then this field's columns are serialized using its own implementation + // of `SerializeRow` and flattened as if they were fields in this struct. + #[darling(default)] + flatten: bool, + // If true, then the field is not serialized at all, but simply ignored. // All other attributes are ignored. #[darling(default)] @@ -64,6 +68,8 @@ struct FieldAttributes { struct Context { attributes: Attributes, fields: Vec, + struct_name: syn::Ident, + generics: syn::Generics, } pub(crate) fn derive_serialize_row(tokens_input: TokenStream) -> Result { @@ -90,7 +96,12 @@ pub(crate) fn derive_serialize_row(tokens_input: TokenStream) -> Result>()?; - let ctx = Context { attributes, fields }; + let ctx = Context { + attributes, + fields, + struct_name: struct_name.clone(), + generics: input.generics.clone(), + }; ctx.validate(&input.ident)?; let gen: Box = match ctx.attributes.flavor { @@ -137,6 +148,30 @@ impl Context { } } + // `flatten` annotations is not yet supported outside of `match_by_name` + if !matches!(self.attributes.flavor, Flavor::MatchByName) { + if let Some(field) = self.fields.iter().find(|f| f.attrs.flatten) { + let err = darling::Error::custom( + "the `flatten` annotations is only supported wit the `match_by_name` flavor", + ) + .with_span(&field.ident); + errors.push(err); + } + } + + // Check that no renames are attempted on flattened fields + let rename_flatten_errors = self + .fields + .iter() + .filter(|f| f.attrs.flatten && f.attrs.rename.is_some()) + .map(|f| { + darling::Error::custom( + "`rename` and `flatten` annotations do not make sense together", + ) + .with_span(&f.ident) + }); + errors.extend(rename_flatten_errors); + // Check for name collisions let mut used_names = HashMap::::new(); for field in self.fields.iter() { @@ -200,94 +235,134 @@ impl Generator for ColumnSortingGenerator<'_> { // Need to: // - Check that all required columns are there and no more // - Check that the column types match - let mut statements: Vec = Vec::new(); let crate_path = self.ctx.attributes.crate_path(); + let struct_name = &self.ctx.struct_name; + let (impl_generics, ty_generics, where_clause) = self.ctx.generics.split_for_impl(); + let partial_struct_name = syn::Ident::new( + &format!("_{}ScyllaSerPartial", struct_name), + struct_name.span(), + ); + let mut partial_generics = self.ctx.generics.clone(); + let partial_lt: syn::LifetimeParam = syn::parse_quote!('scylla_ser_partial); + if !self.ctx.fields.is_empty() { + partial_generics + .params + .push(syn::GenericParam::Lifetime(partial_lt.clone())); + } - let rust_field_idents = self - .ctx - .fields - .iter() - .map(|f| f.ident.clone()) - .collect::>(); - let rust_field_names = self + let (partial_impl_generics, partial_ty_generics, partial_where_clause) = + partial_generics.split_for_impl(); + + let flattened: Vec<_> = self.ctx.fields.iter().filter(|f| f.attrs.flatten).collect(); + let flattened_fields: Vec<_> = flattened.iter().map(|f| &f.ident).collect(); + let flattened_tys: Vec<_> = flattened.iter().map(|f| &f.ty).collect(); + + let unflattened: Vec<_> = self .ctx .fields .iter() - .map(|f| f.column_name()) - .collect::>(); - let udt_field_names = rust_field_names.clone(); // For now, it's the same - let field_types = self.ctx.fields.iter().map(|f| &f.ty).collect::>(); + .filter(|f| !f.attrs.flatten) + .collect(); + let unflattened_columns: Vec<_> = unflattened.iter().map(|f| f.column_name()).collect(); + let unflattened_fields: Vec<_> = unflattened.iter().map(|f| &f.ident).collect(); + let unflattened_tys: Vec<_> = unflattened.iter().map(|f| &f.ty).collect(); + + let all_names = self.ctx.fields.iter().map(|f| f.column_name()); + + let partial_struct: syn::ItemStruct = parse_quote! { + struct #partial_struct_name #partial_generics { + #(#unflattened_fields: &#partial_lt #unflattened_tys,)* + #(#flattened_fields: <#flattened_tys as #crate_path::ser::row::SerializeRowByName>::Partial<#partial_lt>,)* + missing: ::std::collections::HashSet<&'static str>, + } + }; + + let serialize_field_block: syn::Block = if self.ctx.fields.is_empty() { + parse_quote! {{ + ::std::result::Result::Ok(#crate_path::ser::row::FieldStatus::NotUsed) + }} + } else { + parse_quote! {{ + match spec.name() { + #(#unflattened_columns => { + #crate_path::ser::row::serialize_column::<#struct_name #ty_generics>( + &self.#unflattened_fields, spec, writer, + )?; + self.missing.remove(#unflattened_columns); + })* + _ => { + #({ + match self.#flattened_fields.serialize_field(spec, writer)? { + #crate_path::ser::row::FieldStatus::Done => { + self.missing.remove(stringify!(#flattened_fields)); + return ::std::result::Result::Ok(if self.missing.is_empty() { + #crate_path::ser::row::FieldStatus::Done + } else { + #crate_path::ser::row::FieldStatus::NotDone + }); + } + #crate_path::ser::row::FieldStatus::NotDone => { + return ::std::result::Result::Ok(#crate_path::ser::row::FieldStatus::NotDone) + } + #crate_path::ser::row::FieldStatus::NotUsed => {} + }; + })* - // Declare a helper lambda for creating errors - statements.push(self.ctx.generate_mk_typck_err()); - statements.push(self.ctx.generate_mk_ser_err()); + return ::std::result::Result::Ok(#crate_path::ser::row::FieldStatus::NotUsed); + } + } - // Generate a "visited" flag for each field - let visited_flag_names = rust_field_idents - .iter() - .map(|s| syn::Ident::new(&format!("visited_flag_{}", s), Span::call_site())) - .collect::>(); - statements.extend::>(parse_quote! { - #(let mut #visited_flag_names = false;)* - }); + ::std::result::Result::Ok(if self.missing.is_empty() { + #crate_path::ser::row::FieldStatus::Done + } else { + #crate_path::ser::row::FieldStatus::NotDone + }) + }} + }; + + let partial_serialize: syn::ItemImpl = parse_quote! { + impl #partial_impl_generics #crate_path::ser::row::PartialSerializeRowByName for #partial_struct_name #partial_ty_generics #partial_where_clause { + fn serialize_field( + &mut self, + spec: &#crate_path::ColumnSpec, + writer: &mut #crate_path::RowWriter<'_>, + ) -> ::std::result::Result<#crate_path::ser::row::FieldStatus, #crate_path::SerializationError> { + #serialize_field_block + } - // Generate a variable that counts down visited fields. - let field_count = self.ctx.fields.len(); - statements.push(parse_quote! { - let mut remaining_count = #field_count; - }); + fn check_missing(self) -> ::std::result::Result<(), #crate_path::SerializationError> { + use ::std::iter::{Iterator as _, IntoIterator as _}; - // Generate a loop over the fields and a `match` block to match on - // the field name. - statements.push(parse_quote! { - for spec in ctx.columns() { - match spec.name() { - #( - #udt_field_names => { - let sub_writer = #crate_path::RowWriter::make_cell_writer(writer); - match <#field_types as #crate_path::SerializeValue>::serialize(&self.#rust_field_idents, spec.typ(), sub_writer) { - ::std::result::Result::Ok(_proof) => {} - ::std::result::Result::Err(err) => { - return ::std::result::Result::Err(mk_ser_err( - #crate_path::BuiltinRowSerializationErrorKind::ColumnSerializationFailed { - name: <_ as ::std::borrow::ToOwned>::to_owned(spec.name()), - err, - } - )); - } - } - if !#visited_flag_names { - #visited_flag_names = true; - remaining_count -= 1; - } - } - )* - _ => return ::std::result::Result::Err(mk_typck_err( - #crate_path::BuiltinRowTypeCheckErrorKind::NoColumnWithName { - name: <_ as ::std::borrow::ToOwned>::to_owned(spec.name()), - } - )), + let ::std::option::Option::Some(missing) = self.missing.into_iter().nth(0) else { + return ::std::result::Result::Ok(()); + }; + + match missing { + #(stringify!(#flattened_fields) => self.#flattened_fields.check_missing(),)* + _ => ::std::result::Result::Err(#crate_path::ser::row::mk_typck_err::<#struct_name #ty_generics>(#crate_path::BuiltinRowTypeCheckErrorKind::ValueMissingForColumn { + name: <_ as ::std::borrow::ToOwned>::to_owned(missing), + })) + } } } - }); + }; - // Finally, check that all fields were consumed. - // If there are some missing fields, return an error - statements.push(parse_quote! { - if remaining_count > 0 { - #( - if !#visited_flag_names { - return ::std::result::Result::Err(mk_typck_err( - #crate_path::BuiltinRowTypeCheckErrorKind::ValueMissingForColumn { - name: <_ as ::std::string::ToString>::to_string(#rust_field_names), - } - )); + let serialize_by_name: syn::ItemImpl = parse_quote! { + impl #impl_generics #crate_path::ser::row::SerializeRowByName for #struct_name #ty_generics #where_clause { + type Partial<#partial_lt> = #partial_struct_name #partial_ty_generics where Self: #partial_lt; + + fn partial(&self) -> Self::Partial<'_> { + use ::std::iter::FromIterator as _; + + #partial_struct_name { + #(#unflattened_fields: &self.#unflattened_fields,)* + #(#flattened_fields: self.#flattened_fields.partial(),)* + missing: ::std::collections::HashSet::from_iter([#(#all_names,)*]), } - )* - ::std::unreachable!() + } } - }); + }; parse_quote! { fn serialize<'b>( @@ -295,8 +370,13 @@ impl Generator for ColumnSortingGenerator<'_> { ctx: &#crate_path::RowSerializationContext, writer: &mut #crate_path::RowWriter<'b>, ) -> ::std::result::Result<(), #crate_path::SerializationError> { - #(#statements)* - ::std::result::Result::Ok(()) + #partial_struct + #partial_serialize + + #[allow(non_local_definitions)] + #serialize_by_name + + #crate_path::ser::row::ByName(self).serialize(ctx, writer) } } } diff --git a/scylla/src/lib.rs b/scylla/src/lib.rs index 8dc56420a..586e85361 100644 --- a/scylla/src/lib.rs +++ b/scylla/src/lib.rs @@ -81,7 +81,7 @@ //! .query_unpaged("SELECT a, b FROM ks.tab", &[]) //! .await? //! .into_rows_result()?; -//! +//! //! for row in query_rows.rows()? { //! // Parse row as int and text \ //! let (int_val, text_val): (i32, &str) = row?; diff --git a/scylla/src/macros.rs b/scylla/src/macros.rs index 3e75fa1ab..80a3507a6 100644 --- a/scylla/src/macros.rs +++ b/scylla/src/macros.rs @@ -246,6 +246,14 @@ pub use scylla_cql::macros::SerializeValue; /// /// Don't use the field during serialization. /// +/// `#[scylla(flatten)]` +/// +/// Use this field's `SerializeRow` implementation to serialize its columns as part +/// of this struct. Note that the name of this field is ignored and hence the +/// `rename` attribute does not make sense here and will cause a compilation +/// error. Currently this is only supported for the `"match_by_name"` flavor in both +/// the wrapper struct and this flattened struct. +/// /// --- /// pub use scylla_cql::macros::SerializeRow;