Skip to content

Commit

Permalink
maint: replace ported sources with shim imports and deprecation warnings
Browse files Browse the repository at this point in the history
Resolves: #787.
  • Loading branch information
oesteban committed Mar 12, 2023
1 parent c2b8e0f commit 1489115
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 3,284 deletions.
234 changes: 11 additions & 223 deletions niworkflows/interfaces/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,229 +21,17 @@
# https://www.nipreps.org/community/licensing/
#
"""Visualization tools."""
import numpy as np
import nibabel as nb

from nipype.utils.filemanip import fname_presuffix
from nipype.interfaces.base import (
File,
BaseInterfaceInputSpec,
TraitedSpec,
SimpleInterface,
traits,
isdefined,
)
from niworkflows.utils.timeseries import _cifti_timeseries, _nifti_timeseries
from niworkflows.viz.plots import (
fMRIPlot,
compcor_variance_plot,
confounds_correlation_plot,
import warnings
from nireports.interfaces import (
CompCorVariancePlot,
ConfoundsCorrelationPlot,
FMRISummary,
)

__all__ = (
"CompCorVariancePlot",
"ConfoundsCorrelationPlot",
"FMRISummary",
)

class _FMRISummaryInputSpec(BaseInterfaceInputSpec):
in_func = File(exists=True, mandatory=True, desc="")
in_spikes_bg = File(exists=True, desc="")
fd = File(exists=True, desc="")
dvars = File(exists=True, desc="")
outliers = File(exists=True, desc="")
in_segm = File(exists=True, desc="")
tr = traits.Either(None, traits.Float, usedefault=True, desc="the TR")
fd_thres = traits.Float(0.2, usedefault=True, desc="")
drop_trs = traits.Int(0, usedefault=True, desc="dummy scans")


class _FMRISummaryOutputSpec(TraitedSpec):
out_file = File(exists=True, desc="written file path")


class FMRISummary(SimpleInterface):
"""Prepare an fMRI summary plot for the report."""

input_spec = _FMRISummaryInputSpec
output_spec = _FMRISummaryOutputSpec

def _run_interface(self, runtime):
import pandas as pd

self._results["out_file"] = fname_presuffix(
self.inputs.in_func,
suffix="_fmriplot.svg",
use_ext=False,
newpath=runtime.cwd,
)

dataframe = pd.DataFrame({
"outliers": np.loadtxt(self.inputs.outliers, usecols=[0]).tolist(),
# Pick non-standardize dvars (col 1)
# First timepoint is NaN (difference)
"DVARS": [np.nan]
+ np.loadtxt(self.inputs.dvars, skiprows=1, usecols=[1]).tolist(),
# First timepoint is zero (reference volume)
"FD": [0.0]
+ np.loadtxt(self.inputs.fd, skiprows=1, usecols=[0]).tolist(),
}) if (
isdefined(self.inputs.outliers)
and isdefined(self.inputs.dvars)
and isdefined(self.inputs.fd)
) else None

input_data = nb.load(self.inputs.in_func)
seg_file = self.inputs.in_segm if isdefined(self.inputs.in_segm) else None
dataset, segments = (
_cifti_timeseries(input_data)
if isinstance(input_data, nb.Cifti2Image) else
_nifti_timeseries(input_data, seg_file)
)

fig = fMRIPlot(
dataset,
segments=segments,
spikes_files=(
[self.inputs.in_spikes_bg]
if isdefined(self.inputs.in_spikes_bg) else None
),
tr=(
self.inputs.tr if isdefined(self.inputs.tr) else
_get_tr(input_data)
),
confounds=dataframe,
units={"outliers": "%", "FD": "mm"},
vlines={"FD": [self.inputs.fd_thres]},
nskip=self.inputs.drop_trs,
).plot()
fig.savefig(self._results["out_file"], bbox_inches="tight")
return runtime


class _CompCorVariancePlotInputSpec(BaseInterfaceInputSpec):
metadata_files = traits.List(
File(exists=True),
mandatory=True,
desc="List of files containing component " "metadata",
)
metadata_sources = traits.List(
traits.Str,
desc="List of names of decompositions "
"(e.g., aCompCor, tCompCor) yielding "
"the arguments in `metadata_files`",
)
variance_thresholds = traits.Tuple(
traits.Float(0.5),
traits.Float(0.7),
traits.Float(0.9),
usedefault=True,
desc="Levels of explained variance to include in " "plot",
)
out_file = traits.Either(
None, File, value=None, usedefault=True, desc="Path to save plot"
)


class _CompCorVariancePlotOutputSpec(TraitedSpec):
out_file = File(exists=True, desc="Path to saved plot")


class CompCorVariancePlot(SimpleInterface):
"""Plot the number of components necessary to explain the specified levels of variance."""

input_spec = _CompCorVariancePlotInputSpec
output_spec = _CompCorVariancePlotOutputSpec

def _run_interface(self, runtime):
if self.inputs.out_file is None:
self._results["out_file"] = fname_presuffix(
self.inputs.metadata_files[0],
suffix="_compcor.svg",
use_ext=False,
newpath=runtime.cwd,
)
else:
self._results["out_file"] = self.inputs.out_file
compcor_variance_plot(
metadata_files=self.inputs.metadata_files,
metadata_sources=self.inputs.metadata_sources,
output_file=self._results["out_file"],
varexp_thresh=self.inputs.variance_thresholds,
)
return runtime


class _ConfoundsCorrelationPlotInputSpec(BaseInterfaceInputSpec):
confounds_file = File(
exists=True, mandatory=True, desc="File containing confound regressors"
)
out_file = traits.Either(
None, File, value=None, usedefault=True, desc="Path to save plot"
)
reference_column = traits.Str(
"global_signal",
usedefault=True,
desc="Column in the confound file for "
"which all correlation magnitudes "
"should be ranked and plotted",
)
columns = traits.List(
traits.Str,
desc="Filter out all regressors not found in this list."
)
max_dim = traits.Int(
20,
usedefault=True,
desc="Maximum number of regressors to include in "
"plot. Regressors with highest magnitude of "
"correlation with `reference_column` will be "
"selected.",
)


class _ConfoundsCorrelationPlotOutputSpec(TraitedSpec):
out_file = File(exists=True, desc="Path to saved plot")


class ConfoundsCorrelationPlot(SimpleInterface):
"""Plot the correlation among confound regressors."""

input_spec = _ConfoundsCorrelationPlotInputSpec
output_spec = _ConfoundsCorrelationPlotOutputSpec

def _run_interface(self, runtime):
if self.inputs.out_file is None:
self._results["out_file"] = fname_presuffix(
self.inputs.confounds_file,
suffix="_confoundCorrelation.svg",
use_ext=False,
newpath=runtime.cwd,
)
else:
self._results["out_file"] = self.inputs.out_file
confounds_correlation_plot(
confounds_file=self.inputs.confounds_file,
columns=self.inputs.columns if isdefined(self.inputs.columns) else None,
max_dim=self.inputs.max_dim,
output_file=self._results["out_file"],
reference=self.inputs.reference_column,
)
return runtime


def _get_tr(img):
"""
Attempt to extract repetition time from NIfTI/CIFTI header
Examples
--------
>>> _get_tr(nb.load(Path(test_data) /
... 'sub-ds205s03_task-functionallocalizer_run-01_bold_volreg.nii.gz'))
2.2
>>> _get_tr(nb.load(Path(test_data) /
... 'sub-01_task-mixedgamblestask_run-02_space-fsLR_den-91k_bold.dtseries.nii'))
2.0
"""

try:
return img.header.matrix.get_index_map(0).series_step
except AttributeError:
return img.header.get_zooms()[-1]
raise RuntimeError("Could not extract TR - unknown data structure type")
warnings.warn("Please use nireports.interfaces", DeprecationWarning)
Loading

0 comments on commit 1489115

Please sign in to comment.