Skip to content

Commit

Permalink
Rename type alias
Browse files Browse the repository at this point in the history
  • Loading branch information
caiw committed Mar 15, 2024
1 parent 6c8a3c0 commit 18f614d
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 25 deletions.
6 changes: 3 additions & 3 deletions kymata/datasets/data_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path
from typing import Optional

from kymata.io.file import path_type
from kymata.io.file import PathType


_DATA_PATH_ENVIRONMENT_VAR_NAME = "KYMATA_DATA_ROOT"
Expand All @@ -17,13 +17,13 @@
]


def data_root_path(data_root: Optional[path_type] = None) -> Path:
def data_root_path(data_root: Optional[PathType] = None) -> Path:

# Check if the data root has been specified

# Might be in an environmental variable
if data_root is None:
data_root: path_type | None = getenv(_DATA_PATH_ENVIRONMENT_VAR_NAME, default=None)
data_root: PathType | None = getenv(_DATA_PATH_ENVIRONMENT_VAR_NAME, default=None)

# Might have been supplied as an argument
if data_root is not None:
Expand Down
10 changes: 5 additions & 5 deletions kymata/datasets/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from kymata.datasets.data_root import data_root_path
from kymata.entities.expression import HexelExpressionSet, SensorExpressionSet
from kymata.io.file import path_type
from kymata.io.file import PathType
from kymata.io.nkg import load_expression_set

_SAMPLE_DATA_DIR_NAME = "tutorial_nkg_data"
Expand Down Expand Up @@ -62,7 +62,7 @@ def to_expressionset(self) -> HexelExpressionSet:


class KymataMirror2023Q3Dataset(SampleDataset):
def __init__(self, data_root: Optional[path_type] = None, download: bool = True):
def __init__(self, data_root: Optional[PathType] = None, download: bool = True):
name = "kymata_mirror_Q3_2023"
super().__init__(
name=name,
Expand All @@ -81,7 +81,7 @@ def to_expressionset(self) -> HexelExpressionSet:


class TVLInsLoudnessOnlyDataset(SampleDataset):
def __init__(self, data_root: Optional[path_type] = None, download: bool = True):
def __init__(self, data_root: Optional[PathType] = None, download: bool = True):
name = "TVL_2020_ins_loudness_only"
super().__init__(
name=name,
Expand All @@ -100,7 +100,7 @@ def to_expressionset(self) -> HexelExpressionSet:


class TVLDeltaInsTC1LoudnessOnlyDataset(SampleDataset):
def __init__(self, data_root: Optional[path_type] = None, download: bool = True):
def __init__(self, data_root: Optional[PathType] = None, download: bool = True):
name = "TVL_2020_delta_ins_tontop_chan1_loudness_only"
super().__init__(
name=name,
Expand All @@ -119,7 +119,7 @@ def to_expressionset(self) -> HexelExpressionSet:


class TVLDeltaInsTC1LoudnessOnlySensorsDataset(SampleDataset):
def __init__(self, data_root: Optional[path_type] = None, download: bool = True):
def __init__(self, data_root: Optional[PathType] = None, download: bool = True):
name = "TVL_2020_delta_ins_tontop_chan1_loudness_only_sensors"
super().__init__(
name=name,
Expand Down
4 changes: 2 additions & 2 deletions kymata/io/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import yaml

from kymata.datasets.data_root import data_root_path
from kymata.io.file import path_type, file_type, open_or_use
from kymata.io.file import PathType, FileType, open_or_use


def load_config(config_location: path_type | file_type):
def load_config(config_location: PathType | FileType):
"""Load config parameters"""
with open_or_use(config_location) as stream:
return yaml.safe_load(stream)
Expand Down
6 changes: 3 additions & 3 deletions kymata/io/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from pathlib import Path
from typing import TextIO, BinaryIO, Union

file_type = Union[TextIO, BinaryIO]
path_type = Union[str, Path]
FileType = Union[TextIO, BinaryIO]
PathType = Union[str, Path]


@contextmanager
def open_or_use(path_or_file: path_type | file_type, mode: str = "r") -> file_type:
def open_or_use(path_or_file: PathType | FileType, mode: str = "r") -> FileType:
"""
If passed a path, will open it and return the file handle, and close when done.
if passed a file handle, will keep it open, and return it when done.
Expand Down
4 changes: 2 additions & 2 deletions kymata/io/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from scipy.io import loadmat

from kymata.entities.functions import Function
from kymata.io.file import path_type
from kymata.io.file import PathType


def load_function(function_path_without_suffix: path_type, func_name: str, n_derivatives: int = 0, bruce_neurons: tuple = (0, 10)) -> Function:
def load_function(function_path_without_suffix: PathType, func_name: str, n_derivatives: int = 0, bruce_neurons: tuple = (0, 10)) -> Function:
function_path_without_suffix = Path(function_path_without_suffix)
func: NDArray
if 'neurogram' in func_name:
Expand Down
12 changes: 6 additions & 6 deletions kymata/io/nkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
SensorExpressionSet
from kymata.math.p_values import p_to_logp
from kymata.entities.sparse_data import expand_dims
from kymata.io.file import path_type, file_type, open_or_use
from kymata.io.file import PathType, FileType, open_or_use


class _Keys(StrEnum):
Expand Down Expand Up @@ -64,13 +64,13 @@ def block_names(self) -> list[str]:
CURRENT_VERSION = "0.4"


def file_version(from_path_or_file: path_type | file_type) -> version.Version:
def file_version(from_path_or_file: PathType | FileType) -> version.Version:
with open_or_use(from_path_or_file, mode="rb") as archive, ZipFile(archive, "r") as zf:
with TextIOWrapper(zf.open("_metadata/format-version.txt"), encoding="utf-8") as f:
return version.parse(str(f.read()).strip())


def load_expression_set(from_path_or_file: path_type | file_type) -> ExpressionSet:
def load_expression_set(from_path_or_file: PathType | FileType) -> ExpressionSet:
_v, data_dict = _load_data(from_path_or_file)

type_identifier = data_dict[_Keys.expressionset_type]
Expand All @@ -97,7 +97,7 @@ def load_expression_set(from_path_or_file: path_type | file_type) -> ExpressionS


def save_expression_set(expression_set: ExpressionSet,
to_path_or_file: path_type | file_type,
to_path_or_file: PathType | FileType,
compression=ZIP_LZMA,
overwrite: bool = False):
"""
Expand Down Expand Up @@ -128,7 +128,7 @@ def save_expression_set(expression_set: ExpressionSet,
zf.writestr(f"/{block_name}/coo-shape.txt", "\n".join(str(x) for x in expression_set._data[block_name].data.shape))


def _load_data(from_path_or_file: path_type | file_type) -> tuple[version.Version, dict[str, Any]]:
def _load_data(from_path_or_file: PathType | FileType) -> tuple[version.Version, dict[str, Any]]:
"""
Load an ExpressionSet from an open file, or the file at the specified path.
Expand Down Expand Up @@ -252,7 +252,7 @@ def _load_data(from_path_or_file: path_type | file_type) -> tuple[version.Versio


# noinspection DuplicatedCode
def _load_data_current(from_path_or_file: path_type | file_type) -> dict[str, Any]:
def _load_data_current(from_path_or_file: PathType | FileType) -> dict[str, Any]:
"""
Load data from current version
"""
Expand Down
8 changes: 4 additions & 4 deletions kymata/io/nkg_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from numpy.typing import NDArray

from kymata.entities.datatypes import LatencyDType, FunctionNameDType, HexelDType
from kymata.io.file import path_type, file_type, open_or_use
from kymata.io.file import PathType, FileType, open_or_use


# noinspection DuplicatedCode
def _load_data_0_3(from_path_or_file: path_type | file_type) -> dict[str, Any]:
def _load_data_0_3(from_path_or_file: PathType | FileType) -> dict[str, Any]:
"""
This is a function which loads data format 0.3.
Expand Down Expand Up @@ -60,7 +60,7 @@ def _load_data_0_3(from_path_or_file: path_type | file_type) -> dict[str, Any]:


# noinspection DuplicatedCode
def _load_data_0_2(from_path_or_file: path_type | file_type) -> dict[str, Any]:
def _load_data_0_2(from_path_or_file: PathType | FileType) -> dict[str, Any]:
"""
This is a function which loads data format 0.2.
Expand Down Expand Up @@ -99,7 +99,7 @@ def _load_data_0_2(from_path_or_file: path_type | file_type) -> dict[str, Any]:


# noinspection DuplicatedCode
def _load_data_0_1(from_path_or_file: path_type | file_type) -> dict[str, Any]:
def _load_data_0_1(from_path_or_file: PathType | FileType) -> dict[str, Any]:
"""
This is a function which loads data format 0.1.
Expand Down

0 comments on commit 18f614d

Please sign in to comment.