Skip to content

Commit

Permalink
scylla-macros: implement enforce_order flavor of SerializeRow
Browse files Browse the repository at this point in the history
Like in the case of `SerializeRow`, some people might be used to working
with the old `ValueList` and already order their Rust struct fields with
accordance to the queries they are used with and don't need the overhead
associated with looking up columns by name. The `enforce_order` mode is
added to `SerializeRow` which works analogously as in `SerializeCql` -
expects the columns to be in the correct order and verifies that this is
the case when serializing, but just fails instead of reordering if that
expectation is broken.
  • Loading branch information
piodul committed Nov 24, 2023
1 parent b41e303 commit 0b95f92
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 5 deletions.
19 changes: 16 additions & 3 deletions scylla-cql/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@ pub use scylla_macros::SerializeCql;
/// Derive macro for the [`SerializeRow`](crate::types::serialize::row::SerializeRow) trait
/// which serializes given Rust structure into bind markers for a CQL statement.
///
/// At the moment, only structs with named fields are supported. The generated
/// implementation of the trait will match the struct fields to bind markers/columns
/// by name automatically.
/// At the moment, only structs with named fields are supported.
///
/// Serialization will fail if there are some bind markers/columns in the statement
/// that don't match to any of the Rust struct fields, _or vice versa_.
Expand Down Expand Up @@ -125,6 +123,21 @@ pub use scylla_macros::SerializeCql;
///
/// # Attributes
///
/// `#[scylla(flavor = "flavor_name")]`
///
/// Allows to choose one of the possible "flavors", i.e. the way how the
/// generated code will approach serialization. Possible flavors are:
///
/// - `"match_by_name"` (default) - the generated implementation _does not
/// require_ the fields in the Rust struct to be in the same order as the
/// columns/bind markers. During serialization, the implementation will take
/// care to serialize the fields in the order which the database expects.
/// - `"enforce_order"` - the generated implementation _requires_ the fields
/// in the Rust struct to be in the same order as the columns/bind markers.
/// If the order is incorrect, type checking/serialization will fail.
/// This is a less robust flavor than `"match_by_name"`, but should be
/// slightly more performant as it doesn't need to perform lookups by name.
///
/// `#[scylla(crate = crate_name)]`
///
/// By default, the code generated by the derive macro will refer to the items
Expand Down
106 changes: 106 additions & 0 deletions scylla-cql/src/types/serialize/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,12 @@ pub enum BuiltinTypeCheckErrorKind {
/// A value required by the statement is not provided by the Rust type.
ColumnMissingForValue { name: String },

/// A different column name was expected at given position.
ColumnNameMismatch {
rust_column_name: String,
db_column_name: String,
},

/// One of the columns failed to type check.
ColumnTypeCheckFailed {
name: String,
Expand All @@ -488,6 +494,10 @@ impl Display for BuiltinTypeCheckErrorKind {
"value for column {name} was provided, but there is no bind marker for this column in the query"
)
}
BuiltinTypeCheckErrorKind::ColumnNameMismatch { rust_column_name, db_column_name } => write!(
f,
"expected column with name {db_column_name} at given position, but the Rust field name is {rust_column_name}"
),
BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { name, err } => {
write!(f, "failed to check column {name}: {err}")
}
Expand Down Expand Up @@ -742,4 +752,100 @@ mod tests {
BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { .. }
));
}

#[derive(SerializeRow, Debug, PartialEq, Eq)]
#[scylla(crate = crate, flavor = "enforce_order")]
struct TestRowWithEnforcedOrder {
a: String,
b: i32,
c: Vec<i64>,
}

#[test]
fn test_row_serialization_with_enforced_order_correct_order() {
let spec = [
col("a", ColumnType::Text),
col("b", ColumnType::Int),
col("c", ColumnType::List(Box::new(ColumnType::BigInt))),
];

let reference = do_serialize(("Ala ma kota", 42i32, vec![1i64, 2i64, 3i64]), &spec);
let row = do_serialize(
TestRowWithEnforcedOrder {
a: "Ala ma kota".to_owned(),
b: 42,
c: vec![1, 2, 3],
},
&spec,
);

assert_eq!(reference, row);
}

#[test]
fn test_row_serialization_with_enforced_order_failing_type_check() {
// The order of two last columns is swapped
let spec = [
col("a", ColumnType::Text),
col("c", ColumnType::List(Box::new(ColumnType::BigInt))),
col("b", ColumnType::Int),
];
let ctx = RowSerializationContext { columns: &spec };
let err = TestRowWithEnforcedOrder::preliminary_type_check(&ctx).unwrap_err();
let err = err.downcast_ref::<BuiltinTypeCheckError>().unwrap();
assert!(matches!(
err.kind,
BuiltinTypeCheckErrorKind::ColumnNameMismatch { .. }
));

let spec_without_c = [
col("a", ColumnType::Text),
col("b", ColumnType::Int),
// Missing column c
];

let ctx = RowSerializationContext {
columns: &spec_without_c,
};
let err = TestRowWithEnforcedOrder::preliminary_type_check(&ctx).unwrap_err();
let err = err.downcast_ref::<BuiltinTypeCheckError>().unwrap();
assert!(matches!(
err.kind,
BuiltinTypeCheckErrorKind::ColumnMissingForValue { .. }
));

let spec_duplicate_column = [
col("a", ColumnType::Text),
col("b", ColumnType::Int),
col("c", ColumnType::List(Box::new(ColumnType::BigInt))),
// Unexpected last column
col("d", ColumnType::Counter),
];

let ctx = RowSerializationContext {
columns: &spec_duplicate_column,
};
let err = TestRowWithEnforcedOrder::preliminary_type_check(&ctx).unwrap_err();
let err = err.downcast_ref::<BuiltinTypeCheckError>().unwrap();
assert!(matches!(
err.kind,
BuiltinTypeCheckErrorKind::MissingValueForColumn { .. }
));

let spec_wrong_type = [
col("a", ColumnType::Text),
col("b", ColumnType::Int),
col("c", ColumnType::TinyInt), // Wrong type
];

let ctx = RowSerializationContext {
columns: &spec_wrong_type,
};
let err = TestRowWithEnforcedOrder::preliminary_type_check(&ctx).unwrap_err();
let err = err.downcast_ref::<BuiltinTypeCheckError>().unwrap();
assert!(matches!(
err.kind,
BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { .. }
));
}
}
145 changes: 143 additions & 2 deletions scylla-macros/src/serialize/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ use proc_macro::TokenStream;
use proc_macro2::Span;
use syn::parse_quote;

use super::Flavor;

#[derive(FromAttributes)]
#[darling(attributes(scylla))]
struct Attributes {
#[darling(rename = "crate")]
crate_path: Option<syn::Path>,

flavor: Option<Flavor>,
}

impl Attributes {
Expand Down Expand Up @@ -36,7 +40,11 @@ pub fn derive_serialize_row(tokens_input: TokenStream) -> Result<syn::ItemImpl,

let fields = named_fields.named.iter().cloned().collect();
let ctx = Context { attributes, fields };
let gen = ColumnSortingGenerator { ctx: &ctx };

let gen: Box<dyn Generator> = match ctx.attributes.flavor {
Some(Flavor::MatchByName) | None => Box::new(ColumnSortingGenerator { ctx: &ctx }),
Some(Flavor::EnforceOrder) => Box::new(ColumnOrderedGenerator { ctx: &ctx }),
};

let preliminary_type_check_item = gen.generate_preliminary_type_check();
let serialize_item = gen.generate_serialize();
Expand Down Expand Up @@ -80,13 +88,18 @@ impl Context {
}
}

trait Generator {
fn generate_preliminary_type_check(&self) -> syn::TraitItemFn;
fn generate_serialize(&self) -> syn::TraitItemFn;
}

// Generates an implementation of the trait which sorts the columns according
// to how they are defined in prepared statement metadata.
struct ColumnSortingGenerator<'a> {
ctx: &'a Context,
}

impl<'a> ColumnSortingGenerator<'a> {
impl<'a> Generator for ColumnSortingGenerator<'a> {
fn generate_preliminary_type_check(&self) -> syn::TraitItemFn {
// Need to:
// - Check that all required columns are there and no more
Expand Down Expand Up @@ -245,3 +258,131 @@ impl<'a> ColumnSortingGenerator<'a> {
}
}
}

// Generates an implementation of the trait which requires the columns
// to be placed in the same order as they are defined in the struct.
struct ColumnOrderedGenerator<'a> {
ctx: &'a Context,
}

impl<'a> Generator for ColumnOrderedGenerator<'a> {
fn generate_preliminary_type_check(&self) -> syn::TraitItemFn {
let mut statements: Vec<syn::Stmt> = Vec::new();

let crate_path = self.ctx.attributes.crate_path();

statements.push(self.ctx.generate_mk_typck_err());

// Create an iterator over fields
statements.push(parse_quote! {
let mut column_iter = ctx.columns().iter();
});

// Go over all fields, check their names and then type check
for field in self.ctx.fields.iter() {
let name = field.ident.as_ref().unwrap().to_string();
let typ = &field.ty;
statements.push(parse_quote! {
match column_iter.next() {
Some(spec) => {
if spec.name == #name {
match <#typ as #crate_path::SerializeCql>::preliminary_type_check(&spec.typ) {
Ok(()) => {}
Err(err) => {
return ::std::result::Result::Err(mk_typck_err(
#crate_path::BuiltinRowTypeCheckErrorKind::ColumnTypeCheckFailed {
name: <_ as ::std::clone::Clone>::clone(&spec.name),
err,
}
));
}
}
} else {
return ::std::result::Result::Err(mk_typck_err(
#crate_path::BuiltinRowTypeCheckErrorKind::ColumnNameMismatch {
rust_column_name: <_ as ::std::string::ToString>::to_string(#name),
db_column_name: <_ as ::std::clone::Clone>::clone(&spec.name),
}
));
}
}
None => {
return ::std::result::Result::Err(mk_typck_err(
#crate_path::BuiltinRowTypeCheckErrorKind::ColumnMissingForValue {
name: <_ as ::std::string::ToString>::to_string(#name),
}
));
}
}
});
}

// Check whether there are some columns remaining
statements.push(parse_quote! {
if let Some(spec) = column_iter.next() {
return ::std::result::Result::Err(mk_typck_err(
#crate_path::BuiltinRowTypeCheckErrorKind::MissingValueForColumn {
name: <_ as ::std::clone::Clone>::clone(&spec.name),
}
));
}
});

// Concatenate generated code and return
parse_quote! {
fn preliminary_type_check(
ctx: &#crate_path::RowSerializationContext,
) -> ::std::result::Result<(), #crate_path::SerializationError> {
#(#statements)*
::std::result::Result::Ok(())
}
}
}

fn generate_serialize(&self) -> syn::TraitItemFn {
let mut statements: Vec<syn::Stmt> = Vec::new();

let crate_path = self.ctx.attributes.crate_path();

// Declare a helper lambda for creating errors
statements.push(self.ctx.generate_mk_ser_err());

// Create an iterator over fields
statements.push(parse_quote! {
let mut column_iter = ctx.columns().iter();
});

// Serialize each field
for field in self.ctx.fields.iter() {
let name = &field.ident;
let typ = &field.ty;
statements.push(parse_quote! {
if let Some(spec) = column_iter.next() {
let cell_writer = <_ as #crate_path::RowWriter>::make_cell_writer(writer);
match <#typ as #crate_path::SerializeCql>::serialize(&self.#name, &spec.typ, cell_writer) {
Ok(_proof) => {},
Err(err) => {
return ::std::result::Result::Err(mk_ser_err(
#crate_path::BuiltinRowSerializationErrorKind::ColumnSerializationFailed {
name: <_ as ::std::clone::Clone>::clone(&spec.name),
err,
}
));
}
}
}
});
}

parse_quote! {
fn serialize<W: #crate_path::RowWriter>(
&self,
ctx: &#crate_path::RowSerializationContext,
writer: &mut W,
) -> ::std::result::Result<(), #crate_path::SerializationError> {
#(#statements)*
::std::result::Result::Ok(())
}
}
}
}

0 comments on commit 0b95f92

Please sign in to comment.