-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0c7474d
commit 679098b
Showing
6 changed files
with
205 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import fnmatch | ||
from collections import UserDict | ||
from collections.abc import ItemsView, KeysView, ValuesView | ||
from dataclasses import dataclass | ||
from itertools import chain | ||
from pathlib import Path | ||
from typing import Callable, Dict, List, Literal | ||
|
||
|
||
@dataclass | ||
class BaseUserDict(UserDict): | ||
""" | ||
Base class for `PreprocessingData` and `SortingData` | ||
used for checking and formatting `base_path`, `sub_name` | ||
and `run_names`. The layout of the `rawdata` and | ||
`derivatives` folder is identical up to the run | ||
folder, allowing use of this class for | ||
preprocessing and sorting. | ||
Base UserDict that implements the | ||
keys(), values() and items() convenience functions.""" | ||
|
||
base_path: Path | ||
sub_name: str | ||
sessions_and_runs: Dict | ||
|
||
def __post_init__(self) -> None: | ||
self.data: Dict = {} | ||
self.base_path = Path(self.base_path) | ||
self.check_run_names_are_formatted_as_list() | ||
|
||
def check_run_names_are_formatted_as_list(self) -> None: | ||
"""""" | ||
for key, value in self.sessions_and_runs.items(): | ||
if not isinstance(value, List): | ||
assert isinstance( | ||
value, str | ||
), "Run names must be string or list of strings" | ||
self.sessions_and_runs[key] = [value] | ||
|
||
def preprocessing_sessions_and_runs(self): # TODO: type hint | ||
"""""" | ||
ordered_ses_names = list( | ||
chain(*[[ses] * len(runs) for ses, runs in self.sessions_and_runs.items()]) | ||
) | ||
ordered_run_names = list( | ||
chain(*[runs for runs in self.sessions_and_runs.values()]) | ||
) | ||
|
||
return list(zip(ordered_ses_names, ordered_run_names)) | ||
|
||
def _validate_inputs( | ||
self, | ||
top_level_folder: Literal["rawdata", "derivatives"], | ||
get_top_level_folder: Callable, | ||
get_sub_level_folder: Callable, | ||
get_sub_path: Callable, | ||
get_run_path: Callable, | ||
) -> None: | ||
""" | ||
Check the rawdata / derivatives path, subject path exists | ||
and ensure run_names is a list of strings. | ||
Parameters | ||
---------- | ||
run_names : List[str] | ||
List of run names to process, in order they should be | ||
processed / concatenated. | ||
Returns | ||
------- | ||
run_names : List[str] | ||
Validated `run_names` as a List. | ||
""" | ||
assert get_top_level_folder().is_dir(), ( | ||
f"Ensure there is a folder in base path called '" | ||
f"{top_level_folder}'.\n" | ||
f"No {top_level_folder} directory found at " | ||
f"{get_top_level_folder()}\n" | ||
f"where subject-level folders must be placed." | ||
) | ||
|
||
assert get_sub_level_folder().is_dir(), ( | ||
f"Subject directory not found. {self.sub_name} " | ||
f"is not a folder in {get_top_level_folder()}" | ||
) | ||
|
||
for ses_name in self.sessions_and_runs.keys(): | ||
assert ( | ||
ses_path := get_sub_path(ses_name) | ||
).is_dir(), f"{ses_name} was not found at folder path {ses_path}" | ||
|
||
for run_name in self.sessions_and_runs[ses_name]: | ||
assert (run_path := get_run_path(ses_name, run_name)).is_dir(), ( | ||
f"The run folder {run_path.stem} cannot be found at " | ||
f"file path {run_path.parent}." | ||
) | ||
|
||
gate_str = fnmatch.filter(run_name.split("_"), "g?") | ||
|
||
assert len(gate_str) > 0, ( | ||
f"The SpikeGLX gate index should be in the run name. " | ||
f"It was not found in the name {run_name}." | ||
f"\nEnsure the gate number is in the SpikeGLX-output filename." | ||
) | ||
|
||
assert len(gate_str) == 1, ( | ||
f"The SpikeGLX gate appears in the name " | ||
f"{run_name} more than once" | ||
) | ||
|
||
assert int(gate_str[0][1:]) == 0, ( | ||
f"Gate with index larger than 0 is not supported. This is found " | ||
f"in run name {run_name}. " | ||
) | ||
|
||
def keys(self) -> KeysView: | ||
return self.data.keys() | ||
|
||
def items(self) -> ItemsView: | ||
return self.data.items() | ||
|
||
def values(self) -> ValuesView: | ||
return self.data.values() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import shutil | ||
from dataclasses import dataclass | ||
from typing import Dict | ||
|
||
import spikeinterface | ||
|
||
from ..utils import utils | ||
from .base import BaseUserDict | ||
|
||
|
||
@dataclass | ||
class PreprocessingData(BaseUserDict): | ||
""" | ||
Dictionary to store SpikeInterface preprocessing recordings. | ||
Details on the preprocessing steps are held in the dictionary keys e.g. | ||
e.g. 0-raw, 1-raw-bandpass_filter, 2-raw_bandpass_filter-common_average | ||
and recording objects are held in the value. These are generated | ||
by the `pipeline.preprocess.run_preprocessing()` function. | ||
The class manages paths to raw data and preprocessing output, | ||
as defines methods to dump key information and the SpikeInterface | ||
binary to disk. Note that SI preprocessing is lazy and | ||
preprocessing only run when the recording.get_traces() | ||
is called, or the data is saved to binary. | ||
Parameters | ||
---------- | ||
base_path : Union[Path, str] | ||
Path where the rawdata folder containing subjects. | ||
sub_name : str | ||
'subject' to preprocess. The subject top level dir should | ||
reside in base_path/rawdata/. | ||
run_names : Union[List[str], str] | ||
The SpikeGLX run name (i.e. not including the gate index) | ||
or list of run names. | ||
""" | ||
|
||
def __post_init__(self) -> None: | ||
super().__post_init__() | ||
self._validate_rawdata_inputs() | ||
|
||
self.sync: Dict = {} | ||
|
||
for ses_name, run_name in self.preprocessing_sessions_and_runs(): | ||
utils.update(self.data, ses_name, run_name, {"0-raw": None}) | ||
utils.update(self.sync, ses_name, run_name, None) | ||
|
||
def _validate_rawdata_inputs(self) -> None: | ||
self._validate_inputs( | ||
"rawdata", | ||
self.get_rawdata_top_level_path, | ||
self.get_rawdata_sub_path, | ||
self.get_rawdata_ses_path, | ||
self.get_rawdata_run_path, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from pathlib import Path | ||
|
||
from spikewrap.pipeline.load_data import load_data | ||
|
||
base_path = Path( | ||
r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-short-multises" | ||
) | ||
|
||
sub_name = "sub-1119617" | ||
sessions_and_runs = { | ||
"ses-001": [ | ||
"run-001_1119617_LSE1_shank12_g0", | ||
"run-002_made_up_g0", | ||
], | ||
"ses-002": [ | ||
"run-001_1119617_pretest1_shank12_g0", | ||
], | ||
"ses-003": [ | ||
"run-002_1119617_pretest1_shank12_g0", | ||
], | ||
} | ||
|
||
loaded_data = load_data(base_path, sub_name, sessions_and_runs, data_format="spikeglx") |
Empty file.
Empty file.
Empty file.