From 324d2eeac03bca04b86fce50cdf2d18f638acd61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Tue, 16 Apr 2024 19:48:27 -0400 Subject: [PATCH 1/5] ENH: Make models inherit from base model Make models inherit from base model. --- src/eddymotion/model/base.py | 206 +++++++++++++++++++++-------------- 1 file changed, 126 insertions(+), 80 deletions(-) diff --git a/src/eddymotion/model/base.py b/src/eddymotion/model/base.py index 38207741..96221db9 100644 --- a/src/eddymotion/model/base.py +++ b/src/eddymotion/model/base.py @@ -28,13 +28,34 @@ from dipy.core.gradients import gradient_table from joblib import Parallel, delayed +#: Minimum value when considering the :math:`S_{0}` DWI signal. +DEFAULT_MIN_S0 = 1e-5 + +#: Maximum value when considering the :math:`S_{0}` DWI signal. +DEFAULT_MAX_S0 = 1.0 + +#: Maximum allowed value for the b-value. +DEFAULT_MAX_BVALUE = 1000 + +#: b-value lower bound when considering DWI data. +DEFAULT_LOWB_THRESHOLD = 50 + +#: b-value upper bound when considering DWI data. +DEFAULT_HIGHB_THRESHOLD = 10000 + +#: Percentile clipping value. +DEFAULT_CLIP_PERCENTILE = 75 + +#: Time frame tolerance in seconds. +DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2 + def _exec_fit(model, data, chunk=None): retval = model.fit(data) return retval, chunk -def _exec_predict(model, gradient, chunk=None, **kwargs): +def _exec_predict_dwi(model, gradient, chunk=None, **kwargs): """Propagate model parameters and call predict.""" return np.squeeze(model.predict(gradient, S0=kwargs.pop("S0", None))), chunk @@ -86,51 +107,18 @@ class BaseModel: __slots__ = ( "_model", "_mask", - "_S0", - "_b_max", "_models", "_datashape", ) _modelargs = () - def __init__(self, gtab, S0=None, mask=None, b_max=None, **kwargs): + def __init__(self, mask=None, **kwargs): """Base initialization.""" - # Setup B0 map - self._S0 = None - if S0 is not None: - self._S0 = np.clip( - S0.astype("float32") / S0.max(), - a_min=1e-5, - a_max=1.0, - ) + self._model = None # Setup brain mask self._mask = mask - if mask is None and S0 is not None: - self._mask = self._S0 > np.percentile(self._S0, 35) - - # Cap b-values, if requested - self._b_max = None - if b_max and b_max > 1000: - # Saturate b-values at b_max, since signal stops dropping - gtab[-1, gtab[-1] > b_max] = b_max - # A possibly good alternative is completely remove very high b-values - # bval_mask = gtab[-1] < b_max - # data = data[..., bval_mask] - # gtab = gtab[:, bval_mask] - self._b_max = b_max - - kwargs = {k: v for k, v in kwargs.items() if k in self._modelargs} - - model_str = getattr(self, "_model_class", None) - if not model_str: - raise TypeError("No model defined") - - from importlib import import_module - - module_name, class_name = model_str.rsplit(".", 1) - self._model = getattr(import_module(module_name), class_name)(_rasb2dipy(gtab), **kwargs) self._datashape = None self._models = None @@ -166,12 +154,73 @@ def fit(self, data, n_jobs=None, **kwargs): self._model = None # Preempt further actions on the model - def predict(self, gradient, **kwargs): + def predict(self, *args, **kwargs): + pass + + +class BaseDWIModel(BaseModel): + """Interface and default methods for DWI models.""" + + __slots__ = ( + "_gtab", + "_S0", + "_b_max", + ) + + def __init__(self, gtab, S0=None, b_max=None, **kwargs): + """Initialization. + + Parameters + ---------- + gtab : :obj:`numpy.ndarray` + An :math:`N \times 4` table, where rows (*N*) are diffusion gradients and + columns are b-vector components and corresponding b-value, respectively. + S0 : :obj:`numpy.ndarray` + :math:`S_{0}` signal. + b_max : :obj:`int` + Maximum value to cap b-values. + """ + + super().__init__(**kwargs) + + # Setup B0 map + self._S0 = None + if S0 is not None: + self._S0 = np.clip( + S0.astype("float32") / S0.max(), + a_min=DEFAULT_MIN_S0, + a_max=DEFAULT_MAX_S0, + ) + + # Cap b-values, if requested + self._gtab = gtab + self._b_max = None + if b_max and b_max > DEFAULT_MAX_BVALUE: + # Saturate b-values at b_max, since signal stops dropping + self._gtab[-1, self._gtab[-1] > b_max] = b_max + # A possibly good alternative is completely remove very high b-values + # bval_mask = gtab[-1] < b_max + # data = data[..., bval_mask] + # gtab = gtab[:, bval_mask] + self._b_max = b_max + + kwargs = {k: v for k, v in kwargs.items() if k in self._modelargs} + + model_str = getattr(self, "_model_class", None) + if not model_str: + raise TypeError("No model defined") + + from importlib import import_module + + module_name, class_name = model_str.rsplit(".", 1) + self._model = getattr(import_module(module_name), class_name)(_rasb2dipy(gtab), **kwargs) + + def predict(self, index, **kwargs): """Predict asynchronously chunk-by-chunk the diffusion signal.""" if self._b_max is not None: - gradient[-1] = min(gradient[-1], self._b_max) + index[-1] = min(index[-1], self._b_max) - gradient = _rasb2dipy(gradient) + self._gtab = _rasb2dipy(self._gtab) S0 = None if self._S0 is not None: @@ -184,7 +233,7 @@ def predict(self, gradient, **kwargs): n_models = len(self._models) if self._model is None and self._models else 1 if n_models == 1: - predicted, _ = _exec_predict(self._model, gradient, S0=S0, **kwargs) + predicted, _ = _exec_predict_dwi(self._model, self._gtab, S0=S0, **kwargs) else: S0 = np.array_split(S0, n_models) if S0 is not None else [None] * n_models @@ -193,7 +242,7 @@ def predict(self, gradient, **kwargs): # Parallelize process with joblib with Parallel(n_jobs=n_models) as executor: results = executor( - delayed(_exec_predict)(model, gradient, S0=S0[i], chunk=i, **kwargs) + delayed(_exec_predict_dwi)(model, self._gtab, S0=S0[i], chunk=i, **kwargs) for i, model in enumerate(self._models) ) for subprediction, index in results: @@ -210,27 +259,25 @@ def predict(self, gradient, **kwargs): return retval -class TrivialB0Model: +class TrivialB0Model(BaseDWIModel): """A trivial model that returns a *b=0* map always.""" - __slots__ = ("_S0",) - - def __init__(self, S0=None, **kwargs): + def __init__(self, **kwargs): """Implement object initialization.""" - if S0 is None: - raise ValueError("S0 must be provided") + super().__init__(**kwargs) - self._S0 = S0 + if self._S0 is None: + raise ValueError("S0 must be provided") - def fit(self, *args, **kwargs): + def fit(self, data, **kwargs): """Do nothing.""" - def predict(self, gradient, **kwargs): + def predict(self, *_, **kwargs): """Return the *b=0* map.""" return self._S0 -class AverageDWModel: +class AverageDWModel(BaseDWIModel): """A trivial model that returns an average map.""" __slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat") @@ -241,35 +288,34 @@ def __init__(self, **kwargs): Parameters ---------- - gtab : :obj:`~numpy.ndarray` - An :math:`N \times 4` table, where rows (*N*) are diffusion gradients and - columns are b-vector components and corresponding b-value, respectively. - th_low : :obj:`~numbers.Number` + th_low : :obj:`numbers.Number` A lower bound for the b-value corresponding to the diffusion weighted images that will be averaged. - th_high : :obj:`~numbers.Number` + th_high : :obj:`numbers.Number` An upper bound for the b-value corresponding to the diffusion weighted images that will be averaged. bias : :obj:`bool` Whether the overall distribution of each diffusion weighted image will be - standardized and centered around the global 75th percentile. + standardized and centered around the + :data:`src.eddymotion.model.base.DEFAULT_CLIP_PERCENTILE` percentile. stat : :obj:`str` Whether the summary statistic to apply is ``"mean"`` or ``"median"``. """ - self._th_low = kwargs.get("th_low", 50) - self._th_high = kwargs.get("th_high", 10000) + super().__init__(**kwargs) + + self._th_low = kwargs.get("th_low", DEFAULT_LOWB_THRESHOLD) + self._th_high = kwargs.get("th_high", DEFAULT_HIGHB_THRESHOLD) self._bias = kwargs.get("bias", True) self._stat = kwargs.get("stat", "median") self._data = None def fit(self, data, **kwargs): """Calculate the average.""" - gtab = kwargs.pop("gtab", None) # Select the interval of b-values for which DWIs will be averaged b_mask = ( - ((gtab[3] >= self._th_low) & (gtab[3] <= self._th_high)) - if gtab is not None + ((self._gtab[3] >= self._th_low) & (self._gtab[3] <= self._th_high)) + if self._gtab is not None else np.ones((data.shape[-1],), dtype=bool) ) shells = data[..., b_mask] @@ -277,7 +323,7 @@ def fit(self, data, **kwargs): # Regress out global signal differences if self._bias: centers = np.median(shells, axis=(0, 1, 2)) - reference = np.percentile(centers[centers >= 1.0], 75) + reference = np.percentile(centers[centers >= 1.0], DEFAULT_CLIP_PERCENTILE) centers[centers < 1.0] = reference drift = reference / centers shells = shells * drift @@ -287,17 +333,17 @@ def fit(self, data, **kwargs): # Calculate the average self._data = avg_func(shells, axis=-1) - def predict(self, gradient, **kwargs): + def predict(self, *_, **kwargs): """Return the average map.""" return self._data -class PETModel: +class PETModel(BaseModel): """A PET imaging realignment model based on B-Spline approximation.""" - __slots__ = ("_t", "_x", "_xlim", "_order", "_coeff", "_mask", "_shape", "_n_ctrl") + __slots__ = ("_t", "_x", "_xlim", "_order", "_coeff", "_n_ctrl") - def __init__(self, timepoints=None, xlim=None, n_ctrl=None, mask=None, order=3, **kwargs): + def __init__(self, timepoints=None, xlim=None, n_ctrl=None, order=3, **kwargs): """ Create the B-Spline interpolating matrix. @@ -314,18 +360,19 @@ def __init__(self, timepoints=None, xlim=None, n_ctrl=None, mask=None, order=3, model. """ + super.__init__(**kwargs) + if timepoints is None or xlim is None: raise TypeError("timepoints must be provided in initialization") self._order = order - self._mask = mask self._x = np.array(timepoints, dtype="float32") self._xlim = xlim - if self._x[0] < 1e-2: + if self._x[0] < DEFAULT_TIMEFRAME_MIDPOINT_TOL: raise ValueError("First frame midpoint should not be zero or negative") - if self._x[-1] > (self._xlim - 1e-2): + if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL): raise ValueError("Last frame midpoint should not be equal or greater than duration") # Calculate index coordinates in the B-Spline grid @@ -334,10 +381,9 @@ def __init__(self, timepoints=None, xlim=None, n_ctrl=None, mask=None, order=3, # B-Spline knots self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32") - self._shape = None self._coeff = None - def fit(self, data, *args, **kwargs): + def fit(self, data, **kwargs): """Fit the model.""" from scipy.interpolate import BSpline from scipy.sparse.linalg import cg @@ -347,7 +393,7 @@ def fit(self, data, *args, **kwargs): timepoints = kwargs.get("timepoints", None) or self._x x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl - self._shape = data.shape[:3] + self._datashape = data.shape[:3] # Convert data into V (voxels) x T (timepoints) data = data.reshape((-1, data.shape[-1])) if self._mask is None else data[self._mask] @@ -368,12 +414,12 @@ def fit(self, data, *args, **kwargs): self._coeff = np.array([r[0] for r in results]) - def predict(self, timepoint, **kwargs): - """Return the *b=0* map.""" + def predict(self, index, **kwargs): + """Return the corrected volume using B-spline interpolation.""" from scipy.interpolate import BSpline # Project sample timing into B-Spline coordinates - x = (timepoint / self._xlim) * self._n_ctrl + x = (index / self._xlim) * self._n_ctrl A = BSpline.design_matrix(x, self._t, k=self._order) # A is 1 (num. timepoints) x C (num. coeff) @@ -381,14 +427,14 @@ def predict(self, timepoint, **kwargs): predicted = np.squeeze(A @ self._coeff.T) if self._mask is None: - return predicted.reshape(self._shape) + return predicted.reshape(self._datashape) - retval = np.zeros(self._shape, dtype="float32") + retval = np.zeros(self._datashape, dtype="float32") retval[self._mask] = predicted return retval -class DTIModel(BaseModel): +class DTIModel(BaseDWIModel): """A wrapper of :obj:`dipy.reconst.dti.TensorModel`.""" _modelargs = ( @@ -402,7 +448,7 @@ class DTIModel(BaseModel): _model_class = "dipy.reconst.dti.TensorModel" -class DKIModel(BaseModel): +class DKIModel(BaseDWIModel): """A wrapper of :obj:`dipy.reconst.dki.DiffusionKurtosisModel`.""" _modelargs = DTIModel._modelargs From eb4d274eb6350eabbe3f0791d6c2967e318b6575 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Sat, 8 Jun 2024 08:50:45 +0200 Subject: [PATCH 2/5] Apply suggestions from code review Improving the documentation of constants. cc/ @jhlegarreta --- src/eddymotion/model/base.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/eddymotion/model/base.py b/src/eddymotion/model/base.py index ed2b1101..e56962b4 100644 --- a/src/eddymotion/model/base.py +++ b/src/eddymotion/model/base.py @@ -29,26 +29,26 @@ from joblib import Parallel, delayed from eddymotion.exceptions import ModelNotFittedError -#: Minimum value when considering the :math:`S_{0}` DWI signal. DEFAULT_MIN_S0 = 1e-5 +"""Minimum value when considering the :math:`S_{0}` DWI signal.""" -#: Maximum value when considering the :math:`S_{0}` DWI signal. DEFAULT_MAX_S0 = 1.0 +"""Maximum value when considering the :math:`S_{0}` DWI signal.""" -#: Maximum allowed value for the b-value. DEFAULT_MAX_BVALUE = 1000 +"""Maximum allowed value for the b-value.""" -#: b-value lower bound when considering DWI data. DEFAULT_LOWB_THRESHOLD = 50 +"""The lower bound for the b-value so that the orientation is considered a DW volume.""" -#: b-value upper bound when considering DWI data. DEFAULT_HIGHB_THRESHOLD = 10000 +"""A b-value cap for DWI data.""" -#: Percentile clipping value. DEFAULT_CLIP_PERCENTILE = 75 +"""Upper percentile threshold for intensity clipping.""" -#: Time frame tolerance in seconds. DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2 +"""Time frame tolerance in seconds.""" def _exec_fit(model, data, chunk=None): From a30b1dd473b78ff4f96959a6fab0f497ead7d4eb Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Sat, 8 Jun 2024 09:07:32 +0200 Subject: [PATCH 3/5] Update src/eddymotion/model/base.py --- src/eddymotion/model/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/eddymotion/model/base.py b/src/eddymotion/model/base.py index e56962b4..7a85613b 100644 --- a/src/eddymotion/model/base.py +++ b/src/eddymotion/model/base.py @@ -415,7 +415,6 @@ def __init__(self, timepoints=None, xlim=None, n_ctrl=None, order=3, **kwargs): @property def is_fitted(self): return self._coeff is not None - def fit(self, data, **kwargs): """Fit the model.""" from scipy.interpolate import BSpline From a768e7278c911d981cfbc003881fb79fb32025e2 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 12 Jun 2024 12:22:40 -0400 Subject: [PATCH 4/5] Code review of #176 (#6) * enh: revise code * sty: ruff format --- src/eddymotion/model/base.py | 321 +++++++++++++++++++---------------- test/test_model.py | 19 ++- 2 files changed, 191 insertions(+), 149 deletions(-) diff --git a/src/eddymotion/model/base.py b/src/eddymotion/model/base.py index 7a85613b..5b2eff7c 100644 --- a/src/eddymotion/model/base.py +++ b/src/eddymotion/model/base.py @@ -27,6 +27,7 @@ import numpy as np from dipy.core.gradients import gradient_table from joblib import Parallel, delayed + from eddymotion.exceptions import ModelNotFittedError DEFAULT_MIN_S0 = 1e-5 @@ -56,9 +57,9 @@ def _exec_fit(model, data, chunk=None): return retval, chunk -def _exec_predict_dwi(model, gradient, chunk=None, **kwargs): +def _exec_predict(model, chunk=None, **kwargs): """Propagate model parameters and call predict.""" - return np.squeeze(model.predict(gradient, S0=kwargs.pop("S0", None))), chunk + return np.squeeze(model.predict(**kwargs)), chunk class ModelFactory: @@ -82,7 +83,7 @@ def init(model="DTI", **kwargs): """ if model.lower() in ("s0", "b0"): - return TrivialB0Model(S0=kwargs.pop("S0")) + return TrivialB0Model(S0=kwargs.pop("S0"), gtab=kwargs.pop("gtab")) if model.lower() in ("avg", "average", "mean"): return AverageDWModel(**kwargs) @@ -111,60 +112,142 @@ class BaseModel: "_models", "_datashape", "_is_fitted", + "_modelargs", ) - _modelargs = () def __init__(self, mask=None, **kwargs): """Base initialization.""" - self._model = None + # Keep model state + self._model = None # "Main" model + self._models = None # For parallel (chunked) execution self._is_fitted = False # Setup brain mask self._mask = mask self._datashape = None - self._models = None self._is_fitted = False + self._modelargs = () + @property def is_fitted(self): return self._is_fitted - def fit(self, data, n_jobs=None, **kwargs): - """Fit the model chunk-by-chunk asynchronously""" - n_jobs = n_jobs or 1 + def fit(self, data, **kwargs): + """Abstract member signature of fit().""" + raise NotImplementedError("Cannot call fit() on a BaseModel instance.") - self._datashape = data.shape + def predict(self, *args, **kwargs): + """Abstract member signature of predict().""" + raise NotImplementedError("Cannot call predict() on a BaseModel instance.") - # Select voxels within mask or just unravel 3D if no mask - data = ( - data[self._mask, ...] if self._mask is not None else data.reshape(-1, data.shape[-1]) - ) + +class PETModel(BaseModel): + """A PET imaging realignment model based on B-Spline approximation.""" + + __slots__ = ("_t", "_x", "_xlim", "_order", "_coeff", "_n_ctrl") + + def __init__(self, timepoints=None, xlim=None, n_ctrl=None, order=3, **kwargs): + """ + Create the B-Spline interpolating matrix. + + Parameters: + ----------- + timepoints : :obj:`list` + The timing (in sec) of each PET volume. + E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330., + 420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]`` + + n_ctrl : :obj:`int` + Number of B-Spline control points. If `None`, then one control point every + six timepoints will be used. The less control points, the smoother is the + model. + + """ + super.__init__(**kwargs) + + if timepoints is None or xlim is None: + raise TypeError("timepoints must be provided in initialization") + + self._order = order + + self._x = np.array(timepoints, dtype="float32") + self._xlim = xlim + + if self._x[0] < DEFAULT_TIMEFRAME_MIDPOINT_TOL: + raise ValueError("First frame midpoint should not be zero or negative") + if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL): + raise ValueError("Last frame midpoint should not be equal or greater than duration") + + # Calculate index coordinates in the B-Spline grid + self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1 + + # B-Spline knots + self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32") + + self._coeff = None + + @property + def is_fitted(self): + return self._coeff is not None + + def fit(self, data, **kwargs): + """Fit the model.""" + from scipy.interpolate import BSpline + from scipy.sparse.linalg import cg + + n_jobs = kwargs.pop("n_jobs", None) or 1 + + timepoints = kwargs.get("timepoints", None) or self._x + x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl + + self._datashape = data.shape[:3] + + # Convert data into V (voxels) x T (timepoints) + data = data.reshape((-1, data.shape[-1])) if self._mask is None else data[self._mask] + + # A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding) + A = BSpline.design_matrix(x, self._t, k=self._order) + AT = A.T + ATdotA = AT @ A # One single CPU - linear execution (full model) if n_jobs == 1: - self._model, _ = _exec_fit(self._model, data) + self._coeff = np.array([cg(ATdotA, AT @ v)[0] for v in data]) return - # Split data into chunks of group of slices - data_chunks = np.array_split(data, n_jobs) - - self._models = [None] * n_jobs - # Parallelize process with joblib with Parallel(n_jobs=n_jobs) as executor: - results = executor( - delayed(_exec_fit)(self._model, dchunk, i) for i, dchunk in enumerate(data_chunks) - ) - for submodel, index in results: - self._models[index] = submodel + results = executor(delayed(cg)(ATdotA, AT @ v) for v in data) - self._is_fitted = True - self._model = None # Preempt further actions on the model + self._coeff = np.array([r[0] for r in results]) - def predict(self, *args, **kwargs): - pass + def predict(self, index=None, **kwargs): + """Return the corrected volume using B-spline interpolation.""" + from scipy.interpolate import BSpline + + if index is None: + raise ValueError("A timepoint index to be simulated must be provided.") + + if not self._is_fitted: + raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") + + # Project sample timing into B-Spline coordinates + x = (index / self._xlim) * self._n_ctrl + A = BSpline.design_matrix(x, self._t, k=self._order) + + # A is 1 (num. timepoints) x C (num. coeff) + # self._coeff is V (num. voxels) x K - 4 + predicted = np.squeeze(A @ self._coeff.T) + + if self._mask is None: + return predicted.reshape(self._datashape) + + retval = np.zeros(self._datashape, dtype="float32") + retval[self._mask] = predicted + return retval class BaseDWIModel(BaseModel): @@ -174,6 +257,8 @@ class BaseDWIModel(BaseModel): "_gtab", "_S0", "_b_max", + "_model_class", # Defining a model class, DIPY models are instantiated automagically + "_modelargs", ) def __init__(self, gtab, S0=None, b_max=None, **kwargs): @@ -188,6 +273,7 @@ def __init__(self, gtab, S0=None, b_max=None, **kwargs): :math:`S_{0}` signal. b_max : :obj:`int` Maximum value to cap b-values. + """ super().__init__(**kwargs) @@ -215,25 +301,64 @@ def __init__(self, gtab, S0=None, b_max=None, **kwargs): kwargs = {k: v for k, v in kwargs.items() if k in self._modelargs} + # DIPY models (or one with a fully-compliant interface) model_str = getattr(self, "_model_class", None) - if not model_str: - raise TypeError("No model defined") + if model_str: + from importlib import import_module + + module_name, class_name = model_str.rsplit(".", 1) + self._model = getattr( + import_module(module_name), + class_name, + )(_rasb2dipy(gtab), **kwargs) - from importlib import import_module + def fit(self, data, n_jobs=None, **kwargs): + """Fit the model chunk-by-chunk asynchronously""" + n_jobs = n_jobs or 1 - module_name, class_name = model_str.rsplit(".", 1) - self._model = getattr(import_module(module_name), class_name)(_rasb2dipy(gtab), **kwargs) + self._datashape = data.shape - def predict(self, index, **kwargs): + # Select voxels within mask or just unravel 3D if no mask + data = ( + data[self._mask, ...] if self._mask is not None else data.reshape(-1, data.shape[-1]) + ) + + # One single CPU - linear execution (full model) + if n_jobs == 1: + self._model, _ = _exec_fit(self._model, data) + return + + # Split data into chunks of group of slices + data_chunks = np.array_split(data, n_jobs) + + self._models = [None] * n_jobs + + # Parallelize process with joblib + with Parallel(n_jobs=n_jobs) as executor: + results = executor( + delayed(_exec_fit)(self._model, dchunk, i) for i, dchunk in enumerate(data_chunks) + ) + for submodel, index in results: + self._models[index] = submodel + + self._is_fitted = True + self._model = None # Preempt further actions on the model + + def predict(self, gradient=None, **kwargs): """Predict asynchronously chunk-by-chunk the diffusion signal.""" + if gradient is None: + raise ValueError("A gradient to be simulated (b-vector, b-value) must be provided") + if not self._is_fitted: raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") - if self._b_max is not None: - index[-1] = min(index[-1], self._b_max) + gradient = np.array(gradient) # Tuples are unmutable + + # Cap the b-value if b_max is defined + gradient[-1] = min(gradient[-1], self._b_max or gradient[-1]) - self._gtab = _rasb2dipy(self._gtab) + self._gtab = _rasb2dipy(gradient) S0 = None if self._S0 is not None: @@ -246,7 +371,7 @@ def predict(self, index, **kwargs): n_models = len(self._models) if self._model is None and self._models else 1 if n_models == 1: - predicted, _ = _exec_predict_dwi(self._model, self._gtab, S0=S0, **kwargs) + predicted, _ = _exec_predict(self._model, **(kwargs | {"gtab": self._gtab, "S0": S0})) else: S0 = np.array_split(S0, n_models) if S0 is not None else [None] * n_models @@ -255,7 +380,11 @@ def predict(self, index, **kwargs): # Parallelize process with joblib with Parallel(n_jobs=n_models) as executor: results = executor( - delayed(_exec_predict_dwi)(model, self._gtab, S0=S0[i], chunk=i, **kwargs) + delayed(_exec_predict)( + model, + chunk=i, + **(kwargs | {"gtab": self._gtab, "S0": S0[i]}), + ) for i, model in enumerate(self._models) ) for subprediction, index in results: @@ -332,10 +461,14 @@ def __init__(self, **kwargs): def fit(self, data, **kwargs): """Calculate the average.""" + + if (gtab := kwargs.pop("gtab", None)) is None: + raise ValueError("A gradient table must be provided.") + # Select the interval of b-values for which DWIs will be averaged b_mask = ( - ((self._gtab[3] >= self._th_low) & (self._gtab[3] <= self._th_high)) - if self._gtab is not None + ((gtab[3] >= self._th_low) & (gtab[3] <= self._th_high)) + if gtab is not None else np.ones((data.shape[-1],), dtype=bool) ) shells = data[..., b_mask] @@ -358,7 +491,7 @@ def fit(self, data, **kwargs): def is_fitted(self): return self._is_fitted - def predict(self, gradient, **kwargs): + def predict(self, *_, **kwargs): """Return the average map.""" if not self._is_fitted: @@ -367,108 +500,6 @@ def predict(self, gradient, **kwargs): return self._data -class PETModel(BaseModel): - """A PET imaging realignment model based on B-Spline approximation.""" - - __slots__ = ("_t", "_x", "_xlim", "_order", "_coeff", "_n_ctrl") - - def __init__(self, timepoints=None, xlim=None, n_ctrl=None, order=3, **kwargs): - """ - Create the B-Spline interpolating matrix. - - Parameters: - ----------- - timepoints : :obj:`list` - The timing (in sec) of each PET volume. - E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330., - 420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]`` - - n_ctrl : :obj:`int` - Number of B-Spline control points. If `None`, then one control point every - six timepoints will be used. The less control points, the smoother is the - model. - - """ - super.__init__(**kwargs) - - if timepoints is None or xlim is None: - raise TypeError("timepoints must be provided in initialization") - - self._order = order - - self._x = np.array(timepoints, dtype="float32") - self._xlim = xlim - - if self._x[0] < DEFAULT_TIMEFRAME_MIDPOINT_TOL: - raise ValueError("First frame midpoint should not be zero or negative") - if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL): - raise ValueError("Last frame midpoint should not be equal or greater than duration") - - # Calculate index coordinates in the B-Spline grid - self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1 - - # B-Spline knots - self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32") - - self._coeff = None - - @property - def is_fitted(self): - return self._coeff is not None - def fit(self, data, **kwargs): - """Fit the model.""" - from scipy.interpolate import BSpline - from scipy.sparse.linalg import cg - - n_jobs = kwargs.pop("n_jobs", None) or 1 - - timepoints = kwargs.get("timepoints", None) or self._x - x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl - - self._datashape = data.shape[:3] - - # Convert data into V (voxels) x T (timepoints) - data = data.reshape((-1, data.shape[-1])) if self._mask is None else data[self._mask] - - # A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding) - A = BSpline.design_matrix(x, self._t, k=self._order) - AT = A.T - ATdotA = AT @ A - - # One single CPU - linear execution (full model) - if n_jobs == 1: - self._coeff = np.array([cg(ATdotA, AT @ v)[0] for v in data]) - return - - # Parallelize process with joblib - with Parallel(n_jobs=n_jobs) as executor: - results = executor(delayed(cg)(ATdotA, AT @ v) for v in data) - - self._coeff = np.array([r[0] for r in results]) - - def predict(self, index, **kwargs): - """Return the corrected volume using B-spline interpolation.""" - from scipy.interpolate import BSpline - - if not self._is_fitted: - raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") - - # Project sample timing into B-Spline coordinates - x = (index / self._xlim) * self._n_ctrl - A = BSpline.design_matrix(x, self._t, k=self._order) - - # A is 1 (num. timepoints) x C (num. coeff) - # self._coeff is V (num. voxels) x K - 4 - predicted = np.squeeze(A @ self._coeff.T) - - if self._mask is None: - return predicted.reshape(self._datashape) - - retval = np.zeros(self._datashape, dtype="float32") - retval[self._mask] = predicted - return retval - - class DTIModel(BaseDWIModel): """A wrapper of :obj:`dipy.reconst.dti.TensorModel`.""" diff --git a/test/test_model.py b/test/test_model.py index c5f319bf..7c7a906f 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -29,22 +29,32 @@ from eddymotion.data.dmri import DWI from eddymotion.data.splitting import lovo_split from eddymotion.exceptions import ModelNotFittedError +from eddymotion.model.base import DEFAULT_MAX_S0, DEFAULT_MIN_S0 def test_trivial_model(): """Check the implementation of the trivial B0 model.""" + rng = np.random.default_rng(1234) + # Should not allow initialization without a B0 with pytest.raises(ValueError): model.TrivialB0Model(gtab=np.eye(4)) - _S0 = np.random.normal(size=(10, 10, 10)) + _S0 = rng.normal(size=(2, 2, 2)) + + _clipped_S0 = np.clip( + _S0.astype("float32") / _S0.max(), + a_min=DEFAULT_MIN_S0, + a_max=DEFAULT_MAX_S0, + ) - tmodel = model.TrivialB0Model(gtab=np.eye(4), S0=_S0) + tmodel = model.TrivialB0Model(gtab=np.eye(4), S0=_clipped_S0) - assert tmodel.fit() is None + data = None + assert tmodel.fit(data) is None - assert np.all(_S0 == tmodel.predict((1, 0, 0))) + assert np.all(_clipped_S0 == tmodel.predict((1, 0, 0))) def test_average_model(): @@ -106,6 +116,7 @@ def test_two_initialisations(datadir): # Direct initialisation model1 = model.AverageDWModel( + gtab=data_train[1], S0=dmri_dataset.bzero, th_low=100, th_high=1000, From d5378330f3b00aa3646ef44958b90b4b977f2c6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Thu, 13 Jun 2024 14:02:24 +0200 Subject: [PATCH 5/5] BUG: Do not overwrite the gradient table in prediction Do not overwrite the gradient table in prediction. Co-authored-by: Oscar Esteban --- src/eddymotion/model/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/eddymotion/model/base.py b/src/eddymotion/model/base.py index 5b2eff7c..6d7653a7 100644 --- a/src/eddymotion/model/base.py +++ b/src/eddymotion/model/base.py @@ -358,7 +358,7 @@ def predict(self, gradient=None, **kwargs): # Cap the b-value if b_max is defined gradient[-1] = min(gradient[-1], self._b_max or gradient[-1]) - self._gtab = _rasb2dipy(gradient) + gradient = _rasb2dipy(gradient) S0 = None if self._S0 is not None: @@ -371,7 +371,7 @@ def predict(self, gradient=None, **kwargs): n_models = len(self._models) if self._model is None and self._models else 1 if n_models == 1: - predicted, _ = _exec_predict(self._model, **(kwargs | {"gtab": self._gtab, "S0": S0})) + predicted, _ = _exec_predict(self._model, **(kwargs | {"gtab": gradient, "S0": S0})) else: S0 = np.array_split(S0, n_models) if S0 is not None else [None] * n_models @@ -383,7 +383,7 @@ def predict(self, gradient=None, **kwargs): delayed(_exec_predict)( model, chunk=i, - **(kwargs | {"gtab": self._gtab, "S0": S0[i]}), + **(kwargs | {"gtab": gradient, "S0": S0[i]}), ) for i, model in enumerate(self._models) )