diff --git a/scylla/src/transport/connection.rs b/scylla/src/transport/connection.rs index d57612dec3..2493cf002c 100644 --- a/scylla/src/transport/connection.rs +++ b/scylla/src/transport/connection.rs @@ -26,7 +26,7 @@ pub(crate) use ssl_config::SslConfig; use crate::authentication::AuthenticatorProvider; use scylla_cql::frame::response::authenticate::Authenticate; -use std::collections::{BTreeSet, HashMap}; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::convert::TryFrom; use std::io::ErrorKind; use std::net::{IpAddr, SocketAddr}; @@ -52,7 +52,7 @@ use crate::frame::{ request::{self, batch, execute, query, register, SerializableRequest}, response::{event::Event, result, NonErrorResponse, Response, ResponseOpcode}, server_event_type::EventType, - value::{BatchValues, ValueList}, + value::{BatchValues, BatchValuesIterator, ValueList}, FrameParams, SerializedRequest, }; use crate::query::Query; @@ -772,11 +772,62 @@ impl Connection { pub(crate) async fn batch_with_consistency( &self, - batch: &Batch, + init_batch: &Batch, values: impl BatchValues, consistency: Consistency, serial_consistency: Option, ) -> Result { + let batch = { + let mut to_prepare = HashSet::::new(); + + { + let mut value_iter = values.batch_values_iter(); + for stmt in &init_batch.statements { + if let BatchStatement::Query(query) = stmt { + let value = value_iter.next_serialized().transpose()?; + if let Some(v) = value { + if v.len() > 0 { + to_prepare.insert(query.contents.clone()); + } + } + } else { + value_iter.skip_next(); + } + } + } + + let mut prepared_queries = HashMap::::new(); + + for query in &to_prepare { + let prepared = self.prepare(&Query::new(query)).await?; + prepared_queries.insert(query.clone(), prepared); + } + + let mut batch: Cow; + + if to_prepare.is_empty() { + batch = Cow::Borrowed(init_batch); + } else { + batch = Cow::Owned(Default::default()); + batch.to_mut().config = init_batch.config.clone(); + for stmt in &init_batch.statements { + match stmt { + BatchStatement::Query(query) => { + match prepared_queries.get(&query.contents) { + Some(prepared) => batch.to_mut().append_statement(prepared.clone()), + None => batch.to_mut().append_statement(query.clone()) + } + } + BatchStatement::PreparedStatement(prepared) => { + batch.to_mut().append_statement(prepared.clone()); + } + } + } + } + + batch + }; + let batch_frame = batch::Batch { statements: Cow::Borrowed(&batch.statements), values,