diff --git a/sea-orm-macros/src/derives/active_enum.rs b/sea-orm-macros/src/derives/active_enum.rs index 591da6f33..f163275f2 100644 --- a/sea-orm-macros/src/derives/active_enum.rs +++ b/sea-orm-macros/src/derives/active_enum.rs @@ -337,6 +337,16 @@ impl ActiveEnum { } } + #[automatically_derived] + impl sea_orm::TryGetableArray for #ident { + fn try_get_by(res: &sea_orm::QueryResult, index: I) -> std::result::Result, sea_orm::TryGetError> { + <::Value as sea_orm::ActiveEnumValue>::try_get_vec_by(res, index)? + .into_iter() + .map(|value| ::try_from_value(&value).map_err(Into::into)) + .collect() + } + } + #[automatically_derived] #[allow(clippy::from_over_into)] impl Into for #ident { diff --git a/sea-orm-macros/src/derives/try_getable_from_json.rs b/sea-orm-macros/src/derives/try_getable_from_json.rs index 8742d3c9e..b111b2497 100644 --- a/sea-orm-macros/src/derives/try_getable_from_json.rs +++ b/sea-orm-macros/src/derives/try_getable_from_json.rs @@ -2,6 +2,15 @@ use proc_macro2::{Ident, TokenStream}; use quote::quote; pub fn expand_derive_from_json_query_result(ident: Ident) -> syn::Result { + let impl_not_u8 = if cfg!(feature = "postgres-array") { + quote!( + #[automatically_derived] + impl sea_orm::sea_query::value::with_array::NotU8 for #ident {} + ) + } else { + quote!() + }; + Ok(quote!( #[automatically_derived] impl sea_orm::TryGetableFromJson for #ident {} @@ -43,5 +52,7 @@ pub fn expand_derive_from_json_query_result(ident: Ident) -> syn::Result + ValueType + Nullable + TryGetable; + /// Define the Rust type that each enum variant corresponds. + type Value: ActiveEnumValue; - /// Define the enum value in Vector type. - type ValueVec: IntoIterator; + /// This has no purpose. It will be removed in the next major version. + type ValueVec; /// Get the name of enum fn name() -> DynIden; @@ -144,19 +144,53 @@ pub trait ActiveEnum: Sized + Iterable { } } -impl TryGetable for Vec -where - T: ActiveEnum, - T::ValueVec: TryGetable, -{ - fn try_get_by(res: &QueryResult, index: I) -> Result { - ::try_get_by(res, index)? - .into_iter() - .map(|value| T::try_from_value(&value).map_err(Into::into)) - .collect() - } +/// The Rust Value backing ActiveEnums +pub trait ActiveEnumValue: Into + ValueType + Nullable + TryGetable { + /// For getting an array of enum. Postgres only + fn try_get_vec_by(res: &QueryResult, index: I) -> Result, TryGetError>; +} + +macro_rules! impl_active_enum_value { + ($type:ident) => { + impl ActiveEnumValue for $type { + fn try_get_vec_by( + _res: &QueryResult, + _index: I, + ) -> Result, TryGetError> { + panic!("Not supported by `postgres-array`") + } + } + }; } +macro_rules! impl_active_enum_value_with_pg_array { + ($type:ident) => { + impl ActiveEnumValue for $type { + fn try_get_vec_by( + _res: &QueryResult, + _index: I, + ) -> Result, TryGetError> { + #[cfg(feature = "postgres-array")] + { + >::try_get_by(_res, _index) + } + #[cfg(not(feature = "postgres-array"))] + panic!("`postgres-array` is not enabled") + } + } + }; +} + +impl_active_enum_value!(u8); +impl_active_enum_value!(u16); +impl_active_enum_value!(u32); +impl_active_enum_value!(u64); +impl_active_enum_value_with_pg_array!(String); +impl_active_enum_value_with_pg_array!(i8); +impl_active_enum_value_with_pg_array!(i16); +impl_active_enum_value_with_pg_array!(i32); +impl_active_enum_value_with_pg_array!(i64); + impl TryFromU64 for T where T: ActiveEnum, diff --git a/src/entity/column.rs b/src/entity/column.rs index a2b7d8723..2117182be 100644 --- a/src/entity/column.rs +++ b/src/entity/column.rs @@ -281,6 +281,7 @@ pub trait ColumnTrait: IdenStatic + Iterable + FromStr { } /// Cast value of an enum column as enum type; do nothing if `self` is not an enum. + /// Will also transform `Array(Vec)` into `Json(Vec)` if the column type is `Json`. fn save_enum_as(&self, val: Expr) -> SimpleExpr { cast_enum_as(val, self, |col, enum_name, col_type| { let type_name = match col_type { @@ -412,9 +413,41 @@ where { let col_def = col.def(); let col_type = col_def.get_column_type(); - match col_type.get_enum_name() { - Some(enum_name) => f(expr, SeaRc::clone(enum_name), col_type), - None => expr.into(), + + match col_type { + #[cfg(all(feature = "with-json", feature = "postgres-array"))] + ColumnType::Json | ColumnType::JsonBinary => { + use sea_query::ArrayType; + use serde_json::Value as Json; + + #[allow(clippy::boxed_local)] + fn unbox(boxed: Box) -> T { + *boxed + } + + let expr = expr.into(); + match expr { + SimpleExpr::Value(Value::Array(ArrayType::Json, Some(json_vec))) => { + // flatten Array(Vec) into Json + let json_vec: Vec = json_vec + .into_iter() + .filter_map(|val| match val { + Value::Json(Some(json)) => Some(unbox(json)), + _ => None, + }) + .collect(); + SimpleExpr::Value(Value::Json(Some(Box::new(json_vec.into())))) + } + SimpleExpr::Value(Value::Array(ArrayType::Json, None)) => { + SimpleExpr::Value(Value::Json(None)) + } + _ => expr, + } + } + _ => match col_type.get_enum_name() { + Some(enum_name) => f(expr, SeaRc::clone(enum_name), col_type), + None => expr.into(), + }, } } diff --git a/src/executor/query.rs b/src/executor/query.rs index 28093822b..95ec29c42 100644 --- a/src/executor/query.rs +++ b/src/executor/query.rs @@ -960,6 +960,25 @@ fn try_get_many_with_slice_len_of(len: usize, cols: &[String]) -> Result<(), Try } } +/// An interface to get an array of values from the query result. +/// A type can only implement `ActiveEnum` or `TryGetableFromJson`, but not both. +/// A blanket impl is provided for `TryGetableFromJson`, while the impl for `ActiveEnum` +/// is provided by the `DeriveActiveEnum` macro. So as an end user you won't normally +/// touch this trait. +pub trait TryGetableArray: Sized { + /// Just a delegate + fn try_get_by(res: &QueryResult, index: I) -> Result, TryGetError>; +} + +impl TryGetable for Vec +where + T: TryGetableArray, +{ + fn try_get_by(res: &QueryResult, index: I) -> Result { + T::try_get_by(res, index) + } +} + // TryGetableFromJson // /// An interface to get a JSON from the query result @@ -999,6 +1018,22 @@ where _ => unreachable!(), } } + + /// Get a Vec from an Array of Json + fn from_json_vec(value: serde_json::Value) -> Result, TryGetError> { + match value { + serde_json::Value::Array(values) => { + let mut res = Vec::new(); + for item in values { + res.push(serde_json::from_value(item).map_err(json_err)?); + } + Ok(res) + } + _ => Err(TryGetError::DbErr(DbErr::Json( + "Value is not an Array".to_owned(), + ))), + } + } } #[cfg(feature = "with-json")] @@ -1011,6 +1046,16 @@ where } } +#[cfg(feature = "with-json")] +impl TryGetableArray for T +where + T: TryGetableFromJson, +{ + fn try_get_by(res: &QueryResult, index: I) -> Result, TryGetError> { + T::from_json_vec(serde_json::Value::try_get_by(res, index)?) + } +} + // TryFromU64 // /// Try to convert a type to a u64 pub trait TryFromU64: Sized { diff --git a/tests/active_enum_tests.rs b/tests/active_enum_tests.rs index e1aff3f98..547532052 100644 --- a/tests/active_enum_tests.rs +++ b/tests/active_enum_tests.rs @@ -22,8 +22,13 @@ async fn main() -> Result<(), DbErr> { create_tables(&ctx.db).await?; insert_active_enum(&ctx.db).await?; insert_active_enum_child(&ctx.db).await?; + + #[cfg(feature = "sqlx-postgres")] + insert_active_enum_vec(&ctx.db).await?; + find_related_active_enum(&ctx.db).await?; find_linked_active_enum(&ctx.db).await?; + ctx.delete().await; Ok(()) @@ -205,6 +210,72 @@ pub async fn insert_active_enum_child(db: &DatabaseConnection) -> Result<(), DbE Ok(()) } +pub async fn insert_active_enum_vec(db: &DatabaseConnection) -> Result<(), DbErr> { + use categories::*; + + let model = Model { + id: 1, + categories: None, + }; + + assert_eq!( + model, + ActiveModel { + id: Set(1), + categories: Set(None), + ..Default::default() + } + .insert(db) + .await? + ); + assert_eq!(model, Entity::find().one(db).await?.unwrap()); + assert_eq!( + model, + Entity::find() + .filter(Column::Id.is_not_null()) + .filter(Column::Categories.is_null()) + .one(db) + .await? + .unwrap() + ); + + let _ = ActiveModel { + id: Set(1), + categories: Set(Some(vec![Category::Big, Category::Small])), + ..model.into_active_model() + } + .save(db) + .await?; + + let model = Entity::find().one(db).await?.unwrap(); + assert_eq!( + model, + Model { + id: 1, + categories: Some(vec![Category::Big, Category::Small]), + } + ); + assert_eq!( + model, + Entity::find() + .filter(Column::Id.eq(1)) + .filter(Expr::cust_with_values( + r#"$1 = ANY("categories")"#, + vec![Category::Big] + )) + .one(db) + .await? + .unwrap() + ); + + let res = model.delete(db).await?; + + assert_eq!(res.rows_affected, 1); + assert_eq!(Entity::find().one(db).await?, None); + + Ok(()) +} + pub async fn find_related_active_enum(db: &DatabaseConnection) -> Result<(), DbErr> { assert_eq!( active_enum::Model { diff --git a/tests/common/features/active_enum_vec.rs b/tests/common/features/active_enum_vec.rs new file mode 100644 index 000000000..66bda0de3 --- /dev/null +++ b/tests/common/features/active_enum_vec.rs @@ -0,0 +1,16 @@ +use super::sea_orm_active_enums::*; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[cfg_attr(feature = "sqlx-postgres", sea_orm(schema_name = "public"))] +#[sea_orm(table_name = "active_enum")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub categories: Option>, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/tests/common/features/categories.rs b/tests/common/features/categories.rs new file mode 100644 index 000000000..00f25674d --- /dev/null +++ b/tests/common/features/categories.rs @@ -0,0 +1,15 @@ +use super::sea_orm_active_enums::*; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "categories")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub id: i32, + pub categories: Option>, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/tests/common/features/json_vec_derive.rs b/tests/common/features/json_vec_derive.rs index cc6e177b7..58eb3ed4f 100644 --- a/tests/common/features/json_vec_derive.rs +++ b/tests/common/features/json_vec_derive.rs @@ -1,19 +1,46 @@ -use sea_orm::entity::prelude::*; -use sea_orm::FromJsonQueryResult; -use serde::{Deserialize, Serialize}; - -#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] -#[sea_orm(table_name = "json_vec")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: i32, - pub str_vec: Option, +pub mod json_string_vec { + use sea_orm::entity::prelude::*; + use sea_orm::FromJsonQueryResult; + use serde::{Deserialize, Serialize}; + + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "json_string_vec")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub str_vec: Option, + } + + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, FromJsonQueryResult)] + pub struct StringVec(pub Vec); + + #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] + pub enum Relation {} + + impl ActiveModelBehavior for ActiveModel {} } -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} +pub mod json_struct_vec { + use sea_orm::entity::prelude::*; + use sea_orm_macros::FromJsonQueryResult; + use serde::{Deserialize, Serialize}; -impl ActiveModelBehavior for ActiveModel {} + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "json_struct_vec")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + #[sea_orm(column_type = "JsonBinary")] + pub struct_vec: Vec, + } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, FromJsonQueryResult)] -pub struct StringVec(pub Vec); + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, FromJsonQueryResult)] + pub struct JsonColumn { + pub value: String, + } + + #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] + pub enum Relation {} + + impl ActiveModelBehavior for ActiveModel {} +} diff --git a/tests/common/features/mod.rs b/tests/common/features/mod.rs index 84a6f7f22..3d4dee9a5 100644 --- a/tests/common/features/mod.rs +++ b/tests/common/features/mod.rs @@ -1,9 +1,11 @@ pub mod active_enum; pub mod active_enum_child; +pub mod active_enum_vec; pub mod applog; pub mod binary; pub mod bits; pub mod byte_primary_key; +pub mod categories; pub mod collection; pub mod collection_expanded; pub mod custom_active_model; @@ -28,10 +30,12 @@ pub mod value_type; pub use active_enum::Entity as ActiveEnum; pub use active_enum_child::Entity as ActiveEnumChild; +pub use active_enum_vec::Entity as ActiveEnumVec; pub use applog::Entity as Applog; pub use binary::Entity as Binary; pub use bits::Entity as Bits; pub use byte_primary_key::Entity as BytePrimaryKey; +pub use categories::Entity as Categories; pub use collection::Entity as Collection; pub use collection_expanded::Entity as CollectionExpanded; pub use dyn_table_name_lazy_static::Entity as DynTableNameLazyStatic; @@ -40,6 +44,8 @@ pub use event_trigger::Entity as EventTrigger; pub use insert_default::Entity as InsertDefault; pub use json_struct::Entity as JsonStruct; pub use json_vec::Entity as JsonVec; +pub use json_vec_derive::json_string_vec::Entity as JsonStringVec; +pub use json_vec_derive::json_struct_vec::Entity as JsonStructVec; pub use metadata::Entity as Metadata; pub use pi::Entity as Pi; pub use repository::Entity as Repository; diff --git a/tests/common/features/schema.rs b/tests/common/features/schema.rs index 6e1f71cc0..0fb15cf02 100644 --- a/tests/common/features/schema.rs +++ b/tests/common/features/schema.rs @@ -5,7 +5,8 @@ use sea_orm::{ ExecResult, Schema, }; use sea_query::{ - extension::postgres::Type, Alias, BlobSize, ColumnDef, ForeignKeyCreateStatement, IntoIden, + extension::postgres::Type, Alias, BlobSize, ColumnDef, ColumnType, ForeignKeyCreateStatement, + IntoIden, }; pub async fn create_tables(db: &DatabaseConnection) -> Result<(), DbErr> { @@ -18,8 +19,6 @@ pub async fn create_tables(db: &DatabaseConnection) -> Result<(), DbErr> { create_byte_primary_key_table(db).await?; create_satellites_table(db).await?; create_transaction_log_table(db).await?; - create_json_vec_table(db).await?; - create_json_struct_table(db).await?; let create_enum_stmts = match db_backend { DbBackend::MySql | DbBackend::Sqlite => Vec::new(), @@ -50,10 +49,16 @@ pub async fn create_tables(db: &DatabaseConnection) -> Result<(), DbErr> { create_dyn_table_name_lazy_static_table(db).await?; create_value_type_table(db).await?; + create_json_vec_table(db).await?; + create_json_struct_table(db).await?; + create_json_string_vec_table(db).await?; + create_json_struct_vec_table(db).await?; + if DbBackend::Postgres == db_backend { create_value_type_postgres_table(db).await?; create_collection_table(db).await?; create_event_trigger_table(db).await?; + create_categories_table(db).await?; } Ok(()) @@ -341,6 +346,42 @@ pub async fn create_json_struct_table(db: &DbConn) -> Result create_table(db, &stmt, JsonStruct).await } +pub async fn create_json_string_vec_table(db: &DbConn) -> Result { + let create_table_stmt = sea_query::Table::create() + .table(JsonStringVec.table_ref()) + .col( + ColumnDef::new(json_vec_derive::json_string_vec::Column::Id) + .integer() + .not_null() + .auto_increment() + .primary_key(), + ) + .col(ColumnDef::new(json_vec_derive::json_string_vec::Column::StrVec).json()) + .to_owned(); + + create_table(db, &create_table_stmt, JsonStringVec).await +} + +pub async fn create_json_struct_vec_table(db: &DbConn) -> Result { + let create_table_stmt = sea_query::Table::create() + .table(JsonStructVec.table_ref()) + .col( + ColumnDef::new(json_vec_derive::json_struct_vec::Column::Id) + .integer() + .not_null() + .auto_increment() + .primary_key(), + ) + .col( + ColumnDef::new(json_vec_derive::json_struct_vec::Column::StructVec) + .json_binary() + .not_null(), + ) + .to_owned(); + + create_table(db, &create_table_stmt, JsonStructVec).await +} + pub async fn create_collection_table(db: &DbConn) -> Result { db.execute(sea_orm::Statement::from_string( db.get_database_backend(), @@ -521,6 +562,21 @@ pub async fn create_teas_table(db: &DbConn) -> Result { create_table(db, &create_table_stmt, Teas).await } +pub async fn create_categories_table(db: &DbConn) -> Result { + let create_table_stmt = sea_query::Table::create() + .table(categories::Entity.table_ref()) + .col( + ColumnDef::new(categories::Column::Id) + .integer() + .not_null() + .primary_key(), + ) + .col(ColumnDef::new(categories::Column::Categories).array(ColumnType::String(Some(1)))) + .to_owned(); + + create_table(db, &create_table_stmt, Categories).await +} + pub async fn create_binary_table(db: &DbConn) -> Result { let create_table_stmt = sea_query::Table::create() .table(binary::Entity.table_ref()) diff --git a/tests/common/setup/mod.rs b/tests/common/setup/mod.rs index bd4d38bf5..2cdc00d21 100644 --- a/tests/common/setup/mod.rs +++ b/tests/common/setup/mod.rs @@ -1,7 +1,8 @@ use pretty_assertions::assert_eq; use sea_orm::{ - ColumnTrait, ColumnType, ConnectionTrait, Database, DatabaseBackend, DatabaseConnection, - DbBackend, DbConn, DbErr, EntityTrait, ExecResult, Iterable, Schema, Statement, + ColumnTrait, ColumnType, ConnectOptions, ConnectionTrait, Database, DatabaseBackend, + DatabaseConnection, DbBackend, DbConn, DbErr, EntityTrait, ExecResult, Iterable, Schema, + Statement, }; use sea_query::{ extension::postgres::{Type, TypeCreateStatement}, @@ -48,7 +49,9 @@ pub async fn setup(base_url: &str, db_name: &str) -> DatabaseConnection { let url = format!("{base_url}/{db_name}"); Database::connect(&url).await.unwrap() } else { - Database::connect(base_url).await.unwrap() + let mut options: ConnectOptions = base_url.into(); + options.sqlx_logging(false); + Database::connect(options).await.unwrap() } } diff --git a/tests/json_vec_tests.rs b/tests/json_vec_tests.rs index 9af2a3304..f92cb6732 100644 --- a/tests/json_vec_tests.rs +++ b/tests/json_vec_tests.rs @@ -14,7 +14,9 @@ async fn main() -> Result<(), DbErr> { let ctx = TestContext::new("json_vec_tests").await; create_tables(&ctx.db).await?; insert_json_vec(&ctx.db).await?; - insert_json_vec_derive(&ctx.db).await?; + insert_json_string_vec_derive(&ctx.db).await?; + insert_json_struct_vec_derive(&ctx.db).await?; + ctx.delete().await; Ok(()) @@ -44,10 +46,10 @@ pub async fn insert_json_vec(db: &DatabaseConnection) -> Result<(), DbErr> { Ok(()) } -pub async fn insert_json_vec_derive(db: &DatabaseConnection) -> Result<(), DbErr> { - let json_vec = json_vec_derive::Model { +pub async fn insert_json_string_vec_derive(db: &DatabaseConnection) -> Result<(), DbErr> { + let json_vec = json_vec_derive::json_string_vec::Model { id: 2, - str_vec: Some(json_vec_derive::StringVec(vec![ + str_vec: Some(json_vec_derive::json_string_vec::StringVec(vec![ "4".to_string(), "5".to_string(), "6".to_string(), @@ -58,8 +60,37 @@ pub async fn insert_json_vec_derive(db: &DatabaseConnection) -> Result<(), DbErr assert_eq!(result, json_vec); - let model = json_vec_derive::Entity::find() - .filter(json_vec_derive::Column::Id.eq(json_vec.id)) + let model = json_vec_derive::json_string_vec::Entity::find() + .filter(json_vec_derive::json_string_vec::Column::Id.eq(json_vec.id)) + .one(db) + .await?; + + assert_eq!(model, Some(json_vec)); + + Ok(()) +} + +pub async fn insert_json_struct_vec_derive(db: &DatabaseConnection) -> Result<(), DbErr> { + let json_vec = json_vec_derive::json_struct_vec::Model { + id: 2, + struct_vec: vec![ + json_vec_derive::json_struct_vec::JsonColumn { + value: "4".to_string(), + }, + json_vec_derive::json_struct_vec::JsonColumn { + value: "5".to_string(), + }, + json_vec_derive::json_struct_vec::JsonColumn { + value: "6".to_string(), + }, + ], + }; + + let result = json_vec.clone().into_active_model().insert(db).await?; + assert_eq!(result, json_vec); + + let model = json_vec_derive::json_struct_vec::Entity::find() + .filter(json_vec_derive::json_struct_vec::Column::Id.eq(json_vec.id)) .one(db) .await?;