From 9f203fc8b2553db2e66b41352367984ed9abcdd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BC=8A=E6=AC=A7?= Date: Tue, 20 Aug 2024 17:15:48 +0800 Subject: [PATCH] feat: Add support for async proxy connection. (#2278) * Try to attach async to proxy trait first. * Update proxy connection to support async. * Add example. * Try to fix CI. * Remove CI for cloudflare worker example at this moment... * Improve SQL serializer --- .../.gitignore | 5 + .../Cargo.toml | 52 +++++ .../proxy_cloudflare_worker_example/README.md | 13 ++ .../Wrangler.toml | 12 + .../src/entity.rs | 17 ++ .../src/lib.rs | 16 ++ .../src/orm.rs | 218 ++++++++++++++++++ .../src/route.rs | 90 ++++++++ examples/proxy_gluesql_example/Cargo.toml | 6 +- examples/proxy_gluesql_example/src/main.rs | 9 +- src/database/db_connection.rs | 10 +- src/database/mod.rs | 2 +- src/database/proxy.rs | 38 ++- src/database/stream/query.rs | 5 +- src/database/stream/transaction.rs | 5 +- src/driver/proxy.rs | 64 ++--- 16 files changed, 474 insertions(+), 88 deletions(-) create mode 100644 examples/proxy_cloudflare_worker_example/.gitignore create mode 100644 examples/proxy_cloudflare_worker_example/Cargo.toml create mode 100644 examples/proxy_cloudflare_worker_example/README.md create mode 100644 examples/proxy_cloudflare_worker_example/Wrangler.toml create mode 100644 examples/proxy_cloudflare_worker_example/src/entity.rs create mode 100644 examples/proxy_cloudflare_worker_example/src/lib.rs create mode 100644 examples/proxy_cloudflare_worker_example/src/orm.rs create mode 100644 examples/proxy_cloudflare_worker_example/src/route.rs diff --git a/examples/proxy_cloudflare_worker_example/.gitignore b/examples/proxy_cloudflare_worker_example/.gitignore new file mode 100644 index 000000000..3a951ebee --- /dev/null +++ b/examples/proxy_cloudflare_worker_example/.gitignore @@ -0,0 +1,5 @@ +target +node_modules +.wrangler +build +dist diff --git a/examples/proxy_cloudflare_worker_example/Cargo.toml b/examples/proxy_cloudflare_worker_example/Cargo.toml new file mode 100644 index 000000000..404d87cf4 --- /dev/null +++ b/examples/proxy_cloudflare_worker_example/Cargo.toml @@ -0,0 +1,52 @@ +[package] +name = "sea-orm-proxy-cloudflare-worker-example" +version = "0.1.0" +authors = ["Langyo "] +edition = "2021" +publish = false + +[workspace] + +[package.metadata.release] +release = false + +# https://github.com/rustwasm/wasm-pack/issues/1247 +[package.metadata.wasm-pack.profile.release] +wasm-opt = false + +[lib] +crate-type = ["cdylib"] + +[dependencies] +anyhow = "1" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +once_cell = "1" +async-trait = "0.1" + +worker = { version = "0.3.0", features = ['http', 'axum', "d1"] } +worker-macros = { version = "0.3.0", features = ['http'] } +axum = { version = "0.7", default-features = false, features = ["macros"] } +tower-service = "0.3.2" + +chrono = "0.4" +uuid = { version = "1", features = ["v4"] } + +console_error_panic_hook = { version = "0.1" } +wasm-bindgen = "0.2.92" +wasm-bindgen-futures = { version = "0.4" } +gloo = "0.11" +oneshot = "0.1" + +sea-orm = { path = "../../", default-features = false, features = [ + "macros", + "proxy", + "with-uuid", + "with-chrono", + "with-json", + "debug-print", +] } + +[patch.crates-io] +# https://github.com/cloudflare/workers-rs/pull/591 +worker = { git = "https://github.com/cloudflare/workers-rs.git", rev = "ff2e6a0fd58b7e7b4b7651aba46e04067597eb03" } diff --git a/examples/proxy_cloudflare_worker_example/README.md b/examples/proxy_cloudflare_worker_example/README.md new file mode 100644 index 000000000..87632c9fe --- /dev/null +++ b/examples/proxy_cloudflare_worker_example/README.md @@ -0,0 +1,13 @@ +# SeaORM Proxy Demo for Cloudflare Workers + +This is a simple Cloudflare worker written in Rust. It uses the `sea-orm` ORM to interact with SQLite that is stored in the Cloudflare D1. It also uses `axum` as the server framework. + +It's inspired by the [Cloudflare Workers Demo with Rust](https://github.com/logankeenan/full-stack-rust-cloudflare-axum). + +## Run + +Make sure you have `npm` and `cargo` installed. Be sure to use the latest version of `nodejs` and `rust`. + +```bash +npx wrangler dev +``` diff --git a/examples/proxy_cloudflare_worker_example/Wrangler.toml b/examples/proxy_cloudflare_worker_example/Wrangler.toml new file mode 100644 index 000000000..16b0d9d43 --- /dev/null +++ b/examples/proxy_cloudflare_worker_example/Wrangler.toml @@ -0,0 +1,12 @@ +name = "axum" +main = "build/worker/shim.mjs" +compatibility_date = "2024-07-08" + +[[d1_databases]] +binding = "test-d1" +database_name = "axumtest" +# Change it if you want to use your own database +database_id = "00000000-0000-0000-0000-000000000000" + +[build] +command = "cargo install -q worker-build && worker-build --release" diff --git a/examples/proxy_cloudflare_worker_example/src/entity.rs b/examples/proxy_cloudflare_worker_example/src/entity.rs new file mode 100644 index 000000000..868846046 --- /dev/null +++ b/examples/proxy_cloudflare_worker_example/src/entity.rs @@ -0,0 +1,17 @@ +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)] +#[sea_orm(table_name = "posts")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i64, + + pub title: String, + pub text: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/examples/proxy_cloudflare_worker_example/src/lib.rs b/examples/proxy_cloudflare_worker_example/src/lib.rs new file mode 100644 index 000000000..16095ed66 --- /dev/null +++ b/examples/proxy_cloudflare_worker_example/src/lib.rs @@ -0,0 +1,16 @@ +use anyhow::Result; +use axum::{body::Body, response::Response}; +use tower_service::Service; +use worker::{event, Context, Env, HttpRequest}; + +pub(crate) mod entity; +pub(crate) mod orm; +pub(crate) mod route; + +// https://developers.cloudflare.com/workers/languages/rust +#[event(fetch)] +async fn fetch(req: HttpRequest, env: Env, _ctx: Context) -> Result> { + console_error_panic_hook::set_once(); + + Ok(route::router(env).call(req).await?) +} diff --git a/examples/proxy_cloudflare_worker_example/src/orm.rs b/examples/proxy_cloudflare_worker_example/src/orm.rs new file mode 100644 index 000000000..d6cd65a0a --- /dev/null +++ b/examples/proxy_cloudflare_worker_example/src/orm.rs @@ -0,0 +1,218 @@ +use anyhow::{anyhow, Context, Result}; +use std::{collections::BTreeMap, sync::Arc}; +use wasm_bindgen::JsValue; + +use sea_orm::{ + ConnectionTrait, Database, DatabaseConnection, DbBackend, DbErr, ProxyDatabaseTrait, + ProxyExecResult, ProxyRow, RuntimeErr, Schema, Statement, Value, Values, +}; +use worker::{console_log, Env}; + +struct ProxyDb { + env: Arc, +} + +impl std::fmt::Debug for ProxyDb { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ProxyDb").finish() + } +} + +impl ProxyDb { + async fn do_query(env: Arc, statement: Statement) -> Result> { + let sql = statement.sql.clone(); + let values = match statement.values { + Some(Values(values)) => values + .iter() + .map(|val| match &val { + Value::BigInt(Some(val)) => JsValue::from(val.to_string()), + Value::BigUnsigned(Some(val)) => JsValue::from(val.to_string()), + Value::Int(Some(val)) => JsValue::from(*val), + Value::Unsigned(Some(val)) => JsValue::from(*val), + Value::SmallInt(Some(val)) => JsValue::from(*val), + Value::SmallUnsigned(Some(val)) => JsValue::from(*val), + Value::TinyInt(Some(val)) => JsValue::from(*val), + Value::TinyUnsigned(Some(val)) => JsValue::from(*val), + + Value::Float(Some(val)) => JsValue::from_f64(*val as f64), + Value::Double(Some(val)) => JsValue::from_f64(*val), + + Value::Bool(Some(val)) => JsValue::from(*val), + Value::Bytes(Some(val)) => JsValue::from(format!( + "X'{}'", + val.iter() + .map(|byte| format!("{:02x}", byte)) + .collect::() + )), + Value::Char(Some(val)) => JsValue::from(val.to_string()), + Value::Json(Some(val)) => JsValue::from(val.to_string()), + Value::String(Some(val)) => JsValue::from(val.to_string()), + + Value::ChronoDate(Some(val)) => JsValue::from(val.to_string()), + Value::ChronoDateTime(Some(val)) => JsValue::from(val.to_string()), + Value::ChronoDateTimeLocal(Some(val)) => JsValue::from(val.to_string()), + Value::ChronoDateTimeUtc(Some(val)) => JsValue::from(val.to_string()), + Value::ChronoDateTimeWithTimeZone(Some(val)) => JsValue::from(val.to_string()), + + _ => JsValue::NULL, + }) + .collect(), + None => Vec::new(), + }; + + console_log!("SQL query values: {:?}", values); + let ret = env.d1("test-d1")?.prepare(sql).bind(&values)?.all().await?; + if let Some(message) = ret.error() { + return Err(anyhow!(message.to_string())); + } + + let ret = ret.results::()?; + let ret = ret + .iter() + .map(|row| { + let mut values = BTreeMap::new(); + for (key, value) in row.as_object().unwrap() { + values.insert( + key.clone(), + match &value { + serde_json::Value::Bool(val) => Value::Bool(Some(*val)), + serde_json::Value::Number(val) => { + if val.is_i64() { + Value::BigInt(Some(val.as_i64().unwrap())) + } else if val.is_u64() { + Value::BigUnsigned(Some(val.as_u64().unwrap())) + } else { + Value::Double(Some(val.as_f64().unwrap())) + } + } + serde_json::Value::String(val) => { + Value::String(Some(Box::new(val.clone()))) + } + _ => unreachable!("Unsupported JSON value"), + }, + ); + } + ProxyRow { values } + }) + .collect(); + console_log!("SQL query result: {:?}", ret); + + Ok(ret) + } + + async fn do_execute(env: Arc, statement: Statement) -> Result { + let sql = statement.sql.clone(); + let values = match statement.values { + Some(Values(values)) => values + .iter() + .map(|val| match &val { + Value::BigInt(Some(val)) => JsValue::from(val.to_string()), + Value::BigUnsigned(Some(val)) => JsValue::from(val.to_string()), + Value::Int(Some(val)) => JsValue::from(*val), + Value::Unsigned(Some(val)) => JsValue::from(*val), + Value::SmallInt(Some(val)) => JsValue::from(*val), + Value::SmallUnsigned(Some(val)) => JsValue::from(*val), + Value::TinyInt(Some(val)) => JsValue::from(*val), + Value::TinyUnsigned(Some(val)) => JsValue::from(*val), + + Value::Float(Some(val)) => JsValue::from_f64(*val as f64), + Value::Double(Some(val)) => JsValue::from_f64(*val), + + Value::Bool(Some(val)) => JsValue::from(*val), + Value::Bytes(Some(val)) => JsValue::from(format!( + "X'{}'", + val.iter() + .map(|byte| format!("{:02x}", byte)) + .collect::() + )), + Value::Char(Some(val)) => JsValue::from(val.to_string()), + Value::Json(Some(val)) => JsValue::from(val.to_string()), + Value::String(Some(val)) => JsValue::from(val.to_string()), + + Value::ChronoDate(Some(val)) => JsValue::from(val.to_string()), + Value::ChronoDateTime(Some(val)) => JsValue::from(val.to_string()), + Value::ChronoDateTimeLocal(Some(val)) => JsValue::from(val.to_string()), + Value::ChronoDateTimeUtc(Some(val)) => JsValue::from(val.to_string()), + Value::ChronoDateTimeWithTimeZone(Some(val)) => JsValue::from(val.to_string()), + + _ => JsValue::NULL, + }) + .collect(), + None => Vec::new(), + }; + + let ret = env + .d1("test-d1")? + .prepare(sql) + .bind(&values)? + .run() + .await? + .meta()?; + console_log!("SQL execute result: {:?}", ret); + + let last_insert_id = ret + .as_ref() + .map(|meta| meta.last_row_id.unwrap_or(0)) + .unwrap_or(0) as u64; + let rows_affected = ret + .as_ref() + .map(|meta| meta.rows_written.unwrap_or(0)) + .unwrap_or(0) as u64; + + Ok(ProxyExecResult { + last_insert_id, + rows_affected, + }) + } +} + +#[async_trait::async_trait] +impl ProxyDatabaseTrait for ProxyDb { + async fn query(&self, statement: Statement) -> Result, DbErr> { + console_log!("SQL query: {:?}", statement); + + let env = self.env.clone(); + let (tx, rx) = oneshot::channel(); + wasm_bindgen_futures::spawn_local(async move { + let ret = Self::do_query(env, statement).await; + tx.send(ret).unwrap(); + }); + + let ret = rx.await.unwrap(); + ret.map_err(|err| DbErr::Conn(RuntimeErr::Internal(err.to_string()))) + } + + async fn execute(&self, statement: Statement) -> Result { + console_log!("SQL execute: {:?}", statement); + + let env = self.env.clone(); + let (tx, rx) = oneshot::channel(); + wasm_bindgen_futures::spawn_local(async move { + let ret = Self::do_execute(env, statement).await; + tx.send(ret).unwrap(); + }); + + let ret = rx.await.unwrap(); + ret.map_err(|err| DbErr::Conn(RuntimeErr::Internal(err.to_string()))) + } +} + +pub async fn init_db(env: Arc) -> Result { + let db = Database::connect_proxy(DbBackend::Sqlite, Arc::new(Box::new(ProxyDb { env }))) + .await + .context("Failed to connect to database")?; + let builder = db.get_database_backend(); + + console_log!("Connected to database"); + + db.execute( + builder.build( + Schema::new(builder) + .create_table_from_entity(crate::entity::Entity) + .if_not_exists(), + ), + ) + .await?; + + Ok(db) +} diff --git a/examples/proxy_cloudflare_worker_example/src/route.rs b/examples/proxy_cloudflare_worker_example/src/route.rs new file mode 100644 index 000000000..f96831c29 --- /dev/null +++ b/examples/proxy_cloudflare_worker_example/src/route.rs @@ -0,0 +1,90 @@ +use anyhow::Result; +use std::sync::Arc; + +use axum::{extract::State, http::StatusCode, response::IntoResponse, routing::get, Router}; +use worker::{console_error, console_log, Env}; + +use sea_orm::{ + ActiveModelTrait, + ActiveValue::{NotSet, Set}, + EntityTrait, +}; + +#[derive(Clone)] +struct CFEnv { + pub env: Arc, +} + +unsafe impl Send for CFEnv {} +unsafe impl Sync for CFEnv {} + +pub fn router(env: Env) -> Router { + let state = CFEnv { env: Arc::new(env) }; + + Router::new() + .route("/", get(handler_get)) + .route("/generate", get(handler_generate)) + .with_state(state) +} + +async fn handler_get( + State(state): State, +) -> Result { + let env = state.env.clone(); + let db = crate::orm::init_db(env).await.map_err(|err| { + console_log!("Failed to connect to database: {:?}", err); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to connect to database".to_string(), + ) + })?; + + let ret = crate::entity::Entity::find() + .all(&db) + .await + .map_err(|err| { + console_log!("Failed to query database: {:?}", err); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to query database".to_string(), + ) + })?; + let ret = serde_json::to_string(&ret).map_err(|err| { + console_error!("Failed to serialize response: {:?}", err); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to serialize response".to_string(), + ) + })?; + + Ok(ret.into_response()) +} + +async fn handler_generate( + State(state): State, +) -> Result { + let env = state.env.clone(); + let db = crate::orm::init_db(env).await.map_err(|err| { + console_log!("Failed to connect to database: {:?}", err); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to connect to database".to_string(), + ) + })?; + + let ret = crate::entity::ActiveModel { + id: NotSet, + title: Set(chrono::Utc::now().to_rfc3339()), + text: Set(uuid::Uuid::new_v4().to_string()), + }; + + let ret = ret.insert(&db).await.map_err(|err| { + console_log!("Failed to insert into database: {:?}", err); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to insert into database".to_string(), + ) + })?; + + Ok(format!("Inserted: {:?}", ret).into_response()) +} diff --git a/examples/proxy_gluesql_example/Cargo.toml b/examples/proxy_gluesql_example/Cargo.toml index bc495135a..9529d69a6 100644 --- a/examples/proxy_gluesql_example/Cargo.toml +++ b/examples/proxy_gluesql_example/Cargo.toml @@ -14,12 +14,10 @@ serde = { version = "1" } futures = { version = "0.3" } async-stream = { version = "0.3" } futures-util = { version = "0.3" } +async-trait = { version = "0.1" } sqlparser = "0.40" -sea-orm = { path = "../../", features = [ - "proxy", - "debug-print", -] } +sea-orm = { path = "../../", features = ["proxy", "debug-print"] } gluesql = { version = "0.15", default-features = false, features = [ "memory-storage", ] } diff --git a/examples/proxy_gluesql_example/src/main.rs b/examples/proxy_gluesql_example/src/main.rs index 17c942d71..0bbe7eee3 100644 --- a/examples/proxy_gluesql_example/src/main.rs +++ b/examples/proxy_gluesql_example/src/main.rs @@ -27,8 +27,9 @@ impl std::fmt::Debug for ProxyDb { } } +#[async_trait::async_trait] impl ProxyDatabaseTrait for ProxyDb { - fn query(&self, statement: Statement) -> Result, DbErr> { + async fn query(&self, statement: Statement) -> Result, DbErr> { println!("SQL query: {:?}", statement); let sql = statement.sql.clone(); @@ -64,7 +65,7 @@ impl ProxyDatabaseTrait for ProxyDb { Ok(ret) } - fn execute(&self, statement: Statement) -> Result { + async fn execute(&self, statement: Statement) -> Result { let sql = if let Some(values) = statement.values { // Replace all the '?' with the statement values use sqlparser::ast::{Expr, Value}; @@ -149,9 +150,9 @@ async fn main() { let db = Database::connect_proxy( DbBackend::Sqlite, - Arc::new(Mutex::new(Box::new(ProxyDb { + Arc::new(Box::new(ProxyDb { mem: Mutex::new(glue), - }))), + })), ) .await .unwrap(); diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 0981c3ca6..a06381532 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -133,7 +133,7 @@ impl ConnectionTrait for DatabaseConnection { #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt), #[cfg(feature = "proxy")] - DatabaseConnection::ProxyDatabaseConnection(conn) => conn.execute(stmt), + DatabaseConnection::ProxyDatabaseConnection(conn) => conn.execute(stmt).await, DatabaseConnection::Disconnected => Err(conn_err("Disconnected")), } } @@ -162,7 +162,7 @@ impl ConnectionTrait for DatabaseConnection { DatabaseConnection::ProxyDatabaseConnection(conn) => { let db_backend = conn.get_database_backend(); let stmt = Statement::from_string(db_backend, sql); - conn.execute(stmt) + conn.execute(stmt).await } DatabaseConnection::Disconnected => Err(conn_err("Disconnected")), } @@ -181,7 +181,7 @@ impl ConnectionTrait for DatabaseConnection { #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt), #[cfg(feature = "proxy")] - DatabaseConnection::ProxyDatabaseConnection(conn) => conn.query_one(stmt), + DatabaseConnection::ProxyDatabaseConnection(conn) => conn.query_one(stmt).await, DatabaseConnection::Disconnected => Err(conn_err("Disconnected")), } } @@ -199,7 +199,7 @@ impl ConnectionTrait for DatabaseConnection { #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt), #[cfg(feature = "proxy")] - DatabaseConnection::ProxyDatabaseConnection(conn) => conn.query_all(stmt), + DatabaseConnection::ProxyDatabaseConnection(conn) => conn.query_all(stmt).await, DatabaseConnection::Disconnected => Err(conn_err("Disconnected")), } } @@ -470,7 +470,7 @@ impl DatabaseConnection { #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => conn.ping(), #[cfg(feature = "proxy")] - DatabaseConnection::ProxyDatabaseConnection(conn) => conn.ping(), + DatabaseConnection::ProxyDatabaseConnection(conn) => conn.ping().await, DatabaseConnection::Disconnected => Err(conn_err("Disconnected")), } } diff --git a/src/database/mod.rs b/src/database/mod.rs index 67a8d7279..2f6d535fb 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -102,7 +102,7 @@ impl Database { #[instrument(level = "trace", skip(proxy_func_arc))] pub async fn connect_proxy( db_type: DbBackend, - proxy_func_arc: std::sync::Arc>>, + proxy_func_arc: std::sync::Arc>, ) -> Result { match db_type { DbBackend::MySql => { diff --git a/src/database/proxy.rs b/src/database/proxy.rs index 9bd1228ea..3d93d2983 100644 --- a/src/database/proxy.rs +++ b/src/database/proxy.rs @@ -4,24 +4,25 @@ use sea_query::{Value, ValueType}; use std::{collections::BTreeMap, fmt::Debug}; /// Defines the [ProxyDatabaseTrait] to save the functions +#[async_trait::async_trait] pub trait ProxyDatabaseTrait: Send + Sync + std::fmt::Debug { /// Execute a query in the [ProxyDatabase], and return the query results - fn query(&self, statement: Statement) -> Result, DbErr>; + async fn query(&self, statement: Statement) -> Result, DbErr>; /// Execute a command in the [ProxyDatabase], and report the number of rows affected - fn execute(&self, statement: Statement) -> Result; + async fn execute(&self, statement: Statement) -> Result; /// Begin a transaction in the [ProxyDatabase] - fn begin(&self) {} + async fn begin(&self) {} /// Commit a transaction in the [ProxyDatabase] - fn commit(&self) {} + async fn commit(&self) {} /// Rollback a transaction in the [ProxyDatabase] - fn rollback(&self) {} + async fn rollback(&self) {} /// Ping the [ProxyDatabase], it should return an error if the database is not available - fn ping(&self) -> Result<(), DbErr> { + async fn ping(&self) -> Result<(), DbErr> { Ok(()) } } @@ -207,12 +208,12 @@ mod tests { struct ProxyDb {} impl ProxyDatabaseTrait for ProxyDb { - fn query(&self, statement: Statement) -> Result, DbErr> { + async fn query(&self, statement: Statement) -> Result, DbErr> { println!("SQL query: {}", statement.sql); Ok(vec![].into()) } - fn execute(&self, statement: Statement) -> Result { + async fn execute(&self, statement: Statement) -> Result { println!("SQL execute: {}", statement.sql); Ok(ProxyExecResult { last_insert_id: 1, @@ -223,28 +224,25 @@ mod tests { #[smol_potat::test] async fn create_proxy_conn() { - let _db = - Database::connect_proxy(DbBackend::MySql, Arc::new(Mutex::new(Box::new(ProxyDb {})))) - .await - .unwrap(); + let _db = Database::connect_proxy(DbBackend::MySql, Arc::new(Box::new(ProxyDb {}))) + .await + .unwrap(); } #[smol_potat::test] async fn select_rows() { - let db = - Database::connect_proxy(DbBackend::MySql, Arc::new(Mutex::new(Box::new(ProxyDb {})))) - .await - .unwrap(); + let db = Database::connect_proxy(DbBackend::MySql, Arc::new(Box::new(ProxyDb {}))) + .await + .unwrap(); let _ = cake::Entity::find().all(&db).await; } #[smol_potat::test] async fn insert_one_row() { - let db = - Database::connect_proxy(DbBackend::MySql, Arc::new(Mutex::new(Box::new(ProxyDb {})))) - .await - .unwrap(); + let db = Database::connect_proxy(DbBackend::MySql, Arc::new(Box::new(ProxyDb {}))) + .await + .unwrap(); let item = cake::ActiveModel { id: NotSet, diff --git a/src/database/stream/query.rs b/src/database/stream/query.rs index 3c77b7b0d..a8670fa3f 100644 --- a/src/database/stream/query.rs +++ b/src/database/stream/query.rs @@ -86,10 +86,7 @@ impl QueryStream { } #[cfg(feature = "proxy")] InnerConnection::Proxy(c) => { - let _start = _metric_callback.is_some().then(std::time::SystemTime::now); - let stream = c.fetch(stmt); - let elapsed = _start.map(|s| s.elapsed().unwrap_or_default()); - MetricStream::new(_metric_callback, stmt, elapsed, stream) + todo!("Proxy connection is not supported") } #[allow(unreachable_patterns)] _ => unreachable!(), diff --git a/src/database/stream/transaction.rs b/src/database/stream/transaction.rs index 3e9285d15..0040455a3 100644 --- a/src/database/stream/transaction.rs +++ b/src/database/stream/transaction.rs @@ -88,10 +88,7 @@ impl<'a> TransactionStream<'a> { } #[cfg(feature = "proxy")] InnerConnection::Proxy(c) => { - let _start = _metric_callback.is_some().then(std::time::SystemTime::now); - let stream = c.fetch(stmt); - let elapsed = _start.map(|s| s.elapsed().unwrap_or_default()); - MetricStream::new(_metric_callback, stmt, elapsed, stream) + todo!("Proxy connection is not supported") } #[allow(unreachable_patterns)] _ => unreachable!(), diff --git a/src/driver/proxy.rs b/src/driver/proxy.rs index d46335e74..cf852895f 100644 --- a/src/driver/proxy.rs +++ b/src/driver/proxy.rs @@ -2,12 +2,7 @@ use crate::{ debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, ProxyDatabaseTrait, QueryResult, Statement, }; -use futures::Stream; -use std::{ - fmt::Debug, - pin::Pin, - sync::{Arc, Mutex}, -}; +use std::{fmt::Debug, sync::Arc}; use tracing::instrument; /// Defines a database driver for the [ProxyDatabase] @@ -18,7 +13,7 @@ pub struct ProxyDatabaseConnector; #[derive(Debug)] pub struct ProxyDatabaseConnection { db_backend: DbBackend, - proxy: Arc>>, + proxy: Arc>, } impl ProxyDatabaseConnector { @@ -34,7 +29,7 @@ impl ProxyDatabaseConnector { #[instrument(level = "trace")] pub fn connect( db_type: DbBackend, - func: Arc>>, + func: Arc>, ) -> Result { Ok(DatabaseConnection::ProxyDatabaseConnection(Arc::new( ProxyDatabaseConnection::new(db_type, func), @@ -44,7 +39,7 @@ impl ProxyDatabaseConnector { impl ProxyDatabaseConnection { /// Create a connection to the [ProxyDatabase] - pub fn new(db_backend: DbBackend, funcs: Arc>>) -> Self { + pub fn new(db_backend: DbBackend, funcs: Arc>) -> Self { Self { db_backend, proxy: funcs.to_owned(), @@ -58,21 +53,16 @@ impl ProxyDatabaseConnection { /// Execute the SQL statement in the [ProxyDatabase] #[instrument(level = "trace")] - pub fn execute(&self, statement: Statement) -> Result { + pub async fn execute(&self, statement: Statement) -> Result { debug_print!("{}", statement); - Ok(self - .proxy - .lock() - .map_err(exec_err)? - .execute(statement)? - .into()) + Ok(self.proxy.execute(statement).await?.into()) } /// Return one [QueryResult] if the query was successful #[instrument(level = "trace")] - pub fn query_one(&self, statement: Statement) -> Result, DbErr> { + pub async fn query_one(&self, statement: Statement) -> Result, DbErr> { debug_print!("{}", statement); - let result = self.proxy.lock().map_err(query_err)?.query(statement)?; + let result = self.proxy.query(statement).await?; if let Some(first) = result.first() { return Ok(Some(QueryResult { @@ -85,9 +75,9 @@ impl ProxyDatabaseConnection { /// Return all [QueryResult]s if the query was successful #[instrument(level = "trace")] - pub fn query_all(&self, statement: Statement) -> Result, DbErr> { + pub async fn query_all(&self, statement: Statement) -> Result, DbErr> { debug_print!("{}", statement); - let result = self.proxy.lock().map_err(query_err)?.query(statement)?; + let result = self.proxy.query(statement).await?; Ok(result .into_iter() @@ -97,45 +87,27 @@ impl ProxyDatabaseConnection { .collect()) } - /// Return [QueryResult]s from a multi-query operation - #[instrument(level = "trace")] - pub fn fetch( - &self, - statement: &Statement, - ) -> Pin> + Send>> { - match self.query_all(statement.clone()) { - Ok(v) => Box::pin(futures::stream::iter(v.into_iter().map(Ok))), - Err(e) => Box::pin(futures::stream::iter(Some(Err(e)).into_iter())), - } - } - /// Create a statement block of SQL statements that execute together. #[instrument(level = "trace")] - pub fn begin(&self) { - self.proxy.lock().expect("Failed to acquire mocker").begin() + pub async fn begin(&self) { + self.proxy.begin().await } /// Commit a transaction atomically to the database #[instrument(level = "trace")] - pub fn commit(&self) { - self.proxy - .lock() - .expect("Failed to acquire mocker") - .commit() + pub async fn commit(&self) { + self.proxy.commit().await } /// Roll back a faulty transaction #[instrument(level = "trace")] - pub fn rollback(&self) { - self.proxy - .lock() - .expect("Failed to acquire mocker") - .rollback() + pub async fn rollback(&self) { + self.proxy.rollback().await } /// Checks if a connection to the database is still valid. - pub fn ping(&self) -> Result<(), DbErr> { - self.proxy.lock().map_err(query_err)?.ping() + pub async fn ping(&self) -> Result<(), DbErr> { + self.proxy.ping().await } }