diff --git a/scylla-cql/src/errors.rs b/scylla-cql/src/errors.rs index c7cc85d233..9e80247e20 100644 --- a/scylla-cql/src/errors.rs +++ b/scylla-cql/src/errors.rs @@ -349,8 +349,8 @@ pub enum BadQuery { BadKeyspaceName(#[from] BadKeyspaceName), /// Too many queries in the batch statement - #[error("Number of Queries in Batch Statement has exceeded the max value of 65,536")] - TooManyQueriesInBatchStatement, + #[error("Number of Queries in Batch Statement supplied is {0} which has exceeded the max value of 65,535")] + TooManyQueriesInBatchStatement(usize), /// Other reasons of bad query #[error("{0}")] diff --git a/scylla-cql/src/frame/frame_errors.rs b/scylla-cql/src/frame/frame_errors.rs index 403b6ab5fd..3da4e26d01 100644 --- a/scylla-cql/src/frame/frame_errors.rs +++ b/scylla-cql/src/frame/frame_errors.rs @@ -40,7 +40,7 @@ pub enum ParseError { #[error(transparent)] IoError(#[from] std::io::Error), #[error("type not yet implemented, id: {0}")] - TypeNotImplemented(i16), + TypeNotImplemented(u16), #[error(transparent)] SerializeValuesError(#[from] SerializeValuesError), #[error(transparent)] diff --git a/scylla-cql/src/frame/request/batch.rs b/scylla-cql/src/frame/request/batch.rs index 92b8b61ec4..3c0bad3931 100644 --- a/scylla-cql/src/frame/request/batch.rs +++ b/scylla-cql/src/frame/request/batch.rs @@ -81,7 +81,7 @@ where buf.put_u8(self.batch_type as u8); // Serializing queries - types::write_u16(self.statements.len().try_into()?, buf); + types::write_short(self.statements.len().try_into()?, buf); let counts_mismatch_err = |n_values: usize, n_statements: usize| { ParseError::BadDataToSerialize(format!( @@ -190,7 +190,7 @@ impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec Result { let batch_type = buf.get_u8().try_into()?; - let statements_count: usize = types::read_u16(buf)?.try_into()?; + let statements_count: usize = types::read_short(buf)?.try_into()?; let statements_with_values = (0..statements_count) .map(|_| { let batch_statement = BatchStatement::deserialize(buf)?; diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index fa964e9478..1c004e07cf 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -16,7 +16,7 @@ use uuid::Uuid; #[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, TryFromPrimitive)] #[cfg_attr(feature = "serde", derive(serde::Deserialize))] #[cfg_attr(feature = "serde", serde(rename_all = "SCREAMING_SNAKE_CASE"))] -#[repr(i16)] +#[repr(u16)] pub enum Consistency { Any = 0x0000, One = 0x0001, @@ -175,8 +175,8 @@ fn type_long() { } } -pub fn read_short(buf: &mut &[u8]) -> Result { - let v = buf.read_i16::()?; +pub fn read_short(buf: &mut &[u8]) -> Result { + let v = buf.read_u16::()?; Ok(v) } @@ -185,11 +185,7 @@ pub fn read_u16(buf: &mut &[u8]) -> Result { Ok(v) } -pub fn write_short(v: i16, buf: &mut impl BufMut) { - buf.put_i16(v); -} - -pub fn write_u16(v: u16, buf: &mut impl BufMut) { +pub fn write_short(v: u16, buf: &mut impl BufMut) { buf.put_u16(v); } @@ -200,14 +196,14 @@ pub(crate) fn read_short_length(buf: &mut &[u8]) -> Result { } fn write_short_length(v: usize, buf: &mut impl BufMut) -> Result<(), ParseError> { - let v: i16 = v.try_into()?; + let v: u16 = v.try_into()?; write_short(v, buf); Ok(()) } #[test] fn type_short() { - let vals = [i16::MIN, -1, 0, 1, i16::MAX]; + let vals: [u16; 3] = [0, 1, u16::MAX]; for val in vals.iter() { let mut buf = Vec::new(); write_short(*val, &mut buf); @@ -215,15 +211,6 @@ fn type_short() { } } -#[test] -fn type_u16() { - let vals = [0, 1, u16::MAX]; - for val in vals.iter() { - let mut buf = Vec::new(); - write_u16(*val, &mut buf); - assert_eq!(read_u16(&mut &buf[..]).unwrap(), *val); - } -} // https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L208 pub fn read_bytes_opt<'a>(buf: &mut &'a [u8]) -> Result, ParseError> { let len = read_int(buf)?; @@ -488,11 +475,11 @@ pub fn read_consistency(buf: &mut &[u8]) -> Result { } pub fn write_consistency(c: Consistency, buf: &mut impl BufMut) { - write_short(c as i16, buf); + write_short(c as u16, buf); } pub fn write_serial_consistency(c: SerialConsistency, buf: &mut impl BufMut) { - write_short(c as i16, buf); + write_short(c as u16, buf); } #[test] diff --git a/scylla-cql/src/frame/value.rs b/scylla-cql/src/frame/value.rs index e9164f2531..617dce4820 100644 --- a/scylla-cql/src/frame/value.rs +++ b/scylla-cql/src/frame/value.rs @@ -63,7 +63,7 @@ pub struct Time(pub Duration); #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct SerializedValues { serialized_values: Vec, - values_num: i16, + values_num: u16, contains_names: bool, } @@ -134,7 +134,7 @@ impl SerializedValues { if self.contains_names { return Err(SerializeValuesError::MixingNamedAndNotNamedValues); } - if self.values_num == i16::MAX { + if self.values_num == u16::MAX { return Err(SerializeValuesError::TooManyValues); } @@ -158,7 +158,7 @@ impl SerializedValues { return Err(SerializeValuesError::MixingNamedAndNotNamedValues); } self.contains_names = true; - if self.values_num == i16::MAX { + if self.values_num == u16::MAX { return Err(SerializeValuesError::TooManyValues); } @@ -184,7 +184,7 @@ impl SerializedValues { } pub fn write_to_request(&self, buf: &mut impl BufMut) { - buf.put_i16(self.values_num); + buf.put_u16(self.values_num); buf.put(&self.serialized_values[..]); } @@ -192,7 +192,7 @@ impl SerializedValues { self.values_num == 0 } - pub fn len(&self) -> i16 { + pub fn len(&self) -> u16 { self.values_num } diff --git a/scylla/Cargo.toml b/scylla/Cargo.toml index 3408d10330..2460e020b9 100644 --- a/scylla/Cargo.toml +++ b/scylla/Cargo.toml @@ -60,7 +60,6 @@ criterion = "0.4" # Note: v0.5 needs at least rust 1.70.0 tracing-subscriber = { version = "0.3.14", features = ["env-filter"] } assert_matches = "1.5.0" rand_chacha = "0.3.1" -bcs = "0.1.5" [[bench]] name = "benchmark" diff --git a/scylla/src/statement/prepared_statement.rs b/scylla/src/statement/prepared_statement.rs index b57d5d4b23..9814e7350d 100644 --- a/scylla/src/statement/prepared_statement.rs +++ b/scylla/src/statement/prepared_statement.rs @@ -339,7 +339,7 @@ impl PreparedStatement { #[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)] pub enum PartitionKeyExtractionError { #[error("No value with given pk_index! pk_index: {0}, values.len(): {1}")] - NoPkIndexValue(u16, i16), + NoPkIndexValue(u16, u16), } #[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)] diff --git a/scylla/src/transport/large_batch_statements_test.rs b/scylla/src/transport/large_batch_statements_test.rs index 6195de30df..b3d96b25f5 100644 --- a/scylla/src/transport/large_batch_statements_test.rs +++ b/scylla/src/transport/large_batch_statements_test.rs @@ -1,5 +1,7 @@ -use bcs::serialize_into; +use assert_matches::assert_matches; + use scylla_cql::errors::{BadQuery, QueryError}; +use scylla_cql::Consistency; use crate::batch::BatchType; use crate::query::Query; @@ -16,48 +18,55 @@ async fn test_large_batch_statements() { let ks = unique_keyspace_name(); session = create_test_session(session, &ks).await; + // table should be initially empty + let query_result = simple_fetch_all(&session, &ks).await; + assert_eq!(query_result.rows.unwrap().len(), 0); + + // Add batch let max_number_of_queries = u16::MAX as usize; - write_batch(&session, max_number_of_queries).await; + write_batch(&session, max_number_of_queries, &ks).await; - let key_prefix = vec![0]; - let keys = find_keys_by_prefix(&session, key_prefix.clone()).await; - assert_eq!(keys.len(), max_number_of_queries); + // Query batch + let query_result = simple_fetch_all(&session, &ks).await; + assert_eq!(query_result.rows.unwrap().len(), max_number_of_queries); + // Now try with too many queries let too_many_queries = u16::MAX as usize + 1; - - let err = write_batch(&session, too_many_queries).await; - - assert!(err.is_err()); + let batch_insert_result = write_batch(&session, too_many_queries, &ks).await; + assert_matches!( + batch_insert_result.unwrap_err(), + QueryError::BadQuery(BadQuery::TooManyQueriesInBatchStatement(too_many_queries)) + ) } async fn create_test_session(session: Session, ks: &String) -> Session { session .query( - format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{ 'class' : 'SimpleStrategy', 'replication_factor' : 1 }}",ks), + format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{ 'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1 }}",ks), &[], ) .await.unwrap(); session - .query("DROP TABLE IF EXISTS kv.pairs;", &[]) + .query(format!("DROP TABLE IF EXISTS {}.pairs;", ks), &[]) .await .unwrap(); session .query( - "CREATE TABLE IF NOT EXISTS kv.pairs (dummy int, k blob, v blob, primary key (dummy, k))", + format!("CREATE TABLE IF NOT EXISTS {}.pairs (dummy int, k blob, v blob, primary key (dummy, k))", ks), &[], ) .await.unwrap(); session } -async fn write_batch(session: &Session, n: usize) -> Result { +async fn write_batch(session: &Session, n: usize, ks: &String) -> Result { let mut batch_query = Batch::new(BatchType::Logged); let mut batch_values = Vec::new(); for i in 0..n { let mut key = vec![0]; - serialize_into(&mut key, &(i as usize)).unwrap(); + key.extend(i.to_be_bytes().as_slice()); let value = key.clone(); - let query = "INSERT INTO kv.pairs (dummy, k, v) VALUES (0, ?, ?)"; + let query = format!("INSERT INTO {}.pairs (dummy, k, v) VALUES (0, ?, ?)", ks); let values = vec![key, value]; batch_values.push(values); let query = Query::new(query); @@ -66,41 +75,7 @@ async fn write_batch(session: &Session, n: usize) -> Result) -> Vec> { - let len = key_prefix.len(); - let rows = match get_upper_bound_option(&key_prefix) { - None => { - let values = (key_prefix,); - let query = "SELECT k FROM kv.pairs WHERE dummy = 0 AND k >= ? ALLOW FILTERING"; - session.query(query, values).await.unwrap() - } - Some(upper_bound) => { - let values = (key_prefix, upper_bound); - let query = - "SELECT k FROM kv.pairs WHERE dummy = 0 AND k >= ? AND k < ? ALLOW FILTERING"; - session.query(query, values).await.unwrap() - } - }; - let mut keys = Vec::new(); - if let Some(rows) = rows.rows { - for row in rows.into_typed::<(Vec,)>() { - let key = row.unwrap(); - let short_key = key.0[len..].to_vec(); - keys.push(short_key); - } - } - keys -} - -fn get_upper_bound_option(key_prefix: &[u8]) -> Option> { - let len = key_prefix.len(); - for i in (0..len).rev() { - let val = key_prefix[i]; - if val < u8::MAX { - let mut upper_bound = key_prefix[0..i + 1].to_vec(); - upper_bound[i] += 1; - return Some(upper_bound); - } - } - None +async fn simple_fetch_all(session: &Session, ks: &String) -> QueryResult { + let select_query = format!("SELECT * FROM {}.pairs", ks); + session.query(select_query, &[]).await.unwrap() } diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index f92067363d..2f67874f8c 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -1145,9 +1145,10 @@ impl Session { // If users batch statements by shard, they will be rewarded with full shard awareness // check to ensure that we don't send a batch statement with more than u16::MAX queries - if batch.statements.len() > u16::MAX as usize { + let batch_statements_length = batch.statements.len(); + if batch_statements_length > u16::MAX as usize { return Err(QueryError::BadQuery( - BadQuery::TooManyQueriesInBatchStatement, + BadQuery::TooManyQueriesInBatchStatement(batch_statements_length), )); } // Extract first serialized_value