Skip to content

Commit

Permalink
Add load_data.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Sep 13, 2023
1 parent 0c7474d commit 679098b
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 0 deletions.
124 changes: 124 additions & 0 deletions spikewrap/data_classes/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import fnmatch
from collections import UserDict
from collections.abc import ItemsView, KeysView, ValuesView
from dataclasses import dataclass
from itertools import chain
from pathlib import Path
from typing import Callable, Dict, List, Literal


@dataclass
class BaseUserDict(UserDict):
"""
Base class for `PreprocessingData` and `SortingData`
used for checking and formatting `base_path`, `sub_name`
and `run_names`. The layout of the `rawdata` and
`derivatives` folder is identical up to the run
folder, allowing use of this class for
preprocessing and sorting.
Base UserDict that implements the
keys(), values() and items() convenience functions."""

base_path: Path
sub_name: str
sessions_and_runs: Dict

def __post_init__(self) -> None:
self.data: Dict = {}
self.base_path = Path(self.base_path)
self.check_run_names_are_formatted_as_list()

def check_run_names_are_formatted_as_list(self) -> None:
""""""
for key, value in self.sessions_and_runs.items():
if not isinstance(value, List):
assert isinstance(
value, str
), "Run names must be string or list of strings"
self.sessions_and_runs[key] = [value]

def preprocessing_sessions_and_runs(self): # TODO: type hint
""""""
ordered_ses_names = list(
chain(*[[ses] * len(runs) for ses, runs in self.sessions_and_runs.items()])
)
ordered_run_names = list(
chain(*[runs for runs in self.sessions_and_runs.values()])
)

return list(zip(ordered_ses_names, ordered_run_names))

def _validate_inputs(
self,
top_level_folder: Literal["rawdata", "derivatives"],
get_top_level_folder: Callable,
get_sub_level_folder: Callable,
get_sub_path: Callable,
get_run_path: Callable,
) -> None:
"""
Check the rawdata / derivatives path, subject path exists
and ensure run_names is a list of strings.
Parameters
----------
run_names : List[str]
List of run names to process, in order they should be
processed / concatenated.
Returns
-------
run_names : List[str]
Validated `run_names` as a List.
"""
assert get_top_level_folder().is_dir(), (
f"Ensure there is a folder in base path called '"
f"{top_level_folder}'.\n"
f"No {top_level_folder} directory found at "
f"{get_top_level_folder()}\n"
f"where subject-level folders must be placed."
)

assert get_sub_level_folder().is_dir(), (
f"Subject directory not found. {self.sub_name} "
f"is not a folder in {get_top_level_folder()}"
)

for ses_name in self.sessions_and_runs.keys():
assert (
ses_path := get_sub_path(ses_name)
).is_dir(), f"{ses_name} was not found at folder path {ses_path}"

for run_name in self.sessions_and_runs[ses_name]:
assert (run_path := get_run_path(ses_name, run_name)).is_dir(), (
f"The run folder {run_path.stem} cannot be found at "
f"file path {run_path.parent}."
)

gate_str = fnmatch.filter(run_name.split("_"), "g?")

assert len(gate_str) > 0, (
f"The SpikeGLX gate index should be in the run name. "
f"It was not found in the name {run_name}."
f"\nEnsure the gate number is in the SpikeGLX-output filename."
)

assert len(gate_str) == 1, (
f"The SpikeGLX gate appears in the name "
f"{run_name} more than once"
)

assert int(gate_str[0][1:]) == 0, (
f"Gate with index larger than 0 is not supported. This is found "
f"in run name {run_name}. "
)

def keys(self) -> KeysView:
return self.data.keys()

def items(self) -> ItemsView:
return self.data.items()

def values(self) -> ValuesView:
return self.data.values()
58 changes: 58 additions & 0 deletions spikewrap/data_classes/preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import shutil
from dataclasses import dataclass
from typing import Dict

import spikeinterface

from ..utils import utils
from .base import BaseUserDict


@dataclass
class PreprocessingData(BaseUserDict):
"""
Dictionary to store SpikeInterface preprocessing recordings.
Details on the preprocessing steps are held in the dictionary keys e.g.
e.g. 0-raw, 1-raw-bandpass_filter, 2-raw_bandpass_filter-common_average
and recording objects are held in the value. These are generated
by the `pipeline.preprocess.run_preprocessing()` function.
The class manages paths to raw data and preprocessing output,
as defines methods to dump key information and the SpikeInterface
binary to disk. Note that SI preprocessing is lazy and
preprocessing only run when the recording.get_traces()
is called, or the data is saved to binary.
Parameters
----------
base_path : Union[Path, str]
Path where the rawdata folder containing subjects.
sub_name : str
'subject' to preprocess. The subject top level dir should
reside in base_path/rawdata/.
run_names : Union[List[str], str]
The SpikeGLX run name (i.e. not including the gate index)
or list of run names.
"""

def __post_init__(self) -> None:
super().__post_init__()
self._validate_rawdata_inputs()

self.sync: Dict = {}

for ses_name, run_name in self.preprocessing_sessions_and_runs():
utils.update(self.data, ses_name, run_name, {"0-raw": None})
utils.update(self.sync, ses_name, run_name, None)

def _validate_rawdata_inputs(self) -> None:
self._validate_inputs(
"rawdata",
self.get_rawdata_top_level_path,
self.get_rawdata_sub_path,
self.get_rawdata_ses_path,
self.get_rawdata_run_path,
)
23 changes: 23 additions & 0 deletions spikewrap/examples/load_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pathlib import Path

from spikewrap.pipeline.load_data import load_data

base_path = Path(
r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-short-multises"
)

sub_name = "sub-1119617"
sessions_and_runs = {
"ses-001": [
"run-001_1119617_LSE1_shank12_g0",
"run-002_made_up_g0",
],
"ses-002": [
"run-001_1119617_pretest1_shank12_g0",
],
"ses-003": [
"run-002_1119617_pretest1_shank12_g0",
],
}

loaded_data = load_data(base_path, sub_name, sessions_and_runs, data_format="spikeglx")
Empty file removed tests/__init__.py
Empty file.
Empty file removed tests/test_integration/__init__.py
Empty file.
Empty file removed tests/test_unit/__init__.py
Empty file.

0 comments on commit 679098b

Please sign in to comment.