From 436694b3fb49eb4cc8bd09a8e44323ae746c4b50 Mon Sep 17 00:00:00 2001 From: Austin Schey Date: Tue, 3 Dec 2024 20:27:31 -0800 Subject: [PATCH] feat: add preupdate hook --- Cargo.lock | 12 +- sqlx-sqlite/Cargo.toml | 3 +- sqlx-sqlite/src/connection/establish.rs | 1 + sqlx-sqlite/src/connection/mod.rs | 202 ++++++++++++++++++++- sqlx-sqlite/src/lib.rs | 5 +- tests/sqlite/sqlite.rs | 222 +++++++++++++++++++++++- 6 files changed, 430 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2da47afa54..36cb94422e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1177,7 +1177,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] [[package]] @@ -1914,7 +1914,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] [[package]] @@ -3986,7 +3986,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] [[package]] @@ -4815,7 +4815,7 @@ checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", "synstructure", ] @@ -4856,7 +4856,7 @@ checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", "synstructure", ] @@ -4899,5 +4899,5 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] diff --git a/sqlx-sqlite/Cargo.toml b/sqlx-sqlite/Cargo.toml index 391bf4523c..2cc56740f2 100644 --- a/sqlx-sqlite/Cargo.toml +++ b/sqlx-sqlite/Cargo.toml @@ -58,7 +58,8 @@ default-features = false features = [ "pkg-config", "vcpkg", - "unlock_notify" + "unlock_notify", + "preupdate_hook" ] [dependencies.sqlx-core] diff --git a/sqlx-sqlite/src/connection/establish.rs b/sqlx-sqlite/src/connection/establish.rs index 40f9b4c302..2fbe700724 100644 --- a/sqlx-sqlite/src/connection/establish.rs +++ b/sqlx-sqlite/src/connection/establish.rs @@ -296,6 +296,7 @@ impl EstablishParams { log_settings: self.log_settings.clone(), progress_handler_callback: None, update_hook_callback: None, + preupdate_hook_callback: None, commit_hook_callback: None, rollback_hook_callback: None, }) diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index a579b8a605..4375e239eb 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -11,8 +11,10 @@ use futures_core::future::BoxFuture; use futures_intrusive::sync::MutexGuard; use futures_util::future; use libsqlite3_sys::{ - sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook, - sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, + sqlite3, sqlite3_commit_hook, sqlite3_preupdate_count, sqlite3_preupdate_depth, + sqlite3_preupdate_hook, sqlite3_preupdate_new, sqlite3_preupdate_old, sqlite3_progress_handler, + sqlite3_rollback_hook, sqlite3_update_hook, sqlite3_value, sqlite3_value_type, SQLITE_DELETE, + SQLITE_INSERT, SQLITE_OK, SQLITE_UPDATE, }; pub(crate) use handle::ConnectionHandle; @@ -26,7 +28,8 @@ use crate::connection::establish::EstablishParams; use crate::connection::worker::ConnectionWorker; use crate::options::OptimizeOnClose; use crate::statement::VirtualStatement; -use crate::{Sqlite, SqliteConnectOptions}; +use crate::type_info::DataType; +use crate::{Sqlite, SqliteConnectOptions, SqliteError, SqliteTypeInfo, SqliteValue}; pub(crate) mod collation; pub(crate) mod describe; @@ -88,6 +91,14 @@ pub struct UpdateHookResult<'a> { pub table: &'a str, pub rowid: i64, } + +pub struct PreupdateHookResult<'a> { + pub operation: SqliteOperation, + pub database: &'a str, + pub table: &'a str, + pub case: PreupdateCase, +} + pub(crate) struct UpdateHookHandler(NonNull); unsafe impl Send for UpdateHookHandler {} @@ -97,6 +108,108 @@ unsafe impl Send for CommitHookHandler {} pub(crate) struct RollbackHookHandler(NonNull); unsafe impl Send for RollbackHookHandler {} +pub(crate) struct PreupdateHookHandler(NonNull); +unsafe impl Send for PreupdateHookHandler {} + +/// The possible cases for when a PreUpdate Hook gets triggered. Allows access to the relevant +/// functions for each case through the contained values. +pub enum PreupdateCase { + /// Pre-update hook was triggered by an insert. + Insert(PreupdateNewValueAccessor), + /// Pre-update hook was triggered by a delete. + Delete(PreupdateOldValueAccessor), + /// Pre-update hook was triggered by an update. + Update { + old_value_accessor: PreupdateOldValueAccessor, + new_value_accessor: PreupdateNewValueAccessor, + }, + /// This variant is not normally produced by SQLite. You may encounter it + /// if you're using a different version than what's supported by this library. + Unknown, +} + +/// An accessor for the new values of the row being inserted/updated during the preupdate callback. +#[derive(Debug)] +pub struct PreupdateNewValueAccessor { + db: *mut sqlite3, + new_row_id: i64, +} + +impl PreupdateNewValueAccessor { + /// Gets the amount of columns in the row being inserted/updated. + pub fn get_column_count(&self) -> i32 { + unsafe { sqlite3_preupdate_count(self.db) } + } + + /// Gets the depth of the query that triggered the preupdate hook. + /// Returns 0 if the preupdate callback was invoked as a result of + /// a direct insert, update, or delete operation; + /// 1 for inserts, updates, or deletes invoked by top-level triggers; + /// 2 for changes resulting from triggers called by top-level triggers; and so forth. + pub fn get_query_depth(&self) -> i32 { + unsafe { sqlite3_preupdate_depth(self.db) } + } + + /// Gets the row id of the row being inserted/updated. + pub fn get_new_row_id(&self) -> i64 { + self.new_row_id + } + + /// Gets the value of the row being updated/deleted at the specified index. + pub fn get_new_column_value(&self, i: i32) -> Result { + let mut p_value: *mut sqlite3_value = ptr::null_mut(); + unsafe { + let ret = sqlite3_preupdate_new(self.db, i, &mut p_value); + if ret != SQLITE_OK { + return Err(Error::Database(Box::new(SqliteError::new(self.db)))); + } + let data_type = DataType::from_code(sqlite3_value_type(p_value)); + Ok(SqliteValue::new(p_value, SqliteTypeInfo(data_type))) + } + } +} + +/// An accessor for the old values of the row being deleted/updated during the preupdate callback. +#[derive(Debug)] +pub struct PreupdateOldValueAccessor { + db: *mut sqlite3, + old_row_id: i64, +} + +impl PreupdateOldValueAccessor { + /// Gets the amount of columns in the row being deleted/updated. + pub fn get_column_count(&self) -> i32 { + unsafe { sqlite3_preupdate_count(self.db) } + } + + /// Gets the depth of the query that triggered the preupdate hook. + /// Returns 0 if the preupdate callback was invoked as a result of + /// a direct insert, update, or delete operation; + /// 1 for inserts, updates, or deletes invoked by top-level triggers; + /// 2 for changes resulting from triggers called by top-level triggers; and so forth. + pub fn get_query_depth(&self) -> i32 { + unsafe { sqlite3_preupdate_depth(self.db) } + } + + /// Gets the row id of the row being updated/deleted. + pub fn get_old_row_id(&self) -> i64 { + self.old_row_id + } + + /// Gets the value of the row being updated/deleted at the specified index. + pub fn get_old_column_value(&self, i: i32) -> Result { + let mut p_value: *mut sqlite3_value = ptr::null_mut(); + unsafe { + let ret = sqlite3_preupdate_old(self.db, i, &mut p_value); + if ret != SQLITE_OK { + return Err(Error::Database(Box::new(SqliteError::new(self.db)))); + } + let data_type = DataType::from_code(sqlite3_value_type(p_value)); + Ok(SqliteValue::new(p_value, SqliteTypeInfo(data_type))) + } + } +} + pub(crate) struct ConnectionState { pub(crate) handle: ConnectionHandle, @@ -113,6 +226,8 @@ pub(crate) struct ConnectionState { update_hook_callback: Option, + preupdate_hook_callback: Option, + commit_hook_callback: Option, rollback_hook_callback: Option, @@ -138,6 +253,15 @@ impl ConnectionState { } } + pub(crate) fn remove_preupdate_hook(&mut self) { + if let Some(mut handler) = self.preupdate_hook_callback.take() { + unsafe { + sqlite3_preupdate_hook(self.handle.as_ptr(), None, ptr::null_mut()); + let _ = { Box::from_raw(handler.0.as_mut()) }; + } + } + } + pub(crate) fn remove_commit_hook(&mut self) { if let Some(mut handler) = self.commit_hook_callback.take() { unsafe { @@ -312,6 +436,47 @@ extern "C" fn update_hook( } } +extern "C" fn preupdate_hook( + callback: *mut c_void, + db: *mut sqlite3, + op_code: c_int, + database: *const i8, + table: *const i8, + old_row_id: i64, + new_row_id: i64, +) where + F: FnMut(PreupdateHookResult), +{ + unsafe { + let _ = catch_unwind(|| { + let callback: *mut F = callback.cast::(); + let operation: SqliteOperation = op_code.into(); + let database = CStr::from_ptr(database).to_str().unwrap_or_default(); + let table = CStr::from_ptr(table).to_str().unwrap_or_default(); + + let preupdate_case = match operation { + SqliteOperation::Insert => { + PreupdateCase::Insert(PreupdateNewValueAccessor { db, new_row_id }) + } + SqliteOperation::Delete => { + PreupdateCase::Delete(PreupdateOldValueAccessor { db, old_row_id }) + } + SqliteOperation::Update => PreupdateCase::Update { + old_value_accessor: PreupdateOldValueAccessor { db, old_row_id }, + new_value_accessor: PreupdateNewValueAccessor { db, new_row_id }, + }, + SqliteOperation::Unknown(_) => PreupdateCase::Unknown, + }; + (*callback)(PreupdateHookResult { + operation, + database, + table, + case: preupdate_case, + }) + }); + } +} + extern "C" fn commit_hook(callback: *mut c_void) -> c_int where F: FnMut() -> bool, @@ -476,6 +641,33 @@ impl LockedSqliteHandle<'_> { } } + /// Registers a hook that is invoked prior to each `INSERT`, `UPDATE`, and `DELETE` operation on a database table. + /// At most one preupdate hook may be registered at a time on a single database connection. + /// + /// The preupdate hook only fires for changes to real database tables; + /// it is not invoked for changes to virtual tables or to system tables like sqlite_sequence or sqlite_stat1. + /// + /// See https://sqlite.org/c3ref/preupdate_count.html + pub fn set_preupdate_hook(&mut self, callback: F) + where + F: FnMut(PreupdateHookResult) + Send + 'static, + { + unsafe { + let callback_boxed = Box::new(callback); + // SAFETY: `Box::into_raw()` always returns a non-null pointer. + let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed)); + let handler = callback.as_ptr() as *mut _; + self.guard.remove_preupdate_hook(); + self.guard.preupdate_hook_callback = Some(PreupdateHookHandler(callback)); + + sqlite3_preupdate_hook( + self.as_raw_handle().as_mut(), + Some(preupdate_hook::), + handler, + ); + } + } + /// Removes the progress handler on a database connection. The method does nothing if no handler was set. pub fn remove_progress_handler(&mut self) { self.guard.remove_progress_handler(); @@ -492,6 +684,10 @@ impl LockedSqliteHandle<'_> { pub fn remove_rollback_hook(&mut self) { self.guard.remove_rollback_hook(); } + + pub fn remove_preupdate_hook(&mut self) { + self.guard.remove_preupdate_hook(); + } } impl Drop for ConnectionState { diff --git a/sqlx-sqlite/src/lib.rs b/sqlx-sqlite/src/lib.rs index f8f5534879..4b52d1f516 100644 --- a/sqlx-sqlite/src/lib.rs +++ b/sqlx-sqlite/src/lib.rs @@ -46,7 +46,10 @@ use std::sync::atomic::AtomicBool; pub use arguments::{SqliteArgumentValue, SqliteArguments}; pub use column::SqliteColumn; -pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult}; +pub use connection::{ + LockedSqliteHandle, PreupdateCase, PreupdateHookResult, PreupdateNewValueAccessor, + PreupdateOldValueAccessor, SqliteConnection, SqliteOperation, UpdateHookResult, +}; pub use database::Sqlite; pub use error::SqliteError; pub use options::{ diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index b733ccbb4c..f7d83ec921 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -2,11 +2,15 @@ use futures::TryStreamExt; use rand::{Rng, SeedableRng}; use rand_xoshiro::Xoshiro256PlusPlus; use sqlx::sqlite::{SqliteConnectOptions, SqliteOperation, SqlitePoolOptions}; +use sqlx::Decode; +use sqlx::Value; use sqlx::{ query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row, SqliteConnection, SqlitePool, Statement, TypeInfo, }; +use sqlx_sqlite::PreupdateCase; use sqlx_test::new; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; #[sqlx_macros::test] @@ -798,7 +802,7 @@ async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow:: #[sqlx_macros::test] async fn test_query_with_update_hook() -> anyhow::Result<()> { let mut conn = new::().await?; - + static CALLED: AtomicBool = AtomicBool::new(false); // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. let state = format!("test"); conn.lock_handle().await?.set_update_hook(move |result| { @@ -806,12 +810,14 @@ async fn test_query_with_update_hook() -> anyhow::Result<()> { assert_eq!(result.operation, SqliteOperation::Insert); assert_eq!(result.database, "main"); assert_eq!(result.table, "tweet"); - assert_eq!(result.rowid, 2); + assert_eq!(result.rowid, 4); + CALLED.store(true, Ordering::Relaxed); }); let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 3, 'Hello, World' )") .execute(&mut conn) .await?; + assert!(CALLED.load(Ordering::Relaxed)); Ok(()) } @@ -852,10 +858,11 @@ async fn test_multiple_set_update_hook_calls_drop_old_handler() -> anyhow::Resul #[sqlx_macros::test] async fn test_query_with_commit_hook() -> anyhow::Result<()> { let mut conn = new::().await?; - + static CALLED: AtomicBool = AtomicBool::new(false); // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. let state = format!("test"); conn.lock_handle().await?.set_commit_hook(move || { + CALLED.store(true, Ordering::Relaxed); assert_eq!(state, "test"); false }); @@ -870,7 +877,7 @@ async fn test_query_with_commit_hook() -> anyhow::Result<()> { } _ => panic!("expected an error"), } - + assert!(CALLED.load(Ordering::Relaxed)); Ok(()) } @@ -916,8 +923,10 @@ async fn test_query_with_rollback_hook() -> anyhow::Result<()> { // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. let state = format!("test"); + static CALLED: AtomicBool = AtomicBool::new(false); conn.lock_handle().await?.set_rollback_hook(move || { assert_eq!(state, "test"); + CALLED.store(true, Ordering::Relaxed); }); let mut tx = conn.begin().await?; @@ -925,6 +934,7 @@ async fn test_query_with_rollback_hook() -> anyhow::Result<()> { .execute(&mut *tx) .await?; tx.rollback().await?; + assert!(CALLED.load(Ordering::Relaxed)); Ok(()) } @@ -960,3 +970,207 @@ async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Res assert_eq!(1, Arc::strong_count(&ref_counted_object)); Ok(()) } + +#[sqlx_macros::test] +async fn test_query_with_preupdate_hook_insert() -> anyhow::Result<()> { + let mut conn = new::().await?; + static CALLED: AtomicBool = AtomicBool::new(false); + // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. + let state = format!("test"); + conn.lock_handle().await?.set_preupdate_hook(move |result| { + assert_eq!(state, "test"); + assert_eq!(result.operation, SqliteOperation::Insert); + assert_eq!(result.database, "main"); + assert_eq!(result.table, "tweet"); + + if let PreupdateCase::Insert(accessor) = result.case { + assert_eq!(4, accessor.get_column_count()); + assert_eq!(2, accessor.get_new_row_id()); + assert_eq!(0, accessor.get_query_depth()); + assert_eq!( + 4, + >::decode( + accessor.get_new_column_value(0).unwrap().as_ref(), + ) + .unwrap() + ); + assert_eq!( + "Hello, World", + >::decode( + accessor.get_new_column_value(1).unwrap().as_ref(), + ) + .unwrap() + ); + // out of bounds access should return an error + assert!(accessor.get_new_column_value(4).is_err()); + } else { + panic!("wrong preupdate case"); + } + CALLED.store(true, Ordering::Relaxed); + }); + + let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 4, 'Hello, World' )") + .execute(&mut conn) + .await?; + + assert!(CALLED.load(Ordering::Relaxed)); + Ok(()) +} + +#[sqlx_macros::test] +async fn test_query_with_preupdate_hook_delete() -> anyhow::Result<()> { + let mut conn = new::().await?; + let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 5, 'Hello, World' )") + .execute(&mut conn) + .await?; + static CALLED: AtomicBool = AtomicBool::new(false); + // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. + let state = format!("test"); + conn.lock_handle().await?.set_preupdate_hook(move |result| { + assert_eq!(state, "test"); + assert_eq!(result.operation, SqliteOperation::Delete); + assert_eq!(result.database, "main"); + assert_eq!(result.table, "tweet"); + + if let PreupdateCase::Delete(accessor) = result.case { + assert_eq!(4, accessor.get_column_count()); + assert_eq!(2, accessor.get_old_row_id()); + assert_eq!(0, accessor.get_query_depth()); + assert_eq!( + 5, + >::decode( + accessor.get_old_column_value(0).unwrap().as_ref(), + ) + .unwrap() + ); + assert_eq!( + "Hello, World", + >::decode( + accessor.get_old_column_value(1).unwrap().as_ref(), + ) + .unwrap() + ); + // out of bounds access should return an error + assert!(accessor.get_old_column_value(4).is_err()); + } else { + panic!("wrong preupdate case"); + } + CALLED.store(true, Ordering::Relaxed); + }); + + let _ = sqlx::query("DELETE FROM tweet WHERE id = 5") + .execute(&mut conn) + .await?; + + assert!(CALLED.load(Ordering::Relaxed)); + Ok(()) +} + +#[sqlx_macros::test] +async fn test_query_with_preupdate_hook_update() -> anyhow::Result<()> { + let mut conn = new::().await?; + let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 6, 'Hello, World' )") + .execute(&mut conn) + .await?; + static CALLED: AtomicBool = AtomicBool::new(false); + // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. + let state = format!("test"); + conn.lock_handle().await?.set_preupdate_hook(move |result| { + assert_eq!(state, "test"); + assert_eq!(result.operation, SqliteOperation::Update); + assert_eq!(result.database, "main"); + assert_eq!(result.table, "tweet"); + + if let PreupdateCase::Update { + old_value_accessor, + new_value_accessor, + } = result.case + { + assert_eq!(4, old_value_accessor.get_column_count()); + assert_eq!(4, new_value_accessor.get_column_count()); + + assert_eq!(3, old_value_accessor.get_old_row_id()); + assert_eq!(3, new_value_accessor.get_new_row_id()); + + assert_eq!(0, old_value_accessor.get_query_depth()); + assert_eq!(0, new_value_accessor.get_query_depth()); + + assert_eq!( + 6, + >::decode( + old_value_accessor.get_old_column_value(0).unwrap().as_ref(), + ) + .unwrap() + ); + assert_eq!( + 6, + >::decode( + new_value_accessor.get_new_column_value(0).unwrap().as_ref(), + ) + .unwrap() + ); + + assert_eq!( + "Hello, World", + >::decode( + old_value_accessor.get_old_column_value(1).unwrap().as_ref(), + ) + .unwrap() + ); + assert_eq!( + "Hello, World2", + >::decode( + new_value_accessor.get_new_column_value(1).unwrap().as_ref(), + ) + .unwrap() + ); + + // out of bounds access should return an error + assert!(old_value_accessor.get_old_column_value(4).is_err()); + assert!(new_value_accessor.get_new_column_value(4).is_err()); + } else { + panic!("wrong preupdate case"); + } + CALLED.store(true, Ordering::Relaxed); + }); + + let _ = sqlx::query("UPDATE tweet SET text = 'Hello, World2' WHERE id = 6") + .execute(&mut conn) + .await?; + + assert!(CALLED.load(Ordering::Relaxed)); + Ok(()) +} + +#[sqlx_macros::test] +async fn test_multiple_set_preupdate_hook_calls_drop_old_handler() -> anyhow::Result<()> { + let ref_counted_object = Arc::new(0); + assert_eq!(1, Arc::strong_count(&ref_counted_object)); + + { + let mut conn = new::().await?; + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_preupdate_hook(move |_| { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_preupdate_hook(move |_| { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_preupdate_hook(move |_| { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + conn.lock_handle().await?.remove_preupdate_hook(); + } + + assert_eq!(1, Arc::strong_count(&ref_counted_object)); + Ok(()) +}