From 8000ce192ad00a0b46c4e3dd85079caa823b73be Mon Sep 17 00:00:00 2001 From: Jakob Kraus <52459467+JabobKrauskopf@users.noreply.github.com> Date: Wed, 4 Sep 2024 12:57:05 +0200 Subject: [PATCH 1/8] refactor: implement new querying interface (#191) --- medmodels/medrecord/__init__.py | 14 +- medmodels/medrecord/indexers.py | 86 +- medmodels/medrecord/medrecord.py | 193 +- medmodels/medrecord/querying.py | 1583 ----------------- medmodels/medrecord/querying.pyi | 181 ++ medmodels/medrecord/tests/test_indexers.py | 158 +- medmodels/medrecord/tests/test_medrecord.py | 126 +- medmodels/medrecord/tests/test_overview.py | 60 +- medmodels/medrecord/tests/test_querying.py | 1170 ------------ medmodels/treatment_effect/builder.py | 12 +- .../tests/test_treatment_effect.py | 20 +- .../treatment_effect/treatment_effect.py | 73 +- 12 files changed, 577 insertions(+), 3099 deletions(-) delete mode 100644 medmodels/medrecord/querying.py create mode 100644 medmodels/medrecord/querying.pyi delete mode 100644 medmodels/medrecord/tests/test_querying.py diff --git a/medmodels/medrecord/__init__.py b/medmodels/medrecord/__init__.py index 3a9f3f6f..09b4e6dd 100644 --- a/medmodels/medrecord/__init__.py +++ b/medmodels/medrecord/__init__.py @@ -11,12 +11,12 @@ ) from medmodels.medrecord.medrecord import ( EdgeIndex, - EdgeOperation, + EdgeQuery, MedRecord, NodeIndex, - NodeOperation, + NodeQuery, ) -from medmodels.medrecord.querying import edge, node +from medmodels.medrecord.querying import EdgeOperand, NodeOperand from medmodels.medrecord.schema import AttributeType, GroupSchema, Schema __all__ = [ @@ -33,10 +33,10 @@ "AttributeType", "Schema", "GroupSchema", - "node", - "edge", "NodeIndex", "EdgeIndex", - "NodeOperation", - "EdgeOperation", + "EdgeQuery", + "NodeQuery", + "NodeOperand", + "EdgeOperand", ] diff --git a/medmodels/medrecord/indexers.py b/medmodels/medrecord/indexers.py index b76404e1..0496b1a8 100644 --- a/medmodels/medrecord/indexers.py +++ b/medmodels/medrecord/indexers.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict, Tuple, Union, overload -from medmodels.medrecord.querying import EdgeOperation, NodeOperation +from medmodels.medrecord.querying import EdgeQuery, NodeQuery from medmodels.medrecord.types import ( Attributes, AttributesInput, @@ -48,10 +48,10 @@ def __getitem__( self, key: Union[ NodeIndexInputList, - NodeOperation, + NodeQuery, slice, Tuple[ - Union[NodeIndexInputList, NodeOperation, slice], + Union[NodeIndexInputList, NodeQuery, slice], Union[MedRecordAttributeInputList, slice], ], ], @@ -60,7 +60,7 @@ def __getitem__( @overload def __getitem__( self, - key: Tuple[Union[NodeIndexInputList, NodeOperation, slice], MedRecordAttribute], + key: Tuple[Union[NodeIndexInputList, NodeQuery, slice], MedRecordAttribute], ) -> Dict[NodeIndex, MedRecordValue]: ... def __getitem__( @@ -68,10 +68,10 @@ def __getitem__( key: Union[ NodeIndex, NodeIndexInputList, - NodeOperation, + NodeQuery, slice, Tuple[ - Union[NodeIndex, NodeIndexInputList, NodeOperation, slice], + Union[NodeIndex, NodeIndexInputList, NodeQuery, slice], Union[MedRecordAttribute, MedRecordAttributeInputList, slice], ], ], @@ -87,7 +87,7 @@ def __getitem__( if isinstance(key, list): return self._medrecord._medrecord.node(key) - if isinstance(key, NodeOperation): + if isinstance(key, NodeQuery): return self._medrecord._medrecord.node(self._medrecord.select_nodes(key)) if isinstance(key, slice): @@ -112,7 +112,7 @@ def __getitem__( return {x: attributes[x][attribute_selection] for x in attributes.keys()} - if isinstance(index_selection, NodeOperation) and is_medrecord_attribute( + if isinstance(index_selection, NodeQuery) and is_medrecord_attribute( attribute_selection ): attributes = self._medrecord._medrecord.node( @@ -151,7 +151,7 @@ def __getitem__( for x in attributes.keys() } - if isinstance(index_selection, NodeOperation) and isinstance( + if isinstance(index_selection, NodeQuery) and isinstance( attribute_selection, list ): attributes = self._medrecord._medrecord.node( @@ -198,7 +198,7 @@ def __getitem__( return self._medrecord._medrecord.node(index_selection) - if isinstance(index_selection, NodeOperation) and isinstance( + if isinstance(index_selection, NodeQuery) and isinstance( attribute_selection, slice ): if ( @@ -230,7 +230,7 @@ def __getitem__( @overload def __setitem__( self, - key: Union[NodeIndex, NodeIndexInputList, NodeOperation, slice], + key: Union[NodeIndex, NodeIndexInputList, NodeQuery, slice], value: AttributesInput, ) -> None: ... @@ -238,7 +238,7 @@ def __setitem__( def __setitem__( self, key: Tuple[ - Union[NodeIndex, NodeIndexInputList, NodeOperation, slice], + Union[NodeIndex, NodeIndexInputList, NodeQuery, slice], Union[MedRecordAttribute, MedRecordAttributeInputList, slice], ], value: MedRecordValue, @@ -249,10 +249,10 @@ def __setitem__( key: Union[ NodeIndex, NodeIndexInputList, - NodeOperation, + NodeQuery, slice, Tuple[ - Union[NodeIndex, NodeIndexInputList, NodeOperation, slice], + Union[NodeIndex, NodeIndexInputList, NodeQuery, slice], Union[MedRecordAttribute, MedRecordAttributeInputList, slice], ], ], @@ -270,7 +270,7 @@ def __setitem__( return self._medrecord._medrecord.replace_node_attributes(key, value) - if isinstance(key, NodeOperation): + if isinstance(key, NodeQuery): if not is_attributes(value): raise ValueError("Invalid value type. Expected Attributes") @@ -311,7 +311,7 @@ def __setitem__( index_selection, attribute_selection, value ) - if isinstance(index_selection, NodeOperation) and is_medrecord_attribute( + if isinstance(index_selection, NodeQuery) and is_medrecord_attribute( attribute_selection ): if not is_medrecord_value(value): @@ -364,7 +364,7 @@ def __setitem__( return - if isinstance(index_selection, NodeOperation) and isinstance( + if isinstance(index_selection, NodeQuery) and isinstance( attribute_selection, list ): if not is_medrecord_value(value): @@ -440,7 +440,7 @@ def __setitem__( return - if isinstance(index_selection, NodeOperation) and isinstance( + if isinstance(index_selection, NodeQuery) and isinstance( attribute_selection, slice ): if ( @@ -494,7 +494,7 @@ def __setitem__( def __delitem__( self, key: Tuple[ - Union[NodeIndex, NodeIndexInputList, NodeOperation, slice], + Union[NodeIndex, NodeIndexInputList, NodeQuery, slice], Union[MedRecordAttribute, MedRecordAttributeInputList, slice], ], ) -> None: @@ -514,7 +514,7 @@ def __delitem__( index_selection, attribute_selection ) - if isinstance(index_selection, NodeOperation) and is_medrecord_attribute( + if isinstance(index_selection, NodeQuery) and is_medrecord_attribute( attribute_selection ): return self._medrecord._medrecord.remove_node_attribute( @@ -553,7 +553,7 @@ def __delitem__( return - if isinstance(index_selection, NodeOperation) and isinstance( + if isinstance(index_selection, NodeQuery) and isinstance( attribute_selection, list ): for attribute in attribute_selection: @@ -602,7 +602,7 @@ def __delitem__( index_selection, {} ) - if isinstance(index_selection, NodeOperation) and isinstance( + if isinstance(index_selection, NodeQuery) and isinstance( attribute_selection, slice ): if ( @@ -658,10 +658,10 @@ def __getitem__( self, key: Union[ EdgeIndexInputList, - EdgeOperation, + EdgeQuery, slice, Tuple[ - Union[EdgeIndexInputList, EdgeOperation, slice], + Union[EdgeIndexInputList, EdgeQuery, slice], Union[MedRecordAttributeInputList, slice], ], ], @@ -670,7 +670,7 @@ def __getitem__( @overload def __getitem__( self, - key: Tuple[Union[EdgeIndexInputList, EdgeOperation, slice], MedRecordAttribute], + key: Tuple[Union[EdgeIndexInputList, EdgeQuery, slice], MedRecordAttribute], ) -> Dict[EdgeIndex, MedRecordValue]: ... def __getitem__( @@ -678,10 +678,10 @@ def __getitem__( key: Union[ EdgeIndex, EdgeIndexInputList, - EdgeOperation, + EdgeQuery, slice, Tuple[ - Union[EdgeIndex, EdgeIndexInputList, EdgeOperation, slice], + Union[EdgeIndex, EdgeIndexInputList, EdgeQuery, slice], Union[MedRecordAttribute, MedRecordAttributeInputList, slice], ], ], @@ -697,7 +697,7 @@ def __getitem__( if isinstance(key, list): return self._medrecord._medrecord.edge(key) - if isinstance(key, EdgeOperation): + if isinstance(key, EdgeQuery): return self._medrecord._medrecord.edge(self._medrecord.select_edges(key)) if isinstance(key, slice): @@ -722,7 +722,7 @@ def __getitem__( return {x: attributes[x][attribute_selection] for x in attributes.keys()} - if isinstance(index_selection, EdgeOperation) and is_medrecord_attribute( + if isinstance(index_selection, EdgeQuery) and is_medrecord_attribute( attribute_selection ): attributes = self._medrecord._medrecord.edge( @@ -761,7 +761,7 @@ def __getitem__( for x in attributes.keys() } - if isinstance(index_selection, EdgeOperation) and isinstance( + if isinstance(index_selection, EdgeQuery) and isinstance( attribute_selection, list ): attributes = self._medrecord._medrecord.edge( @@ -808,7 +808,7 @@ def __getitem__( return self._medrecord._medrecord.edge(index_selection) - if isinstance(index_selection, EdgeOperation) and isinstance( + if isinstance(index_selection, EdgeQuery) and isinstance( attribute_selection, slice ): if ( @@ -840,7 +840,7 @@ def __getitem__( @overload def __setitem__( self, - key: Union[EdgeIndex, EdgeIndexInputList, EdgeOperation, slice], + key: Union[EdgeIndex, EdgeIndexInputList, EdgeQuery, slice], value: AttributesInput, ) -> None: ... @@ -848,7 +848,7 @@ def __setitem__( def __setitem__( self, key: Tuple[ - Union[EdgeIndex, EdgeIndexInputList, EdgeOperation, slice], + Union[EdgeIndex, EdgeIndexInputList, EdgeQuery, slice], Union[MedRecordAttribute, MedRecordAttributeInputList, slice], ], value: MedRecordValue, @@ -859,10 +859,10 @@ def __setitem__( key: Union[ EdgeIndex, EdgeIndexInputList, - EdgeOperation, + EdgeQuery, slice, Tuple[ - Union[EdgeIndex, EdgeIndexInputList, EdgeOperation, slice], + Union[EdgeIndex, EdgeIndexInputList, EdgeQuery, slice], Union[MedRecordAttribute, MedRecordAttributeInputList, slice], ], ], @@ -880,7 +880,7 @@ def __setitem__( return self._medrecord._medrecord.replace_edge_attributes(key, value) - if isinstance(key, EdgeOperation): + if isinstance(key, EdgeQuery): if not is_attributes(value): raise ValueError("Invalid value type. Expected Attributes") @@ -921,7 +921,7 @@ def __setitem__( index_selection, attribute_selection, value ) - if isinstance(index_selection, EdgeOperation) and is_medrecord_attribute( + if isinstance(index_selection, EdgeQuery) and is_medrecord_attribute( attribute_selection ): if not is_medrecord_value(value): @@ -974,7 +974,7 @@ def __setitem__( return - if isinstance(index_selection, EdgeOperation) and isinstance( + if isinstance(index_selection, EdgeQuery) and isinstance( attribute_selection, list ): if not is_medrecord_value(value): @@ -1048,7 +1048,7 @@ def __setitem__( return - if isinstance(index_selection, EdgeOperation) and isinstance( + if isinstance(index_selection, EdgeQuery) and isinstance( attribute_selection, slice ): if ( @@ -1102,7 +1102,7 @@ def __setitem__( def __delitem__( self, key: Tuple[ - Union[EdgeIndex, EdgeIndexInputList, EdgeOperation, slice], + Union[EdgeIndex, EdgeIndexInputList, EdgeQuery, slice], Union[MedRecordAttribute, MedRecordAttributeInputList, slice], ], ) -> None: @@ -1122,7 +1122,7 @@ def __delitem__( index_selection, attribute_selection ) - if isinstance(index_selection, EdgeOperation) and is_medrecord_attribute( + if isinstance(index_selection, EdgeQuery) and is_medrecord_attribute( attribute_selection ): return self._medrecord._medrecord.remove_edge_attribute( @@ -1161,7 +1161,7 @@ def __delitem__( return - if isinstance(index_selection, EdgeOperation) and isinstance( + if isinstance(index_selection, EdgeQuery) and isinstance( attribute_selection, list ): for attribute in attribute_selection: @@ -1210,7 +1210,7 @@ def __delitem__( index_selection, {} ) - if isinstance(index_selection, EdgeOperation) and isinstance( + if isinstance(index_selection, EdgeQuery) and isinstance( attribute_selection, slice ): if ( diff --git a/medmodels/medrecord/medrecord.py b/medmodels/medrecord/medrecord.py index 69a5bcae..c114ab8a 100644 --- a/medmodels/medrecord/medrecord.py +++ b/medmodels/medrecord/medrecord.py @@ -8,7 +8,7 @@ from medmodels.medrecord._overview import extract_attribute_summary, prettify_table from medmodels.medrecord.builder import MedRecordBuilder from medmodels.medrecord.indexers import EdgeIndexer, NodeIndexer -from medmodels.medrecord.querying import EdgeOperation, NodeOperation +from medmodels.medrecord.querying import EdgeOperand, EdgeQuery, NodeQuery from medmodels.medrecord.schema import Schema from medmodels.medrecord.types import ( AttributeInfo, @@ -415,11 +415,11 @@ def outgoing_edges(self, node: NodeIndex) -> List[EdgeIndex]: ... @overload def outgoing_edges( - self, node: Union[NodeIndexInputList, NodeOperation] + self, node: Union[NodeIndexInputList, NodeQuery] ) -> Dict[NodeIndex, List[EdgeIndex]]: ... def outgoing_edges( - self, node: Union[NodeIndex, NodeIndexInputList, NodeOperation] + self, node: Union[NodeIndex, NodeIndexInputList, NodeQuery] ) -> Union[List[EdgeIndex], Dict[NodeIndex, List[EdgeIndex]]]: """Lists the outgoing edges of the specified node(s) in the MedRecord. @@ -428,14 +428,14 @@ def outgoing_edges( its list of outgoing edge indices. Args: - node (Union[NodeIndex, NodeIndexInputList, NodeOperation]): One or more - node indices or a node operation. + node (Union[NodeIndex, NodeIndexInputList, NodeQuery]): One or more + node indices or a node query. Returns: Union[List[EdgeIndex], Dict[NodeIndex, List[EdgeIndex]]]: Outgoing edge indices for each specified node. """ - if isinstance(node, NodeOperation): + if isinstance(node, NodeQuery): return self._medrecord.outgoing_edges(self.select_nodes(node)) indices = self._medrecord.outgoing_edges( @@ -452,11 +452,11 @@ def incoming_edges(self, node: NodeIndex) -> List[EdgeIndex]: ... @overload def incoming_edges( - self, node: Union[NodeIndexInputList, NodeOperation] + self, node: Union[NodeIndexInputList, NodeQuery] ) -> Dict[NodeIndex, List[EdgeIndex]]: ... def incoming_edges( - self, node: Union[NodeIndex, NodeIndexInputList, NodeOperation] + self, node: Union[NodeIndex, NodeIndexInputList, NodeQuery] ) -> Union[List[EdgeIndex], Dict[NodeIndex, List[EdgeIndex]]]: """Lists the incoming edges of the specified node(s) in the MedRecord. @@ -465,14 +465,14 @@ def incoming_edges( its list of incoming edge indices. Args: - node (Union[NodeIndex, NodeIndexInputList, NodeOperation]): One or more - node indices or a node operation. + node (Union[NodeIndex, NodeIndexInputList, NodeQuery]): One or more + node indices or a node query. Returns: Union[List[EdgeIndex], Dict[NodeIndex, List[EdgeIndex]]]: Incoming edge indices for each specified node. """ - if isinstance(node, NodeOperation): + if isinstance(node, NodeQuery): return self._medrecord.incoming_edges(self.select_nodes(node)) indices = self._medrecord.incoming_edges( @@ -489,11 +489,11 @@ def edge_endpoints(self, edge: EdgeIndex) -> tuple[NodeIndex, NodeIndex]: ... @overload def edge_endpoints( - self, edge: Union[EdgeIndexInputList, EdgeOperation] + self, edge: Union[EdgeIndexInputList, EdgeQuery] ) -> Dict[EdgeIndex, tuple[NodeIndex, NodeIndex]]: ... def edge_endpoints( - self, edge: Union[EdgeIndex, EdgeIndexInputList, EdgeOperation] + self, edge: Union[EdgeIndex, EdgeIndexInputList, EdgeQuery] ) -> Union[ tuple[NodeIndex, NodeIndex], Dict[EdgeIndex, tuple[NodeIndex, NodeIndex]] ]: @@ -504,8 +504,8 @@ def edge_endpoints( a dictionary mapping each edge index to its tuple of node indices. Args: - edge (Union[EdgeIndex, EdgeIndexInputList, EdgeOperation]): One or more - edge indices. + edge (Union[EdgeIndex, EdgeIndexInputList, EdgeQuery]): One or more + edge indices or an edge query. Returns: Union[tuple[NodeIndex, NodeIndex], @@ -513,7 +513,7 @@ def edge_endpoints( Tuple of node indices or a dictionary mapping each edge to its node indices. """ - if isinstance(edge, EdgeOperation): + if isinstance(edge, EdgeQuery): return self._medrecord.edge_endpoints(self.select_edges(edge)) endpoints = self._medrecord.edge_endpoints( @@ -527,8 +527,8 @@ def edge_endpoints( def edges_connecting( self, - source_node: Union[NodeIndex, NodeIndexInputList, NodeOperation], - target_node: Union[NodeIndex, NodeIndexInputList, NodeOperation], + source_node: Union[NodeIndex, NodeIndexInputList, NodeQuery], + target_node: Union[NodeIndex, NodeIndexInputList, NodeQuery], directed: bool = True, ) -> List[EdgeIndex]: """Retrieves the edges connecting the specified source and target nodes in the MedRecord. @@ -539,11 +539,11 @@ def edges_connecting( target nodes. Args: - source_node (Union[NodeIndex, NodeIndexInputList, NodeOperation]): - The index or indices of the source node(s), or a NodeOperation to + source_node (Union[NodeIndex, NodeIndexInputList, NodeQuery]): + The index or indices of the source node(s), or a node query to select source nodes. - target_node (Union[NodeIndex, NodeIndexInputList, NodeOperation]): - The index or indices of the target node(s), or a NodeOperation to + target_node (Union[NodeIndex, NodeIndexInputList, NodeQuery]): + The index or indices of the target node(s), or a node query to select target nodes. directed (bool, optional): Whether to consider edges as directed. @@ -552,10 +552,10 @@ def edges_connecting( target nodes. """ - if isinstance(source_node, NodeOperation): + if isinstance(source_node, NodeQuery): source_node = self.select_nodes(source_node) - if isinstance(target_node, NodeOperation): + if isinstance(target_node, NodeQuery): target_node = self.select_nodes(target_node) if directed: @@ -574,11 +574,11 @@ def remove_nodes(self, nodes: NodeIndex) -> Attributes: ... @overload def remove_nodes( - self, nodes: Union[NodeIndexInputList, NodeOperation] + self, nodes: Union[NodeIndexInputList, NodeQuery] ) -> Dict[NodeIndex, Attributes]: ... def remove_nodes( - self, nodes: Union[NodeIndex, NodeIndexInputList, NodeOperation] + self, nodes: Union[NodeIndex, NodeIndexInputList, NodeQuery] ) -> Union[Attributes, Dict[NodeIndex, Attributes]]: """Removes a node or multiple nodes from the MedRecord and returns their attributes. @@ -587,14 +587,14 @@ def remove_nodes( index to its attributes. Args: - nodes (Union[NodeIndex, NodeIndexInputList, NodeOperation]): One or more - node indices or a node operation. + nodes (Union[NodeIndex, NodeIndexInputList, NodeQuery]): One or more + node indices or a node query. Returns: Union[Attributes, Dict[NodeIndex, Attributes]]: Attributes of the removed node(s). """ - if isinstance(nodes, NodeOperation): + if isinstance(nodes, NodeQuery): return self._medrecord.remove_nodes(self.select_nodes(nodes)) attributes = self._medrecord.remove_nodes( @@ -720,11 +720,11 @@ def remove_edges(self, edges: EdgeIndex) -> Attributes: ... @overload def remove_edges( - self, edges: Union[EdgeIndexInputList, EdgeOperation] + self, edges: Union[EdgeIndexInputList, EdgeQuery] ) -> Dict[EdgeIndex, Attributes]: ... def remove_edges( - self, edges: Union[EdgeIndex, EdgeIndexInputList, EdgeOperation] + self, edges: Union[EdgeIndex, EdgeIndexInputList, EdgeQuery] ) -> Union[Attributes, Dict[EdgeIndex, Attributes]]: """Removes an edge or multiple edges from the MedRecord and returns their attributes. @@ -733,14 +733,14 @@ def remove_edges( index to its attributes. Args: - edge (Union[EdgeIndex, EdgeIndexInputList, EdgeOperation]): One or more - edge indices or an edge operation. + edge (Union[EdgeIndex, EdgeIndexInputList, EdgeQuery]): One or more + edge indices or an edge query. Returns: Union[Attributes, Dict[EdgeIndex, Attributes]]: Attributes of the removed edge(s). """ - if isinstance(edges, EdgeOperation): + if isinstance(edges, EdgeQuery): return self._medrecord.remove_edges(self.select_edges(edges)) attributes = self._medrecord.remove_edges( @@ -864,8 +864,8 @@ def add_edges_polars( def add_group( self, group: Group, - nodes: Optional[Union[NodeIndex, NodeIndexInputList, NodeOperation]] = None, - edges: Optional[Union[EdgeIndex, EdgeIndexInputList, EdgeOperation]] = None, + nodes: Optional[Union[NodeIndex, NodeIndexInputList, NodeQuery]] = None, + edges: Optional[Union[EdgeIndex, EdgeIndexInputList, EdgeQuery]] = None, ) -> None: """Adds a group to the MedRecord instance with an optional list of node indices. @@ -874,20 +874,20 @@ def add_group( Args: group (Group): The name of the group to add. - nodes (Optional[Union[NodeIndex, NodeIndexInputList, NodeOperation]]): - One or more node indices or a node operation to add + nodes (Optional[Union[NodeIndex, NodeIndexInputList, NodeQuery]]): + One or more node indices or a node query to add to the group, optional. - edges (Optional[Union[EdgeIndex, EdgeIndexInputList, EdgeOperation]]): - One or more edge indices or an edge operation to add + edges (Optional[Union[EdgeIndex, EdgeIndexInputList, EdgeQuery]]): + One or more edge indices or an edge query to add to the group, optional. Returns: None """ - if isinstance(nodes, NodeOperation): + if isinstance(nodes, NodeQuery): nodes = self.select_nodes(nodes) - if isinstance(edges, EdgeOperation): + if isinstance(edges, NodeQuery): edges = self.select_edges(edges) if nodes is not None and edges is not None: @@ -921,19 +921,19 @@ def remove_groups(self, groups: Union[Group, GroupInputList]) -> None: ) def add_nodes_to_group( - self, group: Group, nodes: Union[NodeIndex, NodeIndexInputList, NodeOperation] + self, group: Group, nodes: Union[NodeIndex, NodeIndexInputList, NodeQuery] ) -> None: """Adds one or more nodes to a specified group in the MedRecord. Args: group (Group): The name of the group to add nodes to. - nodes (Union[NodeIndex, NodeIndexInputList, NodeOperation]): One or more - node indices or a node operation to add to the group. + nodes (Union[NodeIndex, NodeIndexInputList, NodeQuery]): One or more + node indices or a node query to add to the group. Returns: None """ - if isinstance(nodes, NodeOperation): + if isinstance(nodes, NodeQuery): return self._medrecord.add_nodes_to_group(group, self.select_nodes(nodes)) return self._medrecord.add_nodes_to_group( @@ -941,19 +941,19 @@ def add_nodes_to_group( ) def add_edges_to_group( - self, group: Group, edges: Union[EdgeIndex, EdgeIndexInputList, EdgeOperation] + self, group: Group, edges: Union[EdgeIndex, EdgeIndexInputList, EdgeQuery] ) -> None: """Adds one or more edges to a specified group in the MedRecord. Args: group (Group): The name of the group to add edges to. - edges (Union[EdgeIndex, EdgeIndexInputList, EdgeOperation]): One or more - edge indices or an edge operation to add to the group. + edges (Union[EdgeIndex, EdgeIndexInputList, EdgeQuery]): One or more + edge indices or an edge query to add to the group. Returns: None """ - if isinstance(edges, EdgeOperation): + if isinstance(edges, EdgeQuery): return self._medrecord.add_edges_to_group(group, self.select_edges(edges)) return self._medrecord.add_edges_to_group( @@ -961,19 +961,19 @@ def add_edges_to_group( ) def remove_nodes_from_group( - self, group: Group, nodes: Union[NodeIndex, NodeIndexInputList, NodeOperation] + self, group: Group, nodes: Union[NodeIndex, NodeIndexInputList, NodeQuery] ) -> None: """Removes one or more nodes from a specified group in the MedRecord. Args: group (Group): The name of the group from which to remove nodes. - nodes (Union[NodeIndex, NodeIndexInputList, NodeOperation]): One or more - node indices or a node operation to remove from the group. + nodes (Union[NodeIndex, NodeIndexInputList, NodeQuery]): One or more + node indices or a node query to remove from the group. Returns: None """ - if isinstance(nodes, NodeOperation): + if isinstance(nodes, NodeQuery): return self._medrecord.remove_nodes_from_group( group, self.select_nodes(nodes) ) @@ -983,19 +983,19 @@ def remove_nodes_from_group( ) def remove_edges_from_group( - self, group: Group, edges: Union[EdgeIndex, EdgeIndexInputList, EdgeOperation] + self, group: Group, edges: Union[EdgeIndex, EdgeIndexInputList, EdgeQuery] ) -> None: """Removes one or more edges from a specified group in the MedRecord. Args: group (Group): The name of the group from which to remove edges. - edges (Union[EdgeIndex, EdgeIndexInputList, EdgeOperation]): One or more - edge indices or an edge operation to remove from the group. + edges (Union[EdgeIndex, EdgeIndexInputList, EdgeQuery]): One or more + edge indices or an edge query to remove from the group. Returns: None """ - if isinstance(edges, EdgeOperation): + if isinstance(edges, EdgeQuery): return self._medrecord.remove_edges_from_group( group, self.select_edges(edges) ) @@ -1071,11 +1071,11 @@ def groups_of_node(self, node: NodeIndex) -> List[Group]: ... @overload def groups_of_node( - self, node: Union[NodeIndexInputList, NodeOperation] + self, node: Union[NodeIndexInputList, NodeQuery] ) -> Dict[NodeIndex, List[Group]]: ... def groups_of_node( - self, node: Union[NodeIndex, NodeIndexInputList, NodeOperation] + self, node: Union[NodeIndex, NodeIndexInputList, NodeQuery] ) -> Union[List[Group], Dict[NodeIndex, List[Group]]]: """Retrieves the groups associated with the specified node(s) in the MedRecord. @@ -1084,14 +1084,14 @@ def groups_of_node( its list of groups. Args: - node (Union[NodeIndex, NodeIndexInputList, NodeOperation]): One or more - node indices or a node operation. + node (Union[NodeIndex, NodeIndexInputList, NodeQuery]): One or more + node indices or a node query. Returns: Union[List[Group], Dict[NodeIndex, List[Group]]]: Groups associated with each node. """ - if isinstance(node, NodeOperation): + if isinstance(node, NodeQuery): return self._medrecord.groups_of_node(self.select_nodes(node)) groups = self._medrecord.groups_of_node( @@ -1108,11 +1108,11 @@ def groups_of_edge(self, edge: EdgeIndex) -> List[Group]: ... @overload def groups_of_edge( - self, edge: Union[EdgeIndexInputList, EdgeOperation] + self, edge: Union[EdgeIndexInputList, EdgeQuery] ) -> Dict[EdgeIndex, List[Group]]: ... def groups_of_edge( - self, edge: Union[EdgeIndex, EdgeIndexInputList, EdgeOperation] + self, edge: Union[EdgeIndex, EdgeIndexInputList, EdgeQuery] ) -> Union[List[Group], Dict[EdgeIndex, List[Group]]]: """Retrieves the groups associated with the specified edge(s) in the MedRecord. @@ -1121,14 +1121,14 @@ def groups_of_edge( its list of groups. Args: - edge (Union[EdgeIndex, EdgeIndexInputList, EdgeOperation]): One or more - edge indices or an edge operation. + edge (Union[EdgeIndex, EdgeIndexInputList, EdgeQuery]): One or more + edge indices or an edge query. Returns: Union[List[Group], Dict[EdgeIndex, List[Group]]]: Groups associated with each edge. """ - if isinstance(edge, EdgeOperation): + if isinstance(edge, EdgeQuery): return self._medrecord.groups_of_edge(self.select_edges(edge)) groups = self._medrecord.groups_of_edge( @@ -1207,13 +1207,13 @@ def neighbors( @overload def neighbors( self, - node: Union[NodeIndexInputList, NodeOperation], + node: Union[NodeIndexInputList, NodeQuery], directed: bool = True, ) -> Dict[NodeIndex, List[NodeIndex]]: ... def neighbors( self, - node: Union[NodeIndex, NodeIndexInputList, NodeOperation], + node: Union[NodeIndex, NodeIndexInputList, NodeQuery], directed: bool = True, ) -> Union[List[NodeIndex], Dict[NodeIndex, List[NodeIndex]]]: """Retrieves the neighbors of the specified node(s) in the MedRecord. @@ -1223,14 +1223,14 @@ def neighbors( each node index to its list of neighboring nodes. Args: - node (Union[NodeIndex, NodeIndexInputList, NodeOperation]): One or more - node indices or a node operation. - directed (bool, optional): Whether to consider edges as directed + node (Union[NodeIndex, NodeIndexInputList, NodeQuery]): One or more + node indices or a node query. + directed (bool, optional): Whether to consider edges as directed. Returns: Union[List[NodeIndex], Dict[NodeIndex, List[NodeIndex]]]: Neighboring nodes. """ - if isinstance(node, NodeOperation): + if isinstance(node, NodeQuery): node = self.select_nodes(node) if directed: @@ -1257,50 +1257,9 @@ def clear(self) -> None: """ return self._medrecord.clear() - def select_nodes(self, operation: NodeOperation) -> List[NodeIndex]: - """Selects nodes based on a specified operation and returns their indices. - - Args: - operation (NodeOperation): The operation to apply to select nodes. - - Returns: - List[NodeIndex]: A list of node indices that satisfy the operation. - """ - return self._medrecord.select_nodes(operation._node_operation) - - def select_edges(self, operation: EdgeOperation) -> List[EdgeIndex]: - """Selects edges based on a specified operation and returns their indices. - - Args: - operation (EdgeOperation): The operation to apply to select edges. - - Returns: - List[EdgeIndex]: A list of edge indices that satisfy the operation. - """ - return self._medrecord.select_edges(operation._edge_operation) - - @overload - def __getitem__(self, key: NodeOperation) -> List[NodeIndex]: ... - - @overload - def __getitem__(self, key: EdgeOperation) -> List[EdgeIndex]: ... - - def __getitem__( - self, key: Union[NodeOperation, EdgeOperation] - ) -> Union[List[NodeIndex], List[EdgeIndex]]: - """Allows selection of nodes or edges using operations directly via indexing. - - Args: - key (Union[NodeOperation, EdgeOperation]): Operation to select nodes - or edges. - - Returns: - Union[List[NodeIndex], List[EdgeIndex]]: Node or edge indices selected. - """ - if isinstance(key, NodeOperation): - return self.select_nodes(key) + def select_nodes(self, query: NodeQuery) -> List[NodeIndex]: ... - return self.select_edges(key) + def select_edges(self, query: EdgeQuery) -> List[EdgeIndex]: ... def clone(self) -> MedRecord: """Clones the MedRecord instance. diff --git a/medmodels/medrecord/querying.py b/medmodels/medrecord/querying.py deleted file mode 100644 index 5a6e8385..00000000 --- a/medmodels/medrecord/querying.py +++ /dev/null @@ -1,1583 +0,0 @@ -from __future__ import annotations - -from typing import List, Union - -from medmodels._medmodels import ( - PyEdgeAttributeOperand, - PyEdgeIndexOperand, - PyEdgeOperand, - PyEdgeOperation, - PyNodeAttributeOperand, - PyNodeIndexOperand, - PyNodeOperand, - PyNodeOperation, - PyValueArithmeticOperation, - PyValueTransformationOperation, -) -from medmodels.medrecord.types import ( - EdgeIndex, - Group, - MedRecordAttribute, - MedRecordValue, - NodeIndex, -) - -ValueOperand = Union[ - MedRecordValue, - MedRecordAttribute, - PyValueArithmeticOperation, - PyValueTransformationOperation, -] - - -class NodeOperation: - _node_operation: PyNodeOperation - - def __init__(self, node_operation: PyNodeOperation): - self._node_operation = node_operation - - def logical_and(self, operation: NodeOperation) -> NodeOperation: - """Combines this NodeOperation with another using a logical AND, resulting in a new NodeOperation that is true only if both original operations are true. - - This method allows for the chaining of conditions to refine queries on nodes. - - Args: - operation (NodeOperation): Another NodeOperation to be combined with the - current one. - - Returns: - NodeOperation: A new NodeOperation representing the logical AND of this - operation with another. - """ - return NodeOperation( - self._node_operation.logical_and(operation._node_operation) - ) - - def __and__(self, operation: NodeOperation) -> NodeOperation: - return self.logical_and(operation) - - def logical_or(self, operation: NodeOperation) -> NodeOperation: - """Combines this NodeOperation with another using a logical OR, resulting in a new NodeOperation that is true if either of the original operations is true. - - This method enables the combination of conditions to expand queries on nodes. - - Args: - operation (NodeOperation): Another NodeOperation to be combined with the - current one. - - Returns: - NodeOperation: A new NodeOperation representing the logical OR of this - operation with another. - """ - return NodeOperation(self._node_operation.logical_or(operation._node_operation)) - - def __or__(self, operation: NodeOperation) -> NodeOperation: - return self.logical_or(operation) - - def logical_xor(self, operation: NodeOperation) -> NodeOperation: - """Combines this NodeOperation with another using a logical XOR, resulting in a new NodeOperation that is true only if exactly one of the original operations is true. - - This method is useful for creating conditions that must be - exclusively true. - - Args: - operation (NodeOperation): Another NodeOperation to be combined with the - current one. - - Returns: - NodeOperation: A new NodeOperation representing the logical XOR of this - operation with another. - """ - return NodeOperation( - self._node_operation.logical_xor(operation._node_operation) - ) - - def __xor__(self, operation: NodeOperation) -> NodeOperation: - return self.logical_xor(operation) - - def logical_not(self) -> NodeOperation: - """Creates a new NodeOperation that is the logical NOT of this operation, inversing the current condition. - - This method is useful for negating a condition - to create queries on nodes. - - Returns: - NodeOperation: A new NodeOperation representing the logical NOT of - this operation. - """ - return NodeOperation(self._node_operation.logical_not()) - - def __invert__(self) -> NodeOperation: - return self.logical_not() - - -class EdgeOperation: - _edge_operation: PyEdgeOperation - - def __init__(self, edge_operation: PyEdgeOperation) -> None: - self._edge_operation = edge_operation - - def logical_and(self, operation: EdgeOperation) -> EdgeOperation: - """Combines this EdgeOperation with another using a logical AND, resulting in a new EdgeOperation that is true only if both original operations are true. - - This method allows for the chaining of conditions to refine queries on nodes. - - Args: - operation (EdgeOperation): Another EdgeOperation to be combined with the - current one. - - Returns: - EdgeOperation: A new EdgeOperation representing the logical AND of this - operation with another. - """ - return EdgeOperation( - self._edge_operation.logical_and(operation._edge_operation) - ) - - def __and__(self, operation: EdgeOperation) -> EdgeOperation: - return self.logical_and(operation) - - def logical_or(self, operation: EdgeOperation) -> EdgeOperation: - """Combines this EdgeOperation with another using a logical OR, resulting in a new EdgeOperation that is true if either of the original operations is true. - - This method enables the combination of conditions to expand queries on nodes. - - Args: - operation (EdgeOperation): Another EdgeOperation to be combined with the - current one. - - Returns: - EdgeOperation: A new EdgeOperation representing the logical OR of this - operation with another. - """ - return EdgeOperation(self._edge_operation.logical_or(operation._edge_operation)) - - def __or__(self, operation: EdgeOperation) -> EdgeOperation: - return self.logical_or(operation) - - def logical_xor(self, operation: EdgeOperation) -> EdgeOperation: - """Combines this EdgeOperation with another using a logical XOR, resulting in a new EdgeOperation that is true only if exactly one of the original operations is true. - - This method is useful for creating conditions that must be - exclusively true. - - Args: - operation (EdgeOperation): Another EdgeOperation to be combined with the - current one. - - Returns: - EdgeOperation: A new EdgeOperation representing the logical XOR of this - operation with another. - """ - return EdgeOperation( - self._edge_operation.logical_xor(operation._edge_operation) - ) - - def __xor__(self, operation: EdgeOperation) -> EdgeOperation: - return self.logical_xor(operation) - - def logical_not(self) -> EdgeOperation: - """Creates a new EdgeOperation that is the logical NOT of this operation, inversing the current condition. - - This method is useful for negating a condition - to create queries on nodes. - - Returns: - EdgeOperation: A new EdgeOperation representing the logical NOT of - this operation. - """ - return EdgeOperation(self._edge_operation.logical_not()) - - def __invert__(self) -> EdgeOperation: - return self.logical_not() - - -class NodeAttributeOperand: - _node_attribute_operand: PyNodeAttributeOperand - - def __init__(self, node_attribute_operand: PyNodeAttributeOperand) -> None: - self._node_attribute_operand = node_attribute_operand - - def greater( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand is greater than the specified value or operand. - - Args: - operand (ValueOperand): The value or operand to compare against. - - Returns: - NodeOperation: A NodeOperation representing the greater-than comparison. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.greater(operand._node_attribute_operand) - ) - - return NodeOperation(self._node_attribute_operand.greater(operand)) - - def __gt__( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - return self.greater(operand) - - def less(self, operand: Union[ValueOperand, NodeAttributeOperand]) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand is less than the specified value or operand. - - Args: - operand (ValueOperand): The value or operand to compare against. - - Returns: - NodeOperation: A NodeOperation representing the less-than comparison. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.less(operand._node_attribute_operand) - ) - - return NodeOperation(self._node_attribute_operand.less(operand)) - - def __lt__( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - return self.less(operand) - - def greater_or_equal( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand is greater than or equal to the specified value or operand. - - Args: - operand (ValueOperand): The value or operand to compare against. - - Returns: - NodeOperation: A NodeOperation representing the - greater-than-or-equal-to comparison. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.greater_or_equal( - operand._node_attribute_operand - ) - ) - - return NodeOperation(self._node_attribute_operand.greater_or_equal(operand)) - - def __ge__( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - return self.greater_or_equal(operand) - - def less_or_equal( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand is less than or equal to the specified value or operand. - - Args: - operand (ValueOperand): The value or operand to compare against. - - Returns: - NodeOperation: A NodeOperation representing the - less-than-or-equal-to comparison. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.less_or_equal( - operand._node_attribute_operand - ) - ) - - return NodeOperation(self._node_attribute_operand.less_or_equal(operand)) - - def __le__( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - return self.less_or_equal(operand) - - def equal( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand is equal to the specified value or operand. - - Args: - operand (ValueOperand): The value or operand to compare against.y - - Returns: - NodeOperation: A NodeOperation representing the equality comparison. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.equal(operand._node_attribute_operand) - ) - - return NodeOperation(self._node_attribute_operand.equal(operand)) - - def __eq__( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - return self.equal(operand) - - def not_equal( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand is not equal to the specified value or operand. - - Args: - operand (ValueOperand): The value or operand to compare against. - - Returns: - NodeOperation: A NodeOperation representing the not-equal comparison. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.not_equal(operand._node_attribute_operand) - ) - - return NodeOperation(self._node_attribute_operand.not_equal(operand)) - - def __ne__( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - return self.not_equal(operand) - - def is_in(self, values: List[MedRecordValue]) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand is found within the specified list of values. - - Args: - values (List[MedRecordValue]): The list of values to check the - attribute against. - - Returns: - NodeOperation: A NodeOperation representing the is-in comparison. - """ - return NodeOperation(self._node_attribute_operand.is_in(values)) - - def not_in(self, values: List[MedRecordValue]) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand is not found within the specified list of values. - - Args: - values (List[MedRecordValue]): The list of values to check the - attribute against. - - Returns: - NodeOperation: A NodeOperation representing the not-in comparison. - """ - return NodeOperation(self._node_attribute_operand.not_in(values)) - - def starts_with(self, operand: ValueOperand) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand starts with the specified value or operand. - - Args: - operand (ValueOperand): The value or operand to compare - the starting sequence against. - - Returns: - NodeOperation: A NodeOperation representing the starts-with condition. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.starts_with( - operand._node_attribute_operand - ) - ) - - return NodeOperation(self._node_attribute_operand.starts_with(operand)) - - def ends_with( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand ends with the specified value or operand. - - Args: - operand (ValueOperand): The value or operand to compare - the ending sequence against. - - Returns: - NodeOperation: A NodeOperation representing the ends-with condition. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.ends_with(operand._node_attribute_operand) - ) - - return NodeOperation(self._node_attribute_operand.ends_with(operand)) - - def contains( - self, operand: Union[ValueOperand, NodeAttributeOperand] - ) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the attribute represented by this operand contains the specified value or operand within it. - - Args: - operand (ValueOperand): The value or operand to check for containment. - - Returns: - NodeOperation: A NodeOperation representing the contains condition. - """ - if isinstance(operand, NodeAttributeOperand): - return NodeOperation( - self._node_attribute_operand.contains(operand._node_attribute_operand) - ) - - return NodeOperation(self._node_attribute_operand.contains(operand)) - - def add(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the sum of the attribute's value and the specified value. - - Args: - value (MedRecordValue): The value to add to the attribute's value. - - Returns: - ValueOperand: The result of the addition operation. - """ - return self._node_attribute_operand.add(value) - - def __add__(self, value: MedRecordValue) -> ValueOperand: - return self.add(value) - - def sub(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the difference between the attribute's value and the specified value. - - Args: - value (MedRecordValue): The value to subtract from the attribute's value. - - Returns: - ValueOperand: The result of the subtraction operation. - """ - return self._node_attribute_operand.sub(value) - - def __sub__(self, value: MedRecordValue) -> ValueOperand: - return self.sub(value) - - def mul(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the product of the attribute's value and the specified value. - - Args: - value (MedRecordValue): The value to multiply the attribute's value by. - - Returns: - ValueOperand: The result of the multiplication operation. - """ - return self._node_attribute_operand.mul(value) - - def __mul__(self, value: MedRecordValue) -> ValueOperand: - return self.mul(value) - - def div(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the division of the attribute's value by the specified value. - - Args: - value (MedRecordValue): The value to divide the attribute's value by. - - Returns: - ValueOperand: The result of the division operation. - """ - return self._node_attribute_operand.div(value) - - def __truediv__(self, value: MedRecordValue) -> ValueOperand: - return self.div(value) - - def pow(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the result of raising the attribute's value to the power of the specified value. - - Args: - value (MedRecordValue): The value to raise the attribute's value to. - - Returns: - ValueOperand: The result of the exponentiation operation. - """ - return self._node_attribute_operand.pow(value) - - def __pow__(self, value: MedRecordValue) -> ValueOperand: - return self.pow(value) - - def mod(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the remainder of dividing the attribute's value by the specified value. - - Args: - value (MedRecordValue): The value to divide the attribute's value by. - - Returns: - ValueOperand: The result of the modulo operation. - """ - return self._node_attribute_operand.mod(value) - - def __mod__(self, value: MedRecordValue) -> ValueOperand: - return self.mod(value) - - def round(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of rounding the attribute's value. - - Returns: - ValueOperand: The result of the rounding operation. - """ - return self._node_attribute_operand.round() - - def ceil(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of applying the ceiling function to the attribute's value, effectively rounding it up to the nearest whole number. - - Returns: - ValueOperand: The result of the ceiling operation. - """ - return self._node_attribute_operand.ceil() - - def floor(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of applying the floor function to the attribute's value, effectively rounding it down to the nearest whole number. - - Returns: - ValueOperand: The result of the floor operation. - """ - return self._node_attribute_operand.floor() - - def abs(self) -> ValueOperand: - """Creates a new ValueOperand representing the absolute value of the attribute's value. - - Returns: - ValueOperand: The absolute value of the attribute's value. - """ - return self._node_attribute_operand.abs() - - def sqrt(self) -> ValueOperand: - """Creates a new ValueOperand representing the square root of the attribute's value. - - Returns: - ValueOperand: The square root of the attribute's value. - """ - return self._node_attribute_operand.sqrt() - - def trim(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of trimming whitespace from both ends of the attribute's value. - - Returns: - ValueOperand: The attribute's value with leading and trailing - whitespace removed. - """ - return self._node_attribute_operand.trim() - - def trim_start(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of trimming whitespace from the start (left side) of the attribute's value. - - Returns: - ValueOperand: The attribute's value with leading whitespace removed. - """ - return self._node_attribute_operand.trim_start() - - def trim_end(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of trimming whitespace from the end (right side) of the attribute's value. - - Returns: - ValueOperand: The attribute's value with trailing whitespace removed. - """ - return self._node_attribute_operand.trim_end() - - def lowercase(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of converting all characters in the attribute's value to lowercase. - - Returns: - ValueOperand: The attribute's value in lowercase letters. - """ - return self._node_attribute_operand.lowercase() - - def uppercase(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of converting all characters in the attribute's value to uppercase. - - Returns: - ValueOperand: The attribute's value in uppercase letters. - """ - return self._node_attribute_operand.uppercase() - - def slice(self, start: int, end: int) -> ValueOperand: - """Creates a new ValueOperand representing the result of slicing the attribute's value using the specified start and end indices. - - Args: - start (int): The index at which to start the slice. - end (int): The index at which to end the slice. - - Returns: - ValueOperand: The attribute's value with the specified slice applied. - """ - return self._node_attribute_operand.slice(start, end) - - -class EdgeAttributeOperand: - _edge_attribute_operand: PyEdgeAttributeOperand - - def __init__(self, edge_attribute_operand: PyEdgeAttributeOperand) -> None: - self._edge_attribute_operand = edge_attribute_operand - - def greater( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand is greater than the specified value or operand. - - Args: - operand (ValueOperand): The value or operand to compare against. - - Returns: - EdgeOperation: A EdgeOperation representing the greater-than comparison. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.greater(operand._edge_attribute_operand) - ) - - return EdgeOperation(self._edge_attribute_operand.greater(operand)) - - def __gt__( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - return self.greater(operand) - - def less(self, operand: Union[ValueOperand, EdgeAttributeOperand]) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand is less than the specified value or operand. - - Args: - operand (ValueOperand): The value or operand to compare against. - - Returns: - EdgeOperation: A EdgeOperation representing the less-than comparison. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.less(operand._edge_attribute_operand) - ) - - return EdgeOperation(self._edge_attribute_operand.less(operand)) - - def __lt__(self, operand: ValueOperand) -> EdgeOperation: - return self.less(operand) - - def greater_or_equal( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand is greater than or equal to the specified value or operand. - - Args: - operand (ValueOperand): The value or operand to compare against. - - Returns: - EdgeOperation: A EdgeOperation representing the - greater-than-or-equal-to comparison. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.greater_or_equal( - operand._edge_attribute_operand - ) - ) - - return EdgeOperation(self._edge_attribute_operand.greater_or_equal(operand)) - - def __ge__(self, operand: ValueOperand) -> EdgeOperation: - return self.greater_or_equal(operand) - - def less_or_equal( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand is less than or equal to the specified value or operand. - - Args: - operand (ValueOperand): The value or operand to compare against. - - Returns: - EdgeOperation: A EdgeOperation representing the - less-than-or-equal-to comparison. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.less_or_equal( - operand._edge_attribute_operand - ) - ) - - return EdgeOperation(self._edge_attribute_operand.less_or_equal(operand)) - - def __le__(self, operand: ValueOperand) -> EdgeOperation: - return self.less_or_equal(operand) - - def equal( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand is equal to the specified value or operand. - - Args: - operand (ValueOperand): The value or operand to compare against. - - Returns: - EdgeOperation: A EdgeOperation representing the equality comparison. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.equal(operand._edge_attribute_operand) - ) - - return EdgeOperation(self._edge_attribute_operand.equal(operand)) - - def __eq__( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - return self.equal(operand) - - def not_equal( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand is not equal to the specified value or operand. - - Args: - operand (ValueOperand): The value or operand to compare against. - - Returns: - EdgeOperation: A EdgeOperation representing the not-equal comparison. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.not_equal(operand._edge_attribute_operand) - ) - - return EdgeOperation(self._edge_attribute_operand.not_equal(operand)) - - def __ne__( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - return self.not_equal(operand) - - def is_in(self, values: List[MedRecordValue]) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand is found within the specified list of values. - - Args: - values (List[MedRecordValue]): The list of values to check the - attribute against. - - Returns: - EdgeOperation: A EdgeOperation representing the is-in comparison. - """ - return EdgeOperation(self._edge_attribute_operand.is_in(values)) - - def not_in(self, values: List[MedRecordValue]) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand is not found within the specified list of values. - - Args: - values (List[MedRecordValue]): The list of values to check the - attribute against. - - Returns: - EdgeOperation: A EdgeOperation representing the not-in comparison. - """ - return EdgeOperation(self._edge_attribute_operand.not_in(values)) - - def starts_with( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand starts with the specified value or operand. - - Args: - operand (ValueOperand): The value or operand to compare - the starting sequence against. - - Returns: - EdgeOperation: A EdgeOperation representing the starts-with condition. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.starts_with( - operand._edge_attribute_operand - ) - ) - - return EdgeOperation(self._edge_attribute_operand.starts_with(operand)) - - def ends_with( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand ends with the specified value or operand. - - Args: - operand (ValueOperand): The value or operand to compare - the ending sequence against. - - Returns: - EdgeOperation: A EdgeOperation representing the ends-with condition. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.ends_with(operand._edge_attribute_operand) - ) - - return EdgeOperation(self._edge_attribute_operand.ends_with(operand)) - - def contains( - self, operand: Union[ValueOperand, EdgeAttributeOperand] - ) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the attribute represented by this operand contains the specified value or operand within it. - - Args: - operand (ValueOperand): The value or operand to check for containment. - - Returns: - EdgeOperation: A EdgeOperation representing the contains condition. - """ - if isinstance(operand, EdgeAttributeOperand): - return EdgeOperation( - self._edge_attribute_operand.contains(operand._edge_attribute_operand) - ) - - return EdgeOperation(self._edge_attribute_operand.contains(operand)) - - def add(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the sum of the attribute's value and the specified value. - - Args: - value (MedRecordValue): The value to add to the attribute's value. - - Returns: - ValueOperand: The result of the addition operation. - """ - return self._edge_attribute_operand.add(value) - - def __add__(self, value: MedRecordValue) -> ValueOperand: - return self.add(value) - - def sub(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the difference between the attribute's value and the specified value. - - Args: - value (MedRecordValue): The value to subtract from the attribute's value. - - Returns: - ValueOperand: The result of the subtraction operation. - """ - return self._edge_attribute_operand.sub(value) - - def __sub__(self, value: MedRecordValue) -> ValueOperand: - return self.sub(value) - - def mul(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the product of the attribute's value and the specified value. - - Args: - value (MedRecordValue): The value to multiply the attribute's value by. - - Returns: - ValueOperand: The result of the multiplication operation. - """ - return self._edge_attribute_operand.mul(value) - - def __mul__(self, value: MedRecordValue) -> ValueOperand: - return self.mul(value) - - def div(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the division of the attribute's value by the specified value. - - Args: - value (MedRecordValue): The value to divide the attribute's value by. - - Returns: - ValueOperand: The result of the division operation. - """ - return self._edge_attribute_operand.div(value) - - def __truediv__(self, value: MedRecordValue) -> ValueOperand: - return self.div(value) - - def pow(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the result of raising the attribute's value to the power of the specified value. - - Args: - value (MedRecordValue): The value to raise the attribute's value to. - - Returns: - ValueOperand: The result of the exponentiation operation. - """ - return self._edge_attribute_operand.pow(value) - - def __pow__(self, value: MedRecordValue) -> ValueOperand: - return self.pow(value) - - def mod(self, value: MedRecordValue) -> ValueOperand: - """Creates a new ValueOperand representing the remainder of dividing the attribute's value by the specified value. - - Args: - value (MedRecordValue): The value to divide the attribute's value by. - - Returns: - ValueOperand: The result of the modulo operation. - """ - return self._edge_attribute_operand.mod(value) - - def __mod__(self, value: MedRecordValue) -> ValueOperand: - return self.mod(value) - - def round(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of rounding the attribute's value. - - Returns: - ValueOperand: The result of the rounding operation. - """ - return self._edge_attribute_operand.round() - - def ceil(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of applying the ceiling function to the attribute's value, effectively rounding it up to the nearest whole number. - - Returns: - ValueOperand: The result of the ceiling operation. - """ - return self._edge_attribute_operand.ceil() - - def floor(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of applying the floor function to the attribute's value, effectively rounding it down to the nearest whole number. - - Returns: - ValueOperand: The result of the floor operation. - """ - return self._edge_attribute_operand.floor() - - def abs(self) -> ValueOperand: - """Creates a new ValueOperand representing the absolute value of the attribute's value. - - Returns: - ValueOperand: The absolute value of the attribute's value. - """ - return self._edge_attribute_operand.abs() - - def sqrt(self) -> ValueOperand: - """Creates a new ValueOperand representing the square root of the attribute's value. - - Returns: - ValueOperand: The square root of the attribute's value. - """ - return self._edge_attribute_operand.sqrt() - - def trim(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of trimming whitespace from both ends of the attribute's value. - - Returns: - ValueOperand: The attribute's value with leading and trailing - whitespace removed. - """ - return self._edge_attribute_operand.trim() - - def trim_start(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of trimming whitespace from the start (left side) of the attribute's value. - - Returns: - ValueOperand: The attribute's value with leading whitespace removed. - """ - return self._edge_attribute_operand.trim_start() - - def trim_end(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of trimming whitespace from the end (right side) of the attribute's value. - - Returns: - ValueOperand: The attribute's value with trailing whitespace removed. - """ - return self._edge_attribute_operand.trim_end() - - def lowercase(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of converting all characters in the attribute's value to lowercase. - - Returns: - ValueOperand: The attribute's value in lowercase letters. - """ - return self._edge_attribute_operand.lowercase() - - def uppercase(self) -> ValueOperand: - """Creates a new ValueOperand representing the result of converting all characters in the attribute's value to uppercase. - - Returns: - ValueOperand: The attribute's value in uppercase letters. - """ - return self._edge_attribute_operand.uppercase() - - def slice(self, start: int, end: int) -> ValueOperand: - """Creates a new ValueOperand representing the result of slicing the attribute's value using the specified start and end indices. - - Args: - start (int): The index at which to start the slice. - end (int): The index at which to end the slice. - - Returns: - ValueOperand: The attribute's value with the specified slice applied. - """ - return self._edge_attribute_operand.slice(start, end) - - -class NodeIndexOperand: - _node_index_operand: PyNodeIndexOperand - - def __init__(self, node_index_operand: PyNodeIndexOperand) -> None: - self._node_index_operand = node_index_operand - - def greater(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index is greater than the specified index. - - Args: - operand (NodeIndex): The index to compare against. - - Returns: - NodeOperation: A NodeOperation representing the greater-than comparison. - """ - return NodeOperation(self._node_index_operand.greater(operand)) - - def __gt__(self, operand: NodeIndex) -> NodeOperation: - return self.greater(operand) - - def less(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index is less than the specified index. - - Args: - operand (NodeIndex): The index to compare against. - - Returns: - NodeOperation: A NodeOperation representing the less-than comparison. - """ - return NodeOperation(self._node_index_operand.less(operand)) - - def __lt__(self, operand: NodeIndex) -> NodeOperation: - return self.less(operand) - - def greater_or_equal(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index is greater than or equal to the specified index. - - Args: - operand (NodeIndex): The index to compare against. - - Returns: - NodeOperation: A NodeOperation representing the - greater-than-or-equal-to comparison. - """ - return NodeOperation(self._node_index_operand.greater_or_equal(operand)) - - def __ge__(self, operand: NodeIndex) -> NodeOperation: - return self.greater_or_equal(operand) - - def less_or_equal(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index is less than or equal to the specified index. - - Args: - operand (NodeIndex): The index to compare against. - - Returns: - NodeOperation: A NodeOperation representing the - less-than-or-equal-to comparison. - """ - return NodeOperation(self._node_index_operand.less_or_equal(operand)) - - def __le__(self, operand: NodeIndex) -> NodeOperation: - return self.less_or_equal(operand) - - def equal(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index is equal to the specified index. - - Args: - operand (NodeIndex): The index to compare against. - - Returns: - NodeOperation: A NodeOperation representing the equality comparison. - """ - return NodeOperation(self._node_index_operand.equal(operand)) - - def __eq__(self, operand: NodeIndex) -> NodeOperation: - return self.equal(operand) - - def not_equal(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index is not equal to the specified index. - - Args: - operand (NodeIndex): The index to compare against. - - Returns: - NodeOperation: A NodeOperation representing the not-equal comparison. - """ - return NodeOperation(self._node_index_operand.not_equal(operand)) - - def __ne__(self, operand: NodeIndex) -> NodeOperation: - return self.not_equal(operand) - - def is_in(self, values: List[NodeIndex]) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index is found within the list of indices. - - Args: - values (List[NodeIndex]): The list of indices to check the node index - against. - - Returns: - NodeOperation: A NodeOperation representing the is-in comparison. - """ - return NodeOperation(self._node_index_operand.is_in(values)) - - def not_in(self, values: List[NodeIndex]) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index is not found within the list of indices. - - Args: - values (List[NodeIndex]): The list of indices to check the node index - against. - - Returns: - NodeOperation: A NodeOperation representing the not-in comparison. - """ - return NodeOperation(self._node_index_operand.not_in(values)) - - def starts_with(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index starts with the specified index. - - Args: - operand (NodeIndex): The index to compare against. - - Returns: - NodeOperation: A NodeOperation representing the starts-with condition. - """ - return NodeOperation(self._node_index_operand.starts_with(operand)) - - def ends_with(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index ends with the specified index. - - Args: - operand (NodeIndex): The index to compare against. - - Returns: - NodeOperation: A NodeOperation representing the ends-with condition. - """ - return NodeOperation(self._node_index_operand.ends_with(operand)) - - def contains(self, operand: NodeIndex) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node index contains the specified index. - - Args: - operand (NodeIndex): The index to compare against. - - Returns: - NodeOperation: A NodeOperation representing the contains condition. - """ - return NodeOperation(self._node_index_operand.contains(operand)) - - -class EdgeIndexOperand: - _edge_index_operand: PyEdgeIndexOperand - - def __init__(self, edge_index_operand: PyEdgeIndexOperand) -> None: - self._edge_index_operand = edge_index_operand - - def greater(self, operand: EdgeIndex) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the edge index is greater than the specified index. - - Args: - operand (EdgeIndex): The index to compare against. - - Returns: - EdgeOperation: A EdgeOperation representing the greater-than comparison. - """ - return EdgeOperation(self._edge_index_operand.greater(operand)) - - def __gt__(self, operand: EdgeIndex) -> EdgeOperation: - return self.greater(operand) - - def less(self, operand: EdgeIndex) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the edge index is less than the specified index. - - Args: - operand (EdgeIndex): The index to compare against. - - Returns: - EdgeOperation: A EdgeOperation representing the less-than comparison. - """ - return EdgeOperation(self._edge_index_operand.less(operand)) - - def __lt__(self, operand: EdgeIndex) -> EdgeOperation: - return self.less(operand) - - def greater_or_equal(self, operand: EdgeIndex) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the edge index is greater than or equal to the specified index. - - Args: - operand (EdgeIndex): The index to compare against. - - Returns: - EdgeOperation: A EdgeOperation representing the - greater-than-or-equal-to comparison. - """ - return EdgeOperation(self._edge_index_operand.greater_or_equal(operand)) - - def __ge__(self, operand: EdgeIndex) -> EdgeOperation: - return self.greater_or_equal(operand) - - def less_or_equal(self, operand: EdgeIndex) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the edge index is less than or equal to the specified index. - - Args: - operand (EdgeIndex): The index to compare against. - - Returns: - EdgeOperation: A EdgeOperation representing the - less-than-or-equal-to comparison. - """ - return EdgeOperation(self._edge_index_operand.less_or_equal(operand)) - - def __le__(self, operand: EdgeIndex) -> EdgeOperation: - return self.less_or_equal(operand) - - def equal(self, operand: EdgeIndex) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the edge index is equal to the specified index. - - Args: - operand (EdgeIndex): The index to compare against. - - Returns: - EdgeOperation: A EdgeOperation representing the equality comparison. - """ - return EdgeOperation(self._edge_index_operand.equal(operand)) - - def __eq__(self, operand: EdgeIndex) -> EdgeOperation: - return self.equal(operand) - - def not_equal(self, operand: EdgeIndex) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the edge index is not equal to the specified index. - - Args: - operand (EdgeIndex): The index to compare against. - - Returns: - EdgeOperation: A EdgeOperation representing the not-equal comparison. - """ - return EdgeOperation(self._edge_index_operand.not_equal(operand)) - - def __ne__(self, operand: EdgeIndex) -> EdgeOperation: - return self.not_equal(operand) - - def is_in(self, values: List[EdgeIndex]) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the edge index is found within the list of indices. - - Args: - values (List[EdgeIndex]): The list of indices to check the edge index - against. - - Returns: - EdgeOperation: A EdgeOperation representing the is-in comparison. - """ - return EdgeOperation(self._edge_index_operand.is_in(values)) - - def not_in(self, values: List[EdgeIndex]) -> EdgeOperation: - """Creates a EdgeOperation that evaluates to true if the edge index is not found within the list of indices. - - Args: - values (List[EdgeIndex]): The list of indices to check the edge index - against. - - Returns: - EdgeOperation: A EdgeOperation representing the not-in comparison. - """ - return EdgeOperation(self._edge_index_operand.not_in(values)) - - -class NodeOperand: - _node_operand: PyNodeOperand - - def __init__(self) -> None: - self._node_operand = PyNodeOperand() - - def in_group(self, operand: Group) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node is part of the specified group. - - Args: - operand (Group): The group to check the node against. - - Returns: - NodeOperation: A NodeOperation indicating if the node is part of the - specified group. - """ - return NodeOperation(self._node_operand.in_group(operand)) - - def has_attribute(self, operand: MedRecordAttribute) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node has the specified attribute. - - Args: - operand (MedRecordAttribute): The attribute to check on the node. - - Returns: - NodeOperation: A NodeOperation indicating if the node has the - specified attribute. - """ - return NodeOperation(self._node_operand.has_attribute(operand)) - - def has_outgoing_edge_with(self, operation: EdgeOperation) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node has an outgoing edge that satisfies the specified EdgeOperation. - - Args: - operation (EdgeOperation): An EdgeOperation to evaluate against - outgoing edges. - - Returns: - NodeOperation: A NodeOperation indicating if the node has an - outgoing edge satisfying the specified operation. - """ - return NodeOperation( - self._node_operand.has_outgoing_edge_with(operation._edge_operation) - ) - - def has_incoming_edge_with(self, operation: EdgeOperation) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node has an incoming edge that satisfies the specified EdgeOperation. - - Args: - operation (EdgeOperation): An EdgeOperation to evaluate against - incoming edges. - - Returns: - NodeOperation: A NodeOperation indicating if the node has an - incoming edge satisfying the specified operation. - """ - return NodeOperation( - self._node_operand.has_incoming_edge_with(operation._edge_operation) - ) - - def has_edge_with(self, operation: EdgeOperation) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node has any edge (incoming or outgoing) that satisfies the specified EdgeOperation. - - Args: - operation (EdgeOperation): An EdgeOperation to evaluate against - edges connected to the node. - - Returns: - NodeOperation: A NodeOperation indicating if the node has any edge - satisfying the specified operation. - """ - return NodeOperation( - self._node_operand.has_edge_with(operation._edge_operation) - ) - - def has_neighbor_with( - self, operation: NodeOperation, *, directed: bool = True - ) -> NodeOperation: - """Creates a NodeOperation that evaluates to true if the node has a neighboring node that satisfies the specified NodeOperation. - - Args: - operation (NodeOperation): A NodeOperation to evaluate against - neighboring nodes. - directed (bool): Whether to consider edges as directed - - Returns: - NodeOperation: A NodeOperation indicating if the node has a neighboring node - satisfying the specified operation. - """ - if directed: - return NodeOperation( - self._node_operand.has_neighbor_with(operation._node_operation) - ) - else: - return NodeOperation( - self._node_operand.has_neighbor_undirected_with( - operation._node_operation - ) - ) - - def attribute(self, attribute: MedRecordAttribute) -> NodeAttributeOperand: - """Accesses an NodeAttributeOperand for the specified attribute, allowing for the creation of operations based on node attributes. - - Args: - attribute (MedRecordAttribute): The attribute of the node to perform - operations on. - - Returns: - NodeAttributeOperand: An operand that represents the specified node - attribute, enabling further operations such as comparisons and - arithmetic operations. - """ - return NodeAttributeOperand(self._node_operand.attribute(attribute)) - - def index(self) -> NodeIndexOperand: - """Accesses an NodeIndexOperand, allowing for the creation of operations based on the node index. - - Returns: - NodeIndexOperand: An operand that represents the specified node - index, enabling further operations such as comparisons and - arithmetic operations. - """ - return NodeIndexOperand(self._node_operand.index()) - - -def node() -> NodeOperand: - """Factory function to create and return a new NodeOperand instance. - - Returns: - NodeOperand: An instance of NodeOperand for constructing node-based operations. - """ - return NodeOperand() - - -class EdgeOperand: - _edge_operand: PyEdgeOperand - - def __init__(self) -> None: - self._edge_operand = PyEdgeOperand() - - def connected_target(self, operand: NodeIndex) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if the edge is connected to a target node with the specified index. - - Args: - operand (NodeIndex): The index of the target node to check for a connection. - - Returns: - EdgeOperation: An EdgeOperation indicating if the edge is connected to the - specified target node. - """ - return EdgeOperation(self._edge_operand.connected_target(operand)) - - def connected_source(self, operand: NodeIndex) -> EdgeOperation: - """Generates an EdgeOperation that evaluates to true if the edge originates from a source node with the given index. - - Args: - operand (NodeIndex): The index of the source node to check for a connection. - - Returns: - EdgeOperation: An EdgeOperation indicating if the edge is connected from the - specified source node. - """ - return EdgeOperation(self._edge_operand.connected_source(operand)) - - def connected(self, operand: NodeIndex) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if the edge is connected to or from a node with the specified index. - - Args: - operand (NodeIndex): The index of the node to check for a connection. - - Returns: - EdgeOperation: An EdgeOperation indicating if the edge is connected to the - specified node. - """ - return EdgeOperation(self._edge_operand.connected(operand)) - - def in_group(self, operand: Group) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if the edge is part of the specified group. - - Args: - operand (Group): The group to check the edge against. - - Returns: - EdgeOperation: An EdgeOperation indicating if the edge is part of the - specified group. - """ - return EdgeOperation(self._edge_operand.in_group(operand)) - - def has_attribute(self, operand: MedRecordAttribute) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if the edge has the specified attribute. - - Args: - operand (MedRecordAttribute): The attribute to check on the edge. - - Returns: - EdgeOperation: An EdgeOperation indicating if the edge has the - specified attribute. - """ - return EdgeOperation(self._edge_operand.has_attribute(operand)) - - def connected_source_with(self, operation: NodeOperation) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if the edge originates from a source node that satisfies the specified NodeOperation. - - Args: - operation (NodeOperation): A NodeOperation to evaluate against the - source node. - - Returns: - EdgeOperation: An EdgeOperation indicating if the source node of the - edge satisfies the specified operation. - """ - return EdgeOperation( - self._edge_operand.connected_source_with(operation._node_operation) - ) - - def connected_target_with(self, operation: NodeOperation) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if the edge is connected to a target node that satisfies the specified NodeOperation. - - Args: - operation (NodeOperation): A NodeOperation to evaluate against the - target node. - - Returns: - EdgeOperation: An EdgeOperation indicating if the target node of the - edge satisfies the specified operation. - """ - return EdgeOperation( - self._edge_operand.connected_target_with(operation._node_operation) - ) - - def connected_with(self, operation: NodeOperation) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if the edge is connected to or from a node that satisfies the specified NodeOperation. - - Args: - operation (NodeOperation): A NodeOperation to evaluate against the - connected node. - - Returns: - EdgeOperation: An EdgeOperation indicating if either the source or - target node of the edge satisfies the specified operation. - """ - return EdgeOperation( - self._edge_operand.connected_with(operation._node_operation) - ) - - def has_parallel_edges_with(self, operation: EdgeOperation) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if there are parallel edges that satisfy the specified EdgeOperation. - - Args: - operation (EdgeOperation): An EdgeOperation to evaluate against - parallel edges. - - Returns: - EdgeOperation: An EdgeOperation indicating if there are parallel edges - satisfying the specified operation. - """ - return EdgeOperation( - self._edge_operand.has_parallel_edges_with(operation._edge_operation) - ) - - def has_parallel_edges_with_self_comparison( - self, operation: EdgeOperation - ) -> EdgeOperation: - """Creates an EdgeOperation that evaluates to true if there are parallel edges that satisfy the specified EdgeOperation. - - Using `edge().attribute(...)` in the operation will compare to the attribute of - this edge, not the parallel edge. - - Args: - operation (EdgeOperation): An EdgeOperation to evaluate against - parallel edges. - - Returns: - EdgeOperation: An EdgeOperation indicating if there are parallel edges - satisfying the specified operation. - """ - return EdgeOperation( - self._edge_operand.has_parallel_edges_with_self_comparison( - operation._edge_operation - ) - ) - - def attribute(self, attribute: MedRecordAttribute) -> EdgeAttributeOperand: - """Accesses an EdgeAttributeOperand for the specified attribute, allowing for the creation of operations based on edge attributes. - - Args: - attribute (MedRecordAttribute): The attribute of the edge to perform - operations on. - - Returns: - EdgeAttributeOperand: An operand that represents the specified edge - attribute, enabling further operations such as comparisons and - arithmetic operations. - """ - return EdgeAttributeOperand(self._edge_operand.attribute(attribute)) - - def index(self) -> EdgeIndexOperand: - """Accesses an EdgeIndexOperand, allowing for the creation of operations based on the edge index. - - Returns: - EdgeIndexOperand: An operand that represents the specified edge - index, enabling further operations such as comparisons and - arithmetic operations. - """ - return EdgeIndexOperand(self._edge_operand.index()) - - -def edge() -> EdgeOperand: - """Factory function to create and return a new EdgeOperand instance. - - Returns: - EdgeOperand: An instance of EdgeOperand for constructing edge-based operations. - """ - return EdgeOperand() diff --git a/medmodels/medrecord/querying.pyi b/medmodels/medrecord/querying.pyi new file mode 100644 index 00000000..9b3a7c07 --- /dev/null +++ b/medmodels/medrecord/querying.pyi @@ -0,0 +1,181 @@ +from __future__ import annotations + +import sys +from enum import Enum, auto +from typing import Callable, List, Union + +from medmodels.medrecord.types import Group, MedRecordAttribute, MedRecordValue + +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + +NodeQuery: TypeAlias = Callable[[NodeOperand], None] +EdgeQuery: TypeAlias = Callable[[EdgeOperand], None] + +ValueOperand: TypeAlias = Union[NodeValueOperand, EdgeValueOperand, MedRecordValue] +ValuesOperand: TypeAlias = Union[ + NodeValuesOperand, EdgeValuesOperand, List[MedRecordValue] +] +ComparisonOperand: TypeAlias = Union[ValueOperand, ValuesOperand] + +class EdgeDirection(Enum): + INCOMING = auto() + OUTGOING = auto() + BOTH = auto() + +class NodeOperand: + def attribute(self, attribute: MedRecordAttribute) -> NodeValuesOperand: ... + def index(self) -> NodeValuesOperand: ... + def in_group(self, group: Union[Group, List[Group]]) -> None: ... + def has_attribute( + self, attribute: Union[MedRecordAttribute, List[MedRecordAttribute]] + ) -> None: ... + def incoming_edges(self) -> EdgeOperand: ... + def outgoing_edges(self) -> EdgeOperand: ... + def neighbors( + self, edge_direction: EdgeDirection = EdgeDirection.OUTGOING + ) -> NodeOperand: ... + def either_or(self, either: NodeQuery, or_: NodeQuery) -> None: ... + +class EdgeOperand: + def attribute(self, attribute: MedRecordAttribute) -> EdgeValuesOperand: ... + def index(self) -> EdgeValuesOperand: ... + def in_group(self, group: Union[Group, List[Group]]) -> None: ... + def has_attribute( + self, attribute: Union[MedRecordAttribute, List[MedRecordAttribute]] + ) -> None: ... + def source_node(self) -> NodeOperand: ... + def target_node(self) -> NodeOperand: ... + def either_or(self, either: EdgeQuery, or_: EdgeQuery) -> None: ... + +class NodeValuesOperand: + def max(self) -> NodeValueOperand: ... + def min(self) -> NodeValueOperand: ... + def mean(self) -> NodeValueOperand: ... + def all(self) -> NodeValueOperand: ... + def any(self) -> NodeValueOperand: ... + def greater_than(self, value: ComparisonOperand) -> None: ... + def greater_than_or_equal(self, value: ComparisonOperand) -> None: ... + def less_than(self, value: ComparisonOperand) -> None: ... + def less_than_or_equal(self, value: ComparisonOperand) -> None: ... + def equals(self, value: ComparisonOperand) -> None: ... + def not_equals(self, value: ComparisonOperand) -> None: ... + def is_in(self, values: ValuesOperand) -> None: ... + def is_not_in(self, values: ValuesOperand) -> None: ... + def starts_with(self, value: ComparisonOperand) -> None: ... + def ends_with(self, value: ComparisonOperand) -> None: ... + def contains(self, value: ComparisonOperand) -> None: ... + def add(self, value: ComparisonOperand) -> NodeValuesOperand: ... + def subtract(self, value: ComparisonOperand) -> NodeValuesOperand: ... + def multiply(self, value: ComparisonOperand) -> NodeValuesOperand: ... + def divide(self, value: ComparisonOperand) -> NodeValuesOperand: ... + def modulo(self, value: ComparisonOperand) -> NodeValuesOperand: ... + def power(self, value: ComparisonOperand) -> NodeValuesOperand: ... + def round(self) -> NodeValuesOperand: ... + def ceil(self) -> NodeValuesOperand: ... + def floor(self) -> NodeValuesOperand: ... + def absolute(self) -> NodeValuesOperand: ... + def sqrt(self) -> NodeValuesOperand: ... + def trim(self) -> NodeValuesOperand: ... + def trim_start(self) -> NodeValuesOperand: ... + def trim_end(self) -> NodeValuesOperand: ... + def lowercase(self) -> NodeValuesOperand: ... + def uppercase(self) -> NodeValuesOperand: ... + def slice(self, start: int, end: int) -> NodeValuesOperand: ... + +class EdgeValuesOperand: + def max(self) -> EdgeValueOperand: ... + def min(self) -> EdgeValueOperand: ... + def mean(self) -> EdgeValueOperand: ... + def all(self) -> EdgeValueOperand: ... + def any(self) -> EdgeValueOperand: ... + def greater_than(self, value: ComparisonOperand) -> None: ... + def greater_than_or_equal(self, value: ComparisonOperand) -> None: ... + def less_than(self, value: ComparisonOperand) -> None: ... + def less_than_or_equal(self, value: ComparisonOperand) -> None: ... + def equals(self, value: ComparisonOperand) -> None: ... + def not_equals(self, value: ComparisonOperand) -> None: ... + def is_in(self, values: ValuesOperand) -> None: ... + def is_not_in(self, values: ValuesOperand) -> None: ... + def starts_with(self, value: ComparisonOperand) -> None: ... + def ends_with(self, value: ComparisonOperand) -> None: ... + def contains(self, value: ComparisonOperand) -> None: ... + def add(self, value: ComparisonOperand) -> EdgeValuesOperand: ... + def subtract(self, value: ComparisonOperand) -> EdgeValuesOperand: ... + def multiply(self, value: ComparisonOperand) -> EdgeValuesOperand: ... + def divide(self, value: ComparisonOperand) -> EdgeValuesOperand: ... + def modulo(self, value: ComparisonOperand) -> EdgeValuesOperand: ... + def power(self, value: ComparisonOperand) -> EdgeValuesOperand: ... + def round(self) -> EdgeValuesOperand: ... + def ceil(self) -> EdgeValuesOperand: ... + def floor(self) -> EdgeValuesOperand: ... + def absolute(self) -> EdgeValuesOperand: ... + def sqrt(self) -> EdgeValuesOperand: ... + def trim(self) -> EdgeValuesOperand: ... + def trim_start(self) -> EdgeValuesOperand: ... + def trim_end(self) -> EdgeValuesOperand: ... + def lowercase(self) -> EdgeValuesOperand: ... + def uppercase(self) -> EdgeValuesOperand: ... + def slice(self, start: int, end: int) -> EdgeValuesOperand: ... + +class NodeValueOperand: + def greater_than(self, value: ComparisonOperand) -> None: ... + def greater_than_or_equal(self, value: ComparisonOperand) -> None: ... + def less_than(self, value: ComparisonOperand) -> None: ... + def less_than_or_equal(self, value: ComparisonOperand) -> None: ... + def equals(self, value: ComparisonOperand) -> None: ... + def not_equals(self, value: ComparisonOperand) -> None: ... + def is_in(self, values: ValuesOperand) -> None: ... + def is_not_in(self, values: ValuesOperand) -> None: ... + def starts_with(self, value: ComparisonOperand) -> None: ... + def ends_with(self, value: ComparisonOperand) -> None: ... + def contains(self, value: ComparisonOperand) -> None: ... + def add(self, value: ComparisonOperand) -> NodeValueOperand: ... + def subtract(self, value: ComparisonOperand) -> NodeValueOperand: ... + def multiply(self, value: ComparisonOperand) -> NodeValueOperand: ... + def divide(self, value: ComparisonOperand) -> NodeValueOperand: ... + def modulo(self, value: ComparisonOperand) -> NodeValueOperand: ... + def power(self, value: ComparisonOperand) -> NodeValueOperand: ... + def round(self) -> NodeValueOperand: ... + def ceil(self) -> NodeValueOperand: ... + def floor(self) -> NodeValueOperand: ... + def absolute(self) -> NodeValueOperand: ... + def sqrt(self) -> NodeValueOperand: ... + def trim(self) -> NodeValueOperand: ... + def trim_start(self) -> NodeValueOperand: ... + def trim_end(self) -> NodeValueOperand: ... + def lowercase(self) -> NodeValueOperand: ... + def uppercase(self) -> NodeValueOperand: ... + def slice(self, start: int, end: int) -> NodeValueOperand: ... + +class EdgeValueOperand: + def greater_than(self, value: ComparisonOperand) -> None: ... + def greater_than_or_equal(self, value: ComparisonOperand) -> None: ... + def less_than(self, value: ComparisonOperand) -> None: ... + def less_than_or_equal(self, value: ComparisonOperand) -> None: ... + def equals(self, value: ComparisonOperand) -> None: ... + def not_equals(self, value: ComparisonOperand) -> None: ... + def is_in(self, values: ValuesOperand) -> None: ... + def is_not_in(self, values: ValuesOperand) -> None: ... + def starts_with(self, value: ComparisonOperand) -> None: ... + def ends_with(self, value: ComparisonOperand) -> None: ... + def contains(self, value: ComparisonOperand) -> None: ... + def add(self, value: ComparisonOperand) -> EdgeValueOperand: ... + def subtract(self, value: ComparisonOperand) -> EdgeValueOperand: ... + def multiply(self, value: ComparisonOperand) -> EdgeValueOperand: ... + def divide(self, value: ComparisonOperand) -> EdgeValueOperand: ... + def modulo(self, value: ComparisonOperand) -> EdgeValueOperand: ... + def power(self, value: ComparisonOperand) -> EdgeValueOperand: ... + def round(self) -> EdgeValueOperand: ... + def ceil(self) -> EdgeValueOperand: ... + def floor(self) -> EdgeValueOperand: ... + def absolute(self) -> EdgeValueOperand: ... + def sqrt(self) -> EdgeValueOperand: ... + def trim(self) -> EdgeValueOperand: ... + def trim_start(self) -> EdgeValueOperand: ... + def trim_end(self) -> EdgeValueOperand: ... + def lowercase(self) -> EdgeValueOperand: ... + def uppercase(self) -> EdgeValueOperand: ... + def slice(self, start: int, end: int) -> EdgeValueOperand: ... diff --git a/medmodels/medrecord/tests/test_indexers.py b/medmodels/medrecord/tests/test_indexers.py index 17492c6d..6616b866 100644 --- a/medmodels/medrecord/tests/test_indexers.py +++ b/medmodels/medrecord/tests/test_indexers.py @@ -1,7 +1,7 @@ import unittest from medmodels import MedRecord -from medmodels.medrecord import edge, node +from medmodels.medrecord.querying import EdgeOperand, NodeOperand def create_medrecord(): @@ -21,6 +21,30 @@ def create_medrecord(): ) +def node_greater_than_or_equal_two(node: NodeOperand): + node.index().greater_than_or_equal(2) + + +def node_greater_than_three(node: NodeOperand): + node.index().greater_than(3) + + +def node_less_than_two(node: NodeOperand): + node.index().less_than(2) + + +def edge_greater_than_or_equal_two(edge: EdgeOperand): + edge.index().greater_than_or_equal(2) + + +def edge_greater_than_three(edge: EdgeOperand): + edge.index().greater_than(3) + + +def edge_less_than_two(edge: EdgeOperand): + edge.index().less_than(2) + + class TestMedRecord(unittest.TestCase): def test_node_getitem(self): medrecord = create_medrecord() @@ -118,54 +142,54 @@ def test_node_getitem(self): self.assertEqual( {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}}, - medrecord.node[node().index() >= 2], + medrecord.node[node_greater_than_or_equal_two], ) # Empty query should not fail self.assertEqual( {}, - medrecord.node[node().index() > 3], + medrecord.node[node_greater_than_three], ) self.assertEqual( {2: "bar", 3: "bar"}, - medrecord.node[node().index() >= 2, "foo"], + medrecord.node[node_greater_than_or_equal_two, "foo"], ) # Accessing a non-existing key should fail with self.assertRaises(KeyError): - medrecord.node[node().index() >= 2, "test"] + medrecord.node[node_greater_than_or_equal_two, "test"] self.assertEqual( { 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}, }, - medrecord.node[node().index() >= 2, ["foo", "bar"]], + medrecord.node[node_greater_than_or_equal_two, ["foo", "bar"]], ) # Accessing a non-existing key should fail with self.assertRaises(KeyError): - medrecord.node[node().index() >= 2, ["foo", "test"]] + medrecord.node[node_greater_than_or_equal_two, ["foo", "test"]] # Accessing a key that doesn't exist in all nodes should fail with self.assertRaises(KeyError): - medrecord.node[node().index() < 2, ["foo", "lorem"]] + medrecord.node[node_less_than_two, ["foo", "lorem"]] self.assertEqual( { 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}, }, - medrecord.node[node().index() >= 2, :], + medrecord.node[node_greater_than_or_equal_two, :], ) with self.assertRaises(ValueError): - medrecord.node[node().index() >= 2, 1:] + medrecord.node[node_greater_than_or_equal_two, 1:] with self.assertRaises(ValueError): - medrecord.node[node().index() >= 2, :1] + medrecord.node[node_greater_than_or_equal_two, :1] with self.assertRaises(ValueError): - medrecord.node[node().index() >= 2, ::1] + medrecord.node[node_greater_than_or_equal_two, ::1] self.assertEqual( { @@ -360,7 +384,7 @@ def test_node_setitem(self): medrecord.node[[0, 1], ::1] = "test" medrecord = create_medrecord() - medrecord.node[node().index() >= 2] = {"foo": "bar", "bar": "test"} + medrecord.node[node_greater_than_or_equal_two] = {"foo": "bar", "bar": "test"} self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -373,10 +397,10 @@ def test_node_setitem(self): medrecord = create_medrecord() # Empty query should not fail - medrecord.node[node().index() > 3] = {"foo": "bar", "bar": "test"} + medrecord.node[node_greater_than_three] = {"foo": "bar", "bar": "test"} medrecord = create_medrecord() - medrecord.node[node().index() >= 2, "foo"] = "test" + medrecord.node[node_greater_than_or_equal_two, "foo"] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -388,7 +412,7 @@ def test_node_setitem(self): ) medrecord = create_medrecord() - medrecord.node[node().index() >= 2, ["foo", "bar"]] = "test" + medrecord.node[node_greater_than_or_equal_two, ["foo", "bar"]] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -400,7 +424,7 @@ def test_node_setitem(self): ) medrecord = create_medrecord() - medrecord.node[node().index() >= 2, :] = "test" + medrecord.node[node_greater_than_or_equal_two, :] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -412,11 +436,11 @@ def test_node_setitem(self): ) with self.assertRaises(ValueError): - medrecord.node[node().index() >= 2, 1:] = "test" + medrecord.node[node_greater_than_or_equal_two, 1:] = "test" with self.assertRaises(ValueError): - medrecord.node[node().index() >= 2, :1] = "test" + medrecord.node[node_greater_than_or_equal_two, :1] = "test" with self.assertRaises(ValueError): - medrecord.node[node().index() >= 2, ::1] = "test" + medrecord.node[node_greater_than_or_equal_two, ::1] = "test" medrecord = create_medrecord() medrecord.node[:, "foo"] = "test" @@ -544,7 +568,7 @@ def test_node_setitem(self): ) medrecord = create_medrecord() - medrecord.node[node().index() >= 2, "test"] = "test" + medrecord.node[node_greater_than_or_equal_two, "test"] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -556,7 +580,7 @@ def test_node_setitem(self): ) medrecord = create_medrecord() - medrecord.node[node().index() >= 2, ["test", "test2"]] = "test" + medrecord.node[node_greater_than_or_equal_two, ["test", "test2"]] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -634,7 +658,7 @@ def test_node_setitem(self): ) medrecord = create_medrecord() - medrecord.node[node().index() < 2, "lorem"] = "test" + medrecord.node[node_less_than_two, "lorem"] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, @@ -646,7 +670,7 @@ def test_node_setitem(self): ) medrecord = create_medrecord() - medrecord.node[node().index() < 2, ["lorem", "test"]] = "test" + medrecord.node[node_less_than_two, ["lorem", "test"]] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, @@ -804,7 +828,7 @@ def test_node_delitem(self): del medrecord.node[[0, 1], ::1] medrecord = create_medrecord() - del medrecord.node[node().index() >= 2, "foo"] + del medrecord.node[node_greater_than_or_equal_two, "foo"] self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -817,7 +841,7 @@ def test_node_delitem(self): medrecord = create_medrecord() # Empty query should not fail - del medrecord.node[node().index() > 3, "foo"] + del medrecord.node[node_greater_than_three, "foo"] self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -831,10 +855,10 @@ def test_node_delitem(self): medrecord = create_medrecord() # Removing a non-existing key should fail with self.assertRaises(KeyError): - del medrecord.node[node().index() >= 2, "test"] + del medrecord.node[node_greater_than_or_equal_two, "test"] medrecord = create_medrecord() - del medrecord.node[node().index() >= 2, ["foo", "bar"]] + del medrecord.node[node_greater_than_or_equal_two, ["foo", "bar"]] self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -848,15 +872,15 @@ def test_node_delitem(self): medrecord = create_medrecord() # Removing a non-existing key should fail with self.assertRaises(KeyError): - del medrecord.node[node().index() >= 2, ["foo", "test"]] + del medrecord.node[node_greater_than_or_equal_two, ["foo", "test"]] medrecord = create_medrecord() # Removing a key that doesn't exist in all nodes should fail with self.assertRaises(KeyError): - del medrecord.node[node().index() < 2, ["foo", "lorem"]] + del medrecord.node[node_less_than_two, ["foo", "lorem"]] medrecord = create_medrecord() - del medrecord.node[node().index() >= 2, :] + del medrecord.node[node_greater_than_or_equal_two, :] self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -868,11 +892,11 @@ def test_node_delitem(self): ) with self.assertRaises(ValueError): - del medrecord.node[node().index() >= 2, 1:] + del medrecord.node[node_greater_than_or_equal_two, 1:] with self.assertRaises(ValueError): - del medrecord.node[node().index() >= 2, :1] + del medrecord.node[node_greater_than_or_equal_two, :1] with self.assertRaises(ValueError): - del medrecord.node[node().index() >= 2, ::1] + del medrecord.node[node_greater_than_or_equal_two, ::1] medrecord = create_medrecord() del medrecord.node[:, "foo"] @@ -1048,54 +1072,54 @@ def test_edge_getitem(self): self.assertEqual( {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}}, - medrecord.edge[edge().index() >= 2], + medrecord.edge[edge_greater_than_or_equal_two], ) # Empty query should not fail self.assertEqual( {}, - medrecord.edge[edge().index() > 3], + medrecord.edge[edge_greater_than_three], ) self.assertEqual( {2: "bar", 3: "bar"}, - medrecord.edge[edge().index() >= 2, "foo"], + medrecord.edge[edge_greater_than_or_equal_two, "foo"], ) # Accessing a non-existing key should fail with self.assertRaises(KeyError): - medrecord.edge[edge().index() >= 2, "test"] + medrecord.edge[edge_greater_than_or_equal_two, "test"] self.assertEqual( { 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}, }, - medrecord.edge[edge().index() >= 2, ["foo", "bar"]], + medrecord.edge[edge_greater_than_or_equal_two, ["foo", "bar"]], ) # Accessing a non-existing key should fail with self.assertRaises(KeyError): - medrecord.edge[edge().index() >= 2, ["foo", "test"]] + medrecord.edge[edge_greater_than_or_equal_two, ["foo", "test"]] # Accessing a key that doesn't exist in all edges should fail with self.assertRaises(KeyError): - medrecord.edge[edge().index() < 2, ["foo", "lorem"]] + medrecord.edge[edge_less_than_two, ["foo", "lorem"]] self.assertEqual( { 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}, }, - medrecord.edge[edge().index() >= 2, :], + medrecord.edge[edge_greater_than_or_equal_two, :], ) with self.assertRaises(ValueError): - medrecord.edge[edge().index() >= 2, 1:] + medrecord.edge[edge_greater_than_or_equal_two, 1:] with self.assertRaises(ValueError): - medrecord.edge[edge().index() >= 2, :1] + medrecord.edge[edge_greater_than_or_equal_two, :1] with self.assertRaises(ValueError): - medrecord.edge[edge().index() >= 2, ::1] + medrecord.edge[edge_greater_than_or_equal_two, ::1] self.assertEqual( { @@ -1290,7 +1314,7 @@ def test_edge_setitem(self): medrecord.edge[[0, 1], ::1] = "test" medrecord = create_medrecord() - medrecord.edge[edge().index() >= 2] = {"foo": "bar", "bar": "test"} + medrecord.edge[edge_greater_than_or_equal_two] = {"foo": "bar", "bar": "test"} self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1303,10 +1327,10 @@ def test_edge_setitem(self): medrecord = create_medrecord() # Empty query should not fail - medrecord.edge[edge().index() > 3] = {"foo": "bar", "bar": "test"} + medrecord.edge[edge_greater_than_three] = {"foo": "bar", "bar": "test"} medrecord = create_medrecord() - medrecord.edge[edge().index() >= 2, "foo"] = "test" + medrecord.edge[edge_greater_than_or_equal_two, "foo"] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1318,7 +1342,7 @@ def test_edge_setitem(self): ) medrecord = create_medrecord() - medrecord.edge[edge().index() >= 2, ["foo", "bar"]] = "test" + medrecord.edge[edge_greater_than_or_equal_two, ["foo", "bar"]] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1330,7 +1354,7 @@ def test_edge_setitem(self): ) medrecord = create_medrecord() - medrecord.edge[edge().index() >= 2, :] = "test" + medrecord.edge[edge_greater_than_or_equal_two, :] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1342,11 +1366,11 @@ def test_edge_setitem(self): ) with self.assertRaises(ValueError): - medrecord.edge[edge().index() >= 2, 1:] = "test" + medrecord.edge[edge_greater_than_or_equal_two, 1:] = "test" with self.assertRaises(ValueError): - medrecord.edge[edge().index() >= 2, :1] = "test" + medrecord.edge[edge_greater_than_or_equal_two, :1] = "test" with self.assertRaises(ValueError): - medrecord.edge[edge().index() >= 2, ::1] = "test" + medrecord.edge[edge_greater_than_or_equal_two, ::1] = "test" medrecord = create_medrecord() medrecord.edge[:, "foo"] = "test" @@ -1474,7 +1498,7 @@ def test_edge_setitem(self): ) medrecord = create_medrecord() - medrecord.edge[edge().index() >= 2, "test"] = "test" + medrecord.edge[edge_greater_than_or_equal_two, "test"] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1486,7 +1510,7 @@ def test_edge_setitem(self): ) medrecord = create_medrecord() - medrecord.edge[edge().index() >= 2, ["test", "test2"]] = "test" + medrecord.edge[edge_greater_than_or_equal_two, ["test", "test2"]] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1564,7 +1588,7 @@ def test_edge_setitem(self): ) medrecord = create_medrecord() - medrecord.edge[edge().index() < 2, "lorem"] = "test" + medrecord.edge[edge_less_than_two, "lorem"] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, @@ -1576,7 +1600,7 @@ def test_edge_setitem(self): ) medrecord = create_medrecord() - medrecord.edge[edge().index() < 2, ["lorem", "test"]] = "test" + medrecord.edge[edge_less_than_two, ["lorem", "test"]] = "test" self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, @@ -1734,7 +1758,7 @@ def test_edge_delitem(self): del medrecord.edge[[0, 1], ::1] medrecord = create_medrecord() - del medrecord.edge[edge().index() >= 2, "foo"] + del medrecord.edge[edge_greater_than_or_equal_two, "foo"] self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1747,7 +1771,7 @@ def test_edge_delitem(self): medrecord = create_medrecord() # Empty query should not fail - del medrecord.edge[edge().index() > 3, "foo"] + del medrecord.edge[edge_greater_than_three, "foo"] self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1761,10 +1785,10 @@ def test_edge_delitem(self): medrecord = create_medrecord() # Removing a non-existing key should fail with self.assertRaises(KeyError): - del medrecord.edge[edge().index() >= 2, "test"] + del medrecord.edge[edge_greater_than_or_equal_two, "test"] medrecord = create_medrecord() - del medrecord.edge[edge().index() >= 2, ["foo", "bar"]] + del medrecord.edge[edge_greater_than_or_equal_two, ["foo", "bar"]] self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1778,15 +1802,15 @@ def test_edge_delitem(self): medrecord = create_medrecord() # Removing a non-existing key should fail with self.assertRaises(KeyError): - del medrecord.edge[edge().index() >= 2, ["foo", "test"]] + del medrecord.edge[edge_greater_than_or_equal_two, ["foo", "test"]] medrecord = create_medrecord() # Removing a key that doesn't exist in all edges should fail with self.assertRaises(KeyError): - del medrecord.edge[edge().index() < 2, ["foo", "lorem"]] + del medrecord.edge[edge_less_than_two, ["foo", "lorem"]] medrecord = create_medrecord() - del medrecord.edge[edge().index() >= 2, :] + del medrecord.edge[edge_greater_than_or_equal_two, :] self.assertEqual( { 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, @@ -1798,11 +1822,11 @@ def test_edge_delitem(self): ) with self.assertRaises(ValueError): - del medrecord.edge[edge().index() >= 2, 1:] + del medrecord.edge[edge_greater_than_or_equal_two, 1:] with self.assertRaises(ValueError): - del medrecord.edge[edge().index() >= 2, :1] + del medrecord.edge[edge_greater_than_or_equal_two, :1] with self.assertRaises(ValueError): - del medrecord.edge[edge().index() >= 2, ::1] + del medrecord.edge[edge_greater_than_or_equal_two, ::1] medrecord = create_medrecord() del medrecord.edge[:, "foo"] diff --git a/medmodels/medrecord/tests/test_medrecord.py b/medmodels/medrecord/tests/test_medrecord.py index bcecd58a..3a9a4e5c 100644 --- a/medmodels/medrecord/tests/test_medrecord.py +++ b/medmodels/medrecord/tests/test_medrecord.py @@ -7,8 +7,7 @@ import medmodels.medrecord as mr from medmodels import MedRecord -from medmodels.medrecord import edge as edge_select -from medmodels.medrecord import node as node_select +from medmodels.medrecord.querying import EdgeOperand, NodeOperand from medmodels.medrecord.types import Attributes, NodeIndex @@ -361,7 +360,10 @@ def test_outgoing_edges(self): {key: sorted(value) for (key, value) in edges.items()}, ) - edges = medrecord.outgoing_edges(node_select().index().is_in(["0", "1"])) + def query(node: NodeOperand): + node.index().is_in(["0", "1"]) + + edges = medrecord.outgoing_edges(query) self.assertEqual( {"0": sorted([0, 3]), "1": [1, 2]}, @@ -390,7 +392,10 @@ def test_incoming_edges(self): self.assertEqual({"1": [0], "2": [2]}, edges) - edges = medrecord.incoming_edges(node_select().index().is_in(["1", "2"])) + def query(node: NodeOperand): + node.index().is_in(["1", "2"]) + + edges = medrecord.incoming_edges(query) self.assertEqual({"1": [0], "2": [2]}, edges) @@ -416,7 +421,10 @@ def test_edge_endpoints(self): self.assertEqual({0: ("0", "1"), 1: ("1", "0")}, endpoints) - endpoints = medrecord.edge_endpoints(edge_select().index().is_in([0, 1])) + def query(edge: EdgeOperand): + edge.index().is_in([0, 1]) + + endpoints = medrecord.edge_endpoints(query) self.assertEqual({0: ("0", "1"), 1: ("1", "0")}, endpoints) @@ -442,7 +450,10 @@ def test_edges_connecting(self): self.assertEqual([0], edges) - edges = medrecord.edges_connecting(node_select().index().is_in(["0", "1"]), "1") + def query1(node: NodeOperand): + node.index().is_in(["0", "1"]) + + edges = medrecord.edges_connecting(query1, "1") self.assertEqual([0], edges) @@ -450,7 +461,10 @@ def test_edges_connecting(self): self.assertEqual(sorted([0, 3]), sorted(edges)) - edges = medrecord.edges_connecting("0", node_select().index().is_in(["1", "3"])) + def query2(node: NodeOperand): + node.index().is_in(["1", "3"]) + + edges = medrecord.edges_connecting("0", query2) self.assertEqual(sorted([0, 3]), sorted(edges)) @@ -458,10 +472,13 @@ def test_edges_connecting(self): self.assertEqual(sorted([0, 2, 3]), sorted(edges)) - edges = medrecord.edges_connecting( - node_select().index().is_in(["0", "1"]), - node_select().index().is_in(["1", "2", "3"]), - ) + def query3(node: NodeOperand): + node.index().is_in(["0", "1"]) + + def query4(node: NodeOperand): + node.index().is_in(["1", "2", "3"]) + + edges = medrecord.edges_connecting(query3, query4) self.assertEqual(sorted([0, 2, 3]), sorted(edges)) @@ -490,7 +507,10 @@ def test_remove_nodes(self): self.assertEqual(4, medrecord.node_count()) - attributes = medrecord.remove_nodes(node_select().index().is_in(["0", "1"])) + def query(node: NodeOperand): + node.index().is_in(["0", "1"]) + + attributes = medrecord.remove_nodes(query) self.assertEqual(2, medrecord.node_count()) self.assertEqual( @@ -799,7 +819,10 @@ def test_remove_edges(self): self.assertEqual(4, medrecord.edge_count()) - attributes = medrecord.remove_edges(edge_select().index().is_in([0, 1])) + def query(edge: EdgeOperand): + edge.index().is_in([0, 1]) + + attributes = medrecord.remove_edges(query) self.assertEqual(2, medrecord.edge_count()) self.assertEqual({0: create_edges()[0][2], 1: create_edges()[1][2]}, attributes) @@ -1095,10 +1118,16 @@ def test_add_group(self): self.assertEqual(sorted(["0", "1"]), sorted(nodes_and_edges["nodes"])) self.assertEqual(sorted([0, 1]), sorted(nodes_and_edges["edges"])) + def query1(node: NodeOperand): + node.index().is_in(["0", "1"]) + + def query2(edge: EdgeOperand): + edge.index().is_in([0, 1]) + medrecord.add_group( "3", - node_select().index().is_in(["0", "1"]), - edge_select().index().is_in([0, 1]), + query1, + query2, ) self.assertEqual(4, medrecord.group_count()) @@ -1131,9 +1160,12 @@ def test_invalid_add_group(self): with self.assertRaises(AssertionError): medrecord.add_group("0", ["1", "0"]) + def query(node: NodeOperand): + node.index().equals("0") + # Adding a node to a group that already is in the group should fail with self.assertRaises(AssertionError): - medrecord.add_group("0", node_select().index() == "0") + medrecord.add_group("0", query) def test_remove_groups(self): medrecord = create_medrecord() @@ -1171,7 +1203,10 @@ def test_add_nodes_to_group(self): sorted(medrecord.nodes_in_group("0")), ) - medrecord.add_nodes_to_group("0", node_select().index() == "3") + def query(node: NodeOperand): + node.index().equals("3") + + medrecord.add_nodes_to_group("0", query) self.assertEqual( sorted(["0", "1", "2", "3"]), @@ -1207,9 +1242,12 @@ def test_invalid_add_nodes_to_group(self): with self.assertRaises(AssertionError): medrecord.add_nodes_to_group("0", ["1", "0"]) + def query(node: NodeOperand): + node.index().equals("0") + # Adding a node to a group that already is in the group should fail with self.assertRaises(AssertionError): - medrecord.add_nodes_to_group("0", node_select().index() == "0") + medrecord.add_nodes_to_group("0", query) def test_add_edges_to_group(self): medrecord = create_medrecord() @@ -1229,7 +1267,10 @@ def test_add_edges_to_group(self): sorted(medrecord.edges_in_group("0")), ) - medrecord.add_edges_to_group("0", edge_select().index() == 3) + def query(edge: EdgeOperand): + edge.index().equals(3) + + medrecord.add_edges_to_group("0", query) self.assertEqual( sorted([0, 1, 2, 3]), @@ -1265,9 +1306,12 @@ def test_invalid_add_edges_to_group(self): with self.assertRaises(AssertionError): medrecord.add_edges_to_group("0", [1, 0]) + def query(edge: EdgeOperand): + edge.index().equals(0) + # Adding an edge to a group that already is in the group should fail with self.assertRaises(AssertionError): - medrecord.add_edges_to_group("0", edge_select().index() == 0) + medrecord.add_edges_to_group("0", query) def test_remove_nodes_from_group(self): medrecord = create_medrecord() @@ -1301,7 +1345,10 @@ def test_remove_nodes_from_group(self): sorted(medrecord.nodes_in_group("0")), ) - medrecord.remove_nodes_from_group("0", node_select().index().is_in(["0", "1"])) + def query(node: NodeOperand): + node.index().is_in(["0", "1"]) + + medrecord.remove_nodes_from_group("0", query) self.assertEqual([], medrecord.nodes_in_group("0")) @@ -1318,9 +1365,12 @@ def test_invalid_remove_nodes_from_group(self): with self.assertRaises(IndexError): medrecord.remove_nodes_from_group("50", ["0", "1"]) + def query(node: NodeOperand): + node.index().equals("0") + # Removing a node from a non-existing group should fail with self.assertRaises(IndexError): - medrecord.remove_nodes_from_group("50", node_select().index() == "0") + medrecord.remove_nodes_from_group("50", query) # Removing a non-existing node from a group should fail with self.assertRaises(IndexError): @@ -1362,7 +1412,10 @@ def test_remove_edges_from_group(self): sorted(medrecord.edges_in_group("0")), ) - medrecord.remove_edges_from_group("0", edge_select().index().is_in([0, 1])) + def query(edge: EdgeOperand): + edge.index().is_in([0, 1]) + + medrecord.remove_edges_from_group("0", query) self.assertEqual([], medrecord.edges_in_group("0")) @@ -1379,9 +1432,12 @@ def test_invalid_remove_edges_from_group(self): with self.assertRaises(IndexError): medrecord.remove_edges_from_group("50", [0, 1]) + def query(edge: EdgeOperand): + edge.index().equals(0) + # Removing an edge from a non-existing group should fail with self.assertRaises(IndexError): - medrecord.remove_edges_from_group("50", edge_select().index() == 0) + medrecord.remove_edges_from_group("50", query) # Removing a non-existing edge from a group should fail with self.assertRaises(IndexError): @@ -1434,9 +1490,12 @@ def test_groups_of_node(self): self.assertEqual({"0": ["0"], "1": ["0"]}, medrecord.groups_of_node(["0", "1"])) + def query(node: NodeOperand): + node.index().is_in(["0", "1"]) + self.assertEqual( {"0": ["0"], "1": ["0"]}, - medrecord.groups_of_node(node_select().index().is_in(["0", "1"])), + medrecord.groups_of_node(query), ) def test_invalid_groups_of_node(self): @@ -1459,9 +1518,12 @@ def test_groups_of_edge(self): self.assertEqual({0: ["0"], 1: ["0"]}, medrecord.groups_of_edge([0, 1])) + def query(edge: EdgeOperand): + edge.index().is_in([0, 1]) + self.assertEqual( {0: ["0"], 1: ["0"]}, - medrecord.groups_of_edge(edge_select().index().is_in([0, 1])), + medrecord.groups_of_edge(query), ) def test_invalid_groups_of_edge(self): @@ -1545,7 +1607,10 @@ def test_neighbors(self): {key: sorted(value) for (key, value) in neighbors.items()}, ) - neighbors = medrecord.neighbors(node_select().index().is_in(["0", "1"])) + def query1(node: NodeOperand): + node.index().is_in(["0", "1"]) + + neighbors = medrecord.neighbors(query1) self.assertEqual( {"0": sorted(["1", "3"]), "1": ["0", "2"]}, @@ -1566,9 +1631,10 @@ def test_neighbors(self): {key: sorted(value) for (key, value) in neighbors.items()}, ) - neighbors = medrecord.neighbors( - node_select().index().is_in(["0", "1"]), directed=False - ) + def query2(node: NodeOperand): + node.index().is_in(["0", "1"]) + + neighbors = medrecord.neighbors(query2, directed=False) self.assertEqual( {"0": sorted(["1", "3"]), "1": ["0", "2"]}, diff --git a/medmodels/medrecord/tests/test_overview.py b/medmodels/medrecord/tests/test_overview.py index 3120433d..2357c59f 100644 --- a/medmodels/medrecord/tests/test_overview.py +++ b/medmodels/medrecord/tests/test_overview.py @@ -7,7 +7,7 @@ import medmodels as mm from medmodels.medrecord._overview import extract_attribute_summary, prettify_table -from medmodels.medrecord.querying import edge, node +from medmodels.medrecord.querying import EdgeOperand, NodeOperand def create_medrecord(): @@ -70,46 +70,48 @@ def test_extract_attribute_summary(self): # medrecord without schema medrecord = create_medrecord() + def query1(node: NodeOperand): + node.in_group("Stroke") + # No attributes - no_attributes = extract_attribute_summary( - medrecord.node[node().in_group("Stroke")] - ) + no_attributes = extract_attribute_summary(medrecord.node[query1]) self.assertDictEqual(no_attributes, {}) + def query2(node: NodeOperand): + node.in_group("Patients") + # numeric type - numeric_attribute = extract_attribute_summary( - medrecord.node[node().in_group("Patients")] - ) + numeric_attribute = extract_attribute_summary(medrecord.node[query2]) numeric_expected = {"age": {"min": 20, "max": 70, "mean": 40.0}} self.assertDictEqual(numeric_attribute, numeric_expected) + def query3(node: NodeOperand): + node.in_group("Medications") + # string attributes - str_attributes = extract_attribute_summary( - medrecord.node[node().in_group("Medications")] - ) + str_attributes = extract_attribute_summary(medrecord.node[query3]) self.assertDictEqual( str_attributes, {"ATC": {"values": "Values: B01AA03, B01AF01"}} ) + def query4(node: NodeOperand): + node.in_group("Aspirin") + # nan attribute - nan_attributes = extract_attribute_summary( - medrecord.node[node().in_group("Aspirin")] - ) + nan_attributes = extract_attribute_summary(medrecord.node[query4]) + self.assertDictEqual(nan_attributes, {"ATC": {"values": "-"}}) + def query5(edge: EdgeOperand): + edge.source_node().in_group("Medications") + edge.target_node().in_group("Patients") + # temporal attributes - temp_attributes = extract_attribute_summary( - medrecord.edge[ - medrecord.select_edges( - edge().connected_source_with(node().in_group("Medications")) - & edge().connected_target_with(node().in_group("Patients")) - ) - ] - ) + temp_attributes = extract_attribute_summary(medrecord.edge[query5]) self.assertDictEqual( temp_attributes, @@ -121,14 +123,13 @@ def test_extract_attribute_summary(self): }, ) + def query6(edge: EdgeOperand): + edge.source_node().in_group("Stroke") + edge.target_node().in_group("Patients") + # mixed attributes mixed_attributes = extract_attribute_summary( - medrecord.edge[ - medrecord.select_edges( - edge().connected_source_with(node().in_group("Stroke")) - & edge().connected_target_with(node().in_group("Patients")) - ) - ] + medrecord.edge[medrecord.select_edges(query6)] ) self.assertDictEqual( mixed_attributes, @@ -158,9 +159,12 @@ def test_extract_attribute_summary(self): }, ) + def query7(edge: EdgeOperand): + edge.in_group("patient_diagnosis") + # compare schema and not schema patient_diagnosis = extract_attribute_summary( - mr_schema.edge[edge().in_group("patient_diagnosis")], + mr_schema.edge[query7], schema=mr_schema.schema.group("patient_diagnosis").edges, ) diff --git a/medmodels/medrecord/tests/test_querying.py b/medmodels/medrecord/tests/test_querying.py deleted file mode 100644 index 808961ab..00000000 --- a/medmodels/medrecord/tests/test_querying.py +++ /dev/null @@ -1,1170 +0,0 @@ -import unittest -from typing import List, Tuple - -from medmodels import MedRecord -from medmodels.medrecord import edge, node -from medmodels.medrecord.types import Attributes, NodeIndex - - -def create_nodes() -> List[Tuple[NodeIndex, Attributes]]: - return [ - ( - "0", - { - "lorem": "ipsum", - "dolor": " ipsum ", - "test": "Ipsum", - "integer": 1, - "float": 0.5, - }, - ), - ("1", {"amet": "consectetur"}), - ("2", {"adipiscing": "elit"}), - ("3", {}), - ] - - -def create_edges() -> List[Tuple[NodeIndex, NodeIndex, Attributes]]: - return [ - ("0", "1", {"sed": "do", "eiusmod": "tempor", "dolor": " do ", "test": "DO"}), - ("1", "2", {"incididunt": "ut"}), - ("0", "2", {"test": 1, "integer": 1, "float": 0.5}), - ("0", "2", {"test": 0}), - ] - - -def create_medrecord() -> MedRecord: - return MedRecord.from_tuples(create_nodes(), create_edges()) - - -class TestMedRecord(unittest.TestCase): - def test_select_nodes_node(self): - medrecord = create_medrecord() - - medrecord.add_group("test", ["0"]) - - # Node in group - self.assertEqual(["0"], medrecord.select_nodes(node().in_group("test"))) - - # Node has attribute - self.assertEqual(["0"], medrecord.select_nodes(node().has_attribute("lorem"))) - - # Node has outgoing edge with - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().has_outgoing_edge_with(edge().index().equal(0)) - ), - ) - - # Node has incoming edge with - self.assertEqual( - ["1"], - medrecord.select_nodes( - node().has_incoming_edge_with(edge().index().equal(0)) - ), - ) - - # Node has edge with - self.assertEqual( - sorted(["0", "1"]), - sorted( - medrecord.select_nodes(node().has_edge_with(edge().index().equal(0))) - ), - ) - - # Node has neighbor with - self.assertEqual( - sorted(["0", "1"]), - sorted( - medrecord.select_nodes( - node().has_neighbor_with(node().index().equal("2")) - ) - ), - ) - self.assertEqual( - sorted(["0"]), - sorted( - medrecord.select_nodes( - node().has_neighbor_with(node().index().equal("1"), directed=True) - ) - ), - ) - - # Node has neighbor with - self.assertEqual( - sorted(["0", "2"]), - sorted( - medrecord.select_nodes( - node().has_neighbor_with(node().index().equal("1"), directed=False) - ) - ), - ) - - def test_select_nodes_node_index(self): - medrecord = create_medrecord() - - # Index greater - self.assertEqual( - sorted(["2", "3"]), - sorted(medrecord.select_nodes(node().index().greater("1"))), - ) - - # Index less - self.assertEqual( - sorted(["0", "1"]), sorted(medrecord.select_nodes(node().index().less("2"))) - ) - - # Index greater or equal - self.assertEqual( - sorted(["1", "2", "3"]), - sorted(medrecord.select_nodes(node().index().greater_or_equal("1"))), - ) - - # Index less or equal - self.assertEqual( - sorted(["0", "1", "2"]), - sorted(medrecord.select_nodes(node().index().less_or_equal("2"))), - ) - - # Index equal - self.assertEqual(["1"], medrecord.select_nodes(node().index().equal("1"))) - - # Index not equal - self.assertEqual( - sorted(["0", "2", "3"]), - sorted(medrecord.select_nodes(node().index().not_equal("1"))), - ) - - # Index in - self.assertEqual(["1"], medrecord.select_nodes(node().index().is_in(["1"]))) - - # Index not in - self.assertEqual( - sorted(["0", "2", "3"]), - sorted(medrecord.select_nodes(node().index().not_in(["1"]))), - ) - - # Index starts with - self.assertEqual( - ["1"], - medrecord.select_nodes(node().index().starts_with("1")), - ) - - # Index ends with - self.assertEqual( - ["1"], - medrecord.select_nodes(node().index().ends_with("1")), - ) - - # Index contains - self.assertEqual( - ["1"], - medrecord.select_nodes(node().index().contains("1")), - ) - - def test_select_nodes_node_attribute(self): - medrecord = create_medrecord() - - # Attribute greater - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem").greater("ipsum")) - ) - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem") > "ipsum") - ) - - # Attribute less - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem").less("ipsum")) - ) - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem") < "ipsum") - ) - - # Attribute greater or equal - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").greater_or_equal("ipsum")), - ) - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem") >= "ipsum") - ) - - # Attribute less or equal - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").less_or_equal("ipsum")), - ) - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem") <= "ipsum") - ) - - # Attribute equal - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem").equal("ipsum")) - ) - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem") == "ipsum") - ) - - # Attribute not equal - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem").not_equal("ipsum")) - ) - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem") != "ipsum") - ) - - # Attribute in - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem").is_in(["ipsum"])) - ) - - # Attribute not in - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem").not_in(["ipsum"])) - ) - - # Attribute starts with - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").starts_with("ip")), - ) - - # Attribute ends with - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").ends_with("um")), - ) - - # Attribute contains - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").contains("su")), - ) - - # Attribute compare to attribute - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem")) - ), - ) - - # Attribute compare to attribute add - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").add("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") + "10" - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").add("10")) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") + "10" - ), - ) - - # Attribute compare to attribute sub - # Returns nothing because can't sub a string - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").sub("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") + "10" - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").sub("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") - "10" - ), - ) - - # Attribute compare to attribute sub - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").sub(10)) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").sub(10)) - ), - ) - - # Attribute compare to attribute mul - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").mul(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") * 2 - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").mul(2)) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") * 2 - ), - ) - - # Attribute compare to attribute div - # Returns nothing because can't div a string - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").div("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") / "10" - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").div("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") / "10" - ), - ) - - # Attribute compare to attribute div - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").div(2)) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").div(2)) - ), - ) - - # Attribute compare to attribute pow - # Returns nothing because can't pow a string - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").pow("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") ** "10" - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").pow("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") ** "10" - ), - ) - - # Attribute compare to attribute pow - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").pow(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").pow(2)) - ), - ) - - # Attribute compare to attribute mod - # Returns nothing because can't mod a string - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").mod("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") % "10" - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").mod("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") % "10" - ), - ) - - # Attribute compare to attribute mod - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").mod(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").mod(2)) - ), - ) - - # Attribute compare to attribute round - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").round()) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").round()) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("float").round()) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("float").not_equal(node().attribute("float").round()) - ), - ) - - # Attribute compare to attribute round - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("float").ceil()) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("float").not_equal(node().attribute("float").ceil()) - ), - ) - - # Attribute compare to attribute floor - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("float").floor()) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("float").not_equal(node().attribute("float").floor()) - ), - ) - - # Attribute compare to attribute abs - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").abs()) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("integer").not_equal(node().attribute("integer").abs()) - ), - ) - - # Attribute compare to attribute sqrt - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").sqrt()) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").sqrt()) - ), - ) - - # Attribute compare to attribute trim - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("dolor").trim()) - ), - ) - - # Attribute compare to attribute trim_start - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("dolor").trim_start()) - ), - ) - - # Attribute compare to attribute trim_end - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("dolor").trim_end()) - ), - ) - - # Attribute compare to attribute lowercase - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("test").lowercase()) - ), - ) - - # Attribute compare to attribute uppercase - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("test").uppercase()) - ), - ) - - def test_select_edges_edge(self): - medrecord = create_medrecord() - - medrecord.add_group("test", edges=[0]) - - # Edge connected to target - self.assertEqual( - sorted([1, 2, 3]), - sorted(medrecord.select_edges(edge().connected_target("2"))), - ) - - # Edge connected to source - self.assertEqual( - sorted([0, 2, 3]), - sorted(medrecord.select_edges(edge().connected_source("0"))), - ) - - # Edge connected - self.assertEqual( - sorted([0, 1]), - sorted(medrecord.select_edges(edge().connected("1"))), - ) - - # Edge in group - self.assertEqual( - [0], - medrecord.select_edges(edge().in_group("test")), - ) - - # Edge has attribute - self.assertEqual( - [0], - medrecord.select_edges(edge().has_attribute("sed")), - ) - - # Edge connected to target with - self.assertEqual( - [0], - medrecord.select_edges( - edge().connected_target_with(node().index().equal("1")) - ), - ) - - # Edge connected to source with - self.assertEqual( - sorted([0, 2, 3]), - sorted( - medrecord.select_edges( - edge().connected_source_with(node().index().equal("0")) - ) - ), - ) - - # Edge connected with - self.assertEqual( - sorted([0, 1]), - sorted( - medrecord.select_edges(edge().connected_with(node().index().equal("1"))) - ), - ) - - # Edge has parallel edges with - self.assertEqual( - sorted([2, 3]), - sorted( - medrecord.select_edges( - edge().has_parallel_edges_with(edge().has_attribute("test")) - ) - ), - ) - - # Edge has parallel edges with self comparison - self.assertEqual( - [2], - medrecord.select_edges( - edge().has_parallel_edges_with_self_comparison( - edge().attribute("test").equal(edge().attribute("test").sub(1)) - ) - ), - ) - - def test_select_edges_edge_index(self): - medrecord = create_medrecord() - - # Index greater - self.assertEqual( - sorted([2, 3]), - sorted(medrecord.select_edges(edge().index().greater(1))), - ) - - # Index less - self.assertEqual( - [0], - medrecord.select_edges(edge().index().less(1)), - ) - - # Index greater or equal - self.assertEqual( - sorted([1, 2, 3]), - sorted(medrecord.select_edges(edge().index().greater_or_equal(1))), - ) - - # Index less or equal - self.assertEqual( - sorted([0, 1]), - sorted(medrecord.select_edges(edge().index().less_or_equal(1))), - ) - - # Index equal - self.assertEqual( - [1], - medrecord.select_edges(edge().index().equal(1)), - ) - - # Index not equal - self.assertEqual( - sorted([0, 2, 3]), - sorted(medrecord.select_edges(edge().index().not_equal(1))), - ) - - # Index in - self.assertEqual( - [1], - medrecord.select_edges(edge().index().is_in([1])), - ) - - # Index not in - self.assertEqual( - sorted([0, 2, 3]), - sorted(medrecord.select_edges(edge().index().not_in([1]))), - ) - - def test_select_edges_edges_attribute(self): - medrecord = create_medrecord() - - # Attribute greater - self.assertEqual( - [], - medrecord.select_edges(edge().attribute("sed").greater("do")), - ) - - # Attribute less - self.assertEqual( - [], - medrecord.select_edges(edge().attribute("sed").less("do")), - ) - - # Attribute greater or equal - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").greater_or_equal("do")), - ) - - # Attribute less or equal - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").less_or_equal("do")), - ) - - # Attribute equal - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").equal("do")), - ) - - # Attribute not equal - self.assertEqual( - [], - medrecord.select_edges(edge().attribute("sed").not_equal("do")), - ) - - # Attribute in - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").is_in(["do"])), - ) - - # Attribute not in - self.assertEqual( - [], - medrecord.select_edges(edge().attribute("sed").not_in(["do"])), - ) - - # Attribute starts with - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").starts_with("d")), - ) - - # Attribute ends with - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").ends_with("o")), - ) - - # Attribute contains - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").contains("d")), - ) - - # Attribute compare to attribute - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed")) - ), - ) - - # Attribute compare to attribute add - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed").add("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") == edge().attribute("sed") + "10" - ), - ) - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed").add("10")) - ), - ) - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed") != edge().attribute("sed") + "10" - ), - ) - - # Attribute compare to attribute sub - # Returns nothing because can't sub a string - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed").sub("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") == edge().attribute("sed") - "10" - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed").sub("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") != edge().attribute("sed") - "10" - ), - ) - - # Attribute compare to attribute sub - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").sub(10)) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").sub(10)) - ), - ) - - # Attribute compare to attribute mul - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed").mul(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") == edge().attribute("sed") * 2 - ), - ) - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed").mul(2)) - ), - ) - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed") != edge().attribute("sed") * 2 - ), - ) - - # Attribute compare to attribute div - # Returns nothing because can't div a string - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed").div("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") == edge().attribute("sed") / "10" - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed").div("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") != edge().attribute("sed") / "10" - ), - ) - - # Attribute compare to attribute div - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").div(2)) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").div(2)) - ), - ) - - # Attribute compare to attribute pow - # Returns nothing because can't pow a string - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem").equal(edge().attribute("lorem").pow("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem") == edge().attribute("lorem") ** "10" - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem").not_equal(edge().attribute("lorem").pow("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem") != edge().attribute("lorem") ** "10" - ), - ) - - # Attribute compare to attribute pow - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").pow(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").pow(2)) - ), - ) - - # Attribute compare to attribute mod - # Returns nothing because can't mod a string - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem").equal(edge().attribute("lorem").mod("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem") == edge().attribute("lorem") % "10" - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem").not_equal(edge().attribute("lorem").mod("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem") != edge().attribute("lorem") % "10" - ), - ) - - # Attribute compare to attribute mod - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").mod(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").mod(2)) - ), - ) - - # Attribute compare to attribute round - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed").round()) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed").round()) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("float").round()) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("float").not_equal(edge().attribute("float").round()) - ), - ) - - # Attribute compare to attribute ceil - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("float").ceil()) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("float").not_equal(edge().attribute("float").ceil()) - ), - ) - - # Attribute compare to attribute floor - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("float").floor()) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("float").not_equal(edge().attribute("float").floor()) - ), - ) - - # Attribute compare to attribute abs - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").abs()) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("integer").not_equal(edge().attribute("integer").abs()) - ), - ) - - # Attribute compare to attribute sqrt - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").sqrt()) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").sqrt()) - ), - ) - - # Attribute compare to attribute trim - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("dolor").trim()) - ), - ) - - # Attribute compare to attribute trim_start - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("dolor").trim_start()) - ), - ) - - # Attribute compare to attribute trim_end - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("dolor").trim_end()) - ), - ) - - # Attribute compare to attribute lowercase - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("test").lowercase()) - ), - ) - - # Attribute compare to attribute uppercase - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("test").uppercase()) - ), - ) diff --git a/medmodels/treatment_effect/builder.py b/medmodels/treatment_effect/builder.py index 9a6e5be6..15ceb4c7 100644 --- a/medmodels/treatment_effect/builder.py +++ b/medmodels/treatment_effect/builder.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Literal, Optional import medmodels.treatment_effect.treatment_effect as tee -from medmodels.medrecord.querying import NodeOperation +from medmodels.medrecord.querying import NodeQuery from medmodels.medrecord.types import ( Group, MedRecordAttribute, @@ -31,7 +31,7 @@ class TreatmentEffectBuilder: outcome_before_treatment_days: Optional[int] - filter_controls_operation: Optional[NodeOperation] + filter_controls_query: Optional[NodeQuery] matching_method: Optional[MatchingMethod] matching_essential_covariates: Optional[MedRecordAttributeInputList] @@ -202,17 +202,17 @@ def with_outcome_before_treatment_exclusion( return self - def filter_controls(self, operation: NodeOperation) -> TreatmentEffectBuilder: - """Filter the control group based on the provided operation. + def filter_controls(self, query: NodeQuery) -> TreatmentEffectBuilder: + """Filter the control group based on the provided query. Args: - operation (NodeOperation): The operation to be applied to the control group. + query (NodeQuery): The query to be applied to the control group. Returns: TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder with updated time attribute. """ - self.filter_controls_operation = operation + self.filter_controls_query = query return self diff --git a/medmodels/treatment_effect/tests/test_treatment_effect.py b/medmodels/treatment_effect/tests/test_treatment_effect.py index aca74858..78c56da4 100644 --- a/medmodels/treatment_effect/tests/test_treatment_effect.py +++ b/medmodels/treatment_effect/tests/test_treatment_effect.py @@ -6,7 +6,7 @@ import pandas as pd from medmodels import MedRecord -from medmodels.medrecord import edge, node +from medmodels.medrecord.querying import NodeOperand from medmodels.medrecord.types import NodeIndex from medmodels.treatment_effect.estimate import ContingencyTable, SubjectIndices from medmodels.treatment_effect.treatment_effect import TreatmentEffect @@ -245,8 +245,8 @@ def assert_treatment_effects_equal( treatment_effect2._outcome_before_treatment_days, ) test_case.assertEqual( - treatment_effect1._filter_controls_operation, - treatment_effect2._filter_controls_operation, + treatment_effect1._filter_controls_query, + treatment_effect2._filter_controls_query, ) test_case.assertEqual( treatment_effect1._matching_method, treatment_effect2._matching_method @@ -620,14 +620,14 @@ def test_outcome_before_treatment(self): tee3._find_outcomes(medrecord=self.medrecord, treated_group=treated_group) def test_filter_controls(self): + def query1(node: NodeOperand): + node.neighbors().index().equals("M2") + tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") - .filter_controls( - node().has_outgoing_edge_with(edge().connected_target("M2")) - | node().has_incoming_edge_with(edge().connected_source("M2")) - ) + .filter_controls(query1) .build() ) counts_tee = tee.estimate._compute_subject_counts(self.medrecord) @@ -635,11 +635,15 @@ def test_filter_controls(self): self.assertEqual(counts_tee, (2, 1, 1, 2)) # filter females only + + def query2(node: NodeOperand): + node.attribute("gender").equals("female") + tee2 = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") .with_outcome("Stroke") - .filter_controls(node().attribute("gender").equal("female")) + .filter_controls(query2) .build() ) diff --git a/medmodels/treatment_effect/treatment_effect.py b/medmodels/treatment_effect/treatment_effect.py index 94a3980e..7a0d5aca 100644 --- a/medmodels/treatment_effect/treatment_effect.py +++ b/medmodels/treatment_effect/treatment_effect.py @@ -14,8 +14,7 @@ from typing import Any, Dict, Literal, Optional, Set, Tuple from medmodels import MedRecord -from medmodels.medrecord import node -from medmodels.medrecord.querying import NodeOperation +from medmodels.medrecord.querying import EdgeDirection, NodeOperand, NodeQuery from medmodels.medrecord.types import ( Group, MedRecordAttribute, @@ -50,7 +49,7 @@ class TreatmentEffect: _outcome_before_treatment_days: Optional[int] - _filter_controls_operation: Optional[NodeOperation] + _filter_controls_query: Optional[NodeQuery] _matching_method: Optional[MatchingMethod] _matching_essential_covariates: MedRecordAttributeInputList @@ -92,7 +91,7 @@ def _set_configuration( follow_up_period_days: int = 365, follow_up_period_reference: Literal["first", "last"] = "last", outcome_before_treatment_days: Optional[int] = None, - filter_controls_operation: Optional[NodeOperation] = None, + filter_controls_query: Optional[NodeQuery] = None, matching_method: Optional[MatchingMethod] = None, matching_essential_covariates: MedRecordAttributeInputList = ["gender", "age"], matching_one_hot_covariates: MedRecordAttributeInputList = ["gender"], @@ -127,8 +126,8 @@ def _set_configuration( reference point for the follow-up period. Defaults to "last". outcome_before_treatment_days (Optional[int], optional): The number of days before the treatment to consider for outcomes. Defaults to None. - filter_controls_operation (Optional[NodeOperation], optional): An optional - operation to filter the control group based on specified criteria. + filter_controls_query (Optional[NodeQuery], optional): An optional + query to filter the control group based on specified criteria. Defaults to None. matching_method (Optional[MatchingMethod]): The method to match treatment and control groups. Defaults to None. @@ -158,7 +157,7 @@ def _set_configuration( treatment_effect._follow_up_period_days = follow_up_period_days treatment_effect._follow_up_period_reference = follow_up_period_reference treatment_effect._outcome_before_treatment_days = outcome_before_treatment_days - treatment_effect._filter_controls_operation = filter_controls_operation + treatment_effect._filter_controls_query = filter_controls_query treatment_effect._matching_method = matching_method treatment_effect._matching_essential_covariates = matching_essential_covariates @@ -206,7 +205,7 @@ def _find_groups( control_group=control_group, treated_group=treated_group, rejected_nodes=washout_nodes | outcome_before_treatment_nodes, - filter_controls_operation=self._filter_controls_operation, + filter_controls_query=self._filter_controls_query, ) return ( @@ -234,18 +233,14 @@ def _find_treated_patients(self, medrecord: MedRecord) -> Set[NodeIndex]: treatments = medrecord.nodes_in_group(self._treatments_group) + def query(node: NodeOperand): + node.in_group(self._patients_group) + + node.neighbors(edge_direction=EdgeDirection.BOTH).index().equals(treatment) + # Create the group with all the patients that underwent the treatment for treatment in treatments: - treated_group.update( - set( - medrecord.select_nodes( - node().in_group(self._patients_group) - & node().has_neighbor_with( - node().index() == treatment, directed=False - ) - ) - ) - ) + treated_group.update(set(medrecord.select_nodes(query))) if not treated_group: raise ValueError( "No patients found for the treatment groups in this MedRecord." @@ -288,14 +283,14 @@ def _find_outcomes( f"No outcomes found in the MedRecord for group {self._outcomes_group}" ) + def query(node: NodeOperand): + node.index().is_in(list(treated_group)) + + # This could probably be refactored to a proper query + node.neighbors(edge_direction=EdgeDirection.BOTH).index().equals(outcome) + for outcome in outcomes: - nodes_to_check = set( - medrecord.select_nodes( - node().has_neighbor_with(node().index() == outcome, directed=False) - # This could probably be refactored to a proper query - & node().index().is_in(list(treated_group)) - ) - ) + nodes_to_check = set(medrecord.select_nodes(query)) # Find patients that had the outcome before the treatment if self._outcome_before_treatment_days: @@ -399,12 +394,12 @@ def _find_controls( control_group: Set[NodeIndex], treated_group: Set[NodeIndex], rejected_nodes: Set[NodeIndex] = set(), - filter_controls_operation: Optional[NodeOperation] = None, + filter_controls_query: Optional[NodeQuery] = None, ) -> Tuple[Set[NodeIndex], Set[NodeIndex]]: """Identifies control groups among patients who did not undergo the specified treatments. It takes the control group and removes the rejected nodes, the treated nodes, - and applies the filter_controls_operation if specified. + and applies the filter_controls_query if specified. Control groups are divided into those who had the outcome (control_outcome_true) and those who did not (control_outcome_false), @@ -419,8 +414,8 @@ def _find_controls( treatment. rejected_nodes (Set[NodeIndex]): A set of patient nodes that were rejected due to the washout period or outcome before treatment. - filter_controls_operation (Optional[NodeOperation], optional): An optional - operation to filter the control group based on specified criteria. + filter_controls_query (Optional[NodeQuery], optional): An optional + query to filter the control group based on specified criteria. Defaults to None. Returns: @@ -436,9 +431,9 @@ def _find_controls( outcome group. """ # Apply the filter to the control group if specified - if filter_controls_operation: + if filter_controls_query: control_group = ( - set(medrecord.select_nodes(filter_controls_operation)) & control_group + set(medrecord.select_nodes(filter_controls_query)) & control_group ) control_group = control_group - treated_group - rejected_nodes @@ -453,17 +448,15 @@ def _find_controls( f"No outcomes found in the MedRecord for group {self._outcomes_group}" ) + def query(node: NodeOperand): + node.index().is_in(list(control_group)) + + node.neighbors(edge_direction=EdgeDirection.BOTH).index().equals(outcome) + # Finding the patients that had the outcome in the control group for outcome in outcomes: - control_outcome_true.update( - medrecord.select_nodes( - # This could probably be refactored to a proper query - node().index().is_in(list(control_group)) - & node().has_neighbor_with( - node().index() == outcome, directed=False - ) - ) - ) + control_outcome_true.update(medrecord.select_nodes(query)) + control_outcome_false = control_group - control_outcome_true return control_outcome_true, control_outcome_false From 04e469d2dc43dbea649ac327d3ad2d604d135db2 Mon Sep 17 00:00:00 2001 From: Jakob Kraus <52459467+JabobKrauskopf@users.noreply.github.com> Date: Thu, 19 Sep 2024 15:28:11 +0200 Subject: [PATCH 2/8] refactor: update interface (#212) --- medmodels/medrecord/querying.pyi | 391 ++++++++++++------ medmodels/medrecord/tests/test_medrecord.py | 14 +- .../tests/test_treatment_effect.py | 4 +- .../treatment_effect/treatment_effect.py | 8 +- 4 files changed, 270 insertions(+), 147 deletions(-) diff --git a/medmodels/medrecord/querying.pyi b/medmodels/medrecord/querying.pyi index 9b3a7c07..64e2a316 100644 --- a/medmodels/medrecord/querying.pyi +++ b/medmodels/medrecord/querying.pyi @@ -4,7 +4,12 @@ import sys from enum import Enum, auto from typing import Callable, List, Union -from medmodels.medrecord.types import Group, MedRecordAttribute, MedRecordValue +from medmodels.medrecord.types import ( + EdgeIndex, + Group, + MedRecordAttribute, + MedRecordValue, +) if sys.version_info >= (3, 10): from typing import TypeAlias @@ -14,11 +19,26 @@ else: NodeQuery: TypeAlias = Callable[[NodeOperand], None] EdgeQuery: TypeAlias = Callable[[EdgeOperand], None] -ValueOperand: TypeAlias = Union[NodeValueOperand, EdgeValueOperand, MedRecordValue] -ValuesOperand: TypeAlias = Union[ - NodeValuesOperand, EdgeValuesOperand, List[MedRecordValue] +ValueComparisonOperand: TypeAlias = Union[ + SingleValueOperand, MedRecordValue, MultipleValuesOperand, List[MedRecordValue] ] -ComparisonOperand: TypeAlias = Union[ValueOperand, ValuesOperand] +ValueArithmeticOperand: TypeAlias = Union[SingleValueOperand, MedRecordValue] +AttributeComparisonOperand: TypeAlias = Union[ + SingleAttributeOperand, + MedRecordAttribute, + MultipleAttributesOperand, + List[MedRecordAttribute], +] +AttributeArithmeticOperand: TypeAlias = Union[ + SingleAttributeOperand, MedRecordAttribute +] +EdgeIndexComparisonOperand: TypeAlias = Union[ + EdgeIndexOperand, + EdgeIndex, + EdgeIndicesOperand, + List[EdgeIndex], +] +EdgeIndexArithmeticOperand: TypeAlias = Union[EdgeIndexOperand, EdgeIndex] class EdgeDirection(Enum): INCOMING = auto() @@ -26,8 +46,9 @@ class EdgeDirection(Enum): BOTH = auto() class NodeOperand: - def attribute(self, attribute: MedRecordAttribute) -> NodeValuesOperand: ... - def index(self) -> NodeValuesOperand: ... + def attribute(self, attribute: MedRecordAttribute) -> MultipleValuesOperand: ... + def attributes(self) -> MultipleAttributesOperand: ... + def index(self) -> NodeIndexOperand: ... def in_group(self, group: Union[Group, List[Group]]) -> None: ... def has_attribute( self, attribute: Union[MedRecordAttribute, List[MedRecordAttribute]] @@ -40,8 +61,9 @@ class NodeOperand: def either_or(self, either: NodeQuery, or_: NodeQuery) -> None: ... class EdgeOperand: - def attribute(self, attribute: MedRecordAttribute) -> EdgeValuesOperand: ... - def index(self) -> EdgeValuesOperand: ... + def attribute(self, attribute: MedRecordAttribute) -> MultipleValuesOperand: ... + def attributes(self) -> MultipleAttributesOperand: ... + def index(self) -> EdgeIndexOperand: ... def in_group(self, group: Union[Group, List[Group]]) -> None: ... def has_attribute( self, attribute: Union[MedRecordAttribute, List[MedRecordAttribute]] @@ -50,132 +72,231 @@ class EdgeOperand: def target_node(self) -> NodeOperand: ... def either_or(self, either: EdgeQuery, or_: EdgeQuery) -> None: ... -class NodeValuesOperand: - def max(self) -> NodeValueOperand: ... - def min(self) -> NodeValueOperand: ... - def mean(self) -> NodeValueOperand: ... - def all(self) -> NodeValueOperand: ... - def any(self) -> NodeValueOperand: ... - def greater_than(self, value: ComparisonOperand) -> None: ... - def greater_than_or_equal(self, value: ComparisonOperand) -> None: ... - def less_than(self, value: ComparisonOperand) -> None: ... - def less_than_or_equal(self, value: ComparisonOperand) -> None: ... - def equals(self, value: ComparisonOperand) -> None: ... - def not_equals(self, value: ComparisonOperand) -> None: ... - def is_in(self, values: ValuesOperand) -> None: ... - def is_not_in(self, values: ValuesOperand) -> None: ... - def starts_with(self, value: ComparisonOperand) -> None: ... - def ends_with(self, value: ComparisonOperand) -> None: ... - def contains(self, value: ComparisonOperand) -> None: ... - def add(self, value: ComparisonOperand) -> NodeValuesOperand: ... - def subtract(self, value: ComparisonOperand) -> NodeValuesOperand: ... - def multiply(self, value: ComparisonOperand) -> NodeValuesOperand: ... - def divide(self, value: ComparisonOperand) -> NodeValuesOperand: ... - def modulo(self, value: ComparisonOperand) -> NodeValuesOperand: ... - def power(self, value: ComparisonOperand) -> NodeValuesOperand: ... - def round(self) -> NodeValuesOperand: ... - def ceil(self) -> NodeValuesOperand: ... - def floor(self) -> NodeValuesOperand: ... - def absolute(self) -> NodeValuesOperand: ... - def sqrt(self) -> NodeValuesOperand: ... - def trim(self) -> NodeValuesOperand: ... - def trim_start(self) -> NodeValuesOperand: ... - def trim_end(self) -> NodeValuesOperand: ... - def lowercase(self) -> NodeValuesOperand: ... - def uppercase(self) -> NodeValuesOperand: ... - def slice(self, start: int, end: int) -> NodeValuesOperand: ... +class MultipleValuesOperand: + def max(self) -> SingleValueOperand: ... + def min(self) -> SingleValueOperand: ... + def mean(self) -> SingleValueOperand: ... + def any(self, query: Callable[[SingleValueOperand], None]) -> None: ... + def greater_than(self, value: ValueComparisonOperand) -> None: ... + def greater_than_or_equal(self, value: ValueComparisonOperand) -> None: ... + def less_than(self, value: ValueComparisonOperand) -> None: ... + def less_than_or_equal(self, value: ValueComparisonOperand) -> None: ... + def equal_to(self, value: ValueComparisonOperand) -> None: ... + def not_equal_to(self, value: ValueComparisonOperand) -> None: ... + def is_in(self, values: ValueComparisonOperand) -> None: ... + def is_not_in(self, values: ValueComparisonOperand) -> None: ... + def starts_with(self, value: ValueComparisonOperand) -> None: ... + def ends_with(self, value: ValueComparisonOperand) -> None: ... + def contains(self, value: ValueComparisonOperand) -> None: ... + def add(self, value: ValueArithmeticOperand) -> MultipleValuesOperand: ... + def subtract(self, value: ValueArithmeticOperand) -> MultipleValuesOperand: ... + def multiply(self, value: ValueArithmeticOperand) -> MultipleValuesOperand: ... + def divide(self, value: ValueArithmeticOperand) -> MultipleValuesOperand: ... + def modulo(self, value: ValueArithmeticOperand) -> MultipleValuesOperand: ... + def power(self, value: ValueArithmeticOperand) -> MultipleValuesOperand: ... + def round(self) -> MultipleValuesOperand: ... + def ceil(self) -> MultipleValuesOperand: ... + def floor(self) -> MultipleValuesOperand: ... + def absolute(self) -> MultipleValuesOperand: ... + def sqrt(self) -> MultipleValuesOperand: ... + def trim(self) -> MultipleValuesOperand: ... + def trim_start(self) -> MultipleValuesOperand: ... + def trim_end(self) -> MultipleValuesOperand: ... + def lowercase(self) -> MultipleValuesOperand: ... + def uppercase(self) -> MultipleValuesOperand: ... + def slice(self, start: int, end: int) -> MultipleValuesOperand: ... + def either_or( + self, either: MultipleValuesOperand, or_: MultipleValuesOperand + ) -> None: ... + +class SingleValueOperand: + def greater_than(self, value: ValueComparisonOperand) -> None: ... + def greater_than_or_equal(self, value: ValueComparisonOperand) -> None: ... + def less_than(self, value: ValueComparisonOperand) -> None: ... + def less_than_or_equal(self, value: ValueComparisonOperand) -> None: ... + def equal_to(self, value: ValueComparisonOperand) -> None: ... + def not_equal_to(self, value: ValueComparisonOperand) -> None: ... + def is_in(self, values: ValueComparisonOperand) -> None: ... + def is_not_in(self, values: ValueComparisonOperand) -> None: ... + def starts_with(self, value: ValueComparisonOperand) -> None: ... + def ends_with(self, value: ValueComparisonOperand) -> None: ... + def contains(self, value: ValueComparisonOperand) -> None: ... + def add(self, value: ValueArithmeticOperand) -> SingleValueOperand: ... + def subtract(self, value: ValueArithmeticOperand) -> SingleValueOperand: ... + def multiply(self, value: ValueArithmeticOperand) -> SingleValueOperand: ... + def divide(self, value: ValueArithmeticOperand) -> SingleValueOperand: ... + def modulo(self, value: ValueArithmeticOperand) -> SingleValueOperand: ... + def power(self, value: ValueArithmeticOperand) -> SingleValueOperand: ... + def round(self) -> SingleValueOperand: ... + def ceil(self) -> SingleValueOperand: ... + def floor(self) -> SingleValueOperand: ... + def absolute(self) -> SingleValueOperand: ... + def sqrt(self) -> SingleValueOperand: ... + def trim(self) -> SingleValueOperand: ... + def trim_start(self) -> SingleValueOperand: ... + def trim_end(self) -> SingleValueOperand: ... + def lowercase(self) -> SingleValueOperand: ... + def uppercase(self) -> SingleValueOperand: ... + def slice(self, start: int, end: int) -> SingleValueOperand: ... + def either_or( + self, either: SingleValueOperand, or_: SingleValueOperand + ) -> None: ... + +class MultipleAttributesOperand: + def max(self) -> SingleAttributeOperand: ... + def min(self) -> SingleAttributeOperand: ... + def mean(self) -> SingleAttributeOperand: ... + def any(self, query: Callable[[SingleAttributeOperand], None]) -> None: ... + def greater_than(self, value: AttributeComparisonOperand) -> None: ... + def greater_than_or_equal(self, value: AttributeComparisonOperand) -> None: ... + def less_than(self, value: AttributeComparisonOperand) -> None: ... + def less_than_or_equal(self, value: AttributeComparisonOperand) -> None: ... + def equal_to(self, value: AttributeComparisonOperand) -> None: ... + def not_equal_to(self, value: AttributeComparisonOperand) -> None: ... + def is_in(self, values: AttributeComparisonOperand) -> None: ... + def is_not_in(self, values: AttributeComparisonOperand) -> None: ... + def starts_with(self, value: AttributeComparisonOperand) -> None: ... + def ends_with(self, value: AttributeComparisonOperand) -> None: ... + def contains(self, value: AttributeComparisonOperand) -> None: ... + def add(self, value: AttributeArithmeticOperand) -> MultipleAttributesOperand: ... + def subtract( + self, value: AttributeArithmeticOperand + ) -> MultipleAttributesOperand: ... + def multiply( + self, value: AttributeArithmeticOperand + ) -> MultipleAttributesOperand: ... + def divide( + self, value: AttributeArithmeticOperand + ) -> MultipleAttributesOperand: ... + def modulo( + self, value: AttributeArithmeticOperand + ) -> MultipleAttributesOperand: ... + def power(self, value: AttributeArithmeticOperand) -> MultipleAttributesOperand: ... + def round(self) -> MultipleAttributesOperand: ... + def ceil(self) -> MultipleAttributesOperand: ... + def floor(self) -> MultipleAttributesOperand: ... + def absolute(self) -> MultipleAttributesOperand: ... + def sqrt(self) -> MultipleAttributesOperand: ... + def trim(self) -> MultipleAttributesOperand: ... + def trim_start(self) -> MultipleAttributesOperand: ... + def trim_end(self) -> MultipleAttributesOperand: ... + def lowercase(self) -> MultipleAttributesOperand: ... + def uppercase(self) -> MultipleAttributesOperand: ... + def slice(self, start: int, end: int) -> MultipleAttributesOperand: ... + def either_or( + self, either: MultipleAttributesOperand, or_: MultipleAttributesOperand + ) -> None: ... + +class SingleAttributeOperand: + def greater_than(self, value: AttributeComparisonOperand) -> None: ... + def greater_than_or_equal(self, value: AttributeComparisonOperand) -> None: ... + def less_than(self, value: AttributeComparisonOperand) -> None: ... + def less_than_or_equal(self, value: AttributeComparisonOperand) -> None: ... + def equal_to(self, value: AttributeComparisonOperand) -> None: ... + def not_equal_to(self, value: AttributeComparisonOperand) -> None: ... + def is_in(self, values: AttributeComparisonOperand) -> None: ... + def is_not_in(self, values: AttributeComparisonOperand) -> None: ... + def starts_with(self, value: AttributeComparisonOperand) -> None: ... + def ends_with(self, value: AttributeComparisonOperand) -> None: ... + def contains(self, value: AttributeComparisonOperand) -> None: ... + def add(self, value: AttributeArithmeticOperand) -> SingleAttributeOperand: ... + def subtract(self, value: AttributeArithmeticOperand) -> SingleAttributeOperand: ... + def multiply(self, value: AttributeArithmeticOperand) -> SingleAttributeOperand: ... + def divide(self, value: AttributeArithmeticOperand) -> SingleAttributeOperand: ... + def modulo(self, value: AttributeArithmeticOperand) -> SingleAttributeOperand: ... + def power(self, value: AttributeArithmeticOperand) -> SingleAttributeOperand: ... + def round(self) -> SingleAttributeOperand: ... + def ceil(self) -> SingleAttributeOperand: ... + def floor(self) -> SingleAttributeOperand: ... + def absolute(self) -> SingleAttributeOperand: ... + def sqrt(self) -> SingleAttributeOperand: ... + def trim(self) -> SingleAttributeOperand: ... + def trim_start(self) -> SingleAttributeOperand: ... + def trim_end(self) -> SingleAttributeOperand: ... + def lowercase(self) -> SingleAttributeOperand: ... + def uppercase(self) -> SingleAttributeOperand: ... + def slice(self, start: int, end: int) -> SingleAttributeOperand: ... + def either_or( + self, either: SingleAttributeOperand, or_: SingleAttributeOperand + ) -> None: ... -class EdgeValuesOperand: - def max(self) -> EdgeValueOperand: ... - def min(self) -> EdgeValueOperand: ... - def mean(self) -> EdgeValueOperand: ... - def all(self) -> EdgeValueOperand: ... - def any(self) -> EdgeValueOperand: ... - def greater_than(self, value: ComparisonOperand) -> None: ... - def greater_than_or_equal(self, value: ComparisonOperand) -> None: ... - def less_than(self, value: ComparisonOperand) -> None: ... - def less_than_or_equal(self, value: ComparisonOperand) -> None: ... - def equals(self, value: ComparisonOperand) -> None: ... - def not_equals(self, value: ComparisonOperand) -> None: ... - def is_in(self, values: ValuesOperand) -> None: ... - def is_not_in(self, values: ValuesOperand) -> None: ... - def starts_with(self, value: ComparisonOperand) -> None: ... - def ends_with(self, value: ComparisonOperand) -> None: ... - def contains(self, value: ComparisonOperand) -> None: ... - def add(self, value: ComparisonOperand) -> EdgeValuesOperand: ... - def subtract(self, value: ComparisonOperand) -> EdgeValuesOperand: ... - def multiply(self, value: ComparisonOperand) -> EdgeValuesOperand: ... - def divide(self, value: ComparisonOperand) -> EdgeValuesOperand: ... - def modulo(self, value: ComparisonOperand) -> EdgeValuesOperand: ... - def power(self, value: ComparisonOperand) -> EdgeValuesOperand: ... - def round(self) -> EdgeValuesOperand: ... - def ceil(self) -> EdgeValuesOperand: ... - def floor(self) -> EdgeValuesOperand: ... - def absolute(self) -> EdgeValuesOperand: ... - def sqrt(self) -> EdgeValuesOperand: ... - def trim(self) -> EdgeValuesOperand: ... - def trim_start(self) -> EdgeValuesOperand: ... - def trim_end(self) -> EdgeValuesOperand: ... - def lowercase(self) -> EdgeValuesOperand: ... - def uppercase(self) -> EdgeValuesOperand: ... - def slice(self, start: int, end: int) -> EdgeValuesOperand: ... +NodeIndicesOperand: TypeAlias = MultipleAttributesOperand +NodeIndexOperand: TypeAlias = SingleAttributeOperand -class NodeValueOperand: - def greater_than(self, value: ComparisonOperand) -> None: ... - def greater_than_or_equal(self, value: ComparisonOperand) -> None: ... - def less_than(self, value: ComparisonOperand) -> None: ... - def less_than_or_equal(self, value: ComparisonOperand) -> None: ... - def equals(self, value: ComparisonOperand) -> None: ... - def not_equals(self, value: ComparisonOperand) -> None: ... - def is_in(self, values: ValuesOperand) -> None: ... - def is_not_in(self, values: ValuesOperand) -> None: ... - def starts_with(self, value: ComparisonOperand) -> None: ... - def ends_with(self, value: ComparisonOperand) -> None: ... - def contains(self, value: ComparisonOperand) -> None: ... - def add(self, value: ComparisonOperand) -> NodeValueOperand: ... - def subtract(self, value: ComparisonOperand) -> NodeValueOperand: ... - def multiply(self, value: ComparisonOperand) -> NodeValueOperand: ... - def divide(self, value: ComparisonOperand) -> NodeValueOperand: ... - def modulo(self, value: ComparisonOperand) -> NodeValueOperand: ... - def power(self, value: ComparisonOperand) -> NodeValueOperand: ... - def round(self) -> NodeValueOperand: ... - def ceil(self) -> NodeValueOperand: ... - def floor(self) -> NodeValueOperand: ... - def absolute(self) -> NodeValueOperand: ... - def sqrt(self) -> NodeValueOperand: ... - def trim(self) -> NodeValueOperand: ... - def trim_start(self) -> NodeValueOperand: ... - def trim_end(self) -> NodeValueOperand: ... - def lowercase(self) -> NodeValueOperand: ... - def uppercase(self) -> NodeValueOperand: ... - def slice(self, start: int, end: int) -> NodeValueOperand: ... +class EdgeIndicesOperand: + def max(self) -> EdgeIndexOperand: ... + def min(self) -> EdgeIndexOperand: ... + def mean(self) -> EdgeIndexOperand: ... + def any(self, query: Callable[[EdgeIndexOperand], None]) -> None: ... + def greater_than(self, value: EdgeIndexComparisonOperand) -> None: ... + def greater_than_or_equal(self, value: EdgeIndexComparisonOperand) -> None: ... + def less_than(self, value: EdgeIndexComparisonOperand) -> None: ... + def less_than_or_equal(self, value: EdgeIndexComparisonOperand) -> None: ... + def equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... + def not_equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... + def is_in(self, values: EdgeIndexComparisonOperand) -> None: ... + def is_not_in(self, values: EdgeIndexComparisonOperand) -> None: ... + def starts_with(self, value: EdgeIndexComparisonOperand) -> None: ... + def ends_with(self, value: EdgeIndexComparisonOperand) -> None: ... + def contains(self, value: EdgeIndexComparisonOperand) -> None: ... + def add(self, value: EdgeIndexArithmeticOperand) -> MultipleAttributesOperand: ... + def subtract( + self, value: EdgeIndexArithmeticOperand + ) -> MultipleAttributesOperand: ... + def multiply( + self, value: EdgeIndexArithmeticOperand + ) -> MultipleAttributesOperand: ... + def divide( + self, value: EdgeIndexArithmeticOperand + ) -> MultipleAttributesOperand: ... + def modulo( + self, value: EdgeIndexArithmeticOperand + ) -> MultipleAttributesOperand: ... + def power(self, value: EdgeIndexArithmeticOperand) -> MultipleAttributesOperand: ... + def round(self) -> MultipleAttributesOperand: ... + def ceil(self) -> MultipleAttributesOperand: ... + def floor(self) -> MultipleAttributesOperand: ... + def absolute(self) -> MultipleAttributesOperand: ... + def sqrt(self) -> MultipleAttributesOperand: ... + def trim(self) -> MultipleAttributesOperand: ... + def trim_start(self) -> MultipleAttributesOperand: ... + def trim_end(self) -> MultipleAttributesOperand: ... + def lowercase(self) -> MultipleAttributesOperand: ... + def uppercase(self) -> MultipleAttributesOperand: ... + def slice(self, start: int, end: int) -> MultipleAttributesOperand: ... + def either_or( + self, either: MultipleAttributesOperand, or_: MultipleAttributesOperand + ) -> None: ... -class EdgeValueOperand: - def greater_than(self, value: ComparisonOperand) -> None: ... - def greater_than_or_equal(self, value: ComparisonOperand) -> None: ... - def less_than(self, value: ComparisonOperand) -> None: ... - def less_than_or_equal(self, value: ComparisonOperand) -> None: ... - def equals(self, value: ComparisonOperand) -> None: ... - def not_equals(self, value: ComparisonOperand) -> None: ... - def is_in(self, values: ValuesOperand) -> None: ... - def is_not_in(self, values: ValuesOperand) -> None: ... - def starts_with(self, value: ComparisonOperand) -> None: ... - def ends_with(self, value: ComparisonOperand) -> None: ... - def contains(self, value: ComparisonOperand) -> None: ... - def add(self, value: ComparisonOperand) -> EdgeValueOperand: ... - def subtract(self, value: ComparisonOperand) -> EdgeValueOperand: ... - def multiply(self, value: ComparisonOperand) -> EdgeValueOperand: ... - def divide(self, value: ComparisonOperand) -> EdgeValueOperand: ... - def modulo(self, value: ComparisonOperand) -> EdgeValueOperand: ... - def power(self, value: ComparisonOperand) -> EdgeValueOperand: ... - def round(self) -> EdgeValueOperand: ... - def ceil(self) -> EdgeValueOperand: ... - def floor(self) -> EdgeValueOperand: ... - def absolute(self) -> EdgeValueOperand: ... - def sqrt(self) -> EdgeValueOperand: ... - def trim(self) -> EdgeValueOperand: ... - def trim_start(self) -> EdgeValueOperand: ... - def trim_end(self) -> EdgeValueOperand: ... - def lowercase(self) -> EdgeValueOperand: ... - def uppercase(self) -> EdgeValueOperand: ... - def slice(self, start: int, end: int) -> EdgeValueOperand: ... +class EdgeIndexOperand: + def greater_than(self, value: EdgeIndexComparisonOperand) -> None: ... + def greater_than_or_equal(self, value: EdgeIndexComparisonOperand) -> None: ... + def less_than(self, value: EdgeIndexComparisonOperand) -> None: ... + def less_than_or_equal(self, value: EdgeIndexComparisonOperand) -> None: ... + def equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... + def not_equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... + def is_in(self, values: EdgeIndexComparisonOperand) -> None: ... + def is_not_in(self, values: EdgeIndexComparisonOperand) -> None: ... + def starts_with(self, value: EdgeIndexComparisonOperand) -> None: ... + def ends_with(self, value: EdgeIndexComparisonOperand) -> None: ... + def contains(self, value: EdgeIndexComparisonOperand) -> None: ... + def add(self, value: EdgeIndexArithmeticOperand) -> SingleAttributeOperand: ... + def subtract(self, value: EdgeIndexArithmeticOperand) -> SingleAttributeOperand: ... + def multiply(self, value: EdgeIndexArithmeticOperand) -> SingleAttributeOperand: ... + def divide(self, value: EdgeIndexArithmeticOperand) -> SingleAttributeOperand: ... + def modulo(self, value: EdgeIndexArithmeticOperand) -> SingleAttributeOperand: ... + def power(self, value: EdgeIndexArithmeticOperand) -> SingleAttributeOperand: ... + def round(self) -> SingleAttributeOperand: ... + def ceil(self) -> SingleAttributeOperand: ... + def floor(self) -> SingleAttributeOperand: ... + def absolute(self) -> SingleAttributeOperand: ... + def sqrt(self) -> SingleAttributeOperand: ... + def trim(self) -> SingleAttributeOperand: ... + def trim_start(self) -> SingleAttributeOperand: ... + def trim_end(self) -> SingleAttributeOperand: ... + def lowercase(self) -> SingleAttributeOperand: ... + def uppercase(self) -> SingleAttributeOperand: ... + def slice(self, start: int, end: int) -> SingleAttributeOperand: ... + def either_or( + self, either: SingleAttributeOperand, or_: SingleAttributeOperand + ) -> None: ... diff --git a/medmodels/medrecord/tests/test_medrecord.py b/medmodels/medrecord/tests/test_medrecord.py index 3a9a4e5c..f4b92319 100644 --- a/medmodels/medrecord/tests/test_medrecord.py +++ b/medmodels/medrecord/tests/test_medrecord.py @@ -1161,7 +1161,7 @@ def test_invalid_add_group(self): medrecord.add_group("0", ["1", "0"]) def query(node: NodeOperand): - node.index().equals("0") + node.index().equal_to("0") # Adding a node to a group that already is in the group should fail with self.assertRaises(AssertionError): @@ -1204,7 +1204,7 @@ def test_add_nodes_to_group(self): ) def query(node: NodeOperand): - node.index().equals("3") + node.index().equal_to("3") medrecord.add_nodes_to_group("0", query) @@ -1243,7 +1243,7 @@ def test_invalid_add_nodes_to_group(self): medrecord.add_nodes_to_group("0", ["1", "0"]) def query(node: NodeOperand): - node.index().equals("0") + node.index().equal_to("0") # Adding a node to a group that already is in the group should fail with self.assertRaises(AssertionError): @@ -1268,7 +1268,7 @@ def test_add_edges_to_group(self): ) def query(edge: EdgeOperand): - edge.index().equals(3) + edge.index().equal_to(3) medrecord.add_edges_to_group("0", query) @@ -1307,7 +1307,7 @@ def test_invalid_add_edges_to_group(self): medrecord.add_edges_to_group("0", [1, 0]) def query(edge: EdgeOperand): - edge.index().equals(0) + edge.index().equal_to(0) # Adding an edge to a group that already is in the group should fail with self.assertRaises(AssertionError): @@ -1366,7 +1366,7 @@ def test_invalid_remove_nodes_from_group(self): medrecord.remove_nodes_from_group("50", ["0", "1"]) def query(node: NodeOperand): - node.index().equals("0") + node.index().equal_to("0") # Removing a node from a non-existing group should fail with self.assertRaises(IndexError): @@ -1433,7 +1433,7 @@ def test_invalid_remove_edges_from_group(self): medrecord.remove_edges_from_group("50", [0, 1]) def query(edge: EdgeOperand): - edge.index().equals(0) + edge.index().equal_to(0) # Removing an edge from a non-existing group should fail with self.assertRaises(IndexError): diff --git a/medmodels/treatment_effect/tests/test_treatment_effect.py b/medmodels/treatment_effect/tests/test_treatment_effect.py index 78c56da4..caeec59a 100644 --- a/medmodels/treatment_effect/tests/test_treatment_effect.py +++ b/medmodels/treatment_effect/tests/test_treatment_effect.py @@ -621,7 +621,7 @@ def test_outcome_before_treatment(self): def test_filter_controls(self): def query1(node: NodeOperand): - node.neighbors().index().equals("M2") + node.neighbors().index().equal_to("M2") tee = ( TreatmentEffect.builder() @@ -637,7 +637,7 @@ def query1(node: NodeOperand): # filter females only def query2(node: NodeOperand): - node.attribute("gender").equals("female") + node.attribute("gender").equal_to("female") tee2 = ( TreatmentEffect.builder() diff --git a/medmodels/treatment_effect/treatment_effect.py b/medmodels/treatment_effect/treatment_effect.py index 7a0d5aca..6f22c8a9 100644 --- a/medmodels/treatment_effect/treatment_effect.py +++ b/medmodels/treatment_effect/treatment_effect.py @@ -236,7 +236,9 @@ def _find_treated_patients(self, medrecord: MedRecord) -> Set[NodeIndex]: def query(node: NodeOperand): node.in_group(self._patients_group) - node.neighbors(edge_direction=EdgeDirection.BOTH).index().equals(treatment) + node.neighbors(edge_direction=EdgeDirection.BOTH).index().equal_to( + treatment + ) # Create the group with all the patients that underwent the treatment for treatment in treatments: @@ -287,7 +289,7 @@ def query(node: NodeOperand): node.index().is_in(list(treated_group)) # This could probably be refactored to a proper query - node.neighbors(edge_direction=EdgeDirection.BOTH).index().equals(outcome) + node.neighbors(edge_direction=EdgeDirection.BOTH).index().equal_to(outcome) for outcome in outcomes: nodes_to_check = set(medrecord.select_nodes(query)) @@ -451,7 +453,7 @@ def _find_controls( def query(node: NodeOperand): node.index().is_in(list(control_group)) - node.neighbors(edge_direction=EdgeDirection.BOTH).index().equals(outcome) + node.neighbors(edge_direction=EdgeDirection.BOTH).index().equal_to(outcome) # Finding the patients that had the outcome in the control group for outcome in outcomes: From 9330f88343254909992a4fafd08670707ef73442 Mon Sep 17 00:00:00 2001 From: Jakob Kraus <52459467+JabobKrauskopf@users.noreply.github.com> Date: Thu, 19 Sep 2024 16:45:23 +0200 Subject: [PATCH 3/8] feat: add type narrowing to interface (#214) --- medmodels/medrecord/querying.pyi | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/medmodels/medrecord/querying.pyi b/medmodels/medrecord/querying.pyi index 64e2a316..dbf16026 100644 --- a/medmodels/medrecord/querying.pyi +++ b/medmodels/medrecord/querying.pyi @@ -77,6 +77,11 @@ class MultipleValuesOperand: def min(self) -> SingleValueOperand: ... def mean(self) -> SingleValueOperand: ... def any(self, query: Callable[[SingleValueOperand], None]) -> None: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def is_float(self) -> None: ... + def is_bool(self) -> None: ... + def is_datetime(self) -> None: ... def greater_than(self, value: ValueComparisonOperand) -> None: ... def greater_than_or_equal(self, value: ValueComparisonOperand) -> None: ... def less_than(self, value: ValueComparisonOperand) -> None: ... @@ -110,6 +115,11 @@ class MultipleValuesOperand: ) -> None: ... class SingleValueOperand: + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def is_float(self) -> None: ... + def is_bool(self) -> None: ... + def is_datetime(self) -> None: ... def greater_than(self, value: ValueComparisonOperand) -> None: ... def greater_than_or_equal(self, value: ValueComparisonOperand) -> None: ... def less_than(self, value: ValueComparisonOperand) -> None: ... @@ -146,6 +156,11 @@ class MultipleAttributesOperand: def max(self) -> SingleAttributeOperand: ... def min(self) -> SingleAttributeOperand: ... def mean(self) -> SingleAttributeOperand: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def is_float(self) -> None: ... + def is_bool(self) -> None: ... + def is_datetime(self) -> None: ... def any(self, query: Callable[[SingleAttributeOperand], None]) -> None: ... def greater_than(self, value: AttributeComparisonOperand) -> None: ... def greater_than_or_equal(self, value: AttributeComparisonOperand) -> None: ... @@ -188,6 +203,11 @@ class MultipleAttributesOperand: ) -> None: ... class SingleAttributeOperand: + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def is_float(self) -> None: ... + def is_bool(self) -> None: ... + def is_datetime(self) -> None: ... def greater_than(self, value: AttributeComparisonOperand) -> None: ... def greater_than_or_equal(self, value: AttributeComparisonOperand) -> None: ... def less_than(self, value: AttributeComparisonOperand) -> None: ... @@ -228,6 +248,11 @@ class EdgeIndicesOperand: def min(self) -> EdgeIndexOperand: ... def mean(self) -> EdgeIndexOperand: ... def any(self, query: Callable[[EdgeIndexOperand], None]) -> None: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def is_float(self) -> None: ... + def is_bool(self) -> None: ... + def is_datetime(self) -> None: ... def greater_than(self, value: EdgeIndexComparisonOperand) -> None: ... def greater_than_or_equal(self, value: EdgeIndexComparisonOperand) -> None: ... def less_than(self, value: EdgeIndexComparisonOperand) -> None: ... @@ -269,6 +294,11 @@ class EdgeIndicesOperand: ) -> None: ... class EdgeIndexOperand: + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def is_float(self) -> None: ... + def is_bool(self) -> None: ... + def is_datetime(self) -> None: ... def greater_than(self, value: EdgeIndexComparisonOperand) -> None: ... def greater_than_or_equal(self, value: EdgeIndexComparisonOperand) -> None: ... def less_than(self, value: EdgeIndexComparisonOperand) -> None: ... From 25505dbf93bf55f84aef5a90ca83461cbb31d745 Mon Sep 17 00:00:00 2001 From: Jakob Kraus <52459467+JabobKrauskopf@users.noreply.github.com> Date: Wed, 25 Sep 2024 11:28:03 +0200 Subject: [PATCH 4/8] refactor: update querying interface (#217) --- medmodels/medrecord/querying.pyi | 494 ++++++++++++--------- medmodels/medrecord/tests/test_indexers.py | 4 +- 2 files changed, 281 insertions(+), 217 deletions(-) diff --git a/medmodels/medrecord/querying.pyi b/medmodels/medrecord/querying.pyi index dbf16026..a9952ef1 100644 --- a/medmodels/medrecord/querying.pyi +++ b/medmodels/medrecord/querying.pyi @@ -9,6 +9,7 @@ from medmodels.medrecord.types import ( Group, MedRecordAttribute, MedRecordValue, + NodeIndex, ) if sys.version_info >= (3, 10): @@ -19,26 +20,30 @@ else: NodeQuery: TypeAlias = Callable[[NodeOperand], None] EdgeQuery: TypeAlias = Callable[[EdgeOperand], None] -ValueComparisonOperand: TypeAlias = Union[ - SingleValueOperand, MedRecordValue, MultipleValuesOperand, List[MedRecordValue] +SingleValueComparisonOperand: TypeAlias = Union[SingleValueOperand, MedRecordValue] +MultipleValuesComparisonOperand: TypeAlias = Union[ + MultipleValuesOperand, List[MedRecordValue] ] -ValueArithmeticOperand: TypeAlias = Union[SingleValueOperand, MedRecordValue] -AttributeComparisonOperand: TypeAlias = Union[ + +SingleAttributeComparisonOperand: TypeAlias = Union[ SingleAttributeOperand, MedRecordAttribute, - MultipleAttributesOperand, - List[MedRecordAttribute], ] -AttributeArithmeticOperand: TypeAlias = Union[ - SingleAttributeOperand, MedRecordAttribute +MultipleAttributesComparisonOperand: TypeAlias = Union[ + MultipleAttributesOperand, List[MedRecordAttribute] ] + +NodeIndexComparisonOperand: TypeAlias = Union[NodeIndexOperand, NodeIndex] +NodeIndicesComparisonOperand: TypeAlias = Union[NodeIndicesOperand, List[NodeIndex]] + EdgeIndexComparisonOperand: TypeAlias = Union[ EdgeIndexOperand, EdgeIndex, +] +EdgeIndicesComparisonOperand: TypeAlias = Union[ EdgeIndicesOperand, List[EdgeIndex], ] -EdgeIndexArithmeticOperand: TypeAlias = Union[EdgeIndexOperand, EdgeIndex] class EdgeDirection(Enum): INCOMING = auto() @@ -46,12 +51,14 @@ class EdgeDirection(Enum): BOTH = auto() class NodeOperand: - def attribute(self, attribute: MedRecordAttribute) -> MultipleValuesOperand: ... + def attribute( + self, attribute: Union[MedRecordAttribute, SingleAttributeOperand] + ) -> MultipleValuesOperand: ... def attributes(self) -> MultipleAttributesOperand: ... def index(self) -> NodeIndexOperand: ... def in_group(self, group: Union[Group, List[Group]]) -> None: ... def has_attribute( - self, attribute: Union[MedRecordAttribute, List[MedRecordAttribute]] + self, attribute: Union[MedRecordAttribute, SingleAttributeOperand] ) -> None: ... def incoming_edges(self) -> EdgeOperand: ... def outgoing_edges(self) -> EdgeOperand: ... @@ -59,60 +66,78 @@ class NodeOperand: self, edge_direction: EdgeDirection = EdgeDirection.OUTGOING ) -> NodeOperand: ... def either_or(self, either: NodeQuery, or_: NodeQuery) -> None: ... + def clone(self) -> NodeOperand: ... class EdgeOperand: - def attribute(self, attribute: MedRecordAttribute) -> MultipleValuesOperand: ... + def attribute( + self, attribute: Union[MedRecordAttribute, SingleAttributeOperand] + ) -> MultipleValuesOperand: ... def attributes(self) -> MultipleAttributesOperand: ... def index(self) -> EdgeIndexOperand: ... def in_group(self, group: Union[Group, List[Group]]) -> None: ... def has_attribute( - self, attribute: Union[MedRecordAttribute, List[MedRecordAttribute]] + self, attribute: Union[MedRecordAttribute, SingleAttributeOperand] ) -> None: ... def source_node(self) -> NodeOperand: ... def target_node(self) -> NodeOperand: ... def either_or(self, either: EdgeQuery, or_: EdgeQuery) -> None: ... + def clone(self) -> EdgeOperand: ... class MultipleValuesOperand: def max(self) -> SingleValueOperand: ... def min(self) -> SingleValueOperand: ... + def first_where( + self, query: Callable[[SingleValueOperand], None] + ) -> SingleValueOperand: ... + def last_where( + self, query: Callable[[SingleValueOperand], None] + ) -> SingleValueOperand: ... def mean(self) -> SingleValueOperand: ... - def any(self, query: Callable[[SingleValueOperand], None]) -> None: ... + def median(self) -> SingleValueOperand: ... + def mode(self) -> SingleValueOperand: ... + def std(self) -> SingleValueOperand: ... + def var(self) -> SingleValueOperand: ... + def count(self) -> SingleValueOperand: ... + def sum(self) -> SingleValueOperand: ... def is_string(self) -> None: ... def is_int(self) -> None: ... def is_float(self) -> None: ... def is_bool(self) -> None: ... def is_datetime(self) -> None: ... - def greater_than(self, value: ValueComparisonOperand) -> None: ... - def greater_than_or_equal(self, value: ValueComparisonOperand) -> None: ... - def less_than(self, value: ValueComparisonOperand) -> None: ... - def less_than_or_equal(self, value: ValueComparisonOperand) -> None: ... - def equal_to(self, value: ValueComparisonOperand) -> None: ... - def not_equal_to(self, value: ValueComparisonOperand) -> None: ... - def is_in(self, values: ValueComparisonOperand) -> None: ... - def is_not_in(self, values: ValueComparisonOperand) -> None: ... - def starts_with(self, value: ValueComparisonOperand) -> None: ... - def ends_with(self, value: ValueComparisonOperand) -> None: ... - def contains(self, value: ValueComparisonOperand) -> None: ... - def add(self, value: ValueArithmeticOperand) -> MultipleValuesOperand: ... - def subtract(self, value: ValueArithmeticOperand) -> MultipleValuesOperand: ... - def multiply(self, value: ValueArithmeticOperand) -> MultipleValuesOperand: ... - def divide(self, value: ValueArithmeticOperand) -> MultipleValuesOperand: ... - def modulo(self, value: ValueArithmeticOperand) -> MultipleValuesOperand: ... - def power(self, value: ValueArithmeticOperand) -> MultipleValuesOperand: ... - def round(self) -> MultipleValuesOperand: ... - def ceil(self) -> MultipleValuesOperand: ... - def floor(self) -> MultipleValuesOperand: ... - def absolute(self) -> MultipleValuesOperand: ... - def sqrt(self) -> MultipleValuesOperand: ... - def trim(self) -> MultipleValuesOperand: ... - def trim_start(self) -> MultipleValuesOperand: ... - def trim_end(self) -> MultipleValuesOperand: ... - def lowercase(self) -> MultipleValuesOperand: ... - def uppercase(self) -> MultipleValuesOperand: ... - def slice(self, start: int, end: int) -> MultipleValuesOperand: ... + def greater_than(self, value: SingleValueComparisonOperand) -> None: ... + def greater_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: ... + def less_than(self, value: SingleValueComparisonOperand) -> None: ... + def less_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: ... + def equal_to(self, value: SingleValueComparisonOperand) -> None: ... + def not_equal_to(self, value: SingleValueComparisonOperand) -> None: ... + def is_in(self, values: MultipleValuesComparisonOperand) -> None: ... + def is_not_in(self, values: MultipleValuesComparisonOperand) -> None: ... + def starts_with(self, value: SingleValueComparisonOperand) -> None: ... + def ends_with(self, value: SingleValueComparisonOperand) -> None: ... + def contains(self, value: SingleValueComparisonOperand) -> None: ... + def add(self, value: SingleValueComparisonOperand) -> None: ... + def subtract(self, value: SingleValueComparisonOperand) -> None: ... + def multiply(self, value: SingleValueComparisonOperand) -> None: ... + def divide(self, value: SingleValueComparisonOperand) -> None: ... + def modulo(self, value: SingleValueComparisonOperand) -> None: ... + def power(self, value: SingleValueComparisonOperand) -> None: ... + def round(self) -> None: ... + def ceil(self) -> None: ... + def floor(self) -> None: ... + def absolute(self) -> None: ... + def sqrt(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... def either_or( - self, either: MultipleValuesOperand, or_: MultipleValuesOperand + self, + either: Callable[[MultipleValuesOperand], None], + or_: Callable[[MultipleValuesOperand], None], ) -> None: ... + def clone(self) -> MultipleValuesOperand: ... class SingleValueOperand: def is_string(self) -> None: ... @@ -120,213 +145,252 @@ class SingleValueOperand: def is_float(self) -> None: ... def is_bool(self) -> None: ... def is_datetime(self) -> None: ... - def greater_than(self, value: ValueComparisonOperand) -> None: ... - def greater_than_or_equal(self, value: ValueComparisonOperand) -> None: ... - def less_than(self, value: ValueComparisonOperand) -> None: ... - def less_than_or_equal(self, value: ValueComparisonOperand) -> None: ... - def equal_to(self, value: ValueComparisonOperand) -> None: ... - def not_equal_to(self, value: ValueComparisonOperand) -> None: ... - def is_in(self, values: ValueComparisonOperand) -> None: ... - def is_not_in(self, values: ValueComparisonOperand) -> None: ... - def starts_with(self, value: ValueComparisonOperand) -> None: ... - def ends_with(self, value: ValueComparisonOperand) -> None: ... - def contains(self, value: ValueComparisonOperand) -> None: ... - def add(self, value: ValueArithmeticOperand) -> SingleValueOperand: ... - def subtract(self, value: ValueArithmeticOperand) -> SingleValueOperand: ... - def multiply(self, value: ValueArithmeticOperand) -> SingleValueOperand: ... - def divide(self, value: ValueArithmeticOperand) -> SingleValueOperand: ... - def modulo(self, value: ValueArithmeticOperand) -> SingleValueOperand: ... - def power(self, value: ValueArithmeticOperand) -> SingleValueOperand: ... - def round(self) -> SingleValueOperand: ... - def ceil(self) -> SingleValueOperand: ... - def floor(self) -> SingleValueOperand: ... - def absolute(self) -> SingleValueOperand: ... - def sqrt(self) -> SingleValueOperand: ... - def trim(self) -> SingleValueOperand: ... - def trim_start(self) -> SingleValueOperand: ... - def trim_end(self) -> SingleValueOperand: ... - def lowercase(self) -> SingleValueOperand: ... - def uppercase(self) -> SingleValueOperand: ... - def slice(self, start: int, end: int) -> SingleValueOperand: ... + def greater_than(self, value: SingleValueComparisonOperand) -> None: ... + def greater_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: ... + def less_than(self, value: SingleValueComparisonOperand) -> None: ... + def less_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: ... + def equal_to(self, value: SingleValueComparisonOperand) -> None: ... + def not_equal_to(self, value: SingleValueComparisonOperand) -> None: ... + def is_in(self, values: MultipleValuesComparisonOperand) -> None: ... + def is_not_in(self, values: MultipleValuesComparisonOperand) -> None: ... + def starts_with(self, value: SingleValueComparisonOperand) -> None: ... + def ends_with(self, value: SingleValueComparisonOperand) -> None: ... + def contains(self, value: SingleValueComparisonOperand) -> None: ... + def add(self, value: SingleValueComparisonOperand) -> None: ... + def subtract(self, value: SingleValueComparisonOperand) -> None: ... + def multiply(self, value: SingleValueComparisonOperand) -> None: ... + def modulo(self, value: SingleValueComparisonOperand) -> None: ... + def power(self, value: SingleValueComparisonOperand) -> None: ... + def absolute(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... def either_or( - self, either: SingleValueOperand, or_: SingleValueOperand + self, + either: Callable[[SingleValueOperand], None], + or_: Callable[[SingleValueOperand], None], ) -> None: ... + def clone(self) -> SingleValueOperand: ... class MultipleAttributesOperand: def max(self) -> SingleAttributeOperand: ... def min(self) -> SingleAttributeOperand: ... - def mean(self) -> SingleAttributeOperand: ... + def first_where( + self, query: Callable[[SingleAttributeOperand], None] + ) -> SingleAttributeOperand: ... + def last_where( + self, query: Callable[[SingleAttributeOperand], None] + ) -> SingleAttributeOperand: ... + def count(self) -> SingleAttributeOperand: ... + def sum(self) -> SingleAttributeOperand: ... def is_string(self) -> None: ... def is_int(self) -> None: ... - def is_float(self) -> None: ... - def is_bool(self) -> None: ... - def is_datetime(self) -> None: ... - def any(self, query: Callable[[SingleAttributeOperand], None]) -> None: ... - def greater_than(self, value: AttributeComparisonOperand) -> None: ... - def greater_than_or_equal(self, value: AttributeComparisonOperand) -> None: ... - def less_than(self, value: AttributeComparisonOperand) -> None: ... - def less_than_or_equal(self, value: AttributeComparisonOperand) -> None: ... - def equal_to(self, value: AttributeComparisonOperand) -> None: ... - def not_equal_to(self, value: AttributeComparisonOperand) -> None: ... - def is_in(self, values: AttributeComparisonOperand) -> None: ... - def is_not_in(self, values: AttributeComparisonOperand) -> None: ... - def starts_with(self, value: AttributeComparisonOperand) -> None: ... - def ends_with(self, value: AttributeComparisonOperand) -> None: ... - def contains(self, value: AttributeComparisonOperand) -> None: ... - def add(self, value: AttributeArithmeticOperand) -> MultipleAttributesOperand: ... - def subtract( - self, value: AttributeArithmeticOperand - ) -> MultipleAttributesOperand: ... - def multiply( - self, value: AttributeArithmeticOperand - ) -> MultipleAttributesOperand: ... - def divide( - self, value: AttributeArithmeticOperand - ) -> MultipleAttributesOperand: ... - def modulo( - self, value: AttributeArithmeticOperand - ) -> MultipleAttributesOperand: ... - def power(self, value: AttributeArithmeticOperand) -> MultipleAttributesOperand: ... - def round(self) -> MultipleAttributesOperand: ... - def ceil(self) -> MultipleAttributesOperand: ... - def floor(self) -> MultipleAttributesOperand: ... - def absolute(self) -> MultipleAttributesOperand: ... - def sqrt(self) -> MultipleAttributesOperand: ... - def trim(self) -> MultipleAttributesOperand: ... - def trim_start(self) -> MultipleAttributesOperand: ... - def trim_end(self) -> MultipleAttributesOperand: ... - def lowercase(self) -> MultipleAttributesOperand: ... - def uppercase(self) -> MultipleAttributesOperand: ... - def slice(self, start: int, end: int) -> MultipleAttributesOperand: ... + def greater_than(self, value: SingleAttributeComparisonOperand) -> None: ... + def greater_than_or_equal_to( + self, value: SingleAttributeComparisonOperand + ) -> None: ... + def less_than(self, value: SingleAttributeComparisonOperand) -> None: ... + def less_than_or_equal_to( + self, value: SingleAttributeComparisonOperand + ) -> None: ... + def equal_to(self, value: SingleAttributeComparisonOperand) -> None: ... + def not_equal_to(self, value: SingleAttributeComparisonOperand) -> None: ... + def is_in(self, values: MultipleAttributesComparisonOperand) -> None: ... + def is_not_in(self, values: MultipleAttributesComparisonOperand) -> None: ... + def starts_with(self, value: SingleAttributeComparisonOperand) -> None: ... + def ends_with(self, value: SingleAttributeComparisonOperand) -> None: ... + def contains(self, value: SingleAttributeComparisonOperand) -> None: ... + def add(self, value: SingleAttributeComparisonOperand) -> None: ... + def subtract(self, value: SingleAttributeComparisonOperand) -> None: ... + def multiply(self, value: SingleAttributeComparisonOperand) -> None: ... + def modulo(self, value: SingleAttributeComparisonOperand) -> None: ... + def power(self, value: SingleAttributeComparisonOperand) -> None: ... + def absolute(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... def either_or( - self, either: MultipleAttributesOperand, or_: MultipleAttributesOperand + self, + either: Callable[[MultipleAttributesOperand], None], + or_: Callable[[MultipleAttributesOperand], None], ) -> None: ... + def clone(self) -> MultipleAttributesOperand: ... class SingleAttributeOperand: + def to_values(self) -> MultipleValuesOperand: ... def is_string(self) -> None: ... def is_int(self) -> None: ... - def is_float(self) -> None: ... - def is_bool(self) -> None: ... - def is_datetime(self) -> None: ... - def greater_than(self, value: AttributeComparisonOperand) -> None: ... - def greater_than_or_equal(self, value: AttributeComparisonOperand) -> None: ... - def less_than(self, value: AttributeComparisonOperand) -> None: ... - def less_than_or_equal(self, value: AttributeComparisonOperand) -> None: ... - def equal_to(self, value: AttributeComparisonOperand) -> None: ... - def not_equal_to(self, value: AttributeComparisonOperand) -> None: ... - def is_in(self, values: AttributeComparisonOperand) -> None: ... - def is_not_in(self, values: AttributeComparisonOperand) -> None: ... - def starts_with(self, value: AttributeComparisonOperand) -> None: ... - def ends_with(self, value: AttributeComparisonOperand) -> None: ... - def contains(self, value: AttributeComparisonOperand) -> None: ... - def add(self, value: AttributeArithmeticOperand) -> SingleAttributeOperand: ... - def subtract(self, value: AttributeArithmeticOperand) -> SingleAttributeOperand: ... - def multiply(self, value: AttributeArithmeticOperand) -> SingleAttributeOperand: ... - def divide(self, value: AttributeArithmeticOperand) -> SingleAttributeOperand: ... - def modulo(self, value: AttributeArithmeticOperand) -> SingleAttributeOperand: ... - def power(self, value: AttributeArithmeticOperand) -> SingleAttributeOperand: ... - def round(self) -> SingleAttributeOperand: ... - def ceil(self) -> SingleAttributeOperand: ... - def floor(self) -> SingleAttributeOperand: ... - def absolute(self) -> SingleAttributeOperand: ... - def sqrt(self) -> SingleAttributeOperand: ... - def trim(self) -> SingleAttributeOperand: ... - def trim_start(self) -> SingleAttributeOperand: ... - def trim_end(self) -> SingleAttributeOperand: ... - def lowercase(self) -> SingleAttributeOperand: ... - def uppercase(self) -> SingleAttributeOperand: ... - def slice(self, start: int, end: int) -> SingleAttributeOperand: ... + def greater_than(self, value: SingleAttributeComparisonOperand) -> None: ... + def greater_than_or_equal_to( + self, value: SingleAttributeComparisonOperand + ) -> None: ... + def less_than(self, value: SingleAttributeComparisonOperand) -> None: ... + def less_than_or_equal_to( + self, value: SingleAttributeComparisonOperand + ) -> None: ... + def equal_to(self, value: SingleAttributeComparisonOperand) -> None: ... + def not_equal_to(self, value: SingleAttributeComparisonOperand) -> None: ... + def is_in(self, values: MultipleAttributesComparisonOperand) -> None: ... + def is_not_in(self, values: MultipleAttributesComparisonOperand) -> None: ... + def starts_with(self, value: SingleAttributeComparisonOperand) -> None: ... + def ends_with(self, value: SingleAttributeComparisonOperand) -> None: ... + def contains(self, value: SingleAttributeComparisonOperand) -> None: ... + def add(self, value: SingleAttributeComparisonOperand) -> None: ... + def subtract(self, value: SingleAttributeComparisonOperand) -> None: ... + def multiply(self, value: SingleAttributeComparisonOperand) -> None: ... + def divide(self, value: SingleAttributeComparisonOperand) -> None: ... + def modulo(self, value: SingleAttributeComparisonOperand) -> None: ... + def power(self, value: SingleAttributeComparisonOperand) -> None: ... + def round(self) -> None: ... + def ceil(self) -> None: ... + def floor(self) -> None: ... + def absolute(self) -> None: ... + def sqrt(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... + def either_or( + self, + either: Callable[[SingleAttributeOperand], None], + or_: Callable[[SingleAttributeOperand], None], + ) -> None: ... + def clone(self) -> SingleAttributeOperand: ... + +class NodeIndicesOperand: + def max(self) -> NodeIndexOperand: ... + def min(self) -> NodeIndexOperand: ... + def first_where( + self, query: Callable[[NodeIndexOperand], None] + ) -> NodeIndexOperand: ... + def last_where( + self, query: Callable[[NodeIndexOperand], None] + ) -> NodeIndexOperand: ... + def count(self) -> NodeIndexOperand: ... + def sum(self) -> NodeIndexOperand: ... + def greater_than(self, value: NodeIndexComparisonOperand) -> None: ... + def greater_than_or_equal_to(self, value: NodeIndexComparisonOperand) -> None: ... + def less_than(self, value: NodeIndexComparisonOperand) -> None: ... + def less_than_or_equal_to(self, value: NodeIndexComparisonOperand) -> None: ... + def equal_to(self, value: NodeIndexComparisonOperand) -> None: ... + def not_equal_to(self, value: NodeIndexComparisonOperand) -> None: ... + def is_in(self, values: NodeIndicesComparisonOperand) -> None: ... + def is_not_in(self, values: NodeIndicesComparisonOperand) -> None: ... + def starts_with(self, value: NodeIndexComparisonOperand) -> None: ... + def ends_with(self, value: NodeIndexComparisonOperand) -> None: ... + def contains(self, value: NodeIndexComparisonOperand) -> None: ... + def add(self, value: NodeIndexComparisonOperand) -> None: ... + def subtract(self, value: NodeIndexComparisonOperand) -> None: ... + def multiply(self, value: NodeIndexComparisonOperand) -> None: ... + def modulo(self, value: NodeIndexComparisonOperand) -> None: ... + def power(self, value: NodeIndexComparisonOperand) -> None: ... + def absolute(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... def either_or( - self, either: SingleAttributeOperand, or_: SingleAttributeOperand + self, + either: Callable[[NodeIndicesOperand], None], + or_: Callable[[NodeIndicesOperand], None], ) -> None: ... + def clone(self) -> NodeIndicesOperand: ... -NodeIndicesOperand: TypeAlias = MultipleAttributesOperand -NodeIndexOperand: TypeAlias = SingleAttributeOperand +class NodeIndexOperand: + def greater_than(self, value: NodeIndexComparisonOperand) -> None: ... + def greater_than_or_equal_to(self, value: NodeIndexComparisonOperand) -> None: ... + def less_than(self, value: NodeIndexComparisonOperand) -> None: ... + def less_than_or_equal_to(self, value: NodeIndexComparisonOperand) -> None: ... + def equal_to(self, value: NodeIndexComparisonOperand) -> None: ... + def not_equal_to(self, value: NodeIndexComparisonOperand) -> None: ... + def is_in(self, values: NodeIndicesComparisonOperand) -> None: ... + def is_not_in(self, values: NodeIndicesComparisonOperand) -> None: ... + def starts_with(self, value: NodeIndexComparisonOperand) -> None: ... + def ends_with(self, value: NodeIndexComparisonOperand) -> None: ... + def contains(self, value: NodeIndexComparisonOperand) -> None: ... + def add(self, value: NodeIndexComparisonOperand) -> None: ... + def subtract(self, value: NodeIndexComparisonOperand) -> None: ... + def multiply(self, value: NodeIndexComparisonOperand) -> None: ... + def modulo(self, value: NodeIndexComparisonOperand) -> None: ... + def power(self, value: NodeIndexComparisonOperand) -> None: ... + def absolute(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... + def either_or( + self, + either: Callable[[NodeIndexOperand], None], + or_: Callable[[NodeIndexOperand], None], + ) -> None: ... + def clone(self) -> NodeIndexOperand: ... class EdgeIndicesOperand: def max(self) -> EdgeIndexOperand: ... def min(self) -> EdgeIndexOperand: ... - def mean(self) -> EdgeIndexOperand: ... - def any(self, query: Callable[[EdgeIndexOperand], None]) -> None: ... - def is_string(self) -> None: ... - def is_int(self) -> None: ... - def is_float(self) -> None: ... - def is_bool(self) -> None: ... - def is_datetime(self) -> None: ... + def first_where( + self, query: Callable[[EdgeIndexOperand], None] + ) -> EdgeIndexOperand: ... + def last_where( + self, query: Callable[[EdgeIndexOperand], None] + ) -> EdgeIndexOperand: ... + def count(self) -> EdgeIndexOperand: ... + def sum(self) -> EdgeIndexOperand: ... def greater_than(self, value: EdgeIndexComparisonOperand) -> None: ... - def greater_than_or_equal(self, value: EdgeIndexComparisonOperand) -> None: ... + def greater_than_or_equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... def less_than(self, value: EdgeIndexComparisonOperand) -> None: ... - def less_than_or_equal(self, value: EdgeIndexComparisonOperand) -> None: ... + def less_than_or_equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... def equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... def not_equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... - def is_in(self, values: EdgeIndexComparisonOperand) -> None: ... - def is_not_in(self, values: EdgeIndexComparisonOperand) -> None: ... + def is_in(self, values: EdgeIndicesComparisonOperand) -> None: ... + def is_not_in(self, values: EdgeIndicesComparisonOperand) -> None: ... def starts_with(self, value: EdgeIndexComparisonOperand) -> None: ... def ends_with(self, value: EdgeIndexComparisonOperand) -> None: ... def contains(self, value: EdgeIndexComparisonOperand) -> None: ... - def add(self, value: EdgeIndexArithmeticOperand) -> MultipleAttributesOperand: ... - def subtract( - self, value: EdgeIndexArithmeticOperand - ) -> MultipleAttributesOperand: ... - def multiply( - self, value: EdgeIndexArithmeticOperand - ) -> MultipleAttributesOperand: ... - def divide( - self, value: EdgeIndexArithmeticOperand - ) -> MultipleAttributesOperand: ... - def modulo( - self, value: EdgeIndexArithmeticOperand - ) -> MultipleAttributesOperand: ... - def power(self, value: EdgeIndexArithmeticOperand) -> MultipleAttributesOperand: ... - def round(self) -> MultipleAttributesOperand: ... - def ceil(self) -> MultipleAttributesOperand: ... - def floor(self) -> MultipleAttributesOperand: ... - def absolute(self) -> MultipleAttributesOperand: ... - def sqrt(self) -> MultipleAttributesOperand: ... - def trim(self) -> MultipleAttributesOperand: ... - def trim_start(self) -> MultipleAttributesOperand: ... - def trim_end(self) -> MultipleAttributesOperand: ... - def lowercase(self) -> MultipleAttributesOperand: ... - def uppercase(self) -> MultipleAttributesOperand: ... - def slice(self, start: int, end: int) -> MultipleAttributesOperand: ... + def add(self, value: EdgeIndexComparisonOperand) -> None: ... + def subtract(self, value: EdgeIndexComparisonOperand) -> None: ... + def multiply(self, value: EdgeIndexComparisonOperand) -> None: ... + def modulo(self, value: EdgeIndexComparisonOperand) -> None: ... + def power(self, value: EdgeIndexComparisonOperand) -> None: ... def either_or( - self, either: MultipleAttributesOperand, or_: MultipleAttributesOperand + self, + either: Callable[[EdgeIndicesOperand], None], + or_: Callable[[EdgeIndicesOperand], None], ) -> None: ... + def clone(self) -> EdgeIndicesOperand: ... class EdgeIndexOperand: - def is_string(self) -> None: ... - def is_int(self) -> None: ... - def is_float(self) -> None: ... - def is_bool(self) -> None: ... - def is_datetime(self) -> None: ... def greater_than(self, value: EdgeIndexComparisonOperand) -> None: ... - def greater_than_or_equal(self, value: EdgeIndexComparisonOperand) -> None: ... + def greater_than_or_equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... def less_than(self, value: EdgeIndexComparisonOperand) -> None: ... - def less_than_or_equal(self, value: EdgeIndexComparisonOperand) -> None: ... + def less_than_or_equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... def equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... def not_equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... - def is_in(self, values: EdgeIndexComparisonOperand) -> None: ... - def is_not_in(self, values: EdgeIndexComparisonOperand) -> None: ... - def starts_with(self, value: EdgeIndexComparisonOperand) -> None: ... - def ends_with(self, value: EdgeIndexComparisonOperand) -> None: ... - def contains(self, value: EdgeIndexComparisonOperand) -> None: ... - def add(self, value: EdgeIndexArithmeticOperand) -> SingleAttributeOperand: ... - def subtract(self, value: EdgeIndexArithmeticOperand) -> SingleAttributeOperand: ... - def multiply(self, value: EdgeIndexArithmeticOperand) -> SingleAttributeOperand: ... - def divide(self, value: EdgeIndexArithmeticOperand) -> SingleAttributeOperand: ... - def modulo(self, value: EdgeIndexArithmeticOperand) -> SingleAttributeOperand: ... - def power(self, value: EdgeIndexArithmeticOperand) -> SingleAttributeOperand: ... - def round(self) -> SingleAttributeOperand: ... - def ceil(self) -> SingleAttributeOperand: ... - def floor(self) -> SingleAttributeOperand: ... - def absolute(self) -> SingleAttributeOperand: ... - def sqrt(self) -> SingleAttributeOperand: ... - def trim(self) -> SingleAttributeOperand: ... - def trim_start(self) -> SingleAttributeOperand: ... - def trim_end(self) -> SingleAttributeOperand: ... - def lowercase(self) -> SingleAttributeOperand: ... - def uppercase(self) -> SingleAttributeOperand: ... - def slice(self, start: int, end: int) -> SingleAttributeOperand: ... + def is_in(self, values: EdgeIndicesComparisonOperand) -> None: ... + def is_not_in(self, values: EdgeIndicesComparisonOperand) -> None: ... + def starts_with(self, value: EdgeIndicesComparisonOperand) -> None: ... + def ends_with(self, value: EdgeIndicesComparisonOperand) -> None: ... + def contains(self, value: EdgeIndicesComparisonOperand) -> None: ... + def add(self, value: EdgeIndexComparisonOperand) -> None: ... + def subtract(self, value: EdgeIndexComparisonOperand) -> None: ... + def multiply(self, value: EdgeIndexComparisonOperand) -> None: ... + def modulo(self, value: EdgeIndexComparisonOperand) -> None: ... + def power(self, value: EdgeIndexComparisonOperand) -> None: ... def either_or( - self, either: SingleAttributeOperand, or_: SingleAttributeOperand + self, + either: Callable[[EdgeIndexOperand], None], + or_: Callable[[EdgeIndexOperand], None], ) -> None: ... + def clone(self) -> EdgeIndexOperand: ... diff --git a/medmodels/medrecord/tests/test_indexers.py b/medmodels/medrecord/tests/test_indexers.py index 6616b866..b5c47a32 100644 --- a/medmodels/medrecord/tests/test_indexers.py +++ b/medmodels/medrecord/tests/test_indexers.py @@ -22,7 +22,7 @@ def create_medrecord(): def node_greater_than_or_equal_two(node: NodeOperand): - node.index().greater_than_or_equal(2) + node.index().greater_than_or_equal_to(2) def node_greater_than_three(node: NodeOperand): @@ -34,7 +34,7 @@ def node_less_than_two(node: NodeOperand): def edge_greater_than_or_equal_two(edge: EdgeOperand): - edge.index().greater_than_or_equal(2) + edge.index().greater_than_or_equal_to(2) def edge_greater_than_three(edge: EdgeOperand): From 729291772c1611f2e695dc1ef018eabc54bd369f Mon Sep 17 00:00:00 2001 From: Jakob Kraus <52459467+JabobKrauskopf@users.noreply.github.com> Date: Wed, 9 Oct 2024 13:10:46 +0200 Subject: [PATCH 5/8] refactor: update interface to reflect final (for now) query engine (#226) --- medmodels/medrecord/querying.pyi | 94 ++++++++++++++++++++++---------- 1 file changed, 64 insertions(+), 30 deletions(-) diff --git a/medmodels/medrecord/querying.pyi b/medmodels/medrecord/querying.pyi index a9952ef1..d2258fca 100644 --- a/medmodels/medrecord/querying.pyi +++ b/medmodels/medrecord/querying.pyi @@ -86,12 +86,6 @@ class EdgeOperand: class MultipleValuesOperand: def max(self) -> SingleValueOperand: ... def min(self) -> SingleValueOperand: ... - def first_where( - self, query: Callable[[SingleValueOperand], None] - ) -> SingleValueOperand: ... - def last_where( - self, query: Callable[[SingleValueOperand], None] - ) -> SingleValueOperand: ... def mean(self) -> SingleValueOperand: ... def median(self) -> SingleValueOperand: ... def mode(self) -> SingleValueOperand: ... @@ -99,11 +93,16 @@ class MultipleValuesOperand: def var(self) -> SingleValueOperand: ... def count(self) -> SingleValueOperand: ... def sum(self) -> SingleValueOperand: ... + def first(self) -> SingleValueOperand: ... + def last(self) -> SingleValueOperand: ... def is_string(self) -> None: ... def is_int(self) -> None: ... def is_float(self) -> None: ... def is_bool(self) -> None: ... def is_datetime(self) -> None: ... + def is_null(self) -> None: ... + def is_max(self) -> None: ... + def is_min(self) -> None: ... def greater_than(self, value: SingleValueComparisonOperand) -> None: ... def greater_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: ... def less_than(self, value: SingleValueComparisonOperand) -> None: ... @@ -145,6 +144,7 @@ class SingleValueOperand: def is_float(self) -> None: ... def is_bool(self) -> None: ... def is_datetime(self) -> None: ... + def is_null(self) -> None: ... def greater_than(self, value: SingleValueComparisonOperand) -> None: ... def greater_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: ... def less_than(self, value: SingleValueComparisonOperand) -> None: ... @@ -161,7 +161,11 @@ class SingleValueOperand: def multiply(self, value: SingleValueComparisonOperand) -> None: ... def modulo(self, value: SingleValueComparisonOperand) -> None: ... def power(self, value: SingleValueComparisonOperand) -> None: ... + def round(self) -> None: ... + def ceil(self) -> None: ... + def floor(self) -> None: ... def absolute(self) -> None: ... + def sqrt(self) -> None: ... def trim(self) -> None: ... def trim_start(self) -> None: ... def trim_end(self) -> None: ... @@ -175,19 +179,62 @@ class SingleValueOperand: ) -> None: ... def clone(self) -> SingleValueOperand: ... +class AttributesTreeOperand: + def max(self) -> MultipleAttributesOperand: ... + def min(self) -> MultipleAttributesOperand: ... + def count(self) -> MultipleAttributesOperand: ... + def sum(self) -> MultipleAttributesOperand: ... + def first(self) -> MultipleAttributesOperand: ... + def last(self) -> MultipleAttributesOperand: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def is_max(self) -> None: ... + def is_min(self) -> None: ... + def greater_than(self, value: SingleAttributeComparisonOperand) -> None: ... + def greater_than_or_equal_to( + self, value: SingleAttributeComparisonOperand + ) -> None: ... + def less_than(self, value: SingleAttributeComparisonOperand) -> None: ... + def less_than_or_equal_to( + self, value: SingleAttributeComparisonOperand + ) -> None: ... + def equal_to(self, value: SingleAttributeComparisonOperand) -> None: ... + def not_equal_to(self, value: SingleAttributeComparisonOperand) -> None: ... + def is_in(self, values: MultipleAttributesComparisonOperand) -> None: ... + def is_not_in(self, values: MultipleAttributesComparisonOperand) -> None: ... + def starts_with(self, value: SingleAttributeComparisonOperand) -> None: ... + def ends_with(self, value: SingleAttributeComparisonOperand) -> None: ... + def contains(self, value: SingleAttributeComparisonOperand) -> None: ... + def add(self, value: SingleAttributeComparisonOperand) -> None: ... + def subtract(self, value: SingleAttributeComparisonOperand) -> None: ... + def multiply(self, value: SingleAttributeComparisonOperand) -> None: ... + def modulo(self, value: SingleAttributeComparisonOperand) -> None: ... + def power(self, value: SingleAttributeComparisonOperand) -> None: ... + def absolute(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... + def either_or( + self, + either: Callable[[AttributesTreeOperand], None], + or_: Callable[[AttributesTreeOperand], None], + ) -> None: ... + def clone(self) -> AttributesTreeOperand: ... + class MultipleAttributesOperand: def max(self) -> SingleAttributeOperand: ... def min(self) -> SingleAttributeOperand: ... - def first_where( - self, query: Callable[[SingleAttributeOperand], None] - ) -> SingleAttributeOperand: ... - def last_where( - self, query: Callable[[SingleAttributeOperand], None] - ) -> SingleAttributeOperand: ... def count(self) -> SingleAttributeOperand: ... def sum(self) -> SingleAttributeOperand: ... + def first(self) -> SingleAttributeOperand: ... + def last(self) -> SingleAttributeOperand: ... def is_string(self) -> None: ... def is_int(self) -> None: ... + def is_max(self) -> None: ... + def is_min(self) -> None: ... def greater_than(self, value: SingleAttributeComparisonOperand) -> None: ... def greater_than_or_equal_to( self, value: SingleAttributeComparisonOperand @@ -214,6 +261,7 @@ class MultipleAttributesOperand: def trim_end(self) -> None: ... def lowercase(self) -> None: ... def uppercase(self) -> None: ... + def to_values(self) -> MultipleValuesOperand: ... def slice(self, start: int, end: int) -> None: ... def either_or( self, @@ -223,7 +271,6 @@ class MultipleAttributesOperand: def clone(self) -> MultipleAttributesOperand: ... class SingleAttributeOperand: - def to_values(self) -> MultipleValuesOperand: ... def is_string(self) -> None: ... def is_int(self) -> None: ... def greater_than(self, value: SingleAttributeComparisonOperand) -> None: ... @@ -244,14 +291,9 @@ class SingleAttributeOperand: def add(self, value: SingleAttributeComparisonOperand) -> None: ... def subtract(self, value: SingleAttributeComparisonOperand) -> None: ... def multiply(self, value: SingleAttributeComparisonOperand) -> None: ... - def divide(self, value: SingleAttributeComparisonOperand) -> None: ... def modulo(self, value: SingleAttributeComparisonOperand) -> None: ... def power(self, value: SingleAttributeComparisonOperand) -> None: ... - def round(self) -> None: ... - def ceil(self) -> None: ... - def floor(self) -> None: ... def absolute(self) -> None: ... - def sqrt(self) -> None: ... def trim(self) -> None: ... def trim_start(self) -> None: ... def trim_end(self) -> None: ... @@ -268,14 +310,10 @@ class SingleAttributeOperand: class NodeIndicesOperand: def max(self) -> NodeIndexOperand: ... def min(self) -> NodeIndexOperand: ... - def first_where( - self, query: Callable[[NodeIndexOperand], None] - ) -> NodeIndexOperand: ... - def last_where( - self, query: Callable[[NodeIndexOperand], None] - ) -> NodeIndexOperand: ... def count(self) -> NodeIndexOperand: ... def sum(self) -> NodeIndexOperand: ... + def first(self) -> NodeIndexOperand: ... + def last(self) -> NodeIndexOperand: ... def greater_than(self, value: NodeIndexComparisonOperand) -> None: ... def greater_than_or_equal_to(self, value: NodeIndexComparisonOperand) -> None: ... def less_than(self, value: NodeIndexComparisonOperand) -> None: ... @@ -340,14 +378,10 @@ class NodeIndexOperand: class EdgeIndicesOperand: def max(self) -> EdgeIndexOperand: ... def min(self) -> EdgeIndexOperand: ... - def first_where( - self, query: Callable[[EdgeIndexOperand], None] - ) -> EdgeIndexOperand: ... - def last_where( - self, query: Callable[[EdgeIndexOperand], None] - ) -> EdgeIndexOperand: ... def count(self) -> EdgeIndexOperand: ... def sum(self) -> EdgeIndexOperand: ... + def first(self) -> EdgeIndexOperand: ... + def last(self) -> EdgeIndexOperand: ... def greater_than(self, value: EdgeIndexComparisonOperand) -> None: ... def greater_than_or_equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... def less_than(self, value: EdgeIndexComparisonOperand) -> None: ... From 618129711364a82f3dd004d6b1fc65d6d3090dd2 Mon Sep 17 00:00:00 2001 From: Jakob Kraus <52459467+JabobKrauskopf@users.noreply.github.com> Date: Thu, 10 Oct 2024 11:12:19 +0200 Subject: [PATCH 6/8] refactor: rust query engine (#193) --- Cargo.lock | 31 +- Cargo.toml | 5 +- crates/medmodels-core/Cargo.toml | 4 +- crates/medmodels-core/src/errors/medrecord.rs | 3 + crates/medmodels-core/src/errors/mod.rs | 2 + .../src/medrecord/datatypes/attribute.rs | 223 ++- .../src/medrecord/datatypes/mod.rs | 72 +- .../src/medrecord/datatypes/value.rs | 201 +- .../src/medrecord/example_dataset/mod.rs | 11 +- .../src/medrecord/graph/edge.rs | 6 +- .../medmodels-core/src/medrecord/graph/mod.rs | 47 +- .../src/medrecord/graph/node.rs | 6 +- crates/medmodels-core/src/medrecord/mod.rs | 42 +- .../src/medrecord/querying/attributes/mod.rs | 132 ++ .../medrecord/querying/attributes/operand.rs | 874 +++++++++ .../querying/attributes/operation.rs | 1357 +++++++++++++ .../src/medrecord/querying/edges/mod.rs | 58 + .../src/medrecord/querying/edges/operand.rs | 655 +++++++ .../src/medrecord/querying/edges/operation.rs | 762 ++++++++ .../src/medrecord/querying/edges/selection.rs | 32 + .../src/medrecord/querying/mod.rs | 15 +- .../src/medrecord/querying/nodes/mod.rs | 68 + .../src/medrecord/querying/nodes/operand.rs | 732 +++++++ .../src/medrecord/querying/nodes/operation.rs | 971 +++++++++ .../src/medrecord/querying/nodes/selection.rs | 35 + .../querying/operation/edge_operation.rs | 475 ----- .../src/medrecord/querying/operation/mod.rs | 394 ---- .../querying/operation/node_operation.rs | 246 --- .../medrecord/querying/operation/operand.rs | 649 ------ .../src/medrecord/querying/selection.rs | 1741 ----------------- .../src/medrecord/querying/traits.rs | 21 + .../src/medrecord/querying/values/mod.rs | 185 ++ .../src/medrecord/querying/values/operand.rs | 590 ++++++ .../medrecord/querying/values/operation.rs | 934 +++++++++ .../src/medrecord/querying/wrapper.rs | 45 + crates/medmodels-core/src/medrecord/schema.rs | 2 - rustmodels/Cargo.toml | 4 +- rustmodels/src/medrecord/mod.rs | 2 +- 38 files changed, 7960 insertions(+), 3672 deletions(-) create mode 100644 crates/medmodels-core/src/medrecord/querying/attributes/mod.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/attributes/operand.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/attributes/operation.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/edges/mod.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/edges/operand.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/edges/operation.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/edges/selection.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/nodes/mod.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/nodes/operand.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/nodes/operation.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/nodes/selection.rs delete mode 100644 crates/medmodels-core/src/medrecord/querying/operation/edge_operation.rs delete mode 100644 crates/medmodels-core/src/medrecord/querying/operation/mod.rs delete mode 100644 crates/medmodels-core/src/medrecord/querying/operation/node_operation.rs delete mode 100644 crates/medmodels-core/src/medrecord/querying/operation/operand.rs delete mode 100644 crates/medmodels-core/src/medrecord/querying/selection.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/traits.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/values/mod.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/values/operand.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/values/operation.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/wrapper.rs diff --git a/Cargo.lock b/Cargo.lock index 60095bba..280c2399 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -116,9 +116,9 @@ checksum = "d32a994c2b3ca201d9b263612a374263f05e7adde37c4707f693dcd375076d1f" [[package]] name = "bytemuck" -version = "1.14.3" +version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2ef034f05691a48569bd920a96c81b9d91bbad1ab5ac7c4616c1f6ef36cb79f" +checksum = "773d90827bc3feecfb67fab12e24de0749aad83c74b9504ecde46237b5cd24e2" dependencies = [ "bytemuck_derive", ] @@ -134,6 +134,12 @@ dependencies = [ "syn 2.0.50", ] +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.5.0" @@ -427,6 +433,15 @@ version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.10" @@ -517,8 +532,10 @@ name = "medmodels-core" version = "0.1.2" dependencies = [ "chrono", + "itertools", "medmodels-utils", "polars", + "roaring", "ron", "serde", ] @@ -1335,6 +1352,16 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +[[package]] +name = "roaring" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f4b84ba6e838ceb47b41de5194a60244fac43d9fe03b71dbe8c5a201081d6d1" +dependencies = [ + "bytemuck", + "byteorder", +] + [[package]] name = "ron" version = "0.8.1" diff --git a/Cargo.toml b/Cargo.toml index eac0a8c7..40154170 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,11 +13,8 @@ description = "Limebit MedModels Crate" [workspace.dependencies] hashbrown = { version = "0.14.5", features = ["serde"] } serde = { version = "1.0.203", features = ["derive"] } -ron = "0.8.1" -chrono = { version = "0.4.38", features = ["serde"] } -pyo3 = { version = "0.21.2", features = ["chrono"] } polars = { version = "0.40.0", features = ["polars-io"] } -pyo3-polars = "0.14.0" +chrono = { version = "0.4.38", features = ["serde"] } medmodels = { version = "0.1.2", path = "crates/medmodels" } medmodels-core = { version = "0.1.2", path = "crates/medmodels-core" } diff --git a/crates/medmodels-core/Cargo.toml b/crates/medmodels-core/Cargo.toml index 58097fcf..48225587 100644 --- a/crates/medmodels-core/Cargo.toml +++ b/crates/medmodels-core/Cargo.toml @@ -12,5 +12,7 @@ medmodels-utils = { workspace = true } polars = { workspace = true } serde = { workspace = true } -ron = { workspace = true } chrono = { workspace = true } +ron = "0.8.1" +roaring = "0.10.6" +itertools = "0.13.0" diff --git a/crates/medmodels-core/src/errors/medrecord.rs b/crates/medmodels-core/src/errors/medrecord.rs index f7afb230..3ad22a14 100644 --- a/crates/medmodels-core/src/errors/medrecord.rs +++ b/crates/medmodels-core/src/errors/medrecord.rs @@ -10,6 +10,7 @@ pub enum MedRecordError { ConversionError(String), AssertionError(String), SchemaError(String), + QueryError(String), } impl Error for MedRecordError { @@ -20,6 +21,7 @@ impl Error for MedRecordError { MedRecordError::ConversionError(message) => message, MedRecordError::AssertionError(message) => message, MedRecordError::SchemaError(message) => message, + MedRecordError::QueryError(message) => message, } } } @@ -32,6 +34,7 @@ impl Display for MedRecordError { Self::ConversionError(message) => write!(f, "ConversionError: {}", message), Self::AssertionError(message) => write!(f, "AssertionError: {}", message), Self::SchemaError(message) => write!(f, "SchemaError: {}", message), + Self::QueryError(message) => write!(f, "QueryError: {}", message), } } } diff --git a/crates/medmodels-core/src/errors/mod.rs b/crates/medmodels-core/src/errors/mod.rs index b0c37588..069281ca 100644 --- a/crates/medmodels-core/src/errors/mod.rs +++ b/crates/medmodels-core/src/errors/mod.rs @@ -14,6 +14,8 @@ impl From for MedRecordError { } } +pub type MedRecordResult = Result; + #[cfg(test)] mod test { use super::{GraphError, MedRecordError}; diff --git a/crates/medmodels-core/src/medrecord/datatypes/attribute.rs b/crates/medmodels-core/src/medrecord/datatypes/attribute.rs index f02f12d4..bdb2f12d 100644 --- a/crates/medmodels-core/src/medrecord/datatypes/attribute.rs +++ b/crates/medmodels-core/src/medrecord/datatypes/attribute.rs @@ -1,8 +1,16 @@ -use super::{Contains, EndsWith, MedRecordValue, StartsWith}; -use crate::errors::MedRecordError; +use super::{ + Abs, Contains, EndsWith, Lowercase, MedRecordValue, Mod, Pow, Slice, StartsWith, Trim, TrimEnd, + TrimStart, Uppercase, +}; +use crate::errors::{MedRecordError, MedRecordResult}; use medmodels_utils::implement_from_for_wrapper; use serde::{Deserialize, Serialize}; -use std::{cmp::Ordering, fmt::Display, hash::Hash}; +use std::{ + cmp::Ordering, + fmt::Display, + hash::Hash, + ops::{Add, Mul, Sub}, +}; #[derive(Debug, Clone, Serialize, Deserialize)] pub enum MedRecordAttribute { @@ -43,15 +51,6 @@ impl TryFrom for MedRecordAttribute { } } -impl Display for MedRecordAttribute { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::String(value) => write!(f, "{}", value), - Self::Int(value) => write!(f, "{}", value), - } - } -} - impl PartialEq for MedRecordAttribute { fn eq(&self, other: &Self) -> bool { match (self, other) { @@ -80,6 +79,140 @@ impl PartialOrd for MedRecordAttribute { } } +impl Display for MedRecordAttribute { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::String(value) => write!(f, "{}", value), + Self::Int(value) => write!(f, "{}", value), + } + } +} + +// TODO: Add tests +impl Add for MedRecordAttribute { + type Output = MedRecordResult; + + fn add(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => { + Ok(MedRecordAttribute::String(value + rhs.as_str())) + } + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot add {} to {}", rhs, value)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot add {} to {}", rhs, value)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value + rhs)) + } + } + } +} + +// TODO: Add tests +impl Sub for MedRecordAttribute { + type Output = MedRecordResult; + + fn sub(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), + ), + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value - rhs)) + } + } + } +} + +// TODO: Add tests +impl Mul for MedRecordAttribute { + type Output = MedRecordResult; + + fn mul(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot multiply {} by {}", value, rhs)), + ), + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot multiply {} by {}", value, rhs)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot multiply {} by {}", value, rhs)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value * rhs)) + } + } + } +} + +// TODO: Add tests +impl Pow for MedRecordAttribute { + fn pow(self, rhs: Self) -> MedRecordResult { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => { + Err(MedRecordError::AssertionError(format!( + "Cannot raise {} to the power of {}", + value, rhs + ))) + } + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => { + Err(MedRecordError::AssertionError(format!( + "Cannot raise {} to the power of {}", + value, rhs + ))) + } + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => { + Err(MedRecordError::AssertionError(format!( + "Cannot raise {} to the power of {}", + value, rhs + ))) + } + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value.pow(rhs as u32))) + } + } + } +} + +// TODO: Add tests +impl Mod for MedRecordAttribute { + fn r#mod(self, rhs: Self) -> MedRecordResult { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot mod {} by {}", value, rhs)), + ), + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot mod {} by {}", value, rhs)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot mod {} by {}", value, rhs)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value % rhs)) + } + } + } +} + +// TODO: Add tests +impl Abs for MedRecordAttribute { + fn abs(self) -> Self { + match self { + MedRecordAttribute::Int(value) => MedRecordAttribute::Int(value.abs()), + _ => self, + } + } +} + impl StartsWith for MedRecordAttribute { fn starts_with(&self, other: &Self) -> bool { match (self, other) { @@ -137,6 +270,72 @@ impl Contains for MedRecordAttribute { } } +// TODO: Add tests +impl Slice for MedRecordAttribute { + fn slice(self, range: std::ops::Range) -> Self { + match self { + MedRecordAttribute::String(value) => value[range].into(), + MedRecordAttribute::Int(value) => value.to_string()[range].into(), + } + } +} + +// TODO: Add tests +impl Trim for MedRecordAttribute { + fn trim(self) -> Self { + match self { + MedRecordAttribute::String(value) => { + MedRecordAttribute::String(value.trim().to_string()) + } + _ => self, + } + } +} + +// TODO: Add tests +impl TrimStart for MedRecordAttribute { + fn trim_start(self) -> Self { + match self { + MedRecordAttribute::String(value) => { + MedRecordAttribute::String(value.trim_start().to_string()) + } + _ => self, + } + } +} + +// TODO: Add tests +impl TrimEnd for MedRecordAttribute { + fn trim_end(self) -> Self { + match self { + MedRecordAttribute::String(value) => { + MedRecordAttribute::String(value.trim_end().to_string()) + } + _ => self, + } + } +} + +// TODO: Add tests +impl Lowercase for MedRecordAttribute { + fn lowercase(self) -> Self { + match self { + MedRecordAttribute::String(value) => MedRecordAttribute::String(value.to_lowercase()), + _ => self, + } + } +} + +// TODO: Add tests +impl Uppercase for MedRecordAttribute { + fn uppercase(self) -> Self { + match self { + MedRecordAttribute::String(value) => MedRecordAttribute::String(value.to_uppercase()), + _ => self, + } + } +} + #[cfg(test)] mod test { use super::MedRecordAttribute; diff --git a/crates/medmodels-core/src/medrecord/datatypes/mod.rs b/crates/medmodels-core/src/medrecord/datatypes/mod.rs index 0beca37e..ada0f6c0 100644 --- a/crates/medmodels-core/src/medrecord/datatypes/mod.rs +++ b/crates/medmodels-core/src/medrecord/datatypes/mod.rs @@ -2,6 +2,7 @@ mod attribute; mod value; pub use self::{attribute::MedRecordAttribute, value::MedRecordValue}; +use super::EdgeIndex; use crate::errors::MedRecordError; use serde::{Deserialize, Serialize}; use std::{fmt::Display, ops::Range}; @@ -51,6 +52,24 @@ impl From<&MedRecordValue> for DataType { } } +impl From for DataType { + fn from(value: MedRecordAttribute) -> Self { + match value { + MedRecordAttribute::String(_) => DataType::String, + MedRecordAttribute::Int(_) => DataType::Int, + } + } +} + +impl From<&MedRecordAttribute> for DataType { + fn from(value: &MedRecordAttribute) -> Self { + match value { + MedRecordAttribute::String(_) => DataType::String, + MedRecordAttribute::Int(_) => DataType::Int, + } + } +} + impl PartialEq for DataType { fn eq(&self, other: &Self) -> bool { match (self, other) { @@ -126,28 +145,52 @@ impl DataType { } } -pub trait Pow: Sized { - fn pow(self, exp: Self) -> Result; -} - -pub trait Mod: Sized { - fn r#mod(self, other: Self) -> Result; -} - pub trait StartsWith { fn starts_with(&self, other: &Self) -> bool; } +// TODO: Add tests +impl StartsWith for EdgeIndex { + fn starts_with(&self, other: &Self) -> bool { + self.to_string().starts_with(&other.to_string()) + } +} + pub trait EndsWith { fn ends_with(&self, other: &Self) -> bool; } +// TODO: Add tests +impl EndsWith for EdgeIndex { + fn ends_with(&self, other: &Self) -> bool { + self.to_string().ends_with(&other.to_string()) + } +} + pub trait Contains { fn contains(&self, other: &Self) -> bool; } -pub trait PartialNeq: PartialEq { - fn neq(&self, other: &Self) -> bool; +// TODO: Add tests +impl Contains for EdgeIndex { + fn contains(&self, other: &Self) -> bool { + self.to_string().contains(&other.to_string()) + } +} + +pub trait Pow: Sized { + fn pow(self, exp: Self) -> Result; +} + +pub trait Mod: Sized { + fn r#mod(self, other: Self) -> Result; +} + +// TODO: Add tests +impl Mod for EdgeIndex { + fn r#mod(self, other: Self) -> Result { + Ok(self % other) + } } pub trait Round { @@ -194,15 +237,6 @@ pub trait Slice { fn slice(self, range: Range) -> Self; } -impl PartialNeq for T -where - T: PartialOrd, -{ - fn neq(&self, other: &Self) -> bool { - self != other - } -} - #[cfg(test)] mod test { use super::{DataType, MedRecordValue}; diff --git a/crates/medmodels-core/src/medrecord/datatypes/value.rs b/crates/medmodels-core/src/medrecord/datatypes/value.rs index 792d879d..f3995102 100644 --- a/crates/medmodels-core/src/medrecord/datatypes/value.rs +++ b/crates/medmodels-core/src/medrecord/datatypes/value.rs @@ -3,7 +3,7 @@ use super::{ Trim, TrimEnd, TrimStart, Uppercase, }; use crate::errors::MedRecordError; -use chrono::NaiveDateTime; +use chrono::{DateTime, NaiveDateTime}; use medmodels_utils::implement_from_for_wrapper; use serde::{Deserialize, Serialize}; use std::{ @@ -210,9 +210,17 @@ impl Add for MedRecordValue { (MedRecordValue::DateTime(value), MedRecordValue::Bool(rhs)) => Err( MedRecordError::AssertionError(format!("Cannot add {} to {}", rhs, value)), ), - (MedRecordValue::DateTime(value), MedRecordValue::DateTime(rhs)) => Err( - MedRecordError::AssertionError(format!("Cannot add {} to {}", rhs, value)), - ), + (MedRecordValue::DateTime(value), MedRecordValue::DateTime(rhs)) => { + Ok(DateTime::from_timestamp( + value.and_utc().timestamp() + rhs.and_utc().timestamp(), + 0, + ) + .ok_or(MedRecordError::AssertionError( + "Invalid timestamp".to_string(), + ))? + .naive_utc() + .into()) + } (MedRecordValue::DateTime(value), MedRecordValue::Null) => Err( MedRecordError::AssertionError(format!("Cannot add None to {}", value)), ), @@ -327,9 +335,17 @@ impl Sub for MedRecordValue { (MedRecordValue::DateTime(value), MedRecordValue::Bool(rhs)) => Err( MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), ), - (MedRecordValue::DateTime(value), MedRecordValue::DateTime(rhs)) => Err( - MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), - ), + (MedRecordValue::DateTime(value), MedRecordValue::DateTime(rhs)) => { + Ok(DateTime::from_timestamp( + value.and_utc().timestamp() - rhs.and_utc().timestamp(), + 0, + ) + .ok_or(MedRecordError::AssertionError( + "Invalid timestamp".to_string(), + ))? + .naive_utc() + .into()) + } (MedRecordValue::DateTime(value), MedRecordValue::Null) => Err( MedRecordError::AssertionError(format!("Cannot subtract None from {}", value)), ), @@ -621,9 +637,17 @@ impl Div for MedRecordValue { (MedRecordValue::DateTime(value), MedRecordValue::String(other)) => Err( MedRecordError::AssertionError(format!("Cannot divide {} by {}", value, other)), ), - (MedRecordValue::DateTime(value), MedRecordValue::Int(other)) => Err( - MedRecordError::AssertionError(format!("Cannot divide {} by {}", value, other)), - ), + (MedRecordValue::DateTime(value), MedRecordValue::Int(other)) => { + Ok(DateTime::from_timestamp( + (value.and_utc().timestamp() as f64 / other as f64).floor() as i64, + 0, + ) + .ok_or(MedRecordError::AssertionError( + "Invalid timestamp".to_string(), + ))? + .naive_utc() + .into()) + } (MedRecordValue::DateTime(value), MedRecordValue::Float(other)) => Err( MedRecordError::AssertionError(format!("Cannot divide {} by {}", value, other)), ), @@ -966,6 +990,53 @@ impl Mod for MedRecordValue { } } +impl Round for MedRecordValue { + fn round(self) -> Self { + match self { + MedRecordValue::Float(value) => MedRecordValue::Float(value.round()), + _ => self, + } + } +} + +impl Ceil for MedRecordValue { + fn ceil(self) -> Self { + match self { + MedRecordValue::Float(value) => MedRecordValue::Float(value.ceil()), + _ => self, + } + } +} + +impl Floor for MedRecordValue { + fn floor(self) -> Self { + match self { + MedRecordValue::Float(value) => MedRecordValue::Float(value.floor()), + _ => self, + } + } +} + +impl Abs for MedRecordValue { + fn abs(self) -> Self { + match self { + MedRecordValue::Int(value) => MedRecordValue::Int(value.abs()), + MedRecordValue::Float(value) => MedRecordValue::Float(value.abs()), + _ => self, + } + } +} + +impl Sqrt for MedRecordValue { + fn sqrt(self) -> Self { + match self { + MedRecordValue::Int(value) => MedRecordValue::Float((value as f64).sqrt()), + MedRecordValue::Float(value) => MedRecordValue::Float(value.sqrt()), + _ => self, + } + } +} + impl StartsWith for MedRecordValue { fn starts_with(&self, other: &Self) -> bool { match (self, other) { @@ -1081,53 +1152,6 @@ impl Slice for MedRecordValue { } } -impl Round for MedRecordValue { - fn round(self) -> Self { - match self { - MedRecordValue::Float(value) => MedRecordValue::Float(value.round()), - _ => self, - } - } -} - -impl Ceil for MedRecordValue { - fn ceil(self) -> Self { - match self { - MedRecordValue::Float(value) => MedRecordValue::Float(value.ceil()), - _ => self, - } - } -} - -impl Floor for MedRecordValue { - fn floor(self) -> Self { - match self { - MedRecordValue::Float(value) => MedRecordValue::Float(value.floor()), - _ => self, - } - } -} - -impl Abs for MedRecordValue { - fn abs(self) -> Self { - match self { - MedRecordValue::Int(value) => MedRecordValue::Int(value.abs()), - MedRecordValue::Float(value) => MedRecordValue::Float(value.abs()), - _ => self, - } - } -} - -impl Sqrt for MedRecordValue { - fn sqrt(self) -> Self { - match self { - MedRecordValue::Int(value) => MedRecordValue::Float((value as f64).sqrt()), - MedRecordValue::Float(value) => MedRecordValue::Float(value.sqrt()), - _ => self, - } - } -} - impl Trim for MedRecordValue { fn trim(self) -> Self { match self { @@ -1183,7 +1207,7 @@ mod test { Uppercase, }, }; - use chrono::NaiveDateTime; + use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime}; #[test] fn test_default() { @@ -1669,9 +1693,23 @@ mod test { (MedRecordValue::DateTime(NaiveDateTime::MIN) + MedRecordValue::Bool(false)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); - assert!((MedRecordValue::DateTime(NaiveDateTime::MIN) - + MedRecordValue::DateTime(NaiveDateTime::MIN)) - .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); + assert_eq!( + MedRecordValue::DateTime( + NaiveDate::from_ymd_opt(1970, 1, 4) + .unwrap() + .and_time(NaiveTime::MIN) + ), + (MedRecordValue::DateTime( + NaiveDate::from_ymd_opt(1970, 1, 2) + .unwrap() + .and_time(NaiveTime::MIN) + ) + MedRecordValue::DateTime( + NaiveDate::from_ymd_opt(1970, 1, 3) + .unwrap() + .and_time(NaiveTime::MIN) + )) + .unwrap() + ); assert!( (MedRecordValue::DateTime(NaiveDateTime::MIN) + MedRecordValue::Null) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) @@ -1794,9 +1832,12 @@ mod test { (MedRecordValue::DateTime(NaiveDateTime::MIN) - MedRecordValue::Bool(false)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); - assert!((MedRecordValue::DateTime(NaiveDateTime::MIN) - - MedRecordValue::DateTime(NaiveDateTime::MIN)) - .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); + assert_eq!( + MedRecordValue::DateTime(DateTime::from_timestamp(0, 0).unwrap().naive_utc()), + (MedRecordValue::DateTime(NaiveDateTime::MAX) + - MedRecordValue::DateTime(NaiveDateTime::MAX)) + .unwrap() + ); assert!( (MedRecordValue::DateTime(NaiveDateTime::MIN) - MedRecordValue::Null) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) @@ -1951,15 +1992,15 @@ mod test { / MedRecordValue::String("value".to_string())) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( - (MedRecordValue::String("value".to_string()) / MedRecordValue::Int(0)) + (MedRecordValue::String("value".to_string()) / MedRecordValue::Int(1)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!( - (MedRecordValue::String("value".to_string()) / MedRecordValue::Float(0_f64)) + (MedRecordValue::String("value".to_string()) / MedRecordValue::Float(1_f64)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!( - (MedRecordValue::String("value".to_string()) / MedRecordValue::Bool(false)) + (MedRecordValue::String("value".to_string()) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!((MedRecordValue::String("value".to_string()) @@ -1982,7 +2023,7 @@ mod test { MedRecordValue::Float(1_f64), (MedRecordValue::Int(5) / MedRecordValue::Float(5_f64)).unwrap() ); - assert!((MedRecordValue::Int(0) / MedRecordValue::Bool(false)) + assert!((MedRecordValue::Int(0) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( (MedRecordValue::Int(0) / MedRecordValue::DateTime(NaiveDateTime::MIN)) @@ -2003,7 +2044,7 @@ mod test { MedRecordValue::Float(1_f64), (MedRecordValue::Float(5_f64) / MedRecordValue::Float(5_f64)).unwrap() ); - assert!((MedRecordValue::Float(0_f64) / MedRecordValue::Bool(false)) + assert!((MedRecordValue::Float(0_f64) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( (MedRecordValue::Float(0_f64) / MedRecordValue::DateTime(NaiveDateTime::MIN)) @@ -2016,11 +2057,11 @@ mod test { (MedRecordValue::Bool(false) / MedRecordValue::String("value".to_string())) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); - assert!((MedRecordValue::Bool(false) / MedRecordValue::Int(0)) + assert!((MedRecordValue::Bool(false) / MedRecordValue::Int(1)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!((MedRecordValue::Bool(false) / MedRecordValue::Float(0_f64)) + assert!((MedRecordValue::Bool(false) / MedRecordValue::Float(1_f64)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!((MedRecordValue::Bool(false) / MedRecordValue::Bool(false)) + assert!((MedRecordValue::Bool(false) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( (MedRecordValue::Bool(false) / MedRecordValue::DateTime(NaiveDateTime::MIN)) @@ -2032,16 +2073,16 @@ mod test { assert!((MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::String("value".to_string())) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!( - (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Int(0)) - .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) + assert_eq!( + MedRecordValue::DateTime(NaiveDateTime::MIN), + (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Int(1)).unwrap() ); assert!( - (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Float(0_f64)) + (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Float(1_f64)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!( - (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Bool(false)) + (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!((MedRecordValue::DateTime(NaiveDateTime::MIN) @@ -2056,11 +2097,11 @@ mod test { (MedRecordValue::Null / MedRecordValue::String("value".to_string())) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); - assert!((MedRecordValue::Null / MedRecordValue::Int(0)) + assert!((MedRecordValue::Null / MedRecordValue::Int(1)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!((MedRecordValue::Null / MedRecordValue::Float(0_f64)) + assert!((MedRecordValue::Null / MedRecordValue::Float(1_f64)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!((MedRecordValue::Null / MedRecordValue::Bool(false)) + assert!((MedRecordValue::Null / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( (MedRecordValue::Null / MedRecordValue::DateTime(NaiveDateTime::MIN)) diff --git a/crates/medmodels-core/src/medrecord/example_dataset/mod.rs b/crates/medmodels-core/src/medrecord/example_dataset/mod.rs index e4879307..2a0f3354 100644 --- a/crates/medmodels-core/src/medrecord/example_dataset/mod.rs +++ b/crates/medmodels-core/src/medrecord/example_dataset/mod.rs @@ -71,7 +71,7 @@ impl MedRecord { .into_reader_with_file_handle(cursor) .finish() .expect("DataFrame can be built"); - let patient_diagnosis_ids = (0..patient_diagnosis.height()).collect::>(); + let patient_diagnosis_ids = (0..patient_diagnosis.height() as u32).collect::>(); let cursor = Cursor::new(PATIENT_DRUG); let patient_drug = CsvReadOptions::default() @@ -79,8 +79,8 @@ impl MedRecord { .into_reader_with_file_handle(cursor) .finish() .expect("DataFrame can be built"); - let patient_drug_ids = (patient_diagnosis.height() - ..patient_diagnosis.height() + patient_drug.height()) + let patient_drug_ids = (patient_diagnosis.height() as u32 + ..(patient_diagnosis.height() + patient_drug.height()) as u32) .collect::>(); let cursor = Cursor::new(PATIENT_PROCEDURE); @@ -89,8 +89,9 @@ impl MedRecord { .into_reader_with_file_handle(cursor) .finish() .expect("DataFrame can be built"); - let patient_procedure_ids = (patient_diagnosis.height() + patient_drug.height() - ..patient_diagnosis.height() + patient_drug.height() + patient_procedure.height()) + let patient_procedure_ids = ((patient_diagnosis.height() + patient_drug.height()) as u32 + ..(patient_diagnosis.height() + patient_drug.height() + patient_procedure.height()) + as u32) .collect::>(); let mut medrecord = Self::from_dataframes( diff --git a/crates/medmodels-core/src/medrecord/graph/edge.rs b/crates/medmodels-core/src/medrecord/graph/edge.rs index a45b6c4d..36b790d8 100644 --- a/crates/medmodels-core/src/medrecord/graph/edge.rs +++ b/crates/medmodels-core/src/medrecord/graph/edge.rs @@ -3,9 +3,9 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Edge { - pub attributes: Attributes, - pub(super) source_node_index: NodeIndex, - pub(super) target_node_index: NodeIndex, + pub(crate) attributes: Attributes, + pub(crate) source_node_index: NodeIndex, + pub(crate) target_node_index: NodeIndex, } impl Edge { diff --git a/crates/medmodels-core/src/medrecord/graph/mod.rs b/crates/medmodels-core/src/medrecord/graph/mod.rs index 0a7da3de..96a82584 100644 --- a/crates/medmodels-core/src/medrecord/graph/mod.rs +++ b/crates/medmodels-core/src/medrecord/graph/mod.rs @@ -9,18 +9,18 @@ use node::Node; use serde::{Deserialize, Serialize}; use std::{ collections::{HashMap, HashSet}, - sync::atomic::AtomicUsize, + sync::atomic::AtomicU32, }; pub type NodeIndex = MedRecordAttribute; -pub type EdgeIndex = usize; +pub type EdgeIndex = u32; pub type Attributes = HashMap; #[derive(Serialize, Deserialize, Debug)] pub(super) struct Graph { pub(crate) nodes: MrHashMap, pub(crate) edges: MrHashMap, - edge_index_counter: AtomicUsize, + edge_index_counter: AtomicU32, } impl Clone for Graph { @@ -28,7 +28,7 @@ impl Clone for Graph { Self { nodes: self.nodes.clone(), edges: self.edges.clone(), - edge_index_counter: AtomicUsize::new( + edge_index_counter: AtomicU32::new( self.edge_index_counter .load(std::sync::atomic::Ordering::Relaxed), ), @@ -42,7 +42,7 @@ impl Graph { Self { nodes: MrHashMap::new(), edges: MrHashMap::new(), - edge_index_counter: AtomicUsize::new(0), + edge_index_counter: AtomicU32::new(0), } } @@ -50,7 +50,7 @@ impl Graph { Self { nodes: MrHashMap::with_capacity(node_capacity), edges: MrHashMap::with_capacity(edge_capacity), - edge_index_counter: AtomicUsize::new(0), + edge_index_counter: AtomicU32::new(0), } } @@ -58,13 +58,13 @@ impl Graph { self.nodes.clear(); self.edges.clear(); - self.edge_index_counter = AtomicUsize::new(0); + self.edge_index_counter = AtomicU32::new(0); } pub fn clear_edges(&mut self) { self.edges.clear(); - self.edge_index_counter = AtomicUsize::new(0); + self.edge_index_counter = AtomicU32::new(0); } pub fn node_count(&self) -> usize { @@ -359,7 +359,7 @@ impl Graph { self.edges.contains_key(edge_index) } - pub fn neighbors( + pub fn neighbors_outgoing( &self, node_index: &NodeIndex, ) -> Result, GraphError> { @@ -381,6 +381,29 @@ impl Graph { })) } + // TODO: Add tests + pub fn neighbors_incoming( + &self, + node_index: &NodeIndex, + ) -> Result, GraphError> { + Ok(self + .nodes + .get(node_index) + .ok_or(GraphError::IndexError(format!( + "Cannot find node with index {}", + node_index + )))? + .incoming_edge_indices + .iter() + .map(|edge_index| { + &self + .edges + .get(edge_index) + .expect("Edge must exist") + .source_node_index + })) + } + pub fn neighbors_undirected( &self, node_index: &NodeIndex, @@ -913,7 +936,7 @@ mod test { fn test_neighbors() { let graph = create_graph(); - let neighbors = graph.neighbors(&"0".into()).unwrap(); + let neighbors = graph.neighbors_outgoing(&"0".into()).unwrap(); assert_eq!(2, neighbors.count()); } @@ -923,7 +946,7 @@ mod test { let graph = create_graph(); assert!(graph - .neighbors(&"50".into()) + .neighbors_outgoing(&"50".into()) .is_err_and(|e| matches!(e, GraphError::IndexError(_)))); } @@ -931,7 +954,7 @@ mod test { fn test_neighbors_undirected() { let graph = create_graph(); - let neighbors = graph.neighbors(&"2".into()).unwrap(); + let neighbors = graph.neighbors_outgoing(&"2".into()).unwrap(); assert_eq!(0, neighbors.count()); let neighbors = graph.neighbors_undirected(&"2".into()).unwrap(); diff --git a/crates/medmodels-core/src/medrecord/graph/node.rs b/crates/medmodels-core/src/medrecord/graph/node.rs index 9af16851..4d90ee0f 100644 --- a/crates/medmodels-core/src/medrecord/graph/node.rs +++ b/crates/medmodels-core/src/medrecord/graph/node.rs @@ -4,9 +4,9 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Node { - pub attributes: Attributes, - pub(super) outgoing_edge_indices: MrHashSet, - pub(super) incoming_edge_indices: MrHashSet, + pub(crate) attributes: Attributes, + pub(crate) outgoing_edge_indices: MrHashSet, + pub(crate) incoming_edge_indices: MrHashSet, } impl Node { diff --git a/crates/medmodels-core/src/medrecord/mod.rs b/crates/medmodels-core/src/medrecord/mod.rs index ee4e8ea0..f4000061 100644 --- a/crates/medmodels-core/src/medrecord/mod.rs +++ b/crates/medmodels-core/src/medrecord/mod.rs @@ -11,9 +11,9 @@ pub use self::{ graph::{Attributes, EdgeIndex, NodeIndex}, group_mapping::Group, querying::{ - edge, node, ArithmeticOperation, EdgeAttributeOperand, EdgeIndexOperand, EdgeOperand, - EdgeOperation, NodeAttributeOperand, NodeIndexOperand, NodeOperand, NodeOperation, - TransformationOperation, ValueOperand, + edges::EdgeOperand, + nodes::NodeOperand, + wrapper::{CardinalityWrapper, Wrapper}, }, schema::{AttributeDataType, AttributeType, GroupSchema, Schema}, }; @@ -22,7 +22,7 @@ use ::polars::frame::DataFrame; use graph::Graph; use group_mapping::GroupMapping; use polars::{dataframe_to_edges, dataframe_to_nodes}; -use querying::{EdgeSelection, NodeSelection}; +use querying::{edges::EdgeSelection, nodes::NodeSelection}; use serde::{Deserialize, Serialize}; use std::{fs, mem, path::Path}; @@ -683,12 +683,22 @@ impl MedRecord { self.group_mapping.contains_group(group) } - pub fn neighbors( + pub fn neighbors_outgoing( &self, node_index: &NodeIndex, ) -> Result, MedRecordError> { self.graph - .neighbors(node_index) + .neighbors_outgoing(node_index) + .map_err(MedRecordError::from) + } + + // TODO: Add tests + pub fn neighbors_incoming( + &self, + node_index: &NodeIndex, + ) -> Result, MedRecordError> { + self.graph + .neighbors_incoming(node_index) .map_err(MedRecordError::from) } @@ -706,12 +716,18 @@ impl MedRecord { self.group_mapping.clear(); } - pub fn select_nodes(&self, operation: NodeOperation) -> NodeSelection { - NodeSelection::new(self, operation) + pub fn select_nodes(&self, query: Q) -> NodeSelection + where + Q: FnOnce(&mut Wrapper), + { + NodeSelection::new(self, query) } - pub fn select_edges(&self, operation: EdgeOperation) -> EdgeSelection { - EdgeSelection::new(self, operation) + pub fn select_edges(&self, query: Q) -> EdgeSelection + where + Q: FnOnce(&mut Wrapper), + { + EdgeSelection::new(self, query) } } @@ -1870,7 +1886,7 @@ mod test { fn test_neighbors() { let medrecord = create_medrecord(); - let neighbors = medrecord.neighbors(&"0".into()).unwrap(); + let neighbors = medrecord.neighbors_outgoing(&"0".into()).unwrap(); assert_eq!(2, neighbors.count()); } @@ -1881,7 +1897,7 @@ mod test { // Querying neighbors of a non-existing node sohuld fail assert!(medrecord - .neighbors(&"0".into()) + .neighbors_outgoing(&"0".into()) .is_err_and(|e| matches!(e, MedRecordError::IndexError(_)))); } @@ -1889,7 +1905,7 @@ mod test { fn test_neighbors_undirected() { let medrecord = create_medrecord(); - let neighbors = medrecord.neighbors(&"2".into()).unwrap(); + let neighbors = medrecord.neighbors_outgoing(&"2".into()).unwrap(); assert_eq!(0, neighbors.count()); let neighbors = medrecord.neighbors_undirected(&"2".into()).unwrap(); diff --git a/crates/medmodels-core/src/medrecord/querying/attributes/mod.rs b/crates/medmodels-core/src/medrecord/querying/attributes/mod.rs new file mode 100644 index 00000000..d16fcabd --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/attributes/mod.rs @@ -0,0 +1,132 @@ +mod operand; +mod operation; + +use super::{ + edges::{EdgeOperand, EdgeOperation}, + nodes::{NodeOperand, NodeOperation}, + BoxedIterator, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{Attributes, EdgeIndex, MedRecordAttribute, NodeIndex}, + MedRecord, +}; +pub use operand::{AttributesTreeOperand, MultipleAttributesOperand}; +pub use operation::{AttributesTreeOperation, MultipleAttributesOperation}; +use std::fmt::Display; + +#[derive(Debug, Clone)] +pub enum SingleKind { + Max, + Min, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum MultipleKind { + Max, + Min, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum SingleComparisonKind { + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + NotEqualTo, + StartsWith, + EndsWith, + Contains, +} + +#[derive(Debug, Clone)] +pub enum MultipleComparisonKind { + IsIn, + IsNotIn, +} + +#[derive(Debug, Clone)] +pub enum BinaryArithmeticKind { + Add, + Sub, + Mul, + Pow, + Mod, +} + +impl Display for BinaryArithmeticKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BinaryArithmeticKind::Add => write!(f, "add"), + BinaryArithmeticKind::Sub => write!(f, "sub"), + BinaryArithmeticKind::Mul => write!(f, "mul"), + BinaryArithmeticKind::Pow => write!(f, "pow"), + BinaryArithmeticKind::Mod => write!(f, "mod"), + } + } +} + +#[derive(Debug, Clone)] +pub enum UnaryArithmeticKind { + Abs, + Trim, + TrimStart, + TrimEnd, + Lowercase, + Uppercase, +} + +pub(crate) trait GetAttributes { + fn get_attributes<'a>(&'a self, medrecord: &'a MedRecord) -> MedRecordResult<&'a Attributes>; +} + +impl GetAttributes for NodeIndex { + fn get_attributes<'a>(&'a self, medrecord: &'a MedRecord) -> MedRecordResult<&'a Attributes> { + medrecord.node_attributes(self) + } +} + +impl GetAttributes for EdgeIndex { + fn get_attributes<'a>(&'a self, medrecord: &'a MedRecord) -> MedRecordResult<&'a Attributes> { + medrecord.edge_attributes(self) + } +} + +#[derive(Debug, Clone)] +pub enum Context { + NodeOperand(NodeOperand), + EdgeOperand(EdgeOperand), +} + +impl Context { + pub(crate) fn get_attributes<'a>( + &self, + medrecord: &'a MedRecord, + ) -> MedRecordResult>> { + Ok(match self { + Self::NodeOperand(node_operand) => { + let node_indices = node_operand.evaluate(medrecord)?; + + Box::new( + NodeOperation::get_attributes(medrecord, node_indices).map(|(_, value)| value), + ) + } + Self::EdgeOperand(edge_operand) => { + let edge_indices = edge_operand.evaluate(medrecord)?; + + Box::new( + EdgeOperation::get_attributes(medrecord, edge_indices).map(|(_, value)| value), + ) + } + }) + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/attributes/operand.rs b/crates/medmodels-core/src/medrecord/querying/attributes/operand.rs new file mode 100644 index 00000000..83af4393 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/attributes/operand.rs @@ -0,0 +1,874 @@ +use super::{ + operation::{AttributesTreeOperation, MultipleAttributesOperation, SingleAttributeOperation}, + BinaryArithmeticKind, Context, GetAttributes, MultipleComparisonKind, MultipleKind, + SingleComparisonKind, SingleKind, UnaryArithmeticKind, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{ + querying::{ + traits::{DeepClone, ReadWriteOrPanic}, + values::{self, MultipleValuesOperand}, + BoxedIterator, + }, + MedRecordAttribute, Wrapper, + }, + MedRecord, +}; +use std::{fmt::Display, hash::Hash}; + +macro_rules! implement_attributes_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = Wrapper::::new( + self.deep_clone(), + MultipleKind::$variant, + ); + + self.operations + .push(AttributesTreeOperation::AttributesOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_attribute_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = + Wrapper::::new(self.deep_clone(), SingleKind::$variant); + + self.operations + .push(MultipleAttributesOperation::AttributeOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_single_attribute_comparison_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, attribute: V) { + self.operations + .push($operation::SingleAttributeComparisonOperation { + operand: attribute.into(), + kind: SingleComparisonKind::$kind, + }); + } + }; +} + +macro_rules! implement_binary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, attribute: V) { + self.operations.push($operation::BinaryArithmeticOpration { + operand: attribute.into(), + kind: BinaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_unary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name(&mut self) { + self.operations.push($operation::UnaryArithmeticOperation { + kind: UnaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_assertion_operation { + ($name:ident, $operation:expr) => { + pub fn $name(&mut self) { + self.operations.push($operation); + } + }; +} + +macro_rules! implement_wrapper_operand { + ($name:ident) => { + pub fn $name(&self) { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_return { + ($name:ident, $return_operand:ident) => { + pub fn $name(&self) -> Wrapper<$return_operand> { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_argument { + ($name:ident, $attribute_type:ty) => { + pub fn $name(&self, attribute: $attribute_type) { + self.0.write_or_panic().$name(attribute) + } + }; +} + +#[derive(Debug, Clone)] +pub enum SingleAttributeComparisonOperand { + Operand(SingleAttributeOperand), + Attribute(MedRecordAttribute), +} + +impl DeepClone for SingleAttributeComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Attribute(attribute) => Self::Attribute(attribute.clone()), + } + } +} + +impl From> for SingleAttributeComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for SingleAttributeComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From for SingleAttributeComparisonOperand { + fn from(value: V) -> Self { + Self::Attribute(value.into()) + } +} + +#[derive(Debug, Clone)] +pub enum MultipleAttributesComparisonOperand { + Operand(MultipleAttributesOperand), + Attributes(Vec), +} + +impl DeepClone for MultipleAttributesComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Attributes(attribute) => Self::Attributes(attribute.clone()), + } + } +} + +impl From> for MultipleAttributesComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for MultipleAttributesComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From> for MultipleAttributesComparisonOperand { + fn from(value: Vec) -> Self { + Self::Attributes(value.into_iter().map(Into::into).collect()) + } +} + +impl + Clone, const N: usize> From<[V; N]> + for MultipleAttributesComparisonOperand +{ + fn from(value: [V; N]) -> Self { + value.to_vec().into() + } +} + +#[derive(Debug, Clone)] +pub struct AttributesTreeOperand { + pub(crate) context: Context, + operations: Vec, +} + +impl DeepClone for AttributesTreeOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl AttributesTreeOperand { + pub(crate) fn new(context: Context) -> Self { + Self { + context, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + ) -> MedRecordResult)>> { + let attributes = Box::new(attributes) as BoxedIterator<(&'a T, Vec)>; + + self.operations + .iter() + .try_fold(attributes, |attribute_tuples, operation| { + operation.evaluate(medrecord, attribute_tuples) + }) + } + + implement_attributes_operation!(max, Max); + implement_attributes_operation!(min, Min); + implement_attributes_operation!(count, Count); + implement_attributes_operation!(sum, Sum); + implement_attributes_operation!(first, First); + implement_attributes_operation!(last, Last); + + implement_single_attribute_comparison_operation!( + greater_than, + AttributesTreeOperation, + GreaterThan + ); + implement_single_attribute_comparison_operation!( + greater_than_or_equal_to, + AttributesTreeOperation, + GreaterThanOrEqualTo + ); + implement_single_attribute_comparison_operation!(less_than, AttributesTreeOperation, LessThan); + implement_single_attribute_comparison_operation!( + less_than_or_equal_to, + AttributesTreeOperation, + LessThanOrEqualTo + ); + implement_single_attribute_comparison_operation!(equal_to, AttributesTreeOperation, EqualTo); + implement_single_attribute_comparison_operation!( + not_equal_to, + AttributesTreeOperation, + NotEqualTo + ); + implement_single_attribute_comparison_operation!( + starts_with, + AttributesTreeOperation, + StartsWith + ); + implement_single_attribute_comparison_operation!(ends_with, AttributesTreeOperation, EndsWith); + implement_single_attribute_comparison_operation!(contains, AttributesTreeOperation, Contains); + + pub fn is_in>(&mut self, attributes: V) { + self.operations.push( + AttributesTreeOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsIn, + }, + ); + } + + pub fn is_not_in>(&mut self, attributes: V) { + self.operations.push( + AttributesTreeOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsNotIn, + }, + ); + } + + implement_binary_arithmetic_operation!(add, AttributesTreeOperation, Add); + implement_binary_arithmetic_operation!(sub, AttributesTreeOperation, Sub); + implement_binary_arithmetic_operation!(mul, AttributesTreeOperation, Mul); + implement_binary_arithmetic_operation!(pow, AttributesTreeOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, AttributesTreeOperation, Mod); + + implement_unary_arithmetic_operation!(abs, AttributesTreeOperation, Abs); + implement_unary_arithmetic_operation!(trim, AttributesTreeOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, AttributesTreeOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, AttributesTreeOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, AttributesTreeOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, AttributesTreeOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(AttributesTreeOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, AttributesTreeOperation::IsString); + implement_assertion_operation!(is_int, AttributesTreeOperation::IsInt); + implement_assertion_operation!(is_max, AttributesTreeOperation::IsMax); + implement_assertion_operation!(is_min, AttributesTreeOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(self.context.clone()); + let mut or_operand = Wrapper::::new(self.context.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(AttributesTreeOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: Context) -> Self { + AttributesTreeOperand::new(context).into() + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + ) -> MedRecordResult)>> { + self.0.read_or_panic().evaluate(medrecord, attributes) + } + + implement_wrapper_operand_with_return!(max, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(min, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(count, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(sum, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(first, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(last, MultipleAttributesOperand); + + implement_wrapper_operand_with_argument!( + greater_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!( + not_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + starts_with, + impl Into + ); + implement_wrapper_operand_with_argument!( + ends_with, + impl Into + ); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!( + is_not_in, + impl Into + ); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +#[derive(Debug, Clone)] +pub struct MultipleAttributesOperand { + pub(crate) context: AttributesTreeOperand, + pub(crate) kind: MultipleKind, + operations: Vec, +} + +impl DeepClone for MultipleAttributesOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl MultipleAttributesOperand { + pub(crate) fn new(context: AttributesTreeOperand, kind: MultipleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + ) -> MedRecordResult> { + let attributes = Box::new(attributes) as BoxedIterator<(&'a T, MedRecordAttribute)>; + + self.operations + .iter() + .try_fold(attributes, |attribute_tuples, operation| { + operation.evaluate(medrecord, attribute_tuples) + }) + } + + implement_attribute_operation!(max, Max); + implement_attribute_operation!(min, Min); + implement_attribute_operation!(count, Count); + implement_attribute_operation!(sum, Sum); + implement_attribute_operation!(first, First); + implement_attribute_operation!(last, Last); + + implement_single_attribute_comparison_operation!( + greater_than, + MultipleAttributesOperation, + GreaterThan + ); + implement_single_attribute_comparison_operation!( + greater_than_or_equal_to, + MultipleAttributesOperation, + GreaterThanOrEqualTo + ); + implement_single_attribute_comparison_operation!( + less_than, + MultipleAttributesOperation, + LessThan + ); + implement_single_attribute_comparison_operation!( + less_than_or_equal_to, + MultipleAttributesOperation, + LessThanOrEqualTo + ); + implement_single_attribute_comparison_operation!( + equal_to, + MultipleAttributesOperation, + EqualTo + ); + implement_single_attribute_comparison_operation!( + not_equal_to, + MultipleAttributesOperation, + NotEqualTo + ); + implement_single_attribute_comparison_operation!( + starts_with, + MultipleAttributesOperation, + StartsWith + ); + implement_single_attribute_comparison_operation!( + ends_with, + MultipleAttributesOperation, + EndsWith + ); + implement_single_attribute_comparison_operation!( + contains, + MultipleAttributesOperation, + Contains + ); + + pub fn is_in>(&mut self, attributes: V) { + self.operations.push( + MultipleAttributesOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsIn, + }, + ); + } + + pub fn is_not_in>(&mut self, attributes: V) { + self.operations.push( + MultipleAttributesOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsNotIn, + }, + ); + } + + implement_binary_arithmetic_operation!(add, MultipleAttributesOperation, Add); + implement_binary_arithmetic_operation!(sub, MultipleAttributesOperation, Sub); + implement_binary_arithmetic_operation!(mul, MultipleAttributesOperation, Mul); + implement_binary_arithmetic_operation!(pow, MultipleAttributesOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, MultipleAttributesOperation, Mod); + + implement_unary_arithmetic_operation!(abs, MultipleAttributesOperation, Abs); + implement_unary_arithmetic_operation!(trim, MultipleAttributesOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, MultipleAttributesOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, MultipleAttributesOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, MultipleAttributesOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, MultipleAttributesOperation, Uppercase); + + #[allow(clippy::wrong_self_convention)] + pub fn to_values(&mut self) -> Wrapper { + let operand = Wrapper::::new( + values::Context::MultipleAttributesOperand(self.deep_clone()), + "unused".into(), + ); + + self.operations.push(MultipleAttributesOperation::ToValues { + operand: operand.clone(), + }); + + operand + } + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(MultipleAttributesOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, MultipleAttributesOperation::IsString); + implement_assertion_operation!(is_int, MultipleAttributesOperation::IsInt); + implement_assertion_operation!(is_max, MultipleAttributesOperation::IsMax); + implement_assertion_operation!(is_min, MultipleAttributesOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(MultipleAttributesOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: AttributesTreeOperand, kind: MultipleKind) -> Self { + MultipleAttributesOperand::new(context, kind).into() + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, attributes) + } + + implement_wrapper_operand_with_return!(max, SingleAttributeOperand); + implement_wrapper_operand_with_return!(min, SingleAttributeOperand); + implement_wrapper_operand_with_return!(count, SingleAttributeOperand); + implement_wrapper_operand_with_return!(sum, SingleAttributeOperand); + implement_wrapper_operand_with_return!(first, SingleAttributeOperand); + implement_wrapper_operand_with_return!(last, SingleAttributeOperand); + + implement_wrapper_operand_with_argument!( + greater_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!( + not_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + starts_with, + impl Into + ); + implement_wrapper_operand_with_argument!( + ends_with, + impl Into + ); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!( + is_not_in, + impl Into + ); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + implement_wrapper_operand_with_return!(to_values, MultipleValuesOperand); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +#[derive(Debug, Clone)] +pub struct SingleAttributeOperand { + pub(crate) context: MultipleAttributesOperand, + pub(crate) kind: SingleKind, + operations: Vec, +} + +impl DeepClone for SingleAttributeOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.deep_clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl SingleAttributeOperand { + pub(crate) fn new(context: MultipleAttributesOperand, kind: SingleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + attribute: MedRecordAttribute, + ) -> MedRecordResult> { + self.operations + .iter() + .try_fold(Some(attribute), |attribute, operation| { + if let Some(attribute) = attribute { + operation.evaluate(medrecord, attribute) + } else { + Ok(None) + } + }) + } + + implement_single_attribute_comparison_operation!( + greater_than, + SingleAttributeOperation, + GreaterThan + ); + implement_single_attribute_comparison_operation!( + greater_than_or_equal_to, + SingleAttributeOperation, + GreaterThanOrEqualTo + ); + implement_single_attribute_comparison_operation!(less_than, SingleAttributeOperation, LessThan); + implement_single_attribute_comparison_operation!( + less_than_or_equal_to, + SingleAttributeOperation, + LessThanOrEqualTo + ); + implement_single_attribute_comparison_operation!(equal_to, SingleAttributeOperation, EqualTo); + implement_single_attribute_comparison_operation!( + not_equal_to, + SingleAttributeOperation, + NotEqualTo + ); + implement_single_attribute_comparison_operation!( + starts_with, + SingleAttributeOperation, + StartsWith + ); + implement_single_attribute_comparison_operation!(ends_with, SingleAttributeOperation, EndsWith); + implement_single_attribute_comparison_operation!(contains, SingleAttributeOperation, Contains); + + pub fn is_in>(&mut self, attributes: V) { + self.operations.push( + SingleAttributeOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsIn, + }, + ); + } + + pub fn is_not_in>(&mut self, attributes: V) { + self.operations.push( + SingleAttributeOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsNotIn, + }, + ); + } + + implement_binary_arithmetic_operation!(add, SingleAttributeOperation, Add); + implement_binary_arithmetic_operation!(sub, SingleAttributeOperation, Sub); + implement_binary_arithmetic_operation!(mul, SingleAttributeOperation, Mul); + implement_binary_arithmetic_operation!(pow, SingleAttributeOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, SingleAttributeOperation, Mod); + + implement_unary_arithmetic_operation!(abs, SingleAttributeOperation, Abs); + implement_unary_arithmetic_operation!(trim, SingleAttributeOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, SingleAttributeOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, SingleAttributeOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, SingleAttributeOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, SingleAttributeOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(SingleAttributeOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, SingleAttributeOperation::IsString); + implement_assertion_operation!(is_int, SingleAttributeOperation::IsInt); + + pub fn eiter_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(SingleAttributeOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: MultipleAttributesOperand, kind: SingleKind) -> Self { + SingleAttributeOperand::new(context, kind).into() + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + attribute: MedRecordAttribute, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, attribute) + } + + implement_wrapper_operand_with_argument!( + greater_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!( + not_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + starts_with, + impl Into + ); + implement_wrapper_operand_with_argument!( + ends_with, + impl Into + ); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!( + is_not_in, + impl Into + ); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().eiter_or(either_query, or_query); + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/attributes/operation.rs b/crates/medmodels-core/src/medrecord/querying/attributes/operation.rs new file mode 100644 index 00000000..71479dff --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/attributes/operation.rs @@ -0,0 +1,1357 @@ +use super::{ + operand::{ + MultipleAttributesComparisonOperand, MultipleAttributesOperand, + SingleAttributeComparisonOperand, SingleAttributeOperand, + }, + AttributesTreeOperand, BinaryArithmeticKind, GetAttributes, MultipleComparisonKind, + SingleComparisonKind, UnaryArithmeticKind, +}; +use crate::{ + errors::{MedRecordError, MedRecordResult}, + medrecord::{ + datatypes::{ + Abs, Contains, EndsWith, Lowercase, Mod, Pow, Slice, StartsWith, Trim, TrimEnd, + TrimStart, Uppercase, + }, + querying::{ + attributes::{MultipleKind, SingleKind}, + traits::{DeepClone, ReadWriteOrPanic}, + values::MultipleValuesOperand, + BoxedIterator, + }, + DataType, MedRecordAttribute, MedRecordValue, Wrapper, + }, + MedRecord, +}; +use itertools::Itertools; +use std::{ + cmp::Ordering, + collections::HashMap, + fmt::Display, + hash::Hash, + ops::{Add, Mul, Range, Sub}, +}; + +macro_rules! get_multiple_operand_attributes { + ($kind:ident, $attributes:expr) => { + match $kind { + MultipleKind::Max => Box::new(AttributesTreeOperation::get_max($attributes)?), + MultipleKind::Min => Box::new(AttributesTreeOperation::get_min($attributes)?), + MultipleKind::Count => Box::new(AttributesTreeOperation::get_count($attributes)?), + MultipleKind::Sum => Box::new(AttributesTreeOperation::get_sum($attributes)?), + MultipleKind::First => Box::new(AttributesTreeOperation::get_first($attributes)?), + MultipleKind::Last => Box::new(AttributesTreeOperation::get_last($attributes)?), + } + }; +} + +macro_rules! get_single_operand_attribute { + ($kind:ident, $attributes:expr) => { + match $kind { + SingleKind::Max => MultipleAttributesOperation::get_max($attributes)?.1, + SingleKind::Min => MultipleAttributesOperation::get_min($attributes)?.1, + SingleKind::Count => MultipleAttributesOperation::get_count($attributes), + SingleKind::Sum => MultipleAttributesOperation::get_sum($attributes)?, + SingleKind::First => MultipleAttributesOperation::get_first($attributes)?, + SingleKind::Last => MultipleAttributesOperation::get_last($attributes)?, + } + }; +} + +macro_rules! get_single_attribute_comparison_operand_attribute { + ($operand:ident, $medrecord:ident) => { + match $operand { + SingleAttributeComparisonOperand::Operand(operand) => { + let context = &operand.context.context.context; + let kind = &operand.context.kind; + + let comparison_attributes = context + .get_attributes($medrecord)? + .map(|attribute| (&0, attribute)); + + let comparison_attributes: Box> = + get_multiple_operand_attributes!(kind, comparison_attributes); + + let kind = &operand.kind; + + get_single_operand_attribute!(kind, comparison_attributes) + } + SingleAttributeComparisonOperand::Attribute(attribute) => attribute.clone(), + } + }; +} + +#[derive(Debug, Clone)] +pub enum AttributesTreeOperation { + AttributesOperation { + operand: Wrapper, + }, + SingleAttributeComparisonOperation { + operand: SingleAttributeComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleAttributesComparisonOperation { + operand: MultipleAttributesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleAttributeComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for AttributesTreeOperation { + fn deep_clone(&self) -> Self { + match self { + Self::AttributesOperation { operand } => Self::AttributesOperation { + operand: operand.deep_clone(), + }, + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::SingleAttributeComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::MultipleAttributesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl AttributesTreeOperation { + pub(crate) fn evaluate<'a, T: Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + ) -> MedRecordResult)>> { + match self { + Self::AttributesOperation { operand } => Ok(Box::new( + Self::evaluate_attributes_operation(medrecord, attributes, operand)?, + )), + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::evaluate_single_attribute_comparison_operation( + medrecord, attributes, operand, kind, + ) + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_attributes_comparison_operation( + medrecord, attributes, operand, kind, + ) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, attributes, operand, kind) + } + Self::UnaryArithmeticOperation { kind } => Ok(Box::new( + Self::evaluate_unary_arithmetic_operation(attributes, kind.clone()), + )), + Self::Slice(range) => Ok(Box::new(Self::evaluate_slice(attributes, range.clone()))), + Self::IsString => Ok(Box::new(attributes.map(|(index, attribute)| { + ( + index, + attribute + .into_iter() + .filter(|attribute| matches!(attribute, MedRecordAttribute::String(_))) + .collect(), + ) + }))), + Self::IsInt => Ok(Box::new(attributes.map(|(index, attribute)| { + ( + index, + attribute + .into_iter() + .filter(|attribute| matches!(attribute, MedRecordAttribute::String(_))) + .collect(), + ) + }))), + Self::IsMax => { + let max_attributes = Self::get_max(attributes)?; + + Ok(Box::new( + max_attributes.map(|(index, attribute)| (index, vec![attribute])), + )) + } + Self::IsMin => { + let min_attributes = Self::get_min(attributes)?; + + Ok(Box::new( + min_attributes.map(|(index, attribute)| (index, vec![attribute])), + )) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, attributes, either, or) + } + } + } + + #[inline] + pub(crate) fn get_max<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes.map(|(index, attributes)| { + let mut attributes = attributes.into_iter(); + + let first_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + let attribute = attributes.try_fold(first_attribute, |max, attribute| { + match attribute.partial_cmp(&max) { + Some(Ordering::Greater) => Ok(attribute), + None => { + let first_dtype = DataType::from(attribute); + let second_dtype = DataType::from(max); + + Err(MedRecordError::QueryError(format!( + "Cannot compare attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(max), + } + })?; + + Ok((index, attribute)) + }).collect::>>()?.into_iter()) + } + + #[inline] + pub(crate) fn get_min<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes.map(|(index, attributes)| { + let mut attributes = attributes.into_iter(); + + let first_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + let attribute = attributes.try_fold(first_attribute, |max, attribute| { + match attribute.partial_cmp(&max) { + Some(Ordering::Less) => Ok(attribute), + None => { + let first_dtype = DataType::from(attribute); + let second_dtype = DataType::from(max); + + Err(MedRecordError::QueryError(format!( + "Cannot compare attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(max), + } + })?; + + Ok((index, attribute)) + }).collect::>>()?.into_iter()) + } + + #[inline] + pub(crate) fn get_count<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes + .map(|(index, attribute)| (index, MedRecordAttribute::Int(attribute.len() as i64)))) + } + + #[inline] + pub(crate) fn get_sum<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes.map(|(index, attributes)| { + let mut attributes = attributes.into_iter(); + + let first_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + let attribute = attributes.try_fold(first_attribute, |sum, attribute| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&attribute); + + sum.add(attribute).map_err(|_| { + MedRecordError::QueryError(format!( + "Cannot add attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + )) + }) + })?; + + Ok((index, attribute)) + }).collect::>>()?.into_iter()) + } + + #[inline] + pub(crate) fn get_first<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes + .map(|(index, attributes)| { + let first_attribute = + attributes + .into_iter() + .next() + .ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + Ok((index, first_attribute)) + }) + .collect::>>()? + .into_iter()) + } + + #[inline] + pub(crate) fn get_last<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes + .map(|(index, attributes)| { + let first_attribute = + attributes + .into_iter() + .last() + .ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + Ok((index, first_attribute)) + }) + .collect::>>()? + .into_iter()) + } + + #[inline] + fn evaluate_attributes_operation<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator)>, + operand: &Wrapper, + ) -> MedRecordResult)>> { + let kind = &operand.0.read_or_panic().kind; + + let attributes = attributes.collect::>(); + + let multiple_operand_attributes: Box> = + get_multiple_operand_attributes!(kind, attributes.clone().into_iter()); + + let result = operand.evaluate(medrecord, multiple_operand_attributes)?; + + let mut attributes = attributes.into_iter().collect::>(); + + Ok(result + .map(move |(index, _)| (index, attributes.remove(&index).expect("Index must exist")))) + } + + #[inline] + fn evaluate_single_attribute_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + comparison_operand: &SingleAttributeComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult)>> { + let comparison_attribute = + get_single_attribute_comparison_operand_attribute!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute > &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::GreaterThanOrEqualTo => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute >= &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::LessThan => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute < &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::LessThanOrEqualTo => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute <= &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::EqualTo => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute == &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::NotEqualTo => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute != &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::StartsWith => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute.starts_with(&comparison_attribute)) + .collect(), + ) + }))) + } + SingleComparisonKind::EndsWith => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute.ends_with(&comparison_attribute)) + .collect(), + ) + }))) + } + SingleComparisonKind::Contains => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute.contains(&comparison_attribute)) + .collect(), + ) + }))) + } + } + } + + #[inline] + fn evaluate_multiple_attributes_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + comparison_operand: &MultipleAttributesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult)>> { + let comparison_attributes = match comparison_operand { + MultipleAttributesComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + let comparison_attributes = context + .get_attributes(medrecord)? + .map(|attribute| (&0, attribute)); + + let comparison_attributes: Box> = + get_multiple_operand_attributes!(kind, comparison_attributes); + + comparison_attributes + .map(|(_, attribute)| attribute) + .collect::>() + } + MultipleAttributesComparisonOperand::Attributes(attributes) => attributes.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| comparison_attributes.contains(attribute)) + .collect(), + ) + }))) + } + MultipleComparisonKind::IsNotIn => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| !comparison_attributes.contains(attribute)) + .collect(), + ) + }))) + } + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation<'a, T: 'a>( + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + operand: &SingleAttributeComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult)>> { + let arithmetic_attribute = + get_single_attribute_comparison_operand_attribute!(operand, medrecord); + + let attributes: Box< + dyn Iterator)>>, + > = match kind { + BinaryArithmeticKind::Add => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.add(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + BinaryArithmeticKind::Sub => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.sub(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + BinaryArithmeticKind::Mul => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.mul(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + BinaryArithmeticKind::Pow => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.pow(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + BinaryArithmeticKind::Mod => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.r#mod(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + }; + + Ok(Box::new( + attributes.collect::>>()?.into_iter(), + )) + } + + #[inline] + fn evaluate_unary_arithmetic_operation<'a, T: 'a>( + attributes: impl Iterator)>, + kind: UnaryArithmeticKind, + ) -> impl Iterator)> { + attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .map(|attribute| match kind { + UnaryArithmeticKind::Abs => attribute.abs(), + UnaryArithmeticKind::Trim => attribute.trim(), + UnaryArithmeticKind::TrimStart => attribute.trim_start(), + UnaryArithmeticKind::TrimEnd => attribute.trim_end(), + UnaryArithmeticKind::Lowercase => attribute.lowercase(), + UnaryArithmeticKind::Uppercase => attribute.uppercase(), + }) + .collect(), + ) + }) + } + + #[inline] + fn evaluate_slice<'a, T: 'a>( + attributes: impl Iterator)>, + range: Range, + ) -> impl Iterator)> { + attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .map(|attribute| attribute.slice(range.clone())) + .collect(), + ) + }) + } + + #[inline] + fn evaluate_either_or<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator)>, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult)>> { + let attributes = attributes.collect::>(); + + let either_attributes = either.evaluate(medrecord, attributes.clone().into_iter())?; + let or_attributes = or.evaluate(medrecord, attributes.into_iter())?; + + Ok(Box::new( + either_attributes + .chain(or_attributes) + .unique_by(|attribute| attribute.0), + )) + } +} + +#[derive(Debug, Clone)] +pub enum MultipleAttributesOperation { + AttributeOperation { + operand: Wrapper, + }, + SingleAttributeComparisonOperation { + operand: SingleAttributeComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleAttributesComparisonOperation { + operand: MultipleAttributesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleAttributeComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + ToValues { + operand: Wrapper, + }, + + Slice(Range), + + IsString, + IsInt, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for MultipleAttributesOperation { + fn deep_clone(&self) -> Self { + match self { + Self::AttributeOperation { operand } => Self::AttributeOperation { + operand: operand.deep_clone(), + }, + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::SingleAttributeComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::MultipleAttributesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::ToValues { operand } => Self::ToValues { + operand: operand.deep_clone(), + }, + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl MultipleAttributesOperation { + pub(crate) fn evaluate<'a, T: Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + ) -> MedRecordResult> { + match self { + Self::AttributeOperation { operand } => { + Self::evaluate_attribute_operation(medrecord, attributes, operand) + } + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::evaluate_single_attribute_comparison_operation( + medrecord, attributes, operand, kind, + ) + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_attributes_comparison_operation( + medrecord, attributes, operand, kind, + ) + } + Self::BinaryArithmeticOpration { operand, kind } => Ok(Box::new( + Self::evaluate_binary_arithmetic_operation(medrecord, attributes, operand, kind)?, + )), + Self::UnaryArithmeticOperation { kind } => Ok(Box::new( + Self::evaluate_unary_arithmetic_operation(attributes, kind.clone()), + )), + Self::ToValues { operand } => Ok(Box::new(Self::evaluate_to_values( + medrecord, attributes, operand, + )?)), + Self::Slice(range) => Ok(Box::new(Self::evaluate_slice(attributes, range.clone()))), + Self::IsString => { + Ok(Box::new(attributes.filter(|(_, attribute)| { + matches!(attribute, MedRecordAttribute::String(_)) + }))) + } + Self::IsInt => { + Ok(Box::new(attributes.filter(|(_, attribute)| { + matches!(attribute, MedRecordAttribute::Int(_)) + }))) + } + Self::IsMax => { + let max_attribute = Self::get_max(attributes)?; + + Ok(Box::new(std::iter::once(max_attribute))) + } + Self::IsMin => { + let min_attribute = Self::get_min(attributes)?; + + Ok(Box::new(std::iter::once(min_attribute))) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, attributes, either, or) + } + } + } + + #[inline] + pub(crate) fn get_max<'a, T>( + mut attributes: impl Iterator, + ) -> MedRecordResult<(&'a T, MedRecordAttribute)> { + let max_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + attributes.try_fold(max_attribute, |max_attribute, attribute| { + match attribute.1.partial_cmp(&max_attribute.1) { + Some(Ordering::Greater) => Ok(attribute), + None => { + let first_dtype = DataType::from(attribute.1); + let second_dtype = DataType::from(max_attribute.1); + + Err(MedRecordError::QueryError(format!( + "Cannot compare attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(max_attribute), + } + }) + } + + #[inline] + pub(crate) fn get_min<'a, T>( + mut attributes: impl Iterator, + ) -> MedRecordResult<(&'a T, MedRecordAttribute)> { + let min_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + attributes.try_fold(min_attribute, |min_attribute, attribute| { + match attribute.1.partial_cmp(&min_attribute.1) { + Some(Ordering::Less) => Ok(attribute), + None => { + let first_dtype = DataType::from(attribute.1); + let second_dtype = DataType::from(min_attribute.1); + + Err(MedRecordError::QueryError(format!( + "Cannot compare attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(min_attribute), + } + }) + } + + #[inline] + pub(crate) fn get_count<'a, T: 'a>( + attributes: impl Iterator, + ) -> MedRecordAttribute { + MedRecordAttribute::Int(attributes.count() as i64) + } + + #[inline] + // 🥊💥 + pub(crate) fn get_sum<'a, T: 'a>( + mut attributes: impl Iterator, + ) -> MedRecordResult { + let first_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + attributes.try_fold(first_attribute.1, |sum, (_, attribute)| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&attribute); + + sum.add(attribute).map_err(|_| { + MedRecordError::QueryError(format!( + "Cannot add attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + )) + }) + }) + } + + #[inline] + pub(crate) fn get_first<'a, T: 'a>( + mut attributes: impl Iterator, + ) -> MedRecordResult { + attributes + .next() + .ok_or(MedRecordError::QueryError( + "No attributes to get the first".to_string(), + )) + .map(|(_, attribute)| attribute) + } + + #[inline] + pub(crate) fn get_last<'a, T: 'a>( + attributes: impl Iterator, + ) -> MedRecordResult { + attributes + .last() + .ok_or(MedRecordError::QueryError( + "No attributes to get the first".to_string(), + )) + .map(|(_, attribute)| attribute) + } + + #[inline] + fn evaluate_attribute_operation<'a, T>( + medrecord: &'a MedRecord, + attribtues: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let kind = &operand.0.read_or_panic().kind; + + let attributes = attribtues.collect::>(); + + let attribute = get_single_operand_attribute!(kind, attributes.clone().into_iter()); + + Ok(match operand.evaluate(medrecord, attribute)? { + Some(_) => Box::new(attributes.into_iter()), + None => Box::new(std::iter::empty()), + }) + } + + #[inline] + fn evaluate_single_attribute_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + comparison_operand: &SingleAttributeComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_attribute = + get_single_attribute_comparison_operand_attribute!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute > &comparison_attribute + }))) + } + SingleComparisonKind::GreaterThanOrEqualTo => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute >= &comparison_attribute + }))) + } + SingleComparisonKind::LessThan => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute < &comparison_attribute + }))) + } + SingleComparisonKind::LessThanOrEqualTo => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute <= &comparison_attribute + }))) + } + SingleComparisonKind::EqualTo => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute == &comparison_attribute + }))) + } + SingleComparisonKind::NotEqualTo => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute != &comparison_attribute + }))) + } + SingleComparisonKind::StartsWith => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute.starts_with(&comparison_attribute) + }))) + } + SingleComparisonKind::EndsWith => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute.ends_with(&comparison_attribute) + }))) + } + SingleComparisonKind::Contains => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute.contains(&comparison_attribute) + }))) + } + } + } + + #[inline] + fn evaluate_multiple_attributes_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + comparison_operand: &MultipleAttributesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_attributes = match comparison_operand { + MultipleAttributesComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + let comparison_attributes = context + .get_attributes(medrecord)? + .map(|attribute| (&0, attribute)); + + let attributes: Box> = + get_multiple_operand_attributes!(kind, comparison_attributes); + + attributes + .map(|(_, attribute)| attribute) + .collect::>() + } + MultipleAttributesComparisonOperand::Attributes(attributes) => attributes.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + comparison_attributes.contains(attribute) + }))) + } + MultipleComparisonKind::IsNotIn => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + !comparison_attributes.contains(attribute) + }))) + } + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation<'a, T: 'a>( + medrecord: &'a MedRecord, + attributes: impl Iterator, + operand: &SingleAttributeComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_attribute = + get_single_attribute_comparison_operand_attribute!(operand, medrecord); + + let attributes = attributes + .map(move |(t, attribute)| { + match kind { + BinaryArithmeticKind::Add => attribute.add(arithmetic_attribute.clone()), + BinaryArithmeticKind::Sub => attribute.sub(arithmetic_attribute.clone()), + BinaryArithmeticKind::Mul => { + attribute.clone().mul(arithmetic_attribute.clone()) + } + BinaryArithmeticKind::Pow => { + attribute.clone().pow(arithmetic_attribute.clone()) + } + BinaryArithmeticKind::Mod => { + attribute.clone().r#mod(arithmetic_attribute.clone()) + } + } + .map_err(|_| { + MedRecordError::QueryError(format!( + "Failed arithmetic operation {}. Consider narrowing down the attributes using .is_int() or .is_float()", + kind, + )) + }).map(|result| (t, result)) + }); + + // TODO: This is a temporary solution. It should be optimized. + Ok(attributes.collect::>>()?.into_iter()) + } + + #[inline] + fn evaluate_unary_arithmetic_operation<'a, T: 'a>( + attributes: impl Iterator, + kind: UnaryArithmeticKind, + ) -> impl Iterator { + attributes.map(move |(t, attribute)| { + let attribute = match kind { + UnaryArithmeticKind::Abs => attribute.abs(), + UnaryArithmeticKind::Trim => attribute.trim(), + UnaryArithmeticKind::TrimStart => attribute.trim_start(), + UnaryArithmeticKind::TrimEnd => attribute.trim_end(), + UnaryArithmeticKind::Lowercase => attribute.lowercase(), + UnaryArithmeticKind::Uppercase => attribute.uppercase(), + }; + (t, attribute) + }) + } + + pub(crate) fn get_values<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator, + ) -> MedRecordResult> { + Ok(attributes + .map(|(index, attribute)| { + let value = index.get_attributes(medrecord)?.get(&attribute).ok_or( + MedRecordError::QueryError(format!( + "Cannot find attribute {} for index {}", + attribute, index + )), + )?; + + Ok((index, value.clone())) + }) + .collect::>>()? + .into_iter()) + } + + #[inline] + fn evaluate_to_values<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let attributes = attributes.collect::>(); + + let values = Self::get_values(medrecord, attributes.clone().into_iter())?; + + let mut attributes = attributes.into_iter().collect::>(); + + let values = operand.evaluate(medrecord, values.into_iter())?; + + Ok(values.map(move |(index, _)| { + ( + index, + attributes.remove(&index).expect("Attribute must exist"), + ) + })) + } + + #[inline] + fn evaluate_slice<'a, T: 'a>( + attributes: impl Iterator, + range: Range, + ) -> impl Iterator { + attributes.map(move |(t, attribute)| (t, attribute.slice(range.clone()))) + } + + #[inline] + fn evaluate_either_or<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let attributes = attributes.collect::>(); + + let either_attributes = either.evaluate(medrecord, attributes.clone().into_iter())?; + let or_attributes = or.evaluate(medrecord, attributes.into_iter())?; + + Ok(Box::new( + either_attributes + .chain(or_attributes) + .unique_by(|attribute| attribute.0), + )) + } +} + +#[derive(Debug, Clone)] +pub enum SingleAttributeOperation { + SingleAttributeComparisonOperation { + operand: SingleAttributeComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleAttributesComparisonOperation { + operand: MultipleAttributesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleAttributeComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for SingleAttributeOperation { + fn deep_clone(&self) -> Self { + match self { + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::SingleAttributeComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::MultipleAttributesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl SingleAttributeOperation { + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + attribute: MedRecordAttribute, + ) -> MedRecordResult> { + match self { + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::evaluate_single_attribute_comparison_operation( + medrecord, attribute, operand, kind, + ) + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_attribute_comparison_operation( + medrecord, attribute, operand, kind, + ) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, attribute, operand, kind) + } + Self::UnaryArithmeticOperation { kind } => Ok(Some(match kind { + UnaryArithmeticKind::Abs => attribute.abs(), + UnaryArithmeticKind::Trim => attribute.trim(), + UnaryArithmeticKind::TrimStart => attribute.trim_start(), + UnaryArithmeticKind::TrimEnd => attribute.trim_end(), + UnaryArithmeticKind::Lowercase => attribute.lowercase(), + UnaryArithmeticKind::Uppercase => attribute.uppercase(), + })), + Self::Slice(range) => Ok(Some(attribute.slice(range.clone()))), + Self::IsString => Ok(match attribute { + MedRecordAttribute::String(_) => Some(attribute), + _ => None, + }), + Self::IsInt => Ok(match attribute { + MedRecordAttribute::Int(_) => Some(attribute), + _ => None, + }), + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, attribute, either, or) + } + } + } + + #[inline] + fn evaluate_single_attribute_comparison_operation( + medrecord: &MedRecord, + attribute: MedRecordAttribute, + comparison_operand: &SingleAttributeComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_attribute = + get_single_attribute_comparison_operand_attribute!(comparison_operand, medrecord); + + let comparison_result = match kind { + SingleComparisonKind::GreaterThan => attribute > comparison_attribute, + SingleComparisonKind::GreaterThanOrEqualTo => attribute >= comparison_attribute, + SingleComparisonKind::LessThan => attribute < comparison_attribute, + SingleComparisonKind::LessThanOrEqualTo => attribute <= comparison_attribute, + SingleComparisonKind::EqualTo => attribute == comparison_attribute, + SingleComparisonKind::NotEqualTo => attribute != comparison_attribute, + SingleComparisonKind::StartsWith => attribute.starts_with(&comparison_attribute), + SingleComparisonKind::EndsWith => attribute.ends_with(&comparison_attribute), + SingleComparisonKind::Contains => attribute.contains(&comparison_attribute), + }; + + Ok(if comparison_result { + Some(attribute) + } else { + None + }) + } + + #[inline] + fn evaluate_multiple_attribute_comparison_operation( + medrecord: &MedRecord, + attribute: MedRecordAttribute, + comparison_operand: &MultipleAttributesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_attributes = match comparison_operand { + MultipleAttributesComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + let comparison_attributes = context + .get_attributes(medrecord)? + .map(|attribute| (&0, attribute)); + + let attributes: Box> = + get_multiple_operand_attributes!(kind, comparison_attributes); + + attributes + .map(|(_, attribute)| attribute) + .collect::>() + } + MultipleAttributesComparisonOperand::Attributes(attributes) => attributes.clone(), + }; + + let comparison_result = match kind { + MultipleComparisonKind::IsIn => comparison_attributes.contains(&attribute), + MultipleComparisonKind::IsNotIn => !comparison_attributes.contains(&attribute), + }; + + Ok(if comparison_result { + Some(attribute) + } else { + None + }) + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + attribute: MedRecordAttribute, + operand: &SingleAttributeComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_attribute = + get_single_attribute_comparison_operand_attribute!(operand, medrecord); + + match kind { + BinaryArithmeticKind::Add => attribute.add(arithmetic_attribute), + BinaryArithmeticKind::Sub => attribute.sub(arithmetic_attribute), + BinaryArithmeticKind::Mul => attribute.mul(arithmetic_attribute), + BinaryArithmeticKind::Pow => attribute.pow(arithmetic_attribute), + BinaryArithmeticKind::Mod => attribute.r#mod(arithmetic_attribute), + } + .map(Some) + } + + #[inline] + fn evaluate_either_or( + medrecord: &MedRecord, + attribute: MedRecordAttribute, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord, attribute.clone())?; + let or_result = or.evaluate(medrecord, attribute)?; + + match (either_result, or_result) { + (Some(either_result), _) => Ok(Some(either_result)), + (None, Some(or_result)) => Ok(Some(or_result)), + _ => Ok(None), + } + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/edges/mod.rs b/crates/medmodels-core/src/medrecord/querying/edges/mod.rs new file mode 100644 index 00000000..1045e83e --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/edges/mod.rs @@ -0,0 +1,58 @@ +mod operand; +mod operation; +mod selection; + +pub use operand::EdgeOperand; +pub use operation::EdgeOperation; +pub use selection::EdgeSelection; +use std::fmt::Display; + +#[derive(Debug, Clone)] +pub enum SingleKind { + Max, + Min, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum SingleComparisonKind { + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + NotEqualTo, + StartsWith, + EndsWith, + Contains, +} + +#[derive(Debug, Clone)] +pub enum MultipleComparisonKind { + IsIn, + IsNotIn, +} + +#[derive(Debug, Clone)] +pub enum BinaryArithmeticKind { + Add, + Sub, + Mul, + Pow, + Mod, +} + +impl Display for BinaryArithmeticKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BinaryArithmeticKind::Add => write!(f, "add"), + BinaryArithmeticKind::Sub => write!(f, "sub"), + BinaryArithmeticKind::Mul => write!(f, "mul"), + BinaryArithmeticKind::Pow => write!(f, "pow"), + BinaryArithmeticKind::Mod => write!(f, "mod"), + } + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/edges/operand.rs b/crates/medmodels-core/src/medrecord/querying/edges/operand.rs new file mode 100644 index 00000000..4b7b4f85 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/edges/operand.rs @@ -0,0 +1,655 @@ +use super::{ + operation::{EdgeIndexOperation, EdgeIndicesOperation, EdgeOperation}, + BinaryArithmeticKind, MultipleComparisonKind, SingleComparisonKind, SingleKind, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{ + querying::{ + attributes::{self, AttributesTreeOperand}, + nodes::NodeOperand, + traits::{DeepClone, ReadWriteOrPanic}, + values::{self, MultipleValuesOperand}, + wrapper::Wrapper, + BoxedIterator, + }, + CardinalityWrapper, EdgeIndex, Group, MedRecordAttribute, + }, + MedRecord, +}; +use std::fmt::Debug; + +#[derive(Debug, Clone)] +pub struct EdgeOperand { + pub(crate) operations: Vec, +} + +impl DeepClone for EdgeOperand { + fn deep_clone(&self) -> Self { + Self { + operations: self + .operations + .iter() + .map(|operation| operation.deep_clone()) + .collect(), + } + } +} + +impl EdgeOperand { + pub(crate) fn new() -> Self { + Self { + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + ) -> MedRecordResult> { + let edge_indices = Box::new(medrecord.edge_indices()) as BoxedIterator<&'a EdgeIndex>; + + self.operations + .iter() + .try_fold(edge_indices, |edge_indices, operation| { + operation.evaluate(medrecord, edge_indices) + }) + } + + pub fn attribute(&mut self, attribute: MedRecordAttribute) -> Wrapper { + let operand = Wrapper::::new( + values::Context::EdgeOperand(self.deep_clone()), + attribute, + ); + + self.operations.push(EdgeOperation::Values { + operand: operand.clone(), + }); + + operand + } + + pub fn attributes(&mut self) -> Wrapper { + let operand = Wrapper::::new(attributes::Context::EdgeOperand( + self.deep_clone(), + )); + + self.operations.push(EdgeOperation::Attributes { + operand: operand.clone(), + }); + + operand + } + + pub fn index(&mut self) -> Wrapper { + let operand = Wrapper::::new(self.deep_clone()); + + self.operations.push(EdgeOperation::Indices { + operand: operand.clone(), + }); + + operand + } + + pub fn in_group(&mut self, group: G) + where + G: Into>, + { + self.operations.push(EdgeOperation::InGroup { + group: group.into(), + }); + } + + pub fn has_attribute(&mut self, attribute: A) + where + A: Into>, + { + self.operations.push(EdgeOperation::HasAttribute { + attribute: attribute.into(), + }); + } + + pub fn source_node(&mut self) -> Wrapper { + let operand = Wrapper::::new(); + + self.operations.push(EdgeOperation::SourceNode { + operand: operand.clone(), + }); + + operand + } + + pub fn target_node(&mut self) -> Wrapper { + let operand = Wrapper::::new(); + + self.operations.push(EdgeOperation::TargetNode { + operand: operand.clone(), + }); + + operand + } + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(); + let mut or_operand = Wrapper::::new(); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(EdgeOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new() -> Self { + EdgeOperand::new().into() + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord) + } + + pub fn attribute(&self, attribute: A) -> Wrapper + where + A: Into, + { + self.0.write_or_panic().attribute(attribute.into()) + } + + pub fn attributes(&self) -> Wrapper { + self.0.write_or_panic().attributes() + } + + pub fn index(&self) -> Wrapper { + self.0.write_or_panic().index() + } + + pub fn in_group(&mut self, group: G) + where + G: Into>, + { + self.0.write_or_panic().in_group(group); + } + + pub fn has_attribute(&mut self, attribute: A) + where + A: Into>, + { + self.0.write_or_panic().has_attribute(attribute); + } + + pub fn source_node(&self) -> Wrapper { + self.0.write_or_panic().source_node() + } + + pub fn target_node(&self) -> Wrapper { + self.0.write_or_panic().target_node() + } + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +macro_rules! implement_value_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = Wrapper::::new(self.deep_clone(), SingleKind::$variant); + + self.operations + .push(EdgeIndicesOperation::EdgeIndexOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_single_value_comparison_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations + .push($operation::EdgeIndexComparisonOperation { + operand: value.into(), + kind: SingleComparisonKind::$kind, + }); + } + }; +} + +macro_rules! implement_binary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations.push($operation::BinaryArithmeticOpration { + operand: value.into(), + kind: BinaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_assertion_operation { + ($name:ident, $operation:expr) => { + pub fn $name(&mut self) { + self.operations.push($operation); + } + }; +} + +macro_rules! implement_wrapper_operand { + ($name:ident) => { + pub fn $name(&self) { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_return { + ($name:ident, $return_operand:ident) => { + pub fn $name(&self) -> Wrapper<$return_operand> { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_argument { + ($name:ident, $value_type:ty) => { + pub fn $name(&self, value: $value_type) { + self.0.write_or_panic().$name(value) + } + }; +} + +#[derive(Debug, Clone)] +pub enum EdgeIndexComparisonOperand { + Operand(EdgeIndexOperand), + Index(EdgeIndex), +} + +impl DeepClone for EdgeIndexComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Index(value) => Self::Index(*value), + } + } +} + +impl From> for EdgeIndexComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for EdgeIndexComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From for EdgeIndexComparisonOperand { + fn from(value: V) -> Self { + Self::Index(value.into()) + } +} + +#[derive(Debug, Clone)] +pub enum EdgeIndicesComparisonOperand { + Operand(EdgeIndicesOperand), + Indices(Vec), +} + +impl DeepClone for EdgeIndicesComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Indices(value) => Self::Indices(value.clone()), + } + } +} + +impl From> for EdgeIndicesComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for EdgeIndicesComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From> for EdgeIndicesComparisonOperand { + fn from(value: Vec) -> Self { + Self::Indices(value.into_iter().map(Into::into).collect()) + } +} + +impl + Clone, const N: usize> From<[V; N]> for EdgeIndicesComparisonOperand { + fn from(value: [V; N]) -> Self { + value.to_vec().into() + } +} + +#[derive(Debug, Clone)] +pub struct EdgeIndicesOperand { + pub(crate) context: EdgeOperand, + operations: Vec, +} + +impl DeepClone for EdgeIndicesOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl EdgeIndicesOperand { + pub(crate) fn new(context: EdgeOperand) -> Self { + Self { + context, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult + 'a> { + let values = Box::new(values) as BoxedIterator; + + self.operations + .iter() + .try_fold(values, |value_tuples, operation| { + operation.evaluate(medrecord, value_tuples) + }) + } + + implement_value_operation!(max, Max); + implement_value_operation!(min, Min); + implement_value_operation!(count, Count); + implement_value_operation!(sum, Sum); + implement_value_operation!(first, First); + implement_value_operation!(last, Last); + + implement_single_value_comparison_operation!(greater_than, EdgeIndicesOperation, GreaterThan); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + EdgeIndicesOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, EdgeIndicesOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + EdgeIndicesOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, EdgeIndicesOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, EdgeIndicesOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, EdgeIndicesOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, EdgeIndicesOperation, EndsWith); + implement_single_value_comparison_operation!(contains, EdgeIndicesOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(EdgeIndicesOperation::EdgeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(EdgeIndicesOperation::EdgeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, EdgeIndicesOperation, Add); + implement_binary_arithmetic_operation!(sub, EdgeIndicesOperation, Sub); + implement_binary_arithmetic_operation!(mul, EdgeIndicesOperation, Mul); + implement_binary_arithmetic_operation!(pow, EdgeIndicesOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, EdgeIndicesOperation, Mod); + + implement_assertion_operation!(is_max, EdgeIndicesOperation::IsMax); + implement_assertion_operation!(is_min, EdgeIndicesOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(self.context.clone()); + let mut or_operand = Wrapper::::new(self.context.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(EdgeIndicesOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: EdgeOperand) -> Self { + EdgeIndicesOperand::new(context).into() + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult + 'a> { + self.0.read_or_panic().evaluate(medrecord, values) + } + + implement_wrapper_operand_with_return!(max, EdgeIndexOperand); + implement_wrapper_operand_with_return!(min, EdgeIndexOperand); + implement_wrapper_operand_with_return!(count, EdgeIndexOperand); + implement_wrapper_operand_with_return!(sum, EdgeIndexOperand); + implement_wrapper_operand_with_return!(first, EdgeIndexOperand); + implement_wrapper_operand_with_return!(last, EdgeIndexOperand); + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +#[derive(Debug, Clone)] +pub struct EdgeIndexOperand { + pub(crate) context: EdgeIndicesOperand, + pub(crate) kind: SingleKind, + operations: Vec, +} + +impl DeepClone for EdgeIndexOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.deep_clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl EdgeIndexOperand { + pub(crate) fn new(context: EdgeIndicesOperand, kind: SingleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: EdgeIndex, + ) -> MedRecordResult> { + self.operations + .iter() + .try_fold(Some(value), |value, operation| { + if let Some(value) = value { + operation.evaluate(medrecord, value) + } else { + Ok(None) + } + }) + } + + implement_single_value_comparison_operation!(greater_than, EdgeIndexOperation, GreaterThan); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + EdgeIndexOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, EdgeIndexOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + EdgeIndexOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, EdgeIndexOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, EdgeIndexOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, EdgeIndexOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, EdgeIndexOperation, EndsWith); + implement_single_value_comparison_operation!(contains, EdgeIndexOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(EdgeIndexOperation::EdgeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(EdgeIndexOperation::EdgeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, EdgeIndexOperation, Add); + implement_binary_arithmetic_operation!(sub, EdgeIndexOperation, Sub); + implement_binary_arithmetic_operation!(mul, EdgeIndexOperation, Mul); + implement_binary_arithmetic_operation!(pow, EdgeIndexOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, EdgeIndexOperation, Mod); + + pub fn eiter_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(EdgeIndexOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: EdgeIndicesOperand, kind: SingleKind) -> Self { + EdgeIndexOperand::new(context, kind).into() + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: EdgeIndex, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, value) + } + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().eiter_or(either_query, or_query); + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/edges/operation.rs b/crates/medmodels-core/src/medrecord/querying/edges/operation.rs new file mode 100644 index 00000000..0d36db8f --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/edges/operation.rs @@ -0,0 +1,762 @@ +use super::{ + operand::{ + EdgeIndexComparisonOperand, EdgeIndexOperand, EdgeIndicesComparisonOperand, + EdgeIndicesOperand, + }, + BinaryArithmeticKind, EdgeOperand, MultipleComparisonKind, SingleComparisonKind, +}; +use crate::{ + errors::{MedRecordError, MedRecordResult}, + medrecord::{ + datatypes::{Contains, EndsWith, Mod, StartsWith}, + querying::{ + attributes::AttributesTreeOperand, + edges::SingleKind, + nodes::NodeOperand, + traits::{DeepClone, ReadWriteOrPanic}, + values::MultipleValuesOperand, + wrapper::Wrapper, + BoxedIterator, + }, + CardinalityWrapper, EdgeIndex, Group, MedRecordAttribute, MedRecordValue, + }, + MedRecord, +}; +use itertools::Itertools; +use std::{ + collections::HashSet, + ops::{Add, Mul, Sub}, +}; + +#[derive(Debug, Clone)] +pub enum EdgeOperation { + Values { + operand: Wrapper, + }, + Attributes { + operand: Wrapper, + }, + Indices { + operand: Wrapper, + }, + + InGroup { + group: CardinalityWrapper, + }, + HasAttribute { + attribute: CardinalityWrapper, + }, + + SourceNode { + operand: Wrapper, + }, + TargetNode { + operand: Wrapper, + }, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for EdgeOperation { + fn deep_clone(&self) -> Self { + match self { + Self::Values { operand } => Self::Values { + operand: operand.deep_clone(), + }, + Self::Attributes { operand } => Self::Attributes { + operand: operand.deep_clone(), + }, + Self::Indices { operand } => Self::Indices { + operand: operand.deep_clone(), + }, + Self::InGroup { group } => Self::InGroup { + group: group.clone(), + }, + Self::HasAttribute { attribute } => Self::HasAttribute { + attribute: attribute.clone(), + }, + Self::SourceNode { operand } => Self::SourceNode { + operand: operand.deep_clone(), + }, + Self::TargetNode { operand } => Self::TargetNode { + operand: operand.deep_clone(), + }, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl EdgeOperation { + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + edge_indices: impl Iterator + 'a, + ) -> MedRecordResult> { + Ok(match self { + Self::Values { operand } => Box::new(Self::evaluate_values( + medrecord, + edge_indices, + operand.clone(), + )?), + Self::Attributes { operand } => Box::new(Self::evaluate_attributes( + medrecord, + edge_indices, + operand.clone(), + )?), + Self::Indices { operand } => Box::new(Self::evaluate_indices( + medrecord, + edge_indices, + operand.clone(), + )?), + Self::InGroup { group } => Box::new(Self::evaluate_in_group( + medrecord, + edge_indices, + group.clone(), + )), + Self::HasAttribute { attribute } => Box::new(Self::evaluate_has_attribute( + medrecord, + edge_indices, + attribute.clone(), + )), + Self::SourceNode { operand } => Box::new(Self::evaluate_source_node( + medrecord, + edge_indices, + operand, + )?), + Self::TargetNode { operand } => Box::new(Self::evaluate_target_node( + medrecord, + edge_indices, + operand, + )?), + Self::EitherOr { either, or } => { + Box::new(Self::evaluate_either_or(medrecord, either, or)?) + } + }) + } + + #[inline] + pub(crate) fn get_values<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + attribute: MedRecordAttribute, + ) -> impl Iterator { + edge_indices.flat_map(move |edge_index| { + Some(( + edge_index, + medrecord + .edge_attributes(edge_index) + .expect("Edge must exist") + .get(&attribute)? + .clone(), + )) + }) + } + + #[inline] + fn evaluate_values<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator + 'a, + operand: Wrapper, + ) -> MedRecordResult> { + let values = Self::get_values( + medrecord, + edge_indices, + operand.0.read_or_panic().attribute.clone(), + ); + + Ok(operand.evaluate(medrecord, values)?.map(|value| value.0)) + } + + #[inline] + pub(crate) fn get_attributes<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + ) -> impl Iterator)> { + edge_indices.map(move |edge_index| { + let attributes = medrecord + .edge_attributes(edge_index) + .expect("Edge must exist") + .keys() + .cloned(); + + (edge_index, attributes.collect()) + }) + } + + #[inline] + fn evaluate_attributes<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator + 'a, + operand: Wrapper, + ) -> MedRecordResult> { + let attributes = Self::get_attributes(medrecord, edge_indices); + + Ok(operand + .evaluate(medrecord, attributes)? + .map(|value| value.0)) + } + + #[inline] + fn evaluate_indices<'a>( + medrecord: &MedRecord, + edge_indices: impl Iterator, + operand: Wrapper, + ) -> MedRecordResult> { + // TODO: This is a temporary solution. It should be optimized. + let edge_indices = edge_indices.collect::>(); + + let result = operand + .evaluate(medrecord, edge_indices.clone().into_iter().cloned())? + .collect::>(); + + Ok(edge_indices + .into_iter() + .filter(move |index| result.contains(index))) + } + + #[inline] + fn evaluate_in_group<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + group: CardinalityWrapper, + ) -> impl Iterator { + edge_indices.filter(move |edge_index| { + let groups_of_edge = medrecord + .groups_of_edge(edge_index) + .expect("Node must exist"); + + let groups_of_edge = groups_of_edge.collect::>(); + + match &group { + CardinalityWrapper::Single(group) => groups_of_edge.contains(&group), + CardinalityWrapper::Multiple(groups) => { + groups.iter().all(|group| groups_of_edge.contains(&group)) + } + } + }) + } + + #[inline] + fn evaluate_has_attribute<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + attribute: CardinalityWrapper, + ) -> impl Iterator { + edge_indices.filter(move |edge_index| { + let attributes_of_edge = medrecord + .edge_attributes(edge_index) + .expect("Node must exist") + .keys(); + + let attributes_of_edge = attributes_of_edge.collect::>(); + + match &attribute { + CardinalityWrapper::Single(attribute) => attributes_of_edge.contains(&attribute), + CardinalityWrapper::Multiple(attributes) => attributes + .iter() + .all(|attribute| attributes_of_edge.contains(&attribute)), + } + }) + } + + #[inline] + fn evaluate_source_node<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let node_indices = operand.evaluate(medrecord)?.collect::>(); + + Ok(edge_indices.filter(move |edge_index| { + let edge_endpoints = medrecord + .edge_endpoints(edge_index) + .expect("Edge must exist"); + + node_indices.contains(edge_endpoints.1) + })) + } + + #[inline] + fn evaluate_target_node<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let node_indices = operand.evaluate(medrecord)?.collect::>(); + + Ok(edge_indices.filter(move |edge_index| { + let edge_endpoints = medrecord + .edge_endpoints(edge_index) + .expect("Edge must exist"); + + node_indices.contains(edge_endpoints.1) + })) + } + + #[inline] + fn evaluate_either_or<'a>( + medrecord: &'a MedRecord, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord)?; + let or_result = or.evaluate(medrecord)?; + + Ok(either_result.chain(or_result).unique()) + } +} + +macro_rules! get_edge_index { + ($kind:ident, $indices:expr) => { + match $kind { + SingleKind::Max => EdgeIndicesOperation::get_max($indices)?.clone(), + SingleKind::Min => EdgeIndicesOperation::get_min($indices)?.clone(), + SingleKind::Count => EdgeIndicesOperation::get_count($indices), + SingleKind::Sum => EdgeIndicesOperation::get_sum($indices), + SingleKind::First => EdgeIndicesOperation::get_first($indices)?, + SingleKind::Last => EdgeIndicesOperation::get_last($indices)?, + } + }; +} + +macro_rules! get_edge_index_comparison_operand_index { + ($operand:ident, $medrecord:ident) => { + match $operand { + EdgeIndexComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + // TODO: This is a temporary solution. It should be optimized. + let comparison_indices = context.evaluate($medrecord)?.cloned(); + + let comparison_index = get_edge_index!(kind, comparison_indices); + + comparison_index + } + EdgeIndexComparisonOperand::Index(index) => index.clone(), + } + }; +} + +#[derive(Debug, Clone)] +pub enum EdgeIndicesOperation { + EdgeIndexOperation { + operand: Wrapper, + }, + EdgeIndexComparisonOperation { + operand: EdgeIndexComparisonOperand, + kind: SingleComparisonKind, + }, + EdgeIndicesComparisonOperation { + operand: EdgeIndicesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: EdgeIndexComparisonOperand, + kind: BinaryArithmeticKind, + }, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for EdgeIndicesOperation { + fn deep_clone(&self) -> Self { + match self { + Self::EdgeIndexOperation { operand } => Self::EdgeIndexOperation { + operand: operand.deep_clone(), + }, + Self::EdgeIndexComparisonOperation { operand, kind } => { + Self::EdgeIndexComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::EdgeIndicesComparisonOperation { operand, kind } => { + Self::EdgeIndicesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl EdgeIndicesOperation { + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + indices: impl Iterator + 'a, + ) -> MedRecordResult> { + match self { + Self::EdgeIndexOperation { operand } => { + Self::evaluate_edge_index_operation(medrecord, indices, operand) + } + Self::EdgeIndexComparisonOperation { operand, kind } => { + Self::evaluate_edge_index_comparison_operation(medrecord, indices, operand, kind) + } + Self::EdgeIndicesComparisonOperation { operand, kind } => { + Self::evaluate_edge_indices_comparison_operation(medrecord, indices, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Ok(Box::new(Self::evaluate_binary_arithmetic_operation( + medrecord, + indices, + operand, + kind.clone(), + )?)) + } + Self::IsMax => { + let max_index = Self::get_max(indices)?; + + Ok(Box::new(std::iter::once(max_index))) + } + Self::IsMin => { + let min_index = Self::get_min(indices)?; + + Ok(Box::new(std::iter::once(min_index))) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, indices, either, or) + } + } + } + + #[inline] + pub(crate) fn get_max(indices: impl Iterator) -> MedRecordResult { + indices.max().ok_or(MedRecordError::QueryError( + "No indices to compare".to_string(), + )) + } + + #[inline] + pub(crate) fn get_min(indices: impl Iterator) -> MedRecordResult { + indices.min().ok_or(MedRecordError::QueryError( + "No indices to compare".to_string(), + )) + } + #[inline] + pub(crate) fn get_count(indices: impl Iterator) -> EdgeIndex { + indices.count() as EdgeIndex + } + + #[inline] + pub(crate) fn get_sum(indices: impl Iterator) -> EdgeIndex { + indices.sum() + } + + #[inline] + pub(crate) fn get_first( + mut indices: impl Iterator, + ) -> MedRecordResult { + indices.next().ok_or(MedRecordError::QueryError( + "No indices to get the first".to_string(), + )) + } + + #[inline] + pub(crate) fn get_last(indices: impl Iterator) -> MedRecordResult { + indices.last().ok_or(MedRecordError::QueryError( + "No indices to get the first".to_string(), + )) + } + + #[inline] + fn evaluate_edge_index_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let kind = &operand.0.read_or_panic().kind; + + let indices = indices.collect::>(); + + let index = get_edge_index!(kind, indices.clone().into_iter()); + + Ok(match operand.evaluate(medrecord, index)? { + Some(_) => Box::new(indices.into_iter()), + None => Box::new(std::iter::empty()), + }) + } + + #[inline] + fn evaluate_edge_index_comparison_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator + 'a, + comparison_operand: &EdgeIndexComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_index = + get_edge_index_comparison_operand_index!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => Ok(Box::new( + indices.filter(move |index| index > &comparison_index), + )), + SingleComparisonKind::GreaterThanOrEqualTo => Ok(Box::new( + indices.filter(move |index| index >= &comparison_index), + )), + SingleComparisonKind::LessThan => Ok(Box::new( + indices.filter(move |index| index < &comparison_index), + )), + SingleComparisonKind::LessThanOrEqualTo => Ok(Box::new( + indices.filter(move |index| index <= &comparison_index), + )), + SingleComparisonKind::EqualTo => Ok(Box::new( + indices.filter(move |index| index == &comparison_index), + )), + SingleComparisonKind::NotEqualTo => Ok(Box::new( + indices.filter(move |index| index != &comparison_index), + )), + SingleComparisonKind::StartsWith => Ok(Box::new( + indices.filter(move |index| index.starts_with(&comparison_index)), + )), + SingleComparisonKind::EndsWith => Ok(Box::new( + indices.filter(move |index| index.ends_with(&comparison_index)), + )), + SingleComparisonKind::Contains => Ok(Box::new( + indices.filter(move |index| index.contains(&comparison_index)), + )), + } + } + + #[inline] + fn evaluate_edge_indices_comparison_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator + 'a, + comparison_operand: &EdgeIndicesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_indices = match comparison_operand { + EdgeIndicesComparisonOperand::Operand(operand) => { + let context = &operand.context; + + context.evaluate(medrecord)?.cloned().collect::>() + } + EdgeIndicesComparisonOperand::Indices(indices) => indices.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => Ok(Box::new( + indices.filter(move |index| comparison_indices.contains(index)), + )), + MultipleComparisonKind::IsNotIn => Ok(Box::new( + indices.filter(move |index| !comparison_indices.contains(index)), + )), + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + indices: impl Iterator, + operand: &EdgeIndexComparisonOperand, + kind: BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_index = get_edge_index_comparison_operand_index!(operand, medrecord); + + Ok(indices + .map(move |index| match kind { + BinaryArithmeticKind::Add => Ok(index.add(arithmetic_index)), + BinaryArithmeticKind::Sub => Ok(index.sub(arithmetic_index)), + BinaryArithmeticKind::Mul => Ok(index.mul(arithmetic_index)), + BinaryArithmeticKind::Pow => Ok(index.pow(arithmetic_index)), + BinaryArithmeticKind::Mod => index.r#mod(arithmetic_index), + }) + .collect::>>()? + .into_iter()) + } + + #[inline] + fn evaluate_either_or<'a>( + medrecord: &'a MedRecord, + indices: impl Iterator, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let indices = indices.collect::>(); + + let either_indices = either.evaluate(medrecord, indices.clone().into_iter())?; + let or_indices = or.evaluate(medrecord, indices.into_iter())?; + + Ok(Box::new(either_indices.chain(or_indices).unique())) + } +} + +#[derive(Debug, Clone)] +pub enum EdgeIndexOperation { + EdgeIndexComparisonOperation { + operand: EdgeIndexComparisonOperand, + kind: SingleComparisonKind, + }, + EdgeIndicesComparisonOperation { + operand: EdgeIndicesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: EdgeIndexComparisonOperand, + kind: BinaryArithmeticKind, + }, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for EdgeIndexOperation { + fn deep_clone(&self) -> Self { + match self { + Self::EdgeIndexComparisonOperation { operand, kind } => { + Self::EdgeIndexComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::EdgeIndicesComparisonOperation { operand, kind } => { + Self::EdgeIndicesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl EdgeIndexOperation { + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + index: EdgeIndex, + ) -> MedRecordResult> { + match self { + Self::EdgeIndexComparisonOperation { operand, kind } => { + Self::evaluate_edge_index_comparison_operation(medrecord, index, operand, kind) + } + Self::EdgeIndicesComparisonOperation { operand, kind } => { + Self::evaluate_edge_indcies_comparison_operation(medrecord, index, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, index, operand, kind) + } + Self::EitherOr { either, or } => Self::evaluate_either_or(medrecord, index, either, or), + } + } + + #[inline] + fn evaluate_edge_index_comparison_operation( + medrecord: &MedRecord, + index: EdgeIndex, + comparison_operand: &EdgeIndexComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_index = + get_edge_index_comparison_operand_index!(comparison_operand, medrecord); + + let comparison_result = match kind { + SingleComparisonKind::GreaterThan => index > comparison_index, + SingleComparisonKind::GreaterThanOrEqualTo => index >= comparison_index, + SingleComparisonKind::LessThan => index < comparison_index, + SingleComparisonKind::LessThanOrEqualTo => index <= comparison_index, + SingleComparisonKind::EqualTo => index == comparison_index, + SingleComparisonKind::NotEqualTo => index != comparison_index, + SingleComparisonKind::StartsWith => index.starts_with(&comparison_index), + SingleComparisonKind::EndsWith => index.ends_with(&comparison_index), + SingleComparisonKind::Contains => index.contains(&comparison_index), + }; + + Ok(if comparison_result { Some(index) } else { None }) + } + + #[inline] + fn evaluate_edge_indcies_comparison_operation( + medrecord: &MedRecord, + index: EdgeIndex, + comparison_operand: &EdgeIndicesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_indices = match comparison_operand { + EdgeIndicesComparisonOperand::Operand(operand) => { + let context = &operand.context; + + context.evaluate(medrecord)?.cloned().collect::>() + } + EdgeIndicesComparisonOperand::Indices(indices) => indices.clone(), + }; + + let comparison_result = match kind { + MultipleComparisonKind::IsIn => comparison_indices + .into_iter() + .any(|comparison_index| index == comparison_index), + MultipleComparisonKind::IsNotIn => comparison_indices + .into_iter() + .all(|comparison_index| index != comparison_index), + }; + + Ok(if comparison_result { Some(index) } else { None }) + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + index: EdgeIndex, + operand: &EdgeIndexComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_index = get_edge_index_comparison_operand_index!(operand, medrecord); + + Ok(Some(match kind { + BinaryArithmeticKind::Add => index.add(arithmetic_index), + BinaryArithmeticKind::Sub => index.sub(arithmetic_index), + BinaryArithmeticKind::Mul => index.mul(arithmetic_index), + BinaryArithmeticKind::Pow => index.pow(arithmetic_index), + BinaryArithmeticKind::Mod => index.r#mod(arithmetic_index)?, + })) + } + + #[inline] + fn evaluate_either_or( + medrecord: &MedRecord, + index: EdgeIndex, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord, index)?; + let or_result = or.evaluate(medrecord, index)?; + + match (either_result, or_result) { + (Some(either_result), _) => Ok(Some(either_result)), + (None, Some(or_result)) => Ok(Some(or_result)), + _ => Ok(None), + } + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/edges/selection.rs b/crates/medmodels-core/src/medrecord/querying/edges/selection.rs new file mode 100644 index 00000000..a0d0a519 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/edges/selection.rs @@ -0,0 +1,32 @@ +use super::EdgeOperand; +use crate::{ + errors::MedRecordResult, + medrecord::{querying::wrapper::Wrapper, EdgeIndex, MedRecord}, +}; + +#[derive(Debug, Clone)] +pub struct EdgeSelection<'a> { + medrecord: &'a MedRecord, + operand: Wrapper, +} + +impl<'a> EdgeSelection<'a> { + pub fn new(medrecord: &'a MedRecord, query: Q) -> Self + where + Q: FnOnce(&mut Wrapper), + { + let mut operand = Wrapper::::new(); + + query(&mut operand); + + Self { medrecord, operand } + } + + pub fn iter(&'a self) -> MedRecordResult> { + self.operand.evaluate(self.medrecord) + } + + pub fn collect>(&'a self) -> MedRecordResult { + Ok(FromIterator::from_iter(self.iter()?)) + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/mod.rs b/crates/medmodels-core/src/medrecord/querying/mod.rs index 1f999f78..94728fe4 100644 --- a/crates/medmodels-core/src/medrecord/querying/mod.rs +++ b/crates/medmodels-core/src/medrecord/querying/mod.rs @@ -1,9 +1,8 @@ -mod operation; -mod selection; +pub mod attributes; +pub mod edges; +pub mod nodes; +mod traits; +pub mod values; +pub mod wrapper; -pub use self::operation::{ - edge, node, ArithmeticOperation, EdgeAttributeOperand, EdgeIndexOperand, EdgeOperand, - EdgeOperation, NodeAttributeOperand, NodeIndexOperand, NodeOperand, NodeOperation, - TransformationOperation, ValueOperand, -}; -pub(super) use self::selection::{EdgeSelection, NodeSelection}; +pub(crate) type BoxedIterator<'a, T> = Box + 'a>; diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/mod.rs b/crates/medmodels-core/src/medrecord/querying/nodes/mod.rs new file mode 100644 index 00000000..1041a7e9 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/nodes/mod.rs @@ -0,0 +1,68 @@ +mod operand; +mod operation; +mod selection; + +pub use operand::NodeOperand; +pub use operation::NodeOperation; +pub use selection::NodeSelection; +use std::fmt::Display; + +#[derive(Debug, Clone)] +pub enum SingleKind { + Max, + Min, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum SingleComparisonKind { + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + NotEqualTo, + StartsWith, + EndsWith, + Contains, +} + +#[derive(Debug, Clone)] +pub enum MultipleComparisonKind { + IsIn, + IsNotIn, +} + +#[derive(Debug, Clone)] +pub enum BinaryArithmeticKind { + Add, + Sub, + Mul, + Pow, + Mod, +} + +impl Display for BinaryArithmeticKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BinaryArithmeticKind::Add => write!(f, "add"), + BinaryArithmeticKind::Sub => write!(f, "sub"), + BinaryArithmeticKind::Mul => write!(f, "mul"), + BinaryArithmeticKind::Pow => write!(f, "pow"), + BinaryArithmeticKind::Mod => write!(f, "mod"), + } + } +} + +#[derive(Debug, Clone)] +pub enum UnaryArithmeticKind { + Abs, + Trim, + TrimStart, + TrimEnd, + Lowercase, + Uppercase, +} diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs b/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs new file mode 100644 index 00000000..1800bc00 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs @@ -0,0 +1,732 @@ +use super::{ + operation::{EdgeDirection, NodeIndexOperation, NodeIndicesOperation, NodeOperation}, + BinaryArithmeticKind, MultipleComparisonKind, SingleComparisonKind, SingleKind, + UnaryArithmeticKind, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{ + querying::{ + attributes::{self, AttributesTreeOperand}, + edges::EdgeOperand, + traits::{DeepClone, ReadWriteOrPanic}, + values::{self, MultipleValuesOperand}, + wrapper::{CardinalityWrapper, Wrapper}, + BoxedIterator, + }, + Group, MedRecordAttribute, NodeIndex, + }, + MedRecord, +}; +use std::fmt::Debug; + +#[derive(Debug, Clone)] +pub struct NodeOperand { + operations: Vec, +} + +impl DeepClone for NodeOperand { + fn deep_clone(&self) -> Self { + Self { + operations: self + .operations + .iter() + .map(|operation| operation.deep_clone()) + .collect(), + } + } +} + +impl NodeOperand { + pub(crate) fn new() -> Self { + Self { + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + ) -> MedRecordResult> { + let node_indices = Box::new(medrecord.node_indices()) as BoxedIterator<'a, &'a NodeIndex>; + + self.operations + .iter() + .try_fold(node_indices, |node_indices, operation| { + operation.evaluate(medrecord, node_indices) + }) + } + + pub fn attribute(&mut self, attribute: MedRecordAttribute) -> Wrapper { + let operand = Wrapper::::new( + values::Context::NodeOperand(self.deep_clone()), + attribute, + ); + + self.operations.push(NodeOperation::Values { + operand: operand.clone(), + }); + + operand + } + + pub fn attributes(&mut self) -> Wrapper { + let operand = Wrapper::::new(attributes::Context::NodeOperand( + self.deep_clone(), + )); + + self.operations.push(NodeOperation::Attributes { + operand: operand.clone(), + }); + + operand + } + + pub fn index(&mut self) -> Wrapper { + let operand = Wrapper::::new(self.deep_clone()); + + self.operations.push(NodeOperation::Indices { + operand: operand.clone(), + }); + + operand + } + + pub fn in_group(&mut self, group: G) + where + G: Into>, + { + self.operations.push(NodeOperation::InGroup { + group: group.into(), + }); + } + + pub fn has_attribute(&mut self, attribute: A) + where + A: Into>, + { + self.operations.push(NodeOperation::HasAttribute { + attribute: attribute.into(), + }); + } + + pub fn outgoing_edges(&mut self) -> Wrapper { + let operand = Wrapper::::new(); + + self.operations.push(NodeOperation::OutgoingEdges { + operand: operand.clone(), + }); + + operand + } + + pub fn incoming_edges(&mut self) -> Wrapper { + let operand = Wrapper::::new(); + + self.operations.push(NodeOperation::IncomingEdges { + operand: operand.clone(), + }); + + operand + } + + pub fn neighbors(&mut self, direction: EdgeDirection) -> Wrapper { + let operand = Wrapper::::new(); + + self.operations.push(NodeOperation::Neighbors { + operand: operand.clone(), + direction, + }); + + operand + } + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(); + let mut or_operand = Wrapper::::new(); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(NodeOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new() -> Self { + NodeOperand::new().into() + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord) + } + + pub fn attribute(&mut self, attribute: MedRecordAttribute) -> Wrapper { + self.0.write_or_panic().attribute(attribute) + } + + pub fn attributes(&mut self) -> Wrapper { + self.0.write_or_panic().attributes() + } + + pub fn index(&mut self) -> Wrapper { + self.0.write_or_panic().index() + } + + pub fn in_group(&mut self, group: G) + where + G: Into>, + { + self.0.write_or_panic().in_group(group); + } + + pub fn has_attribute(&mut self, attribute: A) + where + A: Into>, + { + self.0.write_or_panic().has_attribute(attribute); + } + + pub fn outgoing_edges(&mut self) -> Wrapper { + self.0.write_or_panic().outgoing_edges() + } + + pub fn incoming_edges(&mut self) -> Wrapper { + self.0.write_or_panic().incoming_edges() + } + + pub fn neighbors(&mut self, direction: EdgeDirection) -> Wrapper { + self.0.write_or_panic().neighbors(direction) + } + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +macro_rules! implement_value_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = Wrapper::::new(self.deep_clone(), SingleKind::$variant); + + self.operations + .push(NodeIndicesOperation::NodeIndexOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_single_value_comparison_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations + .push($operation::NodeIndexComparisonOperation { + operand: value.into(), + kind: SingleComparisonKind::$kind, + }); + } + }; +} + +macro_rules! implement_binary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations.push($operation::BinaryArithmeticOpration { + operand: value.into(), + kind: BinaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_unary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name(&mut self) { + self.operations.push($operation::UnaryArithmeticOperation { + kind: UnaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_assertion_operation { + ($name:ident, $operation:expr) => { + pub fn $name(&mut self) { + self.operations.push($operation); + } + }; +} + +macro_rules! implement_wrapper_operand { + ($name:ident) => { + pub fn $name(&self) { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_return { + ($name:ident, $return_operand:ident) => { + pub fn $name(&self) -> Wrapper<$return_operand> { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_argument { + ($name:ident, $value_type:ty) => { + pub fn $name(&self, value: $value_type) { + self.0.write_or_panic().$name(value) + } + }; +} + +#[derive(Debug, Clone)] +pub enum NodeIndexComparisonOperand { + Operand(NodeIndexOperand), + Index(NodeIndex), +} + +impl DeepClone for NodeIndexComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Index(value) => Self::Index(value.clone()), + } + } +} + +impl From> for NodeIndexComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for NodeIndexComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From for NodeIndexComparisonOperand { + fn from(value: V) -> Self { + Self::Index(value.into()) + } +} + +#[derive(Debug, Clone)] +pub enum NodeIndicesComparisonOperand { + Operand(NodeIndicesOperand), + Indices(Vec), +} + +impl DeepClone for NodeIndicesComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Indices(value) => Self::Indices(value.clone()), + } + } +} + +impl From> for NodeIndicesComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for NodeIndicesComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From> for NodeIndicesComparisonOperand { + fn from(value: Vec) -> Self { + Self::Indices(value.into_iter().map(Into::into).collect()) + } +} + +impl + Clone, const N: usize> From<[V; N]> for NodeIndicesComparisonOperand { + fn from(value: [V; N]) -> Self { + value.to_vec().into() + } +} + +#[derive(Debug, Clone)] +pub struct NodeIndicesOperand { + pub(crate) context: NodeOperand, + operations: Vec, +} + +impl DeepClone for NodeIndicesOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl NodeIndicesOperand { + pub(crate) fn new(context: NodeOperand) -> Self { + Self { + context, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult + 'a> { + let values = Box::new(values) as BoxedIterator; + + self.operations + .iter() + .try_fold(values, |value_tuples, operation| { + operation.evaluate(medrecord, value_tuples) + }) + } + + implement_value_operation!(max, Max); + implement_value_operation!(min, Min); + implement_value_operation!(count, Count); + implement_value_operation!(sum, Sum); + implement_value_operation!(first, First); + implement_value_operation!(last, Last); + + implement_single_value_comparison_operation!(greater_than, NodeIndicesOperation, GreaterThan); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + NodeIndicesOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, NodeIndicesOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + NodeIndicesOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, NodeIndicesOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, NodeIndicesOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, NodeIndicesOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, NodeIndicesOperation, EndsWith); + implement_single_value_comparison_operation!(contains, NodeIndicesOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(NodeIndicesOperation::NodeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(NodeIndicesOperation::NodeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, NodeIndicesOperation, Add); + implement_binary_arithmetic_operation!(sub, NodeIndicesOperation, Sub); + implement_binary_arithmetic_operation!(mul, NodeIndicesOperation, Mul); + implement_binary_arithmetic_operation!(pow, NodeIndicesOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, NodeIndicesOperation, Mod); + + implement_unary_arithmetic_operation!(abs, NodeIndicesOperation, Abs); + implement_unary_arithmetic_operation!(trim, NodeIndicesOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, NodeIndicesOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, NodeIndicesOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, NodeIndicesOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, NodeIndicesOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(NodeIndicesOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, NodeIndicesOperation::IsString); + implement_assertion_operation!(is_int, NodeIndicesOperation::IsInt); + implement_assertion_operation!(is_max, NodeIndicesOperation::IsMax); + implement_assertion_operation!(is_min, NodeIndicesOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(self.context.clone()); + let mut or_operand = Wrapper::::new(self.context.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(NodeIndicesOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: NodeOperand) -> Self { + NodeIndicesOperand::new(context).into() + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult + 'a> { + self.0.read_or_panic().evaluate(medrecord, values) + } + + implement_wrapper_operand_with_return!(max, NodeIndexOperand); + implement_wrapper_operand_with_return!(min, NodeIndexOperand); + implement_wrapper_operand_with_return!(count, NodeIndexOperand); + implement_wrapper_operand_with_return!(sum, NodeIndexOperand); + implement_wrapper_operand_with_return!(first, NodeIndexOperand); + implement_wrapper_operand_with_return!(last, NodeIndexOperand); + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +#[derive(Debug, Clone)] +pub struct NodeIndexOperand { + pub(crate) context: NodeIndicesOperand, + pub(crate) kind: SingleKind, + operations: Vec, +} + +impl DeepClone for NodeIndexOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.deep_clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl NodeIndexOperand { + pub(crate) fn new(context: NodeIndicesOperand, kind: SingleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: NodeIndex, + ) -> MedRecordResult> { + self.operations + .iter() + .try_fold(Some(value), |value, operation| { + if let Some(value) = value { + operation.evaluate(medrecord, value) + } else { + Ok(None) + } + }) + } + + implement_single_value_comparison_operation!(greater_than, NodeIndexOperation, GreaterThan); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + NodeIndexOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, NodeIndexOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + NodeIndexOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, NodeIndexOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, NodeIndexOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, NodeIndexOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, NodeIndexOperation, EndsWith); + implement_single_value_comparison_operation!(contains, NodeIndexOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(NodeIndexOperation::NodeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(NodeIndexOperation::NodeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, NodeIndexOperation, Add); + implement_binary_arithmetic_operation!(sub, NodeIndexOperation, Sub); + implement_binary_arithmetic_operation!(mul, NodeIndexOperation, Mul); + implement_binary_arithmetic_operation!(pow, NodeIndexOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, NodeIndexOperation, Mod); + + implement_unary_arithmetic_operation!(abs, NodeIndexOperation, Abs); + implement_unary_arithmetic_operation!(trim, NodeIndexOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, NodeIndexOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, NodeIndexOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, NodeIndexOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, NodeIndexOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations.push(NodeIndexOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, NodeIndexOperation::IsString); + implement_assertion_operation!(is_int, NodeIndexOperation::IsInt); + pub fn eiter_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(NodeIndexOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: NodeIndicesOperand, kind: SingleKind) -> Self { + NodeIndexOperand::new(context, kind).into() + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: NodeIndex, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, value) + } + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().eiter_or(either_query, or_query); + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs b/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs new file mode 100644 index 00000000..90e6692c --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs @@ -0,0 +1,971 @@ +use super::{ + operand::{ + NodeIndexComparisonOperand, NodeIndexOperand, NodeIndicesComparisonOperand, + NodeIndicesOperand, + }, + BinaryArithmeticKind, MultipleComparisonKind, NodeOperand, SingleComparisonKind, SingleKind, + UnaryArithmeticKind, +}; +use crate::{ + errors::{MedRecordError, MedRecordResult}, + medrecord::{ + datatypes::{ + Abs, Contains, EndsWith, Lowercase, Mod, Pow, Slice, StartsWith, Trim, TrimEnd, + TrimStart, Uppercase, + }, + querying::{ + attributes::AttributesTreeOperand, + edges::EdgeOperand, + traits::{DeepClone, ReadWriteOrPanic}, + values::MultipleValuesOperand, + wrapper::{CardinalityWrapper, Wrapper}, + BoxedIterator, + }, + DataType, Group, MedRecord, MedRecordAttribute, MedRecordValue, NodeIndex, + }, +}; +use itertools::Itertools; +use roaring::RoaringBitmap; +use std::{ + cmp::Ordering, + collections::HashSet, + ops::{Add, Mul, Range, Sub}, +}; + +#[derive(Debug, Clone)] +pub enum EdgeDirection { + Incoming, + Outgoing, + Both, +} + +#[derive(Debug, Clone)] +pub enum NodeOperation { + Values { + operand: Wrapper, + }, + Attributes { + operand: Wrapper, + }, + Indices { + operand: Wrapper, + }, + + InGroup { + group: CardinalityWrapper, + }, + HasAttribute { + attribute: CardinalityWrapper, + }, + + OutgoingEdges { + operand: Wrapper, + }, + IncomingEdges { + operand: Wrapper, + }, + + Neighbors { + operand: Wrapper, + direction: EdgeDirection, + }, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for NodeOperation { + fn deep_clone(&self) -> Self { + match self { + Self::Values { operand } => Self::Values { + operand: operand.deep_clone(), + }, + Self::Attributes { operand } => Self::Attributes { + operand: operand.deep_clone(), + }, + Self::Indices { operand } => Self::Indices { + operand: operand.deep_clone(), + }, + Self::InGroup { group } => Self::InGroup { + group: group.clone(), + }, + Self::HasAttribute { attribute } => Self::HasAttribute { + attribute: attribute.clone(), + }, + Self::OutgoingEdges { operand } => Self::OutgoingEdges { + operand: operand.deep_clone(), + }, + Self::IncomingEdges { operand } => Self::IncomingEdges { + operand: operand.deep_clone(), + }, + Self::Neighbors { + operand, + direction: drection, + } => Self::Neighbors { + operand: operand.deep_clone(), + direction: drection.clone(), + }, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl NodeOperation { + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + node_indices: impl Iterator + 'a, + ) -> MedRecordResult> { + Ok(match self { + Self::Values { operand } => Box::new(Self::evaluate_values( + medrecord, + node_indices, + operand.clone(), + )?), + Self::Attributes { operand } => Box::new(Self::evaluate_attributes( + medrecord, + node_indices, + operand.clone(), + )?), + Self::Indices { operand } => Box::new(Self::evaluate_indices( + medrecord, + node_indices, + operand.clone(), + )?), + Self::InGroup { group } => Box::new(Self::evaluate_in_group( + medrecord, + node_indices, + group.clone(), + )), + Self::HasAttribute { attribute } => Box::new(Self::evaluate_has_attribute( + medrecord, + node_indices, + attribute.clone(), + )), + Self::OutgoingEdges { operand } => Box::new(Self::evaluate_outgoing_edges( + medrecord, + node_indices, + operand.clone(), + )?), + Self::IncomingEdges { operand } => Box::new(Self::evaluate_incoming_edges( + medrecord, + node_indices, + operand.clone(), + )?), + Self::Neighbors { + operand, + direction: drection, + } => Box::new(Self::evaluate_neighbors( + medrecord, + node_indices, + operand.clone(), + drection.clone(), + )?), + Self::EitherOr { either, or } => { + // TODO: This is a temporary solution. It should be optimized. + let either_result = either.evaluate(medrecord)?.collect::>(); + let or_result = or.evaluate(medrecord)?.collect::>(); + + Box::new(either_result.into_iter().chain(or_result).unique()) + } + }) + } + + #[inline] + pub(crate) fn get_values<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + attribute: MedRecordAttribute, + ) -> impl Iterator { + node_indices.flat_map(move |node_index| { + Some(( + node_index, + medrecord + .node_attributes(node_index) + .expect("Edge must exist") + .get(&attribute)? + .clone(), + )) + }) + } + + #[inline] + fn evaluate_values<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator + 'a, + operand: Wrapper, + ) -> MedRecordResult> { + let values = Self::get_values( + medrecord, + node_indices, + operand.0.read_or_panic().attribute.clone(), + ); + + Ok(operand.evaluate(medrecord, values)?.map(|value| value.0)) + } + + #[inline] + pub(crate) fn get_attributes<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + ) -> impl Iterator)> { + node_indices.map(move |node_index| { + let attributes = medrecord + .node_attributes(node_index) + .expect("Edge must exist") + .keys() + .cloned(); + + (node_index, attributes.collect()) + }) + } + + #[inline] + fn evaluate_attributes<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator + 'a, + operand: Wrapper, + ) -> MedRecordResult> { + let attributes = Self::get_attributes(medrecord, node_indices); + + Ok(operand + .evaluate(medrecord, attributes)? + .map(|value| value.0)) + } + + #[inline] + fn evaluate_indices<'a>( + medrecord: &MedRecord, + edge_indices: impl Iterator, + operand: Wrapper, + ) -> MedRecordResult> { + // TODO: This is a temporary solution. It should be optimized. + let node_indices = edge_indices.collect::>(); + + let result = operand + .evaluate(medrecord, node_indices.clone().into_iter().cloned())? + .collect::>(); + + Ok(node_indices + .into_iter() + .filter(move |index| result.contains(index))) + } + + #[inline] + fn evaluate_in_group<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + group: CardinalityWrapper, + ) -> impl Iterator { + node_indices.filter(move |node_index| { + let groups_of_node = medrecord + .groups_of_node(node_index) + .expect("Node must exist"); + + let groups_of_node = groups_of_node.collect::>(); + + match &group { + CardinalityWrapper::Single(group) => groups_of_node.contains(&group), + CardinalityWrapper::Multiple(groups) => { + groups.iter().all(|group| groups_of_node.contains(&group)) + } + } + }) + } + + #[inline] + fn evaluate_has_attribute<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + attribute: CardinalityWrapper, + ) -> impl Iterator { + node_indices.filter(move |node_index| { + let attributes_of_node = medrecord + .node_attributes(node_index) + .expect("Node must exist") + .keys(); + + let attributes_of_node = attributes_of_node.collect::>(); + + match &attribute { + CardinalityWrapper::Single(attribute) => attributes_of_node.contains(&attribute), + CardinalityWrapper::Multiple(attributes) => attributes + .iter() + .all(|attribute| attributes_of_node.contains(&attribute)), + } + }) + } + + #[inline] + fn evaluate_outgoing_edges<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + operand: Wrapper, + ) -> MedRecordResult> { + let edge_indices = operand.evaluate(medrecord)?.collect::(); + + Ok(node_indices.filter(move |node_index| { + let outgoing_edge_indices = medrecord + .outgoing_edges(node_index) + .expect("Node must exist"); + + let outgoing_edge_indices = outgoing_edge_indices.collect::(); + + !outgoing_edge_indices.is_disjoint(&edge_indices) + })) + } + + #[inline] + fn evaluate_incoming_edges<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + operand: Wrapper, + ) -> MedRecordResult> { + let edge_indices = operand.evaluate(medrecord)?.collect::(); + + Ok(node_indices.filter(move |node_index| { + let incoming_edge_indices = medrecord + .incoming_edges(node_index) + .expect("Node must exist"); + + let incoming_edge_indices = incoming_edge_indices.collect::(); + + !incoming_edge_indices.is_disjoint(&edge_indices) + })) + } + + #[inline] + fn evaluate_neighbors<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + operand: Wrapper, + direction: EdgeDirection, + ) -> MedRecordResult> { + let result = operand.evaluate(medrecord)?.collect::>(); + + Ok(node_indices.filter(move |node_index| { + let mut neighbors: Box> = match direction { + EdgeDirection::Incoming => Box::new( + medrecord + .neighbors_incoming(node_index) + .expect("Node must exist"), + ), + EdgeDirection::Outgoing => Box::new( + medrecord + .neighbors_outgoing(node_index) + .expect("Node must exist"), + ), + EdgeDirection::Both => Box::new( + medrecord + .neighbors_undirected(node_index) + .expect("Node must exist"), + ), + }; + + neighbors.any(|neighbor| result.contains(&neighbor)) + })) + } +} + +macro_rules! get_node_index { + ($kind:ident, $indices:expr) => { + match $kind { + SingleKind::Max => NodeIndicesOperation::get_max($indices)?.clone(), + SingleKind::Min => NodeIndicesOperation::get_min($indices)?.clone(), + SingleKind::Count => NodeIndicesOperation::get_count($indices), + SingleKind::Sum => NodeIndicesOperation::get_sum($indices)?, + SingleKind::First => NodeIndicesOperation::get_first($indices)?, + SingleKind::Last => NodeIndicesOperation::get_last($indices)?, + } + }; +} + +macro_rules! get_node_index_comparison_operand { + ($operand:ident, $medrecord:ident) => { + match $operand { + NodeIndexComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + // TODO: This is a temporary solution. It should be optimized. + let comparison_indices = context.evaluate($medrecord)?.cloned(); + + let comparison_index = get_node_index!(kind, comparison_indices); + + comparison_index + } + NodeIndexComparisonOperand::Index(index) => index.clone(), + } + }; +} + +#[derive(Debug, Clone)] +pub enum NodeIndicesOperation { + NodeIndexOperation { + operand: Wrapper, + }, + NodeIndexComparisonOperation { + operand: NodeIndexComparisonOperand, + kind: SingleComparisonKind, + }, + NodeIndicesComparisonOperation { + operand: NodeIndicesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: NodeIndexComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for NodeIndicesOperation { + fn deep_clone(&self) -> Self { + match self { + Self::NodeIndexOperation { operand } => Self::NodeIndexOperation { + operand: operand.deep_clone(), + }, + Self::NodeIndexComparisonOperation { operand, kind } => { + Self::NodeIndexComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::NodeIndicesComparisonOperation { operand, kind } => { + Self::NodeIndicesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl NodeIndicesOperation { + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + indices: impl Iterator + 'a, + ) -> MedRecordResult> { + match self { + Self::NodeIndexOperation { operand } => { + Self::evaluate_node_index_operation(medrecord, indices, operand) + } + Self::NodeIndexComparisonOperation { operand, kind } => { + Self::evaluate_node_index_comparison_operation(medrecord, indices, operand, kind) + } + Self::NodeIndicesComparisonOperation { operand, kind } => { + Self::evaluate_node_indices_comparison_operation(medrecord, indices, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Ok(Box::new(Self::evaluate_binary_arithmetic_operation( + medrecord, + indices, + operand, + kind.clone(), + )?)) + } + Self::UnaryArithmeticOperation { kind } => Ok(Box::new( + Self::evaluate_unary_arithmetic_operation(indices, kind.clone()), + )), + Self::Slice(range) => Ok(Box::new(Self::evaluate_slice(indices, range.clone()))), + Self::IsString => { + Ok(Box::new(indices.filter(|index| { + matches!(index, MedRecordAttribute::String(_)) + }))) + } + Self::IsInt => { + Ok(Box::new(indices.filter(|index| { + matches!(index, MedRecordAttribute::Int(_)) + }))) + } + Self::IsMax => { + let max_index = Self::get_max(indices)?; + + Ok(Box::new(std::iter::once(max_index))) + } + Self::IsMin => { + let min_index = Self::get_min(indices)?; + + Ok(Box::new(std::iter::once(min_index))) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, indices, either, or) + } + } + } + + #[inline] + pub(crate) fn get_max( + mut indices: impl Iterator, + ) -> MedRecordResult { + let max_index = indices.next().ok_or(MedRecordError::QueryError( + "No indices to compare".to_string(), + ))?; + + indices.try_fold(max_index, |max_index, index| { + match index + .partial_cmp(&max_index) { + Some(Ordering::Greater) => Ok(index), + None => { + let first_dtype = DataType::from(index); + let second_dtype = DataType::from(max_index); + + Err(MedRecordError::QueryError(format!( + "Cannot compare indices of data types {} and {}. Consider narrowing down the indices using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(max_index), + } + }) + } + + #[inline] + pub(crate) fn get_min( + mut indices: impl Iterator, + ) -> MedRecordResult { + let min_index = indices.next().ok_or(MedRecordError::QueryError( + "No indices to compare".to_string(), + ))?; + + indices.try_fold(min_index, |min_index, index| { + match index.partial_cmp(&min_index) { + Some(Ordering::Less) => Ok(index), + None => { + let first_dtype = DataType::from(index); + let second_dtype = DataType::from(min_index); + + Err(MedRecordError::QueryError(format!( + "Cannot compare indices of data types {} and {}. Consider narrowing down the indices using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(min_index), + } + }) + } + #[inline] + pub(crate) fn get_count(indices: impl Iterator) -> NodeIndex { + MedRecordAttribute::Int(indices.count() as i64) + } + + #[inline] + // 🥊💥 + pub(crate) fn get_sum( + mut indices: impl Iterator, + ) -> MedRecordResult { + let first_value = indices + .next() + .ok_or(MedRecordError::QueryError("No indices to sum".to_string()))?; + + indices.try_fold(first_value, |sum, index| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&index); + + sum.add(index).map_err(|_| { + MedRecordError::QueryError(format!( + "Cannot add indices of data types {} and {}. Consider narrowing down the indices using .is_string() or .is_int()", + first_dtype, second_dtype + )) + }) + }) + } + + #[inline] + pub(crate) fn get_first( + mut indices: impl Iterator, + ) -> MedRecordResult { + indices.next().ok_or(MedRecordError::QueryError( + "No indices to get the first".to_string(), + )) + } + + #[inline] + pub(crate) fn get_last(indices: impl Iterator) -> MedRecordResult { + indices.last().ok_or(MedRecordError::QueryError( + "No indices to get the first".to_string(), + )) + } + + #[inline] + fn evaluate_node_index_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let kind = &operand.0.read_or_panic().kind; + + let indices = indices.collect::>(); + + let index = get_node_index!(kind, indices.clone().into_iter()); + + Ok(match operand.evaluate(medrecord, index)? { + Some(_) => Box::new(indices.into_iter()), + None => Box::new(std::iter::empty()), + }) + } + + #[inline] + fn evaluate_node_index_comparison_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator + 'a, + comparison_operand: &NodeIndexComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_index = get_node_index_comparison_operand!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => Ok(Box::new( + indices.filter(move |index| index > &comparison_index), + )), + SingleComparisonKind::GreaterThanOrEqualTo => Ok(Box::new( + indices.filter(move |index| index >= &comparison_index), + )), + SingleComparisonKind::LessThan => Ok(Box::new( + indices.filter(move |index| index < &comparison_index), + )), + SingleComparisonKind::LessThanOrEqualTo => Ok(Box::new( + indices.filter(move |index| index <= &comparison_index), + )), + SingleComparisonKind::EqualTo => Ok(Box::new( + indices.filter(move |index| index == &comparison_index), + )), + SingleComparisonKind::NotEqualTo => Ok(Box::new( + indices.filter(move |index| index != &comparison_index), + )), + SingleComparisonKind::StartsWith => Ok(Box::new( + indices.filter(move |index| index.starts_with(&comparison_index)), + )), + SingleComparisonKind::EndsWith => Ok(Box::new( + indices.filter(move |index| index.ends_with(&comparison_index)), + )), + SingleComparisonKind::Contains => Ok(Box::new( + indices.filter(move |index| index.contains(&comparison_index)), + )), + } + } + + #[inline] + fn evaluate_node_indices_comparison_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator + 'a, + comparison_operand: &NodeIndicesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_indices = match comparison_operand { + NodeIndicesComparisonOperand::Operand(operand) => { + let context = &operand.context; + + context.evaluate(medrecord)?.cloned().collect::>() + } + NodeIndicesComparisonOperand::Indices(indices) => indices.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => Ok(Box::new( + indices.filter(move |index| comparison_indices.contains(index)), + )), + MultipleComparisonKind::IsNotIn => Ok(Box::new( + indices.filter(move |index| !comparison_indices.contains(index)), + )), + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + indices: impl Iterator, + operand: &NodeIndexComparisonOperand, + kind: BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_index = get_node_index_comparison_operand!(operand, medrecord); + + let indices = indices + .map(move |index| { + match kind { + BinaryArithmeticKind::Add => index.add(arithmetic_index.clone()), + BinaryArithmeticKind::Sub => index.sub(arithmetic_index.clone()), + BinaryArithmeticKind::Mul => { + index.clone().mul(arithmetic_index.clone()) + } + BinaryArithmeticKind::Pow => { + index.clone().pow(arithmetic_index.clone()) + } + BinaryArithmeticKind::Mod => { + index.clone().r#mod(arithmetic_index.clone()) + } + } + .map_err(|_| { + MedRecordError::QueryError(format!( + "Failed arithmetic operation {}. Consider narrowing down the indices using .is_string() or .is_int()", + kind, + )) + }) + }); + + // TODO: This is a temporary solution. It should be optimized. + Ok(indices.collect::>>()?.into_iter()) + } + + #[inline] + fn evaluate_unary_arithmetic_operation( + indices: impl Iterator, + kind: UnaryArithmeticKind, + ) -> impl Iterator { + indices.map(move |index| match kind { + UnaryArithmeticKind::Abs => index.abs(), + UnaryArithmeticKind::Trim => index.trim(), + UnaryArithmeticKind::TrimStart => index.trim_start(), + UnaryArithmeticKind::TrimEnd => index.trim_end(), + UnaryArithmeticKind::Lowercase => index.lowercase(), + UnaryArithmeticKind::Uppercase => index.uppercase(), + }) + } + + #[inline] + fn evaluate_slice( + indices: impl Iterator, + range: Range, + ) -> impl Iterator { + indices.map(move |index| index.slice(range.clone())) + } + + #[inline] + fn evaluate_either_or<'a>( + medrecord: &'a MedRecord, + indices: impl Iterator, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let indices = indices.collect::>(); + + let either_indices = either.evaluate(medrecord, indices.clone().into_iter())?; + let or_indices = or.evaluate(medrecord, indices.into_iter())?; + + Ok(Box::new(either_indices.chain(or_indices).unique())) + } +} + +#[derive(Debug, Clone)] +pub enum NodeIndexOperation { + NodeIndexComparisonOperation { + operand: NodeIndexComparisonOperand, + kind: SingleComparisonKind, + }, + NodeIndicesComparisonOperation { + operand: NodeIndicesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: NodeIndexComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for NodeIndexOperation { + fn deep_clone(&self) -> Self { + match self { + Self::NodeIndexComparisonOperation { operand, kind } => { + Self::NodeIndexComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::NodeIndicesComparisonOperation { operand, kind } => { + Self::NodeIndicesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl NodeIndexOperation { + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + index: NodeIndex, + ) -> MedRecordResult> { + match self { + Self::NodeIndexComparisonOperation { operand, kind } => { + Self::evaluate_node_index_comparison_operation(medrecord, index, operand, kind) + } + Self::NodeIndicesComparisonOperation { operand, kind } => { + Self::evaluate_node_indices_comparison_operation(medrecord, index, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, index, operand, kind) + } + Self::UnaryArithmeticOperation { kind } => Ok(Some(match kind { + UnaryArithmeticKind::Abs => index.abs(), + UnaryArithmeticKind::Trim => index.trim(), + UnaryArithmeticKind::TrimStart => index.trim_start(), + UnaryArithmeticKind::TrimEnd => index.trim_end(), + UnaryArithmeticKind::Lowercase => index.lowercase(), + UnaryArithmeticKind::Uppercase => index.uppercase(), + })), + Self::Slice(range) => Ok(Some(index.slice(range.clone()))), + Self::IsString => Ok(match index { + MedRecordAttribute::String(_) => Some(index), + _ => None, + }), + Self::IsInt => Ok(match index { + MedRecordAttribute::Int(_) => Some(index), + _ => None, + }), + Self::EitherOr { either, or } => Self::evaluate_either_or(medrecord, index, either, or), + } + } + + #[inline] + fn evaluate_node_index_comparison_operation( + medrecord: &MedRecord, + index: NodeIndex, + comparison_operand: &NodeIndexComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_index = get_node_index_comparison_operand!(comparison_operand, medrecord); + + let comparison_result = match kind { + SingleComparisonKind::GreaterThan => index > comparison_index, + SingleComparisonKind::GreaterThanOrEqualTo => index >= comparison_index, + SingleComparisonKind::LessThan => index < comparison_index, + SingleComparisonKind::LessThanOrEqualTo => index <= comparison_index, + SingleComparisonKind::EqualTo => index == comparison_index, + SingleComparisonKind::NotEqualTo => index != comparison_index, + SingleComparisonKind::StartsWith => index.starts_with(&comparison_index), + SingleComparisonKind::EndsWith => index.ends_with(&comparison_index), + SingleComparisonKind::Contains => index.contains(&comparison_index), + }; + + Ok(if comparison_result { Some(index) } else { None }) + } + + #[inline] + fn evaluate_node_indices_comparison_operation( + medrecord: &MedRecord, + index: NodeIndex, + comparison_operand: &NodeIndicesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_indices = match comparison_operand { + NodeIndicesComparisonOperand::Operand(operand) => { + let context = &operand.context; + + context.evaluate(medrecord)?.cloned().collect::>() + } + NodeIndicesComparisonOperand::Indices(indices) => indices.clone(), + }; + + let comparison_result = match kind { + MultipleComparisonKind::IsIn => comparison_indices + .into_iter() + .any(|comparison_index| index == comparison_index), + MultipleComparisonKind::IsNotIn => comparison_indices + .into_iter() + .all(|comparison_index| index != comparison_index), + }; + + Ok(if comparison_result { Some(index) } else { None }) + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + index: NodeIndex, + operand: &NodeIndexComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_index = get_node_index_comparison_operand!(operand, medrecord); + + Ok(Some(match kind { + BinaryArithmeticKind::Add => index.add(arithmetic_index)?, + BinaryArithmeticKind::Sub => index.sub(arithmetic_index)?, + BinaryArithmeticKind::Mul => index.mul(arithmetic_index)?, + BinaryArithmeticKind::Pow => index.pow(arithmetic_index)?, + BinaryArithmeticKind::Mod => index.r#mod(arithmetic_index)?, + })) + } + + #[inline] + fn evaluate_either_or( + medrecord: &MedRecord, + index: NodeIndex, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord, index.clone())?; + let or_result = or.evaluate(medrecord, index)?; + + match (either_result, or_result) { + (Some(either_result), _) => Ok(Some(either_result)), + (None, Some(or_result)) => Ok(Some(or_result)), + _ => Ok(None), + } + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/selection.rs b/crates/medmodels-core/src/medrecord/querying/nodes/selection.rs new file mode 100644 index 00000000..d994543d --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/nodes/selection.rs @@ -0,0 +1,35 @@ +use super::NodeOperand; +use crate::{ + errors::MedRecordResult, + medrecord::{querying::wrapper::Wrapper, MedRecord, NodeIndex}, +}; + +#[derive(Debug, Clone)] +pub struct NodeSelection<'a> { + medrecord: &'a MedRecord, + operand: Wrapper, +} + +impl<'a> NodeSelection<'a> { + pub fn new(medrecord: &'a MedRecord, query: Q) -> Self + where + Q: FnOnce(&mut Wrapper), + { + let mut operand = Wrapper::::new(); + + query(&mut operand); + + Self { medrecord, operand } + } + + pub fn iter(self) -> MedRecordResult> { + self.operand.evaluate(self.medrecord) + } + + pub fn collect(self) -> MedRecordResult + where + B: FromIterator<&'a NodeIndex>, + { + Ok(FromIterator::from_iter(self.iter()?)) + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/operation/edge_operation.rs b/crates/medmodels-core/src/medrecord/querying/operation/edge_operation.rs deleted file mode 100644 index f005c53f..00000000 --- a/crates/medmodels-core/src/medrecord/querying/operation/edge_operation.rs +++ /dev/null @@ -1,475 +0,0 @@ -#![allow(clippy::should_implement_trait)] - -use super::{ - operand::{ArithmeticOperation, EdgeIndexInOperand, IntoVecEdgeIndex, ValueOperand}, - AttributeOperation, NodeOperation, Operation, -}; -use crate::medrecord::{ - datatypes::{ - Abs, Ceil, Floor, Lowercase, Mod, Pow, Round, Slice, Sqrt, Trim, TrimEnd, TrimStart, - Uppercase, - }, - EdgeIndex, MedRecord, MedRecordAttribute, -}; - -#[derive(Debug, Clone)] -pub enum EdgeIndexOperation { - Gt(EdgeIndex), - Lt(EdgeIndex), - Gte(EdgeIndex), - Lte(EdgeIndex), - Eq(EdgeIndex), - In(Box), -} - -#[derive(Debug, Clone)] -pub enum EdgeOperation { - Attribute(AttributeOperation), - Index(EdgeIndexOperation), - - ConnectedSource(MedRecordAttribute), - ConnectedTarget(MedRecordAttribute), - InGroup(MedRecordAttribute), - HasAttribute(MedRecordAttribute), - - ConnectedSourceWith(Box), - ConnectedTargetWith(Box), - - HasParallelEdgesWith(Box), - HasParallelEdgesWithSelfComparison(Box), - - And(Box<(EdgeOperation, EdgeOperation)>), - Or(Box<(EdgeOperation, EdgeOperation)>), - Not(Box), -} - -impl Operation for EdgeOperation { - type IndexType = EdgeIndex; - - fn evaluate<'a>( - self, - medrecord: &'a MedRecord, - indices: impl Iterator + 'a, - ) -> Box + 'a> { - match self { - EdgeOperation::Attribute(attribute_operation) => { - Self::evaluate_attribute(indices, attribute_operation, |index| { - medrecord.edge_attributes(index) - }) - } - EdgeOperation::Index(index_operation) => { - Self::evaluate_index(medrecord, indices, index_operation) - } - - EdgeOperation::ConnectedSource(attribute_operand) => Box::new( - Self::evaluate_connected_target(medrecord, indices, attribute_operand), - ), - EdgeOperation::ConnectedTarget(attribute_operand) => Box::new( - Self::evaluate_connected_source(medrecord, indices, attribute_operand), - ), - EdgeOperation::InGroup(attribute_operand) => Box::new(Self::evaluate_in_group( - medrecord, - indices, - attribute_operand, - )), - EdgeOperation::HasAttribute(attribute_operand) => Box::new( - Self::evaluate_has_attribute(indices, attribute_operand, |index| { - medrecord.edge_attributes(index) - }), - ), - - EdgeOperation::ConnectedSourceWith(operation) => Box::new( - Self::evaluate_connected_source_with(medrecord, indices, *operation), - ), - EdgeOperation::ConnectedTargetWith(operation) => Box::new( - Self::evaluate_connected_target_with(medrecord, indices, *operation), - ), - - EdgeOperation::HasParallelEdgesWith(operation) => { - Self::evaluate_has_parallel_edges_with(medrecord, Box::new(indices), *operation) - } - EdgeOperation::HasParallelEdgesWithSelfComparison(operation) => { - Self::evaluate_has_parallel_edges_with_compare_to_self( - medrecord, - Box::new(indices), - *operation, - ) - } - - EdgeOperation::And(operations) => Box::new(Self::evaluate_and( - medrecord, - indices.collect::>(), - (*operations).0, - (*operations).1, - )), - EdgeOperation::Or(operations) => Box::new(Self::evaluate_or( - medrecord, - indices.collect::>(), - (*operations).0, - (*operations).1, - )), - EdgeOperation::Not(operation) => Box::new(Self::evaluate_not( - medrecord, - indices.collect::>(), - *operation, - )), - } - } -} - -impl EdgeOperation { - pub fn and(self, operation: EdgeOperation) -> EdgeOperation { - EdgeOperation::And(Box::new((self, operation))) - } - - pub fn or(self, operation: EdgeOperation) -> EdgeOperation { - EdgeOperation::Or(Box::new((self, operation))) - } - - pub fn xor(self, operation: EdgeOperation) -> EdgeOperation { - EdgeOperation::And(Box::new((self, operation))).not() - } - - pub fn not(self) -> EdgeOperation { - EdgeOperation::Not(Box::new(self)) - } - - fn evaluate_index<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator + 'a, - operation: EdgeIndexOperation, - ) -> Box + 'a> { - match operation { - EdgeIndexOperation::Gt(operand) => { - Box::new(Self::evaluate_index_gt(edge_indices, operand)) - } - EdgeIndexOperation::Lt(operand) => { - Box::new(Self::evaluate_index_lt(edge_indices, operand)) - } - EdgeIndexOperation::Gte(operand) => { - Box::new(Self::evaluate_index_gte(edge_indices, operand)) - } - EdgeIndexOperation::Lte(operand) => { - Box::new(Self::evaluate_index_lte(edge_indices, operand)) - } - EdgeIndexOperation::Eq(operand) => { - Box::new(Self::evaluate_index_eq(edge_indices, operand)) - } - EdgeIndexOperation::In(operands) => Box::new(Self::evaluate_index_in( - edge_indices, - operands.into_vec_edge_index(medrecord), - )), - } - } - - fn evaluate_connected_target<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - ) -> impl Iterator { - edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - *endpoints.1 == attribute_operand - }) - } - - fn evaluate_connected_source<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - ) -> impl Iterator { - edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - *endpoints.0 == attribute_operand - }) - } - - fn evaluate_in_group<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - ) -> impl Iterator { - let edges_in_group = match medrecord.edges_in_group(&attribute_operand) { - Ok(edges_in_group) => edges_in_group.collect::>(), - Err(_) => Vec::new(), - }; - - edge_indices.filter(move |index| edges_in_group.contains(index)) - } - - fn evaluate_connected_target_with<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator, - operation: NodeOperation, - ) -> impl Iterator { - edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - operation - .clone() - .evaluate(medrecord, vec![endpoints.1].into_iter()) - .count() - > 0 - }) - } - - fn evaluate_connected_source_with<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator, - operation: NodeOperation, - ) -> impl Iterator { - edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - operation - .clone() - .evaluate(medrecord, vec![endpoints.0].into_iter()) - .count() - > 0 - }) - } - - fn evaluate_has_parallel_edges_with<'a>( - medrecord: &'a MedRecord, - edge_indices: Box + 'a>, - operation: EdgeOperation, - ) -> Box + 'a> { - Box::new(edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - let edges = medrecord - .edges_connecting(vec![endpoints.0], vec![endpoints.1]) - .filter(|other_index| other_index != index); - - operation.clone().evaluate(medrecord, edges).count() > 0 - })) - } - - fn convert_value_operand<'a>( - medrecord: &'a MedRecord, - index: &'a EdgeIndex, - value_operand: ValueOperand, - ) -> Option { - match value_operand { - ValueOperand::Value(value) => Some(ValueOperand::Value(value)), - ValueOperand::Evaluate(attribute) => Some(ValueOperand::Value( - medrecord - .edge_attributes(index) - .ok()? - .get(&attribute)? - .clone(), - )), - ValueOperand::ArithmeticOperation(operation, attribute, other_value) => { - let value = medrecord.edge_attributes(index).ok()?.get(&attribute)?; - - let result = match operation { - ArithmeticOperation::Addition => value.clone() + other_value, - ArithmeticOperation::Subtraction => value.clone() - other_value, - ArithmeticOperation::Multiplication => value.clone() * other_value, - ArithmeticOperation::Division => value.clone() / other_value, - ArithmeticOperation::Power => value.clone().pow(other_value), - ArithmeticOperation::Modulo => value.clone().r#mod(other_value), - } - .ok()?; - - Some(ValueOperand::Value(result)) - } - ValueOperand::Slice(attribute, range) => { - let value = medrecord.edge_attributes(index).ok()?.get(&attribute)?; - - Some(ValueOperand::Value(value.clone().slice(range))) - } - ValueOperand::TransformationOperation(operation, attribute) => { - let value = medrecord.edge_attributes(index).ok()?.get(&attribute)?; - - let result = match operation { - super::operand::TransformationOperation::Round => value.clone().round(), - super::operand::TransformationOperation::Ceil => value.clone().ceil(), - super::operand::TransformationOperation::Floor => value.clone().floor(), - super::operand::TransformationOperation::Abs => value.clone().abs(), - super::operand::TransformationOperation::Sqrt => value.clone().sqrt(), - super::operand::TransformationOperation::Trim => value.clone().trim(), - super::operand::TransformationOperation::TrimStart => { - value.clone().trim_start() - } - super::operand::TransformationOperation::TrimEnd => value.clone().trim_end(), - super::operand::TransformationOperation::Lowercase => value.clone().lowercase(), - super::operand::TransformationOperation::Uppercase => value.clone().uppercase(), - }; - - Some(ValueOperand::Value(result)) - } - } - } - fn evaluate_has_parallel_edges_with_compare_to_self<'a>( - medrecord: &'a MedRecord, - edge_indices: Box + 'a>, - operation: EdgeOperation, - ) -> Box + 'a> { - Box::new(edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - let edges = medrecord - .edges_connecting(vec![endpoints.0], vec![endpoints.1]) - .filter(|other_index| other_index != index); - - let operation = operation.clone(); - - let EdgeOperation::Attribute(operation) = operation else { - return operation.evaluate(medrecord, edges).count() > 0; - }; - - match operation { - AttributeOperation::Gt(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Gt(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Lt(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Lt(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Gte(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Gte(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Lte(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Lte(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Eq(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Eq(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Neq(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Neq(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::In(attribute, value) => { - Self::evaluate_attribute( - edges, - AttributeOperation::In(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::NotIn(attribute, value) => { - Self::evaluate_attribute( - edges, - AttributeOperation::In(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::StartsWith(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::StartsWith(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::EndsWith(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::EndsWith(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Contains(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Contains(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - } - })) - } -} diff --git a/crates/medmodels-core/src/medrecord/querying/operation/mod.rs b/crates/medmodels-core/src/medrecord/querying/operation/mod.rs deleted file mode 100644 index 174adeda..00000000 --- a/crates/medmodels-core/src/medrecord/querying/operation/mod.rs +++ /dev/null @@ -1,394 +0,0 @@ -mod edge_operation; -mod node_operation; -mod operand; - -pub use self::{ - edge_operation::EdgeOperation, - node_operation::NodeOperation, - operand::{ - edge, node, ArithmeticOperation, EdgeAttributeOperand, EdgeIndexOperand, EdgeOperand, - NodeAttributeOperand, NodeIndexOperand, NodeOperand, TransformationOperation, ValueOperand, - }, -}; -use crate::{ - errors::MedRecordError, - medrecord::{ - datatypes::{ - Abs, Ceil, Contains, EndsWith, Floor, Lowercase, Mod, PartialNeq, Pow, Round, Slice, - Sqrt, StartsWith, Trim, TrimEnd, TrimStart, Uppercase, - }, - Attributes, MedRecord, MedRecordAttribute, MedRecordValue, - }, -}; - -macro_rules! implement_attribute_evaluate { - ($name: ident, $evaluate: ident) => { - fn $name<'a, P>( - node_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - value_operand: ValueOperand, - attributes_for_index_fn: P, - ) -> impl Iterator - where - P: Fn(&Self::IndexType) -> Result<&'a Attributes, MedRecordError>, - Self::IndexType: 'a, - { - node_indices.filter(move |index| { - let Ok(attributes) = attributes_for_index_fn(index) else { - return false; - }; - - let Some(value) = attributes.get(&attribute_operand) else { - return false; - }; - - match &value_operand { - ValueOperand::Value(value_operand) => value.$evaluate(value_operand), - ValueOperand::Evaluate(value_attribute) => { - let Some(other) = attributes.get(&value_attribute) else { - return false; - }; - - value.$evaluate(other) - } - ValueOperand::ArithmeticOperation( - operation, - value_attribute, - value_operand, - ) => { - let Some(other) = attributes.get(&value_attribute) else { - return false; - }; - - let operation = match operation { - ArithmeticOperation::Addition => other.clone() + value_operand.clone(), - ArithmeticOperation::Subtraction => { - other.clone() - value_operand.clone() - } - ArithmeticOperation::Multiplication => { - other.clone() * value_operand.clone() - } - ArithmeticOperation::Division => other.clone() / value_operand.clone(), - ArithmeticOperation::Power => other.clone().pow(value_operand.clone()), - ArithmeticOperation::Modulo => { - other.clone().r#mod(value_operand.clone()) - } - }; - - match operation { - Ok(operation) => value.$evaluate(&operation), - Err(_) => false, - } - } - ValueOperand::TransformationOperation(operation, value_attribute) => { - let Some(other) = attributes.get(&value_attribute) else { - return false; - }; - - let operation = match operation { - TransformationOperation::Round => other.clone().round(), - TransformationOperation::Ceil => other.clone().ceil(), - TransformationOperation::Floor => other.clone().floor(), - TransformationOperation::Abs => other.clone().abs(), - TransformationOperation::Sqrt => other.clone().sqrt(), - TransformationOperation::Trim => other.clone().trim(), - TransformationOperation::TrimStart => other.clone().trim_start(), - TransformationOperation::TrimEnd => other.clone().trim_end(), - TransformationOperation::Lowercase => other.clone().lowercase(), - TransformationOperation::Uppercase => other.clone().uppercase(), - }; - - value.$evaluate(&operation) - } - ValueOperand::Slice(value_attribute, range) => { - let Some(other) = attributes.get(&value_attribute) else { - return false; - }; - - value.$evaluate(&other.clone().slice(range.clone())) - } - } - }) - } - }; -} - -macro_rules! implement_index_evaluate { - ($name: ident, $evaluate: ident) => { - fn $name<'a>( - indices: impl Iterator, - operand: Self::IndexType, - ) -> impl Iterator - where - Self::IndexType: 'a, - { - indices.filter(move |index| (*index).$evaluate(&operand)) - } - }; -} - -pub(super) trait Operation: Sized { - type IndexType: PartialEq + PartialNeq + PartialOrd; - - fn evaluate_and<'a>( - medrecord: &'a MedRecord, - indices: Vec<&'a Self::IndexType>, - operation1: Self, - operation2: Self, - ) -> impl Iterator { - let operation1_indices = operation1 - .evaluate(medrecord, indices.clone().into_iter()) - .collect::>(); - let operation2_indices = operation2 - .evaluate(medrecord, indices.clone().into_iter()) - .collect::>(); - - indices.into_iter().filter(move |index| { - operation1_indices.contains(index) && operation2_indices.contains(index) - }) - } - - fn evaluate_or<'a>( - medrecord: &'a MedRecord, - indices: Vec<&'a Self::IndexType>, - operation1: Self, - operation2: Self, - ) -> impl Iterator { - let operation1_indices = operation1 - .evaluate(medrecord, indices.clone().into_iter()) - .collect::>(); - let operation2_indices = operation2 - .evaluate(medrecord, indices.clone().into_iter()) - .collect::>(); - - indices.into_iter().filter(move |index| { - operation1_indices.contains(index) || operation2_indices.contains(index) - }) - } - - fn evaluate_not<'a>( - medrecord: &'a MedRecord, - indices: Vec<&'a Self::IndexType>, - operation: Self, - ) -> impl Iterator { - let operation_indices = operation - .evaluate(medrecord, indices.clone().into_iter()) - .collect::>(); - - indices - .into_iter() - .filter(move |index| !operation_indices.contains(index)) - } - - fn evaluate_attribute_in<'a, P>( - node_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - value_operands: Vec, - attributes_for_index_fn: P, - ) -> impl Iterator - where - P: Fn(&Self::IndexType) -> Result<&'a Attributes, MedRecordError>, - Self::IndexType: 'a, - { - node_indices.filter(move |index| { - let Ok(attributes) = attributes_for_index_fn(index) else { - return false; - }; - - let Some(value) = attributes.get(&attribute_operand) else { - return false; - }; - - value_operands.contains(value) - }) - } - - fn evaluate_attribute_not_in<'a, P>( - node_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - value_operands: Vec, - attributes_for_index_fn: P, - ) -> impl Iterator - where - P: Fn(&Self::IndexType) -> Result<&'a Attributes, MedRecordError>, - Self::IndexType: 'a, - { - node_indices.filter(move |index| { - let Ok(attributes) = attributes_for_index_fn(index) else { - return false; - }; - - let Some(value) = attributes.get(&attribute_operand) else { - return false; - }; - - !value_operands.contains(value) - }) - } - - implement_attribute_evaluate!(evaluate_attribute_gt, gt); - implement_attribute_evaluate!(evaluate_attribute_lt, lt); - implement_attribute_evaluate!(evaluate_attribute_gte, ge); - implement_attribute_evaluate!(evaluate_attribute_lte, le); - implement_attribute_evaluate!(evaluate_attribute_eq, eq); - implement_attribute_evaluate!(evaluate_attribute_neq, neq); - implement_attribute_evaluate!(evaluate_attribute_starts_with, starts_with); - implement_attribute_evaluate!(evaluate_attribute_ends_with, ends_with); - implement_attribute_evaluate!(evaluate_attribute_contains, contains); - - fn evaluate_has_attribute<'a, P>( - indices: impl Iterator, - attribute_operand: MedRecordAttribute, - attributes_for_index_fn: P, - ) -> impl Iterator - where - P: Fn(&Self::IndexType) -> Result<&'a Attributes, MedRecordError>, - Self::IndexType: 'a, - { - indices.filter(move |index| { - let Ok(attributes) = attributes_for_index_fn(index) else { - return false; - }; - - attributes.contains_key(&attribute_operand) - }) - } - - fn evaluate_attribute<'a, P>( - indices: impl Iterator + 'a, - operation: AttributeOperation, - attributes_for_index_fn: P, - ) -> Box + 'a> - where - P: Fn(&Self::IndexType) -> Result<&'a Attributes, MedRecordError> + 'a, - Self: 'a, - { - match operation { - AttributeOperation::Gt(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_gt( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Lt(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_lt( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Gte(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_gte( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Lte(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_lte( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Eq(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_eq( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Neq(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_neq( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::In(attribute_operand, value_operands) => { - Box::new(Self::evaluate_attribute_in( - indices, - attribute_operand, - value_operands, - attributes_for_index_fn, - )) - } - AttributeOperation::NotIn(attribute_operand, value_operands) => { - Box::new(Self::evaluate_attribute_not_in( - indices, - attribute_operand, - value_operands, - attributes_for_index_fn, - )) - } - AttributeOperation::StartsWith(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_starts_with( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::EndsWith(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_ends_with( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Contains(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_contains( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - } - } - - implement_index_evaluate!(evaluate_index_gt, gt); - implement_index_evaluate!(evaluate_index_lt, lt); - implement_index_evaluate!(evaluate_index_gte, ge); - implement_index_evaluate!(evaluate_index_lte, le); - implement_index_evaluate!(evaluate_index_eq, eq); - - fn evaluate_index_in<'a>( - indices: impl Iterator, - operands: Vec, - ) -> impl Iterator - where - Self::IndexType: 'a, - { - indices.filter(move |index| operands.contains(index)) - } - - fn evaluate<'a>( - self, - medrecord: &'a MedRecord, - indices: impl Iterator + 'a, - ) -> Box + 'a>; -} - -#[derive(Debug, Clone)] -pub enum AttributeOperation { - Gt(MedRecordAttribute, ValueOperand), - Lt(MedRecordAttribute, ValueOperand), - Gte(MedRecordAttribute, ValueOperand), - Lte(MedRecordAttribute, ValueOperand), - Eq(MedRecordAttribute, ValueOperand), - Neq(MedRecordAttribute, ValueOperand), - In(MedRecordAttribute, Vec), - NotIn(MedRecordAttribute, Vec), - StartsWith(MedRecordAttribute, ValueOperand), - EndsWith(MedRecordAttribute, ValueOperand), - Contains(MedRecordAttribute, ValueOperand), -} diff --git a/crates/medmodels-core/src/medrecord/querying/operation/node_operation.rs b/crates/medmodels-core/src/medrecord/querying/operation/node_operation.rs deleted file mode 100644 index 677db205..00000000 --- a/crates/medmodels-core/src/medrecord/querying/operation/node_operation.rs +++ /dev/null @@ -1,246 +0,0 @@ -#![allow(clippy::should_implement_trait)] - -use super::{ - edge_operation::EdgeOperation, - operand::{IntoVecNodeIndex, NodeIndexInOperand}, - AttributeOperation, Operation, -}; -use crate::medrecord::{ - datatypes::{Contains, EndsWith, StartsWith}, - MedRecord, MedRecordAttribute, NodeIndex, -}; - -macro_rules! implement_index_evaluate { - ($name: ident, $evaluate: ident) => { - fn $name<'a>( - indices: impl Iterator, - operand: NodeIndex, - ) -> impl Iterator { - indices.filter(move |index| (*index).$evaluate(&operand)) - } - }; -} - -#[derive(Debug, Clone)] -pub enum NodeIndexOperation { - Gt(NodeIndex), - Lt(NodeIndex), - Gte(NodeIndex), - Lte(NodeIndex), - Eq(NodeIndex), - In(Box), - StartsWith(NodeIndex), - EndsWith(NodeIndex), - Contains(NodeIndex), -} - -#[derive(Debug, Clone)] -pub enum NodeOperation { - Attribute(AttributeOperation), - Index(NodeIndexOperation), - - InGroup(MedRecordAttribute), - HasAttribute(MedRecordAttribute), - - HasIncomingEdgeWith(Box), - HasOutgoingEdgeWith(Box), - HasNeighborWith(Box), - HasNeighborUndirectedWith(Box), - - And(Box<(NodeOperation, NodeOperation)>), - Or(Box<(NodeOperation, NodeOperation)>), - Not(Box), -} - -impl Operation for NodeOperation { - type IndexType = NodeIndex; - - fn evaluate<'a>( - self, - medrecord: &'a MedRecord, - indices: impl Iterator + 'a, - ) -> Box + 'a> { - match self { - NodeOperation::Attribute(attribute_operation) => { - Self::evaluate_attribute(indices, attribute_operation, |index| { - medrecord.node_attributes(index) - }) - } - NodeOperation::Index(index_operation) => { - Self::evaluate_index(medrecord, indices, index_operation) - } - - NodeOperation::InGroup(attribute_operand) => Box::new(Self::evaluate_in_group( - medrecord, - indices, - attribute_operand, - )), - NodeOperation::HasAttribute(attribute_operand) => Box::new( - Self::evaluate_has_attribute(indices, attribute_operand, |index| { - medrecord.node_attributes(index) - }), - ), - - NodeOperation::HasOutgoingEdgeWith(operation) => Box::new( - Self::evaluate_has_outgoing_edge_with(medrecord, indices, *operation), - ), - NodeOperation::HasIncomingEdgeWith(operation) => Box::new( - Self::evaluate_has_incoming_edge_with(medrecord, indices, *operation), - ), - NodeOperation::HasNeighborWith(operation) => Box::new( - Self::evaluate_has_neighbor_with(medrecord, indices, *operation), - ), - NodeOperation::HasNeighborUndirectedWith(operation) => Box::new( - Self::evaluate_has_neighbor_undirected_with(medrecord, indices, *operation), - ), - - NodeOperation::And(operations) => Box::new(Self::evaluate_and( - medrecord, - indices.collect::>(), - (*operations).0, - (*operations).1, - )), - NodeOperation::Or(operations) => Box::new(Self::evaluate_or( - medrecord, - indices.collect::>(), - (*operations).0, - (*operations).1, - )), - NodeOperation::Not(operation) => Box::new(Self::evaluate_not( - medrecord, - indices.collect::>(), - *operation, - )), - } - } -} - -impl NodeOperation { - pub fn and(self, operation: NodeOperation) -> NodeOperation { - NodeOperation::And(Box::new((self, operation))) - } - - pub fn or(self, operation: NodeOperation) -> NodeOperation { - NodeOperation::Or(Box::new((self, operation))) - } - - pub fn xor(self, operation: NodeOperation) -> NodeOperation { - NodeOperation::And(Box::new((self, operation))).not() - } - - pub fn not(self) -> NodeOperation { - NodeOperation::Not(Box::new(self)) - } - - fn evaluate_index<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator + 'a, - operation: NodeIndexOperation, - ) -> Box + 'a> { - match operation { - NodeIndexOperation::Gt(operand) => { - Box::new(Self::evaluate_index_gt(node_indices, operand)) - } - NodeIndexOperation::Lt(operand) => { - Box::new(Self::evaluate_index_lt(node_indices, operand)) - } - NodeIndexOperation::Gte(operand) => { - Box::new(Self::evaluate_index_gte(node_indices, operand)) - } - NodeIndexOperation::Lte(operand) => { - Box::new(Self::evaluate_index_lte(node_indices, operand)) - } - NodeIndexOperation::Eq(operand) => { - Box::new(Self::evaluate_index_eq(node_indices, operand)) - } - NodeIndexOperation::In(operands) => Box::new(Self::evaluate_index_in( - node_indices, - operands.into_vec_node_index(medrecord), - )), - NodeIndexOperation::StartsWith(operand) => { - Box::new(Self::evaluate_index_starts_with(node_indices, operand)) - } - NodeIndexOperation::EndsWith(operand) => { - Box::new(Self::evaluate_index_ends_with(node_indices, operand)) - } - NodeIndexOperation::Contains(operand) => { - Box::new(Self::evaluate_index_contains(node_indices, operand)) - } - } - } - - implement_index_evaluate!(evaluate_index_starts_with, starts_with); - implement_index_evaluate!(evaluate_index_ends_with, ends_with); - implement_index_evaluate!(evaluate_index_contains, contains); - - fn evaluate_in_group<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - ) -> impl Iterator { - let nodes_in_group = match medrecord.nodes_in_group(&attribute_operand) { - Ok(nodes_in_group) => nodes_in_group.collect::>(), - Err(_) => Vec::new(), - }; - - node_indices.filter(move |index| nodes_in_group.contains(index)) - } - - fn evaluate_has_outgoing_edge_with<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator, - operation: EdgeOperation, - ) -> impl Iterator { - node_indices.filter(move |index| { - let Ok(edges) = medrecord.outgoing_edges(index) else { - return false; - }; - - let edge_indices = operation.clone().evaluate(medrecord, edges); - - edge_indices.count() > 0 - }) - } - - fn evaluate_has_incoming_edge_with<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator, - operation: EdgeOperation, - ) -> impl Iterator { - node_indices.filter(move |index| { - let Ok(edges) = medrecord.incoming_edges(index) else { - return false; - }; - - operation.clone().evaluate(medrecord, edges).count() > 0 - }) - } - - fn evaluate_has_neighbor_with<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator, - operation: NodeOperation, - ) -> impl Iterator { - node_indices.filter(move |index| { - let Ok(neighbors) = medrecord.neighbors(index) else { - return false; - }; - - operation.clone().evaluate(medrecord, neighbors).count() > 0 - }) - } - - fn evaluate_has_neighbor_undirected_with<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator, - operation: NodeOperation, - ) -> impl Iterator { - node_indices.filter(move |index| { - let Ok(neighbors) = medrecord.neighbors_undirected(index) else { - return false; - }; - - operation.clone().evaluate(medrecord, neighbors).count() > 0 - }) - } -} diff --git a/crates/medmodels-core/src/medrecord/querying/operation/operand.rs b/crates/medmodels-core/src/medrecord/querying/operation/operand.rs deleted file mode 100644 index c7b7849e..00000000 --- a/crates/medmodels-core/src/medrecord/querying/operation/operand.rs +++ /dev/null @@ -1,649 +0,0 @@ -#![allow(clippy::should_implement_trait)] - -use super::{ - edge_operation::EdgeIndexOperation, - node_operation::{NodeIndexOperation, NodeOperation}, - AttributeOperation, EdgeOperation, Operation, -}; -use crate::medrecord::{ - EdgeIndex, Group, MedRecord, MedRecordAttribute, MedRecordValue, NodeIndex, -}; -use std::{fmt::Debug, ops::Range}; - -#[derive(Debug, Clone)] -pub enum ArithmeticOperation { - Addition, - Subtraction, - Multiplication, - Division, - Power, - Modulo, -} - -#[derive(Debug, Clone)] -pub enum TransformationOperation { - Round, - Ceil, - Floor, - Abs, - Sqrt, - - Trim, - TrimStart, - TrimEnd, - - Lowercase, - Uppercase, -} - -#[derive(Debug, Clone)] -pub enum ValueOperand { - Value(MedRecordValue), - Evaluate(MedRecordAttribute), - ArithmeticOperation(ArithmeticOperation, MedRecordAttribute, MedRecordValue), - TransformationOperation(TransformationOperation, MedRecordAttribute), - Slice(MedRecordAttribute, Range), -} - -pub trait IntoValueOperand { - fn into_value_operand(self) -> ValueOperand; -} - -impl> IntoValueOperand for T { - fn into_value_operand(self) -> ValueOperand { - ValueOperand::Value(self.into()) - } -} -impl IntoValueOperand for NodeAttributeOperand { - fn into_value_operand(self) -> ValueOperand { - ValueOperand::Evaluate(self.into()) - } -} -impl IntoValueOperand for EdgeAttributeOperand { - fn into_value_operand(self) -> ValueOperand { - ValueOperand::Evaluate(self.into()) - } -} -impl IntoValueOperand for ValueOperand { - fn into_value_operand(self) -> ValueOperand { - self - } -} - -#[derive(Debug, Clone)] -pub struct NodeAttributeOperand(MedRecordAttribute); - -impl From for NodeAttributeOperand { - fn from(value: MedRecordAttribute) -> Self { - NodeAttributeOperand(value) - } -} - -impl From for MedRecordAttribute { - fn from(val: NodeAttributeOperand) -> Self { - val.0 - } -} - -impl NodeAttributeOperand { - pub fn greater(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Gt( - self.into(), - operand.into_value_operand(), - )) - } - pub fn less(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Lt( - self.into(), - operand.into_value_operand(), - )) - } - pub fn greater_or_equal(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Gte( - self.into(), - operand.into_value_operand(), - )) - } - pub fn less_or_equal(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Lte( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn equal(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Eq( - self.into(), - operand.into_value_operand(), - )) - } - pub fn not_equal(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Neq( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn r#in(self, operand: Vec>) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::In( - self.into(), - operand.into_iter().map(|value| value.into()).collect(), - )) - } - pub fn not_in(self, operand: Vec>) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::NotIn( - self.into(), - operand.into_iter().map(|value| value.into()).collect(), - )) - } - - pub fn starts_with(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::StartsWith( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn ends_with(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::EndsWith( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn contains(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Contains( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn add(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Addition, self.into(), value.into()) - } - - pub fn sub(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation( - ArithmeticOperation::Subtraction, - self.into(), - value.into(), - ) - } - - pub fn mul(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation( - ArithmeticOperation::Multiplication, - self.into(), - value.into(), - ) - } - - pub fn div(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Division, self.into(), value.into()) - } - - pub fn pow(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Power, self.into(), value.into()) - } - - pub fn r#mod(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Modulo, self.into(), value.into()) - } - - pub fn round(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Round, self.into()) - } - - pub fn ceil(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Ceil, self.into()) - } - - pub fn floor(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Floor, self.into()) - } - - pub fn abs(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Abs, self.into()) - } - - pub fn sqrt(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Sqrt, self.into()) - } - - pub fn trim(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Trim, self.into()) - } - - pub fn trim_start(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::TrimStart, self.into()) - } - - pub fn trim_end(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::TrimEnd, self.into()) - } - - pub fn lowercase(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Lowercase, self.into()) - } - - pub fn uppercase(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Uppercase, self.into()) - } - - pub fn slice(self, range: Range) -> ValueOperand { - ValueOperand::Slice(self.into(), range) - } -} - -#[derive(Debug, Clone)] -pub struct EdgeAttributeOperand(MedRecordAttribute); - -impl From for MedRecordAttribute { - fn from(val: EdgeAttributeOperand) -> Self { - val.0 - } -} - -impl EdgeAttributeOperand { - pub fn greater(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Gt( - self.into(), - operand.into_value_operand(), - )) - } - pub fn less(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Lt( - self.into(), - operand.into_value_operand(), - )) - } - pub fn greater_or_equal(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Gte( - self.into(), - operand.into_value_operand(), - )) - } - pub fn less_or_equal(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Lte( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn equal(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Eq( - self.into(), - operand.into_value_operand(), - )) - } - pub fn not_equal(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Neq( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn r#in(self, operand: Vec>) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::In( - self.into(), - operand.into_iter().map(|value| value.into()).collect(), - )) - } - pub fn not_in(self, operand: Vec>) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::NotIn( - self.into(), - operand.into_iter().map(|value| value.into()).collect(), - )) - } - - pub fn starts_with(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::StartsWith( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn ends_with(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::EndsWith( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn contains(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Contains( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn add(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Addition, self.into(), value.into()) - } - - pub fn sub(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation( - ArithmeticOperation::Subtraction, - self.into(), - value.into(), - ) - } - - pub fn mul(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation( - ArithmeticOperation::Multiplication, - self.into(), - value.into(), - ) - } - - pub fn div(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Division, self.into(), value.into()) - } - - pub fn pow(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Power, self.into(), value.into()) - } - - pub fn r#mod(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Modulo, self.into(), value.into()) - } - - pub fn round(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Round, self.into()) - } - - pub fn ceil(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Ceil, self.into()) - } - - pub fn floor(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Floor, self.into()) - } - - pub fn abs(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Abs, self.into()) - } - - pub fn sqrt(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Sqrt, self.into()) - } - - pub fn trim(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Trim, self.into()) - } - - pub fn trim_start(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::TrimStart, self.into()) - } - - pub fn trim_end(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::TrimEnd, self.into()) - } - - pub fn lowercase(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Lowercase, self.into()) - } - - pub fn uppercase(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Uppercase, self.into()) - } - - pub fn slice(self, range: Range) -> ValueOperand { - ValueOperand::Slice(self.into(), range) - } -} - -#[derive(Debug, Clone)] -pub enum NodeIndexInOperand { - Vector(Vec), - Operation(NodeOperation), -} - -impl From> for NodeIndexInOperand -where - T: Into, -{ - fn from(value: Vec) -> NodeIndexInOperand { - NodeIndexInOperand::Vector(value.into_iter().map(|value| value.into()).collect()) - } -} - -impl From for NodeIndexInOperand { - fn from(value: NodeOperation) -> Self { - NodeIndexInOperand::Operation(value) - } -} - -pub(super) trait IntoVecNodeIndex { - fn into_vec_node_index(self, medrecord: &MedRecord) -> Vec; -} - -impl IntoVecNodeIndex for NodeIndexInOperand { - fn into_vec_node_index(self, medrecord: &MedRecord) -> Vec { - match self { - NodeIndexInOperand::Vector(value) => value, - NodeIndexInOperand::Operation(operation) => operation - .evaluate(medrecord, medrecord.node_indices()) - .cloned() - .collect(), - } - } -} - -#[derive(Debug, Clone)] -pub struct NodeIndexOperand; - -impl NodeIndexOperand { - pub fn greater(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Gt(operand.into())) - } - pub fn less(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Lt(operand.into())) - } - pub fn greater_or_equal(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Gte(operand.into())) - } - pub fn less_or_equal(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Lte(operand.into())) - } - - pub fn equal(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Eq(operand.into())) - } - pub fn not_equal(self, operand: impl Into) -> NodeOperation { - self.equal(operand).not() - } - - pub fn r#in(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::In(Box::new(operand.into()))) - } - pub fn not_in(self, operand: impl Into) -> NodeOperation { - self.r#in(operand).not() - } - - pub fn starts_with(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::StartsWith(operand.into())) - } - - pub fn ends_with(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::EndsWith(operand.into())) - } - - pub fn contains(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Contains(operand.into())) - } -} - -#[derive(Debug, Clone)] -pub struct NodeOperand; - -impl NodeOperand { - pub fn in_group(self, operand: impl Into) -> NodeOperation { - NodeOperation::InGroup(operand.into()) - } - - pub fn has_attribute(self, operand: impl Into) -> NodeOperation { - NodeOperation::HasAttribute(operand.into()) - } - - pub fn has_outgoing_edge_with(self, operation: EdgeOperation) -> NodeOperation { - NodeOperation::HasOutgoingEdgeWith(operation.into()) - } - pub fn has_incoming_edge_with(self, operation: EdgeOperation) -> NodeOperation { - NodeOperation::HasIncomingEdgeWith(operation.into()) - } - pub fn has_edge_with(self, operation: EdgeOperation) -> NodeOperation { - NodeOperation::HasOutgoingEdgeWith(operation.clone().into()) - .or(NodeOperation::HasIncomingEdgeWith(operation.into())) - } - - pub fn has_neighbor_with(self, operation: NodeOperation) -> NodeOperation { - NodeOperation::HasNeighborWith(Box::new(operation)) - } - pub fn has_neighbor_undirected_with(self, operation: NodeOperation) -> NodeOperation { - NodeOperation::HasNeighborUndirectedWith(Box::new(operation)) - } - - pub fn attribute(self, attribute: impl Into) -> NodeAttributeOperand { - NodeAttributeOperand(attribute.into()) - } - - pub fn index(self) -> NodeIndexOperand { - NodeIndexOperand - } -} - -pub fn node() -> NodeOperand { - NodeOperand -} - -#[derive(Debug, Clone)] -pub enum EdgeIndexInOperand { - Vector(Vec), - Operation(EdgeOperation), -} - -impl> From> for EdgeIndexInOperand { - fn from(value: Vec) -> EdgeIndexInOperand { - EdgeIndexInOperand::Vector(value.into_iter().map(|value| value.into()).collect()) - } -} - -impl From for EdgeIndexInOperand { - fn from(value: EdgeOperation) -> Self { - EdgeIndexInOperand::Operation(value) - } -} - -pub(super) trait IntoVecEdgeIndex { - fn into_vec_edge_index(self, medrecord: &MedRecord) -> Vec; -} - -impl IntoVecEdgeIndex for EdgeIndexInOperand { - fn into_vec_edge_index(self, medrecord: &MedRecord) -> Vec { - match self { - EdgeIndexInOperand::Vector(value) => value, - EdgeIndexInOperand::Operation(operation) => operation - .evaluate(medrecord, medrecord.edge_indices()) - .copied() - .collect(), - } - } -} - -#[derive(Debug, Clone)] -pub struct EdgeIndexOperand; - -impl EdgeIndexOperand { - pub fn greater(self, operand: EdgeIndex) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::Gt(operand)) - } - pub fn less(self, operand: EdgeIndex) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::Lt(operand)) - } - pub fn greater_or_equal(self, operand: EdgeIndex) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::Gte(operand)) - } - pub fn less_or_equal(self, operand: EdgeIndex) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::Lte(operand)) - } - - pub fn equal(self, operand: EdgeIndex) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::Eq(operand)) - } - pub fn not_equal(self, operand: EdgeIndex) -> EdgeOperation { - self.equal(operand).not() - } - - pub fn r#in(self, operand: impl Into) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::In(Box::new(operand.into()))) - } - pub fn not_in(self, operand: impl Into) -> EdgeOperation { - self.r#in(operand).not() - } -} - -#[derive(Debug, Clone)] -pub struct EdgeOperand; - -impl EdgeOperand { - pub fn connected_target(self, operand: impl Into) -> EdgeOperation { - EdgeOperation::ConnectedSource(operand.into()) - } - - pub fn connected_source(self, operand: impl Into) -> EdgeOperation { - EdgeOperation::ConnectedTarget(operand.into()) - } - - pub fn connected(self, operand: impl Into) -> EdgeOperation { - let attribute = operand.into(); - - EdgeOperation::ConnectedSource(attribute.clone()) - .or(EdgeOperation::ConnectedTarget(attribute)) - } - - pub fn in_group(self, operand: impl Into) -> EdgeOperation { - EdgeOperation::InGroup(operand.into()) - } - - pub fn has_attribute(self, operand: impl Into) -> EdgeOperation { - EdgeOperation::HasAttribute(operand.into()) - } - - pub fn connected_source_with(self, operation: NodeOperation) -> EdgeOperation { - EdgeOperation::ConnectedSourceWith(operation.into()) - } - - pub fn connected_target_with(self, operation: NodeOperation) -> EdgeOperation { - EdgeOperation::ConnectedTargetWith(operation.into()) - } - - pub fn connected_with(self, operation: NodeOperation) -> EdgeOperation { - EdgeOperation::ConnectedSourceWith(operation.clone().into()) - .or(EdgeOperation::ConnectedTargetWith(operation.into())) - } - - pub fn has_parallel_edges_with(self, operation: EdgeOperation) -> EdgeOperation { - EdgeOperation::HasParallelEdgesWith(Box::new(operation)) - } - - pub fn has_parallel_edges_with_self_comparison( - self, - operation: EdgeOperation, - ) -> EdgeOperation { - EdgeOperation::HasParallelEdgesWithSelfComparison(Box::new(operation)) - } - - pub fn attribute(self, attribute: impl Into) -> EdgeAttributeOperand { - EdgeAttributeOperand(attribute.into()) - } - - pub fn index(self) -> EdgeIndexOperand { - EdgeIndexOperand - } -} - -pub fn edge() -> EdgeOperand { - EdgeOperand -} diff --git a/crates/medmodels-core/src/medrecord/querying/selection.rs b/crates/medmodels-core/src/medrecord/querying/selection.rs deleted file mode 100644 index 82e8356e..00000000 --- a/crates/medmodels-core/src/medrecord/querying/selection.rs +++ /dev/null @@ -1,1741 +0,0 @@ -use super::operation::{EdgeOperation, NodeOperation, Operation}; -use crate::medrecord::{EdgeIndex, MedRecord, NodeIndex}; - -#[derive(Debug)] -pub struct NodeSelection<'a> { - medrecord: &'a MedRecord, - operation: NodeOperation, -} - -impl<'a> NodeSelection<'a> { - pub fn new(medrecord: &'a MedRecord, operation: NodeOperation) -> Self { - Self { - medrecord, - operation, - } - } - - pub fn iter(self) -> impl Iterator { - self.operation - .evaluate(self.medrecord, self.medrecord.node_indices()) - } - - pub fn collect>(self) -> B { - FromIterator::from_iter(self.iter()) - } -} - -#[derive(Debug)] -pub struct EdgeSelection<'a> { - medrecord: &'a MedRecord, - operation: EdgeOperation, -} - -impl<'a> EdgeSelection<'a> { - pub fn new(medrecord: &'a MedRecord, operation: EdgeOperation) -> Self { - Self { - medrecord, - operation, - } - } - - pub fn iter(self) -> impl Iterator { - self.operation - .evaluate(self.medrecord, self.medrecord.edge_indices()) - } - - pub fn collect>(self) -> B { - FromIterator::from_iter(self.iter()) - } -} - -#[cfg(test)] -mod test { - use crate::medrecord::{edge, node, Attributes, MedRecord, MedRecordAttribute, NodeIndex}; - use std::collections::HashMap; - - fn create_nodes() -> Vec<(NodeIndex, Attributes)> { - vec![ - ( - "0".into(), - HashMap::from([ - ("lorem".into(), "ipsum".into()), - ("dolor".into(), " ipsum ".into()), - ("test".into(), "Ipsum".into()), - ("integer".into(), 1.into()), - ("float".into(), 0.5.into()), - ]), - ), - ( - "1".into(), - HashMap::from([("amet".into(), "consectetur".into())]), - ), - ( - "2".into(), - HashMap::from([("adipiscing".into(), "elit".into())]), - ), - ("3".into(), HashMap::new()), - ] - } - - fn create_edges() -> Vec<(NodeIndex, NodeIndex, Attributes)> { - vec![ - ( - "0".into(), - "1".into(), - HashMap::from([ - ("sed".into(), "do".into()), - ("eiusmod".into(), "tempor".into()), - ("dolor".into(), " do ".into()), - ("test".into(), "DO".into()), - ]), - ), - ( - "1".into(), - "2".into(), - HashMap::from([("incididunt".into(), "ut".into())]), - ), - ( - "0".into(), - "2".into(), - HashMap::from([ - ("test".into(), 1.into()), - ("integer".into(), 1.into()), - ("float".into(), 0.5.into()), - ]), - ), - ( - "0".into(), - "2".into(), - HashMap::from([("test".into(), 0.into())]), - ), - ] - } - - fn create_medrecord() -> MedRecord { - let nodes = create_nodes(); - let edges = create_edges(); - - MedRecord::from_tuples(nodes, Some(edges), None).unwrap() - } - - #[test] - fn test_iter() { - let medrecord = create_medrecord(); - - assert_eq!( - 1, - medrecord - .select_nodes(node().has_attribute("lorem")) - .iter() - .count(), - ); - - assert_eq!( - 1, - medrecord - .select_edges(edge().has_attribute("sed")) - .iter() - .count(), - ); - } - - #[test] - fn test_collect() { - let medrecord = create_medrecord(); - - assert_eq!( - vec![&MedRecordAttribute::from("0")], - medrecord - .select_nodes(node().has_attribute("lorem")) - .collect::>(), - ); - - assert_eq!( - vec![&0], - medrecord - .select_edges(edge().has_attribute("sed")) - .collect::>(), - ); - } - - #[test] - fn test_select_nodes_node() { - let mut medrecord = create_medrecord(); - - medrecord - .add_group("test".into(), Some(vec!["0".into()]), None) - .unwrap(); - - // Node in group - assert_eq!( - 1, - medrecord - .select_nodes(node().in_group("test")) - .iter() - .count(), - ); - - // Node has attribute - assert_eq!( - 1, - medrecord - .select_nodes(node().has_attribute("lorem")) - .iter() - .count(), - ); - - // Node has outgoing edge with - assert_eq!( - 1, - medrecord - .select_nodes(node().has_outgoing_edge_with(edge().index().equal(0))) - .iter() - .count(), - ); - - // Node has incoming edge with - assert_eq!( - 1, - medrecord - .select_nodes(node().has_incoming_edge_with(edge().index().equal(0))) - .iter() - .count(), - ); - - // Node has edge with - assert_eq!( - 2, - medrecord - .select_nodes(node().has_edge_with(edge().index().equal(0))) - .iter() - .count(), - ); - - // Node has neighbor with - assert_eq!( - 2, - medrecord - .select_nodes(node().has_neighbor_with(node().index().equal("2"))) - .iter() - .count(), - ); - assert_eq!( - 1, - medrecord - .select_nodes(node().has_neighbor_with(node().index().equal("1"))) - .iter() - .count(), - ); - - // Node has undirected neighbor with - assert_eq!( - 2, - medrecord - .select_nodes(node().has_neighbor_undirected_with(node().index().equal("1"))) - .iter() - .count(), - ); - } - - #[test] - fn test_select_nodes_node_index() { - let medrecord = create_medrecord(); - - // Index greater - assert_eq!( - 2, - medrecord - .select_nodes(node().index().greater("1")) - .iter() - .count(), - ); - - // Index less - assert_eq!( - 1, - medrecord - .select_nodes(node().index().less("1")) - .iter() - .count(), - ); - - // Index greater or equal - assert_eq!( - 3, - medrecord - .select_nodes(node().index().greater_or_equal("1")) - .iter() - .count(), - ); - - // Index less or equal - assert_eq!( - 2, - medrecord - .select_nodes(node().index().less_or_equal("1")) - .iter() - .count(), - ); - - // Index equal - assert_eq!( - 1, - medrecord - .select_nodes(node().index().equal("1")) - .iter() - .count(), - ); - - // Index not equal - assert_eq!( - 3, - medrecord - .select_nodes(node().index().not_equal("1")) - .iter() - .count(), - ); - - // Index in - assert_eq!( - 1, - medrecord - .select_nodes(node().index().r#in(vec!["1"])) - .iter() - .count(), - ); - - // Index in - assert_eq!( - 1, - medrecord - .select_nodes(node().index().r#in(node().has_attribute("lorem"))) - .iter() - .count(), - ); - - // Index not in - assert_eq!( - 3, - medrecord - .select_nodes(node().index().not_in(node().has_attribute("lorem"))) - .iter() - .count(), - ); - - // Index starts with - assert_eq!( - 1, - medrecord - .select_nodes(node().index().starts_with("1")) - .iter() - .count(), - ); - - // Index ends with - assert_eq!( - 1, - medrecord - .select_nodes(node().index().ends_with("1")) - .iter() - .count(), - ); - - // Index contains - assert_eq!( - 1, - medrecord - .select_nodes(node().index().contains("1")) - .iter() - .count(), - ); - } - - #[test] - fn test_select_nodes_node_attribute() { - let medrecord = create_medrecord(); - - // Attribute greater - assert_eq!( - 0, - medrecord - .select_nodes(node().attribute("lorem").greater("ipsum")) - .iter() - .count(), - ); - - // Attribute less - assert_eq!( - 0, - medrecord - .select_nodes(node().attribute("lorem").less("ipsum")) - .iter() - .count(), - ); - - // Attribute greater or equal - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").greater_or_equal("ipsum")) - .iter() - .count(), - ); - - // Attribute less or equal - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").less_or_equal("ipsum")) - .iter() - .count(), - ); - - // Attribute equal - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").equal("ipsum")) - .iter() - .count(), - ); - - // Attribute not equal - assert_eq!( - 0, - medrecord - .select_nodes(node().attribute("lorem").not_equal("ipsum")) - .iter() - .count(), - ); - - // Attribute in - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").r#in(vec!["ipsum"])) - .iter() - .count(), - ); - - // Attribute not in - assert_eq!( - 0, - medrecord - .select_nodes(node().attribute("lorem").not_in(vec!["ipsum"])) - .iter() - .count(), - ); - - // Attribute starts with - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").starts_with("ip")) - .iter() - .count(), - ); - - // Attribute ends with - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").ends_with("um")) - .iter() - .count(), - ); - - // Attribute contains - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").contains("su")) - .iter() - .count(), - ); - - // Attribute compare to attribute - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").equal(node().attribute("lorem"))) - .iter() - .count(), - ); - - // Attribute compare to attribute - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute add - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").add("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute add - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").add("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - // Returns nothing because can't sub a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").sub("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - // Doesn't work because can't sub a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").sub("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").sub(10)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").sub(10)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mul - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").mul(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mul - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").mul(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - // Returns nothing because can't div a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").div("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - // Doesn't work because can't div a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").div("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").div(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").div(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - // Returns nothing because can't pow a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").pow("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - // Doesn't work because can't pow a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").pow("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").pow(2)) // 1 ** 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").pow(2)) // 1 ** 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - // Returns nothing because can't mod a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").r#mod("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - // Doesn't work because can't mod a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").r#mod("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").r#mod(2)) // 1 % 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").r#mod(2)) // 1 % 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("float").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("float") - .not_equal(node().attribute("float").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute ceil - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("float").ceil()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute ceil - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("float") - .not_equal(node().attribute("float").ceil()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute floor - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("float").floor()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute floor - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("float") - .not_equal(node().attribute("float").floor()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute abs - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").abs()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sqrt - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").sqrt()) // sqrt(1) = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("dolor").trim()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim_start - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("dolor").trim_start()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim_end - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("dolor").trim_end()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute lowercase - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("test").lowercase()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute uppercase - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("test").uppercase()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute slice - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("dolor").slice(2..7)) - ) - .iter() - .count(), - ); - } - - #[test] - fn test_select_edges_edge() { - let mut medrecord = create_medrecord(); - - medrecord - .add_group("test".into(), None, Some(vec![0])) - .unwrap(); - - // Edge connected to target - assert_eq!( - 3, - medrecord - .select_edges(edge().connected_target("2")) - .iter() - .count(), - ); - - // Edge connected to source - assert_eq!( - 3, - medrecord - .select_edges(edge().connected_source("0")) - .iter() - .count(), - ); - - // Edge connected - assert_eq!( - 2, - medrecord.select_edges(edge().connected("1")).iter().count(), - ); - - // Edge in group - assert_eq!( - 1, - medrecord - .select_edges(edge().in_group("test")) - .iter() - .count(), - ); - - // Edge has attribute - assert_eq!( - 1, - medrecord - .select_edges(edge().has_attribute("sed")) - .iter() - .count(), - ); - - // Edge connected to target with - assert_eq!( - 1, - medrecord - .select_edges(edge().connected_target_with(node().index().equal("1"))) - .iter() - .count(), - ); - - // Edge connected to source with - assert_eq!( - 3, - medrecord - .select_edges(edge().connected_source_with(node().index().equal("0"))) - .iter() - .count(), - ); - - // Edge connected with - assert_eq!( - 2, - medrecord - .select_edges(edge().connected_with(node().index().equal("1"))) - .iter() - .count(), - ); - - // Edge has parallel edges with - assert_eq!( - 2, - medrecord - .select_edges(edge().has_parallel_edges_with(edge().has_attribute("test"))) - .iter() - .count(), - ); - - // Edge has parallel edges with self comparison - assert_eq!( - 1, - medrecord - .select_edges( - edge().has_parallel_edges_with_self_comparison( - edge() - .attribute("test") - .equal(edge().attribute("test").sub(1)) - ) - ) - .iter() - .count(), - ); - } - - #[test] - fn test_select_edges_edge_index() { - let medrecord = create_medrecord(); - - // Index greater - assert_eq!( - 2, - medrecord - .select_edges(edge().index().greater(1)) - .iter() - .count(), - ); - - // Index less - assert_eq!( - 1, - medrecord - .select_edges(edge().index().less(1)) - .iter() - .count(), - ); - - // Index greater or equal - assert_eq!( - 3, - medrecord - .select_edges(edge().index().greater_or_equal(1)) - .iter() - .count(), - ); - - // Index less or equal - assert_eq!( - 2, - medrecord - .select_edges(edge().index().less_or_equal(1)) - .iter() - .count(), - ); - - // Index equal - assert_eq!( - 1, - medrecord - .select_edges(edge().index().equal(1)) - .iter() - .count(), - ); - - // Index not equal - assert_eq!( - 3, - medrecord - .select_edges(edge().index().not_equal(1)) - .iter() - .count(), - ); - - // Index in - assert_eq!( - 1, - medrecord - .select_edges(edge().index().r#in(vec![1_usize])) - .iter() - .count(), - ); - - // Index in - assert_eq!( - 1, - medrecord - .select_edges(edge().index().r#in(edge().has_attribute("sed"))) - .iter() - .count(), - ); - - // Index not in - assert_eq!( - 3, - medrecord - .select_edges(edge().index().not_in(edge().has_attribute("sed"))) - .iter() - .count(), - ); - } - - #[test] - fn test_select_edges_edge_attribute() { - let medrecord = create_medrecord(); - - // Attribute greater - assert_eq!( - 0, - medrecord - .select_edges(edge().attribute("sed").greater("do")) - .iter() - .count(), - ); - - // Attribute less - assert_eq!( - 0, - medrecord - .select_edges(edge().attribute("sed").less("do")) - .iter() - .count(), - ); - - // Attribute greater or equal - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").greater_or_equal("do")) - .iter() - .count(), - ); - - // Attribute less or equal - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").less_or_equal("do")) - .iter() - .count(), - ); - - // Attribute equal - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").equal("do")) - .iter() - .count(), - ); - - // Attribute not equal - assert_eq!( - 0, - medrecord - .select_edges(edge().attribute("sed").not_equal("do")) - .iter() - .count(), - ); - - // Attribute in - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").r#in(vec!["do"])) - .iter() - .count(), - ); - - // Attribute not in - assert_eq!( - 0, - medrecord - .select_edges(edge().attribute("sed").not_in(vec!["do"])) - .iter() - .count(), - ); - - // Attribute starts with - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").starts_with("d")) - .iter() - .count(), - ); - - // Attribute ends with - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").ends_with("o")) - .iter() - .count(), - ); - - // Attribute contains - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").contains("do")) - .iter() - .count(), - ); - - // Attribute compare to attribute - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").equal(edge().attribute("sed"))) - .iter() - .count(), - ); - - // Attribute compare to attribute - assert_eq!( - 0, - medrecord - .select_edges(edge().attribute("sed").not_equal(edge().attribute("sed"))) - .iter() - .count(), - ); - - // Attribute compare to attribute add - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("sed").add("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute add - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .not_equal(edge().attribute("sed").add("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - // Returns nothing because can't sub a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("sed").sub("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - // Doesn't work because can't sub a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .not_equal(edge().attribute("sed").sub("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").sub(10)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").sub(10)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mul - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("sed").mul(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mul - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .not_equal(edge().attribute("sed").mul(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - // Returns nothing because can't div a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("sed").div("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - // Doesn't work because can't div a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .not_equal(edge().attribute("sed").div("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").div(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").div(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - // Returns nothing because can't pow a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("lorem") - .equal(edge().attribute("lorem").pow("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - // Doesn't work because can't pow a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("lorem") - .not_equal(edge().attribute("lorem").pow("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").pow(2)) // 1 ** 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").pow(2)) // 1 ** 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - // Returns nothing because can't mod a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("lorem") - .equal(edge().attribute("lorem").r#mod("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - // Doesn't work because can't mod a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("lorem") - .not_equal(edge().attribute("lorem").r#mod("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").r#mod(2)) // 1 % 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").r#mod(2)) // 1 % 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("sed").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .not_equal(edge().attribute("sed").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("float").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("float") - .not_equal(edge().attribute("float").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute ceil - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("float").ceil()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute ceil - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("float") - .not_equal(edge().attribute("float").ceil()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute floor - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("float").floor()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute floor - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("float") - .not_equal(edge().attribute("float").floor()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute abs - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").abs()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sqrt - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").sqrt()) // sqrt(1) = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("dolor").trim()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim_start - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("dolor").trim_start()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim_end - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("dolor").trim_end()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute lowercase - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("test").lowercase()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute uppercase - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("test").uppercase()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute slice - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("dolor").slice(2..4)) - ) - .iter() - .count(), - ); - } -} diff --git a/crates/medmodels-core/src/medrecord/querying/traits.rs b/crates/medmodels-core/src/medrecord/querying/traits.rs new file mode 100644 index 00000000..4e8d33e8 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/traits.rs @@ -0,0 +1,21 @@ +use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; + +pub trait DeepClone { + fn deep_clone(&self) -> Self; +} + +pub(crate) trait ReadWriteOrPanic { + fn read_or_panic(&self) -> RwLockReadGuard<'_, T>; + + fn write_or_panic(&self) -> RwLockWriteGuard<'_, T>; +} + +impl ReadWriteOrPanic for RwLock { + fn read_or_panic(&self) -> RwLockReadGuard<'_, T> { + self.read().unwrap() + } + + fn write_or_panic(&self) -> RwLockWriteGuard<'_, T> { + self.write().unwrap() + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/values/mod.rs b/crates/medmodels-core/src/medrecord/querying/values/mod.rs new file mode 100644 index 00000000..bf2e2f4a --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/values/mod.rs @@ -0,0 +1,185 @@ +mod operand; +mod operation; + +use super::{ + attributes::{ + self, AttributesTreeOperation, MultipleAttributesOperand, MultipleAttributesOperation, + }, + edges::{EdgeOperand, EdgeOperation}, + nodes::{NodeOperand, NodeOperation}, + BoxedIterator, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{MedRecordAttribute, MedRecordValue}, + MedRecord, +}; +pub use operand::MultipleValuesOperand; +use std::fmt::Display; + +macro_rules! get_attributes { + ($operand:ident, $medrecord:ident, $operation:ident, $multiple_attributes_operand:ident) => {{ + let indices = $operand.evaluate($medrecord)?; + + let attributes = $operation::get_attributes($medrecord, indices); + + let attributes = $multiple_attributes_operand + .context + .evaluate($medrecord, attributes)?; + + let attributes: Box> = + match $multiple_attributes_operand.kind { + attributes::MultipleKind::Max => { + Box::new(AttributesTreeOperation::get_max(attributes)?) + } + attributes::MultipleKind::Min => { + Box::new(AttributesTreeOperation::get_min(attributes)?) + } + attributes::MultipleKind::Count => { + Box::new(AttributesTreeOperation::get_count(attributes)?) + } + attributes::MultipleKind::Sum => { + Box::new(AttributesTreeOperation::get_sum(attributes)?) + } + attributes::MultipleKind::First => { + Box::new(AttributesTreeOperation::get_first(attributes)?) + } + attributes::MultipleKind::Last => { + Box::new(AttributesTreeOperation::get_last(attributes)?) + } + }; + + let attributes = $multiple_attributes_operand.evaluate($medrecord, attributes)?; + + Box::new( + MultipleAttributesOperation::get_values($medrecord, attributes)? + .map(|(_, value)| value), + ) + }}; +} + +#[derive(Debug, Clone)] +pub enum SingleKind { + Max, + Min, + Mean, + Median, + Mode, + Std, + Var, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum SingleComparisonKind { + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + NotEqualTo, + StartsWith, + EndsWith, + Contains, +} + +#[derive(Debug, Clone)] +pub enum MultipleComparisonKind { + IsIn, + IsNotIn, +} + +#[derive(Debug, Clone)] +pub enum BinaryArithmeticKind { + Add, + Sub, + Mul, + Div, + Pow, + Mod, +} + +impl Display for BinaryArithmeticKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BinaryArithmeticKind::Add => write!(f, "add"), + BinaryArithmeticKind::Sub => write!(f, "sub"), + BinaryArithmeticKind::Mul => write!(f, "mul"), + BinaryArithmeticKind::Div => write!(f, "div"), + BinaryArithmeticKind::Pow => write!(f, "pow"), + BinaryArithmeticKind::Mod => write!(f, "mod"), + } + } +} + +#[derive(Debug, Clone)] +pub enum UnaryArithmeticKind { + Round, + Ceil, + Floor, + Abs, + Sqrt, + Trim, + TrimStart, + TrimEnd, + Lowercase, + Uppercase, +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone)] +pub enum Context { + NodeOperand(NodeOperand), + EdgeOperand(EdgeOperand), + MultipleAttributesOperand(MultipleAttributesOperand), +} + +impl Context { + pub(crate) fn get_values<'a>( + &self, + medrecord: &'a MedRecord, + attribute: MedRecordAttribute, + ) -> MedRecordResult> { + Ok(match self { + Self::NodeOperand(node_operand) => { + let node_indices = node_operand.evaluate(medrecord)?; + + Box::new( + NodeOperation::get_values(medrecord, node_indices, attribute) + .map(|(_, value)| value), + ) + } + Self::EdgeOperand(edge_operand) => { + let edge_indices = edge_operand.evaluate(medrecord)?; + + Box::new( + EdgeOperation::get_values(medrecord, edge_indices, attribute) + .map(|(_, value)| value), + ) + } + Self::MultipleAttributesOperand(multiple_attributes_operand) => { + match &multiple_attributes_operand.context.context { + attributes::Context::NodeOperand(node_operand) => { + get_attributes!( + node_operand, + medrecord, + NodeOperation, + multiple_attributes_operand + ) + } + attributes::Context::EdgeOperand(edge_operand) => { + get_attributes!( + edge_operand, + medrecord, + EdgeOperation, + multiple_attributes_operand + ) + } + } + } + }) + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/values/operand.rs b/crates/medmodels-core/src/medrecord/querying/values/operand.rs new file mode 100644 index 00000000..01796ed7 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/values/operand.rs @@ -0,0 +1,590 @@ +use super::{ + operation::{MultipleValuesOperation, SingleValueOperation}, + BinaryArithmeticKind, Context, MultipleComparisonKind, SingleComparisonKind, SingleKind, + UnaryArithmeticKind, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{ + querying::{ + traits::{DeepClone, ReadWriteOrPanic}, + BoxedIterator, + }, + MedRecordAttribute, MedRecordValue, Wrapper, + }, + MedRecord, +}; +use std::hash::Hash; + +macro_rules! implement_value_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = + Wrapper::::new(self.deep_clone(), SingleKind::$variant); + + self.operations + .push(MultipleValuesOperation::ValueOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_single_value_comparison_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations + .push($operation::SingleValueComparisonOperation { + operand: value.into(), + kind: SingleComparisonKind::$kind, + }); + } + }; +} + +macro_rules! implement_binary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations.push($operation::BinaryArithmeticOpration { + operand: value.into(), + kind: BinaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_unary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name(&mut self) { + self.operations.push($operation::UnaryArithmeticOperation { + kind: UnaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_assertion_operation { + ($name:ident, $operation:expr) => { + pub fn $name(&mut self) { + self.operations.push($operation); + } + }; +} + +macro_rules! implement_wrapper_operand { + ($name:ident) => { + pub fn $name(&self) { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_return { + ($name:ident, $return_operand:ident) => { + pub fn $name(&self) -> Wrapper<$return_operand> { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_argument { + ($name:ident, $value_type:ty) => { + pub fn $name(&self, value: $value_type) { + self.0.write_or_panic().$name(value) + } + }; +} + +#[derive(Debug, Clone)] +pub enum SingleValueComparisonOperand { + Operand(SingleValueOperand), + Value(MedRecordValue), +} + +impl DeepClone for SingleValueComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Value(value) => Self::Value(value.clone()), + } + } +} + +impl From> for SingleValueComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for SingleValueComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From for SingleValueComparisonOperand { + fn from(value: V) -> Self { + Self::Value(value.into()) + } +} + +#[derive(Debug, Clone)] +pub enum MultipleValuesComparisonOperand { + Operand(MultipleValuesOperand), + Values(Vec), +} + +impl DeepClone for MultipleValuesComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Values(value) => Self::Values(value.clone()), + } + } +} + +impl From> for MultipleValuesComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for MultipleValuesComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From> for MultipleValuesComparisonOperand { + fn from(value: Vec) -> Self { + Self::Values(value.into_iter().map(Into::into).collect()) + } +} + +impl + Clone, const N: usize> From<[V; N]> + for MultipleValuesComparisonOperand +{ + fn from(value: [V; N]) -> Self { + value.to_vec().into() + } +} + +#[derive(Debug, Clone)] +pub struct MultipleValuesOperand { + pub(crate) context: Context, + pub(crate) attribute: MedRecordAttribute, + operations: Vec, +} + +impl DeepClone for MultipleValuesOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + attribute: self.attribute.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl MultipleValuesOperand { + pub(crate) fn new(context: Context, attribute: MedRecordAttribute) -> Self { + Self { + context, + attribute, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult> { + let values = Box::new(values) as BoxedIterator<(&'a T, MedRecordValue)>; + + self.operations + .iter() + .try_fold(values, |value_tuples, operation| { + operation.evaluate(medrecord, value_tuples) + }) + } + + implement_value_operation!(max, Max); + implement_value_operation!(min, Min); + implement_value_operation!(mean, Mean); + implement_value_operation!(median, Median); + implement_value_operation!(mode, Mode); + implement_value_operation!(std, Std); + implement_value_operation!(var, Var); + implement_value_operation!(count, Count); + implement_value_operation!(sum, Sum); + implement_value_operation!(first, First); + implement_value_operation!(last, Last); + + implement_single_value_comparison_operation!( + greater_than, + MultipleValuesOperation, + GreaterThan + ); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + MultipleValuesOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, MultipleValuesOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + MultipleValuesOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, MultipleValuesOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, MultipleValuesOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, MultipleValuesOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, MultipleValuesOperation, EndsWith); + implement_single_value_comparison_operation!(contains, MultipleValuesOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(MultipleValuesOperation::MultipleValuesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(MultipleValuesOperation::MultipleValuesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, MultipleValuesOperation, Add); + implement_binary_arithmetic_operation!(sub, MultipleValuesOperation, Sub); + implement_binary_arithmetic_operation!(mul, MultipleValuesOperation, Mul); + implement_binary_arithmetic_operation!(div, MultipleValuesOperation, Div); + implement_binary_arithmetic_operation!(pow, MultipleValuesOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, MultipleValuesOperation, Mod); + + implement_unary_arithmetic_operation!(round, MultipleValuesOperation, Round); + implement_unary_arithmetic_operation!(ceil, MultipleValuesOperation, Ceil); + implement_unary_arithmetic_operation!(floor, MultipleValuesOperation, Floor); + implement_unary_arithmetic_operation!(abs, MultipleValuesOperation, Abs); + implement_unary_arithmetic_operation!(sqrt, MultipleValuesOperation, Sqrt); + implement_unary_arithmetic_operation!(trim, MultipleValuesOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, MultipleValuesOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, MultipleValuesOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, MultipleValuesOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, MultipleValuesOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(MultipleValuesOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, MultipleValuesOperation::IsString); + implement_assertion_operation!(is_int, MultipleValuesOperation::IsInt); + implement_assertion_operation!(is_float, MultipleValuesOperation::IsFloat); + implement_assertion_operation!(is_bool, MultipleValuesOperation::IsBool); + implement_assertion_operation!(is_datetime, MultipleValuesOperation::IsDateTime); + implement_assertion_operation!(is_null, MultipleValuesOperation::IsNull); + implement_assertion_operation!(is_max, MultipleValuesOperation::IsMax); + implement_assertion_operation!(is_min, MultipleValuesOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.attribute.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.attribute.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(MultipleValuesOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: Context, attribute: MedRecordAttribute) -> Self { + MultipleValuesOperand::new(context, attribute).into() + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, values) + } + + implement_wrapper_operand_with_return!(max, SingleValueOperand); + implement_wrapper_operand_with_return!(min, SingleValueOperand); + implement_wrapper_operand_with_return!(mean, SingleValueOperand); + implement_wrapper_operand_with_return!(median, SingleValueOperand); + implement_wrapper_operand_with_return!(mode, SingleValueOperand); + implement_wrapper_operand_with_return!(std, SingleValueOperand); + implement_wrapper_operand_with_return!(var, SingleValueOperand); + implement_wrapper_operand_with_return!(count, SingleValueOperand); + implement_wrapper_operand_with_return!(sum, SingleValueOperand); + implement_wrapper_operand_with_return!(first, SingleValueOperand); + implement_wrapper_operand_with_return!(last, SingleValueOperand); + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(div, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(round); + implement_wrapper_operand!(ceil); + implement_wrapper_operand!(floor); + implement_wrapper_operand!(abs); + implement_wrapper_operand!(sqrt); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_float); + implement_wrapper_operand!(is_bool); + implement_wrapper_operand!(is_datetime); + implement_wrapper_operand!(is_null); + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +#[derive(Debug, Clone)] +pub struct SingleValueOperand { + pub(crate) context: MultipleValuesOperand, + pub(crate) kind: SingleKind, + operations: Vec, +} + +impl DeepClone for SingleValueOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.deep_clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl SingleValueOperand { + pub(crate) fn new(context: MultipleValuesOperand, kind: SingleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: MedRecordValue, + ) -> MedRecordResult> { + self.operations + .iter() + .try_fold(Some(value), |value, operation| { + if let Some(value) = value { + operation.evaluate(medrecord, value) + } else { + Ok(None) + } + }) + } + + implement_single_value_comparison_operation!(greater_than, SingleValueOperation, GreaterThan); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + SingleValueOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, SingleValueOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + SingleValueOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, SingleValueOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, SingleValueOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, SingleValueOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, SingleValueOperation, EndsWith); + implement_single_value_comparison_operation!(contains, SingleValueOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(SingleValueOperation::MultipleValuesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(SingleValueOperation::MultipleValuesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, SingleValueOperation, Add); + implement_binary_arithmetic_operation!(sub, SingleValueOperation, Sub); + implement_binary_arithmetic_operation!(mul, SingleValueOperation, Mul); + implement_binary_arithmetic_operation!(div, SingleValueOperation, Div); + implement_binary_arithmetic_operation!(pow, SingleValueOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, SingleValueOperation, Mod); + + implement_unary_arithmetic_operation!(round, SingleValueOperation, Round); + implement_unary_arithmetic_operation!(ceil, SingleValueOperation, Ceil); + implement_unary_arithmetic_operation!(floor, SingleValueOperation, Floor); + implement_unary_arithmetic_operation!(abs, SingleValueOperation, Abs); + implement_unary_arithmetic_operation!(sqrt, SingleValueOperation, Sqrt); + implement_unary_arithmetic_operation!(trim, SingleValueOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, SingleValueOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, SingleValueOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, SingleValueOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, SingleValueOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(SingleValueOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, SingleValueOperation::IsString); + implement_assertion_operation!(is_int, SingleValueOperation::IsInt); + implement_assertion_operation!(is_float, SingleValueOperation::IsFloat); + implement_assertion_operation!(is_bool, SingleValueOperation::IsBool); + implement_assertion_operation!(is_datetime, SingleValueOperation::IsDateTime); + implement_assertion_operation!(is_null, SingleValueOperation::IsNull); + + pub fn eiter_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(SingleValueOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: MultipleValuesOperand, kind: SingleKind) -> Self { + SingleValueOperand::new(context, kind).into() + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: MedRecordValue, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, value) + } + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(div, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(round); + implement_wrapper_operand!(ceil); + implement_wrapper_operand!(floor); + implement_wrapper_operand!(abs); + implement_wrapper_operand!(sqrt); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_float); + implement_wrapper_operand!(is_bool); + implement_wrapper_operand!(is_datetime); + implement_wrapper_operand!(is_null); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().eiter_or(either_query, or_query); + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/values/operation.rs b/crates/medmodels-core/src/medrecord/querying/values/operation.rs new file mode 100644 index 00000000..2a559d99 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/values/operation.rs @@ -0,0 +1,934 @@ +use super::{ + operand::{ + MultipleValuesComparisonOperand, MultipleValuesOperand, SingleValueComparisonOperand, + SingleValueOperand, + }, + BinaryArithmeticKind, MultipleComparisonKind, SingleComparisonKind, SingleKind, + UnaryArithmeticKind, +}; +use crate::{ + errors::{MedRecordError, MedRecordResult}, + medrecord::{ + datatypes::{ + Abs, Ceil, Contains, EndsWith, Floor, Lowercase, Mod, Pow, Round, Slice, Sqrt, + StartsWith, Trim, TrimEnd, TrimStart, Uppercase, + }, + querying::{ + traits::{DeepClone, ReadWriteOrPanic}, + BoxedIterator, + }, + DataType, MedRecordValue, Wrapper, + }, + MedRecord, +}; +use itertools::Itertools; +use std::{ + cmp::Ordering, + hash::Hash, + ops::{Add, Div, Mul, Range, Sub}, +}; + +macro_rules! get_single_operand_value { + ($kind:ident, $values:expr) => { + match $kind { + SingleKind::Max => MultipleValuesOperation::get_max($values)?.1, + SingleKind::Min => MultipleValuesOperation::get_min($values)?.1, + SingleKind::Mean => MultipleValuesOperation::get_mean($values)?, + SingleKind::Median => MultipleValuesOperation::get_median($values)?, + SingleKind::Mode => MultipleValuesOperation::get_mode($values)?, + SingleKind::Std => MultipleValuesOperation::get_std($values)?, + SingleKind::Var => MultipleValuesOperation::get_var($values)?, + SingleKind::Count => MultipleValuesOperation::get_count($values), + SingleKind::Sum => MultipleValuesOperation::get_sum($values)?, + SingleKind::First => MultipleValuesOperation::get_first($values)?, + SingleKind::Last => MultipleValuesOperation::get_last($values)?, + } + }; +} + +macro_rules! get_single_value_comparison_operand_value { + ($operand:ident, $medrecord:ident) => { + match $operand { + SingleValueComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let attribute = operand.context.attribute.clone(); + let kind = &operand.kind; + + let comparison_values = context + .get_values($medrecord, attribute)? + .map(|value| (&0, value)); + + let comparison_value = get_single_operand_value!(kind, comparison_values); + + comparison_value + } + SingleValueComparisonOperand::Value(value) => value.clone(), + } + }; +} + +macro_rules! get_median { + ($values:ident, $variant:ident) => { + if $values.len() % 2 == 0 { + let middle = $values.len() / 2; + + let first = $values.get(middle - 1).unwrap(); + let second = $values.get(middle).unwrap(); + + let first = MedRecordValue::$variant(*first); + let second = MedRecordValue::$variant(*second); + + first.add(second).unwrap().div(MedRecordValue::Int(2)) + } else { + let middle = $values.len() / 2; + + Ok(MedRecordValue::$variant( + $values.get(middle).unwrap().clone(), + )) + } + }; +} + +#[derive(Debug, Clone)] +pub enum MultipleValuesOperation { + ValueOperation { + operand: Wrapper, + }, + SingleValueComparisonOperation { + operand: SingleValueComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleValuesComparisonOperation { + operand: MultipleValuesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleValueComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + IsFloat, + IsBool, + IsDateTime, + IsNull, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for MultipleValuesOperation { + fn deep_clone(&self) -> Self { + match self { + Self::ValueOperation { operand } => Self::ValueOperation { + operand: operand.deep_clone(), + }, + Self::SingleValueComparisonOperation { operand, kind } => { + Self::SingleValueComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleValuesComparisonOperation { operand, kind } => { + Self::MultipleValuesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsFloat => Self::IsFloat, + Self::IsBool => Self::IsBool, + Self::IsDateTime => Self::IsDateTime, + Self::IsNull => Self::IsNull, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl MultipleValuesOperation { + pub(crate) fn evaluate<'a, T: Eq + Hash>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult> { + match self { + Self::ValueOperation { operand } => { + Self::evaluate_value_operation(medrecord, values, operand) + } + Self::SingleValueComparisonOperation { operand, kind } => { + Self::evaluate_single_value_comparison_operation(medrecord, values, operand, kind) + } + Self::MultipleValuesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_values_comparison_operation( + medrecord, values, operand, kind, + ) + } + Self::BinaryArithmeticOpration { operand, kind } => Ok(Box::new( + Self::evaluate_binary_arithmetic_operation(medrecord, values, operand, kind)?, + )), + Self::UnaryArithmeticOperation { kind } => Ok(Box::new( + Self::evaluate_unary_arithmetic_operation(values, kind.clone()), + )), + Self::Slice(range) => Ok(Box::new(Self::evaluate_slice(values, range.clone()))), + Self::IsString => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::String(_)) + }))) + } + Self::IsInt => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::Int(_)) + }))) + } + Self::IsFloat => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::Float(_)) + }))) + } + Self::IsBool => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::Bool(_)) + }))) + } + Self::IsDateTime => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::DateTime(_)) + }))) + } + Self::IsNull => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::Null) + }))) + } + Self::IsMax => { + let max_value = Self::get_max(values)?; + + Ok(Box::new(std::iter::once(max_value))) + } + Self::IsMin => { + let min_value = Self::get_min(values)?; + + Ok(Box::new(std::iter::once(min_value))) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, values, either, or) + } + } + } + + #[inline] + pub(crate) fn get_max<'a, T>( + mut values: impl Iterator, + ) -> MedRecordResult<(&'a T, MedRecordValue)> { + let max_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + values.try_fold(max_value, |max_value, value| { + match value.1.partial_cmp(&max_value.1) { + Some(Ordering::Greater) => Ok(value), + None => { + let first_dtype = DataType::from(value.1); + let second_dtype = DataType::from(max_value.1); + + Err(MedRecordError::QueryError(format!( + "Cannot compare values of data types {} and {}. Consider narrowing down the values using .is_string(), .is_int(), .is_float(), .is_bool() or .is_datetime()", + first_dtype, second_dtype + ))) + } + _ => Ok(max_value), + } + }) + } + + #[inline] + pub(crate) fn get_min<'a, T>( + mut values: impl Iterator, + ) -> MedRecordResult<(&'a T, MedRecordValue)> { + let min_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + values.try_fold(min_value, |min_value, value| { + match value.1.partial_cmp(&min_value.1) { + Some(Ordering::Less) => Ok(value), + None => { + let first_dtype = DataType::from(value.1); + let second_dtype = DataType::from(min_value.1); + + Err(MedRecordError::QueryError(format!( + "Cannot compare values of data types {} and {}. Consider narrowing down the values using .is_string(), .is_int(), .is_float(), .is_bool() or .is_datetime()", + first_dtype, second_dtype + ))) + } + _ => Ok(min_value), + } + }) + } + + #[inline] + pub(crate) fn get_mean<'a, T: 'a>( + mut values: impl Iterator, + ) -> MedRecordResult { + let first_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + let (sum, count) = values.try_fold((first_value.1, 1), |(sum, count), (_, value)| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&value); + + match sum.add(value) { + Ok(sum) => Ok((sum, count + 1)), + Err(_) => Err(MedRecordError::QueryError(format!( + "Cannot add values of data types {} and {}. Consider narrowing down the values using .is_int(), .is_float() or .is_datetime()", + first_dtype, second_dtype + ))), + } + })?; + + sum.div(MedRecordValue::Int(count as i64)) + } + + // TODO: This is a temporary solution. It should be optimized. + #[inline] + pub(crate) fn get_median<'a, T: 'a>( + mut values: impl Iterator, + ) -> MedRecordResult { + let first_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + let first_data_type = DataType::from(&first_value.1); + + match first_value.1 { + MedRecordValue::Int(value) => { + let mut values = values.map(|(_, value)| { + let data_type = DataType::from(&value); + + match value { + MedRecordValue::Int(value) => Ok(value as f64), + MedRecordValue::Float(value) => Ok(value), + _ => Err(MedRecordError::QueryError(format!( + "Cannot calculate median of mixed data types {} and {}. Consider narrowing down the values using .is_int(), .is_float() or .is_datetime()", + first_data_type, data_type + ))), + } + }).collect::>>()?; + values.push(value as f64); + values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + + get_median!(values, Float) + } + MedRecordValue::Float(value) => { + let mut values = values.map(|(_, value)| { + let data_type = DataType::from(&value); + + match value { + MedRecordValue::Int(value) => Ok(value as f64), + MedRecordValue::Float(value) => Ok(value), + _ => Err(MedRecordError::QueryError(format!( + "Cannot calculate median of mixed data types {} and {}. Consider narrowing down the values using .is_int(), .is_float() or .is_datetime()", + first_data_type, data_type + ))), + } + }).collect::>>()?; + values.push(value); + values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + + get_median!(values, Float) + } + MedRecordValue::DateTime(value) => { + let mut values = values.map(|(_, value)| { + let data_type = DataType::from(&value); + + match value { + MedRecordValue::DateTime(naive_date_time) => Ok(naive_date_time), + _ => Err(MedRecordError::QueryError(format!( + "Cannot calculate median of mixed data types {} and {}. Consider narrowing down the values using .is_int(), .is_float() or .is_datetime()", + first_data_type, data_type + ))), + } + }).collect::>>()?; + values.push(value); + values.sort(); + + get_median!(values, DateTime) + } + _ => Err(MedRecordError::QueryError(format!( + "Cannot calculate median of data type {}", + first_data_type + )))?, + } + } + + // TODO: This is a temporary solution. It should be optimized. + #[inline] + pub(crate) fn get_mode<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordResult { + let values = values.map(|(_, value)| value).collect::>(); + + let most_common_value = values + .first() + .ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))? + .clone(); + let most_common_count = values + .iter() + .filter(|value| **value == most_common_value) + .count(); + + let (_, most_common_value) = values.clone().into_iter().fold( + (most_common_count, most_common_value), + |acc, value| { + let count = values.iter().filter(|v| **v == value).count(); + + if count > acc.0 { + (count, value) + } else { + acc + } + }, + ); + + Ok(most_common_value) + } + + #[inline] + // 👀 + pub(crate) fn get_std<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordResult { + let variance = Self::get_var(values)?; + + let MedRecordValue::Float(variance) = variance else { + unreachable!() + }; + + Ok(MedRecordValue::Float(variance.sqrt())) + } + + // TODO: This is a temporary solution. It should be optimized. + #[inline] + pub(crate) fn get_var<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordResult { + let values = values.collect::>(); + + let mean = Self::get_mean(values.clone().into_iter())?; + + let MedRecordValue::Float(mean) = mean else { + let data_type = DataType::from(mean); + + return Err(MedRecordError::QueryError( + format!("Cannot calculate variance of data type {}. Consider narrowing down the values using .is_int() or .is_float()", data_type), + )); + }; + + let values = values + .into_iter() + .map(|value| { + let data_type = DataType::from(&value.1); + + match value.1 { + MedRecordValue::Int(value) => Ok(value as f64), + MedRecordValue::Float(value) => Ok(value), + _ => Err(MedRecordError::QueryError( + format!("Cannot calculate variance of data type {}. Consider narrowing down the values using .is_int() or .is_float()", data_type), + )), + }}) + .collect::>>()?; + + let values_length = values.len(); + + let variance = values + .into_iter() + .map(|value| (value - mean).powi(2)) + .sum::() + / values_length as f64; + + Ok(MedRecordValue::Float(variance)) + } + + #[inline] + pub(crate) fn get_count<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordValue { + MedRecordValue::Int(values.count() as i64) + } + + #[inline] + // 🥊💥 + pub(crate) fn get_sum<'a, T: 'a>( + mut values: impl Iterator, + ) -> MedRecordResult { + let first_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + values.try_fold(first_value.1, |sum, (_, value)| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&value); + + sum.add(value).map_err(|_| { + MedRecordError::QueryError(format!( + "Cannot add values of data types {} and {}. Consider narrowing down the values using .is_string(), .is_int(), .is_float(), .is_bool() or .is_datetime()", + first_dtype, second_dtype + )) + }) + }) + } + + #[inline] + pub(crate) fn get_first<'a, T: 'a>( + mut values: impl Iterator, + ) -> MedRecordResult { + values + .next() + .ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + )) + .map(|(_, value)| value) + } + + #[inline] + // 🥊💥 + pub(crate) fn get_last<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordResult { + values + .last() + .ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + )) + .map(|(_, value)| value) + } + + #[inline] + fn evaluate_value_operation<'a, T>( + medrecord: &'a MedRecord, + values: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let kind = &operand.0.read_or_panic().kind; + + let values = values.collect::>(); + + let value = get_single_operand_value!(kind, values.clone().into_iter()); + + Ok(match operand.evaluate(medrecord, value)? { + Some(_) => Box::new(values.into_iter()), + None => Box::new(std::iter::empty()), + }) + } + + #[inline] + fn evaluate_single_value_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + comparison_operand: &SingleValueComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_value = + get_single_value_comparison_operand_value!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => Ok(Box::new( + values.filter(move |(_, value)| value > &comparison_value), + )), + SingleComparisonKind::GreaterThanOrEqualTo => Ok(Box::new( + values.filter(move |(_, value)| value >= &comparison_value), + )), + SingleComparisonKind::LessThan => Ok(Box::new( + values.filter(move |(_, value)| value < &comparison_value), + )), + SingleComparisonKind::LessThanOrEqualTo => Ok(Box::new( + values.filter(move |(_, value)| value <= &comparison_value), + )), + SingleComparisonKind::EqualTo => Ok(Box::new( + values.filter(move |(_, value)| value == &comparison_value), + )), + SingleComparisonKind::NotEqualTo => Ok(Box::new( + values.filter(move |(_, value)| value != &comparison_value), + )), + SingleComparisonKind::StartsWith => { + Ok(Box::new(values.filter(move |(_, value)| { + value.starts_with(&comparison_value) + }))) + } + SingleComparisonKind::EndsWith => { + Ok(Box::new(values.filter(move |(_, value)| { + value.ends_with(&comparison_value) + }))) + } + SingleComparisonKind::Contains => { + Ok(Box::new(values.filter(move |(_, value)| { + value.contains(&comparison_value) + }))) + } + } + } + + #[inline] + fn evaluate_multiple_values_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + comparison_operand: &MultipleValuesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_values = match comparison_operand { + MultipleValuesComparisonOperand::Operand(operand) => { + let context = &operand.context; + let attribute = operand.attribute.clone(); + + context + .get_values(medrecord, attribute)? + .collect::>() + } + MultipleValuesComparisonOperand::Values(values) => values.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => { + Ok(Box::new(values.filter(move |(_, value)| { + comparison_values.contains(value) + }))) + } + MultipleComparisonKind::IsNotIn => { + Ok(Box::new(values.filter(move |(_, value)| { + !comparison_values.contains(value) + }))) + } + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation<'a, T: 'a>( + medrecord: &'a MedRecord, + values: impl Iterator, + operand: &SingleValueComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_value = get_single_value_comparison_operand_value!(operand, medrecord); + + let values = values + .map(move |(t, value)| { + match kind { + BinaryArithmeticKind::Add => value.add(arithmetic_value.clone()), + BinaryArithmeticKind::Sub => value.sub(arithmetic_value.clone()), + BinaryArithmeticKind::Mul => { + value.clone().mul(arithmetic_value.clone()) + } + BinaryArithmeticKind::Div => { + value.clone().div(arithmetic_value.clone()) + } + BinaryArithmeticKind::Pow => { + value.clone().pow(arithmetic_value.clone()) + } + BinaryArithmeticKind::Mod => { + value.clone().r#mod(arithmetic_value.clone()) + } + } + .map_err(|_| { + MedRecordError::QueryError(format!( + "Failed arithmetic operation {}. Consider narrowing down the values using .is_int() or .is_float()", + kind, + )) + }).map(|result| (t, result)) + }); + + // TODO: This is a temporary solution. It should be optimized. + Ok(values.collect::>>()?.into_iter()) + } + + #[inline] + fn evaluate_unary_arithmetic_operation<'a, T: 'a>( + values: impl Iterator, + kind: UnaryArithmeticKind, + ) -> impl Iterator { + values.map(move |(t, value)| { + let value = match kind { + UnaryArithmeticKind::Round => value.round(), + UnaryArithmeticKind::Ceil => value.ceil(), + UnaryArithmeticKind::Floor => value.floor(), + UnaryArithmeticKind::Abs => value.abs(), + UnaryArithmeticKind::Sqrt => value.sqrt(), + UnaryArithmeticKind::Trim => value.trim(), + UnaryArithmeticKind::TrimStart => value.trim_start(), + UnaryArithmeticKind::TrimEnd => value.trim_end(), + UnaryArithmeticKind::Lowercase => value.lowercase(), + UnaryArithmeticKind::Uppercase => value.uppercase(), + }; + (t, value) + }) + } + + #[inline] + fn evaluate_slice<'a, T: 'a>( + values: impl Iterator, + range: Range, + ) -> impl Iterator { + values.map(move |(t, value)| (t, value.slice(range.clone()))) + } + + #[inline] + fn evaluate_either_or<'a, T: 'a + Eq + Hash>( + medrecord: &'a MedRecord, + values: impl Iterator, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let values = values.collect::>(); + + let either_values = either.evaluate(medrecord, values.clone().into_iter())?; + let or_values = or.evaluate(medrecord, values.into_iter())?; + + Ok(Box::new( + either_values.chain(or_values).unique_by(|value| value.0), + )) + } +} + +#[derive(Debug, Clone)] +pub enum SingleValueOperation { + SingleValueComparisonOperation { + operand: SingleValueComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleValuesComparisonOperation { + operand: MultipleValuesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleValueComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + IsFloat, + IsBool, + IsDateTime, + IsNull, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for SingleValueOperation { + fn deep_clone(&self) -> Self { + match self { + Self::SingleValueComparisonOperation { operand, kind } => { + Self::SingleValueComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleValuesComparisonOperation { operand, kind } => { + Self::MultipleValuesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsFloat => Self::IsFloat, + Self::IsBool => Self::IsBool, + Self::IsDateTime => Self::IsDateTime, + Self::IsNull => Self::IsNull, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl SingleValueOperation { + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: MedRecordValue, + ) -> MedRecordResult> { + match self { + Self::SingleValueComparisonOperation { operand, kind } => { + Self::evaluate_single_value_comparison_operation(medrecord, value, operand, kind) + } + Self::MultipleValuesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_values_comparison_operation(medrecord, value, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, value, operand, kind) + } + Self::UnaryArithmeticOperation { kind } => Ok(Some(match kind { + UnaryArithmeticKind::Round => value.round(), + UnaryArithmeticKind::Ceil => value.ceil(), + UnaryArithmeticKind::Floor => value.floor(), + UnaryArithmeticKind::Abs => value.abs(), + UnaryArithmeticKind::Sqrt => value.sqrt(), + UnaryArithmeticKind::Trim => value.trim(), + UnaryArithmeticKind::TrimStart => value.trim_start(), + UnaryArithmeticKind::TrimEnd => value.trim_end(), + UnaryArithmeticKind::Lowercase => value.lowercase(), + UnaryArithmeticKind::Uppercase => value.uppercase(), + })), + Self::Slice(range) => Ok(Some(value.slice(range.clone()))), + Self::IsString => Ok(match value { + MedRecordValue::String(_) => Some(value), + _ => None, + }), + Self::IsInt => Ok(match value { + MedRecordValue::Int(_) => Some(value), + _ => None, + }), + Self::IsFloat => Ok(match value { + MedRecordValue::Float(_) => Some(value), + _ => None, + }), + Self::IsBool => Ok(match value { + MedRecordValue::Bool(_) => Some(value), + _ => None, + }), + Self::IsDateTime => Ok(match value { + MedRecordValue::DateTime(_) => Some(value), + _ => None, + }), + Self::IsNull => Ok(match value { + MedRecordValue::Null => Some(value), + _ => None, + }), + Self::EitherOr { either, or } => Self::evaluate_either_or(medrecord, value, either, or), + } + } + + #[inline] + fn evaluate_single_value_comparison_operation( + medrecord: &MedRecord, + value: MedRecordValue, + comparison_operand: &SingleValueComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_value = + get_single_value_comparison_operand_value!(comparison_operand, medrecord); + + let comparison_result = match kind { + SingleComparisonKind::GreaterThan => value > comparison_value, + SingleComparisonKind::GreaterThanOrEqualTo => value >= comparison_value, + SingleComparisonKind::LessThan => value < comparison_value, + SingleComparisonKind::LessThanOrEqualTo => value <= comparison_value, + SingleComparisonKind::EqualTo => value == comparison_value, + SingleComparisonKind::NotEqualTo => value != comparison_value, + SingleComparisonKind::StartsWith => value.starts_with(&comparison_value), + SingleComparisonKind::EndsWith => value.ends_with(&comparison_value), + SingleComparisonKind::Contains => value.contains(&comparison_value), + }; + + Ok(if comparison_result { Some(value) } else { None }) + } + + #[inline] + fn evaluate_multiple_values_comparison_operation( + medrecord: &MedRecord, + value: MedRecordValue, + comparison_operand: &MultipleValuesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_values = match comparison_operand { + MultipleValuesComparisonOperand::Operand(operand) => { + let context = &operand.context; + let attribute = operand.attribute.clone(); + + context + .get_values(medrecord, attribute)? + .collect::>() + } + MultipleValuesComparisonOperand::Values(values) => values.clone(), + }; + + let comparison_result = match kind { + MultipleComparisonKind::IsIn => comparison_values.contains(&value), + MultipleComparisonKind::IsNotIn => !comparison_values.contains(&value), + }; + + Ok(if comparison_result { Some(value) } else { None }) + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + value: MedRecordValue, + operand: &SingleValueComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_value = get_single_value_comparison_operand_value!(operand, medrecord); + + match kind { + BinaryArithmeticKind::Add => value.add(arithmetic_value), + BinaryArithmeticKind::Sub => value.sub(arithmetic_value), + BinaryArithmeticKind::Mul => value.mul(arithmetic_value), + BinaryArithmeticKind::Div => value.div(arithmetic_value), + BinaryArithmeticKind::Pow => value.pow(arithmetic_value), + BinaryArithmeticKind::Mod => value.r#mod(arithmetic_value), + } + .map(Some) + } + + #[inline] + fn evaluate_either_or( + medrecord: &MedRecord, + value: MedRecordValue, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord, value.clone())?; + let or_result = or.evaluate(medrecord, value)?; + + match (either_result, or_result) { + (Some(either_result), _) => Ok(Some(either_result)), + (None, Some(or_result)) => Ok(Some(or_result)), + _ => Ok(None), + } + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/wrapper.rs b/crates/medmodels-core/src/medrecord/querying/wrapper.rs new file mode 100644 index 00000000..a5d338bc --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/wrapper.rs @@ -0,0 +1,45 @@ +use super::traits::{DeepClone, ReadWriteOrPanic}; +use std::sync::{Arc, RwLock}; + +#[repr(transparent)] +#[derive(Debug, Clone)] +pub struct Wrapper(pub(crate) Arc>); + +impl From for Wrapper { + fn from(value: T) -> Self { + Self(Arc::new(RwLock::new(value))) + } +} + +impl DeepClone for Wrapper +where + T: DeepClone, +{ + fn deep_clone(&self) -> Self { + self.0.read_or_panic().deep_clone().into() + } +} + +#[derive(Debug, Clone)] +pub enum CardinalityWrapper { + Single(T), + Multiple(Vec), +} + +impl From for CardinalityWrapper { + fn from(value: T) -> Self { + Self::Single(value) + } +} + +impl From> for CardinalityWrapper { + fn from(value: Vec) -> Self { + Self::Multiple(value) + } +} + +impl From<[T; N]> for CardinalityWrapper { + fn from(value: [T; N]) -> Self { + Self::Multiple(value.to_vec()) + } +} diff --git a/crates/medmodels-core/src/medrecord/schema.rs b/crates/medmodels-core/src/medrecord/schema.rs index 2bcdd562..8015870e 100644 --- a/crates/medmodels-core/src/medrecord/schema.rs +++ b/crates/medmodels-core/src/medrecord/schema.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] - use super::{Attributes, EdgeIndex, NodeIndex}; use crate::{ errors::GraphError, diff --git a/rustmodels/Cargo.toml b/rustmodels/Cargo.toml index a6640f90..70922829 100644 --- a/rustmodels/Cargo.toml +++ b/rustmodels/Cargo.toml @@ -11,7 +11,7 @@ crate-type = ["cdylib"] medmodels-core = { workspace = true } medmodels-utils = { workspace = true } -pyo3 = { workspace = true } -pyo3-polars = { workspace = true } +pyo3 = { version = "0.21.2", features = ["chrono"] } +pyo3-polars = "0.14.0" polars = { workspace = true } chrono = { workspace = true } diff --git a/rustmodels/src/medrecord/mod.rs b/rustmodels/src/medrecord/mod.rs index 73b4670b..059106ee 100644 --- a/rustmodels/src/medrecord/mod.rs +++ b/rustmodels/src/medrecord/mod.rs @@ -614,7 +614,7 @@ impl PyMedRecord { .map(|node_index| { let neighbors = self .0 - .neighbors(&node_index) + .neighbors_outgoing(&node_index) .map_err(PyMedRecordError::from)? .map(|neighbor| neighbor.clone().into()) .collect(); From ed2af4b7e76bd3a0364497f20c5a340538b82447 Mon Sep 17 00:00:00 2001 From: Jakob Kraus <52459467+JabobKrauskopf@users.noreply.github.com> Date: Fri, 11 Oct 2024 14:04:42 +0200 Subject: [PATCH 7/8] feat: implement python bindings of query engine (#227) --- crates/medmodels-core/src/medrecord/mod.rs | 19 +- .../src/medrecord/querying/attributes/mod.rs | 5 +- .../src/medrecord/querying/edges/mod.rs | 5 +- .../src/medrecord/querying/edges/operand.rs | 138 ++-- .../src/medrecord/querying/mod.rs | 2 +- .../src/medrecord/querying/nodes/mod.rs | 7 +- .../src/medrecord/querying/nodes/operand.rs | 145 ++-- .../src/medrecord/querying/nodes/operation.rs | 4 +- .../src/medrecord/querying/values/mod.rs | 5 +- medmodels/_medmodels.pyi | 589 ++++++++++---- medmodels/medrecord/indexers.py | 46 +- rustmodels/src/lib.rs | 31 +- rustmodels/src/medrecord/attribute.rs | 2 +- rustmodels/src/medrecord/errors.rs | 1 + rustmodels/src/medrecord/mod.rs | 157 ++-- rustmodels/src/medrecord/querying.rs | 732 ------------------ .../src/medrecord/querying/attributes.rs | 565 ++++++++++++++ rustmodels/src/medrecord/querying/edges.rs | 384 +++++++++ rustmodels/src/medrecord/querying/mod.rs | 52 ++ rustmodels/src/medrecord/querying/nodes.rs | 491 ++++++++++++ rustmodels/src/medrecord/querying/values.rs | 482 ++++++++++++ rustmodels/src/medrecord/value.rs | 2 +- 22 files changed, 2712 insertions(+), 1152 deletions(-) delete mode 100644 rustmodels/src/medrecord/querying.rs create mode 100644 rustmodels/src/medrecord/querying/attributes.rs create mode 100644 rustmodels/src/medrecord/querying/edges.rs create mode 100644 rustmodels/src/medrecord/querying/mod.rs create mode 100644 rustmodels/src/medrecord/querying/nodes.rs create mode 100644 rustmodels/src/medrecord/querying/values.rs diff --git a/crates/medmodels-core/src/medrecord/mod.rs b/crates/medmodels-core/src/medrecord/mod.rs index f4000061..f9b4b03f 100644 --- a/crates/medmodels-core/src/medrecord/mod.rs +++ b/crates/medmodels-core/src/medrecord/mod.rs @@ -11,8 +11,23 @@ pub use self::{ graph::{Attributes, EdgeIndex, NodeIndex}, group_mapping::Group, querying::{ - edges::EdgeOperand, - nodes::NodeOperand, + attributes::{ + AttributesTreeOperand, MultipleAttributesComparisonOperand, MultipleAttributesOperand, + SingleAttributeComparisonOperand, SingleAttributeOperand, + }, + edges::{ + EdgeIndexComparisonOperand, EdgeIndexOperand, EdgeIndicesComparisonOperand, + EdgeIndicesOperand, EdgeOperand, + }, + nodes::{ + EdgeDirection, NodeIndexComparisonOperand, NodeIndexOperand, + NodeIndicesComparisonOperand, NodeIndicesOperand, NodeOperand, + }, + traits::DeepClone, + values::{ + MultipleValuesComparisonOperand, MultipleValuesOperand, SingleValueComparisonOperand, + SingleValueOperand, + }, wrapper::{CardinalityWrapper, Wrapper}, }, schema::{AttributeDataType, AttributeType, GroupSchema, Schema}, diff --git a/crates/medmodels-core/src/medrecord/querying/attributes/mod.rs b/crates/medmodels-core/src/medrecord/querying/attributes/mod.rs index d16fcabd..8e60945f 100644 --- a/crates/medmodels-core/src/medrecord/querying/attributes/mod.rs +++ b/crates/medmodels-core/src/medrecord/querying/attributes/mod.rs @@ -11,7 +11,10 @@ use crate::{ medrecord::{Attributes, EdgeIndex, MedRecordAttribute, NodeIndex}, MedRecord, }; -pub use operand::{AttributesTreeOperand, MultipleAttributesOperand}; +pub use operand::{ + AttributesTreeOperand, MultipleAttributesComparisonOperand, MultipleAttributesOperand, + SingleAttributeComparisonOperand, SingleAttributeOperand, +}; pub use operation::{AttributesTreeOperation, MultipleAttributesOperation}; use std::fmt::Display; diff --git a/crates/medmodels-core/src/medrecord/querying/edges/mod.rs b/crates/medmodels-core/src/medrecord/querying/edges/mod.rs index 1045e83e..f78eabb0 100644 --- a/crates/medmodels-core/src/medrecord/querying/edges/mod.rs +++ b/crates/medmodels-core/src/medrecord/querying/edges/mod.rs @@ -2,7 +2,10 @@ mod operand; mod operation; mod selection; -pub use operand::EdgeOperand; +pub use operand::{ + EdgeIndexComparisonOperand, EdgeIndexOperand, EdgeIndicesComparisonOperand, EdgeIndicesOperand, + EdgeOperand, +}; pub use operation::EdgeOperation; pub use selection::EdgeSelection; use std::fmt::Display; diff --git a/crates/medmodels-core/src/medrecord/querying/edges/operand.rs b/crates/medmodels-core/src/medrecord/querying/edges/operand.rs index 4b7b4f85..f2c50513 100644 --- a/crates/medmodels-core/src/medrecord/querying/edges/operand.rs +++ b/crates/medmodels-core/src/medrecord/querying/edges/operand.rs @@ -205,7 +205,7 @@ impl Wrapper { } } -macro_rules! implement_value_operation { +macro_rules! implement_index_operation { ($name:ident, $variant:ident) => { pub fn $name(&mut self) -> Wrapper { let operand = Wrapper::::new(self.deep_clone(), SingleKind::$variant); @@ -220,12 +220,12 @@ macro_rules! implement_value_operation { }; } -macro_rules! implement_single_value_comparison_operation { +macro_rules! implement_single_index_comparison_operation { ($name:ident, $operation:ident, $kind:ident) => { - pub fn $name>(&mut self, value: V) { + pub fn $name>(&mut self, index: V) { self.operations .push($operation::EdgeIndexComparisonOperation { - operand: value.into(), + operand: index.into(), kind: SingleComparisonKind::$kind, }); } @@ -234,9 +234,9 @@ macro_rules! implement_single_value_comparison_operation { macro_rules! implement_binary_arithmetic_operation { ($name:ident, $operation:ident, $kind:ident) => { - pub fn $name>(&mut self, value: V) { + pub fn $name>(&mut self, index: V) { self.operations.push($operation::BinaryArithmeticOpration { - operand: value.into(), + operand: index.into(), kind: BinaryArithmeticKind::$kind, }); } @@ -268,9 +268,9 @@ macro_rules! implement_wrapper_operand_with_return { } macro_rules! implement_wrapper_operand_with_argument { - ($name:ident, $value_type:ty) => { - pub fn $name(&self, value: $value_type) { - self.0.write_or_panic().$name(value) + ($name:ident, $index_type:ty) => { + pub fn $name(&self, index: $index_type) { + self.0.write_or_panic().$name(index) } }; } @@ -285,26 +285,26 @@ impl DeepClone for EdgeIndexComparisonOperand { fn deep_clone(&self) -> Self { match self { Self::Operand(operand) => Self::Operand(operand.deep_clone()), - Self::Index(value) => Self::Index(*value), + Self::Index(index) => Self::Index(*index), } } } impl From> for EdgeIndexComparisonOperand { - fn from(value: Wrapper) -> Self { - Self::Operand(value.0.read_or_panic().deep_clone()) + fn from(index: Wrapper) -> Self { + Self::Operand(index.0.read_or_panic().deep_clone()) } } impl From<&Wrapper> for EdgeIndexComparisonOperand { - fn from(value: &Wrapper) -> Self { - Self::Operand(value.0.read_or_panic().deep_clone()) + fn from(index: &Wrapper) -> Self { + Self::Operand(index.0.read_or_panic().deep_clone()) } } impl> From for EdgeIndexComparisonOperand { - fn from(value: V) -> Self { - Self::Index(value.into()) + fn from(index: V) -> Self { + Self::Index(index.into()) } } @@ -318,32 +318,32 @@ impl DeepClone for EdgeIndicesComparisonOperand { fn deep_clone(&self) -> Self { match self { Self::Operand(operand) => Self::Operand(operand.deep_clone()), - Self::Indices(value) => Self::Indices(value.clone()), + Self::Indices(indices) => Self::Indices(indices.clone()), } } } impl From> for EdgeIndicesComparisonOperand { - fn from(value: Wrapper) -> Self { - Self::Operand(value.0.read_or_panic().deep_clone()) + fn from(indices: Wrapper) -> Self { + Self::Operand(indices.0.read_or_panic().deep_clone()) } } impl From<&Wrapper> for EdgeIndicesComparisonOperand { - fn from(value: &Wrapper) -> Self { - Self::Operand(value.0.read_or_panic().deep_clone()) + fn from(indices: &Wrapper) -> Self { + Self::Operand(indices.0.read_or_panic().deep_clone()) } } impl> From> for EdgeIndicesComparisonOperand { - fn from(value: Vec) -> Self { - Self::Indices(value.into_iter().map(Into::into).collect()) + fn from(indices: Vec) -> Self { + Self::Indices(indices.into_iter().map(Into::into).collect()) } } impl + Clone, const N: usize> From<[V; N]> for EdgeIndicesComparisonOperand { - fn from(value: [V; N]) -> Self { - value.to_vec().into() + fn from(indices: [V; N]) -> Self { + indices.to_vec().into() } } @@ -373,54 +373,54 @@ impl EdgeIndicesOperand { pub(crate) fn evaluate<'a>( &self, medrecord: &'a MedRecord, - values: impl Iterator + 'a, + indices: impl Iterator + 'a, ) -> MedRecordResult + 'a> { - let values = Box::new(values) as BoxedIterator; + let indices = Box::new(indices) as BoxedIterator; self.operations .iter() - .try_fold(values, |value_tuples, operation| { - operation.evaluate(medrecord, value_tuples) + .try_fold(indices, |index_tuples, operation| { + operation.evaluate(medrecord, index_tuples) }) } - implement_value_operation!(max, Max); - implement_value_operation!(min, Min); - implement_value_operation!(count, Count); - implement_value_operation!(sum, Sum); - implement_value_operation!(first, First); - implement_value_operation!(last, Last); + implement_index_operation!(max, Max); + implement_index_operation!(min, Min); + implement_index_operation!(count, Count); + implement_index_operation!(sum, Sum); + implement_index_operation!(first, First); + implement_index_operation!(last, Last); - implement_single_value_comparison_operation!(greater_than, EdgeIndicesOperation, GreaterThan); - implement_single_value_comparison_operation!( + implement_single_index_comparison_operation!(greater_than, EdgeIndicesOperation, GreaterThan); + implement_single_index_comparison_operation!( greater_than_or_equal_to, EdgeIndicesOperation, GreaterThanOrEqualTo ); - implement_single_value_comparison_operation!(less_than, EdgeIndicesOperation, LessThan); - implement_single_value_comparison_operation!( + implement_single_index_comparison_operation!(less_than, EdgeIndicesOperation, LessThan); + implement_single_index_comparison_operation!( less_than_or_equal_to, EdgeIndicesOperation, LessThanOrEqualTo ); - implement_single_value_comparison_operation!(equal_to, EdgeIndicesOperation, EqualTo); - implement_single_value_comparison_operation!(not_equal_to, EdgeIndicesOperation, NotEqualTo); - implement_single_value_comparison_operation!(starts_with, EdgeIndicesOperation, StartsWith); - implement_single_value_comparison_operation!(ends_with, EdgeIndicesOperation, EndsWith); - implement_single_value_comparison_operation!(contains, EdgeIndicesOperation, Contains); + implement_single_index_comparison_operation!(equal_to, EdgeIndicesOperation, EqualTo); + implement_single_index_comparison_operation!(not_equal_to, EdgeIndicesOperation, NotEqualTo); + implement_single_index_comparison_operation!(starts_with, EdgeIndicesOperation, StartsWith); + implement_single_index_comparison_operation!(ends_with, EdgeIndicesOperation, EndsWith); + implement_single_index_comparison_operation!(contains, EdgeIndicesOperation, Contains); - pub fn is_in>(&mut self, values: V) { + pub fn is_in>(&mut self, indices: V) { self.operations .push(EdgeIndicesOperation::EdgeIndicesComparisonOperation { - operand: values.into(), + operand: indices.into(), kind: MultipleComparisonKind::IsIn, }); } - pub fn is_not_in>(&mut self, values: V) { + pub fn is_not_in>(&mut self, indices: V) { self.operations .push(EdgeIndicesOperation::EdgeIndicesComparisonOperation { - operand: values.into(), + operand: indices.into(), kind: MultipleComparisonKind::IsNotIn, }); } @@ -460,9 +460,9 @@ impl Wrapper { pub(crate) fn evaluate<'a>( &self, medrecord: &'a MedRecord, - values: impl Iterator + 'a, + indices: impl Iterator + 'a, ) -> MedRecordResult + 'a> { - self.0.read_or_panic().evaluate(medrecord, values) + self.0.read_or_panic().evaluate(medrecord, indices) } implement_wrapper_operand_with_return!(max, EdgeIndexOperand); @@ -536,49 +536,49 @@ impl EdgeIndexOperand { pub(crate) fn evaluate( &self, medrecord: &MedRecord, - value: EdgeIndex, + index: EdgeIndex, ) -> MedRecordResult> { self.operations .iter() - .try_fold(Some(value), |value, operation| { - if let Some(value) = value { - operation.evaluate(medrecord, value) + .try_fold(Some(index), |index, operation| { + if let Some(index) = index { + operation.evaluate(medrecord, index) } else { Ok(None) } }) } - implement_single_value_comparison_operation!(greater_than, EdgeIndexOperation, GreaterThan); - implement_single_value_comparison_operation!( + implement_single_index_comparison_operation!(greater_than, EdgeIndexOperation, GreaterThan); + implement_single_index_comparison_operation!( greater_than_or_equal_to, EdgeIndexOperation, GreaterThanOrEqualTo ); - implement_single_value_comparison_operation!(less_than, EdgeIndexOperation, LessThan); - implement_single_value_comparison_operation!( + implement_single_index_comparison_operation!(less_than, EdgeIndexOperation, LessThan); + implement_single_index_comparison_operation!( less_than_or_equal_to, EdgeIndexOperation, LessThanOrEqualTo ); - implement_single_value_comparison_operation!(equal_to, EdgeIndexOperation, EqualTo); - implement_single_value_comparison_operation!(not_equal_to, EdgeIndexOperation, NotEqualTo); - implement_single_value_comparison_operation!(starts_with, EdgeIndexOperation, StartsWith); - implement_single_value_comparison_operation!(ends_with, EdgeIndexOperation, EndsWith); - implement_single_value_comparison_operation!(contains, EdgeIndexOperation, Contains); + implement_single_index_comparison_operation!(equal_to, EdgeIndexOperation, EqualTo); + implement_single_index_comparison_operation!(not_equal_to, EdgeIndexOperation, NotEqualTo); + implement_single_index_comparison_operation!(starts_with, EdgeIndexOperation, StartsWith); + implement_single_index_comparison_operation!(ends_with, EdgeIndexOperation, EndsWith); + implement_single_index_comparison_operation!(contains, EdgeIndexOperation, Contains); - pub fn is_in>(&mut self, values: V) { + pub fn is_in>(&mut self, indices: V) { self.operations .push(EdgeIndexOperation::EdgeIndicesComparisonOperation { - operand: values.into(), + operand: indices.into(), kind: MultipleComparisonKind::IsIn, }); } - pub fn is_not_in>(&mut self, values: V) { + pub fn is_not_in>(&mut self, indices: V) { self.operations .push(EdgeIndexOperation::EdgeIndicesComparisonOperation { - operand: values.into(), + operand: indices.into(), kind: MultipleComparisonKind::IsNotIn, }); } @@ -617,9 +617,9 @@ impl Wrapper { pub(crate) fn evaluate( &self, medrecord: &MedRecord, - value: EdgeIndex, + index: EdgeIndex, ) -> MedRecordResult> { - self.0.read_or_panic().evaluate(medrecord, value) + self.0.read_or_panic().evaluate(medrecord, index) } implement_wrapper_operand_with_argument!(greater_than, impl Into); diff --git a/crates/medmodels-core/src/medrecord/querying/mod.rs b/crates/medmodels-core/src/medrecord/querying/mod.rs index 94728fe4..0096f87e 100644 --- a/crates/medmodels-core/src/medrecord/querying/mod.rs +++ b/crates/medmodels-core/src/medrecord/querying/mod.rs @@ -1,7 +1,7 @@ pub mod attributes; pub mod edges; pub mod nodes; -mod traits; +pub mod traits; pub mod values; pub mod wrapper; diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/mod.rs b/crates/medmodels-core/src/medrecord/querying/nodes/mod.rs index 1041a7e9..4714ccd4 100644 --- a/crates/medmodels-core/src/medrecord/querying/nodes/mod.rs +++ b/crates/medmodels-core/src/medrecord/querying/nodes/mod.rs @@ -2,8 +2,11 @@ mod operand; mod operation; mod selection; -pub use operand::NodeOperand; -pub use operation::NodeOperation; +pub use operand::{ + NodeIndexComparisonOperand, NodeIndexOperand, NodeIndicesComparisonOperand, NodeIndicesOperand, + NodeOperand, +}; +pub use operation::{EdgeDirection, NodeOperation}; pub use selection::NodeSelection; use std::fmt::Display; diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs b/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs index 1800bc00..17eccb6d 100644 --- a/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs +++ b/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs @@ -171,8 +171,11 @@ impl Wrapper { self.0.read_or_panic().evaluate(medrecord) } - pub fn attribute(&mut self, attribute: MedRecordAttribute) -> Wrapper { - self.0.write_or_panic().attribute(attribute) + pub fn attribute(&mut self, attribute: A) -> Wrapper + where + A: Into, + { + self.0.write_or_panic().attribute(attribute.into()) } pub fn attributes(&mut self) -> Wrapper { @@ -218,7 +221,7 @@ impl Wrapper { } } -macro_rules! implement_value_operation { +macro_rules! implement_index_operation { ($name:ident, $variant:ident) => { pub fn $name(&mut self) -> Wrapper { let operand = Wrapper::::new(self.deep_clone(), SingleKind::$variant); @@ -233,12 +236,12 @@ macro_rules! implement_value_operation { }; } -macro_rules! implement_single_value_comparison_operation { +macro_rules! implement_single_index_comparison_operation { ($name:ident, $operation:ident, $kind:ident) => { - pub fn $name>(&mut self, value: V) { + pub fn $name>(&mut self, index: V) { self.operations .push($operation::NodeIndexComparisonOperation { - operand: value.into(), + operand: index.into(), kind: SingleComparisonKind::$kind, }); } @@ -247,9 +250,9 @@ macro_rules! implement_single_value_comparison_operation { macro_rules! implement_binary_arithmetic_operation { ($name:ident, $operation:ident, $kind:ident) => { - pub fn $name>(&mut self, value: V) { + pub fn $name>(&mut self, index: V) { self.operations.push($operation::BinaryArithmeticOpration { - operand: value.into(), + operand: index.into(), kind: BinaryArithmeticKind::$kind, }); } @@ -291,9 +294,9 @@ macro_rules! implement_wrapper_operand_with_return { } macro_rules! implement_wrapper_operand_with_argument { - ($name:ident, $value_type:ty) => { - pub fn $name(&self, value: $value_type) { - self.0.write_or_panic().$name(value) + ($name:ident, $index_type:ty) => { + pub fn $name(&self, index: $index_type) { + self.0.write_or_panic().$name(index) } }; } @@ -308,26 +311,26 @@ impl DeepClone for NodeIndexComparisonOperand { fn deep_clone(&self) -> Self { match self { Self::Operand(operand) => Self::Operand(operand.deep_clone()), - Self::Index(value) => Self::Index(value.clone()), + Self::Index(index) => Self::Index(index.clone()), } } } impl From> for NodeIndexComparisonOperand { - fn from(value: Wrapper) -> Self { - Self::Operand(value.0.read_or_panic().deep_clone()) + fn from(index: Wrapper) -> Self { + Self::Operand(index.0.read_or_panic().deep_clone()) } } impl From<&Wrapper> for NodeIndexComparisonOperand { - fn from(value: &Wrapper) -> Self { - Self::Operand(value.0.read_or_panic().deep_clone()) + fn from(index: &Wrapper) -> Self { + Self::Operand(index.0.read_or_panic().deep_clone()) } } impl> From for NodeIndexComparisonOperand { - fn from(value: V) -> Self { - Self::Index(value.into()) + fn from(index: V) -> Self { + Self::Index(index.into()) } } @@ -341,32 +344,32 @@ impl DeepClone for NodeIndicesComparisonOperand { fn deep_clone(&self) -> Self { match self { Self::Operand(operand) => Self::Operand(operand.deep_clone()), - Self::Indices(value) => Self::Indices(value.clone()), + Self::Indices(indices) => Self::Indices(indices.clone()), } } } impl From> for NodeIndicesComparisonOperand { - fn from(value: Wrapper) -> Self { - Self::Operand(value.0.read_or_panic().deep_clone()) + fn from(indices: Wrapper) -> Self { + Self::Operand(indices.0.read_or_panic().deep_clone()) } } impl From<&Wrapper> for NodeIndicesComparisonOperand { - fn from(value: &Wrapper) -> Self { - Self::Operand(value.0.read_or_panic().deep_clone()) + fn from(indices: &Wrapper) -> Self { + Self::Operand(indices.0.read_or_panic().deep_clone()) } } impl> From> for NodeIndicesComparisonOperand { - fn from(value: Vec) -> Self { - Self::Indices(value.into_iter().map(Into::into).collect()) + fn from(indices: Vec) -> Self { + Self::Indices(indices.into_iter().map(Into::into).collect()) } } impl + Clone, const N: usize> From<[V; N]> for NodeIndicesComparisonOperand { - fn from(value: [V; N]) -> Self { - value.to_vec().into() + fn from(indices: [V; N]) -> Self { + indices.to_vec().into() } } @@ -396,54 +399,54 @@ impl NodeIndicesOperand { pub(crate) fn evaluate<'a>( &self, medrecord: &'a MedRecord, - values: impl Iterator + 'a, + indices: impl Iterator + 'a, ) -> MedRecordResult + 'a> { - let values = Box::new(values) as BoxedIterator; + let indices = Box::new(indices) as BoxedIterator; self.operations .iter() - .try_fold(values, |value_tuples, operation| { - operation.evaluate(medrecord, value_tuples) + .try_fold(indices, |index_tuples, operation| { + operation.evaluate(medrecord, index_tuples) }) } - implement_value_operation!(max, Max); - implement_value_operation!(min, Min); - implement_value_operation!(count, Count); - implement_value_operation!(sum, Sum); - implement_value_operation!(first, First); - implement_value_operation!(last, Last); + implement_index_operation!(max, Max); + implement_index_operation!(min, Min); + implement_index_operation!(count, Count); + implement_index_operation!(sum, Sum); + implement_index_operation!(first, First); + implement_index_operation!(last, Last); - implement_single_value_comparison_operation!(greater_than, NodeIndicesOperation, GreaterThan); - implement_single_value_comparison_operation!( + implement_single_index_comparison_operation!(greater_than, NodeIndicesOperation, GreaterThan); + implement_single_index_comparison_operation!( greater_than_or_equal_to, NodeIndicesOperation, GreaterThanOrEqualTo ); - implement_single_value_comparison_operation!(less_than, NodeIndicesOperation, LessThan); - implement_single_value_comparison_operation!( + implement_single_index_comparison_operation!(less_than, NodeIndicesOperation, LessThan); + implement_single_index_comparison_operation!( less_than_or_equal_to, NodeIndicesOperation, LessThanOrEqualTo ); - implement_single_value_comparison_operation!(equal_to, NodeIndicesOperation, EqualTo); - implement_single_value_comparison_operation!(not_equal_to, NodeIndicesOperation, NotEqualTo); - implement_single_value_comparison_operation!(starts_with, NodeIndicesOperation, StartsWith); - implement_single_value_comparison_operation!(ends_with, NodeIndicesOperation, EndsWith); - implement_single_value_comparison_operation!(contains, NodeIndicesOperation, Contains); + implement_single_index_comparison_operation!(equal_to, NodeIndicesOperation, EqualTo); + implement_single_index_comparison_operation!(not_equal_to, NodeIndicesOperation, NotEqualTo); + implement_single_index_comparison_operation!(starts_with, NodeIndicesOperation, StartsWith); + implement_single_index_comparison_operation!(ends_with, NodeIndicesOperation, EndsWith); + implement_single_index_comparison_operation!(contains, NodeIndicesOperation, Contains); - pub fn is_in>(&mut self, values: V) { + pub fn is_in>(&mut self, indices: V) { self.operations .push(NodeIndicesOperation::NodeIndicesComparisonOperation { - operand: values.into(), + operand: indices.into(), kind: MultipleComparisonKind::IsIn, }); } - pub fn is_not_in>(&mut self, values: V) { + pub fn is_not_in>(&mut self, indices: V) { self.operations .push(NodeIndicesOperation::NodeIndicesComparisonOperation { - operand: values.into(), + operand: indices.into(), kind: MultipleComparisonKind::IsNotIn, }); } @@ -497,9 +500,9 @@ impl Wrapper { pub(crate) fn evaluate<'a>( &self, medrecord: &'a MedRecord, - values: impl Iterator + 'a, + indices: impl Iterator + 'a, ) -> MedRecordResult + 'a> { - self.0.read_or_panic().evaluate(medrecord, values) + self.0.read_or_panic().evaluate(medrecord, indices) } implement_wrapper_operand_with_return!(max, NodeIndexOperand); @@ -586,49 +589,49 @@ impl NodeIndexOperand { pub(crate) fn evaluate( &self, medrecord: &MedRecord, - value: NodeIndex, + index: NodeIndex, ) -> MedRecordResult> { self.operations .iter() - .try_fold(Some(value), |value, operation| { - if let Some(value) = value { - operation.evaluate(medrecord, value) + .try_fold(Some(index), |index, operation| { + if let Some(index) = index { + operation.evaluate(medrecord, index) } else { Ok(None) } }) } - implement_single_value_comparison_operation!(greater_than, NodeIndexOperation, GreaterThan); - implement_single_value_comparison_operation!( + implement_single_index_comparison_operation!(greater_than, NodeIndexOperation, GreaterThan); + implement_single_index_comparison_operation!( greater_than_or_equal_to, NodeIndexOperation, GreaterThanOrEqualTo ); - implement_single_value_comparison_operation!(less_than, NodeIndexOperation, LessThan); - implement_single_value_comparison_operation!( + implement_single_index_comparison_operation!(less_than, NodeIndexOperation, LessThan); + implement_single_index_comparison_operation!( less_than_or_equal_to, NodeIndexOperation, LessThanOrEqualTo ); - implement_single_value_comparison_operation!(equal_to, NodeIndexOperation, EqualTo); - implement_single_value_comparison_operation!(not_equal_to, NodeIndexOperation, NotEqualTo); - implement_single_value_comparison_operation!(starts_with, NodeIndexOperation, StartsWith); - implement_single_value_comparison_operation!(ends_with, NodeIndexOperation, EndsWith); - implement_single_value_comparison_operation!(contains, NodeIndexOperation, Contains); + implement_single_index_comparison_operation!(equal_to, NodeIndexOperation, EqualTo); + implement_single_index_comparison_operation!(not_equal_to, NodeIndexOperation, NotEqualTo); + implement_single_index_comparison_operation!(starts_with, NodeIndexOperation, StartsWith); + implement_single_index_comparison_operation!(ends_with, NodeIndexOperation, EndsWith); + implement_single_index_comparison_operation!(contains, NodeIndexOperation, Contains); - pub fn is_in>(&mut self, values: V) { + pub fn is_in>(&mut self, indices: V) { self.operations .push(NodeIndexOperation::NodeIndicesComparisonOperation { - operand: values.into(), + operand: indices.into(), kind: MultipleComparisonKind::IsIn, }); } - pub fn is_not_in>(&mut self, values: V) { + pub fn is_not_in>(&mut self, indices: V) { self.operations .push(NodeIndexOperation::NodeIndicesComparisonOperation { - operand: values.into(), + operand: indices.into(), kind: MultipleComparisonKind::IsNotIn, }); } @@ -680,9 +683,9 @@ impl Wrapper { pub(crate) fn evaluate( &self, medrecord: &MedRecord, - value: NodeIndex, + index: NodeIndex, ) -> MedRecordResult> { - self.0.read_or_panic().evaluate(medrecord, value) + self.0.read_or_panic().evaluate(medrecord, index) } implement_wrapper_operand_with_argument!(greater_than, impl Into); diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs b/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs index 90e6692c..fc06cf39 100644 --- a/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs +++ b/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs @@ -591,11 +591,11 @@ impl NodeIndicesOperation { pub(crate) fn get_sum( mut indices: impl Iterator, ) -> MedRecordResult { - let first_value = indices + let first_index = indices .next() .ok_or(MedRecordError::QueryError("No indices to sum".to_string()))?; - indices.try_fold(first_value, |sum, index| { + indices.try_fold(first_index, |sum, index| { let first_dtype = DataType::from(&sum); let second_dtype = DataType::from(&index); diff --git a/crates/medmodels-core/src/medrecord/querying/values/mod.rs b/crates/medmodels-core/src/medrecord/querying/values/mod.rs index bf2e2f4a..893fa7c9 100644 --- a/crates/medmodels-core/src/medrecord/querying/values/mod.rs +++ b/crates/medmodels-core/src/medrecord/querying/values/mod.rs @@ -14,7 +14,10 @@ use crate::{ medrecord::{MedRecordAttribute, MedRecordValue}, MedRecord, }; -pub use operand::MultipleValuesOperand; +pub use operand::{ + MultipleValuesComparisonOperand, MultipleValuesOperand, SingleValueComparisonOperand, + SingleValueOperand, +}; use std::fmt::Display; macro_rules! get_attributes { diff --git a/medmodels/_medmodels.pyi b/medmodels/_medmodels.pyi index 4762b4a5..3bf73e47 100644 --- a/medmodels/_medmodels.pyi +++ b/medmodels/_medmodels.pyi @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Sequence, Union from medmodels.medrecord.types import ( Attributes, @@ -28,13 +28,6 @@ if TYPE_CHECKING: else: from typing_extensions import TypeAlias -ValueOperand: TypeAlias = Union[ - MedRecordValue, - MedRecordAttribute, - PyValueArithmeticOperation, - PyValueTransformationOperation, -] - PyDataType: TypeAlias = Union[ PyString, PyInt, @@ -234,171 +227,431 @@ class PyMedRecord: self, node_indices: NodeIndexInputList ) -> Dict[NodeIndex, List[NodeIndex]]: ... def clear(self) -> None: ... - def select_nodes(self, operation: PyNodeOperation) -> List[NodeIndex]: ... - def select_edges(self, operation: PyEdgeOperation) -> List[EdgeIndex]: ... + def select_nodes( + self, query: Callable[[PyNodeOperand], None] + ) -> List[NodeIndex]: ... + def select_edges( + self, query: Callable[[PyNodeOperand], None] + ) -> List[EdgeIndex]: ... def clone(self) -> PyMedRecord: ... -class PyValueArithmeticOperation: ... -class PyValueTransformationOperation: ... - -class PyNodeOperation: - def logical_and(self, operation: PyNodeOperation) -> PyNodeOperation: ... - def logical_or(self, operation: PyNodeOperation) -> PyNodeOperation: ... - def logical_xor(self, operation: PyNodeOperation) -> PyNodeOperation: ... - def logical_not(self) -> PyNodeOperation: ... - -class PyEdgeOperation: - def logical_and(self, operation: PyEdgeOperation) -> PyEdgeOperation: ... - def logical_or(self, operation: PyEdgeOperation) -> PyEdgeOperation: ... - def logical_xor(self, operation: PyEdgeOperation) -> PyEdgeOperation: ... - def logical_not(self) -> PyEdgeOperation: ... - -class PyNodeAttributeOperand: - def greater( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def less( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def greater_or_equal( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def less_or_equal( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def equal( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def not_equal( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def is_in(self, operands: List[MedRecordValue]) -> PyNodeOperation: ... - def not_in(self, operands: List[MedRecordValue]) -> PyNodeOperation: ... - def starts_with( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def ends_with( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def contains( - self, operand: Union[ValueOperand, PyNodeAttributeOperand] - ) -> PyNodeOperation: ... - def add(self, value: MedRecordValue) -> ValueOperand: ... - def sub(self, value: MedRecordValue) -> ValueOperand: ... - def mul(self, value: MedRecordValue) -> ValueOperand: ... - def div(self, value: MedRecordValue) -> ValueOperand: ... - def pow(self, value: MedRecordValue) -> ValueOperand: ... - def mod(self, value: MedRecordValue) -> ValueOperand: ... - def round(self) -> ValueOperand: ... - def ceil(self) -> ValueOperand: ... - def floor(self) -> ValueOperand: ... - def abs(self) -> ValueOperand: ... - def sqrt(self) -> ValueOperand: ... - def trim(self) -> ValueOperand: ... - def trim_start(self) -> ValueOperand: ... - def trim_end(self) -> ValueOperand: ... - def lowercase(self) -> ValueOperand: ... - def uppercase(self) -> ValueOperand: ... - def slice(self, start: int, end: int) -> ValueOperand: ... - -class PyEdgeAttributeOperand: - def greater( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def less( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def greater_or_equal( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def less_or_equal( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def equal( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def not_equal( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def is_in(self, operands: List[MedRecordValue]) -> PyEdgeOperation: ... - def not_in(self, operands: List[MedRecordValue]) -> PyEdgeOperation: ... - def starts_with( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def ends_with( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def contains( - self, operand: Union[ValueOperand, PyEdgeAttributeOperand] - ) -> PyEdgeOperation: ... - def add(self, value: MedRecordValue) -> ValueOperand: ... - def sub(self, value: MedRecordValue) -> ValueOperand: ... - def mul(self, value: MedRecordValue) -> ValueOperand: ... - def div(self, value: MedRecordValue) -> ValueOperand: ... - def pow(self, value: MedRecordValue) -> ValueOperand: ... - def mod(self, value: MedRecordValue) -> ValueOperand: ... - def round(self) -> ValueOperand: ... - def ceil(self) -> ValueOperand: ... - def floor(self) -> ValueOperand: ... - def abs(self) -> ValueOperand: ... - def sqrt(self) -> ValueOperand: ... - def trim(self) -> ValueOperand: ... - def trim_start(self) -> ValueOperand: ... - def trim_end(self) -> ValueOperand: ... - def lowercase(self) -> ValueOperand: ... - def uppercase(self) -> ValueOperand: ... - def slice(self, start: int, end: int) -> ValueOperand: ... +class PyEdgeDirection(Enum): + Incoming = 0 + Outgoing = 1 + Both = 2 + +class PyNodeOperand: + def attribute(self, attribute: MedRecordAttribute) -> PyMultipleValuesOperand: ... + def attributes(self) -> PyAttributesTreeOperand: ... + def index(self) -> PyNodeIndicesOperand: ... + def in_group(self, group: Union[Group, List[Group]]) -> None: ... + def has_attribute( + self, attribute: Union[MedRecordAttribute, List[MedRecordAttribute]] + ) -> None: ... + def outgoing_edges(self) -> PyEdgeOperand: ... + def incoming_edges(self) -> PyEdgeOperand: ... + def neighbors(self, direction: PyEdgeDirection) -> PyNodeOperand: ... + def either_or( + self, + either: Callable[[PyNodeOperand], None], + or_: Callable[[PyNodeOperand], None], + ) -> None: ... + def deep_clone(self) -> PyNodeOperand: ... + +PyNodeIndexComparisonOperand: TypeAlias = Union[NodeIndex, PyNodeIndexOperand] +PyNodeIndexArithmeticOperand: TypeAlias = PyNodeIndexComparisonOperand +PyNodeIndicesComparisonOperand: TypeAlias = Union[List[NodeIndex], PyNodeIndicesOperand] + +class PyNodeIndicesOperand: + def max(self) -> PyNodeIndexOperand: ... + def min(self) -> PyNodeIndexOperand: ... + def count(self) -> PyNodeIndexOperand: ... + def sum(self) -> PyNodeIndexOperand: ... + def first(self) -> PyNodeIndexOperand: ... + def last(self) -> PyNodeIndexOperand: ... + def greater_than(self, index: PyNodeIndexComparisonOperand) -> None: ... + def greater_than_or_equal_to(self, index: PyNodeIndexComparisonOperand) -> None: ... + def less_than(self, index: PyNodeIndexComparisonOperand) -> None: ... + def less_than_or_equal_to(self, index: PyNodeIndexComparisonOperand) -> None: ... + def equal_to(self, index: PyNodeIndexComparisonOperand) -> None: ... + def not_equal_to(self, index: PyNodeIndexComparisonOperand) -> None: ... + def starts_with(self, index: PyNodeIndexComparisonOperand) -> None: ... + def ends_with(self, index: PyNodeIndexComparisonOperand) -> None: ... + def contains(self, index: PyNodeIndexComparisonOperand) -> None: ... + def is_in(self, indices: PyNodeIndicesComparisonOperand) -> None: ... + def is_not_in(self, indices: PyNodeIndicesComparisonOperand) -> None: ... + def add(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def sub(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def mul(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def pow(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def mod(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def abs(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def is_max(self) -> None: ... + def is_min(self) -> None: ... + def either_or( + self, + either: Callable[[PyNodeIndicesOperand], None], + or_: Callable[[PyNodeIndicesOperand], None], + ) -> None: ... + def deep_clone(self) -> PyNodeIndicesOperand: ... class PyNodeIndexOperand: - def greater(self, operand: NodeIndex) -> PyNodeOperation: ... - def less(self, operand: NodeIndex) -> PyNodeOperation: ... - def greater_or_equal(self, operand: NodeIndex) -> PyNodeOperation: ... - def less_or_equal(self, operand: NodeIndex) -> PyNodeOperation: ... - def equal(self, operand: NodeIndex) -> PyNodeOperation: ... - def not_equal(self, operand: NodeIndex) -> PyNodeOperation: ... - def is_in(self, operand: List[NodeIndex]) -> PyNodeOperation: ... - def not_in(self, operand: List[NodeIndex]) -> PyNodeOperation: ... - def starts_with(self, operand: NodeIndex) -> PyNodeOperation: ... - def ends_with(self, operand: NodeIndex) -> PyNodeOperation: ... - def contains(self, operand: NodeIndex) -> PyNodeOperation: ... + def greater_than(self, index: PyNodeIndexComparisonOperand) -> None: ... + def greater_than_or_equal_to(self, index: PyNodeIndexComparisonOperand) -> None: ... + def less_than(self, index: PyNodeIndexComparisonOperand) -> None: ... + def less_than_or_equal_to(self, index: PyNodeIndexComparisonOperand) -> None: ... + def equal_to(self, index: PyNodeIndexComparisonOperand) -> None: ... + def not_equal_to(self, index: PyNodeIndexComparisonOperand) -> None: ... + def starts_with(self, index: PyNodeIndexComparisonOperand) -> None: ... + def ends_with(self, index: PyNodeIndexComparisonOperand) -> None: ... + def contains(self, index: PyNodeIndexComparisonOperand) -> None: ... + def is_in(self, indices: PyNodeIndicesComparisonOperand) -> None: ... + def is_not_in(self, indices: PyNodeIndicesComparisonOperand) -> None: ... + def add(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def sub(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def mul(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def pow(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def mod(self, index: PyNodeIndexArithmeticOperand) -> None: ... + def abs(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def either_or( + self, + either: Callable[[PyNodeIndexOperand], None], + or_: Callable[[PyNodeIndexOperand], None], + ) -> None: ... + def deep_clone(self) -> PyNodeIndexOperand: ... + +class PyEdgeOperand: + def attribute(self, attribute: MedRecordAttribute) -> PyMultipleValuesOperand: ... + def attributes(self) -> PyAttributesTreeOperand: ... + def index(self) -> PyEdgeIndicesOperand: ... + def in_group(self, group: Union[Group, List[Group]]) -> None: ... + def has_attribute( + self, attribute: Union[MedRecordAttribute, List[MedRecordAttribute]] + ) -> None: ... + def source_node(self) -> PyNodeOperand: ... + def target_node(self) -> PyNodeOperand: ... + def either_or( + self, + either: Callable[[PyEdgeOperand], None], + or_: Callable[[PyEdgeOperand], None], + ) -> None: ... + def deep_clone(self) -> PyEdgeOperand: ... + +PyEdgeIndexComparisonOperand: TypeAlias = Union[EdgeIndex, PyEdgeIndexOperand] +PyEdgeIndexArithmeticOperand: TypeAlias = PyEdgeIndexComparisonOperand +PyEdgeIndicesComparisonOperand: TypeAlias = Union[List[EdgeIndex], PyEdgeIndicesOperand] + +class PyEdgeIndicesOperand: + def max(self) -> PyEdgeIndexOperand: ... + def min(self) -> PyEdgeIndexOperand: ... + def count(self) -> PyEdgeIndexOperand: ... + def sum(self) -> PyEdgeIndexOperand: ... + def first(self) -> PyEdgeIndexOperand: ... + def last(self) -> PyEdgeIndexOperand: ... + def greater_than(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def greater_than_or_equal_to(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def less_than(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def less_than_or_equal_to(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def equal_to(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def not_equal_to(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def starts_with(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def ends_with(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def contains(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def is_in(self, indices: PyEdgeIndicesComparisonOperand) -> None: ... + def is_not_in(self, indices: PyEdgeIndicesComparisonOperand) -> None: ... + def add(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def sub(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def mul(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def pow(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def mod(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def is_max(self) -> None: ... + def is_min(self) -> None: ... + def either_or( + self, + either: Callable[[PyEdgeIndicesOperand], None], + or_: Callable[[PyEdgeIndicesOperand], None], + ) -> None: ... + def deep_clone(self) -> PyEdgeIndicesOperand: ... class PyEdgeIndexOperand: - def greater(self, operand: EdgeIndex) -> PyEdgeOperation: ... - def less(self, operand: EdgeIndex) -> PyEdgeOperation: ... - def greater_or_equal(self, operand: EdgeIndex) -> PyEdgeOperation: ... - def less_or_equal(self, operand: EdgeIndex) -> PyEdgeOperation: ... - def equal(self, operand: EdgeIndex) -> PyEdgeOperation: ... - def not_equal(self, operand: EdgeIndex) -> PyEdgeOperation: ... - def is_in(self, operand: List[EdgeIndex]) -> PyEdgeOperation: ... - def not_in(self, operand: List[EdgeIndex]) -> PyEdgeOperation: ... + def greater_than(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def greater_than_or_equal_to(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def less_than(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def less_than_or_equal_to(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def equal_to(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def not_equal_to(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def starts_with(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def ends_with(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def contains(self, index: PyEdgeIndexComparisonOperand) -> None: ... + def is_in(self, indices: PyEdgeIndicesComparisonOperand) -> None: ... + def is_not_in(self, indices: PyEdgeIndicesComparisonOperand) -> None: ... + def add(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def sub(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def mul(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def pow(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def mod(self, index: PyEdgeIndexArithmeticOperand) -> None: ... + def either_or( + self, + either: Callable[[PyEdgeIndexOperand], None], + or_: Callable[[PyEdgeIndexOperand], None], + ) -> None: ... + def deep_clone(self) -> PyEdgeIndexOperand: ... -class PyNodeOperand: - def in_group(self, operand: Group) -> PyNodeOperation: ... - def has_attribute(self, operand: MedRecordAttribute) -> PyNodeOperation: ... - def has_outgoing_edge_with(self, operation: PyEdgeOperation) -> PyNodeOperation: ... - def has_incoming_edge_with(self, operation: PyEdgeOperation) -> PyNodeOperation: ... - def has_edge_with(self, operation: PyEdgeOperation) -> PyNodeOperation: ... - def has_neighbor_with(self, operation: PyNodeOperation) -> PyNodeOperation: ... - def has_neighbor_undirected_with( - self, operation: PyNodeOperation - ) -> PyNodeOperation: ... - def attribute(self, attribute: MedRecordAttribute) -> PyNodeAttributeOperand: ... - def index(self) -> PyNodeIndexOperand: ... +PySingleValueComparisonOperand: TypeAlias = Union[MedRecordValue, PySingleValueOperand] +PySingleValueArithmeticOperand: TypeAlias = PySingleValueComparisonOperand +PyMultipleValuesComparisonOperand: TypeAlias = Union[ + List[MedRecordValue], PyMultipleValuesOperand +] -class PyEdgeOperand: - def connected_target(self, operand: NodeIndex) -> PyEdgeOperation: ... - def connected_source(self, operand: NodeIndex) -> PyEdgeOperation: ... - def connected(self, operand: NodeIndex) -> PyEdgeOperation: ... - def in_group(self, operand: Group) -> PyEdgeOperation: ... - def has_attribute(self, operand: MedRecordAttribute) -> PyEdgeOperation: ... - def connected_source_with(self, operation: PyNodeOperation) -> PyEdgeOperation: ... - def connected_target_with(self, operation: PyNodeOperation) -> PyEdgeOperation: ... - def connected_with(self, operation: PyNodeOperation) -> PyEdgeOperation: ... - def has_parallel_edges_with( - self, operation: PyEdgeOperation - ) -> PyEdgeOperation: ... - def has_parallel_edges_with_self_comparison( - self, operation: PyEdgeOperation - ) -> PyEdgeOperation: ... - def attribute(self, attribute: MedRecordAttribute) -> PyEdgeAttributeOperand: ... - def index(self) -> PyEdgeIndexOperand: ... +class PyMultipleValuesOperand: + def max(self) -> PySingleValueOperand: ... + def min(self) -> PySingleValueOperand: ... + def mean(self) -> PySingleValueOperand: ... + def median(self) -> PySingleValueOperand: ... + def mode(self) -> PySingleValueOperand: ... + def std(self) -> PySingleValueOperand: ... + def var(self) -> PySingleValueOperand: ... + def count(self) -> PySingleValueOperand: ... + def sum(self) -> PySingleValueOperand: ... + def first(self) -> PySingleValueOperand: ... + def last(self) -> PySingleValueOperand: ... + def greater_than(self, value: PySingleValueComparisonOperand) -> None: ... + def greater_than_or_equal_to( + self, value: PySingleValueComparisonOperand + ) -> None: ... + def less_than(self, value: PySingleValueComparisonOperand) -> None: ... + def less_than_or_equal_to(self, value: PySingleValueComparisonOperand) -> None: ... + def equal_to(self, value: PySingleValueComparisonOperand) -> None: ... + def not_equal_to(self, value: PySingleValueComparisonOperand) -> None: ... + def starts_with(self, value: PySingleValueComparisonOperand) -> None: ... + def ends_with(self, value: PySingleValueComparisonOperand) -> None: ... + def contains(self, value: PySingleValueComparisonOperand) -> None: ... + def is_in(self, values: PyMultipleValuesComparisonOperand) -> None: ... + def is_not_in(self, values: PyMultipleValuesComparisonOperand) -> None: ... + def add(self, value: PySingleValueArithmeticOperand) -> None: ... + def sub(self, value: PySingleValueArithmeticOperand) -> None: ... + def mul(self, value: PySingleValueArithmeticOperand) -> None: ... + def div(self, value: PySingleValueArithmeticOperand) -> None: ... + def pow(self, value: PySingleValueArithmeticOperand) -> None: ... + def mod(self, value: PySingleValueArithmeticOperand) -> None: ... + def round(self) -> None: ... + def ceil(self) -> None: ... + def floor(self) -> None: ... + def abs(self) -> None: ... + def sqrt(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def is_float(self) -> None: ... + def is_bool(self) -> None: ... + def is_datetime(self) -> None: ... + def is_null(self) -> None: ... + def is_max(self) -> None: ... + def is_min(self) -> None: ... + def either_or( + self, + either: Callable[[PyMultipleValuesOperand], None], + or_: Callable[[PyMultipleValuesOperand], None], + ) -> None: ... + def deep_clone(self) -> PyMultipleValuesOperand: ... + +class PySingleValueOperand: + def greater_than(self, value: PySingleValueComparisonOperand) -> None: ... + def greater_than_or_equal_to( + self, value: PySingleValueComparisonOperand + ) -> None: ... + def less_than(self, value: PySingleValueComparisonOperand) -> None: ... + def less_than_or_equal_to(self, value: PySingleValueComparisonOperand) -> None: ... + def equal_to(self, value: PySingleValueComparisonOperand) -> None: ... + def not_equal_to(self, value: PySingleValueComparisonOperand) -> None: ... + def starts_with(self, value: PySingleValueComparisonOperand) -> None: ... + def ends_with(self, value: PySingleValueComparisonOperand) -> None: ... + def contains(self, value: PySingleValueComparisonOperand) -> None: ... + def is_in(self, values: PyMultipleValuesComparisonOperand) -> None: ... + def is_not_in(self, values: PyMultipleValuesComparisonOperand) -> None: ... + def add(self, value: PySingleValueArithmeticOperand) -> None: ... + def sub(self, value: PySingleValueArithmeticOperand) -> None: ... + def mul(self, value: PySingleValueArithmeticOperand) -> None: ... + def div(self, value: PySingleValueArithmeticOperand) -> None: ... + def pow(self, value: PySingleValueArithmeticOperand) -> None: ... + def mod(self, value: PySingleValueArithmeticOperand) -> None: ... + def round(self) -> None: ... + def ceil(self) -> None: ... + def floor(self) -> None: ... + def abs(self) -> None: ... + def sqrt(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def is_float(self) -> None: ... + def is_bool(self) -> None: ... + def is_datetime(self) -> None: ... + def is_null(self) -> None: ... + def either_or( + self, + either: Callable[[PySingleValueOperand], None], + or_: Callable[[PySingleValueOperand], None], + ) -> None: ... + def deep_clone(self) -> PySingleValueOperand: ... + +PySingleAttributeComparisonOperand: TypeAlias = Union[ + MedRecordAttribute, PySingleAttributeOperand +] +PySingleAttributeArithmeticOperand: TypeAlias = PySingleAttributeComparisonOperand +PyMultipleAttributesComparisonOperand: TypeAlias = Union[ + List[MedRecordAttribute], PyMultipleAttributesOperand +] + +class PyAttributesTreeOperand: + def max(self) -> PyMultipleAttributesOperand: ... + def min(self) -> PyMultipleAttributesOperand: ... + def count(self) -> PyMultipleAttributesOperand: ... + def sum(self) -> PyMultipleAttributesOperand: ... + def first(self) -> PyMultipleAttributesOperand: ... + def last(self) -> PyMultipleAttributesOperand: ... + def greater_than(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def greater_than_or_equal_to( + self, attribute: PySingleAttributeComparisonOperand + ) -> None: ... + def less_than(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def less_than_or_equal_to( + self, attribute: PySingleAttributeComparisonOperand + ) -> None: ... + def equal_to(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def not_equal_to(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def starts_with(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def ends_with(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def contains(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def is_in(self, attributes: PyMultipleAttributesComparisonOperand) -> None: ... + def is_not_in(self, attributes: PyMultipleAttributesComparisonOperand) -> None: ... + def add(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def sub(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def mul(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def pow(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def mod(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def abs(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def is_max(self) -> None: ... + def is_min(self) -> None: ... + def either_or( + self, + either: Callable[[PyAttributesTreeOperand], None], + or_: Callable[[PyAttributesTreeOperand], None], + ) -> None: ... + def deep_clone(self) -> PyAttributesTreeOperand: ... + +class PyMultipleAttributesOperand: + def max(self) -> PySingleAttributeOperand: ... + def min(self) -> PySingleAttributeOperand: ... + def count(self) -> PySingleAttributeOperand: ... + def sum(self) -> PySingleAttributeOperand: ... + def first(self) -> PySingleAttributeOperand: ... + def last(self) -> PySingleAttributeOperand: ... + def greater_than(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def greater_than_or_equal_to( + self, attribute: PySingleAttributeComparisonOperand + ) -> None: ... + def less_than(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def less_than_or_equal_to( + self, attribute: PySingleAttributeComparisonOperand + ) -> None: ... + def equal_to(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def not_equal_to(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def starts_with(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def ends_with(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def contains(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def is_in(self, attributes: PyMultipleAttributesComparisonOperand) -> None: ... + def is_not_in(self, attributes: PyMultipleAttributesComparisonOperand) -> None: ... + def add(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def sub(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def mul(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def pow(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def mod(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def abs(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def to_values(self) -> PyMultipleValuesOperand: ... + def slice(self, start: int, end: int) -> None: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def is_max(self) -> None: ... + def is_min(self) -> None: ... + def either_or( + self, + either: Callable[[PyMultipleAttributesOperand], None], + or_: Callable[[PyMultipleAttributesOperand], None], + ) -> None: ... + def deep_clone(self) -> PyMultipleAttributesOperand: ... + +class PySingleAttributeOperand: + def greater_than(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def greater_than_or_equal_to( + self, attribute: PySingleAttributeComparisonOperand + ) -> None: ... + def less_than(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def less_than_or_equal_to( + self, attribute: PySingleAttributeComparisonOperand + ) -> None: ... + def equal_to(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def not_equal_to(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def starts_with(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def ends_with(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def contains(self, attribute: PySingleAttributeComparisonOperand) -> None: ... + def is_in(self, attributes: PyMultipleAttributesComparisonOperand) -> None: ... + def is_not_in(self, attributes: PyMultipleAttributesComparisonOperand) -> None: ... + def add(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def sub(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def mul(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def pow(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def mod(self, attribute: PySingleAttributeArithmeticOperand) -> None: ... + def abs(self) -> None: ... + def trim(self) -> None: ... + def trim_start(self) -> None: ... + def trim_end(self) -> None: ... + def lowercase(self) -> None: ... + def uppercase(self) -> None: ... + def slice(self, start: int, end: int) -> None: ... + def is_string(self) -> None: ... + def is_int(self) -> None: ... + def either_or( + self, + either: Callable[[PySingleAttributeOperand], None], + or_: Callable[[PySingleAttributeOperand], None], + ) -> None: ... + def deep_clone(self) -> PySingleAttributeOperand: ... diff --git a/medmodels/medrecord/indexers.py b/medmodels/medrecord/indexers.py index 0496b1a8..33a0f7df 100644 --- a/medmodels/medrecord/indexers.py +++ b/medmodels/medrecord/indexers.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Tuple, Union, overload +from typing import TYPE_CHECKING, Callable, Dict, Tuple, Union, overload from medmodels.medrecord.querying import EdgeQuery, NodeQuery from medmodels.medrecord.types import ( @@ -87,7 +87,7 @@ def __getitem__( if isinstance(key, list): return self._medrecord._medrecord.node(key) - if isinstance(key, NodeQuery): + if isinstance(key, Callable): return self._medrecord._medrecord.node(self._medrecord.select_nodes(key)) if isinstance(key, slice): @@ -112,7 +112,7 @@ def __getitem__( return {x: attributes[x][attribute_selection] for x in attributes.keys()} - if isinstance(index_selection, NodeQuery) and is_medrecord_attribute( + if isinstance(index_selection, Callable) and is_medrecord_attribute( attribute_selection ): attributes = self._medrecord._medrecord.node( @@ -151,7 +151,7 @@ def __getitem__( for x in attributes.keys() } - if isinstance(index_selection, NodeQuery) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, list ): attributes = self._medrecord._medrecord.node( @@ -198,7 +198,7 @@ def __getitem__( return self._medrecord._medrecord.node(index_selection) - if isinstance(index_selection, NodeQuery) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, slice ): if ( @@ -270,7 +270,7 @@ def __setitem__( return self._medrecord._medrecord.replace_node_attributes(key, value) - if isinstance(key, NodeQuery): + if isinstance(key, Callable): if not is_attributes(value): raise ValueError("Invalid value type. Expected Attributes") @@ -311,7 +311,7 @@ def __setitem__( index_selection, attribute_selection, value ) - if isinstance(index_selection, NodeQuery) and is_medrecord_attribute( + if isinstance(index_selection, Callable) and is_medrecord_attribute( attribute_selection ): if not is_medrecord_value(value): @@ -364,7 +364,7 @@ def __setitem__( return - if isinstance(index_selection, NodeQuery) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, list ): if not is_medrecord_value(value): @@ -440,7 +440,7 @@ def __setitem__( return - if isinstance(index_selection, NodeQuery) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, slice ): if ( @@ -514,7 +514,7 @@ def __delitem__( index_selection, attribute_selection ) - if isinstance(index_selection, NodeQuery) and is_medrecord_attribute( + if isinstance(index_selection, Callable) and is_medrecord_attribute( attribute_selection ): return self._medrecord._medrecord.remove_node_attribute( @@ -553,7 +553,7 @@ def __delitem__( return - if isinstance(index_selection, NodeQuery) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, list ): for attribute in attribute_selection: @@ -602,7 +602,7 @@ def __delitem__( index_selection, {} ) - if isinstance(index_selection, NodeQuery) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, slice ): if ( @@ -697,7 +697,7 @@ def __getitem__( if isinstance(key, list): return self._medrecord._medrecord.edge(key) - if isinstance(key, EdgeQuery): + if isinstance(key, Callable): return self._medrecord._medrecord.edge(self._medrecord.select_edges(key)) if isinstance(key, slice): @@ -722,7 +722,7 @@ def __getitem__( return {x: attributes[x][attribute_selection] for x in attributes.keys()} - if isinstance(index_selection, EdgeQuery) and is_medrecord_attribute( + if isinstance(index_selection, Callable) and is_medrecord_attribute( attribute_selection ): attributes = self._medrecord._medrecord.edge( @@ -761,7 +761,7 @@ def __getitem__( for x in attributes.keys() } - if isinstance(index_selection, EdgeQuery) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, list ): attributes = self._medrecord._medrecord.edge( @@ -808,7 +808,7 @@ def __getitem__( return self._medrecord._medrecord.edge(index_selection) - if isinstance(index_selection, EdgeQuery) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, slice ): if ( @@ -880,7 +880,7 @@ def __setitem__( return self._medrecord._medrecord.replace_edge_attributes(key, value) - if isinstance(key, EdgeQuery): + if isinstance(key, Callable): if not is_attributes(value): raise ValueError("Invalid value type. Expected Attributes") @@ -921,7 +921,7 @@ def __setitem__( index_selection, attribute_selection, value ) - if isinstance(index_selection, EdgeQuery) and is_medrecord_attribute( + if isinstance(index_selection, Callable) and is_medrecord_attribute( attribute_selection ): if not is_medrecord_value(value): @@ -974,7 +974,7 @@ def __setitem__( return - if isinstance(index_selection, EdgeQuery) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, list ): if not is_medrecord_value(value): @@ -1048,7 +1048,7 @@ def __setitem__( return - if isinstance(index_selection, EdgeQuery) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, slice ): if ( @@ -1122,7 +1122,7 @@ def __delitem__( index_selection, attribute_selection ) - if isinstance(index_selection, EdgeQuery) and is_medrecord_attribute( + if isinstance(index_selection, Callable) and is_medrecord_attribute( attribute_selection ): return self._medrecord._medrecord.remove_edge_attribute( @@ -1161,7 +1161,7 @@ def __delitem__( return - if isinstance(index_selection, EdgeQuery) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, list ): for attribute in attribute_selection: @@ -1210,7 +1210,7 @@ def __delitem__( index_selection, {} ) - if isinstance(index_selection, EdgeQuery) and isinstance( + if isinstance(index_selection, Callable) and isinstance( attribute_selection, slice ): if ( diff --git a/rustmodels/src/lib.rs b/rustmodels/src/lib.rs index 751ce907..9e33e547 100644 --- a/rustmodels/src/lib.rs +++ b/rustmodels/src/lib.rs @@ -4,9 +4,12 @@ mod medrecord; use medrecord::{ datatype::{PyAny, PyBool, PyDateTime, PyFloat, PyInt, PyNull, PyOption, PyString, PyUnion}, querying::{ - PyEdgeAttributeOperand, PyEdgeIndexOperand, PyEdgeOperand, PyEdgeOperation, - PyNodeAttributeOperand, PyNodeIndexOperand, PyNodeOperand, PyNodeOperation, - PyValueArithmeticOperation, PyValueTransformationOperation, + attributes::{ + PyAttributesTreeOperand, PyMultipleAttributesOperand, PySingleAttributeOperand, + }, + edges::{PyEdgeIndexOperand, PyEdgeIndicesOperand, PyEdgeOperand}, + nodes::{PyNodeIndexOperand, PyNodeIndicesOperand, PyNodeOperand}, + values::{PyMultipleValuesOperand, PySingleValueOperand}, }, schema::{PyAttributeDataType, PyAttributeType, PyGroupSchema, PySchema}, PyMedRecord, @@ -32,20 +35,20 @@ fn _medmodels(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - - m.add_class::()?; - m.add_class::()?; - - m.add_class::()?; - m.add_class::()?; - + m.add_class::()?; + m.add_class::()?; m.add_class::()?; - m.add_class::()?; - m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + m.add_class::()?; + m.add_class::()?; + + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/rustmodels/src/medrecord/attribute.rs b/rustmodels/src/medrecord/attribute.rs index 9615dc6f..cc9fe513 100644 --- a/rustmodels/src/medrecord/attribute.rs +++ b/rustmodels/src/medrecord/attribute.rs @@ -6,7 +6,7 @@ use std::{hash::Hash, ops::Deref}; #[repr(transparent)] #[derive(PartialEq, Eq, Hash, Clone, Debug)] -pub(crate) struct PyMedRecordAttribute(MedRecordAttribute); +pub struct PyMedRecordAttribute(MedRecordAttribute); impl From for PyMedRecordAttribute { fn from(value: MedRecordAttribute) -> Self { diff --git a/rustmodels/src/medrecord/errors.rs b/rustmodels/src/medrecord/errors.rs index 6965791e..f96f9a32 100644 --- a/rustmodels/src/medrecord/errors.rs +++ b/rustmodels/src/medrecord/errors.rs @@ -21,6 +21,7 @@ impl From for PyErr { MedRecordError::ConversionError(message) => PyRuntimeError::new_err(message), MedRecordError::AssertionError(message) => PyAssertionError::new_err(message), MedRecordError::SchemaError(message) => PyValueError::new_err(message), + MedRecordError::QueryError(message) => PyRuntimeError::new_err(message), } } } diff --git a/rustmodels/src/medrecord/mod.rs b/rustmodels/src/medrecord/mod.rs index 059106ee..b6bba790 100644 --- a/rustmodels/src/medrecord/mod.rs +++ b/rustmodels/src/medrecord/mod.rs @@ -1,3 +1,5 @@ +#![allow(clippy::new_without_default)] + mod attribute; pub mod datatype; mod errors; @@ -13,9 +15,9 @@ use medmodels_core::{ errors::MedRecordError, medrecord::{Attributes, EdgeIndex, MedRecord, MedRecordAttribute, MedRecordValue}, }; -use pyo3::prelude::*; +use pyo3::{prelude::*, types::PyFunction}; use pyo3_polars::PyDataFrame; -use querying::{PyEdgeOperation, PyNodeOperation}; +use querying::{edges::PyEdgeOperand, nodes::PyNodeOperand}; use schema::PySchema; use std::collections::HashMap; use traits::DeepInto; @@ -33,17 +35,17 @@ pub struct PyMedRecord(MedRecord); #[pymethods] impl PyMedRecord { #[new] - fn new() -> Self { + pub fn new() -> Self { Self(MedRecord::new()) } #[staticmethod] - fn with_schema(schema: PySchema) -> Self { + pub fn with_schema(schema: PySchema) -> Self { Self(MedRecord::with_schema(schema.into())) } #[staticmethod] - fn from_tuples( + pub fn from_tuples( nodes: Vec<(PyNodeIndex, PyAttributes)>, edges: Option>, ) -> PyResult { @@ -54,7 +56,7 @@ impl PyMedRecord { } #[staticmethod] - fn from_dataframes( + pub fn from_dataframes( nodes_dataframes: Vec<(PyDataFrame, String)>, edges_dataframes: Vec<(PyDataFrame, String, String)>, ) -> PyResult { @@ -65,7 +67,7 @@ impl PyMedRecord { } #[staticmethod] - fn from_nodes_dataframes(nodes_dataframes: Vec<(PyDataFrame, String)>) -> PyResult { + pub fn from_nodes_dataframes(nodes_dataframes: Vec<(PyDataFrame, String)>) -> PyResult { Ok(Self( MedRecord::from_nodes_dataframes(nodes_dataframes, None) .map_err(PyMedRecordError::from)?, @@ -73,22 +75,22 @@ impl PyMedRecord { } #[staticmethod] - fn from_example_dataset() -> Self { + pub fn from_example_dataset() -> Self { Self(MedRecord::from_example_dataset()) } #[staticmethod] - fn from_ron(path: &str) -> PyResult { + pub fn from_ron(path: &str) -> PyResult { Ok(Self( MedRecord::from_ron(path).map_err(PyMedRecordError::from)?, )) } - fn to_ron(&self, path: &str) -> PyResult<()> { + pub fn to_ron(&self, path: &str) -> PyResult<()> { Ok(self.0.to_ron(path).map_err(PyMedRecordError::from)?) } - fn update_schema(&mut self, schema: PySchema) -> PyResult<()> { + pub fn update_schema(&mut self, schema: PySchema) -> PyResult<()> { Ok(self .0 .update_schema(schema.into()) @@ -96,19 +98,22 @@ impl PyMedRecord { } #[getter] - fn schema(&self) -> PySchema { + pub fn schema(&self) -> PySchema { self.0.get_schema().clone().into() } #[getter] - fn nodes(&self) -> Vec { + pub fn nodes(&self) -> Vec { self.0 .node_indices() .map(|node_index| node_index.clone().into()) .collect() } - fn node(&self, node_index: Vec) -> PyResult> { + pub fn node( + &self, + node_index: Vec, + ) -> PyResult> { node_index .into_iter() .map(|node_index| { @@ -123,11 +128,11 @@ impl PyMedRecord { } #[getter] - fn edges(&self) -> Vec { + pub fn edges(&self) -> Vec { self.0.edge_indices().copied().collect() } - fn edge(&self, edge_index: Vec) -> PyResult> { + pub fn edge(&self, edge_index: Vec) -> PyResult> { edge_index .into_iter() .map(|edge_index| { @@ -142,11 +147,11 @@ impl PyMedRecord { } #[getter] - fn groups(&self) -> Vec { + pub fn groups(&self) -> Vec { self.0.groups().map(|group| group.clone().into()).collect() } - fn outgoing_edges( + pub fn outgoing_edges( &self, node_index: Vec, ) -> PyResult>> { @@ -165,7 +170,7 @@ impl PyMedRecord { .collect() } - fn incoming_edges( + pub fn incoming_edges( &self, node_index: Vec, ) -> PyResult>> { @@ -184,7 +189,7 @@ impl PyMedRecord { .collect() } - fn edge_endpoints( + pub fn edge_endpoints( &self, edge_index: Vec, ) -> PyResult> { @@ -207,7 +212,7 @@ impl PyMedRecord { .collect() } - fn edges_connecting( + pub fn edges_connecting( &self, source_node_indices: Vec, target_node_indices: Vec, @@ -224,7 +229,7 @@ impl PyMedRecord { .collect() } - fn edges_connecting_undirected( + pub fn edges_connecting_undirected( &self, first_node_indices: Vec, second_node_indices: Vec, @@ -241,7 +246,7 @@ impl PyMedRecord { .collect() } - fn remove_nodes( + pub fn remove_nodes( &mut self, node_index: Vec, ) -> PyResult> { @@ -258,7 +263,7 @@ impl PyMedRecord { .collect() } - fn replace_node_attributes( + pub fn replace_node_attributes( &mut self, node_index: Vec, attributes: PyAttributes, @@ -277,7 +282,7 @@ impl PyMedRecord { Ok(()) } - fn update_node_attribute( + pub fn update_node_attribute( &mut self, node_index: Vec, attribute: PyMedRecordAttribute, @@ -301,7 +306,7 @@ impl PyMedRecord { Ok(()) } - fn remove_node_attribute( + pub fn remove_node_attribute( &mut self, node_index: Vec, attribute: PyMedRecordAttribute, @@ -325,14 +330,14 @@ impl PyMedRecord { Ok(()) } - fn add_nodes(&mut self, nodes: Vec<(PyNodeIndex, PyAttributes)>) -> PyResult<()> { + pub fn add_nodes(&mut self, nodes: Vec<(PyNodeIndex, PyAttributes)>) -> PyResult<()> { Ok(self .0 .add_nodes(nodes.deep_into()) .map_err(PyMedRecordError::from)?) } - fn add_nodes_dataframes( + pub fn add_nodes_dataframes( &mut self, nodes_dataframes: Vec<(PyDataFrame, String)>, ) -> PyResult<()> { @@ -342,7 +347,7 @@ impl PyMedRecord { .map_err(PyMedRecordError::from)?) } - fn remove_edges( + pub fn remove_edges( &mut self, edge_index: Vec, ) -> PyResult> { @@ -359,7 +364,7 @@ impl PyMedRecord { .collect() } - fn replace_edge_attributes( + pub fn replace_edge_attributes( &mut self, edge_index: Vec, attributes: PyAttributes, @@ -378,7 +383,7 @@ impl PyMedRecord { Ok(()) } - fn update_edge_attribute( + pub fn update_edge_attribute( &mut self, edge_index: Vec, attribute: PyMedRecordAttribute, @@ -399,7 +404,7 @@ impl PyMedRecord { Ok(()) } - fn remove_edge_attribute( + pub fn remove_edge_attribute( &mut self, edge_index: Vec, attribute: PyMedRecordAttribute, @@ -421,7 +426,7 @@ impl PyMedRecord { Ok(()) } - fn add_edges( + pub fn add_edges( &mut self, relations: Vec<(PyNodeIndex, PyNodeIndex, PyAttributes)>, ) -> PyResult> { @@ -431,7 +436,7 @@ impl PyMedRecord { .map_err(PyMedRecordError::from)?) } - fn add_edges_dataframes( + pub fn add_edges_dataframes( &mut self, edges_dataframes: Vec<(PyDataFrame, String, String)>, ) -> PyResult> { @@ -441,7 +446,7 @@ impl PyMedRecord { .map_err(PyMedRecordError::from)?) } - fn add_group( + pub fn add_group( &mut self, group: PyGroup, node_indices_to_add: Option>, @@ -457,7 +462,7 @@ impl PyMedRecord { .map_err(PyMedRecordError::from)?) } - fn remove_groups(&mut self, group: Vec) -> PyResult<()> { + pub fn remove_groups(&mut self, group: Vec) -> PyResult<()> { group.into_iter().try_for_each(|group| { self.0 .remove_group(&group) @@ -467,7 +472,11 @@ impl PyMedRecord { }) } - fn add_nodes_to_group(&mut self, group: PyGroup, node_index: Vec) -> PyResult<()> { + pub fn add_nodes_to_group( + &mut self, + group: PyGroup, + node_index: Vec, + ) -> PyResult<()> { node_index.into_iter().try_for_each(|node_index| { Ok(self .0 @@ -476,7 +485,11 @@ impl PyMedRecord { }) } - fn add_edges_to_group(&mut self, group: PyGroup, edge_index: Vec) -> PyResult<()> { + pub fn add_edges_to_group( + &mut self, + group: PyGroup, + edge_index: Vec, + ) -> PyResult<()> { edge_index.into_iter().try_for_each(|edge_index| { Ok(self .0 @@ -485,7 +498,7 @@ impl PyMedRecord { }) } - fn remove_nodes_from_group( + pub fn remove_nodes_from_group( &mut self, group: PyGroup, node_index: Vec, @@ -498,7 +511,7 @@ impl PyMedRecord { }) } - fn remove_edges_from_group( + pub fn remove_edges_from_group( &mut self, group: PyGroup, edge_index: Vec, @@ -511,7 +524,10 @@ impl PyMedRecord { }) } - fn nodes_in_group(&self, group: Vec) -> PyResult>> { + pub fn nodes_in_group( + &self, + group: Vec, + ) -> PyResult>> { group .into_iter() .map(|group| { @@ -527,7 +543,10 @@ impl PyMedRecord { .collect() } - fn edges_in_group(&self, group: Vec) -> PyResult>> { + pub fn edges_in_group( + &self, + group: Vec, + ) -> PyResult>> { group .into_iter() .map(|group| { @@ -543,7 +562,7 @@ impl PyMedRecord { .collect() } - fn groups_of_node( + pub fn groups_of_node( &self, node_index: Vec, ) -> PyResult>> { @@ -562,7 +581,7 @@ impl PyMedRecord { .collect() } - fn groups_of_edge( + pub fn groups_of_edge( &self, edge_index: Vec, ) -> PyResult>> { @@ -581,31 +600,31 @@ impl PyMedRecord { .collect() } - fn node_count(&self) -> usize { + pub fn node_count(&self) -> usize { self.0.node_count() } - fn edge_count(&self) -> usize { + pub fn edge_count(&self) -> usize { self.0.edge_count() } - fn group_count(&self) -> usize { + pub fn group_count(&self) -> usize { self.0.group_count() } - fn contains_node(&self, node_index: PyNodeIndex) -> bool { + pub fn contains_node(&self, node_index: PyNodeIndex) -> bool { self.0.contains_node(&node_index.into()) } - fn contains_edge(&self, edge_index: EdgeIndex) -> bool { + pub fn contains_edge(&self, edge_index: EdgeIndex) -> bool { self.0.contains_edge(&edge_index) } - fn contains_group(&self, group: PyGroup) -> bool { + pub fn contains_group(&self, group: PyGroup) -> bool { self.0.contains_group(&group.into()) } - fn neighbors( + pub fn neighbors( &self, node_indices: Vec, ) -> PyResult>> { @@ -624,7 +643,7 @@ impl PyMedRecord { .collect() } - fn neighbors_undirected( + pub fn neighbors_undirected( &self, node_indices: Vec, ) -> PyResult>> { @@ -643,27 +662,39 @@ impl PyMedRecord { .collect() } - fn clear(&mut self) { + pub fn clear(&mut self) { self.0.clear(); } - fn select_nodes(&self, operation: PyNodeOperation) -> Vec { - self.0 - .select_nodes(operation.into()) + pub fn select_nodes(&self, query: &Bound<'_, PyFunction>) -> PyResult> { + Ok(self + .0 + .select_nodes(|node| { + query + .call1((PyNodeOperand::from(node.clone()),)) + .expect("Call must succeed"); + }) .iter() - .map(|index| index.clone().into()) - .collect() + .map_err(PyMedRecordError::from)? + .map(|node_index| node_index.clone().into()) + .collect()) } - fn select_edges(&self, operation: PyEdgeOperation) -> Vec { - self.0 - .select_edges(operation.into()) + pub fn select_edges(&self, query: &Bound<'_, PyFunction>) -> PyResult> { + Ok(self + .0 + .select_edges(|edge| { + query + .call1((PyEdgeOperand::from(edge.clone()),)) + .expect("Call must succeed"); + }) .iter() + .map_err(PyMedRecordError::from)? .copied() - .collect() + .collect()) } - fn clone(&self) -> Self { + pub fn clone(&self) -> Self { Self(self.0.clone()) } } diff --git a/rustmodels/src/medrecord/querying.rs b/rustmodels/src/medrecord/querying.rs deleted file mode 100644 index b7965517..00000000 --- a/rustmodels/src/medrecord/querying.rs +++ /dev/null @@ -1,732 +0,0 @@ -use super::{attribute::PyMedRecordAttribute, value::PyMedRecordValue, Lut}; -use crate::{ - gil_hash_map::GILHashMap, - medrecord::{ - errors::PyMedRecordError, value::convert_pyobject_to_medrecordvalue, PyGroup, PyNodeIndex, - }, -}; -use medmodels_core::{ - errors::MedRecordError, - medrecord::{ - ArithmeticOperation, EdgeAttributeOperand, EdgeIndex, EdgeIndexOperand, EdgeOperand, - EdgeOperation, MedRecordAttribute, MedRecordValue, NodeAttributeOperand, NodeIndexOperand, - NodeOperand, NodeOperation, TransformationOperation, ValueOperand, - }, -}; -use pyo3::{ - pyclass, pymethods, types::PyAnyMethods, Bound, FromPyObject, IntoPy, PyAny, PyObject, - PyResult, Python, -}; -use std::ops::Range; - -#[pyclass] -#[derive(Clone, Debug)] -pub struct PyValueArithmeticOperation(ArithmeticOperation, MedRecordAttribute, MedRecordValue); - -#[pyclass] -#[derive(Clone, Debug)] -pub struct PyValueTransformationOperation(TransformationOperation, MedRecordAttribute); - -#[pyclass] -#[derive(Clone, Debug)] -pub struct PyValueSliceOperation(MedRecordAttribute, Range); - -#[repr(transparent)] -#[derive(Clone, Debug)] -pub(crate) struct PyValueOperand(ValueOperand); - -impl From for PyValueOperand { - fn from(value: ValueOperand) -> Self { - PyValueOperand(value) - } -} - -impl From for ValueOperand { - fn from(value: PyValueOperand) -> Self { - value.0 - } -} - -static PYVALUEOPERAND_CONVERSION_LUT: Lut = GILHashMap::new(); - -fn convert_pyobject_to_valueoperand(ob: &Bound<'_, PyAny>) -> PyResult { - if let Ok(value) = convert_pyobject_to_medrecordvalue(ob) { - return Ok(ValueOperand::Value(value)); - }; - - fn convert_node_attribute_operand(ob: &Bound<'_, PyAny>) -> PyResult { - Ok(ValueOperand::Evaluate(MedRecordAttribute::from( - ob.extract::()?.0, - ))) - } - - fn convert_edge_attribute_operand(ob: &Bound<'_, PyAny>) -> PyResult { - Ok(ValueOperand::Evaluate(MedRecordAttribute::from( - ob.extract::()?.0, - ))) - } - - fn convert_arithmetic_operation(ob: &Bound<'_, PyAny>) -> PyResult { - let operation = ob.extract::()?; - - Ok(ValueOperand::ArithmeticOperation( - operation.0, - operation.1, - operation.2, - )) - } - - fn convert_transformation_operation(ob: &Bound<'_, PyAny>) -> PyResult { - let operation = ob.extract::()?; - - Ok(ValueOperand::TransformationOperation( - operation.0, - operation.1, - )) - } - - fn convert_slice_operation(ob: &Bound<'_, PyAny>) -> PyResult { - let operation = ob.extract::()?; - - Ok(ValueOperand::Slice(operation.0, operation.1)) - } - - fn throw_error(ob: &Bound<'_, PyAny>) -> PyResult { - Err( - PyMedRecordError::from(MedRecordError::ConversionError(format!( - "Failed to convert {} into ValueOperand", - ob, - ))) - .into(), - ) - } - - let type_pointer = ob.get_type_ptr() as usize; - - Python::with_gil(|py| { - PYVALUEOPERAND_CONVERSION_LUT.map(py, |lut| { - let conversion_function = lut.entry(type_pointer).or_insert_with(|| { - if ob.is_instance_of::() { - convert_node_attribute_operand - } else if ob.is_instance_of::() { - convert_edge_attribute_operand - } else if ob.is_instance_of::() { - convert_arithmetic_operation - } else if ob.is_instance_of::() { - convert_transformation_operation - } else if ob.is_instance_of::() { - convert_slice_operation - } else { - throw_error - } - }); - - conversion_function(ob) - }) - }) -} - -impl<'a> FromPyObject<'a> for PyValueOperand { - fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult { - convert_pyobject_to_valueoperand(ob).map(PyValueOperand::from) - } -} - -impl IntoPy for PyValueOperand { - fn into_py(self, py: pyo3::prelude::Python<'_>) -> PyObject { - match self.0 { - ValueOperand::Value(value) => PyMedRecordValue::from(value).into_py(py), - ValueOperand::Evaluate(attribute) => PyMedRecordAttribute::from(attribute).into_py(py), - ValueOperand::ArithmeticOperation(operation, attribute, value) => { - PyValueArithmeticOperation(operation, attribute, value).into_py(py) - } - ValueOperand::TransformationOperation(operation, attribute) => { - PyValueTransformationOperation(operation, attribute).into_py(py) - } - ValueOperand::Slice(attribute, range) => { - PyValueSliceOperation(attribute, range).into_py(py) - } - } - } -} - -#[pyclass] -#[repr(transparent)] -#[derive(Clone, Debug)] -pub struct PyNodeOperation(NodeOperation); - -impl From for PyNodeOperation { - fn from(value: NodeOperation) -> Self { - PyNodeOperation(value) - } -} - -impl From for NodeOperation { - fn from(value: PyNodeOperation) -> Self { - value.0 - } -} - -#[pymethods] -impl PyNodeOperation { - fn logical_and(&self, operation: PyNodeOperation) -> PyNodeOperation { - self.clone().0.and(operation.into()).into() - } - - fn logical_or(&self, operation: PyNodeOperation) -> PyNodeOperation { - self.clone().0.or(operation.into()).into() - } - - fn logical_xor(&self, operation: PyNodeOperation) -> PyNodeOperation { - self.clone().0.xor(operation.into()).into() - } - - fn logical_not(&self) -> PyNodeOperation { - self.clone().0.not().into() - } -} - -#[pyclass] -#[repr(transparent)] -#[derive(Clone, Debug)] -pub struct PyEdgeOperation(EdgeOperation); - -impl From for PyEdgeOperation { - fn from(value: EdgeOperation) -> Self { - PyEdgeOperation(value) - } -} - -impl From for EdgeOperation { - fn from(value: PyEdgeOperation) -> Self { - value.0 - } -} - -#[pymethods] -impl PyEdgeOperation { - fn logical_and(&self, operation: PyEdgeOperation) -> PyEdgeOperation { - self.clone().0.and(operation.into()).into() - } - - fn logical_or(&self, operation: PyEdgeOperation) -> PyEdgeOperation { - self.clone().0.or(operation.into()).into() - } - - fn logical_xor(&self, operation: PyEdgeOperation) -> PyEdgeOperation { - self.clone().0.xor(operation.into()).into() - } - - fn logical_not(&self) -> PyEdgeOperation { - self.clone().0.not().into() - } -} - -#[pyclass] -#[repr(transparent)] -#[derive(Clone, Debug)] -pub struct PyNodeAttributeOperand(pub NodeAttributeOperand); - -impl From for PyNodeAttributeOperand { - fn from(value: NodeAttributeOperand) -> Self { - PyNodeAttributeOperand(value) - } -} - -impl From for NodeAttributeOperand { - fn from(value: PyNodeAttributeOperand) -> Self { - value.0 - } -} - -#[pymethods] -impl PyNodeAttributeOperand { - fn greater(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone().0.greater(ValueOperand::from(operand)).into() - } - fn less(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone().0.less(ValueOperand::from(operand)).into() - } - fn greater_or_equal(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone() - .0 - .greater_or_equal(ValueOperand::from(operand)) - .into() - } - fn less_or_equal(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone() - .0 - .less_or_equal(ValueOperand::from(operand)) - .into() - } - - fn equal(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone().0.equal(ValueOperand::from(operand)).into() - } - fn not_equal(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone().0.not_equal(ValueOperand::from(operand)).into() - } - - fn is_in(&self, operands: Vec) -> PyNodeOperation { - self.clone().0.r#in(operands).into() - } - fn not_in(&self, operands: Vec) -> PyNodeOperation { - self.clone().0.not_in(operands).into() - } - - fn starts_with(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone() - .0 - .starts_with(ValueOperand::from(operand)) - .into() - } - - fn ends_with(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone().0.ends_with(ValueOperand::from(operand)).into() - } - - fn contains(&self, operand: PyValueOperand) -> PyNodeOperation { - self.clone().0.contains(ValueOperand::from(operand)).into() - } - - fn add(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.add(value).into() - } - - fn sub(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.sub(value).into() - } - - fn mul(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.mul(value).into() - } - - fn div(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.div(value).into() - } - - fn pow(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.pow(value).into() - } - - fn r#mod(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.r#mod(value).into() - } - - fn round(&self) -> PyValueOperand { - self.clone().0.round().into() - } - - fn ceil(&self) -> PyValueOperand { - self.clone().0.ceil().into() - } - - fn floor(&self) -> PyValueOperand { - self.clone().0.floor().into() - } - - fn abs(&self) -> PyValueOperand { - self.clone().0.abs().into() - } - - fn sqrt(&self) -> PyValueOperand { - self.clone().0.sqrt().into() - } - - fn trim(&self) -> PyValueOperand { - self.clone().0.trim().into() - } - - fn trim_start(&self) -> PyValueOperand { - self.clone().0.trim_start().into() - } - - fn trim_end(&self) -> PyValueOperand { - self.clone().0.trim_end().into() - } - - fn lowercase(&self) -> PyValueOperand { - self.clone().0.lowercase().into() - } - - fn uppercase(&self) -> PyValueOperand { - self.clone().0.uppercase().into() - } - - fn slice(&self, start: usize, end: usize) -> PyResult { - Ok(self.clone().0.slice(Range { start, end }).into()) - } -} - -#[pyclass] -#[repr(transparent)] -#[derive(Clone, Debug)] -pub struct PyEdgeAttributeOperand(EdgeAttributeOperand); - -impl From for PyEdgeAttributeOperand { - fn from(value: EdgeAttributeOperand) -> Self { - PyEdgeAttributeOperand(value) - } -} - -impl From for EdgeAttributeOperand { - fn from(value: PyEdgeAttributeOperand) -> Self { - value.0 - } -} - -#[pymethods] -impl PyEdgeAttributeOperand { - fn greater(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone().0.greater(ValueOperand::from(operand)).into() - } - fn less(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone().0.less(ValueOperand::from(operand)).into() - } - fn greater_or_equal(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone() - .0 - .greater_or_equal(ValueOperand::from(operand)) - .into() - } - fn less_or_equal(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone() - .0 - .less_or_equal(ValueOperand::from(operand)) - .into() - } - - fn equal(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone().0.equal(ValueOperand::from(operand)).into() - } - fn not_equal(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone().0.not_equal(ValueOperand::from(operand)).into() - } - - fn is_in(&self, operand: Vec) -> PyEdgeOperation { - self.clone().0.r#in(operand).into() - } - fn not_in(&self, operand: Vec) -> PyEdgeOperation { - self.clone().0.not_in(operand).into() - } - - fn starts_with(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone() - .0 - .starts_with(ValueOperand::from(operand)) - .into() - } - - fn ends_with(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone().0.ends_with(ValueOperand::from(operand)).into() - } - - fn contains(&self, operand: PyValueOperand) -> PyEdgeOperation { - self.clone().0.contains(ValueOperand::from(operand)).into() - } - - fn add(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.add(value).into() - } - - fn sub(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.sub(value).into() - } - - fn mul(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.mul(value).into() - } - - fn div(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.div(value).into() - } - - fn pow(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.pow(value).into() - } - - fn r#mod(&self, value: PyMedRecordValue) -> PyValueOperand { - self.clone().0.r#mod(value).into() - } - - fn round(&self) -> PyValueOperand { - self.clone().0.round().into() - } - - fn ceil(&self) -> PyValueOperand { - self.clone().0.ceil().into() - } - - fn floor(&self) -> PyValueOperand { - self.clone().0.floor().into() - } - - fn abs(&self) -> PyValueOperand { - self.clone().0.abs().into() - } - - fn sqrt(&self) -> PyValueOperand { - self.clone().0.sqrt().into() - } - - fn trim(&self) -> PyValueOperand { - self.clone().0.trim().into() - } - - fn trim_start(&self) -> PyValueOperand { - self.clone().0.trim_start().into() - } - - fn trim_end(&self) -> PyValueOperand { - self.clone().0.trim_end().into() - } - - fn lowercase(&self) -> PyValueOperand { - self.clone().0.lowercase().into() - } - - fn uppercase(&self) -> PyValueOperand { - self.clone().0.uppercase().into() - } - - fn slice(&self, start: usize, end: usize) -> PyResult { - Ok(self.clone().0.slice(Range { start, end }).into()) - } -} - -#[pyclass] -#[repr(transparent)] -#[derive(Clone, Debug)] -pub struct PyNodeIndexOperand(NodeIndexOperand); - -impl From for PyNodeIndexOperand { - fn from(value: NodeIndexOperand) -> Self { - PyNodeIndexOperand(value) - } -} - -impl From for NodeIndexOperand { - fn from(value: PyNodeIndexOperand) -> Self { - value.0 - } -} - -#[pymethods] -impl PyNodeIndexOperand { - fn greater(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.greater(operand).into() - } - fn less(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.less(operand).into() - } - fn greater_or_equal(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.greater_or_equal(operand).into() - } - fn less_or_equal(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.less_or_equal(operand).into() - } - - fn equal(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.equal(operand).into() - } - fn not_equal(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.not_equal(operand).into() - } - - fn is_in(&self, operand: Vec) -> PyNodeOperation { - self.clone().0.r#in(operand).into() - } - fn not_in(&self, operand: Vec) -> PyNodeOperation { - self.clone().0.not_in(operand).into() - } - - fn starts_with(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.starts_with(operand).into() - } - - fn ends_with(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.ends_with(operand).into() - } - - fn contains(&self, operand: PyNodeIndex) -> PyNodeOperation { - self.clone().0.contains(operand).into() - } -} - -#[pyclass] -#[repr(transparent)] -#[derive(Clone, Debug)] -pub struct PyEdgeIndexOperand(EdgeIndexOperand); - -impl From for PyEdgeIndexOperand { - fn from(value: EdgeIndexOperand) -> Self { - PyEdgeIndexOperand(value) - } -} - -impl From for EdgeIndexOperand { - fn from(value: PyEdgeIndexOperand) -> Self { - value.0 - } -} - -#[pymethods] -impl PyEdgeIndexOperand { - fn greater(&self, operand: EdgeIndex) -> PyEdgeOperation { - self.clone().0.greater(operand).into() - } - fn less(&self, operand: EdgeIndex) -> PyEdgeOperation { - self.clone().0.less(operand).into() - } - fn greater_or_equal(&self, operand: EdgeIndex) -> PyEdgeOperation { - self.clone().0.greater_or_equal(operand).into() - } - fn less_or_equal(&self, operand: EdgeIndex) -> PyEdgeOperation { - self.clone().0.less_or_equal(operand).into() - } - - fn equal(&self, operand: EdgeIndex) -> PyEdgeOperation { - self.clone().0.equal(operand).into() - } - fn not_equal(&self, operand: EdgeIndex) -> PyEdgeOperation { - self.clone().0.not_equal(operand).into() - } - - fn is_in(&self, operand: Vec) -> PyEdgeOperation { - self.clone().0.r#in(operand).into() - } - fn not_in(&self, operand: Vec) -> PyEdgeOperation { - self.clone().0.not_in(operand).into() - } -} - -#[pyclass] -#[repr(transparent)] -#[derive(Clone, Debug)] -pub struct PyNodeOperand(NodeOperand); - -#[pymethods] -impl PyNodeOperand { - #[new] - fn new() -> Self { - Self(NodeOperand) - } - - fn in_group(&self, operand: PyGroup) -> PyNodeOperation { - self.clone().0.in_group(operand).into() - } - - fn has_attribute(&self, operand: PyMedRecordAttribute) -> PyNodeOperation { - self.clone().0.has_attribute(operand).into() - } - - fn has_outgoing_edge_with(&self, operation: PyEdgeOperation) -> PyNodeOperation { - self.clone() - .0 - .has_outgoing_edge_with(operation.into()) - .into() - } - fn has_incoming_edge_with(&self, operation: PyEdgeOperation) -> PyNodeOperation { - self.clone() - .0 - .has_incoming_edge_with(operation.into()) - .into() - } - fn has_edge_with(&self, operation: PyEdgeOperation) -> PyNodeOperation { - self.clone().0.has_edge_with(operation.into()).into() - } - - fn has_neighbor_with(&self, operation: PyNodeOperation) -> PyNodeOperation { - self.clone().0.has_neighbor_with(operation.into()).into() - } - fn has_neighbor_undirected_with(&self, operation: PyNodeOperation) -> PyNodeOperation { - self.clone() - .0 - .has_neighbor_undirected_with(operation.into()) - .into() - } - - fn attribute(&self, attribute: PyMedRecordAttribute) -> PyNodeAttributeOperand { - self.clone().0.attribute(attribute).into() - } - - fn index(&self) -> PyNodeIndexOperand { - self.clone().0.index().into() - } -} - -#[pyclass] -#[repr(transparent)] -#[derive(Clone, Debug)] -pub struct PyEdgeOperand(EdgeOperand); - -#[pymethods] -impl PyEdgeOperand { - #[new] - fn new() -> Self { - Self(EdgeOperand) - } - - fn connected_target(&self, operand: PyNodeIndex) -> PyEdgeOperation { - self.clone().0.connected_target(operand).into() - } - - fn connected_source(&self, operand: PyNodeIndex) -> PyEdgeOperation { - self.clone().0.connected_source(operand).into() - } - - fn connected(&self, operand: PyNodeIndex) -> PyEdgeOperation { - self.clone().0.connected(operand).into() - } - - fn in_group(&self, operand: PyGroup) -> PyEdgeOperation { - self.clone().0.in_group(operand).into() - } - - fn has_attribute(&self, operand: PyMedRecordAttribute) -> PyEdgeOperation { - self.clone().0.has_attribute(operand).into() - } - - fn connected_source_with(&self, operation: PyNodeOperation) -> PyEdgeOperation { - self.clone() - .0 - .connected_source_with(operation.into()) - .into() - } - - fn connected_target_with(&self, operation: PyNodeOperation) -> PyEdgeOperation { - self.clone() - .0 - .connected_target_with(operation.into()) - .into() - } - - fn connected_with(&self, operation: PyNodeOperation) -> PyEdgeOperation { - self.clone().0.connected_with(operation.into()).into() - } - - fn has_parallel_edges_with(&self, operation: PyEdgeOperation) -> PyEdgeOperation { - self.clone() - .0 - .has_parallel_edges_with(operation.into()) - .into() - } - - fn has_parallel_edges_with_self_comparison( - &self, - operation: PyEdgeOperation, - ) -> PyEdgeOperation { - self.clone() - .0 - .has_parallel_edges_with_self_comparison(operation.into()) - .into() - } - - fn attribute(&self, attribute: PyMedRecordAttribute) -> PyEdgeAttributeOperand { - self.clone().0.attribute(attribute).into() - } - - fn index(&self) -> PyEdgeIndexOperand { - self.clone().0.index().into() - } -} diff --git a/rustmodels/src/medrecord/querying/attributes.rs b/rustmodels/src/medrecord/querying/attributes.rs new file mode 100644 index 00000000..4ca611db --- /dev/null +++ b/rustmodels/src/medrecord/querying/attributes.rs @@ -0,0 +1,565 @@ +use super::values::PyMultipleValuesOperand; +use crate::medrecord::{attribute::PyMedRecordAttribute, errors::PyMedRecordError}; +use medmodels_core::{ + errors::MedRecordError, + medrecord::{ + AttributesTreeOperand, DeepClone, MedRecordAttribute, MultipleAttributesComparisonOperand, + MultipleAttributesOperand, SingleAttributeComparisonOperand, SingleAttributeOperand, + Wrapper, + }, +}; +use pyo3::{ + pyclass, pymethods, + types::{PyAnyMethods, PyFunction}, + Bound, FromPyObject, PyAny, PyResult, +}; + +#[repr(transparent)] +pub struct PySingleAttributeComparisonOperand(SingleAttributeComparisonOperand); + +impl From for PySingleAttributeComparisonOperand { + fn from(operand: SingleAttributeComparisonOperand) -> Self { + Self(operand) + } +} + +impl From for SingleAttributeComparisonOperand { + fn from(operand: PySingleAttributeComparisonOperand) -> Self { + operand.0 + } +} + +impl<'a> FromPyObject<'a> for PySingleAttributeComparisonOperand { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(attribute) = ob.extract::() { + Ok(SingleAttributeComparisonOperand::Attribute(attribute.into()).into()) + } else if let Ok(operand) = ob.extract::() { + Ok(PySingleAttributeComparisonOperand(operand.0.into())) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into MedRecordValue or SingleValueOperand", + ob, + ))) + .into(), + ) + } + } +} + +#[repr(transparent)] +pub struct PyMultipleAttributesComparisonOperand(MultipleAttributesComparisonOperand); + +impl From for PyMultipleAttributesComparisonOperand { + fn from(operand: MultipleAttributesComparisonOperand) -> Self { + Self(operand) + } +} + +impl From for MultipleAttributesComparisonOperand { + fn from(operand: PyMultipleAttributesComparisonOperand) -> Self { + operand.0 + } +} + +impl<'a> FromPyObject<'a> for PyMultipleAttributesComparisonOperand { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(values) = ob.extract::>() { + Ok(MultipleAttributesComparisonOperand::Attributes( + values.into_iter().map(MedRecordAttribute::from).collect(), + ) + .into()) + } else if let Ok(operand) = ob.extract::() { + Ok(PyMultipleAttributesComparisonOperand(operand.0.into())) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into List[MedRecordAttribute] or MultipleAttributesOperand", + ob, + ))) + .into(), + ) + } + } +} + +#[pyclass] +#[repr(transparent)] +pub struct PyAttributesTreeOperand(Wrapper); + +impl From> for PyAttributesTreeOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyAttributesTreeOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyAttributesTreeOperand { + pub fn max(&self) -> PyMultipleAttributesOperand { + self.0.max().into() + } + + pub fn min(&self) -> PyMultipleAttributesOperand { + self.0.min().into() + } + + pub fn count(&self) -> PyMultipleAttributesOperand { + self.0.count().into() + } + + pub fn sum(&self) -> PyMultipleAttributesOperand { + self.0.sum().into() + } + + pub fn first(&self) -> PyMultipleAttributesOperand { + self.0.first().into() + } + + pub fn last(&self) -> PyMultipleAttributesOperand { + self.0.last().into() + } + + pub fn greater_than(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.greater_than(attribute); + } + + pub fn greater_than_or_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.greater_than_or_equal_to(attribute); + } + + pub fn less_than(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.less_than(attribute); + } + + pub fn less_than_or_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.less_than_or_equal_to(attribute); + } + + pub fn equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.equal_to(attribute); + } + + pub fn not_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.not_equal_to(attribute); + } + + pub fn starts_with(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.starts_with(attribute); + } + + pub fn ends_with(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.ends_with(attribute); + } + + pub fn contains(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.contains(attribute); + } + + pub fn is_in(&self, attributes: PyMultipleAttributesComparisonOperand) { + self.0.is_in(attributes); + } + + pub fn is_not_in(&self, attributes: PyMultipleAttributesComparisonOperand) { + self.0.is_not_in(attributes); + } + + pub fn add(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.add(attribute); + } + + pub fn sub(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.sub(attribute); + } + + pub fn mul(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.mul(attribute); + } + + pub fn pow(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.pow(attribute); + } + + pub fn r#mod(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.r#mod(attribute); + } + + pub fn abs(&self) { + self.0.abs(); + } + + pub fn trim(&self) { + self.0.trim(); + } + + pub fn trim_start(&self) { + self.0.trim_start(); + } + + pub fn trim_end(&self) { + self.0.trim_end(); + } + + pub fn lowercase(&self) { + self.0.lowercase(); + } + + pub fn uppercase(&self) { + self.0.uppercase(); + } + + pub fn slice(&self, start: usize, end: usize) { + self.0.slice(start, end); + } + + pub fn is_string(&self) { + self.0.is_string(); + } + + pub fn is_int(&self) { + self.0.is_int(); + } + + pub fn is_max(&self) { + self.0.is_max(); + } + + pub fn is_min(&self) { + self.0.is_min(); + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyAttributesTreeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyAttributesTreeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn deep_clone(&self) -> PyAttributesTreeOperand { + self.0.deep_clone().into() + } +} + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PyMultipleAttributesOperand(Wrapper); + +impl From> for PyMultipleAttributesOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyMultipleAttributesOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyMultipleAttributesOperand { + pub fn max(&self) -> PySingleAttributeOperand { + self.0.max().into() + } + + pub fn min(&self) -> PySingleAttributeOperand { + self.0.min().into() + } + + pub fn count(&self) -> PySingleAttributeOperand { + self.0.count().into() + } + + pub fn sum(&self) -> PySingleAttributeOperand { + self.0.sum().into() + } + + pub fn first(&self) -> PySingleAttributeOperand { + self.0.first().into() + } + + pub fn last(&self) -> PySingleAttributeOperand { + self.0.last().into() + } + + pub fn greater_than(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.greater_than(attribute); + } + + pub fn greater_than_or_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.greater_than_or_equal_to(attribute); + } + + pub fn less_than(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.less_than(attribute); + } + + pub fn less_than_or_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.less_than_or_equal_to(attribute); + } + + pub fn equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.equal_to(attribute); + } + + pub fn not_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.not_equal_to(attribute); + } + + pub fn starts_with(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.starts_with(attribute); + } + + pub fn ends_with(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.ends_with(attribute); + } + + pub fn contains(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.contains(attribute); + } + + pub fn is_in(&self, attributes: PyMultipleAttributesComparisonOperand) { + self.0.is_in(attributes); + } + + pub fn is_not_in(&self, attributes: PyMultipleAttributesComparisonOperand) { + self.0.is_not_in(attributes); + } + + pub fn add(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.add(attribute); + } + + pub fn sub(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.sub(attribute); + } + + pub fn mul(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.mul(attribute); + } + + pub fn pow(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.pow(attribute); + } + + pub fn r#mod(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.r#mod(attribute); + } + + pub fn abs(&self) { + self.0.abs(); + } + + pub fn trim(&self) { + self.0.trim(); + } + + pub fn trim_start(&self) { + self.0.trim_start(); + } + + pub fn trim_end(&self) { + self.0.trim_end(); + } + + pub fn lowercase(&self) { + self.0.lowercase(); + } + + pub fn uppercase(&self) { + self.0.uppercase(); + } + + pub fn to_values(&self) -> PyMultipleValuesOperand { + self.0.to_values().into() + } + + pub fn slice(&self, start: usize, end: usize) { + self.0.slice(start, end); + } + + pub fn is_string(&self) { + self.0.is_string(); + } + + pub fn is_int(&self) { + self.0.is_int(); + } + + pub fn is_max(&self) { + self.0.is_max(); + } + + pub fn is_min(&self) { + self.0.is_min(); + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyMultipleAttributesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyMultipleAttributesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn deep_clone(&self) -> PyMultipleAttributesOperand { + self.0.deep_clone().into() + } +} + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PySingleAttributeOperand(Wrapper); + +impl From> for PySingleAttributeOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PySingleAttributeOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PySingleAttributeOperand { + pub fn greater_than(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.greater_than(attribute); + } + + pub fn greater_than_or_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.greater_than_or_equal_to(attribute); + } + + pub fn less_than(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.less_than(attribute); + } + + pub fn less_than_or_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.less_than_or_equal_to(attribute); + } + + pub fn equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.equal_to(attribute); + } + + pub fn not_equal_to(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.not_equal_to(attribute); + } + + pub fn starts_with(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.starts_with(attribute); + } + + pub fn ends_with(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.ends_with(attribute); + } + + pub fn contains(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.contains(attribute); + } + + pub fn is_in(&self, attributes: PyMultipleAttributesComparisonOperand) { + self.0.is_in(attributes); + } + + pub fn is_not_in(&self, attributes: PyMultipleAttributesComparisonOperand) { + self.0.is_not_in(attributes); + } + + pub fn add(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.add(attribute); + } + + pub fn sub(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.sub(attribute); + } + + pub fn mul(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.mul(attribute); + } + + pub fn pow(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.pow(attribute); + } + + pub fn r#mod(&self, attribute: PySingleAttributeComparisonOperand) { + self.0.r#mod(attribute); + } + + pub fn abs(&self) { + self.0.abs(); + } + + pub fn trim(&self) { + self.0.trim(); + } + + pub fn trim_start(&self) { + self.0.trim_start(); + } + + pub fn trim_end(&self) { + self.0.trim_end(); + } + + pub fn lowercase(&self) { + self.0.lowercase(); + } + + pub fn uppercase(&self) { + self.0.uppercase(); + } + + pub fn slice(&self, start: usize, end: usize) { + self.0.slice(start, end); + } + + pub fn is_string(&self) { + self.0.is_string(); + } + + pub fn is_int(&self) { + self.0.is_int(); + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PySingleAttributeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PySingleAttributeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn deep_clone(&self) -> PySingleAttributeOperand { + self.0.deep_clone().into() + } +} diff --git a/rustmodels/src/medrecord/querying/edges.rs b/rustmodels/src/medrecord/querying/edges.rs new file mode 100644 index 00000000..0eabf281 --- /dev/null +++ b/rustmodels/src/medrecord/querying/edges.rs @@ -0,0 +1,384 @@ +use super::{ + attributes::PyAttributesTreeOperand, nodes::PyNodeOperand, values::PyMultipleValuesOperand, + PyGroupCardinalityWrapper, PyMedRecordAttributeCardinalityWrapper, +}; +use crate::medrecord::{attribute::PyMedRecordAttribute, errors::PyMedRecordError}; +use medmodels_core::{ + errors::MedRecordError, + medrecord::{ + DeepClone, EdgeIndex, EdgeIndexComparisonOperand, EdgeIndexOperand, + EdgeIndicesComparisonOperand, EdgeIndicesOperand, EdgeOperand, Wrapper, + }, +}; +use pyo3::{ + pyclass, pymethods, + types::{PyAnyMethods, PyFunction}, + Bound, FromPyObject, PyAny, PyResult, +}; + +#[pyclass] +#[repr(transparent)] +pub struct PyEdgeOperand(Wrapper); + +impl From> for PyEdgeOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyEdgeOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyEdgeOperand { + pub fn attribute(&mut self, attribute: PyMedRecordAttribute) -> PyMultipleValuesOperand { + self.0.attribute(attribute).into() + } + + pub fn attributes(&mut self) -> PyAttributesTreeOperand { + self.0.attributes().into() + } + + pub fn index(&mut self) -> PyEdgeIndicesOperand { + self.0.index().into() + } + + pub fn in_group(&mut self, group: PyGroupCardinalityWrapper) { + self.0.in_group(group); + } + + pub fn has_attribute(&mut self, attribute: PyMedRecordAttributeCardinalityWrapper) { + self.0.has_attribute(attribute); + } + + pub fn source_node(&mut self) -> PyNodeOperand { + self.0.source_node().into() + } + + pub fn target_node(&mut self) -> PyNodeOperand { + self.0.target_node().into() + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyEdgeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyEdgeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn deep_clone(&self) -> PyEdgeOperand { + self.0.deep_clone().into() + } +} + +#[repr(transparent)] +pub struct PyEdgeIndexComparisonOperand(EdgeIndexComparisonOperand); + +impl From for PyEdgeIndexComparisonOperand { + fn from(operand: EdgeIndexComparisonOperand) -> Self { + Self(operand) + } +} + +impl From for EdgeIndexComparisonOperand { + fn from(operand: PyEdgeIndexComparisonOperand) -> Self { + operand.0 + } +} + +impl<'a> FromPyObject<'a> for PyEdgeIndexComparisonOperand { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(index) = ob.extract::() { + Ok(EdgeIndexComparisonOperand::Index(index).into()) + } else if let Ok(operand) = ob.extract::() { + Ok(PyEdgeIndexComparisonOperand(operand.0.into())) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into EdgeIndex or EdgeIndexOperand", + ob, + ))) + .into(), + ) + } + } +} + +#[repr(transparent)] +pub struct PyEdgeIndicesComparisonOperand(EdgeIndicesComparisonOperand); + +impl From for PyEdgeIndicesComparisonOperand { + fn from(operand: EdgeIndicesComparisonOperand) -> Self { + Self(operand) + } +} + +impl From for EdgeIndicesComparisonOperand { + fn from(operand: PyEdgeIndicesComparisonOperand) -> Self { + operand.0 + } +} + +impl<'a> FromPyObject<'a> for PyEdgeIndicesComparisonOperand { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(indices) = ob.extract::>() { + Ok(EdgeIndicesComparisonOperand::Indices(indices).into()) + } else if let Ok(operand) = ob.extract::() { + Ok(PyEdgeIndicesComparisonOperand(operand.0.into())) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into List[EdgeIndex] or EdgeIndicesOperand", + ob, + ))) + .into(), + ) + } + } +} + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PyEdgeIndicesOperand(Wrapper); + +impl From> for PyEdgeIndicesOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyEdgeIndicesOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyEdgeIndicesOperand { + pub fn max(&mut self) -> PyEdgeIndexOperand { + self.0.max().into() + } + + pub fn min(&mut self) -> PyEdgeIndexOperand { + self.0.min().into() + } + + pub fn count(&mut self) -> PyEdgeIndexOperand { + self.0.count().into() + } + + pub fn sum(&mut self) -> PyEdgeIndexOperand { + self.0.sum().into() + } + + pub fn first(&mut self) -> PyEdgeIndexOperand { + self.0.first().into() + } + + pub fn last(&mut self) -> PyEdgeIndexOperand { + self.0.last().into() + } + + pub fn greater_than(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.greater_than(index); + } + + pub fn greater_than_or_equal_to(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.greater_than_or_equal_to(index); + } + + pub fn less_than(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.less_than(index); + } + + pub fn less_than_or_equal_to(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.less_than_or_equal_to(index); + } + + pub fn equal_to(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.equal_to(index); + } + + pub fn not_equal_to(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.not_equal_to(index); + } + + pub fn starts_with(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.starts_with(index); + } + + pub fn ends_with(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.ends_with(index); + } + + pub fn contains(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.contains(index); + } + + pub fn is_in(&mut self, indices: PyEdgeIndicesComparisonOperand) { + self.0.is_in(indices); + } + + pub fn is_not_in(&mut self, indices: PyEdgeIndicesComparisonOperand) { + self.0.is_not_in(indices); + } + + pub fn add(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.add(index); + } + + pub fn sub(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.sub(index); + } + + pub fn mul(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.mul(index); + } + + pub fn pow(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.pow(index); + } + + pub fn r#mod(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.r#mod(index); + } + + pub fn is_max(&mut self) { + self.0.is_max() + } + + pub fn is_min(&mut self) { + self.0.is_min() + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyEdgeIndicesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyEdgeIndicesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn deep_clone(&self) -> PyEdgeIndicesOperand { + self.0.deep_clone().into() + } +} + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PyEdgeIndexOperand(Wrapper); + +impl From> for PyEdgeIndexOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyEdgeIndexOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyEdgeIndexOperand { + pub fn greater_than(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.greater_than(index); + } + + pub fn greater_than_or_equal_to(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.greater_than_or_equal_to(index); + } + + pub fn less_than(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.less_than(index); + } + + pub fn less_than_or_equal_to(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.less_than_or_equal_to(index); + } + + pub fn equal_to(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.equal_to(index); + } + + pub fn not_equal_to(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.not_equal_to(index); + } + + pub fn starts_with(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.starts_with(index); + } + + pub fn ends_with(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.ends_with(index); + } + + pub fn contains(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.contains(index); + } + + pub fn is_in(&mut self, indices: PyEdgeIndicesComparisonOperand) { + self.0.is_in(indices); + } + + pub fn is_not_in(&mut self, indices: PyEdgeIndicesComparisonOperand) { + self.0.is_not_in(indices); + } + + pub fn add(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.add(index); + } + + pub fn sub(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.sub(index); + } + + pub fn mul(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.mul(index); + } + + pub fn pow(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.pow(index); + } + + pub fn r#mod(&mut self, index: PyEdgeIndexComparisonOperand) { + self.0.r#mod(index); + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyEdgeIndexOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyEdgeIndexOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn deep_clone(&self) -> PyEdgeIndexOperand { + self.0.deep_clone().into() + } +} diff --git a/rustmodels/src/medrecord/querying/mod.rs b/rustmodels/src/medrecord/querying/mod.rs new file mode 100644 index 00000000..cfd7b868 --- /dev/null +++ b/rustmodels/src/medrecord/querying/mod.rs @@ -0,0 +1,52 @@ +pub mod attributes; +pub mod edges; +pub mod nodes; +pub mod values; + +use super::{attribute::PyMedRecordAttribute, errors::PyMedRecordError}; +use medmodels_core::{ + errors::MedRecordError, + medrecord::{CardinalityWrapper, MedRecordAttribute}, +}; +use pyo3::{types::PyAnyMethods, Bound, FromPyObject, PyAny, PyResult}; + +#[repr(transparent)] +pub struct PyMedRecordAttributeCardinalityWrapper(CardinalityWrapper); + +impl From> for PyMedRecordAttributeCardinalityWrapper { + fn from(attribute: CardinalityWrapper) -> Self { + Self(attribute) + } +} + +impl From for CardinalityWrapper { + fn from(attribute: PyMedRecordAttributeCardinalityWrapper) -> Self { + attribute.0 + } +} + +impl<'a> FromPyObject<'a> for PyMedRecordAttributeCardinalityWrapper { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(attribute) = ob.extract::() { + Ok(CardinalityWrapper::Single(MedRecordAttribute::from(attribute)).into()) + } else if let Ok(attributes) = ob.extract::>() { + Ok(CardinalityWrapper::Multiple( + attributes + .into_iter() + .map(MedRecordAttribute::from) + .collect(), + ) + .into()) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into MedRecordAttribute or List[MedREcordAttribute]", + ob, + ))) + .into(), + ) + } + } +} + +type PyGroupCardinalityWrapper = PyMedRecordAttributeCardinalityWrapper; diff --git a/rustmodels/src/medrecord/querying/nodes.rs b/rustmodels/src/medrecord/querying/nodes.rs new file mode 100644 index 00000000..9577342e --- /dev/null +++ b/rustmodels/src/medrecord/querying/nodes.rs @@ -0,0 +1,491 @@ +use super::{ + attributes::PyAttributesTreeOperand, edges::PyEdgeOperand, values::PyMultipleValuesOperand, + PyGroupCardinalityWrapper, PyMedRecordAttributeCardinalityWrapper, +}; +use crate::medrecord::{attribute::PyMedRecordAttribute, errors::PyMedRecordError, PyNodeIndex}; +use medmodels_core::{ + errors::MedRecordError, + medrecord::{ + DeepClone, EdgeDirection, NodeIndex, NodeIndexComparisonOperand, NodeIndexOperand, + NodeIndicesComparisonOperand, NodeIndicesOperand, NodeOperand, Wrapper, + }, +}; +use pyo3::{ + pyclass, pymethods, + types::{PyAnyMethods, PyFunction}, + Bound, FromPyObject, PyAny, PyResult, +}; + +#[pyclass] +#[derive(Clone)] +pub enum PyEdgeDirection { + Incoming = 0, + Outgoing = 1, + Both = 2, +} + +impl From for PyEdgeDirection { + fn from(value: EdgeDirection) -> Self { + match value { + EdgeDirection::Incoming => Self::Incoming, + EdgeDirection::Outgoing => Self::Outgoing, + EdgeDirection::Both => Self::Both, + } + } +} + +impl From for EdgeDirection { + fn from(value: PyEdgeDirection) -> Self { + match value { + PyEdgeDirection::Incoming => Self::Incoming, + PyEdgeDirection::Outgoing => Self::Outgoing, + PyEdgeDirection::Both => Self::Both, + } + } +} + +#[pyclass] +#[repr(transparent)] +pub struct PyNodeOperand(Wrapper); + +impl From> for PyNodeOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyNodeOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyNodeOperand { + pub fn attribute(&mut self, attribute: PyMedRecordAttribute) -> PyMultipleValuesOperand { + self.0.attribute(attribute).into() + } + + pub fn attributes(&mut self) -> PyAttributesTreeOperand { + self.0.attributes().into() + } + + pub fn index(&mut self) -> PyNodeIndicesOperand { + self.0.index().into() + } + + pub fn in_group(&mut self, group: PyGroupCardinalityWrapper) { + self.0.in_group(group); + } + + pub fn has_attribute(&mut self, attribute: PyMedRecordAttributeCardinalityWrapper) { + self.0.has_attribute(attribute); + } + + pub fn outgoing_edges(&mut self) -> PyEdgeOperand { + self.0.outgoing_edges().into() + } + + pub fn incoming_edges(&mut self) -> PyEdgeOperand { + self.0.incoming_edges().into() + } + + pub fn neighbors(&mut self, direction: PyEdgeDirection) -> PyNodeOperand { + self.0.neighbors(direction.into()).into() + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyNodeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyNodeOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn deep_clone(&self) -> Self { + self.0.deep_clone().into() + } +} + +#[repr(transparent)] +pub struct PyNodeIndexComparisonOperand(NodeIndexComparisonOperand); + +impl From for PyNodeIndexComparisonOperand { + fn from(operand: NodeIndexComparisonOperand) -> Self { + Self(operand) + } +} + +impl From for NodeIndexComparisonOperand { + fn from(operand: PyNodeIndexComparisonOperand) -> Self { + operand.0 + } +} + +impl<'a> FromPyObject<'a> for PyNodeIndexComparisonOperand { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(index) = ob.extract::() { + Ok(NodeIndexComparisonOperand::Index(NodeIndex::from(index)).into()) + } else if let Ok(operand) = ob.extract::() { + Ok(PyNodeIndexComparisonOperand(operand.0.into())) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into NodeIndex or NodeIndexOperand", + ob, + ))) + .into(), + ) + } + } +} + +#[repr(transparent)] +pub struct PyNodeIndicesComparisonOperand(NodeIndicesComparisonOperand); + +impl From for PyNodeIndicesComparisonOperand { + fn from(operand: NodeIndicesComparisonOperand) -> Self { + Self(operand) + } +} + +impl From for NodeIndicesComparisonOperand { + fn from(operand: PyNodeIndicesComparisonOperand) -> Self { + operand.0 + } +} + +impl<'a> FromPyObject<'a> for PyNodeIndicesComparisonOperand { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(indices) = ob.extract::>() { + Ok(NodeIndicesComparisonOperand::Indices( + indices.into_iter().map(NodeIndex::from).collect(), + ) + .into()) + } else if let Ok(operand) = ob.extract::() { + Ok(PyNodeIndicesComparisonOperand(operand.0.into())) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into List[NodeIndex] or NodeIndicesOperand", + ob, + ))) + .into(), + ) + } + } +} + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PyNodeIndicesOperand(Wrapper); + +impl From> for PyNodeIndicesOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyNodeIndicesOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyNodeIndicesOperand { + pub fn max(&mut self) -> PyNodeIndexOperand { + self.0.max().into() + } + + pub fn min(&mut self) -> PyNodeIndexOperand { + self.0.min().into() + } + + pub fn count(&mut self) -> PyNodeIndexOperand { + self.0.count().into() + } + + pub fn sum(&mut self) -> PyNodeIndexOperand { + self.0.sum().into() + } + + pub fn first(&mut self) -> PyNodeIndexOperand { + self.0.first().into() + } + + pub fn last(&mut self) -> PyNodeIndexOperand { + self.0.last().into() + } + + pub fn greater_than(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.greater_than(index); + } + + pub fn greater_than_or_equal_to(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.greater_than_or_equal_to(index); + } + + pub fn less_than(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.less_than(index); + } + + pub fn less_than_or_equal_to(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.less_than_or_equal_to(index); + } + + pub fn equal_to(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.equal_to(index); + } + + pub fn not_equal_to(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.not_equal_to(index); + } + + pub fn starts_with(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.starts_with(index); + } + + pub fn ends_with(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.ends_with(index); + } + + pub fn contains(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.contains(index); + } + + pub fn is_in(&mut self, indices: PyNodeIndicesComparisonOperand) { + self.0.is_in(indices); + } + + pub fn is_not_in(&mut self, indices: PyNodeIndicesComparisonOperand) { + self.0.is_not_in(indices); + } + + pub fn add(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.add(index); + } + + pub fn sub(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.sub(index); + } + + pub fn mul(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.mul(index); + } + + pub fn pow(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.pow(index); + } + + pub fn r#mod(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.r#mod(index); + } + + pub fn abs(&mut self) { + self.0.abs(); + } + + pub fn trim(&mut self) { + self.0.trim(); + } + + pub fn trim_start(&mut self) { + self.0.trim_start(); + } + + pub fn trim_end(&mut self) { + self.0.trim_end(); + } + + pub fn lowercase(&mut self) { + self.0.lowercase(); + } + + pub fn uppercase(&mut self) { + self.0.uppercase(); + } + + pub fn slice(&mut self, start: usize, end: usize) { + self.0.slice(start, end); + } + + pub fn is_string(&mut self) { + self.0.is_string(); + } + + pub fn is_int(&mut self) { + self.0.is_int(); + } + + pub fn is_max(&mut self) { + self.0.is_max(); + } + + pub fn is_min(&mut self) { + self.0.is_min(); + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyNodeIndicesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyNodeIndicesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn deep_clone(&self) -> Self { + self.0.deep_clone().into() + } +} + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PyNodeIndexOperand(Wrapper); + +impl From> for PyNodeIndexOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyNodeIndexOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyNodeIndexOperand { + pub fn greater_than(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.greater_than(index); + } + + pub fn greater_than_or_equal_to(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.greater_than_or_equal_to(index); + } + + pub fn less_than(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.less_than(index); + } + + pub fn less_than_or_equal_to(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.less_than_or_equal_to(index); + } + + pub fn equal_to(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.equal_to(index); + } + + pub fn not_equal_to(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.not_equal_to(index); + } + + pub fn starts_with(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.starts_with(index); + } + + pub fn ends_with(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.ends_with(index); + } + + pub fn contains(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.contains(index); + } + + pub fn is_in(&mut self, indices: PyNodeIndicesComparisonOperand) { + self.0.is_in(indices); + } + + pub fn is_not_in(&mut self, indices: PyNodeIndicesComparisonOperand) { + self.0.is_not_in(indices); + } + + pub fn add(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.add(index); + } + + pub fn sub(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.sub(index); + } + + pub fn mul(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.mul(index); + } + + pub fn pow(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.pow(index); + } + + pub fn r#mod(&mut self, index: PyNodeIndexComparisonOperand) { + self.0.r#mod(index); + } + + pub fn abs(&mut self) { + self.0.abs(); + } + + pub fn trim(&mut self) { + self.0.trim(); + } + + pub fn trim_start(&mut self) { + self.0.trim_start(); + } + + pub fn trim_end(&mut self) { + self.0.trim_end(); + } + + pub fn lowercase(&mut self) { + self.0.lowercase(); + } + + pub fn uppercase(&mut self) { + self.0.uppercase(); + } + + pub fn slice(&mut self, start: usize, end: usize) { + self.0.slice(start, end); + } + + pub fn is_string(&mut self) { + self.0.is_string(); + } + + pub fn is_int(&mut self) { + self.0.is_int(); + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyNodeIndexOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyNodeIndexOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn deep_clone(&self) -> Self { + self.0.deep_clone().into() + } +} diff --git a/rustmodels/src/medrecord/querying/values.rs b/rustmodels/src/medrecord/querying/values.rs new file mode 100644 index 00000000..31a7f802 --- /dev/null +++ b/rustmodels/src/medrecord/querying/values.rs @@ -0,0 +1,482 @@ +use crate::medrecord::{errors::PyMedRecordError, value::PyMedRecordValue}; +use medmodels_core::{ + errors::MedRecordError, + medrecord::{ + DeepClone, MedRecordValue, MultipleValuesComparisonOperand, MultipleValuesOperand, + SingleValueComparisonOperand, SingleValueOperand, Wrapper, + }, +}; +use pyo3::{ + pyclass, pymethods, + types::{PyAnyMethods, PyFunction}, + Bound, FromPyObject, PyAny, PyResult, +}; + +#[repr(transparent)] +pub struct PySingleValueComparisonOperand(SingleValueComparisonOperand); + +impl From for PySingleValueComparisonOperand { + fn from(operand: SingleValueComparisonOperand) -> Self { + Self(operand) + } +} + +impl From for SingleValueComparisonOperand { + fn from(operand: PySingleValueComparisonOperand) -> Self { + operand.0 + } +} + +impl<'a> FromPyObject<'a> for PySingleValueComparisonOperand { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(value) = ob.extract::() { + Ok(SingleValueComparisonOperand::Value(value.into()).into()) + } else if let Ok(operand) = ob.extract::() { + Ok(PySingleValueComparisonOperand(operand.0.into())) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into MedRecordValue or SingleValueOperand", + ob, + ))) + .into(), + ) + } + } +} + +#[repr(transparent)] +pub struct PyMultipleValuesComparisonOperand(MultipleValuesComparisonOperand); + +impl From for PyMultipleValuesComparisonOperand { + fn from(operand: MultipleValuesComparisonOperand) -> Self { + Self(operand) + } +} + +impl From for MultipleValuesComparisonOperand { + fn from(operand: PyMultipleValuesComparisonOperand) -> Self { + operand.0 + } +} + +impl<'a> FromPyObject<'a> for PyMultipleValuesComparisonOperand { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + if let Ok(values) = ob.extract::>() { + Ok(MultipleValuesComparisonOperand::Values( + values.into_iter().map(MedRecordValue::from).collect(), + ) + .into()) + } else if let Ok(operand) = ob.extract::() { + Ok(PyMultipleValuesComparisonOperand(operand.0.into())) + } else { + Err( + PyMedRecordError::from(MedRecordError::ConversionError(format!( + "Failed to convert {} into List[MedRecordValue] or MultipleValuesOperand", + ob, + ))) + .into(), + ) + } + } +} + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PyMultipleValuesOperand(Wrapper); + +impl From> for PyMultipleValuesOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PyMultipleValuesOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PyMultipleValuesOperand { + pub fn max(&self) -> PySingleValueOperand { + self.0.max().into() + } + + pub fn min(&self) -> PySingleValueOperand { + self.0.min().into() + } + + pub fn mean(&self) -> PySingleValueOperand { + self.0.mean().into() + } + + pub fn median(&self) -> PySingleValueOperand { + self.0.median().into() + } + + pub fn mode(&self) -> PySingleValueOperand { + self.0.mode().into() + } + + pub fn std(&self) -> PySingleValueOperand { + self.0.std().into() + } + + pub fn var(&self) -> PySingleValueOperand { + self.0.var().into() + } + + pub fn count(&self) -> PySingleValueOperand { + self.0.count().into() + } + + pub fn sum(&self) -> PySingleValueOperand { + self.0.sum().into() + } + + pub fn first(&self) -> PySingleValueOperand { + self.0.first().into() + } + + pub fn last(&self) -> PySingleValueOperand { + self.0.last().into() + } + + pub fn greater_than(&self, value: PySingleValueComparisonOperand) { + self.0.greater_than(value); + } + + pub fn greater_than_or_equal_to(&self, value: PySingleValueComparisonOperand) { + self.0.greater_than_or_equal_to(value); + } + + pub fn less_than(&self, value: PySingleValueComparisonOperand) { + self.0.less_than(value); + } + + pub fn less_than_or_equal_to(&self, value: PySingleValueComparisonOperand) { + self.0.less_than_or_equal_to(value); + } + + pub fn equal_to(&self, value: PySingleValueComparisonOperand) { + self.0.equal_to(value); + } + + pub fn not_equal_to(&self, value: PySingleValueComparisonOperand) { + self.0.not_equal_to(value); + } + + pub fn starts_with(&self, value: PySingleValueComparisonOperand) { + self.0.starts_with(value); + } + + pub fn ends_with(&self, value: PySingleValueComparisonOperand) { + self.0.ends_with(value); + } + + pub fn contains(&self, value: PySingleValueComparisonOperand) { + self.0.contains(value); + } + + pub fn is_in(&self, values: PyMultipleValuesComparisonOperand) { + self.0.is_in(values); + } + + pub fn is_not_in(&self, values: PyMultipleValuesComparisonOperand) { + self.0.is_not_in(values); + } + + pub fn add(&self, value: PySingleValueComparisonOperand) { + self.0.add(value); + } + + pub fn sub(&self, value: PySingleValueComparisonOperand) { + self.0.sub(value); + } + + pub fn mul(&self, value: PySingleValueComparisonOperand) { + self.0.mul(value); + } + + pub fn div(&self, value: PySingleValueComparisonOperand) { + self.0.div(value); + } + + pub fn pow(&self, value: PySingleValueComparisonOperand) { + self.0.pow(value); + } + + pub fn r#mod(&self, value: PySingleValueComparisonOperand) { + self.0.r#mod(value); + } + + pub fn round(&self) { + self.0.round(); + } + + pub fn ceil(&self) { + self.0.ceil(); + } + + pub fn floor(&self) { + self.0.floor(); + } + + pub fn abs(&self) { + self.0.abs(); + } + + pub fn sqrt(&self) { + self.0.sqrt(); + } + + pub fn trim(&self) { + self.0.trim(); + } + + pub fn trim_start(&self) { + self.0.trim_start(); + } + + pub fn trim_end(&self) { + self.0.trim_end(); + } + + pub fn lowercase(&self) { + self.0.lowercase(); + } + + pub fn uppercase(&self) { + self.0.uppercase(); + } + + pub fn slice(&self, start: usize, end: usize) { + self.0.slice(start, end); + } + + pub fn is_string(&self) { + self.0.is_string(); + } + + pub fn is_int(&self) { + self.0.is_int(); + } + + pub fn is_float(&self) { + self.0.is_float(); + } + + pub fn is_bool(&self) { + self.0.is_bool(); + } + + pub fn is_datetime(&self) { + self.0.is_datetime(); + } + + pub fn is_null(&self) { + self.0.is_null(); + } + + pub fn is_max(&self) { + self.0.is_max(); + } + + pub fn is_min(&self) { + self.0.is_min(); + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PyMultipleValuesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PyMultipleValuesOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn deep_clone(&self) -> PyMultipleValuesOperand { + self.0.deep_clone().into() + } +} + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PySingleValueOperand(Wrapper); + +impl From> for PySingleValueOperand { + fn from(operand: Wrapper) -> Self { + Self(operand) + } +} + +impl From for Wrapper { + fn from(operand: PySingleValueOperand) -> Self { + operand.0 + } +} + +#[pymethods] +impl PySingleValueOperand { + pub fn greater_than(&self, value: PySingleValueComparisonOperand) { + self.0.greater_than(value); + } + + pub fn greater_than_or_equal_to(&self, value: PySingleValueComparisonOperand) { + self.0.greater_than_or_equal_to(value); + } + + pub fn less_than(&self, value: PySingleValueComparisonOperand) { + self.0.less_than(value); + } + + pub fn less_than_or_equal_to(&self, value: PySingleValueComparisonOperand) { + self.0.less_than_or_equal_to(value); + } + + pub fn equal_to(&self, value: PySingleValueComparisonOperand) { + self.0.equal_to(value); + } + + pub fn not_equal_to(&self, value: PySingleValueComparisonOperand) { + self.0.not_equal_to(value); + } + + pub fn starts_with(&self, value: PySingleValueComparisonOperand) { + self.0.starts_with(value); + } + + pub fn ends_with(&self, value: PySingleValueComparisonOperand) { + self.0.ends_with(value); + } + + pub fn contains(&self, value: PySingleValueComparisonOperand) { + self.0.contains(value); + } + + pub fn is_in(&self, values: PyMultipleValuesComparisonOperand) { + self.0.is_in(values); + } + + pub fn is_not_in(&self, values: PyMultipleValuesComparisonOperand) { + self.0.is_not_in(values); + } + + pub fn add(&self, value: PySingleValueComparisonOperand) { + self.0.add(value); + } + + pub fn sub(&self, value: PySingleValueComparisonOperand) { + self.0.sub(value); + } + + pub fn mul(&self, value: PySingleValueComparisonOperand) { + self.0.mul(value); + } + + pub fn div(&self, value: PySingleValueComparisonOperand) { + self.0.div(value); + } + + pub fn pow(&self, value: PySingleValueComparisonOperand) { + self.0.pow(value); + } + + pub fn r#mod(&self, value: PySingleValueComparisonOperand) { + self.0.r#mod(value); + } + + pub fn round(&self) { + self.0.round(); + } + + pub fn ceil(&self) { + self.0.ceil(); + } + + pub fn floor(&self) { + self.0.floor(); + } + + pub fn abs(&self) { + self.0.abs(); + } + + pub fn sqrt(&self) { + self.0.sqrt(); + } + + pub fn trim(&self) { + self.0.trim(); + } + + pub fn trim_start(&self) { + self.0.trim_start(); + } + + pub fn trim_end(&self) { + self.0.trim_end(); + } + + pub fn lowercase(&self) { + self.0.lowercase(); + } + + pub fn uppercase(&self) { + self.0.uppercase(); + } + + pub fn slice(&self, start: usize, end: usize) { + self.0.slice(start, end); + } + + pub fn is_string(&self) { + self.0.is_string(); + } + + pub fn is_int(&self) { + self.0.is_int(); + } + + pub fn is_float(&self) { + self.0.is_float(); + } + + pub fn is_bool(&self) { + self.0.is_bool(); + } + + pub fn is_datetime(&self) { + self.0.is_datetime(); + } + + pub fn is_null(&self) { + self.0.is_null(); + } + + pub fn either_or(&mut self, either: &Bound<'_, PyFunction>, or: &Bound<'_, PyFunction>) { + self.0.either_or( + |operand| { + either + .call1((PySingleValueOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + |operand| { + or.call1((PySingleValueOperand::from(operand.clone()),)) + .expect("Call must succeed"); + }, + ); + } + + pub fn deep_clone(&self) -> PySingleValueOperand { + self.0.deep_clone().into() + } +} diff --git a/rustmodels/src/medrecord/value.rs b/rustmodels/src/medrecord/value.rs index 1ae6b5c7..489e7a97 100644 --- a/rustmodels/src/medrecord/value.rs +++ b/rustmodels/src/medrecord/value.rs @@ -10,7 +10,7 @@ use std::ops::Deref; #[repr(transparent)] #[derive(Clone, Debug)] -pub(crate) struct PyMedRecordValue(MedRecordValue); +pub struct PyMedRecordValue(MedRecordValue); impl From for PyMedRecordValue { fn from(value: MedRecordValue) -> Self { From b8f363c14d7818c9de7c37efcf0b131035771c39 Mon Sep 17 00:00:00 2001 From: Jakob Kraus <52459467+JabobKrauskopf@users.noreply.github.com> Date: Fri, 11 Oct 2024 18:09:00 +0200 Subject: [PATCH 8/8] feat: implement python query engine wrapper (#229) --- .../src/medrecord/querying/edges/operation.rs | 2 +- medmodels/_medmodels.pyi | 2 +- medmodels/medrecord/datatype.py | 28 +- medmodels/medrecord/medrecord.py | 48 +- medmodels/medrecord/querying.py | 1869 +++++++++++++++++ medmodels/medrecord/querying.pyi | 430 ---- medmodels/medrecord/schema.py | 18 +- .../tests/test_treatment_effect.py | 4 +- rustmodels/src/lib.rs | 4 +- 9 files changed, 1926 insertions(+), 479 deletions(-) create mode 100644 medmodels/medrecord/querying.py delete mode 100644 medmodels/medrecord/querying.pyi diff --git a/crates/medmodels-core/src/medrecord/querying/edges/operation.rs b/crates/medmodels-core/src/medrecord/querying/edges/operation.rs index 0d36db8f..9316ddd7 100644 --- a/crates/medmodels-core/src/medrecord/querying/edges/operation.rs +++ b/crates/medmodels-core/src/medrecord/querying/edges/operation.rs @@ -278,7 +278,7 @@ impl EdgeOperation { .edge_endpoints(edge_index) .expect("Edge must exist"); - node_indices.contains(edge_endpoints.1) + node_indices.contains(edge_endpoints.0) })) } diff --git a/medmodels/_medmodels.pyi b/medmodels/_medmodels.pyi index 3bf73e47..b4b4dd06 100644 --- a/medmodels/_medmodels.pyi +++ b/medmodels/_medmodels.pyi @@ -231,7 +231,7 @@ class PyMedRecord: self, query: Callable[[PyNodeOperand], None] ) -> List[NodeIndex]: ... def select_edges( - self, query: Callable[[PyNodeOperand], None] + self, query: Callable[[PyEdgeOperand], None] ) -> List[EdgeIndex]: ... def clone(self) -> PyMedRecord: ... diff --git a/medmodels/medrecord/datatype.py b/medmodels/medrecord/datatype.py index 2dcfdea3..e8b78a6d 100644 --- a/medmodels/medrecord/datatype.py +++ b/medmodels/medrecord/datatype.py @@ -50,7 +50,7 @@ def __repr__(self) -> str: ... def __eq__(self, value: object) -> bool: ... @staticmethod - def _from_pydatatype(datatype: PyDataType) -> DataType: + def _from_py_data_type(datatype: PyDataType) -> DataType: if isinstance(datatype, PyString): return String() elif isinstance(datatype, PyInt): @@ -67,11 +67,11 @@ def _from_pydatatype(datatype: PyDataType) -> DataType: return Any() elif isinstance(datatype, PyUnion): return Union( - DataType._from_pydatatype(datatype.dtype1), - DataType._from_pydatatype(datatype.dtype2), + DataType._from_py_data_type(datatype.dtype1), + DataType._from_py_data_type(datatype.dtype2), ) else: - return Option(DataType._from_pydatatype(datatype.dtype)) + return Option(DataType._from_py_data_type(datatype.dtype)) class String(DataType): @@ -222,18 +222,18 @@ def _inner(self) -> PyDataType: return self._union def __str__(self) -> str: - return f"Union({DataType._from_pydatatype(self._union.dtype1).__str__()}, {DataType._from_pydatatype(self._union.dtype2).__str__()})" + return f"Union({DataType._from_py_data_type(self._union.dtype1).__str__()}, {DataType._from_py_data_type(self._union.dtype2).__str__()})" def __repr__(self) -> str: - return f"DataType.Union({DataType._from_pydatatype(self._union.dtype1).__repr__()}, {DataType._from_pydatatype(self._union.dtype2).__repr__()})" + return f"DataType.Union({DataType._from_py_data_type(self._union.dtype1).__repr__()}, {DataType._from_py_data_type(self._union.dtype2).__repr__()})" def __eq__(self, value: object) -> bool: return ( isinstance(value, Union) - and DataType._from_pydatatype(self._union.dtype1) - == DataType._from_pydatatype(value._union.dtype1) - and DataType._from_pydatatype(self._union.dtype2) - == DataType._from_pydatatype(value._union.dtype2) + and DataType._from_py_data_type(self._union.dtype1) + == DataType._from_py_data_type(value._union.dtype1) + and DataType._from_py_data_type(self._union.dtype2) + == DataType._from_py_data_type(value._union.dtype2) ) @@ -247,12 +247,12 @@ def _inner(self) -> PyDataType: return self._option def __str__(self) -> str: - return f"Option({DataType._from_pydatatype(self._option.dtype).__str__()})" + return f"Option({DataType._from_py_data_type(self._option.dtype).__str__()})" def __repr__(self) -> str: - return f"DataType.Option({DataType._from_pydatatype(self._option.dtype).__repr__()})" + return f"DataType.Option({DataType._from_py_data_type(self._option.dtype).__repr__()})" def __eq__(self, value: object) -> bool: - return isinstance(value, Option) and DataType._from_pydatatype( + return isinstance(value, Option) and DataType._from_py_data_type( self._option.dtype - ) == DataType._from_pydatatype(value._option.dtype) + ) == DataType._from_py_data_type(value._option.dtype) diff --git a/medmodels/medrecord/medrecord.py b/medmodels/medrecord/medrecord.py index c114ab8a..7d69f235 100644 --- a/medmodels/medrecord/medrecord.py +++ b/medmodels/medrecord/medrecord.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, List, Optional, Sequence, Union, overload +from typing import Callable, Dict, List, Optional, Sequence, Union, overload import polars as pl @@ -8,7 +8,7 @@ from medmodels.medrecord._overview import extract_attribute_summary, prettify_table from medmodels.medrecord.builder import MedRecordBuilder from medmodels.medrecord.indexers import EdgeIndexer, NodeIndexer -from medmodels.medrecord.querying import EdgeOperand, EdgeQuery, NodeQuery +from medmodels.medrecord.querying import EdgeOperand, EdgeQuery, NodeOperand, NodeQuery from medmodels.medrecord.schema import Schema from medmodels.medrecord.types import ( AttributeInfo, @@ -300,7 +300,7 @@ def schema(self) -> Schema: Returns: Schema: The schema of the MedRecord. """ - return Schema._from_pyschema(self._medrecord.schema) + return Schema._from_py_schema(self._medrecord.schema) @schema.setter def schema(self, schema: Schema) -> None: @@ -435,7 +435,7 @@ def outgoing_edges( Union[List[EdgeIndex], Dict[NodeIndex, List[EdgeIndex]]]: Outgoing edge indices for each specified node. """ - if isinstance(node, NodeQuery): + if isinstance(node, Callable): return self._medrecord.outgoing_edges(self.select_nodes(node)) indices = self._medrecord.outgoing_edges( @@ -472,7 +472,7 @@ def incoming_edges( Union[List[EdgeIndex], Dict[NodeIndex, List[EdgeIndex]]]: Incoming edge indices for each specified node. """ - if isinstance(node, NodeQuery): + if isinstance(node, Callable): return self._medrecord.incoming_edges(self.select_nodes(node)) indices = self._medrecord.incoming_edges( @@ -513,7 +513,7 @@ def edge_endpoints( Tuple of node indices or a dictionary mapping each edge to its node indices. """ - if isinstance(edge, EdgeQuery): + if isinstance(edge, Callable): return self._medrecord.edge_endpoints(self.select_edges(edge)) endpoints = self._medrecord.edge_endpoints( @@ -552,10 +552,10 @@ def edges_connecting( target nodes. """ - if isinstance(source_node, NodeQuery): + if isinstance(source_node, Callable): source_node = self.select_nodes(source_node) - if isinstance(target_node, NodeQuery): + if isinstance(target_node, Callable): target_node = self.select_nodes(target_node) if directed: @@ -594,7 +594,7 @@ def remove_nodes( Union[Attributes, Dict[NodeIndex, Attributes]]: Attributes of the removed node(s). """ - if isinstance(nodes, NodeQuery): + if isinstance(nodes, Callable): return self._medrecord.remove_nodes(self.select_nodes(nodes)) attributes = self._medrecord.remove_nodes( @@ -740,7 +740,7 @@ def remove_edges( Union[Attributes, Dict[EdgeIndex, Attributes]]: Attributes of the removed edge(s). """ - if isinstance(edges, EdgeQuery): + if isinstance(edges, Callable): return self._medrecord.remove_edges(self.select_edges(edges)) attributes = self._medrecord.remove_edges( @@ -884,10 +884,10 @@ def add_group( Returns: None """ - if isinstance(nodes, NodeQuery): + if isinstance(nodes, Callable): nodes = self.select_nodes(nodes) - if isinstance(edges, NodeQuery): + if isinstance(edges, Callable): edges = self.select_edges(edges) if nodes is not None and edges is not None: @@ -933,7 +933,7 @@ def add_nodes_to_group( Returns: None """ - if isinstance(nodes, NodeQuery): + if isinstance(nodes, Callable): return self._medrecord.add_nodes_to_group(group, self.select_nodes(nodes)) return self._medrecord.add_nodes_to_group( @@ -953,7 +953,7 @@ def add_edges_to_group( Returns: None """ - if isinstance(edges, EdgeQuery): + if isinstance(edges, Callable): return self._medrecord.add_edges_to_group(group, self.select_edges(edges)) return self._medrecord.add_edges_to_group( @@ -973,7 +973,7 @@ def remove_nodes_from_group( Returns: None """ - if isinstance(nodes, NodeQuery): + if isinstance(nodes, Callable): return self._medrecord.remove_nodes_from_group( group, self.select_nodes(nodes) ) @@ -995,7 +995,7 @@ def remove_edges_from_group( Returns: None """ - if isinstance(edges, EdgeQuery): + if isinstance(edges, Callable): return self._medrecord.remove_edges_from_group( group, self.select_edges(edges) ) @@ -1091,7 +1091,7 @@ def groups_of_node( Union[List[Group], Dict[NodeIndex, List[Group]]]: Groups associated with each node. """ - if isinstance(node, NodeQuery): + if isinstance(node, Callable): return self._medrecord.groups_of_node(self.select_nodes(node)) groups = self._medrecord.groups_of_node( @@ -1128,7 +1128,7 @@ def groups_of_edge( Union[List[Group], Dict[EdgeIndex, List[Group]]]: Groups associated with each edge. """ - if isinstance(edge, EdgeQuery): + if isinstance(edge, Callable): return self._medrecord.groups_of_edge(self.select_edges(edge)) groups = self._medrecord.groups_of_edge( @@ -1230,7 +1230,7 @@ def neighbors( Returns: Union[List[NodeIndex], Dict[NodeIndex, List[NodeIndex]]]: Neighboring nodes. """ - if isinstance(node, NodeQuery): + if isinstance(node, Callable): node = self.select_nodes(node) if directed: @@ -1257,9 +1257,15 @@ def clear(self) -> None: """ return self._medrecord.clear() - def select_nodes(self, query: NodeQuery) -> List[NodeIndex]: ... + def select_nodes(self, query: NodeQuery) -> List[NodeIndex]: + return self._medrecord.select_nodes( + lambda node: query(NodeOperand._from_py_node_operand(node)) + ) - def select_edges(self, query: EdgeQuery) -> List[EdgeIndex]: ... + def select_edges(self, query: EdgeQuery) -> List[EdgeIndex]: + return self._medrecord.select_edges( + lambda edge: query(EdgeOperand._from_py_edge_operand(edge)) + ) def clone(self) -> MedRecord: """Clones the MedRecord instance. diff --git a/medmodels/medrecord/querying.py b/medmodels/medrecord/querying.py new file mode 100644 index 00000000..5794c6b2 --- /dev/null +++ b/medmodels/medrecord/querying.py @@ -0,0 +1,1869 @@ +from __future__ import annotations + +import sys +from enum import Enum +from typing import TYPE_CHECKING, Callable, List, Union + +from medmodels._medmodels import ( + PyAttributesTreeOperand, + PyEdgeDirection, + PyEdgeIndexOperand, + PyEdgeIndicesOperand, + PyEdgeOperand, + PyMultipleAttributesOperand, + PyMultipleValuesOperand, + PyNodeIndexOperand, + PyNodeIndicesOperand, + PyNodeOperand, + PySingleAttributeOperand, + PySingleValueOperand, +) +from medmodels.medrecord.types import ( + EdgeIndex, + Group, + MedRecordAttribute, + MedRecordValue, + NodeIndex, +) + +if TYPE_CHECKING: + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + +NodeQuery: TypeAlias = Callable[["NodeOperand"], None] +EdgeQuery: TypeAlias = Callable[["EdgeOperand"], None] + +SingleValueComparisonOperand: TypeAlias = Union["SingleValueOperand", MedRecordValue] +SingleValueArithmeticOperand: TypeAlias = SingleValueComparisonOperand +MultipleValuesComparisonOperand: TypeAlias = Union[ + "MultipleValuesOperand", List[MedRecordValue] +] + + +def _py_single_value_comparison_operand_from_single_value_comparison_operand( + single_value_comparison_operand: SingleValueComparisonOperand, +) -> Union[MedRecordValue, PySingleValueOperand]: + if isinstance(single_value_comparison_operand, SingleValueOperand): + return single_value_comparison_operand._single_value_operand + return single_value_comparison_operand + + +def _py_multiple_values_comparison_operand_from_multiple_values_comparison_operand( + multiple_values_comparison_operand: MultipleValuesComparisonOperand, +) -> Union[List[MedRecordValue], PyMultipleValuesOperand]: + if isinstance(multiple_values_comparison_operand, MultipleValuesOperand): + return multiple_values_comparison_operand._multiple_values_operand + return multiple_values_comparison_operand + + +SingleAttributeComparisonOperand: TypeAlias = Union[ + "SingleAttributeOperand", + MedRecordAttribute, +] +SingleAttributeArithmeticOperand: TypeAlias = SingleAttributeComparisonOperand +MultipleAttributesComparisonOperand: TypeAlias = Union[ + "MultipleAttributesOperand", List[MedRecordAttribute] +] + + +def _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + single_attribute_comparison_operand: SingleAttributeComparisonOperand, +) -> Union[MedRecordAttribute, PySingleAttributeOperand]: + if isinstance(single_attribute_comparison_operand, SingleAttributeOperand): + return single_attribute_comparison_operand._single_attribute_operand + return single_attribute_comparison_operand + + +def _py_multiple_attributes_comparison_operand_from_multiple_attributes_comparison_operand( + multiple_attributes_comparison_operand: MultipleAttributesComparisonOperand, +) -> Union[List[MedRecordAttribute], PyMultipleAttributesOperand]: + if isinstance(multiple_attributes_comparison_operand, MultipleAttributesOperand): + return multiple_attributes_comparison_operand._multiple_attributes_operand + return multiple_attributes_comparison_operand + + +NodeIndexComparisonOperand: TypeAlias = Union["NodeIndexOperand", NodeIndex] +NodeIndexArithmeticOperand: TypeAlias = NodeIndexComparisonOperand +NodeIndicesComparisonOperand: TypeAlias = Union["NodeIndicesOperand", List[NodeIndex]] + + +def _py_node_index_comparison_operand_from_node_index_comparison_operand( + node_index_comparison_operand: NodeIndexComparisonOperand, +) -> Union[NodeIndex, PyNodeIndexOperand]: + if isinstance(node_index_comparison_operand, NodeIndexOperand): + return node_index_comparison_operand._node_index_operand + return node_index_comparison_operand + + +def _py_node_indices_comparison_operand_from_node_indices_comparison_operand( + node_indices_comparison_operand: NodeIndicesComparisonOperand, +) -> Union[List[NodeIndex], PyNodeIndicesOperand]: + if isinstance(node_indices_comparison_operand, NodeIndicesOperand): + return node_indices_comparison_operand._node_indices_operand + return node_indices_comparison_operand + + +EdgeIndexComparisonOperand: TypeAlias = Union[ + "EdgeIndexOperand", + EdgeIndex, +] +EdgeIndexArithmeticOperand: TypeAlias = EdgeIndexComparisonOperand +EdgeIndicesComparisonOperand: TypeAlias = Union[ + "EdgeIndicesOperand", + List[EdgeIndex], +] + + +def _py_edge_index_comparison_operand_from_edge_index_comparison_operand( + edge_index_comparison_operand: EdgeIndexComparisonOperand, +) -> Union[EdgeIndex, PyEdgeIndexOperand]: + if isinstance(edge_index_comparison_operand, EdgeIndexOperand): + return edge_index_comparison_operand._edge_index_operand + return edge_index_comparison_operand + + +def _py_edge_indices_comparison_operand_from_edge_indices_comparison_operand( + edge_indices_comparison_operand: EdgeIndicesComparisonOperand, +) -> Union[List[EdgeIndex], PyEdgeIndicesOperand]: + if isinstance(edge_indices_comparison_operand, EdgeIndicesOperand): + return edge_indices_comparison_operand._edge_indices_operand + return edge_indices_comparison_operand + + +class EdgeDirection(Enum): + INCOMING = 0 + OUTGOING = 1 + BOTH = 2 + + def _into_py_edge_direction(self) -> PyEdgeDirection: + return ( + PyEdgeDirection.Incoming + if self == EdgeDirection.INCOMING + else PyEdgeDirection.Outgoing + if self == EdgeDirection.OUTGOING + else PyEdgeDirection.Both + ) + + +class NodeOperand: + _node_operand: PyNodeOperand + + def attribute(self, attribute: MedRecordAttribute) -> MultipleValuesOperand: + return MultipleValuesOperand._from_py_multiple_values_operand( + self._node_operand.attribute(attribute) + ) + + def attributes(self) -> AttributesTreeOperand: + return AttributesTreeOperand._from_py_attributes_tree_operand( + self._node_operand.attributes() + ) + + def index(self) -> NodeIndicesOperand: + return NodeIndicesOperand._from_py_node_indices_operand( + self._node_operand.index() + ) + + def in_group(self, group: Union[Group, List[Group]]) -> None: + self._node_operand.in_group(group) + + def has_attribute( + self, attribute: Union[MedRecordAttribute, List[MedRecordAttribute]] + ) -> None: + self._node_operand.has_attribute(attribute) + + def outgoing_edges(self) -> EdgeOperand: + return EdgeOperand._from_py_edge_operand(self._node_operand.outgoing_edges()) + + def incoming_edges(self) -> EdgeOperand: + return EdgeOperand._from_py_edge_operand(self._node_operand.incoming_edges()) + + def neighbors( + self, edge_direction: EdgeDirection = EdgeDirection.OUTGOING + ) -> NodeOperand: + return NodeOperand._from_py_node_operand( + self._node_operand.neighbors(edge_direction._into_py_edge_direction()) + ) + + def either_or(self, either: NodeQuery, or_: NodeQuery) -> None: + self._node_operand.either_or( + lambda node: either(NodeOperand._from_py_node_operand(node)), + lambda node: or_(NodeOperand._from_py_node_operand(node)), + ) + + def clone(self) -> NodeOperand: + return NodeOperand._from_py_node_operand(self._node_operand.deep_clone()) + + @classmethod + def _from_py_node_operand(cls, py_node_operand: PyNodeOperand) -> NodeOperand: + node_operand = cls() + node_operand._node_operand = py_node_operand + return node_operand + + +class EdgeOperand: + _edge_operand: PyEdgeOperand + + def attribute(self, attribute: MedRecordAttribute) -> MultipleValuesOperand: + return MultipleValuesOperand._from_py_multiple_values_operand( + self._edge_operand.attribute(attribute) + ) + + def attributes(self) -> AttributesTreeOperand: + return AttributesTreeOperand._from_py_attributes_tree_operand( + self._edge_operand.attributes() + ) + + def index(self) -> EdgeIndicesOperand: + return EdgeIndicesOperand._from_edge_indices_operand(self._edge_operand.index()) + + def in_group(self, group: Union[Group, List[Group]]) -> None: + self._edge_operand.in_group(group) + + def has_attribute( + self, attribute: Union[MedRecordAttribute, List[MedRecordAttribute]] + ) -> None: + self._edge_operand.has_attribute(attribute) + + def source_node(self) -> NodeOperand: + return NodeOperand._from_py_node_operand(self._edge_operand.source_node()) + + def target_node(self) -> NodeOperand: + return NodeOperand._from_py_node_operand(self._edge_operand.target_node()) + + def either_or(self, either: EdgeQuery, or_: EdgeQuery) -> None: + self._edge_operand.either_or( + lambda edge: either(EdgeOperand._from_py_edge_operand(edge)), + lambda edge: or_(EdgeOperand._from_py_edge_operand(edge)), + ) + + def clone(self) -> EdgeOperand: + return EdgeOperand._from_py_edge_operand(self._edge_operand.deep_clone()) + + @classmethod + def _from_py_edge_operand(cls, py_edge_operand: PyEdgeOperand) -> EdgeOperand: + edge_operand = cls() + edge_operand._edge_operand = py_edge_operand + return edge_operand + + +class MultipleValuesOperand: + _multiple_values_operand: PyMultipleValuesOperand + + def max(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.max() + ) + + def min(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.min() + ) + + def mean(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.mean() + ) + + def median(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.median() + ) + + def mode(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.mode() + ) + + def std(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.std() + ) + + def var(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.var() + ) + + def count(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.count() + ) + + def sum(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.sum() + ) + + def first(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.first() + ) + + def last(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._multiple_values_operand.last() + ) + + def is_string(self) -> None: + self._multiple_values_operand.is_string() + + def is_int(self) -> None: + self._multiple_values_operand.is_int() + + def is_float(self) -> None: + self._multiple_values_operand.is_float() + + def is_bool(self) -> None: + self._multiple_values_operand.is_bool() + + def is_datetime(self) -> None: + self._multiple_values_operand.is_datetime() + + def is_null(self) -> None: + self._multiple_values_operand.is_null() + + def is_max(self) -> None: + self._multiple_values_operand.is_max() + + def is_min(self) -> None: + self._multiple_values_operand.is_min() + + def greater_than(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.greater_than( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def greater_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.greater_than_or_equal_to( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def less_than(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.less_than( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def less_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.less_than_or_equal_to( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def equal_to(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.equal_to( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def not_equal_to(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.not_equal_to( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def is_in(self, values: MultipleValuesComparisonOperand) -> None: + self._multiple_values_operand.is_in( + _py_multiple_values_comparison_operand_from_multiple_values_comparison_operand( + values + ) + ) + + def is_not_in(self, values: MultipleValuesComparisonOperand) -> None: + self._multiple_values_operand.is_not_in( + _py_multiple_values_comparison_operand_from_multiple_values_comparison_operand( + values + ) + ) + + def starts_with(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.starts_with( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def ends_with(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.ends_with( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def contains(self, value: SingleValueComparisonOperand) -> None: + self._multiple_values_operand.contains( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def add(self, value: SingleValueArithmeticOperand) -> None: + self._multiple_values_operand.add( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def subtract(self, value: SingleValueArithmeticOperand) -> None: + self._multiple_values_operand.sub( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def multiply(self, value: SingleValueArithmeticOperand) -> None: + self._multiple_values_operand.mul( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def divide(self, value: SingleValueArithmeticOperand) -> None: + self._multiple_values_operand.div( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def modulo(self, value: SingleValueArithmeticOperand) -> None: + self._multiple_values_operand.mod( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def power(self, value: SingleValueArithmeticOperand) -> None: + self._multiple_values_operand.pow( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def round(self) -> None: + self._multiple_values_operand.round() + + def ceil(self) -> None: + self._multiple_values_operand.ceil() + + def floor(self) -> None: + self._multiple_values_operand.floor() + + def absolute(self) -> None: + self._multiple_values_operand.abs() + + def sqrt(self) -> None: + self._multiple_values_operand.sqrt() + + def trim(self) -> None: + self._multiple_values_operand.trim() + + def trim_start(self) -> None: + self._multiple_values_operand.trim_start() + + def trim_end(self) -> None: + self._multiple_values_operand.trim_end() + + def lowercase(self) -> None: + self._multiple_values_operand.lowercase() + + def uppercase(self) -> None: + self._multiple_values_operand.uppercase() + + def slice(self, start: int, end: int) -> None: + self._multiple_values_operand.slice(start, end) + + def either_or( + self, + either: Callable[[MultipleValuesOperand], None], + or_: Callable[[MultipleValuesOperand], None], + ) -> None: + self._multiple_values_operand.either_or( + lambda values: either( + MultipleValuesOperand._from_py_multiple_values_operand(values) + ), + lambda values: or_( + MultipleValuesOperand._from_py_multiple_values_operand(values) + ), + ) + + def clone(self) -> MultipleValuesOperand: + return MultipleValuesOperand._from_py_multiple_values_operand( + self._multiple_values_operand.deep_clone() + ) + + @classmethod + def _from_py_multiple_values_operand( + cls, py_multiple_values_operand: PyMultipleValuesOperand + ) -> MultipleValuesOperand: + multiple_values_operand = cls() + multiple_values_operand._multiple_values_operand = py_multiple_values_operand + return multiple_values_operand + + +class SingleValueOperand: + _single_value_operand: PySingleValueOperand + + def is_string(self) -> None: + self._single_value_operand.is_string() + + def is_int(self) -> None: + self._single_value_operand.is_int() + + def is_float(self) -> None: + self._single_value_operand.is_float() + + def is_bool(self) -> None: + self._single_value_operand.is_bool() + + def is_datetime(self) -> None: + self._single_value_operand.is_datetime() + + def is_null(self) -> None: + self._single_value_operand.is_null() + + def greater_than(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.greater_than( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def greater_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.greater_than_or_equal_to( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def less_than(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.less_than( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def less_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.less_than_or_equal_to( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def equal_to(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.equal_to( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def not_equal_to(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.not_equal_to( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def is_in(self, values: MultipleValuesComparisonOperand) -> None: + self._single_value_operand.is_in( + _py_multiple_values_comparison_operand_from_multiple_values_comparison_operand( + values + ) + ) + + def is_not_in(self, values: MultipleValuesComparisonOperand) -> None: + self._single_value_operand.is_not_in( + _py_multiple_values_comparison_operand_from_multiple_values_comparison_operand( + values + ) + ) + + def starts_with(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.starts_with( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def ends_with(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.ends_with( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def contains(self, value: SingleValueComparisonOperand) -> None: + self._single_value_operand.contains( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def add(self, value: SingleValueArithmeticOperand) -> None: + self._single_value_operand.add( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def subtract(self, value: SingleValueArithmeticOperand) -> None: + self._single_value_operand.sub( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def multiply(self, value: SingleValueArithmeticOperand) -> None: + self._single_value_operand.mul( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def modulo(self, value: SingleValueArithmeticOperand) -> None: + self._single_value_operand.mod( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def power(self, value: SingleValueArithmeticOperand) -> None: + self._single_value_operand.pow( + _py_single_value_comparison_operand_from_single_value_comparison_operand( + value + ) + ) + + def round(self) -> None: + self._single_value_operand.round() + + def ceil(self) -> None: + self._single_value_operand.ceil() + + def floor(self) -> None: + self._single_value_operand.floor() + + def absolute(self) -> None: + self._single_value_operand.abs() + + def sqrt(self) -> None: + self._single_value_operand.sqrt() + + def trim(self) -> None: + self._single_value_operand.trim() + + def trim_start(self) -> None: + self._single_value_operand.trim_start() + + def trim_end(self) -> None: + self._single_value_operand.trim_end() + + def lowercase(self) -> None: + self._single_value_operand.lowercase() + + def uppercase(self) -> None: + self._single_value_operand.uppercase() + + def slice(self, start: int, end: int) -> None: + self._single_value_operand.slice(start, end) + + def either_or( + self, + either: Callable[[SingleValueOperand], None], + or_: Callable[[SingleValueOperand], None], + ) -> None: + self._single_value_operand.either_or( + lambda value: either( + SingleValueOperand._from_py_single_value_operand(value) + ), + lambda value: or_(SingleValueOperand._from_py_single_value_operand(value)), + ) + + def clone(self) -> SingleValueOperand: + return SingleValueOperand._from_py_single_value_operand( + self._single_value_operand.deep_clone() + ) + + @classmethod + def _from_py_single_value_operand( + cls, py_single_value_operand: PySingleValueOperand + ) -> SingleValueOperand: + single_value_operand = cls() + single_value_operand._single_value_operand = py_single_value_operand + return single_value_operand + + +class AttributesTreeOperand: + _attributes_tree_operand: PyAttributesTreeOperand + + def max(self) -> MultipleAttributesOperand: + return MultipleAttributesOperand._from_py_multiple_attributes_operand( + self._attributes_tree_operand.max() + ) + + def min(self) -> MultipleAttributesOperand: + return MultipleAttributesOperand._from_py_multiple_attributes_operand( + self._attributes_tree_operand.min() + ) + + def count(self) -> MultipleAttributesOperand: + return MultipleAttributesOperand._from_py_multiple_attributes_operand( + self._attributes_tree_operand.count() + ) + + def sum(self) -> MultipleAttributesOperand: + return MultipleAttributesOperand._from_py_multiple_attributes_operand( + self._attributes_tree_operand.sum() + ) + + def first(self) -> MultipleAttributesOperand: + return MultipleAttributesOperand._from_py_multiple_attributes_operand( + self._attributes_tree_operand.first() + ) + + def last(self) -> MultipleAttributesOperand: + return MultipleAttributesOperand._from_py_multiple_attributes_operand( + self._attributes_tree_operand.last() + ) + + def is_string(self) -> None: + self._attributes_tree_operand.is_string() + + def is_int(self) -> None: + self._attributes_tree_operand.is_int() + + def is_max(self) -> None: + self._attributes_tree_operand.is_max() + + def is_min(self) -> None: + self._attributes_tree_operand.is_min() + + def greater_than(self, attribute: SingleAttributeComparisonOperand) -> None: + self._attributes_tree_operand.greater_than( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def greater_than_or_equal_to( + self, attribute: SingleAttributeComparisonOperand + ) -> None: + self._attributes_tree_operand.greater_than_or_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def less_than(self, attribute: SingleAttributeComparisonOperand) -> None: + self._attributes_tree_operand.less_than( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def less_than_or_equal_to( + self, attribute: SingleAttributeComparisonOperand + ) -> None: + self._attributes_tree_operand.less_than_or_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def equal_to(self, attribute: SingleAttributeComparisonOperand) -> None: + self._attributes_tree_operand.equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def not_equal_to(self, attribute: SingleAttributeComparisonOperand) -> None: + self._attributes_tree_operand.not_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def is_in(self, attributes: MultipleAttributesComparisonOperand) -> None: + self._attributes_tree_operand.is_in( + _py_multiple_attributes_comparison_operand_from_multiple_attributes_comparison_operand( + attributes + ) + ) + + def is_not_in(self, attributes: MultipleAttributesComparisonOperand) -> None: + self._attributes_tree_operand.is_not_in( + _py_multiple_attributes_comparison_operand_from_multiple_attributes_comparison_operand( + attributes + ) + ) + + def starts_with(self, attribute: SingleAttributeComparisonOperand) -> None: + self._attributes_tree_operand.starts_with( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def ends_with(self, attribute: SingleAttributeComparisonOperand) -> None: + self._attributes_tree_operand.ends_with( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def contains(self, attribute: SingleAttributeComparisonOperand) -> None: + self._attributes_tree_operand.contains( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def add(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._attributes_tree_operand.add( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def subtract(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._attributes_tree_operand.sub( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def multiply(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._attributes_tree_operand.mul( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def modulo(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._attributes_tree_operand.mod( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def power(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._attributes_tree_operand.pow( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def absolute(self) -> None: + self._attributes_tree_operand.abs() + + def trim(self) -> None: + self._attributes_tree_operand.trim() + + def trim_start(self) -> None: + self._attributes_tree_operand.trim_start() + + def trim_end(self) -> None: + self._attributes_tree_operand.trim_end() + + def lowercase(self) -> None: + self._attributes_tree_operand.lowercase() + + def uppercase(self) -> None: + self._attributes_tree_operand.uppercase() + + def slice(self, start: int, end: int) -> None: + self._attributes_tree_operand.slice(start, end) + + def either_or( + self, + either: Callable[[AttributesTreeOperand], None], + or_: Callable[[AttributesTreeOperand], None], + ) -> None: + self._attributes_tree_operand.either_or( + lambda attributes: either( + AttributesTreeOperand._from_py_attributes_tree_operand(attributes) + ), + lambda attributes: or_( + AttributesTreeOperand._from_py_attributes_tree_operand(attributes) + ), + ) + + def clone(self) -> AttributesTreeOperand: + return AttributesTreeOperand._from_py_attributes_tree_operand( + self._attributes_tree_operand.deep_clone() + ) + + @classmethod + def _from_py_attributes_tree_operand( + cls, py_attributes_tree_operand: PyAttributesTreeOperand + ) -> AttributesTreeOperand: + attributes_tree_operand = cls() + attributes_tree_operand._attributes_tree_operand = py_attributes_tree_operand + return attributes_tree_operand + + +class MultipleAttributesOperand: + _multiple_attributes_operand: PyMultipleAttributesOperand + + def max(self) -> SingleAttributeOperand: + return SingleAttributeOperand._from_py_single_attribute_operand( + self._multiple_attributes_operand.max() + ) + + def min(self) -> SingleAttributeOperand: + return SingleAttributeOperand._from_py_single_attribute_operand( + self._multiple_attributes_operand.min() + ) + + def count(self) -> SingleAttributeOperand: + return SingleAttributeOperand._from_py_single_attribute_operand( + self._multiple_attributes_operand.count() + ) + + def sum(self) -> SingleAttributeOperand: + return SingleAttributeOperand._from_py_single_attribute_operand( + self._multiple_attributes_operand.sum() + ) + + def first(self) -> SingleAttributeOperand: + return SingleAttributeOperand._from_py_single_attribute_operand( + self._multiple_attributes_operand.first() + ) + + def last(self) -> SingleAttributeOperand: + return SingleAttributeOperand._from_py_single_attribute_operand( + self._multiple_attributes_operand.last() + ) + + def is_string(self) -> None: + self._multiple_attributes_operand.is_string() + + def is_int(self) -> None: + self._multiple_attributes_operand.is_int() + + def is_max(self) -> None: + self._multiple_attributes_operand.is_max() + + def is_min(self) -> None: + self._multiple_attributes_operand.is_min() + + def greater_than(self, attribute: SingleAttributeComparisonOperand) -> None: + self._multiple_attributes_operand.greater_than( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def greater_than_or_equal_to( + self, attribute: SingleAttributeComparisonOperand + ) -> None: + self._multiple_attributes_operand.greater_than_or_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def less_than(self, attribute: SingleAttributeComparisonOperand) -> None: + self._multiple_attributes_operand.less_than( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def less_than_or_equal_to( + self, attribute: SingleAttributeComparisonOperand + ) -> None: + self._multiple_attributes_operand.less_than_or_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def equal_to(self, attribute: SingleAttributeComparisonOperand) -> None: + self._multiple_attributes_operand.equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def not_equal_to(self, attribute: SingleAttributeComparisonOperand) -> None: + self._multiple_attributes_operand.not_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def is_in(self, attributes: MultipleAttributesComparisonOperand) -> None: + self._multiple_attributes_operand.is_in( + _py_multiple_attributes_comparison_operand_from_multiple_attributes_comparison_operand( + attributes + ) + ) + + def is_not_in(self, attributes: MultipleAttributesComparisonOperand) -> None: + self._multiple_attributes_operand.is_not_in( + _py_multiple_attributes_comparison_operand_from_multiple_attributes_comparison_operand( + attributes + ) + ) + + def starts_with(self, attribute: SingleAttributeComparisonOperand) -> None: + self._multiple_attributes_operand.starts_with( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def ends_with(self, attribute: SingleAttributeComparisonOperand) -> None: + self._multiple_attributes_operand.ends_with( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def contains(self, attribute: SingleAttributeComparisonOperand) -> None: + self._multiple_attributes_operand.contains( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def add(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._multiple_attributes_operand.add( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def subtract(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._multiple_attributes_operand.sub( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def multiply(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._multiple_attributes_operand.mul( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def modulo(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._multiple_attributes_operand.mod( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def power(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._multiple_attributes_operand.pow( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def absolute(self) -> None: + self._multiple_attributes_operand.abs() + + def trim(self) -> None: + self._multiple_attributes_operand.trim() + + def trim_start(self) -> None: + self._multiple_attributes_operand.trim_start() + + def trim_end(self) -> None: + self._multiple_attributes_operand.trim_end() + + def lowercase(self) -> None: + self._multiple_attributes_operand.lowercase() + + def uppercase(self) -> None: + self._multiple_attributes_operand.uppercase() + + def to_values(self) -> MultipleValuesOperand: + return MultipleValuesOperand._from_py_multiple_values_operand( + self._multiple_attributes_operand.to_values() + ) + + def slice(self, start: int, end: int) -> None: + self._multiple_attributes_operand.slice(start, end) + + def either_or( + self, + either: Callable[[MultipleAttributesOperand], None], + or_: Callable[[MultipleAttributesOperand], None], + ) -> None: + self._multiple_attributes_operand.either_or( + lambda attributes: either( + MultipleAttributesOperand._from_py_multiple_attributes_operand( + attributes + ) + ), + lambda attributes: or_( + MultipleAttributesOperand._from_py_multiple_attributes_operand( + attributes + ) + ), + ) + + def clone(self) -> MultipleAttributesOperand: + return MultipleAttributesOperand._from_py_multiple_attributes_operand( + self._multiple_attributes_operand.deep_clone() + ) + + @classmethod + def _from_py_multiple_attributes_operand( + cls, py_multiple_attributes_operand: PyMultipleAttributesOperand + ) -> MultipleAttributesOperand: + multiple_attributes_operand = cls() + multiple_attributes_operand._multiple_attributes_operand = ( + py_multiple_attributes_operand + ) + return multiple_attributes_operand + + +class SingleAttributeOperand: + _single_attribute_operand: PySingleAttributeOperand + + def is_string(self) -> None: + self._single_attribute_operand.is_string() + + def is_int(self) -> None: + self._single_attribute_operand.is_int() + + def greater_than(self, attribute: SingleAttributeComparisonOperand) -> None: + self._single_attribute_operand.greater_than( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def greater_than_or_equal_to( + self, attribute: SingleAttributeComparisonOperand + ) -> None: + self._single_attribute_operand.greater_than_or_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def less_than(self, attribute: SingleAttributeComparisonOperand) -> None: + self._single_attribute_operand.less_than( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def less_than_or_equal_to( + self, attribute: SingleAttributeComparisonOperand + ) -> None: + self._single_attribute_operand.less_than_or_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def equal_to(self, attribute: SingleAttributeComparisonOperand) -> None: + self._single_attribute_operand.equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def not_equal_to(self, attribute: SingleAttributeComparisonOperand) -> None: + self._single_attribute_operand.not_equal_to( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def is_in(self, attributes: MultipleAttributesComparisonOperand) -> None: + self._single_attribute_operand.is_in( + _py_multiple_attributes_comparison_operand_from_multiple_attributes_comparison_operand( + attributes + ) + ) + + def is_not_in(self, attributes: MultipleAttributesComparisonOperand) -> None: + self._single_attribute_operand.is_not_in( + _py_multiple_attributes_comparison_operand_from_multiple_attributes_comparison_operand( + attributes + ) + ) + + def starts_with(self, attribute: SingleAttributeComparisonOperand) -> None: + self._single_attribute_operand.starts_with( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def ends_with(self, attribute: SingleAttributeComparisonOperand) -> None: + self._single_attribute_operand.ends_with( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def contains(self, attribute: SingleAttributeComparisonOperand) -> None: + self._single_attribute_operand.contains( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def add(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._single_attribute_operand.add( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def subtract(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._single_attribute_operand.sub( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def multiply(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._single_attribute_operand.mul( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def modulo(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._single_attribute_operand.mod( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def power(self, attribute: SingleAttributeArithmeticOperand) -> None: + self._single_attribute_operand.pow( + _py_single_attribute_comparison_operand_from_single_attribute_comparison_operand( + attribute + ) + ) + + def absolute(self) -> None: + self._single_attribute_operand.abs() + + def trim(self) -> None: + self._single_attribute_operand.trim() + + def trim_start(self) -> None: + self._single_attribute_operand.trim_start() + + def trim_end(self) -> None: + self._single_attribute_operand.trim_end() + + def lowercase(self) -> None: + self._single_attribute_operand.lowercase() + + def uppercase(self) -> None: + self._single_attribute_operand.uppercase() + + def slice(self, start: int, end: int) -> None: + self._single_attribute_operand.slice(start, end) + + def either_or( + self, + either: Callable[[SingleAttributeOperand], None], + or_: Callable[[SingleAttributeOperand], None], + ) -> None: + self._single_attribute_operand.either_or( + lambda attribute: either( + SingleAttributeOperand._from_py_single_attribute_operand(attribute) + ), + lambda attribute: or_( + SingleAttributeOperand._from_py_single_attribute_operand(attribute) + ), + ) + + def clone(self) -> SingleAttributeOperand: + return SingleAttributeOperand._from_py_single_attribute_operand( + self._single_attribute_operand.deep_clone() + ) + + @classmethod + def _from_py_single_attribute_operand( + cls, py_single_attribute_operand: PySingleAttributeOperand + ) -> SingleAttributeOperand: + single_attribute_operand = cls() + single_attribute_operand._single_attribute_operand = py_single_attribute_operand + return single_attribute_operand + + +class NodeIndicesOperand: + _node_indices_operand: PyNodeIndicesOperand + + def max(self) -> NodeIndexOperand: + return NodeIndexOperand._from_py_node_index_operand( + self._node_indices_operand.max() + ) + + def min(self) -> NodeIndexOperand: + return NodeIndexOperand._from_py_node_index_operand( + self._node_indices_operand.min() + ) + + def count(self) -> NodeIndexOperand: + return NodeIndexOperand._from_py_node_index_operand( + self._node_indices_operand.count() + ) + + def sum(self) -> NodeIndexOperand: + return NodeIndexOperand._from_py_node_index_operand( + self._node_indices_operand.sum() + ) + + def first(self) -> NodeIndexOperand: + return NodeIndexOperand._from_py_node_index_operand( + self._node_indices_operand.first() + ) + + def last(self) -> NodeIndexOperand: + return NodeIndexOperand._from_py_node_index_operand( + self._node_indices_operand.last() + ) + + def greater_than(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.greater_than( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def greater_than_or_equal_to(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.greater_than_or_equal_to( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def less_than(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.less_than( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def less_than_or_equal_to(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.less_than_or_equal_to( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def equal_to(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.equal_to( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def not_equal_to(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.not_equal_to( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def is_in(self, indices: NodeIndicesComparisonOperand) -> None: + self._node_indices_operand.is_in( + _py_node_indices_comparison_operand_from_node_indices_comparison_operand( + indices + ) + ) + + def is_not_in(self, indices: NodeIndicesComparisonOperand) -> None: + self._node_indices_operand.is_not_in( + _py_node_indices_comparison_operand_from_node_indices_comparison_operand( + indices + ) + ) + + def starts_with(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.starts_with( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def ends_with(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.ends_with( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def contains(self, index: NodeIndexComparisonOperand) -> None: + self._node_indices_operand.contains( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def add(self, index: NodeIndexArithmeticOperand) -> None: + self._node_indices_operand.add( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def subtract(self, index: NodeIndexArithmeticOperand) -> None: + self._node_indices_operand.sub( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def multiply(self, index: NodeIndexArithmeticOperand) -> None: + self._node_indices_operand.mul( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def modulo(self, index: NodeIndexArithmeticOperand) -> None: + self._node_indices_operand.mod( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def power(self, index: NodeIndexArithmeticOperand) -> None: + self._node_indices_operand.pow( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def absolute(self) -> None: + self._node_indices_operand.abs() + + def trim(self) -> None: + self._node_indices_operand.trim() + + def trim_start(self) -> None: + self._node_indices_operand.trim_start() + + def trim_end(self) -> None: + self._node_indices_operand.trim_end() + + def lowercase(self) -> None: + self._node_indices_operand.lowercase() + + def uppercase(self) -> None: + self._node_indices_operand.uppercase() + + def slice(self, start: int, end: int) -> None: + self._node_indices_operand.slice(start, end) + + def either_or( + self, + either: Callable[[NodeIndicesOperand], None], + or_: Callable[[NodeIndicesOperand], None], + ) -> None: + self._node_indices_operand.either_or( + lambda node_indices: either( + NodeIndicesOperand._from_py_node_indices_operand(node_indices) + ), + lambda node_indices: or_( + NodeIndicesOperand._from_py_node_indices_operand(node_indices) + ), + ) + + def clone(self) -> NodeIndicesOperand: + return NodeIndicesOperand._from_py_node_indices_operand( + self._node_indices_operand.deep_clone() + ) + + @classmethod + def _from_py_node_indices_operand( + cls, py_node_indices_operand: PyNodeIndicesOperand + ) -> NodeIndicesOperand: + node_indices_operand = cls() + node_indices_operand._node_indices_operand = py_node_indices_operand + return node_indices_operand + + +class NodeIndexOperand: + _node_index_operand: PyNodeIndexOperand + + def greater_than(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.greater_than( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def greater_than_or_equal_to(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.greater_than_or_equal_to( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def less_than(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.less_than( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def less_than_or_equal_to(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.less_than_or_equal_to( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def equal_to(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.equal_to( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def not_equal_to(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.not_equal_to( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def is_in(self, indices: NodeIndicesComparisonOperand) -> None: + self._node_index_operand.is_in( + _py_node_indices_comparison_operand_from_node_indices_comparison_operand( + indices + ) + ) + + def is_not_in(self, indices: NodeIndicesComparisonOperand) -> None: + self._node_index_operand.is_not_in( + _py_node_indices_comparison_operand_from_node_indices_comparison_operand( + indices + ) + ) + + def starts_with(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.starts_with( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def ends_with(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.ends_with( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def contains(self, index: NodeIndexComparisonOperand) -> None: + self._node_index_operand.contains( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def add(self, index: NodeIndexArithmeticOperand) -> None: + self._node_index_operand.add( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def subtract(self, index: NodeIndexArithmeticOperand) -> None: + self._node_index_operand.sub( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def multiply(self, index: NodeIndexArithmeticOperand) -> None: + self._node_index_operand.mul( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def modulo(self, index: NodeIndexArithmeticOperand) -> None: + self._node_index_operand.mod( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def power(self, index: NodeIndexArithmeticOperand) -> None: + self._node_index_operand.pow( + _py_node_index_comparison_operand_from_node_index_comparison_operand(index) + ) + + def absolute(self) -> None: + self._node_index_operand.abs() + + def trim(self) -> None: + self._node_index_operand.trim() + + def trim_start(self) -> None: + self._node_index_operand.trim_start() + + def trim_end(self) -> None: + self._node_index_operand.trim_end() + + def lowercase(self) -> None: + self._node_index_operand.lowercase() + + def uppercase(self) -> None: + self._node_index_operand.uppercase() + + def slice(self, start: int, end: int) -> None: + self._node_index_operand.slice(start, end) + + def either_or( + self, + either: Callable[[NodeIndexOperand], None], + or_: Callable[[NodeIndexOperand], None], + ) -> None: + self._node_index_operand.either_or( + lambda node_index: either( + NodeIndexOperand._from_py_node_index_operand(node_index) + ), + lambda node_index: or_( + NodeIndexOperand._from_py_node_index_operand(node_index) + ), + ) + + def clone(self) -> NodeIndexOperand: + return NodeIndexOperand._from_py_node_index_operand( + self._node_index_operand.deep_clone() + ) + + @classmethod + def _from_py_node_index_operand( + cls, py_node_index_operand: PyNodeIndexOperand + ) -> NodeIndexOperand: + node_index_operand = cls() + node_index_operand._node_index_operand = py_node_index_operand + return node_index_operand + + +class EdgeIndicesOperand: + _edge_indices_operand: PyEdgeIndicesOperand + + def max(self) -> EdgeIndexOperand: + return EdgeIndexOperand._from_py_edge_index_operand( + self._edge_indices_operand.max() + ) + + def min(self) -> EdgeIndexOperand: + return EdgeIndexOperand._from_py_edge_index_operand( + self._edge_indices_operand.min() + ) + + def count(self) -> EdgeIndexOperand: + return EdgeIndexOperand._from_py_edge_index_operand( + self._edge_indices_operand.count() + ) + + def sum(self) -> EdgeIndexOperand: + return EdgeIndexOperand._from_py_edge_index_operand( + self._edge_indices_operand.sum() + ) + + def first(self) -> EdgeIndexOperand: + return EdgeIndexOperand._from_py_edge_index_operand( + self._edge_indices_operand.first() + ) + + def last(self) -> EdgeIndexOperand: + return EdgeIndexOperand._from_py_edge_index_operand( + self._edge_indices_operand.last() + ) + + def greater_than(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.greater_than( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def greater_than_or_equal_to(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.greater_than_or_equal_to( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def less_than(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.less_than( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def less_than_or_equal_to(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.less_than_or_equal_to( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def equal_to(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.equal_to( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def not_equal_to(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.not_equal_to( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def is_in(self, indices: EdgeIndicesComparisonOperand) -> None: + self._edge_indices_operand.is_in( + _py_edge_indices_comparison_operand_from_edge_indices_comparison_operand( + indices + ) + ) + + def is_not_in(self, indices: EdgeIndicesComparisonOperand) -> None: + self._edge_indices_operand.is_not_in( + _py_edge_indices_comparison_operand_from_edge_indices_comparison_operand( + indices + ) + ) + + def starts_with(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.starts_with( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def ends_with(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.ends_with( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def contains(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_indices_operand.contains( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def add(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_indices_operand.add( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def subtract(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_indices_operand.sub( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def multiply(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_indices_operand.mul( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def modulo(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_indices_operand.mod( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def power(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_indices_operand.pow( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def either_or( + self, + either: Callable[[EdgeIndicesOperand], None], + or_: Callable[[EdgeIndicesOperand], None], + ) -> None: + self._edge_indices_operand.either_or( + lambda edge_indices: either( + EdgeIndicesOperand._from_edge_indices_operand(edge_indices) + ), + lambda edge_indices: or_( + EdgeIndicesOperand._from_edge_indices_operand(edge_indices) + ), + ) + + def clone(self) -> EdgeIndicesOperand: + return EdgeIndicesOperand._from_edge_indices_operand( + self._edge_indices_operand.deep_clone() + ) + + @classmethod + def _from_edge_indices_operand( + cls, py_edge_indices_operand: PyEdgeIndicesOperand + ) -> EdgeIndicesOperand: + edge_indices_operand = cls() + edge_indices_operand._edge_indices_operand = py_edge_indices_operand + return edge_indices_operand + + +class EdgeIndexOperand: + _edge_index_operand: PyEdgeIndexOperand + + def greater_than(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.greater_than( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def greater_than_or_equal_to(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.greater_than_or_equal_to( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def less_than(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.less_than( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def less_than_or_equal_to(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.less_than_or_equal_to( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def equal_to(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.equal_to( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def not_equal_to(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.not_equal_to( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def is_in(self, indices: EdgeIndicesComparisonOperand) -> None: + self._edge_index_operand.is_in( + _py_edge_indices_comparison_operand_from_edge_indices_comparison_operand( + indices + ) + ) + + def is_not_in(self, indices: EdgeIndicesComparisonOperand) -> None: + self._edge_index_operand.is_not_in( + _py_edge_indices_comparison_operand_from_edge_indices_comparison_operand( + indices + ) + ) + + def starts_with(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.starts_with( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def ends_with(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.ends_with( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def contains(self, index: EdgeIndexComparisonOperand) -> None: + self._edge_index_operand.contains( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def add(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_index_operand.add( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def subtract(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_index_operand.sub( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def multiply(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_index_operand.mul( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def modulo(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_index_operand.mod( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def power(self, index: EdgeIndexArithmeticOperand) -> None: + self._edge_index_operand.pow( + _py_edge_index_comparison_operand_from_edge_index_comparison_operand(index) + ) + + def either_or( + self, + either: Callable[[EdgeIndexOperand], None], + or_: Callable[[EdgeIndexOperand], None], + ) -> None: + self._edge_index_operand.either_or( + lambda edge_index: either( + EdgeIndexOperand._from_py_edge_index_operand(edge_index) + ), + lambda edge_index: or_( + EdgeIndexOperand._from_py_edge_index_operand(edge_index) + ), + ) + + def clone(self) -> EdgeIndexOperand: + return EdgeIndexOperand._from_py_edge_index_operand( + self._edge_index_operand.deep_clone() + ) + + @classmethod + def _from_py_edge_index_operand( + cls, py_edge_index_operand: PyEdgeIndexOperand + ) -> EdgeIndexOperand: + edge_index_operand = cls() + edge_index_operand._edge_index_operand = py_edge_index_operand + return edge_index_operand diff --git a/medmodels/medrecord/querying.pyi b/medmodels/medrecord/querying.pyi deleted file mode 100644 index d2258fca..00000000 --- a/medmodels/medrecord/querying.pyi +++ /dev/null @@ -1,430 +0,0 @@ -from __future__ import annotations - -import sys -from enum import Enum, auto -from typing import Callable, List, Union - -from medmodels.medrecord.types import ( - EdgeIndex, - Group, - MedRecordAttribute, - MedRecordValue, - NodeIndex, -) - -if sys.version_info >= (3, 10): - from typing import TypeAlias -else: - from typing_extensions import TypeAlias - -NodeQuery: TypeAlias = Callable[[NodeOperand], None] -EdgeQuery: TypeAlias = Callable[[EdgeOperand], None] - -SingleValueComparisonOperand: TypeAlias = Union[SingleValueOperand, MedRecordValue] -MultipleValuesComparisonOperand: TypeAlias = Union[ - MultipleValuesOperand, List[MedRecordValue] -] - -SingleAttributeComparisonOperand: TypeAlias = Union[ - SingleAttributeOperand, - MedRecordAttribute, -] -MultipleAttributesComparisonOperand: TypeAlias = Union[ - MultipleAttributesOperand, List[MedRecordAttribute] -] - -NodeIndexComparisonOperand: TypeAlias = Union[NodeIndexOperand, NodeIndex] -NodeIndicesComparisonOperand: TypeAlias = Union[NodeIndicesOperand, List[NodeIndex]] - -EdgeIndexComparisonOperand: TypeAlias = Union[ - EdgeIndexOperand, - EdgeIndex, -] -EdgeIndicesComparisonOperand: TypeAlias = Union[ - EdgeIndicesOperand, - List[EdgeIndex], -] - -class EdgeDirection(Enum): - INCOMING = auto() - OUTGOING = auto() - BOTH = auto() - -class NodeOperand: - def attribute( - self, attribute: Union[MedRecordAttribute, SingleAttributeOperand] - ) -> MultipleValuesOperand: ... - def attributes(self) -> MultipleAttributesOperand: ... - def index(self) -> NodeIndexOperand: ... - def in_group(self, group: Union[Group, List[Group]]) -> None: ... - def has_attribute( - self, attribute: Union[MedRecordAttribute, SingleAttributeOperand] - ) -> None: ... - def incoming_edges(self) -> EdgeOperand: ... - def outgoing_edges(self) -> EdgeOperand: ... - def neighbors( - self, edge_direction: EdgeDirection = EdgeDirection.OUTGOING - ) -> NodeOperand: ... - def either_or(self, either: NodeQuery, or_: NodeQuery) -> None: ... - def clone(self) -> NodeOperand: ... - -class EdgeOperand: - def attribute( - self, attribute: Union[MedRecordAttribute, SingleAttributeOperand] - ) -> MultipleValuesOperand: ... - def attributes(self) -> MultipleAttributesOperand: ... - def index(self) -> EdgeIndexOperand: ... - def in_group(self, group: Union[Group, List[Group]]) -> None: ... - def has_attribute( - self, attribute: Union[MedRecordAttribute, SingleAttributeOperand] - ) -> None: ... - def source_node(self) -> NodeOperand: ... - def target_node(self) -> NodeOperand: ... - def either_or(self, either: EdgeQuery, or_: EdgeQuery) -> None: ... - def clone(self) -> EdgeOperand: ... - -class MultipleValuesOperand: - def max(self) -> SingleValueOperand: ... - def min(self) -> SingleValueOperand: ... - def mean(self) -> SingleValueOperand: ... - def median(self) -> SingleValueOperand: ... - def mode(self) -> SingleValueOperand: ... - def std(self) -> SingleValueOperand: ... - def var(self) -> SingleValueOperand: ... - def count(self) -> SingleValueOperand: ... - def sum(self) -> SingleValueOperand: ... - def first(self) -> SingleValueOperand: ... - def last(self) -> SingleValueOperand: ... - def is_string(self) -> None: ... - def is_int(self) -> None: ... - def is_float(self) -> None: ... - def is_bool(self) -> None: ... - def is_datetime(self) -> None: ... - def is_null(self) -> None: ... - def is_max(self) -> None: ... - def is_min(self) -> None: ... - def greater_than(self, value: SingleValueComparisonOperand) -> None: ... - def greater_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: ... - def less_than(self, value: SingleValueComparisonOperand) -> None: ... - def less_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: ... - def equal_to(self, value: SingleValueComparisonOperand) -> None: ... - def not_equal_to(self, value: SingleValueComparisonOperand) -> None: ... - def is_in(self, values: MultipleValuesComparisonOperand) -> None: ... - def is_not_in(self, values: MultipleValuesComparisonOperand) -> None: ... - def starts_with(self, value: SingleValueComparisonOperand) -> None: ... - def ends_with(self, value: SingleValueComparisonOperand) -> None: ... - def contains(self, value: SingleValueComparisonOperand) -> None: ... - def add(self, value: SingleValueComparisonOperand) -> None: ... - def subtract(self, value: SingleValueComparisonOperand) -> None: ... - def multiply(self, value: SingleValueComparisonOperand) -> None: ... - def divide(self, value: SingleValueComparisonOperand) -> None: ... - def modulo(self, value: SingleValueComparisonOperand) -> None: ... - def power(self, value: SingleValueComparisonOperand) -> None: ... - def round(self) -> None: ... - def ceil(self) -> None: ... - def floor(self) -> None: ... - def absolute(self) -> None: ... - def sqrt(self) -> None: ... - def trim(self) -> None: ... - def trim_start(self) -> None: ... - def trim_end(self) -> None: ... - def lowercase(self) -> None: ... - def uppercase(self) -> None: ... - def slice(self, start: int, end: int) -> None: ... - def either_or( - self, - either: Callable[[MultipleValuesOperand], None], - or_: Callable[[MultipleValuesOperand], None], - ) -> None: ... - def clone(self) -> MultipleValuesOperand: ... - -class SingleValueOperand: - def is_string(self) -> None: ... - def is_int(self) -> None: ... - def is_float(self) -> None: ... - def is_bool(self) -> None: ... - def is_datetime(self) -> None: ... - def is_null(self) -> None: ... - def greater_than(self, value: SingleValueComparisonOperand) -> None: ... - def greater_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: ... - def less_than(self, value: SingleValueComparisonOperand) -> None: ... - def less_than_or_equal_to(self, value: SingleValueComparisonOperand) -> None: ... - def equal_to(self, value: SingleValueComparisonOperand) -> None: ... - def not_equal_to(self, value: SingleValueComparisonOperand) -> None: ... - def is_in(self, values: MultipleValuesComparisonOperand) -> None: ... - def is_not_in(self, values: MultipleValuesComparisonOperand) -> None: ... - def starts_with(self, value: SingleValueComparisonOperand) -> None: ... - def ends_with(self, value: SingleValueComparisonOperand) -> None: ... - def contains(self, value: SingleValueComparisonOperand) -> None: ... - def add(self, value: SingleValueComparisonOperand) -> None: ... - def subtract(self, value: SingleValueComparisonOperand) -> None: ... - def multiply(self, value: SingleValueComparisonOperand) -> None: ... - def modulo(self, value: SingleValueComparisonOperand) -> None: ... - def power(self, value: SingleValueComparisonOperand) -> None: ... - def round(self) -> None: ... - def ceil(self) -> None: ... - def floor(self) -> None: ... - def absolute(self) -> None: ... - def sqrt(self) -> None: ... - def trim(self) -> None: ... - def trim_start(self) -> None: ... - def trim_end(self) -> None: ... - def lowercase(self) -> None: ... - def uppercase(self) -> None: ... - def slice(self, start: int, end: int) -> None: ... - def either_or( - self, - either: Callable[[SingleValueOperand], None], - or_: Callable[[SingleValueOperand], None], - ) -> None: ... - def clone(self) -> SingleValueOperand: ... - -class AttributesTreeOperand: - def max(self) -> MultipleAttributesOperand: ... - def min(self) -> MultipleAttributesOperand: ... - def count(self) -> MultipleAttributesOperand: ... - def sum(self) -> MultipleAttributesOperand: ... - def first(self) -> MultipleAttributesOperand: ... - def last(self) -> MultipleAttributesOperand: ... - def is_string(self) -> None: ... - def is_int(self) -> None: ... - def is_max(self) -> None: ... - def is_min(self) -> None: ... - def greater_than(self, value: SingleAttributeComparisonOperand) -> None: ... - def greater_than_or_equal_to( - self, value: SingleAttributeComparisonOperand - ) -> None: ... - def less_than(self, value: SingleAttributeComparisonOperand) -> None: ... - def less_than_or_equal_to( - self, value: SingleAttributeComparisonOperand - ) -> None: ... - def equal_to(self, value: SingleAttributeComparisonOperand) -> None: ... - def not_equal_to(self, value: SingleAttributeComparisonOperand) -> None: ... - def is_in(self, values: MultipleAttributesComparisonOperand) -> None: ... - def is_not_in(self, values: MultipleAttributesComparisonOperand) -> None: ... - def starts_with(self, value: SingleAttributeComparisonOperand) -> None: ... - def ends_with(self, value: SingleAttributeComparisonOperand) -> None: ... - def contains(self, value: SingleAttributeComparisonOperand) -> None: ... - def add(self, value: SingleAttributeComparisonOperand) -> None: ... - def subtract(self, value: SingleAttributeComparisonOperand) -> None: ... - def multiply(self, value: SingleAttributeComparisonOperand) -> None: ... - def modulo(self, value: SingleAttributeComparisonOperand) -> None: ... - def power(self, value: SingleAttributeComparisonOperand) -> None: ... - def absolute(self) -> None: ... - def trim(self) -> None: ... - def trim_start(self) -> None: ... - def trim_end(self) -> None: ... - def lowercase(self) -> None: ... - def uppercase(self) -> None: ... - def slice(self, start: int, end: int) -> None: ... - def either_or( - self, - either: Callable[[AttributesTreeOperand], None], - or_: Callable[[AttributesTreeOperand], None], - ) -> None: ... - def clone(self) -> AttributesTreeOperand: ... - -class MultipleAttributesOperand: - def max(self) -> SingleAttributeOperand: ... - def min(self) -> SingleAttributeOperand: ... - def count(self) -> SingleAttributeOperand: ... - def sum(self) -> SingleAttributeOperand: ... - def first(self) -> SingleAttributeOperand: ... - def last(self) -> SingleAttributeOperand: ... - def is_string(self) -> None: ... - def is_int(self) -> None: ... - def is_max(self) -> None: ... - def is_min(self) -> None: ... - def greater_than(self, value: SingleAttributeComparisonOperand) -> None: ... - def greater_than_or_equal_to( - self, value: SingleAttributeComparisonOperand - ) -> None: ... - def less_than(self, value: SingleAttributeComparisonOperand) -> None: ... - def less_than_or_equal_to( - self, value: SingleAttributeComparisonOperand - ) -> None: ... - def equal_to(self, value: SingleAttributeComparisonOperand) -> None: ... - def not_equal_to(self, value: SingleAttributeComparisonOperand) -> None: ... - def is_in(self, values: MultipleAttributesComparisonOperand) -> None: ... - def is_not_in(self, values: MultipleAttributesComparisonOperand) -> None: ... - def starts_with(self, value: SingleAttributeComparisonOperand) -> None: ... - def ends_with(self, value: SingleAttributeComparisonOperand) -> None: ... - def contains(self, value: SingleAttributeComparisonOperand) -> None: ... - def add(self, value: SingleAttributeComparisonOperand) -> None: ... - def subtract(self, value: SingleAttributeComparisonOperand) -> None: ... - def multiply(self, value: SingleAttributeComparisonOperand) -> None: ... - def modulo(self, value: SingleAttributeComparisonOperand) -> None: ... - def power(self, value: SingleAttributeComparisonOperand) -> None: ... - def absolute(self) -> None: ... - def trim(self) -> None: ... - def trim_start(self) -> None: ... - def trim_end(self) -> None: ... - def lowercase(self) -> None: ... - def uppercase(self) -> None: ... - def to_values(self) -> MultipleValuesOperand: ... - def slice(self, start: int, end: int) -> None: ... - def either_or( - self, - either: Callable[[MultipleAttributesOperand], None], - or_: Callable[[MultipleAttributesOperand], None], - ) -> None: ... - def clone(self) -> MultipleAttributesOperand: ... - -class SingleAttributeOperand: - def is_string(self) -> None: ... - def is_int(self) -> None: ... - def greater_than(self, value: SingleAttributeComparisonOperand) -> None: ... - def greater_than_or_equal_to( - self, value: SingleAttributeComparisonOperand - ) -> None: ... - def less_than(self, value: SingleAttributeComparisonOperand) -> None: ... - def less_than_or_equal_to( - self, value: SingleAttributeComparisonOperand - ) -> None: ... - def equal_to(self, value: SingleAttributeComparisonOperand) -> None: ... - def not_equal_to(self, value: SingleAttributeComparisonOperand) -> None: ... - def is_in(self, values: MultipleAttributesComparisonOperand) -> None: ... - def is_not_in(self, values: MultipleAttributesComparisonOperand) -> None: ... - def starts_with(self, value: SingleAttributeComparisonOperand) -> None: ... - def ends_with(self, value: SingleAttributeComparisonOperand) -> None: ... - def contains(self, value: SingleAttributeComparisonOperand) -> None: ... - def add(self, value: SingleAttributeComparisonOperand) -> None: ... - def subtract(self, value: SingleAttributeComparisonOperand) -> None: ... - def multiply(self, value: SingleAttributeComparisonOperand) -> None: ... - def modulo(self, value: SingleAttributeComparisonOperand) -> None: ... - def power(self, value: SingleAttributeComparisonOperand) -> None: ... - def absolute(self) -> None: ... - def trim(self) -> None: ... - def trim_start(self) -> None: ... - def trim_end(self) -> None: ... - def lowercase(self) -> None: ... - def uppercase(self) -> None: ... - def slice(self, start: int, end: int) -> None: ... - def either_or( - self, - either: Callable[[SingleAttributeOperand], None], - or_: Callable[[SingleAttributeOperand], None], - ) -> None: ... - def clone(self) -> SingleAttributeOperand: ... - -class NodeIndicesOperand: - def max(self) -> NodeIndexOperand: ... - def min(self) -> NodeIndexOperand: ... - def count(self) -> NodeIndexOperand: ... - def sum(self) -> NodeIndexOperand: ... - def first(self) -> NodeIndexOperand: ... - def last(self) -> NodeIndexOperand: ... - def greater_than(self, value: NodeIndexComparisonOperand) -> None: ... - def greater_than_or_equal_to(self, value: NodeIndexComparisonOperand) -> None: ... - def less_than(self, value: NodeIndexComparisonOperand) -> None: ... - def less_than_or_equal_to(self, value: NodeIndexComparisonOperand) -> None: ... - def equal_to(self, value: NodeIndexComparisonOperand) -> None: ... - def not_equal_to(self, value: NodeIndexComparisonOperand) -> None: ... - def is_in(self, values: NodeIndicesComparisonOperand) -> None: ... - def is_not_in(self, values: NodeIndicesComparisonOperand) -> None: ... - def starts_with(self, value: NodeIndexComparisonOperand) -> None: ... - def ends_with(self, value: NodeIndexComparisonOperand) -> None: ... - def contains(self, value: NodeIndexComparisonOperand) -> None: ... - def add(self, value: NodeIndexComparisonOperand) -> None: ... - def subtract(self, value: NodeIndexComparisonOperand) -> None: ... - def multiply(self, value: NodeIndexComparisonOperand) -> None: ... - def modulo(self, value: NodeIndexComparisonOperand) -> None: ... - def power(self, value: NodeIndexComparisonOperand) -> None: ... - def absolute(self) -> None: ... - def trim(self) -> None: ... - def trim_start(self) -> None: ... - def trim_end(self) -> None: ... - def lowercase(self) -> None: ... - def uppercase(self) -> None: ... - def slice(self, start: int, end: int) -> None: ... - def either_or( - self, - either: Callable[[NodeIndicesOperand], None], - or_: Callable[[NodeIndicesOperand], None], - ) -> None: ... - def clone(self) -> NodeIndicesOperand: ... - -class NodeIndexOperand: - def greater_than(self, value: NodeIndexComparisonOperand) -> None: ... - def greater_than_or_equal_to(self, value: NodeIndexComparisonOperand) -> None: ... - def less_than(self, value: NodeIndexComparisonOperand) -> None: ... - def less_than_or_equal_to(self, value: NodeIndexComparisonOperand) -> None: ... - def equal_to(self, value: NodeIndexComparisonOperand) -> None: ... - def not_equal_to(self, value: NodeIndexComparisonOperand) -> None: ... - def is_in(self, values: NodeIndicesComparisonOperand) -> None: ... - def is_not_in(self, values: NodeIndicesComparisonOperand) -> None: ... - def starts_with(self, value: NodeIndexComparisonOperand) -> None: ... - def ends_with(self, value: NodeIndexComparisonOperand) -> None: ... - def contains(self, value: NodeIndexComparisonOperand) -> None: ... - def add(self, value: NodeIndexComparisonOperand) -> None: ... - def subtract(self, value: NodeIndexComparisonOperand) -> None: ... - def multiply(self, value: NodeIndexComparisonOperand) -> None: ... - def modulo(self, value: NodeIndexComparisonOperand) -> None: ... - def power(self, value: NodeIndexComparisonOperand) -> None: ... - def absolute(self) -> None: ... - def trim(self) -> None: ... - def trim_start(self) -> None: ... - def trim_end(self) -> None: ... - def lowercase(self) -> None: ... - def uppercase(self) -> None: ... - def slice(self, start: int, end: int) -> None: ... - def either_or( - self, - either: Callable[[NodeIndexOperand], None], - or_: Callable[[NodeIndexOperand], None], - ) -> None: ... - def clone(self) -> NodeIndexOperand: ... - -class EdgeIndicesOperand: - def max(self) -> EdgeIndexOperand: ... - def min(self) -> EdgeIndexOperand: ... - def count(self) -> EdgeIndexOperand: ... - def sum(self) -> EdgeIndexOperand: ... - def first(self) -> EdgeIndexOperand: ... - def last(self) -> EdgeIndexOperand: ... - def greater_than(self, value: EdgeIndexComparisonOperand) -> None: ... - def greater_than_or_equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... - def less_than(self, value: EdgeIndexComparisonOperand) -> None: ... - def less_than_or_equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... - def equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... - def not_equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... - def is_in(self, values: EdgeIndicesComparisonOperand) -> None: ... - def is_not_in(self, values: EdgeIndicesComparisonOperand) -> None: ... - def starts_with(self, value: EdgeIndexComparisonOperand) -> None: ... - def ends_with(self, value: EdgeIndexComparisonOperand) -> None: ... - def contains(self, value: EdgeIndexComparisonOperand) -> None: ... - def add(self, value: EdgeIndexComparisonOperand) -> None: ... - def subtract(self, value: EdgeIndexComparisonOperand) -> None: ... - def multiply(self, value: EdgeIndexComparisonOperand) -> None: ... - def modulo(self, value: EdgeIndexComparisonOperand) -> None: ... - def power(self, value: EdgeIndexComparisonOperand) -> None: ... - def either_or( - self, - either: Callable[[EdgeIndicesOperand], None], - or_: Callable[[EdgeIndicesOperand], None], - ) -> None: ... - def clone(self) -> EdgeIndicesOperand: ... - -class EdgeIndexOperand: - def greater_than(self, value: EdgeIndexComparisonOperand) -> None: ... - def greater_than_or_equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... - def less_than(self, value: EdgeIndexComparisonOperand) -> None: ... - def less_than_or_equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... - def equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... - def not_equal_to(self, value: EdgeIndexComparisonOperand) -> None: ... - def is_in(self, values: EdgeIndicesComparisonOperand) -> None: ... - def is_not_in(self, values: EdgeIndicesComparisonOperand) -> None: ... - def starts_with(self, value: EdgeIndicesComparisonOperand) -> None: ... - def ends_with(self, value: EdgeIndicesComparisonOperand) -> None: ... - def contains(self, value: EdgeIndicesComparisonOperand) -> None: ... - def add(self, value: EdgeIndexComparisonOperand) -> None: ... - def subtract(self, value: EdgeIndexComparisonOperand) -> None: ... - def multiply(self, value: EdgeIndexComparisonOperand) -> None: ... - def modulo(self, value: EdgeIndexComparisonOperand) -> None: ... - def power(self, value: EdgeIndexComparisonOperand) -> None: ... - def either_or( - self, - either: Callable[[EdgeIndexOperand], None], - or_: Callable[[EdgeIndexOperand], None], - ) -> None: ... - def clone(self) -> EdgeIndexOperand: ... diff --git a/medmodels/medrecord/schema.py b/medmodels/medrecord/schema.py index 2dd28890..690dd601 100644 --- a/medmodels/medrecord/schema.py +++ b/medmodels/medrecord/schema.py @@ -19,7 +19,7 @@ class AttributeType(Enum): Temporal = auto() @staticmethod - def _from_pyattributetype(py_attribute_type: PyAttributeType) -> AttributeType: + def _from_py_attribute_type(py_attribute_type: PyAttributeType) -> AttributeType: """ Converts a PyAttributeType to an AttributeType. @@ -36,7 +36,7 @@ def _from_pyattributetype(py_attribute_type: PyAttributeType) -> AttributeType: elif py_attribute_type == PyAttributeType.Temporal: return AttributeType.Temporal - def _into_pyattributetype(self) -> PyAttributeType: + def _into_py_attribute_type(self) -> PyAttributeType: """ Converts an AttributeType to a PyAttributeType. @@ -81,7 +81,7 @@ def __eq__(self, value: object) -> bool: bool: True if the objects are equal, False otherwise. """ if isinstance(value, PyAttributeType): - return self._into_pyattributetype() == value + return self._into_py_attribute_type() == value elif isinstance(value, AttributeType): return str(self) == str(value) @@ -295,7 +295,7 @@ def _convert_input( ) -> PyAttributeDataType: if isinstance(input, tuple): return PyAttributeDataType( - input[0]._inner(), input[1]._into_pyattributetype() + input[0]._inner(), input[1]._into_py_attribute_type() ) return PyAttributeDataType(input._inner(), None) @@ -334,8 +334,8 @@ def _convert_node( input: PyAttributeDataType, ) -> Tuple[DataType, Optional[AttributeType]]: return ( - DataType._from_pydatatype(input.data_type), - AttributeType._from_pyattributetype(input.attribute_type) + DataType._from_py_data_type(input.data_type), + AttributeType._from_py_attribute_type(input.attribute_type) if input.attribute_type is not None else None, ) @@ -361,8 +361,8 @@ def _convert_edge( input: PyAttributeDataType, ) -> Tuple[DataType, Optional[AttributeType]]: return ( - DataType._from_pydatatype(input.data_type), - AttributeType._from_pyattributetype(input.attribute_type) + DataType._from_py_data_type(input.data_type), + AttributeType._from_py_attribute_type(input.attribute_type) if input.attribute_type is not None else None, ) @@ -422,7 +422,7 @@ def __init__( ) @classmethod - def _from_pyschema(cls, schema: PySchema) -> Schema: + def _from_py_schema(cls, schema: PySchema) -> Schema: """ Creates a Schema instance from an existing PySchema. diff --git a/medmodels/treatment_effect/tests/test_treatment_effect.py b/medmodels/treatment_effect/tests/test_treatment_effect.py index caeec59a..0a830b63 100644 --- a/medmodels/treatment_effect/tests/test_treatment_effect.py +++ b/medmodels/treatment_effect/tests/test_treatment_effect.py @@ -6,7 +6,7 @@ import pandas as pd from medmodels import MedRecord -from medmodels.medrecord.querying import NodeOperand +from medmodels.medrecord.querying import EdgeDirection, NodeOperand from medmodels.medrecord.types import NodeIndex from medmodels.treatment_effect.estimate import ContingencyTable, SubjectIndices from medmodels.treatment_effect.treatment_effect import TreatmentEffect @@ -621,7 +621,7 @@ def test_outcome_before_treatment(self): def test_filter_controls(self): def query1(node: NodeOperand): - node.neighbors().index().equal_to("M2") + node.neighbors(EdgeDirection.BOTH).index().equal_to("M2") tee = ( TreatmentEffect.builder() diff --git a/rustmodels/src/lib.rs b/rustmodels/src/lib.rs index 9e33e547..2584bae3 100644 --- a/rustmodels/src/lib.rs +++ b/rustmodels/src/lib.rs @@ -8,7 +8,7 @@ use medrecord::{ PyAttributesTreeOperand, PyMultipleAttributesOperand, PySingleAttributeOperand, }, edges::{PyEdgeIndexOperand, PyEdgeIndicesOperand, PyEdgeOperand}, - nodes::{PyNodeIndexOperand, PyNodeIndicesOperand, PyNodeOperand}, + nodes::{PyEdgeDirection, PyNodeIndexOperand, PyNodeIndicesOperand, PyNodeOperand}, values::{PyMultipleValuesOperand, PySingleValueOperand}, }, schema::{PyAttributeDataType, PyAttributeType, PyGroupSchema, PySchema}, @@ -35,6 +35,8 @@ fn _medmodels(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?;