From 01ec7959304288048489e55d9f6d3a50ae56ec35 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Thu, 31 Aug 2023 00:07:21 +0100 Subject: [PATCH] Fix typing on `session_and_runs`. --- spikewrap/data_classes/base.py | 12 ++++++++++-- spikewrap/examples/example_full_pipeline.py | 2 +- spikewrap/examples/example_sort.py | 2 +- spikewrap/pipeline/full_pipeline.py | 2 +- spikewrap/pipeline/load_data.py | 4 ++-- spikewrap/pipeline/sort.py | 4 ++-- spikewrap/pipeline/visualise.py | 2 +- 7 files changed, 18 insertions(+), 10 deletions(-) diff --git a/spikewrap/data_classes/base.py b/spikewrap/data_classes/base.py index 629ecfa..16796f8 100644 --- a/spikewrap/data_classes/base.py +++ b/spikewrap/data_classes/base.py @@ -24,7 +24,7 @@ class BaseUserDict(UserDict): base_path: Path sub_name: str - sessions_and_runs: Dict + sessions_and_runs: Dict[str, List[str]] def __post_init__(self) -> None: self.data: Dict = {} @@ -32,7 +32,15 @@ def __post_init__(self) -> None: self.check_run_names_are_formatted_as_list() def check_run_names_are_formatted_as_list(self) -> None: - """""" + """ + `sessions_and_runs` is typed as `Dict[str, List[str]]` but the + class will accept `Dict[str, Union[str, List[str]]]` and + cast here. Attempted to type with the latter, or ` + MutableMapping[str, [str, Union[str, List[str]]]` but had many issues + such as https://github.com/python/mypy/issues/8136. The main thing + is we can work with `Dict[str, List[str]]` but if `Dict[str, str]` is + passed n general use it will not fail. + """ for key, value in self.sessions_and_runs.items(): if not isinstance(value, List): assert isinstance( diff --git a/spikewrap/examples/example_full_pipeline.py b/spikewrap/examples/example_full_pipeline.py index a0ebca7..e34c9e7 100644 --- a/spikewrap/examples/example_full_pipeline.py +++ b/spikewrap/examples/example_full_pipeline.py @@ -36,7 +36,7 @@ run_full_pipeline( base_path, sub_name, - sessions_and_runs, # type: ignore # currently asking on mypy's Gitter + sessions_and_runs, config_name, sorter, concat_sessions_for_sorting=True, # TODO: validate this at the start, in `run_full_pipeline` diff --git a/spikewrap/examples/example_sort.py b/spikewrap/examples/example_sort.py index f25ef8f..8432bd5 100644 --- a/spikewrap/examples/example_sort.py +++ b/spikewrap/examples/example_sort.py @@ -7,7 +7,7 @@ ) sub_name = "sub-1119617" sessions_and_runs = { - "ses-001": "run-001_1119617_LSE1_shank12_g0", + "ses-001": ["run-001_1119617_LSE1_shank12_g0"], } diff --git a/spikewrap/pipeline/full_pipeline.py b/spikewrap/pipeline/full_pipeline.py index 0978e1a..c3e0924 100644 --- a/spikewrap/pipeline/full_pipeline.py +++ b/spikewrap/pipeline/full_pipeline.py @@ -18,7 +18,7 @@ def run_full_pipeline( base_path: Union[Path, str], sub_name: str, - sessions_and_runs: Dict[str, Union[str, List[str]]], + sessions_and_runs: Dict[str, List[str]], config_name: str = "default", sorter: str = "kilosort2_5", concat_sessions_for_sorting: bool = False, diff --git a/spikewrap/pipeline/load_data.py b/spikewrap/pipeline/load_data.py index 8a20834..42ac621 100644 --- a/spikewrap/pipeline/load_data.py +++ b/spikewrap/pipeline/load_data.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Dict, Union +from typing import Dict, List, Union import spikeinterface.extractors as se @@ -12,7 +12,7 @@ def load_data( base_path: Union[Path, str], sub_name: str, - sessions_and_runs: Dict, + sessions_and_runs: Dict[str, List[str]], data_format: str = "spikeglx", ) -> PreprocessingData: """ diff --git a/spikewrap/pipeline/sort.py b/spikewrap/pipeline/sort.py index 560c6f6..03563b4 100644 --- a/spikewrap/pipeline/sort.py +++ b/spikewrap/pipeline/sort.py @@ -2,7 +2,7 @@ import os from pathlib import Path -from typing import Dict, Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Union import spikeinterface.sorters as ss @@ -23,7 +23,7 @@ def run_sorting( base_path: Union[str, Path], sub_name: str, - sessions_and_runs: Dict, + sessions_and_runs: Dict[str, List[str]], sorter: str, concatenate_sessions: bool = False, concatenate_runs: bool = False, diff --git a/spikewrap/pipeline/visualise.py b/spikewrap/pipeline/visualise.py index 386af35..72096d2 100644 --- a/spikewrap/pipeline/visualise.py +++ b/spikewrap/pipeline/visualise.py @@ -20,7 +20,7 @@ def visualise( data: Union[PreprocessingData, SortingData], - sessions_and_runs: Dict, + sessions_and_runs: Dict[str, List[str]], steps: Union[List[str], str] = "all", mode: str = "auto", as_subplot: bool = False,