diff --git a/config.py b/config.py index 2a8ee39db..9234255d5 100644 --- a/config.py +++ b/config.py @@ -2020,6 +2020,14 @@ def gen_log_kwargs( return kwargs +############################################################################### +# Private config vars (not to be set by user) +# ------------------------------------------- + +_raw_split_size = '2GB' +_epochs_split_size = '2GB' + + ############################################################################### # Retrieve custom configuration options # ------------------------------------- @@ -3824,3 +3832,25 @@ def save_logs(logs): 'configuration. Currently the `conditions` parameter is empty. ' 'This is only allowed for resting-state analysis.') raise ValueError(msg) + + +def _update_for_splits(files_dict, key, *, single=False): + if not isinstance(files_dict, dict): # fake it + assert key is None + files_dict, key = dict(x=files_dict), 'x' + bids_path = files_dict[key] + if bids_path.fpath.exists(): + return bids_path # no modifications needed + bids_path = bids_path.copy().update(split='01') + assert bids_path.fpath.exists(), f'Missing file: {bids_path.fpath}' + files_dict[key] = bids_path + # if we only need the first file (i.e., when reading), quit now + if single: + return bids_path + for split in range(2, 100): + split_key = f'{split:02d}' + bids_path_next = bids_path.copy().update(split=split_key) + if not bids_path_next.fpath.exists(): + break + files_dict[f'{key}_split-{split_key}'] = bids_path_next + return bids_path diff --git a/docs/source/changes.md b/docs/source/changes.md index b6fa8afd6..da33dc1c3 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -352,3 +352,5 @@ authors: - EEG channels couldn't be used as "virtual" EOG channels during ICA artifact detection. Reported by "fraenni" on the forum. Thank you! 🌻 ({{ gh(572) }} by {{ authors.hoechenberger }}) +- Fix bug with handling of split files during preprocessing + ({{ gh(597) }} by {{ authors.larsoner }}) diff --git a/run.py b/run.py index 421787830..c9bb2c503 100755 --- a/run.py +++ b/run.py @@ -15,6 +15,10 @@ import coloredlogs +# Ensure that the "scripts" that we import from is the correct one +sys.path.insert(0, str(pathlib.Path(__file__).parent)) + + logger = logging.getLogger(__name__) log_level_styles = { diff --git a/scripts/preprocessing/_01_maxfilter.py b/scripts/preprocessing/_01_maxfilter.py index efcb984e1..0c021b19a 100644 --- a/scripts/preprocessing/_01_maxfilter.py +++ b/scripts/preprocessing/_01_maxfilter.py @@ -28,7 +28,8 @@ import config from config import (gen_log_kwargs, on_error, failsafe_run, - import_experimental_data, import_er_data, import_rest_data) + import_experimental_data, import_er_data, import_rest_data, + _update_for_splits) from config import parallel_func logger = logging.getLogger('mne-bids-pipeline') @@ -81,7 +82,6 @@ def run_maxwell_filter(*, cfg, subject, session=None, run=None, in_files=None): raise ValueError(f'You cannot set use_maxwell_filter to True ' f'if data have already processed with Maxwell-filter.' f' Got proc={config.proc}.') - bids_path_in = in_files[f"raw_run-{run}"] bids_path_out = bids_path_in.copy().update( processing="sss", @@ -154,13 +154,16 @@ def run_maxwell_filter(*, cfg, subject, session=None, run=None, in_files=None): logger.info(**gen_log_kwargs( message=msg, subject=subject, session=session, run=run)) raw_sss.save(out_files['sss_raw'], picks=picks, split_naming='bids', - overwrite=True) + overwrite=True, split_size=cfg._raw_split_size) + # we need to be careful about split files + _update_for_splits(out_files, 'sss_raw') del raw, raw_sss if cfg.interactive: # Load the data we have just written, because it contains only # the relevant channels. - raw_sss = mne.io.read_raw_fif(bids_path_out, allow_maxshield=True) + raw_sss = mne.io.read_raw_fif( + out_files['sss_raw'], allow_maxshield=True) raw_sss.plot(n_channels=50, butterfly=True, block=True) del raw_sss @@ -211,7 +214,7 @@ def run_maxwell_filter(*, cfg, subject, session=None, run=None, in_files=None): # copy the bad channel selection from the reference run over to # the resting-state recording. - raw_sss = mne.io.read_raw_fif(bids_path_out) + raw_sss = mne.io.read_raw_fif(out_files['sss_raw']) rank_exp = mne.compute_rank(raw_sss, rank='info')['meg'] rank_noise = mne.compute_rank(raw_noise_sss, rank='info')['meg'] @@ -248,8 +251,9 @@ def run_maxwell_filter(*, cfg, subject, session=None, run=None, in_files=None): message=msg, subject=subject, session=session, run=run)) raw_noise_sss.save( out_files['sss_noise'], picks=picks, overwrite=True, - split_naming='bids' + split_naming='bids', split_size=cfg._raw_split_size, ) + _update_for_splits(out_files, 'sss_noise') del raw_noise_sss return {key: pth.fpath for key, pth in out_files.items()} @@ -290,6 +294,7 @@ def get_config( min_break_duration=config.min_break_duration, t_break_annot_start_after_previous_event=config.t_break_annot_start_after_previous_event, # noqa:E501 t_break_annot_stop_before_next_event=config.t_break_annot_stop_before_next_event, # noqa:E501 + _raw_split_size=config._raw_split_size, ) return cfg diff --git a/scripts/preprocessing/_02_frequency_filter.py b/scripts/preprocessing/_02_frequency_filter.py index 3ab23ecea..1aae83de3 100644 --- a/scripts/preprocessing/_02_frequency_filter.py +++ b/scripts/preprocessing/_02_frequency_filter.py @@ -35,7 +35,8 @@ import config from config import (gen_log_kwargs, on_error, failsafe_run, - import_experimental_data, import_er_data, import_rest_data) + import_experimental_data, import_er_data, import_rest_data, + _update_for_splits) from config import parallel_func @@ -75,11 +76,10 @@ def get_input_fnames_frequency_filter(**kwargs): if cfg.use_maxwell_filter: bids_path_in.update(processing="sss") - if bids_path_in.copy().update(split='01').fpath.exists(): - bids_path_in = bids_path_in.update(split='01') in_files = dict() in_files[f'raw_run-{run}'] = bids_path_in + _update_for_splits(in_files, f'raw_run-{run}', single=True) if (cfg.process_er or config.noise_cov == 'rest') and run == cfg.runs[0]: noise_task = "rest" if config.noise_cov == "rest" else "noise" @@ -87,9 +87,8 @@ def get_input_fnames_frequency_filter(**kwargs): raw_noise_fname_in = bids_path_in.copy().update( run=None, task=noise_task ) - if raw_noise_fname_in.copy().update(split='01').fpath.exists(): - raw_noise_fname_in.update(split='01') in_files["raw_noise"] = raw_noise_fname_in + _update_for_splits(in_files, "raw_noise", single=True) else: if config.noise_cov == 'rest': in_files["raw_rest"] = bids_path_in.copy().update( @@ -189,7 +188,7 @@ def filter_data( out_files['raw_filt'] = bids_path.copy().update( root=cfg.deriv_root, processing='filt', extension='.fif', - suffix='raw') + suffix='raw', split=None) raw.load_data() filter( raw=raw, subject=subject, session=session, run=run, @@ -201,7 +200,9 @@ def filter_data( resample(raw=raw, subject=subject, session=session, run=run, sfreq=cfg.resample_sfreq, data_type='experimental') - raw.save(out_files['raw_filt'], overwrite=True, split_naming='bids') + raw.save(out_files['raw_filt'], overwrite=True, split_naming='bids', + split_size=cfg._raw_split_size) + _update_for_splits(out_files, 'raw_filt') if cfg.interactive: # Plot raw data and power spectral density. raw.plot(n_channels=50, butterfly=True) @@ -238,7 +239,7 @@ def filter_data( out_files['raw_noise_filt'] = \ bids_path_noise.copy().update( root=cfg.deriv_root, processing='filt', extension='.fif', - suffix='raw') + suffix='raw', split=None) raw_noise.load_data() filter( @@ -252,8 +253,10 @@ def filter_data( sfreq=cfg.resample_sfreq, data_type=data_type) raw_noise.save( - out_files['raw_noise_filt'], overwrite=True, split_naming='bids' + out_files['raw_noise_filt'], overwrite=True, split_naming='bids', + split_size=cfg._raw_split_size, ) + _update_for_splits(out_files, 'raw_noise_filt') if cfg.interactive: # Plot raw data and power spectral density. raw_noise.plot(n_channels=50, butterfly=True) @@ -301,6 +304,7 @@ def get_config( min_break_duration=config.min_break_duration, t_break_annot_start_after_previous_event=config.t_break_annot_start_after_previous_event, # noqa:E501 t_break_annot_stop_before_next_event=config.t_break_annot_stop_before_next_event, # noqa:E501 + _raw_split_size=config._raw_split_size, ) return cfg diff --git a/scripts/preprocessing/_03_make_epochs.py b/scripts/preprocessing/_03_make_epochs.py index 7a70b45db..6e5c4198c 100644 --- a/scripts/preprocessing/_03_make_epochs.py +++ b/scripts/preprocessing/_03_make_epochs.py @@ -20,7 +20,7 @@ import config from config import make_epochs, gen_log_kwargs, on_error, failsafe_run -from config import parallel_func +from config import parallel_func, _update_for_splits logger = logging.getLogger('mne-bids-pipeline') @@ -44,10 +44,7 @@ def run_epochs(*, cfg, subject, session=None): for run in cfg.runs: raw_fname_in = bids_path.copy().update(run=run, processing='filt', suffix='raw', check=False) - - if raw_fname_in.copy().update(split='01').fpath.exists(): - raw_fname_in.update(split='01') - + raw_fname_in = _update_for_splits(raw_fname_in, None, single=True) raw_fnames.append(raw_fname_in) # Generate a unique event name -> event code mapping that can be used @@ -191,7 +188,10 @@ def run_epochs(*, cfg, subject, session=None): logger.info(**gen_log_kwargs(message=msg, subject=subject, session=session)) epochs_fname = bids_path.copy().update(suffix='epo', check=False) - epochs.save(epochs_fname, overwrite=True, split_naming='bids') + epochs.save( + epochs_fname, overwrite=True, split_naming='bids', + split_size=cfg._epochs_split_size) + # _update_for_splits(out_files, 'epochs') if cfg.interactive: epochs.plot() @@ -228,7 +228,8 @@ def get_config( event_repeated=config.event_repeated, decim=config.decim, ch_types=config.ch_types, - eeg_reference=config.get_eeg_reference() + eeg_reference=config.get_eeg_reference(), + _epochs_split_size=config._epochs_split_size, ) return cfg diff --git a/scripts/preprocessing/_04a_run_ica.py b/scripts/preprocessing/_04a_run_ica.py index 308371c9a..aa5967407 100644 --- a/scripts/preprocessing/_04a_run_ica.py +++ b/scripts/preprocessing/_04a_run_ica.py @@ -33,7 +33,7 @@ import config from config import (make_epochs, gen_log_kwargs, on_error, failsafe_run, - annotations_to_events) + annotations_to_events, _update_for_splits) from config import parallel_func @@ -259,9 +259,7 @@ def run_ica(*, cfg, subject, session=None): raw_fnames = [] for run in cfg.runs: raw_fname.update(run=run) - if raw_fname.copy().update(split='01').fpath.exists(): - raw_fname.update(split='01') - + raw_fname = _update_for_splits(raw_fname, None, single=True) raw_fnames.append(raw_fname.copy()) # Generate a unique event name -> event code mapping that can be used diff --git a/scripts/preprocessing/_04b_run_ssp.py b/scripts/preprocessing/_04b_run_ssp.py index 999b121fb..cd1e6a8c1 100644 --- a/scripts/preprocessing/_04b_run_ssp.py +++ b/scripts/preprocessing/_04b_run_ssp.py @@ -18,7 +18,7 @@ import config from config import gen_log_kwargs, on_error, failsafe_run -from config import parallel_func +from config import parallel_func, _update_for_splits logger = logging.getLogger('mne-bids-pipeline') @@ -49,9 +49,7 @@ def run_ssp(*, cfg, subject, session=None): msg = f'Input: {raw_fname_in.basename}, Output: {proj_fname_out.basename}' logger.info(**gen_log_kwargs(message=msg, subject=subject, session=session)) - - if raw_fname_in.copy().update(split='01').fpath.exists(): - raw_fname_in.update(split='01') + raw_fname_in = _update_for_splits(raw_fname_in, None, single=True) raw = mne.io.read_raw_fif(raw_fname_in) msg = 'Computing SSPs for ECG' diff --git a/scripts/preprocessing/_05a_apply_ica.py b/scripts/preprocessing/_05a_apply_ica.py index a0ce08af4..170b15fc4 100644 --- a/scripts/preprocessing/_05a_apply_ica.py +++ b/scripts/preprocessing/_05a_apply_ica.py @@ -93,7 +93,10 @@ def apply_ica(*, cfg, subject, session): msg = 'Saving reconstructed epochs after ICA.' logger.info(**gen_log_kwargs(message=msg, subject=subject, session=session)) - epochs_cleaned.save(fname_epo_out, overwrite=True, split_naming='bids') + epochs_cleaned.save( + fname_epo_out, overwrite=True, split_naming='bids', + split_size=cfg._epochs_split_size) + # _update_for_splits(out_files, 'epochs_cleaned') # Compare ERP/ERF before and after ICA artifact rejection. The evoked # response is calculated across ALL epochs, just like ICA was run on @@ -127,7 +130,8 @@ def get_config( deriv_root=config.get_deriv_root(), interactive=config.interactive, baseline=config.baseline, - ica_reject=config.get_ica_reject() + ica_reject=config.get_ica_reject(), + _epochs_split_size=config._epochs_split_size, ) return cfg diff --git a/scripts/preprocessing/_05b_apply_ssp.py b/scripts/preprocessing/_05b_apply_ssp.py index 9c693cba4..c13c1ad07 100644 --- a/scripts/preprocessing/_05b_apply_ssp.py +++ b/scripts/preprocessing/_05b_apply_ssp.py @@ -62,7 +62,10 @@ def apply_ssp(*, cfg, subject, session=None): msg = 'Saving epochs with projectors.' logger.info(**gen_log_kwargs(message=msg, subject=subject, session=session)) - epochs_cleaned.save(fname_out, overwrite=True, split_naming='bids') + epochs_cleaned.save( + fname_out, overwrite=True, split_naming='bids', + split_size=cfg._epochs_split_size) + # _update_for_splits(out_files, 'epochs_cleaned') def get_config( @@ -76,6 +79,7 @@ def get_config( rec=config.rec, space=config.space, deriv_root=config.get_deriv_root(), + _epochs_split_size=config._epochs_split_size, ) return cfg diff --git a/scripts/preprocessing/_06_ptp_reject.py b/scripts/preprocessing/_06_ptp_reject.py index 5c8c4537c..8d8439730 100644 --- a/scripts/preprocessing/_06_ptp_reject.py +++ b/scripts/preprocessing/_06_ptp_reject.py @@ -89,7 +89,9 @@ def drop_ptp(*, cfg, subject, session=None): msg = 'Saving cleaned, baseline-corrected epochs …' epochs.apply_baseline(cfg.baseline) - epochs.save(fname_out, overwrite=True, split_naming='bids') + epochs.save( + fname_out, overwrite=True, split_naming='bids', + split_size=cfg._epochs_split_size) def get_config( @@ -108,7 +110,8 @@ def get_config( spatial_filter=config.spatial_filter, ica_reject=config.get_ica_reject(), deriv_root=config.get_deriv_root(), - decim=config.decim + decim=config.decim, + _epochs_split_size=config._epochs_split_size, ) return cfg diff --git a/scripts/report/_01_make_reports.py b/scripts/report/_01_make_reports.py index 67bd5cfd4..3fee6e83c 100644 --- a/scripts/report/_01_make_reports.py +++ b/scripts/report/_01_make_reports.py @@ -26,7 +26,7 @@ import config from config import ( gen_log_kwargs, on_error, failsafe_run, parallel_func, - get_noise_cov_bids_path + get_noise_cov_bids_path, _update_for_splits, ) @@ -52,10 +52,7 @@ def get_events(cfg, subject, session): for run in cfg.runs: this_raw_fname = raw_fname.copy().update(run=run) - - if this_raw_fname.copy().update(split='01').fpath.exists(): - this_raw_fname.update(split='01') - + this_raw_fname = _update_for_splits(this_raw_fname, None, single=True) raw_filt = mne.io.read_raw_fif(this_raw_fname) raws_filt.append(raw_filt) del this_raw_fname @@ -81,10 +78,7 @@ def get_er_path(cfg, subject, session): datatype=cfg.datatype, root=cfg.deriv_root, check=False) - - if raw_fname.copy().update(split='01').fpath.exists(): - raw_fname.update(split='01') - + raw_fname = _update_for_splits(raw_fname, None, single=True) return raw_fname @@ -482,9 +476,7 @@ def run_report_preprocessing( run=run, processing='filt', suffix='raw', check=False ) - if fname.copy().update(split='01').fpath.exists(): - fname.update(split='01') - + fname = _update_for_splits(fname, None, single=True) fnames_raw_filt.append(fname) fname_epo_not_clean = bids_path.copy().update(suffix='epo') diff --git a/tests/configs/config_ds000248.py b/tests/configs/config_ds000248.py index c4b2e9401..098aef3d5 100644 --- a/tests/configs/config_ds000248.py +++ b/tests/configs/config_ds000248.py @@ -24,6 +24,8 @@ find_noisy_channels_meg = True use_maxwell_filter = True process_er = True +_raw_split_size = '60MB' # hits both task-noise and task-audiovisual +_epochs_split_size = '30MB' def noise_cov(bp): diff --git a/tests/configs/config_ds003775.py b/tests/configs/config_ds003775.py index ed5ffa493..b317cc40d 100644 --- a/tests/configs/config_ds003775.py +++ b/tests/configs/config_ds003775.py @@ -37,7 +37,6 @@ parallel_backend = 'loky' dask_open_dashboard = True -on_error = 'continue' log_level = 'info' N_JOBS = 1