diff --git a/Cargo.lock b/Cargo.lock index 10e24131bd..04969c2af2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1177,7 +1177,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] [[package]] @@ -1914,7 +1914,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] [[package]] @@ -3972,7 +3972,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] [[package]] @@ -4801,7 +4801,7 @@ checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", "synstructure", ] @@ -4842,7 +4842,7 @@ checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", "synstructure", ] @@ -4885,5 +4885,5 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] diff --git a/Cargo.toml b/Cargo.toml index bf0a867e1e..23e77bd1fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -184,6 +184,7 @@ rand_xoshiro = "0.6.0" hex = "0.4.3" tempfile = "3.10.1" criterion = { version = "0.5.1", features = ["async_tokio"] } +libsqlite3-sys = { version = "0.30.1" } # If this is an unconditional dev-dependency then Cargo will *always* try to build `libsqlite3-sys`, # even when SQLite isn't the intended test target, and fail if the build environment is not set up for compiling C code. diff --git a/sqlx-core/src/acquire.rs b/sqlx-core/src/acquire.rs index c9d7fb215c..59bac9fa59 100644 --- a/sqlx-core/src/acquire.rs +++ b/sqlx-core/src/acquire.rs @@ -93,7 +93,7 @@ impl<'a, DB: Database> Acquire<'a> for &'_ Pool { let conn = self.acquire(); Box::pin(async move { - Transaction::begin(MaybePoolConnection::PoolConnection(conn.await?)).await + Transaction::begin(MaybePoolConnection::PoolConnection(conn.await?), None).await }) } } @@ -121,7 +121,7 @@ macro_rules! impl_acquire { 'c, Result<$crate::transaction::Transaction<'c, $DB>, $crate::error::Error>, > { - $crate::transaction::Transaction::begin(self) + $crate::transaction::Transaction::begin(self, None) } } }; diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index b30cbe83f3..2fe9ed7656 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -3,6 +3,7 @@ use crate::describe::Describe; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; +use std::borrow::Cow; use std::fmt::Debug; pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { @@ -26,7 +27,13 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { fn ping(&mut self) -> BoxFuture<'_, crate::Result<()>>; /// Begin a new transaction or establish a savepoint within the active transaction. - fn begin(&mut self) -> BoxFuture<'_, crate::Result<()>>; + /// + /// If this is a new transaction, `statement` may be used instead of the + /// default "BEGIN" statement. + /// + /// If we are already inside a transaction and `statement.is_some()`, then + /// `Error::InvalidSavePoint` is returned without running any statements. + fn begin(&mut self, statement: Option>) -> BoxFuture<'_, crate::Result<()>>; fn commit(&mut self) -> BoxFuture<'_, crate::Result<()>>; diff --git a/sqlx-core/src/any/connection/mod.rs b/sqlx-core/src/any/connection/mod.rs index b6f795848a..8cf8fc510c 100644 --- a/sqlx-core/src/any/connection/mod.rs +++ b/sqlx-core/src/any/connection/mod.rs @@ -1,4 +1,5 @@ use futures_core::future::BoxFuture; +use std::borrow::Cow; use crate::any::{Any, AnyConnectOptions}; use crate::connection::{ConnectOptions, Connection}; @@ -87,7 +88,17 @@ impl Connection for AnyConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-core/src/any/transaction.rs b/sqlx-core/src/any/transaction.rs index fce4175626..4972268499 100644 --- a/sqlx-core/src/any/transaction.rs +++ b/sqlx-core/src/any/transaction.rs @@ -1,4 +1,5 @@ use futures_util::future::BoxFuture; +use std::borrow::Cow; use crate::any::{Any, AnyConnection}; use crate::error::Error; @@ -9,8 +10,11 @@ pub struct AnyTransactionManager; impl TransactionManager for AnyTransactionManager { type Database = Any; - fn begin(conn: &mut AnyConnection) -> BoxFuture<'_, Result<(), Error>> { - conn.backend.begin() + fn begin<'conn>( + conn: &'conn mut AnyConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { + conn.backend.begin(statement) } fn commit(conn: &mut AnyConnection) -> BoxFuture<'_, Result<(), Error>> { diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index ce2aa6c629..ba226bc814 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -4,6 +4,7 @@ use crate::error::Error; use crate::transaction::Transaction; use futures_core::future::BoxFuture; use log::LevelFilter; +use std::borrow::Cow; use std::fmt::Debug; use std::str::FromStr; use std::time::Duration; @@ -49,6 +50,22 @@ pub trait Connection: Send { where Self: Sized; + /// Begin a new transaction with a custom statement. + /// + /// Returns a [`Transaction`] for controlling and tracking the new transaction. + /// + /// Returns an error if the connection is already in a transaction or if + /// `statement` does not put the connection into a transaction. + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) + } + /// Execute the function inside a transaction. /// /// If the function returns an error, the transaction will be rolled back. If it does not diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 17774addd2..150d643180 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -111,6 +111,12 @@ pub enum Error { #[cfg(feature = "migrate")] #[error("{0}")] Migrate(#[source] Box), + + #[error("attempted to call begin_with at non-zero transaction depth")] + InvalidSavePointStatement, + + #[error("got unexpected connection status after attempting to begin transaction")] + BeginFailed, } impl StdError for Box {} diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index bf3a6d4b1c..c029fec6eb 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -191,7 +191,7 @@ impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection futures_core::future::BoxFuture<'c, Result, Error>> { - crate::transaction::Transaction::begin(&mut **self) + crate::transaction::Transaction::begin(&mut **self, None) } } diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index e998618413..b759bacdda 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -54,6 +54,7 @@ //! [`Pool::acquire`] or //! [`Pool::begin`]. +use std::borrow::Cow; use std::fmt; use std::future::Future; use std::pin::Pin; @@ -367,13 +368,17 @@ impl Pool { /// Retrieves a connection and immediately begins a new transaction. pub async fn begin(&self) -> Result, Error> { - Transaction::begin(MaybePoolConnection::PoolConnection(self.acquire().await?)).await + Transaction::begin( + MaybePoolConnection::PoolConnection(self.acquire().await?), + None, + ) + .await } /// Attempts to retrieve a connection and immediately begins a new transaction if successful. pub async fn try_begin(&self) -> Result>, Error> { match self.try_acquire() { - Some(conn) => Transaction::begin(MaybePoolConnection::PoolConnection(conn)) + Some(conn) => Transaction::begin(MaybePoolConnection::PoolConnection(conn), None) .await .map(Some), @@ -381,6 +386,36 @@ impl Pool { } } + /// Retrieves a connection and immediately begins a new transaction using `statement`. + pub async fn begin_with( + &self, + statement: impl Into>, + ) -> Result, Error> { + Transaction::begin( + MaybePoolConnection::PoolConnection(self.acquire().await?), + Some(statement.into()), + ) + .await + } + + /// Attempts to retrieve a connection and, if successful, immediately begins a new + /// transaction using `statement`. + pub async fn try_begin_with( + &self, + statement: impl Into>, + ) -> Result>, Error> { + match self.try_acquire() { + Some(conn) => Transaction::begin( + MaybePoolConnection::PoolConnection(conn), + Some(statement.into()), + ) + .await + .map(Some), + + None => Ok(None), + } + } + /// Shut down the connection pool, immediately waking all tasks waiting for a connection. /// /// Upon calling this method, any currently waiting or subsequent calls to [`Pool::acquire`] and diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index 9cd38aab3a..d9459c53d4 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -16,9 +16,16 @@ pub trait TransactionManager { type Database: Database; /// Begin a new transaction or establish a savepoint within the active transaction. - fn begin( - conn: &mut ::Connection, - ) -> BoxFuture<'_, Result<(), Error>>; + /// + /// If this is a new transaction, `statement` may be used instead of the + /// default "BEGIN" statement. + /// + /// If we are already inside a transaction and `statement.is_some()`, then + /// `Error::InvalidSavePoint` is returned without running any statements. + fn begin<'conn>( + conn: &'conn mut ::Connection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>>; /// Commit the active transaction or release the most recent savepoint. fn commit( @@ -83,11 +90,12 @@ where #[doc(hidden)] pub fn begin( conn: impl Into>, + statement: Option>, ) -> BoxFuture<'c, Result> { let mut conn = conn.into(); Box::pin(async move { - DB::TransactionManager::begin(&mut conn).await?; + DB::TransactionManager::begin(&mut conn, statement).await?; Ok(Self { connection: conn, @@ -237,7 +245,7 @@ impl<'c, 't, DB: Database> crate::acquire::Acquire<'t> for &'t mut Transaction<' #[inline] fn begin(self) -> BoxFuture<'t, Result, Error>> { - Transaction::begin(&mut **self) + Transaction::begin(&mut **self, None) } } diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index 0466bfc0a4..96190f0bd2 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -16,6 +16,7 @@ use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::transaction::TransactionManager; +use std::borrow::Cow; use std::future; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = MySql); @@ -37,8 +38,11 @@ impl AnyConnectionBackend for MySqlConnection { Connection::ping(self) } - fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { - MySqlTransactionManager::begin(self) + fn begin( + &mut self, + statement: Option>, + ) -> BoxFuture<'_, sqlx_core::Result<()>> { + MySqlTransactionManager::begin(self, statement) } fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index 468478e550..f52756d4c1 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -28,6 +28,7 @@ impl MySqlConnection { inner: Box::new(MySqlConnectionInner { stream, transaction_depth: 0, + status_flags: Default::default(), cache_statement: StatementCache::new(options.statement_cache_capacity), log_settings: options.log_settings.clone(), }), diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index 07c7979b08..169dee76b7 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -166,6 +166,8 @@ impl MySqlConnection { // this indicates either a successful query with no rows at all or a failed query let ok = packet.ok()?; + self.inner.status_flags = ok.status; + let rows_affected = ok.affected_rows; logger.increase_rows_affected(rows_affected); let done = MySqlQueryResult { @@ -208,6 +210,8 @@ impl MySqlConnection { if packet[0] == 0xfe && packet.len() < 9 { let eof = packet.eof(self.inner.stream.capabilities)?; + self.inner.status_flags = eof.status; + r#yield!(Either::Left(MySqlQueryResult { rows_affected: 0, last_insert_id: 0, diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index c4978a7701..0a2f5fb839 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use futures_core::future::BoxFuture; @@ -7,6 +8,7 @@ pub(crate) use stream::{MySqlStream, Waiting}; use crate::common::StatementCache; use crate::error::Error; +use crate::protocol::response::Status; use crate::protocol::statement::StmtClose; use crate::protocol::text::{Ping, Quit}; use crate::statement::MySqlStatementMetadata; @@ -34,6 +36,7 @@ pub(crate) struct MySqlConnectionInner { // transaction status pub(crate) transaction_depth: usize, + status_flags: Status, // cache by query string to the statement id and metadata cache_statement: StatementCache<(u32, MySqlStatementMetadata)>, @@ -41,6 +44,14 @@ pub(crate) struct MySqlConnectionInner { log_settings: LogSettings, } +impl MySqlConnection { + pub(crate) fn in_transaction(&self) -> bool { + self.inner + .status_flags + .intersects(Status::SERVER_STATUS_IN_TRANS) + } +} + impl Debug for MySqlConnection { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("MySqlConnection").finish() @@ -111,7 +122,17 @@ impl Connection for MySqlConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn shrink_buffers(&mut self) { diff --git a/sqlx-mysql/src/protocol/response/status.rs b/sqlx-mysql/src/protocol/response/status.rs index bf5013deed..4a8bb0375a 100644 --- a/sqlx-mysql/src/protocol/response/status.rs +++ b/sqlx-mysql/src/protocol/response/status.rs @@ -1,7 +1,7 @@ // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/mysql__com_8h.html#a1d854e841086925be1883e4d7b4e8cad // https://mariadb.com/kb/en/library/mariadb-connectorc-types-and-definitions/#server-status bitflags::bitflags! { - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)] pub struct Status: u16 { // Is raised when a multi-statement transaction has been started, either explicitly, // by means of BEGIN or COMMIT AND CHAIN, or implicitly, by the first diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index d8538cc2b3..11f56c0cb9 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use futures_core::future::BoxFuture; use crate::connection::Waiting; @@ -14,12 +16,24 @@ pub struct MySqlTransactionManager; impl TransactionManager for MySqlTransactionManager { type Database = MySql; - fn begin(conn: &mut MySqlConnection) -> BoxFuture<'_, Result<(), Error>> { + fn begin<'conn>( + conn: &'conn mut MySqlConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { let depth = conn.inner.transaction_depth; - - conn.execute(&*begin_ansi_transaction_sql(depth)).await?; - conn.inner.transaction_depth = depth + 1; + let statement = match statement { + // custom `BEGIN` statements are not allowed if we're already in a transaction + // (we need to issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), + }; + conn.execute(&*statement).await?; + if !conn.in_transaction() { + return Err(Error::BeginFailed); + } + conn.inner.transaction_depth += 1; Ok(()) }) diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index efa9a044bc..d189301c13 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -5,6 +5,7 @@ use crate::{ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_util::{stream, StreamExt, TryFutureExt, TryStreamExt}; +use std::borrow::Cow; use std::future; use sqlx_core::any::{ @@ -39,8 +40,11 @@ impl AnyConnectionBackend for PgConnection { Connection::ping(self) } - fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { - PgTransactionManager::begin(self) + fn begin( + &mut self, + statement: Option>, + ) -> BoxFuture<'_, sqlx_core::Result<()>> { + PgTransactionManager::begin(self, statement) } fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index c139f8e53d..96e3e2fe12 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -127,6 +128,13 @@ impl PgConnection { Ok(()) } + + pub(crate) fn in_transaction(&self) -> bool { + match self.inner.transaction_status { + TransactionStatus::Transaction => true, + TransactionStatus::Error | TransactionStatus::Idle => false, + } + } } impl Debug for PgConnection { @@ -179,7 +187,17 @@ impl Connection for PgConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index e7c78488eb..f70961cc19 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -1,4 +1,5 @@ use futures_core::future::BoxFuture; +use std::borrow::Cow; use crate::error::Error; use crate::executor::Executor; @@ -13,13 +14,27 @@ pub struct PgTransactionManager; impl TransactionManager for PgTransactionManager { type Database = Postgres; - fn begin(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> { + fn begin<'conn>( + conn: &'conn mut PgConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { + let depth = conn.inner.transaction_depth; + let statement = match statement { + // custom `BEGIN` statements are not allowed if we're already in + // a transaction (we need to issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), + }; + let rollback = Rollback::new(conn); - let query = begin_ansi_transaction_sql(rollback.conn.inner.transaction_depth); - rollback.conn.queue_simple_query(&query)?; - rollback.conn.inner.transaction_depth += 1; + rollback.conn.queue_simple_query(&statement)?; rollback.conn.wait_until_ready().await?; + if !rollback.conn.in_transaction() { + return Err(Error::BeginFailed); + } + rollback.conn.inner.transaction_depth += 1; rollback.defuse(); Ok(()) diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index 01600d9931..2c74c01494 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use crate::{ Either, Sqlite, SqliteArgumentValue, SqliteArguments, SqliteColumn, SqliteConnectOptions, SqliteConnection, SqliteQueryResult, SqliteRow, SqliteTransactionManager, SqliteTypeInfo, @@ -37,8 +39,11 @@ impl AnyConnectionBackend for SqliteConnection { Connection::ping(self) } - fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { - SqliteTransactionManager::begin(self) + fn begin( + &mut self, + statement: Option>, + ) -> BoxFuture<'_, sqlx_core::Result<()>> { + SqliteTransactionManager::begin(self, statement) } fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index a579b8a605..53c3156e9d 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::cmp::Ordering; use std::ffi::CStr; use std::fmt::Write; @@ -11,8 +12,8 @@ use futures_core::future::BoxFuture; use futures_intrusive::sync::MutexGuard; use futures_util::future; use libsqlite3_sys::{ - sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook, - sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, + sqlite3, sqlite3_commit_hook, sqlite3_get_autocommit, sqlite3_progress_handler, + sqlite3_rollback_hook, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, }; pub(crate) use handle::ConnectionHandle; @@ -235,7 +236,17 @@ impl Connection for SqliteConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn cached_statements_size(&self) -> usize { @@ -492,6 +503,11 @@ impl LockedSqliteHandle<'_> { pub fn remove_rollback_hook(&mut self) { self.guard.remove_rollback_hook(); } + + pub(crate) fn in_transaction(&mut self) -> bool { + let ret = unsafe { sqlite3_get_autocommit(self.as_raw_handle().as_ptr()) }; + ret == 0 + } } impl Drop for ConnectionState { diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index a01de2419c..c8e6f0a268 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -56,6 +56,7 @@ enum Command { }, Begin { tx: rendezvous_oneshot::Sender>, + statement: Option>, }, Commit { tx: rendezvous_oneshot::Sender>, @@ -180,11 +181,26 @@ impl ConnectionWorker { update_cached_statements_size(&conn, &shared.cached_statements_size); } - Command::Begin { tx } => { + Command::Begin { tx, statement } => { let depth = conn.transaction_depth; + + let statement = match statement { + // custom `BEGIN` statements are not allowed if + // we're already in a transaction (we need to + // issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => { + if tx.blocking_send(Err(Error::InvalidSavePointStatement)).is_err() { + break; + } + continue; + }, + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), + }; + let res = conn.handle - .exec(begin_ansi_transaction_sql(depth)) + .exec(statement) .map(|_| { conn.transaction_depth += 1; }); @@ -331,8 +347,11 @@ impl ConnectionWorker { Ok(rx) } - pub(crate) async fn begin(&mut self) -> Result<(), Error> { - self.oneshot_cmd_with_ack(|tx| Command::Begin { tx }) + pub(crate) async fn begin( + &mut self, + statement: Option>, + ) -> Result<(), Error> { + self.oneshot_cmd_with_ack(|tx| Command::Begin { tx, statement }) .await? } diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index 24eaca51b1..d217cffd61 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -1,4 +1,5 @@ use futures_core::future::BoxFuture; +use std::borrow::Cow; use crate::{Sqlite, SqliteConnection}; use sqlx_core::error::Error; @@ -10,8 +11,22 @@ pub struct SqliteTransactionManager; impl TransactionManager for SqliteTransactionManager { type Database = Sqlite; - fn begin(conn: &mut SqliteConnection) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(conn.worker.begin()) + fn begin<'conn>( + conn: &'conn mut SqliteConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { + Box::pin(async { + let is_custom_statement = statement.is_some(); + conn.worker.begin(statement).await?; + if is_custom_statement { + // Check that custom statement actually put the connection into a transaction. + let mut handle = conn.lock_handle().await?; + if !handle.in_transaction() { + return Err(Error::BeginFailed); + } + } + Ok(()) + }) } fn commit(conn: &mut SqliteConnection) -> BoxFuture<'_, Result<(), Error>> { diff --git a/tests/mysql/error.rs b/tests/mysql/error.rs index 7c84266c32..3ee1024fc8 100644 --- a/tests/mysql/error.rs +++ b/tests/mysql/error.rs @@ -1,4 +1,4 @@ -use sqlx::{error::ErrorKind, mysql::MySql, Connection}; +use sqlx::{error::ErrorKind, mysql::MySql, Connection, Error}; use sqlx_test::new; #[sqlx_macros::test] @@ -74,3 +74,29 @@ async fn it_fails_with_check_violation() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_begin_failed() -> anyhow::Result<()> { + let mut conn = new::().await?; + let res = conn.begin_with("SELECT * FROM tweet").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::BeginFailed), "{err:?}"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_fails_with_invalid_save_point_statement() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut txn = conn.begin().await?; + let txn_conn = sqlx::Acquire::acquire(&mut txn).await?; + let res = txn_conn.begin_with("BEGIN").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::InvalidSavePointStatement), "{err}"); + + Ok(()) +} diff --git a/tests/postgres/error.rs b/tests/postgres/error.rs index d6f78140da..32bf814770 100644 --- a/tests/postgres/error.rs +++ b/tests/postgres/error.rs @@ -1,4 +1,4 @@ -use sqlx::{error::ErrorKind, postgres::Postgres, Connection}; +use sqlx::{error::ErrorKind, postgres::Postgres, Connection, Error}; use sqlx_test::new; #[sqlx_macros::test] @@ -74,3 +74,29 @@ async fn it_fails_with_check_violation() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_begin_failed() -> anyhow::Result<()> { + let mut conn = new::().await?; + let res = conn.begin_with("SELECT * FROM tweet").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::BeginFailed), "{err:?}"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_fails_with_invalid_save_point_statement() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut txn = conn.begin().await?; + let txn_conn = sqlx::Acquire::acquire(&mut txn).await?; + let res = txn_conn.begin_with("BEGIN").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::InvalidSavePointStatement), "{err}"); + + Ok(()) +} diff --git a/tests/sqlite/error.rs b/tests/sqlite/error.rs index 1f6b797e69..8729842b70 100644 --- a/tests/sqlite/error.rs +++ b/tests/sqlite/error.rs @@ -1,4 +1,4 @@ -use sqlx::{error::ErrorKind, sqlite::Sqlite, Connection, Executor}; +use sqlx::{error::ErrorKind, sqlite::Sqlite, Connection, Error, Executor}; use sqlx_test::new; #[sqlx_macros::test] @@ -70,3 +70,29 @@ async fn it_fails_with_check_violation() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_begin_failed() -> anyhow::Result<()> { + let mut conn = new::().await?; + let res = conn.begin_with("SELECT * FROM tweet").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::BeginFailed), "{err:?}"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_fails_with_invalid_save_point_statement() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut txn = conn.begin().await?; + let txn_conn = sqlx::Acquire::acquire(&mut txn).await?; + let res = txn_conn.begin_with("BEGIN").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::InvalidSavePointStatement), "{err}"); + + Ok(()) +} diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index b733ccbb4c..55b2630f90 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -6,6 +6,7 @@ use sqlx::{ query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row, SqliteConnection, SqlitePool, Statement, TypeInfo, }; +use sqlx_sqlite::LockedSqliteHandle; use sqlx_test::new; use std::sync::Arc; @@ -960,3 +961,53 @@ async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Res assert_eq!(1, Arc::strong_count(&ref_counted_object)); Ok(()) } + +#[sqlx_macros::test] +async fn it_can_use_transaction_options() -> anyhow::Result<()> { + async fn check_txn_state(conn: &mut SqliteConnection, expected: SqliteTransactionState) { + let state = transaction_state(&mut conn.lock_handle().await.unwrap()); + assert_eq!(state, expected); + } + + let mut conn = SqliteConnectOptions::new() + .in_memory(true) + .connect() + .await + .unwrap(); + + check_txn_state(&mut conn, SqliteTransactionState::None).await; + + let mut tx = conn.begin_with("BEGIN DEFERRED").await?; + check_txn_state(&mut tx, SqliteTransactionState::None).await; + drop(tx); + + let mut tx = conn.begin_with("BEGIN IMMEDIATE").await?; + check_txn_state(&mut tx, SqliteTransactionState::Write).await; + drop(tx); + + let mut tx = conn.begin_with("BEGIN EXCLUSIVE").await?; + check_txn_state(&mut tx, SqliteTransactionState::Write).await; + drop(tx); + + Ok(()) +} + +fn transaction_state(handle: &mut LockedSqliteHandle) -> SqliteTransactionState { + use libsqlite3_sys::{sqlite3_txn_state, SQLITE_TXN_NONE, SQLITE_TXN_READ, SQLITE_TXN_WRITE}; + + let unchecked_state = + unsafe { sqlite3_txn_state(handle.as_raw_handle().as_ptr(), std::ptr::null()) }; + match unchecked_state { + SQLITE_TXN_NONE => SqliteTransactionState::None, + SQLITE_TXN_READ => SqliteTransactionState::Read, + SQLITE_TXN_WRITE => SqliteTransactionState::Write, + _ => panic!("unknown txn state: {unchecked_state}"), + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum SqliteTransactionState { + None, + Read, + Write, +}