diff --git a/kymata/datasets/data_root.py b/kymata/datasets/data_root.py index 1e997d0a..bf7062bb 100644 --- a/kymata/datasets/data_root.py +++ b/kymata/datasets/data_root.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Optional -from kymata.io.file import path_type +from kymata.io.file import PathType _DATA_PATH_ENVIRONMENT_VAR_NAME = "KYMATA_DATA_ROOT" @@ -17,13 +17,13 @@ ] -def data_root_path(data_root: Optional[path_type] = None) -> Path: +def data_root_path(data_root: Optional[PathType] = None) -> Path: # Check if the data root has been specified # Might be in an environmental variable if data_root is None: - data_root: path_type | None = getenv(_DATA_PATH_ENVIRONMENT_VAR_NAME, default=None) + data_root: PathType | None = getenv(_DATA_PATH_ENVIRONMENT_VAR_NAME, default=None) # Might have been supplied as an argument if data_root is not None: diff --git a/kymata/datasets/sample.py b/kymata/datasets/sample.py index c002fc3a..6ec2f59e 100644 --- a/kymata/datasets/sample.py +++ b/kymata/datasets/sample.py @@ -6,7 +6,7 @@ from kymata.datasets.data_root import data_root_path from kymata.entities.expression import HexelExpressionSet, SensorExpressionSet -from kymata.io.file import path_type +from kymata.io.file import PathType from kymata.io.nkg import load_expression_set _SAMPLE_DATA_DIR_NAME = "tutorial_nkg_data" @@ -62,7 +62,7 @@ def to_expressionset(self) -> HexelExpressionSet: class KymataMirror2023Q3Dataset(SampleDataset): - def __init__(self, data_root: Optional[path_type] = None, download: bool = True): + def __init__(self, data_root: Optional[PathType] = None, download: bool = True): name = "kymata_mirror_Q3_2023" super().__init__( name=name, @@ -81,7 +81,7 @@ def to_expressionset(self) -> HexelExpressionSet: class TVLInsLoudnessOnlyDataset(SampleDataset): - def __init__(self, data_root: Optional[path_type] = None, download: bool = True): + def __init__(self, data_root: Optional[PathType] = None, download: bool = True): name = "TVL_2020_ins_loudness_only" super().__init__( name=name, @@ -100,7 +100,7 @@ def to_expressionset(self) -> HexelExpressionSet: class TVLDeltaInsTC1LoudnessOnlyDataset(SampleDataset): - def __init__(self, data_root: Optional[path_type] = None, download: bool = True): + def __init__(self, data_root: Optional[PathType] = None, download: bool = True): name = "TVL_2020_delta_ins_tontop_chan1_loudness_only" super().__init__( name=name, @@ -119,7 +119,7 @@ def to_expressionset(self) -> HexelExpressionSet: class TVLDeltaInsTC1LoudnessOnlySensorsDataset(SampleDataset): - def __init__(self, data_root: Optional[path_type] = None, download: bool = True): + def __init__(self, data_root: Optional[PathType] = None, download: bool = True): name = "TVL_2020_delta_ins_tontop_chan1_loudness_only_sensors" super().__init__( name=name, diff --git a/kymata/io/config.py b/kymata/io/config.py index b3437e0c..c5887292 100644 --- a/kymata/io/config.py +++ b/kymata/io/config.py @@ -3,10 +3,10 @@ import yaml from kymata.datasets.data_root import data_root_path -from kymata.io.file import path_type, file_type, open_or_use +from kymata.io.file import PathType, FileType, open_or_use -def load_config(config_location: path_type | file_type): +def load_config(config_location: PathType | FileType): """Load config parameters""" with open_or_use(config_location) as stream: return yaml.safe_load(stream) diff --git a/kymata/io/file.py b/kymata/io/file.py index d4dbe649..bb3a757f 100644 --- a/kymata/io/file.py +++ b/kymata/io/file.py @@ -2,12 +2,12 @@ from pathlib import Path from typing import TextIO, BinaryIO, Union -file_type = Union[TextIO, BinaryIO] -path_type = Union[str, Path] +FileType = Union[TextIO, BinaryIO] +PathType = Union[str, Path] @contextmanager -def open_or_use(path_or_file: path_type | file_type, mode: str = "r") -> file_type: +def open_or_use(path_or_file: PathType | FileType, mode: str = "r") -> FileType: """ If passed a path, will open it and return the file handle, and close when done. if passed a file handle, will keep it open, and return it when done. diff --git a/kymata/io/functions.py b/kymata/io/functions.py index 688a9f84..2a04ee1d 100644 --- a/kymata/io/functions.py +++ b/kymata/io/functions.py @@ -5,10 +5,10 @@ from scipy.io import loadmat from kymata.entities.functions import Function -from kymata.io.file import path_type +from kymata.io.file import PathType -def load_function(function_path_without_suffix: path_type, func_name: str, n_derivatives: int = 0, bruce_neurons: tuple = (0, 10)) -> Function: +def load_function(function_path_without_suffix: PathType, func_name: str, n_derivatives: int = 0, bruce_neurons: tuple = (0, 10)) -> Function: function_path_without_suffix = Path(function_path_without_suffix) func: NDArray if 'neurogram' in func_name: diff --git a/kymata/io/nkg.py b/kymata/io/nkg.py index 602cfbd6..26d51d39 100644 --- a/kymata/io/nkg.py +++ b/kymata/io/nkg.py @@ -16,7 +16,7 @@ SensorExpressionSet from kymata.math.p_values import p_to_logp from kymata.entities.sparse_data import expand_dims -from kymata.io.file import path_type, file_type, open_or_use +from kymata.io.file import PathType, FileType, open_or_use class _Keys(StrEnum): @@ -64,13 +64,13 @@ def block_names(self) -> list[str]: CURRENT_VERSION = "0.4" -def file_version(from_path_or_file: path_type | file_type) -> version.Version: +def file_version(from_path_or_file: PathType | FileType) -> version.Version: with open_or_use(from_path_or_file, mode="rb") as archive, ZipFile(archive, "r") as zf: with TextIOWrapper(zf.open("_metadata/format-version.txt"), encoding="utf-8") as f: return version.parse(str(f.read()).strip()) -def load_expression_set(from_path_or_file: path_type | file_type) -> ExpressionSet: +def load_expression_set(from_path_or_file: PathType | FileType) -> ExpressionSet: _v, data_dict = _load_data(from_path_or_file) type_identifier = data_dict[_Keys.expressionset_type] @@ -97,7 +97,7 @@ def load_expression_set(from_path_or_file: path_type | file_type) -> ExpressionS def save_expression_set(expression_set: ExpressionSet, - to_path_or_file: path_type | file_type, + to_path_or_file: PathType | FileType, compression=ZIP_LZMA, overwrite: bool = False): """ @@ -128,7 +128,7 @@ def save_expression_set(expression_set: ExpressionSet, zf.writestr(f"/{block_name}/coo-shape.txt", "\n".join(str(x) for x in expression_set._data[block_name].data.shape)) -def _load_data(from_path_or_file: path_type | file_type) -> tuple[version.Version, dict[str, Any]]: +def _load_data(from_path_or_file: PathType | FileType) -> tuple[version.Version, dict[str, Any]]: """ Load an ExpressionSet from an open file, or the file at the specified path. @@ -252,7 +252,7 @@ def _load_data(from_path_or_file: path_type | file_type) -> tuple[version.Versio # noinspection DuplicatedCode -def _load_data_current(from_path_or_file: path_type | file_type) -> dict[str, Any]: +def _load_data_current(from_path_or_file: PathType | FileType) -> dict[str, Any]: """ Load data from current version """ diff --git a/kymata/io/nkg_compatibility.py b/kymata/io/nkg_compatibility.py index f69cbaa4..224c247e 100644 --- a/kymata/io/nkg_compatibility.py +++ b/kymata/io/nkg_compatibility.py @@ -17,11 +17,11 @@ from numpy.typing import NDArray from kymata.entities.datatypes import LatencyDType, FunctionNameDType, HexelDType -from kymata.io.file import path_type, file_type, open_or_use +from kymata.io.file import PathType, FileType, open_or_use # noinspection DuplicatedCode -def _load_data_0_3(from_path_or_file: path_type | file_type) -> dict[str, Any]: +def _load_data_0_3(from_path_or_file: PathType | FileType) -> dict[str, Any]: """ This is a function which loads data format 0.3. @@ -60,7 +60,7 @@ def _load_data_0_3(from_path_or_file: path_type | file_type) -> dict[str, Any]: # noinspection DuplicatedCode -def _load_data_0_2(from_path_or_file: path_type | file_type) -> dict[str, Any]: +def _load_data_0_2(from_path_or_file: PathType | FileType) -> dict[str, Any]: """ This is a function which loads data format 0.2. @@ -99,7 +99,7 @@ def _load_data_0_2(from_path_or_file: path_type | file_type) -> dict[str, Any]: # noinspection DuplicatedCode -def _load_data_0_1(from_path_or_file: path_type | file_type) -> dict[str, Any]: +def _load_data_0_1(from_path_or_file: PathType | FileType) -> dict[str, Any]: """ This is a function which loads data format 0.1.