diff --git a/scylla-cql/src/lib.rs b/scylla-cql/src/lib.rs index 740d54644c..877bcec75d 100644 --- a/scylla-cql/src/lib.rs +++ b/scylla-cql/src/lib.rs @@ -21,6 +21,10 @@ pub mod _macro_internal { }; pub use crate::macros::*; + pub use crate::types::serialize::row::{ + RowSerializationContext, RowSerializationError, RowSerializationErrorKind, + RowTypeCheckError, RowTypeCheckErrorKind, SerializeRow, + }; pub use crate::types::serialize::value::{ SerializeCql, UdtSerializationError, UdtSerializationErrorKind, UdtTypeCheckError, UdtTypeCheckErrorKind, diff --git a/scylla-cql/src/macros.rs b/scylla-cql/src/macros.rs index e438d60354..9be4d4ef20 100644 --- a/scylla-cql/src/macros.rs +++ b/scylla-cql/src/macros.rs @@ -75,6 +75,68 @@ pub use scylla_macros::ValueList; /// to either the `scylla` or `scylla-cql` crate. pub use scylla_macros::SerializeCql; +/// Derive macro for the [`SerializeRow`](crate::types::serialize::row::SerializeRow) trait +/// which serializes given Rust structure into bind markers for a CQL statement. +/// +/// At the moment, only structs with named fields are supported. The generated +/// implementation of the trait will match the struct fields to bind markers/columns +/// by name automatically. +/// +/// Serialization will fail if there are some bind markers/columns in the statement +/// that don't match to any of the Rust struct fields, _or vice versa_. +/// +/// In case of failure, either [`RowTypeCheckError`](crate::types::serialize::row::RowTypeCheckError) +/// or [`RowSerializationError`](crate::types::serialize::row::RowSerializationError) +/// will be returned. +/// +/// # Example +/// +/// A UDT defined like this: +/// +/// ```notrust +/// CREATE TYPE ks.my_udt (a int, b text, c blob); +/// ``` +/// +/// ...can be serialized using the following struct: +/// +/// ```rust +/// # use scylla_cql::macros::SerializeRow; +/// #[derive(SerializeRow)] +/// # #[scylla(crate = scylla_cql)] +/// struct MyUdt { +/// a: i32, +/// b: Option, +/// c: Vec, +/// } +/// ``` +/// +/// # Attributes +/// +/// `#[scylla(crate = crate_name)]` +/// +/// By default, the code generated by the derive macro will refer to the items +/// defined by the driver (types, traits, etc.) via the `::scylla` path. +/// For example, it will refer to the [`SerializeRow`](crate::types::serialize::row::SerializeRow) trait +/// using the following path: +/// +/// ```rust,ignore +/// use ::scylla::_macro_internal::SerializeRow; +/// ``` +/// +/// Most users will simply add `scylla` to their dependencies, then use +/// the derive macro and the path above will work. However, there are some +/// niche cases where this path will _not_ work: +/// +/// - The `scylla` crate is imported under a different name, +/// - The `scylla` crate is _not imported at all_ - the macro actually +/// is defined in the `scylla-macros` crate and the generated code depends +/// on items defined in `scylla-cql`. +/// +/// It's not possible to automatically resolve those issues in the procedural +/// macro itself, so in those cases the user must provide an alternative path +/// to either the `scylla` or `scylla-cql` crate. +pub use scylla_macros::SerializeRow; + // Reexports for derive(IntoUserType) pub use bytes::{BufMut, Bytes, BytesMut}; diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index 2e9832412d..e51f2da2ea 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -1,5 +1,8 @@ +use std::fmt::Display; use std::sync::Arc; +use thiserror::Error; + use crate::frame::response::result::ColumnSpec; use crate::frame::value::ValueList; @@ -47,3 +50,221 @@ impl SerializeRow for T { .map_err(|err| Arc::new(err) as SerializationError) } } + +/// Returned by the code generated by [`SerializeRow`] macro when the types +/// of the bind markers expected by the database do not match the expectations +/// of a Rust struct. +/// +/// Returned by the [`SerializeRow::preliminary_type_check`] method from +/// the trait implementation generated by the macro. +#[derive(Debug, Error)] +#[error("Failed to type check Rust struct {rust_name} as a CQL row: {kind}")] +pub struct RowTypeCheckError { + /// Name of the Rust structure that was being serialized. + pub rust_name: String, + + /// Detailed infomation about why type checking of the row failed. + pub kind: RowTypeCheckErrorKind, +} + +/// Detailed information about why type checking of the row failed. +#[derive(Debug)] +#[non_exhaustive] +pub enum RowTypeCheckErrorKind { + /// There is a Rust struct field that must be serialized but does not + /// match against any of the bind markers. + MissingColumn { column_name: String }, + + /// There is a bind marker with a column name that does not match against + /// any of the Rust struct fields. + UnexpectedColumn { column_name: String }, + + /// Failed to type check one of the fields against the expected type + /// of the corresponding bind marker. + ColumnTypeCheckFailed { + column_name: String, + err: SerializationError, + }, +} + +impl Display for RowTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RowTypeCheckErrorKind::MissingColumn { column_name } => write!( + f, + "no bind marker or column named {column_name} in the statement spec" + ), + RowTypeCheckErrorKind::UnexpectedColumn { column_name } => write!( + f, + "the bind marker or column {column_name} does not correspond to any of the rust struct fields" + ), + RowTypeCheckErrorKind::ColumnTypeCheckFailed { column_name, err } => { + write!(f, "the bind marker {column_name} failed to type check: {err}") + } + } + } +} + +/// Returned by the code generated by [`SerializeRow`] macro when a Rust struct +/// fails to be serialized as bind markers for a statement. +/// +/// Returned by the [`SerializeRow::serialize`] method from the trait +/// implementation generated by the macro. +#[derive(Debug, Error)] +#[error("Failed to type check Rust struct {rust_name} as a CQL row: {kind}")] +pub struct RowSerializationError { + /// Name of the Rust structure that was being serialized. + pub rust_name: String, + + /// Detailed infomation about why serialization failed. + pub kind: RowSerializationErrorKind, +} + +/// Detailed information about why serialization of the row failed. +#[derive(Debug)] +#[non_exhaustive] +pub enum RowSerializationErrorKind { + /// One of the bind markers or columns failed to be serialized. + ColumnSerializationFailed { + column_name: String, + err: SerializationError, + }, +} + +impl Display for RowSerializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RowSerializationErrorKind::ColumnSerializationFailed { column_name, err } => { + write!( + f, + "the bind marker or column {column_name} failed to serialize: {err}" + ) + } + } + } +} + +#[cfg(test)] +mod tests { + use scylla_macros::SerializeRow; + + use super::{RowSerializationContext, RowTypeCheckError, RowTypeCheckErrorKind, SerializeRow}; + use crate::frame::response::result::{ColumnSpec, ColumnType, TableSpec}; + + fn do_serialize(t: T, columns: &[ColumnSpec]) -> Vec { + let ctx = RowSerializationContext { columns }; + T::preliminary_type_check(&ctx).unwrap(); + let mut ret = Vec::new(); + t.serialize(&ctx, &mut ret).unwrap(); + ret + } + + fn col(name: &str, typ: ColumnType) -> ColumnSpec { + ColumnSpec { + table_spec: TableSpec { + ks_name: "ks".to_string(), + table_name: "tbl".to_string(), + }, + name: name.to_string(), + typ, + } + } + + // Do not remove. It's not used in tests but we keep it here to check that + // we properly ignore warnings about unused variables, unnecessary `mut`s + // etc. that usually pop up when generating code for empty structs. + #[derive(SerializeRow)] + #[scylla(crate = crate)] + struct TestRowWithNoColumns {} + + #[derive(SerializeRow, Debug, PartialEq, Eq)] + #[scylla(crate = crate)] + struct TestRowWithColumnSorting { + a: String, + b: i32, + c: Vec, + } + + #[test] + fn test_row_serialization_with_column_sorting_correct_order() { + let spec = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + col("c", ColumnType::List(Box::new(ColumnType::BigInt))), + ]; + + let reference = do_serialize(("Ala ma kota", 42i32, vec![1i64, 2i64, 3i64]), &spec); + let row = do_serialize( + TestRowWithColumnSorting { + a: "Ala ma kota".to_owned(), + b: 42, + c: vec![1, 2, 3], + }, + &spec, + ); + + assert_eq!(reference, row); + } + + #[test] + fn test_row_serialization_with_column_sorting_incorrect_order() { + // The order of two last columns is swapped + let spec = [ + col("a", ColumnType::Text), + col("c", ColumnType::List(Box::new(ColumnType::BigInt))), + col("b", ColumnType::Int), + ]; + + let reference = do_serialize(("Ala ma kota", vec![1i64, 2i64, 3i64], 42i32), &spec); + let row = do_serialize( + TestRowWithColumnSorting { + a: "Ala ma kota".to_owned(), + b: 42, + c: vec![1, 2, 3], + }, + &spec, + ); + + assert_eq!(reference, row); + } + + #[test] + fn test_row_serialization_failing_type_check() { + let spec_without_c = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + // Missing column c + ]; + + let ctx = RowSerializationContext { + columns: &spec_without_c, + }; + let err = TestRowWithColumnSorting::preliminary_type_check(&ctx).unwrap_err(); + let err = err.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + RowTypeCheckErrorKind::MissingColumn { .. } + )); + + let spec_duplicate_column = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + col("c", ColumnType::List(Box::new(ColumnType::BigInt))), + // Unexpected last column + col("d", ColumnType::Counter), + ]; + + let ctx = RowSerializationContext { + columns: &spec_duplicate_column, + }; + let err = TestRowWithColumnSorting::preliminary_type_check(&ctx).unwrap_err(); + let err = err.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + RowTypeCheckErrorKind::UnexpectedColumn { .. } + )); + + // TODO: Test case for mismatched field types + // Can't do it without proper SerializeRaw implementation of field types + } +} diff --git a/scylla-macros/src/lib.rs b/scylla-macros/src/lib.rs index 84ee58bca0..64ce0ee06e 100644 --- a/scylla-macros/src/lib.rs +++ b/scylla-macros/src/lib.rs @@ -18,6 +18,15 @@ pub fn serialize_cql_derive(tokens_input: TokenStream) -> TokenStream { } } +/// See the documentation for this item in the `scylla` crate. +#[proc_macro_derive(SerializeRow, attributes(scylla))] +pub fn serialize_row_derive(tokens_input: TokenStream) -> TokenStream { + match serialize::row::derive_serialize_row(tokens_input) { + Ok(t) => t.into_token_stream().into(), + Err(e) => e.into_compile_error().into(), + } +} + /// #[derive(FromRow)] derives FromRow for struct /// Works only on simple structs without generics etc #[proc_macro_derive(FromRow, attributes(scylla_crate))] diff --git a/scylla-macros/src/serialize/mod.rs b/scylla-macros/src/serialize/mod.rs index 15fd9ae87c..53abe0f296 100644 --- a/scylla-macros/src/serialize/mod.rs +++ b/scylla-macros/src/serialize/mod.rs @@ -1 +1,2 @@ pub(crate) mod cql; +pub(crate) mod row; diff --git a/scylla-macros/src/serialize/row.rs b/scylla-macros/src/serialize/row.rs new file mode 100644 index 0000000000..967f779589 --- /dev/null +++ b/scylla-macros/src/serialize/row.rs @@ -0,0 +1,249 @@ +use darling::FromAttributes; +use proc_macro::TokenStream; +use proc_macro2::Span; +use syn::parse_quote; + +#[derive(FromAttributes)] +#[darling(attributes(scylla))] +struct Attributes { + #[darling(rename = "crate")] + crate_path: Option, +} + +impl Attributes { + fn crate_path(&self) -> syn::Path { + self.crate_path + .as_ref() + .map(|p| parse_quote!(#p::_macro_internal)) + .unwrap_or_else(|| parse_quote!(::scylla::_macro_internal)) + } +} + +struct Context { + struct_name: syn::Ident, + attributes: Attributes, + fields: Vec, +} + +pub fn derive_serialize_row(tokens_input: TokenStream) -> Result { + let input: syn::DeriveInput = syn::parse(tokens_input)?; + let struct_name = input.ident.clone(); + let named_fields = crate::parser::parse_named_fields(&input, "SerializeRow")?; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let attributes = Attributes::from_attributes(&input.attrs)?; + + let crate_path = attributes.crate_path(); + let implemented_trait: syn::Path = parse_quote!(#crate_path::SerializeRow); + + let fields = named_fields.named.iter().cloned().collect(); + let ctx = Context { + struct_name: struct_name.clone(), + attributes, + fields, + }; + let gen = ColumnSortingGenerator { ctx: &ctx }; + + let preliminary_type_check_item = gen.generate_preliminary_type_check(); + let serialize_item = gen.generate_serialize(); + + let res = parse_quote! { + impl<#impl_generics> #implemented_trait for #struct_name #ty_generics #where_clause { + #preliminary_type_check_item + #serialize_item + } + }; + Ok(res) +} + +// Generates an implementation of the trait which sorts the columns according +// to how they are defined in prepared statement metadata. +struct ColumnSortingGenerator<'a> { + ctx: &'a Context, +} + +impl<'a> ColumnSortingGenerator<'a> { + fn generate_preliminary_type_check(&self) -> syn::TraitItemFn { + // Need to: + // - Check that all required columns are there and no more + // - Check that the column types match + + let mut statements: Vec = Vec::new(); + + let crate_path = self.ctx.attributes.crate_path(); + + let row_name = self.ctx.struct_name.to_string(); + let rust_field_names = self + .ctx + .fields + .iter() + .map(|f| f.ident.as_ref().unwrap().to_string()) + .collect::>(); + let column_names = rust_field_names.clone(); // For now, it's the same + let field_types = self.ctx.fields.iter().map(|f| &f.ty).collect::>(); + let field_count = self.ctx.fields.len(); + + statements.push(parse_quote! { + let mk_error = |kind: #crate_path::RowTypeCheckErrorKind| -> ::std::sync::Arc<#crate_path::RowTypeCheckError> { + ::std::sync::Arc::new( + #crate_path::RowTypeCheckError { + rust_name: <_ as ::std::string::ToString>::to_string(#row_name), + kind, + } + ) + }; + }); + + // Generate a "visited" flag for each field + let visited_flag_names = rust_field_names + .iter() + .map(|s| syn::Ident::new(&format!("visited_flag_{}", s), Span::call_site())) + .collect::>(); + statements.extend::>(parse_quote! { + #(let mut #visited_flag_names = false;)* + }); + + // Generate a variable that counts down visited fields. + statements.push(parse_quote! { + let mut remaining_count = #field_count; + }); + + // Generate a loop over the fields and a `match` block to match on + // the field name. + statements.push(parse_quote! { + for spec in ctx.columns() { + match ::std::string::String::as_str(&spec.name) { + #( + #column_names => { + match <#field_types as #crate_path::SerializeCql>::preliminary_type_check(&spec.typ) { + ::std::result::Result::Ok(()) => {} + ::std::result::Result::Err(err) => { + return ::std::result::Result::Err(mk_error( + #crate_path::RowTypeCheckErrorKind::ColumnTypeCheckFailed { + column_name: <_ as ::std::clone::Clone>::clone(&spec.name), + err, + } + )); + } + }; + if !#visited_flag_names { + #visited_flag_names = true; + remaining_count -= 1; + } + }, + )* + other => return ::std::result::Result::Err(mk_error( + #crate_path::RowTypeCheckErrorKind::UnexpectedColumn { + column_name: <_ as ::std::clone::Clone>::clone(&&spec.name), + } + )), + } + } + }); + + // Finally, check that all fields were consumed. + // If there are some missing fields, return an error + statements.push(parse_quote! { + if remaining_count > 0 { + #( + if !#visited_flag_names { + return ::std::result::Result::Err(mk_error( + #crate_path::RowTypeCheckErrorKind::MissingColumn { + column_name: <_ as ::std::string::ToString>::to_string(#rust_field_names), + } + )); + } + )* + ::std::unreachable!() + } + }); + + // Concatenate generated code and return + parse_quote! { + fn preliminary_type_check( + ctx: &#crate_path::RowSerializationContext, + ) -> ::std::result::Result<(), #crate_path::SerializationError> { + #(#statements)* + ::std::result::Result::Ok(()) + } + } + } + + fn generate_serialize(&self) -> syn::TraitItemFn { + // Implementation can assume that preliminary_type_check was called + // (although not in an unsafe way). + // Need to write the fields as they appear in the type definition. + let mut statements: Vec = Vec::new(); + + let crate_path = self.ctx.attributes.crate_path(); + + let row_name = self.ctx.struct_name.to_string(); + let rust_field_idents = self + .ctx + .fields + .iter() + .map(|f| f.ident.clone()) + .collect::>(); + let rust_field_names = rust_field_idents + .iter() + .map(|i| i.as_ref().unwrap().to_string()) + .collect::>(); + let udt_field_names = rust_field_names.clone(); // For now, it's the same + let field_types = self.ctx.fields.iter().map(|f| &f.ty).collect::>(); + + // Declare a helper lambda for creating errors + statements.push(parse_quote! { + let mk_error = |kind: #crate_path::RowSerializationErrorKind| -> ::std::sync::Arc<#crate_path::RowSerializationError> { + ::std::sync::Arc::new( + #crate_path::RowSerializationError { + rust_name: <_ as ::std::string::ToString>::to_string(#row_name), + kind, + } + ) + }; + }); + + // Serialize the field count. + // TODO: This should be done by the driver, not the user logic - it knows exactly + // how many columns should be written. This should be done shortly after + // the transition to the new traits is complete. + statements.push(parse_quote! { + out.extend_from_slice(&(ctx.columns().len() as u16).to_be_bytes()); + }); + + // Generate a loop over the fields and a `match` block to match on + // the field name. + statements.push(parse_quote! { + for spec in ctx.columns() { + match ::std::string::String::as_str(&spec.name) { + #( + #udt_field_names => { + match <#field_types as #crate_path::SerializeCql>::serialize(&self.#rust_field_idents, &spec.typ, out) { + ::std::result::Result::Ok(()) => {} + ::std::result::Result::Err(err) => { + return ::std::result::Result::Err(mk_error( + #crate_path::RowSerializationErrorKind::ColumnSerializationFailed { + column_name: <_ as ::std::clone::Clone>::clone(&spec.name), + err, + } + )); + } + } + } + )* + _ => {} + } + } + }); + + parse_quote! { + fn serialize( + &self, + ctx: &#crate_path::RowSerializationContext, + out: &mut ::std::vec::Vec<::std::primitive::u8>, + ) -> ::std::result::Result<(), #crate_path::SerializationError> { + #(#statements)* + ::std::result::Result::Ok(()) + } + } + } +} diff --git a/scylla/tests/integration/hygiene.rs b/scylla/tests/integration/hygiene.rs index 12d55ccb61..cf2aaed7b3 100644 --- a/scylla/tests/integration/hygiene.rs +++ b/scylla/tests/integration/hygiene.rs @@ -64,7 +64,7 @@ macro_rules! test_crate { assert_eq!(sv, sv2); } - #[derive(_scylla::macros::SerializeCql)] + #[derive(_scylla::macros::SerializeCql, _scylla::macros::SerializeRow)] #[scylla(crate = _scylla)] struct TestStructNew { x: ::core::primitive::i32,