Skip to content

Commit

Permalink
Fix typing on session_and_runs.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Aug 30, 2023
1 parent 37cd9a1 commit 2887b80
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 9 deletions.
12 changes: 10 additions & 2 deletions spikewrap/data_classes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,23 @@ class BaseUserDict(UserDict):

base_path: Path
sub_name: str
sessions_and_runs: Dict
sessions_and_runs: Dict[str, List[str]]

def __post_init__(self) -> None:
self.data: Dict = {}
self.base_path = Path(self.base_path)
self.check_run_names_are_formatted_as_list()

def check_run_names_are_formatted_as_list(self) -> None:
""""""
"""
`sessions_and_runs` is typed as `Dict[str, List[str]]` but the
class will accept `Dict[str, Union[str, List[str]]]` and
cast here. Attempted to type with the latter, or `
MutableMapping[str, [str, Union[str, List[str]]]` but had many issues
such as https://github.com/python/mypy/issues/8136. The main thing
is we can work with `Dict[str, List[str]]` but if `Dict[str, str]` is
passed n general use it will not fail.
"""
for key, value in self.sessions_and_runs.items():
if not isinstance(value, List):
assert isinstance(
Expand Down
2 changes: 1 addition & 1 deletion spikewrap/examples/example_full_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
run_full_pipeline(
base_path,
sub_name,
sessions_and_runs, # type: ignore # currently asking on mypy's Gitter
sessions_and_runs,
config_name,
sorter,
concat_sessions_for_sorting=True, # TODO: validate this at the start, in `run_full_pipeline`
Expand Down
2 changes: 1 addition & 1 deletion spikewrap/pipeline/full_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def run_full_pipeline(
base_path: Union[Path, str],
sub_name: str,
sessions_and_runs: Dict[str, Union[str, List[str]]],
sessions_and_runs: Dict[str, List[str]],
config_name: str = "default",
sorter: str = "kilosort2_5",
concat_sessions_for_sorting: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions spikewrap/pipeline/load_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Dict, Union
from typing import Dict, List, Union

import spikeinterface.extractors as se

Expand All @@ -12,7 +12,7 @@
def load_data(
base_path: Union[Path, str],
sub_name: str,
sessions_and_runs: Dict,
sessions_and_runs: Dict[str, List[str]],
data_format: str = "spikeglx",
) -> PreprocessingData:
"""
Expand Down
4 changes: 2 additions & 2 deletions spikewrap/pipeline/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
from pathlib import Path
from typing import Dict, Literal, Optional, Union
from typing import Dict, List, Literal, Optional, Union

import spikeinterface.sorters as ss

Expand All @@ -23,7 +23,7 @@
def run_sorting(
base_path: Union[str, Path],
sub_name: str,
sessions_and_runs: Dict,
sessions_and_runs: Dict[str, List[str]],
sorter: str,
concatenate_sessions: bool = False,
concatenate_runs: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion spikewrap/pipeline/visualise.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def visualise(
data: Union[PreprocessingData, SortingData],
sessions_and_runs: Dict,
sessions_and_runs: Dict[str, List[str]],
steps: Union[List[str], str] = "all",
mode: str = "auto",
as_subplot: bool = False,
Expand Down

0 comments on commit 2887b80

Please sign in to comment.