Skip to content

Commit

Permalink
feat: add preupdate hook
Browse files Browse the repository at this point in the history
  • Loading branch information
aschey committed Dec 5, 2024
1 parent 42ce24d commit bcdb609
Show file tree
Hide file tree
Showing 9 changed files with 471 additions and 13 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/sqlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ jobs:
- run: >
cargo clippy
--no-default-features
--features all-databases,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
--features all-databases,_unstable-all-types,sqlite-preupdate-hook,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
-- -D warnings
# Run beta for new warnings but don't break the build.
# Use a subdirectory of `target` to avoid clobbering the cache.
- run: >
cargo +beta clippy
--no-default-features
--features all-databases,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
--features all-databases,_unstable-all-types,sqlite-preupdate-hook,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
--target-dir target/beta/
check-minimal-versions:
Expand Down Expand Up @@ -140,7 +140,7 @@ jobs:
- run: >
cargo test
--no-default-features
--features any,macros,${{ matrix.linking }},_unstable-all-types,runtime-${{ matrix.runtime }}
--features any,macros,${{ matrix.linking }},${{ matrix.linking == 'sqlite' && 'sqlite-preupdate-hook,'}}_unstable-all-types,runtime-${{ matrix.runtime }}
--
--test-threads=1
env:
Expand Down
12 changes: 6 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ authors.workspace = true
repository.workspace = true

[package.metadata.docs.rs]
features = ["all-databases", "_unstable-all-types"]
features = ["all-databases", "_unstable-all-types", "sqlite-preupdate-hook"]
rustdoc-args = ["--cfg", "docsrs"]

[features]
Expand Down Expand Up @@ -108,6 +108,7 @@ postgres = ["sqlx-postgres", "sqlx-macros?/postgres"]
mysql = ["sqlx-mysql", "sqlx-macros?/mysql"]
sqlite = ["_sqlite", "sqlx-sqlite/bundled", "sqlx-macros?/sqlite"]
sqlite-unbundled = ["_sqlite", "sqlx-sqlite/unbundled", "sqlx-macros?/sqlite-unbundled"]
sqlite-preupdate-hook = ["sqlx-sqlite/preupdate-hook"]

# types
json = ["sqlx-macros?/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlite?/json"]
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ be removed in the future.
* May result in link errors if the SQLite version is too old. Version `3.20.0` or newer is recommended.
* Can increase build time due to the use of bindgen.

- `sqlite-preupdate-hook`: enables SQLite's [preupdate hook](https://sqlite.org/c3ref/preupdate_count.html) API.
* Exposed as a separate feature because it's generally not enabled by default.
* Using this feature with `sqlite-unbundled` may cause linker failures if the system SQLite version does not support it.

- `any`: Add support for the `Any` database driver, which can proxy to a database driver at runtime.

- `derive`: Add support for the derive family macros, those are `FromRow`, `Type`, `Encode`, `Decode`.
Expand Down
2 changes: 2 additions & 0 deletions sqlx-sqlite/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ uuid = ["dep:uuid", "sqlx-core/uuid"]

regexp = ["dep:regex"]

preupdate-hook = ["libsqlite3-sys/preupdate_hook"]

bundled = ["libsqlite3-sys/bundled"]
unbundled = ["libsqlite3-sys/buildtime_bindgen"]

Expand Down
2 changes: 2 additions & 0 deletions sqlx-sqlite/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ impl EstablishParams {
log_settings: self.log_settings.clone(),
progress_handler_callback: None,
update_hook_callback: None,
#[cfg(feature = "preupdate-hook")]
preupdate_hook_callback: None,
commit_hook_callback: None,
rollback_hook_callback: None,
})
Expand Down
221 changes: 221 additions & 0 deletions sqlx-sqlite/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use libsqlite3_sys::{
sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook,
sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE,
};
#[cfg(feature = "preupdate-hook")]
pub use preupdate_hook::*;

pub(crate) use handle::ConnectionHandle;
use sqlx_core::common::StatementCache;
Expand Down Expand Up @@ -88,6 +90,7 @@ pub struct UpdateHookResult<'a> {
pub table: &'a str,
pub rowid: i64,
}

pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
unsafe impl Send for UpdateHookHandler {}

Expand All @@ -112,6 +115,8 @@ pub(crate) struct ConnectionState {
progress_handler_callback: Option<Handler>,

update_hook_callback: Option<UpdateHookHandler>,
#[cfg(feature = "preupdate-hook")]
preupdate_hook_callback: Option<preupdate_hook::PreupdateHookHandler>,

commit_hook_callback: Option<CommitHookHandler>,

Expand Down Expand Up @@ -544,3 +549,219 @@ impl Statements {
self.temp = None;
}
}

#[cfg(feature = "preupdate-hook")]
mod preupdate_hook {
use super::ConnectionState;
use super::LockedSqliteHandle;
use super::SqliteOperation;
use crate::type_info::DataType;
use crate::{SqliteError, SqliteTypeInfo, SqliteValue};
use libsqlite3_sys::{
sqlite3, sqlite3_preupdate_count, sqlite3_preupdate_depth, sqlite3_preupdate_hook,
sqlite3_preupdate_new, sqlite3_preupdate_old, sqlite3_value, sqlite3_value_type, SQLITE_OK,
};
use sqlx_core::error::Error;
use std::ffi::CStr;
use std::fmt::Debug;
use std::os::raw::{c_char, c_int, c_void};
use std::panic::catch_unwind;
use std::ptr;
use std::ptr::NonNull;

pub struct PreupdateHookResult<'a> {
pub operation: SqliteOperation,
pub database: &'a str,
pub table: &'a str,
pub case: PreupdateCase,
}

pub(crate) struct PreupdateHookHandler(
NonNull<dyn FnMut(PreupdateHookResult) + Send + 'static>,
);
unsafe impl Send for PreupdateHookHandler {}

/// The possible cases for when a PreUpdate Hook gets triggered. Allows access to the relevant
/// functions for each case through the contained values.
pub enum PreupdateCase {
/// Pre-update hook was triggered by an insert.
Insert(PreupdateNewValueAccessor),
/// Pre-update hook was triggered by a delete.
Delete(PreupdateOldValueAccessor),
/// Pre-update hook was triggered by an update.
Update {
old_value_accessor: PreupdateOldValueAccessor,
new_value_accessor: PreupdateNewValueAccessor,
},
/// This variant is not normally produced by SQLite. You may encounter it
/// if you're using a different version than what's supported by this library.
Unknown,
}

/// An accessor for the old values of the row being deleted/updated during the preupdate callback.
#[derive(Debug)]
pub struct PreupdateOldValueAccessor {
db: *mut sqlite3,
old_row_id: i64,
}

impl PreupdateOldValueAccessor {
/// Gets the amount of columns in the row being deleted/updated.
pub fn get_column_count(&self) -> i32 {
unsafe { sqlite3_preupdate_count(self.db) }
}

/// Gets the depth of the query that triggered the preupdate hook.
/// Returns 0 if the preupdate callback was invoked as a result of
/// a direct insert, update, or delete operation;
/// 1 for inserts, updates, or deletes invoked by top-level triggers;
/// 2 for changes resulting from triggers called by top-level triggers; and so forth.
pub fn get_query_depth(&self) -> i32 {
unsafe { sqlite3_preupdate_depth(self.db) }
}

/// Gets the row id of the row being updated/deleted.
pub fn get_old_row_id(&self) -> i64 {
self.old_row_id
}

/// Gets the value of the row being updated/deleted at the specified index.
pub fn get_old_column_value(&self, i: i32) -> Result<SqliteValue, Error> {
let mut p_value: *mut sqlite3_value = ptr::null_mut();
unsafe {
let ret = sqlite3_preupdate_old(self.db, i, &mut p_value);
if ret != SQLITE_OK {
return Err(Error::Database(Box::new(SqliteError::new(self.db))));
}
let data_type = DataType::from_code(sqlite3_value_type(p_value));
Ok(SqliteValue::new(p_value, SqliteTypeInfo(data_type)))
}
}
}

/// An accessor for the new values of the row being inserted/updated during the preupdate callback.
#[derive(Debug)]
pub struct PreupdateNewValueAccessor {
db: *mut sqlite3,
new_row_id: i64,
}

impl PreupdateNewValueAccessor {
/// Gets the amount of columns in the row being inserted/updated.
pub fn get_column_count(&self) -> i32 {
unsafe { sqlite3_preupdate_count(self.db) }
}

/// Gets the depth of the query that triggered the preupdate hook.
/// Returns 0 if the preupdate callback was invoked as a result of
/// a direct insert, update, or delete operation;
/// 1 for inserts, updates, or deletes invoked by top-level triggers;
/// 2 for changes resulting from triggers called by top-level triggers; and so forth.
pub fn get_query_depth(&self) -> i32 {
unsafe { sqlite3_preupdate_depth(self.db) }
}

/// Gets the row id of the row being inserted/updated.
pub fn get_new_row_id(&self) -> i64 {
self.new_row_id
}

/// Gets the value of the row being updated/deleted at the specified index.
pub fn get_new_column_value(&self, i: i32) -> Result<SqliteValue, Error> {
let mut p_value: *mut sqlite3_value = ptr::null_mut();
unsafe {
let ret = sqlite3_preupdate_new(self.db, i, &mut p_value);
if ret != SQLITE_OK {
return Err(Error::Database(Box::new(SqliteError::new(self.db))));
}
let data_type = DataType::from_code(sqlite3_value_type(p_value));
Ok(SqliteValue::new(p_value, SqliteTypeInfo(data_type)))
}
}
}

impl ConnectionState {
pub(crate) fn remove_preupdate_hook(&mut self) {
if let Some(mut handler) = self.preupdate_hook_callback.take() {
unsafe {
sqlite3_preupdate_hook(self.handle.as_ptr(), None, ptr::null_mut());
let _ = { Box::from_raw(handler.0.as_mut()) };
}
}
}
}

impl LockedSqliteHandle<'_> {
/// Registers a hook that is invoked prior to each `INSERT`, `UPDATE`, and `DELETE` operation on a database table.
/// At most one preupdate hook may be registered at a time on a single database connection.
///
/// The preupdate hook only fires for changes to real database tables;
/// it is not invoked for changes to virtual tables or to system tables like sqlite_sequence or sqlite_stat1.
///
/// See https://sqlite.org/c3ref/preupdate_count.html
pub fn set_preupdate_hook<F>(&mut self, callback: F)
where
F: FnMut(PreupdateHookResult) + Send + 'static,
{
unsafe {
let callback_boxed = Box::new(callback);
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
let handler = callback.as_ptr() as *mut _;
self.guard.remove_preupdate_hook();
self.guard.preupdate_hook_callback = Some(PreupdateHookHandler(callback));

sqlite3_preupdate_hook(
self.as_raw_handle().as_mut(),
Some(preupdate_hook::<F>),
handler,
);
}
}

pub fn remove_preupdate_hook(&mut self) {
self.guard.remove_preupdate_hook();
}
}

extern "C" fn preupdate_hook<F>(
callback: *mut c_void,
db: *mut sqlite3,
op_code: c_int,
database: *const c_char,
table: *const c_char,
old_row_id: i64,
new_row_id: i64,
) where
F: FnMut(PreupdateHookResult),
{
unsafe {
let _ = catch_unwind(|| {
let callback: *mut F = callback.cast::<F>();
let operation: SqliteOperation = op_code.into();
let database = CStr::from_ptr(database).to_str().unwrap_or_default();
let table = CStr::from_ptr(table).to_str().unwrap_or_default();

let preupdate_case = match operation {
SqliteOperation::Insert => {
PreupdateCase::Insert(PreupdateNewValueAccessor { db, new_row_id })
}
SqliteOperation::Delete => {
PreupdateCase::Delete(PreupdateOldValueAccessor { db, old_row_id })
}
SqliteOperation::Update => PreupdateCase::Update {
old_value_accessor: PreupdateOldValueAccessor { db, old_row_id },
new_value_accessor: PreupdateNewValueAccessor { db, new_row_id },
},
SqliteOperation::Unknown(_) => PreupdateCase::Unknown,
};
(*callback)(PreupdateHookResult {
operation,
database,
table,
case: preupdate_case,
})
});
}
}
}
4 changes: 4 additions & 0 deletions sqlx-sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ use std::sync::atomic::AtomicBool;
pub use arguments::{SqliteArgumentValue, SqliteArguments};
pub use column::SqliteColumn;
pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult};
#[cfg(feature = "preupdate-hook")]
pub use connection::{
PreupdateCase, PreupdateHookResult, PreupdateNewValueAccessor, PreupdateOldValueAccessor,
};
pub use database::Sqlite;
pub use error::SqliteError;
pub use options::{
Expand Down
Loading

0 comments on commit bcdb609

Please sign in to comment.