diff --git a/scylla-cql/src/macros.rs b/scylla-cql/src/macros.rs index 8f53e24fa9..2b7b0b4ae7 100644 --- a/scylla-cql/src/macros.rs +++ b/scylla-cql/src/macros.rs @@ -16,9 +16,7 @@ pub use scylla_macros::ValueList; /// Derive macro for the [`SerializeCql`](crate::types::serialize::value::SerializeCql) trait /// which serializes given Rust structure as a User Defined Type (UDT). /// -/// At the moment, only structs with named fields are supported. The generated -/// implementation of the trait will match the struct fields to UDT fields -/// by name automatically. +/// At the moment, only structs with named fields are supported. /// /// Serialization will fail if there are some fields in the UDT that don't match /// to any of the Rust struct fields, _or vice versa_. @@ -50,6 +48,21 @@ pub use scylla_macros::ValueList; /// /// # Attributes /// +/// `#[scylla(flavor = "flavor_name")]` +/// +/// Allows to choose one of the possible "flavors", i.e. the way how the +/// generated code will approach serialization. Possible flavors are: +/// +/// - `"match_by_name"` (default) - the generated implementation _does not +/// require_ the fields in the Rust struct to be in the same order as the +/// fields in the UDT. During serialization, the implementation will take +/// care to serialize the fields in the order which the database expects. +/// - `"enforce_order"` - the generated implementation _requires_ the fields +/// in the Rust struct to be in the same order as the fields in the UDT. +/// If the order is incorrect, type checking/serialization will fail. +/// This is a less robust flavor than `"match_by_name"`, but should be +/// slightly more performant as it doesn't need to perform lookups by name. +/// /// `#[scylla(crate = crate_name)]` /// /// By default, the code generated by the derive macro will refer to the items diff --git a/scylla-cql/src/types/serialize/value.rs b/scylla-cql/src/types/serialize/value.rs index 85033dac25..567b59cfab 100644 --- a/scylla-cql/src/types/serialize/value.rs +++ b/scylla-cql/src/types/serialize/value.rs @@ -1314,6 +1314,12 @@ pub enum UdtTypeCheckErrorKind { /// The Rust data contains a field that is not present in the UDT UnexpectedFieldInDestination { field_name: String }, + + /// A different field name was expected at given position. + FieldNameMismatch { + rust_field_name: String, + db_field_name: String, + }, } impl Display for UdtTypeCheckErrorKind { @@ -1337,6 +1343,10 @@ impl Display for UdtTypeCheckErrorKind { f, "the field {field_name} present in the Rust data is not present in the CQL type" ), + UdtTypeCheckErrorKind::FieldNameMismatch { rust_field_name, db_field_name } => write!( + f, + "expected field with name {db_field_name} at given position, but the Rust field name is {rust_field_name}" + ), } } } @@ -1668,4 +1678,164 @@ mod tests { check_with_type(ColumnType::Int, 123_i32, CqlValue::Int(123_i32)); check_with_type(ColumnType::Double, 123_f64, CqlValue::Double(123_f64)); } + + #[derive(SerializeCql, Debug, PartialEq, Eq, Default)] + #[scylla(crate = crate, flavor = "enforce_order")] + struct TestUdtWithEnforcedOrder { + a: String, + b: i32, + c: Vec, + } + + #[test] + fn test_udt_serialization_with_enforced_order_correct_order() { + let typ = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + ], + }; + + let reference = do_serialize( + CqlValue::UserDefinedType { + keyspace: "ks".to_string(), + type_name: "typ".to_string(), + fields: vec![ + ( + "a".to_string(), + Some(CqlValue::Text(String::from("Ala ma kota"))), + ), + ("b".to_string(), Some(CqlValue::Int(42))), + ( + "c".to_string(), + Some(CqlValue::List(vec![ + CqlValue::BigInt(1), + CqlValue::BigInt(2), + CqlValue::BigInt(3), + ])), + ), + ], + }, + &typ, + ); + let udt = do_serialize( + TestUdtWithEnforcedOrder { + a: "Ala ma kota".to_owned(), + b: 42, + c: vec![1, 2, 3], + }, + &typ, + ); + + assert_eq!(reference, udt); + } + + #[test] + fn test_udt_serialization_with_enforced_order_failing_type_check() { + let typ_not_udt = ColumnType::Ascii; + let udt = TestUdtWithEnforcedOrder::default(); + + let mut data = Vec::new(); + + let err = <_ as SerializeCql>::serialize(&udt, &typ_not_udt, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NotUdt) + )); + + let typ = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + // Two first columns are swapped + ("b".to_string(), ColumnType::Int), + ("a".to_string(), ColumnType::Text), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + ], + }; + + let err = + <_ as SerializeCql>::serialize(&udt, &typ, CellWriter::new(&mut data)).unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::FieldNameMismatch { .. }) + )); + + let typ_without_c = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + // Last field is missing + ], + }; + + let err = <_ as SerializeCql>::serialize(&udt, &typ_without_c, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::MissingField { .. }) + )); + + let typ_unexpected_field = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + // Unexpected field + ("d".to_string(), ColumnType::Counter), + ], + }; + + let err = + <_ as SerializeCql>::serialize(&udt, &typ_unexpected_field, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::UnexpectedFieldInDestination { .. } + ) + )); + + let typ_unexpected_field = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ("c".to_string(), ColumnType::TinyInt), // Wrong column type + ], + }; + + let err = + <_ as SerializeCql>::serialize(&udt, &typ_unexpected_field, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinSerializationErrorKind::UdtError( + UdtSerializationErrorKind::FieldSerializationFailed { .. } + ) + )); + } } diff --git a/scylla-macros/src/serialize/cql.rs b/scylla-macros/src/serialize/cql.rs index f19e47b27c..d3c5788401 100644 --- a/scylla-macros/src/serialize/cql.rs +++ b/scylla-macros/src/serialize/cql.rs @@ -3,11 +3,15 @@ use proc_macro::TokenStream; use proc_macro2::Span; use syn::parse_quote; +use super::Flavor; + #[derive(FromAttributes)] #[darling(attributes(scylla))] struct Attributes { #[darling(rename = "crate")] crate_path: Option, + + flavor: Option, } impl Attributes { @@ -36,7 +40,11 @@ pub fn derive_serialize_cql(tokens_input: TokenStream) -> Result = match ctx.attributes.flavor { + Some(Flavor::MatchByName) | None => Box::new(FieldSortingGenerator { ctx: &ctx }), + Some(Flavor::EnforceOrder) => Box::new(FieldOrderedGenerator { ctx: &ctx }), + }; let serialize_item = gen.generate_serialize(); @@ -93,13 +101,17 @@ impl Context { } } +trait Generator { + fn generate_serialize(&self) -> syn::TraitItemFn; +} + // Generates an implementation of the trait which sorts the fields according // to how it is defined in the database. struct FieldSortingGenerator<'a> { ctx: &'a Context, } -impl<'a> FieldSortingGenerator<'a> { +impl<'a> Generator for FieldSortingGenerator<'a> { fn generate_serialize(&self) -> syn::TraitItemFn { // Need to: // - Check that all required fields are there and no more @@ -222,3 +234,108 @@ impl<'a> FieldSortingGenerator<'a> { } } } + +// Generates an implementation of the trait which requires the fields +// to be placed in the same order as they are defined in the struct. +struct FieldOrderedGenerator<'a> { + ctx: &'a Context, +} + +impl<'a> Generator for FieldOrderedGenerator<'a> { + fn generate_serialize(&self) -> syn::TraitItemFn { + let mut statements: Vec = Vec::new(); + + let crate_path = self.ctx.attributes.crate_path(); + + // Declare a helper lambda for creating errors + statements.push(self.ctx.generate_mk_typck_err()); + statements.push(self.ctx.generate_mk_ser_err()); + + // Check that the type we want to serialize to is a UDT + statements.push( + self.ctx + .generate_udt_type_match(parse_quote!(#crate_path::UdtTypeCheckErrorKind::NotUdt)), + ); + + // Turn the cell writer into a value builder + statements.push(parse_quote! { + let mut builder = #crate_path::CellWriter::into_value_builder(writer); + }); + + // Create an iterator over fields + statements.push(parse_quote! { + let mut field_iter = field_types.iter(); + }); + + // Serialize each field + for field in self.ctx.fields.iter() { + let rust_field_ident = field.ident.as_ref().unwrap(); + let rust_field_name = rust_field_ident.to_string(); + let typ = &field.ty; + statements.push(parse_quote! { + match field_iter.next() { + Some((field_name, typ)) => { + if field_name == #rust_field_name { + let sub_builder = #crate_path::CellValueBuilder::make_sub_writer(&mut builder); + match <#typ as #crate_path::SerializeCql>::serialize(&self.#rust_field_ident, typ, sub_builder) { + Ok(_proof) => {}, + Err(err) => { + return ::std::result::Result::Err(mk_ser_err( + #crate_path::UdtSerializationErrorKind::FieldSerializationFailed { + field_name: <_ as ::std::clone::Clone>::clone(field_name), + err, + } + )); + } + } + } else { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::UdtTypeCheckErrorKind::FieldNameMismatch { + rust_field_name: <_ as ::std::string::ToString>::to_string(#rust_field_name), + db_field_name: <_ as ::std::clone::Clone>::clone(field_name), + } + )); + } + } + None => { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::UdtTypeCheckErrorKind::MissingField { + field_name: <_ as ::std::string::ToString>::to_string(#rust_field_name), + } + )); + } + } + }); + } + + // Check whether there are some fields remaining + statements.push(parse_quote! { + if let Some((field_name, typ)) = field_iter.next() { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::UdtTypeCheckErrorKind::UnexpectedFieldInDestination { + field_name: <_ as ::std::clone::Clone>::clone(field_name), + } + )); + } + }); + + parse_quote! { + fn serialize<'b>( + &self, + typ: &#crate_path::ColumnType, + writer: #crate_path::CellWriter<'b>, + ) -> ::std::result::Result<#crate_path::WrittenCellProof<'b>, #crate_path::SerializationError> { + #(#statements)* + let proof = #crate_path::CellValueBuilder::finish(builder) + .map_err(|_| #crate_path::SerializationError::new( + #crate_path::BuiltinTypeSerializationError { + rust_name: ::std::any::type_name::(), + got: <_ as ::std::clone::Clone>::clone(typ), + kind: #crate_path::BuiltinTypeSerializationErrorKind::SizeOverflow, + } + ) as #crate_path::SerializationError)?; + ::std::result::Result::Ok(proof) + } + } + } +} diff --git a/scylla-macros/src/serialize/mod.rs b/scylla-macros/src/serialize/mod.rs index 53abe0f296..183183fa91 100644 --- a/scylla-macros/src/serialize/mod.rs +++ b/scylla-macros/src/serialize/mod.rs @@ -1,2 +1,20 @@ +use darling::FromMeta; + pub(crate) mod cql; pub(crate) mod row; + +#[derive(Copy, Clone, PartialEq, Eq)] +enum Flavor { + MatchByName, + EnforceOrder, +} + +impl FromMeta for Flavor { + fn from_string(value: &str) -> darling::Result { + match value { + "match_by_name" => Ok(Self::MatchByName), + "enforce_order" => Ok(Self::EnforceOrder), + _ => Err(darling::Error::unknown_value(value)), + } + } +}