Skip to content

Commit

Permalink
Merge pull request #187 from pni-lab/rcpl-pipeline-for-hcp
Browse files Browse the repository at this point in the history
Improve cluster script & HCP pipeline
  • Loading branch information
spisakt authored Mar 7, 2024
2 parents d993f82 + dfa6ed1 commit 309f797
Show file tree
Hide file tree
Showing 5 changed files with 692 additions and 199 deletions.
38 changes: 29 additions & 9 deletions PUMI/pipelines/func/deconfound.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,32 @@ def fieldmap_correction_qc(wf, volume='middle', **kwargs):
"""

def create_montage(vol_1, vol_2, vol_corrected):
def get_cut_cords(func, n_slices=10):
import nibabel as nib
import numpy as np

func_img = nib.load(func)
y_dim = func_img.shape[1] # y-dimension (coronal direction) is the second dimension in the image shape

slices = np.linspace(-y_dim / 2, y_dim / 2, n_slices)
# slices might contain floats but this is not a problem since nilearn will round floats to the
# nearest integer value!
return slices

def create_montage(vol_1, vol_2, vol_corrected, n_slices=10):
from matplotlib import pyplot as plt
from pathlib import Path
from nilearn import plotting
import os

fig, axes = plt.subplots(3, 1, facecolor='black', figsize=(10, 15))

plotting.plot_anat(vol_1, display_mode='ortho', title='Image #1', black_bg=True, axes=axes[0])
plotting.plot_anat(vol_2, display_mode='ortho', title='Image #2', black_bg=True, axes=axes[1])
plotting.plot_anat(vol_corrected, display_mode='ortho', title='Corrected', black_bg=True, axes=axes[2])
plotting.plot_anat(vol_1, display_mode='y', cut_coords=get_cut_cords(vol_1, n_slices=n_slices),
title='Image #1', black_bg=True, axes=axes[0])
plotting.plot_anat(vol_2, display_mode='y', cut_coords=get_cut_cords(vol_2, n_slices=n_slices),
title='Image #2', black_bg=True, axes=axes[1])
plotting.plot_anat(vol_corrected, display_mode='y', cut_coords=get_cut_cords(vol_corrected, n_slices=n_slices),
title='Corrected', black_bg=True, axes=axes[2])

path = str(Path(os.getcwd() + '/fieldmap_correction_comparison.png'))
plt.savefig(path)
Expand All @@ -67,19 +82,21 @@ def create_montage(vol_1, vol_2, vol_corrected):
wf.connect(vol_corrected, 'out_file', montage, 'vol_corrected')

wf.connect(montage, 'out_file', 'outputspec', 'out_file')
wf.connect(montage, 'out_file', 'sinker', 'out_file')
wf.connect(montage, 'out_file', 'sinker', 'qc_fieldmap_correction')


@FuncPipeline(inputspec_fields=['func_1', 'func_2'],
outputspec_fields=['out_file'])
def fieldmap_correction(wf, encoding_direction=['y-', 'y'], readout_times=[0.08264, 0.08264], tr=0.72, **kwargs):
def fieldmap_correction(wf, encoding_direction=['x-', 'x'], trt=[0.0522, 0.0522], tr=0.72, **kwargs):
"""
Fieldmap correction pipeline.
Parameters:
encoding_direction (list): List of encoding directions (default is left-right and right-left phase encoding).
readout_times (list): List of readout times (default adapted to rsfMRI data of the HCP WU 1200 dataset).
trt (list): List of total readout times (default adapted to rsfMRI data of the HCP WU 1200 dataset).
Default is:
1*(10**(-3))*EchoSpacingMS*EpiFactor = 1*(10**(-3))*0.58*90 = 0.0522 (for LR and RL image)
tr (float): Repetition time (default adapted to rsfMRI data of the HCP WU 1200 dataset).
Inputs:
Expand All @@ -92,8 +109,11 @@ def fieldmap_correction(wf, encoding_direction=['y-', 'y'], readout_times=[0.082
Sinking:
- 4d distortion corrected image.
For more information regarding the parameters:
For more information:
https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/topup/ExampleTopupFollowedByApplytopup
https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/topup/Faq#How_do_I_know_what_phase-encode_vectors_to_put_into_my_--datain_text_file.3F
https://www.humanconnectome.org/storage/app/media/documentation/s1200/HCP_S1200_Release_Appendix_I.pdf
"""

Expand Down Expand Up @@ -127,7 +147,7 @@ def fieldmap_correction(wf, encoding_direction=['y-', 'y'], readout_times=[0.082
# Estimate susceptibility induced distortions
topup = Node(fsl.TOPUP(), name='topup')
topup.inputs.encoding_direction = encoding_direction
topup.inputs.readout_times = readout_times
topup.inputs.readout_times = trt
wf.connect(merger, 'merged_file', topup, 'in_file')

# The two original 4D files are also needed inside a list
Expand Down
286 changes: 117 additions & 169 deletions pipelines/hcp_rcpl.py → pipelines/hcp.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
#!/usr/bin/env python3

import argparse
import glob
from PUMI import globals
from nipype import IdentityInterface, DataGrabber
from nipype.interfaces.fsl import Reorient2Std
from nipype.interfaces import afni
from PUMI.engine import BidsPipeline, NestedNode as Node, FuncPipeline, GroupPipeline, BidsApp
from PUMI.pipelines.anat.anat_proc import anat_proc
from PUMI.pipelines.func.compcor import anat_noise_roi
from PUMI.pipelines.func.compcor import anat_noise_roi, compcor
from PUMI.pipelines.anat.func_to_anat import func2anat
from PUMI.pipelines.func.deconfound import fieldmap_correction
from nipype.interfaces import utility

from PUMI.pipelines.func.deconfound import fieldmap_correction
from PUMI.pipelines.func.func_proc import func_proc_despike_afni
from PUMI.pipelines.func.timeseries_extractor import extract_timeseries_nativespace
from PUMI.pipelines.func.timeseries_extractor import pick_atlas, extract_timeseries_nativespace
from PUMI.utils import mist_modules, mist_labels, get_reference
from PUMI.pipelines.func.func2standard import func2standard
from PUMI.engine import NestedWorkflow as Workflow

from pathlib import Path
from PUMI.pipelines.multimodal.image_manipulation import pick_volume
from PUMI.engine import save_software_versions
import traits
import os

Expand Down Expand Up @@ -369,169 +365,121 @@ def merge_predictions(rpn_out_file, rcpl_out_file):
wf.connect(merge_predictions_wf, 'out_file', 'sinker', 'pain_predictions')


parser = argparse.ArgumentParser()
@BidsPipeline(output_query={
'T1w': dict(
datatype='anat',
suffix='T1w',
extension=['nii', 'nii.gz']
),
'bold_lr': dict(
datatype='func',
suffix='bold',
acquisition='LR',
extension=['nii', 'nii.gz']
),
'bold_rl': dict(
datatype='func',
suffix='bold',
acquisition='RL',
extension=['nii', 'nii.gz']
)
})
def hcp(wf, bbr=True, **kwargs):
"""
The HCP pipeline is the RCPL pipeline but with different inputs (two bold images with different phase encodings
instead of one bold image) and with additional fieldmap correction.
parser.add_argument(
'--bids_dir',
required=True,
help='Root directory of the input dataset.'
)
CAUTION: This pipeline assumes that you converted the HCP dataset into the BIDS format!
"""

parser.add_argument(
'--output_dir',
required=True,
help='Directory where the results will be stored.'
print('* bbr:', bbr)

reorient_struct_wf = Node(Reorient2Std(output_type='NIFTI_GZ'), name="reorient_struct_wf")
wf.connect('inputspec', 'T1w', reorient_struct_wf, 'in_file')

reorient_func_lr_wf = Node(Reorient2Std(output_type='NIFTI_GZ'), name="reorient_func_lr_wf")
wf.connect('inputspec', 'bold_lr', reorient_func_lr_wf, 'in_file')

reorient_func_rl_wf = Node(Reorient2Std(output_type='NIFTI_GZ'), name="reorient_func_rl_wf")
wf.connect('inputspec', 'bold_rl', reorient_func_rl_wf, 'in_file')

fieldmap_corr = fieldmap_correction('fieldmap_corr')
wf.connect(reorient_func_lr_wf, 'out_file', fieldmap_corr, 'func_1')
wf.connect(reorient_func_rl_wf, 'out_file', fieldmap_corr, 'func_2')

anatomical_preprocessing_wf = anat_proc(name='anatomical_preprocessing_wf', bet_tool='deepbet')
wf.connect(reorient_struct_wf, 'out_file', anatomical_preprocessing_wf, 'in_file')

func2anat_wf = func2anat(name='func2anat_wf', bbr=bbr)
wf.connect(fieldmap_corr, 'out_file', func2anat_wf, 'func')
wf.connect(anatomical_preprocessing_wf, 'brain', func2anat_wf, 'head')
wf.connect(anatomical_preprocessing_wf, 'probmap_wm', func2anat_wf, 'anat_wm_segmentation')
wf.connect(anatomical_preprocessing_wf, 'probmap_csf', func2anat_wf, 'anat_csf_segmentation')
wf.connect(anatomical_preprocessing_wf, 'probmap_gm', func2anat_wf, 'anat_gm_segmentation')
wf.connect(anatomical_preprocessing_wf, 'probmap_ventricle', func2anat_wf, 'anat_ventricle_segmentation')

compcor_roi_wf = anat_noise_roi('compcor_roi_wf')
wf.connect(func2anat_wf, 'wm_mask_in_funcspace', compcor_roi_wf, 'wm_mask')
wf.connect(func2anat_wf, 'ventricle_mask_in_funcspace', compcor_roi_wf, 'ventricle_mask')

func_proc_wf = func_proc_despike_afni('func_proc_wf', bet_tool='deepbet', deepbet_n_dilate=2)
wf.connect(fieldmap_corr, 'out_file', func_proc_wf, 'func')
wf.connect(compcor_roi_wf, 'out_file', func_proc_wf, 'cc_noise_roi')

pick_atlas_wf = mist_atlas('pick_atlas_wf')
mist_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data_in/atlas/MIST"))
pick_atlas_wf.get_node('inputspec').inputs.labelmap = os.path.join(mist_dir, 'Parcellations/MIST_122.nii.gz')
pick_atlas_wf.get_node('inputspec').inputs.modules = mist_modules(mist_directory=mist_dir, resolution="122")
pick_atlas_wf.get_node('inputspec').inputs.labels = mist_labels(mist_directory=mist_dir, resolution="122")

extract_timeseries = extract_timeseries_nativespace('extract_timeseries')
wf.connect(pick_atlas_wf, 'relabeled_atlas', extract_timeseries, 'atlas')
wf.connect(pick_atlas_wf, 'reordered_labels', extract_timeseries, 'labels')
wf.connect(pick_atlas_wf, 'reordered_modules', extract_timeseries, 'modules')
wf.connect(anatomical_preprocessing_wf, 'brain', extract_timeseries, 'anat')
wf.connect(func2anat_wf, 'anat_to_func_linear_xfm', extract_timeseries, 'inv_linear_reg_mtrx')
wf.connect(anatomical_preprocessing_wf, 'mni2anat_warpfield', extract_timeseries, 'inv_nonlinear_reg_mtrx')
wf.connect(func2anat_wf, 'gm_mask_in_funcspace', extract_timeseries, 'gm_mask')
wf.connect(func_proc_wf, 'func_preprocessed', extract_timeseries, 'func')
wf.connect(func_proc_wf, 'FD', extract_timeseries, 'confounds')

func2std = func2standard('func2std')
wf.connect(anatomical_preprocessing_wf, 'brain', func2std, 'anat')
wf.connect(func2anat_wf, 'func_to_anat_linear_xfm', func2std, 'linear_reg_mtrx')
wf.connect(anatomical_preprocessing_wf, 'anat2mni_warpfield', func2std, 'nonlinear_reg_mtrx')
wf.connect(anatomical_preprocessing_wf, 'std_template', func2std, 'reference_brain')
wf.connect(func_proc_wf, 'func_preprocessed', func2std, 'func')
wf.connect(func_proc_wf, 'mc_ref_vol', func2std, 'bbr2ants_source_file')

calculate_connectivity_wf = calculate_connectivity('calculate_connectivity_wf')
wf.connect(extract_timeseries, 'timeseries', calculate_connectivity_wf, 'ts_files')
wf.connect(func_proc_wf, 'FD', calculate_connectivity_wf, 'fd_files')

predict_pain_sensitivity_rpn_wf = predict_pain_sensitivity_rpn('predict_pain_sensitivity_rpn_wf')
wf.connect(calculate_connectivity_wf, 'features', predict_pain_sensitivity_rpn_wf, 'X')
wf.connect(fieldmap_corr, 'out_file', predict_pain_sensitivity_rpn_wf, 'in_file')

predict_pain_sensitivity_rcpl_wf = predict_pain_sensitivity_rcpl('predict_pain_sensitivity_rcpl_wf')
wf.connect(calculate_connectivity_wf, 'features', predict_pain_sensitivity_rcpl_wf, 'X')
wf.connect(fieldmap_corr, 'out_file', predict_pain_sensitivity_rcpl_wf, 'in_file')

collect_pain_predictions_wf = collect_pain_predictions('collect_pain_predictions_wf')
wf.connect(predict_pain_sensitivity_rpn_wf, 'out_file', collect_pain_predictions_wf, 'rpn_out_file')
wf.connect(predict_pain_sensitivity_rcpl_wf, 'out_file', collect_pain_predictions_wf, 'rcpl_out_file')

wf.write_graph('HCP-pipeline.png')
save_software_versions(wf)


hcp_app = BidsApp(
pipeline=hcp,
name='hcp'
)

parser.add_argument(
hcp_app.parser.add_argument(
'--bbr',
default='yes',
type=lambda x: (str(x).lower() == ['true', '1', 'yes']),
help='Use BBR registration: yes/no (default: yes)'
type=lambda x: (str(x).lower() in ['true', '1', 'yes']),
help="Use BBR registration: yes/no (default: yes)"
)

parser.add_argument('--n_procs', type=int,
help='Amount of threads to execute in parallel.'
+ 'If not set, the amount of CPU cores is used.'
+ 'Caution: Does only work with the MultiProc-plugin!')

parser.add_argument('--memory_gb', type=int,
help='Memory limit in GB. If not set, use 90% of the available memory'
+ 'Caution: Does only work with the MultiProc-plugin!')


cli_args = parser.parse_args()

input_dir = cli_args.bids_dir
output_dir = cli_args.output_dir
bbr = cli_args.bbr

plugin_args = {}
if cli_args.n_procs is not None:
plugin_args['n_procs'] = cli_args.n_procs

if cli_args.memory_gb is not None:
plugin_args['memory_gb'] = cli_args.memory_gb

subjects = []
excluded = []
for path in glob.glob(str(input_dir) + '/*'):
id = path.split('/')[-1]

base = path + '/unprocessed/3T/'

t1w_base = str(Path(base + '/T1w_MPR1/' + id + '_3T_T1w_MPR1.nii'))
has_t1w = os.path.isfile(t1w_base) or os.path.isfile(t1w_base + '.gz')

lr_base = str(Path(base + '/rfMRI_REST1_LR/' + id + '_3T_rfMRI_REST1_LR.nii'))
has_lr = os.path.isfile(lr_base) or os.path.isfile(lr_base + '.gz')

rl_base = str(Path(base + '/rfMRI_REST1_RL/' + id + '_3T_rfMRI_REST1_RL.nii'))
has_rl = os.path.isfile(rl_base) or os.path.isfile(rl_base + '.gz')

if has_t1w and has_lr and has_rl:
subjects.append(id)
else:
excluded.append(id)

print('-' * 100)
print(f'Included %d subjects.' % len(subjects))
print(f'Excluded %d subjects.' % len(excluded))
print('-' * 100)


wf = Workflow(name='HCP-RCPL')
wf.base_dir = '.'
globals.cfg_parser.set('SINKING', 'sink_dir', str(Path(os.path.abspath(output_dir + '/derivatives'))))
globals.cfg_parser.set('SINKING', 'qc_dir', str(Path(os.path.abspath(output_dir + '/derivatives/qc'))))


# Create a subroutine (subgraph) for every subject
inputspec = Node(interface=IdentityInterface(fields=['subject']), name='inputspec')
inputspec.iterables = [('subject', subjects)]

T1w_grabber = Node(DataGrabber(infields=['subject'], outfields=['T1w']), name='T1w_grabber')
T1w_grabber.inputs.base_directory = os.path.abspath(input_dir)
T1w_grabber.inputs.template = '%s/unprocessed/3T/T1w_MPR1/*T1w_MPR1.nii*'
T1w_grabber.inputs.sort_filelist = True
wf.connect(inputspec, 'subject', T1w_grabber, 'subject')

bold_lr_grabber = Node(DataGrabber(infields=['subject'], outfields=['bold_lr']), name='bold_lr_grabber')
bold_lr_grabber.inputs.base_directory = os.path.abspath(input_dir)
bold_lr_grabber.inputs.template = '%s/unprocessed/3T/rfMRI_REST1_LR/*_3T_rfMRI_REST1_LR.nii*'
bold_lr_grabber.inputs.sort_filelist = True
wf.connect(inputspec, 'subject', bold_lr_grabber, 'subject')

bold_rl_grabber = Node(DataGrabber(infields=['subject'], outfields=['bold_rl']), name='bold_rl_grabber')
bold_rl_grabber.inputs.base_directory = os.path.abspath(input_dir)
bold_rl_grabber.inputs.template = '%s/unprocessed/3T/rfMRI_REST1_RL/*_3T_rfMRI_REST1_RL.nii*'
bold_rl_grabber.inputs.sort_filelist = True
wf.connect(inputspec, 'subject', bold_rl_grabber, 'subject')

reorient_struct_wf = Node(Reorient2Std(output_type='NIFTI_GZ'), name="reorient_struct_wf")
wf.connect(T1w_grabber, 'T1w', reorient_struct_wf, 'in_file')

reorient_func_lr_wf = Node(Reorient2Std(output_type='NIFTI_GZ'), name="reorient_func_lr_wf")
wf.connect(bold_lr_grabber, 'bold_lr', reorient_func_lr_wf, 'in_file')

reorient_func_rl_wf = Node(Reorient2Std(output_type='NIFTI_GZ'), name="reorient_func_rl_wf")
wf.connect(bold_rl_grabber, 'bold_rl', reorient_func_rl_wf, 'in_file')

fieldmap_corr = fieldmap_correction('fieldmap_corr')
wf.connect(reorient_func_lr_wf, 'out_file', fieldmap_corr, 'func_1')
wf.connect(reorient_func_rl_wf, 'out_file', fieldmap_corr, 'func_2')

anatomical_preprocessing_wf = anat_proc(name='anatomical_preprocessing_wf', bet_tool='deepbet')
wf.connect(reorient_struct_wf, 'out_file', anatomical_preprocessing_wf, 'in_file')

func2anat_wf = func2anat(name='func2anat_wf', bbr=bbr)
wf.connect(fieldmap_corr, 'out_file', func2anat_wf, 'func')
wf.connect(anatomical_preprocessing_wf, 'brain', func2anat_wf, 'head')
wf.connect(anatomical_preprocessing_wf, 'probmap_wm', func2anat_wf, 'anat_wm_segmentation')
wf.connect(anatomical_preprocessing_wf, 'probmap_csf', func2anat_wf, 'anat_csf_segmentation')
wf.connect(anatomical_preprocessing_wf, 'probmap_gm', func2anat_wf, 'anat_gm_segmentation')
wf.connect(anatomical_preprocessing_wf, 'probmap_ventricle', func2anat_wf, 'anat_ventricle_segmentation')

compcor_roi_wf = anat_noise_roi('compcor_roi_wf')
wf.connect(func2anat_wf, 'wm_mask_in_funcspace', compcor_roi_wf, 'wm_mask')
wf.connect(func2anat_wf, 'ventricle_mask_in_funcspace', compcor_roi_wf, 'ventricle_mask')

func_proc_wf = func_proc_despike_afni('func_proc_wf', bet_tool='deepbet', deepbet_n_dilate=2)
wf.connect(fieldmap_corr, 'out_file', func_proc_wf, 'func')
wf.connect(compcor_roi_wf, 'out_file', func_proc_wf, 'cc_noise_roi')

pick_atlas_wf = mist_atlas('pick_atlas_wf')
mist_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data_in/atlas/MIST"))
pick_atlas_wf.get_node('inputspec').inputs.labelmap = os.path.join(mist_dir, 'Parcellations/MIST_122.nii.gz')
pick_atlas_wf.get_node('inputspec').inputs.modules = mist_modules(mist_directory=mist_dir, resolution="122")
pick_atlas_wf.get_node('inputspec').inputs.labels = mist_labels(mist_directory=mist_dir, resolution="122")

extract_timeseries = extract_timeseries_nativespace('extract_timeseries')
wf.connect(pick_atlas_wf, 'relabeled_atlas', extract_timeseries, 'atlas')
wf.connect(pick_atlas_wf, 'reordered_labels', extract_timeseries, 'labels')
wf.connect(pick_atlas_wf, 'reordered_modules', extract_timeseries, 'modules')
wf.connect(anatomical_preprocessing_wf, 'brain', extract_timeseries, 'anat')
wf.connect(func2anat_wf, 'anat_to_func_linear_xfm', extract_timeseries, 'inv_linear_reg_mtrx')
wf.connect(anatomical_preprocessing_wf, 'mni2anat_warpfield', extract_timeseries, 'inv_nonlinear_reg_mtrx')
wf.connect(func2anat_wf, 'gm_mask_in_funcspace', extract_timeseries, 'gm_mask')
wf.connect(func_proc_wf, 'func_preprocessed', extract_timeseries, 'func')
wf.connect(func_proc_wf, 'FD', extract_timeseries, 'confounds')

calculate_connectivity_wf = calculate_connectivity('calculate_connectivity_wf')
wf.connect(extract_timeseries, 'timeseries', calculate_connectivity_wf, 'ts_files')
wf.connect(func_proc_wf, 'FD', calculate_connectivity_wf, 'fd_files')

predict_pain_sensitivity_rpn_wf = predict_pain_sensitivity_rpn('predict_pain_sensitivity_rpn_wf')
wf.connect(calculate_connectivity_wf, 'features', predict_pain_sensitivity_rpn_wf, 'X')
wf.connect(fieldmap_corr, 'out_file', predict_pain_sensitivity_rpn_wf, 'in_file')

predict_pain_sensitivity_rcpl_wf = predict_pain_sensitivity_rcpl('predict_pain_sensitivity_rcpl_wf')
wf.connect(calculate_connectivity_wf, 'features', predict_pain_sensitivity_rcpl_wf, 'X')
wf.connect(fieldmap_corr, 'out_file', predict_pain_sensitivity_rcpl_wf, 'in_file')

collect_pain_predictions_wf = collect_pain_predictions('collect_pain_predictions_wf')
wf.connect(predict_pain_sensitivity_rpn_wf, 'out_file', collect_pain_predictions_wf, 'rpn_out_file')
wf.connect(predict_pain_sensitivity_rcpl_wf, 'out_file', collect_pain_predictions_wf, 'rcpl_out_file')

wf.write_graph('Pipeline.png')
wf.run(plugin='MultiProc', plugin_args=plugin_args)
hcp_app.run()
Loading

0 comments on commit 309f797

Please sign in to comment.