From 2685ab9539f12075442e50b00d740d7989ea174c Mon Sep 17 00:00:00 2001 From: Lucas Kent Date: Mon, 24 Jun 2024 14:58:46 +1000 Subject: [PATCH] cassandra 5.0 vector type CREATE/INSERT support makes progress towards: https://github.com/scylladb/scylla-rust-driver/issues/1014 The vector type is introduced by the currently in beta cassandra 5. See: https://cassandra.apache.org/doc/latest/cassandra/reference/vector-data-type.html Scylla does not support vector types and so the tests are setup to only compile/run with a new cassandra_tests config. This commit does not add support for retrieving the data via a SELECT. That was omitted to reduce scope and will be implemented in follow up work. --- .github/workflows/cassandra.yml | 2 +- scylla/Cargo.toml | 2 +- scylla/src/transport/session_test.rs | 105 +++++++++++++++++++++++++++ scylla/src/transport/topology.rs | 47 ++++++++++++ scylla/src/utils/parse.rs | 15 ++++ 5 files changed, 169 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cassandra.yml b/.github/workflows/cassandra.yml index 4926ece5d6..6c8a71874d 100644 --- a/.github/workflows/cassandra.yml +++ b/.github/workflows/cassandra.yml @@ -31,7 +31,7 @@ jobs: run: cargo build --verbose --tests --features "full-serialization" - name: Run tests on cassandra run: | - CDC='disabled' RUST_LOG=trace SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose --features "full-serialization" -- --skip test_views_in_schema_info --skip test_large_batch_statements + CDC='disabled' RUSTFLAGS="--cfg cassandra_tests" RUST_LOG=trace SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose --features "full-serialization" -- --skip test_views_in_schema_info --skip test_large_batch_statements - name: Stop the cluster if: ${{ always() }} run: docker compose -f test/cluster/cassandra/docker-compose.yml stop diff --git a/scylla/Cargo.toml b/scylla/Cargo.toml index d5d9551e2d..390a99501b 100644 --- a/scylla/Cargo.toml +++ b/scylla/Cargo.toml @@ -98,4 +98,4 @@ harness = false [lints.rust] unnameable_types = "warn" unreachable_pub = "warn" -unexpected_cfgs = { level = "warn", check-cfg = ['cfg(scylla_cloud_tests)'] } +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(scylla_cloud_tests)', 'cfg(cassandra_tests)'] } diff --git a/scylla/src/transport/session_test.rs b/scylla/src/transport/session_test.rs index 3e3552c230..f3bb159bc3 100644 --- a/scylla/src/transport/session_test.rs +++ b/scylla/src/transport/session_test.rs @@ -3165,3 +3165,108 @@ async fn test_api_migration_session_sharing() { assert!(matched); } } + +#[cfg(cassandra_tests)] +#[tokio::test] +async fn test_vector_type_metadata() { + setup_tracing(); + let session = create_new_session_builder().build().await.unwrap(); + let ks = unique_keyspace_name(); + + session.query_unpaged(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap(); + session + .query_unpaged( + format!( + "CREATE TABLE IF NOT EXISTS {}.t (a int PRIMARY KEY, b vector, c vector)", + ks + ), + &[], + ) + .await + .unwrap(); + + session.refresh_metadata().await.unwrap(); + let metadata = session.get_cluster_data(); + let columns = &metadata.keyspaces[&ks].tables["t"].columns; + assert_eq!( + columns["b"].type_, + CqlType::Vector { + type_: Box::new(CqlType::Native(NativeType::Int)), + dimensions: 4, + }, + ); + assert_eq!( + columns["c"].type_, + CqlType::Vector { + type_: Box::new(CqlType::Native(NativeType::Text)), + dimensions: 2, + }, + ); +} + +#[cfg(cassandra_tests)] +#[tokio::test] +async fn test_vector_type_unprepared() { + setup_tracing(); + let session = create_new_session_builder().build().await.unwrap(); + let ks = unique_keyspace_name(); + + session.query_unpaged(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap(); + session + .query_unpaged( + format!( + "CREATE TABLE IF NOT EXISTS {}.t (a int PRIMARY KEY, b vector, c vector)", + ks + ), + &[], + ) + .await + .unwrap(); + + session + .query_unpaged( + format!( + "INSERT INTO {}.t (a, b, c) VALUES (1, [1, 2, 3, 4], ['foo', 'bar'])", + ks + ), + &[], + ) + .await + .unwrap(); + + // TODO: Implement and test SELECT statements and bind values (`?`) +} + +#[cfg(cassandra_tests)] +#[tokio::test] +async fn test_vector_type_prepared() { + setup_tracing(); + let session = create_new_session_builder().build().await.unwrap(); + let ks = unique_keyspace_name(); + + session.query_unpaged(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap(); + session + .query_unpaged( + format!( + "CREATE TABLE IF NOT EXISTS {}.t (a int PRIMARY KEY, b vector, c vector)", + ks + ), + &[], + ) + .await + .unwrap(); + + let prepared_statement = session + .prepare(format!( + "INSERT INTO {}.t (a, b, c) VALUES (?, [11, 12, 13, 14], ['afoo', 'abar'])", + ks + )) + .await + .unwrap(); + session + .execute_unpaged(&prepared_statement, &(2,)) + .await + .unwrap(); + + // TODO: Implement and test SELECT statements and bind values (`?`) +} diff --git a/scylla/src/transport/topology.rs b/scylla/src/transport/topology.rs index 2b83b7ade2..7bbfc6b2a1 100644 --- a/scylla/src/transport/topology.rs +++ b/scylla/src/transport/topology.rs @@ -188,6 +188,12 @@ enum PreCqlType { type_: PreCollectionType, }, Tuple(Vec), + Vector { + type_: Box, + /// matches the datatype used by the java driver: + /// + dimensions: i32, + }, UserDefinedType { frozen: bool, name: String, @@ -211,6 +217,10 @@ impl PreCqlType { .map(|t| t.into_cql_type(keyspace_name, udts)) .collect(), ), + PreCqlType::Vector { type_, dimensions } => CqlType::Vector { + type_: Box::new(type_.into_cql_type(keyspace_name, udts)), + dimensions, + }, PreCqlType::UserDefinedType { frozen, name } => { let definition = match udts .get(keyspace_name) @@ -236,6 +246,12 @@ pub enum CqlType { type_: CollectionType, }, Tuple(Vec), + Vector { + type_: Box, + /// matches the datatype used by the java driver: + /// + dimensions: i32, + }, UserDefinedType { frozen: bool, // Using Arc here in order not to have many copies of the same definition @@ -1137,6 +1153,7 @@ fn topo_sort_udts(udts: &mut Vec) -> Result<(), Quer PreCqlType::Tuple(types) => types .iter() .for_each(|type_| do_with_referenced_udts(what, type_)), + PreCqlType::Vector { type_, .. } => do_with_referenced_udts(what, type_), PreCqlType::UserDefinedType { name, .. } => what(name), } } @@ -1637,6 +1654,22 @@ fn parse_cql_type(p: ParserState<'_>) -> ParseResult<(PreCqlType, ParserState<'_ })?; Ok((PreCqlType::Tuple(types), p)) + } else if let Ok(p) = p.accept("vector<") { + let (inner_type, p) = parse_cql_type(p)?; + + let p = p.skip_white(); + let p = p.accept(",")?; + let p = p.skip_white(); + let (size, p) = p.parse_i32()?; + let p = p.skip_white(); + let p = p.accept(">")?; + + let typ = PreCqlType::Vector { + type_: Box::new(inner_type), + dimensions: size, + }; + + Ok((typ, p)) } else if let Ok((typ, p)) = parse_native_type(p) { Ok((PreCqlType::Native(typ), p)) } else if let Ok((name, p)) = parse_user_defined_type(p) { @@ -1827,6 +1860,20 @@ mod tests { PreCqlType::Native(NativeType::Varint), ]), ), + ( + "vector", + PreCqlType::Vector { + type_: Box::new(PreCqlType::Native(NativeType::Int)), + dimensions: 5, + }, + ), + ( + "vector", + PreCqlType::Vector { + type_: Box::new(PreCqlType::Native(NativeType::Text)), + dimensions: 1234, + }, + ), ( "com.scylladb.types.AwesomeType", PreCqlType::UserDefinedType { diff --git a/scylla/src/utils/parse.rs b/scylla/src/utils/parse.rs index 1c5e59ecb7..02e3f77f64 100644 --- a/scylla/src/utils/parse.rs +++ b/scylla/src/utils/parse.rs @@ -87,6 +87,21 @@ impl<'s> ParserState<'s> { me } + /// Parses a sequence of digits and '-' as an integer. + /// Consumes characters until it finds a character that is not a digit or '-'. + /// + /// An error is returned if: + /// * The first character is not a digit or '-' + /// * The integer is larger than i32 + pub(crate) fn parse_i32(self) -> ParseResult<(i32, Self)> { + let (digits, p) = self.take_while(|c| c.is_ascii_digit() || c == '-'); + if let Ok(value) = digits.parse() { + Ok((value, p)) + } else { + Err(p.error(ParseErrorCause::Other("Expected 32-bit signed integer"))) + } + } + /// Skips characters from the beginning while they satisfy given predicate /// and returns new parser state which pub(crate) fn take_while(self, mut pred: impl FnMut(char) -> bool) -> (&'s str, Self) {