From 7ab05a62c7eba5f725a254a427035a31035ffba7 Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Wed, 27 Nov 2024 18:50:02 -0800 Subject: [PATCH] chore: test begin_with works for all SQLite "BEGIN" statements --- sqlx-sqlite/src/connection/mod.rs | 23 +++++++++++++++++++++ sqlx-sqlite/src/lib.rs | 4 +++- tests/sqlite/sqlite.rs | 33 +++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 53c3156e9d..d7cd48fe25 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -508,6 +508,29 @@ impl LockedSqliteHandle<'_> { let ret = unsafe { sqlite3_get_autocommit(self.as_raw_handle().as_ptr()) }; ret == 0 } + + /// Calls `sqlite3_txn_state` on this handle. + pub fn transaction_state(&mut self) -> Result { + use libsqlite3_sys::{ + sqlite3_txn_state, SQLITE_TXN_NONE, SQLITE_TXN_READ, SQLITE_TXN_WRITE, + }; + + let state = + match unsafe { sqlite3_txn_state(self.as_raw_handle().as_ptr(), std::ptr::null()) } { + SQLITE_TXN_NONE => SqliteTransactionState::None, + SQLITE_TXN_READ => SqliteTransactionState::Read, + SQLITE_TXN_WRITE => SqliteTransactionState::Write, + _ => return Err(Error::Protocol("Invalid transaction state".into())), + }; + Ok(state) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum SqliteTransactionState { + None, + Read, + Write, } impl Drop for ConnectionState { diff --git a/sqlx-sqlite/src/lib.rs b/sqlx-sqlite/src/lib.rs index f8f5534879..398ecf59e8 100644 --- a/sqlx-sqlite/src/lib.rs +++ b/sqlx-sqlite/src/lib.rs @@ -46,7 +46,9 @@ use std::sync::atomic::AtomicBool; pub use arguments::{SqliteArgumentValue, SqliteArguments}; pub use column::SqliteColumn; -pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult}; +pub use connection::{ + LockedSqliteHandle, SqliteConnection, SqliteOperation, SqliteTransactionState, UpdateHookResult, +}; pub use database::Sqlite; pub use error::SqliteError; pub use options::{ diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index b733ccbb4c..11a582c370 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -960,3 +960,36 @@ async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Res assert_eq!(1, Arc::strong_count(&ref_counted_object)); Ok(()) } + +#[sqlx_macros::test] +async fn it_can_use_transaction_options() -> anyhow::Result<()> { + use sqlx_sqlite::SqliteTransactionState; + + async fn check_txn_state( + conn: &mut SqliteConnection, + expected: SqliteTransactionState, + ) -> Result<(), sqlx::Error> { + let state = conn.lock_handle().await?.transaction_state()?; + assert_eq!(state, expected); + Ok(()) + } + + let mut conn = new::().await?; + + check_txn_state(&mut conn, SqliteTransactionState::None).await?; + + let mut tx = conn.begin_with("BEGIN DEFERRED").await?; + check_txn_state(&mut *tx, SqliteTransactionState::None).await?; + drop(tx); + + let mut tx = conn.begin_with("BEGIN IMMEDIATE").await?; + check_txn_state(&mut *tx, SqliteTransactionState::Write).await?; + drop(tx); + + // Note: may result in database locked errors if tests are run in parallel + let mut tx = conn.begin_with("BEGIN EXCLUSIVE").await?; + check_txn_state(&mut *tx, SqliteTransactionState::Write).await?; + drop(tx); + + Ok(()) +}