Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typing on session_and_runs. #102

Merged
merged 1 commit into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions spikewrap/data_classes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,23 @@ class BaseUserDict(UserDict):

base_path: Path
sub_name: str
sessions_and_runs: Dict
sessions_and_runs: Dict[str, List[str]]

def __post_init__(self) -> None:
self.data: Dict = {}
self.base_path = Path(self.base_path)
self.check_run_names_are_formatted_as_list()

def check_run_names_are_formatted_as_list(self) -> None:
""""""
"""
`sessions_and_runs` is typed as `Dict[str, List[str]]` but the
class will accept `Dict[str, Union[str, List[str]]]` and
cast here. Attempted to type with the latter, or `
MutableMapping[str, [str, Union[str, List[str]]]` but had many issues
such as https://github.com/python/mypy/issues/8136. The main thing
is we can work with `Dict[str, List[str]]` but if `Dict[str, str]` is
passed n general use it will not fail.
"""
for key, value in self.sessions_and_runs.items():
if not isinstance(value, List):
assert isinstance(
Expand Down
2 changes: 1 addition & 1 deletion spikewrap/examples/example_full_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
run_full_pipeline(
base_path,
sub_name,
sessions_and_runs, # type: ignore # currently asking on mypy's Gitter
sessions_and_runs,
config_name,
sorter,
concat_sessions_for_sorting=True, # TODO: validate this at the start, in `run_full_pipeline`
Expand Down
2 changes: 1 addition & 1 deletion spikewrap/examples/example_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
sub_name = "sub-1119617"
sessions_and_runs = {
"ses-001": "run-001_1119617_LSE1_shank12_g0",
"ses-001": ["run-001_1119617_LSE1_shank12_g0"],
}


Expand Down
2 changes: 1 addition & 1 deletion spikewrap/pipeline/full_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def run_full_pipeline(
base_path: Union[Path, str],
sub_name: str,
sessions_and_runs: Dict[str, Union[str, List[str]]],
sessions_and_runs: Dict[str, List[str]],
config_name: str = "default",
sorter: str = "kilosort2_5",
concat_sessions_for_sorting: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions spikewrap/pipeline/load_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Dict, Union
from typing import Dict, List, Union

import spikeinterface.extractors as se

Expand All @@ -12,7 +12,7 @@
def load_data(
base_path: Union[Path, str],
sub_name: str,
sessions_and_runs: Dict,
sessions_and_runs: Dict[str, List[str]],
data_format: str = "spikeglx",
) -> PreprocessingData:
"""
Expand Down
4 changes: 2 additions & 2 deletions spikewrap/pipeline/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
from pathlib import Path
from typing import Dict, Literal, Optional, Union
from typing import Dict, List, Literal, Optional, Union

import spikeinterface.sorters as ss

Expand All @@ -23,7 +23,7 @@
def run_sorting(
base_path: Union[str, Path],
sub_name: str,
sessions_and_runs: Dict,
sessions_and_runs: Dict[str, List[str]],
sorter: str,
concatenate_sessions: bool = False,
concatenate_runs: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion spikewrap/pipeline/visualise.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def visualise(
data: Union[PreprocessingData, SortingData],
sessions_and_runs: Dict,
sessions_and_runs: Dict[str, List[str]],
steps: Union[List[str], str] = "all",
mode: str = "auto",
as_subplot: bool = False,
Expand Down