Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Fix bugs with split files #597

Merged
merged 3 commits into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2020,6 +2020,14 @@ def gen_log_kwargs(
return kwargs


###############################################################################
# Private config vars (not to be set by user)
# -------------------------------------------

_raw_split_size = '2GB'
_epochs_split_size = '2GB'


###############################################################################
# Retrieve custom configuration options
# -------------------------------------
Expand Down Expand Up @@ -3824,3 +3832,25 @@ def save_logs(logs):
'configuration. Currently the `conditions` parameter is empty. '
'This is only allowed for resting-state analysis.')
raise ValueError(msg)


def _update_for_splits(files_dict, key, *, single=False):
if not isinstance(files_dict, dict): # fake it
assert key is None
files_dict, key = dict(x=files_dict), 'x'
bids_path = files_dict[key]
if bids_path.fpath.exists():
return bids_path # no modifications needed
bids_path = bids_path.copy().update(split='01')
assert bids_path.fpath.exists(), f'Missing file: {bids_path.fpath}'
files_dict[key] = bids_path
# if we only need the first file (i.e., when reading), quit now
if single:
return bids_path
for split in range(2, 100):
split_key = f'{split:02d}'
bids_path_next = bids_path.copy().update(split=split_key)
if not bids_path_next.fpath.exists():
break
files_dict[f'{key}_split-{split_key}'] = bids_path_next
return bids_path
2 changes: 2 additions & 0 deletions docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,5 @@ authors:
- EEG channels couldn't be used as "virtual" EOG channels during ICA artifact
detection. Reported by "fraenni" on the forum. Thank you! 🌻
({{ gh(572) }} by {{ authors.hoechenberger }})
- Fix bug with handling of split files during preprocessing
({{ gh(597) }} by {{ authors.larsoner }})
4 changes: 4 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
import coloredlogs


# Ensure that the "scripts" that we import from is the correct one
sys.path.insert(0, str(pathlib.Path(__file__).parent))


logger = logging.getLogger(__name__)

log_level_styles = {
Expand Down
17 changes: 11 additions & 6 deletions scripts/preprocessing/_01_maxfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

import config
from config import (gen_log_kwargs, on_error, failsafe_run,
import_experimental_data, import_er_data, import_rest_data)
import_experimental_data, import_er_data, import_rest_data,
_update_for_splits)
from config import parallel_func

logger = logging.getLogger('mne-bids-pipeline')
Expand Down Expand Up @@ -81,7 +82,6 @@ def run_maxwell_filter(*, cfg, subject, session=None, run=None, in_files=None):
raise ValueError(f'You cannot set use_maxwell_filter to True '
f'if data have already processed with Maxwell-filter.'
f' Got proc={config.proc}.')

bids_path_in = in_files[f"raw_run-{run}"]
bids_path_out = bids_path_in.copy().update(
processing="sss",
Expand Down Expand Up @@ -154,13 +154,16 @@ def run_maxwell_filter(*, cfg, subject, session=None, run=None, in_files=None):
logger.info(**gen_log_kwargs(
message=msg, subject=subject, session=session, run=run))
raw_sss.save(out_files['sss_raw'], picks=picks, split_naming='bids',
overwrite=True)
overwrite=True, split_size=cfg._raw_split_size)
# we need to be careful about split files
_update_for_splits(out_files, 'sss_raw')
del raw, raw_sss

if cfg.interactive:
# Load the data we have just written, because it contains only
# the relevant channels.
raw_sss = mne.io.read_raw_fif(bids_path_out, allow_maxshield=True)
raw_sss = mne.io.read_raw_fif(
out_files['sss_raw'], allow_maxshield=True)
raw_sss.plot(n_channels=50, butterfly=True, block=True)
del raw_sss

Expand Down Expand Up @@ -211,7 +214,7 @@ def run_maxwell_filter(*, cfg, subject, session=None, run=None, in_files=None):
# copy the bad channel selection from the reference run over to
# the resting-state recording.

raw_sss = mne.io.read_raw_fif(bids_path_out)
raw_sss = mne.io.read_raw_fif(out_files['sss_raw'])
rank_exp = mne.compute_rank(raw_sss, rank='info')['meg']
rank_noise = mne.compute_rank(raw_noise_sss, rank='info')['meg']

Expand Down Expand Up @@ -248,8 +251,9 @@ def run_maxwell_filter(*, cfg, subject, session=None, run=None, in_files=None):
message=msg, subject=subject, session=session, run=run))
raw_noise_sss.save(
out_files['sss_noise'], picks=picks, overwrite=True,
split_naming='bids'
split_naming='bids', split_size=cfg._raw_split_size,
)
_update_for_splits(out_files, 'sss_noise')
del raw_noise_sss
return {key: pth.fpath for key, pth in out_files.items()}

Expand Down Expand Up @@ -290,6 +294,7 @@ def get_config(
min_break_duration=config.min_break_duration,
t_break_annot_start_after_previous_event=config.t_break_annot_start_after_previous_event, # noqa:E501
t_break_annot_stop_before_next_event=config.t_break_annot_stop_before_next_event, # noqa:E501
_raw_split_size=config._raw_split_size,
)
return cfg

Expand Down
22 changes: 13 additions & 9 deletions scripts/preprocessing/_02_frequency_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@

import config
from config import (gen_log_kwargs, on_error, failsafe_run,
import_experimental_data, import_er_data, import_rest_data)
import_experimental_data, import_er_data, import_rest_data,
_update_for_splits)
from config import parallel_func


Expand Down Expand Up @@ -75,21 +76,19 @@ def get_input_fnames_frequency_filter(**kwargs):

if cfg.use_maxwell_filter:
bids_path_in.update(processing="sss")
if bids_path_in.copy().update(split='01').fpath.exists():
bids_path_in = bids_path_in.update(split='01')

in_files = dict()
in_files[f'raw_run-{run}'] = bids_path_in
_update_for_splits(in_files, f'raw_run-{run}', single=True)

if (cfg.process_er or config.noise_cov == 'rest') and run == cfg.runs[0]:
noise_task = "rest" if config.noise_cov == "rest" else "noise"
if cfg.use_maxwell_filter:
raw_noise_fname_in = bids_path_in.copy().update(
run=None, task=noise_task
)
if raw_noise_fname_in.copy().update(split='01').fpath.exists():
raw_noise_fname_in.update(split='01')
in_files["raw_noise"] = raw_noise_fname_in
_update_for_splits(in_files, "raw_noise", single=True)
else:
if config.noise_cov == 'rest':
in_files["raw_rest"] = bids_path_in.copy().update(
Expand Down Expand Up @@ -189,7 +188,7 @@ def filter_data(

out_files['raw_filt'] = bids_path.copy().update(
root=cfg.deriv_root, processing='filt', extension='.fif',
suffix='raw')
suffix='raw', split=None)
raw.load_data()
filter(
raw=raw, subject=subject, session=session, run=run,
Expand All @@ -201,7 +200,9 @@ def filter_data(
resample(raw=raw, subject=subject, session=session, run=run,
sfreq=cfg.resample_sfreq, data_type='experimental')

raw.save(out_files['raw_filt'], overwrite=True, split_naming='bids')
raw.save(out_files['raw_filt'], overwrite=True, split_naming='bids',
split_size=cfg._raw_split_size)
_update_for_splits(out_files, 'raw_filt')
if cfg.interactive:
# Plot raw data and power spectral density.
raw.plot(n_channels=50, butterfly=True)
Expand Down Expand Up @@ -238,7 +239,7 @@ def filter_data(
out_files['raw_noise_filt'] = \
bids_path_noise.copy().update(
root=cfg.deriv_root, processing='filt', extension='.fif',
suffix='raw')
suffix='raw', split=None)

raw_noise.load_data()
filter(
Expand All @@ -252,8 +253,10 @@ def filter_data(
sfreq=cfg.resample_sfreq, data_type=data_type)

raw_noise.save(
out_files['raw_noise_filt'], overwrite=True, split_naming='bids'
out_files['raw_noise_filt'], overwrite=True, split_naming='bids',
split_size=cfg._raw_split_size,
)
_update_for_splits(out_files, 'raw_noise_filt')
if cfg.interactive:
# Plot raw data and power spectral density.
raw_noise.plot(n_channels=50, butterfly=True)
Expand Down Expand Up @@ -301,6 +304,7 @@ def get_config(
min_break_duration=config.min_break_duration,
t_break_annot_start_after_previous_event=config.t_break_annot_start_after_previous_event, # noqa:E501
t_break_annot_stop_before_next_event=config.t_break_annot_stop_before_next_event, # noqa:E501
_raw_split_size=config._raw_split_size,
)
return cfg

Expand Down
15 changes: 8 additions & 7 deletions scripts/preprocessing/_03_make_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import config
from config import make_epochs, gen_log_kwargs, on_error, failsafe_run
from config import parallel_func
from config import parallel_func, _update_for_splits

logger = logging.getLogger('mne-bids-pipeline')

Expand All @@ -44,10 +44,7 @@ def run_epochs(*, cfg, subject, session=None):
for run in cfg.runs:
raw_fname_in = bids_path.copy().update(run=run, processing='filt',
suffix='raw', check=False)

if raw_fname_in.copy().update(split='01').fpath.exists():
raw_fname_in.update(split='01')

raw_fname_in = _update_for_splits(raw_fname_in, None, single=True)
raw_fnames.append(raw_fname_in)

# Generate a unique event name -> event code mapping that can be used
Expand Down Expand Up @@ -191,7 +188,10 @@ def run_epochs(*, cfg, subject, session=None):
logger.info(**gen_log_kwargs(message=msg, subject=subject,
session=session))
epochs_fname = bids_path.copy().update(suffix='epo', check=False)
epochs.save(epochs_fname, overwrite=True, split_naming='bids')
epochs.save(
epochs_fname, overwrite=True, split_naming='bids',
split_size=cfg._epochs_split_size)
# _update_for_splits(out_files, 'epochs')

if cfg.interactive:
epochs.plot()
Expand Down Expand Up @@ -228,7 +228,8 @@ def get_config(
event_repeated=config.event_repeated,
decim=config.decim,
ch_types=config.ch_types,
eeg_reference=config.get_eeg_reference()
eeg_reference=config.get_eeg_reference(),
_epochs_split_size=config._epochs_split_size,
)
return cfg

Expand Down
6 changes: 2 additions & 4 deletions scripts/preprocessing/_04a_run_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

import config
from config import (make_epochs, gen_log_kwargs, on_error, failsafe_run,
annotations_to_events)
annotations_to_events, _update_for_splits)
from config import parallel_func


Expand Down Expand Up @@ -259,9 +259,7 @@ def run_ica(*, cfg, subject, session=None):
raw_fnames = []
for run in cfg.runs:
raw_fname.update(run=run)
if raw_fname.copy().update(split='01').fpath.exists():
raw_fname.update(split='01')

raw_fname = _update_for_splits(raw_fname, None, single=True)
raw_fnames.append(raw_fname.copy())

# Generate a unique event name -> event code mapping that can be used
Expand Down
6 changes: 2 additions & 4 deletions scripts/preprocessing/_04b_run_ssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import config
from config import gen_log_kwargs, on_error, failsafe_run
from config import parallel_func
from config import parallel_func, _update_for_splits


logger = logging.getLogger('mne-bids-pipeline')
Expand Down Expand Up @@ -49,9 +49,7 @@ def run_ssp(*, cfg, subject, session=None):
msg = f'Input: {raw_fname_in.basename}, Output: {proj_fname_out.basename}'
logger.info(**gen_log_kwargs(message=msg, subject=subject,
session=session))

if raw_fname_in.copy().update(split='01').fpath.exists():
raw_fname_in.update(split='01')
raw_fname_in = _update_for_splits(raw_fname_in, None, single=True)

raw = mne.io.read_raw_fif(raw_fname_in)
msg = 'Computing SSPs for ECG'
Expand Down
8 changes: 6 additions & 2 deletions scripts/preprocessing/_05a_apply_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ def apply_ica(*, cfg, subject, session):
msg = 'Saving reconstructed epochs after ICA.'
logger.info(**gen_log_kwargs(message=msg, subject=subject,
session=session))
epochs_cleaned.save(fname_epo_out, overwrite=True, split_naming='bids')
epochs_cleaned.save(
fname_epo_out, overwrite=True, split_naming='bids',
split_size=cfg._epochs_split_size)
# _update_for_splits(out_files, 'epochs_cleaned')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to cleanup

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were intentionally left in to remind us to do it when working on caching :)


# Compare ERP/ERF before and after ICA artifact rejection. The evoked
# response is calculated across ALL epochs, just like ICA was run on
Expand Down Expand Up @@ -127,7 +130,8 @@ def get_config(
deriv_root=config.get_deriv_root(),
interactive=config.interactive,
baseline=config.baseline,
ica_reject=config.get_ica_reject()
ica_reject=config.get_ica_reject(),
_epochs_split_size=config._epochs_split_size,
)
return cfg

Expand Down
6 changes: 5 additions & 1 deletion scripts/preprocessing/_05b_apply_ssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def apply_ssp(*, cfg, subject, session=None):
msg = 'Saving epochs with projectors.'
logger.info(**gen_log_kwargs(message=msg, subject=subject,
session=session))
epochs_cleaned.save(fname_out, overwrite=True, split_naming='bids')
epochs_cleaned.save(
fname_out, overwrite=True, split_naming='bids',
split_size=cfg._epochs_split_size)
# _update_for_splits(out_files, 'epochs_cleaned')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to cleanup



def get_config(
Expand All @@ -76,6 +79,7 @@ def get_config(
rec=config.rec,
space=config.space,
deriv_root=config.get_deriv_root(),
_epochs_split_size=config._epochs_split_size,
)
return cfg

Expand Down
7 changes: 5 additions & 2 deletions scripts/preprocessing/_06_ptp_reject.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def drop_ptp(*, cfg, subject, session=None):
msg = 'Saving cleaned, baseline-corrected epochs …'

epochs.apply_baseline(cfg.baseline)
epochs.save(fname_out, overwrite=True, split_naming='bids')
epochs.save(
fname_out, overwrite=True, split_naming='bids',
split_size=cfg._epochs_split_size)


def get_config(
Expand All @@ -108,7 +110,8 @@ def get_config(
spatial_filter=config.spatial_filter,
ica_reject=config.get_ica_reject(),
deriv_root=config.get_deriv_root(),
decim=config.decim
decim=config.decim,
_epochs_split_size=config._epochs_split_size,
)
return cfg

Expand Down
16 changes: 4 additions & 12 deletions scripts/report/_01_make_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import config
from config import (
gen_log_kwargs, on_error, failsafe_run, parallel_func,
get_noise_cov_bids_path
get_noise_cov_bids_path, _update_for_splits,
)


Expand All @@ -52,10 +52,7 @@ def get_events(cfg, subject, session):

for run in cfg.runs:
this_raw_fname = raw_fname.copy().update(run=run)

if this_raw_fname.copy().update(split='01').fpath.exists():
this_raw_fname.update(split='01')

this_raw_fname = _update_for_splits(this_raw_fname, None, single=True)
raw_filt = mne.io.read_raw_fif(this_raw_fname)
raws_filt.append(raw_filt)
del this_raw_fname
Expand All @@ -81,10 +78,7 @@ def get_er_path(cfg, subject, session):
datatype=cfg.datatype,
root=cfg.deriv_root,
check=False)

if raw_fname.copy().update(split='01').fpath.exists():
raw_fname.update(split='01')

raw_fname = _update_for_splits(raw_fname, None, single=True)
return raw_fname


Expand Down Expand Up @@ -482,9 +476,7 @@ def run_report_preprocessing(
run=run, processing='filt',
suffix='raw', check=False
)
if fname.copy().update(split='01').fpath.exists():
fname.update(split='01')

fname = _update_for_splits(fname, None, single=True)
fnames_raw_filt.append(fname)

fname_epo_not_clean = bids_path.copy().update(suffix='epo')
Expand Down
Loading