diff --git a/spikewrap/data_classes/base.py b/spikewrap/data_classes/base.py new file mode 100644 index 0000000..f8f57ab --- /dev/null +++ b/spikewrap/data_classes/base.py @@ -0,0 +1,124 @@ +import fnmatch +from collections import UserDict +from collections.abc import ItemsView, KeysView, ValuesView +from dataclasses import dataclass +from itertools import chain +from pathlib import Path +from typing import Callable, Dict, List, Literal + + +@dataclass +class BaseUserDict(UserDict): + """ + Base class for `PreprocessingData` and `SortingData` + used for checking and formatting `base_path`, `sub_name` + and `run_names`. The layout of the `rawdata` and + `derivatives` folder is identical up to the run + folder, allowing use of this class for + preprocessing and sorting. + + Base UserDict that implements the + keys(), values() and items() convenience functions.""" + + base_path: Path + sub_name: str + sessions_and_runs: Dict + + def __post_init__(self) -> None: + self.data: Dict = {} + self.base_path = Path(self.base_path) + self.check_run_names_are_formatted_as_list() + + def check_run_names_are_formatted_as_list(self) -> None: + """""" + for key, value in self.sessions_and_runs.items(): + if not isinstance(value, List): + assert isinstance( + value, str + ), "Run names must be string or list of strings" + self.sessions_and_runs[key] = [value] + + def preprocessing_sessions_and_runs(self): # TODO: type hint + """""" + ordered_ses_names = list( + chain(*[[ses] * len(runs) for ses, runs in self.sessions_and_runs.items()]) + ) + ordered_run_names = list( + chain(*[runs for runs in self.sessions_and_runs.values()]) + ) + + return list(zip(ordered_ses_names, ordered_run_names)) + + def _validate_inputs( + self, + top_level_folder: Literal["rawdata", "derivatives"], + get_top_level_folder: Callable, + get_sub_level_folder: Callable, + get_sub_path: Callable, + get_run_path: Callable, + ) -> None: + """ + Check the rawdata / derivatives path, subject path exists + and ensure run_names is a list of strings. + + Parameters + ---------- + run_names : List[str] + List of run names to process, in order they should be + processed / concatenated. + + Returns + ------- + run_names : List[str] + Validated `run_names` as a List. + """ + assert get_top_level_folder().is_dir(), ( + f"Ensure there is a folder in base path called '" + f"{top_level_folder}'.\n" + f"No {top_level_folder} directory found at " + f"{get_top_level_folder()}\n" + f"where subject-level folders must be placed." + ) + + assert get_sub_level_folder().is_dir(), ( + f"Subject directory not found. {self.sub_name} " + f"is not a folder in {get_top_level_folder()}" + ) + + for ses_name in self.sessions_and_runs.keys(): + assert ( + ses_path := get_sub_path(ses_name) + ).is_dir(), f"{ses_name} was not found at folder path {ses_path}" + + for run_name in self.sessions_and_runs[ses_name]: + assert (run_path := get_run_path(ses_name, run_name)).is_dir(), ( + f"The run folder {run_path.stem} cannot be found at " + f"file path {run_path.parent}." + ) + + gate_str = fnmatch.filter(run_name.split("_"), "g?") + + assert len(gate_str) > 0, ( + f"The SpikeGLX gate index should be in the run name. " + f"It was not found in the name {run_name}." + f"\nEnsure the gate number is in the SpikeGLX-output filename." + ) + + assert len(gate_str) == 1, ( + f"The SpikeGLX gate appears in the name " + f"{run_name} more than once" + ) + + assert int(gate_str[0][1:]) == 0, ( + f"Gate with index larger than 0 is not supported. This is found " + f"in run name {run_name}. " + ) + + def keys(self) -> KeysView: + return self.data.keys() + + def items(self) -> ItemsView: + return self.data.items() + + def values(self) -> ValuesView: + return self.data.values() diff --git a/spikewrap/data_classes/preprocessing.py b/spikewrap/data_classes/preprocessing.py new file mode 100644 index 0000000..632ca83 --- /dev/null +++ b/spikewrap/data_classes/preprocessing.py @@ -0,0 +1,58 @@ +import shutil +from dataclasses import dataclass +from typing import Dict + +import spikeinterface + +from ..utils import utils +from .base import BaseUserDict + + +@dataclass +class PreprocessingData(BaseUserDict): + """ + Dictionary to store SpikeInterface preprocessing recordings. + + Details on the preprocessing steps are held in the dictionary keys e.g. + e.g. 0-raw, 1-raw-bandpass_filter, 2-raw_bandpass_filter-common_average + and recording objects are held in the value. These are generated + by the `pipeline.preprocess.run_preprocessing()` function. + + The class manages paths to raw data and preprocessing output, + as defines methods to dump key information and the SpikeInterface + binary to disk. Note that SI preprocessing is lazy and + preprocessing only run when the recording.get_traces() + is called, or the data is saved to binary. + + Parameters + ---------- + base_path : Union[Path, str] + Path where the rawdata folder containing subjects. + + sub_name : str + 'subject' to preprocess. The subject top level dir should + reside in base_path/rawdata/. + + run_names : Union[List[str], str] + The SpikeGLX run name (i.e. not including the gate index) + or list of run names. + """ + + def __post_init__(self) -> None: + super().__post_init__() + self._validate_rawdata_inputs() + + self.sync: Dict = {} + + for ses_name, run_name in self.preprocessing_sessions_and_runs(): + utils.update(self.data, ses_name, run_name, {"0-raw": None}) + utils.update(self.sync, ses_name, run_name, None) + + def _validate_rawdata_inputs(self) -> None: + self._validate_inputs( + "rawdata", + self.get_rawdata_top_level_path, + self.get_rawdata_sub_path, + self.get_rawdata_ses_path, + self.get_rawdata_run_path, + ) diff --git a/spikewrap/examples/load_data.py b/spikewrap/examples/load_data.py new file mode 100644 index 0000000..0b04e2f --- /dev/null +++ b/spikewrap/examples/load_data.py @@ -0,0 +1,23 @@ +from pathlib import Path + +from spikewrap.pipeline.load_data import load_data + +base_path = Path( + r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-short-multises" +) + +sub_name = "sub-1119617" +sessions_and_runs = { + "ses-001": [ + "run-001_1119617_LSE1_shank12_g0", + "run-002_made_up_g0", + ], + "ses-002": [ + "run-001_1119617_pretest1_shank12_g0", + ], + "ses-003": [ + "run-002_1119617_pretest1_shank12_g0", + ], +} + +loaded_data = load_data(base_path, sub_name, sessions_and_runs, data_format="spikeglx") \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_integration/__init__.py b/tests/test_integration/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_unit/__init__.py b/tests/test_unit/__init__.py deleted file mode 100644 index e69de29..0000000