Skip to content

Commit

Permalink
fix: linting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
JabobKrauskopf committed Oct 17, 2024
1 parent 68555e7 commit cdee623
Show file tree
Hide file tree
Showing 12 changed files with 644 additions and 750 deletions.
14 changes: 6 additions & 8 deletions medmodels/_medmodels.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
from enum import Enum
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Sequence, Union
from typing import Callable, Dict, List, Optional, Sequence, Union

from medmodels.medrecord.types import (
Attributes,
Expand All @@ -18,13 +19,10 @@ from medmodels.medrecord.types import (
PolarsNodeDataFrameInput,
)

if TYPE_CHECKING:
import sys

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

PyDataType: TypeAlias = Union[
PyString,
Expand Down
14 changes: 7 additions & 7 deletions medmodels/medrecord/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@
"Bool",
"DateTime",
"EdgeIndex",
"EdgeIndex",
"EdgeOperand",
"EdgeQuery",
"Float",
"GroupSchema",
"GroupSchema",
"Int",
"MedRecord",
"NodeIndex",
"NodeIndex",
"NodeOperand",
"NodeQuery",
"Null",
"Option",
"Schema",
"GroupSchema",
"NodeIndex",
"EdgeIndex",
"EdgeQuery",
"NodeQuery",
"NodeOperand",
"EdgeOperand",
"String",
"Union",
]
30 changes: 17 additions & 13 deletions medmodels/medrecord/_overview.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
from __future__ import annotations

import copy
from datetime import datetime
from typing import Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import polars as pl

from medmodels.medrecord.schema import AttributesSchema, AttributeType
from medmodels.medrecord.types import (
AttributeInfo,
Attributes,
EdgeIndex,
Group,
MedRecordAttribute,
NodeIndex,
NumericAttributeInfo,
StringAttributeInfo,
TemporalAttributeInfo,
)

if TYPE_CHECKING:
from medmodels.medrecord.types import (
AttributeInfo,
Attributes,
EdgeIndex,
Group,
MedRecordAttribute,
NodeIndex,
NumericAttributeInfo,
StringAttributeInfo,
TemporalAttributeInfo,
)


def extract_attribute_summary(

Check failure on line 25 in medmodels/medrecord/_overview.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (D417)

medmodels/medrecord/_overview.py:25:5: D417 Missing argument description in the docstring for `extract_attribute_summary`: `attribute_dictionary`
Expand Down Expand Up @@ -210,7 +214,7 @@ def prettify_table(
row[2] = str(attribute) if first_line else ""

# displaying info based on the type
if "values" in info.keys():
if "values" in info:
row[3] = info[key]
else:
if isinstance(info[key], float):
Expand Down
14 changes: 7 additions & 7 deletions medmodels/medrecord/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import TYPE_CHECKING, Callable, Dict, Tuple, Union, overload

from medmodels.medrecord.querying import EdgeQuery, NodeQuery
from medmodels.medrecord.types import (
Attributes,
AttributesInput,
Expand All @@ -22,6 +21,7 @@

if TYPE_CHECKING:
from medmodels import MedRecord
from medmodels.medrecord.querying import EdgeQuery, NodeQuery


class NodeIndexer:
Expand Down Expand Up @@ -63,7 +63,7 @@ def __getitem__(
key: Tuple[Union[NodeIndexInputList, NodeQuery, slice], MedRecordAttribute],
) -> Dict[NodeIndex, MedRecordValue]: ...

def __getitem__(
def __getitem__( # noqa: C901
self,
key: Union[
NodeIndex,
Expand Down Expand Up @@ -252,7 +252,7 @@ def __setitem__(
value: MedRecordValue,
) -> None: ...

def __setitem__(
def __setitem__( # noqa: C901
self,
key: Union[
NodeIndex,
Expand Down Expand Up @@ -523,7 +523,7 @@ def __setitem__(
return None
return None

def __delitem__(
def __delitem__( # noqa: C901
self,
key: Tuple[
Union[NodeIndex, NodeIndexInputList, NodeQuery, slice],
Expand Down Expand Up @@ -712,7 +712,7 @@ def __getitem__(
key: Tuple[Union[EdgeIndexInputList, EdgeQuery, slice], MedRecordAttribute],
) -> Dict[EdgeIndex, MedRecordValue]: ...

def __getitem__(
def __getitem__( # noqa: C901
self,
key: Union[
EdgeIndex,
Expand Down Expand Up @@ -901,7 +901,7 @@ def __setitem__(
value: MedRecordValue,
) -> None: ...

def __setitem__(
def __setitem__( # noqa: C901
self,
key: Union[
EdgeIndex,
Expand Down Expand Up @@ -1170,7 +1170,7 @@ def __setitem__(
return None
return None

def __delitem__(
def __delitem__( # noqa: C901
self,
key: Tuple[
Union[EdgeIndex, EdgeIndexInputList, EdgeQuery, slice],
Expand Down
24 changes: 16 additions & 8 deletions medmodels/medrecord/medrecord.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from enum import Enum, auto
from typing import Callable, Dict, List, Optional, Sequence, Union, overload

import polars as pl
Expand Down Expand Up @@ -86,7 +87,7 @@ def __init__(
data: Dict[Group, AttributeInfo],
group_header: str,
decimal: int,
):
) -> None:
"""Initializes the OverviewTable class.
Args:
Expand All @@ -105,6 +106,11 @@ def __repr__(self) -> str:
return "\n".join(prettify_table(self.data, header=header, decimal=self.decimal))


class EdgesDirected(Enum):
DIRECTED = auto()
UNDIRECTED = auto()


class MedRecord:
"""A class to manage medical records with node and edge data structures.
Expand Down Expand Up @@ -528,7 +534,7 @@ def edges_connecting(
self,
source_node: Union[NodeIndex, NodeIndexInputList, NodeQuery],
target_node: Union[NodeIndex, NodeIndexInputList, NodeQuery],
directed: bool = True,
directed: EdgesDirected = EdgesDirected.DIRECTED,
) -> List[EdgeIndex]:
"""Retrieves the edges connecting the specified source and target nodes in the MedRecord.
Expand All @@ -544,7 +550,8 @@ def edges_connecting(
target_node (Union[NodeIndex, NodeIndexInputList, NodeQuery]):
The index or indices of the target node(s), or a node query to
select target nodes.
directed (bool, optional): Whether to consider edges as directed.
directed (EdgesDirected, optional): Whether to consider edges as directed.
Defaults to EdgesDirected.DIRECTED.
Returns:
List[EdgeIndex]: A list of edge indices connecting the specified source and
Expand Down Expand Up @@ -1197,20 +1204,20 @@ def contains_group(self, group: Group) -> bool:
def neighbors(
self,
node: NodeIndex,
directed: bool = True,
directed: EdgesDirected = EdgesDirected.DIRECTED,
) -> List[NodeIndex]: ...

@overload
def neighbors(
self,
node: Union[NodeIndexInputList, NodeQuery],
directed: bool = True,
directed: EdgesDirected = EdgesDirected.DIRECTED,
) -> Dict[NodeIndex, List[NodeIndex]]: ...

def neighbors(
self,
node: Union[NodeIndex, NodeIndexInputList, NodeQuery],
directed: bool = True,
directed: EdgesDirected = EdgesDirected.DIRECTED,
) -> Union[List[NodeIndex], Dict[NodeIndex, List[NodeIndex]]]:
"""Retrieves the neighbors of the specified node(s) in the MedRecord.
Expand All @@ -1221,15 +1228,16 @@ def neighbors(
Args:
node (Union[NodeIndex, NodeIndexInputList, NodeQuery]): One or more
node indices or a node query.
directed (bool, optional): Whether to consider edges as directed.
directed (EdgesDirected, optional): Whether to consider edges as directed.
Defaults to EdgesDirected.DIRECTED.
Returns:
Union[List[NodeIndex], Dict[NodeIndex, List[NodeIndex]]]: Neighboring nodes.
"""
if isinstance(node, Callable):
node = self.select_nodes(node)

if directed:
if directed == EdgesDirected.DIRECTED:
neighbors = self._medrecord.neighbors(
node if isinstance(node, list) else [node]
)
Expand Down
3 changes: 2 additions & 1 deletion medmodels/medrecord/querying.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import sys
from enum import Enum
from typing import TYPE_CHECKING, Callable, List, Union

Expand All @@ -27,6 +26,8 @@
)

if TYPE_CHECKING:
import sys

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
Expand Down
10 changes: 5 additions & 5 deletions medmodels/medrecord/schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from enum import Enum, auto
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, overload
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union, overload

from medmodels._medmodels import (
PyAttributeDataType,
Expand Down Expand Up @@ -143,7 +143,7 @@ def __contains__(self, key: MedRecordAttribute) -> bool:
"""
return key in self._attributes_schema

def __iter__(self):
def __iter__(self) -> Iterator[MedRecordAttribute]:
"""Returns an iterator over the attributes schema.
Returns:
Expand Down Expand Up @@ -190,23 +190,23 @@ def __eq__(self, value: object) -> bool:

return True

def keys(self):
def keys(self): # noqa: ANN201
"""Returns the attribute keys in the schema.
Returns:
KeysView: A view object displaying a list of dictionary's keys.
"""
return self._attributes_schema.keys()

def values(self):
def values(self): # noqa: ANN201
"""Returns the attribute values in the schema.
Returns:
ValuesView: A view object displaying a list of dictionary's values.
"""
return self._attributes_schema.values()

def items(self):
def items(self): # noqa: ANN201
"""Returns the attribute key-value pairs in the schema.
Returns:
Expand Down
7 changes: 6 additions & 1 deletion medmodels/medrecord/tests/test_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest

import pytest

import medmodels.medrecord as mr


Expand Down Expand Up @@ -60,7 +62,10 @@ def test_with_schema(self) -> None:

medrecord.add_nodes(("node1", {"attribute": 1}))

with self.assertRaises(ValueError):
with pytest.raises(
ValueError,
match="Attribute attribute of node with index node2 is of type String. Expected Integer.",
):
medrecord.add_nodes(("node2", {"attribute": "1"}))


Expand Down
2 changes: 1 addition & 1 deletion medmodels/medrecord/tests/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_union(self) -> None:
assert mr.Union(mr.String(), mr.Int()) != mr.Union(mr.Int(), mr.String())

def test_invalid_union(self) -> None:
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Union must have at least two arguments"):
mr.Union(mr.String())

def test_option(self) -> None:
Expand Down
Loading

0 comments on commit cdee623

Please sign in to comment.