Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add begin_with methods to support database-specific transaction options #3614

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ rand_xoshiro = "0.6.0"
hex = "0.4.3"
tempfile = "3.10.1"
criterion = { version = "0.5.1", features = ["async_tokio"] }
libsqlite3-sys = { version = "0.30.1" }

# If this is an unconditional dev-dependency then Cargo will *always* try to build `libsqlite3-sys`,
# even when SQLite isn't the intended test target, and fail if the build environment is not set up for compiling C code.
Expand Down
4 changes: 2 additions & 2 deletions sqlx-core/src/acquire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl<'a, DB: Database> Acquire<'a> for &'_ Pool<DB> {
let conn = self.acquire();

Box::pin(async move {
Transaction::begin(MaybePoolConnection::PoolConnection(conn.await?)).await
Transaction::begin(MaybePoolConnection::PoolConnection(conn.await?), None).await
})
}
}
Expand Down Expand Up @@ -121,7 +121,7 @@ macro_rules! impl_acquire {
'c,
Result<$crate::transaction::Transaction<'c, $DB>, $crate::error::Error>,
> {
$crate::transaction::Transaction::begin(self)
$crate::transaction::Transaction::begin(self, None)
}
}
};
Expand Down
9 changes: 8 additions & 1 deletion sqlx-core/src/any/connection/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::describe::Describe;
use either::Either;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use std::borrow::Cow;
use std::fmt::Debug;

pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static {
Expand All @@ -26,7 +27,13 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static {
fn ping(&mut self) -> BoxFuture<'_, crate::Result<()>>;

/// Begin a new transaction or establish a savepoint within the active transaction.
fn begin(&mut self) -> BoxFuture<'_, crate::Result<()>>;
///
/// If this is a new transaction, `statement` may be used instead of the
/// default "BEGIN" statement.
///
/// If we are already inside a transaction and `statement.is_some()`, then
/// `Error::InvalidSavePoint` is returned without running any statements.
fn begin(&mut self, statement: Option<Cow<'static, str>>) -> BoxFuture<'_, crate::Result<()>>;

fn commit(&mut self) -> BoxFuture<'_, crate::Result<()>>;

Expand Down
13 changes: 12 additions & 1 deletion sqlx-core/src/any/connection/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use futures_core::future::BoxFuture;
use std::borrow::Cow;

use crate::any::{Any, AnyConnectOptions};
use crate::connection::{ConnectOptions, Connection};
Expand Down Expand Up @@ -87,7 +88,17 @@ impl Connection for AnyConnection {
where
Self: Sized,
{
Transaction::begin(self)
Transaction::begin(self, None)
}

fn begin_with(
&mut self,
statement: impl Into<Cow<'static, str>>,
) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
where
Self: Sized,
{
Transaction::begin(self, Some(statement.into()))
}

fn cached_statements_size(&self) -> usize {
Expand Down
8 changes: 6 additions & 2 deletions sqlx-core/src/any/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use futures_util::future::BoxFuture;
use std::borrow::Cow;

use crate::any::{Any, AnyConnection};
use crate::error::Error;
Expand All @@ -9,8 +10,11 @@ pub struct AnyTransactionManager;
impl TransactionManager for AnyTransactionManager {
type Database = Any;

fn begin(conn: &mut AnyConnection) -> BoxFuture<'_, Result<(), Error>> {
conn.backend.begin()
fn begin<'conn>(
conn: &'conn mut AnyConnection,
statement: Option<Cow<'static, str>>,
) -> BoxFuture<'conn, Result<(), Error>> {
conn.backend.begin(statement)
}

fn commit(conn: &mut AnyConnection) -> BoxFuture<'_, Result<(), Error>> {
Expand Down
17 changes: 17 additions & 0 deletions sqlx-core/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::error::Error;
use crate::transaction::Transaction;
use futures_core::future::BoxFuture;
use log::LevelFilter;
use std::borrow::Cow;
use std::fmt::Debug;
use std::str::FromStr;
use std::time::Duration;
Expand Down Expand Up @@ -49,6 +50,22 @@ pub trait Connection: Send {
where
Self: Sized;

/// Begin a new transaction with a custom statement.
///
/// Returns a [`Transaction`] for controlling and tracking the new transaction.
bonsairobo marked this conversation as resolved.
Show resolved Hide resolved
///
/// Returns an error if the connection is already in a transaction or if
/// `statement` does not put the connection into a transaction.
fn begin_with(
&mut self,
statement: impl Into<Cow<'static, str>>,
) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
where
Self: Sized,
{
Transaction::begin(self, Some(statement.into()))
}

/// Execute the function inside a transaction.
///
/// If the function returns an error, the transaction will be rolled back. If it does not
Expand Down
6 changes: 6 additions & 0 deletions sqlx-core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ pub enum Error {
#[cfg(feature = "migrate")]
#[error("{0}")]
Migrate(#[source] Box<crate::migrate::MigrateError>),

#[error("attempted to call begin_with at non-zero transaction depth")]
InvalidSavePointStatement,

#[error("got unexpected connection status after attempting to begin transaction")]
BeginFailed,
}

impl StdError for Box<dyn DatabaseError> {}
Expand Down
2 changes: 1 addition & 1 deletion sqlx-core/src/pool/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection<DB
self,
) -> futures_core::future::BoxFuture<'c, Result<crate::transaction::Transaction<'c, DB>, Error>>
{
crate::transaction::Transaction::begin(&mut **self)
crate::transaction::Transaction::begin(&mut **self, None)
}
}

Expand Down
39 changes: 37 additions & 2 deletions sqlx-core/src/pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
//! [`Pool::acquire`] or
//! [`Pool::begin`].

use std::borrow::Cow;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
Expand Down Expand Up @@ -367,20 +368,54 @@ impl<DB: Database> Pool<DB> {

/// Retrieves a connection and immediately begins a new transaction.
pub async fn begin(&self) -> Result<Transaction<'static, DB>, Error> {
Transaction::begin(MaybePoolConnection::PoolConnection(self.acquire().await?)).await
Transaction::begin(
MaybePoolConnection::PoolConnection(self.acquire().await?),
None,
)
.await
}

/// Attempts to retrieve a connection and immediately begins a new transaction if successful.
pub async fn try_begin(&self) -> Result<Option<Transaction<'static, DB>>, Error> {
match self.try_acquire() {
Some(conn) => Transaction::begin(MaybePoolConnection::PoolConnection(conn))
Some(conn) => Transaction::begin(MaybePoolConnection::PoolConnection(conn), None)
.await
.map(Some),

None => Ok(None),
}
}

/// Retrieves a connection and immediately begins a new transaction using `statement`.
pub async fn begin_with(
&self,
statement: impl Into<Cow<'static, str>>,
) -> Result<Transaction<'static, DB>, Error> {
Transaction::begin(
MaybePoolConnection::PoolConnection(self.acquire().await?),
Some(statement.into()),
)
.await
}

/// Attempts to retrieve a connection and, if successful, immediately begins a new
/// transaction using `statement`.
pub async fn try_begin_with(
&self,
statement: impl Into<Cow<'static, str>>,
) -> Result<Option<Transaction<'static, DB>>, Error> {
match self.try_acquire() {
Some(conn) => Transaction::begin(
MaybePoolConnection::PoolConnection(conn),
Some(statement.into()),
)
.await
.map(Some),

None => Ok(None),
}
}

/// Shut down the connection pool, immediately waking all tasks waiting for a connection.
///
/// Upon calling this method, any currently waiting or subsequent calls to [`Pool::acquire`] and
Expand Down
18 changes: 13 additions & 5 deletions sqlx-core/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,16 @@ pub trait TransactionManager {
type Database: Database;

/// Begin a new transaction or establish a savepoint within the active transaction.
fn begin(
conn: &mut <Self::Database as Database>::Connection,
) -> BoxFuture<'_, Result<(), Error>>;
///
/// If this is a new transaction, `statement` may be used instead of the
/// default "BEGIN" statement.
///
/// If we are already inside a transaction and `statement.is_some()`, then
/// `Error::InvalidSavePoint` is returned without running any statements.
fn begin<'conn>(
conn: &'conn mut <Self::Database as Database>::Connection,
statement: Option<Cow<'static, str>>,
) -> BoxFuture<'conn, Result<(), Error>>;

/// Commit the active transaction or release the most recent savepoint.
fn commit(
Expand Down Expand Up @@ -83,11 +90,12 @@ where
#[doc(hidden)]
pub fn begin(
conn: impl Into<MaybePoolConnection<'c, DB>>,
statement: Option<Cow<'static, str>>,
) -> BoxFuture<'c, Result<Self, Error>> {
let mut conn = conn.into();

Box::pin(async move {
DB::TransactionManager::begin(&mut conn).await?;
DB::TransactionManager::begin(&mut conn, statement).await?;

Ok(Self {
connection: conn,
Expand Down Expand Up @@ -237,7 +245,7 @@ impl<'c, 't, DB: Database> crate::acquire::Acquire<'t> for &'t mut Transaction<'

#[inline]
fn begin(self) -> BoxFuture<'t, Result<Transaction<'t, DB>, Error>> {
Transaction::begin(&mut **self)
Transaction::begin(&mut **self, None)
}
}

Expand Down
8 changes: 6 additions & 2 deletions sqlx-mysql/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use sqlx_core::database::Database;
use sqlx_core::describe::Describe;
use sqlx_core::executor::Executor;
use sqlx_core::transaction::TransactionManager;
use std::borrow::Cow;
use std::future;

sqlx_core::declare_driver_with_optional_migrate!(DRIVER = MySql);
Expand All @@ -37,8 +38,11 @@ impl AnyConnectionBackend for MySqlConnection {
Connection::ping(self)
}

fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
MySqlTransactionManager::begin(self)
fn begin(
&mut self,
statement: Option<Cow<'static, str>>,
) -> BoxFuture<'_, sqlx_core::Result<()>> {
MySqlTransactionManager::begin(self, statement)
}

fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
Expand Down
1 change: 1 addition & 0 deletions sqlx-mysql/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ impl MySqlConnection {
inner: Box::new(MySqlConnectionInner {
stream,
transaction_depth: 0,
status_flags: Default::default(),
cache_statement: StatementCache::new(options.statement_cache_capacity),
log_settings: options.log_settings.clone(),
}),
Expand Down
4 changes: 4 additions & 0 deletions sqlx-mysql/src/connection/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ impl MySqlConnection {
// this indicates either a successful query with no rows at all or a failed query
let ok = packet.ok()?;

self.inner.status_flags = ok.status;

let rows_affected = ok.affected_rows;
logger.increase_rows_affected(rows_affected);
let done = MySqlQueryResult {
Expand Down Expand Up @@ -208,6 +210,8 @@ impl MySqlConnection {
if packet[0] == 0xfe && packet.len() < 9 {
let eof = packet.eof(self.inner.stream.capabilities)?;

self.inner.status_flags = eof.status;

r#yield!(Either::Left(MySqlQueryResult {
rows_affected: 0,
last_insert_id: 0,
Expand Down
Loading
Loading