Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sqlite): add preupdate hook #3625

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/sqlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ jobs:
- run: >
cargo clippy
--no-default-features
--features all-databases,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
--features all-databases,_unstable-all-types,sqlite-preupdate-hook,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
-- -D warnings

# Run beta for new warnings but don't break the build.
# Use a subdirectory of `target` to avoid clobbering the cache.
- run: >
cargo +beta clippy
--no-default-features
--features all-databases,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
--features all-databases,_unstable-all-types,sqlite-preupdate-hook,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
--target-dir target/beta/

check-minimal-versions:
Expand Down Expand Up @@ -140,7 +140,7 @@ jobs:
- run: >
cargo test
--no-default-features
--features any,macros,${{ matrix.linking }},_unstable-all-types,runtime-${{ matrix.runtime }}
--features any,macros,${{ matrix.linking }},${{ matrix.linking == 'sqlite' && 'sqlite-preupdate-hook,' || ''}}_unstable-all-types,runtime-${{ matrix.runtime }}
--
--test-threads=1
env:
Expand Down
13 changes: 7 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ authors.workspace = true
repository.workspace = true

[package.metadata.docs.rs]
features = ["all-databases", "_unstable-all-types"]
features = ["all-databases", "_unstable-all-types", "sqlite-preupdate-hook"]
rustdoc-args = ["--cfg", "docsrs"]

[features]
Expand Down Expand Up @@ -108,6 +108,7 @@ postgres = ["sqlx-postgres", "sqlx-macros?/postgres"]
mysql = ["sqlx-mysql", "sqlx-macros?/mysql"]
sqlite = ["_sqlite", "sqlx-sqlite/bundled", "sqlx-macros?/sqlite"]
sqlite-unbundled = ["_sqlite", "sqlx-sqlite/unbundled", "sqlx-macros?/sqlite-unbundled"]
sqlite-preupdate-hook = ["sqlx-sqlite/preupdate-hook"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should emit a compile error if neither sqlite or sqlite-unbundled is enabled or else it could cause weird errors if it's only enabled on its own.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this.


# types
json = ["sqlx-macros?/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlite?/json"]
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ be removed in the future.
* May result in link errors if the SQLite version is too old. Version `3.20.0` or newer is recommended.
* Can increase build time due to the use of bindgen.

- `sqlite-preupdate-hook`: enables SQLite's [preupdate hook](https://sqlite.org/c3ref/preupdate_count.html) API.
* Exposed as a separate feature because it's generally not enabled by default.
* Using this feature with `sqlite-unbundled` may cause linker failures if the system SQLite version does not support it.

- `any`: Add support for the `Any` database driver, which can proxy to a database driver at runtime.

- `derive`: Add support for the derive family macros, those are `FromRow`, `Type`, `Encode`, `Decode`.
Expand Down
3 changes: 3 additions & 0 deletions sqlx-sqlite/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ uuid = ["dep:uuid", "sqlx-core/uuid"]

regexp = ["dep:regex"]

preupdate-hook = ["libsqlite3-sys/preupdate_hook"]

bundled = ["libsqlite3-sys/bundled"]
unbundled = ["libsqlite3-sys/buildtime_bindgen"]

Expand All @@ -48,6 +50,7 @@ atoi = "2.0"

log = "0.4.18"
tracing = { version = "0.1.37", features = ["log"] }
thiserror = "2.0.0"

serde = { version = "1.0.145", features = ["derive"], optional = true }
regex = { version = "1.5.5", optional = true }
Expand Down
2 changes: 2 additions & 0 deletions sqlx-sqlite/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ impl EstablishParams {
log_settings: self.log_settings.clone(),
progress_handler_callback: None,
update_hook_callback: None,
#[cfg(feature = "preupdate-hook")]
preupdate_hook_callback: None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this being set or referenced by anything. Did you mean to expose this on SqliteConnectOptions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was being used within the preupdate_hook module to avoid having to add more cfg checks in the main connection module and make additional fields pub(super), but that does make it a bit harder to find. I went ahead and moved that logic to be with the rest of the hooks.

commit_hook_callback: None,
rollback_hook_callback: None,
})
Expand Down
50 changes: 50 additions & 0 deletions sqlx-sqlite/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use libsqlite3_sys::{
sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook,
sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE,
};
#[cfg(feature = "preupdate-hook")]
pub use preupdate_hook::*;

pub(crate) use handle::ConnectionHandle;
use sqlx_core::common::StatementCache;
Expand All @@ -36,6 +38,8 @@ mod executor;
mod explain;
mod handle;
pub(crate) mod intmap;
#[cfg(feature = "preupdate-hook")]
mod preupdate_hook;

mod worker;

Expand Down Expand Up @@ -88,6 +92,7 @@ pub struct UpdateHookResult<'a> {
pub table: &'a str,
pub rowid: i64,
}

pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
unsafe impl Send for UpdateHookHandler {}

Expand All @@ -112,6 +117,8 @@ pub(crate) struct ConnectionState {
progress_handler_callback: Option<Handler>,

update_hook_callback: Option<UpdateHookHandler>,
#[cfg(feature = "preupdate-hook")]
preupdate_hook_callback: Option<preupdate_hook::PreupdateHookHandler>,

commit_hook_callback: Option<CommitHookHandler>,

Expand All @@ -138,6 +145,16 @@ impl ConnectionState {
}
}

#[cfg(feature = "preupdate-hook")]
pub(crate) fn remove_preupdate_hook(&mut self) {
if let Some(mut handler) = self.preupdate_hook_callback.take() {
unsafe {
libsqlite3_sys::sqlite3_preupdate_hook(self.handle.as_ptr(), None, ptr::null_mut());
let _ = { Box::from_raw(handler.0.as_mut()) };
}
}
}

pub(crate) fn remove_commit_hook(&mut self) {
if let Some(mut handler) = self.commit_hook_callback.take() {
unsafe {
Expand Down Expand Up @@ -421,6 +438,34 @@ impl LockedSqliteHandle<'_> {
}
}

/// Registers a hook that is invoked prior to each `INSERT`, `UPDATE`, and `DELETE` operation on a database table.
/// At most one preupdate hook may be registered at a time on a single database connection.
///
/// The preupdate hook only fires for changes to real database tables;
/// it is not invoked for changes to virtual tables or to system tables like sqlite_sequence or sqlite_stat1.
///
/// See https://sqlite.org/c3ref/preupdate_count.html
#[cfg(feature = "preupdate-hook")]
pub fn set_preupdate_hook<F>(&mut self, callback: F)
where
F: FnMut(PreupdateHookResult) + Send + 'static,
{
unsafe {
let callback_boxed = Box::new(callback);
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
let handler = callback.as_ptr() as *mut _;
self.guard.remove_preupdate_hook();
self.guard.preupdate_hook_callback = Some(PreupdateHookHandler(callback));

libsqlite3_sys::sqlite3_preupdate_hook(
self.as_raw_handle().as_mut(),
Some(preupdate_hook::<F>),
handler,
);
}
}

/// Sets a commit hook that is invoked whenever a transaction is committed. If the commit hook callback
/// returns `false`, then the operation is turned into a ROLLBACK.
///
Expand Down Expand Up @@ -485,6 +530,11 @@ impl LockedSqliteHandle<'_> {
self.guard.remove_update_hook();
}

#[cfg(feature = "preupdate-hook")]
pub fn remove_preupdate_hook(&mut self) {
self.guard.remove_preupdate_hook();
}

pub fn remove_commit_hook(&mut self) {
self.guard.remove_commit_hook();
}
Expand Down
156 changes: 156 additions & 0 deletions sqlx-sqlite/src/connection/preupdate_hook.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
use super::SqliteOperation;
use crate::type_info::DataType;
use crate::{SqliteError, SqliteTypeInfo, SqliteValue};

use libsqlite3_sys::{
sqlite3, sqlite3_preupdate_count, sqlite3_preupdate_depth, sqlite3_preupdate_new,
sqlite3_preupdate_old, sqlite3_value, sqlite3_value_type, SQLITE_OK,
};
use std::ffi::CStr;
use std::os::raw::{c_char, c_int, c_void};
use std::panic::catch_unwind;
use std::ptr;
use std::ptr::NonNull;

#[derive(Debug, thiserror::Error)]
pub enum PreupdateError {
/// Error returned from the database.
#[error("error returned from database: {0}")]
Database(#[source] SqliteError),
/// Index is not within the valid column range
#[error("{0} is not within the valid column range")]
ColumnIndexOutOfBounds(i32),
/// Column value accessor was invoked from an invalid operation
#[error("column value accessor was invoked from an invalid operation")]
InvalidOperation,
}

pub(crate) struct PreupdateHookHandler(
pub(super) NonNull<dyn FnMut(PreupdateHookResult) + Send + 'static>,
);
unsafe impl Send for PreupdateHookHandler {}

#[derive(Debug)]
pub struct PreupdateHookResult<'a> {
pub operation: SqliteOperation,
pub database: &'a str,
pub table: &'a str,
// The database pointer should not be usable after the preupdate hook.
// The lifetime on this struct needs to ensure it cannot outlive the callback.
db: *mut sqlite3,
old_row_id: i64,
new_row_id: i64,
}

impl<'a> PreupdateHookResult<'a> {
/// Gets the amount of columns in the row being inserted, deleted, or updated.
pub fn get_column_count(&self) -> i32 {
unsafe { sqlite3_preupdate_count(self.db) }
}

/// Gets the depth of the query that triggered the preupdate hook.
/// Returns 0 if the preupdate callback was invoked as a result of
/// a direct insert, update, or delete operation;
/// 1 for inserts, updates, or deletes invoked by top-level triggers;
/// 2 for changes resulting from triggers called by top-level triggers; and so forth.
pub fn get_query_depth(&self) -> i32 {
unsafe { sqlite3_preupdate_depth(self.db) }
}

/// Gets the row id of the row being updated/deleted.
/// Returns an error if called from an insert operation.
pub fn get_old_row_id(&self) -> Result<i64, PreupdateError> {
if self.operation == SqliteOperation::Insert {
return Err(PreupdateError::InvalidOperation);
}
Ok(self.old_row_id)
}

/// Gets the row id of the row being inserted/updated.
/// Returns an error if called from a delete operation.
pub fn get_new_row_id(&self) -> Result<i64, PreupdateError> {
if self.operation == SqliteOperation::Delete {
return Err(PreupdateError::InvalidOperation);
}
Ok(self.new_row_id)
}

/// Gets the value of the row being updated/deleted at the specified index.
/// Returns an error if called from an insert operation or the index is out of bounds.
pub fn get_old_column_value(&self, i: i32) -> Result<SqliteValue, PreupdateError> {
if self.operation == SqliteOperation::Insert {
return Err(PreupdateError::InvalidOperation);
}
self.validate_column_index(i)?;

let mut p_value: *mut sqlite3_value = ptr::null_mut();
unsafe {
let ret = sqlite3_preupdate_old(self.db, i, &mut p_value);
self.get_value(ret, p_value)
}
}

/// Gets the value of the row being updated/deleted at the specified index.
/// Returns an error if called from a delete operation or the index is out of bounds.
pub fn get_new_column_value(&self, i: i32) -> Result<SqliteValue, PreupdateError> {
if self.operation == SqliteOperation::Delete {
return Err(PreupdateError::InvalidOperation);
}
self.validate_column_index(i)?;

let mut p_value: *mut sqlite3_value = ptr::null_mut();
unsafe {
let ret = sqlite3_preupdate_new(self.db, i, &mut p_value);
self.get_value(ret, p_value)
}
}

fn validate_column_index(&self, i: i32) -> Result<(), PreupdateError> {
if i < 0 || i >= self.get_column_count() {
return Err(PreupdateError::ColumnIndexOutOfBounds(i));
}
Ok(())
}

unsafe fn get_value(
&self,
ret: i32,
p_value: *mut sqlite3_value,
) -> Result<SqliteValue, PreupdateError> {
if ret != SQLITE_OK {
return Err(PreupdateError::Database(SqliteError::new(self.db)));
}
let data_type = DataType::from_code(sqlite3_value_type(p_value));
Ok(SqliteValue::new(p_value, SqliteTypeInfo(data_type)))
}
}

pub(super) extern "C" fn preupdate_hook<F>(
callback: *mut c_void,
db: *mut sqlite3,
op_code: c_int,
database: *const c_char,
table: *const c_char,
old_row_id: i64,
new_row_id: i64,
) where
F: FnMut(PreupdateHookResult) + Send + 'static,
{
unsafe {
let _ = catch_unwind(|| {
let callback: *mut F = callback.cast::<F>();
let operation: SqliteOperation = op_code.into();
let database = CStr::from_ptr(database).to_str().unwrap_or_default();
let table = CStr::from_ptr(table).to_str().unwrap_or_default();

(*callback)(PreupdateHookResult {
operation,
database,
table,
old_row_id,
new_row_id,
db,
})
});
}
}
Loading
Loading