diff --git a/spikewrap/__init__.py b/spikewrap/__init__.py index 781354f..90c560d 100644 --- a/spikewrap/__init__.py +++ b/spikewrap/__init__.py @@ -7,7 +7,8 @@ pass from .pipeline.full_pipeline import run_full_pipeline -from .pipeline.preprocess import _preprocess_and_save_all_runs + +# from .pipeline.preprocess import _preprocess_and_save_all_runs from .pipeline.sort import run_sorting from .pipeline.postprocess import run_postprocess diff --git a/spikewrap/data_classes/postprocessing.py b/spikewrap/data_classes/postprocessing.py index 566dbdc..d762f14 100644 --- a/spikewrap/data_classes/postprocessing.py +++ b/spikewrap/data_classes/postprocessing.py @@ -42,6 +42,7 @@ class PostprocessingData: def __init__(self, sorting_path: Union[str, Path]) -> None: self.sorting_path = Path(sorting_path) + self.sorter_output_path = self.sorting_path / "sorter_output" self.sorting_info_path = self.sorting_path / utils.canonical_names( "sorting_yaml" @@ -151,9 +152,7 @@ def get_sorting_extractor_object(self) -> si.SortingExtractor: return sorting_without_excess_spikes def get_postprocessing_path(self) -> Path: - return self.sorting_data.get_postprocessing_path( - self.sorted_ses_name, self.sorted_run_name - ) + return utils.make_postprocessing_path(self.sorting_path) def get_quality_metrics_path(self) -> Path: return self.get_postprocessing_path() / "quality_metrics.csv" diff --git a/spikewrap/data_classes/sorting.py b/spikewrap/data_classes/sorting.py index 6e4dd2f..860df51 100644 --- a/spikewrap/data_classes/sorting.py +++ b/spikewrap/data_classes/sorting.py @@ -186,19 +186,31 @@ def check_ses_or_run_folders_in_datetime_order( # Paths # ---------------------------------------------------------------------------------- - def get_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: - return self._get_base_sorting_path(ses_name, run_name) / "sorting" - - def get_sorter_output_path(self, ses_name: str, run_name: Optional[str]) -> Path: - return self.get_sorting_path(ses_name, run_name) / "sorter_output" + def get_sorting_path( + self, ses_name: str, run_name: Optional[str], group_idx: Optional[int] = None + ) -> Path: + if group_idx is None: + format_group_name = "" + else: + format_group_name = f"group-{group_idx}" - def _get_sorting_info_path(self, ses_name: str, run_name: Optional[str]) -> Path: - return self.get_sorting_path(ses_name, run_name) / utils.canonical_names( - "sorting_yaml" + return ( + self.get_base_sorting_path(ses_name, run_name) + / format_group_name + / "sorting" ) - def get_postprocessing_path(self, ses_name: str, run_name: Optional[str]) -> Path: - return self._get_base_sorting_path(ses_name, run_name) / "postprocessing" + def get_sorter_output_path( + self, ses_name: str, run_name: Optional[str], group_idx: Optional[int] = None + ) -> Path: + return self.get_sorting_path(ses_name, run_name, group_idx) / "sorter_output" + + def _get_sorting_info_path( + self, ses_name: str, run_name: Optional[str], group_idx: Optional[int] = None + ) -> Path: + return self.get_sorting_path( + ses_name, run_name, group_idx + ) / utils.canonical_names("sorting_yaml") def _validate_derivatives_inputs(self): self._validate_inputs( @@ -247,7 +259,9 @@ def _make_run_name_from_multiple_run_names(self, run_names: List[str]) -> str: # Sorting info # ---------------------------------------------------------------------------------- - def save_sorting_info(self, ses_name: str, run_name: str) -> None: + def save_sorting_info( + self, ses_name: str, run_name: str, group_idx: Optional[int] = None + ) -> None: """ Save a sorting_info.yaml file containing a dictionary holding important information on the sorting. This is for provenance. @@ -289,7 +303,7 @@ def save_sorting_info(self, ses_name: str, run_name: str) -> None: sorting_info["datetime_created"] = utils.get_formatted_datetime() utils.dump_dict_to_yaml( - self._get_sorting_info_path(ses_name, run_name), sorting_info + self._get_sorting_info_path(ses_name, run_name, group_idx), sorting_info ) @property @@ -325,7 +339,7 @@ def get_preprocessed_recordings( raise NotImplementedError @abstractmethod - def _get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: + def get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: raise NotImplementedError @@ -364,7 +378,7 @@ def get_preprocessed_recordings( self.assert_names(ses_name, run_name) return self[self.concat_ses_name()] - def _get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: + def get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: """""" self.assert_names(ses_name, run_name) @@ -447,7 +461,7 @@ def get_preprocessed_recordings( return self[ses_name][run_name] - def _get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: + def get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: assert run_name == self.concat_run_name(ses_name) assert run_name is not None @@ -501,7 +515,7 @@ def get_preprocessed_recordings( assert run_name is not None return self[ses_name][run_name] - def _get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: + def get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: assert run_name is not None # TODO: centralise paths!!# TODO: centralise paths!!# TODO: centralise paths!! return ( diff --git a/spikewrap/examples/example_full_pipeline.py b/spikewrap/examples/example_full_pipeline.py index 41d8194..93dd67a 100644 --- a/spikewrap/examples/example_full_pipeline.py +++ b/spikewrap/examples/example_full_pipeline.py @@ -4,15 +4,19 @@ from spikewrap.pipeline.full_pipeline import run_full_pipeline base_path = Path( - r"C:\fMRIData\git-repo\spikewrap\tests\data\small_toy_data" # r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-long_origdata" + r"C:\fMRIData\git-repo\spikewrap\tests\data\small_toy_data", ) sub_name = "sub-001_type-test" -sessions_and_runs = {"all": ["all"]} + +sessions_and_runs = { + "all": ["all"], +} + # sub_name = "1119617" # sessions_and_runs = { -# "ses-001": ["1119617_LSE1_shank12_g0"], +# "ses-001": ["1119617_LSE1_shank12_g0"], # } config_name = "test_default" @@ -28,16 +32,13 @@ "spikeinterface", config_name, sorter, + sort_by_group=True, save_preprocessing_chunk_size=30000, existing_preprocessed_data="overwrite", existing_sorting_output="overwrite", overwrite_postprocessing=True, concat_sessions_for_sorting=False, # TODO: validate this at the start, in `run_full_pipeline` concat_runs_for_sorting=False, - # existing_preprocessed_data="skip_if_exists", # this is kind of confusing... - # existing_sorting_output="overwrite", - # overwrite_postprocessing=True, - # slurm_batch=False, ) print(f"TOOK {time.time() - t}") diff --git a/spikewrap/examples/example_preprocess.py b/spikewrap/examples/example_preprocess.py index 7bf88eb..e811f50 100644 --- a/spikewrap/examples/example_preprocess.py +++ b/spikewrap/examples/example_preprocess.py @@ -1,34 +1,46 @@ from pathlib import Path from spikewrap.pipeline.load_data import load_data -from spikewrap.pipeline.preprocess import run_preprocessing +from spikewrap.pipeline.preprocess import PreprocessPipeline base_path = Path( - r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-short-multises" + "/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/code/spikewrap/tests/data/small_toy_data" + # r"C:\fMRIData\git-repo\spikewrap\tests\data\small_toy_data" + # r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-short-multises" # r"C:\data\ephys\test_data\steve_multi_run\1119617\time-miniscule-multises" ) -sub_name = "sub-1119617" +sub_name = "sub-001_type-test" +# sub_name = "sub-1119617" + sessions_and_runs = { - "ses-001": [ - "run-001_1119617_LSE1_shank12_g0", - "run-002_made_up_g0", - ], - "ses-002": [ - "run-001_1119617_pretest1_shank12_g0", - ], - "ses-003": [ - "run-002_1119617_pretest1_shank12_g0", - ], + "ses-001": ["all"], + "ses-002": ["all"], } -loaded_data = load_data(base_path, sub_name, sessions_and_runs, data_format="spikeglx") +if False: + sessions_and_runs = { + "ses-001": [ + "run-001_1119617_LSE1_shank12_g0", + "run-002_made_up_g0", + ], + "ses-002": [ + "run-001_1119617_pretest1_shank12_g0", + ], + "ses-003": [ + "run-002_1119617_pretest1_shank12_g0", + ], + } + +loaded_data = load_data( + base_path, sub_name, sessions_and_runs, data_format="spikeinterface" +) -run_preprocessing( +preprocess_pipeline = PreprocessPipeline( loaded_data, pp_steps="default", handle_existing_data="overwrite", preprocess_by_group=True, log=True, - slurm_batch=False, ) +preprocess_pipeline.run(slurm_batch=True) diff --git a/spikewrap/pipeline/full_pipeline.py b/spikewrap/pipeline/full_pipeline.py index f82c379..d633d73 100644 --- a/spikewrap/pipeline/full_pipeline.py +++ b/spikewrap/pipeline/full_pipeline.py @@ -13,7 +13,8 @@ from spikewrap.configs.configs import get_configs from spikewrap.pipeline.load_data import load_data from spikewrap.pipeline.postprocess import run_postprocess -from spikewrap.pipeline.preprocess import run_preprocessing + +# from spikewrap.pipeline.preprocess import run_preprocessing from spikewrap.pipeline.sort import run_sorting from spikewrap.utils import logging_sw, slurm, utils, validate @@ -26,6 +27,7 @@ def run_full_pipeline( config_name: str = "default", sorter: str = "kilosort2_5", preprocess_by_group: bool = False, + sort_by_group: bool = False, concat_sessions_for_sorting: bool = False, concat_runs_for_sorting: bool = False, existing_preprocessed_data: HandleExisting = "fail_if_exists", @@ -53,6 +55,7 @@ def run_full_pipeline( "config_name": config_name, "sorter": sorter, "preprocess_by_group": preprocess_by_group, + "sort_by_group": sort_by_group, "concat_sessions_for_sorting": concat_sessions_for_sorting, "concat_runs_for_sorting": concat_runs_for_sorting, "existing_preprocessed_data": existing_preprocessed_data, @@ -72,6 +75,7 @@ def run_full_pipeline( config_name, sorter, preprocess_by_group, + sort_by_group, concat_sessions_for_sorting, concat_runs_for_sorting, existing_preprocessed_data, @@ -91,6 +95,7 @@ def _run_full_pipeline( config_name: str = "default", sorter: str = "kilosort2_5", preprocess_by_group: bool = False, + sort_by_group: bool = False, concat_sessions_for_sorting: bool = False, concat_runs_for_sorting: bool = False, existing_preprocessed_data: HandleExisting = "fail_if_exists", @@ -231,6 +236,7 @@ def _run_full_pipeline( sub_name, sessions_and_runs, sorter, + sort_by_group, concat_sessions_for_sorting, concat_runs_for_sorting, sorter_options, @@ -240,20 +246,26 @@ def _run_full_pipeline( # Run Postprocessing for ses_name, run_name in sorting_data.get_sorting_sessions_and_runs(): - sorting_path = sorting_data.get_sorting_path(ses_name, run_name) - - postprocess_data = run_postprocess( - sorting_path, - overwrite_postprocessing=overwrite_postprocessing, - existing_waveform_data="fail_if_exists", - waveform_options=waveform_options, - ) + for sorting_path in _get_sorting_paths( + sorting_data, ses_name, run_name, sort_by_group + ): + postprocess_data = run_postprocess( + sorting_path, + overwrite_postprocessing=overwrite_postprocessing, + existing_waveform_data="fail_if_exists", + waveform_options=waveform_options, + ) # Delete intermediate files for ses_name, run_name in sorting_data.get_sorting_sessions_and_runs(): - handle_delete_intermediate_files( - ses_name, run_name, sorting_data, delete_intermediate_files - ) + for sorting_path in _get_sorting_paths( + sorting_data, ses_name, run_name, sort_by_group + ): + postprocessing_path = utils.make_postprocessing_path(sorting_path) + + handle_delete_intermediate_files( + sorting_path, postprocessing_path, delete_intermediate_files + ) logs.stop_logging() return ( @@ -262,15 +274,37 @@ def _run_full_pipeline( ) +def _get_sorting_paths( + sorting_data: SortingData, ses_name: str, run_name: str, sort_by_group: bool +) -> List[Path]: + """ """ + if sort_by_group: + all_group_paths = sorting_data.get_base_sorting_path(ses_name, run_name).glob( + "group-*" + ) + group_indexes = [ + int(group.name.split("group-")[1]) + for group in all_group_paths + if group.is_dir() + ] # TODO: kind of hacky + all_sorting_paths = [ + sorting_data.get_sorting_path(ses_name, run_name, idx) + for idx in group_indexes + ] + else: + all_sorting_paths = [sorting_data.get_sorting_path(ses_name, run_name)] + + return all_sorting_paths + + # -------------------------------------------------------------------------------------- # Remove Intermediate Files # -------------------------------------------------------------------------------------- def handle_delete_intermediate_files( - ses_name: str, - run_name: Optional[str], - sorting_data: SortingData, + sorting_path: Path, + postprocessing_path: Path, delete_intermediate_files: DeleteIntermediate, ): """ @@ -279,22 +313,13 @@ def handle_delete_intermediate_files( for Kilosort). See `run_full_pipeline` for inputs """ if "recording.dat" in delete_intermediate_files: - if ( - recording_file := sorting_data.get_sorter_output_path(ses_name, run_name) - / "recording.dat" - ).is_file(): + if (recording_file := sorting_path / "recording.dat").is_file(): recording_file.unlink() if "temp_wh.dat" in delete_intermediate_files: - if ( - recording_file := sorting_data.get_sorter_output_path(ses_name, run_name) - / "temp_wh.dat" - ).is_file(): + if (recording_file := sorting_path / "temp_wh.dat").is_file(): recording_file.unlink() if "waveforms" in delete_intermediate_files: - if ( - waveforms_path := sorting_data.get_postprocessing_path(ses_name, run_name) - / "waveforms" - ).is_dir(): + if (waveforms_path := postprocessing_path / "waveforms").is_dir(): shutil.rmtree(waveforms_path) diff --git a/spikewrap/pipeline/preprocess.py b/spikewrap/pipeline/preprocess.py index 5821a0f..1b8abe5 100644 --- a/spikewrap/pipeline/preprocess.py +++ b/spikewrap/pipeline/preprocess.py @@ -15,81 +15,198 @@ # -------------------------------------------------------------------------------------- -def run_preprocessing( - preprocess_data: PreprocessingData, - pp_steps: Union[Dict, str], - handle_existing_data: HandleExisting, - preprocess_by_group: bool, - chunk_size: Optional[int] = None, - slurm_batch: Union[bool, Dict] = False, - log: bool = True, -): - """ - Main entry function to run preprocessing and write to file. Preprocessed - lazy spikeinterface recordings will be added to all sessions / runs in - `preprocess_data` and written to file. - - Parameters - ---------- - - preprocess_data : PreprocessingData - A preprocessing data object that has as attributes the - paths to rawdata. The pp_steps attribute is set on - this class during execution of this function. +class PreprocessPipeline: + """ """ + + def __init__( + self, + preprocess_data: PreprocessingData, + pp_steps: Union[Dict, str], + handle_existing_data: HandleExisting, + preprocess_by_group: bool, + chunk_size: Optional[int] = None, + # slurm_batch: Union[bool, Dict] = False, + log: bool = True, + ): + if isinstance(pp_steps, Dict): + pp_steps_dict = pp_steps + else: + pp_steps_dict, _, _ = configs.get_configs(pp_steps) + # pp_steps_dict = MappingProxyType(pp_steps_dict) + + self.passed_arguments = { # MappingProxyType( + # { + "preprocess_data": preprocess_data, + "pp_steps_dict": pp_steps_dict, + "handle_existing_data": handle_existing_data, + "preprocess_by_group": preprocess_by_group, + "chunk_size": chunk_size, + # "slurm_batch": slurm_batch, + "log": log, + } + # ) + validate.check_function_arguments(self.passed_arguments) + + # TODO: do some check the name is valid + def run(self, slurm_batch: Union[bool, Dict] = False): + """ """ + if slurm_batch: + slurm.run_in_slurm( + slurm_batch, + self._preprocess_and_save_all_runs, + self.passed_arguments, + ), + else: + self._preprocess_and_save_all_runs(**self.passed_arguments) + + # -------------------------------------------------------------------------------------- + # Private Functions + # -------------------------------------------------------------------------------------- + + def _preprocess_and_save_all_runs( + self, + preprocess_data: PreprocessingData, + pp_steps_dict: Dict, + handle_existing_data: HandleExisting, + preprocess_by_group: bool, + chunk_size: Optional[int] = None, + log: bool = True, + ) -> None: + """ + Handle the loading of existing preprocessed data. + See `run_preprocessing()` for details. + + This function validates all input arguments and initialises logging. + Then, it will iterate over every run in `preprocess_data` and + check whether preprocessing needs to be run and saved based on the + `handle_existing_data` option. If so, it will fill the relevant run + with the preprocessed spikeinterface recording object and save to disk. + """ + passed_arguments = locals() + validate.check_function_arguments(passed_arguments) + + if log: + logs = logging_sw.get_started_logger( + utils.get_logging_path( + preprocess_data.base_path, preprocess_data.sub_name + ), + "preprocessing", + ) + utils.show_passed_arguments(passed_arguments, "`run_preprocessing`") - pp_steps: The name of valid preprocessing .yaml file (without the yaml extension). - stored in spikewrap/configs. + for ses_name, run_name in preprocess_data.flat_sessions_and_runs(): + utils.message_user(f"Preprocessing run {run_name}...") - existing_preprocessed_data : custom_types.HandleExisting - Determines how existing preprocessed data (e.g. from a prior pipeline run) - is handled. - "overwrite" : Will overwrite any existing preprocessed data output. - This will delete the 'preprocessed' folder. Therefore, - never save derivative work there. - "skip_if_exists" : will search for existing data and skip preprocesing - if it exists (sorting will run on existing - preprocessed data). - Otherwise, will preprocess and save the current run. - "fail_if_exists" : If existing preprocessed data is found, an error - will be raised. - - slurm_batch : Union[bool, Dict] - see `run_full_pipeline()` for details. - """ - # TOOD: refactor and handle argument groups separately. - # Avoid duplication with logging. - passed_arguments = locals() - validate.check_function_arguments(passed_arguments) + to_save, overwrite = self._handle_existing_data_options( + preprocess_data, ses_name, run_name, handle_existing_data + ) - if isinstance(pp_steps, Dict): - pp_steps_dict = pp_steps - else: - # TODO: do some check the name is valid - pp_steps_dict, _, _ = configs.get_configs(pp_steps) # TODO: call 'config_name' - - if slurm_batch: - slurm.run_in_slurm( - slurm_batch, - _preprocess_and_save_all_runs, - { - "preprocess_data": preprocess_data, - "pp_steps": pp_steps_dict, - "preprocess_by_group": preprocess_by_group, - "chunk_size": chunk_size, - "handle_existing_data": handle_existing_data, - "log": log, - }, - ), - else: - _preprocess_and_save_all_runs( + if to_save: + self._preprocess_and_save_single_run( + preprocess_data, + ses_name, + run_name, + pp_steps_dict, + overwrite, + preprocess_by_group, + chunk_size, + ) + + if log: + logs.stop_logging() + + def _preprocess_and_save_single_run( + self, + preprocess_data: PreprocessingData, + ses_name: str, + run_name: str, + pp_steps_dict: Dict, + overwrite: bool, + preprocess_by_group: bool, + chunk_size: Optional[int], + ) -> None: + """ + Given a single session and run, fill the entry for this run + on the `preprocess_data` object and write to disk. + """ + _fill_run_data_with_preprocessed_recording( preprocess_data, + ses_name, + run_name, pp_steps_dict, - handle_existing_data, preprocess_by_group, - chunk_size, - log, ) + preprocess_data.save_preprocessed_data( + ses_name, run_name, overwrite, chunk_size + ) + + def _handle_existing_data_options( + self, + preprocess_data: PreprocessingData, + ses_name: str, + run_name: str, + handle_existing_data: HandleExisting, + ) -> Tuple[bool, bool]: + """ + Determine whether preprocesing for this run needs to be performed based + on the `handle_existing_data setting`. If preprocessing does not exist, + preprocessing + is always run. Otherwise, if it already exists, the behaviour depends on + the `handle_existing_data` setting. + + Returns + ------- + + to_save : bool + Whether the preprocessing needs to be run and saved. + + to_overwrite : bool + If saving, set the `overwrite` flag to confirm existing data should + be overwritten. + """ + preprocess_path = preprocess_data.get_preprocessing_path(ses_name, run_name) + + if handle_existing_data == "skip_if_exists": + if preprocess_path.is_dir(): + utils.message_user( + f"\nSkipping preprocessing, using file at " + f"{preprocess_path} for sorting.\n" + ) + to_save = False + overwrite = False + else: + utils.message_user( + f"No data found at {preprocess_path}, saving preprocessed data." + ) + to_save = True + overwrite = False + + elif handle_existing_data == "overwrite": + if preprocess_path.is_dir(): + utils.message_user(f"Removing existing file at {preprocess_path}\n") + + utils.message_user(f"Saving preprocessed data to {preprocess_path}") + to_save = True + overwrite = True + + elif handle_existing_data == "fail_if_exists": + if preprocess_path.is_dir(): + raise FileExistsError( + f"Preprocessed binary already exists at " + f"{preprocess_path}. " + f"To overwrite, set 'existing_preprocessed_data' to 'overwrite'" + ) + to_save = True + overwrite = False + + return to_save, overwrite + + +# -------------------------------------------------------------------------------------- +# Preprocessing Functions +# -------------------------------------------------------------------------------------- + def fill_all_runs_with_preprocessed_recording( preprocess_data: PreprocessingData, @@ -121,145 +238,6 @@ def fill_all_runs_with_preprocessed_recording( ) -# -------------------------------------------------------------------------------------- -# Private Functions -# -------------------------------------------------------------------------------------- - - -def _preprocess_and_save_all_runs( - preprocess_data: PreprocessingData, - pp_steps_dict: Dict, - handle_existing_data: HandleExisting, - preprocess_by_group: bool, - chunk_size: Optional[int] = None, - log: bool = True, -) -> None: - """ - Handle the loading of existing preprocessed data. - See `run_preprocessing()` for details. - - This function validates all input arguments and initialises logging. - Then, it will iterate over every run in `preprocess_data` and - check whether preprocessing needs to be run and saved based on the - `handle_existing_data` option. If so, it will fill the relevant run - with the preprocessed spikeinterface recording object and save to disk. - """ - passed_arguments = locals() - validate.check_function_arguments(passed_arguments) - - if log: - logs = logging_sw.get_started_logger( - utils.get_logging_path(preprocess_data.base_path, preprocess_data.sub_name), - "preprocessing", - ) - utils.show_passed_arguments(passed_arguments, "`run_preprocessing`") - - for ses_name, run_name in preprocess_data.flat_sessions_and_runs(): - utils.message_user(f"Preprocessing run {run_name}...") - - to_save, overwrite = _handle_existing_data_options( - preprocess_data, ses_name, run_name, handle_existing_data - ) - - if to_save: - _preprocess_and_save_single_run( - preprocess_data, - ses_name, - run_name, - pp_steps_dict, - overwrite, - preprocess_by_group, - chunk_size, - ) - - if log: - logs.stop_logging() - - -def _preprocess_and_save_single_run( - preprocess_data: PreprocessingData, - ses_name: str, - run_name: str, - pp_steps_dict: Dict, - overwrite: bool, - preprocess_by_group: bool, - chunk_size: Optional[int], -) -> None: - """ - Given a single session and run, fill the entry for this run - on the `preprocess_data` object and write to disk. - """ - _fill_run_data_with_preprocessed_recording( - preprocess_data, - ses_name, - run_name, - pp_steps_dict, - preprocess_by_group, - ) - - preprocess_data.save_preprocessed_data(ses_name, run_name, overwrite, chunk_size) - - -def _handle_existing_data_options( - preprocess_data: PreprocessingData, - ses_name: str, - run_name: str, - handle_existing_data: HandleExisting, -) -> Tuple[bool, bool]: - """ - Determine whether preprocesing for this run needs to be performed based - on the `handle_existing_data setting`. If preprocessing does not exist, preprocessing - is always run. Otherwise, if it already exists, the behaviour depends on - the `handle_existing_data` setting. - - Returns - ------- - - to_save : bool - Whether the preprocessing needs to be run and saved. - - to_overwrite : bool - If saving, set the `overwrite` flag to confirm existing data should - be overwritten. - """ - preprocess_path = preprocess_data.get_preprocessing_path(ses_name, run_name) - - if handle_existing_data == "skip_if_exists": - if preprocess_path.is_dir(): - utils.message_user( - f"\nSkipping preprocessing, using file at " - f"{preprocess_path} for sorting.\n" - ) - to_save = False - overwrite = False - else: - utils.message_user( - f"No data found at {preprocess_path}, saving preprocessed data." - ) - to_save = True - overwrite = False - - elif handle_existing_data == "overwrite": - if preprocess_path.is_dir(): - utils.message_user(f"Removing existing file at {preprocess_path}\n") - - utils.message_user(f"Saving preprocessed data to {preprocess_path}") - to_save = True - overwrite = True - - elif handle_existing_data == "fail_if_exists": - if preprocess_path.is_dir(): - raise FileExistsError( - f"Preprocessed binary already exists at " - f"{preprocess_path}. " - f"To overwrite, set 'existing_preprocessed_data' to 'overwrite'" - ) - to_save = True - overwrite = False - - return to_save, overwrite - - def _fill_run_data_with_preprocessed_recording( preprocess_data: PreprocessingData, ses_name: str, diff --git a/spikewrap/pipeline/sort.py b/spikewrap/pipeline/sort.py index 5680002..5d5b62e 100644 --- a/spikewrap/pipeline/sort.py +++ b/spikewrap/pipeline/sort.py @@ -27,6 +27,7 @@ def run_sorting( sub_name: str, sessions_and_runs: Dict[str, List[str]], sorter: str, + sort_by_group: bool = False, concatenate_sessions: bool = False, concatenate_runs: bool = False, sorter_options: Optional[Dict] = None, @@ -47,6 +48,8 @@ def run_sorting( "sub_name": sub_name, "sessions_and_runs": sessions_and_runs, "concatenate_sessions": concatenate_sessions, + "sorter": sorter, + "sort_by_group": sort_by_group, "concatenate_runs": concatenate_runs, "sorter_options": sorter_options, "existing_sorting_output": existing_sorting_output, @@ -59,6 +62,7 @@ def run_sorting( sub_name, sessions_and_runs, sorter, + sort_by_group, concatenate_sessions, concatenate_runs, sorter_options, @@ -72,6 +76,7 @@ def _run_sorting( sub_name: str, sessions_and_runs: Dict[str, List[str]], sorter: str, + sort_by_group: bool, concatenate_sessions: bool = False, concatenate_runs: bool = False, sorter_options: Optional[Dict] = None, @@ -179,6 +184,7 @@ def _run_sorting( sorting_data, singularity_image, docker_image, + sort_by_group, existing_sorting_output=existing_sorting_output, **sorter_options_dict, ) @@ -223,6 +229,7 @@ def run_sorting_on_all_runs( sorting_data: SortingData, singularity_image: Union[Literal[True], None, str], docker_image: Optional[Literal[True]], + sort_by_group: bool, existing_sorting_output: HandleExisting, **sorter_options_dict, ) -> None: @@ -255,46 +262,65 @@ def run_sorting_on_all_runs( utils.message_user(f"Starting {sorting_data.sorter} sorting...") for ses_name, run_name in sorting_data.get_sorting_sessions_and_runs(): - sorting_output_path = sorting_data.get_sorting_path(ses_name, run_name) - preprocessed_recording = sorting_data.get_preprocessed_recordings( + utils.message_user(f"Sorting session: {ses_name} \n" f"run: {run_name}...") + + orig_preprocessed_recording = sorting_data.get_preprocessed_recordings( ses_name, run_name ) - utils.message_user( - f"Sorting session: {ses_name} \n" - f"run: {ses_name}..." - # TODO: I think can just use run_name now? - ) + if sort_by_group: + split_preprocessing = orig_preprocessed_recording.split_by("group") - if sorting_output_path.is_dir(): - if existing_sorting_output == "fail_if_exists": + if len(split_preprocessing.keys()) == 1: raise RuntimeError( - f"Sorting output already exists at {sorting_output_path} and" - f"`existing_sorting_output` is set to 'fail_if_exists'." - ) - - elif existing_sorting_output == "skip_if_exists": - utils.message_user( - f"Sorting output already exists at {sorting_output_path}. Nothing " - f"will be done. The existing sorting will be used for " - f"postprocessing " - f"if running with `run_full_pipeline`" + "`sort_by_group` is `True` but the recording only has " + "one channel group. Set `sort_by_group`to `False` " + "for this recording." ) - continue - - quick_safety_check(existing_sorting_output, sorting_output_path) - - ss.run_sorter( - sorting_data.sorter, - preprocessed_recording, - output_folder=sorting_output_path, - singularity_image=singularity_image, - docker_image=docker_image, - remove_existing_folder=True, - **sorter_options_dict, - ) - sorting_data.save_sorting_info(ses_name, run_name) + group_indexes = list(split_preprocessing.keys()) + all_preprocessed_recordings = list(split_preprocessing.values()) + else: + group_indexes = [None] + all_preprocessed_recordings = [orig_preprocessed_recording] + + for group_idx, prepro_recording in zip( + group_indexes, all_preprocessed_recordings + ): + sorting_output_path = sorting_data.get_sorting_path( + ses_name, run_name, group_idx + ) + + if sorting_output_path.is_dir(): + if existing_sorting_output == "fail_if_exists": + raise RuntimeError( + f"Sorting output already exists at {sorting_output_path} and" + f"`existing_sorting_output` is set to 'fail_if_exists'." + ) + + elif existing_sorting_output == "skip_if_exists": + utils.message_user( + f"Sorting output already exists at {sorting_output_path}. Nothing " + f"will be done. The existing sorting will be used for " + f"postprocessing " + f"if running with `run_full_pipeline`" + ) + continue + + quick_safety_check(existing_sorting_output, sorting_output_path) + + ss.run_sorter( + sorting_data.sorter, + prepro_recording, + output_folder=sorting_output_path, + singularity_image=singularity_image, + docker_image=docker_image, + remove_existing_folder=True, + **sorter_options_dict, + ) + + # TODO: how does this interact with concat sessions and recordings? + sorting_data.save_sorting_info(ses_name, run_name, group_idx) def quick_safety_check( diff --git a/spikewrap/utils/utils.py b/spikewrap/utils/utils.py index f45c2e7..663cf15 100644 --- a/spikewrap/utils/utils.py +++ b/spikewrap/utils/utils.py @@ -166,6 +166,11 @@ def cast_pp_steps_values( # -------------------------------------------------------------------------------------- +# TODO: move +def make_postprocessing_path(sorting_path: Path): + return sorting_path.parent / "postprocessing" + + def get_keys_first_char( data: Union[PreprocessingData, SortingData], as_int: bool = False ) -> Union[List[str], List[int]]: diff --git a/tests/data/small_toy_data/in_container_params.json b/tests/data/small_toy_data/in_container_params.json new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/tests/data/small_toy_data/in_container_params.json @@ -0,0 +1 @@ +{} diff --git a/tests/data/small_toy_data/in_container_recording.json b/tests/data/small_toy_data/in_container_recording.json new file mode 100644 index 0000000..926e962 --- /dev/null +++ b/tests/data/small_toy_data/in_container_recording.json @@ -0,0 +1,743 @@ +{ + "class": "spikeinterface.core.channelslice.ChannelSliceRecording", + "module": "spikeinterface", + "kwargs": { + "parent_recording": { + "class": "spikeinterface.preprocessing.common_reference.CommonReferenceRecording", + "module": "spikeinterface", + "kwargs": { + "recording": { + "class": "spikeinterface.preprocessing.filter.BandpassFilterRecording", + "module": "spikeinterface", + "kwargs": { + "recording": { + "class": "spikeinterface.preprocessing.phase_shift.PhaseShiftRecording", + "module": "spikeinterface", + "kwargs": { + "recording": { + "class": "spikeinterface.preprocessing.astype.AstypeRecording", + "module": "spikeinterface", + "kwargs": { + "recording": { + "class": "spikeinterface.core.binaryfolder.BinaryFolderRecording", + "module": "spikeinterface", + "kwargs": { + "folder_path": "/fMRIData/git-repo/spikewrap/tests/data/small_toy_data/rawdata/sub-001_type-test/ses-001/ephys/ses-001_run-001" + }, + "version": "0.100.0.dev0", + "annotations": { + "is_filtered": true + }, + "properties": { + "group": [ + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 2, + 2, + 2, + 2, + 3, + 3, + 3, + 3 + ], + "location": [ + [ + 0.0, + 0.0 + ], + [ + 0.0, + 40.0 + ], + [ + 0.0, + 80.0 + ], + [ + 0.0, + 120.0 + ], + [ + 0.0, + 160.0 + ], + [ + 0.0, + 200.0 + ], + [ + 0.0, + 240.0 + ], + [ + 0.0, + 280.0 + ], + [ + 0.0, + 320.0 + ], + [ + 0.0, + 360.0 + ], + [ + 0.0, + 400.0 + ], + [ + 0.0, + 440.0 + ], + [ + 0.0, + 480.0 + ], + [ + 0.0, + 520.0 + ], + [ + 0.0, + 560.0 + ], + [ + 0.0, + 600.0 + ] + ], + "gain_to_uV": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ], + "offset_to_uV": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ] + }, + "relative_paths": false + }, + "dtype": " 1 + + for sorting_output_path in sorted_groups: + assert (sorting_output_path / "sorting").is_dir() + assert (sorting_output_path / "postprocessing").is_dir() + ses_path = sub_path / ses_name / "ephys" concat_all_run_names = "".join( diff --git a/tests/test_integration/test_full_pipeline.py b/tests/test_integration/test_full_pipeline.py index 35c0048..373ef0f 100644 --- a/tests/test_integration/test_full_pipeline.py +++ b/tests/test_integration/test_full_pipeline.py @@ -4,7 +4,7 @@ import pytest import spikeinterface as si import spikeinterface.extractors as se -from spikeinterface import concatenate_recordings +from spikeinterface import concatenate_recordings, sorters from spikeinterface.preprocessing import ( astype, bandpass_filter, @@ -12,7 +12,9 @@ phase_shift, ) -from spikewrap.data_classes.postprocessing import load_saved_sorting_output +from spikewrap.data_classes.postprocessing import ( + load_saved_sorting_output, +) from spikewrap.pipeline import full_pipeline, preprocess from spikewrap.pipeline.load_data import load_data from spikewrap.utils import checks, utils @@ -125,7 +127,8 @@ def test_no_concatenation_all_sorters_single_run(self, test_info, sorter): self.check_no_concat_results(test_info, loaded_data, sorting_data, sorter) @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) - def test_no_concatenation_single_run(self, test_info): + @pytest.mark.parametrize("sort_by_group", [True, False]) + def test_no_concatenation_single_run(self, test_info, sort_by_group): """ Run the full pipeline for a single session and run, and check preprocessing, sorting and waveforms. @@ -135,18 +138,22 @@ def test_no_concatenation_single_run(self, test_info): loaded_data, sorting_data = self.run_full_pipeline( *test_info, data_format=DEFAULT_FORMAT, + sort_by_group=sort_by_group, sorter=DEFAULT_SORTER, concatenate_sessions=False, concatenate_runs=False, ) - self.check_correct_folders_exist(test_info, False, False, DEFAULT_SORTER) + self.check_correct_folders_exist( + test_info, False, False, DEFAULT_SORTER, sort_by_group=sort_by_group + ) self.check_no_concat_results( - test_info, loaded_data, sorting_data, DEFAULT_SORTER + test_info, loaded_data, sorting_data, DEFAULT_SORTER, sort_by_group ) @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) - def test_no_concatenation_multiple_runs(self, test_info): + @pytest.mark.parametrize("sort_by_group", [True, False]) + def test_no_concatenation_multiple_runs(self, test_info, sort_by_group): """ For DEFAULT_SORTER, check `full_pipeline` across multiple sessions and runs without concatenation. @@ -156,16 +163,20 @@ def test_no_concatenation_multiple_runs(self, test_info): data_format=DEFAULT_FORMAT, concatenate_sessions=False, concatenate_runs=False, + sort_by_group=sort_by_group, sorter=DEFAULT_SORTER, ) - self.check_correct_folders_exist(test_info, False, False, DEFAULT_SORTER) - - self.check_correct_folders_exist(test_info, False, False, DEFAULT_SORTER) - self.check_no_concat_results(test_info, loaded_data, sorting_data) + self.check_correct_folders_exist( + test_info, False, False, DEFAULT_SORTER, sort_by_group=sort_by_group + ) + self.check_no_concat_results( + test_info, loaded_data, sorting_data, DEFAULT_SORTER, sort_by_group + ) @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) - def test_concatenate_runs_but_not_sessions(self, test_info): + @pytest.mark.parametrize("sort_by_group", [True, False]) + def test_concatenate_runs_but_not_sessions(self, test_info, sort_by_group): """ For DEFAULT_SORTER, check `full_pipeline` across multiple sessions concatenating runs, but not sessions. This results in a single @@ -177,16 +188,24 @@ def test_concatenate_runs_but_not_sessions(self, test_info): data_format=DEFAULT_FORMAT, concatenate_sessions=False, concatenate_runs=True, + sort_by_group=sort_by_group, sorter=DEFAULT_SORTER, ) - self.check_correct_folders_exist(test_info, False, True, DEFAULT_SORTER) + self.check_correct_folders_exist( + test_info, + False, + True, + DEFAULT_SORTER, + sort_by_group=sort_by_group, + ) self.check_concatenate_runs_but_not_sessions( - test_info, loaded_data, sorting_data + test_info, loaded_data, sorting_data, sort_by_group ) @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) - def test_concatenate_sessions_and_runs(self, test_info): + @pytest.mark.parametrize("sort_by_group", [True, False]) + def test_concatenate_sessions_and_runs(self, test_info, sort_by_group): """ For DEFAULT_SORTER, check `full_pipeline` across multiple sessions concatenating runs and sessions. This will lead to a single @@ -198,10 +217,15 @@ def test_concatenate_sessions_and_runs(self, test_info): concatenate_sessions=True, concatenate_runs=True, sorter=DEFAULT_SORTER, + sort_by_group=sort_by_group, ) - self.check_correct_folders_exist(test_info, True, True, DEFAULT_SORTER) - self.check_concatenate_sessions_and_runs(test_info, loaded_data, sorting_data) + self.check_correct_folders_exist( + test_info, True, True, DEFAULT_SORTER, sort_by_group + ) + self.check_concatenate_sessions_and_runs( + test_info, loaded_data, sorting_data, sort_by_group=sort_by_group + ) @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) def test_ses_concat_no_run_concat(self, test_info): @@ -225,7 +249,8 @@ def test_ses_concat_no_run_concat(self, test_info): ) @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) - def test_existing_output_settings(self, test_info): + @pytest.mark.parametrize("sort_by_group", [True, False]) + def test_existing_output_settings(self, test_info, sort_by_group): """ In spikewrap existing preprocessed and sorting output data is handled with options `fail_if_exists`, `skip_if_exists` or @@ -245,6 +270,7 @@ def test_existing_output_settings(self, test_info): self.run_full_pipeline( *test_info, data_format=DEFAULT_FORMAT, + sort_by_group=sort_by_group, existing_preprocessed_data="fail_if_exists", existing_sorting_output="fail_if_exists", overwrite_postprocessing=False, @@ -252,11 +278,14 @@ def test_existing_output_settings(self, test_info): ) # Test outputs are overwritten if `overwrite` set. - file_paths = self.write_an_empty_file_in_outputs(test_info, ses_name, run_name) + file_paths = self.write_an_empty_file_in_outputs( + test_info, ses_name, run_name, sort_by_group + ) self.run_full_pipeline( *test_info, data_format=DEFAULT_FORMAT, + sort_by_group=sort_by_group, existing_preprocessed_data="overwrite", existing_sorting_output="overwrite", overwrite_postprocessing=True, @@ -266,13 +295,16 @@ def test_existing_output_settings(self, test_info): for path_ in file_paths: assert not path_.is_file() - file_paths = self.write_an_empty_file_in_outputs(test_info, ses_name, run_name) + file_paths = self.write_an_empty_file_in_outputs( + test_info, ses_name, run_name, sort_by_group + ) # Test outputs are not overwritten if `skip_if_exists`. # Postprocessing is always deleted self.run_full_pipeline( *test_info, data_format=DEFAULT_FORMAT, + sort_by_group=sort_by_group, existing_preprocessed_data="skip_if_exists", existing_sorting_output="skip_if_exists", overwrite_postprocessing=True, @@ -287,6 +319,7 @@ def test_existing_output_settings(self, test_info): self.run_full_pipeline( *test_info, data_format=DEFAULT_FORMAT, + sort_by_group=sort_by_group, existing_preprocessed_data="fail_if_exists", existing_sorting_output="skip_if_exists", overwrite_postprocessing=True, @@ -307,6 +340,7 @@ def test_existing_output_settings(self, test_info): self.run_full_pipeline( *test_info, data_format=DEFAULT_FORMAT, + sort_by_group=sort_by_group, existing_preprocessed_data="skip_if_exists", existing_sorting_output="fail_if_exists", overwrite_postprocessing=True, @@ -320,6 +354,7 @@ def test_existing_output_settings(self, test_info): self.run_full_pipeline( *test_info, data_format=DEFAULT_FORMAT, + sort_by_group=sort_by_group, existing_preprocessed_data="skip_if_exists", existing_sorting_output="skip_if_exists", overwrite_postprocessing=False, @@ -354,7 +389,12 @@ def test_smoke_supply_chunk_size(self, test_info, capsys, specify_chunk_size): # ---------------------------------------------------------------------------------- def check_no_concat_results( - self, test_info, loaded_data, sorting_data, sorter=DEFAULT_SORTER + self, + test_info, + loaded_data, + sorting_data, + sorter=DEFAULT_SORTER, + sort_by_group=False, ): """ After `full_pipeline` is run, check the preprocessing, sorting and postprocessing @@ -410,20 +450,23 @@ def check_no_concat_results( ) paths = self.get_output_paths( - test_info, ses_name, run_name, sorter=sorter + test_info, ses_name, run_name, sort_by_group, sorter=sorter ) - self.check_waveforms( - paths["sorter_output"], - paths["postprocessing"], - recs_to_test=[ - sorting_data[ses_name][run_name], - ], - sorter=sorter, - ) + for sorter_output_path, postprocessing_path in zip( + paths["sorter_output"], paths["postprocessing"] + ): + self.check_waveforms( + sorter_output_path, + postprocessing_path, + recs_to_test=[ + sorting_data[ses_name][run_name], + ], + sorter=sorter, + ) def check_concatenate_runs_but_not_sessions( - self, test_info, loaded_data, sorting_data + self, test_info, loaded_data, sorting_data, sort_by_group ): """ Similar to `check_no_concat_results()`, however now test with @@ -483,27 +526,35 @@ def check_concatenate_runs_but_not_sessions( # Load the recording.dat and check it matches the expected data. # Finally, check the waveforms match the preprocessed data. paths = self.get_output_paths( - test_info, ses_name, concat_run_name, concatenate_runs=True + test_info, + ses_name, + concat_run_name, + concatenate_runs=True, + sort_by_group=sort_by_group, ) + for sorter_output_path, postprocessing_path, recording_dat_path in zip( + paths["sorter_output"], paths["postprocessing"], paths["recording_dat"] + ): + if "kilosort" in sorting_data.sorter: + saved_recording = si.read_binary( + recording_dat_path, + sampling_frequency=sorting_data_pp_run.get_sampling_frequency(), + dtype=data_type, + num_channels=sorting_data_pp_run.get_num_channels(), + ) + self.check_recordings_are_the_same( + saved_recording, test_concat_runs, n_split=2 + ) - if "kilosort" in sorting_data.sorter: - saved_recording = si.read_binary( - paths["recording_dat"], - sampling_frequency=sorting_data_pp_run.get_sampling_frequency(), - dtype=data_type, - num_channels=sorting_data_pp_run.get_num_channels(), - ) - self.check_recordings_are_the_same( - saved_recording, test_concat_runs, n_split=2 + self.check_waveforms( + sorter_output_path, + postprocessing_path, + recs_to_test=[sorting_data[ses_name][concat_run_name]], ) - self.check_waveforms( - paths["sorter_output"], - paths["postprocessing"], - recs_to_test=[sorting_data[ses_name][concat_run_name]], - ) - - def check_concatenate_sessions_and_runs(self, test_info, loaded_data, sorting_data): + def check_concatenate_sessions_and_runs( + self, test_info, loaded_data, sorting_data, sort_by_group + ): """ Similar to `check_no_concat_results()` and `check_concatenate_runs_but_not_sessions()`, but now we are checking when `concatenate_sessions=True` and `concatenate_runs=`True`. @@ -549,35 +600,38 @@ def check_concatenate_sessions_and_runs(self, test_info, loaded_data, sorting_da # dtype is converted to original dtype on file writing. test_concat_all = astype(test_concat_all, data_type) + self.check_recordings_are_the_same( + sorted_data_concat_all, test_concat_all, n_split=6 + ) + paths = self.get_output_paths( test_info, + sort_by_group=sort_by_group, ses_name=concat_ses_name, run_name=None, concatenate_sessions=True, concatenate_runs=True, ) + for sorter_output_path, postprocessing_path, recording_dat_path in zip( + paths["sorter_output"], paths["postprocessing"], paths["recording_dat"] + ): + if "kilosort" in sorting_data.sorter: + saved_recording = si.read_binary( + recording_dat_path, + sampling_frequency=sorted_data_concat_all.get_sampling_frequency(), + dtype=data_type, + num_channels=sorted_data_concat_all.get_num_channels(), + ) + self.check_recordings_are_the_same( + saved_recording, test_concat_all, n_split=6 + ) - self.check_recordings_are_the_same( - sorted_data_concat_all, test_concat_all, n_split=6 - ) - - if "kilosort" in sorting_data.sorter: - saved_recording = si.read_binary( - paths["recording_dat"], - sampling_frequency=sorted_data_concat_all.get_sampling_frequency(), - dtype=data_type, - num_channels=sorted_data_concat_all.get_num_channels(), - ) - self.check_recordings_are_the_same( - saved_recording, test_concat_all, n_split=6 + self.check_waveforms( + sorter_output_path, + postprocessing_path, + recs_to_test=[sorting_data[concat_ses_name]], ) - self.check_waveforms( - paths["sorter_output"], - paths["postprocessing"], - recs_to_test=[sorting_data[concat_ses_name]], - ) - def check_recordings_are_the_same(self, rec_1, rec_2, n_split=1): """ Check that two SI recording objects are exactly the same. When the @@ -678,23 +732,26 @@ def check_waveforms( assert np.array_equal(data, first_unit_waveforms[0]) def write_an_empty_file_in_outputs( - self, test_info, ses_name, run_name, sorter=DEFAULT_SORTER + self, test_info, ses_name, run_name, sort_by_group, sorter=DEFAULT_SORTER ): """ Write a file called `test_file.txt` with contents `test_file` in the preprocessed, sorting and postprocessing output path for this session / run. """ - paths = self.get_output_paths(test_info, ses_name, run_name, sorter=sorter) + paths = self.get_output_paths( + test_info, ses_name, run_name, sorter=sorter, sort_by_group=sort_by_group + ) + + paths_to_write = [paths["preprocessing"] / "test_file.txt"] - paths_to_write = [] - for output in ["preprocessing", "sorting_path", "postprocessing"]: - paths_to_write.append(paths[output] / "test_file.txt") + for output in ["sorting_path", "postprocessing"]: + for group_path in paths[output]: + paths_to_write.append(group_path / "test_file.txt") for path_ in paths_to_write: - with open(path_, "w") as file: + with open(path_.as_posix(), "w") as file: file.write("test file.") - return paths_to_write def get_output_paths( @@ -702,6 +759,7 @@ def get_output_paths( test_info, ses_name, run_name, + sort_by_group=False, sorter=DEFAULT_SORTER, concatenate_sessions=False, concatenate_runs=False, @@ -746,14 +804,185 @@ def get_output_paths( paths = { "preprocessing": run_path / "preprocessing", - "sorting_path": run_path / sorter / "sorting", - "postprocessing": run_path / sorter / "postprocessing", + "postprocessing": [], + "sorting_path": [], + "sorter_output": [], + "recording_dat": [], } - paths["sorter_output"] = paths["sorting_path"] / "sorter_output" - paths["recording_dat"] = paths["sorter_output"] / "recording.dat" + + if sort_by_group: + all_groups = sorted((run_path / sorter).glob("group-*")) + assert any(all_groups), "Groups output not found." + + for group in all_groups: + sorting_path = group / "sorting" + paths["sorting_path"].append(sorting_path) + paths["postprocessing"].append(group / "postprocessing") + paths["sorter_output"].append(sorting_path / "sorter_output") + paths["recording_dat"].append( + sorting_path / "sorter_output" / "recording.dat" + ) # TODO: this is only for kilosort! + else: + sorting_path = run_path / sorter / "sorting" + paths["sorting_path"] = [sorting_path] + paths["postprocessing"] = [run_path / sorter / "postprocessing"] + paths["sorter_output"] = [sorting_path / "sorter_output"] + paths["recording_dat"] = [sorting_path / "sorter_output" / "recording.dat"] return paths + @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) + def test_sort_by_group_concat_sessions(self, test_info): + preprocess_data, sorting_data = self.run_full_pipeline( + *test_info, + data_format=DEFAULT_FORMAT, + concatenate_runs=True, + concatenate_sessions=True, + sort_by_group=True, + existing_preprocessed_data="overwrite", + existing_sorting_output="overwrite", + overwrite_postprocessing=True, + sorter=DEFAULT_SORTER, + ) + + concat_ses_name = list(sorting_data.keys())[0] + + prepo_recordings = [ + val["3-raw-phase_shift-bandpass_filter-common_reference"] + for ses_name in preprocess_data.keys() + for val in preprocess_data[ses_name].values() + ] + + test_preprocessed = concatenate_recordings(prepo_recordings) + + sorting_output_paths = self.get_output_paths( + test_info, + concat_ses_name, + run_name=None, + sorter=DEFAULT_SORTER, + concatenate_sessions=True, + concatenate_runs=True, + sort_by_group=True, + )["sorting_path"] + + self.check_sorting_is_correct(test_preprocessed, sorting_output_paths) + + @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) + def test_sort_by_group_concat_runs_not_sessions(self, test_info): + preprocess_data, sorting_data = self.run_full_pipeline( + *test_info, + data_format=DEFAULT_FORMAT, + concatenate_runs=True, + concatenate_sessions=False, + sort_by_group=True, + existing_preprocessed_data="overwrite", + existing_sorting_output="overwrite", + overwrite_postprocessing=True, + sorter=DEFAULT_SORTER, + ) + + base_path, sub_name, sessions_and_runs = test_info + + for ses_name in sessions_and_runs.keys(): + concat_run_name = list(sorting_data[ses_name].keys())[0] + + prepo_recordings = [ + val["3-raw-phase_shift-bandpass_filter-common_reference"] + for val in preprocess_data[ses_name].values() + ] + test_preprocessed = concatenate_recordings(prepo_recordings) + + sorting_output_paths = self.get_output_paths( + test_info, + ses_name, + run_name=concat_run_name, + sorter=DEFAULT_SORTER, + concatenate_sessions=False, + concatenate_runs=True, + sort_by_group=True, + )["sorting_path"] + + self.check_sorting_is_correct(test_preprocessed, sorting_output_paths) + + @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) + def test_sort_by_group_no_concat(self, test_info): + self.run_full_pipeline( + *test_info, + data_format=DEFAULT_FORMAT, + concatenate_runs=False, + concatenate_sessions=False, + sort_by_group=True, + existing_preprocessed_data="overwrite", + existing_sorting_output="overwrite", + overwrite_postprocessing=True, + sorter=DEFAULT_SORTER, + ) + + base_path, sub_name, sessions_and_runs = test_info + + for ses_name in sessions_and_runs.keys(): + for run_name in sessions_and_runs[ses_name]: + sorting_output_paths = self.get_output_paths( + test_info, + ses_name, + run_name=run_name, + sorter=DEFAULT_SORTER, + sort_by_group=True, + )["sorting_path"] + + _, test_preprocessed = self.get_test_rawdata_and_preprocessed_data( + base_path, sub_name, ses_name, run_name + ) + self.check_sorting_is_correct(test_preprocessed, sorting_output_paths) + + def check_sorting_is_correct(self, test_preprocessed, sorting_output_paths): + """""" + split_recording = test_preprocessed.split_by("group") + + if "kilosort" in DEFAULT_SORTER: + singularity_image = True if platform.system() == "Linux" else False + docker_image = not singularity_image + else: + singularity_image = docker_image = False + + sortings = {} + for group, sub_recording in split_recording.items(): + sorting = sorters.run_sorter( + sorter_name=DEFAULT_SORTER, + recording=sub_recording, + output_folder=None, + docker_image=docker_image, + singularity_image=singularity_image, + remove_existing_folder=True, + **{ + "scheme": "2", + "filter": False, + "whiten": False, + "verbose": True, + }, + ) + + sortings[group] = sorting + + assert len(sorting_output_paths) > 1, "Groups output not found." + + for idx, path_ in enumerate(sorting_output_paths): + group_sorting = load_saved_sorting_output( + path_ / "sorter_output", DEFAULT_SORTER + ) + + assert np.array_equal( + group_sorting.get_unit_ids(), sortings[idx].get_unit_ids() + ) + + for unit in group_sorting.get_unit_ids(): + assert np.allclose( + group_sorting.get_unit_spike_train(unit), + sortings[idx].get_unit_spike_train(unit), + rtol=0, + atol=1e-10, + ), f"{idx}, {group_sorting}, {sortings}" + # ---------------------------------------------------------------------------------- # Getters # ---------------------------------------------------------------------------------- diff --git a/tests/test_integration/test_slurm.py b/tests/test_integration/test_slurm.py index 9a4522c..da6d78a 100644 --- a/tests/test_integration/test_slurm.py +++ b/tests/test_integration/test_slurm.py @@ -28,6 +28,9 @@ class TestSLURM(BaseTest): # TODO: cannot test the actual output. # can test recording at least. + # TODO: this is just a smoke test. Need to test against actual sorting + # to ensure matches as expected. Missed case where sorter was not passed + # and default was used! @pytest.mark.skipif(CAN_SLURM is False, reason="CAN_SLURM is false") @pytest.mark.parametrize( "concatenation", [(False, False), (False, True), (True, True)]