diff --git a/.gitignore b/.gitignore index d6af0d6d..293baf5a 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ __pycache__ dist target *.so +*.pyd .vscode .devcontainer .python-version diff --git a/Cargo.lock b/Cargo.lock index 60095bba..280c2399 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -116,9 +116,9 @@ checksum = "d32a994c2b3ca201d9b263612a374263f05e7adde37c4707f693dcd375076d1f" [[package]] name = "bytemuck" -version = "1.14.3" +version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2ef034f05691a48569bd920a96c81b9d91bbad1ab5ac7c4616c1f6ef36cb79f" +checksum = "773d90827bc3feecfb67fab12e24de0749aad83c74b9504ecde46237b5cd24e2" dependencies = [ "bytemuck_derive", ] @@ -134,6 +134,12 @@ dependencies = [ "syn 2.0.50", ] +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.5.0" @@ -427,6 +433,15 @@ version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.10" @@ -517,8 +532,10 @@ name = "medmodels-core" version = "0.1.2" dependencies = [ "chrono", + "itertools", "medmodels-utils", "polars", + "roaring", "ron", "serde", ] @@ -1335,6 +1352,16 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +[[package]] +name = "roaring" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f4b84ba6e838ceb47b41de5194a60244fac43d9fe03b71dbe8c5a201081d6d1" +dependencies = [ + "bytemuck", + "byteorder", +] + [[package]] name = "ron" version = "0.8.1" diff --git a/Cargo.toml b/Cargo.toml index eac0a8c7..40154170 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,11 +13,8 @@ description = "Limebit MedModels Crate" [workspace.dependencies] hashbrown = { version = "0.14.5", features = ["serde"] } serde = { version = "1.0.203", features = ["derive"] } -ron = "0.8.1" -chrono = { version = "0.4.38", features = ["serde"] } -pyo3 = { version = "0.21.2", features = ["chrono"] } polars = { version = "0.40.0", features = ["polars-io"] } -pyo3-polars = "0.14.0" +chrono = { version = "0.4.38", features = ["serde"] } medmodels = { version = "0.1.2", path = "crates/medmodels" } medmodels-core = { version = "0.1.2", path = "crates/medmodels-core" } diff --git a/crates/medmodels-core/Cargo.toml b/crates/medmodels-core/Cargo.toml index 58097fcf..48225587 100644 --- a/crates/medmodels-core/Cargo.toml +++ b/crates/medmodels-core/Cargo.toml @@ -12,5 +12,7 @@ medmodels-utils = { workspace = true } polars = { workspace = true } serde = { workspace = true } -ron = { workspace = true } chrono = { workspace = true } +ron = "0.8.1" +roaring = "0.10.6" +itertools = "0.13.0" diff --git a/crates/medmodels-core/src/errors/medrecord.rs b/crates/medmodels-core/src/errors/medrecord.rs index f7afb230..3ad22a14 100644 --- a/crates/medmodels-core/src/errors/medrecord.rs +++ b/crates/medmodels-core/src/errors/medrecord.rs @@ -10,6 +10,7 @@ pub enum MedRecordError { ConversionError(String), AssertionError(String), SchemaError(String), + QueryError(String), } impl Error for MedRecordError { @@ -20,6 +21,7 @@ impl Error for MedRecordError { MedRecordError::ConversionError(message) => message, MedRecordError::AssertionError(message) => message, MedRecordError::SchemaError(message) => message, + MedRecordError::QueryError(message) => message, } } } @@ -32,6 +34,7 @@ impl Display for MedRecordError { Self::ConversionError(message) => write!(f, "ConversionError: {}", message), Self::AssertionError(message) => write!(f, "AssertionError: {}", message), Self::SchemaError(message) => write!(f, "SchemaError: {}", message), + Self::QueryError(message) => write!(f, "QueryError: {}", message), } } } diff --git a/crates/medmodels-core/src/errors/mod.rs b/crates/medmodels-core/src/errors/mod.rs index b0c37588..069281ca 100644 --- a/crates/medmodels-core/src/errors/mod.rs +++ b/crates/medmodels-core/src/errors/mod.rs @@ -14,6 +14,8 @@ impl From for MedRecordError { } } +pub type MedRecordResult = Result; + #[cfg(test)] mod test { use super::{GraphError, MedRecordError}; diff --git a/crates/medmodels-core/src/medrecord/datatypes/attribute.rs b/crates/medmodels-core/src/medrecord/datatypes/attribute.rs index f02f12d4..bdb2f12d 100644 --- a/crates/medmodels-core/src/medrecord/datatypes/attribute.rs +++ b/crates/medmodels-core/src/medrecord/datatypes/attribute.rs @@ -1,8 +1,16 @@ -use super::{Contains, EndsWith, MedRecordValue, StartsWith}; -use crate::errors::MedRecordError; +use super::{ + Abs, Contains, EndsWith, Lowercase, MedRecordValue, Mod, Pow, Slice, StartsWith, Trim, TrimEnd, + TrimStart, Uppercase, +}; +use crate::errors::{MedRecordError, MedRecordResult}; use medmodels_utils::implement_from_for_wrapper; use serde::{Deserialize, Serialize}; -use std::{cmp::Ordering, fmt::Display, hash::Hash}; +use std::{ + cmp::Ordering, + fmt::Display, + hash::Hash, + ops::{Add, Mul, Sub}, +}; #[derive(Debug, Clone, Serialize, Deserialize)] pub enum MedRecordAttribute { @@ -43,15 +51,6 @@ impl TryFrom for MedRecordAttribute { } } -impl Display for MedRecordAttribute { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::String(value) => write!(f, "{}", value), - Self::Int(value) => write!(f, "{}", value), - } - } -} - impl PartialEq for MedRecordAttribute { fn eq(&self, other: &Self) -> bool { match (self, other) { @@ -80,6 +79,140 @@ impl PartialOrd for MedRecordAttribute { } } +impl Display for MedRecordAttribute { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::String(value) => write!(f, "{}", value), + Self::Int(value) => write!(f, "{}", value), + } + } +} + +// TODO: Add tests +impl Add for MedRecordAttribute { + type Output = MedRecordResult; + + fn add(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => { + Ok(MedRecordAttribute::String(value + rhs.as_str())) + } + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot add {} to {}", rhs, value)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot add {} to {}", rhs, value)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value + rhs)) + } + } + } +} + +// TODO: Add tests +impl Sub for MedRecordAttribute { + type Output = MedRecordResult; + + fn sub(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), + ), + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value - rhs)) + } + } + } +} + +// TODO: Add tests +impl Mul for MedRecordAttribute { + type Output = MedRecordResult; + + fn mul(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot multiply {} by {}", value, rhs)), + ), + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot multiply {} by {}", value, rhs)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot multiply {} by {}", value, rhs)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value * rhs)) + } + } + } +} + +// TODO: Add tests +impl Pow for MedRecordAttribute { + fn pow(self, rhs: Self) -> MedRecordResult { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => { + Err(MedRecordError::AssertionError(format!( + "Cannot raise {} to the power of {}", + value, rhs + ))) + } + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => { + Err(MedRecordError::AssertionError(format!( + "Cannot raise {} to the power of {}", + value, rhs + ))) + } + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => { + Err(MedRecordError::AssertionError(format!( + "Cannot raise {} to the power of {}", + value, rhs + ))) + } + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value.pow(rhs as u32))) + } + } + } +} + +// TODO: Add tests +impl Mod for MedRecordAttribute { + fn r#mod(self, rhs: Self) -> MedRecordResult { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot mod {} by {}", value, rhs)), + ), + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot mod {} by {}", value, rhs)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot mod {} by {}", value, rhs)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value % rhs)) + } + } + } +} + +// TODO: Add tests +impl Abs for MedRecordAttribute { + fn abs(self) -> Self { + match self { + MedRecordAttribute::Int(value) => MedRecordAttribute::Int(value.abs()), + _ => self, + } + } +} + impl StartsWith for MedRecordAttribute { fn starts_with(&self, other: &Self) -> bool { match (self, other) { @@ -137,6 +270,72 @@ impl Contains for MedRecordAttribute { } } +// TODO: Add tests +impl Slice for MedRecordAttribute { + fn slice(self, range: std::ops::Range) -> Self { + match self { + MedRecordAttribute::String(value) => value[range].into(), + MedRecordAttribute::Int(value) => value.to_string()[range].into(), + } + } +} + +// TODO: Add tests +impl Trim for MedRecordAttribute { + fn trim(self) -> Self { + match self { + MedRecordAttribute::String(value) => { + MedRecordAttribute::String(value.trim().to_string()) + } + _ => self, + } + } +} + +// TODO: Add tests +impl TrimStart for MedRecordAttribute { + fn trim_start(self) -> Self { + match self { + MedRecordAttribute::String(value) => { + MedRecordAttribute::String(value.trim_start().to_string()) + } + _ => self, + } + } +} + +// TODO: Add tests +impl TrimEnd for MedRecordAttribute { + fn trim_end(self) -> Self { + match self { + MedRecordAttribute::String(value) => { + MedRecordAttribute::String(value.trim_end().to_string()) + } + _ => self, + } + } +} + +// TODO: Add tests +impl Lowercase for MedRecordAttribute { + fn lowercase(self) -> Self { + match self { + MedRecordAttribute::String(value) => MedRecordAttribute::String(value.to_lowercase()), + _ => self, + } + } +} + +// TODO: Add tests +impl Uppercase for MedRecordAttribute { + fn uppercase(self) -> Self { + match self { + MedRecordAttribute::String(value) => MedRecordAttribute::String(value.to_uppercase()), + _ => self, + } + } +} + #[cfg(test)] mod test { use super::MedRecordAttribute; diff --git a/crates/medmodels-core/src/medrecord/datatypes/mod.rs b/crates/medmodels-core/src/medrecord/datatypes/mod.rs index 0beca37e..ada0f6c0 100644 --- a/crates/medmodels-core/src/medrecord/datatypes/mod.rs +++ b/crates/medmodels-core/src/medrecord/datatypes/mod.rs @@ -2,6 +2,7 @@ mod attribute; mod value; pub use self::{attribute::MedRecordAttribute, value::MedRecordValue}; +use super::EdgeIndex; use crate::errors::MedRecordError; use serde::{Deserialize, Serialize}; use std::{fmt::Display, ops::Range}; @@ -51,6 +52,24 @@ impl From<&MedRecordValue> for DataType { } } +impl From for DataType { + fn from(value: MedRecordAttribute) -> Self { + match value { + MedRecordAttribute::String(_) => DataType::String, + MedRecordAttribute::Int(_) => DataType::Int, + } + } +} + +impl From<&MedRecordAttribute> for DataType { + fn from(value: &MedRecordAttribute) -> Self { + match value { + MedRecordAttribute::String(_) => DataType::String, + MedRecordAttribute::Int(_) => DataType::Int, + } + } +} + impl PartialEq for DataType { fn eq(&self, other: &Self) -> bool { match (self, other) { @@ -126,28 +145,52 @@ impl DataType { } } -pub trait Pow: Sized { - fn pow(self, exp: Self) -> Result; -} - -pub trait Mod: Sized { - fn r#mod(self, other: Self) -> Result; -} - pub trait StartsWith { fn starts_with(&self, other: &Self) -> bool; } +// TODO: Add tests +impl StartsWith for EdgeIndex { + fn starts_with(&self, other: &Self) -> bool { + self.to_string().starts_with(&other.to_string()) + } +} + pub trait EndsWith { fn ends_with(&self, other: &Self) -> bool; } +// TODO: Add tests +impl EndsWith for EdgeIndex { + fn ends_with(&self, other: &Self) -> bool { + self.to_string().ends_with(&other.to_string()) + } +} + pub trait Contains { fn contains(&self, other: &Self) -> bool; } -pub trait PartialNeq: PartialEq { - fn neq(&self, other: &Self) -> bool; +// TODO: Add tests +impl Contains for EdgeIndex { + fn contains(&self, other: &Self) -> bool { + self.to_string().contains(&other.to_string()) + } +} + +pub trait Pow: Sized { + fn pow(self, exp: Self) -> Result; +} + +pub trait Mod: Sized { + fn r#mod(self, other: Self) -> Result; +} + +// TODO: Add tests +impl Mod for EdgeIndex { + fn r#mod(self, other: Self) -> Result { + Ok(self % other) + } } pub trait Round { @@ -194,15 +237,6 @@ pub trait Slice { fn slice(self, range: Range) -> Self; } -impl PartialNeq for T -where - T: PartialOrd, -{ - fn neq(&self, other: &Self) -> bool { - self != other - } -} - #[cfg(test)] mod test { use super::{DataType, MedRecordValue}; diff --git a/crates/medmodels-core/src/medrecord/datatypes/value.rs b/crates/medmodels-core/src/medrecord/datatypes/value.rs index 792d879d..f3995102 100644 --- a/crates/medmodels-core/src/medrecord/datatypes/value.rs +++ b/crates/medmodels-core/src/medrecord/datatypes/value.rs @@ -3,7 +3,7 @@ use super::{ Trim, TrimEnd, TrimStart, Uppercase, }; use crate::errors::MedRecordError; -use chrono::NaiveDateTime; +use chrono::{DateTime, NaiveDateTime}; use medmodels_utils::implement_from_for_wrapper; use serde::{Deserialize, Serialize}; use std::{ @@ -210,9 +210,17 @@ impl Add for MedRecordValue { (MedRecordValue::DateTime(value), MedRecordValue::Bool(rhs)) => Err( MedRecordError::AssertionError(format!("Cannot add {} to {}", rhs, value)), ), - (MedRecordValue::DateTime(value), MedRecordValue::DateTime(rhs)) => Err( - MedRecordError::AssertionError(format!("Cannot add {} to {}", rhs, value)), - ), + (MedRecordValue::DateTime(value), MedRecordValue::DateTime(rhs)) => { + Ok(DateTime::from_timestamp( + value.and_utc().timestamp() + rhs.and_utc().timestamp(), + 0, + ) + .ok_or(MedRecordError::AssertionError( + "Invalid timestamp".to_string(), + ))? + .naive_utc() + .into()) + } (MedRecordValue::DateTime(value), MedRecordValue::Null) => Err( MedRecordError::AssertionError(format!("Cannot add None to {}", value)), ), @@ -327,9 +335,17 @@ impl Sub for MedRecordValue { (MedRecordValue::DateTime(value), MedRecordValue::Bool(rhs)) => Err( MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), ), - (MedRecordValue::DateTime(value), MedRecordValue::DateTime(rhs)) => Err( - MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), - ), + (MedRecordValue::DateTime(value), MedRecordValue::DateTime(rhs)) => { + Ok(DateTime::from_timestamp( + value.and_utc().timestamp() - rhs.and_utc().timestamp(), + 0, + ) + .ok_or(MedRecordError::AssertionError( + "Invalid timestamp".to_string(), + ))? + .naive_utc() + .into()) + } (MedRecordValue::DateTime(value), MedRecordValue::Null) => Err( MedRecordError::AssertionError(format!("Cannot subtract None from {}", value)), ), @@ -621,9 +637,17 @@ impl Div for MedRecordValue { (MedRecordValue::DateTime(value), MedRecordValue::String(other)) => Err( MedRecordError::AssertionError(format!("Cannot divide {} by {}", value, other)), ), - (MedRecordValue::DateTime(value), MedRecordValue::Int(other)) => Err( - MedRecordError::AssertionError(format!("Cannot divide {} by {}", value, other)), - ), + (MedRecordValue::DateTime(value), MedRecordValue::Int(other)) => { + Ok(DateTime::from_timestamp( + (value.and_utc().timestamp() as f64 / other as f64).floor() as i64, + 0, + ) + .ok_or(MedRecordError::AssertionError( + "Invalid timestamp".to_string(), + ))? + .naive_utc() + .into()) + } (MedRecordValue::DateTime(value), MedRecordValue::Float(other)) => Err( MedRecordError::AssertionError(format!("Cannot divide {} by {}", value, other)), ), @@ -966,6 +990,53 @@ impl Mod for MedRecordValue { } } +impl Round for MedRecordValue { + fn round(self) -> Self { + match self { + MedRecordValue::Float(value) => MedRecordValue::Float(value.round()), + _ => self, + } + } +} + +impl Ceil for MedRecordValue { + fn ceil(self) -> Self { + match self { + MedRecordValue::Float(value) => MedRecordValue::Float(value.ceil()), + _ => self, + } + } +} + +impl Floor for MedRecordValue { + fn floor(self) -> Self { + match self { + MedRecordValue::Float(value) => MedRecordValue::Float(value.floor()), + _ => self, + } + } +} + +impl Abs for MedRecordValue { + fn abs(self) -> Self { + match self { + MedRecordValue::Int(value) => MedRecordValue::Int(value.abs()), + MedRecordValue::Float(value) => MedRecordValue::Float(value.abs()), + _ => self, + } + } +} + +impl Sqrt for MedRecordValue { + fn sqrt(self) -> Self { + match self { + MedRecordValue::Int(value) => MedRecordValue::Float((value as f64).sqrt()), + MedRecordValue::Float(value) => MedRecordValue::Float(value.sqrt()), + _ => self, + } + } +} + impl StartsWith for MedRecordValue { fn starts_with(&self, other: &Self) -> bool { match (self, other) { @@ -1081,53 +1152,6 @@ impl Slice for MedRecordValue { } } -impl Round for MedRecordValue { - fn round(self) -> Self { - match self { - MedRecordValue::Float(value) => MedRecordValue::Float(value.round()), - _ => self, - } - } -} - -impl Ceil for MedRecordValue { - fn ceil(self) -> Self { - match self { - MedRecordValue::Float(value) => MedRecordValue::Float(value.ceil()), - _ => self, - } - } -} - -impl Floor for MedRecordValue { - fn floor(self) -> Self { - match self { - MedRecordValue::Float(value) => MedRecordValue::Float(value.floor()), - _ => self, - } - } -} - -impl Abs for MedRecordValue { - fn abs(self) -> Self { - match self { - MedRecordValue::Int(value) => MedRecordValue::Int(value.abs()), - MedRecordValue::Float(value) => MedRecordValue::Float(value.abs()), - _ => self, - } - } -} - -impl Sqrt for MedRecordValue { - fn sqrt(self) -> Self { - match self { - MedRecordValue::Int(value) => MedRecordValue::Float((value as f64).sqrt()), - MedRecordValue::Float(value) => MedRecordValue::Float(value.sqrt()), - _ => self, - } - } -} - impl Trim for MedRecordValue { fn trim(self) -> Self { match self { @@ -1183,7 +1207,7 @@ mod test { Uppercase, }, }; - use chrono::NaiveDateTime; + use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime}; #[test] fn test_default() { @@ -1669,9 +1693,23 @@ mod test { (MedRecordValue::DateTime(NaiveDateTime::MIN) + MedRecordValue::Bool(false)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); - assert!((MedRecordValue::DateTime(NaiveDateTime::MIN) - + MedRecordValue::DateTime(NaiveDateTime::MIN)) - .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); + assert_eq!( + MedRecordValue::DateTime( + NaiveDate::from_ymd_opt(1970, 1, 4) + .unwrap() + .and_time(NaiveTime::MIN) + ), + (MedRecordValue::DateTime( + NaiveDate::from_ymd_opt(1970, 1, 2) + .unwrap() + .and_time(NaiveTime::MIN) + ) + MedRecordValue::DateTime( + NaiveDate::from_ymd_opt(1970, 1, 3) + .unwrap() + .and_time(NaiveTime::MIN) + )) + .unwrap() + ); assert!( (MedRecordValue::DateTime(NaiveDateTime::MIN) + MedRecordValue::Null) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) @@ -1794,9 +1832,12 @@ mod test { (MedRecordValue::DateTime(NaiveDateTime::MIN) - MedRecordValue::Bool(false)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); - assert!((MedRecordValue::DateTime(NaiveDateTime::MIN) - - MedRecordValue::DateTime(NaiveDateTime::MIN)) - .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); + assert_eq!( + MedRecordValue::DateTime(DateTime::from_timestamp(0, 0).unwrap().naive_utc()), + (MedRecordValue::DateTime(NaiveDateTime::MAX) + - MedRecordValue::DateTime(NaiveDateTime::MAX)) + .unwrap() + ); assert!( (MedRecordValue::DateTime(NaiveDateTime::MIN) - MedRecordValue::Null) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) @@ -1951,15 +1992,15 @@ mod test { / MedRecordValue::String("value".to_string())) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( - (MedRecordValue::String("value".to_string()) / MedRecordValue::Int(0)) + (MedRecordValue::String("value".to_string()) / MedRecordValue::Int(1)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!( - (MedRecordValue::String("value".to_string()) / MedRecordValue::Float(0_f64)) + (MedRecordValue::String("value".to_string()) / MedRecordValue::Float(1_f64)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!( - (MedRecordValue::String("value".to_string()) / MedRecordValue::Bool(false)) + (MedRecordValue::String("value".to_string()) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!((MedRecordValue::String("value".to_string()) @@ -1982,7 +2023,7 @@ mod test { MedRecordValue::Float(1_f64), (MedRecordValue::Int(5) / MedRecordValue::Float(5_f64)).unwrap() ); - assert!((MedRecordValue::Int(0) / MedRecordValue::Bool(false)) + assert!((MedRecordValue::Int(0) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( (MedRecordValue::Int(0) / MedRecordValue::DateTime(NaiveDateTime::MIN)) @@ -2003,7 +2044,7 @@ mod test { MedRecordValue::Float(1_f64), (MedRecordValue::Float(5_f64) / MedRecordValue::Float(5_f64)).unwrap() ); - assert!((MedRecordValue::Float(0_f64) / MedRecordValue::Bool(false)) + assert!((MedRecordValue::Float(0_f64) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( (MedRecordValue::Float(0_f64) / MedRecordValue::DateTime(NaiveDateTime::MIN)) @@ -2016,11 +2057,11 @@ mod test { (MedRecordValue::Bool(false) / MedRecordValue::String("value".to_string())) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); - assert!((MedRecordValue::Bool(false) / MedRecordValue::Int(0)) + assert!((MedRecordValue::Bool(false) / MedRecordValue::Int(1)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!((MedRecordValue::Bool(false) / MedRecordValue::Float(0_f64)) + assert!((MedRecordValue::Bool(false) / MedRecordValue::Float(1_f64)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!((MedRecordValue::Bool(false) / MedRecordValue::Bool(false)) + assert!((MedRecordValue::Bool(false) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( (MedRecordValue::Bool(false) / MedRecordValue::DateTime(NaiveDateTime::MIN)) @@ -2032,16 +2073,16 @@ mod test { assert!((MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::String("value".to_string())) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!( - (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Int(0)) - .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) + assert_eq!( + MedRecordValue::DateTime(NaiveDateTime::MIN), + (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Int(1)).unwrap() ); assert!( - (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Float(0_f64)) + (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Float(1_f64)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!( - (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Bool(false)) + (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!((MedRecordValue::DateTime(NaiveDateTime::MIN) @@ -2056,11 +2097,11 @@ mod test { (MedRecordValue::Null / MedRecordValue::String("value".to_string())) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); - assert!((MedRecordValue::Null / MedRecordValue::Int(0)) + assert!((MedRecordValue::Null / MedRecordValue::Int(1)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!((MedRecordValue::Null / MedRecordValue::Float(0_f64)) + assert!((MedRecordValue::Null / MedRecordValue::Float(1_f64)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!((MedRecordValue::Null / MedRecordValue::Bool(false)) + assert!((MedRecordValue::Null / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( (MedRecordValue::Null / MedRecordValue::DateTime(NaiveDateTime::MIN)) diff --git a/crates/medmodels-core/src/medrecord/example_dataset/mod.rs b/crates/medmodels-core/src/medrecord/example_dataset/mod.rs index e4879307..2a0f3354 100644 --- a/crates/medmodels-core/src/medrecord/example_dataset/mod.rs +++ b/crates/medmodels-core/src/medrecord/example_dataset/mod.rs @@ -71,7 +71,7 @@ impl MedRecord { .into_reader_with_file_handle(cursor) .finish() .expect("DataFrame can be built"); - let patient_diagnosis_ids = (0..patient_diagnosis.height()).collect::>(); + let patient_diagnosis_ids = (0..patient_diagnosis.height() as u32).collect::>(); let cursor = Cursor::new(PATIENT_DRUG); let patient_drug = CsvReadOptions::default() @@ -79,8 +79,8 @@ impl MedRecord { .into_reader_with_file_handle(cursor) .finish() .expect("DataFrame can be built"); - let patient_drug_ids = (patient_diagnosis.height() - ..patient_diagnosis.height() + patient_drug.height()) + let patient_drug_ids = (patient_diagnosis.height() as u32 + ..(patient_diagnosis.height() + patient_drug.height()) as u32) .collect::>(); let cursor = Cursor::new(PATIENT_PROCEDURE); @@ -89,8 +89,9 @@ impl MedRecord { .into_reader_with_file_handle(cursor) .finish() .expect("DataFrame can be built"); - let patient_procedure_ids = (patient_diagnosis.height() + patient_drug.height() - ..patient_diagnosis.height() + patient_drug.height() + patient_procedure.height()) + let patient_procedure_ids = ((patient_diagnosis.height() + patient_drug.height()) as u32 + ..(patient_diagnosis.height() + patient_drug.height() + patient_procedure.height()) + as u32) .collect::>(); let mut medrecord = Self::from_dataframes( diff --git a/crates/medmodels-core/src/medrecord/graph/edge.rs b/crates/medmodels-core/src/medrecord/graph/edge.rs index a45b6c4d..36b790d8 100644 --- a/crates/medmodels-core/src/medrecord/graph/edge.rs +++ b/crates/medmodels-core/src/medrecord/graph/edge.rs @@ -3,9 +3,9 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Edge { - pub attributes: Attributes, - pub(super) source_node_index: NodeIndex, - pub(super) target_node_index: NodeIndex, + pub(crate) attributes: Attributes, + pub(crate) source_node_index: NodeIndex, + pub(crate) target_node_index: NodeIndex, } impl Edge { diff --git a/crates/medmodels-core/src/medrecord/graph/mod.rs b/crates/medmodels-core/src/medrecord/graph/mod.rs index 0a7da3de..96a82584 100644 --- a/crates/medmodels-core/src/medrecord/graph/mod.rs +++ b/crates/medmodels-core/src/medrecord/graph/mod.rs @@ -9,18 +9,18 @@ use node::Node; use serde::{Deserialize, Serialize}; use std::{ collections::{HashMap, HashSet}, - sync::atomic::AtomicUsize, + sync::atomic::AtomicU32, }; pub type NodeIndex = MedRecordAttribute; -pub type EdgeIndex = usize; +pub type EdgeIndex = u32; pub type Attributes = HashMap; #[derive(Serialize, Deserialize, Debug)] pub(super) struct Graph { pub(crate) nodes: MrHashMap, pub(crate) edges: MrHashMap, - edge_index_counter: AtomicUsize, + edge_index_counter: AtomicU32, } impl Clone for Graph { @@ -28,7 +28,7 @@ impl Clone for Graph { Self { nodes: self.nodes.clone(), edges: self.edges.clone(), - edge_index_counter: AtomicUsize::new( + edge_index_counter: AtomicU32::new( self.edge_index_counter .load(std::sync::atomic::Ordering::Relaxed), ), @@ -42,7 +42,7 @@ impl Graph { Self { nodes: MrHashMap::new(), edges: MrHashMap::new(), - edge_index_counter: AtomicUsize::new(0), + edge_index_counter: AtomicU32::new(0), } } @@ -50,7 +50,7 @@ impl Graph { Self { nodes: MrHashMap::with_capacity(node_capacity), edges: MrHashMap::with_capacity(edge_capacity), - edge_index_counter: AtomicUsize::new(0), + edge_index_counter: AtomicU32::new(0), } } @@ -58,13 +58,13 @@ impl Graph { self.nodes.clear(); self.edges.clear(); - self.edge_index_counter = AtomicUsize::new(0); + self.edge_index_counter = AtomicU32::new(0); } pub fn clear_edges(&mut self) { self.edges.clear(); - self.edge_index_counter = AtomicUsize::new(0); + self.edge_index_counter = AtomicU32::new(0); } pub fn node_count(&self) -> usize { @@ -359,7 +359,7 @@ impl Graph { self.edges.contains_key(edge_index) } - pub fn neighbors( + pub fn neighbors_outgoing( &self, node_index: &NodeIndex, ) -> Result, GraphError> { @@ -381,6 +381,29 @@ impl Graph { })) } + // TODO: Add tests + pub fn neighbors_incoming( + &self, + node_index: &NodeIndex, + ) -> Result, GraphError> { + Ok(self + .nodes + .get(node_index) + .ok_or(GraphError::IndexError(format!( + "Cannot find node with index {}", + node_index + )))? + .incoming_edge_indices + .iter() + .map(|edge_index| { + &self + .edges + .get(edge_index) + .expect("Edge must exist") + .source_node_index + })) + } + pub fn neighbors_undirected( &self, node_index: &NodeIndex, @@ -913,7 +936,7 @@ mod test { fn test_neighbors() { let graph = create_graph(); - let neighbors = graph.neighbors(&"0".into()).unwrap(); + let neighbors = graph.neighbors_outgoing(&"0".into()).unwrap(); assert_eq!(2, neighbors.count()); } @@ -923,7 +946,7 @@ mod test { let graph = create_graph(); assert!(graph - .neighbors(&"50".into()) + .neighbors_outgoing(&"50".into()) .is_err_and(|e| matches!(e, GraphError::IndexError(_)))); } @@ -931,7 +954,7 @@ mod test { fn test_neighbors_undirected() { let graph = create_graph(); - let neighbors = graph.neighbors(&"2".into()).unwrap(); + let neighbors = graph.neighbors_outgoing(&"2".into()).unwrap(); assert_eq!(0, neighbors.count()); let neighbors = graph.neighbors_undirected(&"2".into()).unwrap(); diff --git a/crates/medmodels-core/src/medrecord/graph/node.rs b/crates/medmodels-core/src/medrecord/graph/node.rs index 9af16851..4d90ee0f 100644 --- a/crates/medmodels-core/src/medrecord/graph/node.rs +++ b/crates/medmodels-core/src/medrecord/graph/node.rs @@ -4,9 +4,9 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Node { - pub attributes: Attributes, - pub(super) outgoing_edge_indices: MrHashSet, - pub(super) incoming_edge_indices: MrHashSet, + pub(crate) attributes: Attributes, + pub(crate) outgoing_edge_indices: MrHashSet, + pub(crate) incoming_edge_indices: MrHashSet, } impl Node { diff --git a/crates/medmodels-core/src/medrecord/mod.rs b/crates/medmodels-core/src/medrecord/mod.rs index ee4e8ea0..f9b4b03f 100644 --- a/crates/medmodels-core/src/medrecord/mod.rs +++ b/crates/medmodels-core/src/medrecord/mod.rs @@ -11,9 +11,24 @@ pub use self::{ graph::{Attributes, EdgeIndex, NodeIndex}, group_mapping::Group, querying::{ - edge, node, ArithmeticOperation, EdgeAttributeOperand, EdgeIndexOperand, EdgeOperand, - EdgeOperation, NodeAttributeOperand, NodeIndexOperand, NodeOperand, NodeOperation, - TransformationOperation, ValueOperand, + attributes::{ + AttributesTreeOperand, MultipleAttributesComparisonOperand, MultipleAttributesOperand, + SingleAttributeComparisonOperand, SingleAttributeOperand, + }, + edges::{ + EdgeIndexComparisonOperand, EdgeIndexOperand, EdgeIndicesComparisonOperand, + EdgeIndicesOperand, EdgeOperand, + }, + nodes::{ + EdgeDirection, NodeIndexComparisonOperand, NodeIndexOperand, + NodeIndicesComparisonOperand, NodeIndicesOperand, NodeOperand, + }, + traits::DeepClone, + values::{ + MultipleValuesComparisonOperand, MultipleValuesOperand, SingleValueComparisonOperand, + SingleValueOperand, + }, + wrapper::{CardinalityWrapper, Wrapper}, }, schema::{AttributeDataType, AttributeType, GroupSchema, Schema}, }; @@ -22,7 +37,7 @@ use ::polars::frame::DataFrame; use graph::Graph; use group_mapping::GroupMapping; use polars::{dataframe_to_edges, dataframe_to_nodes}; -use querying::{EdgeSelection, NodeSelection}; +use querying::{edges::EdgeSelection, nodes::NodeSelection}; use serde::{Deserialize, Serialize}; use std::{fs, mem, path::Path}; @@ -683,12 +698,22 @@ impl MedRecord { self.group_mapping.contains_group(group) } - pub fn neighbors( + pub fn neighbors_outgoing( &self, node_index: &NodeIndex, ) -> Result, MedRecordError> { self.graph - .neighbors(node_index) + .neighbors_outgoing(node_index) + .map_err(MedRecordError::from) + } + + // TODO: Add tests + pub fn neighbors_incoming( + &self, + node_index: &NodeIndex, + ) -> Result, MedRecordError> { + self.graph + .neighbors_incoming(node_index) .map_err(MedRecordError::from) } @@ -706,12 +731,18 @@ impl MedRecord { self.group_mapping.clear(); } - pub fn select_nodes(&self, operation: NodeOperation) -> NodeSelection { - NodeSelection::new(self, operation) + pub fn select_nodes(&self, query: Q) -> NodeSelection + where + Q: FnOnce(&mut Wrapper), + { + NodeSelection::new(self, query) } - pub fn select_edges(&self, operation: EdgeOperation) -> EdgeSelection { - EdgeSelection::new(self, operation) + pub fn select_edges(&self, query: Q) -> EdgeSelection + where + Q: FnOnce(&mut Wrapper), + { + EdgeSelection::new(self, query) } } @@ -1870,7 +1901,7 @@ mod test { fn test_neighbors() { let medrecord = create_medrecord(); - let neighbors = medrecord.neighbors(&"0".into()).unwrap(); + let neighbors = medrecord.neighbors_outgoing(&"0".into()).unwrap(); assert_eq!(2, neighbors.count()); } @@ -1881,7 +1912,7 @@ mod test { // Querying neighbors of a non-existing node sohuld fail assert!(medrecord - .neighbors(&"0".into()) + .neighbors_outgoing(&"0".into()) .is_err_and(|e| matches!(e, MedRecordError::IndexError(_)))); } @@ -1889,7 +1920,7 @@ mod test { fn test_neighbors_undirected() { let medrecord = create_medrecord(); - let neighbors = medrecord.neighbors(&"2".into()).unwrap(); + let neighbors = medrecord.neighbors_outgoing(&"2".into()).unwrap(); assert_eq!(0, neighbors.count()); let neighbors = medrecord.neighbors_undirected(&"2".into()).unwrap(); diff --git a/crates/medmodels-core/src/medrecord/querying/attributes/mod.rs b/crates/medmodels-core/src/medrecord/querying/attributes/mod.rs new file mode 100644 index 00000000..8e60945f --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/attributes/mod.rs @@ -0,0 +1,135 @@ +mod operand; +mod operation; + +use super::{ + edges::{EdgeOperand, EdgeOperation}, + nodes::{NodeOperand, NodeOperation}, + BoxedIterator, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{Attributes, EdgeIndex, MedRecordAttribute, NodeIndex}, + MedRecord, +}; +pub use operand::{ + AttributesTreeOperand, MultipleAttributesComparisonOperand, MultipleAttributesOperand, + SingleAttributeComparisonOperand, SingleAttributeOperand, +}; +pub use operation::{AttributesTreeOperation, MultipleAttributesOperation}; +use std::fmt::Display; + +#[derive(Debug, Clone)] +pub enum SingleKind { + Max, + Min, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum MultipleKind { + Max, + Min, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum SingleComparisonKind { + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + NotEqualTo, + StartsWith, + EndsWith, + Contains, +} + +#[derive(Debug, Clone)] +pub enum MultipleComparisonKind { + IsIn, + IsNotIn, +} + +#[derive(Debug, Clone)] +pub enum BinaryArithmeticKind { + Add, + Sub, + Mul, + Pow, + Mod, +} + +impl Display for BinaryArithmeticKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BinaryArithmeticKind::Add => write!(f, "add"), + BinaryArithmeticKind::Sub => write!(f, "sub"), + BinaryArithmeticKind::Mul => write!(f, "mul"), + BinaryArithmeticKind::Pow => write!(f, "pow"), + BinaryArithmeticKind::Mod => write!(f, "mod"), + } + } +} + +#[derive(Debug, Clone)] +pub enum UnaryArithmeticKind { + Abs, + Trim, + TrimStart, + TrimEnd, + Lowercase, + Uppercase, +} + +pub(crate) trait GetAttributes { + fn get_attributes<'a>(&'a self, medrecord: &'a MedRecord) -> MedRecordResult<&'a Attributes>; +} + +impl GetAttributes for NodeIndex { + fn get_attributes<'a>(&'a self, medrecord: &'a MedRecord) -> MedRecordResult<&'a Attributes> { + medrecord.node_attributes(self) + } +} + +impl GetAttributes for EdgeIndex { + fn get_attributes<'a>(&'a self, medrecord: &'a MedRecord) -> MedRecordResult<&'a Attributes> { + medrecord.edge_attributes(self) + } +} + +#[derive(Debug, Clone)] +pub enum Context { + NodeOperand(NodeOperand), + EdgeOperand(EdgeOperand), +} + +impl Context { + pub(crate) fn get_attributes<'a>( + &self, + medrecord: &'a MedRecord, + ) -> MedRecordResult>> { + Ok(match self { + Self::NodeOperand(node_operand) => { + let node_indices = node_operand.evaluate(medrecord)?; + + Box::new( + NodeOperation::get_attributes(medrecord, node_indices).map(|(_, value)| value), + ) + } + Self::EdgeOperand(edge_operand) => { + let edge_indices = edge_operand.evaluate(medrecord)?; + + Box::new( + EdgeOperation::get_attributes(medrecord, edge_indices).map(|(_, value)| value), + ) + } + }) + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/attributes/operand.rs b/crates/medmodels-core/src/medrecord/querying/attributes/operand.rs new file mode 100644 index 00000000..f53b287d --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/attributes/operand.rs @@ -0,0 +1,933 @@ +use super::{ + operation::{AttributesTreeOperation, MultipleAttributesOperation, SingleAttributeOperation}, + BinaryArithmeticKind, Context, GetAttributes, MultipleComparisonKind, MultipleKind, + SingleComparisonKind, SingleKind, UnaryArithmeticKind, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{ + querying::{ + traits::{DeepClone, ReadWriteOrPanic}, + values::{self, MultipleValuesOperand}, + BoxedIterator, + }, + MedRecordAttribute, Wrapper, + }, + MedRecord, +}; +use std::{fmt::Display, hash::Hash}; + +macro_rules! implement_attributes_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = Wrapper::::new( + self.deep_clone(), + MultipleKind::$variant, + ); + + self.operations + .push(AttributesTreeOperation::AttributesOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_attribute_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = + Wrapper::::new(self.deep_clone(), SingleKind::$variant); + + self.operations + .push(MultipleAttributesOperation::AttributeOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_single_attribute_comparison_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, attribute: V) { + self.operations + .push($operation::SingleAttributeComparisonOperation { + operand: attribute.into(), + kind: SingleComparisonKind::$kind, + }); + } + }; +} + +macro_rules! implement_binary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, attribute: V) { + self.operations.push($operation::BinaryArithmeticOpration { + operand: attribute.into(), + kind: BinaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_unary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name(&mut self) { + self.operations.push($operation::UnaryArithmeticOperation { + kind: UnaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_assertion_operation { + ($name:ident, $operation:expr) => { + pub fn $name(&mut self) { + self.operations.push($operation); + } + }; +} + +macro_rules! implement_wrapper_operand { + ($name:ident) => { + pub fn $name(&self) { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_return { + ($name:ident, $return_operand:ident) => { + pub fn $name(&self) -> Wrapper<$return_operand> { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_argument { + ($name:ident, $attribute_type:ty) => { + pub fn $name(&self, attribute: $attribute_type) { + self.0.write_or_panic().$name(attribute) + } + }; +} + +#[derive(Debug, Clone)] +pub enum SingleAttributeComparisonOperand { + Operand(SingleAttributeOperand), + Attribute(MedRecordAttribute), +} + +impl DeepClone for SingleAttributeComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Attribute(attribute) => Self::Attribute(attribute.clone()), + } + } +} + +impl From> for SingleAttributeComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for SingleAttributeComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From for SingleAttributeComparisonOperand { + fn from(value: V) -> Self { + Self::Attribute(value.into()) + } +} + +#[derive(Debug, Clone)] +pub enum MultipleAttributesComparisonOperand { + Operand(MultipleAttributesOperand), + Attributes(Vec), +} + +impl DeepClone for MultipleAttributesComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Attributes(attribute) => Self::Attributes(attribute.clone()), + } + } +} + +impl From> for MultipleAttributesComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for MultipleAttributesComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From> for MultipleAttributesComparisonOperand { + fn from(value: Vec) -> Self { + Self::Attributes(value.into_iter().map(Into::into).collect()) + } +} + +impl + Clone, const N: usize> From<[V; N]> + for MultipleAttributesComparisonOperand +{ + fn from(value: [V; N]) -> Self { + value.to_vec().into() + } +} + +#[derive(Debug, Clone)] +pub struct AttributesTreeOperand { + pub(crate) context: Context, + operations: Vec, +} + +impl DeepClone for AttributesTreeOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl AttributesTreeOperand { + pub(crate) fn new(context: Context) -> Self { + Self { + context, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + ) -> MedRecordResult)>> { + let attributes = Box::new(attributes) as BoxedIterator<(&'a T, Vec)>; + + self.operations + .iter() + .try_fold(attributes, |attribute_tuples, operation| { + operation.evaluate(medrecord, attribute_tuples) + }) + } + + implement_attributes_operation!(max, Max); + implement_attributes_operation!(min, Min); + implement_attributes_operation!(count, Count); + implement_attributes_operation!(sum, Sum); + implement_attributes_operation!(first, First); + implement_attributes_operation!(last, Last); + + implement_single_attribute_comparison_operation!( + greater_than, + AttributesTreeOperation, + GreaterThan + ); + implement_single_attribute_comparison_operation!( + greater_than_or_equal_to, + AttributesTreeOperation, + GreaterThanOrEqualTo + ); + implement_single_attribute_comparison_operation!(less_than, AttributesTreeOperation, LessThan); + implement_single_attribute_comparison_operation!( + less_than_or_equal_to, + AttributesTreeOperation, + LessThanOrEqualTo + ); + implement_single_attribute_comparison_operation!(equal_to, AttributesTreeOperation, EqualTo); + implement_single_attribute_comparison_operation!( + not_equal_to, + AttributesTreeOperation, + NotEqualTo + ); + implement_single_attribute_comparison_operation!( + starts_with, + AttributesTreeOperation, + StartsWith + ); + implement_single_attribute_comparison_operation!(ends_with, AttributesTreeOperation, EndsWith); + implement_single_attribute_comparison_operation!(contains, AttributesTreeOperation, Contains); + + pub fn is_in>(&mut self, attributes: V) { + self.operations.push( + AttributesTreeOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsIn, + }, + ); + } + + pub fn is_not_in>(&mut self, attributes: V) { + self.operations.push( + AttributesTreeOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsNotIn, + }, + ); + } + + implement_binary_arithmetic_operation!(add, AttributesTreeOperation, Add); + implement_binary_arithmetic_operation!(sub, AttributesTreeOperation, Sub); + implement_binary_arithmetic_operation!(mul, AttributesTreeOperation, Mul); + implement_binary_arithmetic_operation!(pow, AttributesTreeOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, AttributesTreeOperation, Mod); + + implement_unary_arithmetic_operation!(abs, AttributesTreeOperation, Abs); + implement_unary_arithmetic_operation!(trim, AttributesTreeOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, AttributesTreeOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, AttributesTreeOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, AttributesTreeOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, AttributesTreeOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(AttributesTreeOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, AttributesTreeOperation::IsString); + implement_assertion_operation!(is_int, AttributesTreeOperation::IsInt); + implement_assertion_operation!(is_max, AttributesTreeOperation::IsMax); + implement_assertion_operation!(is_min, AttributesTreeOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(self.context.clone()); + let mut or_operand = Wrapper::::new(self.context.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(AttributesTreeOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } + + pub fn exclude(&mut self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + let mut operand = Wrapper::::new(self.context.clone()); + + query(&mut operand); + + self.operations + .push(AttributesTreeOperation::Exclude { operand }); + } +} + +impl Wrapper { + pub(crate) fn new(context: Context) -> Self { + AttributesTreeOperand::new(context).into() + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + ) -> MedRecordResult)>> { + self.0.read_or_panic().evaluate(medrecord, attributes) + } + + implement_wrapper_operand_with_return!(max, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(min, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(count, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(sum, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(first, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(last, MultipleAttributesOperand); + + implement_wrapper_operand_with_argument!( + greater_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!( + not_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + starts_with, + impl Into + ); + implement_wrapper_operand_with_argument!( + ends_with, + impl Into + ); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!( + is_not_in, + impl Into + ); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } + + pub fn exclude(&self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().exclude(query) + } +} + +#[derive(Debug, Clone)] +pub struct MultipleAttributesOperand { + pub(crate) context: AttributesTreeOperand, + pub(crate) kind: MultipleKind, + operations: Vec, +} + +impl DeepClone for MultipleAttributesOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl MultipleAttributesOperand { + pub(crate) fn new(context: AttributesTreeOperand, kind: MultipleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + ) -> MedRecordResult> { + let attributes = Box::new(attributes) as BoxedIterator<(&'a T, MedRecordAttribute)>; + + self.operations + .iter() + .try_fold(attributes, |attribute_tuples, operation| { + operation.evaluate(medrecord, attribute_tuples) + }) + } + + implement_attribute_operation!(max, Max); + implement_attribute_operation!(min, Min); + implement_attribute_operation!(count, Count); + implement_attribute_operation!(sum, Sum); + implement_attribute_operation!(first, First); + implement_attribute_operation!(last, Last); + + implement_single_attribute_comparison_operation!( + greater_than, + MultipleAttributesOperation, + GreaterThan + ); + implement_single_attribute_comparison_operation!( + greater_than_or_equal_to, + MultipleAttributesOperation, + GreaterThanOrEqualTo + ); + implement_single_attribute_comparison_operation!( + less_than, + MultipleAttributesOperation, + LessThan + ); + implement_single_attribute_comparison_operation!( + less_than_or_equal_to, + MultipleAttributesOperation, + LessThanOrEqualTo + ); + implement_single_attribute_comparison_operation!( + equal_to, + MultipleAttributesOperation, + EqualTo + ); + implement_single_attribute_comparison_operation!( + not_equal_to, + MultipleAttributesOperation, + NotEqualTo + ); + implement_single_attribute_comparison_operation!( + starts_with, + MultipleAttributesOperation, + StartsWith + ); + implement_single_attribute_comparison_operation!( + ends_with, + MultipleAttributesOperation, + EndsWith + ); + implement_single_attribute_comparison_operation!( + contains, + MultipleAttributesOperation, + Contains + ); + + pub fn is_in>(&mut self, attributes: V) { + self.operations.push( + MultipleAttributesOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsIn, + }, + ); + } + + pub fn is_not_in>(&mut self, attributes: V) { + self.operations.push( + MultipleAttributesOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsNotIn, + }, + ); + } + + implement_binary_arithmetic_operation!(add, MultipleAttributesOperation, Add); + implement_binary_arithmetic_operation!(sub, MultipleAttributesOperation, Sub); + implement_binary_arithmetic_operation!(mul, MultipleAttributesOperation, Mul); + implement_binary_arithmetic_operation!(pow, MultipleAttributesOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, MultipleAttributesOperation, Mod); + + implement_unary_arithmetic_operation!(abs, MultipleAttributesOperation, Abs); + implement_unary_arithmetic_operation!(trim, MultipleAttributesOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, MultipleAttributesOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, MultipleAttributesOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, MultipleAttributesOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, MultipleAttributesOperation, Uppercase); + + #[allow(clippy::wrong_self_convention)] + pub fn to_values(&mut self) -> Wrapper { + let operand = Wrapper::::new( + values::Context::MultipleAttributesOperand(self.deep_clone()), + "unused".into(), + ); + + self.operations.push(MultipleAttributesOperation::ToValues { + operand: operand.clone(), + }); + + operand + } + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(MultipleAttributesOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, MultipleAttributesOperation::IsString); + implement_assertion_operation!(is_int, MultipleAttributesOperation::IsInt); + implement_assertion_operation!(is_max, MultipleAttributesOperation::IsMax); + implement_assertion_operation!(is_min, MultipleAttributesOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(MultipleAttributesOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } + + pub fn exclude(&mut self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + let mut operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + query(&mut operand); + + self.operations + .push(MultipleAttributesOperation::Exclude { operand }); + } +} + +impl Wrapper { + pub(crate) fn new(context: AttributesTreeOperand, kind: MultipleKind) -> Self { + MultipleAttributesOperand::new(context, kind).into() + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, attributes) + } + + implement_wrapper_operand_with_return!(max, SingleAttributeOperand); + implement_wrapper_operand_with_return!(min, SingleAttributeOperand); + implement_wrapper_operand_with_return!(count, SingleAttributeOperand); + implement_wrapper_operand_with_return!(sum, SingleAttributeOperand); + implement_wrapper_operand_with_return!(first, SingleAttributeOperand); + implement_wrapper_operand_with_return!(last, SingleAttributeOperand); + + implement_wrapper_operand_with_argument!( + greater_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!( + not_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + starts_with, + impl Into + ); + implement_wrapper_operand_with_argument!( + ends_with, + impl Into + ); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!( + is_not_in, + impl Into + ); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + implement_wrapper_operand_with_return!(to_values, MultipleValuesOperand); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } + + pub fn exclude(&self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().exclude(query) + } +} + +#[derive(Debug, Clone)] +pub struct SingleAttributeOperand { + pub(crate) context: MultipleAttributesOperand, + pub(crate) kind: SingleKind, + operations: Vec, +} + +impl DeepClone for SingleAttributeOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.deep_clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl SingleAttributeOperand { + pub(crate) fn new(context: MultipleAttributesOperand, kind: SingleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + attribute: MedRecordAttribute, + ) -> MedRecordResult> { + self.operations + .iter() + .try_fold(Some(attribute), |attribute, operation| { + if let Some(attribute) = attribute { + operation.evaluate(medrecord, attribute) + } else { + Ok(None) + } + }) + } + + implement_single_attribute_comparison_operation!( + greater_than, + SingleAttributeOperation, + GreaterThan + ); + implement_single_attribute_comparison_operation!( + greater_than_or_equal_to, + SingleAttributeOperation, + GreaterThanOrEqualTo + ); + implement_single_attribute_comparison_operation!(less_than, SingleAttributeOperation, LessThan); + implement_single_attribute_comparison_operation!( + less_than_or_equal_to, + SingleAttributeOperation, + LessThanOrEqualTo + ); + implement_single_attribute_comparison_operation!(equal_to, SingleAttributeOperation, EqualTo); + implement_single_attribute_comparison_operation!( + not_equal_to, + SingleAttributeOperation, + NotEqualTo + ); + implement_single_attribute_comparison_operation!( + starts_with, + SingleAttributeOperation, + StartsWith + ); + implement_single_attribute_comparison_operation!(ends_with, SingleAttributeOperation, EndsWith); + implement_single_attribute_comparison_operation!(contains, SingleAttributeOperation, Contains); + + pub fn is_in>(&mut self, attributes: V) { + self.operations.push( + SingleAttributeOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsIn, + }, + ); + } + + pub fn is_not_in>(&mut self, attributes: V) { + self.operations.push( + SingleAttributeOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsNotIn, + }, + ); + } + + implement_binary_arithmetic_operation!(add, SingleAttributeOperation, Add); + implement_binary_arithmetic_operation!(sub, SingleAttributeOperation, Sub); + implement_binary_arithmetic_operation!(mul, SingleAttributeOperation, Mul); + implement_binary_arithmetic_operation!(pow, SingleAttributeOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, SingleAttributeOperation, Mod); + + implement_unary_arithmetic_operation!(abs, SingleAttributeOperation, Abs); + implement_unary_arithmetic_operation!(trim, SingleAttributeOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, SingleAttributeOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, SingleAttributeOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, SingleAttributeOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, SingleAttributeOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(SingleAttributeOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, SingleAttributeOperation::IsString); + implement_assertion_operation!(is_int, SingleAttributeOperation::IsInt); + + pub fn eiter_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(SingleAttributeOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } + + pub fn exclude(&mut self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + let mut operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + query(&mut operand); + + self.operations + .push(SingleAttributeOperation::Exclude { operand }); + } +} + +impl Wrapper { + pub(crate) fn new(context: MultipleAttributesOperand, kind: SingleKind) -> Self { + SingleAttributeOperand::new(context, kind).into() + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + attribute: MedRecordAttribute, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, attribute) + } + + implement_wrapper_operand_with_argument!( + greater_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!( + not_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + starts_with, + impl Into + ); + implement_wrapper_operand_with_argument!( + ends_with, + impl Into + ); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!( + is_not_in, + impl Into + ); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().eiter_or(either_query, or_query); + } + + pub fn exclude(&self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().exclude(query); + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/attributes/operation.rs b/crates/medmodels-core/src/medrecord/querying/attributes/operation.rs new file mode 100644 index 00000000..f498d570 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/attributes/operation.rs @@ -0,0 +1,1431 @@ +use super::{ + operand::{ + MultipleAttributesComparisonOperand, MultipleAttributesOperand, + SingleAttributeComparisonOperand, SingleAttributeOperand, + }, + AttributesTreeOperand, BinaryArithmeticKind, GetAttributes, MultipleComparisonKind, + SingleComparisonKind, UnaryArithmeticKind, +}; +use crate::{ + errors::{MedRecordError, MedRecordResult}, + medrecord::{ + datatypes::{ + Abs, Contains, EndsWith, Lowercase, Mod, Pow, Slice, StartsWith, Trim, TrimEnd, + TrimStart, Uppercase, + }, + querying::{ + attributes::{MultipleKind, SingleKind}, + traits::{DeepClone, ReadWriteOrPanic}, + values::MultipleValuesOperand, + BoxedIterator, + }, + DataType, MedRecordAttribute, MedRecordValue, Wrapper, + }, + MedRecord, +}; +use itertools::Itertools; +use std::{ + cmp::Ordering, + collections::{HashMap, HashSet}, + fmt::Display, + hash::Hash, + ops::{Add, Mul, Range, Sub}, +}; + +macro_rules! get_multiple_operand_attributes { + ($kind:ident, $attributes:expr) => { + match $kind { + MultipleKind::Max => Box::new(AttributesTreeOperation::get_max($attributes)?), + MultipleKind::Min => Box::new(AttributesTreeOperation::get_min($attributes)?), + MultipleKind::Count => Box::new(AttributesTreeOperation::get_count($attributes)?), + MultipleKind::Sum => Box::new(AttributesTreeOperation::get_sum($attributes)?), + MultipleKind::First => Box::new(AttributesTreeOperation::get_first($attributes)?), + MultipleKind::Last => Box::new(AttributesTreeOperation::get_last($attributes)?), + } + }; +} + +macro_rules! get_single_operand_attribute { + ($kind:ident, $attributes:expr) => { + match $kind { + SingleKind::Max => MultipleAttributesOperation::get_max($attributes)?.1, + SingleKind::Min => MultipleAttributesOperation::get_min($attributes)?.1, + SingleKind::Count => MultipleAttributesOperation::get_count($attributes), + SingleKind::Sum => MultipleAttributesOperation::get_sum($attributes)?, + SingleKind::First => MultipleAttributesOperation::get_first($attributes)?, + SingleKind::Last => MultipleAttributesOperation::get_last($attributes)?, + } + }; +} + +macro_rules! get_single_attribute_comparison_operand_attribute { + ($operand:ident, $medrecord:ident) => { + match $operand { + SingleAttributeComparisonOperand::Operand(operand) => { + let context = &operand.context.context.context; + let kind = &operand.context.kind; + + let comparison_attributes = context + .get_attributes($medrecord)? + .map(|attribute| (&0, attribute)); + + let comparison_attributes: Box> = + get_multiple_operand_attributes!(kind, comparison_attributes); + + let kind = &operand.kind; + + let comparison_attributes = + get_single_operand_attribute!(kind, comparison_attributes); + + operand.evaluate($medrecord, comparison_attributes)?.ok_or( + MedRecordError::QueryError("No attribute to compare".to_string()), + )? + } + SingleAttributeComparisonOperand::Attribute(attribute) => attribute.clone(), + } + }; +} + +#[derive(Debug, Clone)] +pub enum AttributesTreeOperation { + AttributesOperation { + operand: Wrapper, + }, + SingleAttributeComparisonOperation { + operand: SingleAttributeComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleAttributesComparisonOperation { + operand: MultipleAttributesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleAttributeComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, + Exclude { + operand: Wrapper, + }, +} + +impl DeepClone for AttributesTreeOperation { + fn deep_clone(&self) -> Self { + match self { + Self::AttributesOperation { operand } => Self::AttributesOperation { + operand: operand.deep_clone(), + }, + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::SingleAttributeComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::MultipleAttributesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + Self::Exclude { operand } => Self::Exclude { + operand: operand.deep_clone(), + }, + } + } +} + +impl AttributesTreeOperation { + pub(crate) fn evaluate<'a, T: Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + ) -> MedRecordResult)>> { + match self { + Self::AttributesOperation { operand } => Ok(Box::new( + Self::evaluate_attributes_operation(medrecord, attributes, operand)?, + )), + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::evaluate_single_attribute_comparison_operation( + medrecord, attributes, operand, kind, + ) + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_attributes_comparison_operation( + medrecord, attributes, operand, kind, + ) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, attributes, operand, kind) + } + Self::UnaryArithmeticOperation { kind } => Ok(Box::new( + Self::evaluate_unary_arithmetic_operation(attributes, kind.clone()), + )), + Self::Slice(range) => Ok(Box::new(Self::evaluate_slice(attributes, range.clone()))), + Self::IsString => Ok(Box::new(attributes.map(|(index, attribute)| { + ( + index, + attribute + .into_iter() + .filter(|attribute| matches!(attribute, MedRecordAttribute::String(_))) + .collect(), + ) + }))), + Self::IsInt => Ok(Box::new(attributes.map(|(index, attribute)| { + ( + index, + attribute + .into_iter() + .filter(|attribute| matches!(attribute, MedRecordAttribute::String(_))) + .collect(), + ) + }))), + Self::IsMax => { + let max_attributes = Self::get_max(attributes)?; + + Ok(Box::new( + max_attributes.map(|(index, attribute)| (index, vec![attribute])), + )) + } + Self::IsMin => { + let min_attributes = Self::get_min(attributes)?; + + Ok(Box::new( + min_attributes.map(|(index, attribute)| (index, vec![attribute])), + )) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, attributes, either, or) + } + Self::Exclude { operand } => Self::evaluate_exclude(medrecord, attributes, operand), + } + } + + #[inline] + pub(crate) fn get_max<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes.map(|(index, attributes)| { + let mut attributes = attributes.into_iter(); + + let first_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + let attribute = attributes.try_fold(first_attribute, |max, attribute| { + match attribute.partial_cmp(&max) { + Some(Ordering::Greater) => Ok(attribute), + None => { + let first_dtype = DataType::from(attribute); + let second_dtype = DataType::from(max); + + Err(MedRecordError::QueryError(format!( + "Cannot compare attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(max), + } + })?; + + Ok((index, attribute)) + }).collect::>>()?.into_iter()) + } + + #[inline] + pub(crate) fn get_min<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes.map(|(index, attributes)| { + let mut attributes = attributes.into_iter(); + + let first_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + let attribute = attributes.try_fold(first_attribute, |max, attribute| { + match attribute.partial_cmp(&max) { + Some(Ordering::Less) => Ok(attribute), + None => { + let first_dtype = DataType::from(attribute); + let second_dtype = DataType::from(max); + + Err(MedRecordError::QueryError(format!( + "Cannot compare attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(max), + } + })?; + + Ok((index, attribute)) + }).collect::>>()?.into_iter()) + } + + #[inline] + pub(crate) fn get_count<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes + .map(|(index, attribute)| (index, MedRecordAttribute::Int(attribute.len() as i64)))) + } + + #[inline] + pub(crate) fn get_sum<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes.map(|(index, attributes)| { + let mut attributes = attributes.into_iter(); + + let first_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + let attribute = attributes.try_fold(first_attribute, |sum, attribute| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&attribute); + + sum.add(attribute).map_err(|_| { + MedRecordError::QueryError(format!( + "Cannot add attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + )) + }) + })?; + + Ok((index, attribute)) + }).collect::>>()?.into_iter()) + } + + #[inline] + pub(crate) fn get_first<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes + .map(|(index, attributes)| { + let first_attribute = + attributes + .into_iter() + .next() + .ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + Ok((index, first_attribute)) + }) + .collect::>>()? + .into_iter()) + } + + #[inline] + pub(crate) fn get_last<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes + .map(|(index, attributes)| { + let first_attribute = + attributes + .into_iter() + .last() + .ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + Ok((index, first_attribute)) + }) + .collect::>>()? + .into_iter()) + } + + #[inline] + fn evaluate_attributes_operation<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator)>, + operand: &Wrapper, + ) -> MedRecordResult)>> { + let kind = &operand.0.read_or_panic().kind; + + let attributes = attributes.collect::>(); + + let multiple_operand_attributes: Box> = + get_multiple_operand_attributes!(kind, attributes.clone().into_iter()); + + let result = operand.evaluate(medrecord, multiple_operand_attributes)?; + + let mut attributes = attributes.into_iter().collect::>(); + + Ok(result + .map(move |(index, _)| (index, attributes.remove(&index).expect("Index must exist")))) + } + + #[inline] + fn evaluate_single_attribute_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + comparison_operand: &SingleAttributeComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult)>> { + let comparison_attribute = + get_single_attribute_comparison_operand_attribute!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute > &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::GreaterThanOrEqualTo => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute >= &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::LessThan => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute < &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::LessThanOrEqualTo => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute <= &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::EqualTo => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute == &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::NotEqualTo => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute != &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::StartsWith => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute.starts_with(&comparison_attribute)) + .collect(), + ) + }))) + } + SingleComparisonKind::EndsWith => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute.ends_with(&comparison_attribute)) + .collect(), + ) + }))) + } + SingleComparisonKind::Contains => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute.contains(&comparison_attribute)) + .collect(), + ) + }))) + } + } + } + + #[inline] + fn evaluate_multiple_attributes_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + comparison_operand: &MultipleAttributesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult)>> { + let comparison_attributes = match comparison_operand { + MultipleAttributesComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + let comparison_attributes = context + .get_attributes(medrecord)? + .map(|attribute| (&0, attribute)); + + let comparison_attributes: Box> = + get_multiple_operand_attributes!(kind, comparison_attributes); + + operand + .evaluate(medrecord, comparison_attributes)? + .map(|(_, attribute)| attribute) + .collect::>() + } + MultipleAttributesComparisonOperand::Attributes(attributes) => attributes.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| comparison_attributes.contains(attribute)) + .collect(), + ) + }))) + } + MultipleComparisonKind::IsNotIn => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| !comparison_attributes.contains(attribute)) + .collect(), + ) + }))) + } + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation<'a, T: 'a>( + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + operand: &SingleAttributeComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult)>> { + let arithmetic_attribute = + get_single_attribute_comparison_operand_attribute!(operand, medrecord); + + let attributes: Box< + dyn Iterator)>>, + > = match kind { + BinaryArithmeticKind::Add => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.add(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + BinaryArithmeticKind::Sub => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.sub(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + BinaryArithmeticKind::Mul => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.mul(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + BinaryArithmeticKind::Pow => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.pow(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + BinaryArithmeticKind::Mod => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.r#mod(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + }; + + Ok(Box::new( + attributes.collect::>>()?.into_iter(), + )) + } + + #[inline] + fn evaluate_unary_arithmetic_operation<'a, T: 'a>( + attributes: impl Iterator)>, + kind: UnaryArithmeticKind, + ) -> impl Iterator)> { + attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .map(|attribute| match kind { + UnaryArithmeticKind::Abs => attribute.abs(), + UnaryArithmeticKind::Trim => attribute.trim(), + UnaryArithmeticKind::TrimStart => attribute.trim_start(), + UnaryArithmeticKind::TrimEnd => attribute.trim_end(), + UnaryArithmeticKind::Lowercase => attribute.lowercase(), + UnaryArithmeticKind::Uppercase => attribute.uppercase(), + }) + .collect(), + ) + }) + } + + #[inline] + fn evaluate_slice<'a, T: 'a>( + attributes: impl Iterator)>, + range: Range, + ) -> impl Iterator)> { + attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .map(|attribute| attribute.slice(range.clone())) + .collect(), + ) + }) + } + + #[inline] + fn evaluate_either_or<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator)>, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult)>> { + let attributes = attributes.collect::>(); + + let either_attributes = either.evaluate(medrecord, attributes.clone().into_iter())?; + let or_attributes = or.evaluate(medrecord, attributes.into_iter())?; + + Ok(Box::new( + either_attributes + .chain(or_attributes) + .unique_by(|attribute| attribute.0), + )) + } + + #[inline] + fn evaluate_exclude<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator)>, + operand: &Wrapper, + ) -> MedRecordResult)>> { + let attributes = attributes.collect::>(); + + let result = operand + .evaluate(medrecord, attributes.clone().into_iter())? + .map(|(index, _)| index) + .collect::>(); + + Ok(Box::new( + attributes + .into_iter() + .filter(move |(index, _)| !result.contains(index)), + )) + } +} + +#[derive(Debug, Clone)] +pub enum MultipleAttributesOperation { + AttributeOperation { + operand: Wrapper, + }, + SingleAttributeComparisonOperation { + operand: SingleAttributeComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleAttributesComparisonOperation { + operand: MultipleAttributesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleAttributeComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + ToValues { + operand: Wrapper, + }, + + Slice(Range), + + IsString, + IsInt, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, + Exclude { + operand: Wrapper, + }, +} + +impl DeepClone for MultipleAttributesOperation { + fn deep_clone(&self) -> Self { + match self { + Self::AttributeOperation { operand } => Self::AttributeOperation { + operand: operand.deep_clone(), + }, + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::SingleAttributeComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::MultipleAttributesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::ToValues { operand } => Self::ToValues { + operand: operand.deep_clone(), + }, + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + Self::Exclude { operand } => Self::Exclude { + operand: operand.deep_clone(), + }, + } + } +} + +impl MultipleAttributesOperation { + pub(crate) fn evaluate<'a, T: Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + ) -> MedRecordResult> { + match self { + Self::AttributeOperation { operand } => { + Self::evaluate_attribute_operation(medrecord, attributes, operand) + } + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::evaluate_single_attribute_comparison_operation( + medrecord, attributes, operand, kind, + ) + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_attributes_comparison_operation( + medrecord, attributes, operand, kind, + ) + } + Self::BinaryArithmeticOpration { operand, kind } => Ok(Box::new( + Self::evaluate_binary_arithmetic_operation(medrecord, attributes, operand, kind)?, + )), + Self::UnaryArithmeticOperation { kind } => Ok(Box::new( + Self::evaluate_unary_arithmetic_operation(attributes, kind.clone()), + )), + Self::ToValues { operand } => Ok(Box::new(Self::evaluate_to_values( + medrecord, attributes, operand, + )?)), + Self::Slice(range) => Ok(Box::new(Self::evaluate_slice(attributes, range.clone()))), + Self::IsString => { + Ok(Box::new(attributes.filter(|(_, attribute)| { + matches!(attribute, MedRecordAttribute::String(_)) + }))) + } + Self::IsInt => { + Ok(Box::new(attributes.filter(|(_, attribute)| { + matches!(attribute, MedRecordAttribute::Int(_)) + }))) + } + Self::IsMax => { + let max_attribute = Self::get_max(attributes)?; + + Ok(Box::new(std::iter::once(max_attribute))) + } + Self::IsMin => { + let min_attribute = Self::get_min(attributes)?; + + Ok(Box::new(std::iter::once(min_attribute))) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, attributes, either, or) + } + Self::Exclude { operand } => Self::evaluate_exclude(medrecord, attributes, operand), + } + } + + #[inline] + pub(crate) fn get_max<'a, T>( + mut attributes: impl Iterator, + ) -> MedRecordResult<(&'a T, MedRecordAttribute)> { + let max_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + attributes.try_fold(max_attribute, |max_attribute, attribute| { + match attribute.1.partial_cmp(&max_attribute.1) { + Some(Ordering::Greater) => Ok(attribute), + None => { + let first_dtype = DataType::from(attribute.1); + let second_dtype = DataType::from(max_attribute.1); + + Err(MedRecordError::QueryError(format!( + "Cannot compare attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(max_attribute), + } + }) + } + + #[inline] + pub(crate) fn get_min<'a, T>( + mut attributes: impl Iterator, + ) -> MedRecordResult<(&'a T, MedRecordAttribute)> { + let min_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + attributes.try_fold(min_attribute, |min_attribute, attribute| { + match attribute.1.partial_cmp(&min_attribute.1) { + Some(Ordering::Less) => Ok(attribute), + None => { + let first_dtype = DataType::from(attribute.1); + let second_dtype = DataType::from(min_attribute.1); + + Err(MedRecordError::QueryError(format!( + "Cannot compare attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(min_attribute), + } + }) + } + + #[inline] + pub(crate) fn get_count<'a, T: 'a>( + attributes: impl Iterator, + ) -> MedRecordAttribute { + MedRecordAttribute::Int(attributes.count() as i64) + } + + #[inline] + // 🥊💥 + pub(crate) fn get_sum<'a, T: 'a>( + mut attributes: impl Iterator, + ) -> MedRecordResult { + let first_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + attributes.try_fold(first_attribute.1, |sum, (_, attribute)| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&attribute); + + sum.add(attribute).map_err(|_| { + MedRecordError::QueryError(format!( + "Cannot add attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + )) + }) + }) + } + + #[inline] + pub(crate) fn get_first<'a, T: 'a>( + mut attributes: impl Iterator, + ) -> MedRecordResult { + attributes + .next() + .ok_or(MedRecordError::QueryError( + "No attributes to get the first".to_string(), + )) + .map(|(_, attribute)| attribute) + } + + #[inline] + pub(crate) fn get_last<'a, T: 'a>( + attributes: impl Iterator, + ) -> MedRecordResult { + attributes + .last() + .ok_or(MedRecordError::QueryError( + "No attributes to get the first".to_string(), + )) + .map(|(_, attribute)| attribute) + } + + #[inline] + fn evaluate_attribute_operation<'a, T>( + medrecord: &'a MedRecord, + attribtues: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let kind = &operand.0.read_or_panic().kind; + + let attributes = attribtues.collect::>(); + + let attribute = get_single_operand_attribute!(kind, attributes.clone().into_iter()); + + Ok(match operand.evaluate(medrecord, attribute)? { + Some(_) => Box::new(attributes.into_iter()), + None => Box::new(std::iter::empty()), + }) + } + + #[inline] + fn evaluate_single_attribute_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + comparison_operand: &SingleAttributeComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_attribute = + get_single_attribute_comparison_operand_attribute!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute > &comparison_attribute + }))) + } + SingleComparisonKind::GreaterThanOrEqualTo => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute >= &comparison_attribute + }))) + } + SingleComparisonKind::LessThan => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute < &comparison_attribute + }))) + } + SingleComparisonKind::LessThanOrEqualTo => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute <= &comparison_attribute + }))) + } + SingleComparisonKind::EqualTo => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute == &comparison_attribute + }))) + } + SingleComparisonKind::NotEqualTo => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute != &comparison_attribute + }))) + } + SingleComparisonKind::StartsWith => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute.starts_with(&comparison_attribute) + }))) + } + SingleComparisonKind::EndsWith => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute.ends_with(&comparison_attribute) + }))) + } + SingleComparisonKind::Contains => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute.contains(&comparison_attribute) + }))) + } + } + } + + #[inline] + fn evaluate_multiple_attributes_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + comparison_operand: &MultipleAttributesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_attributes = match comparison_operand { + MultipleAttributesComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + let comparison_attributes = context + .get_attributes(medrecord)? + .map(|attribute| (&0, attribute)); + + let comparison_attributes: Box> = + get_multiple_operand_attributes!(kind, comparison_attributes); + + operand + .evaluate(medrecord, comparison_attributes)? + .map(|(_, attribute)| attribute) + .collect::>() + } + MultipleAttributesComparisonOperand::Attributes(attributes) => attributes.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + comparison_attributes.contains(attribute) + }))) + } + MultipleComparisonKind::IsNotIn => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + !comparison_attributes.contains(attribute) + }))) + } + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation<'a, T: 'a>( + medrecord: &'a MedRecord, + attributes: impl Iterator, + operand: &SingleAttributeComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_attribute = + get_single_attribute_comparison_operand_attribute!(operand, medrecord); + + let attributes = attributes + .map(move |(t, attribute)| { + match kind { + BinaryArithmeticKind::Add => attribute.add(arithmetic_attribute.clone()), + BinaryArithmeticKind::Sub => attribute.sub(arithmetic_attribute.clone()), + BinaryArithmeticKind::Mul => { + attribute.clone().mul(arithmetic_attribute.clone()) + } + BinaryArithmeticKind::Pow => { + attribute.clone().pow(arithmetic_attribute.clone()) + } + BinaryArithmeticKind::Mod => { + attribute.clone().r#mod(arithmetic_attribute.clone()) + } + } + .map_err(|_| { + MedRecordError::QueryError(format!( + "Failed arithmetic operation {}. Consider narrowing down the attributes using .is_int() or .is_float()", + kind, + )) + }).map(|result| (t, result)) + }); + + // TODO: This is a temporary solution. It should be optimized. + Ok(attributes.collect::>>()?.into_iter()) + } + + #[inline] + fn evaluate_unary_arithmetic_operation<'a, T: 'a>( + attributes: impl Iterator, + kind: UnaryArithmeticKind, + ) -> impl Iterator { + attributes.map(move |(t, attribute)| { + let attribute = match kind { + UnaryArithmeticKind::Abs => attribute.abs(), + UnaryArithmeticKind::Trim => attribute.trim(), + UnaryArithmeticKind::TrimStart => attribute.trim_start(), + UnaryArithmeticKind::TrimEnd => attribute.trim_end(), + UnaryArithmeticKind::Lowercase => attribute.lowercase(), + UnaryArithmeticKind::Uppercase => attribute.uppercase(), + }; + (t, attribute) + }) + } + + pub(crate) fn get_values<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator, + ) -> MedRecordResult> { + Ok(attributes + .map(|(index, attribute)| { + let value = index.get_attributes(medrecord)?.get(&attribute).ok_or( + MedRecordError::QueryError(format!( + "Cannot find attribute {} for index {}", + attribute, index + )), + )?; + + Ok((index, value.clone())) + }) + .collect::>>()? + .into_iter()) + } + + #[inline] + fn evaluate_to_values<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let attributes = attributes.collect::>(); + + let values = Self::get_values(medrecord, attributes.clone().into_iter())?; + + let mut attributes = attributes.into_iter().collect::>(); + + let values = operand.evaluate(medrecord, values.into_iter())?; + + Ok(values.map(move |(index, _)| { + ( + index, + attributes.remove(&index).expect("Attribute must exist"), + ) + })) + } + + #[inline] + fn evaluate_slice<'a, T: 'a>( + attributes: impl Iterator, + range: Range, + ) -> impl Iterator { + attributes.map(move |(t, attribute)| (t, attribute.slice(range.clone()))) + } + + #[inline] + fn evaluate_either_or<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let attributes = attributes.collect::>(); + + let either_attributes = either.evaluate(medrecord, attributes.clone().into_iter())?; + let or_attributes = or.evaluate(medrecord, attributes.into_iter())?; + + Ok(Box::new( + either_attributes + .chain(or_attributes) + .unique_by(|attribute| attribute.0), + )) + } + + #[inline] + fn evaluate_exclude<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let attributes = attributes.collect::>(); + + let result = operand + .evaluate(medrecord, attributes.clone().into_iter())? + .map(|(index, _)| index) + .collect::>(); + + Ok(Box::new( + attributes + .into_iter() + .filter(move |(index, _)| !result.contains(index)), + )) + } +} + +#[derive(Debug, Clone)] +pub enum SingleAttributeOperation { + SingleAttributeComparisonOperation { + operand: SingleAttributeComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleAttributesComparisonOperation { + operand: MultipleAttributesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleAttributeComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, + Exclude { + operand: Wrapper, + }, +} + +impl DeepClone for SingleAttributeOperation { + fn deep_clone(&self) -> Self { + match self { + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::SingleAttributeComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::MultipleAttributesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + Self::Exclude { operand } => Self::Exclude { + operand: operand.deep_clone(), + }, + } + } +} + +impl SingleAttributeOperation { + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + attribute: MedRecordAttribute, + ) -> MedRecordResult> { + match self { + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::evaluate_single_attribute_comparison_operation( + medrecord, attribute, operand, kind, + ) + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_attribute_comparison_operation( + medrecord, attribute, operand, kind, + ) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, attribute, operand, kind) + } + Self::UnaryArithmeticOperation { kind } => Ok(Some(match kind { + UnaryArithmeticKind::Abs => attribute.abs(), + UnaryArithmeticKind::Trim => attribute.trim(), + UnaryArithmeticKind::TrimStart => attribute.trim_start(), + UnaryArithmeticKind::TrimEnd => attribute.trim_end(), + UnaryArithmeticKind::Lowercase => attribute.lowercase(), + UnaryArithmeticKind::Uppercase => attribute.uppercase(), + })), + Self::Slice(range) => Ok(Some(attribute.slice(range.clone()))), + Self::IsString => Ok(match attribute { + MedRecordAttribute::String(_) => Some(attribute), + _ => None, + }), + Self::IsInt => Ok(match attribute { + MedRecordAttribute::Int(_) => Some(attribute), + _ => None, + }), + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, attribute, either, or) + } + Self::Exclude { operand } => { + Ok(match operand.evaluate(medrecord, attribute.clone())? { + Some(_) => None, + None => Some(attribute), + }) + } + } + } + + #[inline] + fn evaluate_single_attribute_comparison_operation( + medrecord: &MedRecord, + attribute: MedRecordAttribute, + comparison_operand: &SingleAttributeComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_attribute = + get_single_attribute_comparison_operand_attribute!(comparison_operand, medrecord); + + let comparison_result = match kind { + SingleComparisonKind::GreaterThan => attribute > comparison_attribute, + SingleComparisonKind::GreaterThanOrEqualTo => attribute >= comparison_attribute, + SingleComparisonKind::LessThan => attribute < comparison_attribute, + SingleComparisonKind::LessThanOrEqualTo => attribute <= comparison_attribute, + SingleComparisonKind::EqualTo => attribute == comparison_attribute, + SingleComparisonKind::NotEqualTo => attribute != comparison_attribute, + SingleComparisonKind::StartsWith => attribute.starts_with(&comparison_attribute), + SingleComparisonKind::EndsWith => attribute.ends_with(&comparison_attribute), + SingleComparisonKind::Contains => attribute.contains(&comparison_attribute), + }; + + Ok(if comparison_result { + Some(attribute) + } else { + None + }) + } + + #[inline] + fn evaluate_multiple_attribute_comparison_operation( + medrecord: &MedRecord, + attribute: MedRecordAttribute, + comparison_operand: &MultipleAttributesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_attributes = match comparison_operand { + MultipleAttributesComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + let comparison_attributes = context + .get_attributes(medrecord)? + .map(|attribute| (&0, attribute)); + + let comparison_attributes: Box> = + get_multiple_operand_attributes!(kind, comparison_attributes); + + operand + .evaluate(medrecord, comparison_attributes)? + .map(|(_, attribute)| attribute) + .collect::>() + } + MultipleAttributesComparisonOperand::Attributes(attributes) => attributes.clone(), + }; + + let comparison_result = match kind { + MultipleComparisonKind::IsIn => comparison_attributes.contains(&attribute), + MultipleComparisonKind::IsNotIn => !comparison_attributes.contains(&attribute), + }; + + Ok(if comparison_result { + Some(attribute) + } else { + None + }) + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + attribute: MedRecordAttribute, + operand: &SingleAttributeComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_attribute = + get_single_attribute_comparison_operand_attribute!(operand, medrecord); + + match kind { + BinaryArithmeticKind::Add => attribute.add(arithmetic_attribute), + BinaryArithmeticKind::Sub => attribute.sub(arithmetic_attribute), + BinaryArithmeticKind::Mul => attribute.mul(arithmetic_attribute), + BinaryArithmeticKind::Pow => attribute.pow(arithmetic_attribute), + BinaryArithmeticKind::Mod => attribute.r#mod(arithmetic_attribute), + } + .map(Some) + } + + #[inline] + fn evaluate_either_or( + medrecord: &MedRecord, + attribute: MedRecordAttribute, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord, attribute.clone())?; + let or_result = or.evaluate(medrecord, attribute)?; + + match (either_result, or_result) { + (Some(either_result), _) => Ok(Some(either_result)), + (None, Some(or_result)) => Ok(Some(or_result)), + _ => Ok(None), + } + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/edges/mod.rs b/crates/medmodels-core/src/medrecord/querying/edges/mod.rs new file mode 100644 index 00000000..f78eabb0 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/edges/mod.rs @@ -0,0 +1,61 @@ +mod operand; +mod operation; +mod selection; + +pub use operand::{ + EdgeIndexComparisonOperand, EdgeIndexOperand, EdgeIndicesComparisonOperand, EdgeIndicesOperand, + EdgeOperand, +}; +pub use operation::EdgeOperation; +pub use selection::EdgeSelection; +use std::fmt::Display; + +#[derive(Debug, Clone)] +pub enum SingleKind { + Max, + Min, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum SingleComparisonKind { + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + NotEqualTo, + StartsWith, + EndsWith, + Contains, +} + +#[derive(Debug, Clone)] +pub enum MultipleComparisonKind { + IsIn, + IsNotIn, +} + +#[derive(Debug, Clone)] +pub enum BinaryArithmeticKind { + Add, + Sub, + Mul, + Pow, + Mod, +} + +impl Display for BinaryArithmeticKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BinaryArithmeticKind::Add => write!(f, "add"), + BinaryArithmeticKind::Sub => write!(f, "sub"), + BinaryArithmeticKind::Mul => write!(f, "mul"), + BinaryArithmeticKind::Pow => write!(f, "pow"), + BinaryArithmeticKind::Mod => write!(f, "mod"), + } + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/edges/operand.rs b/crates/medmodels-core/src/medrecord/querying/edges/operand.rs new file mode 100644 index 00000000..26583415 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/edges/operand.rs @@ -0,0 +1,711 @@ +use super::{ + operation::{EdgeIndexOperation, EdgeIndicesOperation, EdgeOperation}, + BinaryArithmeticKind, MultipleComparisonKind, SingleComparisonKind, SingleKind, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{ + querying::{ + attributes::{self, AttributesTreeOperand}, + nodes::NodeOperand, + traits::{DeepClone, ReadWriteOrPanic}, + values::{self, MultipleValuesOperand}, + wrapper::Wrapper, + BoxedIterator, + }, + CardinalityWrapper, EdgeIndex, Group, MedRecordAttribute, + }, + MedRecord, +}; +use std::fmt::Debug; + +#[derive(Debug, Clone)] +pub struct EdgeOperand { + pub(crate) operations: Vec, +} + +impl DeepClone for EdgeOperand { + fn deep_clone(&self) -> Self { + Self { + operations: self + .operations + .iter() + .map(|operation| operation.deep_clone()) + .collect(), + } + } +} + +impl EdgeOperand { + pub(crate) fn new() -> Self { + Self { + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + ) -> MedRecordResult> { + let edge_indices = Box::new(medrecord.edge_indices()) as BoxedIterator<&'a EdgeIndex>; + + self.operations + .iter() + .try_fold(edge_indices, |edge_indices, operation| { + operation.evaluate(medrecord, edge_indices) + }) + } + + pub fn attribute(&mut self, attribute: MedRecordAttribute) -> Wrapper { + let operand = Wrapper::::new( + values::Context::EdgeOperand(self.deep_clone()), + attribute, + ); + + self.operations.push(EdgeOperation::Values { + operand: operand.clone(), + }); + + operand + } + + pub fn attributes(&mut self) -> Wrapper { + let operand = Wrapper::::new(attributes::Context::EdgeOperand( + self.deep_clone(), + )); + + self.operations.push(EdgeOperation::Attributes { + operand: operand.clone(), + }); + + operand + } + + pub fn index(&mut self) -> Wrapper { + let operand = Wrapper::::new(self.deep_clone()); + + self.operations.push(EdgeOperation::Indices { + operand: operand.clone(), + }); + + operand + } + + pub fn in_group(&mut self, group: G) + where + G: Into>, + { + self.operations.push(EdgeOperation::InGroup { + group: group.into(), + }); + } + + pub fn has_attribute(&mut self, attribute: A) + where + A: Into>, + { + self.operations.push(EdgeOperation::HasAttribute { + attribute: attribute.into(), + }); + } + + pub fn source_node(&mut self) -> Wrapper { + let operand = Wrapper::::new(); + + self.operations.push(EdgeOperation::SourceNode { + operand: operand.clone(), + }); + + operand + } + + pub fn target_node(&mut self) -> Wrapper { + let operand = Wrapper::::new(); + + self.operations.push(EdgeOperation::TargetNode { + operand: operand.clone(), + }); + + operand + } + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(); + let mut or_operand = Wrapper::::new(); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(EdgeOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } + + pub fn exclude(&mut self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + let mut operand = Wrapper::::new(); + + query(&mut operand); + + self.operations.push(EdgeOperation::Exclude { operand }); + } +} + +impl Wrapper { + pub(crate) fn new() -> Self { + EdgeOperand::new().into() + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord) + } + + pub fn attribute(&self, attribute: A) -> Wrapper + where + A: Into, + { + self.0.write_or_panic().attribute(attribute.into()) + } + + pub fn attributes(&self) -> Wrapper { + self.0.write_or_panic().attributes() + } + + pub fn index(&self) -> Wrapper { + self.0.write_or_panic().index() + } + + pub fn in_group(&mut self, group: G) + where + G: Into>, + { + self.0.write_or_panic().in_group(group); + } + + pub fn has_attribute(&mut self, attribute: A) + where + A: Into>, + { + self.0.write_or_panic().has_attribute(attribute); + } + + pub fn source_node(&self) -> Wrapper { + self.0.write_or_panic().source_node() + } + + pub fn target_node(&self) -> Wrapper { + self.0.write_or_panic().target_node() + } + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } + + pub fn exclude(&self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().exclude(query); + } +} + +macro_rules! implement_index_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = Wrapper::::new(self.deep_clone(), SingleKind::$variant); + + self.operations + .push(EdgeIndicesOperation::EdgeIndexOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_single_index_comparison_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, index: V) { + self.operations + .push($operation::EdgeIndexComparisonOperation { + operand: index.into(), + kind: SingleComparisonKind::$kind, + }); + } + }; +} + +macro_rules! implement_binary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, index: V) { + self.operations.push($operation::BinaryArithmeticOpration { + operand: index.into(), + kind: BinaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_assertion_operation { + ($name:ident, $operation:expr) => { + pub fn $name(&mut self) { + self.operations.push($operation); + } + }; +} + +macro_rules! implement_wrapper_operand { + ($name:ident) => { + pub fn $name(&self) { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_return { + ($name:ident, $return_operand:ident) => { + pub fn $name(&self) -> Wrapper<$return_operand> { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_argument { + ($name:ident, $index_type:ty) => { + pub fn $name(&self, index: $index_type) { + self.0.write_or_panic().$name(index) + } + }; +} + +#[derive(Debug, Clone)] +pub enum EdgeIndexComparisonOperand { + Operand(EdgeIndexOperand), + Index(EdgeIndex), +} + +impl DeepClone for EdgeIndexComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Index(index) => Self::Index(*index), + } + } +} + +impl From> for EdgeIndexComparisonOperand { + fn from(index: Wrapper) -> Self { + Self::Operand(index.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for EdgeIndexComparisonOperand { + fn from(index: &Wrapper) -> Self { + Self::Operand(index.0.read_or_panic().deep_clone()) + } +} + +impl> From for EdgeIndexComparisonOperand { + fn from(index: V) -> Self { + Self::Index(index.into()) + } +} + +#[derive(Debug, Clone)] +pub enum EdgeIndicesComparisonOperand { + Operand(EdgeIndicesOperand), + Indices(Vec), +} + +impl DeepClone for EdgeIndicesComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Indices(indices) => Self::Indices(indices.clone()), + } + } +} + +impl From> for EdgeIndicesComparisonOperand { + fn from(indices: Wrapper) -> Self { + Self::Operand(indices.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for EdgeIndicesComparisonOperand { + fn from(indices: &Wrapper) -> Self { + Self::Operand(indices.0.read_or_panic().deep_clone()) + } +} + +impl> From> for EdgeIndicesComparisonOperand { + fn from(indices: Vec) -> Self { + Self::Indices(indices.into_iter().map(Into::into).collect()) + } +} + +impl + Clone, const N: usize> From<[V; N]> for EdgeIndicesComparisonOperand { + fn from(indices: [V; N]) -> Self { + indices.to_vec().into() + } +} + +#[derive(Debug, Clone)] +pub struct EdgeIndicesOperand { + pub(crate) context: EdgeOperand, + operations: Vec, +} + +impl DeepClone for EdgeIndicesOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl EdgeIndicesOperand { + pub(crate) fn new(context: EdgeOperand) -> Self { + Self { + context, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + indices: impl Iterator + 'a, + ) -> MedRecordResult + 'a> { + let indices = Box::new(indices) as BoxedIterator; + + self.operations + .iter() + .try_fold(indices, |index_tuples, operation| { + operation.evaluate(medrecord, index_tuples) + }) + } + + implement_index_operation!(max, Max); + implement_index_operation!(min, Min); + implement_index_operation!(count, Count); + implement_index_operation!(sum, Sum); + implement_index_operation!(first, First); + implement_index_operation!(last, Last); + + implement_single_index_comparison_operation!(greater_than, EdgeIndicesOperation, GreaterThan); + implement_single_index_comparison_operation!( + greater_than_or_equal_to, + EdgeIndicesOperation, + GreaterThanOrEqualTo + ); + implement_single_index_comparison_operation!(less_than, EdgeIndicesOperation, LessThan); + implement_single_index_comparison_operation!( + less_than_or_equal_to, + EdgeIndicesOperation, + LessThanOrEqualTo + ); + implement_single_index_comparison_operation!(equal_to, EdgeIndicesOperation, EqualTo); + implement_single_index_comparison_operation!(not_equal_to, EdgeIndicesOperation, NotEqualTo); + implement_single_index_comparison_operation!(starts_with, EdgeIndicesOperation, StartsWith); + implement_single_index_comparison_operation!(ends_with, EdgeIndicesOperation, EndsWith); + implement_single_index_comparison_operation!(contains, EdgeIndicesOperation, Contains); + + pub fn is_in>(&mut self, indices: V) { + self.operations + .push(EdgeIndicesOperation::EdgeIndicesComparisonOperation { + operand: indices.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, indices: V) { + self.operations + .push(EdgeIndicesOperation::EdgeIndicesComparisonOperation { + operand: indices.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, EdgeIndicesOperation, Add); + implement_binary_arithmetic_operation!(sub, EdgeIndicesOperation, Sub); + implement_binary_arithmetic_operation!(mul, EdgeIndicesOperation, Mul); + implement_binary_arithmetic_operation!(pow, EdgeIndicesOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, EdgeIndicesOperation, Mod); + + implement_assertion_operation!(is_max, EdgeIndicesOperation::IsMax); + implement_assertion_operation!(is_min, EdgeIndicesOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(self.context.clone()); + let mut or_operand = Wrapper::::new(self.context.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(EdgeIndicesOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } + + pub fn exclude(&mut self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + let mut operand = Wrapper::::new(self.context.clone()); + + query(&mut operand); + + self.operations + .push(EdgeIndicesOperation::Exclude { operand }); + } +} + +impl Wrapper { + pub(crate) fn new(context: EdgeOperand) -> Self { + EdgeIndicesOperand::new(context).into() + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + indices: impl Iterator + 'a, + ) -> MedRecordResult + 'a> { + self.0.read_or_panic().evaluate(medrecord, indices) + } + + implement_wrapper_operand_with_return!(max, EdgeIndexOperand); + implement_wrapper_operand_with_return!(min, EdgeIndexOperand); + implement_wrapper_operand_with_return!(count, EdgeIndexOperand); + implement_wrapper_operand_with_return!(sum, EdgeIndexOperand); + implement_wrapper_operand_with_return!(first, EdgeIndexOperand); + implement_wrapper_operand_with_return!(last, EdgeIndexOperand); + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } + + pub fn exclude(&self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().exclude(query); + } +} + +#[derive(Debug, Clone)] +pub struct EdgeIndexOperand { + pub(crate) context: EdgeIndicesOperand, + pub(crate) kind: SingleKind, + operations: Vec, +} + +impl DeepClone for EdgeIndexOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.deep_clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl EdgeIndexOperand { + pub(crate) fn new(context: EdgeIndicesOperand, kind: SingleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + index: EdgeIndex, + ) -> MedRecordResult> { + self.operations + .iter() + .try_fold(Some(index), |index, operation| { + if let Some(index) = index { + operation.evaluate(medrecord, index) + } else { + Ok(None) + } + }) + } + + implement_single_index_comparison_operation!(greater_than, EdgeIndexOperation, GreaterThan); + implement_single_index_comparison_operation!( + greater_than_or_equal_to, + EdgeIndexOperation, + GreaterThanOrEqualTo + ); + implement_single_index_comparison_operation!(less_than, EdgeIndexOperation, LessThan); + implement_single_index_comparison_operation!( + less_than_or_equal_to, + EdgeIndexOperation, + LessThanOrEqualTo + ); + implement_single_index_comparison_operation!(equal_to, EdgeIndexOperation, EqualTo); + implement_single_index_comparison_operation!(not_equal_to, EdgeIndexOperation, NotEqualTo); + implement_single_index_comparison_operation!(starts_with, EdgeIndexOperation, StartsWith); + implement_single_index_comparison_operation!(ends_with, EdgeIndexOperation, EndsWith); + implement_single_index_comparison_operation!(contains, EdgeIndexOperation, Contains); + + pub fn is_in>(&mut self, indices: V) { + self.operations + .push(EdgeIndexOperation::EdgeIndicesComparisonOperation { + operand: indices.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, indices: V) { + self.operations + .push(EdgeIndexOperation::EdgeIndicesComparisonOperation { + operand: indices.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, EdgeIndexOperation, Add); + implement_binary_arithmetic_operation!(sub, EdgeIndexOperation, Sub); + implement_binary_arithmetic_operation!(mul, EdgeIndexOperation, Mul); + implement_binary_arithmetic_operation!(pow, EdgeIndexOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, EdgeIndexOperation, Mod); + + pub fn eiter_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(EdgeIndexOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } + + pub fn exclude(&mut self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + let mut operand = Wrapper::::new(self.context.clone(), self.kind.clone()); + + query(&mut operand); + + self.operations + .push(EdgeIndexOperation::Exclude { operand }); + } +} + +impl Wrapper { + pub(crate) fn new(context: EdgeIndicesOperand, kind: SingleKind) -> Self { + EdgeIndexOperand::new(context, kind).into() + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + index: EdgeIndex, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, index) + } + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().eiter_or(either_query, or_query); + } + + pub fn exclude(&self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().exclude(query); + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/edges/operation.rs b/crates/medmodels-core/src/medrecord/querying/edges/operation.rs new file mode 100644 index 00000000..c6437f2f --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/edges/operation.rs @@ -0,0 +1,807 @@ +use super::{ + operand::{ + EdgeIndexComparisonOperand, EdgeIndexOperand, EdgeIndicesComparisonOperand, + EdgeIndicesOperand, + }, + BinaryArithmeticKind, EdgeOperand, MultipleComparisonKind, SingleComparisonKind, +}; +use crate::{ + errors::{MedRecordError, MedRecordResult}, + medrecord::{ + datatypes::{Contains, EndsWith, Mod, StartsWith}, + querying::{ + attributes::AttributesTreeOperand, + edges::SingleKind, + nodes::NodeOperand, + traits::{DeepClone, ReadWriteOrPanic}, + values::MultipleValuesOperand, + wrapper::Wrapper, + BoxedIterator, + }, + CardinalityWrapper, EdgeIndex, Group, MedRecordAttribute, MedRecordValue, + }, + MedRecord, +}; +use itertools::Itertools; +use std::{ + collections::HashSet, + ops::{Add, Mul, Sub}, +}; + +#[derive(Debug, Clone)] +pub enum EdgeOperation { + Values { + operand: Wrapper, + }, + Attributes { + operand: Wrapper, + }, + Indices { + operand: Wrapper, + }, + + InGroup { + group: CardinalityWrapper, + }, + HasAttribute { + attribute: CardinalityWrapper, + }, + + SourceNode { + operand: Wrapper, + }, + TargetNode { + operand: Wrapper, + }, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, + Exclude { + operand: Wrapper, + }, +} + +impl DeepClone for EdgeOperation { + fn deep_clone(&self) -> Self { + match self { + Self::Values { operand } => Self::Values { + operand: operand.deep_clone(), + }, + Self::Attributes { operand } => Self::Attributes { + operand: operand.deep_clone(), + }, + Self::Indices { operand } => Self::Indices { + operand: operand.deep_clone(), + }, + Self::InGroup { group } => Self::InGroup { + group: group.clone(), + }, + Self::HasAttribute { attribute } => Self::HasAttribute { + attribute: attribute.clone(), + }, + Self::SourceNode { operand } => Self::SourceNode { + operand: operand.deep_clone(), + }, + Self::TargetNode { operand } => Self::TargetNode { + operand: operand.deep_clone(), + }, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + Self::Exclude { operand } => Self::Exclude { + operand: operand.deep_clone(), + }, + } + } +} + +impl EdgeOperation { + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + edge_indices: impl Iterator + 'a, + ) -> MedRecordResult> { + Ok(match self { + Self::Values { operand } => Box::new(Self::evaluate_values( + medrecord, + edge_indices, + operand.clone(), + )?), + Self::Attributes { operand } => Box::new(Self::evaluate_attributes( + medrecord, + edge_indices, + operand.clone(), + )?), + Self::Indices { operand } => Box::new(Self::evaluate_indices( + medrecord, + edge_indices, + operand.clone(), + )?), + Self::InGroup { group } => Box::new(Self::evaluate_in_group( + medrecord, + edge_indices, + group.clone(), + )), + Self::HasAttribute { attribute } => Box::new(Self::evaluate_has_attribute( + medrecord, + edge_indices, + attribute.clone(), + )), + Self::SourceNode { operand } => Box::new(Self::evaluate_source_node( + medrecord, + edge_indices, + operand, + )?), + Self::TargetNode { operand } => Box::new(Self::evaluate_target_node( + medrecord, + edge_indices, + operand, + )?), + Self::EitherOr { either, or } => { + // TODO: This is a temporary solution. It should be optimized. + let either_result = either.evaluate(medrecord)?.collect::>(); + let or_result = or.evaluate(medrecord)?.collect::>(); + + Box::new(edge_indices.filter(move |node_index| { + either_result.contains(node_index) || or_result.contains(node_index) + })) + } + Self::Exclude { operand } => { + let result = operand.evaluate(medrecord)?.collect::>(); + + Box::new(edge_indices.filter(move |node_index| !result.contains(node_index))) + } + }) + } + + #[inline] + pub(crate) fn get_values<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + attribute: MedRecordAttribute, + ) -> impl Iterator { + edge_indices.flat_map(move |edge_index| { + Some(( + edge_index, + medrecord + .edge_attributes(edge_index) + .expect("Edge must exist") + .get(&attribute)? + .clone(), + )) + }) + } + + #[inline] + fn evaluate_values<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator + 'a, + operand: Wrapper, + ) -> MedRecordResult> { + let values = Self::get_values( + medrecord, + edge_indices, + operand.0.read_or_panic().attribute.clone(), + ); + + Ok(operand.evaluate(medrecord, values)?.map(|value| value.0)) + } + + #[inline] + pub(crate) fn get_attributes<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + ) -> impl Iterator)> { + edge_indices.map(move |edge_index| { + let attributes = medrecord + .edge_attributes(edge_index) + .expect("Edge must exist") + .keys() + .cloned(); + + (edge_index, attributes.collect()) + }) + } + + #[inline] + fn evaluate_attributes<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator + 'a, + operand: Wrapper, + ) -> MedRecordResult> { + let attributes = Self::get_attributes(medrecord, edge_indices); + + Ok(operand + .evaluate(medrecord, attributes)? + .map(|value| value.0)) + } + + #[inline] + fn evaluate_indices<'a>( + medrecord: &MedRecord, + edge_indices: impl Iterator, + operand: Wrapper, + ) -> MedRecordResult> { + // TODO: This is a temporary solution. It should be optimized. + let edge_indices = edge_indices.collect::>(); + + let result = operand + .evaluate(medrecord, edge_indices.clone().into_iter().cloned())? + .collect::>(); + + Ok(edge_indices + .into_iter() + .filter(move |index| result.contains(index))) + } + + #[inline] + fn evaluate_in_group<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + group: CardinalityWrapper, + ) -> impl Iterator { + edge_indices.filter(move |edge_index| { + let groups_of_edge = medrecord + .groups_of_edge(edge_index) + .expect("Node must exist"); + + let groups_of_edge = groups_of_edge.collect::>(); + + match &group { + CardinalityWrapper::Single(group) => groups_of_edge.contains(&group), + CardinalityWrapper::Multiple(groups) => { + groups.iter().all(|group| groups_of_edge.contains(&group)) + } + } + }) + } + + #[inline] + fn evaluate_has_attribute<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + attribute: CardinalityWrapper, + ) -> impl Iterator { + edge_indices.filter(move |edge_index| { + let attributes_of_edge = medrecord + .edge_attributes(edge_index) + .expect("Node must exist") + .keys(); + + let attributes_of_edge = attributes_of_edge.collect::>(); + + match &attribute { + CardinalityWrapper::Single(attribute) => attributes_of_edge.contains(&attribute), + CardinalityWrapper::Multiple(attributes) => attributes + .iter() + .all(|attribute| attributes_of_edge.contains(&attribute)), + } + }) + } + + #[inline] + fn evaluate_source_node<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let node_indices = operand.evaluate(medrecord)?.collect::>(); + + Ok(edge_indices.filter(move |edge_index| { + let edge_endpoints = medrecord + .edge_endpoints(edge_index) + .expect("Edge must exist"); + + node_indices.contains(edge_endpoints.0) + })) + } + + #[inline] + fn evaluate_target_node<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let node_indices = operand.evaluate(medrecord)?.collect::>(); + + Ok(edge_indices.filter(move |edge_index| { + let edge_endpoints = medrecord + .edge_endpoints(edge_index) + .expect("Edge must exist"); + + node_indices.contains(edge_endpoints.1) + })) + } +} + +macro_rules! get_edge_index { + ($kind:ident, $indices:expr) => { + match $kind { + SingleKind::Max => EdgeIndicesOperation::get_max($indices)?.clone(), + SingleKind::Min => EdgeIndicesOperation::get_min($indices)?.clone(), + SingleKind::Count => EdgeIndicesOperation::get_count($indices), + SingleKind::Sum => EdgeIndicesOperation::get_sum($indices), + SingleKind::First => EdgeIndicesOperation::get_first($indices)?, + SingleKind::Last => EdgeIndicesOperation::get_last($indices)?, + } + }; +} + +macro_rules! get_edge_index_comparison_operand_index { + ($operand:ident, $medrecord:ident) => { + match $operand { + EdgeIndexComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + // TODO: This is a temporary solution. It should be optimized. + let comparison_indices = context.evaluate($medrecord)?.cloned(); + + let comparison_index = get_edge_index!(kind, comparison_indices); + + operand.evaluate($medrecord, comparison_index)?.ok_or( + MedRecordError::QueryError("No index to compare".to_string()), + )? + } + EdgeIndexComparisonOperand::Index(index) => index.clone(), + } + }; +} + +#[derive(Debug, Clone)] +pub enum EdgeIndicesOperation { + EdgeIndexOperation { + operand: Wrapper, + }, + EdgeIndexComparisonOperation { + operand: EdgeIndexComparisonOperand, + kind: SingleComparisonKind, + }, + EdgeIndicesComparisonOperation { + operand: EdgeIndicesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: EdgeIndexComparisonOperand, + kind: BinaryArithmeticKind, + }, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, + Exclude { + operand: Wrapper, + }, +} + +impl DeepClone for EdgeIndicesOperation { + fn deep_clone(&self) -> Self { + match self { + Self::EdgeIndexOperation { operand } => Self::EdgeIndexOperation { + operand: operand.deep_clone(), + }, + Self::EdgeIndexComparisonOperation { operand, kind } => { + Self::EdgeIndexComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::EdgeIndicesComparisonOperation { operand, kind } => { + Self::EdgeIndicesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + Self::Exclude { operand } => Self::Exclude { + operand: operand.deep_clone(), + }, + } + } +} + +impl EdgeIndicesOperation { + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + indices: impl Iterator + 'a, + ) -> MedRecordResult> { + match self { + Self::EdgeIndexOperation { operand } => { + Self::evaluate_edge_index_operation(medrecord, indices, operand) + } + Self::EdgeIndexComparisonOperation { operand, kind } => { + Self::evaluate_edge_index_comparison_operation(medrecord, indices, operand, kind) + } + Self::EdgeIndicesComparisonOperation { operand, kind } => { + Self::evaluate_edge_indices_comparison_operation(medrecord, indices, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Ok(Box::new(Self::evaluate_binary_arithmetic_operation( + medrecord, + indices, + operand, + kind.clone(), + )?)) + } + Self::IsMax => { + let max_index = Self::get_max(indices)?; + + Ok(Box::new(std::iter::once(max_index))) + } + Self::IsMin => { + let min_index = Self::get_min(indices)?; + + Ok(Box::new(std::iter::once(min_index))) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, indices, either, or) + } + Self::Exclude { operand } => { + let edge_indices = indices.collect::>(); + + let result = operand + .evaluate(medrecord, edge_indices.clone().into_iter())? + .collect::>(); + + Ok(Box::new( + edge_indices + .into_iter() + .filter(move |index| !result.contains(index)), + )) + } + } + } + + #[inline] + pub(crate) fn get_max(indices: impl Iterator) -> MedRecordResult { + indices.max().ok_or(MedRecordError::QueryError( + "No indices to compare".to_string(), + )) + } + + #[inline] + pub(crate) fn get_min(indices: impl Iterator) -> MedRecordResult { + indices.min().ok_or(MedRecordError::QueryError( + "No indices to compare".to_string(), + )) + } + #[inline] + pub(crate) fn get_count(indices: impl Iterator) -> EdgeIndex { + indices.count() as EdgeIndex + } + + #[inline] + pub(crate) fn get_sum(indices: impl Iterator) -> EdgeIndex { + indices.sum() + } + + #[inline] + pub(crate) fn get_first( + mut indices: impl Iterator, + ) -> MedRecordResult { + indices.next().ok_or(MedRecordError::QueryError( + "No indices to get the first".to_string(), + )) + } + + #[inline] + pub(crate) fn get_last(indices: impl Iterator) -> MedRecordResult { + indices.last().ok_or(MedRecordError::QueryError( + "No indices to get the first".to_string(), + )) + } + + #[inline] + fn evaluate_edge_index_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let kind = &operand.0.read_or_panic().kind; + + let indices = indices.collect::>(); + + let index = get_edge_index!(kind, indices.clone().into_iter()); + + Ok(match operand.evaluate(medrecord, index)? { + Some(_) => Box::new(indices.into_iter()), + None => Box::new(std::iter::empty()), + }) + } + + #[inline] + fn evaluate_edge_index_comparison_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator + 'a, + comparison_operand: &EdgeIndexComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_index = + get_edge_index_comparison_operand_index!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => Ok(Box::new( + indices.filter(move |index| index > &comparison_index), + )), + SingleComparisonKind::GreaterThanOrEqualTo => Ok(Box::new( + indices.filter(move |index| index >= &comparison_index), + )), + SingleComparisonKind::LessThan => Ok(Box::new( + indices.filter(move |index| index < &comparison_index), + )), + SingleComparisonKind::LessThanOrEqualTo => Ok(Box::new( + indices.filter(move |index| index <= &comparison_index), + )), + SingleComparisonKind::EqualTo => Ok(Box::new( + indices.filter(move |index| index == &comparison_index), + )), + SingleComparisonKind::NotEqualTo => Ok(Box::new( + indices.filter(move |index| index != &comparison_index), + )), + SingleComparisonKind::StartsWith => Ok(Box::new( + indices.filter(move |index| index.starts_with(&comparison_index)), + )), + SingleComparisonKind::EndsWith => Ok(Box::new( + indices.filter(move |index| index.ends_with(&comparison_index)), + )), + SingleComparisonKind::Contains => Ok(Box::new( + indices.filter(move |index| index.contains(&comparison_index)), + )), + } + } + + #[inline] + fn evaluate_edge_indices_comparison_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator + 'a, + comparison_operand: &EdgeIndicesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_indices = match comparison_operand { + EdgeIndicesComparisonOperand::Operand(operand) => { + let context = &operand.context; + + let comparison_indices = context.evaluate(medrecord)?.cloned(); + + operand + .evaluate(medrecord, comparison_indices)? + .collect::>() + } + EdgeIndicesComparisonOperand::Indices(indices) => indices.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => Ok(Box::new( + indices.filter(move |index| comparison_indices.contains(index)), + )), + MultipleComparisonKind::IsNotIn => Ok(Box::new( + indices.filter(move |index| !comparison_indices.contains(index)), + )), + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + indices: impl Iterator, + operand: &EdgeIndexComparisonOperand, + kind: BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_index = get_edge_index_comparison_operand_index!(operand, medrecord); + + Ok(indices + .map(move |index| match kind { + BinaryArithmeticKind::Add => Ok(index.add(arithmetic_index)), + BinaryArithmeticKind::Sub => Ok(index.sub(arithmetic_index)), + BinaryArithmeticKind::Mul => Ok(index.mul(arithmetic_index)), + BinaryArithmeticKind::Pow => Ok(index.pow(arithmetic_index)), + BinaryArithmeticKind::Mod => index.r#mod(arithmetic_index), + }) + .collect::>>()? + .into_iter()) + } + + #[inline] + fn evaluate_either_or<'a>( + medrecord: &'a MedRecord, + indices: impl Iterator, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let indices = indices.collect::>(); + + let either_indices = either.evaluate(medrecord, indices.clone().into_iter())?; + let or_indices = or.evaluate(medrecord, indices.into_iter())?; + + Ok(Box::new(either_indices.chain(or_indices).unique())) + } +} + +#[derive(Debug, Clone)] +pub enum EdgeIndexOperation { + EdgeIndexComparisonOperation { + operand: EdgeIndexComparisonOperand, + kind: SingleComparisonKind, + }, + EdgeIndicesComparisonOperation { + operand: EdgeIndicesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: EdgeIndexComparisonOperand, + kind: BinaryArithmeticKind, + }, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, + Exclude { + operand: Wrapper, + }, +} + +impl DeepClone for EdgeIndexOperation { + fn deep_clone(&self) -> Self { + match self { + Self::EdgeIndexComparisonOperation { operand, kind } => { + Self::EdgeIndexComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::EdgeIndicesComparisonOperation { operand, kind } => { + Self::EdgeIndicesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + Self::Exclude { operand } => Self::Exclude { + operand: operand.deep_clone(), + }, + } + } +} + +impl EdgeIndexOperation { + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + index: EdgeIndex, + ) -> MedRecordResult> { + match self { + Self::EdgeIndexComparisonOperation { operand, kind } => { + Self::evaluate_edge_index_comparison_operation(medrecord, index, operand, kind) + } + Self::EdgeIndicesComparisonOperation { operand, kind } => { + Self::evaluate_edge_indcies_comparison_operation(medrecord, index, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, index, operand, kind) + } + Self::EitherOr { either, or } => Self::evaluate_either_or(medrecord, index, either, or), + Self::Exclude { operand } => { + let result = operand.evaluate(medrecord, index)?.is_some(); + + Ok(if result { None } else { Some(index) }) + } + } + } + + #[inline] + fn evaluate_edge_index_comparison_operation( + medrecord: &MedRecord, + index: EdgeIndex, + comparison_operand: &EdgeIndexComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_index = + get_edge_index_comparison_operand_index!(comparison_operand, medrecord); + + let comparison_result = match kind { + SingleComparisonKind::GreaterThan => index > comparison_index, + SingleComparisonKind::GreaterThanOrEqualTo => index >= comparison_index, + SingleComparisonKind::LessThan => index < comparison_index, + SingleComparisonKind::LessThanOrEqualTo => index <= comparison_index, + SingleComparisonKind::EqualTo => index == comparison_index, + SingleComparisonKind::NotEqualTo => index != comparison_index, + SingleComparisonKind::StartsWith => index.starts_with(&comparison_index), + SingleComparisonKind::EndsWith => index.ends_with(&comparison_index), + SingleComparisonKind::Contains => index.contains(&comparison_index), + }; + + Ok(if comparison_result { Some(index) } else { None }) + } + + #[inline] + fn evaluate_edge_indcies_comparison_operation( + medrecord: &MedRecord, + index: EdgeIndex, + comparison_operand: &EdgeIndicesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_indices = match comparison_operand { + EdgeIndicesComparisonOperand::Operand(operand) => { + let context = &operand.context; + + let comparison_indices = context.evaluate(medrecord)?.cloned(); + + operand + .evaluate(medrecord, comparison_indices)? + .collect::>() + } + EdgeIndicesComparisonOperand::Indices(indices) => indices.clone(), + }; + + let comparison_result = match kind { + MultipleComparisonKind::IsIn => comparison_indices + .into_iter() + .any(|comparison_index| index == comparison_index), + MultipleComparisonKind::IsNotIn => comparison_indices + .into_iter() + .all(|comparison_index| index != comparison_index), + }; + + Ok(if comparison_result { Some(index) } else { None }) + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + index: EdgeIndex, + operand: &EdgeIndexComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_index = get_edge_index_comparison_operand_index!(operand, medrecord); + + Ok(Some(match kind { + BinaryArithmeticKind::Add => index.add(arithmetic_index), + BinaryArithmeticKind::Sub => index.sub(arithmetic_index), + BinaryArithmeticKind::Mul => index.mul(arithmetic_index), + BinaryArithmeticKind::Pow => index.pow(arithmetic_index), + BinaryArithmeticKind::Mod => index.r#mod(arithmetic_index)?, + })) + } + + #[inline] + fn evaluate_either_or( + medrecord: &MedRecord, + index: EdgeIndex, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord, index)?; + let or_result = or.evaluate(medrecord, index)?; + + match (either_result, or_result) { + (Some(either_result), _) => Ok(Some(either_result)), + (None, Some(or_result)) => Ok(Some(or_result)), + _ => Ok(None), + } + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/edges/selection.rs b/crates/medmodels-core/src/medrecord/querying/edges/selection.rs new file mode 100644 index 00000000..a0d0a519 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/edges/selection.rs @@ -0,0 +1,32 @@ +use super::EdgeOperand; +use crate::{ + errors::MedRecordResult, + medrecord::{querying::wrapper::Wrapper, EdgeIndex, MedRecord}, +}; + +#[derive(Debug, Clone)] +pub struct EdgeSelection<'a> { + medrecord: &'a MedRecord, + operand: Wrapper, +} + +impl<'a> EdgeSelection<'a> { + pub fn new(medrecord: &'a MedRecord, query: Q) -> Self + where + Q: FnOnce(&mut Wrapper), + { + let mut operand = Wrapper::::new(); + + query(&mut operand); + + Self { medrecord, operand } + } + + pub fn iter(&'a self) -> MedRecordResult> { + self.operand.evaluate(self.medrecord) + } + + pub fn collect>(&'a self) -> MedRecordResult { + Ok(FromIterator::from_iter(self.iter()?)) + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/mod.rs b/crates/medmodels-core/src/medrecord/querying/mod.rs index 1f999f78..0096f87e 100644 --- a/crates/medmodels-core/src/medrecord/querying/mod.rs +++ b/crates/medmodels-core/src/medrecord/querying/mod.rs @@ -1,9 +1,8 @@ -mod operation; -mod selection; +pub mod attributes; +pub mod edges; +pub mod nodes; +pub mod traits; +pub mod values; +pub mod wrapper; -pub use self::operation::{ - edge, node, ArithmeticOperation, EdgeAttributeOperand, EdgeIndexOperand, EdgeOperand, - EdgeOperation, NodeAttributeOperand, NodeIndexOperand, NodeOperand, NodeOperation, - TransformationOperation, ValueOperand, -}; -pub(super) use self::selection::{EdgeSelection, NodeSelection}; +pub(crate) type BoxedIterator<'a, T> = Box + 'a>; diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/mod.rs b/crates/medmodels-core/src/medrecord/querying/nodes/mod.rs new file mode 100644 index 00000000..4714ccd4 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/nodes/mod.rs @@ -0,0 +1,71 @@ +mod operand; +mod operation; +mod selection; + +pub use operand::{ + NodeIndexComparisonOperand, NodeIndexOperand, NodeIndicesComparisonOperand, NodeIndicesOperand, + NodeOperand, +}; +pub use operation::{EdgeDirection, NodeOperation}; +pub use selection::NodeSelection; +use std::fmt::Display; + +#[derive(Debug, Clone)] +pub enum SingleKind { + Max, + Min, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum SingleComparisonKind { + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + NotEqualTo, + StartsWith, + EndsWith, + Contains, +} + +#[derive(Debug, Clone)] +pub enum MultipleComparisonKind { + IsIn, + IsNotIn, +} + +#[derive(Debug, Clone)] +pub enum BinaryArithmeticKind { + Add, + Sub, + Mul, + Pow, + Mod, +} + +impl Display for BinaryArithmeticKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BinaryArithmeticKind::Add => write!(f, "add"), + BinaryArithmeticKind::Sub => write!(f, "sub"), + BinaryArithmeticKind::Mul => write!(f, "mul"), + BinaryArithmeticKind::Pow => write!(f, "pow"), + BinaryArithmeticKind::Mod => write!(f, "mod"), + } + } +} + +#[derive(Debug, Clone)] +pub enum UnaryArithmeticKind { + Abs, + Trim, + TrimStart, + TrimEnd, + Lowercase, + Uppercase, +} diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs b/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs new file mode 100644 index 00000000..329dc8f7 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs @@ -0,0 +1,791 @@ +use super::{ + operation::{EdgeDirection, NodeIndexOperation, NodeIndicesOperation, NodeOperation}, + BinaryArithmeticKind, MultipleComparisonKind, SingleComparisonKind, SingleKind, + UnaryArithmeticKind, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{ + querying::{ + attributes::{self, AttributesTreeOperand}, + edges::EdgeOperand, + traits::{DeepClone, ReadWriteOrPanic}, + values::{self, MultipleValuesOperand}, + wrapper::{CardinalityWrapper, Wrapper}, + BoxedIterator, + }, + Group, MedRecordAttribute, NodeIndex, + }, + MedRecord, +}; +use std::fmt::Debug; + +#[derive(Debug, Clone)] +pub struct NodeOperand { + operations: Vec, +} + +impl DeepClone for NodeOperand { + fn deep_clone(&self) -> Self { + Self { + operations: self + .operations + .iter() + .map(|operation| operation.deep_clone()) + .collect(), + } + } +} + +impl NodeOperand { + pub(crate) fn new() -> Self { + Self { + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + ) -> MedRecordResult> { + let node_indices = Box::new(medrecord.node_indices()) as BoxedIterator<'a, &'a NodeIndex>; + + self.operations + .iter() + .try_fold(node_indices, |node_indices, operation| { + operation.evaluate(medrecord, node_indices) + }) + } + + pub fn attribute(&mut self, attribute: MedRecordAttribute) -> Wrapper { + let operand = Wrapper::::new( + values::Context::NodeOperand(self.deep_clone()), + attribute, + ); + + self.operations.push(NodeOperation::Values { + operand: operand.clone(), + }); + + operand + } + + pub fn attributes(&mut self) -> Wrapper { + let operand = Wrapper::::new(attributes::Context::NodeOperand( + self.deep_clone(), + )); + + self.operations.push(NodeOperation::Attributes { + operand: operand.clone(), + }); + + operand + } + + pub fn index(&mut self) -> Wrapper { + let operand = Wrapper::::new(self.deep_clone()); + + self.operations.push(NodeOperation::Indices { + operand: operand.clone(), + }); + + operand + } + + pub fn in_group(&mut self, group: G) + where + G: Into>, + { + self.operations.push(NodeOperation::InGroup { + group: group.into(), + }); + } + + pub fn has_attribute(&mut self, attribute: A) + where + A: Into>, + { + self.operations.push(NodeOperation::HasAttribute { + attribute: attribute.into(), + }); + } + + pub fn outgoing_edges(&mut self) -> Wrapper { + let operand = Wrapper::::new(); + + self.operations.push(NodeOperation::OutgoingEdges { + operand: operand.clone(), + }); + + operand + } + + pub fn incoming_edges(&mut self) -> Wrapper { + let operand = Wrapper::::new(); + + self.operations.push(NodeOperation::IncomingEdges { + operand: operand.clone(), + }); + + operand + } + + pub fn neighbors(&mut self, direction: EdgeDirection) -> Wrapper { + let operand = Wrapper::::new(); + + self.operations.push(NodeOperation::Neighbors { + operand: operand.clone(), + direction, + }); + + operand + } + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(); + let mut or_operand = Wrapper::::new(); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(NodeOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } + + pub fn exclude(&mut self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + let mut operand = Wrapper::::new(); + + query(&mut operand); + + self.operations.push(NodeOperation::Exclude { operand }); + } +} + +impl Wrapper { + pub(crate) fn new() -> Self { + NodeOperand::new().into() + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord) + } + + pub fn attribute(&mut self, attribute: A) -> Wrapper + where + A: Into, + { + self.0.write_or_panic().attribute(attribute.into()) + } + + pub fn attributes(&mut self) -> Wrapper { + self.0.write_or_panic().attributes() + } + + pub fn index(&mut self) -> Wrapper { + self.0.write_or_panic().index() + } + + pub fn in_group(&mut self, group: G) + where + G: Into>, + { + self.0.write_or_panic().in_group(group); + } + + pub fn has_attribute(&mut self, attribute: A) + where + A: Into>, + { + self.0.write_or_panic().has_attribute(attribute); + } + + pub fn outgoing_edges(&mut self) -> Wrapper { + self.0.write_or_panic().outgoing_edges() + } + + pub fn incoming_edges(&mut self) -> Wrapper { + self.0.write_or_panic().incoming_edges() + } + + pub fn neighbors(&mut self, direction: EdgeDirection) -> Wrapper { + self.0.write_or_panic().neighbors(direction) + } + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } + + pub fn exclude(&mut self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().exclude(query); + } +} + +macro_rules! implement_index_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = Wrapper::::new(self.deep_clone(), SingleKind::$variant); + + self.operations + .push(NodeIndicesOperation::NodeIndexOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_single_index_comparison_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, index: V) { + self.operations + .push($operation::NodeIndexComparisonOperation { + operand: index.into(), + kind: SingleComparisonKind::$kind, + }); + } + }; +} + +macro_rules! implement_binary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, index: V) { + self.operations.push($operation::BinaryArithmeticOpration { + operand: index.into(), + kind: BinaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_unary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name(&mut self) { + self.operations.push($operation::UnaryArithmeticOperation { + kind: UnaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_assertion_operation { + ($name:ident, $operation:expr) => { + pub fn $name(&mut self) { + self.operations.push($operation); + } + }; +} + +macro_rules! implement_wrapper_operand { + ($name:ident) => { + pub fn $name(&self) { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_return { + ($name:ident, $return_operand:ident) => { + pub fn $name(&self) -> Wrapper<$return_operand> { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_argument { + ($name:ident, $index_type:ty) => { + pub fn $name(&self, index: $index_type) { + self.0.write_or_panic().$name(index) + } + }; +} + +#[derive(Debug, Clone)] +pub enum NodeIndexComparisonOperand { + Operand(NodeIndexOperand), + Index(NodeIndex), +} + +impl DeepClone for NodeIndexComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Index(index) => Self::Index(index.clone()), + } + } +} + +impl From> for NodeIndexComparisonOperand { + fn from(index: Wrapper) -> Self { + Self::Operand(index.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for NodeIndexComparisonOperand { + fn from(index: &Wrapper) -> Self { + Self::Operand(index.0.read_or_panic().deep_clone()) + } +} + +impl> From for NodeIndexComparisonOperand { + fn from(index: V) -> Self { + Self::Index(index.into()) + } +} + +#[derive(Debug, Clone)] +pub enum NodeIndicesComparisonOperand { + Operand(NodeIndicesOperand), + Indices(Vec), +} + +impl DeepClone for NodeIndicesComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Indices(indices) => Self::Indices(indices.clone()), + } + } +} + +impl From> for NodeIndicesComparisonOperand { + fn from(indices: Wrapper) -> Self { + Self::Operand(indices.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for NodeIndicesComparisonOperand { + fn from(indices: &Wrapper) -> Self { + Self::Operand(indices.0.read_or_panic().deep_clone()) + } +} + +impl> From> for NodeIndicesComparisonOperand { + fn from(indices: Vec) -> Self { + Self::Indices(indices.into_iter().map(Into::into).collect()) + } +} + +impl + Clone, const N: usize> From<[V; N]> for NodeIndicesComparisonOperand { + fn from(indices: [V; N]) -> Self { + indices.to_vec().into() + } +} + +#[derive(Debug, Clone)] +pub struct NodeIndicesOperand { + pub(crate) context: NodeOperand, + operations: Vec, +} + +impl DeepClone for NodeIndicesOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl NodeIndicesOperand { + pub(crate) fn new(context: NodeOperand) -> Self { + Self { + context, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + indices: impl Iterator + 'a, + ) -> MedRecordResult + 'a> { + let indices = Box::new(indices) as BoxedIterator; + + self.operations + .iter() + .try_fold(indices, |index_tuples, operation| { + operation.evaluate(medrecord, index_tuples) + }) + } + + implement_index_operation!(max, Max); + implement_index_operation!(min, Min); + implement_index_operation!(count, Count); + implement_index_operation!(sum, Sum); + implement_index_operation!(first, First); + implement_index_operation!(last, Last); + + implement_single_index_comparison_operation!(greater_than, NodeIndicesOperation, GreaterThan); + implement_single_index_comparison_operation!( + greater_than_or_equal_to, + NodeIndicesOperation, + GreaterThanOrEqualTo + ); + implement_single_index_comparison_operation!(less_than, NodeIndicesOperation, LessThan); + implement_single_index_comparison_operation!( + less_than_or_equal_to, + NodeIndicesOperation, + LessThanOrEqualTo + ); + implement_single_index_comparison_operation!(equal_to, NodeIndicesOperation, EqualTo); + implement_single_index_comparison_operation!(not_equal_to, NodeIndicesOperation, NotEqualTo); + implement_single_index_comparison_operation!(starts_with, NodeIndicesOperation, StartsWith); + implement_single_index_comparison_operation!(ends_with, NodeIndicesOperation, EndsWith); + implement_single_index_comparison_operation!(contains, NodeIndicesOperation, Contains); + + pub fn is_in>(&mut self, indices: V) { + self.operations + .push(NodeIndicesOperation::NodeIndicesComparisonOperation { + operand: indices.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, indices: V) { + self.operations + .push(NodeIndicesOperation::NodeIndicesComparisonOperation { + operand: indices.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, NodeIndicesOperation, Add); + implement_binary_arithmetic_operation!(sub, NodeIndicesOperation, Sub); + implement_binary_arithmetic_operation!(mul, NodeIndicesOperation, Mul); + implement_binary_arithmetic_operation!(pow, NodeIndicesOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, NodeIndicesOperation, Mod); + + implement_unary_arithmetic_operation!(abs, NodeIndicesOperation, Abs); + implement_unary_arithmetic_operation!(trim, NodeIndicesOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, NodeIndicesOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, NodeIndicesOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, NodeIndicesOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, NodeIndicesOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(NodeIndicesOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, NodeIndicesOperation::IsString); + implement_assertion_operation!(is_int, NodeIndicesOperation::IsInt); + implement_assertion_operation!(is_max, NodeIndicesOperation::IsMax); + implement_assertion_operation!(is_min, NodeIndicesOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(self.context.clone()); + let mut or_operand = Wrapper::::new(self.context.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(NodeIndicesOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } + + pub fn exclude(&mut self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + let mut operand = Wrapper::::new(self.context.clone()); + + query(&mut operand); + + self.operations + .push(NodeIndicesOperation::Exclude { operand }); + } +} + +impl Wrapper { + pub(crate) fn new(context: NodeOperand) -> Self { + NodeIndicesOperand::new(context).into() + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + indices: impl Iterator + 'a, + ) -> MedRecordResult + 'a> { + self.0.read_or_panic().evaluate(medrecord, indices) + } + + implement_wrapper_operand_with_return!(max, NodeIndexOperand); + implement_wrapper_operand_with_return!(min, NodeIndexOperand); + implement_wrapper_operand_with_return!(count, NodeIndexOperand); + implement_wrapper_operand_with_return!(sum, NodeIndexOperand); + implement_wrapper_operand_with_return!(first, NodeIndexOperand); + implement_wrapper_operand_with_return!(last, NodeIndexOperand); + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } + + pub fn exclude(&self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().exclude(query); + } +} + +#[derive(Debug, Clone)] +pub struct NodeIndexOperand { + pub(crate) context: NodeIndicesOperand, + pub(crate) kind: SingleKind, + operations: Vec, +} + +impl DeepClone for NodeIndexOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.deep_clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl NodeIndexOperand { + pub(crate) fn new(context: NodeIndicesOperand, kind: SingleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + index: NodeIndex, + ) -> MedRecordResult> { + self.operations + .iter() + .try_fold(Some(index), |index, operation| { + if let Some(index) = index { + operation.evaluate(medrecord, index) + } else { + Ok(None) + } + }) + } + + implement_single_index_comparison_operation!(greater_than, NodeIndexOperation, GreaterThan); + implement_single_index_comparison_operation!( + greater_than_or_equal_to, + NodeIndexOperation, + GreaterThanOrEqualTo + ); + implement_single_index_comparison_operation!(less_than, NodeIndexOperation, LessThan); + implement_single_index_comparison_operation!( + less_than_or_equal_to, + NodeIndexOperation, + LessThanOrEqualTo + ); + implement_single_index_comparison_operation!(equal_to, NodeIndexOperation, EqualTo); + implement_single_index_comparison_operation!(not_equal_to, NodeIndexOperation, NotEqualTo); + implement_single_index_comparison_operation!(starts_with, NodeIndexOperation, StartsWith); + implement_single_index_comparison_operation!(ends_with, NodeIndexOperation, EndsWith); + implement_single_index_comparison_operation!(contains, NodeIndexOperation, Contains); + + pub fn is_in>(&mut self, indices: V) { + self.operations + .push(NodeIndexOperation::NodeIndicesComparisonOperation { + operand: indices.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, indices: V) { + self.operations + .push(NodeIndexOperation::NodeIndicesComparisonOperation { + operand: indices.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, NodeIndexOperation, Add); + implement_binary_arithmetic_operation!(sub, NodeIndexOperation, Sub); + implement_binary_arithmetic_operation!(mul, NodeIndexOperation, Mul); + implement_binary_arithmetic_operation!(pow, NodeIndexOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, NodeIndexOperation, Mod); + + implement_unary_arithmetic_operation!(abs, NodeIndexOperation, Abs); + implement_unary_arithmetic_operation!(trim, NodeIndexOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, NodeIndexOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, NodeIndexOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, NodeIndexOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, NodeIndexOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations.push(NodeIndexOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, NodeIndexOperation::IsString); + implement_assertion_operation!(is_int, NodeIndexOperation::IsInt); + pub fn eiter_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(NodeIndexOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } + + pub fn exclude(&mut self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + let mut operand = Wrapper::::new(self.context.clone(), self.kind.clone()); + + query(&mut operand); + + self.operations + .push(NodeIndexOperation::Exclude { operand }); + } +} + +impl Wrapper { + pub(crate) fn new(context: NodeIndicesOperand, kind: SingleKind) -> Self { + NodeIndexOperand::new(context, kind).into() + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + index: NodeIndex, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, index) + } + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().eiter_or(either_query, or_query); + } + + pub fn exclude(&self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().exclude(query); + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs b/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs new file mode 100644 index 00000000..7db9544b --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs @@ -0,0 +1,1024 @@ +use super::{ + operand::{ + NodeIndexComparisonOperand, NodeIndexOperand, NodeIndicesComparisonOperand, + NodeIndicesOperand, + }, + BinaryArithmeticKind, MultipleComparisonKind, NodeOperand, SingleComparisonKind, SingleKind, + UnaryArithmeticKind, +}; +use crate::{ + errors::{MedRecordError, MedRecordResult}, + medrecord::{ + datatypes::{ + Abs, Contains, EndsWith, Lowercase, Mod, Pow, Slice, StartsWith, Trim, TrimEnd, + TrimStart, Uppercase, + }, + querying::{ + attributes::AttributesTreeOperand, + edges::EdgeOperand, + traits::{DeepClone, ReadWriteOrPanic}, + values::MultipleValuesOperand, + wrapper::{CardinalityWrapper, Wrapper}, + BoxedIterator, + }, + DataType, Group, MedRecord, MedRecordAttribute, MedRecordValue, NodeIndex, + }, +}; +use itertools::Itertools; +use roaring::RoaringBitmap; +use std::{ + cmp::Ordering, + collections::HashSet, + ops::{Add, Mul, Range, Sub}, +}; + +#[derive(Debug, Clone)] +pub enum EdgeDirection { + Incoming, + Outgoing, + Both, +} + +#[derive(Debug, Clone)] +pub enum NodeOperation { + Values { + operand: Wrapper, + }, + Attributes { + operand: Wrapper, + }, + Indices { + operand: Wrapper, + }, + + InGroup { + group: CardinalityWrapper, + }, + HasAttribute { + attribute: CardinalityWrapper, + }, + + OutgoingEdges { + operand: Wrapper, + }, + IncomingEdges { + operand: Wrapper, + }, + + Neighbors { + operand: Wrapper, + direction: EdgeDirection, + }, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, + Exclude { + operand: Wrapper, + }, +} + +impl DeepClone for NodeOperation { + fn deep_clone(&self) -> Self { + match self { + Self::Values { operand } => Self::Values { + operand: operand.deep_clone(), + }, + Self::Attributes { operand } => Self::Attributes { + operand: operand.deep_clone(), + }, + Self::Indices { operand } => Self::Indices { + operand: operand.deep_clone(), + }, + Self::InGroup { group } => Self::InGroup { + group: group.clone(), + }, + Self::HasAttribute { attribute } => Self::HasAttribute { + attribute: attribute.clone(), + }, + Self::OutgoingEdges { operand } => Self::OutgoingEdges { + operand: operand.deep_clone(), + }, + Self::IncomingEdges { operand } => Self::IncomingEdges { + operand: operand.deep_clone(), + }, + Self::Neighbors { + operand, + direction: drection, + } => Self::Neighbors { + operand: operand.deep_clone(), + direction: drection.clone(), + }, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + Self::Exclude { operand } => Self::Exclude { + operand: operand.deep_clone(), + }, + } + } +} + +impl NodeOperation { + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + node_indices: impl Iterator + 'a, + ) -> MedRecordResult> { + Ok(match self { + Self::Values { operand } => Box::new(Self::evaluate_values( + medrecord, + node_indices, + operand.clone(), + )?), + Self::Attributes { operand } => Box::new(Self::evaluate_attributes( + medrecord, + node_indices, + operand.clone(), + )?), + Self::Indices { operand } => Box::new(Self::evaluate_indices( + medrecord, + node_indices, + operand.clone(), + )?), + Self::InGroup { group } => Box::new(Self::evaluate_in_group( + medrecord, + node_indices, + group.clone(), + )), + Self::HasAttribute { attribute } => Box::new(Self::evaluate_has_attribute( + medrecord, + node_indices, + attribute.clone(), + )), + Self::OutgoingEdges { operand } => Box::new(Self::evaluate_outgoing_edges( + medrecord, + node_indices, + operand.clone(), + )?), + Self::IncomingEdges { operand } => Box::new(Self::evaluate_incoming_edges( + medrecord, + node_indices, + operand.clone(), + )?), + Self::Neighbors { + operand, + direction: drection, + } => Box::new(Self::evaluate_neighbors( + medrecord, + node_indices, + operand.clone(), + drection.clone(), + )?), + Self::EitherOr { either, or } => { + // TODO: This is a temporary solution. It should be optimized. + let either_result = either.evaluate(medrecord)?.collect::>(); + let or_result = or.evaluate(medrecord)?.collect::>(); + + Box::new(node_indices.filter(move |node_index| { + either_result.contains(node_index) || or_result.contains(node_index) + })) + } + Self::Exclude { operand } => { + let result = operand.evaluate(medrecord)?.collect::>(); + + Box::new(node_indices.filter(move |node_index| !result.contains(node_index))) + } + }) + } + + #[inline] + pub(crate) fn get_values<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + attribute: MedRecordAttribute, + ) -> impl Iterator { + node_indices.flat_map(move |node_index| { + Some(( + node_index, + medrecord + .node_attributes(node_index) + .expect("Edge must exist") + .get(&attribute)? + .clone(), + )) + }) + } + + #[inline] + fn evaluate_values<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator + 'a, + operand: Wrapper, + ) -> MedRecordResult> { + let values = Self::get_values( + medrecord, + node_indices, + operand.0.read_or_panic().attribute.clone(), + ); + + Ok(operand.evaluate(medrecord, values)?.map(|value| value.0)) + } + + #[inline] + pub(crate) fn get_attributes<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + ) -> impl Iterator)> { + node_indices.map(move |node_index| { + let attributes = medrecord + .node_attributes(node_index) + .expect("Edge must exist") + .keys() + .cloned(); + + (node_index, attributes.collect()) + }) + } + + #[inline] + fn evaluate_attributes<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator + 'a, + operand: Wrapper, + ) -> MedRecordResult> { + let attributes = Self::get_attributes(medrecord, node_indices); + + Ok(operand + .evaluate(medrecord, attributes)? + .map(|value| value.0)) + } + + #[inline] + fn evaluate_indices<'a>( + medrecord: &MedRecord, + edge_indices: impl Iterator, + operand: Wrapper, + ) -> MedRecordResult> { + // TODO: This is a temporary solution. It should be optimized. + let node_indices = edge_indices.collect::>(); + + let result = operand + .evaluate(medrecord, node_indices.clone().into_iter().cloned())? + .collect::>(); + + Ok(node_indices + .into_iter() + .filter(move |index| result.contains(index))) + } + + #[inline] + fn evaluate_in_group<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + group: CardinalityWrapper, + ) -> impl Iterator { + node_indices.filter(move |node_index| { + let groups_of_node = medrecord + .groups_of_node(node_index) + .expect("Node must exist"); + + let groups_of_node = groups_of_node.collect::>(); + + match &group { + CardinalityWrapper::Single(group) => groups_of_node.contains(&group), + CardinalityWrapper::Multiple(groups) => { + groups.iter().all(|group| groups_of_node.contains(&group)) + } + } + }) + } + + #[inline] + fn evaluate_has_attribute<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + attribute: CardinalityWrapper, + ) -> impl Iterator { + node_indices.filter(move |node_index| { + let attributes_of_node = medrecord + .node_attributes(node_index) + .expect("Node must exist") + .keys(); + + let attributes_of_node = attributes_of_node.collect::>(); + + match &attribute { + CardinalityWrapper::Single(attribute) => attributes_of_node.contains(&attribute), + CardinalityWrapper::Multiple(attributes) => attributes + .iter() + .all(|attribute| attributes_of_node.contains(&attribute)), + } + }) + } + + #[inline] + fn evaluate_outgoing_edges<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + operand: Wrapper, + ) -> MedRecordResult> { + let edge_indices = operand.evaluate(medrecord)?.collect::(); + + Ok(node_indices.filter(move |node_index| { + let outgoing_edge_indices = medrecord + .outgoing_edges(node_index) + .expect("Node must exist"); + + let outgoing_edge_indices = outgoing_edge_indices.collect::(); + + !outgoing_edge_indices.is_disjoint(&edge_indices) + })) + } + + #[inline] + fn evaluate_incoming_edges<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + operand: Wrapper, + ) -> MedRecordResult> { + let edge_indices = operand.evaluate(medrecord)?.collect::(); + + Ok(node_indices.filter(move |node_index| { + let incoming_edge_indices = medrecord + .incoming_edges(node_index) + .expect("Node must exist"); + + let incoming_edge_indices = incoming_edge_indices.collect::(); + + !incoming_edge_indices.is_disjoint(&edge_indices) + })) + } + + #[inline] + fn evaluate_neighbors<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + operand: Wrapper, + direction: EdgeDirection, + ) -> MedRecordResult> { + let result = operand.evaluate(medrecord)?.collect::>(); + + Ok(node_indices.filter(move |node_index| { + let mut neighbors: Box> = match direction { + EdgeDirection::Incoming => Box::new( + medrecord + .neighbors_incoming(node_index) + .expect("Node must exist"), + ), + EdgeDirection::Outgoing => Box::new( + medrecord + .neighbors_outgoing(node_index) + .expect("Node must exist"), + ), + EdgeDirection::Both => Box::new( + medrecord + .neighbors_undirected(node_index) + .expect("Node must exist"), + ), + }; + + neighbors.any(|neighbor| result.contains(&neighbor)) + })) + } +} + +macro_rules! get_node_index { + ($kind:ident, $indices:expr) => { + match $kind { + SingleKind::Max => NodeIndicesOperation::get_max($indices)?.clone(), + SingleKind::Min => NodeIndicesOperation::get_min($indices)?.clone(), + SingleKind::Count => NodeIndicesOperation::get_count($indices), + SingleKind::Sum => NodeIndicesOperation::get_sum($indices)?, + SingleKind::First => NodeIndicesOperation::get_first($indices)?, + SingleKind::Last => NodeIndicesOperation::get_last($indices)?, + } + }; +} + +macro_rules! get_node_index_comparison_operand { + ($operand:ident, $medrecord:ident) => { + match $operand { + NodeIndexComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + // TODO: This is a temporary solution. It should be optimized. + let comparison_indices = context.evaluate($medrecord)?.cloned(); + + let comparison_index = get_node_index!(kind, comparison_indices); + + operand.evaluate($medrecord, comparison_index)?.ok_or( + MedRecordError::QueryError("No index to compare".to_string()), + )? + } + NodeIndexComparisonOperand::Index(index) => index.clone(), + } + }; +} + +#[derive(Debug, Clone)] +pub enum NodeIndicesOperation { + NodeIndexOperation { + operand: Wrapper, + }, + NodeIndexComparisonOperation { + operand: NodeIndexComparisonOperand, + kind: SingleComparisonKind, + }, + NodeIndicesComparisonOperation { + operand: NodeIndicesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: NodeIndexComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, + Exclude { + operand: Wrapper, + }, +} + +impl DeepClone for NodeIndicesOperation { + fn deep_clone(&self) -> Self { + match self { + Self::NodeIndexOperation { operand } => Self::NodeIndexOperation { + operand: operand.deep_clone(), + }, + Self::NodeIndexComparisonOperation { operand, kind } => { + Self::NodeIndexComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::NodeIndicesComparisonOperation { operand, kind } => { + Self::NodeIndicesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + Self::Exclude { operand } => Self::Exclude { + operand: operand.deep_clone(), + }, + } + } +} + +impl NodeIndicesOperation { + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + indices: impl Iterator + 'a, + ) -> MedRecordResult> { + match self { + Self::NodeIndexOperation { operand } => { + Self::evaluate_node_index_operation(medrecord, indices, operand) + } + Self::NodeIndexComparisonOperation { operand, kind } => { + Self::evaluate_node_index_comparison_operation(medrecord, indices, operand, kind) + } + Self::NodeIndicesComparisonOperation { operand, kind } => { + Self::evaluate_node_indices_comparison_operation(medrecord, indices, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Ok(Box::new(Self::evaluate_binary_arithmetic_operation( + medrecord, + indices, + operand, + kind.clone(), + )?)) + } + Self::UnaryArithmeticOperation { kind } => Ok(Box::new( + Self::evaluate_unary_arithmetic_operation(indices, kind.clone()), + )), + Self::Slice(range) => Ok(Box::new(Self::evaluate_slice(indices, range.clone()))), + Self::IsString => { + Ok(Box::new(indices.filter(|index| { + matches!(index, MedRecordAttribute::String(_)) + }))) + } + Self::IsInt => { + Ok(Box::new(indices.filter(|index| { + matches!(index, MedRecordAttribute::Int(_)) + }))) + } + Self::IsMax => { + let max_index = Self::get_max(indices)?; + + Ok(Box::new(std::iter::once(max_index))) + } + Self::IsMin => { + let min_index = Self::get_min(indices)?; + + Ok(Box::new(std::iter::once(min_index))) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, indices, either, or) + } + Self::Exclude { operand } => { + let node_indices = indices.collect::>(); + + let result = operand + .evaluate(medrecord, node_indices.clone().into_iter())? + .collect::>(); + + Ok(Box::new( + node_indices + .into_iter() + .filter(move |index| !result.contains(index)), + )) + } + } + } + + #[inline] + pub(crate) fn get_max( + mut indices: impl Iterator, + ) -> MedRecordResult { + let max_index = indices.next().ok_or(MedRecordError::QueryError( + "No indices to compare".to_string(), + ))?; + + indices.try_fold(max_index, |max_index, index| { + match index + .partial_cmp(&max_index) { + Some(Ordering::Greater) => Ok(index), + None => { + let first_dtype = DataType::from(index); + let second_dtype = DataType::from(max_index); + + Err(MedRecordError::QueryError(format!( + "Cannot compare indices of data types {} and {}. Consider narrowing down the indices using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(max_index), + } + }) + } + + #[inline] + pub(crate) fn get_min( + mut indices: impl Iterator, + ) -> MedRecordResult { + let min_index = indices.next().ok_or(MedRecordError::QueryError( + "No indices to compare".to_string(), + ))?; + + indices.try_fold(min_index, |min_index, index| { + match index.partial_cmp(&min_index) { + Some(Ordering::Less) => Ok(index), + None => { + let first_dtype = DataType::from(index); + let second_dtype = DataType::from(min_index); + + Err(MedRecordError::QueryError(format!( + "Cannot compare indices of data types {} and {}. Consider narrowing down the indices using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(min_index), + } + }) + } + #[inline] + pub(crate) fn get_count(indices: impl Iterator) -> NodeIndex { + MedRecordAttribute::Int(indices.count() as i64) + } + + #[inline] + // 🥊💥 + pub(crate) fn get_sum( + mut indices: impl Iterator, + ) -> MedRecordResult { + let first_index = indices + .next() + .ok_or(MedRecordError::QueryError("No indices to sum".to_string()))?; + + indices.try_fold(first_index, |sum, index| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&index); + + sum.add(index).map_err(|_| { + MedRecordError::QueryError(format!( + "Cannot add indices of data types {} and {}. Consider narrowing down the indices using .is_string() or .is_int()", + first_dtype, second_dtype + )) + }) + }) + } + + #[inline] + pub(crate) fn get_first( + mut indices: impl Iterator, + ) -> MedRecordResult { + indices.next().ok_or(MedRecordError::QueryError( + "No indices to get the first".to_string(), + )) + } + + #[inline] + pub(crate) fn get_last(indices: impl Iterator) -> MedRecordResult { + indices.last().ok_or(MedRecordError::QueryError( + "No indices to get the first".to_string(), + )) + } + + #[inline] + fn evaluate_node_index_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let kind = &operand.0.read_or_panic().kind; + + let indices = indices.collect::>(); + + let index = get_node_index!(kind, indices.clone().into_iter()); + + Ok(match operand.evaluate(medrecord, index)? { + Some(_) => Box::new(indices.into_iter()), + None => Box::new(std::iter::empty()), + }) + } + + #[inline] + fn evaluate_node_index_comparison_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator + 'a, + comparison_operand: &NodeIndexComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_index = get_node_index_comparison_operand!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => Ok(Box::new( + indices.filter(move |index| index > &comparison_index), + )), + SingleComparisonKind::GreaterThanOrEqualTo => Ok(Box::new( + indices.filter(move |index| index >= &comparison_index), + )), + SingleComparisonKind::LessThan => Ok(Box::new( + indices.filter(move |index| index < &comparison_index), + )), + SingleComparisonKind::LessThanOrEqualTo => Ok(Box::new( + indices.filter(move |index| index <= &comparison_index), + )), + SingleComparisonKind::EqualTo => Ok(Box::new( + indices.filter(move |index| index == &comparison_index), + )), + SingleComparisonKind::NotEqualTo => Ok(Box::new( + indices.filter(move |index| index != &comparison_index), + )), + SingleComparisonKind::StartsWith => Ok(Box::new( + indices.filter(move |index| index.starts_with(&comparison_index)), + )), + SingleComparisonKind::EndsWith => Ok(Box::new( + indices.filter(move |index| index.ends_with(&comparison_index)), + )), + SingleComparisonKind::Contains => Ok(Box::new( + indices.filter(move |index| index.contains(&comparison_index)), + )), + } + } + + #[inline] + fn evaluate_node_indices_comparison_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator + 'a, + comparison_operand: &NodeIndicesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_indices = match comparison_operand { + NodeIndicesComparisonOperand::Operand(operand) => { + let context = &operand.context; + + let comparison_indices = context.evaluate(medrecord)?.cloned(); + + operand + .evaluate(medrecord, comparison_indices)? + .collect::>() + } + NodeIndicesComparisonOperand::Indices(indices) => indices.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => Ok(Box::new( + indices.filter(move |index| comparison_indices.contains(index)), + )), + MultipleComparisonKind::IsNotIn => Ok(Box::new( + indices.filter(move |index| !comparison_indices.contains(index)), + )), + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + indices: impl Iterator, + operand: &NodeIndexComparisonOperand, + kind: BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_index = get_node_index_comparison_operand!(operand, medrecord); + + let indices = indices + .map(move |index| { + match kind { + BinaryArithmeticKind::Add => index.add(arithmetic_index.clone()), + BinaryArithmeticKind::Sub => index.sub(arithmetic_index.clone()), + BinaryArithmeticKind::Mul => { + index.clone().mul(arithmetic_index.clone()) + } + BinaryArithmeticKind::Pow => { + index.clone().pow(arithmetic_index.clone()) + } + BinaryArithmeticKind::Mod => { + index.clone().r#mod(arithmetic_index.clone()) + } + } + .map_err(|_| { + MedRecordError::QueryError(format!( + "Failed arithmetic operation {}. Consider narrowing down the indices using .is_string() or .is_int()", + kind, + )) + }) + }); + + // TODO: This is a temporary solution. It should be optimized. + Ok(indices.collect::>>()?.into_iter()) + } + + #[inline] + fn evaluate_unary_arithmetic_operation( + indices: impl Iterator, + kind: UnaryArithmeticKind, + ) -> impl Iterator { + indices.map(move |index| match kind { + UnaryArithmeticKind::Abs => index.abs(), + UnaryArithmeticKind::Trim => index.trim(), + UnaryArithmeticKind::TrimStart => index.trim_start(), + UnaryArithmeticKind::TrimEnd => index.trim_end(), + UnaryArithmeticKind::Lowercase => index.lowercase(), + UnaryArithmeticKind::Uppercase => index.uppercase(), + }) + } + + #[inline] + fn evaluate_slice( + indices: impl Iterator, + range: Range, + ) -> impl Iterator { + indices.map(move |index| index.slice(range.clone())) + } + + #[inline] + fn evaluate_either_or<'a>( + medrecord: &'a MedRecord, + indices: impl Iterator, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let indices = indices.collect::>(); + + let either_indices = either.evaluate(medrecord, indices.clone().into_iter())?; + let or_indices = or.evaluate(medrecord, indices.into_iter())?; + + Ok(Box::new(either_indices.chain(or_indices).unique())) + } +} + +#[derive(Debug, Clone)] +pub enum NodeIndexOperation { + NodeIndexComparisonOperation { + operand: NodeIndexComparisonOperand, + kind: SingleComparisonKind, + }, + NodeIndicesComparisonOperation { + operand: NodeIndicesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: NodeIndexComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, + Exclude { + operand: Wrapper, + }, +} + +impl DeepClone for NodeIndexOperation { + fn deep_clone(&self) -> Self { + match self { + Self::NodeIndexComparisonOperation { operand, kind } => { + Self::NodeIndexComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::NodeIndicesComparisonOperation { operand, kind } => { + Self::NodeIndicesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + Self::Exclude { operand } => Self::Exclude { + operand: operand.deep_clone(), + }, + } + } +} + +impl NodeIndexOperation { + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + index: NodeIndex, + ) -> MedRecordResult> { + match self { + Self::NodeIndexComparisonOperation { operand, kind } => { + Self::evaluate_node_index_comparison_operation(medrecord, index, operand, kind) + } + Self::NodeIndicesComparisonOperation { operand, kind } => { + Self::evaluate_node_indices_comparison_operation(medrecord, index, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, index, operand, kind) + } + Self::UnaryArithmeticOperation { kind } => Ok(Some(match kind { + UnaryArithmeticKind::Abs => index.abs(), + UnaryArithmeticKind::Trim => index.trim(), + UnaryArithmeticKind::TrimStart => index.trim_start(), + UnaryArithmeticKind::TrimEnd => index.trim_end(), + UnaryArithmeticKind::Lowercase => index.lowercase(), + UnaryArithmeticKind::Uppercase => index.uppercase(), + })), + Self::Slice(range) => Ok(Some(index.slice(range.clone()))), + Self::IsString => Ok(match index { + MedRecordAttribute::String(_) => Some(index), + _ => None, + }), + Self::IsInt => Ok(match index { + MedRecordAttribute::Int(_) => Some(index), + _ => None, + }), + Self::EitherOr { either, or } => Self::evaluate_either_or(medrecord, index, either, or), + Self::Exclude { operand } => { + let result = operand.evaluate(medrecord, index.clone())?.is_some(); + + Ok(if result { None } else { Some(index) }) + } + } + } + + #[inline] + fn evaluate_node_index_comparison_operation( + medrecord: &MedRecord, + index: NodeIndex, + comparison_operand: &NodeIndexComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_index = get_node_index_comparison_operand!(comparison_operand, medrecord); + + let comparison_result = match kind { + SingleComparisonKind::GreaterThan => index > comparison_index, + SingleComparisonKind::GreaterThanOrEqualTo => index >= comparison_index, + SingleComparisonKind::LessThan => index < comparison_index, + SingleComparisonKind::LessThanOrEqualTo => index <= comparison_index, + SingleComparisonKind::EqualTo => index == comparison_index, + SingleComparisonKind::NotEqualTo => index != comparison_index, + SingleComparisonKind::StartsWith => index.starts_with(&comparison_index), + SingleComparisonKind::EndsWith => index.ends_with(&comparison_index), + SingleComparisonKind::Contains => index.contains(&comparison_index), + }; + + Ok(if comparison_result { Some(index) } else { None }) + } + + #[inline] + fn evaluate_node_indices_comparison_operation( + medrecord: &MedRecord, + index: NodeIndex, + comparison_operand: &NodeIndicesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_indices = match comparison_operand { + NodeIndicesComparisonOperand::Operand(operand) => { + let context = &operand.context; + + let comparison_indices = context.evaluate(medrecord)?.cloned(); + + operand + .evaluate(medrecord, comparison_indices)? + .collect::>() + } + NodeIndicesComparisonOperand::Indices(indices) => indices.clone(), + }; + + let comparison_result = match kind { + MultipleComparisonKind::IsIn => comparison_indices + .into_iter() + .any(|comparison_index| index == comparison_index), + MultipleComparisonKind::IsNotIn => comparison_indices + .into_iter() + .all(|comparison_index| index != comparison_index), + }; + + Ok(if comparison_result { Some(index) } else { None }) + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + index: NodeIndex, + operand: &NodeIndexComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_index = get_node_index_comparison_operand!(operand, medrecord); + + Ok(Some(match kind { + BinaryArithmeticKind::Add => index.add(arithmetic_index)?, + BinaryArithmeticKind::Sub => index.sub(arithmetic_index)?, + BinaryArithmeticKind::Mul => index.mul(arithmetic_index)?, + BinaryArithmeticKind::Pow => index.pow(arithmetic_index)?, + BinaryArithmeticKind::Mod => index.r#mod(arithmetic_index)?, + })) + } + + #[inline] + fn evaluate_either_or( + medrecord: &MedRecord, + index: NodeIndex, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord, index.clone())?; + let or_result = or.evaluate(medrecord, index)?; + + match (either_result, or_result) { + (Some(either_result), _) => Ok(Some(either_result)), + (None, Some(or_result)) => Ok(Some(or_result)), + _ => Ok(None), + } + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/selection.rs b/crates/medmodels-core/src/medrecord/querying/nodes/selection.rs new file mode 100644 index 00000000..d994543d --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/nodes/selection.rs @@ -0,0 +1,35 @@ +use super::NodeOperand; +use crate::{ + errors::MedRecordResult, + medrecord::{querying::wrapper::Wrapper, MedRecord, NodeIndex}, +}; + +#[derive(Debug, Clone)] +pub struct NodeSelection<'a> { + medrecord: &'a MedRecord, + operand: Wrapper, +} + +impl<'a> NodeSelection<'a> { + pub fn new(medrecord: &'a MedRecord, query: Q) -> Self + where + Q: FnOnce(&mut Wrapper), + { + let mut operand = Wrapper::::new(); + + query(&mut operand); + + Self { medrecord, operand } + } + + pub fn iter(self) -> MedRecordResult> { + self.operand.evaluate(self.medrecord) + } + + pub fn collect(self) -> MedRecordResult + where + B: FromIterator<&'a NodeIndex>, + { + Ok(FromIterator::from_iter(self.iter()?)) + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/operation/edge_operation.rs b/crates/medmodels-core/src/medrecord/querying/operation/edge_operation.rs deleted file mode 100644 index f005c53f..00000000 --- a/crates/medmodels-core/src/medrecord/querying/operation/edge_operation.rs +++ /dev/null @@ -1,475 +0,0 @@ -#![allow(clippy::should_implement_trait)] - -use super::{ - operand::{ArithmeticOperation, EdgeIndexInOperand, IntoVecEdgeIndex, ValueOperand}, - AttributeOperation, NodeOperation, Operation, -}; -use crate::medrecord::{ - datatypes::{ - Abs, Ceil, Floor, Lowercase, Mod, Pow, Round, Slice, Sqrt, Trim, TrimEnd, TrimStart, - Uppercase, - }, - EdgeIndex, MedRecord, MedRecordAttribute, -}; - -#[derive(Debug, Clone)] -pub enum EdgeIndexOperation { - Gt(EdgeIndex), - Lt(EdgeIndex), - Gte(EdgeIndex), - Lte(EdgeIndex), - Eq(EdgeIndex), - In(Box), -} - -#[derive(Debug, Clone)] -pub enum EdgeOperation { - Attribute(AttributeOperation), - Index(EdgeIndexOperation), - - ConnectedSource(MedRecordAttribute), - ConnectedTarget(MedRecordAttribute), - InGroup(MedRecordAttribute), - HasAttribute(MedRecordAttribute), - - ConnectedSourceWith(Box), - ConnectedTargetWith(Box), - - HasParallelEdgesWith(Box), - HasParallelEdgesWithSelfComparison(Box), - - And(Box<(EdgeOperation, EdgeOperation)>), - Or(Box<(EdgeOperation, EdgeOperation)>), - Not(Box), -} - -impl Operation for EdgeOperation { - type IndexType = EdgeIndex; - - fn evaluate<'a>( - self, - medrecord: &'a MedRecord, - indices: impl Iterator + 'a, - ) -> Box + 'a> { - match self { - EdgeOperation::Attribute(attribute_operation) => { - Self::evaluate_attribute(indices, attribute_operation, |index| { - medrecord.edge_attributes(index) - }) - } - EdgeOperation::Index(index_operation) => { - Self::evaluate_index(medrecord, indices, index_operation) - } - - EdgeOperation::ConnectedSource(attribute_operand) => Box::new( - Self::evaluate_connected_target(medrecord, indices, attribute_operand), - ), - EdgeOperation::ConnectedTarget(attribute_operand) => Box::new( - Self::evaluate_connected_source(medrecord, indices, attribute_operand), - ), - EdgeOperation::InGroup(attribute_operand) => Box::new(Self::evaluate_in_group( - medrecord, - indices, - attribute_operand, - )), - EdgeOperation::HasAttribute(attribute_operand) => Box::new( - Self::evaluate_has_attribute(indices, attribute_operand, |index| { - medrecord.edge_attributes(index) - }), - ), - - EdgeOperation::ConnectedSourceWith(operation) => Box::new( - Self::evaluate_connected_source_with(medrecord, indices, *operation), - ), - EdgeOperation::ConnectedTargetWith(operation) => Box::new( - Self::evaluate_connected_target_with(medrecord, indices, *operation), - ), - - EdgeOperation::HasParallelEdgesWith(operation) => { - Self::evaluate_has_parallel_edges_with(medrecord, Box::new(indices), *operation) - } - EdgeOperation::HasParallelEdgesWithSelfComparison(operation) => { - Self::evaluate_has_parallel_edges_with_compare_to_self( - medrecord, - Box::new(indices), - *operation, - ) - } - - EdgeOperation::And(operations) => Box::new(Self::evaluate_and( - medrecord, - indices.collect::>(), - (*operations).0, - (*operations).1, - )), - EdgeOperation::Or(operations) => Box::new(Self::evaluate_or( - medrecord, - indices.collect::>(), - (*operations).0, - (*operations).1, - )), - EdgeOperation::Not(operation) => Box::new(Self::evaluate_not( - medrecord, - indices.collect::>(), - *operation, - )), - } - } -} - -impl EdgeOperation { - pub fn and(self, operation: EdgeOperation) -> EdgeOperation { - EdgeOperation::And(Box::new((self, operation))) - } - - pub fn or(self, operation: EdgeOperation) -> EdgeOperation { - EdgeOperation::Or(Box::new((self, operation))) - } - - pub fn xor(self, operation: EdgeOperation) -> EdgeOperation { - EdgeOperation::And(Box::new((self, operation))).not() - } - - pub fn not(self) -> EdgeOperation { - EdgeOperation::Not(Box::new(self)) - } - - fn evaluate_index<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator + 'a, - operation: EdgeIndexOperation, - ) -> Box + 'a> { - match operation { - EdgeIndexOperation::Gt(operand) => { - Box::new(Self::evaluate_index_gt(edge_indices, operand)) - } - EdgeIndexOperation::Lt(operand) => { - Box::new(Self::evaluate_index_lt(edge_indices, operand)) - } - EdgeIndexOperation::Gte(operand) => { - Box::new(Self::evaluate_index_gte(edge_indices, operand)) - } - EdgeIndexOperation::Lte(operand) => { - Box::new(Self::evaluate_index_lte(edge_indices, operand)) - } - EdgeIndexOperation::Eq(operand) => { - Box::new(Self::evaluate_index_eq(edge_indices, operand)) - } - EdgeIndexOperation::In(operands) => Box::new(Self::evaluate_index_in( - edge_indices, - operands.into_vec_edge_index(medrecord), - )), - } - } - - fn evaluate_connected_target<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - ) -> impl Iterator { - edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - *endpoints.1 == attribute_operand - }) - } - - fn evaluate_connected_source<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - ) -> impl Iterator { - edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - *endpoints.0 == attribute_operand - }) - } - - fn evaluate_in_group<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - ) -> impl Iterator { - let edges_in_group = match medrecord.edges_in_group(&attribute_operand) { - Ok(edges_in_group) => edges_in_group.collect::>(), - Err(_) => Vec::new(), - }; - - edge_indices.filter(move |index| edges_in_group.contains(index)) - } - - fn evaluate_connected_target_with<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator, - operation: NodeOperation, - ) -> impl Iterator { - edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - operation - .clone() - .evaluate(medrecord, vec![endpoints.1].into_iter()) - .count() - > 0 - }) - } - - fn evaluate_connected_source_with<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator, - operation: NodeOperation, - ) -> impl Iterator { - edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - operation - .clone() - .evaluate(medrecord, vec![endpoints.0].into_iter()) - .count() - > 0 - }) - } - - fn evaluate_has_parallel_edges_with<'a>( - medrecord: &'a MedRecord, - edge_indices: Box + 'a>, - operation: EdgeOperation, - ) -> Box + 'a> { - Box::new(edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - let edges = medrecord - .edges_connecting(vec![endpoints.0], vec![endpoints.1]) - .filter(|other_index| other_index != index); - - operation.clone().evaluate(medrecord, edges).count() > 0 - })) - } - - fn convert_value_operand<'a>( - medrecord: &'a MedRecord, - index: &'a EdgeIndex, - value_operand: ValueOperand, - ) -> Option { - match value_operand { - ValueOperand::Value(value) => Some(ValueOperand::Value(value)), - ValueOperand::Evaluate(attribute) => Some(ValueOperand::Value( - medrecord - .edge_attributes(index) - .ok()? - .get(&attribute)? - .clone(), - )), - ValueOperand::ArithmeticOperation(operation, attribute, other_value) => { - let value = medrecord.edge_attributes(index).ok()?.get(&attribute)?; - - let result = match operation { - ArithmeticOperation::Addition => value.clone() + other_value, - ArithmeticOperation::Subtraction => value.clone() - other_value, - ArithmeticOperation::Multiplication => value.clone() * other_value, - ArithmeticOperation::Division => value.clone() / other_value, - ArithmeticOperation::Power => value.clone().pow(other_value), - ArithmeticOperation::Modulo => value.clone().r#mod(other_value), - } - .ok()?; - - Some(ValueOperand::Value(result)) - } - ValueOperand::Slice(attribute, range) => { - let value = medrecord.edge_attributes(index).ok()?.get(&attribute)?; - - Some(ValueOperand::Value(value.clone().slice(range))) - } - ValueOperand::TransformationOperation(operation, attribute) => { - let value = medrecord.edge_attributes(index).ok()?.get(&attribute)?; - - let result = match operation { - super::operand::TransformationOperation::Round => value.clone().round(), - super::operand::TransformationOperation::Ceil => value.clone().ceil(), - super::operand::TransformationOperation::Floor => value.clone().floor(), - super::operand::TransformationOperation::Abs => value.clone().abs(), - super::operand::TransformationOperation::Sqrt => value.clone().sqrt(), - super::operand::TransformationOperation::Trim => value.clone().trim(), - super::operand::TransformationOperation::TrimStart => { - value.clone().trim_start() - } - super::operand::TransformationOperation::TrimEnd => value.clone().trim_end(), - super::operand::TransformationOperation::Lowercase => value.clone().lowercase(), - super::operand::TransformationOperation::Uppercase => value.clone().uppercase(), - }; - - Some(ValueOperand::Value(result)) - } - } - } - fn evaluate_has_parallel_edges_with_compare_to_self<'a>( - medrecord: &'a MedRecord, - edge_indices: Box + 'a>, - operation: EdgeOperation, - ) -> Box + 'a> { - Box::new(edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - let edges = medrecord - .edges_connecting(vec![endpoints.0], vec![endpoints.1]) - .filter(|other_index| other_index != index); - - let operation = operation.clone(); - - let EdgeOperation::Attribute(operation) = operation else { - return operation.evaluate(medrecord, edges).count() > 0; - }; - - match operation { - AttributeOperation::Gt(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Gt(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Lt(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Lt(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Gte(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Gte(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Lte(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Lte(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Eq(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Eq(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Neq(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Neq(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::In(attribute, value) => { - Self::evaluate_attribute( - edges, - AttributeOperation::In(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::NotIn(attribute, value) => { - Self::evaluate_attribute( - edges, - AttributeOperation::In(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::StartsWith(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::StartsWith(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::EndsWith(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::EndsWith(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Contains(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Contains(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - } - })) - } -} diff --git a/crates/medmodels-core/src/medrecord/querying/operation/mod.rs b/crates/medmodels-core/src/medrecord/querying/operation/mod.rs deleted file mode 100644 index 174adeda..00000000 --- a/crates/medmodels-core/src/medrecord/querying/operation/mod.rs +++ /dev/null @@ -1,394 +0,0 @@ -mod edge_operation; -mod node_operation; -mod operand; - -pub use self::{ - edge_operation::EdgeOperation, - node_operation::NodeOperation, - operand::{ - edge, node, ArithmeticOperation, EdgeAttributeOperand, EdgeIndexOperand, EdgeOperand, - NodeAttributeOperand, NodeIndexOperand, NodeOperand, TransformationOperation, ValueOperand, - }, -}; -use crate::{ - errors::MedRecordError, - medrecord::{ - datatypes::{ - Abs, Ceil, Contains, EndsWith, Floor, Lowercase, Mod, PartialNeq, Pow, Round, Slice, - Sqrt, StartsWith, Trim, TrimEnd, TrimStart, Uppercase, - }, - Attributes, MedRecord, MedRecordAttribute, MedRecordValue, - }, -}; - -macro_rules! implement_attribute_evaluate { - ($name: ident, $evaluate: ident) => { - fn $name<'a, P>( - node_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - value_operand: ValueOperand, - attributes_for_index_fn: P, - ) -> impl Iterator - where - P: Fn(&Self::IndexType) -> Result<&'a Attributes, MedRecordError>, - Self::IndexType: 'a, - { - node_indices.filter(move |index| { - let Ok(attributes) = attributes_for_index_fn(index) else { - return false; - }; - - let Some(value) = attributes.get(&attribute_operand) else { - return false; - }; - - match &value_operand { - ValueOperand::Value(value_operand) => value.$evaluate(value_operand), - ValueOperand::Evaluate(value_attribute) => { - let Some(other) = attributes.get(&value_attribute) else { - return false; - }; - - value.$evaluate(other) - } - ValueOperand::ArithmeticOperation( - operation, - value_attribute, - value_operand, - ) => { - let Some(other) = attributes.get(&value_attribute) else { - return false; - }; - - let operation = match operation { - ArithmeticOperation::Addition => other.clone() + value_operand.clone(), - ArithmeticOperation::Subtraction => { - other.clone() - value_operand.clone() - } - ArithmeticOperation::Multiplication => { - other.clone() * value_operand.clone() - } - ArithmeticOperation::Division => other.clone() / value_operand.clone(), - ArithmeticOperation::Power => other.clone().pow(value_operand.clone()), - ArithmeticOperation::Modulo => { - other.clone().r#mod(value_operand.clone()) - } - }; - - match operation { - Ok(operation) => value.$evaluate(&operation), - Err(_) => false, - } - } - ValueOperand::TransformationOperation(operation, value_attribute) => { - let Some(other) = attributes.get(&value_attribute) else { - return false; - }; - - let operation = match operation { - TransformationOperation::Round => other.clone().round(), - TransformationOperation::Ceil => other.clone().ceil(), - TransformationOperation::Floor => other.clone().floor(), - TransformationOperation::Abs => other.clone().abs(), - TransformationOperation::Sqrt => other.clone().sqrt(), - TransformationOperation::Trim => other.clone().trim(), - TransformationOperation::TrimStart => other.clone().trim_start(), - TransformationOperation::TrimEnd => other.clone().trim_end(), - TransformationOperation::Lowercase => other.clone().lowercase(), - TransformationOperation::Uppercase => other.clone().uppercase(), - }; - - value.$evaluate(&operation) - } - ValueOperand::Slice(value_attribute, range) => { - let Some(other) = attributes.get(&value_attribute) else { - return false; - }; - - value.$evaluate(&other.clone().slice(range.clone())) - } - } - }) - } - }; -} - -macro_rules! implement_index_evaluate { - ($name: ident, $evaluate: ident) => { - fn $name<'a>( - indices: impl Iterator, - operand: Self::IndexType, - ) -> impl Iterator - where - Self::IndexType: 'a, - { - indices.filter(move |index| (*index).$evaluate(&operand)) - } - }; -} - -pub(super) trait Operation: Sized { - type IndexType: PartialEq + PartialNeq + PartialOrd; - - fn evaluate_and<'a>( - medrecord: &'a MedRecord, - indices: Vec<&'a Self::IndexType>, - operation1: Self, - operation2: Self, - ) -> impl Iterator { - let operation1_indices = operation1 - .evaluate(medrecord, indices.clone().into_iter()) - .collect::>(); - let operation2_indices = operation2 - .evaluate(medrecord, indices.clone().into_iter()) - .collect::>(); - - indices.into_iter().filter(move |index| { - operation1_indices.contains(index) && operation2_indices.contains(index) - }) - } - - fn evaluate_or<'a>( - medrecord: &'a MedRecord, - indices: Vec<&'a Self::IndexType>, - operation1: Self, - operation2: Self, - ) -> impl Iterator { - let operation1_indices = operation1 - .evaluate(medrecord, indices.clone().into_iter()) - .collect::>(); - let operation2_indices = operation2 - .evaluate(medrecord, indices.clone().into_iter()) - .collect::>(); - - indices.into_iter().filter(move |index| { - operation1_indices.contains(index) || operation2_indices.contains(index) - }) - } - - fn evaluate_not<'a>( - medrecord: &'a MedRecord, - indices: Vec<&'a Self::IndexType>, - operation: Self, - ) -> impl Iterator { - let operation_indices = operation - .evaluate(medrecord, indices.clone().into_iter()) - .collect::>(); - - indices - .into_iter() - .filter(move |index| !operation_indices.contains(index)) - } - - fn evaluate_attribute_in<'a, P>( - node_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - value_operands: Vec, - attributes_for_index_fn: P, - ) -> impl Iterator - where - P: Fn(&Self::IndexType) -> Result<&'a Attributes, MedRecordError>, - Self::IndexType: 'a, - { - node_indices.filter(move |index| { - let Ok(attributes) = attributes_for_index_fn(index) else { - return false; - }; - - let Some(value) = attributes.get(&attribute_operand) else { - return false; - }; - - value_operands.contains(value) - }) - } - - fn evaluate_attribute_not_in<'a, P>( - node_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - value_operands: Vec, - attributes_for_index_fn: P, - ) -> impl Iterator - where - P: Fn(&Self::IndexType) -> Result<&'a Attributes, MedRecordError>, - Self::IndexType: 'a, - { - node_indices.filter(move |index| { - let Ok(attributes) = attributes_for_index_fn(index) else { - return false; - }; - - let Some(value) = attributes.get(&attribute_operand) else { - return false; - }; - - !value_operands.contains(value) - }) - } - - implement_attribute_evaluate!(evaluate_attribute_gt, gt); - implement_attribute_evaluate!(evaluate_attribute_lt, lt); - implement_attribute_evaluate!(evaluate_attribute_gte, ge); - implement_attribute_evaluate!(evaluate_attribute_lte, le); - implement_attribute_evaluate!(evaluate_attribute_eq, eq); - implement_attribute_evaluate!(evaluate_attribute_neq, neq); - implement_attribute_evaluate!(evaluate_attribute_starts_with, starts_with); - implement_attribute_evaluate!(evaluate_attribute_ends_with, ends_with); - implement_attribute_evaluate!(evaluate_attribute_contains, contains); - - fn evaluate_has_attribute<'a, P>( - indices: impl Iterator, - attribute_operand: MedRecordAttribute, - attributes_for_index_fn: P, - ) -> impl Iterator - where - P: Fn(&Self::IndexType) -> Result<&'a Attributes, MedRecordError>, - Self::IndexType: 'a, - { - indices.filter(move |index| { - let Ok(attributes) = attributes_for_index_fn(index) else { - return false; - }; - - attributes.contains_key(&attribute_operand) - }) - } - - fn evaluate_attribute<'a, P>( - indices: impl Iterator + 'a, - operation: AttributeOperation, - attributes_for_index_fn: P, - ) -> Box + 'a> - where - P: Fn(&Self::IndexType) -> Result<&'a Attributes, MedRecordError> + 'a, - Self: 'a, - { - match operation { - AttributeOperation::Gt(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_gt( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Lt(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_lt( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Gte(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_gte( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Lte(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_lte( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Eq(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_eq( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Neq(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_neq( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::In(attribute_operand, value_operands) => { - Box::new(Self::evaluate_attribute_in( - indices, - attribute_operand, - value_operands, - attributes_for_index_fn, - )) - } - AttributeOperation::NotIn(attribute_operand, value_operands) => { - Box::new(Self::evaluate_attribute_not_in( - indices, - attribute_operand, - value_operands, - attributes_for_index_fn, - )) - } - AttributeOperation::StartsWith(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_starts_with( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::EndsWith(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_ends_with( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Contains(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_contains( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - } - } - - implement_index_evaluate!(evaluate_index_gt, gt); - implement_index_evaluate!(evaluate_index_lt, lt); - implement_index_evaluate!(evaluate_index_gte, ge); - implement_index_evaluate!(evaluate_index_lte, le); - implement_index_evaluate!(evaluate_index_eq, eq); - - fn evaluate_index_in<'a>( - indices: impl Iterator, - operands: Vec, - ) -> impl Iterator - where - Self::IndexType: 'a, - { - indices.filter(move |index| operands.contains(index)) - } - - fn evaluate<'a>( - self, - medrecord: &'a MedRecord, - indices: impl Iterator + 'a, - ) -> Box + 'a>; -} - -#[derive(Debug, Clone)] -pub enum AttributeOperation { - Gt(MedRecordAttribute, ValueOperand), - Lt(MedRecordAttribute, ValueOperand), - Gte(MedRecordAttribute, ValueOperand), - Lte(MedRecordAttribute, ValueOperand), - Eq(MedRecordAttribute, ValueOperand), - Neq(MedRecordAttribute, ValueOperand), - In(MedRecordAttribute, Vec), - NotIn(MedRecordAttribute, Vec), - StartsWith(MedRecordAttribute, ValueOperand), - EndsWith(MedRecordAttribute, ValueOperand), - Contains(MedRecordAttribute, ValueOperand), -} diff --git a/crates/medmodels-core/src/medrecord/querying/operation/node_operation.rs b/crates/medmodels-core/src/medrecord/querying/operation/node_operation.rs deleted file mode 100644 index 677db205..00000000 --- a/crates/medmodels-core/src/medrecord/querying/operation/node_operation.rs +++ /dev/null @@ -1,246 +0,0 @@ -#![allow(clippy::should_implement_trait)] - -use super::{ - edge_operation::EdgeOperation, - operand::{IntoVecNodeIndex, NodeIndexInOperand}, - AttributeOperation, Operation, -}; -use crate::medrecord::{ - datatypes::{Contains, EndsWith, StartsWith}, - MedRecord, MedRecordAttribute, NodeIndex, -}; - -macro_rules! implement_index_evaluate { - ($name: ident, $evaluate: ident) => { - fn $name<'a>( - indices: impl Iterator, - operand: NodeIndex, - ) -> impl Iterator { - indices.filter(move |index| (*index).$evaluate(&operand)) - } - }; -} - -#[derive(Debug, Clone)] -pub enum NodeIndexOperation { - Gt(NodeIndex), - Lt(NodeIndex), - Gte(NodeIndex), - Lte(NodeIndex), - Eq(NodeIndex), - In(Box), - StartsWith(NodeIndex), - EndsWith(NodeIndex), - Contains(NodeIndex), -} - -#[derive(Debug, Clone)] -pub enum NodeOperation { - Attribute(AttributeOperation), - Index(NodeIndexOperation), - - InGroup(MedRecordAttribute), - HasAttribute(MedRecordAttribute), - - HasIncomingEdgeWith(Box), - HasOutgoingEdgeWith(Box), - HasNeighborWith(Box), - HasNeighborUndirectedWith(Box), - - And(Box<(NodeOperation, NodeOperation)>), - Or(Box<(NodeOperation, NodeOperation)>), - Not(Box), -} - -impl Operation for NodeOperation { - type IndexType = NodeIndex; - - fn evaluate<'a>( - self, - medrecord: &'a MedRecord, - indices: impl Iterator + 'a, - ) -> Box + 'a> { - match self { - NodeOperation::Attribute(attribute_operation) => { - Self::evaluate_attribute(indices, attribute_operation, |index| { - medrecord.node_attributes(index) - }) - } - NodeOperation::Index(index_operation) => { - Self::evaluate_index(medrecord, indices, index_operation) - } - - NodeOperation::InGroup(attribute_operand) => Box::new(Self::evaluate_in_group( - medrecord, - indices, - attribute_operand, - )), - NodeOperation::HasAttribute(attribute_operand) => Box::new( - Self::evaluate_has_attribute(indices, attribute_operand, |index| { - medrecord.node_attributes(index) - }), - ), - - NodeOperation::HasOutgoingEdgeWith(operation) => Box::new( - Self::evaluate_has_outgoing_edge_with(medrecord, indices, *operation), - ), - NodeOperation::HasIncomingEdgeWith(operation) => Box::new( - Self::evaluate_has_incoming_edge_with(medrecord, indices, *operation), - ), - NodeOperation::HasNeighborWith(operation) => Box::new( - Self::evaluate_has_neighbor_with(medrecord, indices, *operation), - ), - NodeOperation::HasNeighborUndirectedWith(operation) => Box::new( - Self::evaluate_has_neighbor_undirected_with(medrecord, indices, *operation), - ), - - NodeOperation::And(operations) => Box::new(Self::evaluate_and( - medrecord, - indices.collect::>(), - (*operations).0, - (*operations).1, - )), - NodeOperation::Or(operations) => Box::new(Self::evaluate_or( - medrecord, - indices.collect::>(), - (*operations).0, - (*operations).1, - )), - NodeOperation::Not(operation) => Box::new(Self::evaluate_not( - medrecord, - indices.collect::>(), - *operation, - )), - } - } -} - -impl NodeOperation { - pub fn and(self, operation: NodeOperation) -> NodeOperation { - NodeOperation::And(Box::new((self, operation))) - } - - pub fn or(self, operation: NodeOperation) -> NodeOperation { - NodeOperation::Or(Box::new((self, operation))) - } - - pub fn xor(self, operation: NodeOperation) -> NodeOperation { - NodeOperation::And(Box::new((self, operation))).not() - } - - pub fn not(self) -> NodeOperation { - NodeOperation::Not(Box::new(self)) - } - - fn evaluate_index<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator + 'a, - operation: NodeIndexOperation, - ) -> Box + 'a> { - match operation { - NodeIndexOperation::Gt(operand) => { - Box::new(Self::evaluate_index_gt(node_indices, operand)) - } - NodeIndexOperation::Lt(operand) => { - Box::new(Self::evaluate_index_lt(node_indices, operand)) - } - NodeIndexOperation::Gte(operand) => { - Box::new(Self::evaluate_index_gte(node_indices, operand)) - } - NodeIndexOperation::Lte(operand) => { - Box::new(Self::evaluate_index_lte(node_indices, operand)) - } - NodeIndexOperation::Eq(operand) => { - Box::new(Self::evaluate_index_eq(node_indices, operand)) - } - NodeIndexOperation::In(operands) => Box::new(Self::evaluate_index_in( - node_indices, - operands.into_vec_node_index(medrecord), - )), - NodeIndexOperation::StartsWith(operand) => { - Box::new(Self::evaluate_index_starts_with(node_indices, operand)) - } - NodeIndexOperation::EndsWith(operand) => { - Box::new(Self::evaluate_index_ends_with(node_indices, operand)) - } - NodeIndexOperation::Contains(operand) => { - Box::new(Self::evaluate_index_contains(node_indices, operand)) - } - } - } - - implement_index_evaluate!(evaluate_index_starts_with, starts_with); - implement_index_evaluate!(evaluate_index_ends_with, ends_with); - implement_index_evaluate!(evaluate_index_contains, contains); - - fn evaluate_in_group<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - ) -> impl Iterator { - let nodes_in_group = match medrecord.nodes_in_group(&attribute_operand) { - Ok(nodes_in_group) => nodes_in_group.collect::>(), - Err(_) => Vec::new(), - }; - - node_indices.filter(move |index| nodes_in_group.contains(index)) - } - - fn evaluate_has_outgoing_edge_with<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator, - operation: EdgeOperation, - ) -> impl Iterator { - node_indices.filter(move |index| { - let Ok(edges) = medrecord.outgoing_edges(index) else { - return false; - }; - - let edge_indices = operation.clone().evaluate(medrecord, edges); - - edge_indices.count() > 0 - }) - } - - fn evaluate_has_incoming_edge_with<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator, - operation: EdgeOperation, - ) -> impl Iterator { - node_indices.filter(move |index| { - let Ok(edges) = medrecord.incoming_edges(index) else { - return false; - }; - - operation.clone().evaluate(medrecord, edges).count() > 0 - }) - } - - fn evaluate_has_neighbor_with<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator, - operation: NodeOperation, - ) -> impl Iterator { - node_indices.filter(move |index| { - let Ok(neighbors) = medrecord.neighbors(index) else { - return false; - }; - - operation.clone().evaluate(medrecord, neighbors).count() > 0 - }) - } - - fn evaluate_has_neighbor_undirected_with<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator, - operation: NodeOperation, - ) -> impl Iterator { - node_indices.filter(move |index| { - let Ok(neighbors) = medrecord.neighbors_undirected(index) else { - return false; - }; - - operation.clone().evaluate(medrecord, neighbors).count() > 0 - }) - } -} diff --git a/crates/medmodels-core/src/medrecord/querying/operation/operand.rs b/crates/medmodels-core/src/medrecord/querying/operation/operand.rs deleted file mode 100644 index c7b7849e..00000000 --- a/crates/medmodels-core/src/medrecord/querying/operation/operand.rs +++ /dev/null @@ -1,649 +0,0 @@ -#![allow(clippy::should_implement_trait)] - -use super::{ - edge_operation::EdgeIndexOperation, - node_operation::{NodeIndexOperation, NodeOperation}, - AttributeOperation, EdgeOperation, Operation, -}; -use crate::medrecord::{ - EdgeIndex, Group, MedRecord, MedRecordAttribute, MedRecordValue, NodeIndex, -}; -use std::{fmt::Debug, ops::Range}; - -#[derive(Debug, Clone)] -pub enum ArithmeticOperation { - Addition, - Subtraction, - Multiplication, - Division, - Power, - Modulo, -} - -#[derive(Debug, Clone)] -pub enum TransformationOperation { - Round, - Ceil, - Floor, - Abs, - Sqrt, - - Trim, - TrimStart, - TrimEnd, - - Lowercase, - Uppercase, -} - -#[derive(Debug, Clone)] -pub enum ValueOperand { - Value(MedRecordValue), - Evaluate(MedRecordAttribute), - ArithmeticOperation(ArithmeticOperation, MedRecordAttribute, MedRecordValue), - TransformationOperation(TransformationOperation, MedRecordAttribute), - Slice(MedRecordAttribute, Range), -} - -pub trait IntoValueOperand { - fn into_value_operand(self) -> ValueOperand; -} - -impl> IntoValueOperand for T { - fn into_value_operand(self) -> ValueOperand { - ValueOperand::Value(self.into()) - } -} -impl IntoValueOperand for NodeAttributeOperand { - fn into_value_operand(self) -> ValueOperand { - ValueOperand::Evaluate(self.into()) - } -} -impl IntoValueOperand for EdgeAttributeOperand { - fn into_value_operand(self) -> ValueOperand { - ValueOperand::Evaluate(self.into()) - } -} -impl IntoValueOperand for ValueOperand { - fn into_value_operand(self) -> ValueOperand { - self - } -} - -#[derive(Debug, Clone)] -pub struct NodeAttributeOperand(MedRecordAttribute); - -impl From for NodeAttributeOperand { - fn from(value: MedRecordAttribute) -> Self { - NodeAttributeOperand(value) - } -} - -impl From for MedRecordAttribute { - fn from(val: NodeAttributeOperand) -> Self { - val.0 - } -} - -impl NodeAttributeOperand { - pub fn greater(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Gt( - self.into(), - operand.into_value_operand(), - )) - } - pub fn less(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Lt( - self.into(), - operand.into_value_operand(), - )) - } - pub fn greater_or_equal(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Gte( - self.into(), - operand.into_value_operand(), - )) - } - pub fn less_or_equal(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Lte( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn equal(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Eq( - self.into(), - operand.into_value_operand(), - )) - } - pub fn not_equal(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Neq( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn r#in(self, operand: Vec>) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::In( - self.into(), - operand.into_iter().map(|value| value.into()).collect(), - )) - } - pub fn not_in(self, operand: Vec>) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::NotIn( - self.into(), - operand.into_iter().map(|value| value.into()).collect(), - )) - } - - pub fn starts_with(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::StartsWith( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn ends_with(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::EndsWith( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn contains(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Contains( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn add(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Addition, self.into(), value.into()) - } - - pub fn sub(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation( - ArithmeticOperation::Subtraction, - self.into(), - value.into(), - ) - } - - pub fn mul(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation( - ArithmeticOperation::Multiplication, - self.into(), - value.into(), - ) - } - - pub fn div(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Division, self.into(), value.into()) - } - - pub fn pow(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Power, self.into(), value.into()) - } - - pub fn r#mod(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Modulo, self.into(), value.into()) - } - - pub fn round(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Round, self.into()) - } - - pub fn ceil(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Ceil, self.into()) - } - - pub fn floor(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Floor, self.into()) - } - - pub fn abs(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Abs, self.into()) - } - - pub fn sqrt(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Sqrt, self.into()) - } - - pub fn trim(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Trim, self.into()) - } - - pub fn trim_start(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::TrimStart, self.into()) - } - - pub fn trim_end(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::TrimEnd, self.into()) - } - - pub fn lowercase(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Lowercase, self.into()) - } - - pub fn uppercase(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Uppercase, self.into()) - } - - pub fn slice(self, range: Range) -> ValueOperand { - ValueOperand::Slice(self.into(), range) - } -} - -#[derive(Debug, Clone)] -pub struct EdgeAttributeOperand(MedRecordAttribute); - -impl From for MedRecordAttribute { - fn from(val: EdgeAttributeOperand) -> Self { - val.0 - } -} - -impl EdgeAttributeOperand { - pub fn greater(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Gt( - self.into(), - operand.into_value_operand(), - )) - } - pub fn less(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Lt( - self.into(), - operand.into_value_operand(), - )) - } - pub fn greater_or_equal(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Gte( - self.into(), - operand.into_value_operand(), - )) - } - pub fn less_or_equal(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Lte( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn equal(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Eq( - self.into(), - operand.into_value_operand(), - )) - } - pub fn not_equal(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Neq( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn r#in(self, operand: Vec>) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::In( - self.into(), - operand.into_iter().map(|value| value.into()).collect(), - )) - } - pub fn not_in(self, operand: Vec>) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::NotIn( - self.into(), - operand.into_iter().map(|value| value.into()).collect(), - )) - } - - pub fn starts_with(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::StartsWith( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn ends_with(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::EndsWith( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn contains(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Contains( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn add(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Addition, self.into(), value.into()) - } - - pub fn sub(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation( - ArithmeticOperation::Subtraction, - self.into(), - value.into(), - ) - } - - pub fn mul(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation( - ArithmeticOperation::Multiplication, - self.into(), - value.into(), - ) - } - - pub fn div(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Division, self.into(), value.into()) - } - - pub fn pow(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Power, self.into(), value.into()) - } - - pub fn r#mod(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Modulo, self.into(), value.into()) - } - - pub fn round(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Round, self.into()) - } - - pub fn ceil(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Ceil, self.into()) - } - - pub fn floor(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Floor, self.into()) - } - - pub fn abs(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Abs, self.into()) - } - - pub fn sqrt(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Sqrt, self.into()) - } - - pub fn trim(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Trim, self.into()) - } - - pub fn trim_start(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::TrimStart, self.into()) - } - - pub fn trim_end(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::TrimEnd, self.into()) - } - - pub fn lowercase(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Lowercase, self.into()) - } - - pub fn uppercase(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Uppercase, self.into()) - } - - pub fn slice(self, range: Range) -> ValueOperand { - ValueOperand::Slice(self.into(), range) - } -} - -#[derive(Debug, Clone)] -pub enum NodeIndexInOperand { - Vector(Vec), - Operation(NodeOperation), -} - -impl From> for NodeIndexInOperand -where - T: Into, -{ - fn from(value: Vec) -> NodeIndexInOperand { - NodeIndexInOperand::Vector(value.into_iter().map(|value| value.into()).collect()) - } -} - -impl From for NodeIndexInOperand { - fn from(value: NodeOperation) -> Self { - NodeIndexInOperand::Operation(value) - } -} - -pub(super) trait IntoVecNodeIndex { - fn into_vec_node_index(self, medrecord: &MedRecord) -> Vec; -} - -impl IntoVecNodeIndex for NodeIndexInOperand { - fn into_vec_node_index(self, medrecord: &MedRecord) -> Vec { - match self { - NodeIndexInOperand::Vector(value) => value, - NodeIndexInOperand::Operation(operation) => operation - .evaluate(medrecord, medrecord.node_indices()) - .cloned() - .collect(), - } - } -} - -#[derive(Debug, Clone)] -pub struct NodeIndexOperand; - -impl NodeIndexOperand { - pub fn greater(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Gt(operand.into())) - } - pub fn less(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Lt(operand.into())) - } - pub fn greater_or_equal(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Gte(operand.into())) - } - pub fn less_or_equal(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Lte(operand.into())) - } - - pub fn equal(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Eq(operand.into())) - } - pub fn not_equal(self, operand: impl Into) -> NodeOperation { - self.equal(operand).not() - } - - pub fn r#in(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::In(Box::new(operand.into()))) - } - pub fn not_in(self, operand: impl Into) -> NodeOperation { - self.r#in(operand).not() - } - - pub fn starts_with(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::StartsWith(operand.into())) - } - - pub fn ends_with(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::EndsWith(operand.into())) - } - - pub fn contains(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Contains(operand.into())) - } -} - -#[derive(Debug, Clone)] -pub struct NodeOperand; - -impl NodeOperand { - pub fn in_group(self, operand: impl Into) -> NodeOperation { - NodeOperation::InGroup(operand.into()) - } - - pub fn has_attribute(self, operand: impl Into) -> NodeOperation { - NodeOperation::HasAttribute(operand.into()) - } - - pub fn has_outgoing_edge_with(self, operation: EdgeOperation) -> NodeOperation { - NodeOperation::HasOutgoingEdgeWith(operation.into()) - } - pub fn has_incoming_edge_with(self, operation: EdgeOperation) -> NodeOperation { - NodeOperation::HasIncomingEdgeWith(operation.into()) - } - pub fn has_edge_with(self, operation: EdgeOperation) -> NodeOperation { - NodeOperation::HasOutgoingEdgeWith(operation.clone().into()) - .or(NodeOperation::HasIncomingEdgeWith(operation.into())) - } - - pub fn has_neighbor_with(self, operation: NodeOperation) -> NodeOperation { - NodeOperation::HasNeighborWith(Box::new(operation)) - } - pub fn has_neighbor_undirected_with(self, operation: NodeOperation) -> NodeOperation { - NodeOperation::HasNeighborUndirectedWith(Box::new(operation)) - } - - pub fn attribute(self, attribute: impl Into) -> NodeAttributeOperand { - NodeAttributeOperand(attribute.into()) - } - - pub fn index(self) -> NodeIndexOperand { - NodeIndexOperand - } -} - -pub fn node() -> NodeOperand { - NodeOperand -} - -#[derive(Debug, Clone)] -pub enum EdgeIndexInOperand { - Vector(Vec), - Operation(EdgeOperation), -} - -impl> From> for EdgeIndexInOperand { - fn from(value: Vec) -> EdgeIndexInOperand { - EdgeIndexInOperand::Vector(value.into_iter().map(|value| value.into()).collect()) - } -} - -impl From for EdgeIndexInOperand { - fn from(value: EdgeOperation) -> Self { - EdgeIndexInOperand::Operation(value) - } -} - -pub(super) trait IntoVecEdgeIndex { - fn into_vec_edge_index(self, medrecord: &MedRecord) -> Vec; -} - -impl IntoVecEdgeIndex for EdgeIndexInOperand { - fn into_vec_edge_index(self, medrecord: &MedRecord) -> Vec { - match self { - EdgeIndexInOperand::Vector(value) => value, - EdgeIndexInOperand::Operation(operation) => operation - .evaluate(medrecord, medrecord.edge_indices()) - .copied() - .collect(), - } - } -} - -#[derive(Debug, Clone)] -pub struct EdgeIndexOperand; - -impl EdgeIndexOperand { - pub fn greater(self, operand: EdgeIndex) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::Gt(operand)) - } - pub fn less(self, operand: EdgeIndex) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::Lt(operand)) - } - pub fn greater_or_equal(self, operand: EdgeIndex) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::Gte(operand)) - } - pub fn less_or_equal(self, operand: EdgeIndex) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::Lte(operand)) - } - - pub fn equal(self, operand: EdgeIndex) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::Eq(operand)) - } - pub fn not_equal(self, operand: EdgeIndex) -> EdgeOperation { - self.equal(operand).not() - } - - pub fn r#in(self, operand: impl Into) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::In(Box::new(operand.into()))) - } - pub fn not_in(self, operand: impl Into) -> EdgeOperation { - self.r#in(operand).not() - } -} - -#[derive(Debug, Clone)] -pub struct EdgeOperand; - -impl EdgeOperand { - pub fn connected_target(self, operand: impl Into) -> EdgeOperation { - EdgeOperation::ConnectedSource(operand.into()) - } - - pub fn connected_source(self, operand: impl Into) -> EdgeOperation { - EdgeOperation::ConnectedTarget(operand.into()) - } - - pub fn connected(self, operand: impl Into) -> EdgeOperation { - let attribute = operand.into(); - - EdgeOperation::ConnectedSource(attribute.clone()) - .or(EdgeOperation::ConnectedTarget(attribute)) - } - - pub fn in_group(self, operand: impl Into) -> EdgeOperation { - EdgeOperation::InGroup(operand.into()) - } - - pub fn has_attribute(self, operand: impl Into) -> EdgeOperation { - EdgeOperation::HasAttribute(operand.into()) - } - - pub fn connected_source_with(self, operation: NodeOperation) -> EdgeOperation { - EdgeOperation::ConnectedSourceWith(operation.into()) - } - - pub fn connected_target_with(self, operation: NodeOperation) -> EdgeOperation { - EdgeOperation::ConnectedTargetWith(operation.into()) - } - - pub fn connected_with(self, operation: NodeOperation) -> EdgeOperation { - EdgeOperation::ConnectedSourceWith(operation.clone().into()) - .or(EdgeOperation::ConnectedTargetWith(operation.into())) - } - - pub fn has_parallel_edges_with(self, operation: EdgeOperation) -> EdgeOperation { - EdgeOperation::HasParallelEdgesWith(Box::new(operation)) - } - - pub fn has_parallel_edges_with_self_comparison( - self, - operation: EdgeOperation, - ) -> EdgeOperation { - EdgeOperation::HasParallelEdgesWithSelfComparison(Box::new(operation)) - } - - pub fn attribute(self, attribute: impl Into) -> EdgeAttributeOperand { - EdgeAttributeOperand(attribute.into()) - } - - pub fn index(self) -> EdgeIndexOperand { - EdgeIndexOperand - } -} - -pub fn edge() -> EdgeOperand { - EdgeOperand -} diff --git a/crates/medmodels-core/src/medrecord/querying/selection.rs b/crates/medmodels-core/src/medrecord/querying/selection.rs deleted file mode 100644 index 82e8356e..00000000 --- a/crates/medmodels-core/src/medrecord/querying/selection.rs +++ /dev/null @@ -1,1741 +0,0 @@ -use super::operation::{EdgeOperation, NodeOperation, Operation}; -use crate::medrecord::{EdgeIndex, MedRecord, NodeIndex}; - -#[derive(Debug)] -pub struct NodeSelection<'a> { - medrecord: &'a MedRecord, - operation: NodeOperation, -} - -impl<'a> NodeSelection<'a> { - pub fn new(medrecord: &'a MedRecord, operation: NodeOperation) -> Self { - Self { - medrecord, - operation, - } - } - - pub fn iter(self) -> impl Iterator { - self.operation - .evaluate(self.medrecord, self.medrecord.node_indices()) - } - - pub fn collect>(self) -> B { - FromIterator::from_iter(self.iter()) - } -} - -#[derive(Debug)] -pub struct EdgeSelection<'a> { - medrecord: &'a MedRecord, - operation: EdgeOperation, -} - -impl<'a> EdgeSelection<'a> { - pub fn new(medrecord: &'a MedRecord, operation: EdgeOperation) -> Self { - Self { - medrecord, - operation, - } - } - - pub fn iter(self) -> impl Iterator { - self.operation - .evaluate(self.medrecord, self.medrecord.edge_indices()) - } - - pub fn collect>(self) -> B { - FromIterator::from_iter(self.iter()) - } -} - -#[cfg(test)] -mod test { - use crate::medrecord::{edge, node, Attributes, MedRecord, MedRecordAttribute, NodeIndex}; - use std::collections::HashMap; - - fn create_nodes() -> Vec<(NodeIndex, Attributes)> { - vec![ - ( - "0".into(), - HashMap::from([ - ("lorem".into(), "ipsum".into()), - ("dolor".into(), " ipsum ".into()), - ("test".into(), "Ipsum".into()), - ("integer".into(), 1.into()), - ("float".into(), 0.5.into()), - ]), - ), - ( - "1".into(), - HashMap::from([("amet".into(), "consectetur".into())]), - ), - ( - "2".into(), - HashMap::from([("adipiscing".into(), "elit".into())]), - ), - ("3".into(), HashMap::new()), - ] - } - - fn create_edges() -> Vec<(NodeIndex, NodeIndex, Attributes)> { - vec![ - ( - "0".into(), - "1".into(), - HashMap::from([ - ("sed".into(), "do".into()), - ("eiusmod".into(), "tempor".into()), - ("dolor".into(), " do ".into()), - ("test".into(), "DO".into()), - ]), - ), - ( - "1".into(), - "2".into(), - HashMap::from([("incididunt".into(), "ut".into())]), - ), - ( - "0".into(), - "2".into(), - HashMap::from([ - ("test".into(), 1.into()), - ("integer".into(), 1.into()), - ("float".into(), 0.5.into()), - ]), - ), - ( - "0".into(), - "2".into(), - HashMap::from([("test".into(), 0.into())]), - ), - ] - } - - fn create_medrecord() -> MedRecord { - let nodes = create_nodes(); - let edges = create_edges(); - - MedRecord::from_tuples(nodes, Some(edges), None).unwrap() - } - - #[test] - fn test_iter() { - let medrecord = create_medrecord(); - - assert_eq!( - 1, - medrecord - .select_nodes(node().has_attribute("lorem")) - .iter() - .count(), - ); - - assert_eq!( - 1, - medrecord - .select_edges(edge().has_attribute("sed")) - .iter() - .count(), - ); - } - - #[test] - fn test_collect() { - let medrecord = create_medrecord(); - - assert_eq!( - vec![&MedRecordAttribute::from("0")], - medrecord - .select_nodes(node().has_attribute("lorem")) - .collect::>(), - ); - - assert_eq!( - vec![&0], - medrecord - .select_edges(edge().has_attribute("sed")) - .collect::>(), - ); - } - - #[test] - fn test_select_nodes_node() { - let mut medrecord = create_medrecord(); - - medrecord - .add_group("test".into(), Some(vec!["0".into()]), None) - .unwrap(); - - // Node in group - assert_eq!( - 1, - medrecord - .select_nodes(node().in_group("test")) - .iter() - .count(), - ); - - // Node has attribute - assert_eq!( - 1, - medrecord - .select_nodes(node().has_attribute("lorem")) - .iter() - .count(), - ); - - // Node has outgoing edge with - assert_eq!( - 1, - medrecord - .select_nodes(node().has_outgoing_edge_with(edge().index().equal(0))) - .iter() - .count(), - ); - - // Node has incoming edge with - assert_eq!( - 1, - medrecord - .select_nodes(node().has_incoming_edge_with(edge().index().equal(0))) - .iter() - .count(), - ); - - // Node has edge with - assert_eq!( - 2, - medrecord - .select_nodes(node().has_edge_with(edge().index().equal(0))) - .iter() - .count(), - ); - - // Node has neighbor with - assert_eq!( - 2, - medrecord - .select_nodes(node().has_neighbor_with(node().index().equal("2"))) - .iter() - .count(), - ); - assert_eq!( - 1, - medrecord - .select_nodes(node().has_neighbor_with(node().index().equal("1"))) - .iter() - .count(), - ); - - // Node has undirected neighbor with - assert_eq!( - 2, - medrecord - .select_nodes(node().has_neighbor_undirected_with(node().index().equal("1"))) - .iter() - .count(), - ); - } - - #[test] - fn test_select_nodes_node_index() { - let medrecord = create_medrecord(); - - // Index greater - assert_eq!( - 2, - medrecord - .select_nodes(node().index().greater("1")) - .iter() - .count(), - ); - - // Index less - assert_eq!( - 1, - medrecord - .select_nodes(node().index().less("1")) - .iter() - .count(), - ); - - // Index greater or equal - assert_eq!( - 3, - medrecord - .select_nodes(node().index().greater_or_equal("1")) - .iter() - .count(), - ); - - // Index less or equal - assert_eq!( - 2, - medrecord - .select_nodes(node().index().less_or_equal("1")) - .iter() - .count(), - ); - - // Index equal - assert_eq!( - 1, - medrecord - .select_nodes(node().index().equal("1")) - .iter() - .count(), - ); - - // Index not equal - assert_eq!( - 3, - medrecord - .select_nodes(node().index().not_equal("1")) - .iter() - .count(), - ); - - // Index in - assert_eq!( - 1, - medrecord - .select_nodes(node().index().r#in(vec!["1"])) - .iter() - .count(), - ); - - // Index in - assert_eq!( - 1, - medrecord - .select_nodes(node().index().r#in(node().has_attribute("lorem"))) - .iter() - .count(), - ); - - // Index not in - assert_eq!( - 3, - medrecord - .select_nodes(node().index().not_in(node().has_attribute("lorem"))) - .iter() - .count(), - ); - - // Index starts with - assert_eq!( - 1, - medrecord - .select_nodes(node().index().starts_with("1")) - .iter() - .count(), - ); - - // Index ends with - assert_eq!( - 1, - medrecord - .select_nodes(node().index().ends_with("1")) - .iter() - .count(), - ); - - // Index contains - assert_eq!( - 1, - medrecord - .select_nodes(node().index().contains("1")) - .iter() - .count(), - ); - } - - #[test] - fn test_select_nodes_node_attribute() { - let medrecord = create_medrecord(); - - // Attribute greater - assert_eq!( - 0, - medrecord - .select_nodes(node().attribute("lorem").greater("ipsum")) - .iter() - .count(), - ); - - // Attribute less - assert_eq!( - 0, - medrecord - .select_nodes(node().attribute("lorem").less("ipsum")) - .iter() - .count(), - ); - - // Attribute greater or equal - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").greater_or_equal("ipsum")) - .iter() - .count(), - ); - - // Attribute less or equal - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").less_or_equal("ipsum")) - .iter() - .count(), - ); - - // Attribute equal - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").equal("ipsum")) - .iter() - .count(), - ); - - // Attribute not equal - assert_eq!( - 0, - medrecord - .select_nodes(node().attribute("lorem").not_equal("ipsum")) - .iter() - .count(), - ); - - // Attribute in - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").r#in(vec!["ipsum"])) - .iter() - .count(), - ); - - // Attribute not in - assert_eq!( - 0, - medrecord - .select_nodes(node().attribute("lorem").not_in(vec!["ipsum"])) - .iter() - .count(), - ); - - // Attribute starts with - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").starts_with("ip")) - .iter() - .count(), - ); - - // Attribute ends with - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").ends_with("um")) - .iter() - .count(), - ); - - // Attribute contains - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").contains("su")) - .iter() - .count(), - ); - - // Attribute compare to attribute - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").equal(node().attribute("lorem"))) - .iter() - .count(), - ); - - // Attribute compare to attribute - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute add - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").add("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute add - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").add("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - // Returns nothing because can't sub a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").sub("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - // Doesn't work because can't sub a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").sub("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").sub(10)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").sub(10)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mul - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").mul(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mul - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").mul(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - // Returns nothing because can't div a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").div("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - // Doesn't work because can't div a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").div("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").div(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").div(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - // Returns nothing because can't pow a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").pow("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - // Doesn't work because can't pow a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").pow("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").pow(2)) // 1 ** 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").pow(2)) // 1 ** 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - // Returns nothing because can't mod a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").r#mod("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - // Doesn't work because can't mod a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").r#mod("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").r#mod(2)) // 1 % 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").r#mod(2)) // 1 % 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("float").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("float") - .not_equal(node().attribute("float").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute ceil - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("float").ceil()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute ceil - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("float") - .not_equal(node().attribute("float").ceil()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute floor - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("float").floor()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute floor - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("float") - .not_equal(node().attribute("float").floor()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute abs - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").abs()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sqrt - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").sqrt()) // sqrt(1) = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("dolor").trim()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim_start - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("dolor").trim_start()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim_end - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("dolor").trim_end()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute lowercase - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("test").lowercase()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute uppercase - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("test").uppercase()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute slice - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("dolor").slice(2..7)) - ) - .iter() - .count(), - ); - } - - #[test] - fn test_select_edges_edge() { - let mut medrecord = create_medrecord(); - - medrecord - .add_group("test".into(), None, Some(vec![0])) - .unwrap(); - - // Edge connected to target - assert_eq!( - 3, - medrecord - .select_edges(edge().connected_target("2")) - .iter() - .count(), - ); - - // Edge connected to source - assert_eq!( - 3, - medrecord - .select_edges(edge().connected_source("0")) - .iter() - .count(), - ); - - // Edge connected - assert_eq!( - 2, - medrecord.select_edges(edge().connected("1")).iter().count(), - ); - - // Edge in group - assert_eq!( - 1, - medrecord - .select_edges(edge().in_group("test")) - .iter() - .count(), - ); - - // Edge has attribute - assert_eq!( - 1, - medrecord - .select_edges(edge().has_attribute("sed")) - .iter() - .count(), - ); - - // Edge connected to target with - assert_eq!( - 1, - medrecord - .select_edges(edge().connected_target_with(node().index().equal("1"))) - .iter() - .count(), - ); - - // Edge connected to source with - assert_eq!( - 3, - medrecord - .select_edges(edge().connected_source_with(node().index().equal("0"))) - .iter() - .count(), - ); - - // Edge connected with - assert_eq!( - 2, - medrecord - .select_edges(edge().connected_with(node().index().equal("1"))) - .iter() - .count(), - ); - - // Edge has parallel edges with - assert_eq!( - 2, - medrecord - .select_edges(edge().has_parallel_edges_with(edge().has_attribute("test"))) - .iter() - .count(), - ); - - // Edge has parallel edges with self comparison - assert_eq!( - 1, - medrecord - .select_edges( - edge().has_parallel_edges_with_self_comparison( - edge() - .attribute("test") - .equal(edge().attribute("test").sub(1)) - ) - ) - .iter() - .count(), - ); - } - - #[test] - fn test_select_edges_edge_index() { - let medrecord = create_medrecord(); - - // Index greater - assert_eq!( - 2, - medrecord - .select_edges(edge().index().greater(1)) - .iter() - .count(), - ); - - // Index less - assert_eq!( - 1, - medrecord - .select_edges(edge().index().less(1)) - .iter() - .count(), - ); - - // Index greater or equal - assert_eq!( - 3, - medrecord - .select_edges(edge().index().greater_or_equal(1)) - .iter() - .count(), - ); - - // Index less or equal - assert_eq!( - 2, - medrecord - .select_edges(edge().index().less_or_equal(1)) - .iter() - .count(), - ); - - // Index equal - assert_eq!( - 1, - medrecord - .select_edges(edge().index().equal(1)) - .iter() - .count(), - ); - - // Index not equal - assert_eq!( - 3, - medrecord - .select_edges(edge().index().not_equal(1)) - .iter() - .count(), - ); - - // Index in - assert_eq!( - 1, - medrecord - .select_edges(edge().index().r#in(vec![1_usize])) - .iter() - .count(), - ); - - // Index in - assert_eq!( - 1, - medrecord - .select_edges(edge().index().r#in(edge().has_attribute("sed"))) - .iter() - .count(), - ); - - // Index not in - assert_eq!( - 3, - medrecord - .select_edges(edge().index().not_in(edge().has_attribute("sed"))) - .iter() - .count(), - ); - } - - #[test] - fn test_select_edges_edge_attribute() { - let medrecord = create_medrecord(); - - // Attribute greater - assert_eq!( - 0, - medrecord - .select_edges(edge().attribute("sed").greater("do")) - .iter() - .count(), - ); - - // Attribute less - assert_eq!( - 0, - medrecord - .select_edges(edge().attribute("sed").less("do")) - .iter() - .count(), - ); - - // Attribute greater or equal - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").greater_or_equal("do")) - .iter() - .count(), - ); - - // Attribute less or equal - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").less_or_equal("do")) - .iter() - .count(), - ); - - // Attribute equal - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").equal("do")) - .iter() - .count(), - ); - - // Attribute not equal - assert_eq!( - 0, - medrecord - .select_edges(edge().attribute("sed").not_equal("do")) - .iter() - .count(), - ); - - // Attribute in - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").r#in(vec!["do"])) - .iter() - .count(), - ); - - // Attribute not in - assert_eq!( - 0, - medrecord - .select_edges(edge().attribute("sed").not_in(vec!["do"])) - .iter() - .count(), - ); - - // Attribute starts with - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").starts_with("d")) - .iter() - .count(), - ); - - // Attribute ends with - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").ends_with("o")) - .iter() - .count(), - ); - - // Attribute contains - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").contains("do")) - .iter() - .count(), - ); - - // Attribute compare to attribute - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").equal(edge().attribute("sed"))) - .iter() - .count(), - ); - - // Attribute compare to attribute - assert_eq!( - 0, - medrecord - .select_edges(edge().attribute("sed").not_equal(edge().attribute("sed"))) - .iter() - .count(), - ); - - // Attribute compare to attribute add - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("sed").add("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute add - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .not_equal(edge().attribute("sed").add("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - // Returns nothing because can't sub a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("sed").sub("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - // Doesn't work because can't sub a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .not_equal(edge().attribute("sed").sub("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").sub(10)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").sub(10)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mul - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("sed").mul(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mul - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .not_equal(edge().attribute("sed").mul(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - // Returns nothing because can't div a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("sed").div("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - // Doesn't work because can't div a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .not_equal(edge().attribute("sed").div("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").div(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").div(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - // Returns nothing because can't pow a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("lorem") - .equal(edge().attribute("lorem").pow("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - // Doesn't work because can't pow a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("lorem") - .not_equal(edge().attribute("lorem").pow("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").pow(2)) // 1 ** 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").pow(2)) // 1 ** 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - // Returns nothing because can't mod a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("lorem") - .equal(edge().attribute("lorem").r#mod("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - // Doesn't work because can't mod a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("lorem") - .not_equal(edge().attribute("lorem").r#mod("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").r#mod(2)) // 1 % 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").r#mod(2)) // 1 % 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("sed").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .not_equal(edge().attribute("sed").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("float").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("float") - .not_equal(edge().attribute("float").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute ceil - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("float").ceil()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute ceil - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("float") - .not_equal(edge().attribute("float").ceil()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute floor - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("float").floor()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute floor - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("float") - .not_equal(edge().attribute("float").floor()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute abs - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").abs()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sqrt - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").sqrt()) // sqrt(1) = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("dolor").trim()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim_start - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("dolor").trim_start()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim_end - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("dolor").trim_end()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute lowercase - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("test").lowercase()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute uppercase - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("test").uppercase()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute slice - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("dolor").slice(2..4)) - ) - .iter() - .count(), - ); - } -} diff --git a/crates/medmodels-core/src/medrecord/querying/traits.rs b/crates/medmodels-core/src/medrecord/querying/traits.rs new file mode 100644 index 00000000..4e8d33e8 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/traits.rs @@ -0,0 +1,21 @@ +use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; + +pub trait DeepClone { + fn deep_clone(&self) -> Self; +} + +pub(crate) trait ReadWriteOrPanic { + fn read_or_panic(&self) -> RwLockReadGuard<'_, T>; + + fn write_or_panic(&self) -> RwLockWriteGuard<'_, T>; +} + +impl ReadWriteOrPanic for RwLock { + fn read_or_panic(&self) -> RwLockReadGuard<'_, T> { + self.read().unwrap() + } + + fn write_or_panic(&self) -> RwLockWriteGuard<'_, T> { + self.write().unwrap() + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/values/mod.rs b/crates/medmodels-core/src/medrecord/querying/values/mod.rs new file mode 100644 index 00000000..893fa7c9 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/values/mod.rs @@ -0,0 +1,188 @@ +mod operand; +mod operation; + +use super::{ + attributes::{ + self, AttributesTreeOperation, MultipleAttributesOperand, MultipleAttributesOperation, + }, + edges::{EdgeOperand, EdgeOperation}, + nodes::{NodeOperand, NodeOperation}, + BoxedIterator, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{MedRecordAttribute, MedRecordValue}, + MedRecord, +}; +pub use operand::{ + MultipleValuesComparisonOperand, MultipleValuesOperand, SingleValueComparisonOperand, + SingleValueOperand, +}; +use std::fmt::Display; + +macro_rules! get_attributes { + ($operand:ident, $medrecord:ident, $operation:ident, $multiple_attributes_operand:ident) => {{ + let indices = $operand.evaluate($medrecord)?; + + let attributes = $operation::get_attributes($medrecord, indices); + + let attributes = $multiple_attributes_operand + .context + .evaluate($medrecord, attributes)?; + + let attributes: Box> = + match $multiple_attributes_operand.kind { + attributes::MultipleKind::Max => { + Box::new(AttributesTreeOperation::get_max(attributes)?) + } + attributes::MultipleKind::Min => { + Box::new(AttributesTreeOperation::get_min(attributes)?) + } + attributes::MultipleKind::Count => { + Box::new(AttributesTreeOperation::get_count(attributes)?) + } + attributes::MultipleKind::Sum => { + Box::new(AttributesTreeOperation::get_sum(attributes)?) + } + attributes::MultipleKind::First => { + Box::new(AttributesTreeOperation::get_first(attributes)?) + } + attributes::MultipleKind::Last => { + Box::new(AttributesTreeOperation::get_last(attributes)?) + } + }; + + let attributes = $multiple_attributes_operand.evaluate($medrecord, attributes)?; + + Box::new( + MultipleAttributesOperation::get_values($medrecord, attributes)? + .map(|(_, value)| value), + ) + }}; +} + +#[derive(Debug, Clone)] +pub enum SingleKind { + Max, + Min, + Mean, + Median, + Mode, + Std, + Var, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum SingleComparisonKind { + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + NotEqualTo, + StartsWith, + EndsWith, + Contains, +} + +#[derive(Debug, Clone)] +pub enum MultipleComparisonKind { + IsIn, + IsNotIn, +} + +#[derive(Debug, Clone)] +pub enum BinaryArithmeticKind { + Add, + Sub, + Mul, + Div, + Pow, + Mod, +} + +impl Display for BinaryArithmeticKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BinaryArithmeticKind::Add => write!(f, "add"), + BinaryArithmeticKind::Sub => write!(f, "sub"), + BinaryArithmeticKind::Mul => write!(f, "mul"), + BinaryArithmeticKind::Div => write!(f, "div"), + BinaryArithmeticKind::Pow => write!(f, "pow"), + BinaryArithmeticKind::Mod => write!(f, "mod"), + } + } +} + +#[derive(Debug, Clone)] +pub enum UnaryArithmeticKind { + Round, + Ceil, + Floor, + Abs, + Sqrt, + Trim, + TrimStart, + TrimEnd, + Lowercase, + Uppercase, +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone)] +pub enum Context { + NodeOperand(NodeOperand), + EdgeOperand(EdgeOperand), + MultipleAttributesOperand(MultipleAttributesOperand), +} + +impl Context { + pub(crate) fn get_values<'a>( + &self, + medrecord: &'a MedRecord, + attribute: MedRecordAttribute, + ) -> MedRecordResult> { + Ok(match self { + Self::NodeOperand(node_operand) => { + let node_indices = node_operand.evaluate(medrecord)?; + + Box::new( + NodeOperation::get_values(medrecord, node_indices, attribute) + .map(|(_, value)| value), + ) + } + Self::EdgeOperand(edge_operand) => { + let edge_indices = edge_operand.evaluate(medrecord)?; + + Box::new( + EdgeOperation::get_values(medrecord, edge_indices, attribute) + .map(|(_, value)| value), + ) + } + Self::MultipleAttributesOperand(multiple_attributes_operand) => { + match &multiple_attributes_operand.context.context { + attributes::Context::NodeOperand(node_operand) => { + get_attributes!( + node_operand, + medrecord, + NodeOperation, + multiple_attributes_operand + ) + } + attributes::Context::EdgeOperand(edge_operand) => { + get_attributes!( + edge_operand, + medrecord, + EdgeOperation, + multiple_attributes_operand + ) + } + } + } + }) + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/values/operand.rs b/crates/medmodels-core/src/medrecord/querying/values/operand.rs new file mode 100644 index 00000000..087b3261 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/values/operand.rs @@ -0,0 +1,630 @@ +use super::{ + operation::{MultipleValuesOperation, SingleValueOperation}, + BinaryArithmeticKind, Context, MultipleComparisonKind, SingleComparisonKind, SingleKind, + UnaryArithmeticKind, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{ + querying::{ + traits::{DeepClone, ReadWriteOrPanic}, + BoxedIterator, + }, + MedRecordAttribute, MedRecordValue, Wrapper, + }, + MedRecord, +}; +use std::hash::Hash; + +macro_rules! implement_value_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = + Wrapper::::new(self.deep_clone(), SingleKind::$variant); + + self.operations + .push(MultipleValuesOperation::ValueOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_single_value_comparison_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations + .push($operation::SingleValueComparisonOperation { + operand: value.into(), + kind: SingleComparisonKind::$kind, + }); + } + }; +} + +macro_rules! implement_binary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations.push($operation::BinaryArithmeticOpration { + operand: value.into(), + kind: BinaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_unary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name(&mut self) { + self.operations.push($operation::UnaryArithmeticOperation { + kind: UnaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_assertion_operation { + ($name:ident, $operation:expr) => { + pub fn $name(&mut self) { + self.operations.push($operation); + } + }; +} + +macro_rules! implement_wrapper_operand { + ($name:ident) => { + pub fn $name(&self) { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_return { + ($name:ident, $return_operand:ident) => { + pub fn $name(&self) -> Wrapper<$return_operand> { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_argument { + ($name:ident, $value_type:ty) => { + pub fn $name(&self, value: $value_type) { + self.0.write_or_panic().$name(value) + } + }; +} + +#[derive(Debug, Clone)] +pub enum SingleValueComparisonOperand { + Operand(SingleValueOperand), + Value(MedRecordValue), +} + +impl DeepClone for SingleValueComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Value(value) => Self::Value(value.clone()), + } + } +} + +impl From> for SingleValueComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for SingleValueComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From for SingleValueComparisonOperand { + fn from(value: V) -> Self { + Self::Value(value.into()) + } +} + +#[derive(Debug, Clone)] +pub enum MultipleValuesComparisonOperand { + Operand(MultipleValuesOperand), + Values(Vec), +} + +impl DeepClone for MultipleValuesComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Values(value) => Self::Values(value.clone()), + } + } +} + +impl From> for MultipleValuesComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for MultipleValuesComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From> for MultipleValuesComparisonOperand { + fn from(value: Vec) -> Self { + Self::Values(value.into_iter().map(Into::into).collect()) + } +} + +impl + Clone, const N: usize> From<[V; N]> + for MultipleValuesComparisonOperand +{ + fn from(value: [V; N]) -> Self { + value.to_vec().into() + } +} + +#[derive(Debug, Clone)] +pub struct MultipleValuesOperand { + pub(crate) context: Context, + pub(crate) attribute: MedRecordAttribute, + operations: Vec, +} + +impl DeepClone for MultipleValuesOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + attribute: self.attribute.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl MultipleValuesOperand { + pub(crate) fn new(context: Context, attribute: MedRecordAttribute) -> Self { + Self { + context, + attribute, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult> { + let values = Box::new(values) as BoxedIterator<(&'a T, MedRecordValue)>; + + self.operations + .iter() + .try_fold(values, |value_tuples, operation| { + operation.evaluate(medrecord, value_tuples) + }) + } + + implement_value_operation!(max, Max); + implement_value_operation!(min, Min); + implement_value_operation!(mean, Mean); + implement_value_operation!(median, Median); + implement_value_operation!(mode, Mode); + implement_value_operation!(std, Std); + implement_value_operation!(var, Var); + implement_value_operation!(count, Count); + implement_value_operation!(sum, Sum); + implement_value_operation!(first, First); + implement_value_operation!(last, Last); + + implement_single_value_comparison_operation!( + greater_than, + MultipleValuesOperation, + GreaterThan + ); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + MultipleValuesOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, MultipleValuesOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + MultipleValuesOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, MultipleValuesOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, MultipleValuesOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, MultipleValuesOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, MultipleValuesOperation, EndsWith); + implement_single_value_comparison_operation!(contains, MultipleValuesOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(MultipleValuesOperation::MultipleValuesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(MultipleValuesOperation::MultipleValuesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, MultipleValuesOperation, Add); + implement_binary_arithmetic_operation!(sub, MultipleValuesOperation, Sub); + implement_binary_arithmetic_operation!(mul, MultipleValuesOperation, Mul); + implement_binary_arithmetic_operation!(div, MultipleValuesOperation, Div); + implement_binary_arithmetic_operation!(pow, MultipleValuesOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, MultipleValuesOperation, Mod); + + implement_unary_arithmetic_operation!(round, MultipleValuesOperation, Round); + implement_unary_arithmetic_operation!(ceil, MultipleValuesOperation, Ceil); + implement_unary_arithmetic_operation!(floor, MultipleValuesOperation, Floor); + implement_unary_arithmetic_operation!(abs, MultipleValuesOperation, Abs); + implement_unary_arithmetic_operation!(sqrt, MultipleValuesOperation, Sqrt); + implement_unary_arithmetic_operation!(trim, MultipleValuesOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, MultipleValuesOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, MultipleValuesOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, MultipleValuesOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, MultipleValuesOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(MultipleValuesOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, MultipleValuesOperation::IsString); + implement_assertion_operation!(is_int, MultipleValuesOperation::IsInt); + implement_assertion_operation!(is_float, MultipleValuesOperation::IsFloat); + implement_assertion_operation!(is_bool, MultipleValuesOperation::IsBool); + implement_assertion_operation!(is_datetime, MultipleValuesOperation::IsDateTime); + implement_assertion_operation!(is_null, MultipleValuesOperation::IsNull); + implement_assertion_operation!(is_max, MultipleValuesOperation::IsMax); + implement_assertion_operation!(is_min, MultipleValuesOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.attribute.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.attribute.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(MultipleValuesOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } + + pub fn exclude(&mut self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + let mut operand = + Wrapper::::new(self.context.clone(), self.attribute.clone()); + + query(&mut operand); + + self.operations + .push(MultipleValuesOperation::Exclude { operand }); + } +} + +impl Wrapper { + pub(crate) fn new(context: Context, attribute: MedRecordAttribute) -> Self { + MultipleValuesOperand::new(context, attribute).into() + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, values) + } + + implement_wrapper_operand_with_return!(max, SingleValueOperand); + implement_wrapper_operand_with_return!(min, SingleValueOperand); + implement_wrapper_operand_with_return!(mean, SingleValueOperand); + implement_wrapper_operand_with_return!(median, SingleValueOperand); + implement_wrapper_operand_with_return!(mode, SingleValueOperand); + implement_wrapper_operand_with_return!(std, SingleValueOperand); + implement_wrapper_operand_with_return!(var, SingleValueOperand); + implement_wrapper_operand_with_return!(count, SingleValueOperand); + implement_wrapper_operand_with_return!(sum, SingleValueOperand); + implement_wrapper_operand_with_return!(first, SingleValueOperand); + implement_wrapper_operand_with_return!(last, SingleValueOperand); + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(div, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(round); + implement_wrapper_operand!(ceil); + implement_wrapper_operand!(floor); + implement_wrapper_operand!(abs); + implement_wrapper_operand!(sqrt); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_float); + implement_wrapper_operand!(is_bool); + implement_wrapper_operand!(is_datetime); + implement_wrapper_operand!(is_null); + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } + + pub fn exclude(&self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().exclude(query); + } +} + +#[derive(Debug, Clone)] +pub struct SingleValueOperand { + pub(crate) context: MultipleValuesOperand, + pub(crate) kind: SingleKind, + operations: Vec, +} + +impl DeepClone for SingleValueOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.deep_clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl SingleValueOperand { + pub(crate) fn new(context: MultipleValuesOperand, kind: SingleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: MedRecordValue, + ) -> MedRecordResult> { + self.operations + .iter() + .try_fold(Some(value), |value, operation| { + if let Some(value) = value { + operation.evaluate(medrecord, value) + } else { + Ok(None) + } + }) + } + + implement_single_value_comparison_operation!(greater_than, SingleValueOperation, GreaterThan); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + SingleValueOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, SingleValueOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + SingleValueOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, SingleValueOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, SingleValueOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, SingleValueOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, SingleValueOperation, EndsWith); + implement_single_value_comparison_operation!(contains, SingleValueOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(SingleValueOperation::MultipleValuesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(SingleValueOperation::MultipleValuesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, SingleValueOperation, Add); + implement_binary_arithmetic_operation!(sub, SingleValueOperation, Sub); + implement_binary_arithmetic_operation!(mul, SingleValueOperation, Mul); + implement_binary_arithmetic_operation!(div, SingleValueOperation, Div); + implement_binary_arithmetic_operation!(pow, SingleValueOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, SingleValueOperation, Mod); + + implement_unary_arithmetic_operation!(round, SingleValueOperation, Round); + implement_unary_arithmetic_operation!(ceil, SingleValueOperation, Ceil); + implement_unary_arithmetic_operation!(floor, SingleValueOperation, Floor); + implement_unary_arithmetic_operation!(abs, SingleValueOperation, Abs); + implement_unary_arithmetic_operation!(sqrt, SingleValueOperation, Sqrt); + implement_unary_arithmetic_operation!(trim, SingleValueOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, SingleValueOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, SingleValueOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, SingleValueOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, SingleValueOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(SingleValueOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, SingleValueOperation::IsString); + implement_assertion_operation!(is_int, SingleValueOperation::IsInt); + implement_assertion_operation!(is_float, SingleValueOperation::IsFloat); + implement_assertion_operation!(is_bool, SingleValueOperation::IsBool); + implement_assertion_operation!(is_datetime, SingleValueOperation::IsDateTime); + implement_assertion_operation!(is_null, SingleValueOperation::IsNull); + + pub fn eiter_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(SingleValueOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } + + pub fn exclude(&mut self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + let mut operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + query(&mut operand); + + self.operations + .push(SingleValueOperation::Exclude { operand }); + } +} + +impl Wrapper { + pub(crate) fn new(context: MultipleValuesOperand, kind: SingleKind) -> Self { + SingleValueOperand::new(context, kind).into() + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: MedRecordValue, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, value) + } + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(div, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(round); + implement_wrapper_operand!(ceil); + implement_wrapper_operand!(floor); + implement_wrapper_operand!(abs); + implement_wrapper_operand!(sqrt); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_float); + implement_wrapper_operand!(is_bool); + implement_wrapper_operand!(is_datetime); + implement_wrapper_operand!(is_null); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().eiter_or(either_query, or_query); + } + + pub fn exclude(&self, query: Q) + where + Q: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().exclude(query); + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/values/operation.rs b/crates/medmodels-core/src/medrecord/querying/values/operation.rs new file mode 100644 index 00000000..80732c60 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/values/operation.rs @@ -0,0 +1,984 @@ +use super::{ + operand::{ + MultipleValuesComparisonOperand, MultipleValuesOperand, SingleValueComparisonOperand, + SingleValueOperand, + }, + BinaryArithmeticKind, MultipleComparisonKind, SingleComparisonKind, SingleKind, + UnaryArithmeticKind, +}; +use crate::{ + errors::{MedRecordError, MedRecordResult}, + medrecord::{ + datatypes::{ + Abs, Ceil, Contains, EndsWith, Floor, Lowercase, Mod, Pow, Round, Slice, Sqrt, + StartsWith, Trim, TrimEnd, TrimStart, Uppercase, + }, + querying::{ + traits::{DeepClone, ReadWriteOrPanic}, + BoxedIterator, + }, + DataType, MedRecordValue, Wrapper, + }, + MedRecord, +}; +use itertools::Itertools; +use std::{ + cmp::Ordering, + collections::HashSet, + hash::Hash, + ops::{Add, Div, Mul, Range, Sub}, +}; + +macro_rules! get_single_operand_value { + ($kind:ident, $values:expr) => { + match $kind { + SingleKind::Max => MultipleValuesOperation::get_max($values)?.1, + SingleKind::Min => MultipleValuesOperation::get_min($values)?.1, + SingleKind::Mean => MultipleValuesOperation::get_mean($values)?, + SingleKind::Median => MultipleValuesOperation::get_median($values)?, + SingleKind::Mode => MultipleValuesOperation::get_mode($values)?, + SingleKind::Std => MultipleValuesOperation::get_std($values)?, + SingleKind::Var => MultipleValuesOperation::get_var($values)?, + SingleKind::Count => MultipleValuesOperation::get_count($values), + SingleKind::Sum => MultipleValuesOperation::get_sum($values)?, + SingleKind::First => MultipleValuesOperation::get_first($values)?, + SingleKind::Last => MultipleValuesOperation::get_last($values)?, + } + }; +} + +macro_rules! get_single_value_comparison_operand_value { + ($operand:ident, $medrecord:ident) => { + match $operand { + SingleValueComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let attribute = operand.context.attribute.clone(); + let kind = &operand.kind; + + let comparison_values = context + .get_values($medrecord, attribute)? + .map(|value| (&0, value)); + + let comparison_value = get_single_operand_value!(kind, comparison_values); + + operand.evaluate($medrecord, comparison_value)?.ok_or( + MedRecordError::QueryError("No index to compare".to_string()), + )? + } + SingleValueComparisonOperand::Value(value) => value.clone(), + } + }; +} + +macro_rules! get_median { + ($values:ident, $variant:ident) => { + if $values.len() % 2 == 0 { + let middle = $values.len() / 2; + + let first = $values.get(middle - 1).unwrap(); + let second = $values.get(middle).unwrap(); + + let first = MedRecordValue::$variant(*first); + let second = MedRecordValue::$variant(*second); + + first.add(second).unwrap().div(MedRecordValue::Int(2)) + } else { + let middle = $values.len() / 2; + + Ok(MedRecordValue::$variant( + $values.get(middle).unwrap().clone(), + )) + } + }; +} + +#[derive(Debug, Clone)] +pub enum MultipleValuesOperation { + ValueOperation { + operand: Wrapper, + }, + SingleValueComparisonOperation { + operand: SingleValueComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleValuesComparisonOperation { + operand: MultipleValuesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleValueComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + IsFloat, + IsBool, + IsDateTime, + IsNull, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, + Exclude { + operand: Wrapper, + }, +} + +impl DeepClone for MultipleValuesOperation { + fn deep_clone(&self) -> Self { + match self { + Self::ValueOperation { operand } => Self::ValueOperation { + operand: operand.deep_clone(), + }, + Self::SingleValueComparisonOperation { operand, kind } => { + Self::SingleValueComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleValuesComparisonOperation { operand, kind } => { + Self::MultipleValuesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsFloat => Self::IsFloat, + Self::IsBool => Self::IsBool, + Self::IsDateTime => Self::IsDateTime, + Self::IsNull => Self::IsNull, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + Self::Exclude { operand } => Self::Exclude { + operand: operand.deep_clone(), + }, + } + } +} + +impl MultipleValuesOperation { + pub(crate) fn evaluate<'a, T: Eq + Hash>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult> { + match self { + Self::ValueOperation { operand } => { + Self::evaluate_value_operation(medrecord, values, operand) + } + Self::SingleValueComparisonOperation { operand, kind } => { + Self::evaluate_single_value_comparison_operation(medrecord, values, operand, kind) + } + Self::MultipleValuesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_values_comparison_operation( + medrecord, values, operand, kind, + ) + } + Self::BinaryArithmeticOpration { operand, kind } => Ok(Box::new( + Self::evaluate_binary_arithmetic_operation(medrecord, values, operand, kind)?, + )), + Self::UnaryArithmeticOperation { kind } => Ok(Box::new( + Self::evaluate_unary_arithmetic_operation(values, kind.clone()), + )), + Self::Slice(range) => Ok(Box::new(Self::evaluate_slice(values, range.clone()))), + Self::IsString => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::String(_)) + }))) + } + Self::IsInt => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::Int(_)) + }))) + } + Self::IsFloat => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::Float(_)) + }))) + } + Self::IsBool => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::Bool(_)) + }))) + } + Self::IsDateTime => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::DateTime(_)) + }))) + } + Self::IsNull => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::Null) + }))) + } + Self::IsMax => { + let max_value = Self::get_max(values)?; + + Ok(Box::new(std::iter::once(max_value))) + } + Self::IsMin => { + let min_value = Self::get_min(values)?; + + Ok(Box::new(std::iter::once(min_value))) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, values, either, or) + } + Self::Exclude { operand } => Self::evaluate_exclude(medrecord, values, operand), + } + } + + #[inline] + pub(crate) fn get_max<'a, T>( + mut values: impl Iterator, + ) -> MedRecordResult<(&'a T, MedRecordValue)> { + let max_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + values.try_fold(max_value, |max_value, value| { + match value.1.partial_cmp(&max_value.1) { + Some(Ordering::Greater) => Ok(value), + None => { + let first_dtype = DataType::from(value.1); + let second_dtype = DataType::from(max_value.1); + + Err(MedRecordError::QueryError(format!( + "Cannot compare values of data types {} and {}. Consider narrowing down the values using .is_string(), .is_int(), .is_float(), .is_bool() or .is_datetime()", + first_dtype, second_dtype + ))) + } + _ => Ok(max_value), + } + }) + } + + #[inline] + pub(crate) fn get_min<'a, T>( + mut values: impl Iterator, + ) -> MedRecordResult<(&'a T, MedRecordValue)> { + let min_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + values.try_fold(min_value, |min_value, value| { + match value.1.partial_cmp(&min_value.1) { + Some(Ordering::Less) => Ok(value), + None => { + let first_dtype = DataType::from(value.1); + let second_dtype = DataType::from(min_value.1); + + Err(MedRecordError::QueryError(format!( + "Cannot compare values of data types {} and {}. Consider narrowing down the values using .is_string(), .is_int(), .is_float(), .is_bool() or .is_datetime()", + first_dtype, second_dtype + ))) + } + _ => Ok(min_value), + } + }) + } + + #[inline] + pub(crate) fn get_mean<'a, T: 'a>( + mut values: impl Iterator, + ) -> MedRecordResult { + let first_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + let (sum, count) = values.try_fold((first_value.1, 1), |(sum, count), (_, value)| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&value); + + match sum.add(value) { + Ok(sum) => Ok((sum, count + 1)), + Err(_) => Err(MedRecordError::QueryError(format!( + "Cannot add values of data types {} and {}. Consider narrowing down the values using .is_int(), .is_float() or .is_datetime()", + first_dtype, second_dtype + ))), + } + })?; + + sum.div(MedRecordValue::Int(count as i64)) + } + + // TODO: This is a temporary solution. It should be optimized. + #[inline] + pub(crate) fn get_median<'a, T: 'a>( + mut values: impl Iterator, + ) -> MedRecordResult { + let first_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + let first_data_type = DataType::from(&first_value.1); + + match first_value.1 { + MedRecordValue::Int(value) => { + let mut values = values.map(|(_, value)| { + let data_type = DataType::from(&value); + + match value { + MedRecordValue::Int(value) => Ok(value as f64), + MedRecordValue::Float(value) => Ok(value), + _ => Err(MedRecordError::QueryError(format!( + "Cannot calculate median of mixed data types {} and {}. Consider narrowing down the values using .is_int(), .is_float() or .is_datetime()", + first_data_type, data_type + ))), + } + }).collect::>>()?; + values.push(value as f64); + values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + + get_median!(values, Float) + } + MedRecordValue::Float(value) => { + let mut values = values.map(|(_, value)| { + let data_type = DataType::from(&value); + + match value { + MedRecordValue::Int(value) => Ok(value as f64), + MedRecordValue::Float(value) => Ok(value), + _ => Err(MedRecordError::QueryError(format!( + "Cannot calculate median of mixed data types {} and {}. Consider narrowing down the values using .is_int(), .is_float() or .is_datetime()", + first_data_type, data_type + ))), + } + }).collect::>>()?; + values.push(value); + values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + + get_median!(values, Float) + } + MedRecordValue::DateTime(value) => { + let mut values = values.map(|(_, value)| { + let data_type = DataType::from(&value); + + match value { + MedRecordValue::DateTime(naive_date_time) => Ok(naive_date_time), + _ => Err(MedRecordError::QueryError(format!( + "Cannot calculate median of mixed data types {} and {}. Consider narrowing down the values using .is_int(), .is_float() or .is_datetime()", + first_data_type, data_type + ))), + } + }).collect::>>()?; + values.push(value); + values.sort(); + + get_median!(values, DateTime) + } + _ => Err(MedRecordError::QueryError(format!( + "Cannot calculate median of data type {}", + first_data_type + )))?, + } + } + + // TODO: This is a temporary solution. It should be optimized. + #[inline] + pub(crate) fn get_mode<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordResult { + let values = values.map(|(_, value)| value).collect::>(); + + let most_common_value = values + .first() + .ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))? + .clone(); + let most_common_count = values + .iter() + .filter(|value| **value == most_common_value) + .count(); + + let (_, most_common_value) = values.clone().into_iter().fold( + (most_common_count, most_common_value), + |acc, value| { + let count = values.iter().filter(|v| **v == value).count(); + + if count > acc.0 { + (count, value) + } else { + acc + } + }, + ); + + Ok(most_common_value) + } + + #[inline] + // 👀 + pub(crate) fn get_std<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordResult { + let variance = Self::get_var(values)?; + + let MedRecordValue::Float(variance) = variance else { + unreachable!() + }; + + Ok(MedRecordValue::Float(variance.sqrt())) + } + + // TODO: This is a temporary solution. It should be optimized. + #[inline] + pub(crate) fn get_var<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordResult { + let values = values.collect::>(); + + let mean = Self::get_mean(values.clone().into_iter())?; + + let MedRecordValue::Float(mean) = mean else { + let data_type = DataType::from(mean); + + return Err(MedRecordError::QueryError( + format!("Cannot calculate variance of data type {}. Consider narrowing down the values using .is_int() or .is_float()", data_type), + )); + }; + + let values = values + .into_iter() + .map(|value| { + let data_type = DataType::from(&value.1); + + match value.1 { + MedRecordValue::Int(value) => Ok(value as f64), + MedRecordValue::Float(value) => Ok(value), + _ => Err(MedRecordError::QueryError( + format!("Cannot calculate variance of data type {}. Consider narrowing down the values using .is_int() or .is_float()", data_type), + )), + }}) + .collect::>>()?; + + let values_length = values.len(); + + let variance = values + .into_iter() + .map(|value| (value - mean).powi(2)) + .sum::() + / values_length as f64; + + Ok(MedRecordValue::Float(variance)) + } + + #[inline] + pub(crate) fn get_count<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordValue { + MedRecordValue::Int(values.count() as i64) + } + + #[inline] + // 🥊💥 + pub(crate) fn get_sum<'a, T: 'a>( + mut values: impl Iterator, + ) -> MedRecordResult { + let first_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + values.try_fold(first_value.1, |sum, (_, value)| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&value); + + sum.add(value).map_err(|_| { + MedRecordError::QueryError(format!( + "Cannot add values of data types {} and {}. Consider narrowing down the values using .is_string(), .is_int(), .is_float(), .is_bool() or .is_datetime()", + first_dtype, second_dtype + )) + }) + }) + } + + #[inline] + pub(crate) fn get_first<'a, T: 'a>( + mut values: impl Iterator, + ) -> MedRecordResult { + values + .next() + .ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + )) + .map(|(_, value)| value) + } + + #[inline] + // 🥊💥 + pub(crate) fn get_last<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordResult { + values + .last() + .ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + )) + .map(|(_, value)| value) + } + + #[inline] + fn evaluate_value_operation<'a, T>( + medrecord: &'a MedRecord, + values: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let kind = &operand.0.read_or_panic().kind; + + let values = values.collect::>(); + + let value = get_single_operand_value!(kind, values.clone().into_iter()); + + Ok(match operand.evaluate(medrecord, value)? { + Some(_) => Box::new(values.into_iter()), + None => Box::new(std::iter::empty()), + }) + } + + #[inline] + fn evaluate_single_value_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + comparison_operand: &SingleValueComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_value = + get_single_value_comparison_operand_value!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => Ok(Box::new( + values.filter(move |(_, value)| value > &comparison_value), + )), + SingleComparisonKind::GreaterThanOrEqualTo => Ok(Box::new( + values.filter(move |(_, value)| value >= &comparison_value), + )), + SingleComparisonKind::LessThan => Ok(Box::new( + values.filter(move |(_, value)| value < &comparison_value), + )), + SingleComparisonKind::LessThanOrEqualTo => Ok(Box::new( + values.filter(move |(_, value)| value <= &comparison_value), + )), + SingleComparisonKind::EqualTo => Ok(Box::new( + values.filter(move |(_, value)| value == &comparison_value), + )), + SingleComparisonKind::NotEqualTo => Ok(Box::new( + values.filter(move |(_, value)| value != &comparison_value), + )), + SingleComparisonKind::StartsWith => { + Ok(Box::new(values.filter(move |(_, value)| { + value.starts_with(&comparison_value) + }))) + } + SingleComparisonKind::EndsWith => { + Ok(Box::new(values.filter(move |(_, value)| { + value.ends_with(&comparison_value) + }))) + } + SingleComparisonKind::Contains => { + Ok(Box::new(values.filter(move |(_, value)| { + value.contains(&comparison_value) + }))) + } + } + } + + #[inline] + fn evaluate_multiple_values_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + comparison_operand: &MultipleValuesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_values = match comparison_operand { + MultipleValuesComparisonOperand::Operand(operand) => { + let context = &operand.context; + let attribute = operand.attribute.clone(); + + // TODO: This is a temporary solution. It should be optimized. + let comparison_values = context + .get_values(medrecord, attribute)? + .map(|value| (&0, value)); + + operand + .evaluate(medrecord, comparison_values)? + .map(|(_, value)| value) + .collect::>() + } + MultipleValuesComparisonOperand::Values(values) => values.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => { + Ok(Box::new(values.filter(move |(_, value)| { + comparison_values.contains(value) + }))) + } + MultipleComparisonKind::IsNotIn => { + Ok(Box::new(values.filter(move |(_, value)| { + !comparison_values.contains(value) + }))) + } + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation<'a, T: 'a>( + medrecord: &'a MedRecord, + values: impl Iterator, + operand: &SingleValueComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_value = get_single_value_comparison_operand_value!(operand, medrecord); + + let values = values + .map(move |(t, value)| { + match kind { + BinaryArithmeticKind::Add => value.add(arithmetic_value.clone()), + BinaryArithmeticKind::Sub => value.sub(arithmetic_value.clone()), + BinaryArithmeticKind::Mul => { + value.clone().mul(arithmetic_value.clone()) + } + BinaryArithmeticKind::Div => { + value.clone().div(arithmetic_value.clone()) + } + BinaryArithmeticKind::Pow => { + value.clone().pow(arithmetic_value.clone()) + } + BinaryArithmeticKind::Mod => { + value.clone().r#mod(arithmetic_value.clone()) + } + } + .map_err(|_| { + MedRecordError::QueryError(format!( + "Failed arithmetic operation {}. Consider narrowing down the values using .is_int() or .is_float()", + kind, + )) + }).map(|result| (t, result)) + }); + + // TODO: This is a temporary solution. It should be optimized. + Ok(values.collect::>>()?.into_iter()) + } + + #[inline] + fn evaluate_unary_arithmetic_operation<'a, T: 'a>( + values: impl Iterator, + kind: UnaryArithmeticKind, + ) -> impl Iterator { + values.map(move |(t, value)| { + let value = match kind { + UnaryArithmeticKind::Round => value.round(), + UnaryArithmeticKind::Ceil => value.ceil(), + UnaryArithmeticKind::Floor => value.floor(), + UnaryArithmeticKind::Abs => value.abs(), + UnaryArithmeticKind::Sqrt => value.sqrt(), + UnaryArithmeticKind::Trim => value.trim(), + UnaryArithmeticKind::TrimStart => value.trim_start(), + UnaryArithmeticKind::TrimEnd => value.trim_end(), + UnaryArithmeticKind::Lowercase => value.lowercase(), + UnaryArithmeticKind::Uppercase => value.uppercase(), + }; + (t, value) + }) + } + + #[inline] + fn evaluate_slice<'a, T: 'a>( + values: impl Iterator, + range: Range, + ) -> impl Iterator { + values.map(move |(t, value)| (t, value.slice(range.clone()))) + } + + #[inline] + fn evaluate_either_or<'a, T: 'a + Eq + Hash>( + medrecord: &'a MedRecord, + values: impl Iterator, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let values = values.collect::>(); + + let either_values = either.evaluate(medrecord, values.clone().into_iter())?; + let or_values = or.evaluate(medrecord, values.into_iter())?; + + Ok(Box::new( + either_values.chain(or_values).unique_by(|value| value.0), + )) + } + + #[inline] + fn evaluate_exclude<'a, T: 'a + Eq + Hash>( + medrecord: &'a MedRecord, + values: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let values = values.collect::>(); + + let result = operand + .evaluate(medrecord, values.clone().into_iter())? + .map(|(t, _)| t) + .collect::>(); + + Ok(Box::new( + values.into_iter().filter(move |(t, _)| !result.contains(t)), + )) + } +} + +#[derive(Debug, Clone)] +pub enum SingleValueOperation { + SingleValueComparisonOperation { + operand: SingleValueComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleValuesComparisonOperation { + operand: MultipleValuesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleValueComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + IsFloat, + IsBool, + IsDateTime, + IsNull, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, + Exclude { + operand: Wrapper, + }, +} + +impl DeepClone for SingleValueOperation { + fn deep_clone(&self) -> Self { + match self { + Self::SingleValueComparisonOperation { operand, kind } => { + Self::SingleValueComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleValuesComparisonOperation { operand, kind } => { + Self::MultipleValuesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsFloat => Self::IsFloat, + Self::IsBool => Self::IsBool, + Self::IsDateTime => Self::IsDateTime, + Self::IsNull => Self::IsNull, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + Self::Exclude { operand } => Self::Exclude { + operand: operand.deep_clone(), + }, + } + } +} + +impl SingleValueOperation { + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: MedRecordValue, + ) -> MedRecordResult> { + match self { + Self::SingleValueComparisonOperation { operand, kind } => { + Self::evaluate_single_value_comparison_operation(medrecord, value, operand, kind) + } + Self::MultipleValuesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_values_comparison_operation(medrecord, value, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, value, operand, kind) + } + Self::UnaryArithmeticOperation { kind } => Ok(Some(match kind { + UnaryArithmeticKind::Round => value.round(), + UnaryArithmeticKind::Ceil => value.ceil(), + UnaryArithmeticKind::Floor => value.floor(), + UnaryArithmeticKind::Abs => value.abs(), + UnaryArithmeticKind::Sqrt => value.sqrt(), + UnaryArithmeticKind::Trim => value.trim(), + UnaryArithmeticKind::TrimStart => value.trim_start(), + UnaryArithmeticKind::TrimEnd => value.trim_end(), + UnaryArithmeticKind::Lowercase => value.lowercase(), + UnaryArithmeticKind::Uppercase => value.uppercase(), + })), + Self::Slice(range) => Ok(Some(value.slice(range.clone()))), + Self::IsString => Ok(match value { + MedRecordValue::String(_) => Some(value), + _ => None, + }), + Self::IsInt => Ok(match value { + MedRecordValue::Int(_) => Some(value), + _ => None, + }), + Self::IsFloat => Ok(match value { + MedRecordValue::Float(_) => Some(value), + _ => None, + }), + Self::IsBool => Ok(match value { + MedRecordValue::Bool(_) => Some(value), + _ => None, + }), + Self::IsDateTime => Ok(match value { + MedRecordValue::DateTime(_) => Some(value), + _ => None, + }), + Self::IsNull => Ok(match value { + MedRecordValue::Null => Some(value), + _ => None, + }), + Self::EitherOr { either, or } => Self::evaluate_either_or(medrecord, value, either, or), + Self::Exclude { operand } => Ok(match operand.evaluate(medrecord, value.clone())? { + Some(_) => None, + None => Some(value), + }), + } + } + + #[inline] + fn evaluate_single_value_comparison_operation( + medrecord: &MedRecord, + value: MedRecordValue, + comparison_operand: &SingleValueComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_value = + get_single_value_comparison_operand_value!(comparison_operand, medrecord); + + let comparison_result = match kind { + SingleComparisonKind::GreaterThan => value > comparison_value, + SingleComparisonKind::GreaterThanOrEqualTo => value >= comparison_value, + SingleComparisonKind::LessThan => value < comparison_value, + SingleComparisonKind::LessThanOrEqualTo => value <= comparison_value, + SingleComparisonKind::EqualTo => value == comparison_value, + SingleComparisonKind::NotEqualTo => value != comparison_value, + SingleComparisonKind::StartsWith => value.starts_with(&comparison_value), + SingleComparisonKind::EndsWith => value.ends_with(&comparison_value), + SingleComparisonKind::Contains => value.contains(&comparison_value), + }; + + Ok(if comparison_result { Some(value) } else { None }) + } + + #[inline] + fn evaluate_multiple_values_comparison_operation( + medrecord: &MedRecord, + value: MedRecordValue, + comparison_operand: &MultipleValuesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_values = match comparison_operand { + MultipleValuesComparisonOperand::Operand(operand) => { + let context = &operand.context; + let attribute = operand.attribute.clone(); + + // TODO: This is a temporary solution. It should be optimized. + let comparison_values = context + .get_values(medrecord, attribute)? + .map(|value| (&0, value)); + + operand + .evaluate(medrecord, comparison_values)? + .map(|(_, value)| value) + .collect::>() + } + MultipleValuesComparisonOperand::Values(values) => values.clone(), + }; + + let comparison_result = match kind { + MultipleComparisonKind::IsIn => comparison_values.contains(&value), + MultipleComparisonKind::IsNotIn => !comparison_values.contains(&value), + }; + + Ok(if comparison_result { Some(value) } else { None }) + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + value: MedRecordValue, + operand: &SingleValueComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_value = get_single_value_comparison_operand_value!(operand, medrecord); + + match kind { + BinaryArithmeticKind::Add => value.add(arithmetic_value), + BinaryArithmeticKind::Sub => value.sub(arithmetic_value), + BinaryArithmeticKind::Mul => value.mul(arithmetic_value), + BinaryArithmeticKind::Div => value.div(arithmetic_value), + BinaryArithmeticKind::Pow => value.pow(arithmetic_value), + BinaryArithmeticKind::Mod => value.r#mod(arithmetic_value), + } + .map(Some) + } + + #[inline] + fn evaluate_either_or( + medrecord: &MedRecord, + value: MedRecordValue, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord, value.clone())?; + let or_result = or.evaluate(medrecord, value)?; + + match (either_result, or_result) { + (Some(either_result), _) => Ok(Some(either_result)), + (None, Some(or_result)) => Ok(Some(or_result)), + _ => Ok(None), + } + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/wrapper.rs b/crates/medmodels-core/src/medrecord/querying/wrapper.rs new file mode 100644 index 00000000..a5d338bc --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/wrapper.rs @@ -0,0 +1,45 @@ +use super::traits::{DeepClone, ReadWriteOrPanic}; +use std::sync::{Arc, RwLock}; + +#[repr(transparent)] +#[derive(Debug, Clone)] +pub struct Wrapper(pub(crate) Arc>); + +impl From for Wrapper { + fn from(value: T) -> Self { + Self(Arc::new(RwLock::new(value))) + } +} + +impl DeepClone for Wrapper +where + T: DeepClone, +{ + fn deep_clone(&self) -> Self { + self.0.read_or_panic().deep_clone().into() + } +} + +#[derive(Debug, Clone)] +pub enum CardinalityWrapper { + Single(T), + Multiple(Vec), +} + +impl From for CardinalityWrapper { + fn from(value: T) -> Self { + Self::Single(value) + } +} + +impl From> for CardinalityWrapper { + fn from(value: Vec) -> Self { + Self::Multiple(value) + } +} + +impl From<[T; N]> for CardinalityWrapper { + fn from(value: [T; N]) -> Self { + Self::Multiple(value.to_vec()) + } +} diff --git a/crates/medmodels-core/src/medrecord/schema.rs b/crates/medmodels-core/src/medrecord/schema.rs index 2bcdd562..8015870e 100644 --- a/crates/medmodels-core/src/medrecord/schema.rs +++ b/crates/medmodels-core/src/medrecord/schema.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] - use super::{Attributes, EdgeIndex, NodeIndex}; use crate::{ errors::GraphError, diff --git a/docs/user_guide/getstarted.md b/docs/user_guide/getstarted.md index 74105b75..bdad0686 100644 --- a/docs/user_guide/getstarted.md +++ b/docs/user_guide/getstarted.md @@ -301,25 +301,13 @@ patient_drug_edges = medrecord.add_edges_polars( ) ``` -### Adding single entries +### Removing entries -Single nodes can be added or removed to an existing MedRecord by their unique identifier. Attributes can also be added during that process. - -Single edges between a source node and target node can be added to the MedRecord instance by specifiying the source and the target node identifier. Attributes for the connection can also be included. - -```python -medrecord.add_node(node="pat_6", attributes={"age": 67, "gender": "F"}) -# add connection between nodes, will return the edge identifier -edge_pat6_pat2_id = medrecord.add_edge( - source_node="pat_6", target_node="pat_2", attributes={"relationship": "Mother"} -) -``` - -Nodes and edges can be easily removed by their identifier. To check if a node or edge exists, the `contain_node()` or `contain_edge()` functions can be used. If a node is deleted from the MedRecord, its corresponding edges will also be removed. +Nodes and edges can be easily removed by their identifier. To check if a node or edge exists, the `contains_node()` or `contains_edge()` functions can be used. If a node is deleted from the MedRecord, its corresponding edges will also be removed. ```python # returns attributes for the node that will be removed -medrecord.remove_node("pat_6") +medrecord.remove_nodes("pat_6") medrecord.contains_node("pat_6") or medrecord.contains_edge(edge_pat6_pat2_id) ``` @@ -434,7 +422,7 @@ additional_young_id = medrecord.select_nodes( node().attribute("age").greater_or_equal(young_age) & node().attribute("age").less(higher_age) ) -medrecord.add_node_to_group(group="Young", node=additional_young_id) +medrecord.add_nodes_to_group(group="Young", nodes=additional_young_id) print( f"Patients in Group 'Young' if threshold age is {higher_age}: {medrecord.group('Young')}" @@ -446,12 +434,12 @@ print( It is possible to remove nodes from groups and to remove groups entirely from the MedRecord. ```python -medrecord.remove_node_from_group(group="Young", node=additional_young_id) +medrecord.remove_nodes_from_group(group="Young", nodes=additional_young_id) print(f"Patients in group 'Young': {medrecord.select_nodes(node().in_group('Young'))}") print(f"The MedRecord contains {medrecord.group_count()} groups.") -medrecord.remove_group("Woman") +medrecord.remove_groups("Woman") print( f"After the removal operation, the MedRecord contains {medrecord.group_count()} groups." ) diff --git a/medmodels/_medmodels.pyi b/medmodels/_medmodels.pyi index 76f13110..55f090bb 100644 --- a/medmodels/_medmodels.pyi +++ b/medmodels/_medmodels.pyi @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Sequence, Union from medmodels.medrecord.types import ( Attributes, @@ -28,13 +28,6 @@ if TYPE_CHECKING: else: from typing_extensions import TypeAlias -ValueOperand: TypeAlias = Union[ - MedRecordValue, - MedRecordAttribute, - PyValueArithmeticOperation, - PyValueTransformationOperation, -] - PyDataType: TypeAlias = Union[ PyString, PyInt, @@ -156,8 +149,7 @@ class PyMedRecord: source_node_indices: NodeIndexInputList, target_node_indices: NodeIndexInputList, ) -> List[EdgeIndex]: ... - def add_node(self, node_index: NodeIndex, attributes: AttributesInput) -> None: ... - def remove_node( + def remove_nodes( self, node_index: NodeIndexInputList ) -> Dict[NodeIndex, Attributes]: ... def replace_node_attributes( @@ -176,13 +168,7 @@ class PyMedRecord: def add_nodes_dataframes( self, nodes_dataframe: List[PolarsNodeDataFrameInput] ) -> None: ... - def add_edge( - self, - source_node_index: NodeIndex, - target_node_index: NodeIndex, - attributes: AttributesInput, - ) -> EdgeIndex: ... - def remove_edge( + def remove_edges( self, edge_index: EdgeIndexInputList ) -> Dict[EdgeIndex, Attributes]: ... def replace_edge_attributes( @@ -207,17 +193,17 @@ class PyMedRecord: node_indices_to_add: Optional[NodeIndexInputList], edge_indices_to_add: Optional[EdgeIndexInputList], ) -> None: ... - def remove_group(self, group: GroupInputList) -> None: ... - def add_node_to_group( + def remove_groups(self, group: GroupInputList) -> None: ... + def add_nodes_to_group( self, group: Group, node_index: NodeIndexInputList ) -> None: ... - def add_edge_to_group( + def add_edges_to_group( self, group: Group, edge_index: EdgeIndexInputList ) -> None: ... - def remove_node_from_group( + def remove_nodes_from_group( self, group: Group, node_index: NodeIndexInputList ) -> None: ... - def remove_edge_from_group( + def remove_edges_from_group( self, group: Group, edge_index: EdgeIndexInputList ) -> None: ... def nodes_in_group(self, group: GroupInputList) -> Dict[Group, List[NodeIndex]]: ... @@ -241,171 +227,442 @@ class PyMedRecord: self, node_indices: NodeIndexInputList ) -> Dict[NodeIndex, List[NodeIndex]]: ... def clear(self) -> None: ... - def select_nodes(self, operation: PyNodeOperation) -> List[NodeIndex]: ... - def select_edges(self, operation: PyEdgeOperation) -> List[EdgeIndex]: ... + def select_nodes( + self, query: Callable[[PyNodeOperand], None] + ) -> List[NodeIndex]: ... + def select_edges( + self, query: Callable[[PyEdgeOperand], None] + ) -> List[EdgeIndex]: ... def clone(self) -> PyMedRecord: ... -class PyValueArithmeticOperation: ... -class PyValueTransformationOperation: ... - -class PyNodeOperation: - def logical_and(self, operation: PyNodeOperation) -> PyNodeOperation: ... - def logical_or(self, operation: PyNodeOperation) -> PyNodeOperation: ... - def logical_xor(self, operation: PyNodeOperation) -> PyNodeOperation: ... - def logical_not(self) -> PyNodeOperation: ... - -class PyEdgeOperation: - def logical_and(self, operation: PyEdgeOperation) -> PyEdgeOperation: ... - def logical_or(self, operation: PyEdgeOperation) -> PyEdgeOperation: ... - def logical_xor(self, operation: PyEdgeOperation) -> PyEdgeOperation: ... - def logical_not(self) -> PyEdgeOperation: ... - -class PyNodeAttributeOperand: - def greater( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def less( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def greater_or_equal( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def less_or_equal( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def equal( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def not_equal( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def is_in(self, operands: List[MedRecordValue]) -> PyNodeOperation: ... - def not_in(self, operands: List[MedRecordValue]) -> PyNodeOperation: ... - def starts_with( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def ends_with( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def contains( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def add(self, value: MedRecordValue) -> ValueOperand: ... - def sub(self, value: MedRecordValue) -> ValueOperand: ... - def mul(self, value: MedRecordValue) -> ValueOperand: ... - def div(self, value: MedRecordValue) -> ValueOperand: ... - def pow(self, value: MedRecordValue) -> ValueOperand: ... - def mod(self, value: MedRecordValue) -> ValueOperand: ... - def round(self) -> ValueOperand: ... - def ceil(self) -> ValueOperand: ... - def floor(self) -> ValueOperand: ... - def abs(self) -> ValueOperand: ... - def sqrt(self) -> ValueOperand: ... - def trim(self) -> ValueOperand: ... - def trim_start(self) -> ValueOperand: ... - def trim_end(self) -> ValueOperand: ... - def lowercase(self) -> ValueOperand: ... - def uppercase(self) -> ValueOperand: ... - def slice(self, start: int, end: int) -> ValueOperand: ... - -class PyEdgeAttributeOperand: - def greater( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def less( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def greater_or_equal( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def less_or_equal( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def equal( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def not_equal( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def is_in(self, operands: List[MedRecordValue]) -> PyEdgeOperation: ... - def not_in(self, operands: List[MedRecordValue]) -> PyEdgeOperation: ... - def starts_with( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def ends_with( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def contains( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def add(self, value: MedRecordValue) -> ValueOperand: ... - def sub(self, value: MedRecordValue) -> ValueOperand: ... - def mul(self, value: MedRecordValue) -> ValueOperand: ... - def div(self, value: MedRecordValue) -> ValueOperand: ... - def pow(self, value: MedRecordValue) -> ValueOperand: ... - def mod(self, value: MedRecordValue) -> ValueOperand: ... - def round(self) -> ValueOperand: ... - def ceil(self) -> ValueOperand: ... - def floor(self) -> ValueOperand: ... - def abs(self) -> ValueOperand: ... - def sqrt(self) -> ValueOperand: ... - def trim(self) -> ValueOperand: ... - def trim_start(self) -> ValueOperand: ... - def trim_end(self) -> ValueOperand: ... - def lowercase(self) -> ValueOperand: ... - def uppercase(self) -> ValueOperand: ... - def slice(self, start: int, end: int) -> ValueOperand: ... +class PyEdgeDirection(Enum): + Incoming = 0 + Outgoing = 1 + Both = 2 + +class PyNodeOperand: + def attribute(self, attribute: MedRecordAttribute) -> PyMultipleValuesOperand: ... + def attributes(self) -> PyAttributesTreeOperand: ... + def index(self) -> PyNodeIndicesOperand: ... + def in_group(self, group: Union[Group, List[Group]]) -> None: ... + def has_attribute( + self, attribute: Union[MedRecordAttribute, List[MedRecordAttribute]] + ) -> None: ... + def outgoing_edges(self) -> PyEdgeOperand: ... + def incoming_edges(self) -> PyEdgeOperand: ... + def neighbors(self, direction: PyEdgeDirection) -> PyNodeOperand: ... + def either_or( + self, + either: Callable[[PyNodeOperand], None], + or_: Callable[[PyNodeOperand], None], + ) -> None: ... + def exclude(self, query: Callable[[PyNodeOperand], None]) -> None: ... + def deep_clone(self) -> PyNodeOperand: ... + +PyNodeIndexComparisonOperand: TypeAlias = Union[NodeIndex, PyNodeIndexOperand] +PyNodeIndexArithmeticOperand: TypeAlias = PyNodeIndexComparisonOperand +PyNodeIndicesComparisonOperand: TypeAlias = Union[List[NodeIndex], PyNodeIndicesOperand] + +class PyNodeIndicesOperand: + def max(self) -> PyNodeIndexOperand: ... + def min(self) -> PyNodeIndexOperand: ... + def count(self) -> PyNodeIndexOperand: ... + def sum(self) -> PyNodeIndexOperand: ... + def first(self) -> PyNodeIndexOperand: ... + def last(self) -> PyNodeIndexOperand: ... + def greater_than(self, index: PyNodeIndexComparisonOperand) -> None: ... + def greater_than_or_equal_to(self, index: PyNodeIndexComparisonOperand) -> None: ... + def less_than(self, index: PyNodeIndexComparisonOperand) -> None: ... + def less_than_or_equal_to(self, index: PyNodeIndexComparisonOperand) -> None: ... + def equal_to(self, index: PyNodeIndexComparisonOperand) -> None: ... + def not_equal_to(self, index: PyNodeIndexComparisonOperand) -> None: ... + def starts_with(self, index: PyNodeIndexComparisonOperand) -> None: ... + def ends_with(self, index: PyNodeIndexComparisonOperand) -> None: ... + def contains(self, index: PyNodeIndexComparisonOperand) -> None: ... + def is_in(self, indices: PyNodeIndicesComparisonOperand) -> None: ... + def is_not_in(self, indices: PyNodeIndicesComparisonOperand) -> None: ... + def add(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def sub(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def mul(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def pow(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def mod(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def abs(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def is_max(self) -> None: ... + def is_min(self) -> None: ... + def either_or( + self, + either: Callable[[PyNodeIndicesOperand], None], + or_: Callable[[PyNodeIndicesOperand], None], + ) -> None: ... + def exclude(self, query: Callable[[PyNodeIndicesOperand], None]) -> None: ... + def deep_clone(self) -> PyNodeIndicesOperand: ... class PyNodeIndexOperand: - def greater(self, operand: NodeIndex) -> PyNodeOperation: ... - def less(self, operand: NodeIndex) -> PyNodeOperation: ... - def greater_or_equal(self, operand: NodeIndex) -> PyNodeOperation: ... - def less_or_equal(self, operand: NodeIndex) -> PyNodeOperation: ... - def equal(self, operand: NodeIndex) -> PyNodeOperation: ... - def not_equal(self, operand: NodeIndex) -> PyNodeOperation: ... - def is_in(self, operand: List[NodeIndex]) -> PyNodeOperation: ... - def not_in(self, operand: List[NodeIndex]) -> PyNodeOperation: ... - def starts_with(self, operand: NodeIndex) -> PyNodeOperation: ... - def ends_with(self, operand: NodeIndex) -> PyNodeOperation: ... - def contains(self, operand: NodeIndex) -> PyNodeOperation: ... + def greater_than(self, index: PyNodeIndexComparisonOperand) -> None: ... + def greater_than_or_equal_to(self, index: PyNodeIndexComparisonOperand) -> None: ... + def less_than(self, index: PyNodeIndexComparisonOperand) -> None: ... + def less_than_or_equal_to(self, index: PyNodeIndexComparisonOperand) -> None: ... + def equal_to(self, index: PyNodeIndexComparisonOperand) -> None: ... + def not_equal_to(self, index: PyNodeIndexComparisonOperand) -> None: ... + def starts_with(self, index: PyNodeIndexComparisonOperand) -> None: ... + def ends_with(self, index: PyNodeIndexComparisonOperand) -> None: ... + def contains(self, index: PyNodeIndexComparisonOperand) -> None: ... + def is_in(self, indices: PyNodeIndicesComparisonOperand) -> None: ... + def is_not_in(self, indices: PyNodeIndicesComparisonOperand) -> None: ... + def add(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def sub(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def mul(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def pow(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def mod(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def abs(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def either_or( + self, + either: Callable[[PyNodeIndexOperand], None], + or_: Callable[[PyNodeIndexOperand], None], + ) -> None: ... + def exclude(self, query: Callable[[PyNodeIndexOperand], None]) -> None: ... + def deep_clone(self) -> PyNodeIndexOperand: ... + +class PyEdgeOperand: + def attribute(self, attribute: MedRecordAttribute) -> PyMultipleValuesOperand: ... + def attributes(self) -> PyAttributesTreeOperand: ... + def index(self) -> PyEdgeIndicesOperand: ... + def in_group(self, group: Union[Group, List[Group]]) -> None: ... + def has_attribute( + self, attribute: Union[MedRecordAttribute, List[MedRecordAttribute]] + ) -> None: ... + def source_node(self) -> PyNodeOperand: ... + def target_node(self) -> PyNodeOperand: ... + def either_or( + self, + either: Callable[[PyEdgeOperand], None], + or_: Callable[[PyEdgeOperand], None], + ) -> None: ... + def exclude(self, query: Callable[[PyEdgeOperand], None]) -> None: ... + def deep_clone(self) -> PyEdgeOperand: ... + +PyEdgeIndexComparisonOperand: TypeAlias = Union[EdgeIndex, PyEdgeIndexOperand] +PyEdgeIndexArithmeticOperand: TypeAlias = PyEdgeIndexComparisonOperand +PyEdgeIndicesComparisonOperand: TypeAlias = Union[List[EdgeIndex], PyEdgeIndicesOperand] + +class PyEdgeIndicesOperand: + def max(self) -> PyEdgeIndexOperand: ... + def min(self) -> PyEdgeIndexOperand: ... + def count(self) -> PyEdgeIndexOperand: ... + def sum(self) -> PyEdgeIndexOperand: ... + def first(self) -> PyEdgeIndexOperand: ... + def last(self) -> PyEdgeIndexOperand: ... + def greater_than(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def greater_than_or_equal_to(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def less_than(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def less_than_or_equal_to(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def equal_to(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def not_equal_to(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def starts_with(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def ends_with(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def contains(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def is_in(self, indices: PyEdgeIndicesComparisonOperand) -> None: ... + def is_not_in(self, indices: PyEdgeIndicesComparisonOperand) -> None: ... + def add(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def sub(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def mul(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def pow(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def mod(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def is_max(self) -> None: ... + def is_min(self) -> None: ... + def either_or( + self, + either: Callable[[PyEdgeIndicesOperand], None], + or_: Callable[[PyEdgeIndicesOperand], None], + ) -> None: ... + def exclude(self, query: Callable[[PyEdgeIndicesOperand], None]) -> None: ... + def deep_clone(self) -> PyEdgeIndicesOperand: ... class PyEdgeIndexOperand: - def greater(self, operand: EdgeIndex) -> PyEdgeOperation: ... - def less(self, operand: EdgeIndex) -> PyEdgeOperation: ... - def greater_or_equal(self, operand: EdgeIndex) -> PyEdgeOperation: ... - def less_or_equal(self, operand: EdgeIndex) -> PyEdgeOperation: ... - def equal(self, operand: EdgeIndex) -> PyEdgeOperation: ... - def not_equal(self, operand: EdgeIndex) -> PyEdgeOperation: ... - def is_in(self, operand: List[EdgeIndex]) -> PyEdgeOperation: ... - def not_in(self, operand: List[EdgeIndex]) -> PyEdgeOperation: ... + def greater_than(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def greater_than_or_equal_to(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def less_than(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def less_than_or_equal_to(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def equal_to(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def not_equal_to(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def starts_with(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def ends_with(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def contains(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def is_in(self, indices: PyEdgeIndicesComparisonOperand) -> None: ... + def is_not_in(self, indices: PyEdgeIndicesComparisonOperand) -> None: ... + def add(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def sub(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def mul(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def pow(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def mod(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def either_or( + self, + either: Callable[[PyEdgeIndexOperand], None], + or_: Callable[[PyEdgeIndexOperand], None], + ) -> None: ... + def exclude(self, query: Callable[[PyEdgeIndexOperand], None]) -> None: ... + def deep_clone(self) -> PyEdgeIndexOperand: ... -class PyNodeOperand: - def in_group(self, operand: Group) -> PyNodeOperation: ... - def has_attribute(self, operand: MedRecordAttribute) -> PyNodeOperation: ... - def has_outgoing_edge_with(self, operation: PyEdgeOperation) -> PyNodeOperation: ... - def has_incoming_edge_with(self, operation: PyEdgeOperation) -> PyNodeOperation: ... - def has_edge_with(self, operation: PyEdgeOperation) -> PyNodeOperation: ... - def has_neighbor_with(self, operation: PyNodeOperation) -> PyNodeOperation: ... - def has_neighbor_undirected_with( - self, operation: PyNodeOperation - ) -> PyNodeOperation: ... - def attribute(self, attribute: MedRecordAttribute) -> PyNodeAttributeOperand: ... - def index(self) -> PyNodeIndexOperand: ... +PySingleValueComparisonOperand: TypeAlias = Union[MedRecordValue, PySingleValueOperand] +PySingleValueArithmeticOperand: TypeAlias = PySingleValueComparisonOperand +PyMultipleValuesComparisonOperand: TypeAlias = Union[ + List[MedRecordValue], PyMultipleValuesOperand +] -class PyEdgeOperand: - def connected_target(self, operand: NodeIndex) -> PyEdgeOperation: ... - def connected_source(self, operand: NodeIndex) -> PyEdgeOperation: ... - def connected(self, operand: NodeIndex) -> PyEdgeOperation: ... - def in_group(self, operand: Group) -> PyEdgeOperation: ... - def has_attribute(self, operand: MedRecordAttribute) -> PyEdgeOperation: ... - def connected_source_with(self, operation: PyNodeOperation) -> PyEdgeOperation: ... - def connected_target_with(self, operation: PyNodeOperation) -> PyEdgeOperation: ... - def connected_with(self, operation: PyNodeOperation) -> PyEdgeOperation: ... - def has_parallel_edges_with( - self, operation: PyEdgeOperation - ) -> PyEdgeOperation: ... - def has_parallel_edges_with_self_comparison( - self, operation: PyEdgeOperation - ) -> PyEdgeOperation: ... - def attribute(self, attribute: MedRecordAttribute) -> PyEdgeAttributeOperand: ... - def index(self) -> PyEdgeIndexOperand: ... +class PyMultipleValuesOperand: + def max(self) -> PySingleValueOperand: ... + def min(self) -> PySingleValueOperand: ... + def mean(self) -> PySingleValueOperand: ... + def median(self) -> PySingleValueOperand: ... + def mode(self) -> PySingleValueOperand: ... + def std(self) -> PySingleValueOperand: ... + def var(self) -> PySingleValueOperand: ... + def count(self) -> PySingleValueOperand: ... + def sum(self) -> PySingleValueOperand: ... + def first(self) -> PySingleValueOperand: ... + def last(self) -> PySingleValueOperand: ... + def greater_than(self, value: PySingleValueComparisonOperand) -> None: ... + def greater_than_or_equal_to( + self, value: PySingleValueComparisonOperand + ) -> None: ... + def less_than(self, value: PySingleValueComparisonOperand) -> None: ... + def less_than_or_equal_to(self, value: PySingleValueComparisonOperand) -> None: ... + def equal_to(self, value: PySingleValueComparisonOperand) -> None: ... + def not_equal_to(self, value: PySingleValueComparisonOperand) -> None: ... + def starts_with(self, value: PySingleValueComparisonOperand) -> None: ... + def ends_with(self, value: PySingleValueComparisonOperand) -> None: ... + def contains(self, value: PySingleValueComparisonOperand) -> None: ... + def is_in(self, values: PyMultipleValuesComparisonOperand) -> None: ... + def is_not_in(self, values: PyMultipleValuesComparisonOperand) -> None: ... + def add(self, value: PySingleValueArithmeticOperand) -> None: ... + def sub(self, value: PySingleValueArithmeticOperand) -> None: ... + def mul(self, value: PySingleValueArithmeticOperand) -> None: ... + def div(self, value: PySingleValueArithmeticOperand) -> None: ... + def pow(self, value: PySingleValueArithmeticOperand) -> None: ... + def mod(self, value: PySingleValueArithmeticOperand) -> None: ... + def round(self) -> None: ... + def ceil(self) -> None: ... + def floor(self) -> None: ... + def abs(self) -> None: ... + def sqrt(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def is_float(self) -> None: ... + def is_bool(self) -> None: ... + def is_datetime(self) -> None: ... + def is_null(self) -> None: ... + def is_max(self) -> None: ... + def is_min(self) -> None: ... + def either_or( + self, + either: Callable[[PyMultipleValuesOperand], None], + or_: Callable[[PyMultipleValuesOperand], None], + ) -> None: ... + def exclude(self, query: Callable[[PyMultipleValuesOperand], None]) -> None: ... + def deep_clone(self) -> PyMultipleValuesOperand: ... + +class PySingleValueOperand: + def greater_than(self, value: PySingleValueComparisonOperand) -> None: ... + def greater_than_or_equal_to( + self, value: PySingleValueComparisonOperand + ) -> None: ... + def less_than(self, value: PySingleValueComparisonOperand) -> None: ... + def less_than_or_equal_to(self, value: PySingleValueComparisonOperand) -> None: ... + def equal_to(self, value: PySingleValueComparisonOperand) -> None: ... + def not_equal_to(self, value: PySingleValueComparisonOperand) -> None: ... + def starts_with(self, value: PySingleValueComparisonOperand) -> None: ... + def ends_with(self, value: PySingleValueComparisonOperand) -> None: ... + def contains(self, value: PySingleValueComparisonOperand) -> None: ... + def is_in(self, values: PyMultipleValuesComparisonOperand) -> None: ... + def is_not_in(self, values: PyMultipleValuesComparisonOperand) -> None: ... + def add(self, value: PySingleValueArithmeticOperand) -> None: ... + def sub(self, value: PySingleValueArithmeticOperand) -> None: ... + def mul(self, value: PySingleValueArithmeticOperand) -> None: ... + def div(self, value: PySingleValueArithmeticOperand) -> None: ... + def pow(self, value: PySingleValueArithmeticOperand) -> None: ... + def mod(self, value: PySingleValueArithmeticOperand) -> None: ... + def round(self) -> None: ... + def ceil(self) -> None: ... + def floor(self) -> None: ... + def abs(self) -> None: ... + def sqrt(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def is_float(self) -> None: ... + def is_bool(self) -> None: ... + def is_datetime(self) -> None: ... + def is_null(self) -> None: ... + def either_or( + self, + either: Callable[[PySingleValueOperand], None], + or_: Callable[[PySingleValueOperand], None], + ) -> None: ... + def exclude(self, query: Callable[[PySingleValueOperand], None]) -> None: ... + def deep_clone(self) -> PySingleValueOperand: ... + +PySingleAttributeComparisonOperand: TypeAlias = Union[ + MedRecordAttribute, PySingleAttributeOperand +] +PySingleAttributeArithmeticOperand: TypeAlias = PySingleAttributeComparisonOperand +PyMultipleAttributesComparisonOperand: TypeAlias = Union[ + List[MedRecordAttribute], PyMultipleAttributesOperand +] + +class PyAttributesTreeOperand: + def max(self) -> PyMultipleAttributesOperand: ... + def min(self) -> PyMultipleAttributesOperand: ... + def count(self) -> PyMultipleAttributesOperand: ... + def sum(self) -> PyMultipleAttributesOperand: ... + def first(self) -> PyMultipleAttributesOperand: ... + def last(self) -> PyMultipleAttributesOperand: ... + def greater_than(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def greater_than_or_equal_to( + self, attribute: PySingleAttributeComparisonOperand + ) -> None: ... + def less_than(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def less_than_or_equal_to( + self, attribute: PySingleAttributeComparisonOperand + ) -> None: ... + def equal_to(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def not_equal_to(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def starts_with(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def ends_with(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def contains(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def is_in(self, attributes: PyMultipleAttributesComparisonOperand) -> None: ... + def is_not_in(self, attributes: PyMultipleAttributesComparisonOperand) -> None: ... + def add(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def sub(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def mul(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def pow(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def mod(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def abs(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def is_max(self) -> None: ... + def is_min(self) -> None: ... + def either_or( + self, + either: Callable[[PyAttributesTreeOperand], None], + or_: Callable[[PyAttributesTreeOperand], None], + ) -> None: ... + def exclude(self, query: Callable[[PyAttributesTreeOperand], None]) -> None: ... + def deep_clone(self) -> PyAttributesTreeOperand: ... + +class PyMultipleAttributesOperand: + def max(self) -> PySingleAttributeOperand: ... + def min(self) -> PySingleAttributeOperand: ... + def count(self) -> PySingleAttributeOperand: ... + def sum(self) -> PySingleAttributeOperand: ... + def first(self) -> PySingleAttributeOperand: ... + def last(self) -> PySingleAttributeOperand: ... + def greater_than(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def greater_than_or_equal_to( + self, attribute: PySingleAttributeComparisonOperand + ) -> None: ... + def less_than(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def less_than_or_equal_to( + self, attribute: PySingleAttributeComparisonOperand + ) -> None: ... + def equal_to(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def not_equal_to(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def starts_with(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def ends_with(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def contains(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def is_in(self, attributes: PyMultipleAttributesComparisonOperand) -> None: ... + def is_not_in(self, attributes: PyMultipleAttributesComparisonOperand) -> None: ... + def add(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def sub(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def mul(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def pow(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def mod(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def abs(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def to_values(self) -> PyMultipleValuesOperand: ... + def slice(self, start: int, end: int) -> None: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def is_max(self) -> None: ... + def is_min(self) -> None: ... + def either_or( + self, + either: Callable[[PyMultipleAttributesOperand], None], + or_: Callable[[PyMultipleAttributesOperand], None], + ) -> None: ... + def exclude(self, query: Callable[[PyMultipleAttributesOperand], None]) -> None: ... + def deep_clone(self) -> PyMultipleAttributesOperand: ... + +class PySingleAttributeOperand: + def greater_than(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def greater_than_or_equal_to( + self, attribute: PySingleAttributeComparisonOperand + ) -> None: ... + def less_than(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def less_than_or_equal_to( + self, attribute: PySingleAttributeComparisonOperand + ) -> None: ... + def equal_to(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def not_equal_to(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def starts_with(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def ends_with(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def contains(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def is_in(self, attributes: PyMultipleAttributesComparisonOperand) -> None: ... + def is_not_in(self, attributes: PyMultipleAttributesComparisonOperand) -> None: ... + def add(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def sub(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def mul(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def pow(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def mod(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def abs(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def either_or( + self, + either: Callable[[PySingleAttributeOperand], None], + or_: Callable[[PySingleAttributeOperand], None], + ) -> None: ... + def exclude(self, query: Callable[[PySingleAttributeOperand], None]) -> None: ... + def deep_clone(self) -> PySingleAttributeOperand: ... diff --git a/medmodels/medrecord/__init__.py b/medmodels/medrecord/__init__.py index 3a9f3f6f..09b4e6dd 100644 --- a/medmodels/medrecord/__init__.py +++ b/medmodels/medrecord/__init__.py @@ -11,12 +11,12 @@ ) from medmodels.medrecord.medrecord import ( EdgeIndex, - EdgeOperation, + EdgeQuery, MedRecord, NodeIndex, - NodeOperation, + NodeQuery, ) -from medmodels.medrecord.querying import edge, node +from medmodels.medrecord.querying import EdgeOperand, NodeOperand from medmodels.medrecord.schema import AttributeType, GroupSchema, Schema __all__ = [ @@ -33,10 +33,10 @@ "AttributeType", "Schema", "GroupSchema", - "node", - "edge", "NodeIndex", "EdgeIndex", - "NodeOperation", - "EdgeOperation", + "EdgeQuery", + "NodeQuery", + "NodeOperand", + "EdgeOperand", ] diff --git a/medmodels/medrecord/builder.py b/medmodels/medrecord/builder.py index 65998f0a..ab014de1 100644 --- a/medmodels/medrecord/builder.py +++ b/medmodels/medrecord/builder.py @@ -31,7 +31,7 @@ is_polars_node_dataframe_input_list, ) -NodeInput = Union[ +NodeInputBuilder = Union[ NodeTuple, List[NodeTuple], PandasNodeDataFrameInput, @@ -39,10 +39,17 @@ PolarsNodeDataFrameInput, List[PolarsNodeDataFrameInput], ] -NodeInputWithGroup = Tuple[NodeInput, Group] -def is_node_input(value: object) -> TypeIs[NodeInput]: +def is_node_input_builder(value: object) -> TypeIs[NodeInputBuilder]: + """Check if a value is a valid node input. + + Args: + value (object): The value to check. + + Returns: + TypeIs[NodeInput]: True if the value is a valid node input, otherwise False. + """ return ( is_node_tuple(value) or is_node_tuple_list(value) @@ -53,7 +60,7 @@ def is_node_input(value: object) -> TypeIs[NodeInput]: ) -EdgeInput = Union[ +EdgeInputBuilder = Union[ EdgeTuple, List[EdgeTuple], PandasEdgeDataFrameInput, @@ -61,7 +68,29 @@ def is_node_input(value: object) -> TypeIs[NodeInput]: PolarsEdgeDataFrameInput, List[PolarsEdgeDataFrameInput], ] -EdgeInputWithGroup = Tuple[EdgeInput, Group] + + +def is_edge_input_builder(value: object) -> TypeIs[EdgeInputBuilder]: + """Check if a value is a valid edge input. + + Args: + value (object): The value to check. + + Returns: + TypeIs[EdgeInput]: True if the value is a valid edge input, otherwise False. + """ + return ( + is_edge_tuple(value) + or is_edge_tuple_list(value) + or is_pandas_edge_dataframe_input(value) + or is_pandas_edge_dataframe_input_list(value) + or is_polars_edge_dataframe_input(value) + or is_polars_edge_dataframe_input_list(value) + ) + + +NodeInputWithGroup = Tuple[NodeInputBuilder, Group] +EdgeInputWithGroup = Tuple[EdgeInputBuilder, Group] class MedRecordBuilder: @@ -71,8 +100,8 @@ class MedRecordBuilder: specifying a schema. """ - _nodes: List[Union[NodeInput, NodeInputWithGroup]] - _edges: List[Union[EdgeInput, EdgeInputWithGroup]] + _nodes: List[Union[NodeInputBuilder, NodeInputWithGroup]] + _edges: List[Union[EdgeInputBuilder, EdgeInputWithGroup]] _groups: Dict[Group, GroupInfo] _schema: Optional[Schema] @@ -85,7 +114,7 @@ def __init__(self) -> None: def add_nodes( self, - nodes: NodeInput, + nodes: NodeInputBuilder, *, group: Optional[Group] = None, ) -> MedRecordBuilder: @@ -107,7 +136,7 @@ def add_nodes( def add_edges( self, - edges: EdgeInput, + edges: EdgeInputBuilder, *, group: Optional[Group] = None, ) -> MedRecordBuilder: @@ -163,64 +192,29 @@ def build(self) -> mm.MedRecord: medrecord = mm.MedRecord() for node in self._nodes: - if is_node_tuple(node): - medrecord.add_node(*node) - continue - - if ( - is_node_tuple_list(node) - or is_pandas_node_dataframe_input(node) - or is_pandas_node_dataframe_input_list(node) - or is_polars_node_dataframe_input(node) - or is_polars_node_dataframe_input_list(node) - ): + if is_node_input_builder(node): medrecord.add_nodes(node) continue group = node[1] node = node[0] - if is_node_tuple(node): - medrecord.add_node(*node, group) - continue - medrecord.add_nodes(node, group) for edge in self._edges: - if is_edge_tuple(edge): - medrecord.add_edge(*edge) - continue - - if ( - is_edge_tuple_list(edge) - or is_pandas_edge_dataframe_input(edge) - or is_pandas_edge_dataframe_input_list(edge) - or is_polars_edge_dataframe_input(edge) - or is_polars_edge_dataframe_input_list(edge) - ): + if is_edge_input_builder(edge): medrecord.add_edges(edge) continue group = edge[1] edge = edge[0] - if is_edge_tuple(edge): - medrecord.add_edge(*edge, group) - continue - - if ( - is_edge_tuple_list(edge) - or is_pandas_edge_dataframe_input(edge) - or is_pandas_edge_dataframe_input_list(edge) - or is_polars_edge_dataframe_input(edge) - or is_polars_edge_dataframe_input_list(edge) - ): - medrecord.add_edges(edge, group) + medrecord.add_edges(edge, group) for group in self._groups: if medrecord.contains_group(group): - medrecord.add_node_to_group(group, self._groups[group]["nodes"]) - medrecord.add_edge_to_group(group, self._groups[group]["edges"]) + medrecord.add_nodes_to_group(group, self._groups[group]["nodes"]) + medrecord.add_edges_to_group(group, self._groups[group]["edges"]) else: medrecord.add_group( group, self._groups[group]["nodes"], self._groups[group]["edges"] diff --git a/medmodels/medrecord/datatype.py b/medmodels/medrecord/datatype.py index 2dcfdea3..e8b78a6d 100644 --- a/medmodels/medrecord/datatype.py +++ b/medmodels/medrecord/datatype.py @@ -50,7 +50,7 @@ def __repr__(self) -> str: ... def __eq__(self, value: object) -> bool: ... @staticmethod - def _from_pydatatype(datatype: PyDataType) -> DataType: + def _from_py_data_type(datatype: PyDataType) -> DataType: if isinstance(datatype, PyString): return String() elif isinstance(datatype, PyInt): @@ -67,11 +67,11 @@ def _from_pydatatype(datatype: PyDataType) -> DataType: return Any() elif isinstance(datatype, PyUnion): return Union( - DataType._from_pydatatype(datatype.dtype1), - DataType._from_pydatatype(datatype.dtype2), + DataType._from_py_data_type(datatype.dtype1), + DataType._from_py_data_type(datatype.dtype2), ) else: - return Option(DataType._from_pydatatype(datatype.dtype)) + return Option(DataType._from_py_data_type(datatype.dtype)) class String(DataType): @@ -222,18 +222,18 @@ def _inner(self) -> PyDataType: return self._union def __str__(self) -> str: - return f"Union({DataType._from_pydatatype(self._union.dtype1).__str__()}, {DataType._from_pydatatype(self._union.dtype2).__str__()})" + return f"Union({DataType._from_py_data_type(self._union.dtype1).__str__()}, {DataType._from_py_data_type(self._union.dtype2).__str__()})" def __repr__(self) -> str: - return f"DataType.Union({DataType._from_pydatatype(self._union.dtype1).__repr__()}, {DataType._from_pydatatype(self._union.dtype2).__repr__()})" + return f"DataType.Union({DataType._from_py_data_type(self._union.dtype1).__repr__()}, {DataType._from_py_data_type(self._union.dtype2).__repr__()})" def __eq__(self, value: object) -> bool: return ( isinstance(value, Union) - and DataType._from_pydatatype(self._union.dtype1) - == DataType._from_pydatatype(value._union.dtype1) - and DataType._from_pydatatype(self._union.dtype2) - == DataType._from_pydatatype(value._union.dtype2) + and DataType._from_py_data_type(self._union.dtype1) + == DataType._from_py_data_type(value._union.dtype1) + and DataType._from_py_data_type(self._union.dtype2) + == DataType._from_py_data_type(value._union.dtype2) ) @@ -247,12 +247,12 @@ def _inner(self) -> PyDataType: return self._option def __str__(self) -> str: - return f"Option({DataType._from_pydatatype(self._option.dtype).__str__()})" + return f"Option({DataType._from_py_data_type(self._option.dtype).__str__()})" def __repr__(self) -> str: - return f"DataType.Option({DataType._from_pydatatype(self._option.dtype).__repr__()})" + return f"DataType.Option({DataType._from_py_data_type(self._option.dtype).__repr__()})" def __eq__(self, value: object) -> bool: - return isinstance(value, Option) and DataType._from_pydatatype( + return isinstance(value, Option) and DataType._from_py_data_type( self._option.dtype - ) == DataType._from_pydatatype(value._option.dtype) + ) == DataType._from_py_data_type(value._option.dtype) diff --git a/medmodels/medrecord/indexers.py b/medmodels/medrecord/indexers.py index b76404e1..33a0f7df 100644 --- a/medmodels/medrecord/indexers.py +++ b/medmodels/medrecord/indexers.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Tuple, Union, overload +from typing import TYPE_CHECKING, Callable, Dict, Tuple, Union, overload -from medmodels.medrecord.querying import EdgeOperation, NodeOperation +from medmodels.medrecord.querying import EdgeQuery, NodeQuery from medmodels.medrecord.types import ( Attributes, AttributesInput, @@ -48,10 +48,10 @@ def __getitem__( self, key: Union[ NodeIndexInputList, - NodeOperation, + NodeQuery, slice, Tuple[ - Union[NodeIndexInputList, NodeOperation, slice], + Union[NodeIndexInputList, NodeQuery, slice], Union[MedRecordAttributeInputList, slice], ], ], @@ -60,7 +60,7 @@ def __getitem__( @overload def __getitem__( self, - key: Tuple[Union[NodeIndexInputList, NodeOperation, slice], MedRecordAttribute], + key: Tuple[Union[NodeIndexInputList, NodeQuery, slice], MedRecordAttribute], ) -> Dict[NodeIndex, MedRecordValue]: ... def __getitem__( @@ -68,10 +68,10 @@ def __getitem__( key: Union[ NodeIndex, NodeIndexInputList, - NodeOperation, + NodeQuery, slice, Tuple[ - Union[NodeIndex, NodeIndexInputList, NodeOperation, slice], + Union[NodeIndex, NodeIndexInputList, NodeQuery, slice], Union[MedRecordAttribute, MedRecordAttributeInputList, slice], ], ], @@ -87,7 +87,7 @@ def __getitem__( if isinstance(key, list): return self._medrecord._medrecord.node(key) - if isinstance(key, NodeOperation): + if isinstance(key, Callable): return self._medrecord._medrecord.node(self._medrecord.select_nodes(key)) if isinstance(key, slice): @@ -112,7 +112,7 @@ def __getitem__( return {x: attributes[x][attribute_selection] for x in attributes.keys()} - if isinstance(index_selection, NodeOperation) and is_medrecord_attribute( + if isinstance(index_selection, Callable) and is_medrecord_attribute( attribute_selection ): attributes = self._medrecord._medrecord.node( @@ -151,7 +151,7 @@ def __getitem__( for x in attributes.keys() } - if isinstance(index_selection, NodeOperation) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, list ): attributes = self._medrecord._medrecord.node( @@ -198,7 +198,7 @@ def __getitem__( return self._medrecord._medrecord.node(index_selection) - if isinstance(index_selection, NodeOperation) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, slice ): if ( @@ -230,7 +230,7 @@ def __getitem__( @overload def __setitem__( self, - key: Union[NodeIndex, NodeIndexInputList, NodeOperation, slice], + key: Union[NodeIndex, NodeIndexInputList, NodeQuery, slice], value: AttributesInput, ) -> None: ... @@ -238,7 +238,7 @@ def __setitem__( def __setitem__( self, key: Tuple[ - Union[NodeIndex, NodeIndexInputList, NodeOperation, slice], + Union[NodeIndex, NodeIndexInputList, NodeQuery, slice], Union[MedRecordAttribute, MedRecordAttributeInputList, slice], ], value: MedRecordValue, @@ -249,10 +249,10 @@ def __setitem__( key: Union[ NodeIndex, NodeIndexInputList, - NodeOperation, + NodeQuery, slice, Tuple[ - Union[NodeIndex, NodeIndexInputList, NodeOperation, slice], + Union[NodeIndex, NodeIndexInputList, NodeQuery, slice], Union[MedRecordAttribute, MedRecordAttributeInputList, slice], ], ], @@ -270,7 +270,7 @@ def __setitem__( return self._medrecord._medrecord.replace_node_attributes(key, value) - if isinstance(key, NodeOperation): + if isinstance(key, Callable): if not is_attributes(value): raise ValueError("Invalid value type. Expected Attributes") @@ -311,7 +311,7 @@ def __setitem__( index_selection, attribute_selection, value ) - if isinstance(index_selection, NodeOperation) and is_medrecord_attribute( + if isinstance(index_selection, Callable) and is_medrecord_attribute( attribute_selection ): if not is_medrecord_value(value): @@ -364,7 +364,7 @@ def __setitem__( return - if isinstance(index_selection, NodeOperation) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, list ): if not is_medrecord_value(value): @@ -440,7 +440,7 @@ def __setitem__( return - if isinstance(index_selection, NodeOperation) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, slice ): if ( @@ -494,7 +494,7 @@ def __setitem__( def __delitem__( self, key: Tuple[ - Union[NodeIndex, NodeIndexInputList, NodeOperation, slice], + Union[NodeIndex, NodeIndexInputList, NodeQuery, slice], Union[MedRecordAttribute, MedRecordAttributeInputList, slice], ], ) -> None: @@ -514,7 +514,7 @@ def __delitem__( index_selection, attribute_selection ) - if isinstance(index_selection, NodeOperation) and is_medrecord_attribute( + if isinstance(index_selection, Callable) and is_medrecord_attribute( attribute_selection ): return self._medrecord._medrecord.remove_node_attribute( @@ -553,7 +553,7 @@ def __delitem__( return - if isinstance(index_selection, NodeOperation) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, list ): for attribute in attribute_selection: @@ -602,7 +602,7 @@ def __delitem__( index_selection, {} ) - if isinstance(index_selection, NodeOperation) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, slice ): if ( @@ -658,10 +658,10 @@ def __getitem__( self, key: Union[ EdgeIndexInputList, - EdgeOperation, + EdgeQuery, slice, Tuple[ - Union[EdgeIndexInputList, EdgeOperation, slice], + Union[EdgeIndexInputList, EdgeQuery, slice], Union[MedRecordAttributeInputList, slice], ], ], @@ -670,7 +670,7 @@ def __getitem__( @overload def __getitem__( self, - key: Tuple[Union[EdgeIndexInputList, EdgeOperation, slice], MedRecordAttribute], + key: Tuple[Union[EdgeIndexInputList, EdgeQuery, slice], MedRecordAttribute], ) -> Dict[EdgeIndex, MedRecordValue]: ... def __getitem__( @@ -678,10 +678,10 @@ def __getitem__( key: Union[ EdgeIndex, EdgeIndexInputList, - EdgeOperation, + EdgeQuery, slice, Tuple[ - Union[EdgeIndex, EdgeIndexInputList, EdgeOperation, slice], + Union[EdgeIndex, EdgeIndexInputList, EdgeQuery, slice], Union[MedRecordAttribute, MedRecordAttributeInputList, slice], ], ], @@ -697,7 +697,7 @@ def __getitem__( if isinstance(key, list): return self._medrecord._medrecord.edge(key) - if isinstance(key, EdgeOperation): + if isinstance(key, Callable): return self._medrecord._medrecord.edge(self._medrecord.select_edges(key)) if isinstance(key, slice): @@ -722,7 +722,7 @@ def __getitem__( return {x: attributes[x][attribute_selection] for x in attributes.keys()} - if isinstance(index_selection, EdgeOperation) and is_medrecord_attribute( + if isinstance(index_selection, Callable) and is_medrecord_attribute( attribute_selection ): attributes = self._medrecord._medrecord.edge( @@ -761,7 +761,7 @@ def __getitem__( for x in attributes.keys() } - if isinstance(index_selection, EdgeOperation) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, list ): attributes = self._medrecord._medrecord.edge( @@ -808,7 +808,7 @@ def __getitem__( return self._medrecord._medrecord.edge(index_selection) - if isinstance(index_selection, EdgeOperation) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, slice ): if ( @@ -840,7 +840,7 @@ def __getitem__( @overload def __setitem__( self, - key: Union[EdgeIndex, EdgeIndexInputList, EdgeOperation, slice], + key: Union[EdgeIndex, EdgeIndexInputList, EdgeQuery, slice], value: AttributesInput, ) -> None: ... @@ -848,7 +848,7 @@ def __setitem__( def __setitem__( self, key: Tuple[ - Union[EdgeIndex, EdgeIndexInputList, EdgeOperation, slice], + Union[EdgeIndex, EdgeIndexInputList, EdgeQuery, slice], Union[MedRecordAttribute, MedRecordAttributeInputList, slice], ], value: MedRecordValue, @@ -859,10 +859,10 @@ def __setitem__( key: Union[ EdgeIndex, EdgeIndexInputList, - EdgeOperation, + EdgeQuery, slice, Tuple[ - Union[EdgeIndex, EdgeIndexInputList, EdgeOperation, slice], + Union[EdgeIndex, EdgeIndexInputList, EdgeQuery, slice], Union[MedRecordAttribute, MedRecordAttributeInputList, slice], ], ], @@ -880,7 +880,7 @@ def __setitem__( return self._medrecord._medrecord.replace_edge_attributes(key, value) - if isinstance(key, EdgeOperation): + if isinstance(key, Callable): if not is_attributes(value): raise ValueError("Invalid value type. Expected Attributes") @@ -921,7 +921,7 @@ def __setitem__( index_selection, attribute_selection, value ) - if isinstance(index_selection, EdgeOperation) and is_medrecord_attribute( + if isinstance(index_selection, Callable) and is_medrecord_attribute( attribute_selection ): if not is_medrecord_value(value): @@ -974,7 +974,7 @@ def __setitem__( return - if isinstance(index_selection, EdgeOperation) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, list ): if not is_medrecord_value(value): @@ -1048,7 +1048,7 @@ def __setitem__( return - if isinstance(index_selection, EdgeOperation) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, slice ): if ( @@ -1102,7 +1102,7 @@ def __setitem__( def __delitem__( self, key: Tuple[ - Union[EdgeIndex, EdgeIndexInputList, EdgeOperation, slice], + Union[EdgeIndex, EdgeIndexInputList, EdgeQuery, slice], Union[MedRecordAttribute, MedRecordAttributeInputList, slice], ], ) -> None: @@ -1122,7 +1122,7 @@ def __delitem__( index_selection, attribute_selection ) - if isinstance(index_selection, EdgeOperation) and is_medrecord_attribute( + if isinstance(index_selection, Callable) and is_medrecord_attribute( attribute_selection ): return self._medrecord._medrecord.remove_edge_attribute( @@ -1161,7 +1161,7 @@ def __delitem__( return - if isinstance(index_selection, EdgeOperation) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, list ): for attribute in attribute_selection: @@ -1210,7 +1210,7 @@ def __delitem__( index_selection, {} ) - if isinstance(index_selection, EdgeOperation) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, slice ): if ( diff --git a/medmodels/medrecord/medrecord.py b/medmodels/medrecord/medrecord.py index e4068111..a7aebbcb 100644 --- a/medmodels/medrecord/medrecord.py +++ b/medmodels/medrecord/medrecord.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, List, Optional, Sequence, Union, overload +from typing import Callable, Dict, List, Optional, Sequence, Union, overload import polars as pl @@ -8,7 +8,7 @@ from medmodels.medrecord._overview import extract_attribute_summary, prettify_table from medmodels.medrecord.builder import MedRecordBuilder from medmodels.medrecord.indexers import EdgeIndexer, NodeIndexer -from medmodels.medrecord.querying import EdgeOperation, NodeOperation +from medmodels.medrecord.querying import EdgeOperand, EdgeQuery, NodeOperand, NodeQuery from medmodels.medrecord.schema import Schema from medmodels.medrecord.types import ( Attributes, @@ -16,17 +16,21 @@ AttributeSummary, EdgeIndex, EdgeIndexInputList, + EdgeInput, EdgeTuple, Group, GroupInfo, GroupInputList, NodeIndex, NodeIndexInputList, + NodeInput, NodeTuple, PandasEdgeDataFrameInput, PandasNodeDataFrameInput, PolarsEdgeDataFrameInput, PolarsNodeDataFrameInput, + is_edge_tuple, + is_node_tuple, is_pandas_edge_dataframe_input, is_pandas_edge_dataframe_input_list, is_pandas_node_dataframe_input, @@ -297,7 +301,7 @@ def schema(self) -> Schema: Returns: Schema: The schema of the MedRecord. """ - return Schema._from_pyschema(self._medrecord.schema) + return Schema._from_py_schema(self._medrecord.schema) @schema.setter def schema(self, schema: Schema) -> None: @@ -412,11 +416,11 @@ def outgoing_edges(self, node: NodeIndex) -> List[EdgeIndex]: ... @overload def outgoing_edges( - self, node: Union[NodeIndexInputList, NodeOperation] + self, node: Union[NodeIndexInputList, NodeQuery] ) -> Dict[NodeIndex, List[EdgeIndex]]: ... def outgoing_edges( - self, node: Union[NodeIndex, NodeIndexInputList, NodeOperation] + self, node: Union[NodeIndex, NodeIndexInputList, NodeQuery] ) -> Union[List[EdgeIndex], Dict[NodeIndex, List[EdgeIndex]]]: """Lists the outgoing edges of the specified node(s) in the MedRecord. @@ -425,14 +429,14 @@ def outgoing_edges( its list of outgoing edge indices. Args: - node (Union[NodeIndex, NodeIndexInputList, NodeOperation]): One or more - node indices or a node operation. + node (Union[NodeIndex, NodeIndexInputList, NodeQuery]): One or more + node indices or a node query. Returns: Union[List[EdgeIndex], Dict[NodeIndex, List[EdgeIndex]]]: Outgoing edge indices for each specified node. """ - if isinstance(node, NodeOperation): + if isinstance(node, Callable): return self._medrecord.outgoing_edges(self.select_nodes(node)) indices = self._medrecord.outgoing_edges( @@ -449,11 +453,11 @@ def incoming_edges(self, node: NodeIndex) -> List[EdgeIndex]: ... @overload def incoming_edges( - self, node: Union[NodeIndexInputList, NodeOperation] + self, node: Union[NodeIndexInputList, NodeQuery] ) -> Dict[NodeIndex, List[EdgeIndex]]: ... def incoming_edges( - self, node: Union[NodeIndex, NodeIndexInputList, NodeOperation] + self, node: Union[NodeIndex, NodeIndexInputList, NodeQuery] ) -> Union[List[EdgeIndex], Dict[NodeIndex, List[EdgeIndex]]]: """Lists the incoming edges of the specified node(s) in the MedRecord. @@ -462,14 +466,14 @@ def incoming_edges( its list of incoming edge indices. Args: - node (Union[NodeIndex, NodeIndexInputList, NodeOperation]): One or more - node indices or a node operation. + node (Union[NodeIndex, NodeIndexInputList, NodeQuery]): One or more + node indices or a node query. Returns: Union[List[EdgeIndex], Dict[NodeIndex, List[EdgeIndex]]]: Incoming edge indices for each specified node. """ - if isinstance(node, NodeOperation): + if isinstance(node, Callable): return self._medrecord.incoming_edges(self.select_nodes(node)) indices = self._medrecord.incoming_edges( @@ -486,11 +490,11 @@ def edge_endpoints(self, edge: EdgeIndex) -> tuple[NodeIndex, NodeIndex]: ... @overload def edge_endpoints( - self, edge: Union[EdgeIndexInputList, EdgeOperation] + self, edge: Union[EdgeIndexInputList, EdgeQuery] ) -> Dict[EdgeIndex, tuple[NodeIndex, NodeIndex]]: ... def edge_endpoints( - self, edge: Union[EdgeIndex, EdgeIndexInputList, EdgeOperation] + self, edge: Union[EdgeIndex, EdgeIndexInputList, EdgeQuery] ) -> Union[ tuple[NodeIndex, NodeIndex], Dict[EdgeIndex, tuple[NodeIndex, NodeIndex]] ]: @@ -501,8 +505,8 @@ def edge_endpoints( a dictionary mapping each edge index to its tuple of node indices. Args: - edge (Union[EdgeIndex, EdgeIndexInputList, EdgeOperation]): One or more - edge indices. + edge (Union[EdgeIndex, EdgeIndexInputList, EdgeQuery]): One or more + edge indices or an edge query. Returns: Union[tuple[NodeIndex, NodeIndex], @@ -510,7 +514,7 @@ def edge_endpoints( Tuple of node indices or a dictionary mapping each edge to its node indices. """ - if isinstance(edge, EdgeOperation): + if isinstance(edge, Callable): return self._medrecord.edge_endpoints(self.select_edges(edge)) endpoints = self._medrecord.edge_endpoints( @@ -524,8 +528,8 @@ def edge_endpoints( def edges_connecting( self, - source_node: Union[NodeIndex, NodeIndexInputList, NodeOperation], - target_node: Union[NodeIndex, NodeIndexInputList, NodeOperation], + source_node: Union[NodeIndex, NodeIndexInputList, NodeQuery], + target_node: Union[NodeIndex, NodeIndexInputList, NodeQuery], directed: bool = True, ) -> List[EdgeIndex]: """Retrieves the edges connecting the specified source and target nodes in the MedRecord. @@ -536,11 +540,11 @@ def edges_connecting( target nodes. Args: - source_node (Union[NodeIndex, NodeIndexInputList, NodeOperation]): - The index or indices of the source node(s), or a NodeOperation to + source_node (Union[NodeIndex, NodeIndexInputList, NodeQuery]): + The index or indices of the source node(s), or a node query to select source nodes. - target_node (Union[NodeIndex, NodeIndexInputList, NodeOperation]): - The index or indices of the target node(s), or a NodeOperation to + target_node (Union[NodeIndex, NodeIndexInputList, NodeQuery]): + The index or indices of the target node(s), or a node query to select target nodes. directed (bool, optional): Whether to consider edges as directed. @@ -549,10 +553,10 @@ def edges_connecting( target nodes. """ - if isinstance(source_node, NodeOperation): + if isinstance(source_node, Callable): source_node = self.select_nodes(source_node) - if isinstance(target_node, NodeOperation): + if isinstance(target_node, Callable): target_node = self.select_nodes(target_node) if directed: @@ -566,42 +570,16 @@ def edges_connecting( (target_node if isinstance(target_node, list) else [target_node]), ) - def add_node( - self, - node: NodeIndex, - attributes: AttributesInput, - group: Optional[Group] = None, - ) -> None: - """Adds a node with specified attributes to the MedRecord instance. Optionally adds the node to a group. - - Args: - node (NodeIndex): The index of the node to add. - attributes (Attributes): A dictionary of the node's attributes. - group (Optional[Group]): The name of the group to add the node to, optional. - - Returns: - None - """ - self._medrecord.add_node(node, attributes) - - if group is None: - return - - if not self.contains_group(group): - self.add_group(group) - - self.add_node_to_group(group, node) - @overload - def remove_node(self, node: NodeIndex) -> Attributes: ... + def remove_nodes(self, nodes: NodeIndex) -> Attributes: ... @overload - def remove_node( - self, node: Union[NodeIndexInputList, NodeOperation] + def remove_nodes( + self, nodes: Union[NodeIndexInputList, NodeQuery] ) -> Dict[NodeIndex, Attributes]: ... - def remove_node( - self, node: Union[NodeIndex, NodeIndexInputList, NodeOperation] + def remove_nodes( + self, nodes: Union[NodeIndex, NodeIndexInputList, NodeQuery] ) -> Union[Attributes, Dict[NodeIndex, Attributes]]: """Removes a node or multiple nodes from the MedRecord and returns their attributes. @@ -610,47 +588,40 @@ def remove_node( index to its attributes. Args: - node (Union[NodeIndex, NodeIndexInputList, NodeOperation]): One or more - node indices or a node operation. + nodes (Union[NodeIndex, NodeIndexInputList, NodeQuery]): One or more + node indices or a node query. Returns: Union[Attributes, Dict[NodeIndex, Attributes]]: Attributes of the removed node(s). """ - if isinstance(node, NodeOperation): - return self._medrecord.remove_node(self.select_nodes(node)) + if isinstance(nodes, Callable): + return self._medrecord.remove_nodes(self.select_nodes(nodes)) - attributes = self._medrecord.remove_node( - node if isinstance(node, list) else [node] + attributes = self._medrecord.remove_nodes( + nodes if isinstance(nodes, list) else [nodes] ) - if isinstance(node, list): + if isinstance(nodes, list): return attributes - return attributes[node] + return attributes[nodes] def add_nodes( self, - nodes: Union[ - Sequence[NodeTuple], - PandasNodeDataFrameInput, - List[PandasNodeDataFrameInput], - PolarsNodeDataFrameInput, - List[PolarsNodeDataFrameInput], - ], + nodes: NodeInput, group: Optional[Group] = None, ) -> None: - """Adds multiple nodes to the MedRecord from different data formats and optionally assigns them to a group. + """Adds nodes to the MedRecord from different data formats and optionally assigns them to a group. - Accepts a list of tuples, DataFrame(s), or PolarsNodeDataFrameInput(s) to add - nodes. If a DataFrame or list of DataFrames is used, the add_nodes_pandas method - is called. If PolarsNodeDataFrameInput(s) are provided, each tuple must include - a DataFrame and the index column. If a group is specified, the nodes are added - to the group. + Accepts a node tuple (single node added), a list of tuples, DataFrame(s), or + PolarsNodeDataFrameInput(s) to add nodes. If a DataFrame or list of DataFrames + is used, the add_nodes_pandas method is called. If PolarsNodeDataFrameInput(s) + are provided, each tuple must include a DataFrame and the index column. If a + group is specified, the nodes are added to the group. Args: - nodes (Union[Sequence[NodeTuple], PandasNodeDataFrameInput, List[PandasNodeDataFrameInput], PolarsNodeDataFrameInput, List[PolarsNodeDataFrameInput]]): - Data representing nodes in various formats. + nodes (NodeInput): Data representing nodes in various formats. group (Optional[Group]): The name of the group to add the nodes to. If not specified, the nodes are added to the MedRecord without a group. @@ -666,6 +637,9 @@ def add_nodes( ) or is_polars_node_dataframe_input_list(nodes): self.add_nodes_polars(nodes, group) else: + if is_node_tuple(nodes): + nodes = [nodes] + self._medrecord.add_nodes(nodes) if group is None: @@ -674,7 +648,7 @@ def add_nodes( if not self.contains_group(group): self.add_group(group) - self.add_node_to_group(group, [node[0] for node in nodes]) + self.add_nodes_to_group(group, [node[0] for node in nodes]) def add_nodes_pandas( self, @@ -740,49 +714,18 @@ def add_nodes_polars( else: node_indices = nodes[0][nodes[1]].to_list() - self.add_node_to_group(group, node_indices) - - def add_edge( - self, - source_node: NodeIndex, - target_node: NodeIndex, - attributes: AttributesInput, - group: Optional[Group] = None, - ) -> EdgeIndex: - """Adds an edge between two specified nodes with given attributes. Optionally assigns the edge to a group. - - Args: - source_node (NodeIndex): Index of the source node. - target_node (NodeIndex): Index of the target node. - attributes (AttributesInput): Dictionary or mapping of edge attributes. - group (Optional[Group]): The name of the group to add the edge to. If not - specified, the edge is added to the MedRecord without a group. - - Returns: - EdgeIndex: The index of the added edge. - """ - edge_index = self._medrecord.add_edge(source_node, target_node, attributes) - - if group is None: - return edge_index - - if not self.contains_group(group): - self.add_group(group) - - self.add_edge_to_group(group, edge_index) - - return edge_index + self.add_nodes_to_group(group, node_indices) @overload - def remove_edge(self, edge: EdgeIndex) -> Attributes: ... + def remove_edges(self, edges: EdgeIndex) -> Attributes: ... @overload - def remove_edge( - self, edge: Union[EdgeIndexInputList, EdgeOperation] + def remove_edges( + self, edges: Union[EdgeIndexInputList, EdgeQuery] ) -> Dict[EdgeIndex, Attributes]: ... - def remove_edge( - self, edge: Union[EdgeIndex, EdgeIndexInputList, EdgeOperation] + def remove_edges( + self, edges: Union[EdgeIndex, EdgeIndexInputList, EdgeQuery] ) -> Union[Attributes, Dict[EdgeIndex, Attributes]]: """Removes an edge or multiple edges from the MedRecord and returns their attributes. @@ -791,49 +734,41 @@ def remove_edge( index to its attributes. Args: - edge (Union[EdgeIndex, EdgeIndexInputList, EdgeOperation]): One or more - edge indices or an edge operation. + edge (Union[EdgeIndex, EdgeIndexInputList, EdgeQuery]): One or more + edge indices or an edge query. Returns: Union[Attributes, Dict[EdgeIndex, Attributes]]: Attributes of the removed edge(s). """ - if isinstance(edge, EdgeOperation): - return self._medrecord.remove_edge(self.select_edges(edge)) + if isinstance(edges, Callable): + return self._medrecord.remove_edges(self.select_edges(edges)) - attributes = self._medrecord.remove_edge( - edge if isinstance(edge, list) else [edge] + attributes = self._medrecord.remove_edges( + edges if isinstance(edges, list) else [edges] ) - if isinstance(edge, list): + if isinstance(edges, list): return attributes - return attributes[edge] + return attributes[edges] def add_edges( self, - edges: Union[ - Sequence[EdgeTuple], - PandasEdgeDataFrameInput, - List[PandasEdgeDataFrameInput], - PolarsEdgeDataFrameInput, - List[PolarsEdgeDataFrameInput], - ], + edges: EdgeInput, group: Optional[Group] = None, ) -> List[EdgeIndex]: """Adds edges to the MedRecord instance from various data formats. Optionally assigns them to a group. - Accepts lists of tuples, DataFrame(s), or EdgeDataFrameInput(s) to add edges. - Each tuple must have indices for source and target nodes and a dictionary of - attributes. If a DataFrame or list of DataFrames is used, - the add_edges_dataframe method is invoked. If PolarsEdgeDataFrameInput(s) are + Accepts edge tuple, lists of tuples, DataFrame(s), or EdgeDataFrameInput(s) to + add edges. Each tuple must have indices for source and target nodes and a + dictionary of attributes. If a DataFrame or list of DataFrames is used, the + add_edges_dataframe method is invoked. If PolarsEdgeDataFrameInput(s) are provided, each tuple must include a DataFrame and index columns for source and target nodes. If a group is specified, the edges are added to the group. Args: - edges (Union[Sequence[EdgeTuple], PandasEdgeDataFrameInput, List[PolarsEdgeDataFrameInput]]): - List[PandasEdgeDataFrameInput], PolarsEdgeDataFrameInput, - Data representing edges in several formats. + edges (EdgeInput): Data representing edges in several formats. group (Optional[Group]): The name of the group to add the edges to. If not specified, the edges are added to the MedRecord without a group. @@ -849,6 +784,9 @@ def add_edges( ) or is_polars_edge_dataframe_input_list(edges): return self.add_edges_polars(edges, group) else: + if is_edge_tuple(edges): + edges = [edges] + edge_indices = self._medrecord.add_edges(edges) if group is None: @@ -857,7 +795,7 @@ def add_edges( if not self.contains_group(group): self.add_group(group) - self.add_edge_to_group(group, edge_indices) + self.add_edges_to_group(group, edge_indices) return edge_indices @@ -920,15 +858,15 @@ def add_edges_polars( if not self.contains_group(group): self.add_group(group) - self.add_edge_to_group(group, edge_indices) + self.add_edges_to_group(group, edge_indices) return edge_indices def add_group( self, group: Group, - nodes: Optional[Union[NodeIndex, NodeIndexInputList, NodeOperation]] = None, - edges: Optional[Union[EdgeIndex, EdgeIndexInputList, EdgeOperation]] = None, + nodes: Optional[Union[NodeIndex, NodeIndexInputList, NodeQuery]] = None, + edges: Optional[Union[EdgeIndex, EdgeIndexInputList, EdgeQuery]] = None, ) -> None: """Adds a group to the MedRecord instance with an optional list of node indices. @@ -937,20 +875,20 @@ def add_group( Args: group (Group): The name of the group to add. - nodes (Optional[Union[NodeIndex, NodeIndexInputList, NodeOperation]]): - One or more node indices or a node operation to add + nodes (Optional[Union[NodeIndex, NodeIndexInputList, NodeQuery]]): + One or more node indices or a node query to add to the group, optional. - edges (Optional[Union[EdgeIndex, EdgeIndexInputList, EdgeOperation]]): - One or more edge indices or an edge operation to add + edges (Optional[Union[EdgeIndex, EdgeIndexInputList, EdgeQuery]]): + One or more edge indices or an edge query to add to the group, optional. Returns: None """ - if isinstance(nodes, NodeOperation): + if isinstance(nodes, Callable): nodes = self.select_nodes(nodes) - if isinstance(edges, EdgeOperation): + if isinstance(edges, Callable): edges = self.select_edges(edges) if nodes is not None and edges is not None: @@ -970,101 +908,101 @@ def add_group( else: return self._medrecord.add_group(group, None, None) - def remove_group(self, group: Union[Group, GroupInputList]) -> None: + def remove_groups(self, groups: Union[Group, GroupInputList]) -> None: """Removes one or more groups from the MedRecord instance. Args: - group (Union[Group, GroupInputList]): One or more group names to remove. + groups (Union[Group, GroupInputList]): One or more group names to remove. Returns: None """ - return self._medrecord.remove_group( - group if isinstance(group, list) else [group] + return self._medrecord.remove_groups( + groups if isinstance(groups, list) else [groups] ) - def add_node_to_group( - self, group: Group, node: Union[NodeIndex, NodeIndexInputList, NodeOperation] + def add_nodes_to_group( + self, group: Group, nodes: Union[NodeIndex, NodeIndexInputList, NodeQuery] ) -> None: """Adds one or more nodes to a specified group in the MedRecord. Args: group (Group): The name of the group to add nodes to. - node (Union[NodeIndex, NodeIndexInputList, NodeOperation]): One or more - node indices or a node operation to add to the group. + nodes (Union[NodeIndex, NodeIndexInputList, NodeQuery]): One or more + node indices or a node query to add to the group. Returns: None """ - if isinstance(node, NodeOperation): - return self._medrecord.add_node_to_group(group, self.select_nodes(node)) + if isinstance(nodes, Callable): + return self._medrecord.add_nodes_to_group(group, self.select_nodes(nodes)) - return self._medrecord.add_node_to_group( - group, node if isinstance(node, list) else [node] + return self._medrecord.add_nodes_to_group( + group, nodes if isinstance(nodes, list) else [nodes] ) - def add_edge_to_group( - self, group: Group, edge: Union[EdgeIndex, EdgeIndexInputList, EdgeOperation] + def add_edges_to_group( + self, group: Group, edges: Union[EdgeIndex, EdgeIndexInputList, EdgeQuery] ) -> None: """Adds one or more edges to a specified group in the MedRecord. Args: group (Group): The name of the group to add edges to. - edge (Union[EdgeIndex, EdgeIndexInputList, EdgeOperation]): One or more - edge indices or an edge operation to add to the group. + edges (Union[EdgeIndex, EdgeIndexInputList, EdgeQuery]): One or more + edge indices or an edge query to add to the group. Returns: None """ - if isinstance(edge, EdgeOperation): - return self._medrecord.add_edge_to_group(group, self.select_edges(edge)) + if isinstance(edges, Callable): + return self._medrecord.add_edges_to_group(group, self.select_edges(edges)) - return self._medrecord.add_edge_to_group( - group, edge if isinstance(edge, list) else [edge] + return self._medrecord.add_edges_to_group( + group, edges if isinstance(edges, list) else [edges] ) - def remove_node_from_group( - self, group: Group, node: Union[NodeIndex, NodeIndexInputList, NodeOperation] + def remove_nodes_from_group( + self, group: Group, nodes: Union[NodeIndex, NodeIndexInputList, NodeQuery] ) -> None: """Removes one or more nodes from a specified group in the MedRecord. Args: group (Group): The name of the group from which to remove nodes. - node (Union[NodeIndex, NodeIndexInputList, NodeOperation]): One or more - node indices or a node operation to remove from the group. + nodes (Union[NodeIndex, NodeIndexInputList, NodeQuery]): One or more + node indices or a node query to remove from the group. Returns: None """ - if isinstance(node, NodeOperation): - return self._medrecord.remove_node_from_group( - group, self.select_nodes(node) + if isinstance(nodes, Callable): + return self._medrecord.remove_nodes_from_group( + group, self.select_nodes(nodes) ) - return self._medrecord.remove_node_from_group( - group, node if isinstance(node, list) else [node] + return self._medrecord.remove_nodes_from_group( + group, nodes if isinstance(nodes, list) else [nodes] ) - def remove_edge_from_group( - self, group: Group, edge: Union[EdgeIndex, EdgeIndexInputList, EdgeOperation] + def remove_edges_from_group( + self, group: Group, edges: Union[EdgeIndex, EdgeIndexInputList, EdgeQuery] ) -> None: """Removes one or more edges from a specified group in the MedRecord. Args: group (Group): The name of the group from which to remove edges. - edge (Union[EdgeIndex, EdgeIndexInputList, EdgeOperation]): One or more - edge indices or an edge operation to remove from the group. + edges (Union[EdgeIndex, EdgeIndexInputList, EdgeQuery]): One or more + edge indices or an edge query to remove from the group. Returns: None """ - if isinstance(edge, EdgeOperation): - return self._medrecord.remove_edge_from_group( - group, self.select_edges(edge) + if isinstance(edges, Callable): + return self._medrecord.remove_edges_from_group( + group, self.select_edges(edges) ) - return self._medrecord.remove_edge_from_group( - group, edge if isinstance(edge, list) else [edge] + return self._medrecord.remove_edges_from_group( + group, edges if isinstance(edges, list) else [edges] ) @overload @@ -1134,11 +1072,11 @@ def groups_of_node(self, node: NodeIndex) -> List[Group]: ... @overload def groups_of_node( - self, node: Union[NodeIndexInputList, NodeOperation] + self, node: Union[NodeIndexInputList, NodeQuery] ) -> Dict[NodeIndex, List[Group]]: ... def groups_of_node( - self, node: Union[NodeIndex, NodeIndexInputList, NodeOperation] + self, node: Union[NodeIndex, NodeIndexInputList, NodeQuery] ) -> Union[List[Group], Dict[NodeIndex, List[Group]]]: """Retrieves the groups associated with the specified node(s) in the MedRecord. @@ -1147,14 +1085,14 @@ def groups_of_node( its list of groups. Args: - node (Union[NodeIndex, NodeIndexInputList, NodeOperation]): One or more - node indices or a node operation. + node (Union[NodeIndex, NodeIndexInputList, NodeQuery]): One or more + node indices or a node query. Returns: Union[List[Group], Dict[NodeIndex, List[Group]]]: Groups associated with each node. """ - if isinstance(node, NodeOperation): + if isinstance(node, Callable): return self._medrecord.groups_of_node(self.select_nodes(node)) groups = self._medrecord.groups_of_node( @@ -1171,11 +1109,11 @@ def groups_of_edge(self, edge: EdgeIndex) -> List[Group]: ... @overload def groups_of_edge( - self, edge: Union[EdgeIndexInputList, EdgeOperation] + self, edge: Union[EdgeIndexInputList, EdgeQuery] ) -> Dict[EdgeIndex, List[Group]]: ... def groups_of_edge( - self, edge: Union[EdgeIndex, EdgeIndexInputList, EdgeOperation] + self, edge: Union[EdgeIndex, EdgeIndexInputList, EdgeQuery] ) -> Union[List[Group], Dict[EdgeIndex, List[Group]]]: """Retrieves the groups associated with the specified edge(s) in the MedRecord. @@ -1184,14 +1122,14 @@ def groups_of_edge( its list of groups. Args: - edge (Union[EdgeIndex, EdgeIndexInputList, EdgeOperation]): One or more - edge indices or an edge operation. + edge (Union[EdgeIndex, EdgeIndexInputList, EdgeQuery]): One or more + edge indices or an edge query. Returns: Union[List[Group], Dict[EdgeIndex, List[Group]]]: Groups associated with each edge. """ - if isinstance(edge, EdgeOperation): + if isinstance(edge, Callable): return self._medrecord.groups_of_edge(self.select_edges(edge)) groups = self._medrecord.groups_of_edge( @@ -1270,13 +1208,13 @@ def neighbors( @overload def neighbors( self, - node: Union[NodeIndexInputList, NodeOperation], + node: Union[NodeIndexInputList, NodeQuery], directed: bool = True, ) -> Dict[NodeIndex, List[NodeIndex]]: ... def neighbors( self, - node: Union[NodeIndex, NodeIndexInputList, NodeOperation], + node: Union[NodeIndex, NodeIndexInputList, NodeQuery], directed: bool = True, ) -> Union[List[NodeIndex], Dict[NodeIndex, List[NodeIndex]]]: """Retrieves the neighbors of the specified node(s) in the MedRecord. @@ -1286,14 +1224,14 @@ def neighbors( each node index to its list of neighboring nodes. Args: - node (Union[NodeIndex, NodeIndexInputList, NodeOperation]): One or more - node indices or a node operation. - directed (bool, optional): Whether to consider edges as directed + node (Union[NodeIndex, NodeIndexInputList, NodeQuery]): One or more + node indices or a node query. + directed (bool, optional): Whether to consider edges as directed. Returns: Union[List[NodeIndex], Dict[NodeIndex, List[NodeIndex]]]: Neighboring nodes. """ - if isinstance(node, NodeOperation): + if isinstance(node, Callable): node = self.select_nodes(node) if directed: @@ -1320,50 +1258,15 @@ def clear(self) -> None: """ return self._medrecord.clear() - def select_nodes(self, operation: NodeOperation) -> List[NodeIndex]: - """Selects nodes based on a specified operation and returns their indices. - - Args: - operation (NodeOperation): The operation to apply to select nodes. - - Returns: - List[NodeIndex]: A list of node indices that satisfy the operation. - """ - return self._medrecord.select_nodes(operation._node_operation) - - def select_edges(self, operation: EdgeOperation) -> List[EdgeIndex]: - """Selects edges based on a specified operation and returns their indices. - - Args: - operation (EdgeOperation): The operation to apply to select edges. - - Returns: - List[EdgeIndex]: A list of edge indices that satisfy the operation. - """ - return self._medrecord.select_edges(operation._edge_operation) - - @overload - def __getitem__(self, key: NodeOperation) -> List[NodeIndex]: ... - - @overload - def __getitem__(self, key: EdgeOperation) -> List[EdgeIndex]: ... - - def __getitem__( - self, key: Union[NodeOperation, EdgeOperation] - ) -> Union[List[NodeIndex], List[EdgeIndex]]: - """Allows selection of nodes or edges using operations directly via indexing. - - Args: - key (Union[NodeOperation, EdgeOperation]): Operation to select nodes - or edges. - - Returns: - Union[List[NodeIndex], List[EdgeIndex]]: Node or edge indices selected. - """ - if isinstance(key, NodeOperation): - return self.select_nodes(key) + def select_nodes(self, query: NodeQuery) -> List[NodeIndex]: + return self._medrecord.select_nodes( + lambda node: query(NodeOperand._from_py_node_operand(node)) + ) - return self.select_edges(key) + def select_edges(self, query: EdgeQuery) -> List[EdgeIndex]: + return self._medrecord.select_edges( + lambda edge: query(EdgeOperand._from_py_edge_operand(edge)) + ) def clone(self) -> MedRecord: """Clones the MedRecord instance. diff --git a/medmodels/medrecord/querying.py b/medmodels/medrecord/querying.py index 5a6e8385..f274dd52 100644 --- a/medmodels/medrecord/querying.py +++ b/medmodels/medrecord/querying.py @@ -1,18 +1,22 @@ from __future__ import annotations -from typing import List, Union +import sys +from enum import Enum +from typing import TYPE_CHECKING, Callable, List, Union from medmodels._medmodels import ( - PyEdgeAttributeOperand, + PyAttributesTreeOperand, + PyEdgeDirection, PyEdgeIndexOperand, + PyEdgeIndicesOperand, PyEdgeOperand, - PyEdgeOperation, - PyNodeAttributeOperand, + PyMultipleAttributesOperand, + PyMultipleValuesOperand, PyNodeIndexOperand, + PyNodeIndicesOperand, PyNodeOperand, - PyNodeOperation, - PyValueArithmeticOperation, - PyValueTransformationOperation, + PySingleAttributeOperand, + PySingleValueOperand, ) from medmodels.medrecord.types import ( EdgeIndex, @@ -22,1562 +26,1917 @@ NodeIndex, ) -ValueOperand = Union[ - MedRecordValue, - MedRecordAttribute, - PyValueArithmeticOperation, - PyValueTransformationOperation, -] +if TYPE_CHECKING: + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias +NodeQuery: TypeAlias = Callable[["NodeOperand"], None] +EdgeQuery: TypeAlias = Callable[["EdgeOperand"], None] -class NodeOperation: - _node_operation: PyNodeOperation - - def __init__(self, node_operation: PyNodeOperation): - self._node_operation = node_operation - - def logical_and(self, operation: NodeOperation) -> NodeOperation: - """Combines this NodeOperation with another using a logical AND, resulting in a new NodeOperation that is true only if both original operations are true. +SingleValueComparisonOperand: TypeAlias = Union["SingleValueOperand", MedRecordValue] +SingleValueArithmeticOperand: TypeAlias = SingleValueComparisonOperand +MultipleValuesComparisonOperand: TypeAlias = Union[ + "MultipleValuesOperand", List[MedRecordValue] +] - This method allows for the chaining of conditions to refine queries on nodes. - Args: - operation (NodeOperation): Another NodeOperation to be combined with the - current one. +def _py_single_value_comparison_operand_from_single_value_comparison_operand( + single_value_comparison_operand: SingleValueComparisonOperand, +) -> Union[MedRecordValue, PySingleValueOperand]: + if isinstance(single_value_comparison_operand, SingleValueOperand): + return single_value_comparison_operand._single_value_operand + return single_value_comparison_operand - Returns: - NodeOperation: A new NodeOperation representing the logical AND of this - operation with another. - """ - return NodeOperation( - self._node_operation.logical_and(operation._node_operation) - ) - def __and__(self, operation: NodeOperation) -> NodeOperation: - return self.logical_and(operation) +def _py_multiple_values_comparison_operand_from_multiple_values_comparison_operand( + multiple_values_comparison_operand: MultipleValuesComparisonOperand, +) -> Union[List[MedRecordValue], PyMultipleValuesOperand]: + if isinstance(multiple_values_comparison_operand, MultipleValuesOperand): + return multiple_values_comparison_operand._multiple_values_operand + return multiple_values_comparison_operand - def logical_or(self, operation: NodeOperation) -> NodeOperation: - """Combines this NodeOperation with another using a logical OR, resulting in a new NodeOperation that is true if either of the original operations is true. - This method enables the combination of conditions to expand queries on nodes. +SingleAttributeComparisonOperand: TypeAlias = Union[ + "SingleAttributeOperand", + MedRecordAttribute, +] +SingleAttributeArithmeticOperand: TypeAlias = SingleAttributeComparisonOperand +MultipleAttributesComparisonOperand: TypeAlias = Union[ + "MultipleAttributesOperand", List[MedRecordAttribute] +] - Args: - operation (NodeOperation): Another NodeOperation to be combined with the - current one. - Returns: - NodeOperation: A new NodeOperation representing the logical OR of this - operation with another. - """ - return NodeOperation(self._node_operation.logical_or(operation._node_operation)) +def _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + single_attribute_comparison_operand: SingleAttributeComparisonOperand, +) -> Union[MedRecordAttribute, PySingleAttributeOperand]: + if isinstance(single_attribute_comparison_operand, SingleAttributeOperand): + return single_attribute_comparison_operand._single_attribute_operand + return single_attribute_comparison_operand - def __or__(self, operation: NodeOperation) -> NodeOperation: - return self.logical_or(operation) - def logical_xor(self, operation: NodeOperation) -> NodeOperation: - """Combines this NodeOperation with another using a logical XOR, resulting in a new NodeOperation that is true only if exactly one of the original operations is true. +def _py_multiple_attributes_comparison_operand_from_multiple_attributes_comparison_operand( + multiple_attributes_comparison_operand: MultipleAttributesComparisonOperand, +) -> Union[List[MedRecordAttribute], PyMultipleAttributesOperand]: + if isinstance(multiple_attributes_comparison_operand, MultipleAttributesOperand): + return multiple_attributes_comparison_operand._multiple_attributes_operand + return multiple_attributes_comparison_operand - This method is useful for creating conditions that must be - exclusively true. - Args: - operation (NodeOperation): Another NodeOperation to be combined with the - current one. +NodeIndexComparisonOperand: TypeAlias = Union["NodeIndexOperand", NodeIndex] +NodeIndexArithmeticOperand: TypeAlias = NodeIndexComparisonOperand +NodeIndicesComparisonOperand: TypeAlias = Union["NodeIndicesOperand", List[NodeIndex]] - Returns: - NodeOperation: A new NodeOperation representing the logical XOR of this - operation with another. - """ - return NodeOperation( - self._node_operation.logical_xor(operation._node_operation) - ) - def __xor__(self, operation: NodeOperation) -> NodeOperation: - return self.logical_xor(operation) +def _py_node_index_comparison_operand_from_node_index_comparison_operand( + node_index_comparison_operand: NodeIndexComparisonOperand, +) -> Union[NodeIndex, PyNodeIndexOperand]: + if isinstance(node_index_comparison_operand, NodeIndexOperand): + return node_index_comparison_operand._node_index_operand + return node_index_comparison_operand - def logical_not(self) -> NodeOperation: - """Creates a new NodeOperation that is the logical NOT of this operation, inversing the current condition. - This method is useful for negating a condition - to create queries on nodes. +def _py_node_indices_comparison_operand_from_node_indices_comparison_operand( + node_indices_comparison_operand: NodeIndicesComparisonOperand, +) -> Union[List[NodeIndex], PyNodeIndicesOperand]: + if isinstance(node_indices_comparison_operand, NodeIndicesOperand): + return node_indices_comparison_operand._node_indices_operand + return node_indices_comparison_operand - Returns: - NodeOperation: A new NodeOperation representing the logical NOT of - this operation. - """ - return NodeOperation(self._node_operation.logical_not()) - def __invert__(self) -> NodeOperation: - return self.logical_not() +EdgeIndexComparisonOperand: TypeAlias = Union[ + "EdgeIndexOperand", + EdgeIndex, +] +EdgeIndexArithmeticOperand: TypeAlias = EdgeIndexComparisonOperand +EdgeIndicesComparisonOperand: TypeAlias = Union[ + "EdgeIndicesOperand", + List[EdgeIndex], +] -class EdgeOperation: - _edge_operation: PyEdgeOperation +def _py_edge_index_comparison_operand_from_edge_index_comparison_operand( + edge_index_comparison_operand: EdgeIndexComparisonOperand, +) -> Union[EdgeIndex, PyEdgeIndexOperand]: + if isinstance(edge_index_comparison_operand, EdgeIndexOperand): + return edge_index_comparison_operand._edge_index_operand + return edge_index_comparison_operand - def __init__(self, edge_operation: PyEdgeOperation) -> None: - self._edge_operation = edge_operation - def logical_and(self, operation: EdgeOperation) -> EdgeOperation: - """Combines this EdgeOperation with another using a logical AND, resulting in a new EdgeOperation that is true only if both original operations are true. +def _py_edge_indices_comparison_operand_from_edge_indices_comparison_operand( + edge_indices_comparison_operand: EdgeIndicesComparisonOperand, +) -> Union[List[EdgeIndex], PyEdgeIndicesOperand]: + if isinstance(edge_indices_comparison_operand, EdgeIndicesOperand): + return edge_indices_comparison_operand._edge_indices_operand + return edge_indices_comparison_operand - This method allows for the chaining of conditions to refine queries on nodes. - Args: - operation (EdgeOperation): Another EdgeOperation to be combined with the - current one. +class EdgeDirection(Enum): + INCOMING = 0 + OUTGOING = 1 + BOTH = 2 - Returns: - EdgeOperation: A new EdgeOperation representing the logical AND of this - operation with another. - """ - return EdgeOperation( - self._edge_operation.logical_and(operation._edge_operation) + def _into_py_edge_direction(self) -> PyEdgeDirection: + return ( + PyEdgeDirection.Incoming + if self == EdgeDirection.INCOMING + else PyEdgeDirection.Outgoing + if self == EdgeDirection.OUTGOING + else PyEdgeDirection.Both ) - def __and__(self, operation: EdgeOperation) -> EdgeOperation: - return self.logical_and(operation) - def logical_or(self, operation: EdgeOperation) -> EdgeOperation: - """Combines this EdgeOperation with another using a logical OR, resulting in a new EdgeOperation that is true if either of the original operations is true. +class NodeOperand: + _node_operand: PyNodeOperand - This method enables the combination of conditions to expand queries on nodes. + def attribute(self, attribute: MedRecordAttribute) -> MultipleValuesOperand: + return MultipleValuesOperand._from_py_multiple_values_operand( + self._node_operand.attribute(attribute) + ) - Args: - operation (EdgeOperation): Another EdgeOperation to be combined with the - current one. + def attributes(self) -> AttributesTreeOperand: + return AttributesTreeOperand._from_py_attributes_tree_operand( + self._node_operand.attributes() + ) - Returns: - EdgeOperation: A new EdgeOperation representing the logical OR of this - operation with another. - """ - return EdgeOperation(self._edge_operation.logical_or(operation._edge_operation)) + def index(self) -> NodeIndicesOperand: + return NodeIndicesOperand._from_py_node_indices_operand( + self._node_operand.index() + ) - def __or__(self, operation: EdgeOperation) -> EdgeOperation: - return self.logical_or(operation) + def in_group(self, group: Union[Group, List[Group]]) -> None: + self._node_operand.in_group(group) - def logical_xor(self, operation: EdgeOperation) -> EdgeOperation: - """Combines this EdgeOperation with another using a logical XOR, resulting in a new EdgeOperation that is true only if exactly one of the original operations is true. + def has_attribute( + self, attribute: Union[MedRecordAttribute, List[MedRecordAttribute]] + ) -> None: + self._node_operand.has_attribute(attribute) - This method is useful for creating conditions that must be - exclusively true. + def outgoing_edges(self) -> EdgeOperand: + return EdgeOperand._from_py_edge_operand(self._node_operand.outgoing_edges()) - Args: - operation (EdgeOperation): Another EdgeOperation to be combined with the - current one. + def incoming_edges(self) -> EdgeOperand: + return EdgeOperand._from_py_edge_operand(self._node_operand.incoming_edges()) - Returns: - EdgeOperation: A new EdgeOperation representing the logical XOR of this - operation with another. - """ - return EdgeOperation( - self._edge_operation.logical_xor(operation._edge_operation) + def neighbors( + self, edge_direction: EdgeDirection = EdgeDirection.OUTGOING + ) -> NodeOperand: + return NodeOperand._from_py_node_operand( + self._node_operand.neighbors(edge_direction._into_py_edge_direction()) ) - def __xor__(self, operation: EdgeOperation) -> EdgeOperation: - return self.logical_xor(operation) - - def logical_not(self) -> EdgeOperation: - """Creates a new EdgeOperation that is the logical NOT of this operation, inversing the current condition. - - This method is useful for negating a condition - to create queries on nodes. + def either_or(self, either: NodeQuery, or_: NodeQuery) -> None: + self._node_operand.either_or( + lambda node: either(NodeOperand._from_py_node_operand(node)), + lambda node: or_(NodeOperand._from_py_node_operand(node)), + ) - Returns: - EdgeOperation: A new EdgeOperation representing the logical NOT of - this operation. - """ - return EdgeOperation(self._edge_operation.logical_not()) + def exclude(self, query: NodeQuery) -> None: + self._node_operand.exclude( + lambda node: query(NodeOperand._from_py_node_operand(node)) + ) - def __invert__(self) -> EdgeOperation: - return self.logical_not() + def clone(self) -> NodeOperand: + return NodeOperand._from_py_node_operand(self._node_operand.deep_clone()) + @classmethod + def _from_py_node_operand(cls, py_node_operand: PyNodeOperand) -> NodeOperand: + node_operand = cls() + node_operand._node_operand = py_node_operand + return node_operand -class NodeAttributeOperand: - _node_attribute_operand: PyNodeAttributeOperand - def __init__(self, node_attribute_operand: PyNodeAttributeOperand) -> None: - self._node_attribute_operand = node_attribute_operand +class EdgeOperand: + _edge_operand: PyEdgeOperand - def greater( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand is greater than the specified value or operand. + def attribute(self, attribute: MedRecordAttribute) -> MultipleValuesOperand: + return MultipleValuesOperand._from_py_multiple_values_operand( + self._edge_operand.attribute(attribute) + ) - Args: - operand (ValueOperand): The value or operand to compare against. + def attributes(self) -> AttributesTreeOperand: + return AttributesTreeOperand._from_py_attributes_tree_operand( + self._edge_operand.attributes() + ) - Returns: - NodeOperation: A NodeOperation representing the greater-than comparison. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.greater(operand._node_attribute_operand) - ) + def index(self) -> EdgeIndicesOperand: + return EdgeIndicesOperand._from_edge_indices_operand(self._edge_operand.index()) - return NodeOperation(self._node_attribute_operand.greater(operand)) + def in_group(self, group: Union[Group, List[Group]]) -> None: + self._edge_operand.in_group(group) - def __gt__( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - return self.greater(operand) + def has_attribute( + self, attribute: Union[MedRecordAttribute, List[MedRecordAttribute]] + ) -> None: + self._edge_operand.has_attribute(attribute) - def less(self, operand: Union[ValueOperand, NodeAttributeOperand]) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand is less than the specified value or operand. + def source_node(self) -> NodeOperand: + return NodeOperand._from_py_node_operand(self._edge_operand.source_node()) - Args: - operand (ValueOperand): The value or operand to compare against. + def target_node(self) -> NodeOperand: + return NodeOperand._from_py_node_operand(self._edge_operand.target_node()) - Returns: - NodeOperation: A NodeOperation representing the less-than comparison. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.less(operand._node_attribute_operand) - ) + def either_or(self, either: EdgeQuery, or_: EdgeQuery) -> None: + self._edge_operand.either_or( + lambda edge: either(EdgeOperand._from_py_edge_operand(edge)), + lambda edge: or_(EdgeOperand._from_py_edge_operand(edge)), + ) - return NodeOperation(self._node_attribute_operand.less(operand)) + def exclude(self, query: EdgeQuery) -> None: + self._edge_operand.exclude( + lambda edge: query(EdgeOperand._from_py_edge_operand(edge)) + ) - def __lt__( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - return self.less(operand) + def clone(self) -> EdgeOperand: + return EdgeOperand._from_py_edge_operand(self._edge_operand.deep_clone()) - def greater_or_equal( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand is greater than or equal to the specified value or operand. + @classmethod + def _from_py_edge_operand(cls, py_edge_operand: PyEdgeOperand) -> EdgeOperand: + edge_operand = cls() + edge_operand._edge_operand = py_edge_operand + return edge_operand - Args: - operand (ValueOperand): The value or operand to compare against. - Returns: - NodeOperation: A NodeOperation representing the - greater-than-or-equal-to comparison. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.greater_or_equal( - operand._node_attribute_operand - ) - ) +class MultipleValuesOperand: + _multiple_values_operand: PyMultipleValuesOperand - return NodeOperation(self._node_attribute_operand.greater_or_equal(operand)) + def max(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.max() + ) - def __ge__( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - return self.greater_or_equal(operand) + def min(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.min() + ) - def less_or_equal( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand is less than or equal to the specified value or operand. + def mean(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.mean() + ) - Args: - operand (ValueOperand): The value or operand to compare against. + def median(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.median() + ) - Returns: - NodeOperation: A NodeOperation representing the - less-than-or-equal-to comparison. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.less_or_equal( - operand._node_attribute_operand - ) - ) + def mode(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.mode() + ) - return NodeOperation(self._node_attribute_operand.less_or_equal(operand)) + def std(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.std() + ) - def __le__( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - return self.less_or_equal(operand) + def var(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.var() + ) - def equal( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand is equal to the specified value or operand. + def count(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.count() + ) - Args: - operand (ValueOperand): The value or operand to compare against.y + def sum(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.sum() + ) - Returns: - NodeOperation: A NodeOperation representing the equality comparison. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.equal(operand._node_attribute_operand) - ) + def first(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.first() + ) - return NodeOperation(self._node_attribute_operand.equal(operand)) + def last(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.last() + ) - def __eq__( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - return self.equal(operand) + def is_string(self) -> None: + self._multiple_values_operand.is_string() - def not_equal( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand is not equal to the specified value or operand. + def is_int(self) -> None: + self._multiple_values_operand.is_int() - Args: - operand (ValueOperand): The value or operand to compare against. + def is_float(self) -> None: + self._multiple_values_operand.is_float() - Returns: - NodeOperation: A NodeOperation representing the not-equal comparison. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.not_equal(operand._node_attribute_operand) - ) + def is_bool(self) -> None: + self._multiple_values_operand.is_bool() - return NodeOperation(self._node_attribute_operand.not_equal(operand)) + def is_datetime(self) -> None: + self._multiple_values_operand.is_datetime() - def __ne__( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - return self.not_equal(operand) + def is_null(self) -> None: + self._multiple_values_operand.is_null() - def is_in(self, values: List[MedRecordValue]) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand is found within the specified list of values. + def is_max(self) -> None: + self._multiple_values_operand.is_max() - Args: - values (List[MedRecordValue]): The list of values to check the - attribute against. + def is_min(self) -> None: + self._multiple_values_operand.is_min() - Returns: - NodeOperation: A NodeOperation representing the is-in comparison. - """ - return NodeOperation(self._node_attribute_operand.is_in(values)) + def greater_than(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.greater_than( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - def not_in(self, values: List[MedRecordValue]) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand is not found within the specified list of values. + def greater_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.greater_than_or_equal_to( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - Args: - values (List[MedRecordValue]): The list of values to check the - attribute against. + def less_than(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.less_than( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - Returns: - NodeOperation: A NodeOperation representing the not-in comparison. - """ - return NodeOperation(self._node_attribute_operand.not_in(values)) + def less_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.less_than_or_equal_to( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - def starts_with(self, operand: ValueOperand) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand starts with the specified value or operand. + def equal_to(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.equal_to( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - Args: - operand (ValueOperand): The value or operand to compare - the starting sequence against. + def not_equal_to(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.not_equal_to( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - Returns: - NodeOperation: A NodeOperation representing the starts-with condition. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.starts_with( - operand._node_attribute_operand - ) + def is_in(self, values: MultipleValuesComparisonOperand) -> None: + self._multiple_values_operand.is_in( + _py_multiple_values_comparison_operand_from_multiple_values_comparison_operand( + values ) + ) - return NodeOperation(self._node_attribute_operand.starts_with(operand)) + def is_not_in(self, values: MultipleValuesComparisonOperand) -> None: + self._multiple_values_operand.is_not_in( + _py_multiple_values_comparison_operand_from_multiple_values_comparison_operand( + values + ) + ) - def ends_with( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand ends with the specified value or operand. + def starts_with(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.starts_with( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - Args: - operand (ValueOperand): The value or operand to compare - the ending sequence against. + def ends_with(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.ends_with( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - Returns: - NodeOperation: A NodeOperation representing the ends-with condition. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.ends_with(operand._node_attribute_operand) + def contains(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.contains( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value ) + ) - return NodeOperation(self._node_attribute_operand.ends_with(operand)) + def add(self, value: SingleValueArithmeticOperand) -> None: + self._multiple_values_operand.add( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - def contains( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand contains the specified value or operand within it. + def subtract(self, value: SingleValueArithmeticOperand) -> None: + self._multiple_values_operand.sub( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - Args: - operand (ValueOperand): The value or operand to check for containment. + def multiply(self, value: SingleValueArithmeticOperand) -> None: + self._multiple_values_operand.mul( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - Returns: - NodeOperation: A NodeOperation representing the contains condition. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.contains(operand._node_attribute_operand) + def divide(self, value: SingleValueArithmeticOperand) -> None: + self._multiple_values_operand.div( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value ) + ) - return NodeOperation(self._node_attribute_operand.contains(operand)) + def modulo(self, value: SingleValueArithmeticOperand) -> None: + self._multiple_values_operand.mod( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - def add(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the sum of the attribute's value and the specified value. + def power(self, value: SingleValueArithmeticOperand) -> None: + self._multiple_values_operand.pow( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - Args: - value (MedRecordValue): The value to add to the attribute's value. + def round(self) -> None: + self._multiple_values_operand.round() - Returns: - ValueOperand: The result of the addition operation. - """ - return self._node_attribute_operand.add(value) + def ceil(self) -> None: + self._multiple_values_operand.ceil() - def __add__(self, value: MedRecordValue) -> ValueOperand: - return self.add(value) + def floor(self) -> None: + self._multiple_values_operand.floor() - def sub(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the difference between the attribute's value and the specified value. + def absolute(self) -> None: + self._multiple_values_operand.abs() - Args: - value (MedRecordValue): The value to subtract from the attribute's value. + def sqrt(self) -> None: + self._multiple_values_operand.sqrt() - Returns: - ValueOperand: The result of the subtraction operation. - """ - return self._node_attribute_operand.sub(value) + def trim(self) -> None: + self._multiple_values_operand.trim() - def __sub__(self, value: MedRecordValue) -> ValueOperand: - return self.sub(value) + def trim_start(self) -> None: + self._multiple_values_operand.trim_start() - def mul(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the product of the attribute's value and the specified value. + def trim_end(self) -> None: + self._multiple_values_operand.trim_end() - Args: - value (MedRecordValue): The value to multiply the attribute's value by. + def lowercase(self) -> None: + self._multiple_values_operand.lowercase() - Returns: - ValueOperand: The result of the multiplication operation. - """ - return self._node_attribute_operand.mul(value) + def uppercase(self) -> None: + self._multiple_values_operand.uppercase() - def __mul__(self, value: MedRecordValue) -> ValueOperand: - return self.mul(value) + def slice(self, start: int, end: int) -> None: + self._multiple_values_operand.slice(start, end) - def div(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the division of the attribute's value by the specified value. + def either_or( + self, + either: Callable[[MultipleValuesOperand], None], + or_: Callable[[MultipleValuesOperand], None], + ) -> None: + self._multiple_values_operand.either_or( + lambda values: either( + MultipleValuesOperand._from_py_multiple_values_operand(values) + ), + lambda values: or_( + MultipleValuesOperand._from_py_multiple_values_operand(values) + ), + ) - Args: - value (MedRecordValue): The value to divide the attribute's value by. + def exclude(self, query: Callable[[MultipleValuesOperand], None]) -> None: + self._multiple_values_operand.exclude( + lambda values: query( + MultipleValuesOperand._from_py_multiple_values_operand(values) + ) + ) - Returns: - ValueOperand: The result of the division operation. - """ - return self._node_attribute_operand.div(value) + def clone(self) -> MultipleValuesOperand: + return MultipleValuesOperand._from_py_multiple_values_operand( + self._multiple_values_operand.deep_clone() + ) - def __truediv__(self, value: MedRecordValue) -> ValueOperand: - return self.div(value) + @classmethod + def _from_py_multiple_values_operand( + cls, py_multiple_values_operand: PyMultipleValuesOperand + ) -> MultipleValuesOperand: + multiple_values_operand = cls() + multiple_values_operand._multiple_values_operand = py_multiple_values_operand + return multiple_values_operand - def pow(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the result of raising the attribute's value to the power of the specified value. - Args: - value (MedRecordValue): The value to raise the attribute's value to. +class SingleValueOperand: + _single_value_operand: PySingleValueOperand - Returns: - ValueOperand: The result of the exponentiation operation. - """ - return self._node_attribute_operand.pow(value) + def is_string(self) -> None: + self._single_value_operand.is_string() - def __pow__(self, value: MedRecordValue) -> ValueOperand: - return self.pow(value) + def is_int(self) -> None: + self._single_value_operand.is_int() - def mod(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the remainder of dividing the attribute's value by the specified value. + def is_float(self) -> None: + self._single_value_operand.is_float() - Args: - value (MedRecordValue): The value to divide the attribute's value by. + def is_bool(self) -> None: + self._single_value_operand.is_bool() - Returns: - ValueOperand: The result of the modulo operation. - """ - return self._node_attribute_operand.mod(value) + def is_datetime(self) -> None: + self._single_value_operand.is_datetime() - def __mod__(self, value: MedRecordValue) -> ValueOperand: - return self.mod(value) + def is_null(self) -> None: + self._single_value_operand.is_null() - def round(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of rounding the attribute's value. + def greater_than(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.greater_than( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - Returns: - ValueOperand: The result of the rounding operation. - """ - return self._node_attribute_operand.round() + def greater_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.greater_than_or_equal_to( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - def ceil(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of applying the ceiling function to the attribute's value, effectively rounding it up to the nearest whole number. + def less_than(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.less_than( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - Returns: - ValueOperand: The result of the ceiling operation. - """ - return self._node_attribute_operand.ceil() + def less_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.less_than_or_equal_to( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - def floor(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of applying the floor function to the attribute's value, effectively rounding it down to the nearest whole number. + def equal_to(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.equal_to( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - Returns: - ValueOperand: The result of the floor operation. - """ - return self._node_attribute_operand.floor() + def not_equal_to(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.not_equal_to( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - def abs(self) -> ValueOperand: - """Creates a new ValueOperand representing the absolute value of the attribute's value. + def is_in(self, values: MultipleValuesComparisonOperand) -> None: + self._single_value_operand.is_in( + _py_multiple_values_comparison_operand_from_multiple_values_comparison_operand( + values + ) + ) - Returns: - ValueOperand: The absolute value of the attribute's value. - """ - return self._node_attribute_operand.abs() + def is_not_in(self, values: MultipleValuesComparisonOperand) -> None: + self._single_value_operand.is_not_in( + _py_multiple_values_comparison_operand_from_multiple_values_comparison_operand( + values + ) + ) - def sqrt(self) -> ValueOperand: - """Creates a new ValueOperand representing the square root of the attribute's value. + def starts_with(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.starts_with( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - Returns: - ValueOperand: The square root of the attribute's value. - """ - return self._node_attribute_operand.sqrt() + def ends_with(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.ends_with( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - def trim(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of trimming whitespace from both ends of the attribute's value. + def contains(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.contains( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - Returns: - ValueOperand: The attribute's value with leading and trailing - whitespace removed. - """ - return self._node_attribute_operand.trim() + def add(self, value: SingleValueArithmeticOperand) -> None: + self._single_value_operand.add( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - def trim_start(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of trimming whitespace from the start (left side) of the attribute's value. + def subtract(self, value: SingleValueArithmeticOperand) -> None: + self._single_value_operand.sub( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - Returns: - ValueOperand: The attribute's value with leading whitespace removed. - """ - return self._node_attribute_operand.trim_start() + def multiply(self, value: SingleValueArithmeticOperand) -> None: + self._single_value_operand.mul( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - def trim_end(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of trimming whitespace from the end (right side) of the attribute's value. + def modulo(self, value: SingleValueArithmeticOperand) -> None: + self._single_value_operand.mod( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - Returns: - ValueOperand: The attribute's value with trailing whitespace removed. - """ - return self._node_attribute_operand.trim_end() + def power(self, value: SingleValueArithmeticOperand) -> None: + self._single_value_operand.pow( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) - def lowercase(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of converting all characters in the attribute's value to lowercase. + def round(self) -> None: + self._single_value_operand.round() - Returns: - ValueOperand: The attribute's value in lowercase letters. - """ - return self._node_attribute_operand.lowercase() + def ceil(self) -> None: + self._single_value_operand.ceil() - def uppercase(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of converting all characters in the attribute's value to uppercase. + def floor(self) -> None: + self._single_value_operand.floor() - Returns: - ValueOperand: The attribute's value in uppercase letters. - """ - return self._node_attribute_operand.uppercase() + def absolute(self) -> None: + self._single_value_operand.abs() - def slice(self, start: int, end: int) -> ValueOperand: - """Creates a new ValueOperand representing the result of slicing the attribute's value using the specified start and end indices. + def sqrt(self) -> None: + self._single_value_operand.sqrt() - Args: - start (int): The index at which to start the slice. - end (int): The index at which to end the slice. + def trim(self) -> None: + self._single_value_operand.trim() - Returns: - ValueOperand: The attribute's value with the specified slice applied. - """ - return self._node_attribute_operand.slice(start, end) + def trim_start(self) -> None: + self._single_value_operand.trim_start() + def trim_end(self) -> None: + self._single_value_operand.trim_end() -class EdgeAttributeOperand: - _edge_attribute_operand: PyEdgeAttributeOperand + def lowercase(self) -> None: + self._single_value_operand.lowercase() - def __init__(self, edge_attribute_operand: PyEdgeAttributeOperand) -> None: - self._edge_attribute_operand = edge_attribute_operand + def uppercase(self) -> None: + self._single_value_operand.uppercase() - def greater( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand is greater than the specified value or operand. + def slice(self, start: int, end: int) -> None: + self._single_value_operand.slice(start, end) - Args: - operand (ValueOperand): The value or operand to compare against. + def either_or( + self, + either: Callable[[SingleValueOperand], None], + or_: Callable[[SingleValueOperand], None], + ) -> None: + self._single_value_operand.either_or( + lambda value: either( + SingleValueOperand._from_py_single_value_operand(value) + ), + lambda value: or_(SingleValueOperand._from_py_single_value_operand(value)), + ) - Returns: - EdgeOperation: A EdgeOperation representing the greater-than comparison. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.greater(operand._edge_attribute_operand) - ) + def exclude(self, query: Callable[[SingleValueOperand], None]) -> None: + self._single_value_operand.exclude( + lambda value: query(SingleValueOperand._from_py_single_value_operand(value)) + ) - return EdgeOperation(self._edge_attribute_operand.greater(operand)) + def clone(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._single_value_operand.deep_clone() + ) - def __gt__( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - return self.greater(operand) + @classmethod + def _from_py_single_value_operand( + cls, py_single_value_operand: PySingleValueOperand + ) -> SingleValueOperand: + single_value_operand = cls() + single_value_operand._single_value_operand = py_single_value_operand + return single_value_operand - def less(self, operand: Union[ValueOperand, EdgeAttributeOperand]) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand is less than the specified value or operand. - Args: - operand (ValueOperand): The value or operand to compare against. +class AttributesTreeOperand: + _attributes_tree_operand: PyAttributesTreeOperand - Returns: - EdgeOperation: A EdgeOperation representing the less-than comparison. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.less(operand._edge_attribute_operand) - ) + def max(self) -> MultipleAttributesOperand: + return MultipleAttributesOperand._from_py_multiple_attributes_operand( + self._attributes_tree_operand.max() + ) - return EdgeOperation(self._edge_attribute_operand.less(operand)) + def min(self) -> MultipleAttributesOperand: + return MultipleAttributesOperand._from_py_multiple_attributes_operand( + self._attributes_tree_operand.min() + ) - def __lt__(self, operand: ValueOperand) -> EdgeOperation: - return self.less(operand) + def count(self) -> MultipleAttributesOperand: + return MultipleAttributesOperand._from_py_multiple_attributes_operand( + self._attributes_tree_operand.count() + ) - def greater_or_equal( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand is greater than or equal to the specified value or operand. + def sum(self) -> MultipleAttributesOperand: + return MultipleAttributesOperand._from_py_multiple_attributes_operand( + self._attributes_tree_operand.sum() + ) - Args: - operand (ValueOperand): The value or operand to compare against. + def first(self) -> MultipleAttributesOperand: + return MultipleAttributesOperand._from_py_multiple_attributes_operand( + self._attributes_tree_operand.first() + ) - Returns: - EdgeOperation: A EdgeOperation representing the - greater-than-or-equal-to comparison. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.greater_or_equal( - operand._edge_attribute_operand - ) - ) + def last(self) -> MultipleAttributesOperand: + return MultipleAttributesOperand._from_py_multiple_attributes_operand( + self._attributes_tree_operand.last() + ) - return EdgeOperation(self._edge_attribute_operand.greater_or_equal(operand)) + def is_string(self) -> None: + self._attributes_tree_operand.is_string() - def __ge__(self, operand: ValueOperand) -> EdgeOperation: - return self.greater_or_equal(operand) + def is_int(self) -> None: + self._attributes_tree_operand.is_int() - def less_or_equal( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand is less than or equal to the specified value or operand. + def is_max(self) -> None: + self._attributes_tree_operand.is_max() - Args: - operand (ValueOperand): The value or operand to compare against. + def is_min(self) -> None: + self._attributes_tree_operand.is_min() - Returns: - EdgeOperation: A EdgeOperation representing the - less-than-or-equal-to comparison. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.less_or_equal( - operand._edge_attribute_operand - ) + def greater_than(self, attribute: SingleAttributeComparisonOperand) -> None: + self._attributes_tree_operand.greater_than( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute ) + ) - return EdgeOperation(self._edge_attribute_operand.less_or_equal(operand)) - - def __le__(self, operand: ValueOperand) -> EdgeOperation: - return self.less_or_equal(operand) + def greater_than_or_equal_to( + self, attribute: SingleAttributeComparisonOperand + ) -> None: + self._attributes_tree_operand.greater_than_or_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def equal( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand is equal to the specified value or operand. + def less_than(self, attribute: SingleAttributeComparisonOperand) -> None: + self._attributes_tree_operand.less_than( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Args: - operand (ValueOperand): The value or operand to compare against. + def less_than_or_equal_to( + self, attribute: SingleAttributeComparisonOperand + ) -> None: + self._attributes_tree_operand.less_than_or_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Returns: - EdgeOperation: A EdgeOperation representing the equality comparison. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.equal(operand._edge_attribute_operand) + def equal_to(self, attribute: SingleAttributeComparisonOperand) -> None: + self._attributes_tree_operand.equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute ) + ) - return EdgeOperation(self._edge_attribute_operand.equal(operand)) + def not_equal_to(self, attribute: SingleAttributeComparisonOperand) -> None: + self._attributes_tree_operand.not_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def __eq__( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - return self.equal(operand) + def is_in(self, attributes: MultipleAttributesComparisonOperand) -> None: + self._attributes_tree_operand.is_in( + _py_multiple_attributes_comparison_operand_from_multiple_attributes_comparison_operand( + attributes + ) + ) - def not_equal( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand is not equal to the specified value or operand. + def is_not_in(self, attributes: MultipleAttributesComparisonOperand) -> None: + self._attributes_tree_operand.is_not_in( + _py_multiple_attributes_comparison_operand_from_multiple_attributes_comparison_operand( + attributes + ) + ) - Args: - operand (ValueOperand): The value or operand to compare against. + def starts_with(self, attribute: SingleAttributeComparisonOperand) -> None: + self._attributes_tree_operand.starts_with( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Returns: - EdgeOperation: A EdgeOperation representing the not-equal comparison. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.not_equal(operand._edge_attribute_operand) + def ends_with(self, attribute: SingleAttributeComparisonOperand) -> None: + self._attributes_tree_operand.ends_with( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute ) + ) - return EdgeOperation(self._edge_attribute_operand.not_equal(operand)) + def contains(self, attribute: SingleAttributeComparisonOperand) -> None: + self._attributes_tree_operand.contains( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def __ne__( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - return self.not_equal(operand) + def add(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._attributes_tree_operand.add( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def is_in(self, values: List[MedRecordValue]) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand is found within the specified list of values. + def subtract(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._attributes_tree_operand.sub( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Args: - values (List[MedRecordValue]): The list of values to check the - attribute against. + def multiply(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._attributes_tree_operand.mul( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Returns: - EdgeOperation: A EdgeOperation representing the is-in comparison. - """ - return EdgeOperation(self._edge_attribute_operand.is_in(values)) + def modulo(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._attributes_tree_operand.mod( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def not_in(self, values: List[MedRecordValue]) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand is not found within the specified list of values. + def power(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._attributes_tree_operand.pow( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Args: - values (List[MedRecordValue]): The list of values to check the - attribute against. + def absolute(self) -> None: + self._attributes_tree_operand.abs() - Returns: - EdgeOperation: A EdgeOperation representing the not-in comparison. - """ - return EdgeOperation(self._edge_attribute_operand.not_in(values)) + def trim(self) -> None: + self._attributes_tree_operand.trim() - def starts_with( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand starts with the specified value or operand. + def trim_start(self) -> None: + self._attributes_tree_operand.trim_start() - Args: - operand (ValueOperand): The value or operand to compare - the starting sequence against. + def trim_end(self) -> None: + self._attributes_tree_operand.trim_end() - Returns: - EdgeOperation: A EdgeOperation representing the starts-with condition. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.starts_with( - operand._edge_attribute_operand - ) - ) + def lowercase(self) -> None: + self._attributes_tree_operand.lowercase() - return EdgeOperation(self._edge_attribute_operand.starts_with(operand)) + def uppercase(self) -> None: + self._attributes_tree_operand.uppercase() - def ends_with( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand ends with the specified value or operand. + def slice(self, start: int, end: int) -> None: + self._attributes_tree_operand.slice(start, end) - Args: - operand (ValueOperand): The value or operand to compare - the ending sequence against. + def either_or( + self, + either: Callable[[AttributesTreeOperand], None], + or_: Callable[[AttributesTreeOperand], None], + ) -> None: + self._attributes_tree_operand.either_or( + lambda attributes: either( + AttributesTreeOperand._from_py_attributes_tree_operand(attributes) + ), + lambda attributes: or_( + AttributesTreeOperand._from_py_attributes_tree_operand(attributes) + ), + ) - Returns: - EdgeOperation: A EdgeOperation representing the ends-with condition. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.ends_with(operand._edge_attribute_operand) + def exclude(self, query: Callable[[AttributesTreeOperand], None]) -> None: + self._attributes_tree_operand.exclude( + lambda attributes: query( + AttributesTreeOperand._from_py_attributes_tree_operand(attributes) ) + ) - return EdgeOperation(self._edge_attribute_operand.ends_with(operand)) + def clone(self) -> AttributesTreeOperand: + return AttributesTreeOperand._from_py_attributes_tree_operand( + self._attributes_tree_operand.deep_clone() + ) - def contains( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand contains the specified value or operand within it. + @classmethod + def _from_py_attributes_tree_operand( + cls, py_attributes_tree_operand: PyAttributesTreeOperand + ) -> AttributesTreeOperand: + attributes_tree_operand = cls() + attributes_tree_operand._attributes_tree_operand = py_attributes_tree_operand + return attributes_tree_operand - Args: - operand (ValueOperand): The value or operand to check for containment. - Returns: - EdgeOperation: A EdgeOperation representing the contains condition. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.contains(operand._edge_attribute_operand) - ) +class MultipleAttributesOperand: + _multiple_attributes_operand: PyMultipleAttributesOperand - return EdgeOperation(self._edge_attribute_operand.contains(operand)) + def max(self) -> SingleAttributeOperand: + return SingleAttributeOperand._from_py_single_attribute_operand( + self._multiple_attributes_operand.max() + ) - def add(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the sum of the attribute's value and the specified value. + def min(self) -> SingleAttributeOperand: + return SingleAttributeOperand._from_py_single_attribute_operand( + self._multiple_attributes_operand.min() + ) - Args: - value (MedRecordValue): The value to add to the attribute's value. + def count(self) -> SingleAttributeOperand: + return SingleAttributeOperand._from_py_single_attribute_operand( + self._multiple_attributes_operand.count() + ) - Returns: - ValueOperand: The result of the addition operation. - """ - return self._edge_attribute_operand.add(value) + def sum(self) -> SingleAttributeOperand: + return SingleAttributeOperand._from_py_single_attribute_operand( + self._multiple_attributes_operand.sum() + ) - def __add__(self, value: MedRecordValue) -> ValueOperand: - return self.add(value) + def first(self) -> SingleAttributeOperand: + return SingleAttributeOperand._from_py_single_attribute_operand( + self._multiple_attributes_operand.first() + ) - def sub(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the difference between the attribute's value and the specified value. + def last(self) -> SingleAttributeOperand: + return SingleAttributeOperand._from_py_single_attribute_operand( + self._multiple_attributes_operand.last() + ) - Args: - value (MedRecordValue): The value to subtract from the attribute's value. + def is_string(self) -> None: + self._multiple_attributes_operand.is_string() - Returns: - ValueOperand: The result of the subtraction operation. - """ - return self._edge_attribute_operand.sub(value) + def is_int(self) -> None: + self._multiple_attributes_operand.is_int() - def __sub__(self, value: MedRecordValue) -> ValueOperand: - return self.sub(value) + def is_max(self) -> None: + self._multiple_attributes_operand.is_max() - def mul(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the product of the attribute's value and the specified value. + def is_min(self) -> None: + self._multiple_attributes_operand.is_min() - Args: - value (MedRecordValue): The value to multiply the attribute's value by. + def greater_than(self, attribute: SingleAttributeComparisonOperand) -> None: + self._multiple_attributes_operand.greater_than( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Returns: - ValueOperand: The result of the multiplication operation. - """ - return self._edge_attribute_operand.mul(value) + def greater_than_or_equal_to( + self, attribute: SingleAttributeComparisonOperand + ) -> None: + self._multiple_attributes_operand.greater_than_or_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def __mul__(self, value: MedRecordValue) -> ValueOperand: - return self.mul(value) + def less_than(self, attribute: SingleAttributeComparisonOperand) -> None: + self._multiple_attributes_operand.less_than( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def div(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the division of the attribute's value by the specified value. + def less_than_or_equal_to( + self, attribute: SingleAttributeComparisonOperand + ) -> None: + self._multiple_attributes_operand.less_than_or_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Args: - value (MedRecordValue): The value to divide the attribute's value by. + def equal_to(self, attribute: SingleAttributeComparisonOperand) -> None: + self._multiple_attributes_operand.equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Returns: - ValueOperand: The result of the division operation. - """ - return self._edge_attribute_operand.div(value) + def not_equal_to(self, attribute: SingleAttributeComparisonOperand) -> None: + self._multiple_attributes_operand.not_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def __truediv__(self, value: MedRecordValue) -> ValueOperand: - return self.div(value) + def is_in(self, attributes: MultipleAttributesComparisonOperand) -> None: + self._multiple_attributes_operand.is_in( + _py_multiple_attributes_comparison_operand_from_multiple_attributes_comparison_operand( + attributes + ) + ) - def pow(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the result of raising the attribute's value to the power of the specified value. + def is_not_in(self, attributes: MultipleAttributesComparisonOperand) -> None: + self._multiple_attributes_operand.is_not_in( + _py_multiple_attributes_comparison_operand_from_multiple_attributes_comparison_operand( + attributes + ) + ) - Args: - value (MedRecordValue): The value to raise the attribute's value to. + def starts_with(self, attribute: SingleAttributeComparisonOperand) -> None: + self._multiple_attributes_operand.starts_with( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Returns: - ValueOperand: The result of the exponentiation operation. - """ - return self._edge_attribute_operand.pow(value) + def ends_with(self, attribute: SingleAttributeComparisonOperand) -> None: + self._multiple_attributes_operand.ends_with( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def __pow__(self, value: MedRecordValue) -> ValueOperand: - return self.pow(value) + def contains(self, attribute: SingleAttributeComparisonOperand) -> None: + self._multiple_attributes_operand.contains( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def mod(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the remainder of dividing the attribute's value by the specified value. + def add(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._multiple_attributes_operand.add( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Args: - value (MedRecordValue): The value to divide the attribute's value by. + def subtract(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._multiple_attributes_operand.sub( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Returns: - ValueOperand: The result of the modulo operation. - """ - return self._edge_attribute_operand.mod(value) + def multiply(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._multiple_attributes_operand.mul( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def __mod__(self, value: MedRecordValue) -> ValueOperand: - return self.mod(value) + def modulo(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._multiple_attributes_operand.mod( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def round(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of rounding the attribute's value. + def power(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._multiple_attributes_operand.pow( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Returns: - ValueOperand: The result of the rounding operation. - """ - return self._edge_attribute_operand.round() + def absolute(self) -> None: + self._multiple_attributes_operand.abs() - def ceil(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of applying the ceiling function to the attribute's value, effectively rounding it up to the nearest whole number. + def trim(self) -> None: + self._multiple_attributes_operand.trim() - Returns: - ValueOperand: The result of the ceiling operation. - """ - return self._edge_attribute_operand.ceil() + def trim_start(self) -> None: + self._multiple_attributes_operand.trim_start() - def floor(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of applying the floor function to the attribute's value, effectively rounding it down to the nearest whole number. + def trim_end(self) -> None: + self._multiple_attributes_operand.trim_end() - Returns: - ValueOperand: The result of the floor operation. - """ - return self._edge_attribute_operand.floor() + def lowercase(self) -> None: + self._multiple_attributes_operand.lowercase() - def abs(self) -> ValueOperand: - """Creates a new ValueOperand representing the absolute value of the attribute's value. + def uppercase(self) -> None: + self._multiple_attributes_operand.uppercase() - Returns: - ValueOperand: The absolute value of the attribute's value. - """ - return self._edge_attribute_operand.abs() + def to_values(self) -> MultipleValuesOperand: + return MultipleValuesOperand._from_py_multiple_values_operand( + self._multiple_attributes_operand.to_values() + ) - def sqrt(self) -> ValueOperand: - """Creates a new ValueOperand representing the square root of the attribute's value. + def slice(self, start: int, end: int) -> None: + self._multiple_attributes_operand.slice(start, end) + + def either_or( + self, + either: Callable[[MultipleAttributesOperand], None], + or_: Callable[[MultipleAttributesOperand], None], + ) -> None: + self._multiple_attributes_operand.either_or( + lambda attributes: either( + MultipleAttributesOperand._from_py_multiple_attributes_operand( + attributes + ) + ), + lambda attributes: or_( + MultipleAttributesOperand._from_py_multiple_attributes_operand( + attributes + ) + ), + ) - Returns: - ValueOperand: The square root of the attribute's value. - """ - return self._edge_attribute_operand.sqrt() + def exclude(self, query: Callable[[MultipleAttributesOperand], None]) -> None: + self._multiple_attributes_operand.exclude( + lambda attributes: query( + MultipleAttributesOperand._from_py_multiple_attributes_operand( + attributes + ) + ) + ) - def trim(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of trimming whitespace from both ends of the attribute's value. + def clone(self) -> MultipleAttributesOperand: + return MultipleAttributesOperand._from_py_multiple_attributes_operand( + self._multiple_attributes_operand.deep_clone() + ) - Returns: - ValueOperand: The attribute's value with leading and trailing - whitespace removed. - """ - return self._edge_attribute_operand.trim() + @classmethod + def _from_py_multiple_attributes_operand( + cls, py_multiple_attributes_operand: PyMultipleAttributesOperand + ) -> MultipleAttributesOperand: + multiple_attributes_operand = cls() + multiple_attributes_operand._multiple_attributes_operand = ( + py_multiple_attributes_operand + ) + return multiple_attributes_operand - def trim_start(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of trimming whitespace from the start (left side) of the attribute's value. - Returns: - ValueOperand: The attribute's value with leading whitespace removed. - """ - return self._edge_attribute_operand.trim_start() +class SingleAttributeOperand: + _single_attribute_operand: PySingleAttributeOperand - def trim_end(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of trimming whitespace from the end (right side) of the attribute's value. + def is_string(self) -> None: + self._single_attribute_operand.is_string() - Returns: - ValueOperand: The attribute's value with trailing whitespace removed. - """ - return self._edge_attribute_operand.trim_end() + def is_int(self) -> None: + self._single_attribute_operand.is_int() - def lowercase(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of converting all characters in the attribute's value to lowercase. + def greater_than(self, attribute: SingleAttributeComparisonOperand) -> None: + self._single_attribute_operand.greater_than( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Returns: - ValueOperand: The attribute's value in lowercase letters. - """ - return self._edge_attribute_operand.lowercase() + def greater_than_or_equal_to( + self, attribute: SingleAttributeComparisonOperand + ) -> None: + self._single_attribute_operand.greater_than_or_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def uppercase(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of converting all characters in the attribute's value to uppercase. + def less_than(self, attribute: SingleAttributeComparisonOperand) -> None: + self._single_attribute_operand.less_than( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Returns: - ValueOperand: The attribute's value in uppercase letters. - """ - return self._edge_attribute_operand.uppercase() + def less_than_or_equal_to( + self, attribute: SingleAttributeComparisonOperand + ) -> None: + self._single_attribute_operand.less_than_or_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def slice(self, start: int, end: int) -> ValueOperand: - """Creates a new ValueOperand representing the result of slicing the attribute's value using the specified start and end indices. + def equal_to(self, attribute: SingleAttributeComparisonOperand) -> None: + self._single_attribute_operand.equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Args: - start (int): The index at which to start the slice. - end (int): The index at which to end the slice. + def not_equal_to(self, attribute: SingleAttributeComparisonOperand) -> None: + self._single_attribute_operand.not_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Returns: - ValueOperand: The attribute's value with the specified slice applied. - """ - return self._edge_attribute_operand.slice(start, end) + def is_in(self, attributes: MultipleAttributesComparisonOperand) -> None: + self._single_attribute_operand.is_in( + _py_multiple_attributes_comparison_operand_from_multiple_attributes_comparison_operand( + attributes + ) + ) + def is_not_in(self, attributes: MultipleAttributesComparisonOperand) -> None: + self._single_attribute_operand.is_not_in( + _py_multiple_attributes_comparison_operand_from_multiple_attributes_comparison_operand( + attributes + ) + ) -class NodeIndexOperand: - _node_index_operand: PyNodeIndexOperand + def starts_with(self, attribute: SingleAttributeComparisonOperand) -> None: + self._single_attribute_operand.starts_with( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def __init__(self, node_index_operand: PyNodeIndexOperand) -> None: - self._node_index_operand = node_index_operand + def ends_with(self, attribute: SingleAttributeComparisonOperand) -> None: + self._single_attribute_operand.ends_with( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def greater(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index is greater than the specified index. + def contains(self, attribute: SingleAttributeComparisonOperand) -> None: + self._single_attribute_operand.contains( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Args: - operand (NodeIndex): The index to compare against. + def add(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._single_attribute_operand.add( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Returns: - NodeOperation: A NodeOperation representing the greater-than comparison. - """ - return NodeOperation(self._node_index_operand.greater(operand)) + def subtract(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._single_attribute_operand.sub( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def __gt__(self, operand: NodeIndex) -> NodeOperation: - return self.greater(operand) + def multiply(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._single_attribute_operand.mul( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - def less(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index is less than the specified index. + def modulo(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._single_attribute_operand.mod( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Args: - operand (NodeIndex): The index to compare against. + def power(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._single_attribute_operand.pow( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) - Returns: - NodeOperation: A NodeOperation representing the less-than comparison. - """ - return NodeOperation(self._node_index_operand.less(operand)) + def absolute(self) -> None: + self._single_attribute_operand.abs() - def __lt__(self, operand: NodeIndex) -> NodeOperation: - return self.less(operand) + def trim(self) -> None: + self._single_attribute_operand.trim() - def greater_or_equal(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index is greater than or equal to the specified index. + def trim_start(self) -> None: + self._single_attribute_operand.trim_start() - Args: - operand (NodeIndex): The index to compare against. + def trim_end(self) -> None: + self._single_attribute_operand.trim_end() - Returns: - NodeOperation: A NodeOperation representing the - greater-than-or-equal-to comparison. - """ - return NodeOperation(self._node_index_operand.greater_or_equal(operand)) + def lowercase(self) -> None: + self._single_attribute_operand.lowercase() - def __ge__(self, operand: NodeIndex) -> NodeOperation: - return self.greater_or_equal(operand) + def uppercase(self) -> None: + self._single_attribute_operand.uppercase() - def less_or_equal(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index is less than or equal to the specified index. + def slice(self, start: int, end: int) -> None: + self._single_attribute_operand.slice(start, end) - Args: - operand (NodeIndex): The index to compare against. + def either_or( + self, + either: Callable[[SingleAttributeOperand], None], + or_: Callable[[SingleAttributeOperand], None], + ) -> None: + self._single_attribute_operand.either_or( + lambda attribute: either( + SingleAttributeOperand._from_py_single_attribute_operand(attribute) + ), + lambda attribute: or_( + SingleAttributeOperand._from_py_single_attribute_operand(attribute) + ), + ) - Returns: - NodeOperation: A NodeOperation representing the - less-than-or-equal-to comparison. - """ - return NodeOperation(self._node_index_operand.less_or_equal(operand)) + def exclude(self, query: Callable[[SingleAttributeOperand], None]) -> None: + self._single_attribute_operand.exclude( + lambda attribute: query( + SingleAttributeOperand._from_py_single_attribute_operand(attribute) + ) + ) - def __le__(self, operand: NodeIndex) -> NodeOperation: - return self.less_or_equal(operand) + def clone(self) -> SingleAttributeOperand: + return SingleAttributeOperand._from_py_single_attribute_operand( + self._single_attribute_operand.deep_clone() + ) - def equal(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index is equal to the specified index. + @classmethod + def _from_py_single_attribute_operand( + cls, py_single_attribute_operand: PySingleAttributeOperand + ) -> SingleAttributeOperand: + single_attribute_operand = cls() + single_attribute_operand._single_attribute_operand = py_single_attribute_operand + return single_attribute_operand - Args: - operand (NodeIndex): The index to compare against. - Returns: - NodeOperation: A NodeOperation representing the equality comparison. - """ - return NodeOperation(self._node_index_operand.equal(operand)) +class NodeIndicesOperand: + _node_indices_operand: PyNodeIndicesOperand - def __eq__(self, operand: NodeIndex) -> NodeOperation: - return self.equal(operand) + def max(self) -> NodeIndexOperand: + return NodeIndexOperand._from_py_node_index_operand( + self._node_indices_operand.max() + ) - def not_equal(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index is not equal to the specified index. + def min(self) -> NodeIndexOperand: + return NodeIndexOperand._from_py_node_index_operand( + self._node_indices_operand.min() + ) - Args: - operand (NodeIndex): The index to compare against. + def count(self) -> NodeIndexOperand: + return NodeIndexOperand._from_py_node_index_operand( + self._node_indices_operand.count() + ) - Returns: - NodeOperation: A NodeOperation representing the not-equal comparison. - """ - return NodeOperation(self._node_index_operand.not_equal(operand)) + def sum(self) -> NodeIndexOperand: + return NodeIndexOperand._from_py_node_index_operand( + self._node_indices_operand.sum() + ) - def __ne__(self, operand: NodeIndex) -> NodeOperation: - return self.not_equal(operand) + def first(self) -> NodeIndexOperand: + return NodeIndexOperand._from_py_node_index_operand( + self._node_indices_operand.first() + ) - def is_in(self, values: List[NodeIndex]) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index is found within the list of indices. + def last(self) -> NodeIndexOperand: + return NodeIndexOperand._from_py_node_index_operand( + self._node_indices_operand.last() + ) - Args: - values (List[NodeIndex]): The list of indices to check the node index - against. + def greater_than(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.greater_than( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - Returns: - NodeOperation: A NodeOperation representing the is-in comparison. - """ - return NodeOperation(self._node_index_operand.is_in(values)) + def greater_than_or_equal_to(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.greater_than_or_equal_to( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - def not_in(self, values: List[NodeIndex]) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index is not found within the list of indices. + def less_than(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.less_than( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - Args: - values (List[NodeIndex]): The list of indices to check the node index - against. + def less_than_or_equal_to(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.less_than_or_equal_to( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - Returns: - NodeOperation: A NodeOperation representing the not-in comparison. - """ - return NodeOperation(self._node_index_operand.not_in(values)) + def equal_to(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.equal_to( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - def starts_with(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index starts with the specified index. + def not_equal_to(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.not_equal_to( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - Args: - operand (NodeIndex): The index to compare against. + def is_in(self, indices: NodeIndicesComparisonOperand) -> None: + self._node_indices_operand.is_in( + _py_node_indices_comparison_operand_from_node_indices_comparison_operand( + indices + ) + ) - Returns: - NodeOperation: A NodeOperation representing the starts-with condition. - """ - return NodeOperation(self._node_index_operand.starts_with(operand)) + def is_not_in(self, indices: NodeIndicesComparisonOperand) -> None: + self._node_indices_operand.is_not_in( + _py_node_indices_comparison_operand_from_node_indices_comparison_operand( + indices + ) + ) - def ends_with(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index ends with the specified index. + def starts_with(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.starts_with( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - Args: - operand (NodeIndex): The index to compare against. + def ends_with(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.ends_with( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - Returns: - NodeOperation: A NodeOperation representing the ends-with condition. - """ - return NodeOperation(self._node_index_operand.ends_with(operand)) + def contains(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.contains( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - def contains(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index contains the specified index. + def add(self, index: NodeIndexArithmeticOperand) -> None: + self._node_indices_operand.add( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - Args: - operand (NodeIndex): The index to compare against. + def subtract(self, index: NodeIndexArithmeticOperand) -> None: + self._node_indices_operand.sub( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - Returns: - NodeOperation: A NodeOperation representing the contains condition. - """ - return NodeOperation(self._node_index_operand.contains(operand)) + def multiply(self, index: NodeIndexArithmeticOperand) -> None: + self._node_indices_operand.mul( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + def modulo(self, index: NodeIndexArithmeticOperand) -> None: + self._node_indices_operand.mod( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) -class EdgeIndexOperand: - _edge_index_operand: PyEdgeIndexOperand + def power(self, index: NodeIndexArithmeticOperand) -> None: + self._node_indices_operand.pow( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - def __init__(self, edge_index_operand: PyEdgeIndexOperand) -> None: - self._edge_index_operand = edge_index_operand + def absolute(self) -> None: + self._node_indices_operand.abs() - def greater(self, operand: EdgeIndex) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the edge index is greater than the specified index. + def trim(self) -> None: + self._node_indices_operand.trim() - Args: - operand (EdgeIndex): The index to compare against. + def trim_start(self) -> None: + self._node_indices_operand.trim_start() - Returns: - EdgeOperation: A EdgeOperation representing the greater-than comparison. - """ - return EdgeOperation(self._edge_index_operand.greater(operand)) + def trim_end(self) -> None: + self._node_indices_operand.trim_end() - def __gt__(self, operand: EdgeIndex) -> EdgeOperation: - return self.greater(operand) + def lowercase(self) -> None: + self._node_indices_operand.lowercase() - def less(self, operand: EdgeIndex) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the edge index is less than the specified index. + def uppercase(self) -> None: + self._node_indices_operand.uppercase() - Args: - operand (EdgeIndex): The index to compare against. + def slice(self, start: int, end: int) -> None: + self._node_indices_operand.slice(start, end) - Returns: - EdgeOperation: A EdgeOperation representing the less-than comparison. - """ - return EdgeOperation(self._edge_index_operand.less(operand)) + def either_or( + self, + either: Callable[[NodeIndicesOperand], None], + or_: Callable[[NodeIndicesOperand], None], + ) -> None: + self._node_indices_operand.either_or( + lambda node_indices: either( + NodeIndicesOperand._from_py_node_indices_operand(node_indices) + ), + lambda node_indices: or_( + NodeIndicesOperand._from_py_node_indices_operand(node_indices) + ), + ) - def __lt__(self, operand: EdgeIndex) -> EdgeOperation: - return self.less(operand) + def exclude(self, query: Callable[[NodeIndicesOperand], None]) -> None: + self._node_indices_operand.exclude( + lambda node_indices: query( + NodeIndicesOperand._from_py_node_indices_operand(node_indices) + ) + ) - def greater_or_equal(self, operand: EdgeIndex) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the edge index is greater than or equal to the specified index. + def clone(self) -> NodeIndicesOperand: + return NodeIndicesOperand._from_py_node_indices_operand( + self._node_indices_operand.deep_clone() + ) - Args: - operand (EdgeIndex): The index to compare against. + @classmethod + def _from_py_node_indices_operand( + cls, py_node_indices_operand: PyNodeIndicesOperand + ) -> NodeIndicesOperand: + node_indices_operand = cls() + node_indices_operand._node_indices_operand = py_node_indices_operand + return node_indices_operand - Returns: - EdgeOperation: A EdgeOperation representing the - greater-than-or-equal-to comparison. - """ - return EdgeOperation(self._edge_index_operand.greater_or_equal(operand)) - def __ge__(self, operand: EdgeIndex) -> EdgeOperation: - return self.greater_or_equal(operand) +class NodeIndexOperand: + _node_index_operand: PyNodeIndexOperand - def less_or_equal(self, operand: EdgeIndex) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the edge index is less than or equal to the specified index. + def greater_than(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.greater_than( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - Args: - operand (EdgeIndex): The index to compare against. + def greater_than_or_equal_to(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.greater_than_or_equal_to( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - Returns: - EdgeOperation: A EdgeOperation representing the - less-than-or-equal-to comparison. - """ - return EdgeOperation(self._edge_index_operand.less_or_equal(operand)) + def less_than(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.less_than( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - def __le__(self, operand: EdgeIndex) -> EdgeOperation: - return self.less_or_equal(operand) + def less_than_or_equal_to(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.less_than_or_equal_to( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - def equal(self, operand: EdgeIndex) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the edge index is equal to the specified index. + def equal_to(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.equal_to( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - Args: - operand (EdgeIndex): The index to compare against. + def not_equal_to(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.not_equal_to( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - Returns: - EdgeOperation: A EdgeOperation representing the equality comparison. - """ - return EdgeOperation(self._edge_index_operand.equal(operand)) + def is_in(self, indices: NodeIndicesComparisonOperand) -> None: + self._node_index_operand.is_in( + _py_node_indices_comparison_operand_from_node_indices_comparison_operand( + indices + ) + ) - def __eq__(self, operand: EdgeIndex) -> EdgeOperation: - return self.equal(operand) + def is_not_in(self, indices: NodeIndicesComparisonOperand) -> None: + self._node_index_operand.is_not_in( + _py_node_indices_comparison_operand_from_node_indices_comparison_operand( + indices + ) + ) - def not_equal(self, operand: EdgeIndex) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the edge index is not equal to the specified index. + def starts_with(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.starts_with( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - Args: - operand (EdgeIndex): The index to compare against. + def ends_with(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.ends_with( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - Returns: - EdgeOperation: A EdgeOperation representing the not-equal comparison. - """ - return EdgeOperation(self._edge_index_operand.not_equal(operand)) + def contains(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.contains( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - def __ne__(self, operand: EdgeIndex) -> EdgeOperation: - return self.not_equal(operand) + def add(self, index: NodeIndexArithmeticOperand) -> None: + self._node_index_operand.add( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - def is_in(self, values: List[EdgeIndex]) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the edge index is found within the list of indices. + def subtract(self, index: NodeIndexArithmeticOperand) -> None: + self._node_index_operand.sub( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - Args: - values (List[EdgeIndex]): The list of indices to check the edge index - against. + def multiply(self, index: NodeIndexArithmeticOperand) -> None: + self._node_index_operand.mul( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - Returns: - EdgeOperation: A EdgeOperation representing the is-in comparison. - """ - return EdgeOperation(self._edge_index_operand.is_in(values)) + def modulo(self, index: NodeIndexArithmeticOperand) -> None: + self._node_index_operand.mod( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - def not_in(self, values: List[EdgeIndex]) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the edge index is not found within the list of indices. + def power(self, index: NodeIndexArithmeticOperand) -> None: + self._node_index_operand.pow( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) - Args: - values (List[EdgeIndex]): The list of indices to check the edge index - against. + def absolute(self) -> None: + self._node_index_operand.abs() - Returns: - EdgeOperation: A EdgeOperation representing the not-in comparison. - """ - return EdgeOperation(self._edge_index_operand.not_in(values)) + def trim(self) -> None: + self._node_index_operand.trim() + def trim_start(self) -> None: + self._node_index_operand.trim_start() -class NodeOperand: - _node_operand: PyNodeOperand + def trim_end(self) -> None: + self._node_index_operand.trim_end() - def __init__(self) -> None: - self._node_operand = PyNodeOperand() + def lowercase(self) -> None: + self._node_index_operand.lowercase() - def in_group(self, operand: Group) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node is part of the specified group. + def uppercase(self) -> None: + self._node_index_operand.uppercase() - Args: - operand (Group): The group to check the node against. + def slice(self, start: int, end: int) -> None: + self._node_index_operand.slice(start, end) - Returns: - NodeOperation: A NodeOperation indicating if the node is part of the - specified group. - """ - return NodeOperation(self._node_operand.in_group(operand)) + def either_or( + self, + either: Callable[[NodeIndexOperand], None], + or_: Callable[[NodeIndexOperand], None], + ) -> None: + self._node_index_operand.either_or( + lambda node_index: either( + NodeIndexOperand._from_py_node_index_operand(node_index) + ), + lambda node_index: or_( + NodeIndexOperand._from_py_node_index_operand(node_index) + ), + ) - def has_attribute(self, operand: MedRecordAttribute) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node has the specified attribute. + def exclude(self, query: Callable[[NodeIndexOperand], None]) -> None: + self._node_index_operand.exclude( + lambda node_index: query( + NodeIndexOperand._from_py_node_index_operand(node_index) + ) + ) - Args: - operand (MedRecordAttribute): The attribute to check on the node. + def clone(self) -> NodeIndexOperand: + return NodeIndexOperand._from_py_node_index_operand( + self._node_index_operand.deep_clone() + ) - Returns: - NodeOperation: A NodeOperation indicating if the node has the - specified attribute. - """ - return NodeOperation(self._node_operand.has_attribute(operand)) + @classmethod + def _from_py_node_index_operand( + cls, py_node_index_operand: PyNodeIndexOperand + ) -> NodeIndexOperand: + node_index_operand = cls() + node_index_operand._node_index_operand = py_node_index_operand + return node_index_operand - def has_outgoing_edge_with(self, operation: EdgeOperation) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node has an outgoing edge that satisfies the specified EdgeOperation. - Args: - operation (EdgeOperation): An EdgeOperation to evaluate against - outgoing edges. +class EdgeIndicesOperand: + _edge_indices_operand: PyEdgeIndicesOperand - Returns: - NodeOperation: A NodeOperation indicating if the node has an - outgoing edge satisfying the specified operation. - """ - return NodeOperation( - self._node_operand.has_outgoing_edge_with(operation._edge_operation) + def max(self) -> EdgeIndexOperand: + return EdgeIndexOperand._from_py_edge_index_operand( + self._edge_indices_operand.max() ) - def has_incoming_edge_with(self, operation: EdgeOperation) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node has an incoming edge that satisfies the specified EdgeOperation. - - Args: - operation (EdgeOperation): An EdgeOperation to evaluate against - incoming edges. - - Returns: - NodeOperation: A NodeOperation indicating if the node has an - incoming edge satisfying the specified operation. - """ - return NodeOperation( - self._node_operand.has_incoming_edge_with(operation._edge_operation) + def min(self) -> EdgeIndexOperand: + return EdgeIndexOperand._from_py_edge_index_operand( + self._edge_indices_operand.min() ) - def has_edge_with(self, operation: EdgeOperation) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node has any edge (incoming or outgoing) that satisfies the specified EdgeOperation. - - Args: - operation (EdgeOperation): An EdgeOperation to evaluate against - edges connected to the node. - - Returns: - NodeOperation: A NodeOperation indicating if the node has any edge - satisfying the specified operation. - """ - return NodeOperation( - self._node_operand.has_edge_with(operation._edge_operation) + def count(self) -> EdgeIndexOperand: + return EdgeIndexOperand._from_py_edge_index_operand( + self._edge_indices_operand.count() ) - def has_neighbor_with( - self, operation: NodeOperation, *, directed: bool = True - ) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node has a neighboring node that satisfies the specified NodeOperation. - - Args: - operation (NodeOperation): A NodeOperation to evaluate against - neighboring nodes. - directed (bool): Whether to consider edges as directed - - Returns: - NodeOperation: A NodeOperation indicating if the node has a neighboring node - satisfying the specified operation. - """ - if directed: - return NodeOperation( - self._node_operand.has_neighbor_with(operation._node_operation) - ) - else: - return NodeOperation( - self._node_operand.has_neighbor_undirected_with( - operation._node_operation - ) - ) + def sum(self) -> EdgeIndexOperand: + return EdgeIndexOperand._from_py_edge_index_operand( + self._edge_indices_operand.sum() + ) - def attribute(self, attribute: MedRecordAttribute) -> NodeAttributeOperand: - """Accesses an NodeAttributeOperand for the specified attribute, allowing for the creation of operations based on node attributes. + def first(self) -> EdgeIndexOperand: + return EdgeIndexOperand._from_py_edge_index_operand( + self._edge_indices_operand.first() + ) - Args: - attribute (MedRecordAttribute): The attribute of the node to perform - operations on. + def last(self) -> EdgeIndexOperand: + return EdgeIndexOperand._from_py_edge_index_operand( + self._edge_indices_operand.last() + ) - Returns: - NodeAttributeOperand: An operand that represents the specified node - attribute, enabling further operations such as comparisons and - arithmetic operations. - """ - return NodeAttributeOperand(self._node_operand.attribute(attribute)) + def greater_than(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.greater_than( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - def index(self) -> NodeIndexOperand: - """Accesses an NodeIndexOperand, allowing for the creation of operations based on the node index. + def greater_than_or_equal_to(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.greater_than_or_equal_to( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - Returns: - NodeIndexOperand: An operand that represents the specified node - index, enabling further operations such as comparisons and - arithmetic operations. - """ - return NodeIndexOperand(self._node_operand.index()) + def less_than(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.less_than( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + def less_than_or_equal_to(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.less_than_or_equal_to( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) -def node() -> NodeOperand: - """Factory function to create and return a new NodeOperand instance. + def equal_to(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.equal_to( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - Returns: - NodeOperand: An instance of NodeOperand for constructing node-based operations. - """ - return NodeOperand() + def not_equal_to(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.not_equal_to( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + def is_in(self, indices: EdgeIndicesComparisonOperand) -> None: + self._edge_indices_operand.is_in( + _py_edge_indices_comparison_operand_from_edge_indices_comparison_operand( + indices + ) + ) -class EdgeOperand: - _edge_operand: PyEdgeOperand + def is_not_in(self, indices: EdgeIndicesComparisonOperand) -> None: + self._edge_indices_operand.is_not_in( + _py_edge_indices_comparison_operand_from_edge_indices_comparison_operand( + indices + ) + ) - def __init__(self) -> None: - self._edge_operand = PyEdgeOperand() + def starts_with(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.starts_with( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - def connected_target(self, operand: NodeIndex) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if the edge is connected to a target node with the specified index. + def ends_with(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.ends_with( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - Args: - operand (NodeIndex): The index of the target node to check for a connection. + def contains(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.contains( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - Returns: - EdgeOperation: An EdgeOperation indicating if the edge is connected to the - specified target node. - """ - return EdgeOperation(self._edge_operand.connected_target(operand)) + def add(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_indices_operand.add( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - def connected_source(self, operand: NodeIndex) -> EdgeOperation: - """Generates an EdgeOperation that evaluates to true if the edge originates from a source node with the given index. + def subtract(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_indices_operand.sub( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - Args: - operand (NodeIndex): The index of the source node to check for a connection. + def multiply(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_indices_operand.mul( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - Returns: - EdgeOperation: An EdgeOperation indicating if the edge is connected from the - specified source node. - """ - return EdgeOperation(self._edge_operand.connected_source(operand)) + def modulo(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_indices_operand.mod( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - def connected(self, operand: NodeIndex) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if the edge is connected to or from a node with the specified index. + def power(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_indices_operand.pow( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - Args: - operand (NodeIndex): The index of the node to check for a connection. + def either_or( + self, + either: Callable[[EdgeIndicesOperand], None], + or_: Callable[[EdgeIndicesOperand], None], + ) -> None: + self._edge_indices_operand.either_or( + lambda edge_indices: either( + EdgeIndicesOperand._from_edge_indices_operand(edge_indices) + ), + lambda edge_indices: or_( + EdgeIndicesOperand._from_edge_indices_operand(edge_indices) + ), + ) - Returns: - EdgeOperation: An EdgeOperation indicating if the edge is connected to the - specified node. - """ - return EdgeOperation(self._edge_operand.connected(operand)) + def exclude(self, query: Callable[[EdgeIndicesOperand], None]) -> None: + self._edge_indices_operand.exclude( + lambda edge_indices: query( + EdgeIndicesOperand._from_edge_indices_operand(edge_indices) + ) + ) - def in_group(self, operand: Group) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if the edge is part of the specified group. + def clone(self) -> EdgeIndicesOperand: + return EdgeIndicesOperand._from_edge_indices_operand( + self._edge_indices_operand.deep_clone() + ) - Args: - operand (Group): The group to check the edge against. + @classmethod + def _from_edge_indices_operand( + cls, py_edge_indices_operand: PyEdgeIndicesOperand + ) -> EdgeIndicesOperand: + edge_indices_operand = cls() + edge_indices_operand._edge_indices_operand = py_edge_indices_operand + return edge_indices_operand - Returns: - EdgeOperation: An EdgeOperation indicating if the edge is part of the - specified group. - """ - return EdgeOperation(self._edge_operand.in_group(operand)) - def has_attribute(self, operand: MedRecordAttribute) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if the edge has the specified attribute. +class EdgeIndexOperand: + _edge_index_operand: PyEdgeIndexOperand - Args: - operand (MedRecordAttribute): The attribute to check on the edge. + def greater_than(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.greater_than( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - Returns: - EdgeOperation: An EdgeOperation indicating if the edge has the - specified attribute. - """ - return EdgeOperation(self._edge_operand.has_attribute(operand)) + def greater_than_or_equal_to(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.greater_than_or_equal_to( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - def connected_source_with(self, operation: NodeOperation) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if the edge originates from a source node that satisfies the specified NodeOperation. + def less_than(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.less_than( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - Args: - operation (NodeOperation): A NodeOperation to evaluate against the - source node. + def less_than_or_equal_to(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.less_than_or_equal_to( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - Returns: - EdgeOperation: An EdgeOperation indicating if the source node of the - edge satisfies the specified operation. - """ - return EdgeOperation( - self._edge_operand.connected_source_with(operation._node_operation) + def equal_to(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.equal_to( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) ) - def connected_target_with(self, operation: NodeOperation) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if the edge is connected to a target node that satisfies the specified NodeOperation. + def not_equal_to(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.not_equal_to( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - Args: - operation (NodeOperation): A NodeOperation to evaluate against the - target node. + def is_in(self, indices: EdgeIndicesComparisonOperand) -> None: + self._edge_index_operand.is_in( + _py_edge_indices_comparison_operand_from_edge_indices_comparison_operand( + indices + ) + ) - Returns: - EdgeOperation: An EdgeOperation indicating if the target node of the - edge satisfies the specified operation. - """ - return EdgeOperation( - self._edge_operand.connected_target_with(operation._node_operation) + def is_not_in(self, indices: EdgeIndicesComparisonOperand) -> None: + self._edge_index_operand.is_not_in( + _py_edge_indices_comparison_operand_from_edge_indices_comparison_operand( + indices + ) ) - def connected_with(self, operation: NodeOperation) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if the edge is connected to or from a node that satisfies the specified NodeOperation. + def starts_with(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.starts_with( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - Args: - operation (NodeOperation): A NodeOperation to evaluate against the - connected node. + def ends_with(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.ends_with( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - Returns: - EdgeOperation: An EdgeOperation indicating if either the source or - target node of the edge satisfies the specified operation. - """ - return EdgeOperation( - self._edge_operand.connected_with(operation._node_operation) + def contains(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.contains( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) ) - def has_parallel_edges_with(self, operation: EdgeOperation) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if there are parallel edges that satisfy the specified EdgeOperation. + def add(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_index_operand.add( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - Args: - operation (EdgeOperation): An EdgeOperation to evaluate against - parallel edges. + def subtract(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_index_operand.sub( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - Returns: - EdgeOperation: An EdgeOperation indicating if there are parallel edges - satisfying the specified operation. - """ - return EdgeOperation( - self._edge_operand.has_parallel_edges_with(operation._edge_operation) + def multiply(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_index_operand.mul( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) ) - def has_parallel_edges_with_self_comparison( - self, operation: EdgeOperation - ) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if there are parallel edges that satisfy the specified EdgeOperation. + def modulo(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_index_operand.mod( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - Using `edge().attribute(...)` in the operation will compare to the attribute of - this edge, not the parallel edge. + def power(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_index_operand.pow( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) - Args: - operation (EdgeOperation): An EdgeOperation to evaluate against - parallel edges. + def either_or( + self, + either: Callable[[EdgeIndexOperand], None], + or_: Callable[[EdgeIndexOperand], None], + ) -> None: + self._edge_index_operand.either_or( + lambda edge_index: either( + EdgeIndexOperand._from_py_edge_index_operand(edge_index) + ), + lambda edge_index: or_( + EdgeIndexOperand._from_py_edge_index_operand(edge_index) + ), + ) - Returns: - EdgeOperation: An EdgeOperation indicating if there are parallel edges - satisfying the specified operation. - """ - return EdgeOperation( - self._edge_operand.has_parallel_edges_with_self_comparison( - operation._edge_operation + def exclude(self, query: Callable[[EdgeIndexOperand], None]) -> None: + self._edge_index_operand.exclude( + lambda edge_index: query( + EdgeIndexOperand._from_py_edge_index_operand(edge_index) ) ) - def attribute(self, attribute: MedRecordAttribute) -> EdgeAttributeOperand: - """Accesses an EdgeAttributeOperand for the specified attribute, allowing for the creation of operations based on edge attributes. - - Args: - attribute (MedRecordAttribute): The attribute of the edge to perform - operations on. - - Returns: - EdgeAttributeOperand: An operand that represents the specified edge - attribute, enabling further operations such as comparisons and - arithmetic operations. - """ - return EdgeAttributeOperand(self._edge_operand.attribute(attribute)) - - def index(self) -> EdgeIndexOperand: - """Accesses an EdgeIndexOperand, allowing for the creation of operations based on the edge index. - - Returns: - EdgeIndexOperand: An operand that represents the specified edge - index, enabling further operations such as comparisons and - arithmetic operations. - """ - return EdgeIndexOperand(self._edge_operand.index()) - - -def edge() -> EdgeOperand: - """Factory function to create and return a new EdgeOperand instance. + def clone(self) -> EdgeIndexOperand: + return EdgeIndexOperand._from_py_edge_index_operand( + self._edge_index_operand.deep_clone() + ) - Returns: - EdgeOperand: An instance of EdgeOperand for constructing edge-based operations. - """ - return EdgeOperand() + @classmethod + def _from_py_edge_index_operand( + cls, py_edge_index_operand: PyEdgeIndexOperand + ) -> EdgeIndexOperand: + edge_index_operand = cls() + edge_index_operand._edge_index_operand = py_edge_index_operand + return edge_index_operand diff --git a/medmodels/medrecord/schema.py b/medmodels/medrecord/schema.py index 2dd28890..690dd601 100644 --- a/medmodels/medrecord/schema.py +++ b/medmodels/medrecord/schema.py @@ -19,7 +19,7 @@ class AttributeType(Enum): Temporal = auto() @staticmethod - def _from_pyattributetype(py_attribute_type: PyAttributeType) -> AttributeType: + def _from_py_attribute_type(py_attribute_type: PyAttributeType) -> AttributeType: """ Converts a PyAttributeType to an AttributeType. @@ -36,7 +36,7 @@ def _from_pyattributetype(py_attribute_type: PyAttributeType) -> AttributeType: elif py_attribute_type == PyAttributeType.Temporal: return AttributeType.Temporal - def _into_pyattributetype(self) -> PyAttributeType: + def _into_py_attribute_type(self) -> PyAttributeType: """ Converts an AttributeType to a PyAttributeType. @@ -81,7 +81,7 @@ def __eq__(self, value: object) -> bool: bool: True if the objects are equal, False otherwise. """ if isinstance(value, PyAttributeType): - return self._into_pyattributetype() == value + return self._into_py_attribute_type() == value elif isinstance(value, AttributeType): return str(self) == str(value) @@ -295,7 +295,7 @@ def _convert_input( ) -> PyAttributeDataType: if isinstance(input, tuple): return PyAttributeDataType( - input[0]._inner(), input[1]._into_pyattributetype() + input[0]._inner(), input[1]._into_py_attribute_type() ) return PyAttributeDataType(input._inner(), None) @@ -334,8 +334,8 @@ def _convert_node( input: PyAttributeDataType, ) -> Tuple[DataType, Optional[AttributeType]]: return ( - DataType._from_pydatatype(input.data_type), - AttributeType._from_pyattributetype(input.attribute_type) + DataType._from_py_data_type(input.data_type), + AttributeType._from_py_attribute_type(input.attribute_type) if input.attribute_type is not None else None, ) @@ -361,8 +361,8 @@ def _convert_edge( input: PyAttributeDataType, ) -> Tuple[DataType, Optional[AttributeType]]: return ( - DataType._from_pydatatype(input.data_type), - AttributeType._from_pyattributetype(input.attribute_type) + DataType._from_py_data_type(input.data_type), + AttributeType._from_py_attribute_type(input.attribute_type) if input.attribute_type is not None else None, ) @@ -422,7 +422,7 @@ def __init__( ) @classmethod - def _from_pyschema(cls, schema: PySchema) -> Schema: + def _from_py_schema(cls, schema: PySchema) -> Schema: """ Creates a Schema instance from an existing PySchema. diff --git a/medmodels/medrecord/tests/test_builder.py b/medmodels/medrecord/tests/test_builder.py index 5b261b35..71ef4ce2 100644 --- a/medmodels/medrecord/tests/test_builder.py +++ b/medmodels/medrecord/tests/test_builder.py @@ -58,10 +58,10 @@ def test_with_schema(self): medrecord = mr.MedRecord.builder().with_schema(schema).build() - medrecord.add_node("node", {"attribute": 1}) + medrecord.add_nodes(("node1", {"attribute": 1})) with self.assertRaises(ValueError): - medrecord.add_node("node", {"attribute": "1"}) + medrecord.add_nodes(("node2", {"attribute": "1"})) if __name__ == "__main__": diff --git a/medmodels/medrecord/tests/test_indexers.py b/medmodels/medrecord/tests/test_indexers.py index 17492c6d..b5c47a32 100644 --- a/medmodels/medrecord/tests/test_indexers.py +++ b/medmodels/medrecord/tests/test_indexers.py @@ -1,7 +1,7 @@ import unittest from medmodels import MedRecord -from medmodels.medrecord import edge, node +from medmodels.medrecord.querying import EdgeOperand, NodeOperand def create_medrecord(): @@ -21,6 +21,30 @@ def create_medrecord(): ) +def node_greater_than_or_equal_two(node: NodeOperand): + node.index().greater_than_or_equal_to(2) + + +def node_greater_than_three(node: NodeOperand): + node.index().greater_than(3) + + +def node_less_than_two(node: NodeOperand): + node.index().less_than(2) + + +def edge_greater_than_or_equal_two(edge: EdgeOperand): + edge.index().greater_than_or_equal_to(2) + + +def edge_greater_than_three(edge: EdgeOperand): + edge.index().greater_than(3) + + +def edge_less_than_two(edge: EdgeOperand): + edge.index().less_than(2) + + class TestMedRecord(unittest.TestCase): def test_node_getitem(self): medrecord = create_medrecord() @@ -118,54 +142,54 @@ def test_node_getitem(self): self.assertEqual( {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}}, - medrecord.node[node().index() >= 2], + medrecord.node[node_greater_than_or_equal_two], ) # Empty query should not fail self.assertEqual( {}, - medrecord.node[node().index() > 3], + medrecord.node[node_greater_than_three], ) self.assertEqual( {2: "bar", 3: "bar"}, - medrecord.node[node().index() >= 2, "foo"], + medrecord.node[node_greater_than_or_equal_two, "foo"], ) # Accessing a non-existing key should fail with self.assertRaises(KeyError): - medrecord.node[node().index() >= 2, "test"] + medrecord.node[node_greater_than_or_equal_two, "test"] self.assertEqual( { 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}, }, - medrecord.node[node().index() >= 2, ["foo", "bar"]], + medrecord.node[node_greater_than_or_equal_two, ["foo", "bar"]], ) # Accessing a non-existing key should fail with self.assertRaises(KeyError): - medrecord.node[node().index() >= 2, ["foo", "test"]] + medrecord.node[node_greater_than_or_equal_two, ["foo", "test"]] # Accessing a key that doesn't exist in all nodes should fail with self.assertRaises(KeyError): - medrecord.node[node().index() < 2, ["foo", "lorem"]] + medrecord.node[node_less_than_two, ["foo", "lorem"]] self.assertEqual( { 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}, }, - medrecord.node[node().index() >= 2, :], + medrecord.node[node_greater_than_or_equal_two, :], ) with self.assertRaises(ValueError): - medrecord.node[node().index() >= 2, 1:] + medrecord.node[node_greater_than_or_equal_two, 1:] with self.assertRaises(ValueError): - medrecord.node[node().index() >= 2, :1] + medrecord.node[node_greater_than_or_equal_two, :1] with self.assertRaises(ValueError): - medrecord.node[node().index() >= 2, ::1] + medrecord.node[node_greater_than_or_equal_two, ::1] self.assertEqual( { @@ -360,7 +384,7 @@ def test_node_setitem(self): medrecord.node[[0, 1], ::1] = "test" medrecord = create_medrecord() - medrecord.node[node().index() >= 2] = {"foo": "bar", "bar": "test"} + medrecord.node[node_greater_than_or_equal_two] = {"foo": "bar", "bar": "test"} self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -373,10 +397,10 @@ def test_node_setitem(self): medrecord = create_medrecord() # Empty query should not fail - medrecord.node[node().index() > 3] = {"foo": "bar", "bar": "test"} + medrecord.node[node_greater_than_three] = {"foo": "bar", "bar": "test"} medrecord = create_medrecord() - medrecord.node[node().index() >= 2, "foo"] = "test" + medrecord.node[node_greater_than_or_equal_two, "foo"] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -388,7 +412,7 @@ def test_node_setitem(self): ) medrecord = create_medrecord() - medrecord.node[node().index() >= 2, ["foo", "bar"]] = "test" + medrecord.node[node_greater_than_or_equal_two, ["foo", "bar"]] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -400,7 +424,7 @@ def test_node_setitem(self): ) medrecord = create_medrecord() - medrecord.node[node().index() >= 2, :] = "test" + medrecord.node[node_greater_than_or_equal_two, :] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -412,11 +436,11 @@ def test_node_setitem(self): ) with self.assertRaises(ValueError): - medrecord.node[node().index() >= 2, 1:] = "test" + medrecord.node[node_greater_than_or_equal_two, 1:] = "test" with self.assertRaises(ValueError): - medrecord.node[node().index() >= 2, :1] = "test" + medrecord.node[node_greater_than_or_equal_two, :1] = "test" with self.assertRaises(ValueError): - medrecord.node[node().index() >= 2, ::1] = "test" + medrecord.node[node_greater_than_or_equal_two, ::1] = "test" medrecord = create_medrecord() medrecord.node[:, "foo"] = "test" @@ -544,7 +568,7 @@ def test_node_setitem(self): ) medrecord = create_medrecord() - medrecord.node[node().index() >= 2, "test"] = "test" + medrecord.node[node_greater_than_or_equal_two, "test"] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -556,7 +580,7 @@ def test_node_setitem(self): ) medrecord = create_medrecord() - medrecord.node[node().index() >= 2, ["test", "test2"]] = "test" + medrecord.node[node_greater_than_or_equal_two, ["test", "test2"]] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -634,7 +658,7 @@ def test_node_setitem(self): ) medrecord = create_medrecord() - medrecord.node[node().index() < 2, "lorem"] = "test" + medrecord.node[node_less_than_two, "lorem"] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, @@ -646,7 +670,7 @@ def test_node_setitem(self): ) medrecord = create_medrecord() - medrecord.node[node().index() < 2, ["lorem", "test"]] = "test" + medrecord.node[node_less_than_two, ["lorem", "test"]] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, @@ -804,7 +828,7 @@ def test_node_delitem(self): del medrecord.node[[0, 1], ::1] medrecord = create_medrecord() - del medrecord.node[node().index() >= 2, "foo"] + del medrecord.node[node_greater_than_or_equal_two, "foo"] self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -817,7 +841,7 @@ def test_node_delitem(self): medrecord = create_medrecord() # Empty query should not fail - del medrecord.node[node().index() > 3, "foo"] + del medrecord.node[node_greater_than_three, "foo"] self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -831,10 +855,10 @@ def test_node_delitem(self): medrecord = create_medrecord() # Removing a non-existing key should fail with self.assertRaises(KeyError): - del medrecord.node[node().index() >= 2, "test"] + del medrecord.node[node_greater_than_or_equal_two, "test"] medrecord = create_medrecord() - del medrecord.node[node().index() >= 2, ["foo", "bar"]] + del medrecord.node[node_greater_than_or_equal_two, ["foo", "bar"]] self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -848,15 +872,15 @@ def test_node_delitem(self): medrecord = create_medrecord() # Removing a non-existing key should fail with self.assertRaises(KeyError): - del medrecord.node[node().index() >= 2, ["foo", "test"]] + del medrecord.node[node_greater_than_or_equal_two, ["foo", "test"]] medrecord = create_medrecord() # Removing a key that doesn't exist in all nodes should fail with self.assertRaises(KeyError): - del medrecord.node[node().index() < 2, ["foo", "lorem"]] + del medrecord.node[node_less_than_two, ["foo", "lorem"]] medrecord = create_medrecord() - del medrecord.node[node().index() >= 2, :] + del medrecord.node[node_greater_than_or_equal_two, :] self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -868,11 +892,11 @@ def test_node_delitem(self): ) with self.assertRaises(ValueError): - del medrecord.node[node().index() >= 2, 1:] + del medrecord.node[node_greater_than_or_equal_two, 1:] with self.assertRaises(ValueError): - del medrecord.node[node().index() >= 2, :1] + del medrecord.node[node_greater_than_or_equal_two, :1] with self.assertRaises(ValueError): - del medrecord.node[node().index() >= 2, ::1] + del medrecord.node[node_greater_than_or_equal_two, ::1] medrecord = create_medrecord() del medrecord.node[:, "foo"] @@ -1048,54 +1072,54 @@ def test_edge_getitem(self): self.assertEqual( {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}}, - medrecord.edge[edge().index() >= 2], + medrecord.edge[edge_greater_than_or_equal_two], ) # Empty query should not fail self.assertEqual( {}, - medrecord.edge[edge().index() > 3], + medrecord.edge[edge_greater_than_three], ) self.assertEqual( {2: "bar", 3: "bar"}, - medrecord.edge[edge().index() >= 2, "foo"], + medrecord.edge[edge_greater_than_or_equal_two, "foo"], ) # Accessing a non-existing key should fail with self.assertRaises(KeyError): - medrecord.edge[edge().index() >= 2, "test"] + medrecord.edge[edge_greater_than_or_equal_two, "test"] self.assertEqual( { 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}, }, - medrecord.edge[edge().index() >= 2, ["foo", "bar"]], + medrecord.edge[edge_greater_than_or_equal_two, ["foo", "bar"]], ) # Accessing a non-existing key should fail with self.assertRaises(KeyError): - medrecord.edge[edge().index() >= 2, ["foo", "test"]] + medrecord.edge[edge_greater_than_or_equal_two, ["foo", "test"]] # Accessing a key that doesn't exist in all edges should fail with self.assertRaises(KeyError): - medrecord.edge[edge().index() < 2, ["foo", "lorem"]] + medrecord.edge[edge_less_than_two, ["foo", "lorem"]] self.assertEqual( { 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}, }, - medrecord.edge[edge().index() >= 2, :], + medrecord.edge[edge_greater_than_or_equal_two, :], ) with self.assertRaises(ValueError): - medrecord.edge[edge().index() >= 2, 1:] + medrecord.edge[edge_greater_than_or_equal_two, 1:] with self.assertRaises(ValueError): - medrecord.edge[edge().index() >= 2, :1] + medrecord.edge[edge_greater_than_or_equal_two, :1] with self.assertRaises(ValueError): - medrecord.edge[edge().index() >= 2, ::1] + medrecord.edge[edge_greater_than_or_equal_two, ::1] self.assertEqual( { @@ -1290,7 +1314,7 @@ def test_edge_setitem(self): medrecord.edge[[0, 1], ::1] = "test" medrecord = create_medrecord() - medrecord.edge[edge().index() >= 2] = {"foo": "bar", "bar": "test"} + medrecord.edge[edge_greater_than_or_equal_two] = {"foo": "bar", "bar": "test"} self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1303,10 +1327,10 @@ def test_edge_setitem(self): medrecord = create_medrecord() # Empty query should not fail - medrecord.edge[edge().index() > 3] = {"foo": "bar", "bar": "test"} + medrecord.edge[edge_greater_than_three] = {"foo": "bar", "bar": "test"} medrecord = create_medrecord() - medrecord.edge[edge().index() >= 2, "foo"] = "test" + medrecord.edge[edge_greater_than_or_equal_two, "foo"] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1318,7 +1342,7 @@ def test_edge_setitem(self): ) medrecord = create_medrecord() - medrecord.edge[edge().index() >= 2, ["foo", "bar"]] = "test" + medrecord.edge[edge_greater_than_or_equal_two, ["foo", "bar"]] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1330,7 +1354,7 @@ def test_edge_setitem(self): ) medrecord = create_medrecord() - medrecord.edge[edge().index() >= 2, :] = "test" + medrecord.edge[edge_greater_than_or_equal_two, :] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1342,11 +1366,11 @@ def test_edge_setitem(self): ) with self.assertRaises(ValueError): - medrecord.edge[edge().index() >= 2, 1:] = "test" + medrecord.edge[edge_greater_than_or_equal_two, 1:] = "test" with self.assertRaises(ValueError): - medrecord.edge[edge().index() >= 2, :1] = "test" + medrecord.edge[edge_greater_than_or_equal_two, :1] = "test" with self.assertRaises(ValueError): - medrecord.edge[edge().index() >= 2, ::1] = "test" + medrecord.edge[edge_greater_than_or_equal_two, ::1] = "test" medrecord = create_medrecord() medrecord.edge[:, "foo"] = "test" @@ -1474,7 +1498,7 @@ def test_edge_setitem(self): ) medrecord = create_medrecord() - medrecord.edge[edge().index() >= 2, "test"] = "test" + medrecord.edge[edge_greater_than_or_equal_two, "test"] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1486,7 +1510,7 @@ def test_edge_setitem(self): ) medrecord = create_medrecord() - medrecord.edge[edge().index() >= 2, ["test", "test2"]] = "test" + medrecord.edge[edge_greater_than_or_equal_two, ["test", "test2"]] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1564,7 +1588,7 @@ def test_edge_setitem(self): ) medrecord = create_medrecord() - medrecord.edge[edge().index() < 2, "lorem"] = "test" + medrecord.edge[edge_less_than_two, "lorem"] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, @@ -1576,7 +1600,7 @@ def test_edge_setitem(self): ) medrecord = create_medrecord() - medrecord.edge[edge().index() < 2, ["lorem", "test"]] = "test" + medrecord.edge[edge_less_than_two, ["lorem", "test"]] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, @@ -1734,7 +1758,7 @@ def test_edge_delitem(self): del medrecord.edge[[0, 1], ::1] medrecord = create_medrecord() - del medrecord.edge[edge().index() >= 2, "foo"] + del medrecord.edge[edge_greater_than_or_equal_two, "foo"] self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1747,7 +1771,7 @@ def test_edge_delitem(self): medrecord = create_medrecord() # Empty query should not fail - del medrecord.edge[edge().index() > 3, "foo"] + del medrecord.edge[edge_greater_than_three, "foo"] self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1761,10 +1785,10 @@ def test_edge_delitem(self): medrecord = create_medrecord() # Removing a non-existing key should fail with self.assertRaises(KeyError): - del medrecord.edge[edge().index() >= 2, "test"] + del medrecord.edge[edge_greater_than_or_equal_two, "test"] medrecord = create_medrecord() - del medrecord.edge[edge().index() >= 2, ["foo", "bar"]] + del medrecord.edge[edge_greater_than_or_equal_two, ["foo", "bar"]] self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1778,15 +1802,15 @@ def test_edge_delitem(self): medrecord = create_medrecord() # Removing a non-existing key should fail with self.assertRaises(KeyError): - del medrecord.edge[edge().index() >= 2, ["foo", "test"]] + del medrecord.edge[edge_greater_than_or_equal_two, ["foo", "test"]] medrecord = create_medrecord() # Removing a key that doesn't exist in all edges should fail with self.assertRaises(KeyError): - del medrecord.edge[edge().index() < 2, ["foo", "lorem"]] + del medrecord.edge[edge_less_than_two, ["foo", "lorem"]] medrecord = create_medrecord() - del medrecord.edge[edge().index() >= 2, :] + del medrecord.edge[edge_greater_than_or_equal_two, :] self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1798,11 +1822,11 @@ def test_edge_delitem(self): ) with self.assertRaises(ValueError): - del medrecord.edge[edge().index() >= 2, 1:] + del medrecord.edge[edge_greater_than_or_equal_two, 1:] with self.assertRaises(ValueError): - del medrecord.edge[edge().index() >= 2, :1] + del medrecord.edge[edge_greater_than_or_equal_two, :1] with self.assertRaises(ValueError): - del medrecord.edge[edge().index() >= 2, ::1] + del medrecord.edge[edge_greater_than_or_equal_two, ::1] medrecord = create_medrecord() del medrecord.edge[:, "foo"] diff --git a/medmodels/medrecord/tests/test_medrecord.py b/medmodels/medrecord/tests/test_medrecord.py index 2157af2f..f4b92319 100644 --- a/medmodels/medrecord/tests/test_medrecord.py +++ b/medmodels/medrecord/tests/test_medrecord.py @@ -7,8 +7,7 @@ import medmodels.medrecord as mr from medmodels import MedRecord -from medmodels.medrecord import edge as edge_select -from medmodels.medrecord import node as node_select +from medmodels.medrecord.querying import EdgeOperand, NodeOperand from medmodels.medrecord.types import Attributes, NodeIndex @@ -262,33 +261,35 @@ def test_schema(self): medrecord = MedRecord.with_schema(schema) medrecord.add_group("group") - medrecord.add_node("0", {"attribute": 1}) + medrecord.add_nodes(("0", {"attribute": 1})) with self.assertRaises(ValueError): - medrecord.add_node("1", {"attribute": "1"}) + medrecord.add_nodes(("1", {"attribute": "1"})) - medrecord.add_node("1", {"attribute": 1, "attribute2": 1}) + medrecord.add_nodes(("1", {"attribute": 1, "attribute2": 1})) - medrecord.add_node_to_group("group", "1") + medrecord.add_nodes_to_group("group", "1") - medrecord.add_node("2", {"attribute": 1, "attribute2": "1"}) + medrecord.add_nodes(("2", {"attribute": 1, "attribute2": "1"})) with self.assertRaises(ValueError): - medrecord.add_node_to_group("group", "2") + medrecord.add_nodes_to_group("group", "2") - medrecord.add_edge("0", "1", {"attribute": 1}) + medrecord.add_edges(("0", "1", {"attribute": 1})) with self.assertRaises(ValueError): - medrecord.add_edge("0", "1", {"attribute": "1"}) + medrecord.add_edges(("0", "1", {"attribute": "1"})) - edge_index = medrecord.add_edge("0", "1", {"attribute": 1, "attribute2": 1}) + edge_index = medrecord.add_edges(("0", "1", {"attribute": 1, "attribute2": 1})) - medrecord.add_edge_to_group("group", edge_index) + medrecord.add_edges_to_group("group", edge_index) - edge_index = medrecord.add_edge("0", "1", {"attribute": 1, "attribute2": "1"}) + edge_index = medrecord.add_edges( + ("0", "1", {"attribute": 1, "attribute2": "1"}) + ) with self.assertRaises(ValueError): - medrecord.add_edge_to_group("group", edge_index) + medrecord.add_edges_to_group("group", edge_index) def test_nodes(self): medrecord = create_medrecord() @@ -359,7 +360,10 @@ def test_outgoing_edges(self): {key: sorted(value) for (key, value) in edges.items()}, ) - edges = medrecord.outgoing_edges(node_select().index().is_in(["0", "1"])) + def query(node: NodeOperand): + node.index().is_in(["0", "1"]) + + edges = medrecord.outgoing_edges(query) self.assertEqual( {"0": sorted([0, 3]), "1": [1, 2]}, @@ -388,7 +392,10 @@ def test_incoming_edges(self): self.assertEqual({"1": [0], "2": [2]}, edges) - edges = medrecord.incoming_edges(node_select().index().is_in(["1", "2"])) + def query(node: NodeOperand): + node.index().is_in(["1", "2"]) + + edges = medrecord.incoming_edges(query) self.assertEqual({"1": [0], "2": [2]}, edges) @@ -414,7 +421,10 @@ def test_edge_endpoints(self): self.assertEqual({0: ("0", "1"), 1: ("1", "0")}, endpoints) - endpoints = medrecord.edge_endpoints(edge_select().index().is_in([0, 1])) + def query(edge: EdgeOperand): + edge.index().is_in([0, 1]) + + endpoints = medrecord.edge_endpoints(query) self.assertEqual({0: ("0", "1"), 1: ("1", "0")}, endpoints) @@ -440,7 +450,10 @@ def test_edges_connecting(self): self.assertEqual([0], edges) - edges = medrecord.edges_connecting(node_select().index().is_in(["0", "1"]), "1") + def query1(node: NodeOperand): + node.index().is_in(["0", "1"]) + + edges = medrecord.edges_connecting(query1, "1") self.assertEqual([0], edges) @@ -448,7 +461,10 @@ def test_edges_connecting(self): self.assertEqual(sorted([0, 3]), sorted(edges)) - edges = medrecord.edges_connecting("0", node_select().index().is_in(["1", "3"])) + def query2(node: NodeOperand): + node.index().is_in(["1", "3"]) + + edges = medrecord.edges_connecting("0", query2) self.assertEqual(sorted([0, 3]), sorted(edges)) @@ -456,10 +472,13 @@ def test_edges_connecting(self): self.assertEqual(sorted([0, 2, 3]), sorted(edges)) - edges = medrecord.edges_connecting( - node_select().index().is_in(["0", "1"]), - node_select().index().is_in(["1", "2", "3"]), - ) + def query3(node: NodeOperand): + node.index().is_in(["0", "1"]) + + def query4(node: NodeOperand): + node.index().is_in(["1", "2", "3"]) + + edges = medrecord.edges_connecting(query3, query4) self.assertEqual(sorted([0, 2, 3]), sorted(edges)) @@ -467,40 +486,17 @@ def test_edges_connecting(self): self.assertEqual([0, 1], sorted(edges)) - def test_add_node(self): - medrecord = MedRecord() - - self.assertEqual(0, medrecord.node_count()) - - medrecord.add_node("0", {}) - - self.assertEqual(1, medrecord.node_count()) - self.assertEqual(0, len(medrecord.groups)) - - medrecord = MedRecord() - - medrecord.add_node("0", {}, "0") - - self.assertIn("0", medrecord.nodes_in_group("0")) - self.assertEqual(1, len(medrecord.groups)) - - def test_invalid_add_node(self): - medrecord = create_medrecord() - - with self.assertRaises(AssertionError): - medrecord.add_node("0", {}) - - def test_remove_node(self): + def test_remove_nodes(self): medrecord = create_medrecord() self.assertEqual(4, medrecord.node_count()) - attributes = medrecord.remove_node("0") + attributes = medrecord.remove_nodes("0") self.assertEqual(3, medrecord.node_count()) self.assertEqual(create_nodes()[0][1], attributes) - attributes = medrecord.remove_node(["1", "2"]) + attributes = medrecord.remove_nodes(["1", "2"]) self.assertEqual(1, medrecord.node_count()) self.assertEqual( @@ -511,23 +507,26 @@ def test_remove_node(self): self.assertEqual(4, medrecord.node_count()) - attributes = medrecord.remove_node(node_select().index().is_in(["0", "1"])) + def query(node: NodeOperand): + node.index().is_in(["0", "1"]) + + attributes = medrecord.remove_nodes(query) self.assertEqual(2, medrecord.node_count()) self.assertEqual( {"0": create_nodes()[0][1], "1": create_nodes()[1][1]}, attributes ) - def test_invalid_remove_node(self): + def test_invalid_remove_nodes(self): medrecord = create_medrecord() # Removing a non-existing node should fail with self.assertRaises(IndexError): - medrecord.remove_node("50") + medrecord.remove_nodes("50") # Removing a non-existing node should fail with self.assertRaises(IndexError): - medrecord.remove_node(["0", "50"]) + medrecord.remove_nodes(["0", "50"]) def test_add_nodes(self): medrecord = MedRecord() @@ -538,6 +537,23 @@ def test_add_nodes(self): self.assertEqual(4, medrecord.node_count()) + # Adding node tuple + medrecord = MedRecord() + + self.assertEqual(0, medrecord.node_count()) + + medrecord.add_nodes(("0", {})) + + self.assertEqual(1, medrecord.node_count()) + self.assertEqual(0, len(medrecord.groups)) + + medrecord = MedRecord() + + medrecord.add_nodes(("0", {}), "0") + + self.assertIn("0", medrecord.nodes_in_group("0")) + self.assertEqual(1, len(medrecord.groups)) + # Adding tuple to a group medrecord = MedRecord() @@ -784,46 +800,17 @@ def test_invalid_add_nodes_polars(self): with self.assertRaises(RuntimeError): medrecord.add_nodes_polars([(nodes, "index"), (second_nodes, "invalid")]) - def test_add_edge(self): - medrecord = create_medrecord() - - self.assertEqual(4, medrecord.edge_count()) - - medrecord.add_edge("0", "3", {}) - - self.assertEqual(5, medrecord.edge_count()) - - medrecord.add_edge("3", "0", {}, group="0") - - self.assertEqual(6, medrecord.edge_count()) - self.assertIn(5, medrecord.edges_in_group("0")) - - def test_invalid_add_edge(self): - medrecord = MedRecord() - - nodes = create_nodes() - - medrecord.add_nodes(nodes) - - # Adding an edge pointing to a non-existent node should fail - with self.assertRaises(IndexError): - medrecord.add_edge("0", "50", {}) - - # Adding an edge from a non-existing node should fail - with self.assertRaises(IndexError): - medrecord.add_edge("50", "0", {}) - - def test_remove_edge(self): + def test_remove_edges(self): medrecord = create_medrecord() self.assertEqual(4, medrecord.edge_count()) - attributes = medrecord.remove_edge(0) + attributes = medrecord.remove_edges(0) self.assertEqual(3, medrecord.edge_count()) self.assertEqual(create_edges()[0][2], attributes) - attributes = medrecord.remove_edge([1, 2]) + attributes = medrecord.remove_edges([1, 2]) self.assertEqual(1, medrecord.edge_count()) self.assertEqual({1: create_edges()[1][2], 2: create_edges()[2][2]}, attributes) @@ -832,17 +819,20 @@ def test_remove_edge(self): self.assertEqual(4, medrecord.edge_count()) - attributes = medrecord.remove_edge(edge_select().index().is_in([0, 1])) + def query(edge: EdgeOperand): + edge.index().is_in([0, 1]) + + attributes = medrecord.remove_edges(query) self.assertEqual(2, medrecord.edge_count()) self.assertEqual({0: create_edges()[0][2], 1: create_edges()[1][2]}, attributes) - def test_invalid_remove_edge(self): + def test_invalid_remove_edges(self): medrecord = create_medrecord() # Removing a non-existing edge should fail with self.assertRaises(IndexError): - medrecord.remove_edge(50) + medrecord.remove_edges(50) def test_add_edges(self): medrecord = MedRecord() @@ -857,8 +847,21 @@ def test_add_edges(self): self.assertEqual(4, medrecord.edge_count()) - # Adding tuple to a group + # Adding single edge tuple + medrecord = create_medrecord() + + self.assertEqual(4, medrecord.edge_count()) + + medrecord.add_edges(("0", "3", {})) + + self.assertEqual(5, medrecord.edge_count()) + + medrecord.add_edges(("3", "0", {}), group="0") + self.assertEqual(6, medrecord.edge_count()) + self.assertIn(5, medrecord.edges_in_group("0")) + + # Adding tuple to a group medrecord = MedRecord() medrecord.add_nodes(nodes) @@ -984,6 +987,21 @@ def test_add_edges(self): self.assertIn(2, medrecord.edges_in_group("0")) self.assertIn(3, medrecord.edges_in_group("0")) + def test_invalid_add_edges(self): + medrecord = MedRecord() + + nodes = create_nodes() + + medrecord.add_nodes(nodes) + + # Adding an edge pointing to a non-existent node should fail + with self.assertRaises(IndexError): + medrecord.add_edges(("0", "50", {})) + + # Adding an edge from a non-existing node should fail + with self.assertRaises(IndexError): + medrecord.add_edges(("50", "0", {})) + def test_add_edges_pandas(self): medrecord = MedRecord() @@ -1100,10 +1118,16 @@ def test_add_group(self): self.assertEqual(sorted(["0", "1"]), sorted(nodes_and_edges["nodes"])) self.assertEqual(sorted([0, 1]), sorted(nodes_and_edges["edges"])) + def query1(node: NodeOperand): + node.index().is_in(["0", "1"]) + + def query2(edge: EdgeOperand): + edge.index().is_in([0, 1]) + medrecord.add_group( "3", - node_select().index().is_in(["0", "1"]), - edge_select().index().is_in([0, 1]), + query1, + query2, ) self.assertEqual(4, medrecord.group_count()) @@ -1136,145 +1160,160 @@ def test_invalid_add_group(self): with self.assertRaises(AssertionError): medrecord.add_group("0", ["1", "0"]) + def query(node: NodeOperand): + node.index().equal_to("0") + # Adding a node to a group that already is in the group should fail with self.assertRaises(AssertionError): - medrecord.add_group("0", node_select().index() == "0") + medrecord.add_group("0", query) - def test_remove_group(self): + def test_remove_groups(self): medrecord = create_medrecord() medrecord.add_group("0") self.assertEqual(1, medrecord.group_count()) - medrecord.remove_group("0") + medrecord.remove_groups("0") self.assertEqual(0, medrecord.group_count()) - def test_invalid_remove_group(self): + def test_invalid_remove_groups(self): medrecord = create_medrecord() # Removing a non-existing group should fail with self.assertRaises(IndexError): - medrecord.remove_group("0") + medrecord.remove_groups("0") - def test_add_node_to_group(self): + def test_add_nodes_to_group(self): medrecord = create_medrecord() medrecord.add_group("0") self.assertEqual([], medrecord.nodes_in_group("0")) - medrecord.add_node_to_group("0", "0") + medrecord.add_nodes_to_group("0", "0") self.assertEqual(["0"], medrecord.nodes_in_group("0")) - medrecord.add_node_to_group("0", ["1", "2"]) + medrecord.add_nodes_to_group("0", ["1", "2"]) self.assertEqual( sorted(["0", "1", "2"]), sorted(medrecord.nodes_in_group("0")), ) - medrecord.add_node_to_group("0", node_select().index() == "3") + def query(node: NodeOperand): + node.index().equal_to("3") + + medrecord.add_nodes_to_group("0", query) self.assertEqual( sorted(["0", "1", "2", "3"]), sorted(medrecord.nodes_in_group("0")), ) - def test_invalid_add_node_to_group(self): + def test_invalid_add_nodes_to_group(self): medrecord = create_medrecord() medrecord.add_group("0", ["0"]) # Adding to a non-existing group should fail with self.assertRaises(IndexError): - medrecord.add_node_to_group("50", "1") + medrecord.add_nodes_to_group("50", "1") # Adding to a non-existing group should fail with self.assertRaises(IndexError): - medrecord.add_node_to_group("50", ["1", "2"]) + medrecord.add_nodes_to_group("50", ["1", "2"]) # Adding a non-existing node to a group should fail with self.assertRaises(IndexError): - medrecord.add_node_to_group("0", "50") + medrecord.add_nodes_to_group("0", "50") # Adding a non-existing node to a group should fail with self.assertRaises(IndexError): - medrecord.add_node_to_group("0", ["1", "50"]) + medrecord.add_nodes_to_group("0", ["1", "50"]) # Adding a node to a group that already is in the group should fail with self.assertRaises(AssertionError): - medrecord.add_node_to_group("0", "0") + medrecord.add_nodes_to_group("0", "0") # Adding a node to a group that already is in the group should fail with self.assertRaises(AssertionError): - medrecord.add_node_to_group("0", ["1", "0"]) + medrecord.add_nodes_to_group("0", ["1", "0"]) + + def query(node: NodeOperand): + node.index().equal_to("0") # Adding a node to a group that already is in the group should fail with self.assertRaises(AssertionError): - medrecord.add_node_to_group("0", node_select().index() == "0") + medrecord.add_nodes_to_group("0", query) - def test_add_edge_to_group(self): + def test_add_edges_to_group(self): medrecord = create_medrecord() medrecord.add_group("0") self.assertEqual([], medrecord.edges_in_group("0")) - medrecord.add_edge_to_group("0", 0) + medrecord.add_edges_to_group("0", 0) self.assertEqual([0], medrecord.edges_in_group("0")) - medrecord.add_edge_to_group("0", [1, 2]) + medrecord.add_edges_to_group("0", [1, 2]) self.assertEqual( sorted([0, 1, 2]), sorted(medrecord.edges_in_group("0")), ) - medrecord.add_edge_to_group("0", edge_select().index() == 3) + def query(edge: EdgeOperand): + edge.index().equal_to(3) + + medrecord.add_edges_to_group("0", query) self.assertEqual( sorted([0, 1, 2, 3]), sorted(medrecord.edges_in_group("0")), ) - def test_invalid_add_edge_to_group(self): + def test_invalid_add_edges_to_group(self): medrecord = create_medrecord() medrecord.add_group("0", edges=[0]) # Adding to a non-existing group should fail with self.assertRaises(IndexError): - medrecord.add_edge_to_group("50", 1) + medrecord.add_edges_to_group("50", 1) # Adding to a non-existing group should fail with self.assertRaises(IndexError): - medrecord.add_edge_to_group("50", [1, 2]) + medrecord.add_edges_to_group("50", [1, 2]) # Adding a non-existing edge to a group should fail with self.assertRaises(IndexError): - medrecord.add_edge_to_group("0", 50) + medrecord.add_edges_to_group("0", 50) # Adding a non-existing edge to a group should fail with self.assertRaises(IndexError): - medrecord.add_edge_to_group("0", [1, 50]) + medrecord.add_edges_to_group("0", [1, 50]) # Adding an edge to a group that already is in the group should fail with self.assertRaises(AssertionError): - medrecord.add_edge_to_group("0", 0) + medrecord.add_edges_to_group("0", 0) # Adding an edge to a group that already is in the group should fail with self.assertRaises(AssertionError): - medrecord.add_edge_to_group("0", [1, 0]) + medrecord.add_edges_to_group("0", [1, 0]) + + def query(edge: EdgeOperand): + edge.index().equal_to(0) # Adding an edge to a group that already is in the group should fail with self.assertRaises(AssertionError): - medrecord.add_edge_to_group("0", edge_select().index() == 0) + medrecord.add_edges_to_group("0", query) - def test_remove_node_from_group(self): + def test_remove_nodes_from_group(self): medrecord = create_medrecord() medrecord.add_group("0", ["0", "1"]) @@ -1284,58 +1323,64 @@ def test_remove_node_from_group(self): sorted(medrecord.nodes_in_group("0")), ) - medrecord.remove_node_from_group("0", "1") + medrecord.remove_nodes_from_group("0", "1") self.assertEqual(["0"], medrecord.nodes_in_group("0")) - medrecord.add_node_to_group("0", "1") + medrecord.add_nodes_to_group("0", "1") self.assertEqual( sorted(["0", "1"]), sorted(medrecord.nodes_in_group("0")), ) - medrecord.remove_node_from_group("0", ["0", "1"]) + medrecord.remove_nodes_from_group("0", ["0", "1"]) self.assertEqual([], medrecord.nodes_in_group("0")) - medrecord.add_node_to_group("0", ["0", "1"]) + medrecord.add_nodes_to_group("0", ["0", "1"]) self.assertEqual( sorted(["0", "1"]), sorted(medrecord.nodes_in_group("0")), ) - medrecord.remove_node_from_group("0", node_select().index().is_in(["0", "1"])) + def query(node: NodeOperand): + node.index().is_in(["0", "1"]) + + medrecord.remove_nodes_from_group("0", query) self.assertEqual([], medrecord.nodes_in_group("0")) - def test_invalid_remove_node_from_group(self): + def test_invalid_remove_nodes_from_group(self): medrecord = create_medrecord() medrecord.add_group("0", ["0", "1"]) # Removing a node from a non-existing group should fail with self.assertRaises(IndexError): - medrecord.remove_node_from_group("50", "0") + medrecord.remove_nodes_from_group("50", "0") # Removing a node from a non-existing group should fail with self.assertRaises(IndexError): - medrecord.remove_node_from_group("50", ["0", "1"]) + medrecord.remove_nodes_from_group("50", ["0", "1"]) + + def query(node: NodeOperand): + node.index().equal_to("0") # Removing a node from a non-existing group should fail with self.assertRaises(IndexError): - medrecord.remove_node_from_group("50", node_select().index() == "0") + medrecord.remove_nodes_from_group("50", query) # Removing a non-existing node from a group should fail with self.assertRaises(IndexError): - medrecord.remove_node_from_group("0", "50") + medrecord.remove_nodes_from_group("0", "50") # Removing a non-existing node from a group should fail with self.assertRaises(IndexError): - medrecord.remove_node_from_group("0", ["0", "50"]) + medrecord.remove_nodes_from_group("0", ["0", "50"]) - def test_remove_edge_from_group(self): + def test_remove_edges_from_group(self): medrecord = create_medrecord() medrecord.add_group("0", edges=[0, 1]) @@ -1345,56 +1390,62 @@ def test_remove_edge_from_group(self): sorted(medrecord.edges_in_group("0")), ) - medrecord.remove_edge_from_group("0", 1) + medrecord.remove_edges_from_group("0", 1) self.assertEqual([0], medrecord.edges_in_group("0")) - medrecord.add_edge_to_group("0", 1) + medrecord.add_edges_to_group("0", 1) self.assertEqual( sorted([0, 1]), sorted(medrecord.edges_in_group("0")), ) - medrecord.remove_edge_from_group("0", [0, 1]) + medrecord.remove_edges_from_group("0", [0, 1]) self.assertEqual([], medrecord.edges_in_group("0")) - medrecord.add_edge_to_group("0", [0, 1]) + medrecord.add_edges_to_group("0", [0, 1]) self.assertEqual( sorted([0, 1]), sorted(medrecord.edges_in_group("0")), ) - medrecord.remove_edge_from_group("0", edge_select().index().is_in([0, 1])) + def query(edge: EdgeOperand): + edge.index().is_in([0, 1]) + + medrecord.remove_edges_from_group("0", query) self.assertEqual([], medrecord.edges_in_group("0")) - def test_invalid_remove_edge_from_group(self): + def test_invalid_remove_edges_from_group(self): medrecord = create_medrecord() medrecord.add_group("0", edges=[0, 1]) # Removing an edge from a non-existing group should fail with self.assertRaises(IndexError): - medrecord.remove_edge_from_group("50", 0) + medrecord.remove_edges_from_group("50", 0) # Removing an edge from a non-existing group should fail with self.assertRaises(IndexError): - medrecord.remove_edge_from_group("50", [0, 1]) + medrecord.remove_edges_from_group("50", [0, 1]) + + def query(edge: EdgeOperand): + edge.index().equal_to(0) # Removing an edge from a non-existing group should fail with self.assertRaises(IndexError): - medrecord.remove_edge_from_group("50", edge_select().index() == 0) + medrecord.remove_edges_from_group("50", query) # Removing a non-existing edge from a group should fail with self.assertRaises(IndexError): - medrecord.remove_edge_from_group("0", 50) + medrecord.remove_edges_from_group("0", 50) # Removing a non-existing edge from a group should fail with self.assertRaises(IndexError): - medrecord.remove_edge_from_group("0", [0, 50]) + medrecord.remove_edges_from_group("0", [0, 50]) def test_nodes_in_group(self): medrecord = create_medrecord() @@ -1439,9 +1490,12 @@ def test_groups_of_node(self): self.assertEqual({"0": ["0"], "1": ["0"]}, medrecord.groups_of_node(["0", "1"])) + def query(node: NodeOperand): + node.index().is_in(["0", "1"]) + self.assertEqual( {"0": ["0"], "1": ["0"]}, - medrecord.groups_of_node(node_select().index().is_in(["0", "1"])), + medrecord.groups_of_node(query), ) def test_invalid_groups_of_node(self): @@ -1464,9 +1518,12 @@ def test_groups_of_edge(self): self.assertEqual({0: ["0"], 1: ["0"]}, medrecord.groups_of_edge([0, 1])) + def query(edge: EdgeOperand): + edge.index().is_in([0, 1]) + self.assertEqual( {0: ["0"], 1: ["0"]}, - medrecord.groups_of_edge(edge_select().index().is_in([0, 1])), + medrecord.groups_of_edge(query), ) def test_invalid_groups_of_edge(self): @@ -1485,19 +1542,19 @@ def test_node_count(self): self.assertEqual(0, medrecord.node_count()) - medrecord.add_node("0", {}) + medrecord.add_nodes([("0", {})]) self.assertEqual(1, medrecord.node_count()) def test_edge_count(self): medrecord = MedRecord() - medrecord.add_node("0", {}) - medrecord.add_node("1", {}) + medrecord.add_nodes(("0", {})) + medrecord.add_nodes(("1", {})) self.assertEqual(0, medrecord.edge_count()) - medrecord.add_edge("0", "1", {}) + medrecord.add_edges(("0", "1", {})) self.assertEqual(1, medrecord.edge_count()) @@ -1550,7 +1607,10 @@ def test_neighbors(self): {key: sorted(value) for (key, value) in neighbors.items()}, ) - neighbors = medrecord.neighbors(node_select().index().is_in(["0", "1"])) + def query1(node: NodeOperand): + node.index().is_in(["0", "1"]) + + neighbors = medrecord.neighbors(query1) self.assertEqual( {"0": sorted(["1", "3"]), "1": ["0", "2"]}, @@ -1571,9 +1631,10 @@ def test_neighbors(self): {key: sorted(value) for (key, value) in neighbors.items()}, ) - neighbors = medrecord.neighbors( - node_select().index().is_in(["0", "1"]), directed=False - ) + def query2(node: NodeOperand): + node.index().is_in(["0", "1"]) + + neighbors = medrecord.neighbors(query2, directed=False) self.assertEqual( {"0": sorted(["1", "3"]), "1": ["0", "2"]}, @@ -1621,8 +1682,8 @@ def test_clone(self): self.assertEqual(medrecord.edge_count(), cloned_medrecord.edge_count()) self.assertEqual(medrecord.group_count(), cloned_medrecord.group_count()) - cloned_medrecord.add_node("new_node", {"attribute": "value"}) - cloned_medrecord.add_edge("0", "new_node", {"attribute": "value"}) + cloned_medrecord.add_nodes(("new_node", {"attribute": "value"})) + cloned_medrecord.add_edges(("0", "new_node", {"attribute": "value"})) cloned_medrecord.add_group("new_group", ["new_node"]) self.assertNotEqual(medrecord.node_count(), cloned_medrecord.node_count()) diff --git a/medmodels/medrecord/tests/test_overview.py b/medmodels/medrecord/tests/test_overview.py index 3120433d..2357c59f 100644 --- a/medmodels/medrecord/tests/test_overview.py +++ b/medmodels/medrecord/tests/test_overview.py @@ -7,7 +7,7 @@ import medmodels as mm from medmodels.medrecord._overview import extract_attribute_summary, prettify_table -from medmodels.medrecord.querying import edge, node +from medmodels.medrecord.querying import EdgeOperand, NodeOperand def create_medrecord(): @@ -70,46 +70,48 @@ def test_extract_attribute_summary(self): # medrecord without schema medrecord = create_medrecord() + def query1(node: NodeOperand): + node.in_group("Stroke") + # No attributes - no_attributes = extract_attribute_summary( - medrecord.node[node().in_group("Stroke")] - ) + no_attributes = extract_attribute_summary(medrecord.node[query1]) self.assertDictEqual(no_attributes, {}) + def query2(node: NodeOperand): + node.in_group("Patients") + # numeric type - numeric_attribute = extract_attribute_summary( - medrecord.node[node().in_group("Patients")] - ) + numeric_attribute = extract_attribute_summary(medrecord.node[query2]) numeric_expected = {"age": {"min": 20, "max": 70, "mean": 40.0}} self.assertDictEqual(numeric_attribute, numeric_expected) + def query3(node: NodeOperand): + node.in_group("Medications") + # string attributes - str_attributes = extract_attribute_summary( - medrecord.node[node().in_group("Medications")] - ) + str_attributes = extract_attribute_summary(medrecord.node[query3]) self.assertDictEqual( str_attributes, {"ATC": {"values": "Values: B01AA03, B01AF01"}} ) + def query4(node: NodeOperand): + node.in_group("Aspirin") + # nan attribute - nan_attributes = extract_attribute_summary( - medrecord.node[node().in_group("Aspirin")] - ) + nan_attributes = extract_attribute_summary(medrecord.node[query4]) + self.assertDictEqual(nan_attributes, {"ATC": {"values": "-"}}) + def query5(edge: EdgeOperand): + edge.source_node().in_group("Medications") + edge.target_node().in_group("Patients") + # temporal attributes - temp_attributes = extract_attribute_summary( - medrecord.edge[ - medrecord.select_edges( - edge().connected_source_with(node().in_group("Medications")) - & edge().connected_target_with(node().in_group("Patients")) - ) - ] - ) + temp_attributes = extract_attribute_summary(medrecord.edge[query5]) self.assertDictEqual( temp_attributes, @@ -121,14 +123,13 @@ def test_extract_attribute_summary(self): }, ) + def query6(edge: EdgeOperand): + edge.source_node().in_group("Stroke") + edge.target_node().in_group("Patients") + # mixed attributes mixed_attributes = extract_attribute_summary( - medrecord.edge[ - medrecord.select_edges( - edge().connected_source_with(node().in_group("Stroke")) - & edge().connected_target_with(node().in_group("Patients")) - ) - ] + medrecord.edge[medrecord.select_edges(query6)] ) self.assertDictEqual( mixed_attributes, @@ -158,9 +159,12 @@ def test_extract_attribute_summary(self): }, ) + def query7(edge: EdgeOperand): + edge.in_group("patient_diagnosis") + # compare schema and not schema patient_diagnosis = extract_attribute_summary( - mr_schema.edge[edge().in_group("patient_diagnosis")], + mr_schema.edge[query7], schema=mr_schema.schema.group("patient_diagnosis").edges, ) diff --git a/medmodels/medrecord/tests/test_querying.py b/medmodels/medrecord/tests/test_querying.py deleted file mode 100644 index 808961ab..00000000 --- a/medmodels/medrecord/tests/test_querying.py +++ /dev/null @@ -1,1170 +0,0 @@ -import unittest -from typing import List, Tuple - -from medmodels import MedRecord -from medmodels.medrecord import edge, node -from medmodels.medrecord.types import Attributes, NodeIndex - - -def create_nodes() -> List[Tuple[NodeIndex, Attributes]]: - return [ - ( - "0", - { - "lorem": "ipsum", - "dolor": " ipsum ", - "test": "Ipsum", - "integer": 1, - "float": 0.5, - }, - ), - ("1", {"amet": "consectetur"}), - ("2", {"adipiscing": "elit"}), - ("3", {}), - ] - - -def create_edges() -> List[Tuple[NodeIndex, NodeIndex, Attributes]]: - return [ - ("0", "1", {"sed": "do", "eiusmod": "tempor", "dolor": " do ", "test": "DO"}), - ("1", "2", {"incididunt": "ut"}), - ("0", "2", {"test": 1, "integer": 1, "float": 0.5}), - ("0", "2", {"test": 0}), - ] - - -def create_medrecord() -> MedRecord: - return MedRecord.from_tuples(create_nodes(), create_edges()) - - -class TestMedRecord(unittest.TestCase): - def test_select_nodes_node(self): - medrecord = create_medrecord() - - medrecord.add_group("test", ["0"]) - - # Node in group - self.assertEqual(["0"], medrecord.select_nodes(node().in_group("test"))) - - # Node has attribute - self.assertEqual(["0"], medrecord.select_nodes(node().has_attribute("lorem"))) - - # Node has outgoing edge with - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().has_outgoing_edge_with(edge().index().equal(0)) - ), - ) - - # Node has incoming edge with - self.assertEqual( - ["1"], - medrecord.select_nodes( - node().has_incoming_edge_with(edge().index().equal(0)) - ), - ) - - # Node has edge with - self.assertEqual( - sorted(["0", "1"]), - sorted( - medrecord.select_nodes(node().has_edge_with(edge().index().equal(0))) - ), - ) - - # Node has neighbor with - self.assertEqual( - sorted(["0", "1"]), - sorted( - medrecord.select_nodes( - node().has_neighbor_with(node().index().equal("2")) - ) - ), - ) - self.assertEqual( - sorted(["0"]), - sorted( - medrecord.select_nodes( - node().has_neighbor_with(node().index().equal("1"), directed=True) - ) - ), - ) - - # Node has neighbor with - self.assertEqual( - sorted(["0", "2"]), - sorted( - medrecord.select_nodes( - node().has_neighbor_with(node().index().equal("1"), directed=False) - ) - ), - ) - - def test_select_nodes_node_index(self): - medrecord = create_medrecord() - - # Index greater - self.assertEqual( - sorted(["2", "3"]), - sorted(medrecord.select_nodes(node().index().greater("1"))), - ) - - # Index less - self.assertEqual( - sorted(["0", "1"]), sorted(medrecord.select_nodes(node().index().less("2"))) - ) - - # Index greater or equal - self.assertEqual( - sorted(["1", "2", "3"]), - sorted(medrecord.select_nodes(node().index().greater_or_equal("1"))), - ) - - # Index less or equal - self.assertEqual( - sorted(["0", "1", "2"]), - sorted(medrecord.select_nodes(node().index().less_or_equal("2"))), - ) - - # Index equal - self.assertEqual(["1"], medrecord.select_nodes(node().index().equal("1"))) - - # Index not equal - self.assertEqual( - sorted(["0", "2", "3"]), - sorted(medrecord.select_nodes(node().index().not_equal("1"))), - ) - - # Index in - self.assertEqual(["1"], medrecord.select_nodes(node().index().is_in(["1"]))) - - # Index not in - self.assertEqual( - sorted(["0", "2", "3"]), - sorted(medrecord.select_nodes(node().index().not_in(["1"]))), - ) - - # Index starts with - self.assertEqual( - ["1"], - medrecord.select_nodes(node().index().starts_with("1")), - ) - - # Index ends with - self.assertEqual( - ["1"], - medrecord.select_nodes(node().index().ends_with("1")), - ) - - # Index contains - self.assertEqual( - ["1"], - medrecord.select_nodes(node().index().contains("1")), - ) - - def test_select_nodes_node_attribute(self): - medrecord = create_medrecord() - - # Attribute greater - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem").greater("ipsum")) - ) - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem") > "ipsum") - ) - - # Attribute less - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem").less("ipsum")) - ) - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem") < "ipsum") - ) - - # Attribute greater or equal - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").greater_or_equal("ipsum")), - ) - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem") >= "ipsum") - ) - - # Attribute less or equal - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").less_or_equal("ipsum")), - ) - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem") <= "ipsum") - ) - - # Attribute equal - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem").equal("ipsum")) - ) - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem") == "ipsum") - ) - - # Attribute not equal - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem").not_equal("ipsum")) - ) - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem") != "ipsum") - ) - - # Attribute in - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem").is_in(["ipsum"])) - ) - - # Attribute not in - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem").not_in(["ipsum"])) - ) - - # Attribute starts with - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").starts_with("ip")), - ) - - # Attribute ends with - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").ends_with("um")), - ) - - # Attribute contains - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").contains("su")), - ) - - # Attribute compare to attribute - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem")) - ), - ) - - # Attribute compare to attribute add - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").add("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") + "10" - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").add("10")) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") + "10" - ), - ) - - # Attribute compare to attribute sub - # Returns nothing because can't sub a string - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").sub("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") + "10" - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").sub("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") - "10" - ), - ) - - # Attribute compare to attribute sub - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").sub(10)) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").sub(10)) - ), - ) - - # Attribute compare to attribute mul - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").mul(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") * 2 - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").mul(2)) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") * 2 - ), - ) - - # Attribute compare to attribute div - # Returns nothing because can't div a string - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").div("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") / "10" - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").div("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") / "10" - ), - ) - - # Attribute compare to attribute div - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").div(2)) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").div(2)) - ), - ) - - # Attribute compare to attribute pow - # Returns nothing because can't pow a string - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").pow("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") ** "10" - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").pow("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") ** "10" - ), - ) - - # Attribute compare to attribute pow - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").pow(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").pow(2)) - ), - ) - - # Attribute compare to attribute mod - # Returns nothing because can't mod a string - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").mod("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") % "10" - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").mod("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") % "10" - ), - ) - - # Attribute compare to attribute mod - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").mod(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").mod(2)) - ), - ) - - # Attribute compare to attribute round - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").round()) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").round()) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("float").round()) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("float").not_equal(node().attribute("float").round()) - ), - ) - - # Attribute compare to attribute round - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("float").ceil()) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("float").not_equal(node().attribute("float").ceil()) - ), - ) - - # Attribute compare to attribute floor - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("float").floor()) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("float").not_equal(node().attribute("float").floor()) - ), - ) - - # Attribute compare to attribute abs - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").abs()) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("integer").not_equal(node().attribute("integer").abs()) - ), - ) - - # Attribute compare to attribute sqrt - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").sqrt()) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").sqrt()) - ), - ) - - # Attribute compare to attribute trim - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("dolor").trim()) - ), - ) - - # Attribute compare to attribute trim_start - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("dolor").trim_start()) - ), - ) - - # Attribute compare to attribute trim_end - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("dolor").trim_end()) - ), - ) - - # Attribute compare to attribute lowercase - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("test").lowercase()) - ), - ) - - # Attribute compare to attribute uppercase - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("test").uppercase()) - ), - ) - - def test_select_edges_edge(self): - medrecord = create_medrecord() - - medrecord.add_group("test", edges=[0]) - - # Edge connected to target - self.assertEqual( - sorted([1, 2, 3]), - sorted(medrecord.select_edges(edge().connected_target("2"))), - ) - - # Edge connected to source - self.assertEqual( - sorted([0, 2, 3]), - sorted(medrecord.select_edges(edge().connected_source("0"))), - ) - - # Edge connected - self.assertEqual( - sorted([0, 1]), - sorted(medrecord.select_edges(edge().connected("1"))), - ) - - # Edge in group - self.assertEqual( - [0], - medrecord.select_edges(edge().in_group("test")), - ) - - # Edge has attribute - self.assertEqual( - [0], - medrecord.select_edges(edge().has_attribute("sed")), - ) - - # Edge connected to target with - self.assertEqual( - [0], - medrecord.select_edges( - edge().connected_target_with(node().index().equal("1")) - ), - ) - - # Edge connected to source with - self.assertEqual( - sorted([0, 2, 3]), - sorted( - medrecord.select_edges( - edge().connected_source_with(node().index().equal("0")) - ) - ), - ) - - # Edge connected with - self.assertEqual( - sorted([0, 1]), - sorted( - medrecord.select_edges(edge().connected_with(node().index().equal("1"))) - ), - ) - - # Edge has parallel edges with - self.assertEqual( - sorted([2, 3]), - sorted( - medrecord.select_edges( - edge().has_parallel_edges_with(edge().has_attribute("test")) - ) - ), - ) - - # Edge has parallel edges with self comparison - self.assertEqual( - [2], - medrecord.select_edges( - edge().has_parallel_edges_with_self_comparison( - edge().attribute("test").equal(edge().attribute("test").sub(1)) - ) - ), - ) - - def test_select_edges_edge_index(self): - medrecord = create_medrecord() - - # Index greater - self.assertEqual( - sorted([2, 3]), - sorted(medrecord.select_edges(edge().index().greater(1))), - ) - - # Index less - self.assertEqual( - [0], - medrecord.select_edges(edge().index().less(1)), - ) - - # Index greater or equal - self.assertEqual( - sorted([1, 2, 3]), - sorted(medrecord.select_edges(edge().index().greater_or_equal(1))), - ) - - # Index less or equal - self.assertEqual( - sorted([0, 1]), - sorted(medrecord.select_edges(edge().index().less_or_equal(1))), - ) - - # Index equal - self.assertEqual( - [1], - medrecord.select_edges(edge().index().equal(1)), - ) - - # Index not equal - self.assertEqual( - sorted([0, 2, 3]), - sorted(medrecord.select_edges(edge().index().not_equal(1))), - ) - - # Index in - self.assertEqual( - [1], - medrecord.select_edges(edge().index().is_in([1])), - ) - - # Index not in - self.assertEqual( - sorted([0, 2, 3]), - sorted(medrecord.select_edges(edge().index().not_in([1]))), - ) - - def test_select_edges_edges_attribute(self): - medrecord = create_medrecord() - - # Attribute greater - self.assertEqual( - [], - medrecord.select_edges(edge().attribute("sed").greater("do")), - ) - - # Attribute less - self.assertEqual( - [], - medrecord.select_edges(edge().attribute("sed").less("do")), - ) - - # Attribute greater or equal - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").greater_or_equal("do")), - ) - - # Attribute less or equal - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").less_or_equal("do")), - ) - - # Attribute equal - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").equal("do")), - ) - - # Attribute not equal - self.assertEqual( - [], - medrecord.select_edges(edge().attribute("sed").not_equal("do")), - ) - - # Attribute in - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").is_in(["do"])), - ) - - # Attribute not in - self.assertEqual( - [], - medrecord.select_edges(edge().attribute("sed").not_in(["do"])), - ) - - # Attribute starts with - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").starts_with("d")), - ) - - # Attribute ends with - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").ends_with("o")), - ) - - # Attribute contains - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").contains("d")), - ) - - # Attribute compare to attribute - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed")) - ), - ) - - # Attribute compare to attribute add - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed").add("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") == edge().attribute("sed") + "10" - ), - ) - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed").add("10")) - ), - ) - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed") != edge().attribute("sed") + "10" - ), - ) - - # Attribute compare to attribute sub - # Returns nothing because can't sub a string - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed").sub("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") == edge().attribute("sed") - "10" - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed").sub("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") != edge().attribute("sed") - "10" - ), - ) - - # Attribute compare to attribute sub - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").sub(10)) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").sub(10)) - ), - ) - - # Attribute compare to attribute mul - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed").mul(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") == edge().attribute("sed") * 2 - ), - ) - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed").mul(2)) - ), - ) - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed") != edge().attribute("sed") * 2 - ), - ) - - # Attribute compare to attribute div - # Returns nothing because can't div a string - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed").div("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") == edge().attribute("sed") / "10" - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed").div("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") != edge().attribute("sed") / "10" - ), - ) - - # Attribute compare to attribute div - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").div(2)) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").div(2)) - ), - ) - - # Attribute compare to attribute pow - # Returns nothing because can't pow a string - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem").equal(edge().attribute("lorem").pow("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem") == edge().attribute("lorem") ** "10" - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem").not_equal(edge().attribute("lorem").pow("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem") != edge().attribute("lorem") ** "10" - ), - ) - - # Attribute compare to attribute pow - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").pow(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").pow(2)) - ), - ) - - # Attribute compare to attribute mod - # Returns nothing because can't mod a string - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem").equal(edge().attribute("lorem").mod("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem") == edge().attribute("lorem") % "10" - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem").not_equal(edge().attribute("lorem").mod("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem") != edge().attribute("lorem") % "10" - ), - ) - - # Attribute compare to attribute mod - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").mod(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").mod(2)) - ), - ) - - # Attribute compare to attribute round - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed").round()) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed").round()) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("float").round()) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("float").not_equal(edge().attribute("float").round()) - ), - ) - - # Attribute compare to attribute ceil - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("float").ceil()) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("float").not_equal(edge().attribute("float").ceil()) - ), - ) - - # Attribute compare to attribute floor - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("float").floor()) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("float").not_equal(edge().attribute("float").floor()) - ), - ) - - # Attribute compare to attribute abs - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").abs()) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("integer").not_equal(edge().attribute("integer").abs()) - ), - ) - - # Attribute compare to attribute sqrt - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").sqrt()) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").sqrt()) - ), - ) - - # Attribute compare to attribute trim - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("dolor").trim()) - ), - ) - - # Attribute compare to attribute trim_start - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("dolor").trim_start()) - ), - ) - - # Attribute compare to attribute trim_end - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("dolor").trim_end()) - ), - ) - - # Attribute compare to attribute lowercase - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("test").lowercase()) - ), - ) - - # Attribute compare to attribute uppercase - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("test").uppercase()) - ), - ) diff --git a/medmodels/medrecord/types.py b/medmodels/medrecord/types.py index ac46ef93..9c017da9 100644 --- a/medmodels/medrecord/types.py +++ b/medmodels/medrecord/types.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import datetime -from typing import TYPE_CHECKING, Dict, List, Mapping, Tuple, TypedDict, Union +from typing import TYPE_CHECKING, Dict, List, Mapping, Sequence, Tuple, TypedDict, Union import pandas as pd import polars as pl @@ -92,6 +92,26 @@ "TemporalAttributeInfo", "NumericAttributeInfo", "StringAttributeInfo" ] +#: A type alias for input to a node. +NodeInput = Union[ + NodeTuple, + Sequence[NodeTuple], + PandasNodeDataFrameInput, + List[PandasNodeDataFrameInput], + PolarsNodeDataFrameInput, + List[PolarsNodeDataFrameInput], +] + +#: A type alias for input to an edge. +EdgeInput = Union[ + EdgeTuple, + Sequence[EdgeTuple], + PandasEdgeDataFrameInput, + List[PandasEdgeDataFrameInput], + PolarsEdgeDataFrameInput, + List[PolarsEdgeDataFrameInput], +] + class GroupInfo(TypedDict): """A dictionary containing lists of node and edge indices for a group.""" diff --git a/medmodels/statistic_evaluations/evaluate_compare/compare.pyi b/medmodels/statistic_evaluations/evaluate_compare/compare.pyi index cf38c4c1..0b36ae2f 100644 --- a/medmodels/statistic_evaluations/evaluate_compare/compare.pyi +++ b/medmodels/statistic_evaluations/evaluate_compare/compare.pyi @@ -25,13 +25,15 @@ class DistanceSummary(TypedDict): distance: float class ComparerSummary(TypedDict): - """Dictionary for the comparing results.""" + """Dictionary for comparing results.""" attribute_tests: Dict[MedRecordAttribute, List[TestSummary]] concepts_tests: Dict[Group, List[TestSummary]] concepts_distance: Dict[Group, DistanceSummary] class TestSummary(TypedDict): + """Dictionary for hypothesis test results.""" + test: str Hypothesis: str not_reject: bool diff --git a/medmodels/treatment_effect/builder.py b/medmodels/treatment_effect/builder.py index 9a6e5be6..15ceb4c7 100644 --- a/medmodels/treatment_effect/builder.py +++ b/medmodels/treatment_effect/builder.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Literal, Optional import medmodels.treatment_effect.treatment_effect as tee -from medmodels.medrecord.querying import NodeOperation +from medmodels.medrecord.querying import NodeQuery from medmodels.medrecord.types import ( Group, MedRecordAttribute, @@ -31,7 +31,7 @@ class TreatmentEffectBuilder: outcome_before_treatment_days: Optional[int] - filter_controls_operation: Optional[NodeOperation] + filter_controls_query: Optional[NodeQuery] matching_method: Optional[MatchingMethod] matching_essential_covariates: Optional[MedRecordAttributeInputList] @@ -202,17 +202,17 @@ def with_outcome_before_treatment_exclusion( return self - def filter_controls(self, operation: NodeOperation) -> TreatmentEffectBuilder: - """Filter the control group based on the provided operation. + def filter_controls(self, query: NodeQuery) -> TreatmentEffectBuilder: + """Filter the control group based on the provided query. Args: - operation (NodeOperation): The operation to be applied to the control group. + query (NodeQuery): The query to be applied to the control group. Returns: TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder with updated time attribute. """ - self.filter_controls_operation = operation + self.filter_controls_query = query return self diff --git a/medmodels/treatment_effect/tests/test_temporal_analysis.py b/medmodels/treatment_effect/tests/test_temporal_analysis.py index 9a118754..41121f5d 100644 --- a/medmodels/treatment_effect/tests/test_temporal_analysis.py +++ b/medmodels/treatment_effect/tests/test_temporal_analysis.py @@ -150,9 +150,7 @@ def test_find_reference_time(self): self.assertEqual(0, edge) # adding medication time - self.medrecord.add_edge( - source_node="M1", target_node="P1", attributes={"time": "2000-01-15"} - ) + self.medrecord.add_edges(("M1", "P1", {"time": "2000-01-15"})) edge = find_reference_edge( self.medrecord, diff --git a/medmodels/treatment_effect/tests/test_treatment_effect.py b/medmodels/treatment_effect/tests/test_treatment_effect.py index aca74858..0a830b63 100644 --- a/medmodels/treatment_effect/tests/test_treatment_effect.py +++ b/medmodels/treatment_effect/tests/test_treatment_effect.py @@ -6,7 +6,7 @@ import pandas as pd from medmodels import MedRecord -from medmodels.medrecord import edge, node +from medmodels.medrecord.querying import EdgeDirection, NodeOperand from medmodels.medrecord.types import NodeIndex from medmodels.treatment_effect.estimate import ContingencyTable, SubjectIndices from medmodels.treatment_effect.treatment_effect import TreatmentEffect @@ -245,8 +245,8 @@ def assert_treatment_effects_equal( treatment_effect2._outcome_before_treatment_days, ) test_case.assertEqual( - treatment_effect1._filter_controls_operation, - treatment_effect2._filter_controls_operation, + treatment_effect1._filter_controls_query, + treatment_effect2._filter_controls_query, ) test_case.assertEqual( treatment_effect1._matching_method, treatment_effect2._matching_method @@ -620,14 +620,14 @@ def test_outcome_before_treatment(self): tee3._find_outcomes(medrecord=self.medrecord, treated_group=treated_group) def test_filter_controls(self): + def query1(node: NodeOperand): + node.neighbors(EdgeDirection.BOTH).index().equal_to("M2") + tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") - .filter_controls( - node().has_outgoing_edge_with(edge().connected_target("M2")) - | node().has_incoming_edge_with(edge().connected_source("M2")) - ) + .filter_controls(query1) .build() ) counts_tee = tee.estimate._compute_subject_counts(self.medrecord) @@ -635,11 +635,15 @@ def test_filter_controls(self): self.assertEqual(counts_tee, (2, 1, 1, 2)) # filter females only + + def query2(node: NodeOperand): + node.attribute("gender").equal_to("female") + tee2 = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") - .filter_controls(node().attribute("gender").equal("female")) + .filter_controls(query2) .build() ) diff --git a/medmodels/treatment_effect/treatment_effect.py b/medmodels/treatment_effect/treatment_effect.py index 94a3980e..6f22c8a9 100644 --- a/medmodels/treatment_effect/treatment_effect.py +++ b/medmodels/treatment_effect/treatment_effect.py @@ -14,8 +14,7 @@ from typing import Any, Dict, Literal, Optional, Set, Tuple from medmodels import MedRecord -from medmodels.medrecord import node -from medmodels.medrecord.querying import NodeOperation +from medmodels.medrecord.querying import EdgeDirection, NodeOperand, NodeQuery from medmodels.medrecord.types import ( Group, MedRecordAttribute, @@ -50,7 +49,7 @@ class TreatmentEffect: _outcome_before_treatment_days: Optional[int] - _filter_controls_operation: Optional[NodeOperation] + _filter_controls_query: Optional[NodeQuery] _matching_method: Optional[MatchingMethod] _matching_essential_covariates: MedRecordAttributeInputList @@ -92,7 +91,7 @@ def _set_configuration( follow_up_period_days: int = 365, follow_up_period_reference: Literal["first", "last"] = "last", outcome_before_treatment_days: Optional[int] = None, - filter_controls_operation: Optional[NodeOperation] = None, + filter_controls_query: Optional[NodeQuery] = None, matching_method: Optional[MatchingMethod] = None, matching_essential_covariates: MedRecordAttributeInputList = ["gender", "age"], matching_one_hot_covariates: MedRecordAttributeInputList = ["gender"], @@ -127,8 +126,8 @@ def _set_configuration( reference point for the follow-up period. Defaults to "last". outcome_before_treatment_days (Optional[int], optional): The number of days before the treatment to consider for outcomes. Defaults to None. - filter_controls_operation (Optional[NodeOperation], optional): An optional - operation to filter the control group based on specified criteria. + filter_controls_query (Optional[NodeQuery], optional): An optional + query to filter the control group based on specified criteria. Defaults to None. matching_method (Optional[MatchingMethod]): The method to match treatment and control groups. Defaults to None. @@ -158,7 +157,7 @@ def _set_configuration( treatment_effect._follow_up_period_days = follow_up_period_days treatment_effect._follow_up_period_reference = follow_up_period_reference treatment_effect._outcome_before_treatment_days = outcome_before_treatment_days - treatment_effect._filter_controls_operation = filter_controls_operation + treatment_effect._filter_controls_query = filter_controls_query treatment_effect._matching_method = matching_method treatment_effect._matching_essential_covariates = matching_essential_covariates @@ -206,7 +205,7 @@ def _find_groups( control_group=control_group, treated_group=treated_group, rejected_nodes=washout_nodes | outcome_before_treatment_nodes, - filter_controls_operation=self._filter_controls_operation, + filter_controls_query=self._filter_controls_query, ) return ( @@ -234,18 +233,16 @@ def _find_treated_patients(self, medrecord: MedRecord) -> Set[NodeIndex]: treatments = medrecord.nodes_in_group(self._treatments_group) + def query(node: NodeOperand): + node.in_group(self._patients_group) + + node.neighbors(edge_direction=EdgeDirection.BOTH).index().equal_to( + treatment + ) + # Create the group with all the patients that underwent the treatment for treatment in treatments: - treated_group.update( - set( - medrecord.select_nodes( - node().in_group(self._patients_group) - & node().has_neighbor_with( - node().index() == treatment, directed=False - ) - ) - ) - ) + treated_group.update(set(medrecord.select_nodes(query))) if not treated_group: raise ValueError( "No patients found for the treatment groups in this MedRecord." @@ -288,14 +285,14 @@ def _find_outcomes( f"No outcomes found in the MedRecord for group {self._outcomes_group}" ) + def query(node: NodeOperand): + node.index().is_in(list(treated_group)) + + # This could probably be refactored to a proper query + node.neighbors(edge_direction=EdgeDirection.BOTH).index().equal_to(outcome) + for outcome in outcomes: - nodes_to_check = set( - medrecord.select_nodes( - node().has_neighbor_with(node().index() == outcome, directed=False) - # This could probably be refactored to a proper query - & node().index().is_in(list(treated_group)) - ) - ) + nodes_to_check = set(medrecord.select_nodes(query)) # Find patients that had the outcome before the treatment if self._outcome_before_treatment_days: @@ -399,12 +396,12 @@ def _find_controls( control_group: Set[NodeIndex], treated_group: Set[NodeIndex], rejected_nodes: Set[NodeIndex] = set(), - filter_controls_operation: Optional[NodeOperation] = None, + filter_controls_query: Optional[NodeQuery] = None, ) -> Tuple[Set[NodeIndex], Set[NodeIndex]]: """Identifies control groups among patients who did not undergo the specified treatments. It takes the control group and removes the rejected nodes, the treated nodes, - and applies the filter_controls_operation if specified. + and applies the filter_controls_query if specified. Control groups are divided into those who had the outcome (control_outcome_true) and those who did not (control_outcome_false), @@ -419,8 +416,8 @@ def _find_controls( treatment. rejected_nodes (Set[NodeIndex]): A set of patient nodes that were rejected due to the washout period or outcome before treatment. - filter_controls_operation (Optional[NodeOperation], optional): An optional - operation to filter the control group based on specified criteria. + filter_controls_query (Optional[NodeQuery], optional): An optional + query to filter the control group based on specified criteria. Defaults to None. Returns: @@ -436,9 +433,9 @@ def _find_controls( outcome group. """ # Apply the filter to the control group if specified - if filter_controls_operation: + if filter_controls_query: control_group = ( - set(medrecord.select_nodes(filter_controls_operation)) & control_group + set(medrecord.select_nodes(filter_controls_query)) & control_group ) control_group = control_group - treated_group - rejected_nodes @@ -453,17 +450,15 @@ def _find_controls( f"No outcomes found in the MedRecord for group {self._outcomes_group}" ) + def query(node: NodeOperand): + node.index().is_in(list(control_group)) + + node.neighbors(edge_direction=EdgeDirection.BOTH).index().equal_to(outcome) + # Finding the patients that had the outcome in the control group for outcome in outcomes: - control_outcome_true.update( - medrecord.select_nodes( - # This could probably be refactored to a proper query - node().index().is_in(list(control_group)) - & node().has_neighbor_with( - node().index() == outcome, directed=False - ) - ) - ) + control_outcome_true.update(medrecord.select_nodes(query)) + control_outcome_false = control_group - control_outcome_true return control_outcome_true, control_outcome_false diff --git a/rustmodels/Cargo.toml b/rustmodels/Cargo.toml index a6640f90..70922829 100644 --- a/rustmodels/Cargo.toml +++ b/rustmodels/Cargo.toml @@ -11,7 +11,7 @@ crate-type = ["cdylib"] medmodels-core = { workspace = true } medmodels-utils = { workspace = true } -pyo3 = { workspace = true } -pyo3-polars = { workspace = true } +pyo3 = { version = "0.21.2", features = ["chrono"] } +pyo3-polars = "0.14.0" polars = { workspace = true } chrono = { workspace = true } diff --git a/rustmodels/src/lib.rs b/rustmodels/src/lib.rs index 751ce907..2584bae3 100644 --- a/rustmodels/src/lib.rs +++ b/rustmodels/src/lib.rs @@ -4,9 +4,12 @@ mod medrecord; use medrecord::{ datatype::{PyAny, PyBool, PyDateTime, PyFloat, PyInt, PyNull, PyOption, PyString, PyUnion}, querying::{ - PyEdgeAttributeOperand, PyEdgeIndexOperand, PyEdgeOperand, PyEdgeOperation, - PyNodeAttributeOperand, PyNodeIndexOperand, PyNodeOperand, PyNodeOperation, - PyValueArithmeticOperation, PyValueTransformationOperation, + attributes::{ + PyAttributesTreeOperand, PyMultipleAttributesOperand, PySingleAttributeOperand, + }, + edges::{PyEdgeIndexOperand, PyEdgeIndicesOperand, PyEdgeOperand}, + nodes::{PyEdgeDirection, PyNodeIndexOperand, PyNodeIndicesOperand, PyNodeOperand}, + values::{PyMultipleValuesOperand, PySingleValueOperand}, }, schema::{PyAttributeDataType, PyAttributeType, PyGroupSchema, PySchema}, PyMedRecord, @@ -32,20 +35,22 @@ fn _medmodels(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - - m.add_class::()?; - m.add_class::()?; - - m.add_class::()?; - m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; - m.add_class::()?; - m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + m.add_class::()?; + m.add_class::()?; + + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/rustmodels/src/medrecord/attribute.rs b/rustmodels/src/medrecord/attribute.rs index 9615dc6f..cc9fe513 100644 --- a/rustmodels/src/medrecord/attribute.rs +++ b/rustmodels/src/medrecord/attribute.rs @@ -6,7 +6,7 @@ use std::{hash::Hash, ops::Deref}; #[repr(transparent)] #[derive(PartialEq, Eq, Hash, Clone, Debug)] -pub(crate) struct PyMedRecordAttribute(MedRecordAttribute); +pub struct PyMedRecordAttribute(MedRecordAttribute); impl From for PyMedRecordAttribute { fn from(value: MedRecordAttribute) -> Self { diff --git a/rustmodels/src/medrecord/errors.rs b/rustmodels/src/medrecord/errors.rs index 6965791e..f96f9a32 100644 --- a/rustmodels/src/medrecord/errors.rs +++ b/rustmodels/src/medrecord/errors.rs @@ -21,6 +21,7 @@ impl From for PyErr { MedRecordError::ConversionError(message) => PyRuntimeError::new_err(message), MedRecordError::AssertionError(message) => PyAssertionError::new_err(message), MedRecordError::SchemaError(message) => PyValueError::new_err(message), + MedRecordError::QueryError(message) => PyRuntimeError::new_err(message), } } } diff --git a/rustmodels/src/medrecord/mod.rs b/rustmodels/src/medrecord/mod.rs index c5cf687a..b6bba790 100644 --- a/rustmodels/src/medrecord/mod.rs +++ b/rustmodels/src/medrecord/mod.rs @@ -1,3 +1,5 @@ +#![allow(clippy::new_without_default)] + mod attribute; pub mod datatype; mod errors; @@ -13,9 +15,9 @@ use medmodels_core::{ errors::MedRecordError, medrecord::{Attributes, EdgeIndex, MedRecord, MedRecordAttribute, MedRecordValue}, }; -use pyo3::prelude::*; +use pyo3::{prelude::*, types::PyFunction}; use pyo3_polars::PyDataFrame; -use querying::{PyEdgeOperation, PyNodeOperation}; +use querying::{edges::PyEdgeOperand, nodes::PyNodeOperand}; use schema::PySchema; use std::collections::HashMap; use traits::DeepInto; @@ -33,17 +35,17 @@ pub struct PyMedRecord(MedRecord); #[pymethods] impl PyMedRecord { #[new] - fn new() -> Self { + pub fn new() -> Self { Self(MedRecord::new()) } #[staticmethod] - fn with_schema(schema: PySchema) -> Self { + pub fn with_schema(schema: PySchema) -> Self { Self(MedRecord::with_schema(schema.into())) } #[staticmethod] - fn from_tuples( + pub fn from_tuples( nodes: Vec<(PyNodeIndex, PyAttributes)>, edges: Option>, ) -> PyResult { @@ -54,7 +56,7 @@ impl PyMedRecord { } #[staticmethod] - fn from_dataframes( + pub fn from_dataframes( nodes_dataframes: Vec<(PyDataFrame, String)>, edges_dataframes: Vec<(PyDataFrame, String, String)>, ) -> PyResult { @@ -65,7 +67,7 @@ impl PyMedRecord { } #[staticmethod] - fn from_nodes_dataframes(nodes_dataframes: Vec<(PyDataFrame, String)>) -> PyResult { + pub fn from_nodes_dataframes(nodes_dataframes: Vec<(PyDataFrame, String)>) -> PyResult { Ok(Self( MedRecord::from_nodes_dataframes(nodes_dataframes, None) .map_err(PyMedRecordError::from)?, @@ -73,22 +75,22 @@ impl PyMedRecord { } #[staticmethod] - fn from_example_dataset() -> Self { + pub fn from_example_dataset() -> Self { Self(MedRecord::from_example_dataset()) } #[staticmethod] - fn from_ron(path: &str) -> PyResult { + pub fn from_ron(path: &str) -> PyResult { Ok(Self( MedRecord::from_ron(path).map_err(PyMedRecordError::from)?, )) } - fn to_ron(&self, path: &str) -> PyResult<()> { + pub fn to_ron(&self, path: &str) -> PyResult<()> { Ok(self.0.to_ron(path).map_err(PyMedRecordError::from)?) } - fn update_schema(&mut self, schema: PySchema) -> PyResult<()> { + pub fn update_schema(&mut self, schema: PySchema) -> PyResult<()> { Ok(self .0 .update_schema(schema.into()) @@ -96,19 +98,22 @@ impl PyMedRecord { } #[getter] - fn schema(&self) -> PySchema { + pub fn schema(&self) -> PySchema { self.0.get_schema().clone().into() } #[getter] - fn nodes(&self) -> Vec { + pub fn nodes(&self) -> Vec { self.0 .node_indices() .map(|node_index| node_index.clone().into()) .collect() } - fn node(&self, node_index: Vec) -> PyResult> { + pub fn node( + &self, + node_index: Vec, + ) -> PyResult> { node_index .into_iter() .map(|node_index| { @@ -123,11 +128,11 @@ impl PyMedRecord { } #[getter] - fn edges(&self) -> Vec { + pub fn edges(&self) -> Vec { self.0.edge_indices().copied().collect() } - fn edge(&self, edge_index: Vec) -> PyResult> { + pub fn edge(&self, edge_index: Vec) -> PyResult> { edge_index .into_iter() .map(|edge_index| { @@ -142,11 +147,11 @@ impl PyMedRecord { } #[getter] - fn groups(&self) -> Vec { + pub fn groups(&self) -> Vec { self.0.groups().map(|group| group.clone().into()).collect() } - fn outgoing_edges( + pub fn outgoing_edges( &self, node_index: Vec, ) -> PyResult>> { @@ -165,7 +170,7 @@ impl PyMedRecord { .collect() } - fn incoming_edges( + pub fn incoming_edges( &self, node_index: Vec, ) -> PyResult>> { @@ -184,7 +189,7 @@ impl PyMedRecord { .collect() } - fn edge_endpoints( + pub fn edge_endpoints( &self, edge_index: Vec, ) -> PyResult> { @@ -207,7 +212,7 @@ impl PyMedRecord { .collect() } - fn edges_connecting( + pub fn edges_connecting( &self, source_node_indices: Vec, target_node_indices: Vec, @@ -224,7 +229,7 @@ impl PyMedRecord { .collect() } - fn edges_connecting_undirected( + pub fn edges_connecting_undirected( &self, first_node_indices: Vec, second_node_indices: Vec, @@ -241,14 +246,7 @@ impl PyMedRecord { .collect() } - fn add_node(&mut self, node_index: PyNodeIndex, attributes: PyAttributes) -> PyResult<()> { - Ok(self - .0 - .add_node(node_index.into(), attributes.deep_into()) - .map_err(PyMedRecordError::from)?) - } - - fn remove_node( + pub fn remove_nodes( &mut self, node_index: Vec, ) -> PyResult> { @@ -265,7 +263,7 @@ impl PyMedRecord { .collect() } - fn replace_node_attributes( + pub fn replace_node_attributes( &mut self, node_index: Vec, attributes: PyAttributes, @@ -284,7 +282,7 @@ impl PyMedRecord { Ok(()) } - fn update_node_attribute( + pub fn update_node_attribute( &mut self, node_index: Vec, attribute: PyMedRecordAttribute, @@ -308,7 +306,7 @@ impl PyMedRecord { Ok(()) } - fn remove_node_attribute( + pub fn remove_node_attribute( &mut self, node_index: Vec, attribute: PyMedRecordAttribute, @@ -332,14 +330,14 @@ impl PyMedRecord { Ok(()) } - fn add_nodes(&mut self, nodes: Vec<(PyNodeIndex, PyAttributes)>) -> PyResult<()> { + pub fn add_nodes(&mut self, nodes: Vec<(PyNodeIndex, PyAttributes)>) -> PyResult<()> { Ok(self .0 .add_nodes(nodes.deep_into()) .map_err(PyMedRecordError::from)?) } - fn add_nodes_dataframes( + pub fn add_nodes_dataframes( &mut self, nodes_dataframes: Vec<(PyDataFrame, String)>, ) -> PyResult<()> { @@ -349,23 +347,7 @@ impl PyMedRecord { .map_err(PyMedRecordError::from)?) } - fn add_edge( - &mut self, - source_node_index: PyNodeIndex, - target_node_index: PyNodeIndex, - attributes: PyAttributes, - ) -> PyResult { - Ok(self - .0 - .add_edge( - source_node_index.into(), - target_node_index.into(), - attributes.deep_into(), - ) - .map_err(PyMedRecordError::from)?) - } - - fn remove_edge( + pub fn remove_edges( &mut self, edge_index: Vec, ) -> PyResult> { @@ -382,7 +364,7 @@ impl PyMedRecord { .collect() } - fn replace_edge_attributes( + pub fn replace_edge_attributes( &mut self, edge_index: Vec, attributes: PyAttributes, @@ -401,7 +383,7 @@ impl PyMedRecord { Ok(()) } - fn update_edge_attribute( + pub fn update_edge_attribute( &mut self, edge_index: Vec, attribute: PyMedRecordAttribute, @@ -422,7 +404,7 @@ impl PyMedRecord { Ok(()) } - fn remove_edge_attribute( + pub fn remove_edge_attribute( &mut self, edge_index: Vec, attribute: PyMedRecordAttribute, @@ -444,7 +426,7 @@ impl PyMedRecord { Ok(()) } - fn add_edges( + pub fn add_edges( &mut self, relations: Vec<(PyNodeIndex, PyNodeIndex, PyAttributes)>, ) -> PyResult> { @@ -454,7 +436,7 @@ impl PyMedRecord { .map_err(PyMedRecordError::from)?) } - fn add_edges_dataframes( + pub fn add_edges_dataframes( &mut self, edges_dataframes: Vec<(PyDataFrame, String, String)>, ) -> PyResult> { @@ -464,7 +446,7 @@ impl PyMedRecord { .map_err(PyMedRecordError::from)?) } - fn add_group( + pub fn add_group( &mut self, group: PyGroup, node_indices_to_add: Option>, @@ -480,7 +462,7 @@ impl PyMedRecord { .map_err(PyMedRecordError::from)?) } - fn remove_group(&mut self, group: Vec) -> PyResult<()> { + pub fn remove_groups(&mut self, group: Vec) -> PyResult<()> { group.into_iter().try_for_each(|group| { self.0 .remove_group(&group) @@ -490,7 +472,11 @@ impl PyMedRecord { }) } - fn add_node_to_group(&mut self, group: PyGroup, node_index: Vec) -> PyResult<()> { + pub fn add_nodes_to_group( + &mut self, + group: PyGroup, + node_index: Vec, + ) -> PyResult<()> { node_index.into_iter().try_for_each(|node_index| { Ok(self .0 @@ -499,7 +485,11 @@ impl PyMedRecord { }) } - fn add_edge_to_group(&mut self, group: PyGroup, edge_index: Vec) -> PyResult<()> { + pub fn add_edges_to_group( + &mut self, + group: PyGroup, + edge_index: Vec, + ) -> PyResult<()> { edge_index.into_iter().try_for_each(|edge_index| { Ok(self .0 @@ -508,7 +498,7 @@ impl PyMedRecord { }) } - fn remove_node_from_group( + pub fn remove_nodes_from_group( &mut self, group: PyGroup, node_index: Vec, @@ -521,7 +511,7 @@ impl PyMedRecord { }) } - fn remove_edge_from_group( + pub fn remove_edges_from_group( &mut self, group: PyGroup, edge_index: Vec, @@ -534,7 +524,10 @@ impl PyMedRecord { }) } - fn nodes_in_group(&self, group: Vec) -> PyResult>> { + pub fn nodes_in_group( + &self, + group: Vec, + ) -> PyResult>> { group .into_iter() .map(|group| { @@ -550,7 +543,10 @@ impl PyMedRecord { .collect() } - fn edges_in_group(&self, group: Vec) -> PyResult>> { + pub fn edges_in_group( + &self, + group: Vec, + ) -> PyResult>> { group .into_iter() .map(|group| { @@ -566,7 +562,7 @@ impl PyMedRecord { .collect() } - fn groups_of_node( + pub fn groups_of_node( &self, node_index: Vec, ) -> PyResult>> { @@ -585,7 +581,7 @@ impl PyMedRecord { .collect() } - fn groups_of_edge( + pub fn groups_of_edge( &self, edge_index: Vec, ) -> PyResult>> { @@ -604,31 +600,31 @@ impl PyMedRecord { .collect() } - fn node_count(&self) -> usize { + pub fn node_count(&self) -> usize { self.0.node_count() } - fn edge_count(&self) -> usize { + pub fn edge_count(&self) -> usize { self.0.edge_count() } - fn group_count(&self) -> usize { + pub fn group_count(&self) -> usize { self.0.group_count() } - fn contains_node(&self, node_index: PyNodeIndex) -> bool { + pub fn contains_node(&self, node_index: PyNodeIndex) -> bool { self.0.contains_node(&node_index.into()) } - fn contains_edge(&self, edge_index: EdgeIndex) -> bool { + pub fn contains_edge(&self, edge_index: EdgeIndex) -> bool { self.0.contains_edge(&edge_index) } - fn contains_group(&self, group: PyGroup) -> bool { + pub fn contains_group(&self, group: PyGroup) -> bool { self.0.contains_group(&group.into()) } - fn neighbors( + pub fn neighbors( &self, node_indices: Vec, ) -> PyResult>> { @@ -637,7 +633,7 @@ impl PyMedRecord { .map(|node_index| { let neighbors = self .0 - .neighbors(&node_index) + .neighbors_outgoing(&node_index) .map_err(PyMedRecordError::from)? .map(|neighbor| neighbor.clone().into()) .collect(); @@ -647,7 +643,7 @@ impl PyMedRecord { .collect() } - fn neighbors_undirected( + pub fn neighbors_undirected( &self, node_indices: Vec, ) -> PyResult>> { @@ -666,27 +662,39 @@ impl PyMedRecord { .collect() } - fn clear(&mut self) { + pub fn clear(&mut self) { self.0.clear(); } - fn select_nodes(&self, operation: PyNodeOperation) -> Vec { - self.0 - .select_nodes(operation.into()) + pub fn select_nodes(&self, query: &Bound<'_, PyFunction>) -> PyResult> { + Ok(self + .0 + .select_nodes(|node| { + query + .call1((PyNodeOperand::from(node.clone()),)) + .expect("Call must succeed"); + }) .iter() - .map(|index| index.clone().into()) - .collect() + .map_err(PyMedRecordError::from)? + .map(|node_index| node_index.clone().into()) + .collect()) } - fn select_edges(&self, operation: PyEdgeOperation) -> Vec { - self.0 - .select_edges(operation.into()) + pub fn select_edges(&self, query: &Bound<'_, PyFunction>) -> PyResult> { + Ok(self + .0 + .select_edges(|edge| { + query + .call1((PyEdgeOperand::from(edge.clone()),)) + .expect("Call must succeed"); + }) .iter() + .map_err(PyMedRecordError::from)? .copied() - .collect() + .collect()) } - fn clone(&self) -> Self { + pub fn clone(&self) -> Self { Self(self.0.clone()) } } diff --git a/rustmodels/src/medrecord/querying.rs b/rustmodels/src/medrecord/querying.rs deleted file mode 100644 index b7965517..00000000 --- a/rustmodels/src/medrecord/querying.rs +++ /dev/null @@ -1,732 +0,0 @@ -use super::{attribute::PyMedRecordAttribute, value::PyMedRecordValue, Lut}; -use crate::{ - gil_hash_map::GILHashMap, - medrecord::{ - errors::PyMedRecordError, value::convert_pyobject_to_medrecordvalue, PyGroup, PyNodeIndex, - }, -}; -use medmodels_core::{ - errors::MedRecordError, - medrecord::{ - ArithmeticOperation, EdgeAttributeOperand, EdgeIndex, EdgeIndexOperand, EdgeOperand, - EdgeOperation, MedRecordAttribute, MedRecordValue, NodeAttributeOperand, NodeIndexOperand, - NodeOperand, NodeOperation, TransformationOperation, ValueOperand, - }, -}; -use pyo3::{ - pyclass, pymethods, types::PyAnyMethods, Bound, FromPyObject, IntoPy, PyAny, PyObject, - PyResult, Python, -}; -use std::ops::Range; - -#[pyclass] -#[derive(Clone, Debug)] -pub struct PyValueArithmeticOperation(ArithmeticOperation, MedRecordAttribute, MedRecordValue); - -#[pyclass] -#[derive(Clone, Debug)] -pub struct PyValueTransformationOperation(TransformationOperation, MedRecordAttribute); - -#[pyclass] -#[derive(Clone, Debug)] -pub struct PyValueSliceOperation(MedRecordAttribute, Range); - -#[repr(transparent)] -#[derive(Clone, Debug)] -pub(crate) struct PyValueOperand(ValueOperand); - -impl From for PyValueOperand { - fn from(value: ValueOperand) -> Self { - PyValueOperand(value) - } -} - -impl From for ValueOperand { - fn from(value: PyValueOperand) -> Self { - value.0 - } -} - -static PYVALUEOPERAND_CONVERSION_LUT: Lut = GILHashMap::new(); - -fn convert_pyobject_to_valueoperand(ob: &Bound<'_, PyAny>) -> PyResult { - if let Ok(value) = convert_pyobject_to_medrecordvalue(ob) { - return Ok(ValueOperand::Value(value)); - }; - - fn convert_node_attribute_operand(ob: &Bound<'_, PyAny>) -> PyResult { - Ok(ValueOperand::Evaluate(MedRecordAttribute::from( - ob.extract::()?.0, - ))) - } - - fn convert_edge_attribute_operand(ob: &Bound<'_, PyAny>) -> PyResult { - Ok(ValueOperand::Evaluate(MedRecordAttribute::from( - ob.extract::()?.0, - ))) - } - - fn convert_arithmetic_operation(ob: &Bound<'_, PyAny>) -> PyResult { - let operation = ob.extract::()?; - - Ok(ValueOperand::ArithmeticOperation( - operation.0, - operation.1, - operation.2, - )) - } - - fn convert_transformation_operation(ob: &Bound<'_, PyAny>) -> PyResult { - let operation = ob.extract::()?; - - Ok(ValueOperand::TransformationOperation( - operation.0, - operation.1, - )) - } - - fn convert_slice_operation(ob: &Bound<'_, PyAny>) -> PyResult { - let operation = ob.extract::()?; - - Ok(ValueOperand::Slice(operation.0, operation.1)) - } - - fn throw_error(ob: &Bound<'_, PyAny>) -> PyResult { - Err( - PyMedRecordError::from(MedRecordError::ConversionError(format!( - "Failed to convert {} into ValueOperand", - ob, - ))) - .into(), - ) - } - - let type_pointer = ob.get_type_ptr() as usize; - - Python::with_gil(|py| { - PYVALUEOPERAND_CONVERSION_LUT.map(py, |lut| { - let conversion_function = lut.entry(type_pointer).or_insert_with(|| { - if ob.is_instance_of::() { - convert_node_attribute_operand - } else if ob.is_instance_of::() { - convert_edge_attribute_operand - } else if ob.is_instance_of::() { - convert_arithmetic_operation - } else if ob.is_instance_of::() { - convert_transformation_operation - } else if ob.is_instance_of::() { - convert_slice_operation - } else { - throw_error - } - }); - - conversion_function(ob) - }) - }) -} - -impl<'a> FromPyObject<'a> for PyValueOperand { - fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult { - convert_pyobject_to_valueoperand(ob).map(PyValueOperand::from) - } -} - -impl IntoPy for PyValueOperand { - fn into_py(self, py: pyo3::prelude::Python<'_>) -> PyObject { - match self.0 { - ValueOperand::Value(value) => PyMedRecordValue::from(value).into_py(py), - ValueOperand::Evaluate(attribute) => PyMedRecordAttribute::from(attribute).into_py(py), - ValueOperand::ArithmeticOperation(operation, attribute, value) => { - PyValueArithmeticOperation(operation, attribute, value).into_py(py) - } - ValueOperand::TransformationOperation(operation, attribute) => { - PyValueTransformationOperation(operation, attribute).into_py(py) - } - ValueOperand::Slice(attribute, range) => { - PyValueSliceOperation(attribute, range).into_py(py) - } - } - } -} - -#[pyclass] -#[repr(transparent)] -#[derive(Clone, Debug)] -pub struct PyNodeOperation(NodeOperation); - -impl From for PyNodeOperation { - fn from(value: NodeOperation) -> Self { - PyNodeOperation(value) - } -} - -impl From for NodeOperation { - fn from(value: PyNodeOperation) -> Self { - value.0 - } -} - -#[pymethods] -impl PyNodeOperation { - fn logical_and(&self, operation: PyNodeOperation) -> PyNodeOperation { - self.clone().0.and(operation.into()).into() - } - - fn logical_or(&self, operation: PyNodeOperation) -> PyNodeOperation { - self.clone().0.or(operation.into()).into() - } - - fn logical_xor(&self, operation: PyNodeOperation) -> PyNodeOperation { - self.clone().0.xor(operation.into()).into() - } - - fn logical_not(&self) -> PyNodeOperation { - self.clone().0.not().into() - } -} - -#[pyclass] -#[repr(transparent)] -#[derive(Clone, Debug)] -pub struct PyEdgeOperation(EdgeOperation); - -impl From for PyEdgeOperation { - fn from(value: EdgeOperation) -> Self { - PyEdgeOperation(value) - } -} - -impl From for EdgeOperation { - fn from(value: PyEdgeOperation) -> Self { - value.0 - } -} - -#[pymethods] -impl PyEdgeOperation { - fn logical_and(&self, operation: PyEdgeOperation) -> PyEdgeOperation { - self.clone().0.and(operation.into()).into() - } - - fn logical_or(&self, operation: PyEdgeOperation) -> PyEdgeOperation { - self.clone().0.or(operation.into()).into() - } - - fn logical_xor(&self, operation: PyEdgeOperation) -> PyEdgeOperation { - self.clone().0.xor(operation.into()).into() - } - - fn logical_not(&self) -> PyEdgeOperation { - self.clone().0.not().into() - } -} - -#[pyclass] -#[repr(transparent)] -#[derive(Clone, Debug)] -pub struct PyNodeAttributeOperand(pub NodeAttributeOperand); - -impl From for PyNodeAttributeOperand { - fn from(value: NodeAttributeOperand) -> Self { - PyNodeAttributeOperand(value) - } -} - -impl From for NodeAttributeOperand { - fn from(value: PyNodeAttributeOperand) -> Self { - value.0 - } -} - -#[pymethods] -impl PyNodeAttributeOperand { - fn greater(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone().0.greater(ValueOperand::from(operand)).into() - } - fn less(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone().0.less(ValueOperand::from(operand)).into() - } - fn greater_or_equal(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone() - .0 - .greater_or_equal(ValueOperand::from(operand)) - .into() - } - fn less_or_equal(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone() - .0 - .less_or_equal(ValueOperand::from(operand)) - .into() - } - - fn equal(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone().0.equal(ValueOperand::from(operand)).into() - } - fn not_equal(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone().0.not_equal(ValueOperand::from(operand)).into() - } - - fn is_in(&self, operands: Vec) -> PyNodeOperation { - self.clone().0.r#in(operands).into() - } - fn not_in(&self, operands: Vec) -> PyNodeOperation { - self.clone().0.not_in(operands).into() - } - - fn starts_with(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone() - .0 - .starts_with(ValueOperand::from(operand)) - .into() - } - - fn ends_with(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone().0.ends_with(ValueOperand::from(operand)).into() - } - - fn contains(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone().0.contains(ValueOperand::from(operand)).into() - } - - fn add(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.add(value).into() - } - - fn sub(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.sub(value).into() - } - - fn mul(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.mul(value).into() - } - - fn div(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.div(value).into() - } - - fn pow(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.pow(value).into() - } - - fn r#mod(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.r#mod(value).into() - } - - fn round(&self) -> PyValueOperand { - self.clone().0.round().into() - } - - fn ceil(&self) -> PyValueOperand { - self.clone().0.ceil().into() - } - - fn floor(&self) -> PyValueOperand { - self.clone().0.floor().into() - } - - fn abs(&self) -> PyValueOperand { - self.clone().0.abs().into() - } - - fn sqrt(&self) -> PyValueOperand { - self.clone().0.sqrt().into() - } - - fn trim(&self) -> PyValueOperand { - self.clone().0.trim().into() - } - - fn trim_start(&self) -> PyValueOperand { - self.clone().0.trim_start().into() - } - - fn trim_end(&self) -> PyValueOperand { - self.clone().0.trim_end().into() - } - - fn lowercase(&self) -> PyValueOperand { - self.clone().0.lowercase().into() - } - - fn uppercase(&self) -> PyValueOperand { - self.clone().0.uppercase().into() - } - - fn slice(&self, start: usize, end: usize) -> PyResult { - Ok(self.clone().0.slice(Range { start, end }).into()) - } -} - -#[pyclass] -#[repr(transparent)] -#[derive(Clone, Debug)] -pub struct PyEdgeAttributeOperand(EdgeAttributeOperand); - -impl From for PyEdgeAttributeOperand { - fn from(value: EdgeAttributeOperand) -> Self { - PyEdgeAttributeOperand(value) - } -} - -impl From for EdgeAttributeOperand { - fn from(value: PyEdgeAttributeOperand) -> Self { - value.0 - } -} - -#[pymethods] -impl PyEdgeAttributeOperand { - fn greater(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone().0.greater(ValueOperand::from(operand)).into() - } - fn less(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone().0.less(ValueOperand::from(operand)).into() - } - fn greater_or_equal(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone() - .0 - .greater_or_equal(ValueOperand::from(operand)) - .into() - } - fn less_or_equal(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone() - .0 - .less_or_equal(ValueOperand::from(operand)) - .into() - } - - fn equal(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone().0.equal(ValueOperand::from(operand)).into() - } - fn not_equal(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone().0.not_equal(ValueOperand::from(operand)).into() - } - - fn is_in(&self, operand: Vec) -> PyEdgeOperation { - self.clone().0.r#in(operand).into() - } - fn not_in(&self, operand: Vec) -> PyEdgeOperation { - self.clone().0.not_in(operand).into() - } - - fn starts_with(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone() - .0 - .starts_with(ValueOperand::from(operand)) - .into() - } - - fn ends_with(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone().0.ends_with(ValueOperand::from(operand)).into() - } - - fn contains(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone().0.contains(ValueOperand::from(operand)).into() - } - - fn add(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.add(value).into() - } - - fn sub(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.sub(value).into() - } - - fn mul(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.mul(value).into() - } - - fn div(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.div(value).into() - } - - fn pow(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.pow(value).into() - } - - fn r#mod(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.r#mod(value).into() - } - - fn round(&self) -> PyValueOperand { - self.clone().0.round().into() - } - - fn ceil(&self) -> PyValueOperand { - self.clone().0.ceil().into() - } - - fn floor(&self) -> PyValueOperand { - self.clone().0.floor().into() - } - - fn abs(&self) -> PyValueOperand { - self.clone().0.abs().into() - } - - fn sqrt(&self) -> PyValueOperand { - self.clone().0.sqrt().into() - } - - fn trim(&self) -> PyValueOperand { - self.clone().0.trim().into() - } - - fn trim_start(&self) -> PyValueOperand { - self.clone().0.trim_start().into() - } - - fn trim_end(&self) -> PyValueOperand { - self.clone().0.trim_end().into() - } - - fn lowercase(&self) -> PyValueOperand { - self.clone().0.lowercase().into() - } - - fn uppercase(&self) -> PyValueOperand { - self.clone().0.uppercase().into() - } - - fn slice(&self, start: usize, end: usize) -> PyResult { - Ok(self.clone().0.slice(Range { start, end }).into()) - } -} - -#[pyclass] -#[repr(transparent)] -#[derive(Clone, Debug)] -pub struct PyNodeIndexOperand(NodeIndexOperand); - -impl From for PyNodeIndexOperand { - fn from(value: NodeIndexOperand) -> Self { - PyNodeIndexOperand(value) - } -} - -impl From for NodeIndexOperand { - fn from(value: PyNodeIndexOperand) -> Self { - value.0 - } -} - -#[pymethods] -impl PyNodeIndexOperand { - fn greater(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.greater(operand).into() - } - fn less(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.less(operand).into() - } - fn greater_or_equal(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.greater_or_equal(operand).into() - } - fn less_or_equal(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.less_or_equal(operand).into() - } - - fn equal(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.equal(operand).into() - } - fn not_equal(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.not_equal(operand).into() - } - - fn is_in(&self, operand: Vec) -> PyNodeOperation { - self.clone().0.r#in(operand).into() - } - fn not_in(&self, operand: Vec) -> PyNodeOperation { - self.clone().0.not_in(operand).into() - } - - fn starts_with(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.starts_with(operand).into() - } - - fn ends_with(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.ends_with(operand).into() - } - - fn contains(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.contains(operand).into() - } -} - -#[pyclass] -#[repr(transparent)] -#[derive(Clone, Debug)] -pub struct PyEdgeIndexOperand(EdgeIndexOperand); - -impl From for PyEdgeIndexOperand { - fn from(value: EdgeIndexOperand) -> Self { - PyEdgeIndexOperand(value) - } -} - -impl From for EdgeIndexOperand { - fn from(value: PyEdgeIndexOperand) -> Self { - value.0 - } -} - -#[pymethods] -impl PyEdgeIndexOperand { - fn greater(&self, operand: EdgeIndex) -> PyEdgeOperation { - self.clone().0.greater(operand).into() - } - fn less(&self, operand: EdgeIndex) -> PyEdgeOperation { - self.clone().0.less(operand).into() - } - fn greater_or_equal(&self, operand: EdgeIndex) -> PyEdgeOperation { - self.clone().0.greater_or_equal(operand).into() - } - fn less_or_equal(&self, operand: EdgeIndex) -> PyEdgeOperation { - self.clone().0.less_or_equal(operand).into() - } - - fn equal(&self, operand: EdgeIndex) -> PyEdgeOperation { - self.clone().0.equal(operand).into() - } - fn not_equal(&self, operand: EdgeIndex) -> PyEdgeOperation { - self.clone().0.not_equal(operand).into() - } - - fn is_in(&self, operand: Vec) -> PyEdgeOperation { - self.clone().0.r#in(operand).into() - } - fn not_in(&self, operand: Vec) -> PyEdgeOperation { - self.clone().0.not_in(operand).into() - } -} - -#[pyclass] -#[repr(transparent)] -#[derive(Clone, Debug)] -pub struct PyNodeOperand(NodeOperand); - -#[pymethods] -impl PyNodeOperand { - #[new] - fn new() -> Self { - Self(NodeOperand) - } - - fn in_group(&self, operand: PyGroup) -> PyNodeOperation { - self.clone().0.in_group(operand).into() - } - - fn has_attribute(&self, operand: PyMedRecordAttribute) -> PyNodeOperation { - self.clone().0.has_attribute(operand).into() - } - - fn has_outgoing_edge_with(&self, operation: PyEdgeOperation) -> PyNodeOperation { - self.clone() - .0 - .has_outgoing_edge_with(operation.into()) - .into() - } - fn has_incoming_edge_with(&self, operation: PyEdgeOperation) -> PyNodeOperation { - self.clone() - .0 - .has_incoming_edge_with(operation.into()) - .into() - } - fn has_edge_with(&self, operation: PyEdgeOperation) -> PyNodeOperation { - self.clone().0.has_edge_with(operation.into()).into() - } - - fn has_neighbor_with(&self, operation: PyNodeOperation) -> PyNodeOperation { - self.clone().0.has_neighbor_with(operation.into()).into() - } - fn has_neighbor_undirected_with(&self, operation: PyNodeOperation) -> PyNodeOperation { - self.clone() - .0 - .has_neighbor_undirected_with(operation.into()) - .into() - } - - fn attribute(&self, attribute: PyMedRecordAttribute) -> PyNodeAttributeOperand { - self.clone().0.attribute(attribute).into() - } - - fn index(&self) -> PyNodeIndexOperand { - self.clone().0.index().into() - } -} - -#[pyclass] -#[repr(transparent)] -#[derive(Clone, Debug)] -pub struct PyEdgeOperand(EdgeOperand); - -#[pymethods] -impl PyEdgeOperand { - #[new] - fn new() -> Self { - Self(EdgeOperand) - } - - fn connected_target(&self, operand: PyNodeIndex) -> PyEdgeOperation { - self.clone().0.connected_target(operand).into() - } - - fn connected_source(&self, operand: PyNodeIndex) -> PyEdgeOperation { - self.clone().0.connected_source(operand).into() - } - - fn connected(&self, operand: PyNodeIndex) -> PyEdgeOperation { - self.clone().0.connected(operand).into() - } - - fn in_group(&self, operand: PyGroup) -> PyEdgeOperation { - self.clone().0.in_group(operand).into() - } - - fn has_attribute(&self, operand: PyMedRecordAttribute) -> PyEdgeOperation { - self.clone().0.has_attribute(operand).into() - } - - fn connected_source_with(&self, operation: PyNodeOperation) -> PyEdgeOperation { - self.clone() - .0 - .connected_source_with(operation.into()) - .into() - } - - fn connected_target_with(&self, operation: PyNodeOperation) -> PyEdgeOperation { - self.clone() - .0 - .connected_target_with(operation.into()) - .into() - } - - fn connected_with(&self, operation: PyNodeOperation) -> PyEdgeOperation { - self.clone().0.connected_with(operation.into()).into() - } - - fn has_parallel_edges_with(&self, operation: PyEdgeOperation) -> PyEdgeOperation { - self.clone() - .0 - .has_parallel_edges_with(operation.into()) - .into() - } - - fn has_parallel_edges_with_self_comparison( - &self, - operation: PyEdgeOperation, - ) -> PyEdgeOperation { - self.clone() - .0 - .has_parallel_edges_with_self_comparison(operation.into()) - .into() - } - - fn attribute(&self, attribute: PyMedRecordAttribute) -> PyEdgeAttributeOperand { - self.clone().0.attribute(attribute).into() - } - - fn index(&self) -> PyEdgeIndexOperand { - self.clone().0.index().into() - } -} diff --git a/rustmodels/src/medrecord/querying/attributes.rs b/rustmodels/src/medrecord/querying/attributes.rs new file mode 100644 index 00000000..49fc07f2 --- /dev/null +++ b/rustmodels/src/medrecord/querying/attributes.rs @@ -0,0 +1,589 @@ +use super::values::PyMultipleValuesOperand; +use crate::medrecord::{attribute::PyMedRecordAttribute, errors::PyMedRecordError}; +use medmodels_core::{ + errors::MedRecordError, + medrecord::{ + AttributesTreeOperand, DeepClone, MedRecordAttribute, MultipleAttributesComparisonOperand, + MultipleAttributesOperand, SingleAttributeComparisonOperand, SingleAttributeOperand, + Wrapper, + }, +}; +use pyo3::{ + pyclass, pymethods, + types::{PyAnyMethods, PyFunction}, + Bound, FromPyObject, PyAny, PyResult, +}; + +#[repr(transparent)] +pub struct PySingleAttributeComparisonOperand(SingleAttributeComparisonOperand); + +impl From for PySingleAttributeComparisonOperand { + fn from(operand: SingleAttributeComparisonOperand) -> Self { + Self(operand) + } +} + +impl From for SingleAttributeComparisonOperand { + fn from(operand: PySingleAttributeComparisonOperand) -> Self { + operand.0 + } +} + +impl<'a> FromPyObject<'a> for PySingleAttributeComparisonOperand { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(attribute) = ob.extract::() { + Ok(SingleAttributeComparisonOperand::Attribute(attribute.into()).into()) + } else if let Ok(operand) = ob.extract::() { + Ok(PySingleAttributeComparisonOperand(operand.0.into())) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into MedRecordValue or SingleValueOperand", + ob, + ))) + .into(), + ) + } + } +} + +#[repr(transparent)] +pub struct PyMultipleAttributesComparisonOperand(MultipleAttributesComparisonOperand); + +impl From for PyMultipleAttributesComparisonOperand { + fn from(operand: MultipleAttributesComparisonOperand) -> Self { + Self(operand) + } +} + +impl From for MultipleAttributesComparisonOperand { + fn from(operand: PyMultipleAttributesComparisonOperand) -> Self { + operand.0 + } +} + +impl<'a> FromPyObject<'a> for PyMultipleAttributesComparisonOperand { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(values) = ob.extract::>() { + Ok(MultipleAttributesComparisonOperand::Attributes( + values.into_iter().map(MedRecordAttribute::from).collect(), + ) + .into()) + } else if let Ok(operand) = ob.extract::() { + Ok(PyMultipleAttributesComparisonOperand(operand.0.into())) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into List[MedRecordAttribute] or MultipleAttributesOperand", + ob, + ))) + .into(), + ) + } + } +} + +#[pyclass] +#[repr(transparent)] +pub struct PyAttributesTreeOperand(Wrapper); + +impl From> for PyAttributesTreeOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyAttributesTreeOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyAttributesTreeOperand { + pub fn max(&self) -> PyMultipleAttributesOperand { + self.0.max().into() + } + + pub fn min(&self) -> PyMultipleAttributesOperand { + self.0.min().into() + } + + pub fn count(&self) -> PyMultipleAttributesOperand { + self.0.count().into() + } + + pub fn sum(&self) -> PyMultipleAttributesOperand { + self.0.sum().into() + } + + pub fn first(&self) -> PyMultipleAttributesOperand { + self.0.first().into() + } + + pub fn last(&self) -> PyMultipleAttributesOperand { + self.0.last().into() + } + + pub fn greater_than(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.greater_than(attribute); + } + + pub fn greater_than_or_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.greater_than_or_equal_to(attribute); + } + + pub fn less_than(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.less_than(attribute); + } + + pub fn less_than_or_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.less_than_or_equal_to(attribute); + } + + pub fn equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.equal_to(attribute); + } + + pub fn not_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.not_equal_to(attribute); + } + + pub fn starts_with(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.starts_with(attribute); + } + + pub fn ends_with(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.ends_with(attribute); + } + + pub fn contains(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.contains(attribute); + } + + pub fn is_in(&self, attributes: PyMultipleAttributesComparisonOperand) { + self.0.is_in(attributes); + } + + pub fn is_not_in(&self, attributes: PyMultipleAttributesComparisonOperand) { + self.0.is_not_in(attributes); + } + + pub fn add(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.add(attribute); + } + + pub fn sub(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.sub(attribute); + } + + pub fn mul(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.mul(attribute); + } + + pub fn pow(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.pow(attribute); + } + + pub fn r#mod(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.r#mod(attribute); + } + + pub fn abs(&self) { + self.0.abs(); + } + + pub fn trim(&self) { + self.0.trim(); + } + + pub fn trim_start(&self) { + self.0.trim_start(); + } + + pub fn trim_end(&self) { + self.0.trim_end(); + } + + pub fn lowercase(&self) { + self.0.lowercase(); + } + + pub fn uppercase(&self) { + self.0.uppercase(); + } + + pub fn slice(&self, start: usize, end: usize) { + self.0.slice(start, end); + } + + pub fn is_string(&self) { + self.0.is_string(); + } + + pub fn is_int(&self) { + self.0.is_int(); + } + + pub fn is_max(&self) { + self.0.is_max(); + } + + pub fn is_min(&self) { + self.0.is_min(); + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyAttributesTreeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyAttributesTreeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn exclude(&mut self, query: &Bound<'_, PyFunction>) { + self.0.exclude(|operand| { + query + .call1((PyAttributesTreeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }); + } + + pub fn deep_clone(&self) -> PyAttributesTreeOperand { + self.0.deep_clone().into() + } +} + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PyMultipleAttributesOperand(Wrapper); + +impl From> for PyMultipleAttributesOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyMultipleAttributesOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyMultipleAttributesOperand { + pub fn max(&self) -> PySingleAttributeOperand { + self.0.max().into() + } + + pub fn min(&self) -> PySingleAttributeOperand { + self.0.min().into() + } + + pub fn count(&self) -> PySingleAttributeOperand { + self.0.count().into() + } + + pub fn sum(&self) -> PySingleAttributeOperand { + self.0.sum().into() + } + + pub fn first(&self) -> PySingleAttributeOperand { + self.0.first().into() + } + + pub fn last(&self) -> PySingleAttributeOperand { + self.0.last().into() + } + + pub fn greater_than(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.greater_than(attribute); + } + + pub fn greater_than_or_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.greater_than_or_equal_to(attribute); + } + + pub fn less_than(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.less_than(attribute); + } + + pub fn less_than_or_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.less_than_or_equal_to(attribute); + } + + pub fn equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.equal_to(attribute); + } + + pub fn not_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.not_equal_to(attribute); + } + + pub fn starts_with(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.starts_with(attribute); + } + + pub fn ends_with(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.ends_with(attribute); + } + + pub fn contains(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.contains(attribute); + } + + pub fn is_in(&self, attributes: PyMultipleAttributesComparisonOperand) { + self.0.is_in(attributes); + } + + pub fn is_not_in(&self, attributes: PyMultipleAttributesComparisonOperand) { + self.0.is_not_in(attributes); + } + + pub fn add(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.add(attribute); + } + + pub fn sub(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.sub(attribute); + } + + pub fn mul(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.mul(attribute); + } + + pub fn pow(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.pow(attribute); + } + + pub fn r#mod(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.r#mod(attribute); + } + + pub fn abs(&self) { + self.0.abs(); + } + + pub fn trim(&self) { + self.0.trim(); + } + + pub fn trim_start(&self) { + self.0.trim_start(); + } + + pub fn trim_end(&self) { + self.0.trim_end(); + } + + pub fn lowercase(&self) { + self.0.lowercase(); + } + + pub fn uppercase(&self) { + self.0.uppercase(); + } + + pub fn to_values(&self) -> PyMultipleValuesOperand { + self.0.to_values().into() + } + + pub fn slice(&self, start: usize, end: usize) { + self.0.slice(start, end); + } + + pub fn is_string(&self) { + self.0.is_string(); + } + + pub fn is_int(&self) { + self.0.is_int(); + } + + pub fn is_max(&self) { + self.0.is_max(); + } + + pub fn is_min(&self) { + self.0.is_min(); + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyMultipleAttributesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyMultipleAttributesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn exclude(&mut self, query: &Bound<'_, PyFunction>) { + self.0.exclude(|operand| { + query + .call1((PyMultipleAttributesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }); + } + + pub fn deep_clone(&self) -> PyMultipleAttributesOperand { + self.0.deep_clone().into() + } +} + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PySingleAttributeOperand(Wrapper); + +impl From> for PySingleAttributeOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PySingleAttributeOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PySingleAttributeOperand { + pub fn greater_than(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.greater_than(attribute); + } + + pub fn greater_than_or_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.greater_than_or_equal_to(attribute); + } + + pub fn less_than(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.less_than(attribute); + } + + pub fn less_than_or_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.less_than_or_equal_to(attribute); + } + + pub fn equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.equal_to(attribute); + } + + pub fn not_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.not_equal_to(attribute); + } + + pub fn starts_with(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.starts_with(attribute); + } + + pub fn ends_with(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.ends_with(attribute); + } + + pub fn contains(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.contains(attribute); + } + + pub fn is_in(&self, attributes: PyMultipleAttributesComparisonOperand) { + self.0.is_in(attributes); + } + + pub fn is_not_in(&self, attributes: PyMultipleAttributesComparisonOperand) { + self.0.is_not_in(attributes); + } + + pub fn add(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.add(attribute); + } + + pub fn sub(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.sub(attribute); + } + + pub fn mul(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.mul(attribute); + } + + pub fn pow(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.pow(attribute); + } + + pub fn r#mod(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.r#mod(attribute); + } + + pub fn abs(&self) { + self.0.abs(); + } + + pub fn trim(&self) { + self.0.trim(); + } + + pub fn trim_start(&self) { + self.0.trim_start(); + } + + pub fn trim_end(&self) { + self.0.trim_end(); + } + + pub fn lowercase(&self) { + self.0.lowercase(); + } + + pub fn uppercase(&self) { + self.0.uppercase(); + } + + pub fn slice(&self, start: usize, end: usize) { + self.0.slice(start, end); + } + + pub fn is_string(&self) { + self.0.is_string(); + } + + pub fn is_int(&self) { + self.0.is_int(); + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PySingleAttributeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PySingleAttributeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn exclude(&mut self, query: &Bound<'_, PyFunction>) { + self.0.exclude(|operand| { + query + .call1((PySingleAttributeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }); + } + + pub fn deep_clone(&self) -> PySingleAttributeOperand { + self.0.deep_clone().into() + } +} diff --git a/rustmodels/src/medrecord/querying/edges.rs b/rustmodels/src/medrecord/querying/edges.rs new file mode 100644 index 00000000..d0d86e22 --- /dev/null +++ b/rustmodels/src/medrecord/querying/edges.rs @@ -0,0 +1,408 @@ +use super::{ + attributes::PyAttributesTreeOperand, nodes::PyNodeOperand, values::PyMultipleValuesOperand, + PyGroupCardinalityWrapper, PyMedRecordAttributeCardinalityWrapper, +}; +use crate::medrecord::{attribute::PyMedRecordAttribute, errors::PyMedRecordError}; +use medmodels_core::{ + errors::MedRecordError, + medrecord::{ + DeepClone, EdgeIndex, EdgeIndexComparisonOperand, EdgeIndexOperand, + EdgeIndicesComparisonOperand, EdgeIndicesOperand, EdgeOperand, Wrapper, + }, +}; +use pyo3::{ + pyclass, pymethods, + types::{PyAnyMethods, PyFunction}, + Bound, FromPyObject, PyAny, PyResult, +}; + +#[pyclass] +#[repr(transparent)] +pub struct PyEdgeOperand(Wrapper); + +impl From> for PyEdgeOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyEdgeOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyEdgeOperand { + pub fn attribute(&mut self, attribute: PyMedRecordAttribute) -> PyMultipleValuesOperand { + self.0.attribute(attribute).into() + } + + pub fn attributes(&mut self) -> PyAttributesTreeOperand { + self.0.attributes().into() + } + + pub fn index(&mut self) -> PyEdgeIndicesOperand { + self.0.index().into() + } + + pub fn in_group(&mut self, group: PyGroupCardinalityWrapper) { + self.0.in_group(group); + } + + pub fn has_attribute(&mut self, attribute: PyMedRecordAttributeCardinalityWrapper) { + self.0.has_attribute(attribute); + } + + pub fn source_node(&mut self) -> PyNodeOperand { + self.0.source_node().into() + } + + pub fn target_node(&mut self) -> PyNodeOperand { + self.0.target_node().into() + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyEdgeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyEdgeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn exclude(&mut self, query: &Bound<'_, PyFunction>) { + self.0.exclude(|operand| { + query + .call1((PyEdgeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }); + } + + pub fn deep_clone(&self) -> PyEdgeOperand { + self.0.deep_clone().into() + } +} + +#[repr(transparent)] +pub struct PyEdgeIndexComparisonOperand(EdgeIndexComparisonOperand); + +impl From for PyEdgeIndexComparisonOperand { + fn from(operand: EdgeIndexComparisonOperand) -> Self { + Self(operand) + } +} + +impl From for EdgeIndexComparisonOperand { + fn from(operand: PyEdgeIndexComparisonOperand) -> Self { + operand.0 + } +} + +impl<'a> FromPyObject<'a> for PyEdgeIndexComparisonOperand { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(index) = ob.extract::() { + Ok(EdgeIndexComparisonOperand::Index(index).into()) + } else if let Ok(operand) = ob.extract::() { + Ok(PyEdgeIndexComparisonOperand(operand.0.into())) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into EdgeIndex or EdgeIndexOperand", + ob, + ))) + .into(), + ) + } + } +} + +#[repr(transparent)] +pub struct PyEdgeIndicesComparisonOperand(EdgeIndicesComparisonOperand); + +impl From for PyEdgeIndicesComparisonOperand { + fn from(operand: EdgeIndicesComparisonOperand) -> Self { + Self(operand) + } +} + +impl From for EdgeIndicesComparisonOperand { + fn from(operand: PyEdgeIndicesComparisonOperand) -> Self { + operand.0 + } +} + +impl<'a> FromPyObject<'a> for PyEdgeIndicesComparisonOperand { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(indices) = ob.extract::>() { + Ok(EdgeIndicesComparisonOperand::Indices(indices).into()) + } else if let Ok(operand) = ob.extract::() { + Ok(PyEdgeIndicesComparisonOperand(operand.0.into())) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into List[EdgeIndex] or EdgeIndicesOperand", + ob, + ))) + .into(), + ) + } + } +} + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PyEdgeIndicesOperand(Wrapper); + +impl From> for PyEdgeIndicesOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyEdgeIndicesOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyEdgeIndicesOperand { + pub fn max(&mut self) -> PyEdgeIndexOperand { + self.0.max().into() + } + + pub fn min(&mut self) -> PyEdgeIndexOperand { + self.0.min().into() + } + + pub fn count(&mut self) -> PyEdgeIndexOperand { + self.0.count().into() + } + + pub fn sum(&mut self) -> PyEdgeIndexOperand { + self.0.sum().into() + } + + pub fn first(&mut self) -> PyEdgeIndexOperand { + self.0.first().into() + } + + pub fn last(&mut self) -> PyEdgeIndexOperand { + self.0.last().into() + } + + pub fn greater_than(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.greater_than(index); + } + + pub fn greater_than_or_equal_to(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.greater_than_or_equal_to(index); + } + + pub fn less_than(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.less_than(index); + } + + pub fn less_than_or_equal_to(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.less_than_or_equal_to(index); + } + + pub fn equal_to(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.equal_to(index); + } + + pub fn not_equal_to(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.not_equal_to(index); + } + + pub fn starts_with(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.starts_with(index); + } + + pub fn ends_with(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.ends_with(index); + } + + pub fn contains(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.contains(index); + } + + pub fn is_in(&mut self, indices: PyEdgeIndicesComparisonOperand) { + self.0.is_in(indices); + } + + pub fn is_not_in(&mut self, indices: PyEdgeIndicesComparisonOperand) { + self.0.is_not_in(indices); + } + + pub fn add(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.add(index); + } + + pub fn sub(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.sub(index); + } + + pub fn mul(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.mul(index); + } + + pub fn pow(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.pow(index); + } + + pub fn r#mod(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.r#mod(index); + } + + pub fn is_max(&mut self) { + self.0.is_max() + } + + pub fn is_min(&mut self) { + self.0.is_min() + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyEdgeIndicesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyEdgeIndicesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn exclude(&mut self, query: &Bound<'_, PyFunction>) { + self.0.exclude(|operand| { + query + .call1((PyEdgeIndicesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }); + } + + pub fn deep_clone(&self) -> PyEdgeIndicesOperand { + self.0.deep_clone().into() + } +} + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PyEdgeIndexOperand(Wrapper); + +impl From> for PyEdgeIndexOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyEdgeIndexOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyEdgeIndexOperand { + pub fn greater_than(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.greater_than(index); + } + + pub fn greater_than_or_equal_to(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.greater_than_or_equal_to(index); + } + + pub fn less_than(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.less_than(index); + } + + pub fn less_than_or_equal_to(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.less_than_or_equal_to(index); + } + + pub fn equal_to(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.equal_to(index); + } + + pub fn not_equal_to(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.not_equal_to(index); + } + + pub fn starts_with(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.starts_with(index); + } + + pub fn ends_with(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.ends_with(index); + } + + pub fn contains(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.contains(index); + } + + pub fn is_in(&mut self, indices: PyEdgeIndicesComparisonOperand) { + self.0.is_in(indices); + } + + pub fn is_not_in(&mut self, indices: PyEdgeIndicesComparisonOperand) { + self.0.is_not_in(indices); + } + + pub fn add(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.add(index); + } + + pub fn sub(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.sub(index); + } + + pub fn mul(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.mul(index); + } + + pub fn pow(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.pow(index); + } + + pub fn r#mod(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.r#mod(index); + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyEdgeIndexOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyEdgeIndexOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn exclude(&mut self, query: &Bound<'_, PyFunction>) { + self.0.exclude(|operand| { + query + .call1((PyEdgeIndexOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }); + } + + pub fn deep_clone(&self) -> PyEdgeIndexOperand { + self.0.deep_clone().into() + } +} diff --git a/rustmodels/src/medrecord/querying/mod.rs b/rustmodels/src/medrecord/querying/mod.rs new file mode 100644 index 00000000..cfd7b868 --- /dev/null +++ b/rustmodels/src/medrecord/querying/mod.rs @@ -0,0 +1,52 @@ +pub mod attributes; +pub mod edges; +pub mod nodes; +pub mod values; + +use super::{attribute::PyMedRecordAttribute, errors::PyMedRecordError}; +use medmodels_core::{ + errors::MedRecordError, + medrecord::{CardinalityWrapper, MedRecordAttribute}, +}; +use pyo3::{types::PyAnyMethods, Bound, FromPyObject, PyAny, PyResult}; + +#[repr(transparent)] +pub struct PyMedRecordAttributeCardinalityWrapper(CardinalityWrapper); + +impl From> for PyMedRecordAttributeCardinalityWrapper { + fn from(attribute: CardinalityWrapper) -> Self { + Self(attribute) + } +} + +impl From for CardinalityWrapper { + fn from(attribute: PyMedRecordAttributeCardinalityWrapper) -> Self { + attribute.0 + } +} + +impl<'a> FromPyObject<'a> for PyMedRecordAttributeCardinalityWrapper { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(attribute) = ob.extract::() { + Ok(CardinalityWrapper::Single(MedRecordAttribute::from(attribute)).into()) + } else if let Ok(attributes) = ob.extract::>() { + Ok(CardinalityWrapper::Multiple( + attributes + .into_iter() + .map(MedRecordAttribute::from) + .collect(), + ) + .into()) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into MedRecordAttribute or List[MedREcordAttribute]", + ob, + ))) + .into(), + ) + } + } +} + +type PyGroupCardinalityWrapper = PyMedRecordAttributeCardinalityWrapper; diff --git a/rustmodels/src/medrecord/querying/nodes.rs b/rustmodels/src/medrecord/querying/nodes.rs new file mode 100644 index 00000000..ec4120f9 --- /dev/null +++ b/rustmodels/src/medrecord/querying/nodes.rs @@ -0,0 +1,515 @@ +use super::{ + attributes::PyAttributesTreeOperand, edges::PyEdgeOperand, values::PyMultipleValuesOperand, + PyGroupCardinalityWrapper, PyMedRecordAttributeCardinalityWrapper, +}; +use crate::medrecord::{attribute::PyMedRecordAttribute, errors::PyMedRecordError, PyNodeIndex}; +use medmodels_core::{ + errors::MedRecordError, + medrecord::{ + DeepClone, EdgeDirection, NodeIndex, NodeIndexComparisonOperand, NodeIndexOperand, + NodeIndicesComparisonOperand, NodeIndicesOperand, NodeOperand, Wrapper, + }, +}; +use pyo3::{ + pyclass, pymethods, + types::{PyAnyMethods, PyFunction}, + Bound, FromPyObject, PyAny, PyResult, +}; + +#[pyclass] +#[derive(Clone)] +pub enum PyEdgeDirection { + Incoming = 0, + Outgoing = 1, + Both = 2, +} + +impl From for PyEdgeDirection { + fn from(value: EdgeDirection) -> Self { + match value { + EdgeDirection::Incoming => Self::Incoming, + EdgeDirection::Outgoing => Self::Outgoing, + EdgeDirection::Both => Self::Both, + } + } +} + +impl From for EdgeDirection { + fn from(value: PyEdgeDirection) -> Self { + match value { + PyEdgeDirection::Incoming => Self::Incoming, + PyEdgeDirection::Outgoing => Self::Outgoing, + PyEdgeDirection::Both => Self::Both, + } + } +} + +#[pyclass] +#[repr(transparent)] +pub struct PyNodeOperand(Wrapper); + +impl From> for PyNodeOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyNodeOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyNodeOperand { + pub fn attribute(&mut self, attribute: PyMedRecordAttribute) -> PyMultipleValuesOperand { + self.0.attribute(attribute).into() + } + + pub fn attributes(&mut self) -> PyAttributesTreeOperand { + self.0.attributes().into() + } + + pub fn index(&mut self) -> PyNodeIndicesOperand { + self.0.index().into() + } + + pub fn in_group(&mut self, group: PyGroupCardinalityWrapper) { + self.0.in_group(group); + } + + pub fn has_attribute(&mut self, attribute: PyMedRecordAttributeCardinalityWrapper) { + self.0.has_attribute(attribute); + } + + pub fn outgoing_edges(&mut self) -> PyEdgeOperand { + self.0.outgoing_edges().into() + } + + pub fn incoming_edges(&mut self) -> PyEdgeOperand { + self.0.incoming_edges().into() + } + + pub fn neighbors(&mut self, direction: PyEdgeDirection) -> PyNodeOperand { + self.0.neighbors(direction.into()).into() + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyNodeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyNodeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn exclude(&mut self, query: &Bound<'_, PyFunction>) { + self.0.exclude(|operand| { + query + .call1((PyNodeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }); + } + + pub fn deep_clone(&self) -> Self { + self.0.deep_clone().into() + } +} + +#[repr(transparent)] +pub struct PyNodeIndexComparisonOperand(NodeIndexComparisonOperand); + +impl From for PyNodeIndexComparisonOperand { + fn from(operand: NodeIndexComparisonOperand) -> Self { + Self(operand) + } +} + +impl From for NodeIndexComparisonOperand { + fn from(operand: PyNodeIndexComparisonOperand) -> Self { + operand.0 + } +} + +impl<'a> FromPyObject<'a> for PyNodeIndexComparisonOperand { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(index) = ob.extract::() { + Ok(NodeIndexComparisonOperand::Index(NodeIndex::from(index)).into()) + } else if let Ok(operand) = ob.extract::() { + Ok(PyNodeIndexComparisonOperand(operand.0.into())) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into NodeIndex or NodeIndexOperand", + ob, + ))) + .into(), + ) + } + } +} + +#[repr(transparent)] +pub struct PyNodeIndicesComparisonOperand(NodeIndicesComparisonOperand); + +impl From for PyNodeIndicesComparisonOperand { + fn from(operand: NodeIndicesComparisonOperand) -> Self { + Self(operand) + } +} + +impl From for NodeIndicesComparisonOperand { + fn from(operand: PyNodeIndicesComparisonOperand) -> Self { + operand.0 + } +} + +impl<'a> FromPyObject<'a> for PyNodeIndicesComparisonOperand { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(indices) = ob.extract::>() { + Ok(NodeIndicesComparisonOperand::Indices( + indices.into_iter().map(NodeIndex::from).collect(), + ) + .into()) + } else if let Ok(operand) = ob.extract::() { + Ok(PyNodeIndicesComparisonOperand(operand.0.into())) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into List[NodeIndex] or NodeIndicesOperand", + ob, + ))) + .into(), + ) + } + } +} + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PyNodeIndicesOperand(Wrapper); + +impl From> for PyNodeIndicesOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyNodeIndicesOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyNodeIndicesOperand { + pub fn max(&mut self) -> PyNodeIndexOperand { + self.0.max().into() + } + + pub fn min(&mut self) -> PyNodeIndexOperand { + self.0.min().into() + } + + pub fn count(&mut self) -> PyNodeIndexOperand { + self.0.count().into() + } + + pub fn sum(&mut self) -> PyNodeIndexOperand { + self.0.sum().into() + } + + pub fn first(&mut self) -> PyNodeIndexOperand { + self.0.first().into() + } + + pub fn last(&mut self) -> PyNodeIndexOperand { + self.0.last().into() + } + + pub fn greater_than(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.greater_than(index); + } + + pub fn greater_than_or_equal_to(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.greater_than_or_equal_to(index); + } + + pub fn less_than(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.less_than(index); + } + + pub fn less_than_or_equal_to(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.less_than_or_equal_to(index); + } + + pub fn equal_to(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.equal_to(index); + } + + pub fn not_equal_to(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.not_equal_to(index); + } + + pub fn starts_with(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.starts_with(index); + } + + pub fn ends_with(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.ends_with(index); + } + + pub fn contains(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.contains(index); + } + + pub fn is_in(&mut self, indices: PyNodeIndicesComparisonOperand) { + self.0.is_in(indices); + } + + pub fn is_not_in(&mut self, indices: PyNodeIndicesComparisonOperand) { + self.0.is_not_in(indices); + } + + pub fn add(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.add(index); + } + + pub fn sub(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.sub(index); + } + + pub fn mul(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.mul(index); + } + + pub fn pow(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.pow(index); + } + + pub fn r#mod(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.r#mod(index); + } + + pub fn abs(&mut self) { + self.0.abs(); + } + + pub fn trim(&mut self) { + self.0.trim(); + } + + pub fn trim_start(&mut self) { + self.0.trim_start(); + } + + pub fn trim_end(&mut self) { + self.0.trim_end(); + } + + pub fn lowercase(&mut self) { + self.0.lowercase(); + } + + pub fn uppercase(&mut self) { + self.0.uppercase(); + } + + pub fn slice(&mut self, start: usize, end: usize) { + self.0.slice(start, end); + } + + pub fn is_string(&mut self) { + self.0.is_string(); + } + + pub fn is_int(&mut self) { + self.0.is_int(); + } + + pub fn is_max(&mut self) { + self.0.is_max(); + } + + pub fn is_min(&mut self) { + self.0.is_min(); + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyNodeIndicesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyNodeIndicesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn exclude(&mut self, query: &Bound<'_, PyFunction>) { + self.0.exclude(|operand| { + query + .call1((PyNodeIndicesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }); + } + + pub fn deep_clone(&self) -> Self { + self.0.deep_clone().into() + } +} + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PyNodeIndexOperand(Wrapper); + +impl From> for PyNodeIndexOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyNodeIndexOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyNodeIndexOperand { + pub fn greater_than(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.greater_than(index); + } + + pub fn greater_than_or_equal_to(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.greater_than_or_equal_to(index); + } + + pub fn less_than(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.less_than(index); + } + + pub fn less_than_or_equal_to(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.less_than_or_equal_to(index); + } + + pub fn equal_to(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.equal_to(index); + } + + pub fn not_equal_to(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.not_equal_to(index); + } + + pub fn starts_with(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.starts_with(index); + } + + pub fn ends_with(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.ends_with(index); + } + + pub fn contains(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.contains(index); + } + + pub fn is_in(&mut self, indices: PyNodeIndicesComparisonOperand) { + self.0.is_in(indices); + } + + pub fn is_not_in(&mut self, indices: PyNodeIndicesComparisonOperand) { + self.0.is_not_in(indices); + } + + pub fn add(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.add(index); + } + + pub fn sub(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.sub(index); + } + + pub fn mul(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.mul(index); + } + + pub fn pow(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.pow(index); + } + + pub fn r#mod(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.r#mod(index); + } + + pub fn abs(&mut self) { + self.0.abs(); + } + + pub fn trim(&mut self) { + self.0.trim(); + } + + pub fn trim_start(&mut self) { + self.0.trim_start(); + } + + pub fn trim_end(&mut self) { + self.0.trim_end(); + } + + pub fn lowercase(&mut self) { + self.0.lowercase(); + } + + pub fn uppercase(&mut self) { + self.0.uppercase(); + } + + pub fn slice(&mut self, start: usize, end: usize) { + self.0.slice(start, end); + } + + pub fn is_string(&mut self) { + self.0.is_string(); + } + + pub fn is_int(&mut self) { + self.0.is_int(); + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyNodeIndexOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyNodeIndexOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn exclude(&mut self, query: &Bound<'_, PyFunction>) { + self.0.exclude(|operand| { + query + .call1((PyNodeIndexOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }); + } + + pub fn deep_clone(&self) -> Self { + self.0.deep_clone().into() + } +} diff --git a/rustmodels/src/medrecord/querying/values.rs b/rustmodels/src/medrecord/querying/values.rs new file mode 100644 index 00000000..af99a6ea --- /dev/null +++ b/rustmodels/src/medrecord/querying/values.rs @@ -0,0 +1,498 @@ +use crate::medrecord::{errors::PyMedRecordError, value::PyMedRecordValue}; +use medmodels_core::{ + errors::MedRecordError, + medrecord::{ + DeepClone, MedRecordValue, MultipleValuesComparisonOperand, MultipleValuesOperand, + SingleValueComparisonOperand, SingleValueOperand, Wrapper, + }, +}; +use pyo3::{ + pyclass, pymethods, + types::{PyAnyMethods, PyFunction}, + Bound, FromPyObject, PyAny, PyResult, +}; + +#[repr(transparent)] +pub struct PySingleValueComparisonOperand(SingleValueComparisonOperand); + +impl From for PySingleValueComparisonOperand { + fn from(operand: SingleValueComparisonOperand) -> Self { + Self(operand) + } +} + +impl From for SingleValueComparisonOperand { + fn from(operand: PySingleValueComparisonOperand) -> Self { + operand.0 + } +} + +impl<'a> FromPyObject<'a> for PySingleValueComparisonOperand { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(value) = ob.extract::() { + Ok(SingleValueComparisonOperand::Value(value.into()).into()) + } else if let Ok(operand) = ob.extract::() { + Ok(PySingleValueComparisonOperand(operand.0.into())) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into MedRecordValue or SingleValueOperand", + ob, + ))) + .into(), + ) + } + } +} + +#[repr(transparent)] +pub struct PyMultipleValuesComparisonOperand(MultipleValuesComparisonOperand); + +impl From for PyMultipleValuesComparisonOperand { + fn from(operand: MultipleValuesComparisonOperand) -> Self { + Self(operand) + } +} + +impl From for MultipleValuesComparisonOperand { + fn from(operand: PyMultipleValuesComparisonOperand) -> Self { + operand.0 + } +} + +impl<'a> FromPyObject<'a> for PyMultipleValuesComparisonOperand { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(values) = ob.extract::>() { + Ok(MultipleValuesComparisonOperand::Values( + values.into_iter().map(MedRecordValue::from).collect(), + ) + .into()) + } else if let Ok(operand) = ob.extract::() { + Ok(PyMultipleValuesComparisonOperand(operand.0.into())) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into List[MedRecordValue] or MultipleValuesOperand", + ob, + ))) + .into(), + ) + } + } +} + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PyMultipleValuesOperand(Wrapper); + +impl From> for PyMultipleValuesOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyMultipleValuesOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyMultipleValuesOperand { + pub fn max(&self) -> PySingleValueOperand { + self.0.max().into() + } + + pub fn min(&self) -> PySingleValueOperand { + self.0.min().into() + } + + pub fn mean(&self) -> PySingleValueOperand { + self.0.mean().into() + } + + pub fn median(&self) -> PySingleValueOperand { + self.0.median().into() + } + + pub fn mode(&self) -> PySingleValueOperand { + self.0.mode().into() + } + + pub fn std(&self) -> PySingleValueOperand { + self.0.std().into() + } + + pub fn var(&self) -> PySingleValueOperand { + self.0.var().into() + } + + pub fn count(&self) -> PySingleValueOperand { + self.0.count().into() + } + + pub fn sum(&self) -> PySingleValueOperand { + self.0.sum().into() + } + + pub fn first(&self) -> PySingleValueOperand { + self.0.first().into() + } + + pub fn last(&self) -> PySingleValueOperand { + self.0.last().into() + } + + pub fn greater_than(&self, value: PySingleValueComparisonOperand) { + self.0.greater_than(value); + } + + pub fn greater_than_or_equal_to(&self, value: PySingleValueComparisonOperand) { + self.0.greater_than_or_equal_to(value); + } + + pub fn less_than(&self, value: PySingleValueComparisonOperand) { + self.0.less_than(value); + } + + pub fn less_than_or_equal_to(&self, value: PySingleValueComparisonOperand) { + self.0.less_than_or_equal_to(value); + } + + pub fn equal_to(&self, value: PySingleValueComparisonOperand) { + self.0.equal_to(value); + } + + pub fn not_equal_to(&self, value: PySingleValueComparisonOperand) { + self.0.not_equal_to(value); + } + + pub fn starts_with(&self, value: PySingleValueComparisonOperand) { + self.0.starts_with(value); + } + + pub fn ends_with(&self, value: PySingleValueComparisonOperand) { + self.0.ends_with(value); + } + + pub fn contains(&self, value: PySingleValueComparisonOperand) { + self.0.contains(value); + } + + pub fn is_in(&self, values: PyMultipleValuesComparisonOperand) { + self.0.is_in(values); + } + + pub fn is_not_in(&self, values: PyMultipleValuesComparisonOperand) { + self.0.is_not_in(values); + } + + pub fn add(&self, value: PySingleValueComparisonOperand) { + self.0.add(value); + } + + pub fn sub(&self, value: PySingleValueComparisonOperand) { + self.0.sub(value); + } + + pub fn mul(&self, value: PySingleValueComparisonOperand) { + self.0.mul(value); + } + + pub fn div(&self, value: PySingleValueComparisonOperand) { + self.0.div(value); + } + + pub fn pow(&self, value: PySingleValueComparisonOperand) { + self.0.pow(value); + } + + pub fn r#mod(&self, value: PySingleValueComparisonOperand) { + self.0.r#mod(value); + } + + pub fn round(&self) { + self.0.round(); + } + + pub fn ceil(&self) { + self.0.ceil(); + } + + pub fn floor(&self) { + self.0.floor(); + } + + pub fn abs(&self) { + self.0.abs(); + } + + pub fn sqrt(&self) { + self.0.sqrt(); + } + + pub fn trim(&self) { + self.0.trim(); + } + + pub fn trim_start(&self) { + self.0.trim_start(); + } + + pub fn trim_end(&self) { + self.0.trim_end(); + } + + pub fn lowercase(&self) { + self.0.lowercase(); + } + + pub fn uppercase(&self) { + self.0.uppercase(); + } + + pub fn slice(&self, start: usize, end: usize) { + self.0.slice(start, end); + } + + pub fn is_string(&self) { + self.0.is_string(); + } + + pub fn is_int(&self) { + self.0.is_int(); + } + + pub fn is_float(&self) { + self.0.is_float(); + } + + pub fn is_bool(&self) { + self.0.is_bool(); + } + + pub fn is_datetime(&self) { + self.0.is_datetime(); + } + + pub fn is_null(&self) { + self.0.is_null(); + } + + pub fn is_max(&self) { + self.0.is_max(); + } + + pub fn is_min(&self) { + self.0.is_min(); + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyMultipleValuesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyMultipleValuesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn exclude(&mut self, query: &Bound<'_, PyFunction>) { + self.0.exclude(|operand| { + query + .call1((PyMultipleValuesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }); + } + + pub fn deep_clone(&self) -> PyMultipleValuesOperand { + self.0.deep_clone().into() + } +} + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PySingleValueOperand(Wrapper); + +impl From> for PySingleValueOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PySingleValueOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PySingleValueOperand { + pub fn greater_than(&self, value: PySingleValueComparisonOperand) { + self.0.greater_than(value); + } + + pub fn greater_than_or_equal_to(&self, value: PySingleValueComparisonOperand) { + self.0.greater_than_or_equal_to(value); + } + + pub fn less_than(&self, value: PySingleValueComparisonOperand) { + self.0.less_than(value); + } + + pub fn less_than_or_equal_to(&self, value: PySingleValueComparisonOperand) { + self.0.less_than_or_equal_to(value); + } + + pub fn equal_to(&self, value: PySingleValueComparisonOperand) { + self.0.equal_to(value); + } + + pub fn not_equal_to(&self, value: PySingleValueComparisonOperand) { + self.0.not_equal_to(value); + } + + pub fn starts_with(&self, value: PySingleValueComparisonOperand) { + self.0.starts_with(value); + } + + pub fn ends_with(&self, value: PySingleValueComparisonOperand) { + self.0.ends_with(value); + } + + pub fn contains(&self, value: PySingleValueComparisonOperand) { + self.0.contains(value); + } + + pub fn is_in(&self, values: PyMultipleValuesComparisonOperand) { + self.0.is_in(values); + } + + pub fn is_not_in(&self, values: PyMultipleValuesComparisonOperand) { + self.0.is_not_in(values); + } + + pub fn add(&self, value: PySingleValueComparisonOperand) { + self.0.add(value); + } + + pub fn sub(&self, value: PySingleValueComparisonOperand) { + self.0.sub(value); + } + + pub fn mul(&self, value: PySingleValueComparisonOperand) { + self.0.mul(value); + } + + pub fn div(&self, value: PySingleValueComparisonOperand) { + self.0.div(value); + } + + pub fn pow(&self, value: PySingleValueComparisonOperand) { + self.0.pow(value); + } + + pub fn r#mod(&self, value: PySingleValueComparisonOperand) { + self.0.r#mod(value); + } + + pub fn round(&self) { + self.0.round(); + } + + pub fn ceil(&self) { + self.0.ceil(); + } + + pub fn floor(&self) { + self.0.floor(); + } + + pub fn abs(&self) { + self.0.abs(); + } + + pub fn sqrt(&self) { + self.0.sqrt(); + } + + pub fn trim(&self) { + self.0.trim(); + } + + pub fn trim_start(&self) { + self.0.trim_start(); + } + + pub fn trim_end(&self) { + self.0.trim_end(); + } + + pub fn lowercase(&self) { + self.0.lowercase(); + } + + pub fn uppercase(&self) { + self.0.uppercase(); + } + + pub fn slice(&self, start: usize, end: usize) { + self.0.slice(start, end); + } + + pub fn is_string(&self) { + self.0.is_string(); + } + + pub fn is_int(&self) { + self.0.is_int(); + } + + pub fn is_float(&self) { + self.0.is_float(); + } + + pub fn is_bool(&self) { + self.0.is_bool(); + } + + pub fn is_datetime(&self) { + self.0.is_datetime(); + } + + pub fn is_null(&self) { + self.0.is_null(); + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PySingleValueOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PySingleValueOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn exclude(&mut self, query: &Bound<'_, PyFunction>) { + self.0.exclude(|operand| { + query + .call1((PySingleValueOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }); + } + + pub fn deep_clone(&self) -> PySingleValueOperand { + self.0.deep_clone().into() + } +} diff --git a/rustmodels/src/medrecord/value.rs b/rustmodels/src/medrecord/value.rs index 1ae6b5c7..489e7a97 100644 --- a/rustmodels/src/medrecord/value.rs +++ b/rustmodels/src/medrecord/value.rs @@ -10,7 +10,7 @@ use std::ops::Deref; #[repr(transparent)] #[derive(Clone, Debug)] -pub(crate) struct PyMedRecordValue(MedRecordValue); +pub struct PyMedRecordValue(MedRecordValue); impl From for PyMedRecordValue { fn from(value: MedRecordValue) -> Self {